diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/api.go | 8 | ||||
| -rw-r--r-- | src/fsm.go | 16 | ||||
| -rw-r--r-- | src/openvpn.go | 4 | ||||
| -rw-r--r-- | src/server.go | 55 | ||||
| -rw-r--r-- | src/server_test.go | 18 | ||||
| -rw-r--r-- | src/state.go | 7 | ||||
| -rw-r--r-- | src/wireguard.go | 4 |
7 files changed, 68 insertions, 44 deletions
@@ -62,9 +62,7 @@ func (server *Server) APIInfo() error { } server.Profiles = structure - - // FIXME: Implement profile selection callback - server.Profiles.Current = 0 + server.ProfilesRaw = string(body) return nil } @@ -75,7 +73,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str } urlForm := url.Values{ - "profile_id": {"default"}, + "profile_id": {profile_id}, "public_key": {pubkey}, } header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) @@ -94,7 +92,7 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro } urlForm := url.Values{ - "profile_id": {"default"}, + "profile_id": {profile_id}, } header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { @@ -23,6 +23,9 @@ const ( // Authenticated means the OAuth process has finished and the user is now authenticated with the server AUTHENTICATED + // Ask profile means the go code is asking for a profile selection from the ui + ASK_PROFILE + // Connected means the user has been connected to the server CONNECTED ) @@ -37,6 +40,8 @@ func (s FSMStateID) String() string { return "Chosen_Server" case OAUTH_STARTED: return "OAuth_Started" + case ASK_PROFILE: + return "Ask_Profile" case AUTHENTICATED: return "Authenticated" case CONNECTED: @@ -88,19 +93,21 @@ func (eduvpn *VPNState) writeGraph() { f.WriteString(graph) } -func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) bool { +func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) (bool, string) { ok := eduvpn.HasTransition(newState) + received := "" + if ok { oldState := eduvpn.FSM.Current eduvpn.FSM.Current = newState if eduvpn.Debug { eduvpn.writeGraph() } - eduvpn.StateCallback(oldState.String(), newState.String(), data) + received = eduvpn.StateCallback(oldState.String(), newState.String(), data) } - return ok + return ok, received } func (eduvpn *VPNState) generateDotGraph() string { @@ -141,7 +148,8 @@ func (eduvpn *VPNState) InitializeFSM() { NO_SERVER: {{CHOSEN_SERVER, "User chooses a server"}}, CHOSEN_SERVER: {{AUTHENTICATED, "Found tokens in config"}, {OAUTH_STARTED, "No tokens found in config"}}, OAUTH_STARTED: {{AUTHENTICATED, "User authorizes with browser"}}, - AUTHENTICATED: {{CONNECTED, "OS reports connected"}, {OAUTH_STARTED, "Re-authenticate with OAuth"}}, + AUTHENTICATED: {{CONNECTED, "OS reports connected"}, {OAUTH_STARTED, "Re-authenticate with OAuth"}, {ASK_PROFILE, "Connect, multiple profiles detected"}}, + ASK_PROFILE: {{CONNECTED, "OS reports connected"}}, CONNECTED: {{AUTHENTICATED, "OS reports disconnected"}}, }, Current: DEREGISTERED, diff --git a/src/openvpn.go b/src/openvpn.go index 2cab2c4..0cf8e36 100644 --- a/src/openvpn.go +++ b/src/openvpn.go @@ -1,7 +1,7 @@ package eduvpn -func (server *Server) OpenVPNGetConfig() (string, error) { - configOpenVPN, _, configErr := server.APIConnectOpenVPN("default") +func (server *Server) OpenVPNGetConfig(profile_id string) (string, error) { + configOpenVPN, _, configErr := server.APIConnectOpenVPN(profile_id) if configErr != nil { return "", configErr diff --git a/src/server.go b/src/server.go index 8a600d4..20d9136 100644 --- a/src/server.go +++ b/src/server.go @@ -10,6 +10,7 @@ type Server struct { Endpoints *ServerEndpoints `json:"endpoints"` OAuth *OAuth `json:"oauth"` Profiles *ServerProfileInfo `json:"profiles"` + ProfilesRaw string `json:"profiles_raw"` } type ServerProfile struct { @@ -20,7 +21,7 @@ type ServerProfile struct { } type ServerProfileInfo struct { - Current uint8 `json:"current_profile"` + Current string `json:"current_profile"` Info struct { ProfileList []ServerProfile `json:"profile_list"` } `json:"info"` @@ -84,17 +85,6 @@ func (server *Server) GetEndpoints() error { return nil } -func (profiles *ServerProfileInfo) getCurrentProfile() (*ServerProfile, error) { - if profiles.Info.ProfileList == nil { - return nil, errors.New("No server profiles") - } - - if (int)(profiles.Current) >= len(profiles.Info.ProfileList) { - return nil, errors.New("Invalid profile") - } - return &profiles.Info.ProfileList[profiles.Current], nil -} - func (profile *ServerProfile) supportsWireguard() bool { for _, proto := range profile.VPNProtoList { if proto == "wireguard" { @@ -104,12 +94,31 @@ func (profile *ServerProfile) supportsWireguard() bool { return false } -func (server *Server) GetCurrentProfile() (*ServerProfile, error) { - if server.Profiles == nil { - return nil, errors.New("No server profiles found") +func (server *Server) getProfileForID(profile_id string) (*ServerProfile, error) { + for _, profile := range server.Profiles.Info.ProfileList { + if profile.ID == profile_id { + return &profile, nil + } } + return nil, errors.New("no profile found for id") +} + +func (server *Server) getConfigWithProfile(profile_id string) (string, error) { + profile, profileErr := server.getProfileForID(profile_id) - return server.Profiles.getCurrentProfile() + if profileErr != nil { + return "", profileErr + } + + if profile.supportsWireguard() { + return server.WireguardGetConfig(profile_id) + } + return server.OpenVPNGetConfig(profile_id) +} + +func (server *Server) askForProfileID() (string, error) { + _, profile_id := GetVPNState().GoTransition(ASK_PROFILE, server.ProfilesRaw) + return profile_id, nil } func (server *Server) GetConfig() (string, error) { @@ -119,14 +128,16 @@ func (server *Server) GetConfig() (string, error) { return "", infoErr } - profile, profileErr := server.GetCurrentProfile() + // Set the current profile if there is only one profile + if len(server.Profiles.Info.ProfileList) == 1 { + return server.getConfigWithProfile(server.Profiles.Info.ProfileList[0].ID) + } + + profile_id, profileErr := server.askForProfileID() if profileErr != nil { - return "", profileErr + return "", nil } - if profile.supportsWireguard() { - return server.WireguardGetConfig() - } - return server.OpenVPNGetConfig() + return server.getConfigWithProfile(profile_id) } diff --git a/src/server_test.go b/src/server_test.go index f914583..5b6ec5a 100644 --- a/src/server_test.go +++ b/src/server_test.go @@ -35,10 +35,13 @@ func LoginOAuthSelenium(t *testing.T, url string) { } } -func StateCallback(t *testing.T, oldState string, newState string, data string) { +func StateCallback(t *testing.T, oldState string, newState string, data string) string { if newState == "OAuth_Started" { go LoginOAuthSelenium(t, data) } + + // We have no data to send back + return "" } func Test_server(t *testing.T) { @@ -47,8 +50,8 @@ func Test_server(t *testing.T) { // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) { - StateCallback(t, old, new, data) + state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) string { + return StateCallback(t, old, new, data) }, false) _, configErr := state.Connect("https://eduvpnserver") @@ -65,7 +68,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) { + state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) string { if newState == "OAuth_Started" { baseURL := "http://127.0.0.1:8000/callback" url, err := HTTPConstructURL(baseURL, parameters) @@ -73,7 +76,10 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect t.Errorf("Error: Constructing url %s with parameters %s", baseURL, fmt.Sprint(parameters)) } go http.Get(url) + } + // We have no data to send back + return "" }, false) _, configErr := state.Connect("https://eduvpnserver") @@ -123,8 +129,8 @@ func Test_token_expired(t *testing.T) { // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) { - StateCallback(t, old, new, data) + state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) string { + return StateCallback(t, old, new, data) }, false) diff --git a/src/state.go b/src/state.go index 1eb8ebd..95bb75c 100644 --- a/src/state.go +++ b/src/state.go @@ -8,7 +8,8 @@ type VPNState struct { // Info passed by the client ConfigDirectory string `json:"-"` Name string `json:"-"` - StateCallback func(string, string, string) `json:"-"` + StateCallback func(string, string, string) string `json:"-"` + StateCallbackData string `json:"-"` // The chosen server Server *Server `json:"server"` @@ -26,7 +27,7 @@ type VPNState struct { Debug bool `json:"-"` } -func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error { +func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string) string, debug bool) error { state.InitializeFSM() if !state.InState(DEREGISTERED) { return errors.New("app already registered") @@ -71,7 +72,7 @@ func (state *VPNState) Deregister() error { } func (state *VPNState) Connect(url string) (string, error) { - if state.Server == nil { + if state.Server == nil || state.Server.BaseURL != url { state.Server = &Server{} } initializeErr := state.Server.Initialize(url) diff --git a/src/wireguard.go b/src/wireguard.go index b701b1d..3f7da40 100644 --- a/src/wireguard.go +++ b/src/wireguard.go @@ -26,7 +26,7 @@ func wireguardConfigAddKey(config string, key wgtypes.Key) string { return interface_re.ReplaceAllString(config, to_replace) } -func (server *Server) WireguardGetConfig() (string, error) { +func (server *Server) WireguardGetConfig(profile_id string) (string, error) { wireguardKey, wireguardErr := wireguardGenerateKey() if wireguardErr != nil { @@ -34,7 +34,7 @@ func (server *Server) WireguardGetConfig() (string, error) { } wireguardPublicKey := wireguardKey.PublicKey().String() - configWireguard, _, configErr := server.APIConnectWireguard("default", wireguardPublicKey) + configWireguard, _, configErr := server.APIConnectWireguard(profile_id, wireguardPublicKey) if configErr != nil { return "", configErr |
