diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-07 17:44:07 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-07 17:44:07 +0200 |
| commit | e1bd5ec1c939f5431925ab3bb83352d0a275ebd9 (patch) | |
| tree | 5272a8592b52757ca288e20a759c244ecb962a3b | |
| parent | 9be031fda160f7bb8e3294ab6620a1510828bd97 (diff) | |
Refactor: Remove the usage of the FSM in other internal packages
This removes the FSM from being imported and thus used in other
internal packages such as `oauth` or `server`. The benefit is that it
becomes much easier now to reason about the FSM as it's only used in
the public package. Additionally, we do not have to re-initialize the
server and the oauth structure with the FSM pointer.
| -rw-r--r-- | internal/oauth/oauth.go | 72 | ||||
| -rw-r--r-- | internal/server/api.go | 15 | ||||
| -rw-r--r-- | internal/server/common.go | 155 | ||||
| -rw-r--r-- | internal/server/instituteaccess.go | 5 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 10 | ||||
| -rw-r--r-- | state.go | 128 | ||||
| -rw-r--r-- | state_test.go | 12 |
7 files changed, 173 insertions, 224 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index bab1de2..59d0061 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -10,7 +10,6 @@ import ( "net/url" "time" - "github.com/jwijenbergh/eduvpn-common/internal/fsm" httpw "github.com/jwijenbergh/eduvpn-common/internal/http" "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" @@ -65,7 +64,6 @@ type OAuth struct { Token OAuthToken `json:"token"` BaseAuthorizationURL string `json:"base_authorization_url"` TokenURL string `json:"token_url"` - FSM *fsm.FSM `json:"-"` } // This structure gets passed to the callback for easy access to the current state @@ -250,24 +248,14 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { } } -func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.FSM) { +func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string) { oauth.BaseAuthorizationURL = baseAuthorizationURL oauth.TokenURL = tokenURL - oauth.FSM = fsm } // Starts the OAuth exchange for eduvpn. -func (oauth *OAuth) start(name string, postprocessAuth func(string) string) error { +func (oauth *OAuth) start(name string, postProcessAuth func(string) string, doAuth func(string) error) error { errorMessage := "failed starting OAuth exchange" - if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateTransitionError{ - Got: oauth.FSM.Current, - Want: fsm.OAUTH_STARTED, - }.CustomError(), - } - } // Generate the state state, stateErr := genState() if stateErr != nil { @@ -300,29 +288,22 @@ func (oauth *OAuth) start(name string, postprocessAuth func(string) string) erro // Fill the struct with the necessary fields filled for the next call to getting the HTTP client oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier} oauth.Session = oauthSession - // Run the state callback in the background so that the user can login while we start the callback server - oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, postprocessAuth(authURL), true) + + // Run the auth callback with the authurl processed + doAuthErr := doAuth(postProcessAuth(authURL)) + if doAuthErr != nil { + return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} + } return nil } // Error definitions func (oauth *OAuth) Finish() error { - errorMessage := "failed finishing OAuth" - if !oauth.FSM.HasTransition(fsm.AUTHORIZED) { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateTransitionError{ - Got: oauth.FSM.Current, - Want: fsm.AUTHORIZED, - }.CustomError(), - } - } tokenErr := oauth.getTokensWithCallback() if tokenErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: tokenErr} + return &types.WrappedErrorMessage{Message: "failed finishing OAuth", Err: tokenErr} } - oauth.FSM.GoTransition(fsm.AUTHORIZED) return nil } @@ -334,9 +315,9 @@ func (oauth *OAuth) Cancel() { oauth.Session.Server.Shutdown(oauth.Session.Context) } -func (oauth *OAuth) Login(name string, postprocessAuth func(string) string) error { +func (oauth *OAuth) Login(name string, postprocessAuth func(string) string, doAuth func(string) error) error { errorMessage := "failed OAuth login" - authInitializeErr := oauth.start(name, postprocessAuth) + authInitializeErr := oauth.start(name, postprocessAuth, doAuth) if authInitializeErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr} @@ -350,28 +331,29 @@ func (oauth *OAuth) Login(name string, postprocessAuth func(string) string) erro return nil } -func (oauth *OAuth) NeedsRelogin() bool { - // Access Token or Refresh Tokens empty, definitely needs a relogin - if oauth.Token.Access == "" || oauth.Token.Refresh == "" { - return true +func (oauth *OAuth) EnsureTokens() error { + errorMessage := "failed ensuring OAuth tokens" + // Access Token or Refresh Tokens empty, we can not ensure the tokens + if oauth.Token.Access == "" && oauth.Token.Refresh == "" { + return &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthTokensInvalidError{Cause: "tokens are empty"}} } // We have tokens... - // The tokens are not expired yet - // No relogin is needed + // So they should be valid, re-login not needed if !oauth.isTokensExpired() { - return false + return nil } + // Otherwise try to refresh them and return if successful refreshErr := oauth.getTokensWithRefresh() // We have obtained new tokens with refresh - if refreshErr == nil { - return false + if refreshErr != nil { + // We have failed to ensure the tokens due to refresh not working + return &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthTokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr)}} } - // Otherwise relogin is really needed - return true + return nil } type OAuthCancelledCallbackError struct{} @@ -397,3 +379,11 @@ type OAuthCallbackStateMatchError struct { func (e *OAuthCallbackStateMatchError) Error() string { return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState) } + +type OAuthTokensInvalidError struct { + Cause string +} + +func (e *OAuthTokensInvalidError) Error() string { + return fmt.Sprintf("tokens are invalid due to: %s", e.Cause) +} diff --git a/internal/server/api.go b/internal/server/api.go index 57d91c6..80ecf2e 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -10,7 +10,6 @@ import ( httpw "github.com/jwijenbergh/eduvpn-common/internal/http" "github.com/jwijenbergh/eduvpn-common/internal/types" - "github.com/jwijenbergh/eduvpn-common/internal/util" ) func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) { @@ -51,19 +50,14 @@ func apiAuthorized( url := base.Endpoints.API.V3.API + endpoint - // Ensure we have valid tokens - stateBefore := base.FSM.Current + // Make sure the tokens are valid, this will return an error if re-login is needed oauthErr := EnsureTokens(server) - - // we reset the state so that we go from the authorized state to the state we want - base.FSM.Current = stateBefore - if oauthErr != nil { return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr} } headerKey := "Authorization" - headerValue := fmt.Sprintf("Bearer %s", server.GetOAuth().Token.Access) + headerValue := fmt.Sprintf("Bearer %s", GetHeaderToken(server)) if opts.Headers != nil { opts.Headers.Add(headerKey, headerValue) } else { @@ -86,8 +80,8 @@ func apiAuthorizedRetry( // Only retry authorized if we get a HTTP 401 if errors.As(bodyErr, &error) && error.Status == 401 { - // Tell the method that the token is expired - server.GetOAuth().Token.ExpiredTimestamp = util.GetCurrentTime() + // Mark the token as expired and retry so we trigger the refresh flow + MarkTokenExpired(server) retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts) if retryErr != nil { return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: retryErr} @@ -205,6 +199,7 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, time.Time, err } // This needs no further return value as it's best effort +// FIXME: doAuth should not be needed here func APIDisconnect(server Server) { apiAuthorized(server, http.MethodPost, "/disconnect", nil) } diff --git a/internal/server/common.go b/internal/server/common.go index c1ce074..801c778 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -4,7 +4,6 @@ import ( "fmt" "time" - "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/oauth" "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" @@ -22,7 +21,6 @@ type ServerBase struct { StartTime time.Time `json:"start_time"` EndTime time.Time `json:"expire_time"` Type string `json:"server_type"` - FSM *fsm.FSM `json:"-"` } type ServerType int8 @@ -214,7 +212,6 @@ func (servers *Servers) GetCurrentServerInfo() (*ServerInfoScreen, error) { func (servers *Servers) addInstituteAndCustom( discoServer *types.DiscoveryServer, isCustom bool, - fsm *fsm.FSM, ) (Server, error) { url := discoServer.BaseURL errorMessage := fmt.Sprintf("failed adding institute access server: %s", url) @@ -244,7 +241,6 @@ func (servers *Servers) addInstituteAndCustom( discoServer.DisplayName, discoServer.Type, discoServer.SupportContact, - fsm, ) if instituteInitErr != nil { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: instituteInitErr} @@ -256,16 +252,14 @@ func (servers *Servers) addInstituteAndCustom( func (servers *Servers) AddInstituteAccessServer( instituteServer *types.DiscoveryServer, - fsm *fsm.FSM, ) (Server, error) { - return servers.addInstituteAndCustom(instituteServer, false, fsm) + return servers.addInstituteAndCustom(instituteServer, false) } func (servers *Servers) AddCustomServer( customServer *types.DiscoveryServer, - fsm *fsm.FSM, ) (Server, error) { - return servers.addInstituteAndCustom(customServer, true, fsm) + return servers.addInstituteAndCustom(customServer, true) } func (servers *Servers) GetSecureLocation() string { @@ -274,11 +268,10 @@ func (servers *Servers) GetSecureLocation() string { func (servers *Servers) SetSecureLocation( chosenLocationServer *types.DiscoveryServer, - fsm *fsm.FSM, ) error { errorMessage := "failed to set secure location" // Make sure to add the current location - _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer, fsm) + _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer) if addLocationErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: addLocationErr} @@ -291,12 +284,11 @@ func (servers *Servers) SetSecureLocation( func (servers *Servers) AddSecureInternet( secureOrg *types.DiscoveryOrganization, secureServer *types.DiscoveryServer, - fsm *fsm.FSM, ) (Server, error) { errorMessage := "failed adding secure internet server" // If we have specified an organization ID // We also need to get an authorization template - initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer, fsm) + initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer) if initErr != nil { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr} @@ -342,24 +334,28 @@ func ShouldRenewButton(server Server) bool { return true } -func Login(server Server) error { - return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth()) +func Login(server Server, doAuth func(string) error) error { + return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth(), doAuth) } -func EnsureTokens(server Server) error { - errorMessage := "failed ensuring server tokens" - if server.GetOAuth().NeedsRelogin() { - loginErr := Login(server) +func GetHeaderToken(server Server) string { + return server.GetOAuth().Token.Access +} - if loginErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr} - } +func MarkTokenExpired(server Server) { + server.GetOAuth().Token.ExpiredTimestamp = util.GetCurrentTime() +} + +func EnsureTokens(server Server) error { + ensureErr := server.GetOAuth().EnsureTokens() + if ensureErr != nil { + return &types.WrappedErrorMessage{Message: "failed ensuring server tokens", Err: ensureErr} } return nil } func NeedsRelogin(server Server) bool { - return server.GetOAuth().NeedsRelogin() + return EnsureTokens(server) != nil } func CancelOAuth(server Server) { @@ -466,24 +462,45 @@ func openVPNGetConfig(server Server) (string, string, error) { return configOpenVPN, "openvpn", nil } -func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) { - errorMessage := "failed getting an OpenVPN/WireGuard configuration with a profile" - base, baseErr := server.GetBase() +func HasValidProfile(server Server) (bool, error) { + errorMessage := "failed has valid profile check" + // Get new profiles using the info call + // This does not override the current profile + infoErr := APIInfo(server) + if infoErr != nil { + return false, &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr} + } + + base, baseErr := server.GetBase() if baseErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return false, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } - if !base.FSM.HasTransition(fsm.DISCONNECTED) { - return "", "", &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateTransitionError{ - Got: base.FSM.Current, - Want: fsm.DISCONNECTED, - }.CustomError(), + + // If there was a profile chosen and it doesn't exist anymore, reset it + if base.Profiles.Current != "" { + _, existsProfileErr := getCurrentProfile(server) + if existsProfileErr != nil { + base.Profiles.Current = "" } } - profile, profileErr := getCurrentProfile(server) + // Set the current profile if there is only one profile or profile is already selected + if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" { + // Set the first profile if none is selected + if base.Profiles.Current == "" { + base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID + } + return true, nil + } + + return false, nil +} + +func GetConfig(server Server, forceTCP bool) (string, string, error) { + errorMessage := "failed getting an OpenVPN/WireGuard configuration" + + profile, profileErr := getCurrentProfile(server) if profileErr != nil { return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr} } @@ -518,76 +535,6 @@ func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) return config, configType, nil } -func askForProfileID(server Server) error { - errorMessage := "failed asking for a server profile ID" - base, baseErr := server.GetBase() - - if baseErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} - } - if !base.FSM.HasTransition(fsm.ASK_PROFILE) { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateTransitionError{ - Got: base.FSM.Current, - Want: fsm.ASK_PROFILE, - }.CustomError(), - } - } - base.FSM.GoTransitionWithData(fsm.ASK_PROFILE, &base.Profiles, false) - return nil -} - -func GetConfig(server Server, forceTCP bool) (string, string, error) { - errorMessage := "failed getting an OpenVPN/WireGuard configuration" - base, baseErr := server.GetBase() - - if baseErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} - } - if !base.FSM.InState(fsm.REQUEST_CONFIG) { - return "", "", &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateError{ - Got: base.FSM.Current, - Want: fsm.REQUEST_CONFIG, - }.CustomError(), - } - } - - // Get new profiles using the info call - // This does not override the current profile - infoErr := APIInfo(server) - if infoErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr} - } - - // If there was a profile chosen and it doesn't exist anymore, reset it - if base.Profiles.Current != "" { - _, existsProfileErr := getCurrentProfile(server) - if existsProfileErr != nil { - base.Profiles.Current = "" - } - } - - // Set the current profile if there is only one profile or profile is already selected - if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" { - // Set the first profile if none is selected - if base.Profiles.Current == "" { - base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID - } - return getConfigWithProfile(server, forceTCP) - } - - profileErr := askForProfileID(server) - - if profileErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr} - } - - return getConfigWithProfile(server, forceTCP) -} - func Disconnect(server Server) { APIDisconnect(server) } diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index e948480..0cad158 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -3,7 +3,6 @@ package server import ( "fmt" - "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/oauth" "github.com/jwijenbergh/eduvpn-common/internal/types" ) @@ -56,19 +55,17 @@ func (institute *InstituteAccessServer) init( displayName map[string]string, serverType string, supportContact []string, - fsm *fsm.FSM, ) error { errorMessage := fmt.Sprintf("failed initializing institute server %s", url) institute.Base.URL = url institute.Base.DisplayName = displayName institute.Base.SupportContact = supportContact - institute.Base.FSM = fsm institute.Base.Type = serverType endpoints, endpointsErr := APIGetEndpoints(url) if endpointsErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} } - institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token, fsm) + institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token) institute.Base.Endpoints = *endpoints return nil } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 40c429b..d5689a8 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -3,7 +3,6 @@ package server import ( "fmt" - "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/oauth" "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" @@ -70,7 +69,6 @@ func (servers *Servers) HasSecureLocation() bool { func (secure *SecureInternetHomeServer) addLocation( locationServer *types.DiscoveryServer, - fsm *fsm.FSM, ) (*ServerBase, error) { errorMessage := "failed adding a location" // Initialize the base map if it is non-nil @@ -95,9 +93,6 @@ func (secure *SecureInternetHomeServer) addLocation( base.Endpoints = *endpoints } - // Pass the fsm - base.FSM = fsm - // Ensure it is in the map secure.BaseMap[locationServer.CountryCode] = base return base, nil @@ -107,7 +102,6 @@ func (secure *SecureInternetHomeServer) addLocation( func (secure *SecureInternetHomeServer) init( homeOrg *types.DiscoveryOrganization, homeLocation *types.DiscoveryServer, - fsm *fsm.FSM, ) error { errorMessage := "failed initializing secure internet home server" @@ -123,14 +117,14 @@ func (secure *SecureInternetHomeServer) init( // Make sure to set the authorization URL template secure.AuthorizationTemplate = homeLocation.AuthenticationURLTemplate - base, baseErr := secure.addLocation(homeLocation, fsm) + base, baseErr := secure.addLocation(homeLocation) if baseErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } // Make sure oauth contains our endpoints - secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm) + secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token) return nil } @@ -126,36 +126,86 @@ func (state *VPNState) GoBack() error { return nil } -func (state *VPNState) getConfig( - chosenServer server.Server, - forceTCP bool, -) (string, string, error) { - errorMessage := "failed to get a configuration for OpenVPN/Wireguard" - if state.InFSMState(fsm.DEREGISTERED) { - return "", "", &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.DeregisteredError{}.CustomError(), - } - } +func (state *VPNState) doAuth(authURL string) error { + state.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, authURL, true) + return nil +} +func (state *VPNState) ensureLogin(chosenServer server.Server) error { // Relogin with oauth // This moves the state to authorized if server.NeedsRelogin(chosenServer) { - loginErr := server.Login(chosenServer) + loginErr := server.Login(chosenServer, state.doAuth) if loginErr != nil { // We are possibly in oauth started // Go back state.GoBack() - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr} + return &types.WrappedErrorMessage{Message: "failed ensuring login", Err: loginErr} } - } else { // OAuth was valid, ensure we are in the authorized state - state.FSM.GoTransition(fsm.AUTHORIZED) } + // OAuth was valid, ensure we are in the authorized state + state.FSM.GoTransition(fsm.AUTHORIZED) + return nil +} +func (state *VPNState) getConfigAuth(chosenServer server.Server, forceTCP bool) (string, string, error) { + loginErr := state.ensureLogin(chosenServer) + if loginErr != nil { + return "", "", loginErr + } state.FSM.GoTransition(fsm.REQUEST_CONFIG) - config, configType, configErr := server.GetConfig(chosenServer, forceTCP) + validProfile, profileErr := server.HasValidProfile(chosenServer) + if profileErr != nil { + return "", "", profileErr + } + + // No valid profile, ask for one + if !validProfile { + askProfileErr := state.askProfile(chosenServer) + if askProfileErr != nil { + return "", "", askProfileErr + } + } + + // We return the error otherwise we wrap it too much + return server.GetConfig(chosenServer, forceTCP) +} + +func (state *VPNState) retryConfigAuth(chosenServer server.Server, forceTCP bool) (string, string, error) { + errorMessage := "failed authorized config retry" + config, configType, configErr := state.getConfigAuth(chosenServer, forceTCP) + if configErr != nil { + var error *oauth.OAuthTokensInvalidError + + // Only retry if the error is that the tokens are invalid + if errors.As(configErr, &error) { + retryConfig, retryConfigType, retryConfigErr := state.getConfigAuth(chosenServer, forceTCP) + if retryConfigErr != nil { + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: retryConfigErr} + } + return retryConfig, retryConfigType, nil + } + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} + } + return config, configType, nil +} + + +func (state *VPNState) getConfig( + chosenServer server.Server, + forceTCP bool, +) (string, string, error) { + errorMessage := "failed to get a configuration for OpenVPN/Wireguard" + if state.InFSMState(fsm.DEREGISTERED) { + return "", "", &types.WrappedErrorMessage{ + Message: errorMessage, + Err: fsm.DeregisteredError{}.CustomError(), + } + } + + config, configType, configErr := state.retryConfigAuth(chosenServer, forceTCP) if configErr != nil { // Go back @@ -180,7 +230,7 @@ func (state *VPNState) SetSecureLocation(countryCode string) error { return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } - setLocationErr := state.Servers.SetSecureLocation(server, &state.FSM) + setLocationErr := state.Servers.SetSecureLocation(server) if setLocationErr != nil { state.GoBack() return &types.WrappedErrorMessage{Message: errorMessage, Err: setLocationErr} @@ -188,6 +238,15 @@ func (state *VPNState) SetSecureLocation(countryCode string) error { return nil } +func (state *VPNState) askProfile(chosenServer server.Server) error { + base, baseErr := chosenServer.GetBase() + if baseErr != nil { + return &types.WrappedErrorMessage{Message: "failed asking for profiles", Err: baseErr} + } + state.FSM.GoTransitionWithData(fsm.ASK_PROFILE, &base.Profiles, false) + return nil +} + func (state *VPNState) askSecureLocation() error { locations := state.Discovery.GetSecureLocationList() @@ -214,7 +273,7 @@ func (state *VPNState) addSecureInternetHomeServer(orgID string) (server.Server, } // Add the secure internet server - server, serverErr := state.Servers.AddSecureInternet(secureOrg, secureServer, &state.FSM) + server, serverErr := state.Servers.AddSecureInternet(secureOrg, secureServer) if serverErr != nil { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} @@ -309,7 +368,7 @@ func (state *VPNState) addInstituteServer(url string) (server.Server, error) { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: discoErr} } // Add the secure internet server - server, serverErr := state.Servers.AddInstituteAccessServer(instituteServer, &state.FSM) + server, serverErr := state.Servers.AddInstituteAccessServer(instituteServer) if serverErr != nil { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} @@ -336,7 +395,7 @@ func (state *VPNState) addCustomServer(url string) (server.Server, error) { } // A custom server is just an institute access server under the hood - server, serverErr := state.Servers.AddCustomServer(customServer, &state.FSM) + server, serverErr := state.Servers.AddCustomServer(customServer) if serverErr != nil { state.RemoveCustomServer(url) @@ -543,15 +602,6 @@ func (state *VPNState) SetDisconnected(cleanup bool) error { return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} } - oauthStructure := currentServer.GetOAuth() - - // Make sure the FSM is initialized - oauthStructure.FSM = &state.FSM - base, baseErr := currentServer.GetBase() - if baseErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} - } - base.FSM = &state.FSM server.Disconnect(currentServer) } @@ -569,25 +619,7 @@ func (state *VPNState) RenewSession() error { return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} } - oauthStructure := currentServer.GetOAuth() - oauthStructure.Token = oauth.OAuthToken{ - Access: "", - Refresh: "", - Type: "", - Expires: 0, - ExpiredTimestamp: util.GetCurrentTime(), - } - - // Make sure the FSM is initialized - oauthStructure.FSM = &state.FSM - base, baseErr := currentServer.GetBase() - if baseErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} - } - base.FSM = &state.FSM - - loginErr := server.Login(currentServer) - + loginErr := server.Login(currentServer, state.doAuth) if loginErr != nil { // Go back state.GoBack() diff --git a/state_test.go b/state_test.go index 20a7064..2106f40 100644 --- a/state_test.go +++ b/state_test.go @@ -123,7 +123,6 @@ func test_connect_oauth_parameter( ) } go http.Get(url) - } }, false, @@ -216,10 +215,10 @@ func Test_token_expired(t *testing.T) { // Wait for TTL so that the tokens expire time.Sleep(time.Duration(expiredInt) * time.Second) - infoErr := server.APIInfo(currentServer) + _, _, configErr = state.GetConfigCustomServer(serverURI, false) - if infoErr != nil { - t.Fatalf("Info error after expired: %v", infoErr) + if configErr != nil { + t.Fatalf("Connect error after expiry: %v", configErr) } // Check if tokens have changed @@ -256,11 +255,6 @@ func Test_token_invalid(t *testing.T) { t.Fatalf("Connect error before invalid: %v", configErr) } - // Go to request_config so we can re-authorize - // This is needed as the only actual authenticated requests we do in request_config (for profiles) and /connect - // /disconnect is best effort so this does not need re-auth - state.FSM.GoTransition(fsm.REQUEST_CONFIG) - dummy_value := "37" currentServer, serverErr := state.Servers.GetCurrentServer() |
