summaryrefslogtreecommitdiff
path: root/state.go
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-07 17:44:07 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-07 17:44:07 +0200
commite1bd5ec1c939f5431925ab3bb83352d0a275ebd9 (patch)
tree5272a8592b52757ca288e20a759c244ecb962a3b /state.go
parent9be031fda160f7bb8e3294ab6620a1510828bd97 (diff)
Refactor: Remove the usage of the FSM in other internal packages
This removes the FSM from being imported and thus used in other internal packages such as `oauth` or `server`. The benefit is that it becomes much easier now to reason about the FSM as it's only used in the public package. Additionally, we do not have to re-initialize the server and the oauth structure with the FSM pointer.
Diffstat (limited to 'state.go')
-rw-r--r--state.go128
1 files changed, 80 insertions, 48 deletions
diff --git a/state.go b/state.go
index 9a077f6..7f2691a 100644
--- a/state.go
+++ b/state.go
@@ -126,36 +126,86 @@ func (state *VPNState) GoBack() error {
return nil
}
-func (state *VPNState) getConfig(
- chosenServer server.Server,
- forceTCP bool,
-) (string, string, error) {
- errorMessage := "failed to get a configuration for OpenVPN/Wireguard"
- if state.InFSMState(fsm.DEREGISTERED) {
- return "", "", &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: fsm.DeregisteredError{}.CustomError(),
- }
- }
+func (state *VPNState) doAuth(authURL string) error {
+ state.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, authURL, true)
+ return nil
+}
+func (state *VPNState) ensureLogin(chosenServer server.Server) error {
// Relogin with oauth
// This moves the state to authorized
if server.NeedsRelogin(chosenServer) {
- loginErr := server.Login(chosenServer)
+ loginErr := server.Login(chosenServer, state.doAuth)
if loginErr != nil {
// We are possibly in oauth started
// Go back
state.GoBack()
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr}
+ return &types.WrappedErrorMessage{Message: "failed ensuring login", Err: loginErr}
}
- } else { // OAuth was valid, ensure we are in the authorized state
- state.FSM.GoTransition(fsm.AUTHORIZED)
}
+ // OAuth was valid, ensure we are in the authorized state
+ state.FSM.GoTransition(fsm.AUTHORIZED)
+ return nil
+}
+func (state *VPNState) getConfigAuth(chosenServer server.Server, forceTCP bool) (string, string, error) {
+ loginErr := state.ensureLogin(chosenServer)
+ if loginErr != nil {
+ return "", "", loginErr
+ }
state.FSM.GoTransition(fsm.REQUEST_CONFIG)
- config, configType, configErr := server.GetConfig(chosenServer, forceTCP)
+ validProfile, profileErr := server.HasValidProfile(chosenServer)
+ if profileErr != nil {
+ return "", "", profileErr
+ }
+
+ // No valid profile, ask for one
+ if !validProfile {
+ askProfileErr := state.askProfile(chosenServer)
+ if askProfileErr != nil {
+ return "", "", askProfileErr
+ }
+ }
+
+ // We return the error otherwise we wrap it too much
+ return server.GetConfig(chosenServer, forceTCP)
+}
+
+func (state *VPNState) retryConfigAuth(chosenServer server.Server, forceTCP bool) (string, string, error) {
+ errorMessage := "failed authorized config retry"
+ config, configType, configErr := state.getConfigAuth(chosenServer, forceTCP)
+ if configErr != nil {
+ var error *oauth.OAuthTokensInvalidError
+
+ // Only retry if the error is that the tokens are invalid
+ if errors.As(configErr, &error) {
+ retryConfig, retryConfigType, retryConfigErr := state.getConfigAuth(chosenServer, forceTCP)
+ if retryConfigErr != nil {
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: retryConfigErr}
+ }
+ return retryConfig, retryConfigType, nil
+ }
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
+ }
+ return config, configType, nil
+}
+
+
+func (state *VPNState) getConfig(
+ chosenServer server.Server,
+ forceTCP bool,
+) (string, string, error) {
+ errorMessage := "failed to get a configuration for OpenVPN/Wireguard"
+ if state.InFSMState(fsm.DEREGISTERED) {
+ return "", "", &types.WrappedErrorMessage{
+ Message: errorMessage,
+ Err: fsm.DeregisteredError{}.CustomError(),
+ }
+ }
+
+ config, configType, configErr := state.retryConfigAuth(chosenServer, forceTCP)
if configErr != nil {
// Go back
@@ -180,7 +230,7 @@ func (state *VPNState) SetSecureLocation(countryCode string) error {
return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
- setLocationErr := state.Servers.SetSecureLocation(server, &state.FSM)
+ setLocationErr := state.Servers.SetSecureLocation(server)
if setLocationErr != nil {
state.GoBack()
return &types.WrappedErrorMessage{Message: errorMessage, Err: setLocationErr}
@@ -188,6 +238,15 @@ func (state *VPNState) SetSecureLocation(countryCode string) error {
return nil
}
+func (state *VPNState) askProfile(chosenServer server.Server) error {
+ base, baseErr := chosenServer.GetBase()
+ if baseErr != nil {
+ return &types.WrappedErrorMessage{Message: "failed asking for profiles", Err: baseErr}
+ }
+ state.FSM.GoTransitionWithData(fsm.ASK_PROFILE, &base.Profiles, false)
+ return nil
+}
+
func (state *VPNState) askSecureLocation() error {
locations := state.Discovery.GetSecureLocationList()
@@ -214,7 +273,7 @@ func (state *VPNState) addSecureInternetHomeServer(orgID string) (server.Server,
}
// Add the secure internet server
- server, serverErr := state.Servers.AddSecureInternet(secureOrg, secureServer, &state.FSM)
+ server, serverErr := state.Servers.AddSecureInternet(secureOrg, secureServer)
if serverErr != nil {
return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
@@ -309,7 +368,7 @@ func (state *VPNState) addInstituteServer(url string) (server.Server, error) {
return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: discoErr}
}
// Add the secure internet server
- server, serverErr := state.Servers.AddInstituteAccessServer(instituteServer, &state.FSM)
+ server, serverErr := state.Servers.AddInstituteAccessServer(instituteServer)
if serverErr != nil {
return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
@@ -336,7 +395,7 @@ func (state *VPNState) addCustomServer(url string) (server.Server, error) {
}
// A custom server is just an institute access server under the hood
- server, serverErr := state.Servers.AddCustomServer(customServer, &state.FSM)
+ server, serverErr := state.Servers.AddCustomServer(customServer)
if serverErr != nil {
state.RemoveCustomServer(url)
@@ -543,15 +602,6 @@ func (state *VPNState) SetDisconnected(cleanup bool) error {
return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr}
}
- oauthStructure := currentServer.GetOAuth()
-
- // Make sure the FSM is initialized
- oauthStructure.FSM = &state.FSM
- base, baseErr := currentServer.GetBase()
- if baseErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
- }
- base.FSM = &state.FSM
server.Disconnect(currentServer)
}
@@ -569,25 +619,7 @@ func (state *VPNState) RenewSession() error {
return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr}
}
- oauthStructure := currentServer.GetOAuth()
- oauthStructure.Token = oauth.OAuthToken{
- Access: "",
- Refresh: "",
- Type: "",
- Expires: 0,
- ExpiredTimestamp: util.GetCurrentTime(),
- }
-
- // Make sure the FSM is initialized
- oauthStructure.FSM = &state.FSM
- base, baseErr := currentServer.GetBase()
- if baseErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
- }
- base.FSM = &state.FSM
-
- loginErr := server.Login(currentServer)
-
+ loginErr := server.Login(currentServer, state.doAuth)
if loginErr != nil {
// Go back
state.GoBack()