diff options
| -rw-r--r-- | exports/exports.go | 6 | ||||
| -rw-r--r-- | src/oauth.go | 16 | ||||
| -rw-r--r-- | src/server.go | 37 | ||||
| -rw-r--r-- | src/server_test.go | 30 | ||||
| -rw-r--r-- | src/state.go | 21 |
5 files changed, 78 insertions, 32 deletions
diff --git a/exports/exports.go b/exports/exports.go index c9ef41b..84cd571 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -86,7 +86,11 @@ func SetProfileID(data *C.char) *C.char { // Set current profile to id profile_id := C.GoString(data) - state.Server.Profiles.Current = profile_id + server, serverErr := state.Servers.GetCurrentServer() + if serverErr != nil { + return C.CString("No server found for setting a profile ID") + } + server.Profiles.Current = profile_id return C.CString("") } diff --git a/src/oauth.go b/src/oauth.go index cb13d19..263690e 100644 --- a/src/oauth.go +++ b/src/oauth.go @@ -220,7 +220,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // It returns the authurl for the browser and an error if present func (eduvpn *VPNState) InitializeOAuth() error { if !eduvpn.HasTransition(OAUTH_STARTED) { - return errors.New("Failed starting oauth, invalid state") + return errors.New(fmt.Sprintf("Failed starting oauth, invalid state %s", eduvpn.FSM.Current.String())) } // Generate the state state, stateErr := genState() @@ -245,7 +245,11 @@ func (eduvpn *VPNState) InitializeOAuth() error { "redirect_uri": "http://127.0.0.1:8000/callback", } - authURL, urlErr := HTTPConstructURL(eduvpn.Server.Endpoints.API.V3.Authorization, parameters) + server, serverErr := eduvpn.Servers.GetCurrentServer() + if serverErr != nil { + return errors.New("OAuth Initialize no server found") + } + authURL, urlErr := HTTPConstructURL(server.Endpoints.API.V3.Authorization, parameters) if urlErr != nil { // shouldn't happen panic(urlErr) @@ -253,7 +257,7 @@ func (eduvpn *VPNState) InitializeOAuth() error { // Fill the struct with the necessary fields filled for the next call to getting the HTTP client oauthSession := OAuthExchangeSession{ClientID: eduvpn.Name, State: state, Verifier: verifier} - eduvpn.Server.OAuth = OAuth{TokenURL: eduvpn.Server.Endpoints.API.V3.Token, Session: oauthSession} + server.OAuth = OAuth{TokenURL: server.Endpoints.API.V3.Token, Session: oauthSession} eduvpn.GoTransitionWithData(OAUTH_STARTED, authURL) return nil } @@ -263,7 +267,11 @@ func (eduvpn *VPNState) FinishOAuth() error { if !eduvpn.HasTransition(AUTHENTICATED) { return errors.New("invalid state to finish oauth") } - tokenErr := eduvpn.Server.OAuth.getTokensWithCallback() + server, serverErr := eduvpn.Servers.GetCurrentServer() + if serverErr != nil { + return errors.New("OAuth Initialize No server found") + } + tokenErr := server.OAuth.getTokensWithCallback() if tokenErr != nil { return tokenErr } diff --git a/src/server.go b/src/server.go index f4aab66..3dca26b 100644 --- a/src/server.go +++ b/src/server.go @@ -13,6 +13,39 @@ type Server struct { ProfilesRaw string `json:"profiles_raw"` } +type Servers struct { + List map[string]*Server `json:"list"` + Current string `json:"current"` +} + +func (servers *Servers) GetCurrentServer() (*Server, error) { + if servers.List == nil { + return nil, errors.New("No map found to get Current Server") + } + server, exists := servers.List[servers.Current] + + if !exists || server == nil { + return nil, errors.New("Current Server not found") + } + return server, nil +} + +func (servers *Servers) EnsureServer(url string) *Server { + if servers.List == nil { + servers.List = make(map[string]*Server) + } + + server, exists := servers.List[url] + + if !exists || server == nil { + server = &Server{} + server.Initialize(url) + servers.List[url] = server + } + servers.Current = url + return server +} + type ServerProfile struct { ID string `json:"profile_id"` DisplayName string `json:"display_name"` @@ -43,15 +76,11 @@ type ServerEndpoints struct { } func (server *Server) Initialize(url string) error { - if !GetVPNState().HasTransition(CHOSEN_SERVER) { - return errors.New("cannot choose a server") - } server.BaseURL = url endpointsErr := server.GetEndpoints() if endpointsErr != nil { return endpointsErr } - GetVPNState().GoTransition(CHOSEN_SERVER) return nil } diff --git a/src/server_test.go b/src/server_test.go index 7081bde..ccf58f6 100644 --- a/src/server_test.go +++ b/src/server_test.go @@ -134,21 +134,26 @@ func Test_token_expired(t *testing.T) { t.Errorf("Connect error before expired: %v", configErr) } - accessToken := state.Server.OAuth.Token.Access - refreshToken := state.Server.OAuth.Token.Refresh + server, serverErr := state.Servers.GetCurrentServer() + if serverErr != nil { + t.Errorf("No server found") + } + + accessToken := server.OAuth.Token.Access + refreshToken := server.OAuth.Token.Refresh // Wait for TTL so that the tokens expire time.Sleep(time.Duration(expiredInt) * time.Second) - infoErr := state.Server.APIInfo() + infoErr := server.APIInfo() if infoErr != nil { t.Errorf("Info error after expired: %v", infoErr) } // Check if tokens have changed - accessTokenAfter := state.Server.OAuth.Token.Access - refreshTokenAfter := state.Server.OAuth.Token.Refresh + accessTokenAfter := server.OAuth.Token.Access + refreshTokenAfter := server.OAuth.Token.Refresh if accessToken == accessTokenAfter { t.Errorf("Access token is the same after refresh") @@ -184,21 +189,26 @@ func Test_token_invalid(t *testing.T) { dummy_value := "37" + server, serverErr := state.Servers.GetCurrentServer() + if serverErr != nil { + t.Errorf("No server found") + } + // Override tokens with invalid values - state.Server.OAuth.Token.Access = dummy_value - state.Server.OAuth.Token.Refresh = dummy_value + server.OAuth.Token.Access = dummy_value + server.OAuth.Token.Refresh = dummy_value - infoErr := state.Server.APIInfo() + infoErr := server.APIInfo() if infoErr != nil { t.Errorf("Info error after invalid: %v", infoErr) } - if state.Server.OAuth.Token.Access == dummy_value { + if server.OAuth.Token.Access == dummy_value { t.Errorf("Access token is equal to dummy value: %s", dummy_value) } - if state.Server.OAuth.Token.Refresh == dummy_value { + if server.OAuth.Token.Refresh == dummy_value { t.Errorf("Refresh token is equal to dummy value: %s", dummy_value) } } diff --git a/src/state.go b/src/state.go index e15cb3c..c0e512f 100644 --- a/src/state.go +++ b/src/state.go @@ -12,7 +12,7 @@ type VPNState struct { StateCallbackData string `json:"-"` // The chosen server - Server Server `json:"server"` + Servers Servers `json:"servers"` // The list of servers and organizations from disco DiscoList DiscoLists `json:"-"` @@ -62,25 +62,20 @@ func (state *VPNState) Deregister() error { // Write the config state.WriteConfig() - // Re-initialize the server and FSM - state.Server = Server{} + // Re-initialize the servers and FSM + state.Servers = Servers{} state.InitializeFSM() return nil } func (state *VPNState) Connect(url string) (string, error) { // New server chosen, ensure the server is fresh - if state.Server.BaseURL != url { - state.Server = Server{} - } - initializeErr := state.Server.Initialize(url) - - if initializeErr != nil { - return "", initializeErr - } + server := state.Servers.EnsureServer(url) + // Make sure we are in the chosen state if available + state.GoTransition(CHOSEN_SERVER) // Relogin with oauth // This moves the state to authenticated - if state.Server.NeedsRelogin() { + if server.NeedsRelogin() { loginErr := state.LoginOAuth() if loginErr != nil { @@ -92,7 +87,7 @@ func (state *VPNState) Connect(url string) (string, error) { state.GoTransition(REQUEST_CONFIG) - config, configErr := state.Server.GetConfig() + config, configErr := server.GetConfig() if configErr != nil { return "", configErr |
