diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-26 14:50:22 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-26 15:33:04 +0200 |
| commit | 7e4494256a08f585523e01b1bbc51f41ff4e2b95 (patch) | |
| tree | ccbf873b2bfb11aa22f185e78ce1e2e5eebd094c | |
| parent | 448c51d2142c186f0490b9d51c0d73beb3c76863 (diff) | |
Refactor: Errors into custom export types and expose types
| -rw-r--r-- | exports/disco.go | 15 | ||||
| -rw-r--r-- | exports/error.h | 17 | ||||
| -rw-r--r-- | exports/exports.go | 153 | ||||
| -rw-r--r-- | exports/servers.go | 5 | ||||
| -rw-r--r-- | fsm.go | 2 | ||||
| -rw-r--r-- | internal/config/config.go | 2 | ||||
| -rw-r--r-- | internal/discovery/discovery.go | 2 | ||||
| -rw-r--r-- | internal/http/http.go | 2 | ||||
| -rw-r--r-- | internal/log/log.go | 2 | ||||
| -rw-r--r-- | internal/oauth/oauth.go | 2 | ||||
| -rw-r--r-- | internal/server/api.go | 2 | ||||
| -rw-r--r-- | internal/server/common.go | 2 | ||||
| -rw-r--r-- | internal/server/instituteaccess.go | 2 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 2 | ||||
| -rw-r--r-- | internal/util/util.go | 2 | ||||
| -rw-r--r-- | internal/verify/verify.go | 2 | ||||
| -rw-r--r-- | internal/wireguard/wireguard.go | 2 | ||||
| -rw-r--r-- | state.go | 32 | ||||
| -rw-r--r-- | state_test.go | 2 | ||||
| -rw-r--r-- | types/error.go (renamed from internal/types/error.go) | 34 | ||||
| -rw-r--r-- | types/server.go (renamed from internal/types/server.go) | 0 | ||||
| -rw-r--r-- | wrappers/python/src/__init__.py | 78 | ||||
| -rw-r--r-- | wrappers/python/src/error.py | 15 | ||||
| -rw-r--r-- | wrappers/python/src/main.py | 41 |
24 files changed, 192 insertions, 226 deletions
diff --git a/exports/disco.go b/exports/disco.go index 73bc1ac..ac7ac7d 100644 --- a/exports/disco.go +++ b/exports/disco.go @@ -3,6 +3,7 @@ package main /* // for free and size_t #include <stdlib.h> +#include "error.h" typedef struct discoveryServer { const char* authentication_url_template; @@ -42,7 +43,7 @@ import ( "unsafe" eduvpn "github.com/eduvpn/eduvpn-common" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" ) func getCPtrDiscoOrganization( @@ -168,15 +169,15 @@ func FreeDiscoOrganizations(cOrganizations *C.discoveryOrganizations) { } //export GetDiscoServers -func GetDiscoServers(name *C.char) (*C.discoveryServers, *C.char) { +func GetDiscoServers(name *C.char) (*C.discoveryServers, *C.error) { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return nil, C.CString(ErrorToString(stateErr)) + return nil, getError(stateErr) } servers, serversErr := state.GetDiscoServers() if serversErr != nil { - return nil, C.CString(ErrorToString(serversErr)) + return nil, getError(serversErr) } returnedStruct := (*C.discoveryServers)( @@ -191,15 +192,15 @@ func GetDiscoServers(name *C.char) (*C.discoveryServers, *C.char) { } //export GetDiscoOrganizations -func GetDiscoOrganizations(name *C.char) (*C.discoveryOrganizations, *C.char) { +func GetDiscoOrganizations(name *C.char) (*C.discoveryOrganizations, *C.error) { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return nil, C.CString(ErrorToString(stateErr)) + return nil, getError(stateErr) } organizations, organizationsErr := state.GetDiscoOrganizations() if organizationsErr != nil { - return nil, C.CString(ErrorToString(organizationsErr)) + return nil, getError(organizationsErr) } returnedStruct := (*C.discoveryOrganizations)( diff --git a/exports/error.h b/exports/error.h new file mode 100644 index 0000000..64592e1 --- /dev/null +++ b/exports/error.h @@ -0,0 +1,17 @@ +#ifndef ERROR_H +#define ERROR_H + +typedef enum errorLevel { + ERR_OTHER, + ERR_INFO, + ERR_WARNING, + ERR_FATAL, +} errorLevel; + +typedef struct error { + errorLevel level; + const char* traceback; + const char* cause; +} error; + +#endif /* ERROR_H */ diff --git a/exports/exports.go b/exports/exports.go index 745b992..a2d97c7 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -2,6 +2,7 @@ package main /* #include <stdlib.h> +#include "error.h" typedef void (*PythonCB)(const char* name, int oldstate, int newstate, void* data); @@ -13,7 +14,6 @@ static void call_callback(PythonCB callback, const char *name, int oldstate, int import "C" import ( - "encoding/json" "fmt" "unsafe" @@ -90,7 +90,7 @@ func Register( config_directory *C.char, stateCallback C.PythonCB, debug C.int, -) *C.char { +) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { @@ -116,237 +116,228 @@ func Register( if registerErr != nil { delete(VPNStates, nameStr) } - return C.CString(ErrorToString(registerErr)) + return getError(registerErr) } //export Deregister -func Deregister(name *C.char) *C.char { +func Deregister(name *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } state.Deregister() return nil } -func ErrorToString(error error) string { - if error == nil { - return "" +func getError(err error) *C.error { + if err == nil { + return nil } + errorStruct := (*C.error)( + C.malloc(C.size_t(unsafe.Sizeof(C.error{}))), + ) + errorStruct.level = C.errorLevel(eduvpn.GetErrorLevel(err)) + errorStruct.traceback = C.CString(eduvpn.GetErrorTraceback(err)) + errorStruct.cause = C.CString(eduvpn.GetErrorCause(err).Error()) + return errorStruct +} - errorString, jsonErr := eduvpn.GetErrorJSONString(error) - if jsonErr != nil { - return "" - } - return errorString +//export FreeError +func FreeError(err *C.error) { + C.free(unsafe.Pointer(err.traceback)) + C.free(unsafe.Pointer(err.cause)) + C.free(unsafe.Pointer(err)) } //export CancelOAuth -func CancelOAuth(name *C.char) *C.char { +func CancelOAuth(name *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } cancelErr := state.CancelOAuth() - cancelErrString := ErrorToString(cancelErr) - return C.CString(cancelErrString) -} - -type configJSON struct { - Config string `json:"config"` - ConfigType string `json:"config_type"` -} - -func getConfigJSON(config string, configType string) *C.char { - object := &configJSON{Config: config, ConfigType: configType} - jsonBytes, jsonErr := json.Marshal(object) - - if jsonErr != nil { - panic(jsonErr) - } - - return C.CString(string(jsonBytes)) + return getError(cancelErr) } //export RemoveSecureInternet -func RemoveSecureInternet(name *C.char) *C.char { +func RemoveSecureInternet(name *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } removeErr := state.RemoveSecureInternet() - return C.CString(ErrorToString(removeErr)) + return getError(removeErr) } //export RemoveInstituteAccess -func RemoveInstituteAccess(name *C.char, url *C.char) *C.char { +func RemoveInstituteAccess(name *C.char, url *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } removeErr := state.RemoveInstituteAccess(C.GoString(url)) - return C.CString(ErrorToString(removeErr)) + return getError(removeErr) } //export RemoveCustomServer -func RemoveCustomServer(name *C.char, url *C.char) *C.char { +func RemoveCustomServer(name *C.char, url *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } removeErr := state.RemoveCustomServer(C.GoString(url)) - return C.CString(ErrorToString(removeErr)) + return getError(removeErr) } //export GetConfigSecureInternet -func GetConfigSecureInternet(name *C.char, orgID *C.char, forceTCP C.int) (*C.char, *C.char) { +func GetConfigSecureInternet(name *C.char, orgID *C.char, forceTCP C.int) (*C.char, *C.char, *C.error) { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return nil, C.CString(ErrorToString(stateErr)) + return nil, nil, getError(stateErr) } forceTCPBool := forceTCP == 1 config, configType, configErr := state.GetConfigSecureInternet(C.GoString(orgID), forceTCPBool) - return getConfigJSON(config, configType), C.CString(ErrorToString(configErr)) + return C.CString(config), C.CString(configType), getError(configErr) } //export GetConfigInstituteAccess -func GetConfigInstituteAccess(name *C.char, url *C.char, forceTCP C.int) (*C.char, *C.char) { +func GetConfigInstituteAccess(name *C.char, url *C.char, forceTCP C.int) (*C.char, *C.char, *C.error) { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return nil, C.CString(ErrorToString(stateErr)) + return nil, nil, getError(stateErr) } forceTCPBool := forceTCP == 1 config, configType, configErr := state.GetConfigInstituteAccess(C.GoString(url), forceTCPBool) - return getConfigJSON(config, configType), C.CString(ErrorToString(configErr)) + return C.CString(config), C.CString(configType), getError(configErr) } //export GetConfigCustomServer -func GetConfigCustomServer(name *C.char, url *C.char, forceTCP C.int) (*C.char, *C.char) { +func GetConfigCustomServer(name *C.char, url *C.char, forceTCP C.int) (*C.char, *C.char, *C.error) { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return nil, C.CString(ErrorToString(stateErr)) + return nil, nil, getError(stateErr) } forceTCPBool := forceTCP == 1 config, configType, configErr := state.GetConfigCustomServer(C.GoString(url), forceTCPBool) - return getConfigJSON(config, configType), C.CString(ErrorToString(configErr)) + return C.CString(config), C.CString(configType), getError(configErr) } //export SetProfileID -func SetProfileID(name *C.char, data *C.char) *C.char { +func SetProfileID(name *C.char, data *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } profileErr := state.SetProfileID(C.GoString(data)) - return C.CString(ErrorToString(profileErr)) + return getError(profileErr) } //export ChangeSecureLocation -func ChangeSecureLocation(name *C.char) *C.char { +func ChangeSecureLocation(name *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } locationErr := state.ChangeSecureLocation() - return C.CString(ErrorToString(locationErr)) + return getError(locationErr) } //export SetSecureLocation -func SetSecureLocation(name *C.char, data *C.char) *C.char { +func SetSecureLocation(name *C.char, data *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } locationErr := state.SetSecureLocation(C.GoString(data)) - return C.CString(ErrorToString(locationErr)) + return getError(locationErr) } //export GoBack -func GoBack(name *C.char) *C.char { +func GoBack(name *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } goBackErr := state.GoBack() - return C.CString(ErrorToString(goBackErr)) + return getError(goBackErr) } //export SetSearchServer -func SetSearchServer(name *C.char) *C.char { +func SetSearchServer(name *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } setSearchErr := state.SetSearchServer() - return C.CString(ErrorToString(setSearchErr)) + return getError(setSearchErr) } //export SetDisconnected -func SetDisconnected(name *C.char, cleanup C.int) *C.char { +func SetDisconnected(name *C.char, cleanup C.int) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } setDisconnectedErr := state.SetDisconnected(int(cleanup) == 1) - return C.CString(ErrorToString(setDisconnectedErr)) + return getError(setDisconnectedErr) } //export SetDisconnecting -func SetDisconnecting(name *C.char) *C.char { +func SetDisconnecting(name *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } setDisconnectingErr := state.SetDisconnecting() - return C.CString(ErrorToString(setDisconnectingErr)) + return getError(setDisconnectingErr) } //export SetConnecting -func SetConnecting(name *C.char) *C.char { +func SetConnecting(name *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } setConnectingErr := state.SetConnecting() - return C.CString(ErrorToString(setConnectingErr)) + return getError(setConnectingErr) } //export SetConnected -func SetConnected(name *C.char) *C.char { +func SetConnected(name *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } setConnectedErr := state.SetConnected() - return C.CString(ErrorToString(setConnectedErr)) + return getError(setConnectedErr) } //export RenewSession -func RenewSession(name *C.char) *C.char { +func RenewSession(name *C.char) *C.error { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return C.CString(ErrorToString(stateErr)) + return getError(stateErr) } renewSessionErr := state.RenewSession() - return C.CString(ErrorToString(renewSessionErr)) + return getError(renewSessionErr) } //export ShouldRenewButton diff --git a/exports/servers.go b/exports/servers.go index a487176..a399db7 100644 --- a/exports/servers.go +++ b/exports/servers.go @@ -3,6 +3,7 @@ package main /* // for free and size_t #include <stdlib.h> +#include "error.h" // The struct for a single server profile typedef struct serverProfile { @@ -293,11 +294,11 @@ func getSavedServersWithOptions(state *eduvpn.VPNState, servers *server.Servers) //export GetSavedServers // This function takes the name as input which is the name of the client // It gets the state by name and then returns the saved servers as a c struct belonging to it -func GetSavedServers(name *C.char) (*C.servers, *C.char) { +func GetSavedServers(name *C.char) (*C.servers, *C.error) { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return nil, C.CString(ErrorToString(stateErr)) + return nil, getError(stateErr) } servers := getSavedServersWithOptions(state, &state.Servers) return servers, nil @@ -5,7 +5,7 @@ import ( "fmt" "github.com/eduvpn/eduvpn-common/internal/fsm" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" ) type ( diff --git a/internal/config/config.go b/internal/config/config.go index 18e466a..0965998 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,7 +6,7 @@ import ( "io/ioutil" "path" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" "github.com/eduvpn/eduvpn-common/internal/util" ) diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index a3877f5..e7270ea 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -6,7 +6,7 @@ import ( "time" "github.com/eduvpn/eduvpn-common/internal/http" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" "github.com/eduvpn/eduvpn-common/internal/util" "github.com/eduvpn/eduvpn-common/internal/verify" ) diff --git a/internal/http/http.go b/internal/http/http.go index f9dafbb..3a81eb6 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" ) type URLParameters map[string]string diff --git a/internal/log/log.go b/internal/log/log.go index d6a7373..c0e9c7d 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -6,7 +6,7 @@ import ( "os" "path" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" "github.com/eduvpn/eduvpn-common/internal/util" ) diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 456c854..f4eacbc 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -11,7 +11,7 @@ import ( "time" httpw "github.com/eduvpn/eduvpn-common/internal/http" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" "github.com/eduvpn/eduvpn-common/internal/util" ) diff --git a/internal/server/api.go b/internal/server/api.go index d824c4a..4648a8f 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -10,7 +10,7 @@ import ( "time" httpw "github.com/eduvpn/eduvpn-common/internal/http" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" ) func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) { diff --git a/internal/server/common.go b/internal/server/common.go index eaa8cdf..7f4a0de 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -5,7 +5,7 @@ import ( "time" "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" "github.com/eduvpn/eduvpn-common/internal/util" "github.com/eduvpn/eduvpn-common/internal/wireguard" ) diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index aaabeb1..c5b58ef 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" ) // An instute access server diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 2fbb143..3981022 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" "github.com/eduvpn/eduvpn-common/internal/util" ) diff --git a/internal/util/util.go b/internal/util/util.go index 0af89ac..f9e2f7b 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -8,7 +8,7 @@ import ( "strings" "time" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" ) func EnsureValidURL(s string) (string, error) { diff --git a/internal/verify/verify.go b/internal/verify/verify.go index 47f5187..458e5e5 100644 --- a/internal/verify/verify.go +++ b/internal/verify/verify.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/jedisct1/go-minisign" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" ) // Verify verifies the signature (.minisig file format) on signedJson. diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 52b9102..3d3ae8e 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -4,7 +4,7 @@ import ( "fmt" "regexp" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -10,7 +10,7 @@ import ( "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" "github.com/eduvpn/eduvpn-common/internal/util" ) @@ -225,25 +225,25 @@ func (state *VPNState) retryConfigAuth( errorMessage := "failed authorized config retry" config, configType, configErr := state.getConfigAuth(chosenServer, forceTCP) if configErr != nil { + level := types.ERR_OTHER var error *oauth.OAuthTokensInvalidError + var oauthCancelledError *oauth.OAuthCancelledCallbackError // Only retry if the error is that the tokens are invalid if errors.As(configErr, &error) { - retryConfig, retryConfigType, retryConfigErr := state.getConfigAuth( + config, configType, configErr = state.getConfigAuth( chosenServer, forceTCP, ) - if retryConfigErr != nil { - state.goBackInternal() - return "", "", &types.WrappedErrorMessage{ - Message: errorMessage, - Err: retryConfigErr, - } + if configErr == nil { + return config, configType, nil } - return retryConfig, retryConfigType, nil + } + if errors.As(configErr, &oauthCancelledError) { + level = types.ERR_INFO } state.goBackInternal() - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} + return "", "", &types.WrappedErrorMessage{Level: level, Message: errorMessage, Err: configErr} } return config, configType, nil } @@ -263,7 +263,7 @@ func (state *VPNState) getConfig( config, configType, configErr := state.retryConfigAuth(chosenServer, forceTCP) if configErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} + return "", "", &types.WrappedErrorMessage{Level: GetErrorLevel(configErr), Message: errorMessage, Err: configErr} } currentServer, currentServerErr := state.Servers.GetCurrentServer() @@ -484,7 +484,7 @@ func (state *VPNState) GetConfigSecureInternet( GetErrorTraceback(configErr), ), ) - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} + return "", "", &types.WrappedErrorMessage{Level: GetErrorLevel(configErr), Message: errorMessage, Err: configErr} } return config, configType, nil } @@ -554,7 +554,7 @@ func (state *VPNState) GetConfigInstituteAccess(url string, forceTCP bool) (stri GetErrorTraceback(configErr), ), ) - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} + return "", "", &types.WrappedErrorMessage{Level: GetErrorLevel(configErr), Message: errorMessage, Err: configErr} } return config, configType, nil } @@ -583,7 +583,7 @@ func (state *VPNState) GetConfigCustomServer(url string, forceTCP bool) (string, GetErrorTraceback(configErr), ), ) - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} + return "", "", &types.WrappedErrorMessage{Level: GetErrorLevel(configErr), Message: errorMessage, Err: configErr} } return config, configType, nil } @@ -948,10 +948,6 @@ func GetErrorTraceback(err error) string { return types.GetErrorTraceback(err) } -func GetErrorJSONString(err error) (string, error) { - return types.GetErrorJSONString(err) -} - func (state *VPNState) GetTranslated(languages map[string]string) string { return util.GetLanguageMatched(languages, state.Language) } diff --git a/state_test.go b/state_test.go index a557a4c..5647b12 100644 --- a/state_test.go +++ b/state_test.go @@ -14,7 +14,7 @@ import ( httpw "github.com/eduvpn/eduvpn-common/internal/http" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" - "github.com/eduvpn/eduvpn-common/internal/types" + "github.com/eduvpn/eduvpn-common/types" ) func ensureLocalWellKnown() { diff --git a/internal/types/error.go b/types/error.go index 0b3feae..607e6c6 100644 --- a/internal/types/error.go +++ b/types/error.go @@ -1,7 +1,6 @@ package types import ( - "encoding/json" "errors" "fmt" ) @@ -81,36 +80,3 @@ func GetErrorLevel(err error) ErrorLevel { } return ERR_OTHER } - -type WrappedErrorMessageJSON struct { - Level ErrorLevel `json:"level"` - Cause string `json:"cause"` - Traceback string `json:"traceback"` -} - -func GetErrorJSONString(err error) (string, error) { - var wrappedErr *WrappedErrorMessage - - var level ErrorLevel - var cause error - var traceback string - - if errors.As(err, &wrappedErr) { - level = wrappedErr.Level - cause = wrappedErr.Cause() - traceback = wrappedErr.Traceback() - } else { - level = ERR_OTHER - cause = err - traceback = err.Error() - } - - json, jsonErr := json.Marshal( - &WrappedErrorMessageJSON{Level: level, Cause: cause.Error(), Traceback: traceback}, - ) - - if jsonErr != nil { - return "", jsonErr - } - return string(json), nil -} diff --git a/internal/types/server.go b/types/server.go index 48f94fb..48f94fb 100644 --- a/internal/types/server.go +++ b/types/server.go diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py index 3bafc0e..cb4ba9b 100644 --- a/wrappers/python/src/__init__.py +++ b/wrappers/python/src/__init__.py @@ -1,11 +1,10 @@ from ctypes import * from collections import defaultdict -from enum import Enum import pathlib import platform from typing import Tuple, Optional -import json from typing import List +from .error import WrappedError, ErrorLevel _lib_prefixes = defaultdict( lambda: "lib", @@ -37,10 +36,12 @@ except: lib = cdll.LoadLibrary(str(pathlib.Path(__file__).parent / "lib" / _libfile)) -class ErrorLevel(Enum): - ERR_OTHER = 0 - ERR_INFO = 1 - +class cError(Structure): + _fields_ = [ + ("level", c_int), + ("traceback", c_char_p), + ("cause", c_char_p), + ] class cServerLocations(Structure): _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)] @@ -126,7 +127,11 @@ class cServers(Structure): class DataError(Structure): - _fields_ = [("data", c_void_p), ("error", c_void_p)] + _fields_ = [("data", c_void_p), ("error", POINTER(cError))] + + +class ConfigError(Structure): + _fields_ = [("config", c_char_p), ("config_type", c_char_p), ("error", POINTER(cError))] VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) @@ -149,17 +154,17 @@ lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [ c_char_p, c_char_p, c_int, -], DataError +], ConfigError lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [ c_char_p, c_char_p, c_int, -], DataError +], ConfigError lib.GetConfigCustomServer.argtypes, lib.GetConfigCustomServer.restype = [ c_char_p, c_char_p, c_int, -], DataError +], ConfigError lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None lib.Register.argtypes, lib.Register.restype = [ c_char_p, @@ -195,19 +200,13 @@ lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [ c_void_p ], None lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None +lib.FreeError.argtypes, lib.FreeError.restype = [c_void_p], None lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], DataError -class WrappedError: - def __init__(self, traceback: str, cause: str, level: ErrorLevel): - self.traceback = traceback - self.cause = cause - self.level = level - - def encode_args(args, types): for arg, t in zip(args, types): # c_char_p needs the str to be encoded to bytes @@ -239,37 +238,21 @@ def get_ptr_list_strings( return strings_list return [] - -def get_ptr_error(ptr: c_void_p) -> Optional[WrappedError]: - error_string = get_ptr_string(ptr) - - if not error_string: +def get_error(ptr: c_void_p) -> Optional[WrappedError]: + if not ptr: return None - - error_json = json.loads(error_string) - - if not error_json: - return None - - if "level" not in error_json: - return error_string - level = error_json["level"] - traceback = error_json["traceback"] - cause = error_json["cause"] - return WrappedError(traceback, cause, ErrorLevel(level)) - - -def get_error(ptr: c_void_p) -> str: - error = get_ptr_error(ptr) - if not error: - return "" - - if not isinstance(error, WrappedError): - return error - return error.cause - - -def get_data_error(data_error: DataError, data_conv=get_ptr_string) -> Tuple[str, str]: + err = cast(ptr, POINTER(cError)).contents + wrapped = WrappedError(err.traceback.decode(), err.cause.decode(), ErrorLevel(err.level)) + lib.FreeError(ptr) + return wrapped + +def get_config_error(config_error: ConfigError) -> Tuple[str, str, Optional[WrappedError]]: + config = get_ptr_string(config_error.config) + config_type = get_ptr_string(config_error.config_type) + err = get_error(config_error.error) + return config, config_type, err + +def get_data_error(data_error: DataError, data_conv=get_ptr_string) -> Tuple[str, Optional[WrappedError]]: data = data_conv(data_error.data) error = get_error(data_error.error) return data, error @@ -283,4 +266,5 @@ decode_map = { c_int: get_bool, c_void_p: get_error, DataError: get_data_error, + ConfigError: get_config_error, } diff --git a/wrappers/python/src/error.py b/wrappers/python/src/error.py new file mode 100644 index 0000000..50298bb --- /dev/null +++ b/wrappers/python/src/error.py @@ -0,0 +1,15 @@ +from enum import Enum + +class ErrorLevel(Enum): + ERR_OTHER = 0 + ERR_INFO = 1 + ERR_WARNING = 2 + ERR_FATAL = 3 + +class WrappedError(Exception): + def __init__(self, traceback: str, cause: str, level: ErrorLevel): + super(WrappedError, self).__init__(cause) + self.traceback = traceback + self.cause = cause + self.level = level + diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py index 1ee9dd7..01621ae 100644 --- a/wrappers/python/src/main.py +++ b/wrappers/python/src/main.py @@ -5,7 +5,6 @@ from .discovery import get_disco_organizations, get_disco_servers from .event import EventHandler from .state import State, StateType from .server import get_servers -import json eduvpn_objects = {} @@ -70,7 +69,7 @@ class EduVPN(object): cancel_oauth_err = self.go_function(lib.CancelOAuth) if cancel_oauth_err: - raise Exception(cancel_oauth_err) + raise cancel_oauth_err def deregister(self) -> None: self.go_function(lib.Deregister) @@ -85,7 +84,7 @@ class EduVPN(object): ) if register_err: - raise Exception(register_err) + raise register_err def get_disco_servers(self) -> str: servers, servers_err = self.go_function_custom_decode( @@ -93,7 +92,7 @@ class EduVPN(object): ) if servers_err: - raise Exception(servers_err) + raise servers_err return servers @@ -103,7 +102,7 @@ class EduVPN(object): ) if organizations_err: - raise Exception(organizations_err) + raise organizations_err return organizations @@ -111,19 +110,19 @@ class EduVPN(object): remove_err = self.go_function(lib.RemoveSecureInternet) if remove_err: - raise Exception(remove_err) + raise remove_err def remove_institute_access(self, url: str): remove_err = self.go_function(lib.RemoveInstituteAccess, url) if remove_err: - raise Exception(remove_err) + raise remove_err def remove_custom_server(self, url: str): remove_err = self.go_function(lib.RemoveCustomServer, url) if remove_err: - raise Exception(remove_err) + raise remove_err def get_config(self, url: str, func: callable, force_tcp: bool = False): # Because it could be the case that a profile callback is started, store a threading event @@ -131,17 +130,13 @@ class EduVPN(object): # The event is set in self.set_profile self.profile_event = threading.Event() - config_json, config_err = self.go_function(func, url, force_tcp) + config, config_type, config_err = self.go_function(func, url, force_tcp) self.profile_event = None self.location_event = None if config_err: - raise Exception(config_err) - - config_json_dict = json.loads(config_json) - config = config_json_dict["config"] - config_type = config_json_dict["config_type"] + raise config_err return config, config_type @@ -169,31 +164,31 @@ class EduVPN(object): connect_err = self.go_function(lib.SetConnected) if connect_err: - raise Exception(connect_err) + raise connect_err def set_disconnecting(self) -> None: disconnecting_err = self.go_function(lib.SetDisconnecting) if disconnecting_err: - raise Exception(disconnecting_err) + raise disconnecting_err def set_connecting(self) -> None: connecting_err = self.go_function(lib.SetConnecting) if connecting_err: - raise Exception(connecting_err) + raise connecting_err def set_disconnected(self, cleanup=True) -> None: disconnect_err = self.go_function(lib.SetDisconnected, cleanup) if disconnect_err: - raise Exception(disconnect_err) + raise disconnect_err def set_search_server(self) -> None: search_err = self.go_function(lib.SetSearchServer) if search_err: - raise Exception(search_err) + raise search_err def remove_class_callbacks(self, cls) -> None: self.event_handler.change_class_callbacks(cls, add=False) @@ -218,7 +213,7 @@ class EduVPN(object): self.profile_event.set() if profile_err: - raise Exception(profile_err) + raise profile_err def change_secure_location(self) -> None: # Set the location by country code @@ -226,7 +221,7 @@ class EduVPN(object): location_err = self.go_function(lib.ChangeSecureLocation) if location_err: - raise Exception(location_err) + raise location_err def set_secure_location(self, country_code: str) -> None: # Set the location by country code @@ -238,13 +233,13 @@ class EduVPN(object): self.location_event.set() if location_err: - raise Exception(location_err) + raise location_err def renew_session(self) -> None: renew_err = self.go_function(lib.RenewSession) if renew_err: - raise Exception(renew_err) + raise renew_err def should_renew_button(self) -> bool: return self.go_function(lib.ShouldRenewButton) |
