From 6192f9ab54a805c1fabe6a2c5b8eca622b565082 Mon Sep 17 00:00:00 2001 From: Jeroen Wijenbergh Date: Mon, 28 Mar 2022 23:29:43 +0200 Subject: OAuth: Token refresh changes and tests --- src/api.go | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) (limited to 'src/api.go') diff --git a/src/api.go b/src/api.go index 5a6ba7d..08e3819 100644 --- a/src/api.go +++ b/src/api.go @@ -2,13 +2,14 @@ package eduvpn import ( "encoding/json" + "errors" "fmt" "net/http" "net/url" ) // Authenticated wrappers on top of HTTP -func (server *Server) apiAuthenticatedWithOpts(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { +func (server *Server) apiAuthenticated(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { // Ensure optional is not nil as we will fill it with headers if opts == nil { opts = &HTTPOptionalParams{} @@ -29,15 +30,26 @@ func (server *Server) apiAuthenticatedWithOpts(method string, endpoint string, o } else { opts.Headers = &http.Header{headerKey: {headerValue}} } - header, body, bodyErr := HTTPMethodWithOpts(method, url, opts) + return HTTPMethodWithOpts(method, url, opts) +} + +func (server *Server) apiAuthenticatedRetry(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { + header, body, bodyErr := server.apiAuthenticated(method, endpoint, opts) if bodyErr != nil { + var error *HTTPStatusError + + if errors.As(bodyErr, &error) { + // Tell the method that the token is expired + server.OAuth.Token.ExpiredTimestamp = GenerateTimeSeconds() + return server.apiAuthenticated(method, endpoint, opts) + } return header, nil, bodyErr } - return header, body, nil + return header, body, bodyErr } func (server *Server) APIInfo() error { - _, body, bodyErr := server.apiAuthenticatedWithOpts(http.MethodGet, "/info", nil) + _, body, bodyErr := server.apiAuthenticatedRetry(http.MethodGet, "/info", nil) if bodyErr != nil { return bodyErr } @@ -65,7 +77,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str "profile_id": {"default"}, "public_key": {pubkey}, } - header, connectBody, connectErr := server.apiAuthenticatedWithOpts(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) + header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { return "", "", connectErr } @@ -83,7 +95,7 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro urlForm := url.Values{ "profile_id": {"default"}, } - header, connectBody, connectErr := server.apiAuthenticatedWithOpts(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) + header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { return "", "", connectErr } -- cgit v1.2.3