summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/oauth/oauth.go34
-rw-r--r--internal/server/common.go8
-rw-r--r--state.go31
3 files changed, 34 insertions, 39 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 59d0061..89854f7 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -254,18 +254,18 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string) {
}
// Starts the OAuth exchange for eduvpn.
-func (oauth *OAuth) start(name string, postProcessAuth func(string) string, doAuth func(string) error) error {
+func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) (string, error) {
errorMessage := "failed starting OAuth exchange"
// Generate the state
state, stateErr := genState()
if stateErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr}
}
// Generate the verifier and challenge
verifier, verifierErr := genVerifier()
if verifierErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: verifierErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: verifierErr}
}
challenge := genChallengeS256(verifier)
@@ -282,23 +282,19 @@ func (oauth *OAuth) start(name string, postProcessAuth func(string) string, doAu
authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters)
if urlErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
}
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier}
oauth.Session = oauthSession
- // Run the auth callback with the authurl processed
- doAuthErr := doAuth(postProcessAuth(authURL))
- if doAuthErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
- }
- return nil
+ // Return the url processed
+ return postProcessAuth(authURL), nil
}
// Error definitions
-func (oauth *OAuth) Finish() error {
+func (oauth *OAuth) Exchange() error {
tokenErr := oauth.getTokensWithCallback()
if tokenErr != nil {
@@ -315,22 +311,6 @@ func (oauth *OAuth) Cancel() {
oauth.Session.Server.Shutdown(oauth.Session.Context)
}
-func (oauth *OAuth) Login(name string, postprocessAuth func(string) string, doAuth func(string) error) error {
- errorMessage := "failed OAuth login"
- authInitializeErr := oauth.start(name, postprocessAuth, doAuth)
-
- if authInitializeErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr}
- }
-
- oauthErr := oauth.Finish()
-
- if oauthErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr}
- }
- return nil
-}
-
func (oauth *OAuth) EnsureTokens() error {
errorMessage := "failed ensuring OAuth tokens"
// Access Token or Refresh Tokens empty, we can not ensure the tokens
diff --git a/internal/server/common.go b/internal/server/common.go
index 801c778..64b8079 100644
--- a/internal/server/common.go
+++ b/internal/server/common.go
@@ -334,8 +334,12 @@ func ShouldRenewButton(server Server) bool {
return true
}
-func Login(server Server, doAuth func(string) error) error {
- return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth(), doAuth)
+func GetOAuthURL(server Server, name string) (string, error) {
+ return server.GetOAuth().GetAuthURL(name, server.GetTemplateAuth())
+}
+
+func OAuthExchange(server Server) error {
+ return server.GetOAuth().Exchange()
}
func GetHeaderToken(server Server) string {
diff --git a/state.go b/state.go
index d77d1b6..965c934 100644
--- a/state.go
+++ b/state.go
@@ -131,16 +131,24 @@ func (state *VPNState) doAuth(authURL string) error {
}
func (state *VPNState) ensureLogin(chosenServer server.Server) error {
+ errorMessage := "failed ensuring login"
// Relogin with oauth
// This moves the state to authorized
if server.NeedsRelogin(chosenServer) {
- loginErr := server.Login(chosenServer, state.doAuth)
+ url, urlErr := server.GetOAuthURL(chosenServer, state.FSM.Name)
- if loginErr != nil {
- // We are possibly in oauth started
- // Go back
+ state.FSM.GoTransitionWithData(STATE_OAUTH_STARTED, url, true)
+
+ if urlErr != nil {
state.GoBack()
- return &types.WrappedErrorMessage{Message: "failed ensuring login", Err: loginErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
+ }
+
+ exchangeErr := server.OAuthExchange(chosenServer)
+
+ if exchangeErr != nil {
+ state.GoBack()
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: exchangeErr}
}
}
// OAuth was valid, ensure we are in the authorized state
@@ -208,7 +216,6 @@ func (state *VPNState) getConfig(
if configErr != nil {
// Go back
- state.GoBack()
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
}
@@ -255,7 +262,7 @@ func (state *VPNState) askSecureLocation() error {
// The state has changed, meaning setting the secure location was not successful
if state.FSM.Current != STATE_ASK_LOCATION {
// TODO: maybe a custom type for this errors.new?
- return &types.WrappedErrorMessage{Message: "failed setting secure location", Err: errors.New("failed setting secure location due to state change")}
+ return &types.WrappedErrorMessage{Message: "failed setting secure location", Err: errors.New("failed loading secure location")}
}
return nil
}
@@ -352,6 +359,7 @@ func (state *VPNState) GetConfigSecureInternet(
if serverErr != nil {
state.RemoveSecureInternet()
+ state.GoBack()
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
@@ -397,7 +405,6 @@ func (state *VPNState) addCustomServer(url string) (server.Server, error) {
server, serverErr := state.Servers.AddCustomServer(customServer)
if serverErr != nil {
- state.RemoveCustomServer(url)
return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
@@ -413,6 +420,7 @@ func (state *VPNState) GetConfigInstituteAccess(url string, forceTCP bool) (stri
if serverErr != nil {
state.RemoveInstituteAccess(url)
+ state.GoBack()
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
@@ -488,12 +496,14 @@ func (state *VPNState) SetProfileID(profileID string) error {
errorMessage := "failed to set the profile ID for the current server"
server, serverErr := state.Servers.GetCurrentServer()
if serverErr != nil {
+ state.GoBack()
return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
base, baseErr := server.GetBase()
if baseErr != nil {
+ state.GoBack()
return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
base.Profiles.Current = profileID
@@ -618,10 +628,11 @@ func (state *VPNState) RenewSession() error {
return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr}
}
- loginErr := server.Login(currentServer, state.doAuth)
+ // FIXME: Delete tokens?
+
+ loginErr := state.ensureLogin(currentServer)
if loginErr != nil {
// Go back
- state.GoBack()
return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr}
}