summaryrefslogtreecommitdiff
path: root/internal/oauth.go
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-05-02 14:34:35 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-05-02 14:34:35 +0200
commit466450f0c47bdc614e66326d90e5fc6fb56ae732 (patch)
treea01518a58d50d2f8449d37dadecc40e35c9f1fe1 /internal/oauth.go
parenta2a8efdcaad3d9b1852b1367a7cd7e8c5860cecf (diff)
Refactor: Wrap most errors in a custom type
Diffstat (limited to 'internal/oauth.go')
-rw-r--r--internal/oauth.go113
1 files changed, 72 insertions, 41 deletions
diff --git a/internal/oauth.go b/internal/oauth.go
index 98af5a4..c13ea99 100644
--- a/internal/oauth.go
+++ b/internal/oauth.go
@@ -5,7 +5,6 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
- "errors"
"fmt"
"net/http"
"net/url"
@@ -20,7 +19,7 @@ import (
func genState() (string, error) {
randomBytes, err := MakeRandomByteSlice(32)
if err != nil {
- return "", &OAuthGenStateUnableError{Err: err}
+ return "", &OAuthGenStateError{Err: err}
}
// For consistency we also use raw url encoding here
@@ -46,7 +45,7 @@ func genChallengeS256(verifier string) string {
func genVerifier() (string, error) {
randomBytes, err := MakeRandomByteSlice(32)
if err != nil {
- return "", &OAuthGenVerifierUnableError{Err: err}
+ return "", &OAuthGenVerifierError{Err: err}
}
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
@@ -95,7 +94,7 @@ func (oauth *OAuth) getTokensWithCallback() error {
}
mux.HandleFunc("/callback", oauth.Callback)
if err := oauth.Session.Server.ListenAndServe(); err != http.ErrServerClosed {
- return &OAuthFailedCallbackError{Addr: addr, Err: err}
+ return &OAuthCallbackError{Addr: addr, Err: err}
}
return oauth.Session.CallbackError
}
@@ -122,7 +121,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
current_time := GenerateTimeSeconds()
_, body, bodyErr := HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
- return bodyErr
+ return &OAuthAuthError{Err: bodyErr}
}
tokenStructure := OAuthToken{}
@@ -160,7 +159,7 @@ func (oauth *OAuth) getTokensWithRefresh() error {
current_time := GenerateTimeSeconds()
_, body, bodyErr := HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
- return bodyErr
+ return &OAuthRefreshError{Err: bodyErr}
}
tokenStructure := OAuthToken{}
@@ -181,7 +180,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// Extract the authorization code
code, success := req.URL.Query()["code"]
if !success {
- oauth.Session.CallbackError = &OAuthFailedCallbackParameterError{Parameter: "code", URL: req.URL.String()}
+ oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}
go oauth.Session.Server.Shutdown(oauth.Session.Context)
return
}
@@ -192,14 +191,14 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
state, success := req.URL.Query()["state"]
if !success {
- oauth.Session.CallbackError = &OAuthFailedCallbackParameterError{Parameter: "state", URL: req.URL.String()}
+ oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}
go oauth.Session.Server.Shutdown(oauth.Session.Context)
return
}
// The state is the first entry
extractedState := state[0]
if extractedState != oauth.Session.State {
- oauth.Session.CallbackError = &OAuthFailedCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State}
+ oauth.Session.CallbackError = &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State}
go oauth.Session.Server.Shutdown(oauth.Session.Context)
return
}
@@ -208,7 +207,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// Obtaining the access and refresh tokens
err := oauth.getTokensWithAuthCode(extractedCode)
if err != nil {
- oauth.Session.CallbackError = &OAuthFailedCallbackGetTokensError{Err: err}
+ oauth.Session.CallbackError = &OAuthCallbackGetTokensError{Err: err}
go oauth.Session.Server.Shutdown(oauth.Session.Context)
return
}
@@ -225,18 +224,18 @@ func (oauth *OAuth) Init(fsm *FSM, logger *FileLogger) {
// Starts the OAuth exchange for eduvpn.
func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string) error {
if !oauth.FSM.HasTransition(OAUTH_STARTED) {
- return errors.New(fmt.Sprintf("Failed starting oauth, invalid state %s", oauth.FSM.Current.String()))
+ return &FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: OAUTH_STARTED}
}
// Generate the state
state, stateErr := genState()
if stateErr != nil {
- return &OAuthFailedInitializeError{Err: stateErr}
+ return &OAuthInitializeError{Err: stateErr}
}
// Generate the verifier and challenge
verifier, verifierErr := genVerifier()
if verifierErr != nil {
- return &OAuthFailedInitializeError{Err: verifierErr}
+ return &OAuthInitializeError{Err: verifierErr}
}
challenge := genChallengeS256(verifier)
@@ -252,8 +251,8 @@ func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string)
authURL, urlErr := HTTPConstructURL(authorizationURL, parameters)
- if urlErr != nil { // shouldn't happen
- panic(urlErr)
+ if urlErr != nil {
+ return &OAuthInitializeError{Err: urlErr}
}
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
@@ -268,12 +267,12 @@ func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string)
// Error definitions
func (oauth *OAuth) Finish() error {
if !oauth.FSM.HasTransition(AUTHORIZED) {
- return errors.New("invalid state to finish oauth")
+ return &FSMWrongStateError{Got: oauth.FSM.Current, Want: AUTHORIZED}
}
tokenErr := oauth.getTokensWithCallback()
if tokenErr != nil {
- return tokenErr
+ return &OAuthFinishError{Err: tokenErr}
}
oauth.FSM.GoTransition(AUTHORIZED)
return nil
@@ -288,13 +287,13 @@ func (oauth *OAuth) Login(name string, authorizationURL string, tokenURL string)
authInitializeErr := oauth.start(name, authorizationURL, tokenURL)
if authInitializeErr != nil {
- return authInitializeErr
+ return &OAuthLoginError{Err: authInitializeErr}
}
oauthErr := oauth.Finish()
if oauthErr != nil {
- return oauthErr
+ return &OAuthLoginError{Err: oauthErr}
}
return nil
}
@@ -329,64 +328,96 @@ func (oauth *OAuth) NeedsRelogin() bool {
type OAuthCancelledCallbackError struct{}
func (e *OAuthCancelledCallbackError) Error() string {
- return fmt.Sprintf("Client cancelled OAuth")
+ return fmt.Sprintf("client cancelled OAuth")
}
-type OAuthGenStateUnableError struct {
+type OAuthGenStateError struct {
Err error
}
-func (e *OAuthGenStateUnableError) Error() string {
- return fmt.Sprintf("failed generating state with error %v", e.Err)
+func (e *OAuthGenStateError) Error() string {
+ return fmt.Sprintf("failed generating state with error: %v", e.Err)
}
-type OAuthGenVerifierUnableError struct {
+type OAuthGenVerifierError struct {
Err error
}
-func (e *OAuthGenVerifierUnableError) Error() string {
- return fmt.Sprintf("failed generating verifier with error %v", e.Err)
+func (e *OAuthGenVerifierError) Error() string {
+ return fmt.Sprintf("failed generating verifier with error: %v", e.Err)
}
-type OAuthFailedCallbackError struct {
+type OAuthCallbackError struct {
Addr string
Err error
}
-func (e *OAuthFailedCallbackError) Error() string {
- return fmt.Sprintf("failed callback %s with error %v", e.Addr, e.Err)
+func (e *OAuthCallbackError) Error() string {
+ return fmt.Sprintf("failed callback: %s with error: %v", e.Addr, e.Err)
}
-type OAuthFailedCallbackParameterError struct {
+type OAuthCallbackParameterError struct {
Parameter string
URL string
}
-func (e *OAuthFailedCallbackParameterError) Error() string {
- return fmt.Sprintf("failed retrieving parameter %s in url %s", e.Parameter, e.URL)
+func (e *OAuthCallbackParameterError) Error() string {
+ return fmt.Sprintf("failed retrieving parameter: %s in url: %s", e.Parameter, e.URL)
}
-type OAuthFailedCallbackStateMatchError struct {
+type OAuthCallbackStateMatchError struct {
State string
ExpectedState string
}
-func (e *OAuthFailedCallbackStateMatchError) Error() string {
- return fmt.Sprintf("failed matching state, got %s, want %s", e.State, e.ExpectedState)
+func (e *OAuthCallbackStateMatchError) Error() string {
+ return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState)
}
-type OAuthFailedCallbackGetTokensError struct {
+type OAuthCallbackGetTokensError struct {
Err error
}
-func (e *OAuthFailedCallbackGetTokensError) Error() string {
- return fmt.Sprintf("failed getting tokens with error %v", e.Err)
+func (e *OAuthCallbackGetTokensError) Error() string {
+ return fmt.Sprintf("failed getting tokens with error: %v", e.Err)
}
-type OAuthFailedInitializeError struct {
+type OAuthFinishError struct {
Err error
}
-func (e *OAuthFailedInitializeError) Error() string {
- return fmt.Sprintf("failed initializing OAuth with error %v", e.Err)
+func (e *OAuthFinishError) Error() string {
+ return fmt.Sprintf("failed finishing OAuth with error: %v", e.Err)
+}
+
+type OAuthLoginError struct {
+ Err error
+}
+
+func (e *OAuthLoginError) Error() string {
+ return fmt.Sprintf("failed OAuth logging in with error: %v", e.Err)
+}
+
+type OAuthInitializeError struct {
+ Err error
+}
+
+func (e *OAuthInitializeError) Error() string {
+ return fmt.Sprintf("failed initializing OAuth with error: %v", e.Err)
+}
+
+type OAuthAuthError struct {
+ Err error
+}
+
+func (e *OAuthAuthError) Error() string {
+ return fmt.Sprintf("failed getting tokens with auth code for OAuth with error: %v", e.Err)
+}
+
+type OAuthRefreshError struct {
+ Err error
+}
+
+func (e *OAuthRefreshError) Error() string {
+ return fmt.Sprintf("failed refreshing tokens for OAuth with error: %v", e.Err)
}