diff options
Diffstat (limited to 'src/oauth.go')
| -rw-r--r-- | src/oauth.go | 142 |
1 files changed, 80 insertions, 62 deletions
diff --git a/src/oauth.go b/src/oauth.go index 6212124..5b7a5c4 100644 --- a/src/oauth.go +++ b/src/oauth.go @@ -12,6 +12,22 @@ import ( "strings" ) +type OAuthGenStateUnableError struct { + Err error +} + +func (e *OAuthGenStateUnableError) Error() string { + return fmt.Sprintf("failed generating state with error %v", e.Err) +} + +type OAuthGenVerifierUnableError struct { + Err error +} + +func (e *OAuthGenVerifierUnableError) Error() string { + return fmt.Sprintf("failed generating verifier with error %v", e.Err) +} + // Generates a random base64 string to be used for state // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1 // "state": OPTIONAL. An opaque value used by the client to maintain @@ -22,7 +38,7 @@ func genState() (string, error) { randomBytes, err := MakeRandomByteSlice(32) if err != nil { - return "", err + return "", &OAuthGenStateUnableError{Err: err} } // For consistency we also use raw url encoding here @@ -48,7 +64,7 @@ func genChallengeS256(verifier string) string { func genVerifier() (string, error) { randomBytes, err := MakeRandomByteSlice(32) if err != nil { - return "", err + return "", &OAuthGenVerifierUnableError{Err: err} } return base64.RawURLEncoding.EncodeToString(randomBytes), nil @@ -76,17 +92,27 @@ type EduVPNOAuthToken struct { Expires int `json:"expires_in"` } +type OAuthFailedCallbackError struct { + Addr string + Err error +} + +func (e *OAuthFailedCallbackError) Error() string { + return fmt.Sprintf("failed callback %s with error %v", e.Addr, e.Err) +} + // Gets an authenticated HTTP client by obtaining refresh and access tokens func (eduvpn *EduVPNOAuthSession) getHTTPTokenClient() error { eduvpn.context = context.Background() mux := http.NewServeMux() + addr := "127.0.0.1:8000" eduvpn.server = &http.Server{ - Addr: "127.0.0.1:8000", + Addr: addr, Handler: mux, } mux.HandleFunc("/callback", eduvpn.oauthCallback) if err := eduvpn.server.ListenAndServe(); err != http.ErrServerClosed { - return detailedOAuthError{errCallbackServerError, fmt.Sprintf("oauth callback server error"), err} + return &OAuthFailedCallbackError{Addr: addr, Err: err} } return eduvpn.callbackError } @@ -106,15 +132,16 @@ func (eduvpn *EduVPNOAuthSession) getTokens(authCode string) error { "redirect_uri": {"http://127.0.0.1:8000/callback"}, } client := &http.Client{} - req, reqErr := http.NewRequest(http.MethodPost, eduvpn.VPNState.Endpoints.API.V3.Token, strings.NewReader(data.Encode())) - if reqErr != nil { - return reqErr + url := eduvpn.VPNState.Endpoints.API.V3.Token + req, reqErr := http.NewRequest(http.MethodPost, url, strings.NewReader(data.Encode())) + if reqErr != nil { // shouldn't happen + panic(reqErr) } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") resp, reqErr := client.Do(req) if reqErr != nil { - return reqErr + return &HTTPResourceError{URL: url, Err: reqErr} } // Close the response body at the end @@ -123,14 +150,14 @@ func (eduvpn *EduVPNOAuthSession) getTokens(authCode string) error { // Read the body body, readErr := ioutil.ReadAll(resp.Body) if readErr != nil { - return readErr + return &HTTPReadError{URL: url, Err: readErr} } tokenStructure := &EduVPNOAuthToken{} jsonErr := json.Unmarshal(body, tokenStructure) if jsonErr != nil { - return jsonErr + return &HTTPParseJsonError{URL: url, Body: string(body), Err: jsonErr} } eduvpn.VPNState.OAuthToken = tokenStructure @@ -138,13 +165,39 @@ func (eduvpn *EduVPNOAuthSession) getTokens(authCode string) error { return nil } +type OAuthFailedCallbackParameterError struct { + Parameter string + URL string +} + +func (e *OAuthFailedCallbackParameterError) Error() string { + return fmt.Sprintf("failed retrieving parameter %s in url %s", e.Parameter, e.URL) +} + +type OAuthFailedCallbackStateMatchError struct { + State string + ExpectedState string +} + +func (e *OAuthFailedCallbackStateMatchError) Error() string { + return fmt.Sprintf("failed matching state, got %s, want %s", e.State, e.ExpectedState) +} + +type OAuthFailedCallbackGetTokensError struct { + Err error +} + +func (e *OAuthFailedCallbackGetTokensError) Error() string { + return fmt.Sprintf("failed getting tokens with error %v", e.Err) +} + // //// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 func (eduvpn *EduVPNOAuthSession) oauthCallback(w http.ResponseWriter, req *http.Request) { // Extract the authorization code code, success := req.URL.Query()["code"] if !success { - eduvpn.callbackError = detailedOAuthError{errCallbackGetAuthCodeError, fmt.Sprintf("oauth auth code cannot be retrieved"), nil} + eduvpn.callbackError = &OAuthFailedCallbackParameterError{Parameter: "code", URL: req.URL.String()} go eduvpn.server.Shutdown(eduvpn.context) return } @@ -155,14 +208,14 @@ func (eduvpn *EduVPNOAuthSession) oauthCallback(w http.ResponseWriter, req *http // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 state, success := req.URL.Query()["state"] if !success { - eduvpn.callbackError = detailedOAuthError{errCallbackGetStateError, fmt.Sprintf("oauth state cannot be retrieved"), nil} + eduvpn.callbackError = &OAuthFailedCallbackParameterError{Parameter: "state", URL: req.URL.String()} go eduvpn.server.Shutdown(eduvpn.context) return } // The state is the first entry extractedState := state[0] if extractedState != eduvpn.state { - eduvpn.callbackError = detailedOAuthError{errCallbackVerifyStateMatchError, fmt.Sprintf("oauth state does not match"), nil} + eduvpn.callbackError = &OAuthFailedCallbackStateMatchError{State: extractedState, ExpectedState: eduvpn.state} go eduvpn.server.Shutdown(eduvpn.context) return } @@ -172,7 +225,7 @@ func (eduvpn *EduVPNOAuthSession) oauthCallback(w http.ResponseWriter, req *http err := eduvpn.getTokens(extractedCode) if err != nil { - eduvpn.callbackError = detailedOAuthError{errCallbackGetTokenExchangeError, fmt.Sprintf("oauth failed to get token in exchange"), err} + eduvpn.callbackError = &OAuthFailedCallbackGetTokensError{Err: err} go eduvpn.server.Shutdown(eduvpn.context) return } @@ -197,6 +250,14 @@ func constructURL(baseURL string, parameters map[string]string) (string, error) return url.String(), nil } +type OAuthFailedInitializeError struct { + Err error +} + +func (e *OAuthFailedInitializeError) Error() string { + return fmt.Sprintf("failed initializing OAuth with error %v", e.Err) +} + // Initializes the OAuth for eduvpn. // It needs a vpn state that was gotten from `Register` // It returns the authurl for the browser and an error if present @@ -208,13 +269,13 @@ func InitializeOAuth(vpnState *EduVPNState) (string, error) { // Generate the state state, stateErr := genState() if stateErr != nil { - return "", detailedOAuthError{errGenStateError, fmt.Sprintf("oauth failed to gen random bytes for state"), stateErr} + return "", &OAuthFailedInitializeError{Err: stateErr} } // Generate the verifier and challenge - verifier, err := genVerifier() - if err != nil { - return "", detailedOAuthError{errGenVerifierError, fmt.Sprintf("oauth failed to verifier"), err} + verifier, verifierErr := genVerifier() + if verifierErr != nil { + return "", &OAuthFailedInitializeError{Err: verifierErr} } challenge := genChallengeS256(verifier) @@ -230,7 +291,7 @@ func InitializeOAuth(vpnState *EduVPNState) (string, error) { authURL, urlErr := constructURL(vpnState.Endpoints.API.V3.Authorization, parameters) - if urlErr != nil { + if urlErr != nil { // shouldn't happen panic(urlErr) } @@ -249,46 +310,3 @@ func FinishOAuth(vpnState *EduVPNState) error { } return vpnState.OAuthSession.getHTTPTokenClient() } - -// OAuthErrorCode Simplified error code for public interface. -type OAuthErrorCode = VPNErrorCode -type OAuthError = VPNError - -// detailedOAuthErrorCode used for unit tests. -type detailedOAuthErrorCode = detailedVPNErrorCode -type detailedOAuthError = detailedVPNError - -const ( - ErrGenError OAuthErrorCode = iota + 1 - ErrCallbackTokenError -) - -const ( - errGenStateError detailedOAuthErrorCode = iota + 1 - errGenVerifierError - errCallbackServerError - errCallbackGetAuthCodeError - errCallbackGetStateError - errCallbackVerifyStateMatchError - errCallbackGetTokenExchangeError -) - -func (err detailedOAuthError) ToOAuthError() OAuthError { - return RequestError{err.Code.ToOAuthErrorCode(), err} -} - -func (code detailedOAuthErrorCode) ToOAuthErrorCode() OAuthErrorCode { - switch code { - case errGenStateError: - case errGenVerifierError: - return ErrGenError - - case errCallbackServerError: - case errCallbackGetAuthCodeError: - case errCallbackGetStateError: - case errCallbackVerifyStateMatchError: - case errCallbackGetTokenExchangeError: - return ErrCallbackTokenError - } - panic("invalid detailedOAuthErrorCode") -} |
