diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/fsm.go | 8 | ||||
| -rw-r--r-- | src/openvpn.go | 3 | ||||
| -rw-r--r-- | src/server.go | 26 | ||||
| -rw-r--r-- | src/server_test.go | 17 | ||||
| -rw-r--r-- | src/state.go | 10 | ||||
| -rw-r--r-- | src/wireguard.go | 3 |
6 files changed, 32 insertions, 35 deletions
@@ -119,21 +119,19 @@ func (eduvpn *VPNState) writeGraph() { f.WriteString(graph) } -func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) (bool, string) { +func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) bool { ok := eduvpn.HasTransition(newState) - received := "" - if ok { oldState := eduvpn.FSM.Current eduvpn.FSM.Current = newState if eduvpn.Debug { eduvpn.writeGraph() } - received = eduvpn.StateCallback(oldState.String(), newState.String(), data) + eduvpn.StateCallback(oldState.String(), newState.String(), data) } - return ok, received + return ok } func (eduvpn *VPNState) generateDotGraph() string { diff --git a/src/openvpn.go b/src/openvpn.go index 0cf8e36..95e1328 100644 --- a/src/openvpn.go +++ b/src/openvpn.go @@ -1,6 +1,7 @@ package eduvpn -func (server *Server) OpenVPNGetConfig(profile_id string) (string, error) { +func (server *Server) OpenVPNGetConfig() (string, error) { + profile_id := server.Profiles.Current configOpenVPN, _, configErr := server.APIConnectOpenVPN(profile_id) if configErr != nil { diff --git a/src/server.go b/src/server.go index d8afb01..b7d55cb 100644 --- a/src/server.go +++ b/src/server.go @@ -94,7 +94,8 @@ func (profile *ServerProfile) supportsWireguard() bool { return false } -func (server *Server) getProfileForID(profile_id string) (*ServerProfile, error) { +func (server *Server) getCurrentProfile() (*ServerProfile, error) { + profile_id := server.Profiles.Current for _, profile := range server.Profiles.Info.ProfileList { if profile.ID == profile_id { return &profile, nil @@ -103,28 +104,28 @@ func (server *Server) getProfileForID(profile_id string) (*ServerProfile, error) return nil, errors.New("no profile found for id") } -func (server *Server) getConfigWithProfile(profile_id string) (string, error) { +func (server *Server) getConfigWithProfile() (string, error) { if !GetVPNState().HasTransition(HAS_CONFIG) { return "", errors.New("cannot get a config with a profile, invalid state") } - profile, profileErr := server.getProfileForID(profile_id) + profile, profileErr := server.getCurrentProfile() if profileErr != nil { return "", profileErr } if profile.supportsWireguard() { - return server.WireguardGetConfig(profile_id) + return server.WireguardGetConfig() } - return server.OpenVPNGetConfig(profile_id) + return server.OpenVPNGetConfig() } -func (server *Server) askForProfileID() (string, error) { +func (server *Server) askForProfileID() error { if !GetVPNState().HasTransition(ASK_PROFILE) { - return "", errors.New("cannot ask for a profile id, invalid state") + return errors.New("cannot ask for a profile id, invalid state") } - _, profile_id := GetVPNState().GoTransition(ASK_PROFILE, server.ProfilesRaw) - return profile_id, nil + GetVPNState().GoTransition(ASK_PROFILE, server.ProfilesRaw) + return nil } func (server *Server) GetConfig() (string, error) { @@ -139,14 +140,15 @@ func (server *Server) GetConfig() (string, error) { // Set the current profile if there is only one profile if len(server.Profiles.Info.ProfileList) == 1 { - return server.getConfigWithProfile(server.Profiles.Info.ProfileList[0].ID) + server.Profiles.Current = server.Profiles.Info.ProfileList[0].ID + return server.getConfigWithProfile() } - profile_id, profileErr := server.askForProfileID() + profileErr := server.askForProfileID() if profileErr != nil { return "", nil } - return server.getConfigWithProfile(profile_id) + return server.getConfigWithProfile() } diff --git a/src/server_test.go b/src/server_test.go index 3045983..6199884 100644 --- a/src/server_test.go +++ b/src/server_test.go @@ -35,13 +35,10 @@ func LoginOAuthSelenium(t *testing.T, url string) { } } -func StateCallback(t *testing.T, oldState string, newState string, data string) string { +func StateCallback(t *testing.T, oldState string, newState string, data string) { if newState == "OAuth_Started" { go LoginOAuthSelenium(t, data) } - - // We have no data to send back - return "" } func Test_server(t *testing.T) { @@ -50,8 +47,8 @@ func Test_server(t *testing.T) { // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) string { - return StateCallback(t, old, new, data) + state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) { + StateCallback(t, old, new, data) }, false) _, configErr := state.Connect("https://eduvpnserver") @@ -68,7 +65,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) string { + state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) { if newState == "OAuth_Started" { baseURL := "http://127.0.0.1:8000/callback" url, err := HTTPConstructURL(baseURL, parameters) @@ -78,8 +75,6 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect go http.Get(url) } - // We have no data to send back - return "" }, false) _, configErr := state.Connect("https://eduvpnserver") @@ -129,8 +124,8 @@ func Test_token_expired(t *testing.T) { // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) string { - return StateCallback(t, old, new, data) + state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) { + StateCallback(t, old, new, data) }, false) _, configErr := state.Connect("https://eduvpnserver") diff --git a/src/state.go b/src/state.go index b9bb920..6c740c1 100644 --- a/src/state.go +++ b/src/state.go @@ -6,10 +6,10 @@ import ( type VPNState struct { // Info passed by the client - ConfigDirectory string `json:"-"` - Name string `json:"-"` - StateCallback func(string, string, string) string `json:"-"` - StateCallbackData string `json:"-"` + ConfigDirectory string `json:"-"` + Name string `json:"-"` + StateCallback func(string, string, string) `json:"-"` + StateCallbackData string `json:"-"` // The chosen server Server *Server `json:"server"` @@ -27,7 +27,7 @@ type VPNState struct { Debug bool `json:"-"` } -func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string) string, debug bool) error { +func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error { state.InitializeFSM() if !state.InState(DEREGISTERED) { return errors.New("app already registered") diff --git a/src/wireguard.go b/src/wireguard.go index 3f7da40..2f1c41c 100644 --- a/src/wireguard.go +++ b/src/wireguard.go @@ -26,7 +26,8 @@ func wireguardConfigAddKey(config string, key wgtypes.Key) string { return interface_re.ReplaceAllString(config, to_replace) } -func (server *Server) WireguardGetConfig(profile_id string) (string, error) { +func (server *Server) WireguardGetConfig() (string, error) { + profile_id := server.Profiles.Current wireguardKey, wireguardErr := wireguardGenerateKey() if wireguardErr != nil { |
