summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-05-10 11:53:08 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-05-10 11:53:08 +0200
commit9e3e7f22892c3504e6de9827af0fabd9b4b098ea (patch)
tree0d16b63dae17f73bccfe720d9c03c25166856497
parent113de64ac73f529af14da3e0aff12b05c2edd3a7 (diff)
API/Server: Correctly handle multiple protocol preference
-rw-r--r--internal/api.go17
-rw-r--r--internal/server.go40
-rw-r--r--internal/wireguard.go18
-rw-r--r--state.go2
4 files changed, 57 insertions, 20 deletions
diff --git a/internal/api.go b/internal/api.go
index b615976..45e025b 100644
--- a/internal/api.go
+++ b/internal/api.go
@@ -96,23 +96,33 @@ func APIInfo(server Server) error {
return nil
}
-func APIConnectWireguard(server Server, profile_id string, pubkey string) (string, string, error) {
+func APIConnectWireguard(server Server, profile_id string, pubkey string, supportsOpenVPN bool) (string, string, string, error) {
headers := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
"accept": {"application/x-wireguard-profile"},
}
+ if supportsOpenVPN {
+ headers.Add("accept", "application/x-openvpn-profile")
+ }
+
urlForm := url.Values{
"profile_id": {profile_id},
"public_key": {pubkey},
}
header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
- return "", "", &APIConnectWireguardError{Err: connectErr}
+ return "", "", "", &APIConnectWireguardError{Err: connectErr}
}
expires := header.Get("expires")
- return string(connectBody), expires, nil
+ contentType := header.Get("content-type")
+
+ content := "openvpn"
+ if contentType == "application/x-wireguard-profile" {
+ content = "wireguard"
+ }
+ return string(connectBody), content, expires, nil
}
func APIConnectOpenVPN(server Server, profile_id string) (string, string, error) {
@@ -124,6 +134,7 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, string, error)
urlForm := url.Values{
"profile_id": {profile_id},
}
+
header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
return "", "", &APIConnectOpenVPNError{Err: connectErr}
diff --git a/internal/server.go b/internal/server.go
index 4452c6d..c9e31af 100644
--- a/internal/server.go
+++ b/internal/server.go
@@ -273,15 +273,23 @@ func getEndpoints(baseURL string) (*ServerEndpoints, error) {
return endpoints, nil
}
-func (profile *ServerProfile) supportsWireguard() bool {
+func (profile *ServerProfile) supportsProtocol(protocol string) bool {
for _, proto := range profile.VPNProtoList {
- if proto == "wireguard" {
+ if proto == protocol {
return true
}
}
return false
}
+func (profile *ServerProfile) supportsWireguard() bool {
+ return profile.supportsProtocol("wireguard")
+}
+
+func (profile *ServerProfile) supportsOpenVPN() bool {
+ return profile.supportsProtocol("openvpn")
+}
+
func getCurrentProfile(server Server) (*ServerProfile, error) {
base, baseErr := server.GetBase()
@@ -297,7 +305,7 @@ func getCurrentProfile(server Server) (*ServerProfile, error) {
return nil, &ServerGetCurrentProfileNotFoundError{ProfileID: profileID}
}
-func getConfigWithProfile(server Server) (string, error) {
+func getConfigWithProfile(server Server, forceTCP bool) (string, error) {
base, baseErr := server.GetBase()
if baseErr != nil {
@@ -312,8 +320,18 @@ func getConfigWithProfile(server Server) (string, error) {
return "", &ServerGetConfigWithProfileError{Err: profileErr}
}
- if profile.supportsWireguard() {
- return WireguardGetConfig(server)
+ supportsOpenVPN := profile.supportsOpenVPN()
+ supportsWireguard := profile.supportsWireguard()
+
+ // If forceTCP we must be able to get a config with OpenVPN
+ if forceTCP && supportsOpenVPN {
+ return "", &ServerGetConfigForceTCPError{}
+ }
+
+ if supportsWireguard {
+ // A wireguard connect call needs to generate a wireguard key and add it to the config
+ // Also the server could send back an OpenVPN config if it supports OpenVPN
+ return WireguardGetConfig(server, supportsOpenVPN)
}
return OpenVPNGetConfig(server)
}
@@ -331,7 +349,7 @@ func askForProfileID(server Server) error {
return nil
}
-func GetConfig(server Server) (string, error) {
+func GetConfig(server Server, forceTCP bool) (string, error) {
base, baseErr := server.GetBase()
if baseErr != nil {
@@ -363,7 +381,7 @@ func GetConfig(server Server) (string, error) {
if base.Profiles.Current == "" {
base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID
}
- return getConfigWithProfile(server)
+ return getConfigWithProfile(server, forceTCP)
}
profileErr := askForProfileID(server)
@@ -372,7 +390,7 @@ func GetConfig(server Server) (string, error) {
return "", &ServerGetConfigError{Err: profileErr}
}
- return getConfigWithProfile(server)
+ return getConfigWithProfile(server, forceTCP)
}
type ServerGetCurrentProfileNotFoundError struct {
@@ -391,6 +409,12 @@ func (e *ServerGetConfigWithProfileError) Error() string {
return fmt.Sprintf("failed to get config including profile with error %v", e.Err)
}
+type ServerGetConfigForceTCPError struct{}
+
+func (e *ServerGetConfigForceTCPError) Error() string {
+ return fmt.Sprintf("failed to get config, force TCP is on but the server does not support OpenVPN")
+}
+
type ServerGetEndpointsError struct {
Err error
}
diff --git a/internal/wireguard.go b/internal/wireguard.go
index 318e0dc..7f8da38 100644
--- a/internal/wireguard.go
+++ b/internal/wireguard.go
@@ -30,7 +30,7 @@ func wireguardConfigAddKey(config string, key wgtypes.Key) string {
return interface_re.ReplaceAllString(config, to_replace)
}
-func WireguardGetConfig(server Server) (string, error) {
+func WireguardGetConfig(server Server, supportsOpenVPN bool) (string, error) {
base, baseErr := server.GetBase()
if baseErr != nil {
@@ -45,20 +45,22 @@ func WireguardGetConfig(server Server) (string, error) {
}
wireguardPublicKey := wireguardKey.PublicKey().String()
- configWireguard, _, configErr := APIConnectWireguard(server, profile_id, wireguardPublicKey)
+ config, content, _, configErr := APIConnectWireguard(server, profile_id, wireguardPublicKey, supportsOpenVPN)
if configErr != nil {
return "", &WireguardGetConfigError{Err: wireguardErr}
}
- // FIXME: Store expiry
- // This needs the go code a way to identify a connection
- // Use the uuid of the connection e.g. on Linux
- // This needs the client code to call the go code
+ if content == "wireguard" {
+ // FIXME: Store expiry
+ // This needs the go code a way to identify a connection
+ // Use the uuid of the connection e.g. on Linux
+ // This needs the client code to call the go code
- configWireguardKey := wireguardConfigAddKey(configWireguard, wireguardKey)
+ config = wireguardConfigAddKey(config, wireguardKey)
+ }
- return configWireguardKey, nil
+ return config, nil
}
type WireguardGenerateKeyError struct {
diff --git a/state.go b/state.go
index 7f383d9..64146c8 100644
--- a/state.go
+++ b/state.go
@@ -134,7 +134,7 @@ func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, f
state.FSM.GoTransition(internal.REQUEST_CONFIG)
- config, configErr := internal.GetConfig(server)
+ config, configErr := internal.GetConfig(server, forceTCP)
if configErr != nil {
// Go back to no server if possible