diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-05-03 14:10:40 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-05-03 14:10:40 +0200 |
| commit | 43604f7ffcbbf5b06ae481d2af7e66f6423f183f (patch) | |
| tree | 6eb74ed54929edcfac61e5ca55078ab6670e0081 | |
| parent | 466450f0c47bdc614e66326d90e5fc6fb56ae732 (diff) | |
Refactor: Secure internet into a different type but with interface
| -rw-r--r-- | internal/api.go | 54 | ||||
| -rw-r--r-- | internal/oauth.go | 29 | ||||
| -rw-r--r-- | internal/openvpn.go | 11 | ||||
| -rw-r--r-- | internal/server.go | 366 | ||||
| -rw-r--r-- | internal/wireguard.go | 12 | ||||
| -rw-r--r-- | state.go | 49 | ||||
| -rw-r--r-- | state_test.go | 28 |
7 files changed, 379 insertions, 170 deletions
diff --git a/internal/api.go b/internal/api.go index da17f76..a987f00 100644 --- a/internal/api.go +++ b/internal/api.go @@ -10,22 +10,28 @@ import ( // Authorized wrappers on top of HTTP // the errors will not be wrapped here so that the caller can check if we got a status error, to retry oauth -func (server *Server) apiAuthorized(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { +func apiAuthorized(server Server, method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { // Ensure optional is not nil as we will fill it with headers if opts == nil { opts = &HTTPOptionalParams{} } - url := server.Endpoints.API.V3.API + endpoint + base, baseErr := server.GetBase() + + if baseErr != nil { + return nil, nil, baseErr + } + + url := base.Endpoints.API.V3.API + endpoint // Ensure we have valid tokens - oauthErr := server.EnsureTokens() + oauthErr := EnsureTokens(server) if oauthErr != nil { return nil, nil, oauthErr } headerKey := "Authorization" - headerValue := fmt.Sprintf("Bearer %s", server.OAuth.Token.Access) + headerValue := fmt.Sprintf("Bearer %s", server.GetOAuth().Token.Access) if opts.Headers != nil { opts.Headers.Add(headerKey, headerValue) } else { @@ -34,17 +40,22 @@ func (server *Server) apiAuthorized(method string, endpoint string, opts *HTTPOp return HTTPMethodWithOpts(method, url, opts) } -func (server *Server) apiAuthorizedRetry(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { - header, body, bodyErr := server.apiAuthorized(method, endpoint, opts) +func apiAuthorizedRetry(server Server, method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { + header, body, bodyErr := apiAuthorized(server, method, endpoint, opts) + base, baseErr := server.GetBase() + + if baseErr != nil { + return nil, nil, &APIAuthorizedError{Err: baseErr} + } if bodyErr != nil { var error *HTTPStatusError // Only retry authorized if we get a HTTP 401 if errors.As(bodyErr, &error) && error.Status == 401 { - server.Logger.Log(LOG_INFO, fmt.Sprintf("API: Got HTTP error %v, retrying authorized", error)) + base.Logger.Log(LOG_INFO, fmt.Sprintf("API: Got HTTP error %v, retrying authorized", error)) // Tell the method that the token is expired - server.OAuth.Token.ExpiredTimestamp = GenerateTimeSeconds() - retryHeader, retryBody, retryErr := server.apiAuthorized(method, endpoint, opts) + server.GetOAuth().Token.ExpiredTimestamp = GenerateTimeSeconds() + retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts) if retryErr != nil { return nil, nil, &APIAuthorizedError{Err: retryErr} } @@ -55,8 +66,8 @@ func (server *Server) apiAuthorizedRetry(method string, endpoint string, opts *H return header, body, nil } -func (server *Server) APIInfo() error { - _, body, bodyErr := server.apiAuthorizedRetry(http.MethodGet, "/info", nil) +func APIInfo(server Server) error { + _, body, bodyErr := apiAuthorizedRetry(server, http.MethodGet, "/info", nil) if bodyErr != nil { return &APIInfoError{Err: bodyErr} } @@ -67,12 +78,17 @@ func (server *Server) APIInfo() error { return &APIInfoError{Err: jsonErr} } - server.Profiles = structure - server.ProfilesRaw = string(body) + base, baseErr := server.GetBase() + + if baseErr != nil { + return &APIInfoError{Err: baseErr} + } + base.Profiles = structure + base.ProfilesRaw = string(body) return nil } -func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (string, string, error) { +func APIConnectWireguard(server Server, profile_id string, pubkey string) (string, string, error) { headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-wireguard-profile"}, @@ -82,7 +98,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str "profile_id": {profile_id}, "public_key": {pubkey}, } - header, connectBody, connectErr := server.apiAuthorizedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) + header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { return "", "", &APIConnectWireguardError{Err: connectErr} } @@ -91,7 +107,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str return string(connectBody), expires, nil } -func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, error) { +func APIConnectOpenVPN(server Server, profile_id string) (string, string, error) { headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-openvpn-profile"}, @@ -100,7 +116,7 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro urlForm := url.Values{ "profile_id": {profile_id}, } - header, connectBody, connectErr := server.apiAuthorizedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) + header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { return "", "", &APIConnectOpenVPNError{Err: connectErr} } @@ -110,8 +126,8 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro } // This needs no further return value as it's best effort -func (server *Server) APIDisconnect() { - server.apiAuthorizedRetry(http.MethodPost, "/disconnect", nil) +func APIDisconnect(server Server) { + apiAuthorizedRetry(server, http.MethodPost, "/disconnect", nil) } type APIAuthorizedError struct { diff --git a/internal/oauth.go b/internal/oauth.go index c13ea99..c566425 100644 --- a/internal/oauth.go +++ b/internal/oauth.go @@ -52,11 +52,12 @@ func genVerifier() (string, error) { } type OAuth struct { - Session OAuthExchangeSession `json:"-"` - Token OAuthToken `json:"token"` - TokenURL string `json:"token_url"` - Logger *FileLogger `json:"-"` - FSM *FSM `json:"-"` + Session OAuthExchangeSession `json:"-"` + Token OAuthToken `json:"token"` + BaseAuthorizationURL string `json:"base_authorization_url"` + TokenURL string `json:"token_url"` + Logger *FileLogger `json:"-"` + FSM *FSM `json:"-"` } // This structure gets passed to the callback for easy access to the current state @@ -216,13 +217,20 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { go oauth.Session.Server.Shutdown(oauth.Session.Context) } -func (oauth *OAuth) Init(fsm *FSM, logger *FileLogger) { +func (oauth *OAuth) Update(fsm *FSM, logger *FileLogger) { + oauth.FSM = fsm + oauth.Logger = logger +} + +func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *FSM, logger *FileLogger) { + oauth.BaseAuthorizationURL = baseAuthorizationURL + oauth.TokenURL = tokenURL oauth.FSM = fsm oauth.Logger = logger } // Starts the OAuth exchange for eduvpn. -func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string) error { +func (oauth *OAuth) start(name string) error { if !oauth.FSM.HasTransition(OAUTH_STARTED) { return &FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: OAUTH_STARTED} } @@ -249,7 +257,7 @@ func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string) "redirect_uri": "http://127.0.0.1:8000/callback", } - authURL, urlErr := HTTPConstructURL(authorizationURL, parameters) + authURL, urlErr := HTTPConstructURL(oauth.BaseAuthorizationURL, parameters) if urlErr != nil { return &OAuthInitializeError{Err: urlErr} @@ -257,7 +265,6 @@ func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string) // Fill the struct with the necessary fields filled for the next call to getting the HTTP client oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier} - oauth.TokenURL = tokenURL oauth.Session = oauthSession // Run the state callback in the background so that the user can login while we start the callback server oauth.FSM.GoTransitionWithData(OAUTH_STARTED, authURL, true) @@ -283,8 +290,8 @@ func (oauth *OAuth) Cancel() { oauth.Session.Server.Shutdown(oauth.Session.Context) } -func (oauth *OAuth) Login(name string, authorizationURL string, tokenURL string) error { - authInitializeErr := oauth.start(name, authorizationURL, tokenURL) +func (oauth *OAuth) Login(name string) error { + authInitializeErr := oauth.start(name) if authInitializeErr != nil { return &OAuthLoginError{Err: authInitializeErr} diff --git a/internal/openvpn.go b/internal/openvpn.go index ed31fe2..45beb51 100644 --- a/internal/openvpn.go +++ b/internal/openvpn.go @@ -2,9 +2,14 @@ package internal import "fmt" -func (server *Server) OpenVPNGetConfig() (string, error) { - profile_id := server.Profiles.Current - configOpenVPN, _, configErr := server.APIConnectOpenVPN(profile_id) +func OpenVPNGetConfig(server Server) (string, error) { + base, baseErr := server.GetBase() + + if baseErr != nil { + return "", &OpenVPNGetConfigError{Err: baseErr} + } + profile_id := base.Profiles.Current + configOpenVPN, _, configErr := APIConnectOpenVPN(server, profile_id) if configErr != nil { return "", &OpenVPNGetConfigError{Err: configErr} diff --git a/internal/server.go b/internal/server.go index 489719e..7500a26 100644 --- a/internal/server.go +++ b/internal/server.go @@ -2,114 +2,225 @@ package internal import ( "encoding/json" - "errors" "fmt" ) -type Server struct { - BaseURL string `json:"base_url"` +// The base type for servers +type ServerBase struct { + URL string `json:"base_url"` Endpoints ServerEndpoints `json:"endpoints"` - OAuth OAuth `json:"oauth"` Profiles ServerProfileInfo `json:"profiles"` ProfilesRaw string `json:"profiles_raw"` Logger *FileLogger `json:"-"` FSM *FSM `json:"-"` } -type Servers struct { - List map[string]*Server `json:"list"` - Current string `json:"current"` - SecureHome string `json:"secure_home"` +// An instute access server +type InstituteAccessServer struct { + // An instute access server has its own OAuth + OAuth OAuth `json:"oauth"` + + // Embed the server base + Base ServerBase `json:"base"` +} + +// A secure internet server which has its own OAuth tokens +// It specifies the current location url it is connected to +type SecureInternetHomeServer struct { + OAuth OAuth `json:"oauth"` + + // The home server has a list of info for each configured server + BaseMap map[string]*ServerBase `json:"base_map"` + + // We have the home url and the current url + HomeURL string `json:"home_url"` + CurrentURL string `json:"current_url"` } -func (servers *Servers) GetCurrentServer() (*Server, error) { - if servers.List == nil { +type InstituteServers struct { + Map map[string]*InstituteAccessServer `json:"map"` + CurrentURL string `json:"current_url"` +} + +func (servers *Servers) GetCurrentServer() (Server, error) { + if servers.IsSecureInternet { + return &servers.SecureInternetHomeServer, nil + } + currentInstitute := servers.InstituteServers.CurrentURL + institutes := servers.InstituteServers.Map + if institutes == nil { return nil, &ServerGetCurrentNoMapError{} } - server, exists := servers.List[servers.Current] + institute, exists := institutes[currentInstitute] - if !exists || server == nil { + if !exists || institute == nil { return nil, &ServerGetCurrentNotFoundError{} } - return server, nil + return institute, nil } -func (server *Server) CancelOAuth() { - server.OAuth.Cancel() +type Servers struct { + InstituteServers InstituteServers `json:"institute_servers"` + SecureInternetHomeServer SecureInternetHomeServer `json:"secure_internet_home"` + IsSecureInternet bool `json:"is_secure_internet"` } -func (server *Server) Init(url string, fsm *FSM, logger *FileLogger) error { - server.BaseURL = url - server.FSM = fsm - server.Logger = logger - server.OAuth.Init(fsm, logger) - endpointsErr := server.GetEndpoints() - if endpointsErr != nil { - return &ServerInitializeError{URL: url, Err: endpointsErr} +type Server interface { + // Gets the current OAuth object + GetOAuth() *OAuth + + // Gets the server base + GetBase() (*ServerBase, error) + + // initialize method + init(url string, fsm *FSM, logger *FileLogger) error +} + +// For an institute, we can simply get the OAuth +func (institute *InstituteAccessServer) GetOAuth() *OAuth { + return &institute.OAuth +} + +func (secure *SecureInternetHomeServer) GetOAuth() *OAuth { + return &secure.OAuth +} + +func (institute *InstituteAccessServer) GetBase() (*ServerBase, error) { + return &institute.Base, nil +} + +func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) { + if server.BaseMap == nil { + return nil, &ServerSecureInternetMapNotFoundError{} } - return nil + + base, exists := server.BaseMap[server.CurrentURL] + + if !exists { + return nil, &ServerSecureInternetBaseNotFoundError{Current: server.CurrentURL} + } + return base, nil } -func (server *Server) EnsureTokens() error { - if server.OAuth.NeedsRelogin() { - server.Logger.Log(LOG_INFO, "OAuth: Tokens are invalid, relogging in") - return server.Login() +func (institute *InstituteAccessServer) init(url string, fsm *FSM, logger *FileLogger) error { + institute.Base.URL = url + institute.Base.FSM = fsm + institute.Base.Logger = logger + endpoints, endpointsErr := getEndpoints(url) + if endpointsErr != nil { + return &ServerInitializeError{URL: url, Err: endpointsErr} } + institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token, fsm, logger) + institute.Base.Endpoints = *endpoints return nil } -func (servers *Servers) EnsureServer(url string, fsm *FSM, logger *FileLogger, makeCurrent bool) (*Server, error) { - if url == "" { - return nil, &ServerEnsureServerEmptyURLError{} +func (secure *SecureInternetHomeServer) init(url string, fsm *FSM, logger *FileLogger) error { + // Initialize the base map if it is non-nil + if secure.BaseMap == nil { + secure.BaseMap = make(map[string]*ServerBase) } - if servers.List == nil { - servers.List = make(map[string]*Server) + + // Add it if not present + base, exists := secure.BaseMap[url] + + if !exists || base == nil { + // Create the base to be added to the map + base = &ServerBase{} + base.URL = url + endpoints, endpointsErr := getEndpoints(url) + if endpointsErr != nil { + return &ServerInitializeError{URL: url, Err: endpointsErr} + } + base.Endpoints = *endpoints } - server, exists := servers.List[url] + // Pass the fsm and logger + base.FSM = fsm + base.Logger = logger + + // Ensure it is in the map + secure.BaseMap[url] = base - if !exists || server == nil { - server = &Server{} + // Set the home url if it is not set yet + if secure.HomeURL == "" { + secure.HomeURL = url + // Make sure oauth contains our endpoints + secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm, logger) + } else { // Else just pass in the fsm and logger + secure.OAuth.Update(fsm, logger) } - serverInitErr := server.Init(url, fsm, logger) - if serverInitErr != nil { - return nil, &ServerEnsureServerError{Err: serverInitErr} + // Set the current url + secure.CurrentURL = url + return nil +} + +func Login(server Server) error { + return server.GetOAuth().Login("org.eduvpn.app.linux") +} + +func EnsureTokens(server Server) error { + base, baseErr := server.GetBase() + + if baseErr != nil { + return &ServerEnsureTokensError{Err: baseErr} } - servers.List[url] = server + if server.GetOAuth().NeedsRelogin() { + base.Logger.Log(LOG_INFO, "OAuth: Tokens are invalid, relogging in") + loginErr := Login(server) - if makeCurrent { - servers.Current = url + if loginErr != nil { + return &ServerEnsureTokensError{Err: loginErr} + } } - return server, nil + return nil } -func (servers *Servers) getSecureInternetHome() (*Server, error) { - server, exists := servers.List[servers.SecureHome] +func NeedsRelogin(server Server) bool { + return server.GetOAuth().NeedsRelogin() +} + +func CancelOAuth(server Server) { + server.GetOAuth().Cancel() +} + +func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *FSM, logger *FileLogger) (Server, error) { + // Intialize the secure internet server + // This calls the init method which takes care of the rest + if isSecureInternet { + initErr := servers.SecureInternetHomeServer.init(url, fsm, logger) + + if initErr != nil { + return nil, &ServerEnsureServerError{Err: initErr} + } - if !exists || server == nil { - return nil, &ServerGetSecureInternetHomeError{} + servers.IsSecureInternet = true + return &servers.SecureInternetHomeServer, nil } - return server, nil -} + instituteServers := &servers.InstituteServers -func (servers *Servers) EnsureSecureHome(server *Server) { - if servers.SecureHome == "" { - servers.SecureHome = server.BaseURL + if instituteServers.Map == nil { + instituteServers.Map = make(map[string]*InstituteAccessServer) } -} -func (servers *Servers) CopySecureInternetOAuth(server *Server) error { - secureHome, secureHomeErr := servers.getSecureInternetHome() + institute, exists := instituteServers.Map[url] - if secureHomeErr != nil { - return &ServerCopySecureInternetOAuthError{Err: secureHomeErr} + // initialize the server if it doesn't exist yet + if !exists { + institute = &InstituteAccessServer{} } - // Forward token properties - server.OAuth = secureHome.OAuth - return nil + // Set the current server + instituteServers.CurrentURL = url + instituteInitErr := institute.init(url, fsm, logger) + if instituteInitErr != nil { + return nil, &ServerEnsureServerError{Err: instituteInitErr} + } + instituteServers.Map[url] = institute + servers.IsSecureInternet = false + return institute, nil } type ServerProfile struct { @@ -141,33 +252,22 @@ type ServerEndpoints struct { V string `json:"v"` } -func (server *Server) Login() error { - return server.OAuth.Login("org.eduvpn.app.linux", server.Endpoints.API.V3.Authorization, server.Endpoints.API.V3.Token) -} - -func (server *Server) NeedsRelogin() bool { - // Check if OAuth needs relogin - return server.OAuth.NeedsRelogin() -} - -func (server *Server) GetEndpoints() error { - url := server.BaseURL + "/.well-known/vpn-user-portal" +func getEndpoints(baseURL string) (*ServerEndpoints, error) { + url := fmt.Sprintf("%s/.well-known/vpn-user-portal", baseURL) _, body, bodyErr := HTTPGet(url) if bodyErr != nil { - return &ServerGetEndpointsError{Err: bodyErr} + return nil, &ServerGetEndpointsError{Err: bodyErr} } - endpoints := ServerEndpoints{} - jsonErr := json.Unmarshal(body, &endpoints) + endpoints := &ServerEndpoints{} + jsonErr := json.Unmarshal(body, endpoints) if jsonErr != nil { - return &ServerGetEndpointsError{Err: jsonErr} + return nil, &ServerGetEndpointsError{Err: jsonErr} } - server.Endpoints = endpoints - - return nil + return endpoints, nil } func (profile *ServerProfile) supportsWireguard() bool { @@ -179,9 +279,14 @@ func (profile *ServerProfile) supportsWireguard() bool { return false } -func (server *Server) getCurrentProfile() (*ServerProfile, error) { - profileID := server.Profiles.Current - for _, profile := range server.Profiles.Info.ProfileList { +func getCurrentProfile(server Server) (*ServerProfile, error) { + base, baseErr := server.GetBase() + + if baseErr != nil { + return nil, &ServerGetCurrentProfileError{Err: baseErr} + } + profileID := base.Profiles.Current + for _, profile := range base.Profiles.Info.ProfileList { if profile.ID == profileID { return &profile, nil } @@ -189,53 +294,68 @@ func (server *Server) getCurrentProfile() (*ServerProfile, error) { return nil, &ServerGetCurrentProfileNotFoundError{ProfileID: profileID} } -func (server *Server) getConfigWithProfile() (string, error) { - if !server.FSM.HasTransition(HAS_CONFIG) { - return "", &FSMWrongStateTransitionError{Got: server.FSM.Current, Want: HAS_CONFIG} +func getConfigWithProfile(server Server) (string, error) { + base, baseErr := server.GetBase() + + if baseErr != nil { + return "", &ServerGetConfigWithProfileError{Err: baseErr} + } + if !base.FSM.HasTransition(HAS_CONFIG) { + return "", &FSMWrongStateTransitionError{Got: base.FSM.Current, Want: HAS_CONFIG} } - profile, profileErr := server.getCurrentProfile() + profile, profileErr := getCurrentProfile(server) if profileErr != nil { return "", &ServerGetConfigWithProfileError{Err: profileErr} } if profile.supportsWireguard() { - return server.WireguardGetConfig() + return WireguardGetConfig(server) } - return server.OpenVPNGetConfig() + return OpenVPNGetConfig(server) } -func (server *Server) askForProfileID() error { - if !server.FSM.HasTransition(ASK_PROFILE) { - return &FSMWrongStateTransitionError{Got: server.FSM.Current, Want: ASK_PROFILE} +func askForProfileID(server Server) error { + base, baseErr := server.GetBase() + + if baseErr != nil { + return &ServerAskForProfileIDError{Err: baseErr} } - server.FSM.GoTransitionWithData(ASK_PROFILE, server.ProfilesRaw, false) + if !base.FSM.HasTransition(ASK_PROFILE) { + return &FSMWrongStateTransitionError{Got: base.FSM.Current, Want: ASK_PROFILE} + } + base.FSM.GoTransitionWithData(ASK_PROFILE, base.ProfilesRaw, false) return nil } -func (server *Server) GetConfig() (string, error) { - if !server.FSM.InState(REQUEST_CONFIG) { - return "", errors.New(fmt.Sprintf("cannot get a config, invalid state %s", server.FSM.Current.String())) +func GetConfig(server Server) (string, error) { + base, baseErr := server.GetBase() + + if baseErr != nil { + return "", &ServerGetConfigError{Err: baseErr} + } + if !base.FSM.InState(REQUEST_CONFIG) { + return "", &FSMWrongStateError{Got: base.FSM.Current, Want: REQUEST_CONFIG} } - infoErr := server.APIInfo() + infoErr := APIInfo(server) if infoErr != nil { return "", &ServerGetConfigError{Err: infoErr} } // Set the current profile if there is only one profile - if len(server.Profiles.Info.ProfileList) == 1 { - server.Profiles.Current = server.Profiles.Info.ProfileList[0].ID - return server.getConfigWithProfile() + if len(base.Profiles.Info.ProfileList) == 1 { + base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID + return getConfigWithProfile(server) } - profileErr := server.askForProfileID() + profileErr := askForProfileID(server) if profileErr != nil { return "", &ServerGetConfigError{Err: profileErr} } - return server.getConfigWithProfile() + return getConfigWithProfile(server) } type ServerGetCurrentProfileNotFoundError struct { @@ -318,3 +438,49 @@ type ServerInitializeError struct { func (e *ServerInitializeError) Error() string { return fmt.Sprintf("failed initializing server with url %s and error %v", e.URL, e.Err) } + +type ServerInstituteBaseNotFoundError struct { + Err error +} + +func (e *ServerInstituteBaseNotFoundError) Error() string { + return "institute base not found" +} + +type ServerSecureInternetMapNotFoundError struct{} + +func (e *ServerSecureInternetMapNotFoundError) Error() string { + return "secure internet map not found" +} + +type ServerSecureInternetBaseNotFoundError struct { + Current string +} + +func (e *ServerSecureInternetBaseNotFoundError) Error() string { + return fmt.Sprintf("secure internet base not found with current: %s", e.Current) +} + +type ServerGetCurrentProfileError struct { + Err error +} + +func (e *ServerGetCurrentProfileError) Error() string { + return fmt.Sprintf("failed getting current profile with error: %v", e.Err) +} + +type ServerAskForProfileIDError struct { + Err error +} + +func (e *ServerAskForProfileIDError) Error() string { + return fmt.Sprintf("ask for profile ID error: %v", e.Err) +} + +type ServerEnsureTokensError struct { + Err error +} + +func (e *ServerEnsureTokensError) Error() string { + return fmt.Sprintf("failed ensuring tokens with error: %v", e.Err) +} diff --git a/internal/wireguard.go b/internal/wireguard.go index 7977dbc..318e0dc 100644 --- a/internal/wireguard.go +++ b/internal/wireguard.go @@ -30,8 +30,14 @@ func wireguardConfigAddKey(config string, key wgtypes.Key) string { return interface_re.ReplaceAllString(config, to_replace) } -func (server *Server) WireguardGetConfig() (string, error) { - profile_id := server.Profiles.Current +func WireguardGetConfig(server Server) (string, error) { + base, baseErr := server.GetBase() + + if baseErr != nil { + return "", &WireguardGetConfigError{Err: baseErr} + } + + profile_id := base.Profiles.Current wireguardKey, wireguardErr := wireguardGenerateKey() if wireguardErr != nil { @@ -39,7 +45,7 @@ func (server *Server) WireguardGetConfig() (string, error) { } wireguardPublicKey := wireguardKey.PublicKey().String() - configWireguard, _, configErr := server.APIConnectWireguard(profile_id, wireguardPublicKey) + configWireguard, _, configErr := APIConnectWireguard(server, profile_id, wireguardPublicKey) if configErr != nil { return "", &WireguardGetConfigError{Err: wireguardErr} @@ -83,36 +83,38 @@ func (state *VPNState) CancelOAuth() error { if serverErr != nil { return &StateOAuthCancelError{Err: serverErr} } - server.CancelOAuth() + internal.CancelOAuth(server) return nil } -func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (string, error) { - if state.FSM.InState(internal.DEREGISTERED) { - return "", &StateFSMNotRegisteredError{} - } +func (state *VPNState) chooseServer(url string, isSecureInternet bool) (internal.Server, error) { // New server chosen, ensure the server is fresh - server, serverErr := state.Servers.EnsureServer(url, &state.FSM, &state.Logger, true) + server, serverErr := state.Servers.EnsureServer(url, isSecureInternet, &state.FSM, &state.Logger) if serverErr != nil { - return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: serverErr} + return nil, serverErr } - // When we connect to secure internet, copy over the tokens from the home server - if isSecureInternet { - // Ensure the secure home server - state.Servers.EnsureServer(state.Servers.SecureHome, &state.FSM, &state.Logger, false) + // Make sure we are in the chosen state if available + state.FSM.GoTransition(internal.CHOSEN_SERVER) + return server, nil +} - // Copy the tokens - state.Servers.CopySecureInternetOAuth(server) +func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (string, error) { + if state.FSM.InState(internal.DEREGISTERED) { + return "", &StateFSMNotRegisteredError{} } - // Make sure we are in the chosen state if available - state.FSM.GoTransition(internal.CHOSEN_SERVER) + // Make sure the server is chosen + server, serverErr := state.chooseServer(url, isSecureInternet) + + if serverErr != nil { + return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: serverErr} + } // Relogin with oauth // This moves the state to authorized - if server.NeedsRelogin() { - loginErr := server.Login() + if internal.NeedsRelogin(server) { + loginErr := internal.Login(server) if loginErr != nil { // We are possibly in oauth started @@ -124,12 +126,9 @@ func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (st state.FSM.GoTransition(internal.AUTHORIZED) } - // Set the home server if it is not set already - state.Servers.EnsureSecureHome(server) - state.FSM.GoTransition(internal.REQUEST_CONFIG) - config, configErr := server.GetConfig() + config, configErr := internal.GetConfig(server) if configErr != nil { return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: configErr} @@ -171,7 +170,13 @@ func (state *VPNState) SetProfileID(profileID string) error { if serverErr != nil { return &StateSetProfileError{ProfileID: profileID, Err: serverErr} } - server.Profiles.Current = profileID + + base, baseErr := server.GetBase() + + if baseErr != nil { + return &StateSetProfileError{ProfileID: profileID, Err: baseErr} + } + base.Profiles.Current = profileID return nil } diff --git a/state_test.go b/state_test.go index c6e33e0..c174c72 100644 --- a/state_test.go +++ b/state_test.go @@ -161,29 +161,31 @@ func Test_token_expired(t *testing.T) { _, configErr := state.ConnectInstituteAccess(serverURI) if configErr != nil { - t.Errorf("Connect error before expired: %v", configErr) + t.Fatalf("Connect error before expired: %v", configErr) } server, serverErr := state.Servers.GetCurrentServer() if serverErr != nil { - t.Errorf("No server found") + t.Fatalf("No server found") } - accessToken := server.OAuth.Token.Access - refreshToken := server.OAuth.Token.Refresh + oauth := server.GetOAuth() + + accessToken := oauth.Token.Access + refreshToken := oauth.Token.Refresh // Wait for TTL so that the tokens expire time.Sleep(time.Duration(expiredInt) * time.Second) - infoErr := server.APIInfo() + infoErr := internal.APIInfo(server) if infoErr != nil { t.Errorf("Info error after expired: %v", infoErr) } // Check if tokens have changed - accessTokenAfter := server.OAuth.Token.Access - refreshTokenAfter := server.OAuth.Token.Refresh + accessTokenAfter := oauth.Token.Access + refreshTokenAfter := oauth.Token.Refresh if accessToken == accessTokenAfter { t.Errorf("Access token is the same after refresh") @@ -221,21 +223,23 @@ func Test_token_invalid(t *testing.T) { return } + oauth := server.GetOAuth() + // Override tokens with invalid values - server.OAuth.Token.Access = dummy_value - server.OAuth.Token.Refresh = dummy_value + oauth.Token.Access = dummy_value + oauth.Token.Refresh = dummy_value - infoErr := server.APIInfo() + infoErr := internal.APIInfo(server) if infoErr != nil { t.Errorf("Info error after invalid: %v", infoErr) } - if server.OAuth.Token.Access == dummy_value { + if oauth.Token.Access == dummy_value { t.Errorf("Access token is equal to dummy value: %s", dummy_value) } - if server.OAuth.Token.Refresh == dummy_value { + if oauth.Token.Refresh == dummy_value { t.Errorf("Refresh token is equal to dummy value: %s", dummy_value) } } |
