summaryrefslogtreecommitdiff
path: root/internal/oauth
diff options
context:
space:
mode:
Diffstat (limited to 'internal/oauth')
-rw-r--r--internal/oauth/oauth.go72
1 files changed, 31 insertions, 41 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index bab1de2..59d0061 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -10,7 +10,6 @@ import (
"net/url"
"time"
- "github.com/jwijenbergh/eduvpn-common/internal/fsm"
httpw "github.com/jwijenbergh/eduvpn-common/internal/http"
"github.com/jwijenbergh/eduvpn-common/internal/types"
"github.com/jwijenbergh/eduvpn-common/internal/util"
@@ -65,7 +64,6 @@ type OAuth struct {
Token OAuthToken `json:"token"`
BaseAuthorizationURL string `json:"base_authorization_url"`
TokenURL string `json:"token_url"`
- FSM *fsm.FSM `json:"-"`
}
// This structure gets passed to the callback for easy access to the current state
@@ -250,24 +248,14 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
}
}
-func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.FSM) {
+func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string) {
oauth.BaseAuthorizationURL = baseAuthorizationURL
oauth.TokenURL = tokenURL
- oauth.FSM = fsm
}
// Starts the OAuth exchange for eduvpn.
-func (oauth *OAuth) start(name string, postprocessAuth func(string) string) error {
+func (oauth *OAuth) start(name string, postProcessAuth func(string) string, doAuth func(string) error) error {
errorMessage := "failed starting OAuth exchange"
- if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) {
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: fsm.WrongStateTransitionError{
- Got: oauth.FSM.Current,
- Want: fsm.OAUTH_STARTED,
- }.CustomError(),
- }
- }
// Generate the state
state, stateErr := genState()
if stateErr != nil {
@@ -300,29 +288,22 @@ func (oauth *OAuth) start(name string, postprocessAuth func(string) string) erro
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier}
oauth.Session = oauthSession
- // Run the state callback in the background so that the user can login while we start the callback server
- oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, postprocessAuth(authURL), true)
+
+ // Run the auth callback with the authurl processed
+ doAuthErr := doAuth(postProcessAuth(authURL))
+ if doAuthErr != nil {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
+ }
return nil
}
// Error definitions
func (oauth *OAuth) Finish() error {
- errorMessage := "failed finishing OAuth"
- if !oauth.FSM.HasTransition(fsm.AUTHORIZED) {
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: fsm.WrongStateTransitionError{
- Got: oauth.FSM.Current,
- Want: fsm.AUTHORIZED,
- }.CustomError(),
- }
- }
tokenErr := oauth.getTokensWithCallback()
if tokenErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: tokenErr}
+ return &types.WrappedErrorMessage{Message: "failed finishing OAuth", Err: tokenErr}
}
- oauth.FSM.GoTransition(fsm.AUTHORIZED)
return nil
}
@@ -334,9 +315,9 @@ func (oauth *OAuth) Cancel() {
oauth.Session.Server.Shutdown(oauth.Session.Context)
}
-func (oauth *OAuth) Login(name string, postprocessAuth func(string) string) error {
+func (oauth *OAuth) Login(name string, postprocessAuth func(string) string, doAuth func(string) error) error {
errorMessage := "failed OAuth login"
- authInitializeErr := oauth.start(name, postprocessAuth)
+ authInitializeErr := oauth.start(name, postprocessAuth, doAuth)
if authInitializeErr != nil {
return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr}
@@ -350,28 +331,29 @@ func (oauth *OAuth) Login(name string, postprocessAuth func(string) string) erro
return nil
}
-func (oauth *OAuth) NeedsRelogin() bool {
- // Access Token or Refresh Tokens empty, definitely needs a relogin
- if oauth.Token.Access == "" || oauth.Token.Refresh == "" {
- return true
+func (oauth *OAuth) EnsureTokens() error {
+ errorMessage := "failed ensuring OAuth tokens"
+ // Access Token or Refresh Tokens empty, we can not ensure the tokens
+ if oauth.Token.Access == "" && oauth.Token.Refresh == "" {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthTokensInvalidError{Cause: "tokens are empty"}}
}
// We have tokens...
-
// The tokens are not expired yet
- // No relogin is needed
+ // So they should be valid, re-login not needed
if !oauth.isTokensExpired() {
- return false
+ return nil
}
+ // Otherwise try to refresh them and return if successful
refreshErr := oauth.getTokensWithRefresh()
// We have obtained new tokens with refresh
- if refreshErr == nil {
- return false
+ if refreshErr != nil {
+ // We have failed to ensure the tokens due to refresh not working
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthTokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr)}}
}
- // Otherwise relogin is really needed
- return true
+ return nil
}
type OAuthCancelledCallbackError struct{}
@@ -397,3 +379,11 @@ type OAuthCallbackStateMatchError struct {
func (e *OAuthCallbackStateMatchError) Error() string {
return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState)
}
+
+type OAuthTokensInvalidError struct {
+ Cause string
+}
+
+func (e *OAuthTokensInvalidError) Error() string {
+ return fmt.Sprintf("tokens are invalid due to: %s", e.Cause)
+}