summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJeroen Wijenbergh <jeroenwijenbergh@protonmail.com>2022-03-28 23:29:43 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-03-28 23:29:43 +0200
commit6192f9ab54a805c1fabe6a2c5b8eca622b565082 (patch)
treef889404ebca573c8ecc886ea1858dc6822158f6e /src
parent785e34a4ebacee7dea16af6d16725647b7f6fd7d (diff)
OAuth: Token refresh changes and tests
Diffstat (limited to 'src')
-rw-r--r--src/api.go24
-rw-r--r--src/http.go6
-rw-r--r--src/oauth.go40
-rw-r--r--src/server_test.go86
-rw-r--r--src/state.go16
5 files changed, 136 insertions, 36 deletions
diff --git a/src/api.go b/src/api.go
index 5a6ba7d..08e3819 100644
--- a/src/api.go
+++ b/src/api.go
@@ -2,13 +2,14 @@ package eduvpn
import (
"encoding/json"
+ "errors"
"fmt"
"net/http"
"net/url"
)
// Authenticated wrappers on top of HTTP
-func (server *Server) apiAuthenticatedWithOpts(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) {
+func (server *Server) apiAuthenticated(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) {
// Ensure optional is not nil as we will fill it with headers
if opts == nil {
opts = &HTTPOptionalParams{}
@@ -29,15 +30,26 @@ func (server *Server) apiAuthenticatedWithOpts(method string, endpoint string, o
} else {
opts.Headers = &http.Header{headerKey: {headerValue}}
}
- header, body, bodyErr := HTTPMethodWithOpts(method, url, opts)
+ return HTTPMethodWithOpts(method, url, opts)
+}
+
+func (server *Server) apiAuthenticatedRetry(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) {
+ header, body, bodyErr := server.apiAuthenticated(method, endpoint, opts)
if bodyErr != nil {
+ var error *HTTPStatusError
+
+ if errors.As(bodyErr, &error) {
+ // Tell the method that the token is expired
+ server.OAuth.Token.ExpiredTimestamp = GenerateTimeSeconds()
+ return server.apiAuthenticated(method, endpoint, opts)
+ }
return header, nil, bodyErr
}
- return header, body, nil
+ return header, body, bodyErr
}
func (server *Server) APIInfo() error {
- _, body, bodyErr := server.apiAuthenticatedWithOpts(http.MethodGet, "/info", nil)
+ _, body, bodyErr := server.apiAuthenticatedRetry(http.MethodGet, "/info", nil)
if bodyErr != nil {
return bodyErr
}
@@ -65,7 +77,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str
"profile_id": {"default"},
"public_key": {pubkey},
}
- header, connectBody, connectErr := server.apiAuthenticatedWithOpts(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
+ header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
return "", "", connectErr
}
@@ -83,7 +95,7 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro
urlForm := url.Values{
"profile_id": {"default"},
}
- header, connectBody, connectErr := server.apiAuthenticatedWithOpts(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
+ header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
return "", "", connectErr
}
diff --git a/src/http.go b/src/http.go
index bc2cc20..bd06342 100644
--- a/src/http.go
+++ b/src/http.go
@@ -160,6 +160,10 @@ func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) (ht
return resp.Header, nil, &HTTPReadError{URL: url, Err: readErr}
}
- // Return the body in bytes and signal that there was no error
+ if resp.StatusCode != 200 {
+ return resp.Header, body, &HTTPStatusError{URL: url, Status: resp.StatusCode}
+ }
+
+ // Return the body in bytes and signal the status error if there was one
return resp.Header, body, nil
}
diff --git a/src/oauth.go b/src/oauth.go
index 92c64d4..dd86279 100644
--- a/src/oauth.go
+++ b/src/oauth.go
@@ -73,7 +73,7 @@ type OAuthExchangeSession struct {
Server *http.Server
}
-func generateTimeSeconds() int64 {
+func GenerateTimeSeconds() int64 {
current := time.Now()
return current.Unix()
}
@@ -122,13 +122,14 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
"content-type": {"application/x-www-form-urlencoded"},
}
opts := &HTTPOptionalParams{Headers: headers, Body: data}
- current_time := generateTimeSeconds()
+ current_time := GenerateTimeSeconds()
_, body, bodyErr := HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
return bodyErr
}
tokenStructure := &OAuthToken{}
+
jsonErr := json.Unmarshal(body, tokenStructure)
if jsonErr != nil {
@@ -143,7 +144,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
func (oauth *OAuth) isTokensExpired() bool {
expired_time := oauth.Token.ExpiredTimestamp
- current_time := generateTimeSeconds()
+ current_time := GenerateTimeSeconds()
return current_time >= expired_time
}
@@ -160,7 +161,7 @@ func (oauth *OAuth) getTokensWithRefresh() error {
"content-type": {"application/x-www-form-urlencoded"},
}
opts := &HTTPOptionalParams{Headers: headers, Body: data}
- current_time := generateTimeSeconds()
+ current_time := GenerateTimeSeconds()
_, body, bodyErr := HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
return bodyErr
@@ -269,9 +270,38 @@ func (eduvpn *VPNState) FinishOAuth() error {
return oauth.getTokensWithCallback()
}
+func (state *VPNState) LoginOAuth() error {
+ authURL, authInitializeErr := state.InitializeOAuth()
+
+ if authInitializeErr != nil {
+ return authInitializeErr
+ }
+
+ go state.StateCallback("Registered", "OAuthInitialized", authURL)
+ oauthErr := state.FinishOAuth()
+
+ if oauthErr != nil {
+ return oauthErr
+ }
+
+ state.StateCallback("OAuthInitialized", "OAuthFinished", "finished oauth")
+ state.WriteConfig()
+ return nil
+}
+
+func (oauth *OAuth) Login() error {
+ // FIXME: Find a better way
+ state := GetVPNState()
+ return state.LoginOAuth()
+}
+
func (oauth *OAuth) EnsureTokens() error {
if oauth.isTokensExpired() {
- return oauth.getTokensWithRefresh()
+ err := oauth.getTokensWithRefresh()
+ if err != nil {
+ // log that we're getting tokens using login
+ return oauth.Login()
+ }
}
return nil
}
diff --git a/src/server_test.go b/src/server_test.go
index 3492e01..618c3b6 100644
--- a/src/server_test.go
+++ b/src/server_test.go
@@ -1,13 +1,13 @@
package eduvpn
import (
+ "crypto/tls"
"errors"
"fmt"
- "testing"
"net/http"
- "crypto/tls"
"os/exec"
"strings"
+ "testing"
)
func runCommand(t *testing.T, errBuffer *strings.Builder, name string, args ...string) error {
@@ -22,12 +22,11 @@ func runCommand(t *testing.T, errBuffer *strings.Builder, name string, args ...s
return cmd.Wait()
}
-func LoginOAuthSelenium(t* testing.T, url string) {
+func LoginOAuthSelenium(t *testing.T, url string) {
// We could use the go selenium library
// But it does not support the latest selenium v4 just yet
var errBuffer strings.Builder
err := runCommand(t, &errBuffer, "python3", "../selenium_eduvpn.py", url)
-
if err != nil {
t.Errorf("Login OAuth with selenium script failed with error %v and stderr %s", err, errBuffer.String())
}
@@ -45,9 +44,10 @@ func Test_server(t *testing.T) {
// Do not verify because during testing, the cert is self-signed
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
- state.Register("org.eduvpn.app.linux", "configs", func(old string, new string, data string) {
+ state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) {
StateCallback(t, old, new, data)
})
+
_, configErr := state.Connect("https://eduvpnserver")
if configErr != nil {
@@ -55,7 +55,7 @@ func Test_server(t *testing.T) {
}
}
-func test_connect_oauth_parameter(t* testing.T, parameters URLParameters, expectedErr interface{}) {
+func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expectedErr interface{}) {
state := &VPNState{}
// Do not verify because during testing, the cert is self-signed
@@ -78,23 +78,87 @@ func test_connect_oauth_parameter(t* testing.T, parameters URLParameters, expect
}
}
-func Test_connect_oauth_parameters(t* testing.T) {
-
+func Test_connect_oauth_parameters(t *testing.T) {
var (
- failedCallbackParameterError *OAuthFailedCallbackParameterError
+ failedCallbackParameterError *OAuthFailedCallbackParameterError
failedCallbackStateMatchError *OAuthFailedCallbackStateMatchError
)
tests := []struct {
expectedErr interface{}
- parameters URLParameters
+ parameters URLParameters
}{
{&failedCallbackParameterError, URLParameters{}},
{&failedCallbackParameterError, URLParameters{"code": "42"}},
- {&failedCallbackStateMatchError, URLParameters{"code": "42", "state": "21",}},
+ {&failedCallbackStateMatchError, URLParameters{"code": "42", "state": "21"}},
}
for _, test := range tests {
test_connect_oauth_parameter(t, test.parameters, test.expectedErr)
}
}
+
+func Test_token_refresh(t *testing.T) {
+ state := GetVPNState()
+
+ // Do not verify because during testing, the cert is self-signed
+ http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
+
+ state.Register("org.eduvpn.app.linux", "configsrefresh", func(old string, new string, data string) {
+ StateCallback(t, old, new, data)
+ })
+
+ // Fake expiry
+ state.Server.OAuth.Token.ExpiredTimestamp = GenerateTimeSeconds()
+ accessToken := state.Server.OAuth.Token.Access
+ refreshToken := state.Server.OAuth.Token.Refresh
+
+ _, configErr := state.Connect("https://eduvpnserver")
+
+ if configErr != nil {
+ t.Errorf("Connect error: %v", configErr)
+ }
+
+ // Check if tokens have changed
+ accessTokenAfter := state.Server.OAuth.Token.Access
+ refreshTokenAfter := state.Server.OAuth.Token.Refresh
+
+ if accessToken == accessTokenAfter {
+ t.Errorf("Access token is the same after refresh")
+ }
+
+ if refreshToken == refreshTokenAfter {
+ t.Errorf("Refresh token is the same after refresh")
+ }
+}
+
+func Test_token_invalid(t *testing.T) {
+ state := GetVPNState()
+
+ // Do not verify because during testing, the cert is self-signed
+ http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
+
+ state.Register("org.eduvpn.app.linux", "configsinvalid", func(old string, new string, data string) {
+ StateCallback(t, old, new, data)
+ })
+
+ dummy_value := "37"
+
+ // Override tokens with invalid values
+ state.Server.OAuth.Token.Access = dummy_value
+ state.Server.OAuth.Token.Refresh = dummy_value
+
+ _, configErr := state.Connect("https://eduvpnserver")
+
+ if configErr != nil {
+ t.Errorf("Connect error: %v", configErr)
+ }
+
+ if state.Server.OAuth.Token.Access == dummy_value {
+ t.Errorf("Access token is equal to dummy value: %s", dummy_value)
+ }
+
+ if state.Server.OAuth.Token.Refresh == dummy_value {
+ t.Errorf("Refresh token is equal to dummy value: %s", dummy_value)
+ }
+}
diff --git a/src/state.go b/src/state.go
index fcc5930..12ad57a 100644
--- a/src/state.go
+++ b/src/state.go
@@ -39,21 +39,11 @@ func (state *VPNState) Connect(url string) (string, error) {
}
if !state.Server.IsAuthenticated() {
- authURL, authInitializeErr := state.InitializeOAuth()
+ loginErr := state.LoginOAuth()
- if authInitializeErr != nil {
- return "", authInitializeErr
+ if loginErr != nil {
+ return "", loginErr
}
-
- go state.StateCallback("Registered", "OAuthInitialized", authURL)
- oauthErr := state.FinishOAuth()
-
- if oauthErr != nil {
- return "", oauthErr
- }
-
- state.StateCallback("OAuthInitialized", "OAuthFinished", "finished oauth")
- state.WriteConfig()
}
return state.Server.GetConfig()