diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-07 17:44:07 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-07 17:44:07 +0200 |
| commit | e1bd5ec1c939f5431925ab3bb83352d0a275ebd9 (patch) | |
| tree | 5272a8592b52757ca288e20a759c244ecb962a3b /internal/server | |
| parent | 9be031fda160f7bb8e3294ab6620a1510828bd97 (diff) | |
Refactor: Remove the usage of the FSM in other internal packages
This removes the FSM from being imported and thus used in other
internal packages such as `oauth` or `server`. The benefit is that it
becomes much easier now to reason about the FSM as it's only used in
the public package. Additionally, we do not have to re-initialize the
server and the oauth structure with the FSM pointer.
Diffstat (limited to 'internal/server')
| -rw-r--r-- | internal/server/api.go | 15 | ||||
| -rw-r--r-- | internal/server/common.go | 155 | ||||
| -rw-r--r-- | internal/server/instituteaccess.go | 5 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 10 |
4 files changed, 59 insertions, 126 deletions
diff --git a/internal/server/api.go b/internal/server/api.go index 57d91c6..80ecf2e 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -10,7 +10,6 @@ import ( httpw "github.com/jwijenbergh/eduvpn-common/internal/http" "github.com/jwijenbergh/eduvpn-common/internal/types" - "github.com/jwijenbergh/eduvpn-common/internal/util" ) func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) { @@ -51,19 +50,14 @@ func apiAuthorized( url := base.Endpoints.API.V3.API + endpoint - // Ensure we have valid tokens - stateBefore := base.FSM.Current + // Make sure the tokens are valid, this will return an error if re-login is needed oauthErr := EnsureTokens(server) - - // we reset the state so that we go from the authorized state to the state we want - base.FSM.Current = stateBefore - if oauthErr != nil { return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr} } headerKey := "Authorization" - headerValue := fmt.Sprintf("Bearer %s", server.GetOAuth().Token.Access) + headerValue := fmt.Sprintf("Bearer %s", GetHeaderToken(server)) if opts.Headers != nil { opts.Headers.Add(headerKey, headerValue) } else { @@ -86,8 +80,8 @@ func apiAuthorizedRetry( // Only retry authorized if we get a HTTP 401 if errors.As(bodyErr, &error) && error.Status == 401 { - // Tell the method that the token is expired - server.GetOAuth().Token.ExpiredTimestamp = util.GetCurrentTime() + // Mark the token as expired and retry so we trigger the refresh flow + MarkTokenExpired(server) retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts) if retryErr != nil { return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: retryErr} @@ -205,6 +199,7 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, time.Time, err } // This needs no further return value as it's best effort +// FIXME: doAuth should not be needed here func APIDisconnect(server Server) { apiAuthorized(server, http.MethodPost, "/disconnect", nil) } diff --git a/internal/server/common.go b/internal/server/common.go index c1ce074..801c778 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -4,7 +4,6 @@ import ( "fmt" "time" - "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/oauth" "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" @@ -22,7 +21,6 @@ type ServerBase struct { StartTime time.Time `json:"start_time"` EndTime time.Time `json:"expire_time"` Type string `json:"server_type"` - FSM *fsm.FSM `json:"-"` } type ServerType int8 @@ -214,7 +212,6 @@ func (servers *Servers) GetCurrentServerInfo() (*ServerInfoScreen, error) { func (servers *Servers) addInstituteAndCustom( discoServer *types.DiscoveryServer, isCustom bool, - fsm *fsm.FSM, ) (Server, error) { url := discoServer.BaseURL errorMessage := fmt.Sprintf("failed adding institute access server: %s", url) @@ -244,7 +241,6 @@ func (servers *Servers) addInstituteAndCustom( discoServer.DisplayName, discoServer.Type, discoServer.SupportContact, - fsm, ) if instituteInitErr != nil { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: instituteInitErr} @@ -256,16 +252,14 @@ func (servers *Servers) addInstituteAndCustom( func (servers *Servers) AddInstituteAccessServer( instituteServer *types.DiscoveryServer, - fsm *fsm.FSM, ) (Server, error) { - return servers.addInstituteAndCustom(instituteServer, false, fsm) + return servers.addInstituteAndCustom(instituteServer, false) } func (servers *Servers) AddCustomServer( customServer *types.DiscoveryServer, - fsm *fsm.FSM, ) (Server, error) { - return servers.addInstituteAndCustom(customServer, true, fsm) + return servers.addInstituteAndCustom(customServer, true) } func (servers *Servers) GetSecureLocation() string { @@ -274,11 +268,10 @@ func (servers *Servers) GetSecureLocation() string { func (servers *Servers) SetSecureLocation( chosenLocationServer *types.DiscoveryServer, - fsm *fsm.FSM, ) error { errorMessage := "failed to set secure location" // Make sure to add the current location - _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer, fsm) + _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer) if addLocationErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: addLocationErr} @@ -291,12 +284,11 @@ func (servers *Servers) SetSecureLocation( func (servers *Servers) AddSecureInternet( secureOrg *types.DiscoveryOrganization, secureServer *types.DiscoveryServer, - fsm *fsm.FSM, ) (Server, error) { errorMessage := "failed adding secure internet server" // If we have specified an organization ID // We also need to get an authorization template - initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer, fsm) + initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer) if initErr != nil { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr} @@ -342,24 +334,28 @@ func ShouldRenewButton(server Server) bool { return true } -func Login(server Server) error { - return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth()) +func Login(server Server, doAuth func(string) error) error { + return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth(), doAuth) } -func EnsureTokens(server Server) error { - errorMessage := "failed ensuring server tokens" - if server.GetOAuth().NeedsRelogin() { - loginErr := Login(server) +func GetHeaderToken(server Server) string { + return server.GetOAuth().Token.Access +} - if loginErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr} - } +func MarkTokenExpired(server Server) { + server.GetOAuth().Token.ExpiredTimestamp = util.GetCurrentTime() +} + +func EnsureTokens(server Server) error { + ensureErr := server.GetOAuth().EnsureTokens() + if ensureErr != nil { + return &types.WrappedErrorMessage{Message: "failed ensuring server tokens", Err: ensureErr} } return nil } func NeedsRelogin(server Server) bool { - return server.GetOAuth().NeedsRelogin() + return EnsureTokens(server) != nil } func CancelOAuth(server Server) { @@ -466,24 +462,45 @@ func openVPNGetConfig(server Server) (string, string, error) { return configOpenVPN, "openvpn", nil } -func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) { - errorMessage := "failed getting an OpenVPN/WireGuard configuration with a profile" - base, baseErr := server.GetBase() +func HasValidProfile(server Server) (bool, error) { + errorMessage := "failed has valid profile check" + // Get new profiles using the info call + // This does not override the current profile + infoErr := APIInfo(server) + if infoErr != nil { + return false, &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr} + } + + base, baseErr := server.GetBase() if baseErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return false, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } - if !base.FSM.HasTransition(fsm.DISCONNECTED) { - return "", "", &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateTransitionError{ - Got: base.FSM.Current, - Want: fsm.DISCONNECTED, - }.CustomError(), + + // If there was a profile chosen and it doesn't exist anymore, reset it + if base.Profiles.Current != "" { + _, existsProfileErr := getCurrentProfile(server) + if existsProfileErr != nil { + base.Profiles.Current = "" } } - profile, profileErr := getCurrentProfile(server) + // Set the current profile if there is only one profile or profile is already selected + if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" { + // Set the first profile if none is selected + if base.Profiles.Current == "" { + base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID + } + return true, nil + } + + return false, nil +} + +func GetConfig(server Server, forceTCP bool) (string, string, error) { + errorMessage := "failed getting an OpenVPN/WireGuard configuration" + + profile, profileErr := getCurrentProfile(server) if profileErr != nil { return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr} } @@ -518,76 +535,6 @@ func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) return config, configType, nil } -func askForProfileID(server Server) error { - errorMessage := "failed asking for a server profile ID" - base, baseErr := server.GetBase() - - if baseErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} - } - if !base.FSM.HasTransition(fsm.ASK_PROFILE) { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateTransitionError{ - Got: base.FSM.Current, - Want: fsm.ASK_PROFILE, - }.CustomError(), - } - } - base.FSM.GoTransitionWithData(fsm.ASK_PROFILE, &base.Profiles, false) - return nil -} - -func GetConfig(server Server, forceTCP bool) (string, string, error) { - errorMessage := "failed getting an OpenVPN/WireGuard configuration" - base, baseErr := server.GetBase() - - if baseErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} - } - if !base.FSM.InState(fsm.REQUEST_CONFIG) { - return "", "", &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateError{ - Got: base.FSM.Current, - Want: fsm.REQUEST_CONFIG, - }.CustomError(), - } - } - - // Get new profiles using the info call - // This does not override the current profile - infoErr := APIInfo(server) - if infoErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr} - } - - // If there was a profile chosen and it doesn't exist anymore, reset it - if base.Profiles.Current != "" { - _, existsProfileErr := getCurrentProfile(server) - if existsProfileErr != nil { - base.Profiles.Current = "" - } - } - - // Set the current profile if there is only one profile or profile is already selected - if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" { - // Set the first profile if none is selected - if base.Profiles.Current == "" { - base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID - } - return getConfigWithProfile(server, forceTCP) - } - - profileErr := askForProfileID(server) - - if profileErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr} - } - - return getConfigWithProfile(server, forceTCP) -} - func Disconnect(server Server) { APIDisconnect(server) } diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index e948480..0cad158 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -3,7 +3,6 @@ package server import ( "fmt" - "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/oauth" "github.com/jwijenbergh/eduvpn-common/internal/types" ) @@ -56,19 +55,17 @@ func (institute *InstituteAccessServer) init( displayName map[string]string, serverType string, supportContact []string, - fsm *fsm.FSM, ) error { errorMessage := fmt.Sprintf("failed initializing institute server %s", url) institute.Base.URL = url institute.Base.DisplayName = displayName institute.Base.SupportContact = supportContact - institute.Base.FSM = fsm institute.Base.Type = serverType endpoints, endpointsErr := APIGetEndpoints(url) if endpointsErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} } - institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token, fsm) + institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token) institute.Base.Endpoints = *endpoints return nil } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 40c429b..d5689a8 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -3,7 +3,6 @@ package server import ( "fmt" - "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/oauth" "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" @@ -70,7 +69,6 @@ func (servers *Servers) HasSecureLocation() bool { func (secure *SecureInternetHomeServer) addLocation( locationServer *types.DiscoveryServer, - fsm *fsm.FSM, ) (*ServerBase, error) { errorMessage := "failed adding a location" // Initialize the base map if it is non-nil @@ -95,9 +93,6 @@ func (secure *SecureInternetHomeServer) addLocation( base.Endpoints = *endpoints } - // Pass the fsm - base.FSM = fsm - // Ensure it is in the map secure.BaseMap[locationServer.CountryCode] = base return base, nil @@ -107,7 +102,6 @@ func (secure *SecureInternetHomeServer) addLocation( func (secure *SecureInternetHomeServer) init( homeOrg *types.DiscoveryOrganization, homeLocation *types.DiscoveryServer, - fsm *fsm.FSM, ) error { errorMessage := "failed initializing secure internet home server" @@ -123,14 +117,14 @@ func (secure *SecureInternetHomeServer) init( // Make sure to set the authorization URL template secure.AuthorizationTemplate = homeLocation.AuthenticationURLTemplate - base, baseErr := secure.addLocation(homeLocation, fsm) + base, baseErr := secure.addLocation(homeLocation) if baseErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } // Make sure oauth contains our endpoints - secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm) + secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token) return nil } |
