summaryrefslogtreecommitdiff
path: root/internal/oauth
diff options
context:
space:
mode:
authorAleksandar Pesic <peske.nis@gmail.com>2022-12-04 21:48:20 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-12-12 13:26:51 +0100
commit3ac1d35257b56cca92ad0eb7f4d18abb366cf105 (patch)
tree432db14d1f92a252518f371be420fa0d3ef044c8 /internal/oauth
parent37bca013bd4405548b274ac473acf959ad661ee6 (diff)
simplify error handling
fixes #6 Signed-off-by: Aleksandar Pesic <peske.nis@gmail.com>
Diffstat (limited to 'internal/oauth')
-rw-r--r--internal/oauth/oauth.go307
1 files changed, 105 insertions, 202 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 802295d..3dcd3d3 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -1,4 +1,4 @@
-// package oauth implement an oauth client defined in e.g. rfc 6749
+// Package oauth implement an oauth client defined in e.g. rfc 6749
// However, we try to follow some recommendations from the v2.1 oauth draft RFC
// Some specific things we implement here:
// - PKCE (RFC 7636)
@@ -10,7 +10,6 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
- "errors"
"fmt"
"html/template"
"net"
@@ -20,7 +19,7 @@ import (
httpw "github.com/eduvpn/eduvpn-common/internal/http"
"github.com/eduvpn/eduvpn-common/internal/util"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
// genState generates a random base64 string to be used for state
@@ -31,13 +30,13 @@ import (
// client.
// We implement it similarly to the verifier.
func genState() (string, error) {
- randomBytes, err := util.MakeRandomByteSlice(32)
+ bts, err := util.MakeRandomByteSlice(32)
if err != nil {
- return "", types.NewWrappedError("failed generating an OAuth state", err)
+ return "", err
}
- // For consistency we also use raw url encoding here
- return base64.RawURLEncoding.EncodeToString(randomBytes), nil
+ // For consistency, we also use raw url encoding here
+ return base64.RawURLEncoding.EncodeToString(bts), nil
}
// genChallengeS256 generates a sha256 base64 challenge from a verifier
@@ -68,10 +67,7 @@ func genChallengeS256(verifier string) string {
func genVerifier() (string, error) {
randomBytes, err := util.MakeRandomByteSlice(32)
if err != nil {
- return "", types.NewWrappedError(
- "failed generating an OAuth verifier",
- err,
- )
+ return "", err
}
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
@@ -89,10 +85,10 @@ type OAuth struct {
TokenURL string `json:"token_url"`
// session is the internal in progress OAuth session
- session ExchangeSession `json:"-"`
+ session ExchangeSession
// Token is where the access and refresh tokens are stored along with the timestamps
- token Token `json:"-"`
+ token Token
}
// ExchangeSession is a structure that gets passed to the callback for easy access to the current state.
@@ -126,39 +122,31 @@ type ExchangeSession struct {
// It returns the access token as a string, possibly obtained fresh using the Refresh Token
// If the token cannot be obtained, an error is returned and the token is an empty string.
func (oauth *OAuth) AccessToken() (string, error) {
- errorMessage := "failed getting access token"
- tokens := oauth.token
+ ts := oauth.token
// We have tokens...
// The tokens are not expired yet
- // So they should be valid, re-authorization not needed
- if !tokens.Expired() {
- return tokens.access, nil
+ // So they should be valid, re-login not needed
+ if !ts.Expired() {
+ return ts.access, nil
}
// Check if refresh is even possible by doing a simple check if the refresh token is empty
// This is not needed but reduces API calls to the server
- if tokens.refresh == "" {
- return "", types.NewWrappedError(
- errorMessage,
- &TokensInvalidError{Cause: "no refresh token is present"},
- )
+ if ts.refresh == "" {
+ return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0)
}
// Otherwise refresh and then later return the access token if we are successful
- refreshErr := oauth.tokensWithRefresh()
- if refreshErr != nil {
+ err := oauth.tokensWithRefresh()
+ if err != nil {
// We have failed to ensure the tokens due to refresh not working
- return "", types.NewWrappedError(
- errorMessage,
- &TokensInvalidError{
- Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr),
- },
- )
+ return "", errors.Wrap(
+ &TokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", err)}, 0)
}
// We have obtained new tokens with refresh
- return tokens.access, nil
+ return ts.access, nil
}
// setupListener sets up an OAuth listener
@@ -166,24 +154,22 @@ func (oauth *OAuth) AccessToken() (string, error) {
// @see https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-07.html#section-8.4.2
// "Loopback Interface Redirection".
func (oauth *OAuth) setupListener() error {
- errorMessage := "failed setting up listener"
oauth.session.Context = context.Background()
// create a listener
- listener, listenerErr := net.Listen("tcp", "127.0.0.1:0")
- if listenerErr != nil {
- return types.NewWrappedError(errorMessage, listenerErr)
+ lst, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return errors.WrapPrefix(err, "net.Listen failed", 0)
}
- oauth.session.Listener = listener
+ oauth.session.Listener = lst
return nil
}
// tokensWithCallback gets the OAuth tokens using a local web server
// If it was unsuccessful it returns an error.
func (oauth *OAuth) tokensWithCallback() error {
- errorMessage := "failed getting tokens with callback"
if oauth.session.Listener == nil {
- return types.NewWrappedError(errorMessage, errors.New("no listener"))
+ return errors.Errorf("failed getting tokens with callback: no listener")
}
mux := http.NewServeMux()
// server /callback over the listener address
@@ -196,7 +182,7 @@ func (oauth *OAuth) tokensWithCallback() error {
mux.HandleFunc("/callback", oauth.Callback)
if err := oauth.session.Server.Serve(oauth.session.Listener); err != http.ErrServerClosed {
- return types.NewWrappedError(errorMessage, err)
+ return errors.WrapPrefix(err, "failed getting tokens with callback", 0)
}
return oauth.session.CallbackError
}
@@ -205,23 +191,18 @@ func (oauth *OAuth) tokensWithCallback() error {
// It calculates the expired timestamp by having a 'startTime' passed to it
// The URL that is input here is used for additional context.
func (oauth *OAuth) fillToken(response []byte, startTime time.Time, url string) error {
- responseStructure := TokenResponse{}
-
- jsonErr := json.Unmarshal(response, &responseStructure)
- if jsonErr != nil {
- return types.NewWrappedError(
- "failed filling OAuth tokens",
- &httpw.ParseJSONError{URL: url, Body: string(response), Err: jsonErr},
- )
- }
-
- internalStructure := Token{}
- internalStructure.expiredTimestamp = startTime.Add(
- time.Second * time.Duration(responseStructure.Expires),
- )
- internalStructure.access = responseStructure.Access
- internalStructure.refresh = responseStructure.Refresh
- oauth.token = internalStructure
+ res := TokenResponse{}
+
+ err := json.Unmarshal(response, &res)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed filling OAuth tokens from "+url, 0)
+ }
+
+ oauth.token = Token{
+ access: res.Access,
+ refresh: res.Refresh,
+ expiredTimestamp: startTime.Add(time.Second * time.Duration(res.Expires)),
+ }
return nil
}
@@ -240,14 +221,13 @@ func (oauth *OAuth) SetTokenRenew() {
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
// If it was unsuccessful it returns an error.
func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
- errorMessage := "failed getting tokens with the authorization code"
// Make sure the verifier is set as the parameter
// so that the server can verify that we are the actual owner of the authorization code
- reqURL := oauth.TokenURL
+ u := oauth.TokenURL
- port, portErr := oauth.ListenerPort()
- if portErr != nil {
- return types.NewWrappedError(errorMessage, portErr)
+ port, err := oauth.ListenerPort()
+ if err != nil {
+ return err
}
data := url.Values{
@@ -257,21 +237,17 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
"grant_type": {"authorization_code"},
"redirect_uri": {fmt.Sprintf("http://127.0.0.1:%d/callback", port)},
}
- headers := http.Header{
+ h := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
}
- opts := &httpw.OptionalParams{Headers: headers, Body: data}
- currentTime := time.Now()
- _, body, bodyErr := httpw.PostWithOpts(reqURL, opts)
- if bodyErr != nil {
- return types.NewWrappedError(errorMessage, bodyErr)
+ opts := &httpw.OptionalParams{Headers: h, Body: data}
+ now := time.Now()
+ _, body, err := httpw.PostWithOpts(u, opts)
+ if err != nil {
+ return err
}
- fillErr := oauth.fillToken(body, currentTime, reqURL)
- if fillErr != nil {
- return types.NewWrappedError(errorMessage, fillErr)
- }
- return nil
+ return oauth.fillToken(body, now, u)
}
// tokensWithRefresh gets the access and refresh tokens with a previously received refresh token
@@ -279,27 +255,22 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
// If it was unsuccessful it returns an error.
func (oauth *OAuth) tokensWithRefresh() error {
- errorMessage := "failed getting tokens with the refresh token"
- reqURL := oauth.TokenURL
+ u := oauth.TokenURL
data := url.Values{
"refresh_token": {oauth.token.refresh},
"grant_type": {"refresh_token"},
}
- headers := http.Header{
+ h := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
}
- opts := &httpw.OptionalParams{Headers: headers, Body: data}
- currentTime := time.Now()
- _, body, bodyErr := httpw.PostWithOpts(reqURL, opts)
- if bodyErr != nil {
- return types.NewWrappedError(errorMessage, bodyErr)
+ opts := &httpw.OptionalParams{Headers: h, Body: data}
+ now := time.Now()
+ _, body, err := httpw.PostWithOpts(u, opts)
+ if err != nil {
+ return err
}
- fillErr := oauth.fillToken(body, currentTime, reqURL)
- if fillErr != nil {
- return types.NewWrappedError(errorMessage, fillErr)
- }
- return nil
+ return oauth.fillToken(body, now, u)
}
// responseTemplate is the HTML template for the OAuth authorized response
@@ -349,27 +320,17 @@ type oauthResponseHTML struct {
// writeResponseHTML writes the OAuth response using a response writer and the title + message
// If it was unsuccessful it returns an error.
func writeResponseHTML(w http.ResponseWriter, title string, message string) error {
- errorMessage := "failed writing response HTML"
- template, templateErr := template.New("oauth-response").Parse(responseTemplate)
- if templateErr != nil {
- return types.NewWrappedError(errorMessage, templateErr)
+ t, err := template.New("oauth-response").Parse(responseTemplate)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed writing response HTML", 0)
}
- executeErr := template.Execute(w, oauthResponseHTML{
- Title: title,
- Message: message,
- })
- if executeErr != nil {
- return types.NewWrappedError(errorMessage, executeErr)
- }
- return nil
+ return t.Execute(w, oauthResponseHTML{Title: title, Message: message})
}
// Callback is the public function used to get the OAuth tokens using an authorization code callback
// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
- errorMessage := "failed callback to retrieve the authorization code"
-
// Shutdown after we're done
defer func() {
// writing the html is best effort
@@ -383,64 +344,49 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
_ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.")
}
if oauth.session.Server != nil {
- go oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
+ go func() {
+ _ = oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
+ }()
}
}()
// ISS: https://www.rfc-editor.org/rfc/rfc9207.html
// TODO: Make this a required parameter in the future
- urlQuery := req.URL.Query()
- extractedISS := urlQuery.Get("iss")
- if extractedISS != "" {
- if oauth.session.ISS != extractedISS {
- oauth.session.CallbackError = types.NewWrappedError(
- errorMessage,
- &CallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.session.ISS},
- )
+ q := req.URL.Query()
+ iss := q.Get("iss")
+ if iss != "" {
+ if oauth.session.ISS != iss {
+ oauth.session.CallbackError = errors.Errorf("failed matching ISS; expected '%s' got '%s'",
+ oauth.session.ISS, iss)
return
}
}
// Make sure the state is present and matches to protect against cross-site request forgeries
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
- extractedState := urlQuery.Get("state")
- if extractedState == "" {
- oauth.session.CallbackError = types.NewWrappedError(
- errorMessage,
- &CallbackParameterError{Parameter: "state", URL: req.URL.String()},
- )
+ state := q.Get("state")
+ if state == "" {
+ oauth.session.CallbackError = errors.Errorf("failed retrieving parameter 'state' from '%s'", req.URL)
return
}
// The state is the first entry
- if extractedState != oauth.session.State {
- oauth.session.CallbackError = types.NewWrappedError(
- errorMessage,
- &CallbackStateMatchError{
- State: extractedState,
- ExpectedState: oauth.session.State,
- },
- )
+ if state != oauth.session.State {
+ oauth.session.CallbackError = errors.Errorf("failed matching state; expected '%s' got '%s'",
+ oauth.session.State, state)
return
}
// No authorization code
- extractedCode := urlQuery.Get("code")
- if extractedCode == "" {
- oauth.session.CallbackError = types.NewWrappedError(
- errorMessage,
- &CallbackParameterError{Parameter: "code", URL: req.URL.String()},
- )
+ code := q.Get("code")
+ if code == "" {
+ oauth.session.CallbackError = errors.Errorf("failed retrieving parameter 'code' from '%s'", req.URL)
return
}
// Now that we have obtained the authorization code, we can move to the next step:
// Obtaining the access and refresh tokens
- getTokensErr := oauth.tokensWithAuthCode(extractedCode)
- if getTokensErr != nil {
- oauth.session.CallbackError = types.NewWrappedError(
- errorMessage,
- getTokensErr,
- )
+ if err := oauth.tokensWithAuthCode(code); err != nil {
+ oauth.session.CallbackError = errors.WrapPrefix(err, "failed callback to retrieve the authorization code", 0)
return
}
}
@@ -457,94 +403,78 @@ func (oauth *OAuth) Init(iss string, baseAuthorizationURL string, tokenURL strin
// ListenerPort gets the listener for the OAuth web server
// It returns the port as an integer and an error if there is any.
-func (oauth OAuth) ListenerPort() (int, error) {
- errorMessage := "failed to get listener port"
-
+func (oauth *OAuth) ListenerPort() (int, error) {
if oauth.session.Listener == nil {
- return 0, types.NewWrappedError(errorMessage, errors.New("no OAuth listener"))
+ return 0, errors.Errorf("failed to get listener port")
}
return oauth.session.Listener.Addr().(*net.TCPAddr).Port, nil
}
// AuthURL gets the authorization url to start the OAuth procedure.
func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (string, error) {
- errorMessage := "failed starting OAuth exchange"
-
// Generate the verifier and challenge
- verifier, verifierErr := genVerifier()
- if verifierErr != nil {
- return "", types.NewWrappedError(errorMessage, verifierErr)
+ v, err := genVerifier()
+ if err != nil {
+ return "", errors.WrapPrefix(err, "genVerifier error", 0)
}
- challenge := genChallengeS256(verifier)
// Generate the state
- state, stateErr := genState()
- if stateErr != nil {
- return "", types.NewWrappedError(errorMessage, stateErr)
+ state, err := genState()
+ if err != nil {
+ return "", errors.WrapPrefix(err, "genState error", 0)
}
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
- oauthSession := ExchangeSession{
+ oauth.session = ExchangeSession{
ClientID: name,
ISS: oauth.ISS,
State: state,
- Verifier: verifier,
+ Verifier: v,
}
- oauth.session = oauthSession
// set up the listener to get the redirect URI
- listenerErr := oauth.setupListener()
- if listenerErr != nil {
- return "", types.NewWrappedError(errorMessage, stateErr)
+ if err = oauth.setupListener(); err != nil {
+ return "", errors.WrapPrefix(err, "oauth.setupListener error", 0)
}
// Get the listener port
- port, portErr := oauth.ListenerPort()
- if portErr != nil {
- return "", types.NewWrappedError(errorMessage, portErr)
+ port, err := oauth.ListenerPort()
+ if err != nil {
+ return "", errors.WrapPrefix(err, "oauth.ListenerPort error", 0)
}
- parameters := map[string]string{
+ params := map[string]string{
"client_id": name,
"code_challenge_method": "S256",
- "code_challenge": challenge,
+ "code_challenge": genChallengeS256(v),
"response_type": "code",
"scope": "config",
"state": state,
"redirect_uri": fmt.Sprintf("http://127.0.0.1:%d/callback", port),
}
- authURL, urlErr := httpw.ConstructURL(oauth.BaseAuthorizationURL, parameters)
+ u, err := httpw.ConstructURL(oauth.BaseAuthorizationURL, params)
- if urlErr != nil {
- return "", types.NewWrappedError(errorMessage, urlErr)
+ if err != nil {
+ return "", errors.WrapPrefix(err, "httpw.ConstructURL error", 0)
}
// Return the url processed
- return postProcessAuth(authURL), nil
+ return postProcessAuth(u), nil
}
// Exchange starts the OAuth exchange by getting the tokens with the redirect callback
// If it was unsuccessful it returns an error.
func (oauth *OAuth) Exchange() error {
- tokenErr := oauth.tokensWithCallback()
-
- if tokenErr != nil {
- return types.NewWrappedError("failed finishing OAuth", tokenErr)
- }
- return nil
+ return oauth.tokensWithCallback()
}
// Cancel cancels the existing OAuth
// TODO: Use context for this.
func (oauth *OAuth) Cancel() {
- oauth.session.CallbackError = types.NewWrappedErrorLevel(
- types.ErrInfo,
- "cancelled OAuth",
- &CancelledCallbackError{},
- )
+ oauth.session.CallbackError = errors.Wrap(&CancelledCallbackError{}, 0)
if oauth.session.Server != nil {
- oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
+ _ = oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
}
}
@@ -554,33 +484,6 @@ func (e *CancelledCallbackError) Error() string {
return "client cancelled OAuth"
}
-type CallbackParameterError struct {
- Parameter string
- URL string
-}
-
-func (e *CallbackParameterError) Error() string {
- return fmt.Sprintf("failed retrieving parameter: %s in url: %s", e.Parameter, e.URL)
-}
-
-type CallbackStateMatchError struct {
- State string
- ExpectedState string
-}
-
-func (e *CallbackStateMatchError) Error() string {
- return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState)
-}
-
-type CallbackISSMatchError struct {
- ISS string
- ExpectedISS string
-}
-
-func (e *CallbackISSMatchError) Error() string {
- return fmt.Sprintf("failed matching ISS, got: %s, want: %s", e.ISS, e.ExpectedISS)
-}
-
type TokensInvalidError struct {
Cause string
}