diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2024-02-06 14:44:18 +0100 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2024-02-19 14:15:07 +0100 |
| commit | 4d26c8489b09acc98128715e9a2ed67558eb8105 (patch) | |
| tree | 0ed8f4c95c12e501bc1a78c646c707ed6618936b | |
| parent | 3fd29f3e1c963196cac69fcbb9d68116f7ea80ec (diff) | |
Util + OAuth: Delete internal OAuth implementation
Preparing to move to github.com/jwijenbergh/eduoauth-go
| -rw-r--r-- | internal/oauth/oauth.go | 551 | ||||
| -rw-r--r-- | internal/oauth/oauth_test.go | 222 | ||||
| -rw-r--r-- | internal/oauth/token.go | 162 | ||||
| -rw-r--r-- | internal/util/util.go | 11 | ||||
| -rw-r--r-- | internal/util/util_test.go | 24 |
5 files changed, 1 insertions, 969 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go deleted file mode 100644 index d7da299..0000000 --- a/internal/oauth/oauth.go +++ /dev/null @@ -1,551 +0,0 @@ -// Package oauth implement an oauth client defined in e.g. rfc 6749 -// However, we try to follow some recommendations from the v2.1 oauth draft RFC -// Some specific things we implement here: -// - PKCE (RFC 7636) -// - ISS (RFC 9207) -package oauth - -import ( - "context" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "fmt" - "html/template" - "net" - "net/http" - "net/url" - "sync" - "time" - - httpw "github.com/eduvpn/eduvpn-common/internal/http" - "github.com/eduvpn/eduvpn-common/internal/util" - "github.com/go-errors/errors" -) - -// genState generates a random base64 string to be used for state -// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1 -// "state": OPTIONAL. An opaque value used by the client to maintain -// state between the request and callback. The authorization server -// includes this value when redirecting the user agent back to the -// client. -// We implement it similarly to the verifier. -func genState() (string, error) { - bs, err := util.MakeRandomByteSlice(32) - if err != nil { - return "", err - } - - // For consistency, we also use raw url encoding here - return base64.RawURLEncoding.EncodeToString(bs), nil -} - -// genChallengeS256 generates a sha256 base64 challenge from a verifier -// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.8 -func genChallengeS256(verifier string) string { - hash := sha256.Sum256([]byte(verifier)) - - // We use raw url encoding as the challenge does not accept padding - return base64.RawURLEncoding.EncodeToString(hash[:]) -} - -// genVerifier generates a verifier -// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1 -// The code_verifier is a unique high-entropy cryptographically random -// string generated for each authorization request, using the unreserved -// characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~", with a -// minimum length of 43 characters and a maximum length of 128 -// characters. -// We implement it according to the note: -// -// NOTE: The code verifier SHOULD have enough entropy to make it -// impractical to guess the value. It is RECOMMENDED that the output of -// a suitable random number generator be used to create a 32-octet -// sequence. The octet sequence is then base64url-encoded to produce a -// 43-octet URL safe string to use as the code verifier. -// -// See: https://datatracker.ietf.org/doc/html/rfc7636#section-4.1 -func genVerifier() (string, error) { - random, err := util.MakeRandomByteSlice(32) - if err != nil { - return "", err - } - - return base64.RawURLEncoding.EncodeToString(random), nil -} - -// OAuth defines the main structure for this package. -type OAuth struct { - // The cached client id so we don't have to pass it around - ClientID string `json:"client_id"` - - // The HTTP client that is used - httpClient *httpw.Client - - // ISS indicates the issuer identifier of the authorization server as defined in RFC 9207 - ISS string `json:"iss"` - - // BaseAuthorizationURL is the URL where authorization should take place - BaseAuthorizationURL string `json:"base_authorization_url"` - - // TokenURL is the URL where tokens should be obtained - TokenURL string `json:"token_url"` - - // session is the internal in progress OAuth session - session exchangeSession - - // Token is where the access and refresh tokens are stored along with the timestamps - // It is protected by a lock - token *tokenLock -} - -// exchangeSession is a structure that gets passed to the callback for easy access to the current state. -type exchangeSession struct { - // ISS indicates the issuer identifier - ISS string - - // State is the expected URL state parameter - State string - - // Verifier is the preimage of the challenge - Verifier string - - // RedirectURI is the passed redirect URI - RedirectURI string - - // Listener is the listener where the servers 'listens' on - Listener net.Listener - - // ErrChan is used to send the error from the handler - ErrChan chan error -} - -// AccessToken gets the OAuth access token used for contacting the server API -// It returns the access token as a string, possibly obtained fresh using the Refresh Token -// If the token cannot be obtained, an error is returned and the token is an empty string. -func (oauth *OAuth) AccessToken(ctx context.Context) (string, error) { - tl := oauth.token - if tl == nil { - return "", errors.New("No token structure available") - } - return tl.Access(ctx) -} - -// setupListener sets up an OAuth listener -// If it was unsuccessful it returns an error. -// @see https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-07.html#section-8.4.2 -// "Loopback Interface Redirection". -func (oauth *OAuth) setupListener() (net.Listener, error) { - // create a listener - lst, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return nil, errors.WrapPrefix(err, "net.Listen failed", 0) - } - return lst, nil -} - -// tokensWithCallback gets the OAuth tokens using a local web server -// If it was unsuccessful it returns an error. -func (oauth *OAuth) tokensWithCallback(ctx context.Context) error { - if oauth.session.Listener == nil { - return errors.New("failed getting tokens with callback: no listener") - } - mux := http.NewServeMux() - // server /callback over the listener address - s := &http.Server{ - Handler: mux, - // Define a default 60 second header read timeout to protect against a Slowloris Attack - // A bit overkill maybe for a local server but good to define anyways - ReadHeaderTimeout: 60 * time.Second, - } - defer s.Shutdown(ctx) //nolint:errcheck - - // Use a sync.Once to only handle one request up until we shutdown the server - var once sync.Once - mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { - once.Do(func() { - oauth.Handler(w, r) - }) - }) - - go func() { - if err := s.Serve(oauth.session.Listener); err != http.ErrServerClosed { - oauth.session.ErrChan <- errors.WrapPrefix(err, "failed getting tokens with callback", 0) - } - }() - select { - case err := <-oauth.session.ErrChan: - return err - case <-ctx.Done(): - return errors.WrapPrefix(context.Canceled, "stopped oauth server", 0) - } -} - -// tokenResponse fills the OAuth token response structure by the response -// The URL that is input here is used for additional context -// It returns this structure and an error if there is one -func (oauth *OAuth) tokenResponse(response []byte, url string) (*TokenResponse, error) { - if oauth.token == nil { - return nil, errors.New("No oauth structure when filling token") - } - res := TokenResponse{} - - err := json.Unmarshal(response, &res) - if err != nil { - return nil, errors.WrapPrefix(err, "failed filling OAuth tokens from "+url, 0) - } - - return &res, nil -} - -// SetTokenExpired marks the tokens as expired by setting the expired timestamp to the current time. -func (oauth *OAuth) SetTokenExpired() { - if oauth.token != nil { - oauth.token.SetExpired() - } -} - -// SetTokenRenew sets the tokens for renewal by completely clearing the structure. -func (oauth *OAuth) SetTokenRenew() { - if oauth.token != nil { - oauth.token.Update(Token{}) - } -} - -func (oauth *OAuth) Token() Token { - t := Token{} - if oauth.token != nil { - t = oauth.token.Get() - } - - return t -} - -// tokensWithAuthCode gets the access and refresh tokens using the authorization code -// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 -// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 -// If it was unsuccessful it returns an error. -func (oauth *OAuth) tokensWithAuthCode(ctx context.Context, authCode string) error { - // Make sure the verifier is set as the parameter - // so that the server can verify that we are the actual owner of the authorization code - u := oauth.TokenURL - - data := url.Values{ - "client_id": {oauth.ClientID}, - "code": {authCode}, - "code_verifier": {oauth.session.Verifier}, - "grant_type": {"authorization_code"}, - "redirect_uri": {oauth.session.RedirectURI}, - } - h := http.Header{ - "content-type": {"application/x-www-form-urlencoded"}, - } - opts := &httpw.OptionalParams{Headers: h, Body: data} - now := time.Now() - - // We are sure that we have a http client because we have initialized it when starting the exchange - _, body, err := oauth.httpClient.PostWithOpts(ctx, u, opts) - if err != nil { - return err - } - - tr, err := oauth.tokenResponse(body, u) - if err != nil { - return err - } - if tr == nil { - return errors.New("No token response after authorization code") - } - - oauth.token.UpdateResponse(*tr, now) - return nil -} - -func (oauth *OAuth) UpdateTokens(t Token) { - if oauth.token == nil { - oauth.token = &tokenLock{t: &tokenRefresher{Refresher: oauth.refreshResponse}} - } - oauth.token.Update(t) -} - -// refreshResponse gets the refresh token response with a refresh token -// This response contains the access and refresh tokens, together with a timestamp -// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 -// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 -// If it was unsuccessful it returns an error. -func (oauth *OAuth) refreshResponse(ctx context.Context, r string) (*TokenResponse, time.Time, error) { - u := oauth.TokenURL - if oauth.token == nil { - return nil, time.Time{}, errors.New("No oauth token structure in refresh") - } - if oauth.ClientID == "" { - return nil, time.Time{}, errors.New("No client ID was cached for refresh") - } - data := url.Values{ - "client_id": {oauth.ClientID}, - "refresh_token": {r}, - "grant_type": {"refresh_token"}, - } - h := http.Header{ - "content-type": {"application/x-www-form-urlencoded"}, - } - opts := &httpw.OptionalParams{Headers: h, Body: data} - now := time.Now() - - // Test if we have a http client and if not recreate one - if oauth.httpClient == nil { - oauth.httpClient = httpw.NewClient() - } - - _, body, err := oauth.httpClient.PostWithOpts(ctx, u, opts) - if err != nil { - return nil, time.Time{}, err - } - - tr, err := oauth.tokenResponse(body, u) - return tr, now, err -} - -// responseTemplate is the HTML template for the OAuth authorized response -// this template was dapted from: https://github.com/eduvpn/apple/blob/5b18f834be7aebfed00570ae0c2f7bcbaf1c69cc/EduVPN/Helpers/Mac/OAuthRedirectHTTPHandler.m#L25 -const responseTemplate string = ` -<!DOCTYPE html> -<html dir="ltr" xmlns="http://www.w3.org/1999/xhtml" lang="en"><head> -<meta http-equiv="content-type" content="text/html; charset=UTF-8"> -<meta charset="utf-8"> -<title>{{.Title}}</title> -<style> -body { - font-family: arial; - margin: 0; - height: 100vh; - display: flex; - align-items: center; - justify-content: center; - background: #ccc; - color: #252622; -} -main { - padding: 1em 2em; - text-align: center; - border: 2pt solid #666; - box-shadow: rgba(0, 0, 0, 0.6) 0px 1px 4px; - border-color: #aaa; - background: #ddd; -} -</style> -</head> -<body> - <main> - <h1>{{.Title}}</h1> - <p>{{.Message}}</p> - </main> -</body> -</html> -` - -// oauthResponseHTML is a structure that is used to give back the OAuth response. -type oauthResponseHTML struct { - Title string - Message string -} - -// writeResponseHTML writes the OAuth response using a response writer and the title + message -// If it was unsuccessful it returns an error. -func writeResponseHTML(w http.ResponseWriter, title string, message string) error { - t, err := template.New("oauth-response").Parse(responseTemplate) - if err != nil { - return errors.WrapPrefix(err, "failed writing response HTML", 0) - } - - return t.Execute(w, oauthResponseHTML{Title: title, Message: message}) -} - -// Authcode gets the authorization code from the url -// It returns the code and an error if there is one -func (s *exchangeSession) Authcode(url *url.URL) (string, error) { - // ISS: https://www.rfc-editor.org/rfc/rfc9207.html - q := url.Query() - - // first check ISS - iss := q.Get("iss") - if s.ISS != "" && s.ISS != iss { - return "", errors.Errorf("failed matching ISS; expected '%s' got '%s'", s.ISS, iss) - } - // Make sure the state is present and matches to protect against cross-site request forgeries - // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 - state := q.Get("state") - if state == "" { - return "", errors.Errorf("failed retrieving parameter 'state' from '%s'", url) - } - // The state is the first entry - if state != s.State { - return "", errors.Errorf("failed matching state; expected '%s' got '%s'", s.State, state) - } - - // check if an error is present - // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-09#name-authorization-response (error response) - errc := q.Get("error") - if errc != "" { - // these are optional but let's include them - errdesc := q.Get("error_description") - erruri := q.Get("error_uri") - return "", errors.Errorf("failed obtaining oauthorization code, error code '%s', error description '%s', error uri '%s'", errc, errdesc, erruri) - } - - // No authorization code - code := q.Get("code") - if code == "" { - return "", errors.Errorf("failed retrieving parameter 'code' from '%s'", url) - } - - return code, nil -} - -// tokenHandler gets the tokens using the authorization code that is obtained through the url -// This function is called by the http handler and returns an error if the tokens cannot be obtained -func (oauth *OAuth) tokenHandler(ctx context.Context, url *url.URL) error { - // Get the authorization code - c, err := oauth.session.Authcode(url) - if err != nil { - return err - } - // Now that we have obtained the authorization code, we can move to the next step: - // Obtaining the access and refresh tokens - return oauth.tokensWithAuthCode(ctx, c) -} - -// Handler is the function used to get the OAuth tokens using an authorization code callback -// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 -// It sends an error to the session channel (can be nil) -func (oauth *OAuth) Handler(w http.ResponseWriter, req *http.Request) { - // TODO: should this be something else than context background? - err := oauth.tokenHandler(context.Background(), req.URL) - if err != nil { - _ = writeResponseHTML( - w, - "Authorization Failed", - "The authorization has failed. See the log file for more information.", - ) - } else { - _ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.") - } - oauth.session.ErrChan <- err -} - -// Init initializes OAuth with the following parameters: -// - OAuth server issuer identification -// - The URL used for authorization -// - The URL to obtain new tokens. -func (oauth *OAuth) Init(clientID string, iss string, baseAuthorizationURL string, tokenURL string) { - oauth.ClientID = clientID - oauth.ISS = iss - oauth.BaseAuthorizationURL = baseAuthorizationURL - oauth.TokenURL = tokenURL -} - -// AuthURL gets the authorization url to start the OAuth procedure. -func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string, cr string) (string, error) { - // Update the client ID - oauth.ClientID = name - - // Generate the verifier and challenge - v, err := genVerifier() - if err != nil { - return "", errors.WrapPrefix(err, "genVerifier error", 0) - } - - // Generate the state - state, err := genState() - if err != nil { - return "", errors.WrapPrefix(err, "genState error", 0) - } - - // Re-initialize the token structure - oauth.UpdateTokens(Token{}) - - // Fill the struct with the necessary fields filled for the next call to getting the HTTP client - red := cr - - // no custom redirect URI defined, we setup our own - var l net.Listener - if cr == "" { - // set up the listener to get the redirect URI - l, err = oauth.setupListener() - if err != nil { - return "", errors.WrapPrefix(err, "oauth.setupListener error", 0) - } - port := l.Addr().(*net.TCPAddr).Port - // see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php - red = fmt.Sprintf("http://127.0.0.1:%d/callback", port) - } - oauth.session = exchangeSession{ - ISS: oauth.ISS, - State: state, - Verifier: v, - ErrChan: make(chan error), - RedirectURI: red, - Listener: l, - } - - params := map[string]string{ - "client_id": name, - "code_challenge_method": "S256", - "code_challenge": genChallengeS256(v), - "response_type": "code", - "scope": "config", - "state": state, - "redirect_uri": red, - } - - p, err := url.Parse(oauth.BaseAuthorizationURL) - if err != nil { - return "", errors.WrapPrefix(err, fmt.Sprintf("failed to parse OAuth base URL '%s'", oauth.BaseAuthorizationURL), 0) - } - // Make sure the scheme is HTTPS - p.Scheme = "https" - - u, err := httpw.ConstructURL(p, params) - if err != nil { - return "", errors.WrapPrefix(err, "httpw.ConstructURL error", 0) - } - - // Return the url processed - return postProcessAuth(u), nil -} - -func (oauth *OAuth) tokensWithURI(ctx context.Context, uri string) error { - // parse URI - p, err := url.Parse(uri) - if err != nil { - return err - } - return oauth.tokenHandler(ctx, p) -} - -// Exchange starts the OAuth exchange by getting the tokens with the redirect callback -// If it was unsuccessful it returns an error. -func (oauth *OAuth) Exchange(ctx context.Context, uri string) error { - // If there is no HTTP client defined, create a new one - if oauth.httpClient == nil { - oauth.httpClient = httpw.NewClient() - } - if uri != "" { - return oauth.tokensWithURI(ctx, uri) - } - return oauth.tokensWithCallback(ctx) -} - -type CancelledCallbackError struct{} - -func (e *CancelledCallbackError) Error() string { - return "client cancelled OAuth" -} - -type TokensInvalidError struct { - Cause string -} - -func (e *TokensInvalidError) Error() string { - return fmt.Sprintf("tokens are invalid due to: %s", e.Cause) -} diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go deleted file mode 100644 index 1181b5d..0000000 --- a/internal/oauth/oauth_test.go +++ /dev/null @@ -1,222 +0,0 @@ -package oauth - -import ( - "context" - "encoding/json" - "net/url" - "strings" - "testing" - "time" -) - -func Test_verifiergen(t *testing.T) { - v, err := genVerifier() - if err != nil { - t.Fatalf("Gen verifier error: %v", err) - } - - // Verifier must be at minimum 43 and at max 128 characters... - // However... Our verifier is exactly 43! - if len(v) != 43 { - t.Fatalf( - "Got verifier length: %d, want a verifier with at least 43 characters", - len(v), - ) - } - - _, err = url.QueryUnescape(v) - if err != nil { - t.Fatalf("Verifier: %s can not be unescaped", v) - } -} - -func Test_stategen(t *testing.T) { - s1, err := genState() - if err != nil { - t.Fatalf("Error when generating state 1: %v", err) - } - - s2, err := genState() - if err != nil { - t.Fatalf("Error when generating state 2: %v", err) - } - - if s1 == s2 { - t.Fatalf("State: %v, equal to: %v", s1, s2) - } -} - -func Test_challengergen(t *testing.T) { - verifier := "test" - // Calculated using: base64.urlsafe_b64encode(hashlib.sha256("test".encode("utf-8")).digest()).decode("utf-8").replace("=", "") in Python - // This test might not be the best because we're now comparing two different implementations, but at least it gives us a way to see if we messed something up in a commit - want := "n4bQgYhMfWWaL-qgxVrQFaO_TxsrC4Is0V1sFbDwCgg" - got := genChallengeS256(verifier) - - if got != want { - t.Fatalf("Challenger not equal, got: %v, want: %v", got, want) - } -} - -func Test_accessToken(t *testing.T) { - o := OAuth{} - _, err := o.AccessToken(context.Background()) - if err == nil { - t.Fatalf("No error when getting access token on empty structure") - } - - // Here we should get no error because the access token is set and is not expired - want := "test" - expired := time.Now().Add(1 * time.Hour) - o = OAuth{token: &tokenLock{t: &tokenRefresher{Token: Token{Access: want, ExpiredTimestamp: expired}}}} - got, err := o.AccessToken(context.Background()) - if err != nil { - t.Fatalf("Got error when getting access token on non-empty structure: %v", err) - } - if got != want { - t.Fatalf("Access token not equal, Got: %v, Want: %v", got, want) - } - - // Set the tokens as expired - o.SetTokenExpired() - - // We should get an error because expired and no refresh token - _, err = o.AccessToken(context.Background()) - if err == nil { - t.Fatal("Got no error when getting access token on non-empty structure and expired") - } - - want = "test2" - // Now we internally update the refresh function and refresh token, we should get new tokens - refresh := "refresh" - o.token.t.Refresh = refresh - o.token.t.Refresher = func(ctx context.Context, refreshToken string) (*TokenResponse, time.Time, error) { - if refreshToken != refresh { - t.Fatalf("Passed refresh token to refresher not equal to updated refresh token, got: %v, want: %v", refreshToken, refresh) - } - // Only the access and refresh fields are really important - r := &TokenResponse{Access: want, Refresh: "test2"} - return r, expired, nil - } - - got, err = o.AccessToken(context.Background()) - if err != nil { - t.Fatalf("Got error when getting access token on non-empty expired structure and with a 'valid' refresh token: %v", err) - } - if got != want { - t.Fatalf("Access token not equal, Got: %v, Want: %v", got, want) - } - - - // Set the tokens as expired - o.SetTokenExpired() - want = "test3" - - // Now let's act like a 2.x server, we give no refresh token back. When we refresh the previous refresh token should be gotten - o.token.t.Refresh = refresh - prevRefresh := refresh - o.token.t.Refresher = func(ctx context.Context, refreshToken string) (*TokenResponse, time.Time, error) { - if refreshToken != refresh { - t.Fatalf("Passed refresh token to refresher not equal to updated refresh token, got: %v, want: %v", refreshToken, refresh) - } - // Only the access token is returned now - r := &TokenResponse{Access: want} - return r, expired, nil - } - - got, err = o.AccessToken(context.Background()) - if err != nil { - t.Fatalf("Got error when getting access token on non-empty expired structure and with an empty refresh response: %v", err) - } - if got != want { - t.Fatalf("Access token not equal, Got: %v, Want: %v", got, want) - } - if o.token.t.Refresh == "" { - t.Fatalf("Refresh token is empty after refreshing and getting back an empty refresh") - } - if o.token.t.Refresh != prevRefresh { - t.Fatalf("Refresh token is not equal to previous refresh token after refreshing and getting back an empty refresh token, got: %v, want: %v", o.token.t.Refresh, prevRefresh) - } -} - -func Test_secretJSON(t *testing.T) { - // Access and refresh tokens should not be present in marshalled JSON - a := "ineedtobesecret_access" - r := "ineedtobesecret_refresh" - o := OAuth{token: &tokenLock{t: &tokenRefresher{Token: Token{Access: a, Refresh: r}}}} - b, err := json.Marshal(o) - if err != nil { - t.Fatalf("Error when marshalling OAuth JSON: %v", err) - } - s := string(b) - // Of course this is a very dumb check, it could be that we are writing in some other serialized format. However, we simply marshal the structure directly. Go just serializes this as a simple string - if strings.Contains(s, a) { - t.Fatalf("Serialized OAuth contains Access Token! Serialized: %v, Access Token: %v", s, a) - } - - if strings.Contains(s, r) { - t.Fatalf("Serialized OAuth contains Refresh Token! Serialized: %v, Refresh Token: %v", s, a) - } -} - -func Test_AuthURL(t *testing.T) { - iss := "local" - auth := "https://127.0.0.1/auth" - token := "https://127.0.0.1/token" - id := "client_id" - o := OAuth{ISS: iss, BaseAuthorizationURL: auth, TokenURL: token} - s, err := o.AuthURL(id, func(s string) string { - // We do nothing here are this function is for skipping WAYF - return s - }, "") - if err != nil { - t.Fatalf("Error in getting OAuth URL: %v", err) - } - - // Check if the OAuth session has valid values - if o.ClientID != id { - t.Fatalf("OAuth ClientID not equal, want: %v, got: %v", o.ClientID, id) - } - if o.session.ISS != iss { - t.Fatalf("OAuth ISS not equal, want: %v, got: %v", o.session.ISS, iss) - } - if o.session.State == "" { - t.Fatal("No OAuth session state paremeter found") - } - if o.session.Verifier == "" { - t.Fatal("No OAuth session state paremeter found") - } - if o.session.ErrChan == nil { - t.Fatal("No OAuth session error channel found") - } - - u, err := url.Parse(s) - if err != nil { - t.Fatalf("Returned Auth URL cannot be parsed with error: %v", err) - } - - c := []struct { - query string - want string - }{ - {query: "client_id", want: id}, - {query: "code_challenge_method", want: "S256"}, - {query: "response_type", want: "code"}, - {query: "scope", want: "config"}, - {query: "redirect_uri", want: o.session.RedirectURI}, - } - - q := u.Query() - - // We should have 7 parameters: client_id, challenge method, challenge, response type, scope, state and redirect uri - if len(q) != 7 { - t.Fatalf("Total query parameters is not 7, url: %v, total params: %v", u, len(q)) - } - - for _, v := range c { - p := q.Get(v.query) - if p != v.want { - t.Fatalf("Parameter: %v, not equal, want: %v, got: %v", v.query, v.want, p) - } - } -} diff --git a/internal/oauth/token.go b/internal/oauth/token.go deleted file mode 100644 index 251b689..0000000 --- a/internal/oauth/token.go +++ /dev/null @@ -1,162 +0,0 @@ -package oauth - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/eduvpn/eduvpn-common/internal/log" - "github.com/eduvpn/eduvpn-common/types/server" - "github.com/go-errors/errors" -) - -// TokenResponse defines the OAuth response from the server that includes the tokens. -type TokenResponse struct { - // Access is the access token returned by the server - Access string `json:"access_token"` - - // Refresh token is the refresh token returned by the server - Refresh string `json:"refresh_token"` - - // Type indicates which type of tokens we have - Type string `json:"token_type"` - - // Expires is the expires time returned by the server - Expires int64 `json:"expires_in"` -} - -// The public type that can be passed to an update function -// It contains our access and refresh tokens with a timestamp -type Token struct { - // Access is the Access token returned by the server - Access string - - // Refresh token is the Refresh token returned by the server - Refresh string - - // ExpiredTimestamp is the Expires field but converted to a Go timestamp - ExpiredTimestamp time.Time -} - -func (t *Token) Public() server.Tokens { - return server.Tokens{ - Access: t.Access, - Refresh: t.Refresh, - Expires: t.ExpiredTimestamp.Unix(), - } -} - -// tokenRefresher is a structure that contains our access and refresh tokens and a timestamp when they expire. -// Additionally, it contains the refresher to get new tokens -type tokenRefresher struct { - Token - // Refresher is the function that refreshes the token - Refresher func(context.Context, string) (*TokenResponse, time.Time, error) -} - -// tokenLock is a wrapper around token that protects it with a lock -type tokenLock struct { - // Protects t - mu sync.Mutex - - // The token fields protected by the lock - // This token struct contains a refresher - t *tokenRefresher -} - -// Access gets the OAuth access token used for contacting the server API -// It returns the access token as a string, possibly obtained fresh using the refresher -// If the token cannot be obtained, an error is returned and the token is an empty string. -func (l *tokenLock) Access(ctx context.Context) (string, error) { - log.Logger.Debugf("Getting access token") - l.mu.Lock() - defer l.mu.Unlock() - - // The tokens are not expired yet - // So they should be valid, re-login not neede - if !l.expired() { - log.Logger.Debugf("Access token is not expired, returning") - return l.t.Access, nil - } - - // Check if refresh is even possible by doing a simple check if the refresh token is empty - // This is not needed but reduces API calls to the server - if l.t.Refresh == "" { - log.Logger.Debugf("Refresh token is empty, returning error") - return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0) - } - - // Otherwise refresh and then later return the access token if we are successful - tr, s, err := l.t.Refresher(ctx, l.t.Refresh) - if err != nil { - log.Logger.Debugf("Got a refresh token error: %v", err) - // We have failed to ensure the tokens due to refresh not working - return "", errors.Wrap( - &TokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", err)}, 0) - } - if tr == nil { - log.Logger.Debugf("No token response after refreshing") - return "", errors.New("No token response after refreshing") - } - // store the previous refresh token - pr := l.t.Refresh - // get the response as a non-pointer - r := *tr - e := s.Add(time.Second * time.Duration(r.Expires)) - t := Token{Access: r.Access, Refresh: r.Refresh, ExpiredTimestamp: e} - l.updateInternal(t) - // set the previous refresh token if the new one is empty - // This is for 2.x servers - if l.t.Refresh == "" { - log.Logger.Debugf("The previous refresh token is set as the response had no refresh token") - l.t.Refresh = pr - } - return l.t.Access, nil -} - -// UpdateResponse updates the structure using the server response and locks -func (l *tokenLock) UpdateResponse(r TokenResponse, s time.Time) { - l.mu.Lock() - e := s.Add(time.Second * time.Duration(r.Expires)) - t := Token{Access: r.Access, Refresh: r.Refresh, ExpiredTimestamp: e} - l.updateInternal(t) - l.mu.Unlock() -} - -// updateInternal updates the token structure internally but does not lock -func (l *tokenLock) updateInternal(r Token) { - l.t.Access = r.Access - l.t.Refresh = r.Refresh - l.t.ExpiredTimestamp = r.ExpiredTimestamp -} - -// Update updates the token structure using the internal function but locks -func (l *tokenLock) Update(r Token) { - l.mu.Lock() - l.updateInternal(r) - l.mu.Unlock() -} - -// Get gets the tokens into a public struct -func (l *tokenLock) Get() Token { - // TODO: Check nil? - l.mu.Lock() - defer l.mu.Unlock() - return l.t.Token -} - -// SetExpired overrides the timestamp to the current time -// This marks the tokens as expired -func (l *tokenLock) SetExpired() { - l.mu.Lock() - l.t.ExpiredTimestamp = time.Now() - l.mu.Unlock() -} - -// expired checks if the access token is expired. -// This is only called internally and thus does not lock -func (l *tokenLock) expired() bool { - now := time.Now() - return !now.Before(l.t.ExpiredTimestamp) -} diff --git a/internal/util/util.go b/internal/util/util.go index 85c4b37..97b4151 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -2,23 +2,12 @@ package util import ( - "crypto/rand" "fmt" "net/url" "os" "strings" ) -// MakeRandomByteSlice creates a cryptographically random bytes slice of `size` -// It returns the byte slice (or nil if error) and an error if it could not be generated. -func MakeRandomByteSlice(n int) ([]byte, error) { - bs := make([]byte, n) - if _, err := rand.Read(bs); err != nil { - return nil, errors.WrapPrefix(err, "failed reading random", 0) - } - return bs, nil -} - // EnsureDirectory creates a directory with permission 700. func EnsureDirectory(dir string) error { // Create with 700 permissions, read, write, execute only for the owner diff --git a/internal/util/util_test.go b/internal/util/util_test.go index 5e19f57..827fbe1 100644 --- a/internal/util/util_test.go +++ b/internal/util/util_test.go @@ -1,28 +1,6 @@ package util -import ( - "bytes" - "testing" -) - -func TestMakeRandomByteSlice(t *testing.T) { - random, randomErr := MakeRandomByteSlice(32) - if randomErr != nil { - t.Fatalf("Got: %v, want: nil", randomErr) - } - if len(random) != 32 { - t.Fatalf("Got length: %d, want length: 32", len(random)) - } - - random2, randomErr2 := MakeRandomByteSlice(32) - if randomErr2 != nil { - t.Fatalf("2, Got: %v, want: nil", randomErr) - } - - if bytes.Equal(random2, random) { - t.Fatalf("Two random byteslices are the same: %v, %v", random2, random) - } -} +import "testing" func TestReplaceWAYF(t *testing.T) { // We expect url encoding but the spaces to be correctly replace with a + instead of a %20 |
