diff options
| author | Jeroen Wijenbergh <jeroenwijenbergh@protonmail.com> | 2022-03-28 23:29:43 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-03-28 23:29:43 +0200 |
| commit | 6192f9ab54a805c1fabe6a2c5b8eca622b565082 (patch) | |
| tree | f889404ebca573c8ecc886ea1858dc6822158f6e /src | |
| parent | 785e34a4ebacee7dea16af6d16725647b7f6fd7d (diff) | |
OAuth: Token refresh changes and tests
Diffstat (limited to 'src')
| -rw-r--r-- | src/api.go | 24 | ||||
| -rw-r--r-- | src/http.go | 6 | ||||
| -rw-r--r-- | src/oauth.go | 40 | ||||
| -rw-r--r-- | src/server_test.go | 86 | ||||
| -rw-r--r-- | src/state.go | 16 |
5 files changed, 136 insertions, 36 deletions
@@ -2,13 +2,14 @@ package eduvpn import ( "encoding/json" + "errors" "fmt" "net/http" "net/url" ) // Authenticated wrappers on top of HTTP -func (server *Server) apiAuthenticatedWithOpts(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { +func (server *Server) apiAuthenticated(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { // Ensure optional is not nil as we will fill it with headers if opts == nil { opts = &HTTPOptionalParams{} @@ -29,15 +30,26 @@ func (server *Server) apiAuthenticatedWithOpts(method string, endpoint string, o } else { opts.Headers = &http.Header{headerKey: {headerValue}} } - header, body, bodyErr := HTTPMethodWithOpts(method, url, opts) + return HTTPMethodWithOpts(method, url, opts) +} + +func (server *Server) apiAuthenticatedRetry(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { + header, body, bodyErr := server.apiAuthenticated(method, endpoint, opts) if bodyErr != nil { + var error *HTTPStatusError + + if errors.As(bodyErr, &error) { + // Tell the method that the token is expired + server.OAuth.Token.ExpiredTimestamp = GenerateTimeSeconds() + return server.apiAuthenticated(method, endpoint, opts) + } return header, nil, bodyErr } - return header, body, nil + return header, body, bodyErr } func (server *Server) APIInfo() error { - _, body, bodyErr := server.apiAuthenticatedWithOpts(http.MethodGet, "/info", nil) + _, body, bodyErr := server.apiAuthenticatedRetry(http.MethodGet, "/info", nil) if bodyErr != nil { return bodyErr } @@ -65,7 +77,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str "profile_id": {"default"}, "public_key": {pubkey}, } - header, connectBody, connectErr := server.apiAuthenticatedWithOpts(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) + header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { return "", "", connectErr } @@ -83,7 +95,7 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro urlForm := url.Values{ "profile_id": {"default"}, } - header, connectBody, connectErr := server.apiAuthenticatedWithOpts(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) + header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { return "", "", connectErr } diff --git a/src/http.go b/src/http.go index bc2cc20..bd06342 100644 --- a/src/http.go +++ b/src/http.go @@ -160,6 +160,10 @@ func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) (ht return resp.Header, nil, &HTTPReadError{URL: url, Err: readErr} } - // Return the body in bytes and signal that there was no error + if resp.StatusCode != 200 { + return resp.Header, body, &HTTPStatusError{URL: url, Status: resp.StatusCode} + } + + // Return the body in bytes and signal the status error if there was one return resp.Header, body, nil } diff --git a/src/oauth.go b/src/oauth.go index 92c64d4..dd86279 100644 --- a/src/oauth.go +++ b/src/oauth.go @@ -73,7 +73,7 @@ type OAuthExchangeSession struct { Server *http.Server } -func generateTimeSeconds() int64 { +func GenerateTimeSeconds() int64 { current := time.Now() return current.Unix() } @@ -122,13 +122,14 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { "content-type": {"application/x-www-form-urlencoded"}, } opts := &HTTPOptionalParams{Headers: headers, Body: data} - current_time := generateTimeSeconds() + current_time := GenerateTimeSeconds() _, body, bodyErr := HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { return bodyErr } tokenStructure := &OAuthToken{} + jsonErr := json.Unmarshal(body, tokenStructure) if jsonErr != nil { @@ -143,7 +144,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { func (oauth *OAuth) isTokensExpired() bool { expired_time := oauth.Token.ExpiredTimestamp - current_time := generateTimeSeconds() + current_time := GenerateTimeSeconds() return current_time >= expired_time } @@ -160,7 +161,7 @@ func (oauth *OAuth) getTokensWithRefresh() error { "content-type": {"application/x-www-form-urlencoded"}, } opts := &HTTPOptionalParams{Headers: headers, Body: data} - current_time := generateTimeSeconds() + current_time := GenerateTimeSeconds() _, body, bodyErr := HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { return bodyErr @@ -269,9 +270,38 @@ func (eduvpn *VPNState) FinishOAuth() error { return oauth.getTokensWithCallback() } +func (state *VPNState) LoginOAuth() error { + authURL, authInitializeErr := state.InitializeOAuth() + + if authInitializeErr != nil { + return authInitializeErr + } + + go state.StateCallback("Registered", "OAuthInitialized", authURL) + oauthErr := state.FinishOAuth() + + if oauthErr != nil { + return oauthErr + } + + state.StateCallback("OAuthInitialized", "OAuthFinished", "finished oauth") + state.WriteConfig() + return nil +} + +func (oauth *OAuth) Login() error { + // FIXME: Find a better way + state := GetVPNState() + return state.LoginOAuth() +} + func (oauth *OAuth) EnsureTokens() error { if oauth.isTokensExpired() { - return oauth.getTokensWithRefresh() + err := oauth.getTokensWithRefresh() + if err != nil { + // log that we're getting tokens using login + return oauth.Login() + } } return nil } diff --git a/src/server_test.go b/src/server_test.go index 3492e01..618c3b6 100644 --- a/src/server_test.go +++ b/src/server_test.go @@ -1,13 +1,13 @@ package eduvpn import ( + "crypto/tls" "errors" "fmt" - "testing" "net/http" - "crypto/tls" "os/exec" "strings" + "testing" ) func runCommand(t *testing.T, errBuffer *strings.Builder, name string, args ...string) error { @@ -22,12 +22,11 @@ func runCommand(t *testing.T, errBuffer *strings.Builder, name string, args ...s return cmd.Wait() } -func LoginOAuthSelenium(t* testing.T, url string) { +func LoginOAuthSelenium(t *testing.T, url string) { // We could use the go selenium library // But it does not support the latest selenium v4 just yet var errBuffer strings.Builder err := runCommand(t, &errBuffer, "python3", "../selenium_eduvpn.py", url) - if err != nil { t.Errorf("Login OAuth with selenium script failed with error %v and stderr %s", err, errBuffer.String()) } @@ -45,9 +44,10 @@ func Test_server(t *testing.T) { // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configs", func(old string, new string, data string) { + state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) { StateCallback(t, old, new, data) }) + _, configErr := state.Connect("https://eduvpnserver") if configErr != nil { @@ -55,7 +55,7 @@ func Test_server(t *testing.T) { } } -func test_connect_oauth_parameter(t* testing.T, parameters URLParameters, expectedErr interface{}) { +func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expectedErr interface{}) { state := &VPNState{} // Do not verify because during testing, the cert is self-signed @@ -78,23 +78,87 @@ func test_connect_oauth_parameter(t* testing.T, parameters URLParameters, expect } } -func Test_connect_oauth_parameters(t* testing.T) { - +func Test_connect_oauth_parameters(t *testing.T) { var ( - failedCallbackParameterError *OAuthFailedCallbackParameterError + failedCallbackParameterError *OAuthFailedCallbackParameterError failedCallbackStateMatchError *OAuthFailedCallbackStateMatchError ) tests := []struct { expectedErr interface{} - parameters URLParameters + parameters URLParameters }{ {&failedCallbackParameterError, URLParameters{}}, {&failedCallbackParameterError, URLParameters{"code": "42"}}, - {&failedCallbackStateMatchError, URLParameters{"code": "42", "state": "21",}}, + {&failedCallbackStateMatchError, URLParameters{"code": "42", "state": "21"}}, } for _, test := range tests { test_connect_oauth_parameter(t, test.parameters, test.expectedErr) } } + +func Test_token_refresh(t *testing.T) { + state := GetVPNState() + + // Do not verify because during testing, the cert is self-signed + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + + state.Register("org.eduvpn.app.linux", "configsrefresh", func(old string, new string, data string) { + StateCallback(t, old, new, data) + }) + + // Fake expiry + state.Server.OAuth.Token.ExpiredTimestamp = GenerateTimeSeconds() + accessToken := state.Server.OAuth.Token.Access + refreshToken := state.Server.OAuth.Token.Refresh + + _, configErr := state.Connect("https://eduvpnserver") + + if configErr != nil { + t.Errorf("Connect error: %v", configErr) + } + + // Check if tokens have changed + accessTokenAfter := state.Server.OAuth.Token.Access + refreshTokenAfter := state.Server.OAuth.Token.Refresh + + if accessToken == accessTokenAfter { + t.Errorf("Access token is the same after refresh") + } + + if refreshToken == refreshTokenAfter { + t.Errorf("Refresh token is the same after refresh") + } +} + +func Test_token_invalid(t *testing.T) { + state := GetVPNState() + + // Do not verify because during testing, the cert is self-signed + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + + state.Register("org.eduvpn.app.linux", "configsinvalid", func(old string, new string, data string) { + StateCallback(t, old, new, data) + }) + + dummy_value := "37" + + // Override tokens with invalid values + state.Server.OAuth.Token.Access = dummy_value + state.Server.OAuth.Token.Refresh = dummy_value + + _, configErr := state.Connect("https://eduvpnserver") + + if configErr != nil { + t.Errorf("Connect error: %v", configErr) + } + + if state.Server.OAuth.Token.Access == dummy_value { + t.Errorf("Access token is equal to dummy value: %s", dummy_value) + } + + if state.Server.OAuth.Token.Refresh == dummy_value { + t.Errorf("Refresh token is equal to dummy value: %s", dummy_value) + } +} diff --git a/src/state.go b/src/state.go index fcc5930..12ad57a 100644 --- a/src/state.go +++ b/src/state.go @@ -39,21 +39,11 @@ func (state *VPNState) Connect(url string) (string, error) { } if !state.Server.IsAuthenticated() { - authURL, authInitializeErr := state.InitializeOAuth() + loginErr := state.LoginOAuth() - if authInitializeErr != nil { - return "", authInitializeErr + if loginErr != nil { + return "", loginErr } - - go state.StateCallback("Registered", "OAuthInitialized", authURL) - oauthErr := state.FinishOAuth() - - if oauthErr != nil { - return "", oauthErr - } - - state.StateCallback("OAuthInitialized", "OAuthFinished", "finished oauth") - state.WriteConfig() } return state.Server.GetConfig() |
