summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd/cli/main.go4
-rw-r--r--exports/exports.go1
-rw-r--r--internal/config/config.go25
-rw-r--r--internal/discovery/discovery.go61
-rw-r--r--internal/fsm/fsm.go21
-rw-r--r--internal/http/http.go55
-rw-r--r--internal/log/log.go18
-rw-r--r--internal/oauth/oauth.go144
-rw-r--r--internal/server/api.go78
-rw-r--r--internal/server/server.go152
-rw-r--r--internal/types/error.go62
-rw-r--r--internal/util/util.go7
-rw-r--r--internal/verify/verify.go21
-rw-r--r--internal/wireguard/wireguard.go12
-rw-r--r--state.go91
-rw-r--r--state_test.go31
16 files changed, 283 insertions, 500 deletions
diff --git a/cmd/cli/main.go b/cmd/cli/main.go
index a9bff2e..e6be0bf 100644
--- a/cmd/cli/main.go
+++ b/cmd/cli/main.go
@@ -205,7 +205,9 @@ func printConfig(url string, isInstitute bool) {
config, _, configErr := getConfig(state, url, isInstitute)
if configErr != nil {
- fmt.Println("Error getting config", configErr)
+ // Show the usage of tracebacks and causes
+ fmt.Println("Error getting config:", state.GetErrorTraceback(configErr))
+ fmt.Println("Error getting config, cause:", state.GetErrorCause(configErr))
return
}
diff --git a/exports/exports.go b/exports/exports.go
index 2139254..79b3606 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -17,6 +17,7 @@ import (
"errors"
"fmt"
"unsafe"
+
"github.com/jwijenbergh/eduvpn-common"
)
diff --git a/internal/config/config.go b/internal/config/config.go
index a9ebec7..a74bcd8 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -5,6 +5,8 @@ import (
"fmt"
"io/ioutil"
"path"
+
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
"github.com/jwijenbergh/eduvpn-common/internal/util"
)
@@ -24,13 +26,14 @@ func (config *Config) GetFilename() string {
}
func (config *Config) Save(readStruct interface{}) error {
+ errorMessage := "failed saving configuration"
configDirErr := util.EnsureDirectory(config.Directory)
if configDirErr != nil {
- return &ConfigSaveError{Err: configDirErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: configDirErr}
}
jsonString, marshalErr := json.Marshal(readStruct)
if marshalErr != nil {
- return &ConfigSaveError{Err: marshalErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: marshalErr}
}
return ioutil.WriteFile(config.GetFilename(), jsonString, 0o600)
}
@@ -38,23 +41,7 @@ func (config *Config) Save(readStruct interface{}) error {
func (config *Config) Load(writeStruct interface{}) error {
bytes, readErr := ioutil.ReadFile(config.GetFilename())
if readErr != nil {
- return &ConfigLoadError{Err: readErr}
+ return &types.WrappedErrorMessage{Message: "failed loading configuration", Err: readErr}
}
return json.Unmarshal(bytes, writeStruct)
}
-
-type ConfigSaveError struct {
- Err error
-}
-
-func (e *ConfigSaveError) Error() string {
- return fmt.Sprintf("failed to save config with error: %v", e.Err)
-}
-
-type ConfigLoadError struct {
- Err error
-}
-
-func (e *ConfigLoadError) Error() string {
- return fmt.Sprintf("failed to load config with error: %v", e.Err)
-}
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index d72b4a6..ac3bf57 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -3,50 +3,15 @@ package discovery
import (
"encoding/json"
"fmt"
+
"github.com/jwijenbergh/eduvpn-common/internal/fsm"
"github.com/jwijenbergh/eduvpn-common/internal/http"
"github.com/jwijenbergh/eduvpn-common/internal/log"
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
"github.com/jwijenbergh/eduvpn-common/internal/util"
"github.com/jwijenbergh/eduvpn-common/internal/verify"
)
-type DiscoFileError struct {
- URL string
- Err error
-}
-
-func (e *DiscoFileError) Error() string {
- return fmt.Sprintf("failed obtaining disco file %s with error %v", e.URL, e.Err)
-}
-
-type DiscoSigFileError struct {
- URL string
- Err error
-}
-
-func (e *DiscoSigFileError) Error() string {
- return fmt.Sprintf("failed obtaining disco signature file %s with error %v", e.URL, e.Err)
-}
-
-type DiscoVerifyError struct {
- File string
- Sigfile string
- Err error
-}
-
-func (e *DiscoVerifyError) Error() string {
- return fmt.Sprintf("failed verifying file %s with signature %s due to error %v", e.File, e.Sigfile, e.Err)
-}
-
-type DiscoJSONError struct {
- Body string
- Err error
-}
-
-func (e *DiscoJSONError) Error() string {
- return fmt.Sprintf("failed parsing JSON for contents %s with error %v", e.Body, e.Err)
-}
-
type OrganizationList struct {
JSON json.RawMessage `json:"organization_list"`
Version uint64 `json:"v"`
@@ -68,13 +33,14 @@ type Discovery struct {
// Helper function that gets a disco json
func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) error {
+ errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile)
// Get json data
discoURL := "https://disco.eduvpn.org/v2/"
fileURL := discoURL + jsonFile
_, fileBody, fileErr := http.HTTPGet(fileURL)
if fileErr != nil {
- return &DiscoFileError{fileURL, fileErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr}
}
// Get signature
@@ -83,7 +49,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
_, sigBody, sigFileErr := http.HTTPGet(sigURL)
if sigFileErr != nil {
- return &DiscoSigFileError{URL: sigURL, Err: sigFileErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr}
}
// Verify signature
@@ -92,28 +58,19 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
verifySuccess, verifyErr := verify.Verify(string(sigBody), fileBody, jsonFile, previousVersion, forcePrehash)
if !verifySuccess || verifyErr != nil {
- return &DiscoVerifyError{File: jsonFile, Sigfile: sigFile, Err: verifyErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr}
}
// Parse JSON to extract version and list
jsonErr := json.Unmarshal(fileBody, structure)
if jsonErr != nil {
- return &DiscoJSONError{Body: string(fileBody), Err: jsonErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr}
}
return nil
}
-type GetListError struct {
- File string
- Err error
-}
-
-func (e *GetListError) Error() string {
- return fmt.Sprintf("failed getting disco list file %s with error %v", e.File, e.Err)
-}
-
func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) {
discovery.FSM = fsm
discovery.Logger = logger
@@ -155,7 +112,7 @@ func (discovery *Discovery) GetOrganizationsList() (string, error) {
err := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations)
if err != nil {
// Return previous with an error
- return string(discovery.Organizations.JSON), &GetListError{File: file, Err: err}
+ return string(discovery.Organizations.JSON), &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: err}
}
return string(discovery.Organizations.JSON), nil
}
@@ -169,7 +126,7 @@ func (discovery *Discovery) GetServersList() (string, error) {
err := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers)
if err != nil {
// Return previous with an error
- return string(discovery.Servers.JSON), &GetListError{File: file, Err: err}
+ return string(discovery.Servers.JSON), &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: err}
}
// Update servers timestamp
discovery.Servers.Timestamp = util.GenerateTimeSeconds()
diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go
index bb7f330..3d0bfd6 100644
--- a/internal/fsm/fsm.go
+++ b/internal/fsm/fsm.go
@@ -1,12 +1,15 @@
package fsm
import (
+ "errors"
"fmt"
"os"
"os/exec"
"path"
"sort"
+
"github.com/jwijenbergh/eduvpn-common/internal/log"
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
)
type (
@@ -209,20 +212,26 @@ func (fsm *FSM) GenerateGraph() string {
return fsm.generateMermaidGraph()
}
-type FSMWrongStateTransitionError struct {
+type DeregisteredError struct{}
+
+func (e DeregisteredError) CustomError() *types.WrappedErrorMessage {
+ return &types.WrappedErrorMessage{Message: "Client not registered with the GO library", Err: errors.New("the current FSM state is deregistered, but the function needs a state that is not deregistered")}
+}
+
+type WrongStateTransitionError struct {
Got FSMStateID
Want FSMStateID
}
-func (e *FSMWrongStateTransitionError) Error() string {
- return fmt.Sprintf("wrong FSM state, got: %s, want a state with a transition to: %s", e.Got.String(), e.Want.String())
+func (e WrongStateTransitionError) CustomError() *types.WrappedErrorMessage {
+ return &types.WrappedErrorMessage{Message: "Wrong FSM transition", Err: errors.New(fmt.Sprintf("wrong FSM state, got: %s, want: a state with a transition to: %s", e.Got.String(), e.Want.String()))}
}
-type FSMWrongStateError struct {
+type WrongStateError struct {
Got FSMStateID
Want FSMStateID
}
-func (e *FSMWrongStateError) Error() string {
- return fmt.Sprintf("wrong FSM state, got: %s, want: %s", e.Got.String(), e.Want.String())
+func (e WrongStateError) CustomError() *types.WrappedErrorMessage {
+ return &types.WrappedErrorMessage{Message: "Wrong FSM State", Err: errors.New(fmt.Sprintf("wrong FSM state, got: %s, want: %s", e.Got.String(), e.Want.String()))}
}
diff --git a/internal/http/http.go b/internal/http/http.go
index 87346f1..3c8e4e1 100644
--- a/internal/http/http.go
+++ b/internal/http/http.go
@@ -7,6 +7,8 @@ import (
"net/http"
"net/url"
"strings"
+
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
)
type URLParameters map[string]string
@@ -21,7 +23,7 @@ type HTTPOptionalParams struct {
func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error) {
url, parseErr := url.Parse(baseURL)
if parseErr != nil {
- return "", &HTTPConstructURLError{URL: baseURL, Parameters: parameters, Err: parseErr}
+ return "", &types.WrappedErrorMessage{Message: fmt.Sprintf("failed to construct url: %s including parameters: %v", url, parameters), Err: parseErr}
}
q := url.Query()
@@ -55,7 +57,7 @@ func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) {
url, urlErr := HTTPConstructURL(url, opts.URLParameters)
if urlErr != nil {
- return url, &HTTPRequestCreateError{URL: url, Err: urlErr}
+ return url, &types.WrappedErrorMessage{Message: fmt.Sprintf("failed to create HTTP request with url: %s", url), Err: urlErr}
}
return url, nil
}
@@ -91,10 +93,12 @@ func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) (ht
// Create a client
client := &http.Client{}
+ errorMessage := fmt.Sprintf("failed HTTP request with method %s and url %s", method, url)
+
// Create request object with the body reader generated from the optional arguments
req, reqErr := http.NewRequest(method, url, httpOptionalBodyReader(opts))
if reqErr != nil {
- return nil, nil, &HTTPRequestCreateError{URL: url, Err: reqErr}
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: reqErr}
}
// See https://stackoverflow.com/questions/17714494/golang-http-request-results-in-eof-errors-when-making-multiple-requests-successi
@@ -106,7 +110,7 @@ func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) (ht
// Do request
resp, respErr := client.Do(req)
if respErr != nil {
- return nil, nil, &HTTPResourceError{URL: url, Err: respErr}
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: respErr}
}
// Request successful, make sure body is closed at the end
@@ -115,26 +119,19 @@ func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) (ht
// Return a string
body, readErr := ioutil.ReadAll(resp.Body)
if readErr != nil {
- return resp.Header, nil, &HTTPReadError{URL: url, Err: readErr}
+ return resp.Header, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: readErr}
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
- return resp.Header, body, &HTTPStatusError{URL: url, Status: resp.StatusCode}
+ // We make this a custom error because we want to extract the status code later
+ statusErr := &HTTPStatusError{URL: url, Status: resp.StatusCode}
+ return resp.Header, body, &types.WrappedErrorMessage{Message: errorMessage, Err: statusErr}
}
// Return the body in bytes and signal the status error if there was one
return resp.Header, body, nil
}
-type HTTPResourceError struct {
- URL string
- Err error
-}
-
-func (e *HTTPResourceError) Error() string {
- return fmt.Sprintf("failed obtaining HTTP resource: %s with error: %v", e.URL, e.Err)
-}
-
type HTTPStatusError struct {
URL string
Status int
@@ -144,15 +141,6 @@ func (e *HTTPStatusError) Error() string {
return fmt.Sprintf("failed obtaining HTTP resource: %s as it gave an unsuccesful status code: %d", e.URL, e.Status)
}
-type HTTPReadError struct {
- URL string
- Err error
-}
-
-func (e *HTTPReadError) Error() string {
- return fmt.Sprintf("failed reading HTTP resource: %s with error: %v", e.URL, e.Err)
-}
-
type HTTPParseJsonError struct {
URL string
Body string
@@ -162,22 +150,3 @@ type HTTPParseJsonError struct {
func (e *HTTPParseJsonError) Error() string {
return fmt.Sprintf("failed parsing json %s for HTTP resource: %s with error: %v", e.Body, e.URL, e.Err)
}
-
-type HTTPRequestCreateError struct {
- URL string
- Err error
-}
-
-func (e *HTTPRequestCreateError) Error() string {
- return fmt.Sprintf("failed to create HTTP request with url: %s and error: %v", e.URL, e.Err)
-}
-
-type HTTPConstructURLError struct {
- URL string
- Parameters URLParameters
- Err error
-}
-
-func (e *HTTPConstructURLError) Error() string {
- return fmt.Sprintf("failed to construct url: %s including parameters: %v with error: %v", e.URL, e.Parameters, e.Err)
-}
diff --git a/internal/log/log.go b/internal/log/log.go
index cba3364..f4024e2 100644
--- a/internal/log/log.go
+++ b/internal/log/log.go
@@ -5,6 +5,8 @@ import (
"log"
"os"
"path"
+
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
"github.com/jwijenbergh/eduvpn-common/internal/util"
)
@@ -38,13 +40,15 @@ func (e LogLevel) String() string {
}
func (logger *FileLogger) Init(level LogLevel, name string, directory string) error {
+ errorMessage := "failed creating log"
+
configDirErr := util.EnsureDirectory(directory)
if configDirErr != nil {
- return &LogInitializeError{Name: name, Directory: directory, Err: configDirErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: configDirErr}
}
logFile, logOpenErr := os.OpenFile(logger.getFilename(directory, name), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o666)
if logOpenErr != nil {
- return &LogInitializeError{Name: name, Directory: directory, Err: logOpenErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: logOpenErr}
}
log.SetOutput(logFile)
logger.File = logFile
@@ -66,13 +70,3 @@ func (logger *FileLogger) Log(level LogLevel, str string) {
func (logger *FileLogger) Close() {
logger.File.Close()
}
-
-type LogInitializeError struct {
- Name string
- Directory string
- Err error
-}
-
-func (e *LogInitializeError) Error() string {
- return fmt.Sprintf("failed initializing logging with name: %s and directory: %s with error: %v", e.Name, e.Directory, e.Err)
-}
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index f6ed916..824db90 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -8,10 +8,12 @@ import (
"fmt"
"net/http"
"net/url"
+
"github.com/jwijenbergh/eduvpn-common/internal/fsm"
httpw "github.com/jwijenbergh/eduvpn-common/internal/http"
- "github.com/jwijenbergh/eduvpn-common/internal/util"
"github.com/jwijenbergh/eduvpn-common/internal/log"
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
+ "github.com/jwijenbergh/eduvpn-common/internal/util"
)
// Generates a random base64 string to be used for state
@@ -23,7 +25,7 @@ import (
func genState() (string, error) {
randomBytes, err := util.MakeRandomByteSlice(32)
if err != nil {
- return "", &OAuthGenStateError{Err: err}
+ return "", &types.WrappedErrorMessage{Message: "failed generating an OAuth state", Err: err}
}
// For consistency we also use raw url encoding here
@@ -49,7 +51,7 @@ func genChallengeS256(verifier string) string {
func genVerifier() (string, error) {
randomBytes, err := util.MakeRandomByteSlice(32)
if err != nil {
- return "", &OAuthGenVerifierError{Err: err}
+ return "", &types.WrappedErrorMessage{Message: "failed generating an OAuth verifier", Err: err}
}
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
@@ -60,8 +62,8 @@ type OAuth struct {
Token OAuthToken `json:"token"`
BaseAuthorizationURL string `json:"base_authorization_url"`
TokenURL string `json:"token_url"`
- Logger *log.FileLogger `json:"-"`
- FSM *fsm.FSM `json:"-"`
+ Logger *log.FileLogger `json:"-"`
+ FSM *fsm.FSM `json:"-"`
}
// This structure gets passed to the callback for easy access to the current state
@@ -99,7 +101,7 @@ func (oauth *OAuth) getTokensWithCallback() error {
}
mux.HandleFunc("/callback", oauth.Callback)
if err := oauth.Session.Server.ListenAndServe(); err != http.ErrServerClosed {
- return &OAuthCallbackError{Addr: addr, Err: err}
+ return &types.WrappedErrorMessage{Message: "failed getting tokens with callback", Err: err}
}
return oauth.Session.CallbackError
}
@@ -108,9 +110,9 @@ func (oauth *OAuth) getTokensWithCallback() error {
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
+ errorMessage := "failed getting tokens with the authorization code"
// Make sure the verifier is set as the parameter
// so that the server can verify that we are the actual owner of the authorization code
-
reqURL := oauth.TokenURL
data := url.Values{
"client_id": {oauth.Session.ClientID},
@@ -126,7 +128,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
current_time := util.GenerateTimeSeconds()
_, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
- return &OAuthAuthError{Err: bodyErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
}
tokenStructure := OAuthToken{}
@@ -134,7 +136,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
jsonErr := json.Unmarshal(body, &tokenStructure)
if jsonErr != nil {
- return &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}}
}
tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires
@@ -152,6 +154,7 @@ func (oauth *OAuth) isTokensExpired() bool {
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
func (oauth *OAuth) getTokensWithRefresh() error {
+ errorMessage := "failed getting tokens with the refresh token"
reqURL := oauth.TokenURL
data := url.Values{
"refresh_token": {oauth.Token.Refresh},
@@ -164,14 +167,14 @@ func (oauth *OAuth) getTokensWithRefresh() error {
current_time := util.GenerateTimeSeconds()
_, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
- return &OAuthRefreshError{Err: bodyErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
}
tokenStructure := OAuthToken{}
jsonErr := json.Unmarshal(body, &tokenStructure)
if jsonErr != nil {
- return &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}}
}
tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires
@@ -182,11 +185,15 @@ func (oauth *OAuth) getTokensWithRefresh() error {
//
//// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
+ errorMessage := "failed callback to retrieve the authorization code"
// Extract the authorization code
code, success := req.URL.Query()["code"]
- if !success {
- oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}
+ // Shutdown after we're done
+ defer func() {
go oauth.Session.Server.Shutdown(oauth.Session.Context)
+ }()
+ if !success {
+ oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}}
return
}
// The code is the first entry
@@ -195,30 +202,25 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// Make sure the state is present and matches to protect against cross-site request forgeries
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
state, success := req.URL.Query()["state"]
+
if !success {
- oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}
- go oauth.Session.Server.Shutdown(oauth.Session.Context)
+ oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}}
return
}
// The state is the first entry
extractedState := state[0]
if extractedState != oauth.Session.State {
- oauth.Session.CallbackError = &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State}
- go oauth.Session.Server.Shutdown(oauth.Session.Context)
+ oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State}}
return
}
// Now that we have obtained the authorization code, we can move to the next step:
// Obtaining the access and refresh tokens
- err := oauth.getTokensWithAuthCode(extractedCode)
- if err != nil {
- oauth.Session.CallbackError = &OAuthCallbackGetTokensError{Err: err}
- go oauth.Session.Server.Shutdown(oauth.Session.Context)
+ getTokensErr := oauth.getTokensWithAuthCode(extractedCode)
+ if getTokensErr != nil {
+ oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: getTokensErr}
return
}
-
- // Shutdown the server as we're done listening
- go oauth.Session.Server.Shutdown(oauth.Session.Context)
}
func (oauth *OAuth) Update(fsm *fsm.FSM, logger *log.FileLogger) {
@@ -235,19 +237,20 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.
// Starts the OAuth exchange for eduvpn.
func (oauth *OAuth) start(name string) error {
+ errorMessage := "failed starting OAuth exchange"
if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) {
- return &fsm.FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()}
}
// Generate the state
state, stateErr := genState()
if stateErr != nil {
- return &OAuthInitializeError{Err: stateErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr}
}
// Generate the verifier and challenge
verifier, verifierErr := genVerifier()
if verifierErr != nil {
- return &OAuthInitializeError{Err: verifierErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: verifierErr}
}
challenge := genChallengeS256(verifier)
@@ -264,7 +267,7 @@ func (oauth *OAuth) start(name string) error {
authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters)
if urlErr != nil {
- return &OAuthInitializeError{Err: urlErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
}
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
@@ -277,34 +280,36 @@ func (oauth *OAuth) start(name string) error {
// Error definitions
func (oauth *OAuth) Finish() error {
+ errorMessage := "failed finishing OAuth"
if !oauth.FSM.HasTransition(fsm.AUTHORIZED) {
- return &fsm.FSMWrongStateError{Got: oauth.FSM.Current, Want: fsm.AUTHORIZED}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.AUTHORIZED}.CustomError()}
}
tokenErr := oauth.getTokensWithCallback()
if tokenErr != nil {
- return &OAuthFinishError{Err: tokenErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: tokenErr}
}
oauth.FSM.GoTransition(fsm.AUTHORIZED)
return nil
}
func (oauth *OAuth) Cancel() {
- oauth.Session.CallbackError = &OAuthCancelledCallbackError{}
+ oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: "failed cancelling OAuth", Err: &OAuthCancelledCallbackError{}}
oauth.Session.Server.Shutdown(oauth.Session.Context)
}
func (oauth *OAuth) Login(name string) error {
+ errorMessage := "failed OAuth login"
authInitializeErr := oauth.start(name)
if authInitializeErr != nil {
- return &OAuthLoginError{Err: authInitializeErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr}
}
oauthErr := oauth.Finish()
if oauthErr != nil {
- return &OAuthLoginError{Err: oauthErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr}
}
return nil
}
@@ -342,31 +347,6 @@ func (e *OAuthCancelledCallbackError) Error() string {
return fmt.Sprintf("client cancelled OAuth")
}
-type OAuthGenStateError struct {
- Err error
-}
-
-func (e *OAuthGenStateError) Error() string {
- return fmt.Sprintf("failed generating state with error: %v", e.Err)
-}
-
-type OAuthGenVerifierError struct {
- Err error
-}
-
-func (e *OAuthGenVerifierError) Error() string {
- return fmt.Sprintf("failed generating verifier with error: %v", e.Err)
-}
-
-type OAuthCallbackError struct {
- Addr string
- Err error
-}
-
-func (e *OAuthCallbackError) Error() string {
- return fmt.Sprintf("failed callback: %s with error: %v", e.Addr, e.Err)
-}
-
type OAuthCallbackParameterError struct {
Parameter string
URL string
@@ -384,51 +364,3 @@ type OAuthCallbackStateMatchError struct {
func (e *OAuthCallbackStateMatchError) Error() string {
return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState)
}
-
-type OAuthCallbackGetTokensError struct {
- Err error
-}
-
-func (e *OAuthCallbackGetTokensError) Error() string {
- return fmt.Sprintf("failed getting tokens with error: %v", e.Err)
-}
-
-type OAuthFinishError struct {
- Err error
-}
-
-func (e *OAuthFinishError) Error() string {
- return fmt.Sprintf("failed finishing OAuth with error: %v", e.Err)
-}
-
-type OAuthLoginError struct {
- Err error
-}
-
-func (e *OAuthLoginError) Error() string {
- return fmt.Sprintf("failed OAuth logging in with error: %v", e.Err)
-}
-
-type OAuthInitializeError struct {
- Err error
-}
-
-func (e *OAuthInitializeError) Error() string {
- return fmt.Sprintf("failed initializing OAuth with error: %v", e.Err)
-}
-
-type OAuthAuthError struct {
- Err error
-}
-
-func (e *OAuthAuthError) Error() string {
- return fmt.Sprintf("failed getting tokens with auth code for OAuth with error: %v", e.Err)
-}
-
-type OAuthRefreshError struct {
- Err error
-}
-
-func (e *OAuthRefreshError) Error() string {
- return fmt.Sprintf("failed refreshing tokens for OAuth with error: %v", e.Err)
-}
diff --git a/internal/server/api.go b/internal/server/api.go
index 96bd641..c8c7180 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -6,32 +6,34 @@ import (
"fmt"
"net/http"
"net/url"
+
httpw "github.com/jwijenbergh/eduvpn-common/internal/http"
"github.com/jwijenbergh/eduvpn-common/internal/log"
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
"github.com/jwijenbergh/eduvpn-common/internal/util"
)
func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) {
+ errorMessage := "failed getting server endpoints"
url := fmt.Sprintf("%s/%s", baseURL, WellKnownPath)
_, body, bodyErr := httpw.HTTPGet(url)
if bodyErr != nil {
- return nil, &APIGetEndpointsError{Err: bodyErr}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
}
endpoints := &ServerEndpoints{}
jsonErr := json.Unmarshal(body, endpoints)
if jsonErr != nil {
- return nil, &APIGetEndpointsError{Err: jsonErr}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr}
}
return endpoints, nil
}
-// Authorized wrappers on top of HTTP
-// the errors will not be wrapped here so that the caller can check if we got a status error, to retry oauth
func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HTTPOptionalParams) (http.Header, []byte, error) {
+ errorMessage := "failed API authorized"
// Ensure optional is not nil as we will fill it with headers
if opts == nil {
opts = &httpw.HTTPOptionalParams{}
@@ -39,7 +41,7 @@ func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HT
base, baseErr := server.GetBase()
if baseErr != nil {
- return nil, nil, baseErr
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
url := base.Endpoints.API.V3.API + endpoint
@@ -52,7 +54,7 @@ func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HT
base.FSM.Current = stateBefore
if oauthErr != nil {
- return nil, nil, oauthErr
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr}
}
headerKey := "Authorization"
@@ -66,11 +68,12 @@ func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HT
}
func apiAuthorizedRetry(server Server, method string, endpoint string, opts *httpw.HTTPOptionalParams) (http.Header, []byte, error) {
+ errorMessage := "failed authorized API retry"
header, body, bodyErr := apiAuthorized(server, method, endpoint, opts)
base, baseErr := server.GetBase()
if baseErr != nil {
- return nil, nil, &APIAuthorizedError{Err: baseErr}
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
if bodyErr != nil {
var error *httpw.HTTPStatusError
@@ -82,31 +85,32 @@ func apiAuthorizedRetry(server Server, method string, endpoint string, opts *htt
server.GetOAuth().Token.ExpiredTimestamp = util.GenerateTimeSeconds()
retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts)
if retryErr != nil {
- return nil, nil, &APIAuthorizedError{Err: retryErr}
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: retryErr}
}
return retryHeader, retryBody, nil
}
- return nil, nil, &APIAuthorizedError{Err: bodyErr}
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
}
return header, body, nil
}
func APIInfo(server Server) error {
+ errorMessage := "failed API /info"
_, body, bodyErr := apiAuthorizedRetry(server, http.MethodGet, "/info", nil)
if bodyErr != nil {
- return &APIInfoError{Err: bodyErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
}
structure := ServerProfileInfo{}
jsonErr := json.Unmarshal(body, &structure)
if jsonErr != nil {
- return &APIInfoError{Err: jsonErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr}
}
base, baseErr := server.GetBase()
if baseErr != nil {
- return &APIInfoError{Err: baseErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
// Store the profiles and make sure that the current profile is not overwritten
@@ -118,6 +122,7 @@ func APIInfo(server Server) error {
}
func APIConnectWireguard(server Server, profile_id string, pubkey string, supportsOpenVPN bool) (string, string, int64, error) {
+ errorMessage := "failed obtaining a WireGuard configuration"
headers := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
"accept": {"application/x-wireguard-profile"},
@@ -133,14 +138,14 @@ func APIConnectWireguard(server Server, profile_id string, pubkey string, suppor
}
header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
- return "", "", 0, &APIConnectWireguardError{Err: connectErr}
+ return "", "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: connectErr}
}
expires := header.Get("expires")
pTime, pTimeErr := http.ParseTime(expires)
if pTimeErr != nil {
- return "", "", 0, &APIConnectWireguardError{Err: pTimeErr}
+ return "", "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr}
}
contentType := header.Get("content-type")
@@ -153,6 +158,7 @@ func APIConnectWireguard(server Server, profile_id string, pubkey string, suppor
}
func APIConnectOpenVPN(server Server, profile_id string) (string, int64, error) {
+ errorMessage := "failed obtaining an OpenVPN configuration"
headers := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
"accept": {"application/x-openvpn-profile"},
@@ -164,13 +170,13 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, int64, error)
header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
- return "", 0, &APIConnectOpenVPNError{Err: connectErr}
+ return "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: connectErr}
}
expires := header.Get("expires")
pTime, pTimeErr := http.ParseTime(expires)
if pTimeErr != nil {
- return "", 0, &APIConnectOpenVPNError{Err: pTimeErr}
+ return "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr}
}
return string(connectBody), pTime.Unix(), nil
}
@@ -179,43 +185,3 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, int64, error)
func APIDisconnect(server Server) {
apiAuthorizedRetry(server, http.MethodPost, "/disconnect", nil)
}
-
-type APIAuthorizedError struct {
- Err error
-}
-
-func (e *APIAuthorizedError) Error() string {
- return fmt.Sprintf("failed api authorized call with error: %v", e.Err)
-}
-
-type APIConnectWireguardError struct {
- Err error
-}
-
-func (e *APIConnectWireguardError) Error() string {
- return fmt.Sprintf("failed api /connect wireguard call with error: %v", e.Err)
-}
-
-type APIConnectOpenVPNError struct {
- Err error
-}
-
-func (e *APIConnectOpenVPNError) Error() string {
- return fmt.Sprintf("failed api /connect OpenVPN call with error: %v", e.Err)
-}
-
-type APIInfoError struct {
- Err error
-}
-
-func (e *APIInfoError) Error() string {
- return fmt.Sprintf("failed api /info call with error: %v", e.Err)
-}
-
-type APIGetEndpointsError struct {
- Err error
-}
-
-func (e *APIGetEndpointsError) Error() string {
- return fmt.Sprintf("failed to get server endpoint with error %v", e.Err)
-}
diff --git a/internal/server/server.go b/internal/server/server.go
index a1fb749..ce72400 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -2,9 +2,11 @@ package server
import (
"fmt"
+
"github.com/jwijenbergh/eduvpn-common/internal/fsm"
"github.com/jwijenbergh/eduvpn-common/internal/log"
"github.com/jwijenbergh/eduvpn-common/internal/oauth"
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
"github.com/jwijenbergh/eduvpn-common/internal/util"
"github.com/jwijenbergh/eduvpn-common/internal/wireguard"
)
@@ -17,8 +19,8 @@ type ServerBase struct {
ProfilesRaw string `json:"profiles_raw"`
StartTime int64 `json:"start-time"`
EndTime int64 `json:"end-time"`
- Logger *log.FileLogger `json:"-"`
- FSM *fsm.FSM `json:"-"`
+ Logger *log.FileLogger `json:"-"`
+ FSM *fsm.FSM `json:"-"`
}
// An instute access server
@@ -49,18 +51,19 @@ type InstituteServers struct {
}
func (servers *Servers) GetCurrentServer() (Server, error) {
+ errorMessage := "failed getting current server"
if servers.IsSecureInternet {
return &servers.SecureInternetHomeServer, nil
}
currentInstitute := servers.InstituteServers.CurrentURL
institutes := servers.InstituteServers.Map
if institutes == nil {
- return nil, &ServerGetCurrentNoMapError{}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerGetCurrentNoMapError{}}
}
institute, exists := institutes[currentInstitute]
if !exists || institute == nil {
- return nil, &ServerGetCurrentNotFoundError{}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerGetCurrentNotFoundError{}}
}
return institute, nil
}
@@ -96,25 +99,27 @@ func (institute *InstituteAccessServer) GetBase() (*ServerBase, error) {
}
func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) {
+ errorMessage := "failed getting current secure internet home base"
if server.BaseMap == nil {
- return nil, &ServerSecureInternetMapNotFoundError{}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetMapNotFoundError{}}
}
base, exists := server.BaseMap[server.CurrentURL]
if !exists {
- return nil, &ServerSecureInternetBaseNotFoundError{Current: server.CurrentURL}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentURL}}
}
return base, nil
}
func (institute *InstituteAccessServer) init(url string, fsm *fsm.FSM, logger *log.FileLogger) error {
+ errorMessage := fmt.Sprintf("failed initializing institute server %s", url)
institute.Base.URL = url
institute.Base.FSM = fsm
institute.Base.Logger = logger
endpoints, endpointsErr := APIGetEndpoints(url)
if endpointsErr != nil {
- return &ServerInitializeError{URL: url, Err: endpointsErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr}
}
institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token, fsm, logger)
institute.Base.Endpoints = *endpoints
@@ -122,6 +127,7 @@ func (institute *InstituteAccessServer) init(url string, fsm *fsm.FSM, logger *l
}
func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *log.FileLogger) error {
+ errorMessage := fmt.Sprintf("failed initializing secure internet home server %s", url)
// Initialize the base map if it is non-nil
if secure.BaseMap == nil {
secure.BaseMap = make(map[string]*ServerBase)
@@ -136,7 +142,7 @@ func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *l
base.URL = url
endpoints, endpointsErr := APIGetEndpoints(url)
if endpointsErr != nil {
- return &ServerInitializeError{URL: url, Err: endpointsErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr}
}
base.Endpoints = *endpoints
}
@@ -166,7 +172,7 @@ func ShouldRenewButton(server Server) (bool, error) {
base, baseErr := server.GetBase()
if baseErr != nil {
- //return false, &GetRenewButtonTimeError{Err: baseErr}
+ // return false, &GetRenewButtonTimeError{Err: baseErr}
return false, nil
}
@@ -186,7 +192,7 @@ func ShouldRenewButton(server Server) (bool, error) {
// Session duration is less than 24 hours but not 75% has passed
duration := base.EndTime - base.StartTime
// TODO: Is converting to float64 okay here?
- if duration < 24*60*60 && float64(current) <= (float64(base.StartTime) + 0.75*float64(duration)) {
+ if duration < 24*60*60 && float64(current) <= (float64(base.StartTime)+0.75*float64(duration)) {
return false, nil
}
@@ -198,17 +204,18 @@ func Login(server Server) error {
}
func EnsureTokens(server Server) error {
+ errorMessage := "failed ensuring server tokens"
base, baseErr := server.GetBase()
if baseErr != nil {
- return &ServerEnsureTokensError{Err: baseErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
if server.GetOAuth().NeedsRelogin() {
base.Logger.Log(log.LOG_INFO, "OAuth: Tokens are invalid, relogging in")
loginErr := Login(server)
if loginErr != nil {
- return &ServerEnsureTokensError{Err: loginErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr}
}
}
return nil
@@ -223,13 +230,14 @@ func CancelOAuth(server Server) {
}
func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) {
+ errorMessage := "failed ensuring server"
// Intialize the secure internet server
// This calls the init method which takes care of the rest
if isSecureInternet {
initErr := servers.SecureInternetHomeServer.init(url, fsm, logger)
if initErr != nil {
- return nil, &ServerEnsureServerError{Err: initErr}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr}
}
servers.IsSecureInternet = true
@@ -253,7 +261,7 @@ func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm
instituteServers.CurrentURL = url
instituteInitErr := institute.init(url, fsm, logger)
if instituteInitErr != nil {
- return nil, &ServerEnsureServerError{Err: instituteInitErr}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: instituteInitErr}
}
instituteServers.Map[url] = institute
servers.IsSecureInternet = false
@@ -310,10 +318,11 @@ func (profile *ServerProfile) supportsOpenVPN() bool {
}
func getCurrentProfile(server Server) (*ServerProfile, error) {
+ errorMessage := "failed getting current profile"
base, baseErr := server.GetBase()
if baseErr != nil {
- return nil, &ServerGetCurrentProfileError{Err: baseErr}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
profileID := base.Profiles.Current
for _, profile := range base.Profiles.Info.ProfileList {
@@ -321,28 +330,30 @@ func getCurrentProfile(server Server) (*ServerProfile, error) {
return &profile, nil
}
}
- return nil, &ServerGetCurrentProfileNotFoundError{ProfileID: profileID}
+
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerGetCurrentProfileNotFoundError{ProfileID: profileID}}
}
func wireguardGetConfig(server Server, supportsOpenVPN bool) (string, string, error) {
+ errorMessage := "failed getting server WireGuard configuration"
base, baseErr := server.GetBase()
if baseErr != nil {
- return "", "", baseErr
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
profile_id := base.Profiles.Current
wireguardKey, wireguardErr := wireguard.GenerateKey()
if wireguardErr != nil {
- return "", "", wireguardErr
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: wireguardErr}
}
wireguardPublicKey := wireguardKey.PublicKey().String()
config, content, expires, configErr := APIConnectWireguard(server, profile_id, wireguardPublicKey, supportsOpenVPN)
if configErr != nil {
- return "", "", wireguardErr
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
}
// Store start and end time
@@ -361,10 +372,11 @@ func wireguardGetConfig(server Server, supportsOpenVPN bool) (string, string, er
}
func openVPNGetConfig(server Server) (string, string, error) {
+ errorMessage := "failed getting server OpenVPN configuration"
base, baseErr := server.GetBase()
if baseErr != nil {
- return "", "", baseErr
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
profile_id := base.Profiles.Current
configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profile_id)
@@ -374,25 +386,26 @@ func openVPNGetConfig(server Server) (string, string, error) {
base.EndTime = expires
if configErr != nil {
- return "", "", configErr
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
}
return configOpenVPN, "openvpn", nil
}
func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) {
+ errorMessage := "failed getting an OpenVPN/WireGuard configuration with a profile"
base, baseErr := server.GetBase()
if baseErr != nil {
- return "", "", &ServerGetConfigWithProfileError{Err: baseErr}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
if !base.FSM.HasTransition(fsm.HAS_CONFIG) {
- return "", "", &fsm.FSMWrongStateTransitionError{Got: base.FSM.Current, Want: fsm.HAS_CONFIG}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: base.FSM.Current, Want: fsm.HAS_CONFIG}.CustomError()}
}
profile, profileErr := getCurrentProfile(server)
if profileErr != nil {
- return "", "", &ServerGetConfigWithProfileError{Err: profileErr}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr}
}
supportsOpenVPN := profile.supportsOpenVPN()
@@ -400,7 +413,7 @@ func getConfigWithProfile(server Server, forceTCP bool) (string, string, error)
// If forceTCP we must be able to get a config with OpenVPN
if forceTCP && supportsOpenVPN {
- return "", "", &ServerGetConfigForceTCPError{}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerGetConfigForceTCPError{}}
}
var config string
@@ -416,40 +429,42 @@ func getConfigWithProfile(server Server, forceTCP bool) (string, string, error)
}
if configErr != nil {
- return "", "", &ServerGetConfigWithProfileError{Err: configErr}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
}
return config, configType, nil
}
func askForProfileID(server Server) error {
+ errorMessage := "failed asking for a server profile ID"
base, baseErr := server.GetBase()
if baseErr != nil {
- return &ServerAskForProfileIDError{Err: baseErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
if !base.FSM.HasTransition(fsm.ASK_PROFILE) {
- return &fsm.FSMWrongStateTransitionError{Got: base.FSM.Current, Want: fsm.ASK_PROFILE}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: base.FSM.Current, Want: fsm.ASK_PROFILE}.CustomError()}
}
base.FSM.GoTransitionWithData(fsm.ASK_PROFILE, base.ProfilesRaw, false)
return nil
}
func GetConfig(server Server, forceTCP bool) (string, string, error) {
+ errorMessage := "failed getting an OpenVPN/WireGuard configuration"
base, baseErr := server.GetBase()
if baseErr != nil {
- return "", "", &ServerGetConfigError{Err: baseErr}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
if !base.FSM.InState(fsm.REQUEST_CONFIG) {
- return "", "", &fsm.FSMWrongStateError{Got: base.FSM.Current, Want: fsm.REQUEST_CONFIG}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: base.FSM.Current, Want: fsm.REQUEST_CONFIG}.CustomError()}
}
// Get new profiles using the info call
// This does not override the current profile
infoErr := APIInfo(server)
if infoErr != nil {
- return "", "", &ServerGetConfigError{Err: infoErr}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr}
}
// If there was a profile chosen and it doesn't exist anymore, reset it
@@ -473,7 +488,7 @@ func GetConfig(server Server, forceTCP bool) (string, string, error) {
profileErr := askForProfileID(server)
if profileErr != nil {
- return "", "", &ServerGetConfigError{Err: profileErr}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr}
}
return getConfigWithProfile(server, forceTCP)
@@ -487,14 +502,6 @@ func (e *ServerGetCurrentProfileNotFoundError) Error() string {
return fmt.Sprintf("failed to get current profile, profile with ID: %s not found", e.ProfileID)
}
-type ServerGetConfigWithProfileError struct {
- Err error
-}
-
-func (e *ServerGetConfigWithProfileError) Error() string {
- return fmt.Sprintf("failed to get config including profile with error %v", e.Err)
-}
-
type ServerGetConfigForceTCPError struct{}
func (e *ServerGetConfigForceTCPError) Error() string {
@@ -507,28 +514,12 @@ func (e *ServerGetSecureInternetHomeError) Error() string {
return "failed to get secure internet home server, not found"
}
-type ServerCopySecureInternetOAuthError struct {
- Err error
-}
-
-func (e *ServerCopySecureInternetOAuthError) Error() string {
- return fmt.Sprintf("failed to copy oauth tokens from home server with error %v", e.Err)
-}
-
type ServerEnsureServerEmptyURLError struct{}
func (e *ServerEnsureServerEmptyURLError) Error() string {
return "failed ensuring server, empty url provided"
}
-type ServerEnsureServerError struct {
- Err error
-}
-
-func (e *ServerEnsureServerError) Error() string {
- return fmt.Sprintf("failed ensuring server with error %v", e.Err)
-}
-
type ServerGetCurrentNoMapError struct{}
func (e *ServerGetCurrentNoMapError) Error() string {
@@ -541,31 +532,6 @@ func (e *ServerGetCurrentNotFoundError) Error() string {
return "failed getting current server, not found"
}
-type ServerGetConfigError struct {
- Err error
-}
-
-func (e *ServerGetConfigError) Error() string {
- return fmt.Sprintf("failed getting server config with error %v", e.Err)
-}
-
-type ServerInitializeError struct {
- URL string
- Err error
-}
-
-func (e *ServerInitializeError) Error() string {
- return fmt.Sprintf("failed initializing server with url %s and error %v", e.URL, e.Err)
-}
-
-type ServerInstituteBaseNotFoundError struct {
- Err error
-}
-
-func (e *ServerInstituteBaseNotFoundError) Error() string {
- return "institute base not found"
-}
-
type ServerSecureInternetMapNotFoundError struct{}
func (e *ServerSecureInternetMapNotFoundError) Error() string {
@@ -579,27 +545,3 @@ type ServerSecureInternetBaseNotFoundError struct {
func (e *ServerSecureInternetBaseNotFoundError) Error() string {
return fmt.Sprintf("secure internet base not found with current: %s", e.Current)
}
-
-type ServerGetCurrentProfileError struct {
- Err error
-}
-
-func (e *ServerGetCurrentProfileError) Error() string {
- return fmt.Sprintf("failed getting current profile with error: %v", e.Err)
-}
-
-type ServerAskForProfileIDError struct {
- Err error
-}
-
-func (e *ServerAskForProfileIDError) Error() string {
- return fmt.Sprintf("ask for profile ID error: %v", e.Err)
-}
-
-type ServerEnsureTokensError struct {
- Err error
-}
-
-func (e *ServerEnsureTokensError) Error() string {
- return fmt.Sprintf("failed ensuring tokens with error: %v", e.Err)
-}
diff --git a/internal/types/error.go b/internal/types/error.go
new file mode 100644
index 0000000..fda7c9c
--- /dev/null
+++ b/internal/types/error.go
@@ -0,0 +1,62 @@
+package types
+
+import (
+ "errors"
+ "fmt"
+)
+
+type WrappedErrorMessage struct {
+ Message string
+ Err error
+}
+
+func (e *WrappedErrorMessage) Unwrap() error {
+ return e.Err
+}
+
+func (e *WrappedErrorMessage) Cause() error {
+ causeErr := e.Err
+ for errors.Unwrap(causeErr) != nil {
+ causeErr = errors.Unwrap(causeErr)
+ }
+ return causeErr
+}
+
+func (e *WrappedErrorMessage) Traceback() string {
+ returnStr := fmt.Sprintf("Traceback for error: %s", e.Message)
+ causeErr := e.Err
+ for errors.Unwrap(causeErr) != nil {
+ causeErr = errors.Unwrap(causeErr)
+ var wrappedErr *WrappedErrorMessage
+
+ errorStr := causeErr.Error()
+
+ if errors.As(causeErr, &wrappedErr) {
+ errorStr = wrappedErr.Message
+ }
+ returnStr += fmt.Sprintf("\n - %s", errorStr)
+ }
+ return returnStr
+}
+
+func (e *WrappedErrorMessage) Error() string {
+ return fmt.Sprintf("Got error: %s, with cause: %s", e.Message, e.Err)
+}
+
+func GetErrorTraceback(err error) string {
+ var wrappedErr *WrappedErrorMessage
+
+ if errors.As(err, &wrappedErr) {
+ return wrappedErr.Traceback()
+ }
+ return err.Error()
+}
+
+func GetErrorCause(err error) error {
+ var wrappedErr *WrappedErrorMessage
+
+ if errors.As(err, &wrappedErr) {
+ return wrappedErr.Cause()
+ }
+ return err
+}
diff --git a/internal/util/util.go b/internal/util/util.go
index 4bdd1b5..8dee61e 100644
--- a/internal/util/util.go
+++ b/internal/util/util.go
@@ -2,8 +2,11 @@ package util
import (
"crypto/rand"
+ "fmt"
"os"
"time"
+
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
)
// Creates a random byteslice of `size`
@@ -11,7 +14,7 @@ func MakeRandomByteSlice(size int) ([]byte, error) {
byteSlice := make([]byte, size)
_, err := rand.Read(byteSlice)
if err != nil {
- return nil, err
+ return nil, &types.WrappedErrorMessage{Message: "failed reading random", Err: err}
}
return byteSlice, nil
}
@@ -24,7 +27,7 @@ func GenerateTimeSeconds() int64 {
func EnsureDirectory(directory string) error {
mkdirErr := os.MkdirAll(directory, os.ModePerm)
if mkdirErr != nil {
- return mkdirErr
+ return &types.WrappedErrorMessage{Message: fmt.Sprintf("failed to create directory %s", directory), Err: mkdirErr}
}
return nil
}
diff --git a/internal/verify/verify.go b/internal/verify/verify.go
index 2d53b2e..b159297 100644
--- a/internal/verify/verify.go
+++ b/internal/verify/verify.go
@@ -6,6 +6,7 @@ import (
"os"
"github.com/jedisct1/go-minisign"
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
)
// getKeys returns keys taken from https://git.sr.ht/~eduvpn/disco.eduvpn.org#public-keys.
@@ -28,16 +29,19 @@ func getKeys() []string {
//
// Verify is a wrapper around verifyWithKeys where allowedPublicKeys is set to the list from https://git.sr.ht/~eduvpn/disco.eduvpn.org#public-keys.
func Verify(signatureFileContent string, signedJson []byte, expectedFileName string, minSignTime uint64, forcePrehash bool) (bool, error) {
+ errorMessage := "failed signature verify"
keyStrs := getKeys()
if extraKey != "" {
keyStrs = append(keyStrs, extraKey)
_, err := fmt.Fprintf(os.Stderr, "INSECURE TEST MODE ENABLED WITH KEY %q\n", extraKey)
+ err = &types.WrappedErrorMessage{Message: errorMessage, Err: err}
if err != nil {
panic(err)
}
}
valid, err := verifyWithKeys(signatureFileContent, signedJson, expectedFileName, minSignTime, keyStrs, forcePrehash)
if err != nil {
+ err = &types.WrappedErrorMessage{Message: errorMessage, Err: err}
var verifyCreatePublickeyError *VerifyCreatePublicKeyError
if errors.As(err, &verifyCreatePublickeyError) {
panic(err) // This should not happen unless keyStrs has an invalid key
@@ -67,6 +71,7 @@ func InsecureTestingSetExtraKey(keyString string) {
// The signature is checked to have a timestamp with a value of at least minSignTime, which is a UNIX timestamp without milliseconds.
//
// The return value will either be (true, nil) on success or (false, detailedVerifyError) on failure.
+// Note that every error path is wrapped in a custom type here because minisign does not return custom error types, they use errors.New
func verifyWithKeys(signatureFileContent string, signedJson []byte, filename string, minSignTime uint64, allowedPublicKeys []string, forcePrehash bool) (bool, error) {
switch filename {
case "server_list.json", "organization_list.json":
@@ -143,6 +148,10 @@ func (e *VerifyInvalidSignatureFormatError) Error() string {
return fmt.Sprintf("invalid signature format with error: %v", e.Err)
}
+func (e *VerifyInvalidSignatureFormatError) Unwrap() error {
+ return e.Err
+}
+
type VerifyInvalidSignatureAlgorithmError struct {
Algorithm string
WantedAlgorithm string
@@ -161,6 +170,10 @@ func (e *VerifyCreatePublicKeyError) Error() string {
return fmt.Sprintf("failed to create public key: %s with error: %v", e.PublicKey, e.Err)
}
+func (e *VerifyCreatePublicKeyError) Unwrap() error {
+ return e.Err
+}
+
type VerifyInvalidSignatureError struct {
Err error
}
@@ -169,6 +182,10 @@ func (e *VerifyInvalidSignatureError) Error() string {
return fmt.Sprintf("invalid signature with error: %v", e.Err)
}
+func (e *VerifyInvalidSignatureError) Unwrap() error {
+ return e.Err
+}
+
type VerifyInvalidTrustedCommentError struct {
TrustedComment string
Err error
@@ -178,6 +195,10 @@ func (e *VerifyInvalidTrustedCommentError) Error() string {
return fmt.Sprintf("invalid trusted comment: %s with error: %v", e.TrustedComment, e.Err)
}
+func (e *VerifyInvalidTrustedCommentError) Unwrap() error {
+ return e.Err
+}
+
type VerifyWrongSigFilenameError struct {
Filename string
SigFilename string
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index db20067..bb26b69 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -3,6 +3,8 @@ package wireguard
import (
"fmt"
"regexp"
+
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -10,7 +12,7 @@ func GenerateKey() (wgtypes.Key, error) {
key, keyErr := wgtypes.GeneratePrivateKey()
if keyErr != nil {
- return key, &WireguardGenerateKeyError{Err: keyErr}
+ return key, &types.WrappedErrorMessage{Message: "failed generating WireGuard key", Err: keyErr}
}
return key, nil
}
@@ -28,11 +30,3 @@ func ConfigAddKey(config string, key wgtypes.Key) string {
to_replace := fmt.Sprintf("%s\nPrivateKey = %s", interface_section, key.String())
return interface_re.ReplaceAllString(config, to_replace)
}
-
-type WireguardGenerateKeyError struct {
- Err error
-}
-
-func (e *WireguardGenerateKeyError) Error() string {
- return fmt.Sprintf("failed generating Wireguard key with error: %v", e.Err)
-}
diff --git a/state.go b/state.go
index 6abc901..9187b54 100644
--- a/state.go
+++ b/state.go
@@ -1,13 +1,12 @@
package eduvpn
import (
- "fmt"
-
"github.com/jwijenbergh/eduvpn-common/internal/config"
"github.com/jwijenbergh/eduvpn-common/internal/discovery"
"github.com/jwijenbergh/eduvpn-common/internal/fsm"
"github.com/jwijenbergh/eduvpn-common/internal/log"
"github.com/jwijenbergh/eduvpn-common/internal/server"
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
)
type VPNState struct {
@@ -31,8 +30,9 @@ type VPNState struct {
}
func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error {
+ errorMessage := "failed to register with the GO library"
if !state.FSM.InState(fsm.DEREGISTERED) {
- return &StateWrongFSMStateError{Got: state.FSM.Current, Want: fsm.DEREGISTERED}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.DeregisteredError{}.CustomError()}
}
// Initialize the logger
logLevel := log.LOG_WARNING
@@ -43,7 +43,7 @@ func (state *VPNState) Register(name string, directory string, stateCallback fun
loggerErr := state.Logger.Init(logLevel, name, directory)
if loggerErr != nil {
- return &StateRegisterError{Err: loggerErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: loggerErr}
}
// Initialize the FSM
@@ -78,14 +78,15 @@ func (state *VPNState) Deregister() error {
}
func (state *VPNState) CancelOAuth() error {
+ errorMessage := "failed to cancel OAuth"
if !state.FSM.InState(fsm.OAUTH_STARTED) {
- return &StateWrongFSMStateError{Got: state.FSM.Current, Want: fsm.OAUTH_STARTED}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: state.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()}
}
currentServer, serverErr := state.Servers.GetCurrentServer()
if serverErr != nil {
- return &StateOAuthCancelError{Err: serverErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
server.CancelOAuth(currentServer)
return nil
@@ -96,7 +97,7 @@ func (state *VPNState) chooseServer(url string, isSecureInternet bool) (server.S
server, serverErr := state.Servers.EnsureServer(url, isSecureInternet, &state.FSM, &state.Logger)
if serverErr != nil {
- return nil, serverErr
+ return nil, &types.WrappedErrorMessage{Message: "failed to choose server", Err: serverErr}
}
// Make sure we are in the chosen state if available
@@ -105,20 +106,21 @@ func (state *VPNState) chooseServer(url string, isSecureInternet bool) (server.S
}
func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, forceTCP bool) (string, string, error) {
+ errorMessage := "failed to get a configuration for OpenVPN/Wireguard"
if state.FSM.InState(fsm.DEREGISTERED) {
- return "", "", &StateFSMNotRegisteredError{}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.DeregisteredError{}.CustomError()}
}
// Go to no server if possible, else return an error
if !state.FSM.InState(fsm.NO_SERVER) && !state.FSM.GoTransition(fsm.NO_SERVER) {
- return "", "", &fsm.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: fsm.NO_SERVER}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: state.FSM.Current, Want: fsm.NO_SERVER}.CustomError()}
}
// Make sure the server is chosen
chosenServer, serverErr := state.chooseServer(url, isSecureInternet)
if serverErr != nil {
- return "", "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: serverErr}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
// Relogin with oauth
// This moves the state to authorized
@@ -129,7 +131,7 @@ func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, f
// We are possibly in oauth started
// So go to no server
state.FSM.GoTransition(fsm.NO_SERVER)
- return "", "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: loginErr}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr}
}
} else { // OAuth was valid, ensure we are in the authorized state
state.FSM.GoTransition(fsm.AUTHORIZED)
@@ -142,7 +144,7 @@ func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, f
if configErr != nil {
// Go back to no server if possible
state.FSM.GoTransition(fsm.NO_SERVER)
- return "", "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: configErr}
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
} else {
state.FSM.GoTransition(fsm.HAS_CONFIG)
}
@@ -160,32 +162,33 @@ func (state *VPNState) GetConfigSecureInternet(url string, forceTCP bool) (strin
func (state *VPNState) GetDiscoOrganizations() (string, error) {
if state.FSM.InState(fsm.DEREGISTERED) {
- return "", &StateWrongFSMStateError{Got: state.FSM.Current, Want: fsm.DEREGISTERED}
+ return "", &types.WrappedErrorMessage{Message: "failed to get the organizations with Discovery", Err: fsm.DeregisteredError{}.CustomError()}
}
return state.Discovery.GetOrganizationsList()
}
func (state *VPNState) GetDiscoServers() (string, error) {
if state.FSM.InState(fsm.DEREGISTERED) {
- return "", &StateFSMNotRegisteredError{}
+ return "", &types.WrappedErrorMessage{Message: "failed to get the servers with Discovery", Err: fsm.DeregisteredError{}.CustomError()}
}
return state.Discovery.GetServersList()
}
func (state *VPNState) SetProfileID(profileID string) error {
+ errorMessage := "failed to set the profile ID for the current server"
if !state.FSM.InState(fsm.ASK_PROFILE) {
- return &StateWrongFSMStateError{Got: state.FSM.Current, Want: fsm.ASK_PROFILE}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: state.FSM.Current, Want: fsm.ASK_PROFILE}.CustomError()}
}
server, serverErr := state.Servers.GetCurrentServer()
if serverErr != nil {
- return &StateSetProfileError{ProfileID: profileID, Err: serverErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
base, baseErr := server.GetBase()
if baseErr != nil {
- return &StateSetProfileError{ProfileID: profileID, Err: baseErr}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
base.Profiles.Current = profileID
return nil
@@ -193,7 +196,7 @@ func (state *VPNState) SetProfileID(profileID string) error {
func (state *VPNState) SetConnected() error {
if !state.FSM.HasTransition(fsm.CONNECTED) {
- return &fsm.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: fsm.CONNECTED}
+ return fsm.WrongStateTransitionError{Got: state.FSM.Current, Want: fsm.CONNECTED}.CustomError()
}
state.FSM.GoTransition(fsm.CONNECTED)
@@ -202,59 +205,17 @@ func (state *VPNState) SetConnected() error {
func (state *VPNState) SetDisconnected() error {
if !state.FSM.HasTransition(fsm.HAS_CONFIG) {
- return &fsm.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: fsm.HAS_CONFIG}
+ return fsm.WrongStateTransitionError{Got: state.FSM.Current, Want: fsm.HAS_CONFIG}.CustomError()
}
state.FSM.GoTransition(fsm.HAS_CONFIG)
return nil
}
-type StateSetProfileError struct {
- ProfileID string
- Err error
-}
-
-func (e *StateSetProfileError) Error() string {
- return fmt.Sprintf("failed to set profile ID: %s with error: %v", e.ProfileID, e.Err)
-}
-
-type StateRegisterError struct {
- Err error
-}
-
-func (e *StateRegisterError) Error() string {
- return fmt.Sprintf("failed to register with error: %v", e.Err)
-}
-
-type StateFSMNotRegisteredError struct{}
-
-func (e *StateFSMNotRegisteredError) Error() string {
- return fmt.Sprintf("state is not registered. Current FSM state: %s", fsm.DEREGISTERED.String())
-}
-
-type StateWrongFSMStateError struct {
- Got fsm.FSMStateID
- Want fsm.FSMStateID
-}
-
-func (e *StateWrongFSMStateError) Error() string {
- return fmt.Sprintf("wrong FSM state, got: %s, want: %s", e.Got.String(), e.Want.String())
-}
-
-type StateOAuthCancelError struct {
- Err error
-}
-
-func (e *StateOAuthCancelError) Error() string {
- return fmt.Sprintf("failed cancelling OAuth for state with error: %v", e.Err)
-}
-
-type StateConnectError struct {
- URL string
- IsSecureInternet bool
- Err error
+func (state *VPNState) GetErrorTraceback(err error) string {
+ return types.GetErrorTraceback(err)
}
-func (e *StateConnectError) Error() string {
- return fmt.Sprintf("failed connecting to server: %s (is secure internet: %v) with error: %v", e.URL, e.IsSecureInternet, e.Err)
+func (state *VPNState) GetErrorCause(err error) error {
+ return types.GetErrorCause(err)
}
diff --git a/state_test.go b/state_test.go
index 521513a..76f4763 100644
--- a/state_test.go
+++ b/state_test.go
@@ -15,6 +15,7 @@ import (
httpw "github.com/jwijenbergh/eduvpn-common/internal/http"
"github.com/jwijenbergh/eduvpn-common/internal/oauth"
"github.com/jwijenbergh/eduvpn-common/internal/server"
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
)
func ensureLocalWellKnown() {
@@ -96,34 +97,16 @@ func test_connect_oauth_parameter(t *testing.T, parameters httpw.URLParameters,
}, false)
_, _, configErr := state.GetConfigInstituteAccess(serverURI, false)
- var stateErr *StateConnectError
- var loginErr *oauth.OAuthLoginError
- var finishErr *oauth.OAuthFinishError
+ var wrappedErr *types.WrappedErrorMessage
- // We go through the chain of errors by unwrapping them one by one
-
- // First ensure we get a state connect error
- if !errors.As(configErr, &stateErr) {
- t.Fatalf("error %T = %v, wantErr %T", configErr, configErr, stateErr)
- }
-
- // Then ensure we get a login error
- gotLoginErr := stateErr.Err
-
- if !errors.As(gotLoginErr, &loginErr) {
- t.Fatalf("error %T = %v, wantErr %T", gotLoginErr, gotLoginErr, loginErr)
- }
-
- // Then ensure we get a finish error
- gotFinishErr := loginErr.Err
-
- if !errors.As(gotFinishErr, &finishErr) {
- t.Fatalf("error %T = %v, wantErr %T", gotFinishErr, gotFinishErr, finishErr)
+ // We ensure the error is of a wrappedErrorMessage
+ if !errors.As(configErr, &wrappedErr) {
+ t.Fatalf("error %T = %v, wantErr %T", configErr, configErr, wrappedErr)
}
- // Then ensure we get the expected inner error
- gotExpectedErr := finishErr.Err
+ gotExpectedErr := wrappedErr.Cause()
+ // Then we check if the cause is correct
if !errors.As(gotExpectedErr, expectedErr) {
t.Fatalf("error %T = %v, wantErr %T", gotExpectedErr, gotExpectedErr, expectedErr)
}