summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2024-02-06 14:44:18 +0100
committerJeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com>2024-02-19 14:15:07 +0100
commit4d26c8489b09acc98128715e9a2ed67558eb8105 (patch)
tree0ed8f4c95c12e501bc1a78c646c707ed6618936b
parent3fd29f3e1c963196cac69fcbb9d68116f7ea80ec (diff)
Util + OAuth: Delete internal OAuth implementation
Preparing to move to github.com/jwijenbergh/eduoauth-go
-rw-r--r--internal/oauth/oauth.go551
-rw-r--r--internal/oauth/oauth_test.go222
-rw-r--r--internal/oauth/token.go162
-rw-r--r--internal/util/util.go11
-rw-r--r--internal/util/util_test.go24
5 files changed, 1 insertions, 969 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
deleted file mode 100644
index d7da299..0000000
--- a/internal/oauth/oauth.go
+++ /dev/null
@@ -1,551 +0,0 @@
-// Package oauth implement an oauth client defined in e.g. rfc 6749
-// However, we try to follow some recommendations from the v2.1 oauth draft RFC
-// Some specific things we implement here:
-// - PKCE (RFC 7636)
-// - ISS (RFC 9207)
-package oauth
-
-import (
- "context"
- "crypto/sha256"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "html/template"
- "net"
- "net/http"
- "net/url"
- "sync"
- "time"
-
- httpw "github.com/eduvpn/eduvpn-common/internal/http"
- "github.com/eduvpn/eduvpn-common/internal/util"
- "github.com/go-errors/errors"
-)
-
-// genState generates a random base64 string to be used for state
-// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1
-// "state": OPTIONAL. An opaque value used by the client to maintain
-// state between the request and callback. The authorization server
-// includes this value when redirecting the user agent back to the
-// client.
-// We implement it similarly to the verifier.
-func genState() (string, error) {
- bs, err := util.MakeRandomByteSlice(32)
- if err != nil {
- return "", err
- }
-
- // For consistency, we also use raw url encoding here
- return base64.RawURLEncoding.EncodeToString(bs), nil
-}
-
-// genChallengeS256 generates a sha256 base64 challenge from a verifier
-// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.8
-func genChallengeS256(verifier string) string {
- hash := sha256.Sum256([]byte(verifier))
-
- // We use raw url encoding as the challenge does not accept padding
- return base64.RawURLEncoding.EncodeToString(hash[:])
-}
-
-// genVerifier generates a verifier
-// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1
-// The code_verifier is a unique high-entropy cryptographically random
-// string generated for each authorization request, using the unreserved
-// characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~", with a
-// minimum length of 43 characters and a maximum length of 128
-// characters.
-// We implement it according to the note:
-//
-// NOTE: The code verifier SHOULD have enough entropy to make it
-// impractical to guess the value. It is RECOMMENDED that the output of
-// a suitable random number generator be used to create a 32-octet
-// sequence. The octet sequence is then base64url-encoded to produce a
-// 43-octet URL safe string to use as the code verifier.
-//
-// See: https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
-func genVerifier() (string, error) {
- random, err := util.MakeRandomByteSlice(32)
- if err != nil {
- return "", err
- }
-
- return base64.RawURLEncoding.EncodeToString(random), nil
-}
-
-// OAuth defines the main structure for this package.
-type OAuth struct {
- // The cached client id so we don't have to pass it around
- ClientID string `json:"client_id"`
-
- // The HTTP client that is used
- httpClient *httpw.Client
-
- // ISS indicates the issuer identifier of the authorization server as defined in RFC 9207
- ISS string `json:"iss"`
-
- // BaseAuthorizationURL is the URL where authorization should take place
- BaseAuthorizationURL string `json:"base_authorization_url"`
-
- // TokenURL is the URL where tokens should be obtained
- TokenURL string `json:"token_url"`
-
- // session is the internal in progress OAuth session
- session exchangeSession
-
- // Token is where the access and refresh tokens are stored along with the timestamps
- // It is protected by a lock
- token *tokenLock
-}
-
-// exchangeSession is a structure that gets passed to the callback for easy access to the current state.
-type exchangeSession struct {
- // ISS indicates the issuer identifier
- ISS string
-
- // State is the expected URL state parameter
- State string
-
- // Verifier is the preimage of the challenge
- Verifier string
-
- // RedirectURI is the passed redirect URI
- RedirectURI string
-
- // Listener is the listener where the servers 'listens' on
- Listener net.Listener
-
- // ErrChan is used to send the error from the handler
- ErrChan chan error
-}
-
-// AccessToken gets the OAuth access token used for contacting the server API
-// It returns the access token as a string, possibly obtained fresh using the Refresh Token
-// If the token cannot be obtained, an error is returned and the token is an empty string.
-func (oauth *OAuth) AccessToken(ctx context.Context) (string, error) {
- tl := oauth.token
- if tl == nil {
- return "", errors.New("No token structure available")
- }
- return tl.Access(ctx)
-}
-
-// setupListener sets up an OAuth listener
-// If it was unsuccessful it returns an error.
-// @see https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-07.html#section-8.4.2
-// "Loopback Interface Redirection".
-func (oauth *OAuth) setupListener() (net.Listener, error) {
- // create a listener
- lst, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- return nil, errors.WrapPrefix(err, "net.Listen failed", 0)
- }
- return lst, nil
-}
-
-// tokensWithCallback gets the OAuth tokens using a local web server
-// If it was unsuccessful it returns an error.
-func (oauth *OAuth) tokensWithCallback(ctx context.Context) error {
- if oauth.session.Listener == nil {
- return errors.New("failed getting tokens with callback: no listener")
- }
- mux := http.NewServeMux()
- // server /callback over the listener address
- s := &http.Server{
- Handler: mux,
- // Define a default 60 second header read timeout to protect against a Slowloris Attack
- // A bit overkill maybe for a local server but good to define anyways
- ReadHeaderTimeout: 60 * time.Second,
- }
- defer s.Shutdown(ctx) //nolint:errcheck
-
- // Use a sync.Once to only handle one request up until we shutdown the server
- var once sync.Once
- mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
- once.Do(func() {
- oauth.Handler(w, r)
- })
- })
-
- go func() {
- if err := s.Serve(oauth.session.Listener); err != http.ErrServerClosed {
- oauth.session.ErrChan <- errors.WrapPrefix(err, "failed getting tokens with callback", 0)
- }
- }()
- select {
- case err := <-oauth.session.ErrChan:
- return err
- case <-ctx.Done():
- return errors.WrapPrefix(context.Canceled, "stopped oauth server", 0)
- }
-}
-
-// tokenResponse fills the OAuth token response structure by the response
-// The URL that is input here is used for additional context
-// It returns this structure and an error if there is one
-func (oauth *OAuth) tokenResponse(response []byte, url string) (*TokenResponse, error) {
- if oauth.token == nil {
- return nil, errors.New("No oauth structure when filling token")
- }
- res := TokenResponse{}
-
- err := json.Unmarshal(response, &res)
- if err != nil {
- return nil, errors.WrapPrefix(err, "failed filling OAuth tokens from "+url, 0)
- }
-
- return &res, nil
-}
-
-// SetTokenExpired marks the tokens as expired by setting the expired timestamp to the current time.
-func (oauth *OAuth) SetTokenExpired() {
- if oauth.token != nil {
- oauth.token.SetExpired()
- }
-}
-
-// SetTokenRenew sets the tokens for renewal by completely clearing the structure.
-func (oauth *OAuth) SetTokenRenew() {
- if oauth.token != nil {
- oauth.token.Update(Token{})
- }
-}
-
-func (oauth *OAuth) Token() Token {
- t := Token{}
- if oauth.token != nil {
- t = oauth.token.Get()
- }
-
- return t
-}
-
-// tokensWithAuthCode gets the access and refresh tokens using the authorization code
-// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
-// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
-// If it was unsuccessful it returns an error.
-func (oauth *OAuth) tokensWithAuthCode(ctx context.Context, authCode string) error {
- // Make sure the verifier is set as the parameter
- // so that the server can verify that we are the actual owner of the authorization code
- u := oauth.TokenURL
-
- data := url.Values{
- "client_id": {oauth.ClientID},
- "code": {authCode},
- "code_verifier": {oauth.session.Verifier},
- "grant_type": {"authorization_code"},
- "redirect_uri": {oauth.session.RedirectURI},
- }
- h := http.Header{
- "content-type": {"application/x-www-form-urlencoded"},
- }
- opts := &httpw.OptionalParams{Headers: h, Body: data}
- now := time.Now()
-
- // We are sure that we have a http client because we have initialized it when starting the exchange
- _, body, err := oauth.httpClient.PostWithOpts(ctx, u, opts)
- if err != nil {
- return err
- }
-
- tr, err := oauth.tokenResponse(body, u)
- if err != nil {
- return err
- }
- if tr == nil {
- return errors.New("No token response after authorization code")
- }
-
- oauth.token.UpdateResponse(*tr, now)
- return nil
-}
-
-func (oauth *OAuth) UpdateTokens(t Token) {
- if oauth.token == nil {
- oauth.token = &tokenLock{t: &tokenRefresher{Refresher: oauth.refreshResponse}}
- }
- oauth.token.Update(t)
-}
-
-// refreshResponse gets the refresh token response with a refresh token
-// This response contains the access and refresh tokens, together with a timestamp
-// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
-// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
-// If it was unsuccessful it returns an error.
-func (oauth *OAuth) refreshResponse(ctx context.Context, r string) (*TokenResponse, time.Time, error) {
- u := oauth.TokenURL
- if oauth.token == nil {
- return nil, time.Time{}, errors.New("No oauth token structure in refresh")
- }
- if oauth.ClientID == "" {
- return nil, time.Time{}, errors.New("No client ID was cached for refresh")
- }
- data := url.Values{
- "client_id": {oauth.ClientID},
- "refresh_token": {r},
- "grant_type": {"refresh_token"},
- }
- h := http.Header{
- "content-type": {"application/x-www-form-urlencoded"},
- }
- opts := &httpw.OptionalParams{Headers: h, Body: data}
- now := time.Now()
-
- // Test if we have a http client and if not recreate one
- if oauth.httpClient == nil {
- oauth.httpClient = httpw.NewClient()
- }
-
- _, body, err := oauth.httpClient.PostWithOpts(ctx, u, opts)
- if err != nil {
- return nil, time.Time{}, err
- }
-
- tr, err := oauth.tokenResponse(body, u)
- return tr, now, err
-}
-
-// responseTemplate is the HTML template for the OAuth authorized response
-// this template was dapted from: https://github.com/eduvpn/apple/blob/5b18f834be7aebfed00570ae0c2f7bcbaf1c69cc/EduVPN/Helpers/Mac/OAuthRedirectHTTPHandler.m#L25
-const responseTemplate string = `
-<!DOCTYPE html>
-<html dir="ltr" xmlns="http://www.w3.org/1999/xhtml" lang="en"><head>
-<meta http-equiv="content-type" content="text/html; charset=UTF-8">
-<meta charset="utf-8">
-<title>{{.Title}}</title>
-<style>
-body {
- font-family: arial;
- margin: 0;
- height: 100vh;
- display: flex;
- align-items: center;
- justify-content: center;
- background: #ccc;
- color: #252622;
-}
-main {
- padding: 1em 2em;
- text-align: center;
- border: 2pt solid #666;
- box-shadow: rgba(0, 0, 0, 0.6) 0px 1px 4px;
- border-color: #aaa;
- background: #ddd;
-}
-</style>
-</head>
-<body>
- <main>
- <h1>{{.Title}}</h1>
- <p>{{.Message}}</p>
- </main>
-</body>
-</html>
-`
-
-// oauthResponseHTML is a structure that is used to give back the OAuth response.
-type oauthResponseHTML struct {
- Title string
- Message string
-}
-
-// writeResponseHTML writes the OAuth response using a response writer and the title + message
-// If it was unsuccessful it returns an error.
-func writeResponseHTML(w http.ResponseWriter, title string, message string) error {
- t, err := template.New("oauth-response").Parse(responseTemplate)
- if err != nil {
- return errors.WrapPrefix(err, "failed writing response HTML", 0)
- }
-
- return t.Execute(w, oauthResponseHTML{Title: title, Message: message})
-}
-
-// Authcode gets the authorization code from the url
-// It returns the code and an error if there is one
-func (s *exchangeSession) Authcode(url *url.URL) (string, error) {
- // ISS: https://www.rfc-editor.org/rfc/rfc9207.html
- q := url.Query()
-
- // first check ISS
- iss := q.Get("iss")
- if s.ISS != "" && s.ISS != iss {
- return "", errors.Errorf("failed matching ISS; expected '%s' got '%s'", s.ISS, iss)
- }
- // Make sure the state is present and matches to protect against cross-site request forgeries
- // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
- state := q.Get("state")
- if state == "" {
- return "", errors.Errorf("failed retrieving parameter 'state' from '%s'", url)
- }
- // The state is the first entry
- if state != s.State {
- return "", errors.Errorf("failed matching state; expected '%s' got '%s'", s.State, state)
- }
-
- // check if an error is present
- // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-09#name-authorization-response (error response)
- errc := q.Get("error")
- if errc != "" {
- // these are optional but let's include them
- errdesc := q.Get("error_description")
- erruri := q.Get("error_uri")
- return "", errors.Errorf("failed obtaining oauthorization code, error code '%s', error description '%s', error uri '%s'", errc, errdesc, erruri)
- }
-
- // No authorization code
- code := q.Get("code")
- if code == "" {
- return "", errors.Errorf("failed retrieving parameter 'code' from '%s'", url)
- }
-
- return code, nil
-}
-
-// tokenHandler gets the tokens using the authorization code that is obtained through the url
-// This function is called by the http handler and returns an error if the tokens cannot be obtained
-func (oauth *OAuth) tokenHandler(ctx context.Context, url *url.URL) error {
- // Get the authorization code
- c, err := oauth.session.Authcode(url)
- if err != nil {
- return err
- }
- // Now that we have obtained the authorization code, we can move to the next step:
- // Obtaining the access and refresh tokens
- return oauth.tokensWithAuthCode(ctx, c)
-}
-
-// Handler is the function used to get the OAuth tokens using an authorization code callback
-// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
-// It sends an error to the session channel (can be nil)
-func (oauth *OAuth) Handler(w http.ResponseWriter, req *http.Request) {
- // TODO: should this be something else than context background?
- err := oauth.tokenHandler(context.Background(), req.URL)
- if err != nil {
- _ = writeResponseHTML(
- w,
- "Authorization Failed",
- "The authorization has failed. See the log file for more information.",
- )
- } else {
- _ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.")
- }
- oauth.session.ErrChan <- err
-}
-
-// Init initializes OAuth with the following parameters:
-// - OAuth server issuer identification
-// - The URL used for authorization
-// - The URL to obtain new tokens.
-func (oauth *OAuth) Init(clientID string, iss string, baseAuthorizationURL string, tokenURL string) {
- oauth.ClientID = clientID
- oauth.ISS = iss
- oauth.BaseAuthorizationURL = baseAuthorizationURL
- oauth.TokenURL = tokenURL
-}
-
-// AuthURL gets the authorization url to start the OAuth procedure.
-func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string, cr string) (string, error) {
- // Update the client ID
- oauth.ClientID = name
-
- // Generate the verifier and challenge
- v, err := genVerifier()
- if err != nil {
- return "", errors.WrapPrefix(err, "genVerifier error", 0)
- }
-
- // Generate the state
- state, err := genState()
- if err != nil {
- return "", errors.WrapPrefix(err, "genState error", 0)
- }
-
- // Re-initialize the token structure
- oauth.UpdateTokens(Token{})
-
- // Fill the struct with the necessary fields filled for the next call to getting the HTTP client
- red := cr
-
- // no custom redirect URI defined, we setup our own
- var l net.Listener
- if cr == "" {
- // set up the listener to get the redirect URI
- l, err = oauth.setupListener()
- if err != nil {
- return "", errors.WrapPrefix(err, "oauth.setupListener error", 0)
- }
- port := l.Addr().(*net.TCPAddr).Port
- // see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php
- red = fmt.Sprintf("http://127.0.0.1:%d/callback", port)
- }
- oauth.session = exchangeSession{
- ISS: oauth.ISS,
- State: state,
- Verifier: v,
- ErrChan: make(chan error),
- RedirectURI: red,
- Listener: l,
- }
-
- params := map[string]string{
- "client_id": name,
- "code_challenge_method": "S256",
- "code_challenge": genChallengeS256(v),
- "response_type": "code",
- "scope": "config",
- "state": state,
- "redirect_uri": red,
- }
-
- p, err := url.Parse(oauth.BaseAuthorizationURL)
- if err != nil {
- return "", errors.WrapPrefix(err, fmt.Sprintf("failed to parse OAuth base URL '%s'", oauth.BaseAuthorizationURL), 0)
- }
- // Make sure the scheme is HTTPS
- p.Scheme = "https"
-
- u, err := httpw.ConstructURL(p, params)
- if err != nil {
- return "", errors.WrapPrefix(err, "httpw.ConstructURL error", 0)
- }
-
- // Return the url processed
- return postProcessAuth(u), nil
-}
-
-func (oauth *OAuth) tokensWithURI(ctx context.Context, uri string) error {
- // parse URI
- p, err := url.Parse(uri)
- if err != nil {
- return err
- }
- return oauth.tokenHandler(ctx, p)
-}
-
-// Exchange starts the OAuth exchange by getting the tokens with the redirect callback
-// If it was unsuccessful it returns an error.
-func (oauth *OAuth) Exchange(ctx context.Context, uri string) error {
- // If there is no HTTP client defined, create a new one
- if oauth.httpClient == nil {
- oauth.httpClient = httpw.NewClient()
- }
- if uri != "" {
- return oauth.tokensWithURI(ctx, uri)
- }
- return oauth.tokensWithCallback(ctx)
-}
-
-type CancelledCallbackError struct{}
-
-func (e *CancelledCallbackError) Error() string {
- return "client cancelled OAuth"
-}
-
-type TokensInvalidError struct {
- Cause string
-}
-
-func (e *TokensInvalidError) Error() string {
- return fmt.Sprintf("tokens are invalid due to: %s", e.Cause)
-}
diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go
deleted file mode 100644
index 1181b5d..0000000
--- a/internal/oauth/oauth_test.go
+++ /dev/null
@@ -1,222 +0,0 @@
-package oauth
-
-import (
- "context"
- "encoding/json"
- "net/url"
- "strings"
- "testing"
- "time"
-)
-
-func Test_verifiergen(t *testing.T) {
- v, err := genVerifier()
- if err != nil {
- t.Fatalf("Gen verifier error: %v", err)
- }
-
- // Verifier must be at minimum 43 and at max 128 characters...
- // However... Our verifier is exactly 43!
- if len(v) != 43 {
- t.Fatalf(
- "Got verifier length: %d, want a verifier with at least 43 characters",
- len(v),
- )
- }
-
- _, err = url.QueryUnescape(v)
- if err != nil {
- t.Fatalf("Verifier: %s can not be unescaped", v)
- }
-}
-
-func Test_stategen(t *testing.T) {
- s1, err := genState()
- if err != nil {
- t.Fatalf("Error when generating state 1: %v", err)
- }
-
- s2, err := genState()
- if err != nil {
- t.Fatalf("Error when generating state 2: %v", err)
- }
-
- if s1 == s2 {
- t.Fatalf("State: %v, equal to: %v", s1, s2)
- }
-}
-
-func Test_challengergen(t *testing.T) {
- verifier := "test"
- // Calculated using: base64.urlsafe_b64encode(hashlib.sha256("test".encode("utf-8")).digest()).decode("utf-8").replace("=", "") in Python
- // This test might not be the best because we're now comparing two different implementations, but at least it gives us a way to see if we messed something up in a commit
- want := "n4bQgYhMfWWaL-qgxVrQFaO_TxsrC4Is0V1sFbDwCgg"
- got := genChallengeS256(verifier)
-
- if got != want {
- t.Fatalf("Challenger not equal, got: %v, want: %v", got, want)
- }
-}
-
-func Test_accessToken(t *testing.T) {
- o := OAuth{}
- _, err := o.AccessToken(context.Background())
- if err == nil {
- t.Fatalf("No error when getting access token on empty structure")
- }
-
- // Here we should get no error because the access token is set and is not expired
- want := "test"
- expired := time.Now().Add(1 * time.Hour)
- o = OAuth{token: &tokenLock{t: &tokenRefresher{Token: Token{Access: want, ExpiredTimestamp: expired}}}}
- got, err := o.AccessToken(context.Background())
- if err != nil {
- t.Fatalf("Got error when getting access token on non-empty structure: %v", err)
- }
- if got != want {
- t.Fatalf("Access token not equal, Got: %v, Want: %v", got, want)
- }
-
- // Set the tokens as expired
- o.SetTokenExpired()
-
- // We should get an error because expired and no refresh token
- _, err = o.AccessToken(context.Background())
- if err == nil {
- t.Fatal("Got no error when getting access token on non-empty structure and expired")
- }
-
- want = "test2"
- // Now we internally update the refresh function and refresh token, we should get new tokens
- refresh := "refresh"
- o.token.t.Refresh = refresh
- o.token.t.Refresher = func(ctx context.Context, refreshToken string) (*TokenResponse, time.Time, error) {
- if refreshToken != refresh {
- t.Fatalf("Passed refresh token to refresher not equal to updated refresh token, got: %v, want: %v", refreshToken, refresh)
- }
- // Only the access and refresh fields are really important
- r := &TokenResponse{Access: want, Refresh: "test2"}
- return r, expired, nil
- }
-
- got, err = o.AccessToken(context.Background())
- if err != nil {
- t.Fatalf("Got error when getting access token on non-empty expired structure and with a 'valid' refresh token: %v", err)
- }
- if got != want {
- t.Fatalf("Access token not equal, Got: %v, Want: %v", got, want)
- }
-
-
- // Set the tokens as expired
- o.SetTokenExpired()
- want = "test3"
-
- // Now let's act like a 2.x server, we give no refresh token back. When we refresh the previous refresh token should be gotten
- o.token.t.Refresh = refresh
- prevRefresh := refresh
- o.token.t.Refresher = func(ctx context.Context, refreshToken string) (*TokenResponse, time.Time, error) {
- if refreshToken != refresh {
- t.Fatalf("Passed refresh token to refresher not equal to updated refresh token, got: %v, want: %v", refreshToken, refresh)
- }
- // Only the access token is returned now
- r := &TokenResponse{Access: want}
- return r, expired, nil
- }
-
- got, err = o.AccessToken(context.Background())
- if err != nil {
- t.Fatalf("Got error when getting access token on non-empty expired structure and with an empty refresh response: %v", err)
- }
- if got != want {
- t.Fatalf("Access token not equal, Got: %v, Want: %v", got, want)
- }
- if o.token.t.Refresh == "" {
- t.Fatalf("Refresh token is empty after refreshing and getting back an empty refresh")
- }
- if o.token.t.Refresh != prevRefresh {
- t.Fatalf("Refresh token is not equal to previous refresh token after refreshing and getting back an empty refresh token, got: %v, want: %v", o.token.t.Refresh, prevRefresh)
- }
-}
-
-func Test_secretJSON(t *testing.T) {
- // Access and refresh tokens should not be present in marshalled JSON
- a := "ineedtobesecret_access"
- r := "ineedtobesecret_refresh"
- o := OAuth{token: &tokenLock{t: &tokenRefresher{Token: Token{Access: a, Refresh: r}}}}
- b, err := json.Marshal(o)
- if err != nil {
- t.Fatalf("Error when marshalling OAuth JSON: %v", err)
- }
- s := string(b)
- // Of course this is a very dumb check, it could be that we are writing in some other serialized format. However, we simply marshal the structure directly. Go just serializes this as a simple string
- if strings.Contains(s, a) {
- t.Fatalf("Serialized OAuth contains Access Token! Serialized: %v, Access Token: %v", s, a)
- }
-
- if strings.Contains(s, r) {
- t.Fatalf("Serialized OAuth contains Refresh Token! Serialized: %v, Refresh Token: %v", s, a)
- }
-}
-
-func Test_AuthURL(t *testing.T) {
- iss := "local"
- auth := "https://127.0.0.1/auth"
- token := "https://127.0.0.1/token"
- id := "client_id"
- o := OAuth{ISS: iss, BaseAuthorizationURL: auth, TokenURL: token}
- s, err := o.AuthURL(id, func(s string) string {
- // We do nothing here are this function is for skipping WAYF
- return s
- }, "")
- if err != nil {
- t.Fatalf("Error in getting OAuth URL: %v", err)
- }
-
- // Check if the OAuth session has valid values
- if o.ClientID != id {
- t.Fatalf("OAuth ClientID not equal, want: %v, got: %v", o.ClientID, id)
- }
- if o.session.ISS != iss {
- t.Fatalf("OAuth ISS not equal, want: %v, got: %v", o.session.ISS, iss)
- }
- if o.session.State == "" {
- t.Fatal("No OAuth session state paremeter found")
- }
- if o.session.Verifier == "" {
- t.Fatal("No OAuth session state paremeter found")
- }
- if o.session.ErrChan == nil {
- t.Fatal("No OAuth session error channel found")
- }
-
- u, err := url.Parse(s)
- if err != nil {
- t.Fatalf("Returned Auth URL cannot be parsed with error: %v", err)
- }
-
- c := []struct {
- query string
- want string
- }{
- {query: "client_id", want: id},
- {query: "code_challenge_method", want: "S256"},
- {query: "response_type", want: "code"},
- {query: "scope", want: "config"},
- {query: "redirect_uri", want: o.session.RedirectURI},
- }
-
- q := u.Query()
-
- // We should have 7 parameters: client_id, challenge method, challenge, response type, scope, state and redirect uri
- if len(q) != 7 {
- t.Fatalf("Total query parameters is not 7, url: %v, total params: %v", u, len(q))
- }
-
- for _, v := range c {
- p := q.Get(v.query)
- if p != v.want {
- t.Fatalf("Parameter: %v, not equal, want: %v, got: %v", v.query, v.want, p)
- }
- }
-}
diff --git a/internal/oauth/token.go b/internal/oauth/token.go
deleted file mode 100644
index 251b689..0000000
--- a/internal/oauth/token.go
+++ /dev/null
@@ -1,162 +0,0 @@
-package oauth
-
-import (
- "context"
- "fmt"
- "sync"
- "time"
-
- "github.com/eduvpn/eduvpn-common/internal/log"
- "github.com/eduvpn/eduvpn-common/types/server"
- "github.com/go-errors/errors"
-)
-
-// TokenResponse defines the OAuth response from the server that includes the tokens.
-type TokenResponse struct {
- // Access is the access token returned by the server
- Access string `json:"access_token"`
-
- // Refresh token is the refresh token returned by the server
- Refresh string `json:"refresh_token"`
-
- // Type indicates which type of tokens we have
- Type string `json:"token_type"`
-
- // Expires is the expires time returned by the server
- Expires int64 `json:"expires_in"`
-}
-
-// The public type that can be passed to an update function
-// It contains our access and refresh tokens with a timestamp
-type Token struct {
- // Access is the Access token returned by the server
- Access string
-
- // Refresh token is the Refresh token returned by the server
- Refresh string
-
- // ExpiredTimestamp is the Expires field but converted to a Go timestamp
- ExpiredTimestamp time.Time
-}
-
-func (t *Token) Public() server.Tokens {
- return server.Tokens{
- Access: t.Access,
- Refresh: t.Refresh,
- Expires: t.ExpiredTimestamp.Unix(),
- }
-}
-
-// tokenRefresher is a structure that contains our access and refresh tokens and a timestamp when they expire.
-// Additionally, it contains the refresher to get new tokens
-type tokenRefresher struct {
- Token
- // Refresher is the function that refreshes the token
- Refresher func(context.Context, string) (*TokenResponse, time.Time, error)
-}
-
-// tokenLock is a wrapper around token that protects it with a lock
-type tokenLock struct {
- // Protects t
- mu sync.Mutex
-
- // The token fields protected by the lock
- // This token struct contains a refresher
- t *tokenRefresher
-}
-
-// Access gets the OAuth access token used for contacting the server API
-// It returns the access token as a string, possibly obtained fresh using the refresher
-// If the token cannot be obtained, an error is returned and the token is an empty string.
-func (l *tokenLock) Access(ctx context.Context) (string, error) {
- log.Logger.Debugf("Getting access token")
- l.mu.Lock()
- defer l.mu.Unlock()
-
- // The tokens are not expired yet
- // So they should be valid, re-login not neede
- if !l.expired() {
- log.Logger.Debugf("Access token is not expired, returning")
- return l.t.Access, nil
- }
-
- // Check if refresh is even possible by doing a simple check if the refresh token is empty
- // This is not needed but reduces API calls to the server
- if l.t.Refresh == "" {
- log.Logger.Debugf("Refresh token is empty, returning error")
- return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0)
- }
-
- // Otherwise refresh and then later return the access token if we are successful
- tr, s, err := l.t.Refresher(ctx, l.t.Refresh)
- if err != nil {
- log.Logger.Debugf("Got a refresh token error: %v", err)
- // We have failed to ensure the tokens due to refresh not working
- return "", errors.Wrap(
- &TokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", err)}, 0)
- }
- if tr == nil {
- log.Logger.Debugf("No token response after refreshing")
- return "", errors.New("No token response after refreshing")
- }
- // store the previous refresh token
- pr := l.t.Refresh
- // get the response as a non-pointer
- r := *tr
- e := s.Add(time.Second * time.Duration(r.Expires))
- t := Token{Access: r.Access, Refresh: r.Refresh, ExpiredTimestamp: e}
- l.updateInternal(t)
- // set the previous refresh token if the new one is empty
- // This is for 2.x servers
- if l.t.Refresh == "" {
- log.Logger.Debugf("The previous refresh token is set as the response had no refresh token")
- l.t.Refresh = pr
- }
- return l.t.Access, nil
-}
-
-// UpdateResponse updates the structure using the server response and locks
-func (l *tokenLock) UpdateResponse(r TokenResponse, s time.Time) {
- l.mu.Lock()
- e := s.Add(time.Second * time.Duration(r.Expires))
- t := Token{Access: r.Access, Refresh: r.Refresh, ExpiredTimestamp: e}
- l.updateInternal(t)
- l.mu.Unlock()
-}
-
-// updateInternal updates the token structure internally but does not lock
-func (l *tokenLock) updateInternal(r Token) {
- l.t.Access = r.Access
- l.t.Refresh = r.Refresh
- l.t.ExpiredTimestamp = r.ExpiredTimestamp
-}
-
-// Update updates the token structure using the internal function but locks
-func (l *tokenLock) Update(r Token) {
- l.mu.Lock()
- l.updateInternal(r)
- l.mu.Unlock()
-}
-
-// Get gets the tokens into a public struct
-func (l *tokenLock) Get() Token {
- // TODO: Check nil?
- l.mu.Lock()
- defer l.mu.Unlock()
- return l.t.Token
-}
-
-// SetExpired overrides the timestamp to the current time
-// This marks the tokens as expired
-func (l *tokenLock) SetExpired() {
- l.mu.Lock()
- l.t.ExpiredTimestamp = time.Now()
- l.mu.Unlock()
-}
-
-// expired checks if the access token is expired.
-// This is only called internally and thus does not lock
-func (l *tokenLock) expired() bool {
- now := time.Now()
- return !now.Before(l.t.ExpiredTimestamp)
-}
diff --git a/internal/util/util.go b/internal/util/util.go
index 85c4b37..97b4151 100644
--- a/internal/util/util.go
+++ b/internal/util/util.go
@@ -2,23 +2,12 @@
package util
import (
- "crypto/rand"
"fmt"
"net/url"
"os"
"strings"
)
-// MakeRandomByteSlice creates a cryptographically random bytes slice of `size`
-// It returns the byte slice (or nil if error) and an error if it could not be generated.
-func MakeRandomByteSlice(n int) ([]byte, error) {
- bs := make([]byte, n)
- if _, err := rand.Read(bs); err != nil {
- return nil, errors.WrapPrefix(err, "failed reading random", 0)
- }
- return bs, nil
-}
-
// EnsureDirectory creates a directory with permission 700.
func EnsureDirectory(dir string) error {
// Create with 700 permissions, read, write, execute only for the owner
diff --git a/internal/util/util_test.go b/internal/util/util_test.go
index 5e19f57..827fbe1 100644
--- a/internal/util/util_test.go
+++ b/internal/util/util_test.go
@@ -1,28 +1,6 @@
package util
-import (
- "bytes"
- "testing"
-)
-
-func TestMakeRandomByteSlice(t *testing.T) {
- random, randomErr := MakeRandomByteSlice(32)
- if randomErr != nil {
- t.Fatalf("Got: %v, want: nil", randomErr)
- }
- if len(random) != 32 {
- t.Fatalf("Got length: %d, want length: 32", len(random))
- }
-
- random2, randomErr2 := MakeRandomByteSlice(32)
- if randomErr2 != nil {
- t.Fatalf("2, Got: %v, want: nil", randomErr)
- }
-
- if bytes.Equal(random2, random) {
- t.Fatalf("Two random byteslices are the same: %v, %v", random2, random)
- }
-}
+import "testing"
func TestReplaceWAYF(t *testing.T) {
// We expect url encoding but the spaces to be correctly replace with a + instead of a %20