From 56548c511163b4dd22d9a96a2f5ae647f1627a7b Mon Sep 17 00:00:00 2001 From: Jeroen Wijenbergh Date: Mon, 7 Mar 2022 15:43:07 +0100 Subject: Refactor: Simplify API by using a state as context --- src/oauth.go | 93 ++++++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 65 insertions(+), 28 deletions(-) (limited to 'src/oauth.go') diff --git a/src/oauth.go b/src/oauth.go index 9eb2272..87284a2 100644 --- a/src/oauth.go +++ b/src/oauth.go @@ -66,33 +66,8 @@ type EduVPNOauth struct { verifier string } -// Initializes the OAuth eduvpn class. It returns a tuple of the class and error. -// If the error is non-nil, the class will be nil. -func InitializeOAuth(config *oauth2.Config) (*EduVPNOauth, error) { - // Generate the state - state, stateErr := genState() - if stateErr != nil { - return nil, detailedOAuthError{errGenStateError, fmt.Sprintf("oauth failed to gen random bytes for state"), stateErr} - } - - // Generate the verifier and challenge - verifier, err := genVerifier() - if err != nil { - return nil, detailedOAuthError{errGenVerifierError, fmt.Sprintf("oauth failed to verifier"), err} - } - challenge := genChallengeS256(verifier) - - // Update the auth url with the challenge and state - codeChallengeMethod := oauth2.SetAuthURLParam("code_challenge_method", "S256") - codeChallenge := oauth2.SetAuthURLParam("code_challenge", challenge) - authURL := config.AuthCodeURL(state, codeChallengeMethod, codeChallenge) - - // Return the struct with the necessary fields filled for the next call to getting the HTTP client - return &EduVPNOauth{AuthURL: authURL, Config: config, state: state, verifier: verifier}, nil -} - // Gets an authenticated HTTP client by obtaining refresh and access tokens -func (eduvpn *EduVPNOauth) GetHTTPTokenClient() (*http.Client, error) { +func (eduvpn *EduVPNOauth) getHTTPTokenClient() error { eduvpn.context = context.Background() mux := http.NewServeMux() eduvpn.server = &http.Server{ @@ -101,9 +76,9 @@ func (eduvpn *EduVPNOauth) GetHTTPTokenClient() (*http.Client, error) { } mux.HandleFunc("/callback", eduvpn.oauthCallback) if err := eduvpn.server.ListenAndServe(); err != http.ErrServerClosed { - return nil, detailedOAuthError{errCallbackServerError, fmt.Sprintf("oauth callback server error"), err} + return detailedOAuthError{errCallbackServerError, fmt.Sprintf("oauth callback server error"), err} } - return eduvpn.client, eduvpn.callbackError + return eduvpn.callbackError } // Get the access and refresh tokens @@ -169,6 +144,68 @@ func (eduvpn *EduVPNOauth) oauthCallback(w http.ResponseWriter, req *http.Reques go eduvpn.server.Shutdown(eduvpn.context) } +// Generate a config for oauth +// It uses the state to get the server and the name +func genConfig(vpnState *EduVPNState) (*oauth2.Config, error) { + config := &oauth2.Config{ + RedirectURL: "http://127.0.0.1:8000/callback", + ClientID: vpnState.Name, + Scopes: []string{"config"}, + Endpoint: oauth2.Endpoint{ + AuthURL: vpnState.Endpoints.API.V3.AuthorizationEndpoint, + TokenURL: vpnState.Endpoints.API.V3.TokenEndpoint, + }, + } + return config, nil +} + +// Initializes the OAuth for eduvpn. +// It needs a vpn state that was gotten from `Register` +// It returns the authurl for the browser and an error if present +func InitializeOAuth(vpnState *EduVPNState) (string, error) { + if vpnState == nil { + panic("invalid state") + } + + config, configErr := genConfig(vpnState) + if configErr != nil { + return "", configErr + } + + // Generate the state + state, stateErr := genState() + if stateErr != nil { + return "", detailedOAuthError{errGenStateError, fmt.Sprintf("oauth failed to gen random bytes for state"), stateErr} + } + + // Generate the verifier and challenge + verifier, err := genVerifier() + if err != nil { + return "", detailedOAuthError{errGenVerifierError, fmt.Sprintf("oauth failed to verifier"), err} + } + challenge := genChallengeS256(verifier) + + // Update the auth url with the challenge and state + codeChallengeMethod := oauth2.SetAuthURLParam("code_challenge_method", "S256") + codeChallenge := oauth2.SetAuthURLParam("code_challenge", challenge) + authURL := config.AuthCodeURL(state, codeChallengeMethod, codeChallenge) + + // Fill the struct with the necessary fields filled for the next call to getting the HTTP client + vpnState.OAuth = &EduVPNOauth{AuthURL: authURL, Config: config, state: state, verifier: verifier} + return authURL, nil +} + +func FinishOAuth(vpnState *EduVPNState) error { + if vpnState == nil { + panic("invalid state") + } + + if vpnState.OAuth == nil { + panic("invalid oauth state") + } + return vpnState.OAuth.getHTTPTokenClient() +} + // OAuthErrorCode Simplified error code for public interface. type OAuthErrorCode = VPNErrorCode type OAuthError = VPNError -- cgit v1.2.3