From 7af07c596166bf93b79a9d0816b1950dde360fb9 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Fri, 17 Jun 2022 14:00:40 +0200 Subject: Server: Implement function for checking renewal button visibility --- internal/oauth/oauth.go | 434 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 434 insertions(+) create mode 100644 internal/oauth/oauth.go (limited to 'internal/oauth') diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go new file mode 100644 index 0000000..f6ed916 --- /dev/null +++ b/internal/oauth/oauth.go @@ -0,0 +1,434 @@ +package oauth + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "github.com/jwijenbergh/eduvpn-common/internal/fsm" + httpw "github.com/jwijenbergh/eduvpn-common/internal/http" + "github.com/jwijenbergh/eduvpn-common/internal/util" + "github.com/jwijenbergh/eduvpn-common/internal/log" +) + +// Generates a random base64 string to be used for state +// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1 +// "state": OPTIONAL. An opaque value used by the client to maintain +// state between the request and callback. The authorization server +// includes this value when redirecting the user agent back to the +// client. +func genState() (string, error) { + randomBytes, err := util.MakeRandomByteSlice(32) + if err != nil { + return "", &OAuthGenStateError{Err: err} + } + + // For consistency we also use raw url encoding here + return base64.RawURLEncoding.EncodeToString(randomBytes), nil +} + +// Generates a sha256 base64 challenge from a verifier +// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.8 +func genChallengeS256(verifier string) string { + hash := sha256.Sum256([]byte(verifier)) + + // We use raw url encoding as the challenge does not accept padding + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + +// Generates a verifier +// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1 +// The code_verifier is a unique high-entropy cryptographically random +// string generated for each authorization request, using the unreserved +// characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~", with a +// minimum length of 43 characters and a maximum length of 128 +// characters. +func genVerifier() (string, error) { + randomBytes, err := util.MakeRandomByteSlice(32) + if err != nil { + return "", &OAuthGenVerifierError{Err: err} + } + + return base64.RawURLEncoding.EncodeToString(randomBytes), nil +} + +type OAuth struct { + Session OAuthExchangeSession `json:"-"` + Token OAuthToken `json:"token"` + BaseAuthorizationURL string `json:"base_authorization_url"` + TokenURL string `json:"token_url"` + Logger *log.FileLogger `json:"-"` + FSM *fsm.FSM `json:"-"` +} + +// This structure gets passed to the callback for easy access to the current state +type OAuthExchangeSession struct { + // returned from the callback + CallbackError error + + // filled in in initialize + ClientID string + State string + Verifier string + + // filled in when constructing the callback + Context context.Context + Server http.Server +} + +// Struct that defines the json format for /.well-known/vpn-user-portal" +type OAuthToken struct { + Access string `json:"access_token"` + Refresh string `json:"refresh_token"` + Type string `json:"token_type"` + Expires int64 `json:"expires_in"` + ExpiredTimestamp int64 `json:"expires_in_timestamp"` +} + +// Gets an authorized HTTP client by obtaining refresh and access tokens +func (oauth *OAuth) getTokensWithCallback() error { + oauth.Session.Context = context.Background() + mux := http.NewServeMux() + addr := "127.0.0.1:8000" + oauth.Session.Server = http.Server{ + Addr: addr, + Handler: mux, + } + mux.HandleFunc("/callback", oauth.Callback) + if err := oauth.Session.Server.ListenAndServe(); err != http.ErrServerClosed { + return &OAuthCallbackError{Addr: addr, Err: err} + } + return oauth.Session.CallbackError +} + +// Get the access and refresh tokens +// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 +// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 +func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { + // Make sure the verifier is set as the parameter + // so that the server can verify that we are the actual owner of the authorization code + + reqURL := oauth.TokenURL + data := url.Values{ + "client_id": {oauth.Session.ClientID}, + "code": {authCode}, + "code_verifier": {oauth.Session.Verifier}, + "grant_type": {"authorization_code"}, + "redirect_uri": {"http://127.0.0.1:8000/callback"}, + } + headers := http.Header{ + "content-type": {"application/x-www-form-urlencoded"}, + } + opts := &httpw.HTTPOptionalParams{Headers: headers, Body: data} + current_time := util.GenerateTimeSeconds() + _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) + if bodyErr != nil { + return &OAuthAuthError{Err: bodyErr} + } + + tokenStructure := OAuthToken{} + + jsonErr := json.Unmarshal(body, &tokenStructure) + + if jsonErr != nil { + return &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} + } + + tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires + oauth.Token = tokenStructure + return nil +} + +func (oauth *OAuth) isTokensExpired() bool { + expired_time := oauth.Token.ExpiredTimestamp + current_time := util.GenerateTimeSeconds() + return current_time >= expired_time +} + +// Get the access and refresh tokens with a previously received refresh token +// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 +// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 +func (oauth *OAuth) getTokensWithRefresh() error { + reqURL := oauth.TokenURL + data := url.Values{ + "refresh_token": {oauth.Token.Refresh}, + "grant_type": {"refresh_token"}, + } + headers := http.Header{ + "content-type": {"application/x-www-form-urlencoded"}, + } + opts := &httpw.HTTPOptionalParams{Headers: headers, Body: data} + current_time := util.GenerateTimeSeconds() + _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) + if bodyErr != nil { + return &OAuthRefreshError{Err: bodyErr} + } + + tokenStructure := OAuthToken{} + jsonErr := json.Unmarshal(body, &tokenStructure) + + if jsonErr != nil { + return &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} + } + + tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires + oauth.Token = tokenStructure + return nil +} + +// +//// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 +func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { + // Extract the authorization code + code, success := req.URL.Query()["code"] + if !success { + oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()} + go oauth.Session.Server.Shutdown(oauth.Session.Context) + return + } + // The code is the first entry + extractedCode := code[0] + + // Make sure the state is present and matches to protect against cross-site request forgeries + // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 + state, success := req.URL.Query()["state"] + if !success { + oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()} + go oauth.Session.Server.Shutdown(oauth.Session.Context) + return + } + // The state is the first entry + extractedState := state[0] + if extractedState != oauth.Session.State { + oauth.Session.CallbackError = &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State} + go oauth.Session.Server.Shutdown(oauth.Session.Context) + return + } + + // Now that we have obtained the authorization code, we can move to the next step: + // Obtaining the access and refresh tokens + err := oauth.getTokensWithAuthCode(extractedCode) + if err != nil { + oauth.Session.CallbackError = &OAuthCallbackGetTokensError{Err: err} + go oauth.Session.Server.Shutdown(oauth.Session.Context) + return + } + + // Shutdown the server as we're done listening + go oauth.Session.Server.Shutdown(oauth.Session.Context) +} + +func (oauth *OAuth) Update(fsm *fsm.FSM, logger *log.FileLogger) { + oauth.FSM = fsm + oauth.Logger = logger +} + +func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.FSM, logger *log.FileLogger) { + oauth.BaseAuthorizationURL = baseAuthorizationURL + oauth.TokenURL = tokenURL + oauth.FSM = fsm + oauth.Logger = logger +} + +// Starts the OAuth exchange for eduvpn. +func (oauth *OAuth) start(name string) error { + if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) { + return &fsm.FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED} + } + // Generate the state + state, stateErr := genState() + if stateErr != nil { + return &OAuthInitializeError{Err: stateErr} + } + + // Generate the verifier and challenge + verifier, verifierErr := genVerifier() + if verifierErr != nil { + return &OAuthInitializeError{Err: verifierErr} + } + challenge := genChallengeS256(verifier) + + parameters := map[string]string{ + "client_id": name, + "code_challenge_method": "S256", + "code_challenge": challenge, + "response_type": "code", + "scope": "config", + "state": state, + "redirect_uri": "http://127.0.0.1:8000/callback", + } + + authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters) + + if urlErr != nil { + return &OAuthInitializeError{Err: urlErr} + } + + // Fill the struct with the necessary fields filled for the next call to getting the HTTP client + oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier} + oauth.Session = oauthSession + // Run the state callback in the background so that the user can login while we start the callback server + oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, authURL, true) + return nil +} + +// Error definitions +func (oauth *OAuth) Finish() error { + if !oauth.FSM.HasTransition(fsm.AUTHORIZED) { + return &fsm.FSMWrongStateError{Got: oauth.FSM.Current, Want: fsm.AUTHORIZED} + } + tokenErr := oauth.getTokensWithCallback() + + if tokenErr != nil { + return &OAuthFinishError{Err: tokenErr} + } + oauth.FSM.GoTransition(fsm.AUTHORIZED) + return nil +} + +func (oauth *OAuth) Cancel() { + oauth.Session.CallbackError = &OAuthCancelledCallbackError{} + oauth.Session.Server.Shutdown(oauth.Session.Context) +} + +func (oauth *OAuth) Login(name string) error { + authInitializeErr := oauth.start(name) + + if authInitializeErr != nil { + return &OAuthLoginError{Err: authInitializeErr} + } + + oauthErr := oauth.Finish() + + if oauthErr != nil { + return &OAuthLoginError{Err: oauthErr} + } + return nil +} + +func (oauth *OAuth) NeedsRelogin() bool { + // Access Token or Refresh Tokens empty, definitely needs a relogin + if oauth.Token.Access == "" || oauth.Token.Refresh == "" { + oauth.Logger.Log(log.LOG_INFO, "OAuth: Tokens are empty") + return true + } + + // We have tokens... + + // The tokens are not expired yet + // No relogin is needed + if !oauth.isTokensExpired() { + oauth.Logger.Log(log.LOG_INFO, "OAuth: Tokens are not expired, re-login not needed") + return false + } + + refreshErr := oauth.getTokensWithRefresh() + // We have obtained new tokens with refresh + if refreshErr == nil { + oauth.Logger.Log(log.LOG_INFO, "OAuth: Tokens could be re-acquired using the refresh token, re-login not needed") + return false + } + + // Otherwise relogin is really needed + return true +} + +type OAuthCancelledCallbackError struct{} + +func (e *OAuthCancelledCallbackError) Error() string { + return fmt.Sprintf("client cancelled OAuth") +} + +type OAuthGenStateError struct { + Err error +} + +func (e *OAuthGenStateError) Error() string { + return fmt.Sprintf("failed generating state with error: %v", e.Err) +} + +type OAuthGenVerifierError struct { + Err error +} + +func (e *OAuthGenVerifierError) Error() string { + return fmt.Sprintf("failed generating verifier with error: %v", e.Err) +} + +type OAuthCallbackError struct { + Addr string + Err error +} + +func (e *OAuthCallbackError) Error() string { + return fmt.Sprintf("failed callback: %s with error: %v", e.Addr, e.Err) +} + +type OAuthCallbackParameterError struct { + Parameter string + URL string +} + +func (e *OAuthCallbackParameterError) Error() string { + return fmt.Sprintf("failed retrieving parameter: %s in url: %s", e.Parameter, e.URL) +} + +type OAuthCallbackStateMatchError struct { + State string + ExpectedState string +} + +func (e *OAuthCallbackStateMatchError) Error() string { + return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState) +} + +type OAuthCallbackGetTokensError struct { + Err error +} + +func (e *OAuthCallbackGetTokensError) Error() string { + return fmt.Sprintf("failed getting tokens with error: %v", e.Err) +} + +type OAuthFinishError struct { + Err error +} + +func (e *OAuthFinishError) Error() string { + return fmt.Sprintf("failed finishing OAuth with error: %v", e.Err) +} + +type OAuthLoginError struct { + Err error +} + +func (e *OAuthLoginError) Error() string { + return fmt.Sprintf("failed OAuth logging in with error: %v", e.Err) +} + +type OAuthInitializeError struct { + Err error +} + +func (e *OAuthInitializeError) Error() string { + return fmt.Sprintf("failed initializing OAuth with error: %v", e.Err) +} + +type OAuthAuthError struct { + Err error +} + +func (e *OAuthAuthError) Error() string { + return fmt.Sprintf("failed getting tokens with auth code for OAuth with error: %v", e.Err) +} + +type OAuthRefreshError struct { + Err error +} + +func (e *OAuthRefreshError) Error() string { + return fmt.Sprintf("failed refreshing tokens for OAuth with error: %v", e.Err) +} -- cgit v1.2.3