summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/oauth.go16
-rw-r--r--src/server.go37
-rw-r--r--src/server_test.go30
-rw-r--r--src/state.go21
4 files changed, 73 insertions, 31 deletions
diff --git a/src/oauth.go b/src/oauth.go
index cb13d19..263690e 100644
--- a/src/oauth.go
+++ b/src/oauth.go
@@ -220,7 +220,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// It returns the authurl for the browser and an error if present
func (eduvpn *VPNState) InitializeOAuth() error {
if !eduvpn.HasTransition(OAUTH_STARTED) {
- return errors.New("Failed starting oauth, invalid state")
+ return errors.New(fmt.Sprintf("Failed starting oauth, invalid state %s", eduvpn.FSM.Current.String()))
}
// Generate the state
state, stateErr := genState()
@@ -245,7 +245,11 @@ func (eduvpn *VPNState) InitializeOAuth() error {
"redirect_uri": "http://127.0.0.1:8000/callback",
}
- authURL, urlErr := HTTPConstructURL(eduvpn.Server.Endpoints.API.V3.Authorization, parameters)
+ server, serverErr := eduvpn.Servers.GetCurrentServer()
+ if serverErr != nil {
+ return errors.New("OAuth Initialize no server found")
+ }
+ authURL, urlErr := HTTPConstructURL(server.Endpoints.API.V3.Authorization, parameters)
if urlErr != nil { // shouldn't happen
panic(urlErr)
@@ -253,7 +257,7 @@ func (eduvpn *VPNState) InitializeOAuth() error {
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
oauthSession := OAuthExchangeSession{ClientID: eduvpn.Name, State: state, Verifier: verifier}
- eduvpn.Server.OAuth = OAuth{TokenURL: eduvpn.Server.Endpoints.API.V3.Token, Session: oauthSession}
+ server.OAuth = OAuth{TokenURL: server.Endpoints.API.V3.Token, Session: oauthSession}
eduvpn.GoTransitionWithData(OAUTH_STARTED, authURL)
return nil
}
@@ -263,7 +267,11 @@ func (eduvpn *VPNState) FinishOAuth() error {
if !eduvpn.HasTransition(AUTHENTICATED) {
return errors.New("invalid state to finish oauth")
}
- tokenErr := eduvpn.Server.OAuth.getTokensWithCallback()
+ server, serverErr := eduvpn.Servers.GetCurrentServer()
+ if serverErr != nil {
+ return errors.New("OAuth Initialize No server found")
+ }
+ tokenErr := server.OAuth.getTokensWithCallback()
if tokenErr != nil {
return tokenErr
}
diff --git a/src/server.go b/src/server.go
index f4aab66..3dca26b 100644
--- a/src/server.go
+++ b/src/server.go
@@ -13,6 +13,39 @@ type Server struct {
ProfilesRaw string `json:"profiles_raw"`
}
+type Servers struct {
+ List map[string]*Server `json:"list"`
+ Current string `json:"current"`
+}
+
+func (servers *Servers) GetCurrentServer() (*Server, error) {
+ if servers.List == nil {
+ return nil, errors.New("No map found to get Current Server")
+ }
+ server, exists := servers.List[servers.Current]
+
+ if !exists || server == nil {
+ return nil, errors.New("Current Server not found")
+ }
+ return server, nil
+}
+
+func (servers *Servers) EnsureServer(url string) *Server {
+ if servers.List == nil {
+ servers.List = make(map[string]*Server)
+ }
+
+ server, exists := servers.List[url]
+
+ if !exists || server == nil {
+ server = &Server{}
+ server.Initialize(url)
+ servers.List[url] = server
+ }
+ servers.Current = url
+ return server
+}
+
type ServerProfile struct {
ID string `json:"profile_id"`
DisplayName string `json:"display_name"`
@@ -43,15 +76,11 @@ type ServerEndpoints struct {
}
func (server *Server) Initialize(url string) error {
- if !GetVPNState().HasTransition(CHOSEN_SERVER) {
- return errors.New("cannot choose a server")
- }
server.BaseURL = url
endpointsErr := server.GetEndpoints()
if endpointsErr != nil {
return endpointsErr
}
- GetVPNState().GoTransition(CHOSEN_SERVER)
return nil
}
diff --git a/src/server_test.go b/src/server_test.go
index 7081bde..ccf58f6 100644
--- a/src/server_test.go
+++ b/src/server_test.go
@@ -134,21 +134,26 @@ func Test_token_expired(t *testing.T) {
t.Errorf("Connect error before expired: %v", configErr)
}
- accessToken := state.Server.OAuth.Token.Access
- refreshToken := state.Server.OAuth.Token.Refresh
+ server, serverErr := state.Servers.GetCurrentServer()
+ if serverErr != nil {
+ t.Errorf("No server found")
+ }
+
+ accessToken := server.OAuth.Token.Access
+ refreshToken := server.OAuth.Token.Refresh
// Wait for TTL so that the tokens expire
time.Sleep(time.Duration(expiredInt) * time.Second)
- infoErr := state.Server.APIInfo()
+ infoErr := server.APIInfo()
if infoErr != nil {
t.Errorf("Info error after expired: %v", infoErr)
}
// Check if tokens have changed
- accessTokenAfter := state.Server.OAuth.Token.Access
- refreshTokenAfter := state.Server.OAuth.Token.Refresh
+ accessTokenAfter := server.OAuth.Token.Access
+ refreshTokenAfter := server.OAuth.Token.Refresh
if accessToken == accessTokenAfter {
t.Errorf("Access token is the same after refresh")
@@ -184,21 +189,26 @@ func Test_token_invalid(t *testing.T) {
dummy_value := "37"
+ server, serverErr := state.Servers.GetCurrentServer()
+ if serverErr != nil {
+ t.Errorf("No server found")
+ }
+
// Override tokens with invalid values
- state.Server.OAuth.Token.Access = dummy_value
- state.Server.OAuth.Token.Refresh = dummy_value
+ server.OAuth.Token.Access = dummy_value
+ server.OAuth.Token.Refresh = dummy_value
- infoErr := state.Server.APIInfo()
+ infoErr := server.APIInfo()
if infoErr != nil {
t.Errorf("Info error after invalid: %v", infoErr)
}
- if state.Server.OAuth.Token.Access == dummy_value {
+ if server.OAuth.Token.Access == dummy_value {
t.Errorf("Access token is equal to dummy value: %s", dummy_value)
}
- if state.Server.OAuth.Token.Refresh == dummy_value {
+ if server.OAuth.Token.Refresh == dummy_value {
t.Errorf("Refresh token is equal to dummy value: %s", dummy_value)
}
}
diff --git a/src/state.go b/src/state.go
index e15cb3c..c0e512f 100644
--- a/src/state.go
+++ b/src/state.go
@@ -12,7 +12,7 @@ type VPNState struct {
StateCallbackData string `json:"-"`
// The chosen server
- Server Server `json:"server"`
+ Servers Servers `json:"servers"`
// The list of servers and organizations from disco
DiscoList DiscoLists `json:"-"`
@@ -62,25 +62,20 @@ func (state *VPNState) Deregister() error {
// Write the config
state.WriteConfig()
- // Re-initialize the server and FSM
- state.Server = Server{}
+ // Re-initialize the servers and FSM
+ state.Servers = Servers{}
state.InitializeFSM()
return nil
}
func (state *VPNState) Connect(url string) (string, error) {
// New server chosen, ensure the server is fresh
- if state.Server.BaseURL != url {
- state.Server = Server{}
- }
- initializeErr := state.Server.Initialize(url)
-
- if initializeErr != nil {
- return "", initializeErr
- }
+ server := state.Servers.EnsureServer(url)
+ // Make sure we are in the chosen state if available
+ state.GoTransition(CHOSEN_SERVER)
// Relogin with oauth
// This moves the state to authenticated
- if state.Server.NeedsRelogin() {
+ if server.NeedsRelogin() {
loginErr := state.LoginOAuth()
if loginErr != nil {
@@ -92,7 +87,7 @@ func (state *VPNState) Connect(url string) (string, error) {
state.GoTransition(REQUEST_CONFIG)
- config, configErr := state.Server.GetConfig()
+ config, configErr := server.GetConfig()
if configErr != nil {
return "", configErr