diff options
| -rw-r--r-- | internal/api.go | 17 | ||||
| -rw-r--r-- | internal/server.go | 40 | ||||
| -rw-r--r-- | internal/wireguard.go | 18 | ||||
| -rw-r--r-- | state.go | 2 |
4 files changed, 57 insertions, 20 deletions
diff --git a/internal/api.go b/internal/api.go index b615976..45e025b 100644 --- a/internal/api.go +++ b/internal/api.go @@ -96,23 +96,33 @@ func APIInfo(server Server) error { return nil } -func APIConnectWireguard(server Server, profile_id string, pubkey string) (string, string, error) { +func APIConnectWireguard(server Server, profile_id string, pubkey string, supportsOpenVPN bool) (string, string, string, error) { headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-wireguard-profile"}, } + if supportsOpenVPN { + headers.Add("accept", "application/x-openvpn-profile") + } + urlForm := url.Values{ "profile_id": {profile_id}, "public_key": {pubkey}, } header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { - return "", "", &APIConnectWireguardError{Err: connectErr} + return "", "", "", &APIConnectWireguardError{Err: connectErr} } expires := header.Get("expires") - return string(connectBody), expires, nil + contentType := header.Get("content-type") + + content := "openvpn" + if contentType == "application/x-wireguard-profile" { + content = "wireguard" + } + return string(connectBody), content, expires, nil } func APIConnectOpenVPN(server Server, profile_id string) (string, string, error) { @@ -124,6 +134,7 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, string, error) urlForm := url.Values{ "profile_id": {profile_id}, } + header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { return "", "", &APIConnectOpenVPNError{Err: connectErr} diff --git a/internal/server.go b/internal/server.go index 4452c6d..c9e31af 100644 --- a/internal/server.go +++ b/internal/server.go @@ -273,15 +273,23 @@ func getEndpoints(baseURL string) (*ServerEndpoints, error) { return endpoints, nil } -func (profile *ServerProfile) supportsWireguard() bool { +func (profile *ServerProfile) supportsProtocol(protocol string) bool { for _, proto := range profile.VPNProtoList { - if proto == "wireguard" { + if proto == protocol { return true } } return false } +func (profile *ServerProfile) supportsWireguard() bool { + return profile.supportsProtocol("wireguard") +} + +func (profile *ServerProfile) supportsOpenVPN() bool { + return profile.supportsProtocol("openvpn") +} + func getCurrentProfile(server Server) (*ServerProfile, error) { base, baseErr := server.GetBase() @@ -297,7 +305,7 @@ func getCurrentProfile(server Server) (*ServerProfile, error) { return nil, &ServerGetCurrentProfileNotFoundError{ProfileID: profileID} } -func getConfigWithProfile(server Server) (string, error) { +func getConfigWithProfile(server Server, forceTCP bool) (string, error) { base, baseErr := server.GetBase() if baseErr != nil { @@ -312,8 +320,18 @@ func getConfigWithProfile(server Server) (string, error) { return "", &ServerGetConfigWithProfileError{Err: profileErr} } - if profile.supportsWireguard() { - return WireguardGetConfig(server) + supportsOpenVPN := profile.supportsOpenVPN() + supportsWireguard := profile.supportsWireguard() + + // If forceTCP we must be able to get a config with OpenVPN + if forceTCP && supportsOpenVPN { + return "", &ServerGetConfigForceTCPError{} + } + + if supportsWireguard { + // A wireguard connect call needs to generate a wireguard key and add it to the config + // Also the server could send back an OpenVPN config if it supports OpenVPN + return WireguardGetConfig(server, supportsOpenVPN) } return OpenVPNGetConfig(server) } @@ -331,7 +349,7 @@ func askForProfileID(server Server) error { return nil } -func GetConfig(server Server) (string, error) { +func GetConfig(server Server, forceTCP bool) (string, error) { base, baseErr := server.GetBase() if baseErr != nil { @@ -363,7 +381,7 @@ func GetConfig(server Server) (string, error) { if base.Profiles.Current == "" { base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID } - return getConfigWithProfile(server) + return getConfigWithProfile(server, forceTCP) } profileErr := askForProfileID(server) @@ -372,7 +390,7 @@ func GetConfig(server Server) (string, error) { return "", &ServerGetConfigError{Err: profileErr} } - return getConfigWithProfile(server) + return getConfigWithProfile(server, forceTCP) } type ServerGetCurrentProfileNotFoundError struct { @@ -391,6 +409,12 @@ func (e *ServerGetConfigWithProfileError) Error() string { return fmt.Sprintf("failed to get config including profile with error %v", e.Err) } +type ServerGetConfigForceTCPError struct{} + +func (e *ServerGetConfigForceTCPError) Error() string { + return fmt.Sprintf("failed to get config, force TCP is on but the server does not support OpenVPN") +} + type ServerGetEndpointsError struct { Err error } diff --git a/internal/wireguard.go b/internal/wireguard.go index 318e0dc..7f8da38 100644 --- a/internal/wireguard.go +++ b/internal/wireguard.go @@ -30,7 +30,7 @@ func wireguardConfigAddKey(config string, key wgtypes.Key) string { return interface_re.ReplaceAllString(config, to_replace) } -func WireguardGetConfig(server Server) (string, error) { +func WireguardGetConfig(server Server, supportsOpenVPN bool) (string, error) { base, baseErr := server.GetBase() if baseErr != nil { @@ -45,20 +45,22 @@ func WireguardGetConfig(server Server) (string, error) { } wireguardPublicKey := wireguardKey.PublicKey().String() - configWireguard, _, configErr := APIConnectWireguard(server, profile_id, wireguardPublicKey) + config, content, _, configErr := APIConnectWireguard(server, profile_id, wireguardPublicKey, supportsOpenVPN) if configErr != nil { return "", &WireguardGetConfigError{Err: wireguardErr} } - // FIXME: Store expiry - // This needs the go code a way to identify a connection - // Use the uuid of the connection e.g. on Linux - // This needs the client code to call the go code + if content == "wireguard" { + // FIXME: Store expiry + // This needs the go code a way to identify a connection + // Use the uuid of the connection e.g. on Linux + // This needs the client code to call the go code - configWireguardKey := wireguardConfigAddKey(configWireguard, wireguardKey) + config = wireguardConfigAddKey(config, wireguardKey) + } - return configWireguardKey, nil + return config, nil } type WireguardGenerateKeyError struct { @@ -134,7 +134,7 @@ func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, f state.FSM.GoTransition(internal.REQUEST_CONFIG) - config, configErr := internal.GetConfig(server) + config, configErr := internal.GetConfig(server, forceTCP) if configErr != nil { // Go back to no server if possible |
