diff options
| -rw-r--r-- | cmd/cli/main.go | 2 | ||||
| -rw-r--r-- | internal/server/api.go | 15 | ||||
| -rw-r--r-- | internal/server/common.go | 19 | ||||
| -rw-r--r-- | state_test.go | 46 |
4 files changed, 67 insertions, 15 deletions
diff --git a/cmd/cli/main.go b/cmd/cli/main.go index ab5d659..8bc083d 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -95,7 +95,7 @@ func getConfig(state *eduvpn.VPNState, url string, serverType ServerTypes) (stri if !strings.HasPrefix(url, "http") { url = "https://" + url } - // Force TCP is set to False + // Prefer TCP is set to False if serverType == ServerTypeInstituteAccess { return state.GetConfigInstituteAccess(url, false) } else if serverType == ServerTypeCustom { diff --git a/internal/server/api.go b/internal/server/api.go index 4648a8f..0c1a0f5 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -131,10 +131,19 @@ func APIInfo(server Server) error { return nil } +// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1 +func GetPreferTCPString(preferTCP bool) string { + if preferTCP { + return "yes" + } + return "no" +} + func APIConnectWireguard( server Server, profile_id string, pubkey string, + preferTCP bool, supportsOpenVPN bool, ) (string, string, time.Time, error) { errorMessage := "failed obtaining a WireGuard configuration" @@ -143,6 +152,8 @@ func APIConnectWireguard( "accept": {"application/x-wireguard-profile"}, } + // This profile also supports OpenVPN + // Indicate that we also accept OpenVPN profiles if supportsOpenVPN { headers.Add("accept", "application/x-openvpn-profile") } @@ -150,6 +161,7 @@ func APIConnectWireguard( urlForm := url.Values{ "profile_id": {profile_id}, "public_key": {pubkey}, + "prefer_tcp": {GetPreferTCPString(preferTCP)}, } header, connectBody, connectErr := apiAuthorizedRetry( server, @@ -180,7 +192,7 @@ func APIConnectWireguard( return string(connectBody), content, pTime, nil } -func APIConnectOpenVPN(server Server, profile_id string) (string, time.Time, error) { +func APIConnectOpenVPN(server Server, profile_id string, preferTCP bool) (string, time.Time, error) { errorMessage := "failed obtaining an OpenVPN configuration" headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, @@ -189,6 +201,7 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, time.Time, err urlForm := url.Values{ "profile_id": {profile_id}, + "prefer_tcp": {GetPreferTCPString(preferTCP)}, } header, connectBody, connectErr := apiAuthorizedRetry( diff --git a/internal/server/common.go b/internal/server/common.go index 36dba32..6f57c7f 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -324,7 +324,7 @@ func getCurrentProfile(server Server) (*ServerProfile, error) { } } -func wireguardGetConfig(server Server, supportsOpenVPN bool) (string, string, error) { +func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (string, string, error) { errorMessage := "failed getting server WireGuard configuration" base, baseErr := server.GetBase() @@ -344,6 +344,7 @@ func wireguardGetConfig(server Server, supportsOpenVPN bool) (string, string, er server, profile_id, wireguardPublicKey, + preferTCP, supportsOpenVPN, ) @@ -366,7 +367,7 @@ func wireguardGetConfig(server Server, supportsOpenVPN bool) (string, string, er return config, content, nil } -func openVPNGetConfig(server Server) (string, string, error) { +func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) { errorMessage := "failed getting server OpenVPN configuration" base, baseErr := server.GetBase() @@ -374,7 +375,7 @@ func openVPNGetConfig(server Server) (string, string, error) { return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } profile_id := base.Profiles.Current - configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profile_id) + configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profile_id, preferTCP) // Store start and end time base.StartTime = util.GetCurrentTime() @@ -433,14 +434,6 @@ func GetConfig(server Server, preferTCP bool) (string, string, error) { supportsOpenVPN := profile.supportsOpenVPN() supportsWireguard := profile.supportsWireguard() - // If preferTCP we must be able to get a config with OpenVPN - if preferTCP && supportsOpenVPN { - return "", "", &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &ServerGetConfigForceTCPError{}, - } - } - var config string var configType string var configErr error @@ -448,9 +441,9 @@ func GetConfig(server Server, preferTCP bool) (string, string, error) { if supportsWireguard { // A wireguard connect call needs to generate a wireguard key and add it to the config // Also the server could send back an OpenVPN config if it supports OpenVPN - config, configType, configErr = wireguardGetConfig(server, supportsOpenVPN) + config, configType, configErr = wireguardGetConfig(server, preferTCP, supportsOpenVPN) } else { - config, configType, configErr = openVPNGetConfig(server) + config, configType, configErr = openVPNGetConfig(server, preferTCP) } if configErr != nil { diff --git a/state_test.go b/state_test.go index 5647b12..82d88e2 100644 --- a/state_test.go +++ b/state_test.go @@ -350,3 +350,49 @@ func Test_invalid_profile_corrected(t *testing.T) { ) } } + +// Test if an invalid profile will be corrected +func Test_prefer_tcp(t *testing.T) { + serverURI := getServerURI(t) + state := &VPNState{} + + ensureLocalWellKnown() + + registerErr := state.Register( + "org.eduvpn.app.linux", + "configsprefertcp", + func(old FSMStateID, new FSMStateID, data interface{}) { + stateCallback(t, old, new, data, state) + }, + false, + ) + if registerErr != nil { + t.Fatalf("Register error: %v", registerErr) + } + + // get a config with preferTCP set to true + config, configType, configErr := state.GetConfigCustomServer(serverURI, true) + + // Test server should accept prefer TCP! + if configType != "openvpn" { + t.Fatalf("Invalid protocol for prefer TCP, got: WireGuard, want: OpenVPN") + } + + if configErr != nil { + t.Fatalf("Config error: %v", configErr) + } + + if !strings.HasSuffix(config, "remote eduvpnserver 1194 tcp\nremote eduvpnserver 1194 udp") { + t.Fatalf("Suffix for prefer TCP is not in the right order for config: %s", config) + } + + // get a config with preferTCP set to false + config, configType, configErr = state.GetConfigCustomServer(serverURI, false) + if configErr != nil { + t.Fatalf("Config error: %v", configErr) + } + + if configType == "openvpn" && !strings.HasSuffix(config, "remote eduvpnserver 1194 udp\nremote eduvpnserver 1194 tcp") { + t.Fatalf("Suffix for disable prefer TCP is not in the right order for config: %s", config) + } +} |
