summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-04-26 16:31:45 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-04-26 16:31:45 +0200
commit39f0e8e26ab37c4b83c1933ba90bae15cd7e04fc (patch)
treea1d496c73fa5d2793105151835cb7e06b82c17b8
parent5608d9a858c2323002305ea1fedb5793a40edc58 (diff)
State: Add a state map to exports instead of a global singleton
-rw-r--r--cmd/cli/main.go2
-rw-r--r--exports/exports.go79
-rw-r--r--state.go9
-rw-r--r--wrappers/python/eduvpncommon/__init__.py12
-rw-r--r--wrappers/python/eduvpncommon/main.py42
5 files changed, 97 insertions, 47 deletions
diff --git a/cmd/cli/main.go b/cmd/cli/main.go
index 1406175..c31716f 100644
--- a/cmd/cli/main.go
+++ b/cmd/cli/main.go
@@ -34,7 +34,7 @@ func main() {
urlString = "https://" + urlString
}
- state := eduvpn.GetVPNState()
+ state := &eduvpn.VPNState{}
state.Register("org.eduvpn.app.linux", "configs", logState, true)
config, configErr := state.Connect(urlString)
diff --git a/exports/exports.go b/exports/exports.go
index d81f1ce..f21a354 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -13,11 +13,16 @@ void call_callback(PythonCB callback, const char* oldstate, const char* newstate
}
*/
import "C"
+import "errors"
+import "fmt"
import "unsafe"
import "github.com/jwijenbergh/eduvpn-common"
var P_StateCallback C.PythonCB
+
+var VPNStates map[string]*eduvpn.VPNState
+
func StateCallback(old_state string, new_state string, data string) {
if P_StateCallback == nil {
return
@@ -31,18 +36,44 @@ func StateCallback(old_state string, new_state string, data string) {
C.free(unsafe.Pointer(data_c))
}
+
+func GetVPNState(name string) (*eduvpn.VPNState, error) {
+ state, exists := VPNStates[name]
+
+ if !exists || state == nil {
+ return nil, errors.New(fmt.Sprintf("State with name %s not found", name))
+ }
+
+ return state, nil
+}
+
//export Register
func Register(name *C.char, config_directory *C.char, stateCallback C.PythonCB, debug C.int) *C.char {
+ state := &eduvpn.VPNState{}
+ nameStr := C.GoString(name)
+
+ if VPNStates == nil {
+ VPNStates = make(map[string]*eduvpn.VPNState)
+ }
+ VPNStates[nameStr] = state
P_StateCallback = stateCallback
- state := eduvpn.GetVPNState()
- registerErr := state.Register(C.GoString(name), C.GoString(config_directory), StateCallback, debug != 0)
+ registerErr := state.Register(nameStr, C.GoString(config_directory), StateCallback, debug != 0)
+
+ if registerErr != nil {
+ delete(VPNStates, nameStr)
+ }
return C.CString(ErrorToString(registerErr))
}
//export Deregister
-func Deregister() {
- state := eduvpn.GetVPNState()
+func Deregister(name *C.char) *C.char {
+ nameStr := C.GoString(name)
+ state, stateErr := GetVPNState(nameStr)
+ if stateErr != nil {
+ return C.CString(ErrorToString(stateErr))
+ }
state.Deregister()
+ return nil
}
func ErrorToString(error error) string {
@@ -54,38 +85,58 @@ func ErrorToString(error error) string {
}
//export CancelOAuth
-func CancelOAuth() (*C.char) {
- state := eduvpn.GetVPNState()
+func CancelOAuth(name *C.char) *C.char {
+ nameStr := C.GoString(name)
+ state, stateErr := GetVPNState(nameStr)
+ if stateErr != nil {
+ return C.CString(ErrorToString(stateErr))
+ }
cancelErr := state.CancelOAuth()
cancelErrString := ErrorToString(cancelErr)
return C.CString(cancelErrString)
}
//export Connect
-func Connect(url *C.char) (*C.char, *C.char) {
- state := eduvpn.GetVPNState()
+func Connect(name *C.char, url *C.char) (*C.char, *C.char) {
+ nameStr := C.GoString(name)
+ state, stateErr := GetVPNState(nameStr)
+ if stateErr != nil {
+ return nil, C.CString(ErrorToString(stateErr))
+ }
config, configErr := state.Connect(C.GoString(url))
return C.CString(config), C.CString(ErrorToString(configErr))
}
//export GetOrganizationsList
-func GetOrganizationsList() (*C.char, *C.char) {
- state := eduvpn.GetVPNState()
+func GetOrganizationsList(name *C.char) (*C.char, *C.char) {
+ nameStr := C.GoString(name)
+ state, stateErr := GetVPNState(nameStr)
+ if stateErr != nil {
+ return nil, C.CString(ErrorToString(stateErr))
+ }
organizations, organizationsErr := state.GetDiscoOrganizations()
return C.CString(organizations), C.CString(ErrorToString(organizationsErr))
}
//export GetServersList
-func GetServersList() (*C.char, *C.char) {
- state := eduvpn.GetVPNState()
+func GetServersList(name *C.char) (*C.char, *C.char) {
+ nameStr := C.GoString(name)
+ state, stateErr := GetVPNState(nameStr)
+ if stateErr != nil {
+ return nil, C.CString(ErrorToString(stateErr))
+ }
servers, serversErr := state.GetDiscoServers()
return C.CString(servers), C.CString(ErrorToString(serversErr))
}
//export SetProfileID
-func SetProfileID(data *C.char) *C.char {
- state := eduvpn.GetVPNState()
+func SetProfileID(name *C.char, data *C.char) *C.char {
+ nameStr := C.GoString(name)
+ state, stateErr := GetVPNState(nameStr)
+ if stateErr != nil {
+ return C.CString(ErrorToString(stateErr))
+ }
profileErr := state.SetProfileID(C.GoString(data))
return C.CString(ErrorToString(profileErr))
}
diff --git a/state.go b/state.go
index 1f9f52c..3ca0a4b 100644
--- a/state.go
+++ b/state.go
@@ -25,15 +25,6 @@ type VPNState struct {
Debug bool `json:"-"`
}
-var VPNStateInstance *VPNState
-
-func GetVPNState() *VPNState {
- if VPNStateInstance == nil {
- VPNStateInstance = &VPNState{}
- }
- return VPNStateInstance
-}
-
func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error {
if !state.FSM.InState(internal.DEREGISTERED) {
return errors.New("app already registered")
diff --git a/wrappers/python/eduvpncommon/__init__.py b/wrappers/python/eduvpncommon/__init__.py
index e0fe0d0..1df305b 100644
--- a/wrappers/python/eduvpncommon/__init__.py
+++ b/wrappers/python/eduvpncommon/__init__.py
@@ -33,13 +33,13 @@ class DataError(Structure):
VPNStateChange = CFUNCTYPE(None, c_char_p, c_char_p, c_char_p)
# Exposed functions
-lib.Connect.argtypes, lib.Connect.restype = [c_char_p], DataError
-lib.Deregister.argtypes, lib.Deregister.restype = [], None
+lib.Connect.argtypes, lib.Connect.restype = [c_char_p, c_char_p], DataError
+lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], c_void_p
lib.Register.argtypes, lib.Register.restype = [c_char_p, c_char_p, VPNStateChange, c_int], c_void_p
-lib.GetOrganizationsList.argtypes, lib.GetOrganizationsList.restype = [], DataError
-lib.GetServersList.argtypes, lib.GetServersList.restype = [], DataError
-lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [], c_void_p
-lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p], c_void_p
+lib.GetOrganizationsList.argtypes, lib.GetOrganizationsList.restype = [c_char_p], DataError
+lib.GetServersList.argtypes, lib.GetServersList.restype = [c_char_p], DataError
+lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p
+lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p
# We have to use c_void_p instead of c_char_p to free it properly
# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error
lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None
diff --git a/wrappers/python/eduvpncommon/main.py b/wrappers/python/eduvpncommon/main.py
index eae7014..5474ade 100644
--- a/wrappers/python/eduvpncommon/main.py
+++ b/wrappers/python/eduvpncommon/main.py
@@ -18,23 +18,29 @@ def Register(name, config_directory, state_callback, debug):
err_string = GetPtrString(ptr_err)
return err_string
-def CancelOAuth():
- ptr_err = lib.CancelOAuth()
+def CancelOAuth(name):
+ name_bytes = name.encode("utf-8")
+ ptr_err = lib.CancelOAuth(name_bytes)
err_string = GetPtrString(ptr_err)
return err_string
-def Deregister():
- lib.Deregister()
+def Deregister(name):
+ name_bytes = name.encode("utf-8")
+ ptr_err = lib.Deregister(name_bytes)
+ err_string = GetPtrString(ptr_err)
+ return err_string
-def GetDiscoServers():
- servers, serversErr = GetDataError(lib.GetServersList())
- organizations, organizationsErr = GetDataError(lib.GetOrganizationsList())
+def GetDiscoServers(name):
+ name_bytes = name.encode("utf-8")
+ servers, serversErr = GetDataError(lib.GetServersList(name_bytes))
+ organizations, organizationsErr = GetDataError(lib.GetOrganizationsList(name_bytes))
return servers, serversErr, organizations, organizationsErr
-def Connect(url):
+def Connect(name, url):
+ name_bytes = name.encode("utf-8")
url_bytes = url.encode("utf-8")
- data_error = lib.Connect(url_bytes)
+ data_error = lib.Connect(name_bytes, url_bytes)
return GetDataError(data_error)
@@ -51,8 +57,10 @@ def register_callback(eduvpn):
)
-def SetProfileID(profile_id) -> str:
- error_string = lib.SetProfileID(profile_id.encode("utf-8"))
+def SetProfileID(name, profile_id) -> str:
+ name_bytes = name.encode("utf-8")
+ profile_bytes = profile_id.encode("utf-8")
+ error_string = lib.SetProfileID(name_bytes, profile_bytes)
return GetPtrString(error_string)
@@ -64,19 +72,19 @@ class EduVPN(object):
register_callback(self)
def cancel_oauth(self) -> str:
- return CancelOAuth()
+ return CancelOAuth(self.name)
- def deregister(self):
- Deregister()
+ def deregister(self) -> str:
+ return Deregister(self.name)
def register(self, debug=False) -> bool:
return Register(self.name, self.config_directory, callback_function, debug) == ""
def get_disco(self):
- return GetDiscoServers()
+ return GetDiscoServers(self.name)
def connect(self, url):
- return Connect(url)
+ return Connect(self.name, url)
@property
def event(self):
@@ -86,7 +94,7 @@ class EduVPN(object):
self.event.run(old_state, new_state, data)
def set_profile(self, profile_id) -> str:
- return SetProfileID(profile_id)
+ return SetProfileID(self.name, profile_id)
class EventHandler(object):