summaryrefslogtreecommitdiff
path: root/internal/oauth
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-06-20 15:20:18 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-06-20 15:20:18 +0200
commit2252135fadb8c579ad27345e3203be755130e3cd (patch)
treeed5a530e85b43736fc0bc28c927cfa8488f9199b /internal/oauth
parent7af07c596166bf93b79a9d0816b1950dde360fb9 (diff)
Refactor: Errors to have one custom type that is to be wrapped
- For this an `internal/types` package is created with a custom error type - This custom error type can give back the cause and traceback of an error
Diffstat (limited to 'internal/oauth')
-rw-r--r--internal/oauth/oauth.go144
1 files changed, 38 insertions, 106 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index f6ed916..824db90 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -8,10 +8,12 @@ import (
"fmt"
"net/http"
"net/url"
+
"github.com/jwijenbergh/eduvpn-common/internal/fsm"
httpw "github.com/jwijenbergh/eduvpn-common/internal/http"
- "github.com/jwijenbergh/eduvpn-common/internal/util"
"github.com/jwijenbergh/eduvpn-common/internal/log"
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
+ "github.com/jwijenbergh/eduvpn-common/internal/util"
)
// Generates a random base64 string to be used for state
@@ -23,7 +25,7 @@ import (
func genState() (string, error) {
randomBytes, err := util.MakeRandomByteSlice(32)
if err != nil {
- return "", &OAuthGenStateError{Err: err}
+ return "", &types.WrappedErrorMessage{Message: "failed generating an OAuth state", Err: err}
}
// For consistency we also use raw url encoding here
@@ -49,7 +51,7 @@ func genChallengeS256(verifier string) string {
func genVerifier() (string, error) {
randomBytes, err := util.MakeRandomByteSlice(32)
if err != nil {
- return "", &OAuthGenVerifierError{Err: err}
+ return "", &types.WrappedErrorMessage{Message: "failed generating an OAuth verifier", Err: err}
}
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
@@ -60,8 +62,8 @@ type OAuth struct {
Token OAuthToken `json:"token"`
BaseAuthorizationURL string `json:"base_authorization_url"`
TokenURL string `json:"token_url"`
- Logger *log.FileLogger `json:"-"`
- FSM *fsm.FSM `json:"-"`
+ Logger *log.FileLogger `json:"-"`
+ FSM *fsm.FSM `json:"-"`
}
// This structure gets passed to the callback for easy access to the current state
@@ -99,7 +101,7 @@ func (oauth *OAuth) getTokensWithCallback() error {
}
mux.HandleFunc("/callback", oauth.Callback)
if err := oauth.Session.Server.ListenAndServe(); err != http.ErrServerClosed {
- return &OAuthCallbackError{Addr: addr, Err: err}
+ return &types.WrappedErrorMessage{Message: "failed getting tokens with callback", Err: err}
}
return oauth.Session.CallbackError
}
@@ -108,9 +110,9 @@ func (oauth *OAuth) getTokensWithCallback() error {
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
+ errorMessage := "failed getting tokens with the authorization code"
// Make sure the verifier is set as the parameter
// so that the server can verify that we are the actual owner of the authorization code
-
reqURL := oauth.TokenURL
data := url.Values{
"client_id": {oauth.Session.ClientID},
@@ -126,7 +128,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
current_time := util.GenerateTimeSeconds()
_, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
- return &OAuthAuthError{Err: bodyErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
}
tokenStructure := OAuthToken{}
@@ -134,7 +136,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
jsonErr := json.Unmarshal(body, &tokenStructure)
if jsonErr != nil {
- return &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}}
}
tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires
@@ -152,6 +154,7 @@ func (oauth *OAuth) isTokensExpired() bool {
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
func (oauth *OAuth) getTokensWithRefresh() error {
+ errorMessage := "failed getting tokens with the refresh token"
reqURL := oauth.TokenURL
data := url.Values{
"refresh_token": {oauth.Token.Refresh},
@@ -164,14 +167,14 @@ func (oauth *OAuth) getTokensWithRefresh() error {
current_time := util.GenerateTimeSeconds()
_, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
- return &OAuthRefreshError{Err: bodyErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
}
tokenStructure := OAuthToken{}
jsonErr := json.Unmarshal(body, &tokenStructure)
if jsonErr != nil {
- return &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}}
}
tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires
@@ -182,11 +185,15 @@ func (oauth *OAuth) getTokensWithRefresh() error {
//
//// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
+ errorMessage := "failed callback to retrieve the authorization code"
// Extract the authorization code
code, success := req.URL.Query()["code"]
- if !success {
- oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}
+ // Shutdown after we're done
+ defer func() {
go oauth.Session.Server.Shutdown(oauth.Session.Context)
+ }()
+ if !success {
+ oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}}
return
}
// The code is the first entry
@@ -195,30 +202,25 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// Make sure the state is present and matches to protect against cross-site request forgeries
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
state, success := req.URL.Query()["state"]
+
if !success {
- oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}
- go oauth.Session.Server.Shutdown(oauth.Session.Context)
+ oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}}
return
}
// The state is the first entry
extractedState := state[0]
if extractedState != oauth.Session.State {
- oauth.Session.CallbackError = &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State}
- go oauth.Session.Server.Shutdown(oauth.Session.Context)
+ oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State}}
return
}
// Now that we have obtained the authorization code, we can move to the next step:
// Obtaining the access and refresh tokens
- err := oauth.getTokensWithAuthCode(extractedCode)
- if err != nil {
- oauth.Session.CallbackError = &OAuthCallbackGetTokensError{Err: err}
- go oauth.Session.Server.Shutdown(oauth.Session.Context)
+ getTokensErr := oauth.getTokensWithAuthCode(extractedCode)
+ if getTokensErr != nil {
+ oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: getTokensErr}
return
}
-
- // Shutdown the server as we're done listening
- go oauth.Session.Server.Shutdown(oauth.Session.Context)
}
func (oauth *OAuth) Update(fsm *fsm.FSM, logger *log.FileLogger) {
@@ -235,19 +237,20 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.
// Starts the OAuth exchange for eduvpn.
func (oauth *OAuth) start(name string) error {
+ errorMessage := "failed starting OAuth exchange"
if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) {
- return &fsm.FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()}
}
// Generate the state
state, stateErr := genState()
if stateErr != nil {
- return &OAuthInitializeError{Err: stateErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr}
}
// Generate the verifier and challenge
verifier, verifierErr := genVerifier()
if verifierErr != nil {
- return &OAuthInitializeError{Err: verifierErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: verifierErr}
}
challenge := genChallengeS256(verifier)
@@ -264,7 +267,7 @@ func (oauth *OAuth) start(name string) error {
authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters)
if urlErr != nil {
- return &OAuthInitializeError{Err: urlErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
}
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
@@ -277,34 +280,36 @@ func (oauth *OAuth) start(name string) error {
// Error definitions
func (oauth *OAuth) Finish() error {
+ errorMessage := "failed finishing OAuth"
if !oauth.FSM.HasTransition(fsm.AUTHORIZED) {
- return &fsm.FSMWrongStateError{Got: oauth.FSM.Current, Want: fsm.AUTHORIZED}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.AUTHORIZED}.CustomError()}
}
tokenErr := oauth.getTokensWithCallback()
if tokenErr != nil {
- return &OAuthFinishError{Err: tokenErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: tokenErr}
}
oauth.FSM.GoTransition(fsm.AUTHORIZED)
return nil
}
func (oauth *OAuth) Cancel() {
- oauth.Session.CallbackError = &OAuthCancelledCallbackError{}
+ oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: "failed cancelling OAuth", Err: &OAuthCancelledCallbackError{}}
oauth.Session.Server.Shutdown(oauth.Session.Context)
}
func (oauth *OAuth) Login(name string) error {
+ errorMessage := "failed OAuth login"
authInitializeErr := oauth.start(name)
if authInitializeErr != nil {
- return &OAuthLoginError{Err: authInitializeErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr}
}
oauthErr := oauth.Finish()
if oauthErr != nil {
- return &OAuthLoginError{Err: oauthErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr}
}
return nil
}
@@ -342,31 +347,6 @@ func (e *OAuthCancelledCallbackError) Error() string {
return fmt.Sprintf("client cancelled OAuth")
}
-type OAuthGenStateError struct {
- Err error
-}
-
-func (e *OAuthGenStateError) Error() string {
- return fmt.Sprintf("failed generating state with error: %v", e.Err)
-}
-
-type OAuthGenVerifierError struct {
- Err error
-}
-
-func (e *OAuthGenVerifierError) Error() string {
- return fmt.Sprintf("failed generating verifier with error: %v", e.Err)
-}
-
-type OAuthCallbackError struct {
- Addr string
- Err error
-}
-
-func (e *OAuthCallbackError) Error() string {
- return fmt.Sprintf("failed callback: %s with error: %v", e.Addr, e.Err)
-}
-
type OAuthCallbackParameterError struct {
Parameter string
URL string
@@ -384,51 +364,3 @@ type OAuthCallbackStateMatchError struct {
func (e *OAuthCallbackStateMatchError) Error() string {
return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState)
}
-
-type OAuthCallbackGetTokensError struct {
- Err error
-}
-
-func (e *OAuthCallbackGetTokensError) Error() string {
- return fmt.Sprintf("failed getting tokens with error: %v", e.Err)
-}
-
-type OAuthFinishError struct {
- Err error
-}
-
-func (e *OAuthFinishError) Error() string {
- return fmt.Sprintf("failed finishing OAuth with error: %v", e.Err)
-}
-
-type OAuthLoginError struct {
- Err error
-}
-
-func (e *OAuthLoginError) Error() string {
- return fmt.Sprintf("failed OAuth logging in with error: %v", e.Err)
-}
-
-type OAuthInitializeError struct {
- Err error
-}
-
-func (e *OAuthInitializeError) Error() string {
- return fmt.Sprintf("failed initializing OAuth with error: %v", e.Err)
-}
-
-type OAuthAuthError struct {
- Err error
-}
-
-func (e *OAuthAuthError) Error() string {
- return fmt.Sprintf("failed getting tokens with auth code for OAuth with error: %v", e.Err)
-}
-
-type OAuthRefreshError struct {
- Err error
-}
-
-func (e *OAuthRefreshError) Error() string {
- return fmt.Sprintf("failed refreshing tokens for OAuth with error: %v", e.Err)
-}