diff options
| -rw-r--r-- | exports/exports.go | 15 | ||||
| -rw-r--r-- | src/api.go | 8 | ||||
| -rw-r--r-- | src/fsm.go | 16 | ||||
| -rw-r--r-- | src/openvpn.go | 4 | ||||
| -rw-r--r-- | src/server.go | 55 | ||||
| -rw-r--r-- | src/server_test.go | 18 | ||||
| -rw-r--r-- | src/state.go | 7 | ||||
| -rw-r--r-- | src/wireguard.go | 4 | ||||
| -rw-r--r-- | wrappers/python/eduvpncommon/__init__.py | 1 | ||||
| -rw-r--r-- | wrappers/python/eduvpncommon/main.py | 7 | ||||
| -rw-r--r-- | wrappers/python/main.py | 6 |
11 files changed, 95 insertions, 46 deletions
diff --git a/exports/exports.go b/exports/exports.go index e34721e..7d2b1ce 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -18,9 +18,9 @@ import "github.com/jwijenbergh/eduvpn-common/src" var P_StateCallback C.PythonCB -func StateCallback(old_state string, new_state string, data string) { +func StateCallback(old_state string, new_state string, data string) string { if P_StateCallback == nil { - return + return "" } oldState_c := C.CString(old_state) newState_c := C.CString(new_state) @@ -29,6 +29,12 @@ func StateCallback(old_state string, new_state string, data string) { C.free(unsafe.Pointer(oldState_c)) C.free(unsafe.Pointer(newState_c)) C.free(unsafe.Pointer(data_c)) + + // Get state data and reset + state := eduvpn.GetVPNState() + received_data := state.StateCallbackData + state.StateCallbackData = "" + return received_data } //export Register @@ -75,6 +81,11 @@ func GetServersList() (*C.char, *C.char) { return C.CString(servers), C.CString(ErrorToString(serversErr)) } +//export SendData +func SendData(data *C.char) { + eduvpn.GetVPNState().StateCallbackData = C.GoString(data) +} + //export FreeString func FreeString(addr *C.char) { C.free(unsafe.Pointer(addr)) @@ -62,9 +62,7 @@ func (server *Server) APIInfo() error { } server.Profiles = structure - - // FIXME: Implement profile selection callback - server.Profiles.Current = 0 + server.ProfilesRaw = string(body) return nil } @@ -75,7 +73,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str } urlForm := url.Values{ - "profile_id": {"default"}, + "profile_id": {profile_id}, "public_key": {pubkey}, } header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) @@ -94,7 +92,7 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro } urlForm := url.Values{ - "profile_id": {"default"}, + "profile_id": {profile_id}, } header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { @@ -23,6 +23,9 @@ const ( // Authenticated means the OAuth process has finished and the user is now authenticated with the server AUTHENTICATED + // Ask profile means the go code is asking for a profile selection from the ui + ASK_PROFILE + // Connected means the user has been connected to the server CONNECTED ) @@ -37,6 +40,8 @@ func (s FSMStateID) String() string { return "Chosen_Server" case OAUTH_STARTED: return "OAuth_Started" + case ASK_PROFILE: + return "Ask_Profile" case AUTHENTICATED: return "Authenticated" case CONNECTED: @@ -88,19 +93,21 @@ func (eduvpn *VPNState) writeGraph() { f.WriteString(graph) } -func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) bool { +func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) (bool, string) { ok := eduvpn.HasTransition(newState) + received := "" + if ok { oldState := eduvpn.FSM.Current eduvpn.FSM.Current = newState if eduvpn.Debug { eduvpn.writeGraph() } - eduvpn.StateCallback(oldState.String(), newState.String(), data) + received = eduvpn.StateCallback(oldState.String(), newState.String(), data) } - return ok + return ok, received } func (eduvpn *VPNState) generateDotGraph() string { @@ -141,7 +148,8 @@ func (eduvpn *VPNState) InitializeFSM() { NO_SERVER: {{CHOSEN_SERVER, "User chooses a server"}}, CHOSEN_SERVER: {{AUTHENTICATED, "Found tokens in config"}, {OAUTH_STARTED, "No tokens found in config"}}, OAUTH_STARTED: {{AUTHENTICATED, "User authorizes with browser"}}, - AUTHENTICATED: {{CONNECTED, "OS reports connected"}, {OAUTH_STARTED, "Re-authenticate with OAuth"}}, + AUTHENTICATED: {{CONNECTED, "OS reports connected"}, {OAUTH_STARTED, "Re-authenticate with OAuth"}, {ASK_PROFILE, "Connect, multiple profiles detected"}}, + ASK_PROFILE: {{CONNECTED, "OS reports connected"}}, CONNECTED: {{AUTHENTICATED, "OS reports disconnected"}}, }, Current: DEREGISTERED, diff --git a/src/openvpn.go b/src/openvpn.go index 2cab2c4..0cf8e36 100644 --- a/src/openvpn.go +++ b/src/openvpn.go @@ -1,7 +1,7 @@ package eduvpn -func (server *Server) OpenVPNGetConfig() (string, error) { - configOpenVPN, _, configErr := server.APIConnectOpenVPN("default") +func (server *Server) OpenVPNGetConfig(profile_id string) (string, error) { + configOpenVPN, _, configErr := server.APIConnectOpenVPN(profile_id) if configErr != nil { return "", configErr diff --git a/src/server.go b/src/server.go index 8a600d4..20d9136 100644 --- a/src/server.go +++ b/src/server.go @@ -10,6 +10,7 @@ type Server struct { Endpoints *ServerEndpoints `json:"endpoints"` OAuth *OAuth `json:"oauth"` Profiles *ServerProfileInfo `json:"profiles"` + ProfilesRaw string `json:"profiles_raw"` } type ServerProfile struct { @@ -20,7 +21,7 @@ type ServerProfile struct { } type ServerProfileInfo struct { - Current uint8 `json:"current_profile"` + Current string `json:"current_profile"` Info struct { ProfileList []ServerProfile `json:"profile_list"` } `json:"info"` @@ -84,17 +85,6 @@ func (server *Server) GetEndpoints() error { return nil } -func (profiles *ServerProfileInfo) getCurrentProfile() (*ServerProfile, error) { - if profiles.Info.ProfileList == nil { - return nil, errors.New("No server profiles") - } - - if (int)(profiles.Current) >= len(profiles.Info.ProfileList) { - return nil, errors.New("Invalid profile") - } - return &profiles.Info.ProfileList[profiles.Current], nil -} - func (profile *ServerProfile) supportsWireguard() bool { for _, proto := range profile.VPNProtoList { if proto == "wireguard" { @@ -104,12 +94,31 @@ func (profile *ServerProfile) supportsWireguard() bool { return false } -func (server *Server) GetCurrentProfile() (*ServerProfile, error) { - if server.Profiles == nil { - return nil, errors.New("No server profiles found") +func (server *Server) getProfileForID(profile_id string) (*ServerProfile, error) { + for _, profile := range server.Profiles.Info.ProfileList { + if profile.ID == profile_id { + return &profile, nil + } } + return nil, errors.New("no profile found for id") +} + +func (server *Server) getConfigWithProfile(profile_id string) (string, error) { + profile, profileErr := server.getProfileForID(profile_id) - return server.Profiles.getCurrentProfile() + if profileErr != nil { + return "", profileErr + } + + if profile.supportsWireguard() { + return server.WireguardGetConfig(profile_id) + } + return server.OpenVPNGetConfig(profile_id) +} + +func (server *Server) askForProfileID() (string, error) { + _, profile_id := GetVPNState().GoTransition(ASK_PROFILE, server.ProfilesRaw) + return profile_id, nil } func (server *Server) GetConfig() (string, error) { @@ -119,14 +128,16 @@ func (server *Server) GetConfig() (string, error) { return "", infoErr } - profile, profileErr := server.GetCurrentProfile() + // Set the current profile if there is only one profile + if len(server.Profiles.Info.ProfileList) == 1 { + return server.getConfigWithProfile(server.Profiles.Info.ProfileList[0].ID) + } + + profile_id, profileErr := server.askForProfileID() if profileErr != nil { - return "", profileErr + return "", nil } - if profile.supportsWireguard() { - return server.WireguardGetConfig() - } - return server.OpenVPNGetConfig() + return server.getConfigWithProfile(profile_id) } diff --git a/src/server_test.go b/src/server_test.go index f914583..5b6ec5a 100644 --- a/src/server_test.go +++ b/src/server_test.go @@ -35,10 +35,13 @@ func LoginOAuthSelenium(t *testing.T, url string) { } } -func StateCallback(t *testing.T, oldState string, newState string, data string) { +func StateCallback(t *testing.T, oldState string, newState string, data string) string { if newState == "OAuth_Started" { go LoginOAuthSelenium(t, data) } + + // We have no data to send back + return "" } func Test_server(t *testing.T) { @@ -47,8 +50,8 @@ func Test_server(t *testing.T) { // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) { - StateCallback(t, old, new, data) + state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) string { + return StateCallback(t, old, new, data) }, false) _, configErr := state.Connect("https://eduvpnserver") @@ -65,7 +68,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) { + state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) string { if newState == "OAuth_Started" { baseURL := "http://127.0.0.1:8000/callback" url, err := HTTPConstructURL(baseURL, parameters) @@ -73,7 +76,10 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect t.Errorf("Error: Constructing url %s with parameters %s", baseURL, fmt.Sprint(parameters)) } go http.Get(url) + } + // We have no data to send back + return "" }, false) _, configErr := state.Connect("https://eduvpnserver") @@ -123,8 +129,8 @@ func Test_token_expired(t *testing.T) { // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) { - StateCallback(t, old, new, data) + state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) string { + return StateCallback(t, old, new, data) }, false) diff --git a/src/state.go b/src/state.go index 1eb8ebd..95bb75c 100644 --- a/src/state.go +++ b/src/state.go @@ -8,7 +8,8 @@ type VPNState struct { // Info passed by the client ConfigDirectory string `json:"-"` Name string `json:"-"` - StateCallback func(string, string, string) `json:"-"` + StateCallback func(string, string, string) string `json:"-"` + StateCallbackData string `json:"-"` // The chosen server Server *Server `json:"server"` @@ -26,7 +27,7 @@ type VPNState struct { Debug bool `json:"-"` } -func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error { +func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string) string, debug bool) error { state.InitializeFSM() if !state.InState(DEREGISTERED) { return errors.New("app already registered") @@ -71,7 +72,7 @@ func (state *VPNState) Deregister() error { } func (state *VPNState) Connect(url string) (string, error) { - if state.Server == nil { + if state.Server == nil || state.Server.BaseURL != url { state.Server = &Server{} } initializeErr := state.Server.Initialize(url) diff --git a/src/wireguard.go b/src/wireguard.go index b701b1d..3f7da40 100644 --- a/src/wireguard.go +++ b/src/wireguard.go @@ -26,7 +26,7 @@ func wireguardConfigAddKey(config string, key wgtypes.Key) string { return interface_re.ReplaceAllString(config, to_replace) } -func (server *Server) WireguardGetConfig() (string, error) { +func (server *Server) WireguardGetConfig(profile_id string) (string, error) { wireguardKey, wireguardErr := wireguardGenerateKey() if wireguardErr != nil { @@ -34,7 +34,7 @@ func (server *Server) WireguardGetConfig() (string, error) { } wireguardPublicKey := wireguardKey.PublicKey().String() - configWireguard, _, configErr := server.APIConnectWireguard("default", wireguardPublicKey) + configWireguard, _, configErr := server.APIConnectWireguard(profile_id, wireguardPublicKey) if configErr != nil { return "", configErr diff --git a/wrappers/python/eduvpncommon/__init__.py b/wrappers/python/eduvpncommon/__init__.py index 056ce18..50af8eb 100644 --- a/wrappers/python/eduvpncommon/__init__.py +++ b/wrappers/python/eduvpncommon/__init__.py @@ -38,6 +38,7 @@ lib.Deregister.argtypes, lib.Deregister.restype = [], None lib.Register.argtypes, lib.Register.restype = [c_char_p, c_char_p, VPNStateChange, c_int], c_void_p lib.GetOrganizationsList.argtypes, lib.GetOrganizationsList.restype = [], DataError lib.GetServersList.argtypes, lib.GetServersList.restype = [], DataError +lib.SendData.argtypes, lib.SendData.restype = [c_char_p], None # We have to use c_void_p instead of c_char_p to free it properly # See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None diff --git a/wrappers/python/eduvpncommon/main.py b/wrappers/python/eduvpncommon/main.py index 53b9831..8d2de9c 100644 --- a/wrappers/python/eduvpncommon/main.py +++ b/wrappers/python/eduvpncommon/main.py @@ -47,6 +47,10 @@ def register_callback(eduvpn): ) +def SendData(data): + lib.SendData(data.encode("utf-8")) + + class EduVPN(object): def __init__(self, name, config_directory): self.event_handler = EventHandler() @@ -75,6 +79,9 @@ class EduVPN(object): def callback(self, old_state, new_state, data): self.event.run(old_state, new_state, data) + def send_data(self, data): + return SendData(data) + class EventHandler(object): def __init__(self): diff --git a/wrappers/python/main.py b/wrappers/python/main.py index d6a568e..bc29239 100644 --- a/wrappers/python/main.py +++ b/wrappers/python/main.py @@ -11,6 +11,12 @@ def oauth_initialized(url): webbrowser.open(url) +@_eduvpn.event.on("Ask_Profile", eduvpn.StateType.Enter) +def ask_profile(profiles): + print("ASK PROFILE CB", profiles) + _eduvpn.send_data("prefer-openvpn") + + success = _eduvpn.register(debug=True) if not success: |
