summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/config/config.go6
-rw-r--r--internal/discovery/discovery.go52
-rw-r--r--internal/http/http.go24
-rw-r--r--internal/log/log.go7
-rw-r--r--internal/oauth/oauth.go121
-rw-r--r--internal/server/api.go36
-rw-r--r--internal/server/common.go62
-rw-r--r--internal/server/custom.go8
-rw-r--r--internal/server/instituteaccess.go10
-rw-r--r--internal/server/secureinternet.go24
-rw-r--r--internal/util/util.go18
-rw-r--r--internal/verify/verify.go2
-rw-r--r--internal/wireguard/wireguard.go8
13 files changed, 190 insertions, 188 deletions
diff --git a/internal/config/config.go b/internal/config/config.go
index 1d5a201..180b881 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -29,11 +29,11 @@ func (config *Config) Save(readStruct interface{}) error {
errorMessage := "failed saving configuration"
configDirErr := util.EnsureDirectory(config.Directory)
if configDirErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: configDirErr}
+ return types.NewWrappedError(errorMessage, configDirErr)
}
jsonString, marshalErr := json.Marshal(readStruct)
if marshalErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: marshalErr}
+ return types.NewWrappedError(errorMessage, marshalErr)
}
return ioutil.WriteFile(config.GetFilename(), jsonString, 0o600)
}
@@ -41,7 +41,7 @@ func (config *Config) Save(readStruct interface{}) error {
func (config *Config) Load(writeStruct interface{}) error {
bytes, readErr := ioutil.ReadFile(config.GetFilename())
if readErr != nil {
- return &types.WrappedErrorMessage{Message: "failed loading configuration", Err: readErr}
+ return types.NewWrappedError("failed loading configuration", readErr)
}
return json.Unmarshal(bytes, writeStruct)
}
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index 01773fa..d639406 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -25,7 +25,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
_, fileBody, fileErr := http.HTTPGet(fileURL)
if fileErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr}
+ return types.NewWrappedError(errorMessage, fileErr)
}
// Get signature
@@ -34,7 +34,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
_, sigBody, sigFileErr := http.HTTPGet(sigURL)
if sigFileErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr}
+ return types.NewWrappedError(errorMessage, sigFileErr)
}
// Verify signature
@@ -49,14 +49,14 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
)
if !verifySuccess || verifyErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr}
+ return types.NewWrappedError(errorMessage, verifyErr)
}
// Parse JSON to extract version and list
jsonErr := json.Unmarshal(fileBody, structure)
if jsonErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr}
+ return types.NewWrappedError(errorMessage, jsonErr)
}
return nil
@@ -91,10 +91,10 @@ func (discovery *Discovery) GetServerByURL(
return &server, nil
}
}
- return nil, &types.WrappedErrorMessage{
- Message: "failed getting server by URL from discovery",
- Err: &GetServerByURLNotFoundError{URL: url, Type: _type},
- }
+ return nil, types.NewWrappedError(
+ "failed getting server by URL from discovery",
+ &GetServerByURLNotFoundError{URL: url, Type: _type},
+ )
}
func (discovery *Discovery) GetServerByCountryCode(
@@ -106,10 +106,10 @@ func (discovery *Discovery) GetServerByCountryCode(
return &server, nil
}
}
- return nil, &types.WrappedErrorMessage{
- Message: "failed getting server by country code from discovery",
- Err: &GetServerByCountryCodeNotFoundError{CountryCode: code, Type: _type},
- }
+ return nil, types.NewWrappedError(
+ "failed getting server by country code from discovery",
+ &GetServerByCountryCodeNotFoundError{CountryCode: code, Type: _type},
+ )
}
func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganization, error) {
@@ -118,10 +118,10 @@ func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganizati
return &organization, nil
}
}
- return nil, &types.WrappedErrorMessage{
- Message: "failed getting Secure Internet Home URL from discovery",
- Err: &GetOrgByIDNotFoundError{ID: orgID},
- }
+ return nil, types.NewWrappedError(
+ "failed getting Secure Internet Home URL from discovery",
+ &GetOrgByIDNotFoundError{ID: orgID},
+ )
}
func (discovery *Discovery) GetSecureHomeArgs(
@@ -131,7 +131,7 @@ func (discovery *Discovery) GetSecureHomeArgs(
org, orgErr := discovery.getOrgByID(orgID)
if orgErr != nil {
- return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: orgErr}
+ return nil, nil, types.NewWrappedError(errorMessage, orgErr)
}
// Get a server with the base url
@@ -140,7 +140,7 @@ func (discovery *Discovery) GetSecureHomeArgs(
server, serverErr := discovery.GetServerByURL(url, "secure_internet")
if serverErr != nil {
- return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
+ return nil, nil, types.NewWrappedError(errorMessage, serverErr)
}
return org, server, nil
}
@@ -168,10 +168,10 @@ func (discovery *Discovery) GetOrganizationsList() (*types.DiscoveryOrganization
bodyErr := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations)
if bodyErr != nil {
// Return previous with an error
- return &discovery.Organizations, &types.WrappedErrorMessage{
- Message: "failed getting organizations in Discovery",
- Err: bodyErr,
- }
+ return &discovery.Organizations, types.NewWrappedError(
+ "failed getting organizations in Discovery",
+ bodyErr,
+ )
}
discovery.Organizations.Timestamp = util.GetCurrentTime()
return &discovery.Organizations, nil
@@ -186,10 +186,10 @@ func (discovery *Discovery) GetServersList() (*types.DiscoveryServers, error) {
bodyErr := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers)
if bodyErr != nil {
// Return previous with an error
- return &discovery.Servers, &types.WrappedErrorMessage{
- Message: "failed getting servers in Discovery",
- Err: bodyErr,
- }
+ return &discovery.Servers, types.NewWrappedError(
+ "failed getting servers in Discovery",
+ bodyErr,
+ )
}
// Update servers timestamp
discovery.Servers.Timestamp = util.GetCurrentTime()
diff --git a/internal/http/http.go b/internal/http/http.go
index 02d83a6..6ff853b 100644
--- a/internal/http/http.go
+++ b/internal/http/http.go
@@ -25,14 +25,14 @@ type HTTPOptionalParams struct {
func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error) {
url, parseErr := url.Parse(baseURL)
if parseErr != nil {
- return "", &types.WrappedErrorMessage{
- Message: fmt.Sprintf(
+ return "", types.NewWrappedError(
+ fmt.Sprintf(
"failed to construct url: %s including parameters: %v",
url,
parameters,
),
- Err: parseErr,
- }
+ parseErr,
+ )
}
q := url.Query()
@@ -66,10 +66,10 @@ func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) {
url, urlErr := HTTPConstructURL(url, opts.URLParameters)
if urlErr != nil {
- return url, &types.WrappedErrorMessage{
- Message: fmt.Sprintf("failed to create HTTP request with url: %s", url),
- Err: urlErr,
- }
+ return url, types.NewWrappedError(
+ fmt.Sprintf("failed to create HTTP request with url: %s", url),
+ urlErr,
+ )
}
return url, nil
}
@@ -121,7 +121,7 @@ func HTTPMethodWithOpts(
// Create request object with the body reader generated from the optional arguments
req, reqErr := http.NewRequest(method, url, httpOptionalBodyReader(opts))
if reqErr != nil {
- return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: reqErr}
+ return nil, nil, types.NewWrappedError(errorMessage, reqErr)
}
// See https://stackoverflow.com/questions/17714494/golang-http-request-results-in-eof-errors-when-making-multiple-requests-successi
@@ -133,7 +133,7 @@ func HTTPMethodWithOpts(
// Do request
resp, respErr := client.Do(req)
if respErr != nil {
- return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: respErr}
+ return nil, nil, types.NewWrappedError(errorMessage, respErr)
}
// Request successful, make sure body is closed at the end
@@ -142,13 +142,13 @@ func HTTPMethodWithOpts(
// Return a string
body, readErr := ioutil.ReadAll(resp.Body)
if readErr != nil {
- return resp.Header, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: readErr}
+ return resp.Header, nil, types.NewWrappedError(errorMessage, readErr)
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
// We make this a custom error because we want to extract the status code later
statusErr := &HTTPStatusError{URL: url, Body: string(body), Status: resp.StatusCode}
- return resp.Header, body, &types.WrappedErrorMessage{Message: errorMessage, Err: statusErr}
+ return resp.Header, body, types.NewWrappedError(errorMessage, statusErr)
}
// Return the body in bytes and signal the status error if there was one
diff --git a/internal/log/log.go b/internal/log/log.go
index 970480f..eabecb9 100644
--- a/internal/log/log.go
+++ b/internal/log/log.go
@@ -52,7 +52,7 @@ func (logger *FileLogger) Init(level LogLevel, name string, directory string) er
configDirErr := util.EnsureDirectory(directory)
if configDirErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: configDirErr}
+ return types.NewWrappedError(errorMessage, configDirErr)
}
logFile, logOpenErr := os.OpenFile(
logger.getFilename(directory, name),
@@ -60,7 +60,7 @@ func (logger *FileLogger) Init(level LogLevel, name string, directory string) er
0o666,
)
if logOpenErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: logOpenErr}
+ return types.NewWrappedError(errorMessage, logOpenErr)
}
log.SetOutput(logFile)
logger.File = logFile
@@ -68,9 +68,10 @@ func (logger *FileLogger) Init(level LogLevel, name string, directory string) er
return nil
}
-func (logger *FileLogger) Inherit(err error, msg string) {
+func (logger *FileLogger) Inherit(label string, err error) {
level := types.GetErrorLevel(err)
+ msg := fmt.Sprintf("%s with err: %s", label, types.GetErrorTraceback(err))
switch level {
case types.ERR_INFO:
logger.Info(msg)
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 6ac773c..df29a9c 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -28,7 +28,7 @@ import (
func genState() (string, error) {
randomBytes, err := util.MakeRandomByteSlice(32)
if err != nil {
- return "", &types.WrappedErrorMessage{Message: "failed generating an OAuth state", Err: err}
+ return "", types.NewWrappedError("failed generating an OAuth state", err)
}
// For consistency we also use raw url encoding here
@@ -61,10 +61,10 @@ func genChallengeS256(verifier string) string {
func genVerifier() (string, error) {
randomBytes, err := util.MakeRandomByteSlice(32)
if err != nil {
- return "", &types.WrappedErrorMessage{
- Message: "failed generating an OAuth verifier",
- Err: err,
- }
+ return "", types.NewWrappedError(
+ "failed generating an OAuth verifier",
+ err,
+ )
}
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
@@ -111,7 +111,7 @@ func (oauth *OAuth) setupListener() error {
// create a listener
listener, listenerErr := net.Listen("tcp", ":0")
if listenerErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: listenerErr}
+ return types.NewWrappedError(errorMessage, listenerErr)
}
oauth.Session.Listener = listener
return nil
@@ -120,7 +120,7 @@ func (oauth *OAuth) setupListener() error {
func (oauth *OAuth) getTokensWithCallback() error {
errorMessage := "failed getting tokens with callback"
if oauth.Session.Listener == nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No listener")}
+ return types.NewWrappedError(errorMessage, errors.New("No listener"))
}
mux := http.NewServeMux()
// server /callback over the listener address
@@ -130,7 +130,7 @@ func (oauth *OAuth) getTokensWithCallback() error {
mux.HandleFunc("/callback", oauth.Callback)
if err := oauth.Session.Server.Serve(oauth.Session.Listener); err != http.ErrServerClosed {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: err}
+ return types.NewWrappedError(errorMessage, err)
}
return oauth.Session.CallbackError
}
@@ -146,7 +146,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
port, portErr := oauth.GetListenerPort()
if portErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: portErr}
+ return types.NewWrappedError(errorMessage, portErr)
}
data := url.Values{
@@ -163,7 +163,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
current_time := util.GetCurrentTime()
_, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
+ return types.NewWrappedError(errorMessage, bodyErr)
}
tokenStructure := OAuthToken{}
@@ -171,10 +171,10 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
jsonErr := json.Unmarshal(body, &tokenStructure)
if jsonErr != nil {
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr},
- }
+ return types.NewWrappedError(
+ errorMessage,
+ &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr},
+ )
}
tokenStructure.ExpiredTimestamp = current_time.Add(
@@ -207,17 +207,17 @@ func (oauth *OAuth) getTokensWithRefresh() error {
current_time := util.GetCurrentTime()
_, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
+ return types.NewWrappedError(errorMessage, bodyErr)
}
tokenStructure := OAuthToken{}
jsonErr := json.Unmarshal(body, &tokenStructure)
if jsonErr != nil {
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr},
- }
+ return types.NewWrappedError(
+ errorMessage,
+ &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr},
+ )
}
tokenStructure.ExpiredTimestamp = current_time.Add(
@@ -273,7 +273,7 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro
errorMessage := "failed writing response HTML"
template, templateErr := template.New("oauth-response").Parse(responseTemplate)
if templateErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: templateErr}
+ return types.NewWrappedError(errorMessage, templateErr)
}
executeErr := template.Execute(w, oauthResponseHTML{
@@ -281,7 +281,7 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro
Message: message,
})
if executeErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: executeErr}
+ return types.NewWrappedError(errorMessage, executeErr)
}
return nil
}
@@ -310,10 +310,10 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
extractedISS := urlQuery.Get("iss")
if extractedISS != "" {
if oauth.Session.ISS != extractedISS {
- oauth.Session.CallbackError = &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.Session.ISS},
- }
+ oauth.Session.CallbackError = types.NewWrappedError(
+ errorMessage,
+ &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.Session.ISS},
+ )
return
}
@@ -323,31 +323,31 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
extractedState := urlQuery.Get("state")
if extractedState == "" {
- oauth.Session.CallbackError = &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()},
- }
+ oauth.Session.CallbackError = types.NewWrappedError(
+ errorMessage,
+ &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()},
+ )
return
}
// The state is the first entry
if extractedState != oauth.Session.State {
- oauth.Session.CallbackError = &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &OAuthCallbackStateMatchError{
+ oauth.Session.CallbackError = types.NewWrappedError(
+ errorMessage,
+ &OAuthCallbackStateMatchError{
State: extractedState,
ExpectedState: oauth.Session.State,
},
- }
+ )
return
}
// No authorization code
extractedCode := urlQuery.Get("code")
if extractedCode == "" {
- oauth.Session.CallbackError = &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()},
- }
+ oauth.Session.CallbackError = types.NewWrappedError(
+ errorMessage,
+ &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()},
+ )
return
}
@@ -355,10 +355,10 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// Obtaining the access and refresh tokens
getTokensErr := oauth.getTokensWithAuthCode(extractedCode)
if getTokensErr != nil {
- oauth.Session.CallbackError = &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: getTokensErr,
- }
+ oauth.Session.CallbackError = types.NewWrappedError(
+ errorMessage,
+ getTokensErr,
+ )
return
}
}
@@ -372,7 +372,7 @@ func (oauth OAuth) GetListenerPort() (int, error) {
errorMessage := "failed to get listener port"
if oauth.Session.Listener == nil {
- return 0, &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No OAuth listener")}
+ return 0, types.NewWrappedError(errorMessage, errors.New("No OAuth listener"))
}
return oauth.Session.Listener.Addr().(*net.TCPAddr).Port, nil
}
@@ -384,14 +384,14 @@ func (oauth *OAuth) GetAuthURL(name string, iss string, postProcessAuth func(str
// Generate the verifier and challenge
verifier, verifierErr := genVerifier()
if verifierErr != nil {
- return "", &types.WrappedErrorMessage{Message: errorMessage, Err: verifierErr}
+ return "", types.NewWrappedError(errorMessage, verifierErr)
}
challenge := genChallengeS256(verifier)
// Generate the state
state, stateErr := genState()
if stateErr != nil {
- return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr}
+ return "", types.NewWrappedError(errorMessage, stateErr)
}
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
@@ -401,13 +401,13 @@ func (oauth *OAuth) GetAuthURL(name string, iss string, postProcessAuth func(str
// set up the listener to get the redirect URI
listenerErr := oauth.setupListener()
if listenerErr != nil {
- return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr}
+ return "", types.NewWrappedError(errorMessage, stateErr)
}
// Get the listener port
port, portErr := oauth.GetListenerPort()
if portErr != nil {
- return "", &types.WrappedErrorMessage{Message: errorMessage, Err: portErr}
+ return "", types.NewWrappedError(errorMessage, portErr)
}
parameters := map[string]string{
@@ -423,7 +423,7 @@ func (oauth *OAuth) GetAuthURL(name string, iss string, postProcessAuth func(str
authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters)
if urlErr != nil {
- return "", &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
+ return "", types.NewWrappedError(errorMessage, urlErr)
}
// Return the url processed
@@ -435,16 +435,17 @@ func (oauth *OAuth) Exchange() error {
tokenErr := oauth.getTokensWithCallback()
if tokenErr != nil {
- return &types.WrappedErrorMessage{Message: "failed finishing OAuth", Err: tokenErr}
+ return types.NewWrappedError("failed finishing OAuth", tokenErr)
}
return nil
}
func (oauth *OAuth) Cancel() {
- oauth.Session.CallbackError = &types.WrappedErrorMessage{
- Message: "cancelled OAuth",
- Err: &OAuthCancelledCallbackError{},
- }
+ oauth.Session.CallbackError = types.NewWrappedErrorLevel(
+ types.ERR_INFO,
+ "cancelled OAuth",
+ &OAuthCancelledCallbackError{},
+ )
if oauth.Session.Server != nil {
oauth.Session.Server.Shutdown(oauth.Session.Context) //nolint:errcheck
}
@@ -454,10 +455,10 @@ func (oauth *OAuth) EnsureTokens() error {
errorMessage := "failed ensuring OAuth tokens"
// Access Token or Refresh Tokens empty, we can not ensure the tokens
if oauth.Token.Access == "" && oauth.Token.Refresh == "" {
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &OAuthTokensInvalidError{Cause: "tokens are empty"},
- }
+ return types.NewWrappedError(
+ errorMessage,
+ &OAuthTokensInvalidError{Cause: "tokens are empty"},
+ )
}
// We have tokens...
@@ -472,12 +473,12 @@ func (oauth *OAuth) EnsureTokens() error {
// We have obtained new tokens with refresh
if refreshErr != nil {
// We have failed to ensure the tokens due to refresh not working
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &OAuthTokensInvalidError{
+ return types.NewWrappedError(
+ errorMessage,
+ &OAuthTokensInvalidError{
Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr),
},
- }
+ )
}
return nil
diff --git a/internal/server/api.go b/internal/server/api.go
index 05d2528..be7281c 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -17,21 +17,21 @@ func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) {
errorMessage := "failed getting server endpoints"
url, urlErr := url.Parse(baseURL)
if urlErr != nil {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
+ return nil, types.NewWrappedError(errorMessage, urlErr)
}
url.Path = path.Join(url.Path, WellKnownPath)
_, body, bodyErr := httpw.HTTPGet(url.String())
if bodyErr != nil {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
+ return nil, types.NewWrappedError(errorMessage, bodyErr)
}
endpoints := &ServerEndpoints{}
jsonErr := json.Unmarshal(body, endpoints)
if jsonErr != nil {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr}
+ return nil, types.NewWrappedError(errorMessage, jsonErr)
}
return endpoints, nil
@@ -51,20 +51,20 @@ func apiAuthorized(
base, baseErr := server.GetBase()
if baseErr != nil {
- return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
+ return nil, nil, types.NewWrappedError(errorMessage, baseErr)
}
// Join the paths
url, urlErr := url.Parse(base.Endpoints.API.V3.API)
if urlErr != nil {
- return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
+ return nil, nil, types.NewWrappedError(errorMessage, urlErr)
}
url.Path = path.Join(url.Path, endpoint)
// Make sure the tokens are valid, this will return an error if re-login is needed
oauthErr := EnsureTokens(server)
if oauthErr != nil {
- return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr}
+ return nil, nil, types.NewWrappedError(errorMessage, oauthErr)
}
headerKey := "Authorization"
@@ -95,11 +95,11 @@ func apiAuthorizedRetry(
MarkTokenExpired(server)
retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts)
if retryErr != nil {
- return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: retryErr}
+ return nil, nil, types.NewWrappedError(errorMessage, retryErr)
}
return retryHeader, retryBody, nil
}
- return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
+ return nil, nil, types.NewWrappedError(errorMessage, bodyErr)
}
return header, body, nil
}
@@ -108,19 +108,19 @@ func APIInfo(server Server) error {
errorMessage := "failed API /info"
_, body, bodyErr := apiAuthorizedRetry(server, http.MethodGet, "/info", nil)
if bodyErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
+ return types.NewWrappedError(errorMessage, bodyErr)
}
structure := ServerProfileInfo{}
jsonErr := json.Unmarshal(body, &structure)
if jsonErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr}
+ return types.NewWrappedError(errorMessage, jsonErr)
}
base, baseErr := server.GetBase()
if baseErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
+ return types.NewWrappedError(errorMessage, baseErr)
}
// Store the profiles and make sure that the current profile is not overwritten
@@ -169,17 +169,17 @@ func APIConnectWireguard(
&httpw.HTTPOptionalParams{Headers: headers, Body: urlForm},
)
if connectErr != nil {
- return "", "", time.Time{}, &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: connectErr,
- }
+ return "", "", time.Time{}, types.NewWrappedError(
+ errorMessage,
+ connectErr,
+ )
}
expires := header.Get("expires")
pTime, pTimeErr := http.ParseTime(expires)
if pTimeErr != nil {
- return "", "", time.Time{}, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr}
+ return "", "", time.Time{}, types.NewWrappedError(errorMessage, pTimeErr)
}
contentType := header.Get("content-type")
@@ -210,13 +210,13 @@ func APIConnectOpenVPN(server Server, profile_id string, preferTCP bool) (string
&httpw.HTTPOptionalParams{Headers: headers, Body: urlForm},
)
if connectErr != nil {
- return "", time.Time{}, &types.WrappedErrorMessage{Message: errorMessage, Err: connectErr}
+ return "", time.Time{}, types.NewWrappedError(errorMessage, connectErr)
}
expires := header.Get("expires")
pTime, pTimeErr := http.ParseTime(expires)
if pTimeErr != nil {
- return "", time.Time{}, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr}
+ return "", time.Time{}, types.NewWrappedError(errorMessage, pTimeErr)
}
return string(connectBody), pTime, nil
}
diff --git a/internal/server/common.go b/internal/server/common.go
index bf72bc6..bf6f4ca 100644
--- a/internal/server/common.go
+++ b/internal/server/common.go
@@ -97,10 +97,10 @@ func (servers *Servers) GetCurrentServer() (Server, error) {
errorMessage := "failed getting current server"
if servers.IsType == SecureInternetServerType {
if !servers.HasSecureLocation() {
- return nil, &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &ServerGetCurrentNotFoundError{},
- }
+ return nil, types.NewWrappedError(
+ errorMessage,
+ &ServerGetCurrentNotFoundError{},
+ )
}
return &servers.SecureInternetHomeServer, nil
}
@@ -113,18 +113,18 @@ func (servers *Servers) GetCurrentServer() (Server, error) {
currentServerURL := serversStruct.CurrentURL
bases := serversStruct.Map
if bases == nil {
- return nil, &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &ServerGetCurrentNoMapError{},
- }
+ return nil, types.NewWrappedError(
+ errorMessage,
+ &ServerGetCurrentNoMapError{},
+ )
}
server, exists := bases[currentServerURL]
if !exists || server == nil {
- return nil, &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &ServerGetCurrentNotFoundError{},
- }
+ return nil, types.NewWrappedError(
+ errorMessage,
+ &ServerGetCurrentNotFoundError{},
+ )
}
return server, nil
}
@@ -161,7 +161,7 @@ func (servers *Servers) addInstituteAndCustom(
discoServer.SupportContact,
)
if instituteInitErr != nil {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: instituteInitErr}
+ return nil, types.NewWrappedError(errorMessage, instituteInitErr)
}
toAddServers.Map[url] = server
servers.IsType = serverType
@@ -192,7 +192,7 @@ func (servers *Servers) SetSecureLocation(
_, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer)
if addLocationErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: addLocationErr}
+ return types.NewWrappedError(errorMessage, addLocationErr)
}
servers.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode
@@ -209,7 +209,7 @@ func (servers *Servers) AddSecureInternet(
initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer)
if initErr != nil {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr}
+ return nil, types.NewWrappedError(errorMessage, initErr)
}
servers.IsType = SecureInternetServerType
@@ -255,7 +255,7 @@ func ShouldRenewButton(server Server) bool {
func GetISS(server Server) (string, error) {
base, baseErr := server.GetBase()
if baseErr != nil {
- return "", &types.WrappedErrorMessage{Message: "failed getting server ISS", Err: baseErr}
+ return "", types.NewWrappedError("failed getting server ISS", baseErr)
}
// We have already ensured that the base URL ends with a /
return base.URL, nil
@@ -288,7 +288,7 @@ func MarkTokensForRenew(server Server) {
func EnsureTokens(server Server) error {
ensureErr := server.GetOAuth().EnsureTokens()
if ensureErr != nil {
- return &types.WrappedErrorMessage{Message: "failed ensuring server tokens", Err: ensureErr}
+ return types.NewWrappedError("failed ensuring server tokens", ensureErr)
}
return nil
}
@@ -323,7 +323,7 @@ func getCurrentProfile(server Server) (*ServerProfile, error) {
base, baseErr := server.GetBase()
if baseErr != nil {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
+ return nil, types.NewWrappedError(errorMessage, baseErr)
}
profileID := base.Profiles.Current
for _, profile := range base.Profiles.Info.ProfileList {
@@ -332,10 +332,10 @@ func getCurrentProfile(server Server) (*ServerProfile, error) {
}
}
- return nil, &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &ServerGetCurrentProfileNotFoundError{ProfileID: profileID},
- }
+ return nil, types.NewWrappedError(
+ errorMessage,
+ &ServerGetCurrentProfileNotFoundError{ProfileID: profileID},
+ )
}
func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (string, string, error) {
@@ -343,14 +343,14 @@ func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (st
base, baseErr := server.GetBase()
if baseErr != nil {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
+ return "", "", types.NewWrappedError(errorMessage, baseErr)
}
profile_id := base.Profiles.Current
wireguardKey, wireguardErr := wireguard.GenerateKey()
if wireguardErr != nil {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: wireguardErr}
+ return "", "", types.NewWrappedError(errorMessage, wireguardErr)
}
wireguardPublicKey := wireguardKey.PublicKey().String()
@@ -363,7 +363,7 @@ func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (st
)
if configErr != nil {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
+ return "", "", types.NewWrappedError(errorMessage, configErr)
}
// Store start and end time
@@ -386,7 +386,7 @@ func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) {
base, baseErr := server.GetBase()
if baseErr != nil {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
+ return "", "", types.NewWrappedError(errorMessage, baseErr)
}
profile_id := base.Profiles.Current
configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profile_id, preferTCP)
@@ -396,7 +396,7 @@ func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) {
base.EndTime = expires
if configErr != nil {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
+ return "", "", types.NewWrappedError(errorMessage, configErr)
}
return configOpenVPN, "openvpn", nil
@@ -409,12 +409,12 @@ func HasValidProfile(server Server) (bool, error) {
// This does not override the current profile
infoErr := APIInfo(server)
if infoErr != nil {
- return false, &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr}
+ return false, types.NewWrappedError(errorMessage, infoErr)
}
base, baseErr := server.GetBase()
if baseErr != nil {
- return false, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
+ return false, types.NewWrappedError(errorMessage, baseErr)
}
// If there was a profile chosen and it doesn't exist anymore, reset it
@@ -442,7 +442,7 @@ func GetConfig(server Server, preferTCP bool) (string, string, error) {
profile, profileErr := getCurrentProfile(server)
if profileErr != nil {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr}
+ return "", "", types.NewWrappedError(errorMessage, profileErr)
}
supportsOpenVPN := profile.supportsOpenVPN()
@@ -461,7 +461,7 @@ func GetConfig(server Server, preferTCP bool) (string, string, error) {
}
if configErr != nil {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
+ return "", "", types.NewWrappedError(errorMessage, configErr)
}
return config, configType, nil
diff --git a/internal/server/custom.go b/internal/server/custom.go
index feda1f3..6ba6503 100644
--- a/internal/server/custom.go
+++ b/internal/server/custom.go
@@ -11,18 +11,18 @@ func (servers *Servers) SetCustomServer(server Server) error {
errorMessage := "failed setting custom server"
base, baseErr := server.GetBase()
if baseErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
+ return types.NewWrappedError(errorMessage, baseErr)
}
if base.Type != "custom_server" {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("Not a custom server")}
+ return types.NewWrappedError(errorMessage, errors.New("Not a custom server"))
}
if _, ok := servers.CustomServers.Map[base.URL]; ok {
servers.CustomServers.CurrentURL = base.URL
servers.IsType = CustomServerType
} else {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("Not a custom server")}
+ return types.NewWrappedError(errorMessage, errors.New("Not a custom server"))
}
return nil
}
@@ -31,7 +31,7 @@ func (servers *Servers) GetCustomServer(url string) (*InstituteAccessServer, err
if server, ok := servers.CustomServers.Map[url]; ok {
return server, nil
}
- return nil, &types.WrappedErrorMessage{Message: "failed to get institute access server", Err: fmt.Errorf("No custom server with URL: %s", url)}
+ return nil, types.NewWrappedError("failed to get institute access server", fmt.Errorf("No custom server with URL: %s", url))
}
func (servers *Servers) RemoveCustomServer(url string) {
diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go
index bf0e2bc..0f097b0 100644
--- a/internal/server/instituteaccess.go
+++ b/internal/server/instituteaccess.go
@@ -26,18 +26,18 @@ func (servers *Servers) SetInstituteAccess(server Server) error {
errorMessage := "failed setting institute access server"
base, baseErr := server.GetBase()
if baseErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
+ return types.NewWrappedError(errorMessage, baseErr)
}
if base.Type != "institute_access" {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("Not an institute access server")}
+ return types.NewWrappedError(errorMessage, errors.New("Not an institute access server"))
}
if _, ok := servers.InstituteServers.Map[base.URL]; ok {
servers.InstituteServers.CurrentURL = base.URL
servers.IsType = InstituteAccessServerType
} else {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No such institute access server")}
+ return types.NewWrappedError(errorMessage, errors.New("No such institute access server"))
}
return nil
}
@@ -46,7 +46,7 @@ func (servers *Servers) GetInstituteAccess(url string) (*InstituteAccessServer,
if server, ok := servers.InstituteServers.Map[url]; ok {
return server, nil
}
- return nil, &types.WrappedErrorMessage{Message: "failed to get institute access server", Err: fmt.Errorf("No institute access server with URL: %s", url)}
+ return nil, types.NewWrappedError("failed to get institute access server", fmt.Errorf("No institute access server with URL: %s", url))
}
func (servers *Servers) RemoveInstituteAccess(url string) {
@@ -91,7 +91,7 @@ func (institute *InstituteAccessServer) init(
institute.Base.Type = serverType
endpoints, endpointsErr := APIGetEndpoints(url)
if endpointsErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr}
+ return types.NewWrappedError(errorMessage, endpointsErr)
}
institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token)
institute.Base.Endpoints = *endpoints
diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go
index 27d48a5..93e83cf 100644
--- a/internal/server/secureinternet.go
+++ b/internal/server/secureinternet.go
@@ -35,11 +35,11 @@ func (servers *Servers) SetSecureInternet(server Server) error {
errorMessage := "failed setting secure internet server"
base, baseErr := server.GetBase()
if baseErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
+ return types.NewWrappedError(errorMessage, baseErr)
}
if base.Type != "secure_internet" {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("Not a secure internet server")}
+ return types.NewWrappedError(errorMessage, errors.New("Not a secure internet server"))
}
// The location should already be configured
@@ -71,19 +71,19 @@ func (secure *SecureInternetHomeServer) GetTemplateAuth() func(string) string {
func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) {
errorMessage := "failed getting current secure internet home base"
if server.BaseMap == nil {
- return nil, &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &ServerSecureInternetMapNotFoundError{},
- }
+ return nil, types.NewWrappedError(
+ errorMessage,
+ &ServerSecureInternetMapNotFoundError{},
+ )
}
base, exists := server.BaseMap[server.CurrentLocation]
if !exists {
- return nil, &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentLocation},
- }
+ return nil, types.NewWrappedError(
+ errorMessage,
+ &ServerSecureInternetBaseNotFoundError{Current: server.CurrentLocation},
+ )
}
return base, nil
}
@@ -113,7 +113,7 @@ func (secure *SecureInternetHomeServer) addLocation(
base.Type = "secure_internet"
endpoints, endpointsErr := APIGetEndpoints(locationServer.BaseURL)
if endpointsErr != nil {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr}
+ return nil, types.NewWrappedError(errorMessage, endpointsErr)
}
base.Endpoints = *endpoints
}
@@ -145,7 +145,7 @@ func (secure *SecureInternetHomeServer) init(
base, baseErr := secure.addLocation(homeLocation)
if baseErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
+ return types.NewWrappedError(errorMessage, baseErr)
}
// Make sure oauth contains our endpoints
diff --git a/internal/util/util.go b/internal/util/util.go
index a500e15..ef52ce2 100644
--- a/internal/util/util.go
+++ b/internal/util/util.go
@@ -15,10 +15,10 @@ import (
func EnsureValidURL(s string) (string, error) {
parsedURL, parseErr := url.Parse(s)
if parseErr != nil {
- return "", &types.WrappedErrorMessage{
- Message: fmt.Sprintf("failed parsing url: %s", s),
- Err: parseErr,
- }
+ return "", types.NewWrappedError(
+ fmt.Sprintf("failed parsing url: %s", s),
+ parseErr,
+ )
}
if parsedURL.Scheme == "" {
@@ -44,7 +44,7 @@ func MakeRandomByteSlice(size int) ([]byte, error) {
byteSlice := make([]byte, size)
_, err := rand.Read(byteSlice)
if err != nil {
- return nil, &types.WrappedErrorMessage{Message: "failed reading random", Err: err}
+ return nil, types.NewWrappedError("failed reading random", err)
}
return byteSlice, nil
}
@@ -57,10 +57,10 @@ func EnsureDirectory(directory string) error {
// Create with 700 permissions, read, write, execute only for the owner
mkdirErr := os.MkdirAll(directory, 0o700)
if mkdirErr != nil {
- return &types.WrappedErrorMessage{
- Message: fmt.Sprintf("failed to create directory %s", directory),
- Err: mkdirErr,
- }
+ return types.NewWrappedError(
+ fmt.Sprintf("failed to create directory %s", directory),
+ mkdirErr,
+ )
}
return nil
}
diff --git a/internal/verify/verify.go b/internal/verify/verify.go
index 43b6c74..2dd0472 100644
--- a/internal/verify/verify.go
+++ b/internal/verify/verify.go
@@ -39,7 +39,7 @@ func Verify(
forcePrehash,
)
if err != nil {
- return valid, &types.WrappedErrorMessage{Message: "failed signature verify", Err: err}
+ return valid, types.NewWrappedError("failed signature verify", err)
}
return valid, nil
}
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index 3d3ae8e..0a1ba5f 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -12,10 +12,10 @@ func GenerateKey() (wgtypes.Key, error) {
key, keyErr := wgtypes.GeneratePrivateKey()
if keyErr != nil {
- return key, &types.WrappedErrorMessage{
- Message: "failed generating WireGuard key",
- Err: keyErr,
- }
+ return key, types.NewWrappedError(
+ "failed generating WireGuard key",
+ keyErr,
+ )
}
return key, nil
}