diff options
| -rw-r--r-- | internal/oauth/oauth.go | 34 | ||||
| -rw-r--r-- | internal/server/common.go | 8 | ||||
| -rw-r--r-- | state.go | 31 |
3 files changed, 34 insertions, 39 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 59d0061..89854f7 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -254,18 +254,18 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string) { } // Starts the OAuth exchange for eduvpn. -func (oauth *OAuth) start(name string, postProcessAuth func(string) string, doAuth func(string) error) error { +func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) (string, error) { errorMessage := "failed starting OAuth exchange" // Generate the state state, stateErr := genState() if stateErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr} } // Generate the verifier and challenge verifier, verifierErr := genVerifier() if verifierErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: verifierErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: verifierErr} } challenge := genChallengeS256(verifier) @@ -282,23 +282,19 @@ func (oauth *OAuth) start(name string, postProcessAuth func(string) string, doAu authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters) if urlErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} } // Fill the struct with the necessary fields filled for the next call to getting the HTTP client oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier} oauth.Session = oauthSession - // Run the auth callback with the authurl processed - doAuthErr := doAuth(postProcessAuth(authURL)) - if doAuthErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} - } - return nil + // Return the url processed + return postProcessAuth(authURL), nil } // Error definitions -func (oauth *OAuth) Finish() error { +func (oauth *OAuth) Exchange() error { tokenErr := oauth.getTokensWithCallback() if tokenErr != nil { @@ -315,22 +311,6 @@ func (oauth *OAuth) Cancel() { oauth.Session.Server.Shutdown(oauth.Session.Context) } -func (oauth *OAuth) Login(name string, postprocessAuth func(string) string, doAuth func(string) error) error { - errorMessage := "failed OAuth login" - authInitializeErr := oauth.start(name, postprocessAuth, doAuth) - - if authInitializeErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr} - } - - oauthErr := oauth.Finish() - - if oauthErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr} - } - return nil -} - func (oauth *OAuth) EnsureTokens() error { errorMessage := "failed ensuring OAuth tokens" // Access Token or Refresh Tokens empty, we can not ensure the tokens diff --git a/internal/server/common.go b/internal/server/common.go index 801c778..64b8079 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -334,8 +334,12 @@ func ShouldRenewButton(server Server) bool { return true } -func Login(server Server, doAuth func(string) error) error { - return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth(), doAuth) +func GetOAuthURL(server Server, name string) (string, error) { + return server.GetOAuth().GetAuthURL(name, server.GetTemplateAuth()) +} + +func OAuthExchange(server Server) error { + return server.GetOAuth().Exchange() } func GetHeaderToken(server Server) string { @@ -131,16 +131,24 @@ func (state *VPNState) doAuth(authURL string) error { } func (state *VPNState) ensureLogin(chosenServer server.Server) error { + errorMessage := "failed ensuring login" // Relogin with oauth // This moves the state to authorized if server.NeedsRelogin(chosenServer) { - loginErr := server.Login(chosenServer, state.doAuth) + url, urlErr := server.GetOAuthURL(chosenServer, state.FSM.Name) - if loginErr != nil { - // We are possibly in oauth started - // Go back + state.FSM.GoTransitionWithData(STATE_OAUTH_STARTED, url, true) + + if urlErr != nil { state.GoBack() - return &types.WrappedErrorMessage{Message: "failed ensuring login", Err: loginErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} + } + + exchangeErr := server.OAuthExchange(chosenServer) + + if exchangeErr != nil { + state.GoBack() + return &types.WrappedErrorMessage{Message: errorMessage, Err: exchangeErr} } } // OAuth was valid, ensure we are in the authorized state @@ -208,7 +216,6 @@ func (state *VPNState) getConfig( if configErr != nil { // Go back - state.GoBack() return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} } @@ -255,7 +262,7 @@ func (state *VPNState) askSecureLocation() error { // The state has changed, meaning setting the secure location was not successful if state.FSM.Current != STATE_ASK_LOCATION { // TODO: maybe a custom type for this errors.new? - return &types.WrappedErrorMessage{Message: "failed setting secure location", Err: errors.New("failed setting secure location due to state change")} + return &types.WrappedErrorMessage{Message: "failed setting secure location", Err: errors.New("failed loading secure location")} } return nil } @@ -352,6 +359,7 @@ func (state *VPNState) GetConfigSecureInternet( if serverErr != nil { state.RemoveSecureInternet() + state.GoBack() return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } @@ -397,7 +405,6 @@ func (state *VPNState) addCustomServer(url string) (server.Server, error) { server, serverErr := state.Servers.AddCustomServer(customServer) if serverErr != nil { - state.RemoveCustomServer(url) return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } @@ -413,6 +420,7 @@ func (state *VPNState) GetConfigInstituteAccess(url string, forceTCP bool) (stri if serverErr != nil { state.RemoveInstituteAccess(url) + state.GoBack() return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } @@ -488,12 +496,14 @@ func (state *VPNState) SetProfileID(profileID string) error { errorMessage := "failed to set the profile ID for the current server" server, serverErr := state.Servers.GetCurrentServer() if serverErr != nil { + state.GoBack() return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } base, baseErr := server.GetBase() if baseErr != nil { + state.GoBack() return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } base.Profiles.Current = profileID @@ -618,10 +628,11 @@ func (state *VPNState) RenewSession() error { return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} } - loginErr := server.Login(currentServer, state.doAuth) + // FIXME: Delete tokens? + + loginErr := state.ensureLogin(currentServer) if loginErr != nil { // Go back - state.GoBack() return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr} } |
