summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-07 17:44:07 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-07 17:44:07 +0200
commite1bd5ec1c939f5431925ab3bb83352d0a275ebd9 (patch)
tree5272a8592b52757ca288e20a759c244ecb962a3b /internal
parent9be031fda160f7bb8e3294ab6620a1510828bd97 (diff)
Refactor: Remove the usage of the FSM in other internal packages
This removes the FSM from being imported and thus used in other internal packages such as `oauth` or `server`. The benefit is that it becomes much easier now to reason about the FSM as it's only used in the public package. Additionally, we do not have to re-initialize the server and the oauth structure with the FSM pointer.
Diffstat (limited to 'internal')
-rw-r--r--internal/oauth/oauth.go72
-rw-r--r--internal/server/api.go15
-rw-r--r--internal/server/common.go155
-rw-r--r--internal/server/instituteaccess.go5
-rw-r--r--internal/server/secureinternet.go10
5 files changed, 90 insertions, 167 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index bab1de2..59d0061 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -10,7 +10,6 @@ import (
"net/url"
"time"
- "github.com/jwijenbergh/eduvpn-common/internal/fsm"
httpw "github.com/jwijenbergh/eduvpn-common/internal/http"
"github.com/jwijenbergh/eduvpn-common/internal/types"
"github.com/jwijenbergh/eduvpn-common/internal/util"
@@ -65,7 +64,6 @@ type OAuth struct {
Token OAuthToken `json:"token"`
BaseAuthorizationURL string `json:"base_authorization_url"`
TokenURL string `json:"token_url"`
- FSM *fsm.FSM `json:"-"`
}
// This structure gets passed to the callback for easy access to the current state
@@ -250,24 +248,14 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
}
}
-func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.FSM) {
+func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string) {
oauth.BaseAuthorizationURL = baseAuthorizationURL
oauth.TokenURL = tokenURL
- oauth.FSM = fsm
}
// Starts the OAuth exchange for eduvpn.
-func (oauth *OAuth) start(name string, postprocessAuth func(string) string) error {
+func (oauth *OAuth) start(name string, postProcessAuth func(string) string, doAuth func(string) error) error {
errorMessage := "failed starting OAuth exchange"
- if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) {
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: fsm.WrongStateTransitionError{
- Got: oauth.FSM.Current,
- Want: fsm.OAUTH_STARTED,
- }.CustomError(),
- }
- }
// Generate the state
state, stateErr := genState()
if stateErr != nil {
@@ -300,29 +288,22 @@ func (oauth *OAuth) start(name string, postprocessAuth func(string) string) erro
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier}
oauth.Session = oauthSession
- // Run the state callback in the background so that the user can login while we start the callback server
- oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, postprocessAuth(authURL), true)
+
+ // Run the auth callback with the authurl processed
+ doAuthErr := doAuth(postProcessAuth(authURL))
+ if doAuthErr != nil {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
+ }
return nil
}
// Error definitions
func (oauth *OAuth) Finish() error {
- errorMessage := "failed finishing OAuth"
- if !oauth.FSM.HasTransition(fsm.AUTHORIZED) {
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: fsm.WrongStateTransitionError{
- Got: oauth.FSM.Current,
- Want: fsm.AUTHORIZED,
- }.CustomError(),
- }
- }
tokenErr := oauth.getTokensWithCallback()
if tokenErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: tokenErr}
+ return &types.WrappedErrorMessage{Message: "failed finishing OAuth", Err: tokenErr}
}
- oauth.FSM.GoTransition(fsm.AUTHORIZED)
return nil
}
@@ -334,9 +315,9 @@ func (oauth *OAuth) Cancel() {
oauth.Session.Server.Shutdown(oauth.Session.Context)
}
-func (oauth *OAuth) Login(name string, postprocessAuth func(string) string) error {
+func (oauth *OAuth) Login(name string, postprocessAuth func(string) string, doAuth func(string) error) error {
errorMessage := "failed OAuth login"
- authInitializeErr := oauth.start(name, postprocessAuth)
+ authInitializeErr := oauth.start(name, postprocessAuth, doAuth)
if authInitializeErr != nil {
return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr}
@@ -350,28 +331,29 @@ func (oauth *OAuth) Login(name string, postprocessAuth func(string) string) erro
return nil
}
-func (oauth *OAuth) NeedsRelogin() bool {
- // Access Token or Refresh Tokens empty, definitely needs a relogin
- if oauth.Token.Access == "" || oauth.Token.Refresh == "" {
- return true
+func (oauth *OAuth) EnsureTokens() error {
+ errorMessage := "failed ensuring OAuth tokens"
+ // Access Token or Refresh Tokens empty, we can not ensure the tokens
+ if oauth.Token.Access == "" && oauth.Token.Refresh == "" {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthTokensInvalidError{Cause: "tokens are empty"}}
}
// We have tokens...
-
// The tokens are not expired yet
- // No relogin is needed
+ // So they should be valid, re-login not needed
if !oauth.isTokensExpired() {
- return false
+ return nil
}
+ // Otherwise try to refresh them and return if successful
refreshErr := oauth.getTokensWithRefresh()
// We have obtained new tokens with refresh
- if refreshErr == nil {
- return false
+ if refreshErr != nil {
+ // We have failed to ensure the tokens due to refresh not working
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthTokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr)}}
}
- // Otherwise relogin is really needed
- return true
+ return nil
}
type OAuthCancelledCallbackError struct{}
@@ -397,3 +379,11 @@ type OAuthCallbackStateMatchError struct {
func (e *OAuthCallbackStateMatchError) Error() string {
return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState)
}
+
+type OAuthTokensInvalidError struct {
+ Cause string
+}
+
+func (e *OAuthTokensInvalidError) Error() string {
+ return fmt.Sprintf("tokens are invalid due to: %s", e.Cause)
+}
diff --git a/internal/server/api.go b/internal/server/api.go
index 57d91c6..80ecf2e 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -10,7 +10,6 @@ import (
httpw "github.com/jwijenbergh/eduvpn-common/internal/http"
"github.com/jwijenbergh/eduvpn-common/internal/types"
- "github.com/jwijenbergh/eduvpn-common/internal/util"
)
func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) {
@@ -51,19 +50,14 @@ func apiAuthorized(
url := base.Endpoints.API.V3.API + endpoint
- // Ensure we have valid tokens
- stateBefore := base.FSM.Current
+ // Make sure the tokens are valid, this will return an error if re-login is needed
oauthErr := EnsureTokens(server)
-
- // we reset the state so that we go from the authorized state to the state we want
- base.FSM.Current = stateBefore
-
if oauthErr != nil {
return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr}
}
headerKey := "Authorization"
- headerValue := fmt.Sprintf("Bearer %s", server.GetOAuth().Token.Access)
+ headerValue := fmt.Sprintf("Bearer %s", GetHeaderToken(server))
if opts.Headers != nil {
opts.Headers.Add(headerKey, headerValue)
} else {
@@ -86,8 +80,8 @@ func apiAuthorizedRetry(
// Only retry authorized if we get a HTTP 401
if errors.As(bodyErr, &error) && error.Status == 401 {
- // Tell the method that the token is expired
- server.GetOAuth().Token.ExpiredTimestamp = util.GetCurrentTime()
+ // Mark the token as expired and retry so we trigger the refresh flow
+ MarkTokenExpired(server)
retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts)
if retryErr != nil {
return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: retryErr}
@@ -205,6 +199,7 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, time.Time, err
}
// This needs no further return value as it's best effort
+// FIXME: doAuth should not be needed here
func APIDisconnect(server Server) {
apiAuthorized(server, http.MethodPost, "/disconnect", nil)
}
diff --git a/internal/server/common.go b/internal/server/common.go
index c1ce074..801c778 100644
--- a/internal/server/common.go
+++ b/internal/server/common.go
@@ -4,7 +4,6 @@ import (
"fmt"
"time"
- "github.com/jwijenbergh/eduvpn-common/internal/fsm"
"github.com/jwijenbergh/eduvpn-common/internal/oauth"
"github.com/jwijenbergh/eduvpn-common/internal/types"
"github.com/jwijenbergh/eduvpn-common/internal/util"
@@ -22,7 +21,6 @@ type ServerBase struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"expire_time"`
Type string `json:"server_type"`
- FSM *fsm.FSM `json:"-"`
}
type ServerType int8
@@ -214,7 +212,6 @@ func (servers *Servers) GetCurrentServerInfo() (*ServerInfoScreen, error) {
func (servers *Servers) addInstituteAndCustom(
discoServer *types.DiscoveryServer,
isCustom bool,
- fsm *fsm.FSM,
) (Server, error) {
url := discoServer.BaseURL
errorMessage := fmt.Sprintf("failed adding institute access server: %s", url)
@@ -244,7 +241,6 @@ func (servers *Servers) addInstituteAndCustom(
discoServer.DisplayName,
discoServer.Type,
discoServer.SupportContact,
- fsm,
)
if instituteInitErr != nil {
return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: instituteInitErr}
@@ -256,16 +252,14 @@ func (servers *Servers) addInstituteAndCustom(
func (servers *Servers) AddInstituteAccessServer(
instituteServer *types.DiscoveryServer,
- fsm *fsm.FSM,
) (Server, error) {
- return servers.addInstituteAndCustom(instituteServer, false, fsm)
+ return servers.addInstituteAndCustom(instituteServer, false)
}
func (servers *Servers) AddCustomServer(
customServer *types.DiscoveryServer,
- fsm *fsm.FSM,
) (Server, error) {
- return servers.addInstituteAndCustom(customServer, true, fsm)
+ return servers.addInstituteAndCustom(customServer, true)
}
func (servers *Servers) GetSecureLocation() string {
@@ -274,11 +268,10 @@ func (servers *Servers) GetSecureLocation() string {
func (servers *Servers) SetSecureLocation(
chosenLocationServer *types.DiscoveryServer,
- fsm *fsm.FSM,
) error {
errorMessage := "failed to set secure location"
// Make sure to add the current location
- _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer, fsm)
+ _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer)
if addLocationErr != nil {
return &types.WrappedErrorMessage{Message: errorMessage, Err: addLocationErr}
@@ -291,12 +284,11 @@ func (servers *Servers) SetSecureLocation(
func (servers *Servers) AddSecureInternet(
secureOrg *types.DiscoveryOrganization,
secureServer *types.DiscoveryServer,
- fsm *fsm.FSM,
) (Server, error) {
errorMessage := "failed adding secure internet server"
// If we have specified an organization ID
// We also need to get an authorization template
- initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer, fsm)
+ initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer)
if initErr != nil {
return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr}
@@ -342,24 +334,28 @@ func ShouldRenewButton(server Server) bool {
return true
}
-func Login(server Server) error {
- return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth())
+func Login(server Server, doAuth func(string) error) error {
+ return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth(), doAuth)
}
-func EnsureTokens(server Server) error {
- errorMessage := "failed ensuring server tokens"
- if server.GetOAuth().NeedsRelogin() {
- loginErr := Login(server)
+func GetHeaderToken(server Server) string {
+ return server.GetOAuth().Token.Access
+}
- if loginErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr}
- }
+func MarkTokenExpired(server Server) {
+ server.GetOAuth().Token.ExpiredTimestamp = util.GetCurrentTime()
+}
+
+func EnsureTokens(server Server) error {
+ ensureErr := server.GetOAuth().EnsureTokens()
+ if ensureErr != nil {
+ return &types.WrappedErrorMessage{Message: "failed ensuring server tokens", Err: ensureErr}
}
return nil
}
func NeedsRelogin(server Server) bool {
- return server.GetOAuth().NeedsRelogin()
+ return EnsureTokens(server) != nil
}
func CancelOAuth(server Server) {
@@ -466,24 +462,45 @@ func openVPNGetConfig(server Server) (string, string, error) {
return configOpenVPN, "openvpn", nil
}
-func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) {
- errorMessage := "failed getting an OpenVPN/WireGuard configuration with a profile"
- base, baseErr := server.GetBase()
+func HasValidProfile(server Server) (bool, error) {
+ errorMessage := "failed has valid profile check"
+ // Get new profiles using the info call
+ // This does not override the current profile
+ infoErr := APIInfo(server)
+ if infoErr != nil {
+ return false, &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr}
+ }
+
+ base, baseErr := server.GetBase()
if baseErr != nil {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
+ return false, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
- if !base.FSM.HasTransition(fsm.DISCONNECTED) {
- return "", "", &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: fsm.WrongStateTransitionError{
- Got: base.FSM.Current,
- Want: fsm.DISCONNECTED,
- }.CustomError(),
+
+ // If there was a profile chosen and it doesn't exist anymore, reset it
+ if base.Profiles.Current != "" {
+ _, existsProfileErr := getCurrentProfile(server)
+ if existsProfileErr != nil {
+ base.Profiles.Current = ""
}
}
- profile, profileErr := getCurrentProfile(server)
+ // Set the current profile if there is only one profile or profile is already selected
+ if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" {
+ // Set the first profile if none is selected
+ if base.Profiles.Current == "" {
+ base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID
+ }
+ return true, nil
+ }
+
+ return false, nil
+}
+
+func GetConfig(server Server, forceTCP bool) (string, string, error) {
+ errorMessage := "failed getting an OpenVPN/WireGuard configuration"
+
+ profile, profileErr := getCurrentProfile(server)
if profileErr != nil {
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr}
}
@@ -518,76 +535,6 @@ func getConfigWithProfile(server Server, forceTCP bool) (string, string, error)
return config, configType, nil
}
-func askForProfileID(server Server) error {
- errorMessage := "failed asking for a server profile ID"
- base, baseErr := server.GetBase()
-
- if baseErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
- }
- if !base.FSM.HasTransition(fsm.ASK_PROFILE) {
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: fsm.WrongStateTransitionError{
- Got: base.FSM.Current,
- Want: fsm.ASK_PROFILE,
- }.CustomError(),
- }
- }
- base.FSM.GoTransitionWithData(fsm.ASK_PROFILE, &base.Profiles, false)
- return nil
-}
-
-func GetConfig(server Server, forceTCP bool) (string, string, error) {
- errorMessage := "failed getting an OpenVPN/WireGuard configuration"
- base, baseErr := server.GetBase()
-
- if baseErr != nil {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
- }
- if !base.FSM.InState(fsm.REQUEST_CONFIG) {
- return "", "", &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: fsm.WrongStateError{
- Got: base.FSM.Current,
- Want: fsm.REQUEST_CONFIG,
- }.CustomError(),
- }
- }
-
- // Get new profiles using the info call
- // This does not override the current profile
- infoErr := APIInfo(server)
- if infoErr != nil {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr}
- }
-
- // If there was a profile chosen and it doesn't exist anymore, reset it
- if base.Profiles.Current != "" {
- _, existsProfileErr := getCurrentProfile(server)
- if existsProfileErr != nil {
- base.Profiles.Current = ""
- }
- }
-
- // Set the current profile if there is only one profile or profile is already selected
- if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" {
- // Set the first profile if none is selected
- if base.Profiles.Current == "" {
- base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID
- }
- return getConfigWithProfile(server, forceTCP)
- }
-
- profileErr := askForProfileID(server)
-
- if profileErr != nil {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr}
- }
-
- return getConfigWithProfile(server, forceTCP)
-}
-
func Disconnect(server Server) {
APIDisconnect(server)
}
diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go
index e948480..0cad158 100644
--- a/internal/server/instituteaccess.go
+++ b/internal/server/instituteaccess.go
@@ -3,7 +3,6 @@ package server
import (
"fmt"
- "github.com/jwijenbergh/eduvpn-common/internal/fsm"
"github.com/jwijenbergh/eduvpn-common/internal/oauth"
"github.com/jwijenbergh/eduvpn-common/internal/types"
)
@@ -56,19 +55,17 @@ func (institute *InstituteAccessServer) init(
displayName map[string]string,
serverType string,
supportContact []string,
- fsm *fsm.FSM,
) error {
errorMessage := fmt.Sprintf("failed initializing institute server %s", url)
institute.Base.URL = url
institute.Base.DisplayName = displayName
institute.Base.SupportContact = supportContact
- institute.Base.FSM = fsm
institute.Base.Type = serverType
endpoints, endpointsErr := APIGetEndpoints(url)
if endpointsErr != nil {
return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr}
}
- institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token, fsm)
+ institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token)
institute.Base.Endpoints = *endpoints
return nil
}
diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go
index 40c429b..d5689a8 100644
--- a/internal/server/secureinternet.go
+++ b/internal/server/secureinternet.go
@@ -3,7 +3,6 @@ package server
import (
"fmt"
- "github.com/jwijenbergh/eduvpn-common/internal/fsm"
"github.com/jwijenbergh/eduvpn-common/internal/oauth"
"github.com/jwijenbergh/eduvpn-common/internal/types"
"github.com/jwijenbergh/eduvpn-common/internal/util"
@@ -70,7 +69,6 @@ func (servers *Servers) HasSecureLocation() bool {
func (secure *SecureInternetHomeServer) addLocation(
locationServer *types.DiscoveryServer,
- fsm *fsm.FSM,
) (*ServerBase, error) {
errorMessage := "failed adding a location"
// Initialize the base map if it is non-nil
@@ -95,9 +93,6 @@ func (secure *SecureInternetHomeServer) addLocation(
base.Endpoints = *endpoints
}
- // Pass the fsm
- base.FSM = fsm
-
// Ensure it is in the map
secure.BaseMap[locationServer.CountryCode] = base
return base, nil
@@ -107,7 +102,6 @@ func (secure *SecureInternetHomeServer) addLocation(
func (secure *SecureInternetHomeServer) init(
homeOrg *types.DiscoveryOrganization,
homeLocation *types.DiscoveryServer,
- fsm *fsm.FSM,
) error {
errorMessage := "failed initializing secure internet home server"
@@ -123,14 +117,14 @@ func (secure *SecureInternetHomeServer) init(
// Make sure to set the authorization URL template
secure.AuthorizationTemplate = homeLocation.AuthenticationURLTemplate
- base, baseErr := secure.addLocation(homeLocation, fsm)
+ base, baseErr := secure.addLocation(homeLocation)
if baseErr != nil {
return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
// Make sure oauth contains our endpoints
- secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm)
+ secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token)
return nil
}