summaryrefslogtreecommitdiff
path: root/internal/oauth/oauth.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/oauth/oauth.go')
-rw-r--r--internal/oauth/oauth.go141
1 files changed, 98 insertions, 43 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 1f3b719..232f68c 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -1,3 +1,8 @@
+// package oauth implement an oauth client defined in e.g. rfc 6749
+// However, we try to follow some recommendations from the v2.1 oauth draft RFC
+// Some specific things we implement here:
+// - PKCE (RFC 7636)
+// - ISS (RFC 9207)
package oauth
import (
@@ -18,7 +23,7 @@ import (
"github.com/eduvpn/eduvpn-common/types"
)
-// Generates a random base64 string to be used for state
+// genState generates a random base64 string to be used for state
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1
// "state": OPTIONAL. An opaque value used by the client to maintain
// state between the request and callback. The authorization server
@@ -35,7 +40,7 @@ func genState() (string, error) {
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
}
-// Generates a sha256 base64 challenge from a verifier
+// genChallengeS256 generates a sha256 base64 challenge from a verifier
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.8
func genChallengeS256(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
@@ -44,7 +49,7 @@ func genChallengeS256(verifier string) string {
return base64.RawURLEncoding.EncodeToString(hash[:])
}
-// Generates a verifier
+// genVerifier generates a verifier
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1
// The code_verifier is a unique high-entropy cryptographically random
// string generated for each authorization request, using the unreserved
@@ -70,75 +75,108 @@ func genVerifier() (string, error) {
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
}
+// OAuth defines the main structure for this package
type OAuth struct {
+ // ISS indicates the issuer indentifier of the authorization server as defined in RFC 9207
ISS string `json:"iss"`
- Session OAuthExchangeSession `json:"-"`
+
+ // Token is where the access and refresh tokens are stored along with the timestamps
Token OAuthToken `json:"token"`
+
+ // BaseAuthorizationURL is the URL where authorization should take place
BaseAuthorizationURL string `json:"base_authorization_url"`
+
+ // TokenURL is the URL where tokens should be obtained
TokenURL string `json:"token_url"`
+
+ // session is the internal in progress OAuth session
+ session OAuthExchangeSession `json:"-"`
}
-// This structure gets passed to the callback for easy access to the current state
+// OAuthExchangeSession is a structure that gets passed to the callback for easy access to the current state
type OAuthExchangeSession struct {
- // returned from the callback
+ // CallbackError indicates an error returned by the server
CallbackError error
- // filled in in initialize
+ // ClientID is the ID of the OAuth client
ClientID string
+
+ // ISS indicates the issuer inditifer
ISS string
+
+ // State is the expected URL state paremeter
State string
+
+ // Verifier is the preimage of the challenge
Verifier string
- // filled in when constructing the callback
+ // Context is the context used for cancellation
Context context.Context
+
+ // Server is the server of the session
Server *http.Server
+
+ // Listener is the listener where the servers 'listens' on
Listener net.Listener
}
-// Struct that defines the json format for /.well-known/vpn-user-portal"
+// OAuthToken is a structure that defines the json format for /.well-known/vpn-user-portal"
type OAuthToken struct {
+ // Access is the access token returned by the server
Access string `json:"access_token"`
+
+ // Refresh token is the refresh token returned by the server
Refresh string `json:"refresh_token"`
+
+ // Type indicates which type of tokens we have
Type string `json:"token_type"`
+
+ // Expires is the expires time returned by the server
Expires int64 `json:"expires_in"`
+
+ // ExpiredTimestamp is the Expires field but converted to a Go timestamp
ExpiredTimestamp time.Time `json:"expires_in_timestamp"`
}
-// Sets up a listener
+// setupListener sets up an OAuth listener
+// If it was unsuccessful it returns an error
func (oauth *OAuth) setupListener() error {
errorMessage := "failed setting up listener"
- oauth.Session.Context = context.Background()
+ oauth.session.Context = context.Background()
// create a listener
listener, listenerErr := net.Listen("tcp", ":0")
if listenerErr != nil {
return types.NewWrappedError(errorMessage, listenerErr)
}
- oauth.Session.Listener = listener
+ oauth.session.Listener = listener
return nil
}
+// getTokensWithCallback gets the OAuth tokens using a local web server
+// If it was unsuccessful it returns an error
func (oauth *OAuth) getTokensWithCallback() error {
errorMessage := "failed getting tokens with callback"
- if oauth.Session.Listener == nil {
+ if oauth.session.Listener == nil {
return types.NewWrappedError(errorMessage, errors.New("no listener"))
}
mux := http.NewServeMux()
// server /callback over the listener address
- oauth.Session.Server = &http.Server{
+ oauth.session.Server = &http.Server{
Handler: mux,
}
mux.HandleFunc("/callback", oauth.Callback)
- if err := oauth.Session.Server.Serve(oauth.Session.Listener); err != http.ErrServerClosed {
+ if err := oauth.session.Server.Serve(oauth.session.Listener); err != http.ErrServerClosed {
return types.NewWrappedError(errorMessage, err)
}
- return oauth.Session.CallbackError
+ return oauth.session.CallbackError
}
-// Get the access and refresh tokens
+// getTokensWithAuthCode gets the access and refresh tokens using the authorization code
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
+// If it was unsuccessful it returns an error
func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
errorMessage := "failed getting tokens with the authorization code"
// Make sure the verifier is set as the parameter
@@ -151,9 +189,9 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
}
data := url.Values{
- "client_id": {oauth.Session.ClientID},
+ "client_id": {oauth.session.ClientID},
"code": {authCode},
- "code_verifier": {oauth.Session.Verifier},
+ "code_verifier": {oauth.session.Verifier},
"grant_type": {"authorization_code"},
"redirect_uri": {fmt.Sprintf("http://127.0.0.1:%d/callback", port)},
}
@@ -185,15 +223,17 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
return nil
}
+// isTokensExpired returns if the OAuth tokens are expired using the expired timestamp
func (oauth *OAuth) isTokensExpired() bool {
expiredTime := oauth.Token.ExpiredTimestamp
currentTime := time.Now()
return !currentTime.Before(expiredTime)
}
-// Get the access and refresh tokens with a previously received refresh token
+// getTokensWithRefresh gets the access and refresh tokens with a previously received refresh token
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
+// If it was unsuccessful it returns an error
func (oauth *OAuth) getTokensWithRefresh() error {
errorMessage := "failed getting tokens with the refresh token"
reqURL := oauth.TokenURL
@@ -228,7 +268,8 @@ func (oauth *OAuth) getTokensWithRefresh() error {
return nil
}
-// Adapted from: https://github.com/eduvpn/apple/blob/5b18f834be7aebfed00570ae0c2f7bcbaf1c69cc/EduVPN/Helpers/Mac/OAuthRedirectHTTPHandler.m#L25
+// responseTemplate is the HTML template for the OAuth authorized response
+// this template was dapted from: https://github.com/eduvpn/apple/blob/5b18f834be7aebfed00570ae0c2f7bcbaf1c69cc/EduVPN/Helpers/Mac/OAuthRedirectHTTPHandler.m#L25
const responseTemplate string = `
<!DOCTYPE html>
<html dir="ltr" xmlns="http://www.w3.org/1999/xhtml" lang="en"><head>
@@ -265,11 +306,14 @@ main {
</html>
`
+// oauthResponseHTML is a structure that is used to give back the OAuth response
type oauthResponseHTML struct {
Title string
Message string
}
+// writeResponseHTML writes the OAuth response using a response writer and the title + message
+// If it was unsuccessful it returns an error
func writeResponseHTML(w http.ResponseWriter, title string, message string) error {
errorMessage := "failed writing response HTML"
template, templateErr := template.New("oauth-response").Parse(responseTemplate)
@@ -287,21 +331,21 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro
return nil
}
-//
-//// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
+// Callback is the public function used to get the OAuth tokens using an authorization code callback
+// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
errorMessage := "failed callback to retrieve the authorization code"
// Shutdown after we're done
defer func() {
// writing the html is best effort
- if oauth.Session.CallbackError != nil {
+ if oauth.session.CallbackError != nil {
_ = writeResponseHTML(w, "Authorization Failed", "The authorization has failed. See the log file for more information.")
} else {
_ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.")
}
- if oauth.Session.Server != nil {
- go oauth.Session.Server.Shutdown(oauth.Session.Context) //nolint:errcheck
+ if oauth.session.Server != nil {
+ go oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
}
}()
@@ -310,10 +354,10 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
urlQuery := req.URL.Query()
extractedISS := urlQuery.Get("iss")
if extractedISS != "" {
- if oauth.Session.ISS != extractedISS {
- oauth.Session.CallbackError = types.NewWrappedError(
+ if oauth.session.ISS != extractedISS {
+ oauth.session.CallbackError = types.NewWrappedError(
errorMessage,
- &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.Session.ISS},
+ &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.session.ISS},
)
return
}
@@ -324,19 +368,19 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
extractedState := urlQuery.Get("state")
if extractedState == "" {
- oauth.Session.CallbackError = types.NewWrappedError(
+ oauth.session.CallbackError = types.NewWrappedError(
errorMessage,
&OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()},
)
return
}
// The state is the first entry
- if extractedState != oauth.Session.State {
- oauth.Session.CallbackError = types.NewWrappedError(
+ if extractedState != oauth.session.State {
+ oauth.session.CallbackError = types.NewWrappedError(
errorMessage,
&OAuthCallbackStateMatchError{
State: extractedState,
- ExpectedState: oauth.Session.State,
+ ExpectedState: oauth.session.State,
},
)
return
@@ -345,7 +389,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// No authorization code
extractedCode := urlQuery.Get("code")
if extractedCode == "" {
- oauth.Session.CallbackError = types.NewWrappedError(
+ oauth.session.CallbackError = types.NewWrappedError(
errorMessage,
&OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()},
)
@@ -356,7 +400,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// Obtaining the access and refresh tokens
getTokensErr := oauth.getTokensWithAuthCode(extractedCode)
if getTokensErr != nil {
- oauth.Session.CallbackError = types.NewWrappedError(
+ oauth.session.CallbackError = types.NewWrappedError(
errorMessage,
getTokensErr,
)
@@ -364,22 +408,28 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
}
}
+// Init initializes OAuth with the following parameters:
+// - OAuth server issuer identification
+// - The URL used for authorization
+// - The URL to obtain new tokens
func (oauth *OAuth) Init(iss string, baseAuthorizationURL string, tokenURL string) {
oauth.ISS = iss
oauth.BaseAuthorizationURL = baseAuthorizationURL
oauth.TokenURL = tokenURL
}
+// GetListenerPort gets the listener for the OAuth web server
+// It returns the port as an integer and an error if there is any
func (oauth OAuth) GetListenerPort() (int, error) {
errorMessage := "failed to get listener port"
- if oauth.Session.Listener == nil {
+ if oauth.session.Listener == nil {
return 0, types.NewWrappedError(errorMessage, errors.New("no OAuth listener"))
}
- return oauth.Session.Listener.Addr().(*net.TCPAddr).Port, nil
+ return oauth.session.Listener.Addr().(*net.TCPAddr).Port, nil
}
-// Starts the OAuth exchange for eduvpn.
+// GetAuthURL gets the authorization url to start the OAuth procedure
func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) (string, error) {
errorMessage := "failed starting OAuth exchange"
@@ -398,7 +448,7 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string)
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
oauthSession := OAuthExchangeSession{ClientID: name, ISS: oauth.ISS, State: state, Verifier: verifier}
- oauth.Session = oauthSession
+ oauth.session = oauthSession
// set up the listener to get the redirect URI
listenerErr := oauth.setupListener()
@@ -432,7 +482,8 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string)
return postProcessAuth(authURL), nil
}
-// Error definitions
+// Exchange starts the OAuth exchange by getting the tokens with the redirect callback
+// If it was unsuccessful it returns an error
func (oauth *OAuth) Exchange() error {
tokenErr := oauth.getTokensWithCallback()
@@ -442,17 +493,21 @@ func (oauth *OAuth) Exchange() error {
return nil
}
+// Cancel cancels the existing OAuth
+// TODO: Use context for this
func (oauth *OAuth) Cancel() {
- oauth.Session.CallbackError = types.NewWrappedErrorLevel(
+ oauth.session.CallbackError = types.NewWrappedErrorLevel(
types.ErrInfo,
"cancelled OAuth",
&OAuthCancelledCallbackError{},
)
- if oauth.Session.Server != nil {
- oauth.Session.Server.Shutdown(oauth.Session.Context) //nolint:errcheck
+ if oauth.session.Server != nil {
+ oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
}
}
+// EnsureTokens makes sure the OAuth tokens are still valid
+// if this cannot be guaranteed, it returns an error
func (oauth *OAuth) EnsureTokens() error {
errorMessage := "failed ensuring OAuth tokens"
// Access Token or Refresh Tokens empty, we can not ensure the tokens