diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-05-10 11:53:08 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-05-10 11:53:08 +0200 |
| commit | 9e3e7f22892c3504e6de9827af0fabd9b4b098ea (patch) | |
| tree | 0d16b63dae17f73bccfe720d9c03c25166856497 /internal/server.go | |
| parent | 113de64ac73f529af14da3e0aff12b05c2edd3a7 (diff) | |
API/Server: Correctly handle multiple protocol preference
Diffstat (limited to 'internal/server.go')
| -rw-r--r-- | internal/server.go | 40 |
1 files changed, 32 insertions, 8 deletions
diff --git a/internal/server.go b/internal/server.go index 4452c6d..c9e31af 100644 --- a/internal/server.go +++ b/internal/server.go @@ -273,15 +273,23 @@ func getEndpoints(baseURL string) (*ServerEndpoints, error) { return endpoints, nil } -func (profile *ServerProfile) supportsWireguard() bool { +func (profile *ServerProfile) supportsProtocol(protocol string) bool { for _, proto := range profile.VPNProtoList { - if proto == "wireguard" { + if proto == protocol { return true } } return false } +func (profile *ServerProfile) supportsWireguard() bool { + return profile.supportsProtocol("wireguard") +} + +func (profile *ServerProfile) supportsOpenVPN() bool { + return profile.supportsProtocol("openvpn") +} + func getCurrentProfile(server Server) (*ServerProfile, error) { base, baseErr := server.GetBase() @@ -297,7 +305,7 @@ func getCurrentProfile(server Server) (*ServerProfile, error) { return nil, &ServerGetCurrentProfileNotFoundError{ProfileID: profileID} } -func getConfigWithProfile(server Server) (string, error) { +func getConfigWithProfile(server Server, forceTCP bool) (string, error) { base, baseErr := server.GetBase() if baseErr != nil { @@ -312,8 +320,18 @@ func getConfigWithProfile(server Server) (string, error) { return "", &ServerGetConfigWithProfileError{Err: profileErr} } - if profile.supportsWireguard() { - return WireguardGetConfig(server) + supportsOpenVPN := profile.supportsOpenVPN() + supportsWireguard := profile.supportsWireguard() + + // If forceTCP we must be able to get a config with OpenVPN + if forceTCP && supportsOpenVPN { + return "", &ServerGetConfigForceTCPError{} + } + + if supportsWireguard { + // A wireguard connect call needs to generate a wireguard key and add it to the config + // Also the server could send back an OpenVPN config if it supports OpenVPN + return WireguardGetConfig(server, supportsOpenVPN) } return OpenVPNGetConfig(server) } @@ -331,7 +349,7 @@ func askForProfileID(server Server) error { return nil } -func GetConfig(server Server) (string, error) { +func GetConfig(server Server, forceTCP bool) (string, error) { base, baseErr := server.GetBase() if baseErr != nil { @@ -363,7 +381,7 @@ func GetConfig(server Server) (string, error) { if base.Profiles.Current == "" { base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID } - return getConfigWithProfile(server) + return getConfigWithProfile(server, forceTCP) } profileErr := askForProfileID(server) @@ -372,7 +390,7 @@ func GetConfig(server Server) (string, error) { return "", &ServerGetConfigError{Err: profileErr} } - return getConfigWithProfile(server) + return getConfigWithProfile(server, forceTCP) } type ServerGetCurrentProfileNotFoundError struct { @@ -391,6 +409,12 @@ func (e *ServerGetConfigWithProfileError) Error() string { return fmt.Sprintf("failed to get config including profile with error %v", e.Err) } +type ServerGetConfigForceTCPError struct{} + +func (e *ServerGetConfigForceTCPError) Error() string { + return fmt.Sprintf("failed to get config, force TCP is on but the server does not support OpenVPN") +} + type ServerGetEndpointsError struct { Err error } |
