summaryrefslogtreecommitdiff
path: root/src/oauth.go
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-03-11 13:52:49 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-04-05 12:26:14 +0200
commita019e95fdbaea3d7af2d8ad10903fd656bfc4466 (patch)
tree4e852e36da327823a02678dfb766aca58fa0a23f /src/oauth.go
parent5065de4cff907b70ea3446888a7bad243744a8ab (diff)
Refactor: Simplify errors for wrapping
Diffstat (limited to 'src/oauth.go')
-rw-r--r--src/oauth.go142
1 files changed, 80 insertions, 62 deletions
diff --git a/src/oauth.go b/src/oauth.go
index 6212124..5b7a5c4 100644
--- a/src/oauth.go
+++ b/src/oauth.go
@@ -12,6 +12,22 @@ import (
"strings"
)
+type OAuthGenStateUnableError struct {
+ Err error
+}
+
+func (e *OAuthGenStateUnableError) Error() string {
+ return fmt.Sprintf("failed generating state with error %v", e.Err)
+}
+
+type OAuthGenVerifierUnableError struct {
+ Err error
+}
+
+func (e *OAuthGenVerifierUnableError) Error() string {
+ return fmt.Sprintf("failed generating verifier with error %v", e.Err)
+}
+
// Generates a random base64 string to be used for state
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1
// "state": OPTIONAL. An opaque value used by the client to maintain
@@ -22,7 +38,7 @@ func genState() (string, error) {
randomBytes, err := MakeRandomByteSlice(32)
if err != nil {
- return "", err
+ return "", &OAuthGenStateUnableError{Err: err}
}
// For consistency we also use raw url encoding here
@@ -48,7 +64,7 @@ func genChallengeS256(verifier string) string {
func genVerifier() (string, error) {
randomBytes, err := MakeRandomByteSlice(32)
if err != nil {
- return "", err
+ return "", &OAuthGenVerifierUnableError{Err: err}
}
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
@@ -76,17 +92,27 @@ type EduVPNOAuthToken struct {
Expires int `json:"expires_in"`
}
+type OAuthFailedCallbackError struct {
+ Addr string
+ Err error
+}
+
+func (e *OAuthFailedCallbackError) Error() string {
+ return fmt.Sprintf("failed callback %s with error %v", e.Addr, e.Err)
+}
+
// Gets an authenticated HTTP client by obtaining refresh and access tokens
func (eduvpn *EduVPNOAuthSession) getHTTPTokenClient() error {
eduvpn.context = context.Background()
mux := http.NewServeMux()
+ addr := "127.0.0.1:8000"
eduvpn.server = &http.Server{
- Addr: "127.0.0.1:8000",
+ Addr: addr,
Handler: mux,
}
mux.HandleFunc("/callback", eduvpn.oauthCallback)
if err := eduvpn.server.ListenAndServe(); err != http.ErrServerClosed {
- return detailedOAuthError{errCallbackServerError, fmt.Sprintf("oauth callback server error"), err}
+ return &OAuthFailedCallbackError{Addr: addr, Err: err}
}
return eduvpn.callbackError
}
@@ -106,15 +132,16 @@ func (eduvpn *EduVPNOAuthSession) getTokens(authCode string) error {
"redirect_uri": {"http://127.0.0.1:8000/callback"},
}
client := &http.Client{}
- req, reqErr := http.NewRequest(http.MethodPost, eduvpn.VPNState.Endpoints.API.V3.Token, strings.NewReader(data.Encode()))
- if reqErr != nil {
- return reqErr
+ url := eduvpn.VPNState.Endpoints.API.V3.Token
+ req, reqErr := http.NewRequest(http.MethodPost, url, strings.NewReader(data.Encode()))
+ if reqErr != nil { // shouldn't happen
+ panic(reqErr)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, reqErr := client.Do(req)
if reqErr != nil {
- return reqErr
+ return &HTTPResourceError{URL: url, Err: reqErr}
}
// Close the response body at the end
@@ -123,14 +150,14 @@ func (eduvpn *EduVPNOAuthSession) getTokens(authCode string) error {
// Read the body
body, readErr := ioutil.ReadAll(resp.Body)
if readErr != nil {
- return readErr
+ return &HTTPReadError{URL: url, Err: readErr}
}
tokenStructure := &EduVPNOAuthToken{}
jsonErr := json.Unmarshal(body, tokenStructure)
if jsonErr != nil {
- return jsonErr
+ return &HTTPParseJsonError{URL: url, Body: string(body), Err: jsonErr}
}
eduvpn.VPNState.OAuthToken = tokenStructure
@@ -138,13 +165,39 @@ func (eduvpn *EduVPNOAuthSession) getTokens(authCode string) error {
return nil
}
+type OAuthFailedCallbackParameterError struct {
+ Parameter string
+ URL string
+}
+
+func (e *OAuthFailedCallbackParameterError) Error() string {
+ return fmt.Sprintf("failed retrieving parameter %s in url %s", e.Parameter, e.URL)
+}
+
+type OAuthFailedCallbackStateMatchError struct {
+ State string
+ ExpectedState string
+}
+
+func (e *OAuthFailedCallbackStateMatchError) Error() string {
+ return fmt.Sprintf("failed matching state, got %s, want %s", e.State, e.ExpectedState)
+}
+
+type OAuthFailedCallbackGetTokensError struct {
+ Err error
+}
+
+func (e *OAuthFailedCallbackGetTokensError) Error() string {
+ return fmt.Sprintf("failed getting tokens with error %v", e.Err)
+}
+
//
//// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
func (eduvpn *EduVPNOAuthSession) oauthCallback(w http.ResponseWriter, req *http.Request) {
// Extract the authorization code
code, success := req.URL.Query()["code"]
if !success {
- eduvpn.callbackError = detailedOAuthError{errCallbackGetAuthCodeError, fmt.Sprintf("oauth auth code cannot be retrieved"), nil}
+ eduvpn.callbackError = &OAuthFailedCallbackParameterError{Parameter: "code", URL: req.URL.String()}
go eduvpn.server.Shutdown(eduvpn.context)
return
}
@@ -155,14 +208,14 @@ func (eduvpn *EduVPNOAuthSession) oauthCallback(w http.ResponseWriter, req *http
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
state, success := req.URL.Query()["state"]
if !success {
- eduvpn.callbackError = detailedOAuthError{errCallbackGetStateError, fmt.Sprintf("oauth state cannot be retrieved"), nil}
+ eduvpn.callbackError = &OAuthFailedCallbackParameterError{Parameter: "state", URL: req.URL.String()}
go eduvpn.server.Shutdown(eduvpn.context)
return
}
// The state is the first entry
extractedState := state[0]
if extractedState != eduvpn.state {
- eduvpn.callbackError = detailedOAuthError{errCallbackVerifyStateMatchError, fmt.Sprintf("oauth state does not match"), nil}
+ eduvpn.callbackError = &OAuthFailedCallbackStateMatchError{State: extractedState, ExpectedState: eduvpn.state}
go eduvpn.server.Shutdown(eduvpn.context)
return
}
@@ -172,7 +225,7 @@ func (eduvpn *EduVPNOAuthSession) oauthCallback(w http.ResponseWriter, req *http
err := eduvpn.getTokens(extractedCode)
if err != nil {
- eduvpn.callbackError = detailedOAuthError{errCallbackGetTokenExchangeError, fmt.Sprintf("oauth failed to get token in exchange"), err}
+ eduvpn.callbackError = &OAuthFailedCallbackGetTokensError{Err: err}
go eduvpn.server.Shutdown(eduvpn.context)
return
}
@@ -197,6 +250,14 @@ func constructURL(baseURL string, parameters map[string]string) (string, error)
return url.String(), nil
}
+type OAuthFailedInitializeError struct {
+ Err error
+}
+
+func (e *OAuthFailedInitializeError) Error() string {
+ return fmt.Sprintf("failed initializing OAuth with error %v", e.Err)
+}
+
// Initializes the OAuth for eduvpn.
// It needs a vpn state that was gotten from `Register`
// It returns the authurl for the browser and an error if present
@@ -208,13 +269,13 @@ func InitializeOAuth(vpnState *EduVPNState) (string, error) {
// Generate the state
state, stateErr := genState()
if stateErr != nil {
- return "", detailedOAuthError{errGenStateError, fmt.Sprintf("oauth failed to gen random bytes for state"), stateErr}
+ return "", &OAuthFailedInitializeError{Err: stateErr}
}
// Generate the verifier and challenge
- verifier, err := genVerifier()
- if err != nil {
- return "", detailedOAuthError{errGenVerifierError, fmt.Sprintf("oauth failed to verifier"), err}
+ verifier, verifierErr := genVerifier()
+ if verifierErr != nil {
+ return "", &OAuthFailedInitializeError{Err: verifierErr}
}
challenge := genChallengeS256(verifier)
@@ -230,7 +291,7 @@ func InitializeOAuth(vpnState *EduVPNState) (string, error) {
authURL, urlErr := constructURL(vpnState.Endpoints.API.V3.Authorization, parameters)
- if urlErr != nil {
+ if urlErr != nil { // shouldn't happen
panic(urlErr)
}
@@ -249,46 +310,3 @@ func FinishOAuth(vpnState *EduVPNState) error {
}
return vpnState.OAuthSession.getHTTPTokenClient()
}
-
-// OAuthErrorCode Simplified error code for public interface.
-type OAuthErrorCode = VPNErrorCode
-type OAuthError = VPNError
-
-// detailedOAuthErrorCode used for unit tests.
-type detailedOAuthErrorCode = detailedVPNErrorCode
-type detailedOAuthError = detailedVPNError
-
-const (
- ErrGenError OAuthErrorCode = iota + 1
- ErrCallbackTokenError
-)
-
-const (
- errGenStateError detailedOAuthErrorCode = iota + 1
- errGenVerifierError
- errCallbackServerError
- errCallbackGetAuthCodeError
- errCallbackGetStateError
- errCallbackVerifyStateMatchError
- errCallbackGetTokenExchangeError
-)
-
-func (err detailedOAuthError) ToOAuthError() OAuthError {
- return RequestError{err.Code.ToOAuthErrorCode(), err}
-}
-
-func (code detailedOAuthErrorCode) ToOAuthErrorCode() OAuthErrorCode {
- switch code {
- case errGenStateError:
- case errGenVerifierError:
- return ErrGenError
-
- case errCallbackServerError:
- case errCallbackGetAuthCodeError:
- case errCallbackGetStateError:
- case errCallbackVerifyStateMatchError:
- case errCallbackGetTokenExchangeError:
- return ErrCallbackTokenError
- }
- panic("invalid detailedOAuthErrorCode")
-}