diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-10-24 14:05:45 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-10-24 14:05:45 +0200 |
| commit | 01cbb80b300f92c3456d3b2965630c0783607905 (patch) | |
| tree | ad162c4bc4bf5ebafac0d4ba44147437604c82d3 | |
| parent | 56f084389a3eb6b34df86af347ce60acdeb6106b (diff) | |
Client + Server + Exports: Implement optional WireGuard support
| -rw-r--r-- | client/client.go | 17 | ||||
| -rw-r--r-- | client/server.go | 4 | ||||
| -rw-r--r-- | exports/exports.go | 11 | ||||
| -rw-r--r-- | exports/servers.go | 3 | ||||
| -rw-r--r-- | internal/server/common.go | 58 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/loader.py | 4 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/main.py | 12 |
7 files changed, 95 insertions, 14 deletions
diff --git a/client/client.go b/client/client.go index 5df2255..541883f 100644 --- a/client/client.go +++ b/client/client.go @@ -57,6 +57,9 @@ type Client struct { // The config Config config.Config `json:"-"` + // Whether or not this client supports WireGuard + SupportsWireguard bool `json:"-"` + // Whether to enable debugging Debug bool `json:"-"` } @@ -99,6 +102,11 @@ func (client *Client) Register( // Initialize the FSM client.FSM = newFSM(stateCallback, directory, debug) + + // By default we support wireguard + client.SupportsWireguard = true + + // Debug only if given client.Debug = debug // Initialize the Config @@ -148,11 +156,11 @@ func (client *Client) Deregister() { // askProfile asks the user for a profile by moving the FSM to the ASK_PROFILE state. func (client *Client) askProfile(chosenServer server.Server) error { - base, baseErr := chosenServer.GetBase() - if baseErr != nil { - return types.NewWrappedError("failed asking for profiles", baseErr) + profiles, profilesErr := server.GetValidProfiles(chosenServer, client.SupportsWireguard) + if profilesErr != nil { + return types.NewWrappedError("failed asking for profiles", profilesErr) } - client.FSM.GoTransitionWithData(STATE_ASK_PROFILE, &base.Profiles, false) + client.FSM.GoTransitionWithData(STATE_ASK_PROFILE, profiles, false) return nil } @@ -209,3 +217,4 @@ type LetsConnectNotSupportedError struct{} func (e LetsConnectNotSupportedError) Error() string { return "Any operation that involves discovery is not allowed with the Let's Connect! client" } + diff --git a/client/server.go b/client/server.go index 9ff895a..2274378 100644 --- a/client/server.go +++ b/client/server.go @@ -22,7 +22,7 @@ func (client *Client) getConfigAuth( } client.FSM.GoTransition(STATE_REQUEST_CONFIG) - validProfile, profileErr := server.HasValidProfile(chosenServer) + validProfile, profileErr := server.HasValidProfile(chosenServer, client.SupportsWireguard) if profileErr != nil { return "", "", profileErr } @@ -36,7 +36,7 @@ func (client *Client) getConfigAuth( } // We return the error otherwise we wrap it too much - return server.GetConfig(chosenServer, preferTCP) + return server.GetConfig(chosenServer, client.SupportsWireguard, preferTCP) } // retryConfigAuth retries the getConfigAuth function if the tokens are invalid. diff --git a/exports/exports.go b/exports/exports.go index 6407da4..05f2462 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -407,6 +407,17 @@ func InFSMState(name *C.char, checkState C.int) C.int { return C.int(0) } +//export SetSupportWireguard +func SetSupportWireguard(name *C.char, support C.int) *C.error { + nameStr := C.GoString(name) + state, stateErr := GetVPNState(nameStr) + if stateErr != nil { + return getError(stateErr) + } + state.SupportsWireguard = support == 1 + return nil +} + //export FreeString func FreeString(addr *C.char) { C.free(unsafe.Pointer(addr)) diff --git a/exports/servers.go b/exports/servers.go index 57e1be1..36763b4 100644 --- a/exports/servers.go +++ b/exports/servers.go @@ -181,7 +181,8 @@ func getCPtrServer(state *client.Client, base *client.ServerBase) *C.server { server.total_support_contact, server.support_contact = getCPtrListStrings( base.SupportContact, ) - server.profiles = getCPtrProfiles(&base.Profiles) + profiles := base.GetValidProfiles(state.SupportsWireguard) + server.profiles = getCPtrProfiles(&profiles) // No endtime is given if we get servers when it has been partially initialised if base.EndTime.IsZero() { server.expire_time = C.ulonglong(0) diff --git a/internal/server/common.go b/internal/server/common.go index e70bee0..443a925 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -1,6 +1,7 @@ package server import ( + "errors" "fmt" "time" @@ -56,11 +57,13 @@ type ServerProfile struct { DefaultGateway bool `json:"default_gateway"` } +type ServerProfileListInfo struct { + ProfileList []ServerProfile `json:"profile_list"` +} + type ServerProfileInfo struct { Current string `json:"current_profile"` - Info struct { - ProfileList []ServerProfile `json:"profile_list"` - } `json:"info"` + Info ServerProfileListInfo `json:"info"` } func (info ServerProfileInfo) GetCurrentProfileIndex() int { @@ -325,6 +328,33 @@ func getCurrentProfile(server Server) (*ServerProfile, error) { ) } +func (base *ServerBase) GetValidProfiles(clientSupportsWireguard bool) ServerProfileInfo { + var validProfiles []ServerProfile + for _, profile := range base.Profiles.Info.ProfileList { + // Not a valid profile because it does not support openvpn + // Also the client does not support wireguard + if !profile.supportsOpenVPN() && !clientSupportsWireguard { + continue + } + validProfiles = append(validProfiles, profile) + } + return ServerProfileInfo{Current: base.Profiles.Current, Info: ServerProfileListInfo{ProfileList: validProfiles}} +} + +func GetValidProfiles(server Server, clientSupportsWireguard bool) (*ServerProfileInfo, error) { + errorMessage := "failed to get valid profiles" + // No error wrapping here otherwise we wrap it too much + base, baseErr := server.GetBase() + if baseErr != nil { + return nil, types.NewWrappedError(errorMessage, baseErr) + } + profiles := base.GetValidProfiles(clientSupportsWireguard) + if len(profiles.Info.ProfileList) == 0 { + return nil, types.NewWrappedError(errorMessage, errors.New("no profiles found with supported protocols")) + } + return &profiles, nil +} + func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (string, string, error) { errorMessage := "failed getting server WireGuard configuration" base, baseErr := server.GetBase() @@ -389,7 +419,7 @@ func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) { return configOpenVPN, "openvpn", nil } -func HasValidProfile(server Server) (bool, error) { +func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) { errorMessage := "failed has valid profile check" // Get new profiles using the info call @@ -418,13 +448,22 @@ func HasValidProfile(server Server) (bool, error) { if base.Profiles.Current == "" { base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID } + profile, profileErr := getCurrentProfile(server) + // shouldn't happen + if profileErr != nil { + return false, types.NewWrappedError(errorMessage, profileErr) + } + // Profile does not support OpenVPN but the client also doesn't support WireGuard + if !profile.supportsOpenVPN() && !clientSupportsWireguard { + return false, nil + } return true, nil } return false, nil } -func GetConfig(server Server, preferTCP bool) (string, string, error) { +func GetConfig(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) { errorMessage := "failed getting an OpenVPN/WireGuard configuration" profile, profileErr := getCurrentProfile(server) @@ -433,18 +472,23 @@ func GetConfig(server Server, preferTCP bool) (string, string, error) { } supportsOpenVPN := profile.supportsOpenVPN() - supportsWireguard := profile.supportsWireguard() + supportsWireguard := profile.supportsWireguard() && clientSupportsWireguard var config string var configType string var configErr error + // The config supports wireguard, do a specialized request with a public key if supportsWireguard { // A wireguard connect call needs to generate a wireguard key and add it to the config // Also the server could send back an OpenVPN config if it supports OpenVPN config, configType, configErr = wireguardGetConfig(server, preferTCP, supportsOpenVPN) - } else { + // The config only supports OpenVPN + } else if supportsOpenVPN { config, configType, configErr = openVPNGetConfig(server, preferTCP) + // The config supports no available protocol because the profile only supports WireGuard but the client doesn't + } else { + return "", "", types.NewWrappedError(errorMessage, errors.New("No supported protocol found")) } if configErr != nil { diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py index fc3d090..9463ab1 100644 --- a/wrappers/python/eduvpn_common/loader.py +++ b/wrappers/python/eduvpn_common/loader.py @@ -152,4 +152,8 @@ def initialize_functions(lib: CDLL) -> None: c_char_p, c_char_p, ], c_void_p + lib.SetSupportWireguard.argtypes, lib.SetSupportWireguard.restype = [ + c_char_p, + c_int, + ], c_void_p lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index 69d00db..1467adb 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -440,6 +440,18 @@ class EduVPN(object): if renew_err: raise renew_err + def set_support_wireguard(self, support: bool) -> None: + """Indicates whether or not the OS supports WireGuard connections. + + :param support: bool: whether or not wireguard is supported + + :raises WrappedError: An error by the Go library + """ + support_err = self.go_function(self.lib.SetSupportWireguard, support) + + if support_err: + raise support_err + def should_renew_button(self) -> bool: """Whether or not the UI should show the renew button |
