summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-05-02 14:34:35 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-05-02 14:34:35 +0200
commit466450f0c47bdc614e66326d90e5fc6fb56ae732 (patch)
treea01518a58d50d2f8449d37dadecc40e35c9f1fe1
parenta2a8efdcaad3d9b1852b1367a7cd7e8c5860cecf (diff)
Refactor: Wrap most errors in a custom type
-rw-r--r--internal/api.go53
-rw-r--r--internal/config.go22
-rw-r--r--internal/fsm.go18
-rw-r--r--internal/http.go109
-rw-r--r--internal/log.go14
-rw-r--r--internal/oauth.go113
-rw-r--r--internal/openvpn.go12
-rw-r--r--internal/server.go115
-rw-r--r--internal/verify.go154
-rw-r--r--internal/wireguard.go28
-rw-r--r--state.go76
-rw-r--r--state_test.go36
12 files changed, 531 insertions, 219 deletions
diff --git a/internal/api.go b/internal/api.go
index 2ed605e..da17f76 100644
--- a/internal/api.go
+++ b/internal/api.go
@@ -9,6 +9,7 @@ import (
)
// Authorized wrappers on top of HTTP
+// the errors will not be wrapped here so that the caller can check if we got a status error, to retry oauth
func (server *Server) apiAuthorized(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) {
// Ensure optional is not nil as we will fill it with headers
if opts == nil {
@@ -38,28 +39,32 @@ func (server *Server) apiAuthorizedRetry(method string, endpoint string, opts *H
if bodyErr != nil {
var error *HTTPStatusError
- // Only retry authroized if we get a HTTP 401
+ // Only retry authorized if we get a HTTP 401
if errors.As(bodyErr, &error) && error.Status == 401 {
server.Logger.Log(LOG_INFO, fmt.Sprintf("API: Got HTTP error %v, retrying authorized", error))
// Tell the method that the token is expired
server.OAuth.Token.ExpiredTimestamp = GenerateTimeSeconds()
- return server.apiAuthorized(method, endpoint, opts)
+ retryHeader, retryBody, retryErr := server.apiAuthorized(method, endpoint, opts)
+ if retryErr != nil {
+ return nil, nil, &APIAuthorizedError{Err: retryErr}
+ }
+ return retryHeader, retryBody, nil
}
- return header, nil, bodyErr
+ return nil, nil, &APIAuthorizedError{Err: bodyErr}
}
- return header, body, bodyErr
+ return header, body, nil
}
func (server *Server) APIInfo() error {
_, body, bodyErr := server.apiAuthorizedRetry(http.MethodGet, "/info", nil)
if bodyErr != nil {
- return bodyErr
+ return &APIInfoError{Err: bodyErr}
}
structure := ServerProfileInfo{}
jsonErr := json.Unmarshal(body, &structure)
if jsonErr != nil {
- return jsonErr
+ return &APIInfoError{Err: jsonErr}
}
server.Profiles = structure
@@ -79,7 +84,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str
}
header, connectBody, connectErr := server.apiAuthorizedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
- return "", "", connectErr
+ return "", "", &APIConnectWireguardError{Err: connectErr}
}
expires := header.Get("expires")
@@ -97,7 +102,7 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro
}
header, connectBody, connectErr := server.apiAuthorizedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
- return "", "", connectErr
+ return "", "", &APIConnectOpenVPNError{Err: connectErr}
}
expires := header.Get("expires")
@@ -108,3 +113,35 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro
func (server *Server) APIDisconnect() {
server.apiAuthorizedRetry(http.MethodPost, "/disconnect", nil)
}
+
+type APIAuthorizedError struct {
+ Err error
+}
+
+func (e *APIAuthorizedError) Error() string {
+ return fmt.Sprintf("failed api authorized call with error: %v", e.Err)
+}
+
+type APIConnectWireguardError struct {
+ Err error
+}
+
+func (e *APIConnectWireguardError) Error() string {
+ return fmt.Sprintf("failed api /connect wireguard call with error: %v", e.Err)
+}
+
+type APIConnectOpenVPNError struct {
+ Err error
+}
+
+func (e *APIConnectOpenVPNError) Error() string {
+ return fmt.Sprintf("failed api /connect OpenVPN call with error: %v", e.Err)
+}
+
+type APIInfoError struct {
+ Err error
+}
+
+func (e *APIInfoError) Error() string {
+ return fmt.Sprintf("failed api /info call with error: %v", e.Err)
+}
diff --git a/internal/config.go b/internal/config.go
index 47f773e..a135ac6 100644
--- a/internal/config.go
+++ b/internal/config.go
@@ -25,11 +25,11 @@ func (config *Config) GetFilename() string {
func (config *Config) Save(readStruct interface{}) error {
configDirErr := EnsureDirectory(config.Directory)
if configDirErr != nil {
- return configDirErr
+ return &ConfigSaveError{Err: configDirErr}
}
jsonString, marshalErr := json.Marshal(readStruct)
if marshalErr != nil {
- return marshalErr
+ return &ConfigSaveError{Err: marshalErr}
}
return ioutil.WriteFile(config.GetFilename(), jsonString, 0o644)
}
@@ -37,7 +37,23 @@ func (config *Config) Save(readStruct interface{}) error {
func (config *Config) Load(writeStruct interface{}) error {
bytes, readErr := ioutil.ReadFile(config.GetFilename())
if readErr != nil {
- return readErr
+ return &ConfigLoadError{Err: readErr}
}
return json.Unmarshal(bytes, writeStruct)
}
+
+type ConfigSaveError struct {
+ Err error
+}
+
+func (e *ConfigSaveError) Error() string {
+ return fmt.Sprintf("failed to save config with error: %v", e.Err)
+}
+
+type ConfigLoadError struct {
+ Err error
+}
+
+func (e *ConfigLoadError) Error() string {
+ return fmt.Sprintf("failed to load config with error: %v", e.Err)
+}
diff --git a/internal/fsm.go b/internal/fsm.go
index 1bcc479..0b9ad1e 100644
--- a/internal/fsm.go
+++ b/internal/fsm.go
@@ -206,3 +206,21 @@ func (fsm *FSM) generateMermaidGraph() string {
func (fsm *FSM) GenerateGraph() string {
return fsm.generateMermaidGraph()
}
+
+type FSMWrongStateTransitionError struct {
+ Got FSMStateID
+ Want FSMStateID
+}
+
+func (e *FSMWrongStateTransitionError) Error() string {
+ return fmt.Sprintf("wrong FSM state, got: %s, want a state with a transition to: %s", e.Got.String(), e.Want.String())
+}
+
+type FSMWrongStateError struct {
+ Got FSMStateID
+ Want FSMStateID
+}
+
+func (e *FSMWrongStateError) Error() string {
+ return fmt.Sprintf("wrong FSM state, got: %s, want: %s", e.Got.String(), e.Want.String())
+}
diff --git a/internal/http.go b/internal/http.go
index 8ca8cb9..0b1eda4 100644
--- a/internal/http.go
+++ b/internal/http.go
@@ -9,52 +9,6 @@ import (
"strings"
)
-type HTTPResourceError struct {
- URL string
- Err error
-}
-
-func (e *HTTPResourceError) Error() string {
- return fmt.Sprintf("failed obtaining HTTP resource %s with error %v", e.URL, e.Err)
-}
-
-type HTTPStatusError struct {
- URL string
- Status int
-}
-
-func (e *HTTPStatusError) Error() string {
- return fmt.Sprintf("failed obtaining HTTP resource %s as it gave an unsuccesful status code %d", e.URL, e.Status)
-}
-
-type HTTPReadError struct {
- URL string
- Err error
-}
-
-func (e *HTTPReadError) Error() string {
- return fmt.Sprintf("failed reading HTTP resource %s with error %v", e.URL, e.Err)
-}
-
-type HTTPParseJsonError struct {
- URL string
- Body string
- Err error
-}
-
-func (e *HTTPParseJsonError) Error() string {
- return fmt.Sprintf("failed parsing json %s for HTTP resource %s with error %v", e.Body, e.URL, e.Err)
-}
-
-type HTTPRequestCreateError struct {
- URL string
- Err error
-}
-
-func (e *HTTPRequestCreateError) Error() string {
- return fmt.Sprintf("failed to create HTTP request with url %s and error %v", e.URL, e.Err)
-}
-
type URLParameters map[string]string
type HTTPOptionalParams struct {
@@ -65,9 +19,9 @@ type HTTPOptionalParams struct {
// Construct an URL including on parameters
func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error) {
- url, err := url.Parse(baseURL)
- if err != nil {
- return "", err
+ url, parseErr := url.Parse(baseURL)
+ if parseErr != nil {
+ return "", &HTTPConstructURLError{URL: baseURL, Parameters: parameters, Err: parseErr}
}
q := url.Query()
@@ -130,6 +84,7 @@ func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) (ht
// it already has the right error so so we don't wrap it further
url, urlErr := httpOptionalURL(url, opts)
if urlErr != nil {
+ // No further type wrapping is needed here
return nil, nil, urlErr
}
@@ -170,3 +125,59 @@ func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) (ht
// Return the body in bytes and signal the status error if there was one
return resp.Header, body, nil
}
+
+type HTTPResourceError struct {
+ URL string
+ Err error
+}
+
+func (e *HTTPResourceError) Error() string {
+ return fmt.Sprintf("failed obtaining HTTP resource: %s with error: %v", e.URL, e.Err)
+}
+
+type HTTPStatusError struct {
+ URL string
+ Status int
+}
+
+func (e *HTTPStatusError) Error() string {
+ return fmt.Sprintf("failed obtaining HTTP resource: %s as it gave an unsuccesful status code: %d", e.URL, e.Status)
+}
+
+type HTTPReadError struct {
+ URL string
+ Err error
+}
+
+func (e *HTTPReadError) Error() string {
+ return fmt.Sprintf("failed reading HTTP resource: %s with error: %v", e.URL, e.Err)
+}
+
+type HTTPParseJsonError struct {
+ URL string
+ Body string
+ Err error
+}
+
+func (e *HTTPParseJsonError) Error() string {
+ return fmt.Sprintf("failed parsing json %s for HTTP resource: %s with error: %v", e.Body, e.URL, e.Err)
+}
+
+type HTTPRequestCreateError struct {
+ URL string
+ Err error
+}
+
+func (e *HTTPRequestCreateError) Error() string {
+ return fmt.Sprintf("failed to create HTTP request with url: %s and error: %v", e.URL, e.Err)
+}
+
+type HTTPConstructURLError struct {
+ URL string
+ Parameters URLParameters
+ Err error
+}
+
+func (e *HTTPConstructURLError) Error() string {
+ return fmt.Sprintf("failed to construct url: %s including parameters: %v with error: %v", e.URL, e.Parameters, e.Err)
+}
diff --git a/internal/log.go b/internal/log.go
index 9248fc0..5109ba2 100644
--- a/internal/log.go
+++ b/internal/log.go
@@ -39,11 +39,11 @@ func (e LogLevel) String() string {
func (logger *FileLogger) Init(level LogLevel, name string, directory string) error {
configDirErr := EnsureDirectory(directory)
if configDirErr != nil {
- return configDirErr
+ return &LogInitializeError{Name: name, Directory: directory, Err: configDirErr}
}
logFile, logOpenErr := os.OpenFile(logger.getFilename(directory, name), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o666)
if logOpenErr != nil {
- return logOpenErr
+ return &LogInitializeError{Name: name, Directory: directory, Err: logOpenErr}
}
log.SetOutput(logFile)
logger.File = logFile
@@ -65,3 +65,13 @@ func (logger *FileLogger) Log(level LogLevel, str string) {
func (logger *FileLogger) Close() {
logger.File.Close()
}
+
+type LogInitializeError struct {
+ Name string
+ Directory string
+ Err error
+}
+
+func (e *LogInitializeError) Error() string {
+ return fmt.Sprintf("failed initializing logging with name: %s and directory: %s with error: %v", e.Name, e.Directory, e.Err)
+}
diff --git a/internal/oauth.go b/internal/oauth.go
index 98af5a4..c13ea99 100644
--- a/internal/oauth.go
+++ b/internal/oauth.go
@@ -5,7 +5,6 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
- "errors"
"fmt"
"net/http"
"net/url"
@@ -20,7 +19,7 @@ import (
func genState() (string, error) {
randomBytes, err := MakeRandomByteSlice(32)
if err != nil {
- return "", &OAuthGenStateUnableError{Err: err}
+ return "", &OAuthGenStateError{Err: err}
}
// For consistency we also use raw url encoding here
@@ -46,7 +45,7 @@ func genChallengeS256(verifier string) string {
func genVerifier() (string, error) {
randomBytes, err := MakeRandomByteSlice(32)
if err != nil {
- return "", &OAuthGenVerifierUnableError{Err: err}
+ return "", &OAuthGenVerifierError{Err: err}
}
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
@@ -95,7 +94,7 @@ func (oauth *OAuth) getTokensWithCallback() error {
}
mux.HandleFunc("/callback", oauth.Callback)
if err := oauth.Session.Server.ListenAndServe(); err != http.ErrServerClosed {
- return &OAuthFailedCallbackError{Addr: addr, Err: err}
+ return &OAuthCallbackError{Addr: addr, Err: err}
}
return oauth.Session.CallbackError
}
@@ -122,7 +121,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
current_time := GenerateTimeSeconds()
_, body, bodyErr := HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
- return bodyErr
+ return &OAuthAuthError{Err: bodyErr}
}
tokenStructure := OAuthToken{}
@@ -160,7 +159,7 @@ func (oauth *OAuth) getTokensWithRefresh() error {
current_time := GenerateTimeSeconds()
_, body, bodyErr := HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
- return bodyErr
+ return &OAuthRefreshError{Err: bodyErr}
}
tokenStructure := OAuthToken{}
@@ -181,7 +180,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// Extract the authorization code
code, success := req.URL.Query()["code"]
if !success {
- oauth.Session.CallbackError = &OAuthFailedCallbackParameterError{Parameter: "code", URL: req.URL.String()}
+ oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}
go oauth.Session.Server.Shutdown(oauth.Session.Context)
return
}
@@ -192,14 +191,14 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
state, success := req.URL.Query()["state"]
if !success {
- oauth.Session.CallbackError = &OAuthFailedCallbackParameterError{Parameter: "state", URL: req.URL.String()}
+ oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}
go oauth.Session.Server.Shutdown(oauth.Session.Context)
return
}
// The state is the first entry
extractedState := state[0]
if extractedState != oauth.Session.State {
- oauth.Session.CallbackError = &OAuthFailedCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State}
+ oauth.Session.CallbackError = &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State}
go oauth.Session.Server.Shutdown(oauth.Session.Context)
return
}
@@ -208,7 +207,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// Obtaining the access and refresh tokens
err := oauth.getTokensWithAuthCode(extractedCode)
if err != nil {
- oauth.Session.CallbackError = &OAuthFailedCallbackGetTokensError{Err: err}
+ oauth.Session.CallbackError = &OAuthCallbackGetTokensError{Err: err}
go oauth.Session.Server.Shutdown(oauth.Session.Context)
return
}
@@ -225,18 +224,18 @@ func (oauth *OAuth) Init(fsm *FSM, logger *FileLogger) {
// Starts the OAuth exchange for eduvpn.
func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string) error {
if !oauth.FSM.HasTransition(OAUTH_STARTED) {
- return errors.New(fmt.Sprintf("Failed starting oauth, invalid state %s", oauth.FSM.Current.String()))
+ return &FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: OAUTH_STARTED}
}
// Generate the state
state, stateErr := genState()
if stateErr != nil {
- return &OAuthFailedInitializeError{Err: stateErr}
+ return &OAuthInitializeError{Err: stateErr}
}
// Generate the verifier and challenge
verifier, verifierErr := genVerifier()
if verifierErr != nil {
- return &OAuthFailedInitializeError{Err: verifierErr}
+ return &OAuthInitializeError{Err: verifierErr}
}
challenge := genChallengeS256(verifier)
@@ -252,8 +251,8 @@ func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string)
authURL, urlErr := HTTPConstructURL(authorizationURL, parameters)
- if urlErr != nil { // shouldn't happen
- panic(urlErr)
+ if urlErr != nil {
+ return &OAuthInitializeError{Err: urlErr}
}
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
@@ -268,12 +267,12 @@ func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string)
// Error definitions
func (oauth *OAuth) Finish() error {
if !oauth.FSM.HasTransition(AUTHORIZED) {
- return errors.New("invalid state to finish oauth")
+ return &FSMWrongStateError{Got: oauth.FSM.Current, Want: AUTHORIZED}
}
tokenErr := oauth.getTokensWithCallback()
if tokenErr != nil {
- return tokenErr
+ return &OAuthFinishError{Err: tokenErr}
}
oauth.FSM.GoTransition(AUTHORIZED)
return nil
@@ -288,13 +287,13 @@ func (oauth *OAuth) Login(name string, authorizationURL string, tokenURL string)
authInitializeErr := oauth.start(name, authorizationURL, tokenURL)
if authInitializeErr != nil {
- return authInitializeErr
+ return &OAuthLoginError{Err: authInitializeErr}
}
oauthErr := oauth.Finish()
if oauthErr != nil {
- return oauthErr
+ return &OAuthLoginError{Err: oauthErr}
}
return nil
}
@@ -329,64 +328,96 @@ func (oauth *OAuth) NeedsRelogin() bool {
type OAuthCancelledCallbackError struct{}
func (e *OAuthCancelledCallbackError) Error() string {
- return fmt.Sprintf("Client cancelled OAuth")
+ return fmt.Sprintf("client cancelled OAuth")
}
-type OAuthGenStateUnableError struct {
+type OAuthGenStateError struct {
Err error
}
-func (e *OAuthGenStateUnableError) Error() string {
- return fmt.Sprintf("failed generating state with error %v", e.Err)
+func (e *OAuthGenStateError) Error() string {
+ return fmt.Sprintf("failed generating state with error: %v", e.Err)
}
-type OAuthGenVerifierUnableError struct {
+type OAuthGenVerifierError struct {
Err error
}
-func (e *OAuthGenVerifierUnableError) Error() string {
- return fmt.Sprintf("failed generating verifier with error %v", e.Err)
+func (e *OAuthGenVerifierError) Error() string {
+ return fmt.Sprintf("failed generating verifier with error: %v", e.Err)
}
-type OAuthFailedCallbackError struct {
+type OAuthCallbackError struct {
Addr string
Err error
}
-func (e *OAuthFailedCallbackError) Error() string {
- return fmt.Sprintf("failed callback %s with error %v", e.Addr, e.Err)
+func (e *OAuthCallbackError) Error() string {
+ return fmt.Sprintf("failed callback: %s with error: %v", e.Addr, e.Err)
}
-type OAuthFailedCallbackParameterError struct {
+type OAuthCallbackParameterError struct {
Parameter string
URL string
}
-func (e *OAuthFailedCallbackParameterError) Error() string {
- return fmt.Sprintf("failed retrieving parameter %s in url %s", e.Parameter, e.URL)
+func (e *OAuthCallbackParameterError) Error() string {
+ return fmt.Sprintf("failed retrieving parameter: %s in url: %s", e.Parameter, e.URL)
}
-type OAuthFailedCallbackStateMatchError struct {
+type OAuthCallbackStateMatchError struct {
State string
ExpectedState string
}
-func (e *OAuthFailedCallbackStateMatchError) Error() string {
- return fmt.Sprintf("failed matching state, got %s, want %s", e.State, e.ExpectedState)
+func (e *OAuthCallbackStateMatchError) Error() string {
+ return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState)
}
-type OAuthFailedCallbackGetTokensError struct {
+type OAuthCallbackGetTokensError struct {
Err error
}
-func (e *OAuthFailedCallbackGetTokensError) Error() string {
- return fmt.Sprintf("failed getting tokens with error %v", e.Err)
+func (e *OAuthCallbackGetTokensError) Error() string {
+ return fmt.Sprintf("failed getting tokens with error: %v", e.Err)
}
-type OAuthFailedInitializeError struct {
+type OAuthFinishError struct {
Err error
}
-func (e *OAuthFailedInitializeError) Error() string {
- return fmt.Sprintf("failed initializing OAuth with error %v", e.Err)
+func (e *OAuthFinishError) Error() string {
+ return fmt.Sprintf("failed finishing OAuth with error: %v", e.Err)
+}
+
+type OAuthLoginError struct {
+ Err error
+}
+
+func (e *OAuthLoginError) Error() string {
+ return fmt.Sprintf("failed OAuth logging in with error: %v", e.Err)
+}
+
+type OAuthInitializeError struct {
+ Err error
+}
+
+func (e *OAuthInitializeError) Error() string {
+ return fmt.Sprintf("failed initializing OAuth with error: %v", e.Err)
+}
+
+type OAuthAuthError struct {
+ Err error
+}
+
+func (e *OAuthAuthError) Error() string {
+ return fmt.Sprintf("failed getting tokens with auth code for OAuth with error: %v", e.Err)
+}
+
+type OAuthRefreshError struct {
+ Err error
+}
+
+func (e *OAuthRefreshError) Error() string {
+ return fmt.Sprintf("failed refreshing tokens for OAuth with error: %v", e.Err)
}
diff --git a/internal/openvpn.go b/internal/openvpn.go
index 1b2e626..ed31fe2 100644
--- a/internal/openvpn.go
+++ b/internal/openvpn.go
@@ -1,12 +1,22 @@
package internal
+import "fmt"
+
func (server *Server) OpenVPNGetConfig() (string, error) {
profile_id := server.Profiles.Current
configOpenVPN, _, configErr := server.APIConnectOpenVPN(profile_id)
if configErr != nil {
- return "", configErr
+ return "", &OpenVPNGetConfigError{Err: configErr}
}
return configOpenVPN, nil
}
+
+type OpenVPNGetConfigError struct {
+ Err error
+}
+
+func (e *OpenVPNGetConfigError) Error() string {
+ return fmt.Sprintf("failed getting OpenVPN config with error: %v", e.Err)
+}
diff --git a/internal/server.go b/internal/server.go
index 1d6f1e1..489719e 100644
--- a/internal/server.go
+++ b/internal/server.go
@@ -24,12 +24,12 @@ type Servers struct {
func (servers *Servers) GetCurrentServer() (*Server, error) {
if servers.List == nil {
- return nil, errors.New("No map found to get Current Server")
+ return nil, &ServerGetCurrentNoMapError{}
}
server, exists := servers.List[servers.Current]
if !exists || server == nil {
- return nil, errors.New("Current Server not found")
+ return nil, &ServerGetCurrentNotFoundError{}
}
return server, nil
}
@@ -45,7 +45,7 @@ func (server *Server) Init(url string, fsm *FSM, logger *FileLogger) error {
server.OAuth.Init(fsm, logger)
endpointsErr := server.GetEndpoints()
if endpointsErr != nil {
- return endpointsErr
+ return &ServerInitializeError{URL: url, Err: endpointsErr}
}
return nil
}
@@ -60,7 +60,7 @@ func (server *Server) EnsureTokens() error {
func (servers *Servers) EnsureServer(url string, fsm *FSM, logger *FileLogger, makeCurrent bool) (*Server, error) {
if url == "" {
- return nil, errors.New("Emtpy URL to ensure Server")
+ return nil, &ServerEnsureServerEmptyURLError{}
}
if servers.List == nil {
servers.List = make(map[string]*Server)
@@ -74,7 +74,7 @@ func (servers *Servers) EnsureServer(url string, fsm *FSM, logger *FileLogger, m
serverInitErr := server.Init(url, fsm, logger)
if serverInitErr != nil {
- return nil, serverInitErr
+ return nil, &ServerEnsureServerError{Err: serverInitErr}
}
servers.List[url] = server
@@ -88,7 +88,7 @@ func (servers *Servers) getSecureInternetHome() (*Server, error) {
server, exists := servers.List[servers.SecureHome]
if !exists || server == nil {
- return nil, errors.New("No secure internet home found")
+ return nil, &ServerGetSecureInternetHomeError{}
}
return server, nil
@@ -104,7 +104,7 @@ func (servers *Servers) CopySecureInternetOAuth(server *Server) error {
secureHome, secureHomeErr := servers.getSecureInternetHome()
if secureHomeErr != nil {
- return secureHomeErr
+ return &ServerCopySecureInternetOAuthError{Err: secureHomeErr}
}
// Forward token properties
@@ -155,14 +155,14 @@ func (server *Server) GetEndpoints() error {
_, body, bodyErr := HTTPGet(url)
if bodyErr != nil {
- return bodyErr
+ return &ServerGetEndpointsError{Err: bodyErr}
}
endpoints := ServerEndpoints{}
jsonErr := json.Unmarshal(body, &endpoints)
if jsonErr != nil {
- return jsonErr
+ return &ServerGetEndpointsError{Err: jsonErr}
}
server.Endpoints = endpoints
@@ -180,23 +180,23 @@ func (profile *ServerProfile) supportsWireguard() bool {
}
func (server *Server) getCurrentProfile() (*ServerProfile, error) {
- profile_id := server.Profiles.Current
+ profileID := server.Profiles.Current
for _, profile := range server.Profiles.Info.ProfileList {
- if profile.ID == profile_id {
+ if profile.ID == profileID {
return &profile, nil
}
}
- return nil, errors.New(fmt.Sprintf("no profile found for id %s", profile_id))
+ return nil, &ServerGetCurrentProfileNotFoundError{ProfileID: profileID}
}
func (server *Server) getConfigWithProfile() (string, error) {
if !server.FSM.HasTransition(HAS_CONFIG) {
- return "", errors.New("cannot get a config with a profile, invalid state")
+ return "", &FSMWrongStateTransitionError{Got: server.FSM.Current, Want: HAS_CONFIG}
}
profile, profileErr := server.getCurrentProfile()
if profileErr != nil {
- return "", profileErr
+ return "", &ServerGetConfigWithProfileError{Err: profileErr}
}
if profile.supportsWireguard() {
@@ -207,7 +207,7 @@ func (server *Server) getConfigWithProfile() (string, error) {
func (server *Server) askForProfileID() error {
if !server.FSM.HasTransition(ASK_PROFILE) {
- return errors.New("cannot ask for a profile id, invalid state")
+ return &FSMWrongStateTransitionError{Got: server.FSM.Current, Want: ASK_PROFILE}
}
server.FSM.GoTransitionWithData(ASK_PROFILE, server.ProfilesRaw, false)
return nil
@@ -220,7 +220,7 @@ func (server *Server) GetConfig() (string, error) {
infoErr := server.APIInfo()
if infoErr != nil {
- return "", infoErr
+ return "", &ServerGetConfigError{Err: infoErr}
}
// Set the current profile if there is only one profile
@@ -232,8 +232,89 @@ func (server *Server) GetConfig() (string, error) {
profileErr := server.askForProfileID()
if profileErr != nil {
- return "", nil
+ return "", &ServerGetConfigError{Err: profileErr}
}
return server.getConfigWithProfile()
}
+
+type ServerGetCurrentProfileNotFoundError struct {
+ ProfileID string
+}
+
+func (e *ServerGetCurrentProfileNotFoundError) Error() string {
+ return fmt.Sprintf("failed to get current profile, profile with ID: %s not found", e.ProfileID)
+}
+
+type ServerGetConfigWithProfileError struct {
+ Err error
+}
+
+func (e *ServerGetConfigWithProfileError) Error() string {
+ return fmt.Sprintf("failed to get config including profile with error %v", e.Err)
+}
+
+type ServerGetEndpointsError struct {
+ Err error
+}
+
+func (e *ServerGetEndpointsError) Error() string {
+ return fmt.Sprintf("failed to get server endpoint with error %v", e.Err)
+}
+
+type ServerGetSecureInternetHomeError struct{}
+
+func (e *ServerGetSecureInternetHomeError) Error() string {
+ return "failed to get secure internet home server, not found"
+}
+
+type ServerCopySecureInternetOAuthError struct {
+ Err error
+}
+
+func (e *ServerCopySecureInternetOAuthError) Error() string {
+ return fmt.Sprintf("failed to copy oauth tokens from home server with error %v", e.Err)
+}
+
+type ServerEnsureServerEmptyURLError struct{}
+
+func (e *ServerEnsureServerEmptyURLError) Error() string {
+ return "failed ensuring server, empty url provided"
+}
+
+type ServerEnsureServerError struct {
+ Err error
+}
+
+func (e *ServerEnsureServerError) Error() string {
+ return fmt.Sprintf("failed ensuring server with error %v", e.Err)
+}
+
+type ServerGetCurrentNoMapError struct{}
+
+func (e *ServerGetCurrentNoMapError) Error() string {
+ return "failed getting current server, no servers available"
+}
+
+type ServerGetCurrentNotFoundError struct{}
+
+func (e *ServerGetCurrentNotFoundError) Error() string {
+ return "failed getting current server, not found"
+}
+
+type ServerGetConfigError struct {
+ Err error
+}
+
+func (e *ServerGetConfigError) Error() string {
+ return fmt.Sprintf("failed getting server config with error %v", e.Err)
+}
+
+type ServerInitializeError struct {
+ URL string
+ Err error
+}
+
+func (e *ServerInitializeError) Error() string {
+ return fmt.Sprintf("failed initializing server with url %s and error %v", e.URL, e.Err)
+}
diff --git a/internal/verify.go b/internal/verify.go
index 9128777..713e4d7 100644
--- a/internal/verify.go
+++ b/internal/verify.go
@@ -58,13 +58,81 @@ func InsecureTestingSetExtraKey(keyString string) {
extraKey = keyString
}
+// verifyWithKeys verifies the Minisign signature in signatureFileContent (minisig file format) over the server_list/organization_list JSON in signedJson.
+//
+// Verification is performed using a matching key in allowedPublicKeys.
+// The signature is checked to be a Ed25519 Minisign (optionally Ed25519 Blake2b-512 prehashed, see forcePrehash) signature with a valid trusted comment.
+// The file type that is verified is indicated by expectedFileName, which must be one of "server_list.json"/"organization_list.json".
+// The trusted comment is checked to be of the form "timestamp:<timestamp>\tfile:<expectedFileName>", optionally suffixed by something, e.g. "\thashed".
+// The signature is checked to have a timestamp with a value of at least minSignTime, which is a UNIX timestamp without milliseconds.
+//
+// The return value will either be (true, nil) on success or (false, detailedVerifyError) on failure.
+func verifyWithKeys(signatureFileContent string, signedJson []byte, filename string, minSignTime uint64, allowedPublicKeys []string, forcePrehash bool) (bool, error) {
+ switch filename {
+ case "server_list.json", "organization_list.json":
+ break
+ default:
+ return false, &VerifyUnknownExpectedFilenameError{Filename: filename, Expected: "server_list.json or organization_list.json"}
+ }
+
+ sig, err := minisign.DecodeSignature(signatureFileContent)
+ if err != nil {
+ return false, &VerifyInvalidSignatureFormatError{Err: err}
+ }
+
+ // Check if signature is prehashed, see https://jedisct1.github.io/minisign/#signature-format
+ if forcePrehash && sig.SignatureAlgorithm != [2]byte{'E', 'D'} {
+ return false, &VerifyInvalidSignatureAlgorithmError{Algorithm: string(sig.SignatureAlgorithm[:]), WantedAlgorithm: "ED (BLAKE2b-prehashed EdDSA)"}
+ }
+
+ // Find allowed key used for signature
+ for _, keyStr := range allowedPublicKeys {
+ key, err := minisign.NewPublicKey(keyStr)
+ if err != nil {
+ // Should only happen if Verify is wrong or extraKey is invalid
+ return false, &VerifyCreatePublicKeyError{PublicKey: keyStr, Err: err}
+ }
+
+ if sig.KeyId != key.KeyId {
+ continue // Wrong key
+ }
+
+ valid, err := key.Verify(signedJson, sig)
+ if !valid {
+ return false, &VerifyInvalidSignatureError{Err: err}
+ }
+
+ // Parse trusted comment
+ var signTime uint64
+ var sigFileName string
+ // sigFileName cannot have spaces
+ _, err = fmt.Sscanf(sig.TrustedComment, "trusted comment: timestamp:%d\tfile:%s", &signTime, &sigFileName)
+ if err != nil {
+ return false, &VerifyInvalidTrustedCommentError{TrustedComment: sig.TrustedComment, Err: err}
+ }
+
+ if sigFileName != filename {
+ return false, &VerifyWrongSigFilenameError{Filename: filename, SigFilename: sigFileName}
+ }
+
+ if signTime < minSignTime {
+ return false, &VerifySigTimeEarlierError{SigTime: signTime, MinSigTime: minSignTime}
+ }
+
+ return true, nil
+ }
+
+ // No matching allowed key found
+ return false, &VerifyUnknownKeyError{Filename: filename}
+}
+
type VerifyUnknownExpectedFilenameError struct {
Filename string
Expected string
}
func (e *VerifyUnknownExpectedFilenameError) Error() string {
- return fmt.Sprintf("invalid filename %s, expected %s", e.Filename, e.Expected)
+ return fmt.Sprintf("invalid filename: %s, expected: %s", e.Filename, e.Expected)
}
type VerifyInvalidSignatureFormatError struct {
@@ -72,7 +140,7 @@ type VerifyInvalidSignatureFormatError struct {
}
func (e *VerifyInvalidSignatureFormatError) Error() string {
- return fmt.Sprintf("invalid signature format, error %v", e.Err)
+ return fmt.Sprintf("invalid signature format with error: %v", e.Err)
}
type VerifyInvalidSignatureAlgorithmError struct {
@@ -81,7 +149,7 @@ type VerifyInvalidSignatureAlgorithmError struct {
}
func (e *VerifyInvalidSignatureAlgorithmError) Error() string {
- return fmt.Sprintf("invalid signature algorithm %s, wanted %s", e.Algorithm, e.WantedAlgorithm)
+ return fmt.Sprintf("invalid signature algorithm: %s, wanted: %s", e.Algorithm, e.WantedAlgorithm)
}
type VerifyCreatePublicKeyError struct {
@@ -90,7 +158,7 @@ type VerifyCreatePublicKeyError struct {
}
func (e *VerifyCreatePublicKeyError) Error() string {
- return fmt.Sprintf("failed to create public key %s with error %v", e.PublicKey, e.Err)
+ return fmt.Sprintf("failed to create public key: %s with error: %v", e.PublicKey, e.Err)
}
type VerifyInvalidSignatureError struct {
@@ -98,7 +166,7 @@ type VerifyInvalidSignatureError struct {
}
func (e *VerifyInvalidSignatureError) Error() string {
- return fmt.Sprintf("invalid signature with error %v", e.Err)
+ return fmt.Sprintf("invalid signature with error: %v", e.Err)
}
type VerifyInvalidTrustedCommentError struct {
@@ -107,7 +175,7 @@ type VerifyInvalidTrustedCommentError struct {
}
func (e *VerifyInvalidTrustedCommentError) Error() string {
- return fmt.Sprintf("invalid trusted comment %s with error %v", e.TrustedComment, e.Err)
+ return fmt.Sprintf("invalid trusted comment: %s with error: %v", e.TrustedComment, e.Err)
}
type VerifyWrongSigFilenameError struct {
@@ -116,7 +184,7 @@ type VerifyWrongSigFilenameError struct {
}
func (e *VerifyWrongSigFilenameError) Error() string {
- return fmt.Sprintf("wrong filename %s, expected filename %s for signature", e.Filename, e.SigFilename)
+ return fmt.Sprintf("wrong filename: %s, expected filename: %s for signature", e.Filename, e.SigFilename)
}
type VerifySigTimeEarlierError struct {
@@ -125,7 +193,7 @@ type VerifySigTimeEarlierError struct {
}
func (e *VerifySigTimeEarlierError) Error() string {
- return fmt.Sprintf("Sign time %d is earlier than sign time %d", e.SigTime, e.MinSigTime)
+ return fmt.Sprintf("Sign time: %d is earlier than sign time: %d", e.SigTime, e.MinSigTime)
}
type VerifyUnknownKeyError struct {
@@ -133,73 +201,5 @@ type VerifyUnknownKeyError struct {
}
func (e *VerifyUnknownKeyError) Error() string {
- return fmt.Sprintf("signature for filename %s was created with an unknown key", e.Filename)
-}
-
-// verifyWithKeys verifies the Minisign signature in signatureFileContent (minisig file format) over the server_list/organization_list JSON in signedJson.
-//
-// Verification is performed using a matching key in allowedPublicKeys.
-// The signature is checked to be a Ed25519 Minisign (optionally Ed25519 Blake2b-512 prehashed, see forcePrehash) signature with a valid trusted comment.
-// The file type that is verified is indicated by expectedFileName, which must be one of "server_list.json"/"organization_list.json".
-// The trusted comment is checked to be of the form "timestamp:<timestamp>\tfile:<expectedFileName>", optionally suffixed by something, e.g. "\thashed".
-// The signature is checked to have a timestamp with a value of at least minSignTime, which is a UNIX timestamp without milliseconds.
-//
-// The return value will either be (true, nil) on success or (false, detailedVerifyError) on failure.
-func verifyWithKeys(signatureFileContent string, signedJson []byte, filename string, minSignTime uint64, allowedPublicKeys []string, forcePrehash bool) (bool, error) {
- switch filename {
- case "server_list.json", "organization_list.json":
- break
- default:
- return false, &VerifyUnknownExpectedFilenameError{Filename: filename, Expected: "server_list.json or organization_list.json"}
- }
-
- sig, err := minisign.DecodeSignature(signatureFileContent)
- if err != nil {
- return false, &VerifyInvalidSignatureFormatError{Err: err}
- }
-
- // Check if signature is prehashed, see https://jedisct1.github.io/minisign/#signature-format
- if forcePrehash && sig.SignatureAlgorithm != [2]byte{'E', 'D'} {
- return false, &VerifyInvalidSignatureAlgorithmError{Algorithm: string(sig.SignatureAlgorithm[:]), WantedAlgorithm: "ED (BLAKE2b-prehashed EdDSA)"}
- }
-
- // Find allowed key used for signature
- for _, keyStr := range allowedPublicKeys {
- key, err := minisign.NewPublicKey(keyStr)
- if err != nil {
- // Should only happen if Verify is wrong or extraKey is invalid
- return false, &VerifyCreatePublicKeyError{PublicKey: keyStr, Err: err}
- }
-
- if sig.KeyId != key.KeyId {
- continue // Wrong key
- }
-
- valid, err := key.Verify(signedJson, sig)
- if !valid {
- return false, &VerifyInvalidSignatureError{Err: err}
- }
-
- // Parse trusted comment
- var signTime uint64
- var sigFileName string
- // sigFileName cannot have spaces
- _, err = fmt.Sscanf(sig.TrustedComment, "trusted comment: timestamp:%d\tfile:%s", &signTime, &sigFileName)
- if err != nil {
- return false, &VerifyInvalidTrustedCommentError{TrustedComment: sig.TrustedComment, Err: err}
- }
-
- if sigFileName != filename {
- return false, &VerifyWrongSigFilenameError{Filename: filename, SigFilename: sigFileName}
- }
-
- if signTime < minSignTime {
- return false, &VerifySigTimeEarlierError{SigTime: signTime, MinSigTime: minSignTime}
- }
-
- return true, nil
- }
-
- // No matching allowed key found
- return false, &VerifyUnknownKeyError{Filename: filename}
+ return fmt.Sprintf("signature for filename: %s was created with an unknown key", e.Filename)
}
diff --git a/internal/wireguard.go b/internal/wireguard.go
index 4ec12bd..7977dbc 100644
--- a/internal/wireguard.go
+++ b/internal/wireguard.go
@@ -8,8 +8,12 @@ import (
)
func wireguardGenerateKey() (wgtypes.Key, error) {
- key, error := wgtypes.GeneratePrivateKey()
- return key, error
+ key, keyErr := wgtypes.GeneratePrivateKey()
+
+ if keyErr != nil {
+ return key, &WireguardGenerateKeyError{Err: keyErr}
+ }
+ return key, nil
}
// FIXME: Instead of doing a regex replace, decide if we should use a parser
@@ -31,14 +35,14 @@ func (server *Server) WireguardGetConfig() (string, error) {
wireguardKey, wireguardErr := wireguardGenerateKey()
if wireguardErr != nil {
- return "", wireguardErr
+ return "", &WireguardGetConfigError{Err: wireguardErr}
}
wireguardPublicKey := wireguardKey.PublicKey().String()
configWireguard, _, configErr := server.APIConnectWireguard(profile_id, wireguardPublicKey)
if configErr != nil {
- return "", configErr
+ return "", &WireguardGetConfigError{Err: wireguardErr}
}
// FIXME: Store expiry
@@ -50,3 +54,19 @@ func (server *Server) WireguardGetConfig() (string, error) {
return configWireguardKey, nil
}
+
+type WireguardGenerateKeyError struct {
+ Err error
+}
+
+func (e *WireguardGenerateKeyError) Error() string {
+ return fmt.Sprintf("failed generating Wireguard key with error: %v", e.Err)
+}
+
+type WireguardGetConfigError struct {
+ Err error
+}
+
+func (e *WireguardGetConfigError) Error() string {
+ return fmt.Sprintf("failed getting Wireguard config with error: %v", e.Err)
+}
diff --git a/state.go b/state.go
index c69cf37..767425c 100644
--- a/state.go
+++ b/state.go
@@ -1,7 +1,7 @@
package eduvpn
import (
- "errors"
+ "fmt"
"github.com/jwijenbergh/eduvpn-common/internal"
)
@@ -28,7 +28,7 @@ type VPNState struct {
func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error {
if !state.FSM.InState(internal.DEREGISTERED) {
- return errors.New("app already registered")
+ return &StateWrongFSMStateError{Got: state.FSM.Current, Want: internal.DEREGISTERED}
}
// Initialize the logger
logLevel := internal.LOG_WARNING
@@ -39,7 +39,7 @@ func (state *VPNState) Register(name string, directory string, stateCallback fun
loggerErr := state.Logger.Init(logLevel, name, directory)
if loggerErr != nil {
- return errors.New("Failed to create a logger")
+ return &StateRegisterError{Err: loggerErr}
}
// Initialize the FSM
@@ -75,13 +75,13 @@ func (state *VPNState) Deregister() error {
func (state *VPNState) CancelOAuth() error {
if !state.FSM.InState(internal.OAUTH_STARTED) {
- return errors.New("cannot cancel oauth, oauth not started")
+ return &StateWrongFSMStateError{Got: state.FSM.Current, Want: internal.OAUTH_STARTED}
}
server, serverErr := state.Servers.GetCurrentServer()
if serverErr != nil {
- return serverErr
+ return &StateOAuthCancelError{Err: serverErr}
}
server.CancelOAuth()
return nil
@@ -89,13 +89,13 @@ func (state *VPNState) CancelOAuth() error {
func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (string, error) {
if state.FSM.InState(internal.DEREGISTERED) {
- return "", errors.New("app not registered")
+ return "", &StateFSMNotRegisteredError{}
}
// New server chosen, ensure the server is fresh
server, serverErr := state.Servers.EnsureServer(url, &state.FSM, &state.Logger, true)
if serverErr != nil {
- return "", serverErr
+ return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: serverErr}
}
// When we connect to secure internet, copy over the tokens from the home server
@@ -118,7 +118,7 @@ func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (st
// We are possibly in oauth started
// So go to chosen server
state.FSM.GoTransition(internal.CHOSEN_SERVER)
- return "", loginErr
+ return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: loginErr}
}
} else { // OAuth was valid, ensure we are in the authorized state
state.FSM.GoTransition(internal.AUTHORIZED)
@@ -132,7 +132,7 @@ func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (st
config, configErr := server.GetConfig()
if configErr != nil {
- return "", configErr
+ return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: configErr}
} else {
state.FSM.GoTransition(internal.HAS_CONFIG)
}
@@ -150,27 +150,77 @@ func (state *VPNState) ConnectSecureInternet(url string) (string, error) {
func (state *VPNState) GetDiscoOrganizations() (string, error) {
if state.FSM.InState(internal.DEREGISTERED) {
- return "", errors.New("app not registered")
+ return "", &StateWrongFSMStateError{Got: state.FSM.Current, Want: internal.DEREGISTERED}
}
return state.Discovery.GetOrganizationsList()
}
func (state *VPNState) GetDiscoServers() (string, error) {
if state.FSM.InState(internal.DEREGISTERED) {
- return "", errors.New("app not registered")
+ return "", &StateFSMNotRegisteredError{}
}
return state.Discovery.GetServersList()
}
func (state *VPNState) SetProfileID(profileID string) error {
if !state.FSM.InState(internal.ASK_PROFILE) {
- return errors.New("Invalid state for setting a profile")
+ return &StateWrongFSMStateError{Got: state.FSM.Current, Want: internal.ASK_PROFILE}
}
server, serverErr := state.Servers.GetCurrentServer()
if serverErr != nil {
- return errors.New("No server found for setting a profile ID")
+ return &StateSetProfileError{ProfileID: profileID, Err: serverErr}
}
server.Profiles.Current = profileID
return nil
}
+
+type StateSetProfileError struct {
+ ProfileID string
+ Err error
+}
+
+func (e *StateSetProfileError) Error() string {
+ return fmt.Sprintf("failed to set profile ID %s with error %v", e.ProfileID, e.Err)
+}
+
+type StateRegisterError struct {
+ Err error
+}
+
+func (e *StateRegisterError) Error() string {
+ return fmt.Sprintf("failed to register with error %v", e.Err)
+}
+
+type StateFSMNotRegisteredError struct{}
+
+func (e *StateFSMNotRegisteredError) Error() string {
+ return fmt.Sprintf("state is not registered. Current FSM state: %s", internal.DEREGISTERED.String())
+}
+
+type StateWrongFSMStateError struct {
+ Got internal.FSMStateID
+ Want internal.FSMStateID
+}
+
+func (e *StateWrongFSMStateError) Error() string {
+ return fmt.Sprintf("wrong FSM state, got: %s, want: %s", e.Got.String(), e.Want.String())
+}
+
+type StateOAuthCancelError struct {
+ Err error
+}
+
+func (e *StateOAuthCancelError) Error() string {
+ return fmt.Sprintf("failed cancelling OAuth for state with error: %v", e.Err)
+}
+
+type StateConnectError struct {
+ URL string
+ IsSecureInternet bool
+ Err error
+}
+
+func (e *StateConnectError) Error() string {
+ return fmt.Sprintf("failed connecting to server: %s (is secure internet: %v) with error: %v", e.URL, e.IsSecureInternet, e.Err)
+}
diff --git a/state_test.go b/state_test.go
index 4320a6d..c6e33e0 100644
--- a/state_test.go
+++ b/state_test.go
@@ -84,15 +84,43 @@ func test_connect_oauth_parameter(t *testing.T, parameters internal.URLParameter
}, false)
_, configErr := state.ConnectInstituteAccess(serverURI)
- if !errors.As(configErr, expectedErr) {
- t.Errorf("error %T = %v, wantErr %T", configErr, configErr, expectedErr)
+ var stateErr *StateConnectError
+ var loginErr *internal.OAuthLoginError
+ var finishErr *internal.OAuthFinishError
+
+ // We go through the chain of errors by unwrapping them one by one
+
+ // First ensure we get a state connect error
+ if !errors.As(configErr, &stateErr) {
+ t.Errorf("error %T = %v, wantErr %T", configErr, configErr, stateErr)
+ }
+
+ // Then ensure we get a login error
+ gotLoginErr := stateErr.Err
+
+ if !errors.As(gotLoginErr, &loginErr) {
+ t.Errorf("error %T = %v, wantErr %T", gotLoginErr, gotLoginErr, loginErr)
+ }
+
+ // Then ensure we get a finish error
+ gotFinishErr := loginErr.Err
+
+ if !errors.As(gotFinishErr, &finishErr) {
+ t.Errorf("error %T = %v, wantErr %T", gotFinishErr, gotFinishErr, finishErr)
+ }
+
+ // Then ensure we get the expected inner error
+ gotExpectedErr := finishErr.Err
+
+ if !errors.As(gotExpectedErr, expectedErr) {
+ t.Errorf("error %T = %v, wantErr %T", gotExpectedErr, gotExpectedErr, expectedErr)
}
}
func Test_connect_oauth_parameters(t *testing.T) {
var (
- failedCallbackParameterError *internal.OAuthFailedCallbackParameterError
- failedCallbackStateMatchError *internal.OAuthFailedCallbackStateMatchError
+ failedCallbackParameterError *internal.OAuthCallbackParameterError
+ failedCallbackStateMatchError *internal.OAuthCallbackStateMatchError
)
tests := []struct {