summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-05-03 14:10:40 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-05-03 14:10:40 +0200
commit43604f7ffcbbf5b06ae481d2af7e66f6423f183f (patch)
tree6eb74ed54929edcfac61e5ca55078ab6670e0081
parent466450f0c47bdc614e66326d90e5fc6fb56ae732 (diff)
Refactor: Secure internet into a different type but with interface
-rw-r--r--internal/api.go54
-rw-r--r--internal/oauth.go29
-rw-r--r--internal/openvpn.go11
-rw-r--r--internal/server.go366
-rw-r--r--internal/wireguard.go12
-rw-r--r--state.go49
-rw-r--r--state_test.go28
7 files changed, 379 insertions, 170 deletions
diff --git a/internal/api.go b/internal/api.go
index da17f76..a987f00 100644
--- a/internal/api.go
+++ b/internal/api.go
@@ -10,22 +10,28 @@ import (
// Authorized wrappers on top of HTTP
// the errors will not be wrapped here so that the caller can check if we got a status error, to retry oauth
-func (server *Server) apiAuthorized(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) {
+func apiAuthorized(server Server, method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) {
// Ensure optional is not nil as we will fill it with headers
if opts == nil {
opts = &HTTPOptionalParams{}
}
- url := server.Endpoints.API.V3.API + endpoint
+ base, baseErr := server.GetBase()
+
+ if baseErr != nil {
+ return nil, nil, baseErr
+ }
+
+ url := base.Endpoints.API.V3.API + endpoint
// Ensure we have valid tokens
- oauthErr := server.EnsureTokens()
+ oauthErr := EnsureTokens(server)
if oauthErr != nil {
return nil, nil, oauthErr
}
headerKey := "Authorization"
- headerValue := fmt.Sprintf("Bearer %s", server.OAuth.Token.Access)
+ headerValue := fmt.Sprintf("Bearer %s", server.GetOAuth().Token.Access)
if opts.Headers != nil {
opts.Headers.Add(headerKey, headerValue)
} else {
@@ -34,17 +40,22 @@ func (server *Server) apiAuthorized(method string, endpoint string, opts *HTTPOp
return HTTPMethodWithOpts(method, url, opts)
}
-func (server *Server) apiAuthorizedRetry(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) {
- header, body, bodyErr := server.apiAuthorized(method, endpoint, opts)
+func apiAuthorizedRetry(server Server, method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) {
+ header, body, bodyErr := apiAuthorized(server, method, endpoint, opts)
+ base, baseErr := server.GetBase()
+
+ if baseErr != nil {
+ return nil, nil, &APIAuthorizedError{Err: baseErr}
+ }
if bodyErr != nil {
var error *HTTPStatusError
// Only retry authorized if we get a HTTP 401
if errors.As(bodyErr, &error) && error.Status == 401 {
- server.Logger.Log(LOG_INFO, fmt.Sprintf("API: Got HTTP error %v, retrying authorized", error))
+ base.Logger.Log(LOG_INFO, fmt.Sprintf("API: Got HTTP error %v, retrying authorized", error))
// Tell the method that the token is expired
- server.OAuth.Token.ExpiredTimestamp = GenerateTimeSeconds()
- retryHeader, retryBody, retryErr := server.apiAuthorized(method, endpoint, opts)
+ server.GetOAuth().Token.ExpiredTimestamp = GenerateTimeSeconds()
+ retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts)
if retryErr != nil {
return nil, nil, &APIAuthorizedError{Err: retryErr}
}
@@ -55,8 +66,8 @@ func (server *Server) apiAuthorizedRetry(method string, endpoint string, opts *H
return header, body, nil
}
-func (server *Server) APIInfo() error {
- _, body, bodyErr := server.apiAuthorizedRetry(http.MethodGet, "/info", nil)
+func APIInfo(server Server) error {
+ _, body, bodyErr := apiAuthorizedRetry(server, http.MethodGet, "/info", nil)
if bodyErr != nil {
return &APIInfoError{Err: bodyErr}
}
@@ -67,12 +78,17 @@ func (server *Server) APIInfo() error {
return &APIInfoError{Err: jsonErr}
}
- server.Profiles = structure
- server.ProfilesRaw = string(body)
+ base, baseErr := server.GetBase()
+
+ if baseErr != nil {
+ return &APIInfoError{Err: baseErr}
+ }
+ base.Profiles = structure
+ base.ProfilesRaw = string(body)
return nil
}
-func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (string, string, error) {
+func APIConnectWireguard(server Server, profile_id string, pubkey string) (string, string, error) {
headers := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
"accept": {"application/x-wireguard-profile"},
@@ -82,7 +98,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str
"profile_id": {profile_id},
"public_key": {pubkey},
}
- header, connectBody, connectErr := server.apiAuthorizedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
+ header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
return "", "", &APIConnectWireguardError{Err: connectErr}
}
@@ -91,7 +107,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str
return string(connectBody), expires, nil
}
-func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, error) {
+func APIConnectOpenVPN(server Server, profile_id string) (string, string, error) {
headers := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
"accept": {"application/x-openvpn-profile"},
@@ -100,7 +116,7 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro
urlForm := url.Values{
"profile_id": {profile_id},
}
- header, connectBody, connectErr := server.apiAuthorizedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
+ header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
return "", "", &APIConnectOpenVPNError{Err: connectErr}
}
@@ -110,8 +126,8 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro
}
// This needs no further return value as it's best effort
-func (server *Server) APIDisconnect() {
- server.apiAuthorizedRetry(http.MethodPost, "/disconnect", nil)
+func APIDisconnect(server Server) {
+ apiAuthorizedRetry(server, http.MethodPost, "/disconnect", nil)
}
type APIAuthorizedError struct {
diff --git a/internal/oauth.go b/internal/oauth.go
index c13ea99..c566425 100644
--- a/internal/oauth.go
+++ b/internal/oauth.go
@@ -52,11 +52,12 @@ func genVerifier() (string, error) {
}
type OAuth struct {
- Session OAuthExchangeSession `json:"-"`
- Token OAuthToken `json:"token"`
- TokenURL string `json:"token_url"`
- Logger *FileLogger `json:"-"`
- FSM *FSM `json:"-"`
+ Session OAuthExchangeSession `json:"-"`
+ Token OAuthToken `json:"token"`
+ BaseAuthorizationURL string `json:"base_authorization_url"`
+ TokenURL string `json:"token_url"`
+ Logger *FileLogger `json:"-"`
+ FSM *FSM `json:"-"`
}
// This structure gets passed to the callback for easy access to the current state
@@ -216,13 +217,20 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
go oauth.Session.Server.Shutdown(oauth.Session.Context)
}
-func (oauth *OAuth) Init(fsm *FSM, logger *FileLogger) {
+func (oauth *OAuth) Update(fsm *FSM, logger *FileLogger) {
+ oauth.FSM = fsm
+ oauth.Logger = logger
+}
+
+func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *FSM, logger *FileLogger) {
+ oauth.BaseAuthorizationURL = baseAuthorizationURL
+ oauth.TokenURL = tokenURL
oauth.FSM = fsm
oauth.Logger = logger
}
// Starts the OAuth exchange for eduvpn.
-func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string) error {
+func (oauth *OAuth) start(name string) error {
if !oauth.FSM.HasTransition(OAUTH_STARTED) {
return &FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: OAUTH_STARTED}
}
@@ -249,7 +257,7 @@ func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string)
"redirect_uri": "http://127.0.0.1:8000/callback",
}
- authURL, urlErr := HTTPConstructURL(authorizationURL, parameters)
+ authURL, urlErr := HTTPConstructURL(oauth.BaseAuthorizationURL, parameters)
if urlErr != nil {
return &OAuthInitializeError{Err: urlErr}
@@ -257,7 +265,6 @@ func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string)
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier}
- oauth.TokenURL = tokenURL
oauth.Session = oauthSession
// Run the state callback in the background so that the user can login while we start the callback server
oauth.FSM.GoTransitionWithData(OAUTH_STARTED, authURL, true)
@@ -283,8 +290,8 @@ func (oauth *OAuth) Cancel() {
oauth.Session.Server.Shutdown(oauth.Session.Context)
}
-func (oauth *OAuth) Login(name string, authorizationURL string, tokenURL string) error {
- authInitializeErr := oauth.start(name, authorizationURL, tokenURL)
+func (oauth *OAuth) Login(name string) error {
+ authInitializeErr := oauth.start(name)
if authInitializeErr != nil {
return &OAuthLoginError{Err: authInitializeErr}
diff --git a/internal/openvpn.go b/internal/openvpn.go
index ed31fe2..45beb51 100644
--- a/internal/openvpn.go
+++ b/internal/openvpn.go
@@ -2,9 +2,14 @@ package internal
import "fmt"
-func (server *Server) OpenVPNGetConfig() (string, error) {
- profile_id := server.Profiles.Current
- configOpenVPN, _, configErr := server.APIConnectOpenVPN(profile_id)
+func OpenVPNGetConfig(server Server) (string, error) {
+ base, baseErr := server.GetBase()
+
+ if baseErr != nil {
+ return "", &OpenVPNGetConfigError{Err: baseErr}
+ }
+ profile_id := base.Profiles.Current
+ configOpenVPN, _, configErr := APIConnectOpenVPN(server, profile_id)
if configErr != nil {
return "", &OpenVPNGetConfigError{Err: configErr}
diff --git a/internal/server.go b/internal/server.go
index 489719e..7500a26 100644
--- a/internal/server.go
+++ b/internal/server.go
@@ -2,114 +2,225 @@ package internal
import (
"encoding/json"
- "errors"
"fmt"
)
-type Server struct {
- BaseURL string `json:"base_url"`
+// The base type for servers
+type ServerBase struct {
+ URL string `json:"base_url"`
Endpoints ServerEndpoints `json:"endpoints"`
- OAuth OAuth `json:"oauth"`
Profiles ServerProfileInfo `json:"profiles"`
ProfilesRaw string `json:"profiles_raw"`
Logger *FileLogger `json:"-"`
FSM *FSM `json:"-"`
}
-type Servers struct {
- List map[string]*Server `json:"list"`
- Current string `json:"current"`
- SecureHome string `json:"secure_home"`
+// An instute access server
+type InstituteAccessServer struct {
+ // An instute access server has its own OAuth
+ OAuth OAuth `json:"oauth"`
+
+ // Embed the server base
+ Base ServerBase `json:"base"`
+}
+
+// A secure internet server which has its own OAuth tokens
+// It specifies the current location url it is connected to
+type SecureInternetHomeServer struct {
+ OAuth OAuth `json:"oauth"`
+
+ // The home server has a list of info for each configured server
+ BaseMap map[string]*ServerBase `json:"base_map"`
+
+ // We have the home url and the current url
+ HomeURL string `json:"home_url"`
+ CurrentURL string `json:"current_url"`
}
-func (servers *Servers) GetCurrentServer() (*Server, error) {
- if servers.List == nil {
+type InstituteServers struct {
+ Map map[string]*InstituteAccessServer `json:"map"`
+ CurrentURL string `json:"current_url"`
+}
+
+func (servers *Servers) GetCurrentServer() (Server, error) {
+ if servers.IsSecureInternet {
+ return &servers.SecureInternetHomeServer, nil
+ }
+ currentInstitute := servers.InstituteServers.CurrentURL
+ institutes := servers.InstituteServers.Map
+ if institutes == nil {
return nil, &ServerGetCurrentNoMapError{}
}
- server, exists := servers.List[servers.Current]
+ institute, exists := institutes[currentInstitute]
- if !exists || server == nil {
+ if !exists || institute == nil {
return nil, &ServerGetCurrentNotFoundError{}
}
- return server, nil
+ return institute, nil
}
-func (server *Server) CancelOAuth() {
- server.OAuth.Cancel()
+type Servers struct {
+ InstituteServers InstituteServers `json:"institute_servers"`
+ SecureInternetHomeServer SecureInternetHomeServer `json:"secure_internet_home"`
+ IsSecureInternet bool `json:"is_secure_internet"`
}
-func (server *Server) Init(url string, fsm *FSM, logger *FileLogger) error {
- server.BaseURL = url
- server.FSM = fsm
- server.Logger = logger
- server.OAuth.Init(fsm, logger)
- endpointsErr := server.GetEndpoints()
- if endpointsErr != nil {
- return &ServerInitializeError{URL: url, Err: endpointsErr}
+type Server interface {
+ // Gets the current OAuth object
+ GetOAuth() *OAuth
+
+ // Gets the server base
+ GetBase() (*ServerBase, error)
+
+ // initialize method
+ init(url string, fsm *FSM, logger *FileLogger) error
+}
+
+// For an institute, we can simply get the OAuth
+func (institute *InstituteAccessServer) GetOAuth() *OAuth {
+ return &institute.OAuth
+}
+
+func (secure *SecureInternetHomeServer) GetOAuth() *OAuth {
+ return &secure.OAuth
+}
+
+func (institute *InstituteAccessServer) GetBase() (*ServerBase, error) {
+ return &institute.Base, nil
+}
+
+func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) {
+ if server.BaseMap == nil {
+ return nil, &ServerSecureInternetMapNotFoundError{}
}
- return nil
+
+ base, exists := server.BaseMap[server.CurrentURL]
+
+ if !exists {
+ return nil, &ServerSecureInternetBaseNotFoundError{Current: server.CurrentURL}
+ }
+ return base, nil
}
-func (server *Server) EnsureTokens() error {
- if server.OAuth.NeedsRelogin() {
- server.Logger.Log(LOG_INFO, "OAuth: Tokens are invalid, relogging in")
- return server.Login()
+func (institute *InstituteAccessServer) init(url string, fsm *FSM, logger *FileLogger) error {
+ institute.Base.URL = url
+ institute.Base.FSM = fsm
+ institute.Base.Logger = logger
+ endpoints, endpointsErr := getEndpoints(url)
+ if endpointsErr != nil {
+ return &ServerInitializeError{URL: url, Err: endpointsErr}
}
+ institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token, fsm, logger)
+ institute.Base.Endpoints = *endpoints
return nil
}
-func (servers *Servers) EnsureServer(url string, fsm *FSM, logger *FileLogger, makeCurrent bool) (*Server, error) {
- if url == "" {
- return nil, &ServerEnsureServerEmptyURLError{}
+func (secure *SecureInternetHomeServer) init(url string, fsm *FSM, logger *FileLogger) error {
+ // Initialize the base map if it is non-nil
+ if secure.BaseMap == nil {
+ secure.BaseMap = make(map[string]*ServerBase)
}
- if servers.List == nil {
- servers.List = make(map[string]*Server)
+
+ // Add it if not present
+ base, exists := secure.BaseMap[url]
+
+ if !exists || base == nil {
+ // Create the base to be added to the map
+ base = &ServerBase{}
+ base.URL = url
+ endpoints, endpointsErr := getEndpoints(url)
+ if endpointsErr != nil {
+ return &ServerInitializeError{URL: url, Err: endpointsErr}
+ }
+ base.Endpoints = *endpoints
}
- server, exists := servers.List[url]
+ // Pass the fsm and logger
+ base.FSM = fsm
+ base.Logger = logger
+
+ // Ensure it is in the map
+ secure.BaseMap[url] = base
- if !exists || server == nil {
- server = &Server{}
+ // Set the home url if it is not set yet
+ if secure.HomeURL == "" {
+ secure.HomeURL = url
+ // Make sure oauth contains our endpoints
+ secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm, logger)
+ } else { // Else just pass in the fsm and logger
+ secure.OAuth.Update(fsm, logger)
}
- serverInitErr := server.Init(url, fsm, logger)
- if serverInitErr != nil {
- return nil, &ServerEnsureServerError{Err: serverInitErr}
+ // Set the current url
+ secure.CurrentURL = url
+ return nil
+}
+
+func Login(server Server) error {
+ return server.GetOAuth().Login("org.eduvpn.app.linux")
+}
+
+func EnsureTokens(server Server) error {
+ base, baseErr := server.GetBase()
+
+ if baseErr != nil {
+ return &ServerEnsureTokensError{Err: baseErr}
}
- servers.List[url] = server
+ if server.GetOAuth().NeedsRelogin() {
+ base.Logger.Log(LOG_INFO, "OAuth: Tokens are invalid, relogging in")
+ loginErr := Login(server)
- if makeCurrent {
- servers.Current = url
+ if loginErr != nil {
+ return &ServerEnsureTokensError{Err: loginErr}
+ }
}
- return server, nil
+ return nil
}
-func (servers *Servers) getSecureInternetHome() (*Server, error) {
- server, exists := servers.List[servers.SecureHome]
+func NeedsRelogin(server Server) bool {
+ return server.GetOAuth().NeedsRelogin()
+}
+
+func CancelOAuth(server Server) {
+ server.GetOAuth().Cancel()
+}
+
+func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *FSM, logger *FileLogger) (Server, error) {
+ // Intialize the secure internet server
+ // This calls the init method which takes care of the rest
+ if isSecureInternet {
+ initErr := servers.SecureInternetHomeServer.init(url, fsm, logger)
+
+ if initErr != nil {
+ return nil, &ServerEnsureServerError{Err: initErr}
+ }
- if !exists || server == nil {
- return nil, &ServerGetSecureInternetHomeError{}
+ servers.IsSecureInternet = true
+ return &servers.SecureInternetHomeServer, nil
}
- return server, nil
-}
+ instituteServers := &servers.InstituteServers
-func (servers *Servers) EnsureSecureHome(server *Server) {
- if servers.SecureHome == "" {
- servers.SecureHome = server.BaseURL
+ if instituteServers.Map == nil {
+ instituteServers.Map = make(map[string]*InstituteAccessServer)
}
-}
-func (servers *Servers) CopySecureInternetOAuth(server *Server) error {
- secureHome, secureHomeErr := servers.getSecureInternetHome()
+ institute, exists := instituteServers.Map[url]
- if secureHomeErr != nil {
- return &ServerCopySecureInternetOAuthError{Err: secureHomeErr}
+ // initialize the server if it doesn't exist yet
+ if !exists {
+ institute = &InstituteAccessServer{}
}
- // Forward token properties
- server.OAuth = secureHome.OAuth
- return nil
+ // Set the current server
+ instituteServers.CurrentURL = url
+ instituteInitErr := institute.init(url, fsm, logger)
+ if instituteInitErr != nil {
+ return nil, &ServerEnsureServerError{Err: instituteInitErr}
+ }
+ instituteServers.Map[url] = institute
+ servers.IsSecureInternet = false
+ return institute, nil
}
type ServerProfile struct {
@@ -141,33 +252,22 @@ type ServerEndpoints struct {
V string `json:"v"`
}
-func (server *Server) Login() error {
- return server.OAuth.Login("org.eduvpn.app.linux", server.Endpoints.API.V3.Authorization, server.Endpoints.API.V3.Token)
-}
-
-func (server *Server) NeedsRelogin() bool {
- // Check if OAuth needs relogin
- return server.OAuth.NeedsRelogin()
-}
-
-func (server *Server) GetEndpoints() error {
- url := server.BaseURL + "/.well-known/vpn-user-portal"
+func getEndpoints(baseURL string) (*ServerEndpoints, error) {
+ url := fmt.Sprintf("%s/.well-known/vpn-user-portal", baseURL)
_, body, bodyErr := HTTPGet(url)
if bodyErr != nil {
- return &ServerGetEndpointsError{Err: bodyErr}
+ return nil, &ServerGetEndpointsError{Err: bodyErr}
}
- endpoints := ServerEndpoints{}
- jsonErr := json.Unmarshal(body, &endpoints)
+ endpoints := &ServerEndpoints{}
+ jsonErr := json.Unmarshal(body, endpoints)
if jsonErr != nil {
- return &ServerGetEndpointsError{Err: jsonErr}
+ return nil, &ServerGetEndpointsError{Err: jsonErr}
}
- server.Endpoints = endpoints
-
- return nil
+ return endpoints, nil
}
func (profile *ServerProfile) supportsWireguard() bool {
@@ -179,9 +279,14 @@ func (profile *ServerProfile) supportsWireguard() bool {
return false
}
-func (server *Server) getCurrentProfile() (*ServerProfile, error) {
- profileID := server.Profiles.Current
- for _, profile := range server.Profiles.Info.ProfileList {
+func getCurrentProfile(server Server) (*ServerProfile, error) {
+ base, baseErr := server.GetBase()
+
+ if baseErr != nil {
+ return nil, &ServerGetCurrentProfileError{Err: baseErr}
+ }
+ profileID := base.Profiles.Current
+ for _, profile := range base.Profiles.Info.ProfileList {
if profile.ID == profileID {
return &profile, nil
}
@@ -189,53 +294,68 @@ func (server *Server) getCurrentProfile() (*ServerProfile, error) {
return nil, &ServerGetCurrentProfileNotFoundError{ProfileID: profileID}
}
-func (server *Server) getConfigWithProfile() (string, error) {
- if !server.FSM.HasTransition(HAS_CONFIG) {
- return "", &FSMWrongStateTransitionError{Got: server.FSM.Current, Want: HAS_CONFIG}
+func getConfigWithProfile(server Server) (string, error) {
+ base, baseErr := server.GetBase()
+
+ if baseErr != nil {
+ return "", &ServerGetConfigWithProfileError{Err: baseErr}
+ }
+ if !base.FSM.HasTransition(HAS_CONFIG) {
+ return "", &FSMWrongStateTransitionError{Got: base.FSM.Current, Want: HAS_CONFIG}
}
- profile, profileErr := server.getCurrentProfile()
+ profile, profileErr := getCurrentProfile(server)
if profileErr != nil {
return "", &ServerGetConfigWithProfileError{Err: profileErr}
}
if profile.supportsWireguard() {
- return server.WireguardGetConfig()
+ return WireguardGetConfig(server)
}
- return server.OpenVPNGetConfig()
+ return OpenVPNGetConfig(server)
}
-func (server *Server) askForProfileID() error {
- if !server.FSM.HasTransition(ASK_PROFILE) {
- return &FSMWrongStateTransitionError{Got: server.FSM.Current, Want: ASK_PROFILE}
+func askForProfileID(server Server) error {
+ base, baseErr := server.GetBase()
+
+ if baseErr != nil {
+ return &ServerAskForProfileIDError{Err: baseErr}
}
- server.FSM.GoTransitionWithData(ASK_PROFILE, server.ProfilesRaw, false)
+ if !base.FSM.HasTransition(ASK_PROFILE) {
+ return &FSMWrongStateTransitionError{Got: base.FSM.Current, Want: ASK_PROFILE}
+ }
+ base.FSM.GoTransitionWithData(ASK_PROFILE, base.ProfilesRaw, false)
return nil
}
-func (server *Server) GetConfig() (string, error) {
- if !server.FSM.InState(REQUEST_CONFIG) {
- return "", errors.New(fmt.Sprintf("cannot get a config, invalid state %s", server.FSM.Current.String()))
+func GetConfig(server Server) (string, error) {
+ base, baseErr := server.GetBase()
+
+ if baseErr != nil {
+ return "", &ServerGetConfigError{Err: baseErr}
+ }
+ if !base.FSM.InState(REQUEST_CONFIG) {
+ return "", &FSMWrongStateError{Got: base.FSM.Current, Want: REQUEST_CONFIG}
}
- infoErr := server.APIInfo()
+ infoErr := APIInfo(server)
if infoErr != nil {
return "", &ServerGetConfigError{Err: infoErr}
}
// Set the current profile if there is only one profile
- if len(server.Profiles.Info.ProfileList) == 1 {
- server.Profiles.Current = server.Profiles.Info.ProfileList[0].ID
- return server.getConfigWithProfile()
+ if len(base.Profiles.Info.ProfileList) == 1 {
+ base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID
+ return getConfigWithProfile(server)
}
- profileErr := server.askForProfileID()
+ profileErr := askForProfileID(server)
if profileErr != nil {
return "", &ServerGetConfigError{Err: profileErr}
}
- return server.getConfigWithProfile()
+ return getConfigWithProfile(server)
}
type ServerGetCurrentProfileNotFoundError struct {
@@ -318,3 +438,49 @@ type ServerInitializeError struct {
func (e *ServerInitializeError) Error() string {
return fmt.Sprintf("failed initializing server with url %s and error %v", e.URL, e.Err)
}
+
+type ServerInstituteBaseNotFoundError struct {
+ Err error
+}
+
+func (e *ServerInstituteBaseNotFoundError) Error() string {
+ return "institute base not found"
+}
+
+type ServerSecureInternetMapNotFoundError struct{}
+
+func (e *ServerSecureInternetMapNotFoundError) Error() string {
+ return "secure internet map not found"
+}
+
+type ServerSecureInternetBaseNotFoundError struct {
+ Current string
+}
+
+func (e *ServerSecureInternetBaseNotFoundError) Error() string {
+ return fmt.Sprintf("secure internet base not found with current: %s", e.Current)
+}
+
+type ServerGetCurrentProfileError struct {
+ Err error
+}
+
+func (e *ServerGetCurrentProfileError) Error() string {
+ return fmt.Sprintf("failed getting current profile with error: %v", e.Err)
+}
+
+type ServerAskForProfileIDError struct {
+ Err error
+}
+
+func (e *ServerAskForProfileIDError) Error() string {
+ return fmt.Sprintf("ask for profile ID error: %v", e.Err)
+}
+
+type ServerEnsureTokensError struct {
+ Err error
+}
+
+func (e *ServerEnsureTokensError) Error() string {
+ return fmt.Sprintf("failed ensuring tokens with error: %v", e.Err)
+}
diff --git a/internal/wireguard.go b/internal/wireguard.go
index 7977dbc..318e0dc 100644
--- a/internal/wireguard.go
+++ b/internal/wireguard.go
@@ -30,8 +30,14 @@ func wireguardConfigAddKey(config string, key wgtypes.Key) string {
return interface_re.ReplaceAllString(config, to_replace)
}
-func (server *Server) WireguardGetConfig() (string, error) {
- profile_id := server.Profiles.Current
+func WireguardGetConfig(server Server) (string, error) {
+ base, baseErr := server.GetBase()
+
+ if baseErr != nil {
+ return "", &WireguardGetConfigError{Err: baseErr}
+ }
+
+ profile_id := base.Profiles.Current
wireguardKey, wireguardErr := wireguardGenerateKey()
if wireguardErr != nil {
@@ -39,7 +45,7 @@ func (server *Server) WireguardGetConfig() (string, error) {
}
wireguardPublicKey := wireguardKey.PublicKey().String()
- configWireguard, _, configErr := server.APIConnectWireguard(profile_id, wireguardPublicKey)
+ configWireguard, _, configErr := APIConnectWireguard(server, profile_id, wireguardPublicKey)
if configErr != nil {
return "", &WireguardGetConfigError{Err: wireguardErr}
diff --git a/state.go b/state.go
index 767425c..47f23df 100644
--- a/state.go
+++ b/state.go
@@ -83,36 +83,38 @@ func (state *VPNState) CancelOAuth() error {
if serverErr != nil {
return &StateOAuthCancelError{Err: serverErr}
}
- server.CancelOAuth()
+ internal.CancelOAuth(server)
return nil
}
-func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (string, error) {
- if state.FSM.InState(internal.DEREGISTERED) {
- return "", &StateFSMNotRegisteredError{}
- }
+func (state *VPNState) chooseServer(url string, isSecureInternet bool) (internal.Server, error) {
// New server chosen, ensure the server is fresh
- server, serverErr := state.Servers.EnsureServer(url, &state.FSM, &state.Logger, true)
+ server, serverErr := state.Servers.EnsureServer(url, isSecureInternet, &state.FSM, &state.Logger)
if serverErr != nil {
- return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: serverErr}
+ return nil, serverErr
}
- // When we connect to secure internet, copy over the tokens from the home server
- if isSecureInternet {
- // Ensure the secure home server
- state.Servers.EnsureServer(state.Servers.SecureHome, &state.FSM, &state.Logger, false)
+ // Make sure we are in the chosen state if available
+ state.FSM.GoTransition(internal.CHOSEN_SERVER)
+ return server, nil
+}
- // Copy the tokens
- state.Servers.CopySecureInternetOAuth(server)
+func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (string, error) {
+ if state.FSM.InState(internal.DEREGISTERED) {
+ return "", &StateFSMNotRegisteredError{}
}
- // Make sure we are in the chosen state if available
- state.FSM.GoTransition(internal.CHOSEN_SERVER)
+ // Make sure the server is chosen
+ server, serverErr := state.chooseServer(url, isSecureInternet)
+
+ if serverErr != nil {
+ return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: serverErr}
+ }
// Relogin with oauth
// This moves the state to authorized
- if server.NeedsRelogin() {
- loginErr := server.Login()
+ if internal.NeedsRelogin(server) {
+ loginErr := internal.Login(server)
if loginErr != nil {
// We are possibly in oauth started
@@ -124,12 +126,9 @@ func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (st
state.FSM.GoTransition(internal.AUTHORIZED)
}
- // Set the home server if it is not set already
- state.Servers.EnsureSecureHome(server)
-
state.FSM.GoTransition(internal.REQUEST_CONFIG)
- config, configErr := server.GetConfig()
+ config, configErr := internal.GetConfig(server)
if configErr != nil {
return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: configErr}
@@ -171,7 +170,13 @@ func (state *VPNState) SetProfileID(profileID string) error {
if serverErr != nil {
return &StateSetProfileError{ProfileID: profileID, Err: serverErr}
}
- server.Profiles.Current = profileID
+
+ base, baseErr := server.GetBase()
+
+ if baseErr != nil {
+ return &StateSetProfileError{ProfileID: profileID, Err: baseErr}
+ }
+ base.Profiles.Current = profileID
return nil
}
diff --git a/state_test.go b/state_test.go
index c6e33e0..c174c72 100644
--- a/state_test.go
+++ b/state_test.go
@@ -161,29 +161,31 @@ func Test_token_expired(t *testing.T) {
_, configErr := state.ConnectInstituteAccess(serverURI)
if configErr != nil {
- t.Errorf("Connect error before expired: %v", configErr)
+ t.Fatalf("Connect error before expired: %v", configErr)
}
server, serverErr := state.Servers.GetCurrentServer()
if serverErr != nil {
- t.Errorf("No server found")
+ t.Fatalf("No server found")
}
- accessToken := server.OAuth.Token.Access
- refreshToken := server.OAuth.Token.Refresh
+ oauth := server.GetOAuth()
+
+ accessToken := oauth.Token.Access
+ refreshToken := oauth.Token.Refresh
// Wait for TTL so that the tokens expire
time.Sleep(time.Duration(expiredInt) * time.Second)
- infoErr := server.APIInfo()
+ infoErr := internal.APIInfo(server)
if infoErr != nil {
t.Errorf("Info error after expired: %v", infoErr)
}
// Check if tokens have changed
- accessTokenAfter := server.OAuth.Token.Access
- refreshTokenAfter := server.OAuth.Token.Refresh
+ accessTokenAfter := oauth.Token.Access
+ refreshTokenAfter := oauth.Token.Refresh
if accessToken == accessTokenAfter {
t.Errorf("Access token is the same after refresh")
@@ -221,21 +223,23 @@ func Test_token_invalid(t *testing.T) {
return
}
+ oauth := server.GetOAuth()
+
// Override tokens with invalid values
- server.OAuth.Token.Access = dummy_value
- server.OAuth.Token.Refresh = dummy_value
+ oauth.Token.Access = dummy_value
+ oauth.Token.Refresh = dummy_value
- infoErr := server.APIInfo()
+ infoErr := internal.APIInfo(server)
if infoErr != nil {
t.Errorf("Info error after invalid: %v", infoErr)
}
- if server.OAuth.Token.Access == dummy_value {
+ if oauth.Token.Access == dummy_value {
t.Errorf("Access token is equal to dummy value: %s", dummy_value)
}
- if server.OAuth.Token.Refresh == dummy_value {
+ if oauth.Token.Refresh == dummy_value {
t.Errorf("Refresh token is equal to dummy value: %s", dummy_value)
}
}