diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/oauth/oauth.go | 72 | ||||
| -rw-r--r-- | internal/server/api.go | 15 | ||||
| -rw-r--r-- | internal/server/common.go | 155 | ||||
| -rw-r--r-- | internal/server/instituteaccess.go | 5 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 10 |
5 files changed, 90 insertions, 167 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index bab1de2..59d0061 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -10,7 +10,6 @@ import ( "net/url" "time" - "github.com/jwijenbergh/eduvpn-common/internal/fsm" httpw "github.com/jwijenbergh/eduvpn-common/internal/http" "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" @@ -65,7 +64,6 @@ type OAuth struct { Token OAuthToken `json:"token"` BaseAuthorizationURL string `json:"base_authorization_url"` TokenURL string `json:"token_url"` - FSM *fsm.FSM `json:"-"` } // This structure gets passed to the callback for easy access to the current state @@ -250,24 +248,14 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { } } -func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.FSM) { +func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string) { oauth.BaseAuthorizationURL = baseAuthorizationURL oauth.TokenURL = tokenURL - oauth.FSM = fsm } // Starts the OAuth exchange for eduvpn. -func (oauth *OAuth) start(name string, postprocessAuth func(string) string) error { +func (oauth *OAuth) start(name string, postProcessAuth func(string) string, doAuth func(string) error) error { errorMessage := "failed starting OAuth exchange" - if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateTransitionError{ - Got: oauth.FSM.Current, - Want: fsm.OAUTH_STARTED, - }.CustomError(), - } - } // Generate the state state, stateErr := genState() if stateErr != nil { @@ -300,29 +288,22 @@ func (oauth *OAuth) start(name string, postprocessAuth func(string) string) erro // Fill the struct with the necessary fields filled for the next call to getting the HTTP client oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier} oauth.Session = oauthSession - // Run the state callback in the background so that the user can login while we start the callback server - oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, postprocessAuth(authURL), true) + + // Run the auth callback with the authurl processed + doAuthErr := doAuth(postProcessAuth(authURL)) + if doAuthErr != nil { + return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} + } return nil } // Error definitions func (oauth *OAuth) Finish() error { - errorMessage := "failed finishing OAuth" - if !oauth.FSM.HasTransition(fsm.AUTHORIZED) { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateTransitionError{ - Got: oauth.FSM.Current, - Want: fsm.AUTHORIZED, - }.CustomError(), - } - } tokenErr := oauth.getTokensWithCallback() if tokenErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: tokenErr} + return &types.WrappedErrorMessage{Message: "failed finishing OAuth", Err: tokenErr} } - oauth.FSM.GoTransition(fsm.AUTHORIZED) return nil } @@ -334,9 +315,9 @@ func (oauth *OAuth) Cancel() { oauth.Session.Server.Shutdown(oauth.Session.Context) } -func (oauth *OAuth) Login(name string, postprocessAuth func(string) string) error { +func (oauth *OAuth) Login(name string, postprocessAuth func(string) string, doAuth func(string) error) error { errorMessage := "failed OAuth login" - authInitializeErr := oauth.start(name, postprocessAuth) + authInitializeErr := oauth.start(name, postprocessAuth, doAuth) if authInitializeErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr} @@ -350,28 +331,29 @@ func (oauth *OAuth) Login(name string, postprocessAuth func(string) string) erro return nil } -func (oauth *OAuth) NeedsRelogin() bool { - // Access Token or Refresh Tokens empty, definitely needs a relogin - if oauth.Token.Access == "" || oauth.Token.Refresh == "" { - return true +func (oauth *OAuth) EnsureTokens() error { + errorMessage := "failed ensuring OAuth tokens" + // Access Token or Refresh Tokens empty, we can not ensure the tokens + if oauth.Token.Access == "" && oauth.Token.Refresh == "" { + return &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthTokensInvalidError{Cause: "tokens are empty"}} } // We have tokens... - // The tokens are not expired yet - // No relogin is needed + // So they should be valid, re-login not needed if !oauth.isTokensExpired() { - return false + return nil } + // Otherwise try to refresh them and return if successful refreshErr := oauth.getTokensWithRefresh() // We have obtained new tokens with refresh - if refreshErr == nil { - return false + if refreshErr != nil { + // We have failed to ensure the tokens due to refresh not working + return &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthTokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr)}} } - // Otherwise relogin is really needed - return true + return nil } type OAuthCancelledCallbackError struct{} @@ -397,3 +379,11 @@ type OAuthCallbackStateMatchError struct { func (e *OAuthCallbackStateMatchError) Error() string { return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState) } + +type OAuthTokensInvalidError struct { + Cause string +} + +func (e *OAuthTokensInvalidError) Error() string { + return fmt.Sprintf("tokens are invalid due to: %s", e.Cause) +} diff --git a/internal/server/api.go b/internal/server/api.go index 57d91c6..80ecf2e 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -10,7 +10,6 @@ import ( httpw "github.com/jwijenbergh/eduvpn-common/internal/http" "github.com/jwijenbergh/eduvpn-common/internal/types" - "github.com/jwijenbergh/eduvpn-common/internal/util" ) func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) { @@ -51,19 +50,14 @@ func apiAuthorized( url := base.Endpoints.API.V3.API + endpoint - // Ensure we have valid tokens - stateBefore := base.FSM.Current + // Make sure the tokens are valid, this will return an error if re-login is needed oauthErr := EnsureTokens(server) - - // we reset the state so that we go from the authorized state to the state we want - base.FSM.Current = stateBefore - if oauthErr != nil { return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr} } headerKey := "Authorization" - headerValue := fmt.Sprintf("Bearer %s", server.GetOAuth().Token.Access) + headerValue := fmt.Sprintf("Bearer %s", GetHeaderToken(server)) if opts.Headers != nil { opts.Headers.Add(headerKey, headerValue) } else { @@ -86,8 +80,8 @@ func apiAuthorizedRetry( // Only retry authorized if we get a HTTP 401 if errors.As(bodyErr, &error) && error.Status == 401 { - // Tell the method that the token is expired - server.GetOAuth().Token.ExpiredTimestamp = util.GetCurrentTime() + // Mark the token as expired and retry so we trigger the refresh flow + MarkTokenExpired(server) retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts) if retryErr != nil { return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: retryErr} @@ -205,6 +199,7 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, time.Time, err } // This needs no further return value as it's best effort +// FIXME: doAuth should not be needed here func APIDisconnect(server Server) { apiAuthorized(server, http.MethodPost, "/disconnect", nil) } diff --git a/internal/server/common.go b/internal/server/common.go index c1ce074..801c778 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -4,7 +4,6 @@ import ( "fmt" "time" - "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/oauth" "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" @@ -22,7 +21,6 @@ type ServerBase struct { StartTime time.Time `json:"start_time"` EndTime time.Time `json:"expire_time"` Type string `json:"server_type"` - FSM *fsm.FSM `json:"-"` } type ServerType int8 @@ -214,7 +212,6 @@ func (servers *Servers) GetCurrentServerInfo() (*ServerInfoScreen, error) { func (servers *Servers) addInstituteAndCustom( discoServer *types.DiscoveryServer, isCustom bool, - fsm *fsm.FSM, ) (Server, error) { url := discoServer.BaseURL errorMessage := fmt.Sprintf("failed adding institute access server: %s", url) @@ -244,7 +241,6 @@ func (servers *Servers) addInstituteAndCustom( discoServer.DisplayName, discoServer.Type, discoServer.SupportContact, - fsm, ) if instituteInitErr != nil { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: instituteInitErr} @@ -256,16 +252,14 @@ func (servers *Servers) addInstituteAndCustom( func (servers *Servers) AddInstituteAccessServer( instituteServer *types.DiscoveryServer, - fsm *fsm.FSM, ) (Server, error) { - return servers.addInstituteAndCustom(instituteServer, false, fsm) + return servers.addInstituteAndCustom(instituteServer, false) } func (servers *Servers) AddCustomServer( customServer *types.DiscoveryServer, - fsm *fsm.FSM, ) (Server, error) { - return servers.addInstituteAndCustom(customServer, true, fsm) + return servers.addInstituteAndCustom(customServer, true) } func (servers *Servers) GetSecureLocation() string { @@ -274,11 +268,10 @@ func (servers *Servers) GetSecureLocation() string { func (servers *Servers) SetSecureLocation( chosenLocationServer *types.DiscoveryServer, - fsm *fsm.FSM, ) error { errorMessage := "failed to set secure location" // Make sure to add the current location - _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer, fsm) + _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer) if addLocationErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: addLocationErr} @@ -291,12 +284,11 @@ func (servers *Servers) SetSecureLocation( func (servers *Servers) AddSecureInternet( secureOrg *types.DiscoveryOrganization, secureServer *types.DiscoveryServer, - fsm *fsm.FSM, ) (Server, error) { errorMessage := "failed adding secure internet server" // If we have specified an organization ID // We also need to get an authorization template - initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer, fsm) + initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer) if initErr != nil { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr} @@ -342,24 +334,28 @@ func ShouldRenewButton(server Server) bool { return true } -func Login(server Server) error { - return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth()) +func Login(server Server, doAuth func(string) error) error { + return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth(), doAuth) } -func EnsureTokens(server Server) error { - errorMessage := "failed ensuring server tokens" - if server.GetOAuth().NeedsRelogin() { - loginErr := Login(server) +func GetHeaderToken(server Server) string { + return server.GetOAuth().Token.Access +} - if loginErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr} - } +func MarkTokenExpired(server Server) { + server.GetOAuth().Token.ExpiredTimestamp = util.GetCurrentTime() +} + +func EnsureTokens(server Server) error { + ensureErr := server.GetOAuth().EnsureTokens() + if ensureErr != nil { + return &types.WrappedErrorMessage{Message: "failed ensuring server tokens", Err: ensureErr} } return nil } func NeedsRelogin(server Server) bool { - return server.GetOAuth().NeedsRelogin() + return EnsureTokens(server) != nil } func CancelOAuth(server Server) { @@ -466,24 +462,45 @@ func openVPNGetConfig(server Server) (string, string, error) { return configOpenVPN, "openvpn", nil } -func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) { - errorMessage := "failed getting an OpenVPN/WireGuard configuration with a profile" - base, baseErr := server.GetBase() +func HasValidProfile(server Server) (bool, error) { + errorMessage := "failed has valid profile check" + // Get new profiles using the info call + // This does not override the current profile + infoErr := APIInfo(server) + if infoErr != nil { + return false, &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr} + } + + base, baseErr := server.GetBase() if baseErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return false, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } - if !base.FSM.HasTransition(fsm.DISCONNECTED) { - return "", "", &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateTransitionError{ - Got: base.FSM.Current, - Want: fsm.DISCONNECTED, - }.CustomError(), + + // If there was a profile chosen and it doesn't exist anymore, reset it + if base.Profiles.Current != "" { + _, existsProfileErr := getCurrentProfile(server) + if existsProfileErr != nil { + base.Profiles.Current = "" } } - profile, profileErr := getCurrentProfile(server) + // Set the current profile if there is only one profile or profile is already selected + if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" { + // Set the first profile if none is selected + if base.Profiles.Current == "" { + base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID + } + return true, nil + } + + return false, nil +} + +func GetConfig(server Server, forceTCP bool) (string, string, error) { + errorMessage := "failed getting an OpenVPN/WireGuard configuration" + + profile, profileErr := getCurrentProfile(server) if profileErr != nil { return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr} } @@ -518,76 +535,6 @@ func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) return config, configType, nil } -func askForProfileID(server Server) error { - errorMessage := "failed asking for a server profile ID" - base, baseErr := server.GetBase() - - if baseErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} - } - if !base.FSM.HasTransition(fsm.ASK_PROFILE) { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateTransitionError{ - Got: base.FSM.Current, - Want: fsm.ASK_PROFILE, - }.CustomError(), - } - } - base.FSM.GoTransitionWithData(fsm.ASK_PROFILE, &base.Profiles, false) - return nil -} - -func GetConfig(server Server, forceTCP bool) (string, string, error) { - errorMessage := "failed getting an OpenVPN/WireGuard configuration" - base, baseErr := server.GetBase() - - if baseErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} - } - if !base.FSM.InState(fsm.REQUEST_CONFIG) { - return "", "", &types.WrappedErrorMessage{ - Message: errorMessage, - Err: fsm.WrongStateError{ - Got: base.FSM.Current, - Want: fsm.REQUEST_CONFIG, - }.CustomError(), - } - } - - // Get new profiles using the info call - // This does not override the current profile - infoErr := APIInfo(server) - if infoErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr} - } - - // If there was a profile chosen and it doesn't exist anymore, reset it - if base.Profiles.Current != "" { - _, existsProfileErr := getCurrentProfile(server) - if existsProfileErr != nil { - base.Profiles.Current = "" - } - } - - // Set the current profile if there is only one profile or profile is already selected - if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" { - // Set the first profile if none is selected - if base.Profiles.Current == "" { - base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID - } - return getConfigWithProfile(server, forceTCP) - } - - profileErr := askForProfileID(server) - - if profileErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr} - } - - return getConfigWithProfile(server, forceTCP) -} - func Disconnect(server Server) { APIDisconnect(server) } diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index e948480..0cad158 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -3,7 +3,6 @@ package server import ( "fmt" - "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/oauth" "github.com/jwijenbergh/eduvpn-common/internal/types" ) @@ -56,19 +55,17 @@ func (institute *InstituteAccessServer) init( displayName map[string]string, serverType string, supportContact []string, - fsm *fsm.FSM, ) error { errorMessage := fmt.Sprintf("failed initializing institute server %s", url) institute.Base.URL = url institute.Base.DisplayName = displayName institute.Base.SupportContact = supportContact - institute.Base.FSM = fsm institute.Base.Type = serverType endpoints, endpointsErr := APIGetEndpoints(url) if endpointsErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} } - institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token, fsm) + institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token) institute.Base.Endpoints = *endpoints return nil } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 40c429b..d5689a8 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -3,7 +3,6 @@ package server import ( "fmt" - "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/oauth" "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" @@ -70,7 +69,6 @@ func (servers *Servers) HasSecureLocation() bool { func (secure *SecureInternetHomeServer) addLocation( locationServer *types.DiscoveryServer, - fsm *fsm.FSM, ) (*ServerBase, error) { errorMessage := "failed adding a location" // Initialize the base map if it is non-nil @@ -95,9 +93,6 @@ func (secure *SecureInternetHomeServer) addLocation( base.Endpoints = *endpoints } - // Pass the fsm - base.FSM = fsm - // Ensure it is in the map secure.BaseMap[locationServer.CountryCode] = base return base, nil @@ -107,7 +102,6 @@ func (secure *SecureInternetHomeServer) addLocation( func (secure *SecureInternetHomeServer) init( homeOrg *types.DiscoveryOrganization, homeLocation *types.DiscoveryServer, - fsm *fsm.FSM, ) error { errorMessage := "failed initializing secure internet home server" @@ -123,14 +117,14 @@ func (secure *SecureInternetHomeServer) init( // Make sure to set the authorization URL template secure.AuthorizationTemplate = homeLocation.AuthenticationURLTemplate - base, baseErr := secure.addLocation(homeLocation, fsm) + base, baseErr := secure.addLocation(homeLocation) if baseErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } // Make sure oauth contains our endpoints - secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm) + secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token) return nil } |
