diff options
Diffstat (limited to 'client/client_test.go')
| -rw-r--r-- | client/client_test.go | 84 |
1 files changed, 44 insertions, 40 deletions
diff --git a/client/client_test.go b/client/client_test.go index 56c38ff..7077ce4 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "fmt" "net/http" "net/url" @@ -12,6 +13,7 @@ import ( "time" httpw "github.com/eduvpn/eduvpn-common/internal/http" + "github.com/eduvpn/eduvpn-common/types/cookie" "github.com/eduvpn/eduvpn-common/types/protocol" srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" @@ -22,7 +24,7 @@ func getServerURI(t *testing.T) string { if serverURI == "" { t.Skip("Skipping server test as no SERVER_URI env var has been passed") } - serverURI, parseErr := httpw.EnsureValidURL(serverURI) + serverURI, parseErr := httpw.EnsureValidURL(serverURI, true) if parseErr != nil { t.Skip("Skipping server test as the server uri is not valid") } @@ -41,13 +43,13 @@ func runCommand(errBuffer *strings.Builder, name string, args ...string) error { return cmd.Wait() } -func loginOAuthSelenium(url string, state *Client) { +func loginOAuthSelenium(ck *cookie.Cookie, url string) { // We could use the go selenium library // But it does not support the latest selenium v4 just yet var errBuffer strings.Builder err := runCommand(&errBuffer, "python3", "../selenium_eduvpn.py", url) if err != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() panic(fmt.Sprintf( "Login OAuth with selenium script failed with error %v and stderr %s", err, @@ -58,10 +60,10 @@ func loginOAuthSelenium(url string, state *Client) { func stateCallback( t *testing.T, + ck *cookie.Cookie, _ FSMStateID, newState FSMStateID, data interface{}, - state *Client, ) { if newState == StateOAuthStarted { url, ok := data.(string) @@ -69,20 +71,20 @@ func stateCallback( if !ok { t.Fatalf("data is not a string for OAuth URL") } - loginOAuthSelenium(url, state) + loginOAuthSelenium(ck, url) } } func TestServer(t *testing.T) { serverURI := getServerURI(t) - state := &Client{} - + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configstest", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) + stateCallback(t, &ck, old, new, data) return true }, false, @@ -95,12 +97,11 @@ func TestServer(t *testing.T) { t.Fatalf("Registering error: %v", err) } - - addErr := state.AddCustomServer(serverURI) + addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if addErr != nil { t.Fatalf("Add error: %v", addErr) } - _, configErr := state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Connect error: %v", configErr) } @@ -112,33 +113,36 @@ func testConnectOAuthParameter( errPrefix string, ) { serverURI := getServerURI(t) - state := &Client{} configDirectory := "test_oauth_parameters" + state := &Client{} + + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", configDirectory, func(oldState FSMStateID, newState FSMStateID, data interface{}) bool { if newState == StateOAuthStarted { - server, serverErr := state.Servers.GetCustomServer(serverURI) + server, serverErr := state.Servers.CustomServer(serverURI) if serverErr != nil { t.Fatalf("No server with error: %v", serverErr) } port, portErr := server.OAuth().ListenerPort() if portErr != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() t.Fatalf("No port with error: %v", portErr) } baseURL := fmt.Sprintf("http://127.0.0.1:%d/callback", port) p, err := url.Parse(baseURL) if err != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() t.Fatalf("Failed to parse URL with error: %v", err) } url, err := httpw.ConstructURL(p, parameters) if err != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() t.Fatalf( "Error: Constructing url %s with parameters %s", baseURL, @@ -148,7 +152,7 @@ func testConnectOAuthParameter( go func() { _, getErr := http.Get(url) if getErr != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() t.Logf("HTTP GET error: %v", getErr) } }() @@ -165,7 +169,7 @@ func testConnectOAuthParameter( t.Fatalf("Registering error: %v", err) } - err = state.AddCustomServer(serverURI) + err = state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if errPrefix == "" { if err != nil { @@ -247,14 +251,14 @@ func TestTokenExpired(t *testing.T) { } // Get a vpn state - state := &Client{} - + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configsexpired", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) + stateCallback(t, &ck, old, new, data) return true }, false, @@ -267,25 +271,25 @@ func TestTokenExpired(t *testing.T) { t.Fatalf("Registering error: %v", err) } - addErr := state.AddCustomServer(serverURI) + addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if addErr != nil { t.Fatalf("Add error: %v", addErr) } - _, configErr := state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Connect error before expired: %v", configErr) } - currentServer, serverErr := state.Servers.GetCurrentServer() + currentServer, serverErr := state.Servers.Current() if serverErr != nil { t.Fatalf("No server found") } serverOAuth := currentServer.OAuth() - accessToken, accessTokenErr := serverOAuth.AccessToken() + accessToken, accessTokenErr := serverOAuth.AccessToken(ck.Context()) if accessTokenErr != nil { t.Fatalf("Failed to get token: %v", accessTokenErr) } @@ -293,14 +297,14 @@ func TestTokenExpired(t *testing.T) { // Wait for TTL so that the tokens expire time.Sleep(time.Duration(expiredInt) * time.Second) - _, configErr = state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Connect error after expiry: %v", configErr) } // Check if tokens have changed - accessTokenAfter, accessTokenAfterErr := serverOAuth.AccessToken() + accessTokenAfter, accessTokenAfterErr := serverOAuth.AccessToken(ck.Context()) if accessTokenAfterErr != nil { t.Fatalf("Failed to get token: %v", accessTokenAfterErr) } @@ -313,14 +317,14 @@ func TestTokenExpired(t *testing.T) { // Test if an invalid profile will be corrected. func TestInvalidProfileCorrected(t *testing.T) { serverURI := getServerURI(t) - state := &Client{} - + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configscancelprofile", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) + stateCallback(t, &ck, old, new, data) return true }, false, @@ -333,18 +337,18 @@ func TestInvalidProfileCorrected(t *testing.T) { t.Fatalf("Registering error: %v", err) } - addErr := state.AddCustomServer(serverURI) + addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if addErr != nil { t.Fatalf("Add error: %v", addErr) } - _, configErr := state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("First connect error: %v", configErr) } - currentServer, serverErr := state.Servers.GetCurrentServer() + currentServer, serverErr := state.Servers.Current() if serverErr != nil { t.Fatalf("No server found") } @@ -357,7 +361,7 @@ func TestInvalidProfileCorrected(t *testing.T) { previousProfile := base.Profiles.Current base.Profiles.Current = "IDONOTEXIST" - _, configErr = state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Second connect error: %v", configErr) } @@ -374,14 +378,14 @@ func TestInvalidProfileCorrected(t *testing.T) { // Test if prefer tcp is handled correctly by checking the returned config and config type. func TestPreferTCP(t *testing.T) { serverURI := getServerURI(t) - state := &Client{} - + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configsprefertcp", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) + stateCallback(t, &ck, old, new, data) return true }, false, @@ -394,13 +398,13 @@ func TestPreferTCP(t *testing.T) { t.Fatalf("Registering error: %v", err) } - addErr := state.AddCustomServer(serverURI) + addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if addErr != nil { t.Fatalf("Add error: %v", addErr) } // get a config with preferTCP set to true - config, configErr := state.GetConfigCustomServer(serverURI, true, srvtypes.Tokens{}) + config, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, true) // Test server should accept prefer TCP! if config.Protocol != protocol.OpenVPN { @@ -417,7 +421,7 @@ func TestPreferTCP(t *testing.T) { } // get a config with preferTCP set to false - config, configErr = state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + config, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Config error: %v", configErr) } |
