diff options
Diffstat (limited to 'src/oauth.go')
| -rw-r--r-- | src/oauth.go | 218 |
1 files changed, 218 insertions, 0 deletions
diff --git a/src/oauth.go b/src/oauth.go new file mode 100644 index 0000000..9697ac0 --- /dev/null +++ b/src/oauth.go @@ -0,0 +1,218 @@ +package eduvpn + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/http" + "golang.org/x/oauth2" +) + +// Generates a random base64 string to be used for state +// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1 +// "state": OPTIONAL. An opaque value used by the client to maintain +// state between the request and callback. The authorization server +// includes this value when redirecting the user agent back to the +// client. +func genState() (string, error) { + randomBytes, err := MakeRandomByteSlice(32) + + if err != nil { + return "", err + } + + // For consistency we also use raw url encoding here + return base64.RawURLEncoding.EncodeToString(randomBytes), nil +} + +// Generates a sha256 base64 challenge from a verifier +// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.8 +func genChallengeS256(verifier string) string { + hash := sha256.Sum256([]byte(verifier)) + + // We use raw url encoding as the challenge does not accept padding + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + + +// Generates a verifier +// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1 +// The code_verifier is a unique high-entropy cryptographically random +// string generated for each authorization request, using the unreserved +// characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~", with a +// minimum length of 43 characters and a maximum length of 128 +// characters. +func genVerifier() (string, error) { + randomBytes, err := MakeRandomByteSlice(32) + if err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(randomBytes), nil +} + + +// This structure gets passed to the callback for easy access to the current state +type EduVPNOauth struct { + // Public + AuthURL string + Config *oauth2.Config + + // private + callbackError error + client *http.Client + context context.Context + state string + server *http.Server + verifier string +} + +// Initializes the OAuth eduvpn class. It returns a tuple of the class and error. +// If the error is non-nil, the class will be nil. +func InitializeOAuth(config *oauth2.Config) (*EduVPNOauth, error) { + // Generate the state + state, stateErr := genState() + if stateErr != nil { + return nil, detailedOAuthError{errGenStateError, fmt.Sprintf("oauth failed to gen random bytes for state"), stateErr} + } + + // Generate the verifier and challenge + verifier, err := genVerifier() + if err != nil { + return nil, detailedOAuthError{errGenVerifierError, fmt.Sprintf("oauth failed to verifier"), err} + } + challenge := genChallengeS256(verifier) + + // Update the auth url with the challenge and state + codeChallengeMethod := oauth2.SetAuthURLParam("code_challenge_method", "S256") + codeChallenge := oauth2.SetAuthURLParam("code_challenge", challenge) + authURL := config.AuthCodeURL(state, codeChallengeMethod, codeChallenge) + + // Return the struct with the necessary fields filled for the next call to getting the HTTP client + return &EduVPNOauth{AuthURL: authURL, Config: config, state: state, verifier: verifier}, nil +} + +// Gets an authenticated HTTP client by obtaining refresh and access tokens +func (eduvpn* EduVPNOauth) GetHTTPTokenClient() (*http.Client, error) { + eduvpn.context = context.Background() + mux := http.NewServeMux() + eduvpn.server = &http.Server { + Addr: "127.0.0.1:8000", + Handler: mux, + } + mux.HandleFunc("/callback", eduvpn.oauthCallback) + if err := eduvpn.server.ListenAndServe(); err != http.ErrServerClosed { + return nil, detailedOAuthError{errCallbackServerError, fmt.Sprintf("oauth callback server error"), err} + } + return eduvpn.client, eduvpn.callbackError +} + + + +// Get the access and refresh tokens +// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 +// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 +func (eduvpn *EduVPNOauth) getTokens(authCode string) error { + // Make sure the verifier is set as the parameter + // so that the server can verify that we are the actual owner of the authorization code + codeVerifier := oauth2.SetAuthURLParam("code_verifier", eduvpn.verifier) + + // This is from the oauth library that gets the tokens from the authcode + // https://pkg.go.dev/golang.org/x/oauth2#Config.Exchange + // We pass the verifier as an additional parameter + tok, err := eduvpn.Config.Exchange(eduvpn.context, authCode, codeVerifier) + if err != nil { + return err + } + + // Fill the http client for future use in the struct + eduvpn.client = eduvpn.Config.Client(eduvpn.context, tok) + return nil +} + +// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 +func (eduvpn *EduVPNOauth) oauthCallback(w http.ResponseWriter, req *http.Request) { + // Extract the authorization code + code, success := req.URL.Query()["code"] + if !success { + eduvpn.callbackError = detailedOAuthError{errCallbackGetAuthCodeError, fmt.Sprintf("oauth auth code cannot be retrieved"), nil} + go eduvpn.server.Shutdown(eduvpn.context) + return + } + // The code is the first entry + extractedCode := code[0] + + + // Make sure the state is present and matches to protect against cross-site request forgeries + // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 + state, success := req.URL.Query()["state"] + if !success { + eduvpn.callbackError = detailedOAuthError{errCallbackGetStateError, fmt.Sprintf("oauth state cannot be retrieved"), nil} + go eduvpn.server.Shutdown(eduvpn.context) + return + } + // The state is the first entry + extractedState := state[0] + if extractedState != eduvpn.state { + eduvpn.callbackError = detailedOAuthError{errCallbackVerifyStateMatchError, fmt.Sprintf("oauth state does not match"), nil} + go eduvpn.server.Shutdown(eduvpn.context) + return + } + + // Now that we have obtained the authorization code, we can move to the next step: + // Obtaining the access and refresh tokens + err := eduvpn.getTokens(extractedCode) + + if err != nil { + eduvpn.callbackError = detailedOAuthError{errCallbackGetTokenExchangeError, fmt.Sprintf("oauth failed to get token in exchange"), err} + go eduvpn.server.Shutdown(eduvpn.context) + return + } + + // Shutdown the server as we're done listening + go eduvpn.server.Shutdown(eduvpn.context) +} + +// OAuthErrorCode Simplified error code for public interface. +type OAuthErrorCode = VPNErrorCode +type OAuthError = VPNError + +// detailedOAuthErrorCode used for unit tests. +type detailedOAuthErrorCode = detailedVPNErrorCode +type detailedOAuthError = detailedVPNError + +const ( + ErrGenError OAuthErrorCode = iota + 1 + ErrCallbackTokenError +) + +const ( + errGenStateError detailedOAuthErrorCode = iota + 1 + errGenVerifierError + errCallbackServerError + errCallbackGetAuthCodeError + errCallbackGetStateError + errCallbackVerifyStateMatchError + errCallbackGetTokenExchangeError +) + +func (err detailedOAuthError) ToOAuthError() OAuthError { + return RequestError{err.Code.ToOAuthErrorCode(), err} +} + +func (code detailedOAuthErrorCode) ToOAuthErrorCode() OAuthErrorCode { + switch code { + case errGenStateError: + case errGenVerifierError: + return ErrGenError + + case errCallbackServerError: + case errCallbackGetAuthCodeError: + case errCallbackGetStateError: + case errCallbackVerifyStateMatchError: + case errCallbackGetTokenExchangeError: + return ErrCallbackTokenError + } + panic("invalid detailedOAuthErrorCode") +} |
