diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/config/config.go | 6 | ||||
| -rw-r--r-- | internal/discovery/discovery.go | 52 | ||||
| -rw-r--r-- | internal/http/http.go | 24 | ||||
| -rw-r--r-- | internal/log/log.go | 7 | ||||
| -rw-r--r-- | internal/oauth/oauth.go | 121 | ||||
| -rw-r--r-- | internal/server/api.go | 36 | ||||
| -rw-r--r-- | internal/server/common.go | 62 | ||||
| -rw-r--r-- | internal/server/custom.go | 8 | ||||
| -rw-r--r-- | internal/server/instituteaccess.go | 10 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 24 | ||||
| -rw-r--r-- | internal/util/util.go | 18 | ||||
| -rw-r--r-- | internal/verify/verify.go | 2 | ||||
| -rw-r--r-- | internal/wireguard/wireguard.go | 8 |
13 files changed, 190 insertions, 188 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index 1d5a201..180b881 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -29,11 +29,11 @@ func (config *Config) Save(readStruct interface{}) error { errorMessage := "failed saving configuration" configDirErr := util.EnsureDirectory(config.Directory) if configDirErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: configDirErr} + return types.NewWrappedError(errorMessage, configDirErr) } jsonString, marshalErr := json.Marshal(readStruct) if marshalErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: marshalErr} + return types.NewWrappedError(errorMessage, marshalErr) } return ioutil.WriteFile(config.GetFilename(), jsonString, 0o600) } @@ -41,7 +41,7 @@ func (config *Config) Save(readStruct interface{}) error { func (config *Config) Load(writeStruct interface{}) error { bytes, readErr := ioutil.ReadFile(config.GetFilename()) if readErr != nil { - return &types.WrappedErrorMessage{Message: "failed loading configuration", Err: readErr} + return types.NewWrappedError("failed loading configuration", readErr) } return json.Unmarshal(bytes, writeStruct) } diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 01773fa..d639406 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -25,7 +25,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} _, fileBody, fileErr := http.HTTPGet(fileURL) if fileErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr} + return types.NewWrappedError(errorMessage, fileErr) } // Get signature @@ -34,7 +34,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} _, sigBody, sigFileErr := http.HTTPGet(sigURL) if sigFileErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr} + return types.NewWrappedError(errorMessage, sigFileErr) } // Verify signature @@ -49,14 +49,14 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} ) if !verifySuccess || verifyErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr} + return types.NewWrappedError(errorMessage, verifyErr) } // Parse JSON to extract version and list jsonErr := json.Unmarshal(fileBody, structure) if jsonErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} + return types.NewWrappedError(errorMessage, jsonErr) } return nil @@ -91,10 +91,10 @@ func (discovery *Discovery) GetServerByURL( return &server, nil } } - return nil, &types.WrappedErrorMessage{ - Message: "failed getting server by URL from discovery", - Err: &GetServerByURLNotFoundError{URL: url, Type: _type}, - } + return nil, types.NewWrappedError( + "failed getting server by URL from discovery", + &GetServerByURLNotFoundError{URL: url, Type: _type}, + ) } func (discovery *Discovery) GetServerByCountryCode( @@ -106,10 +106,10 @@ func (discovery *Discovery) GetServerByCountryCode( return &server, nil } } - return nil, &types.WrappedErrorMessage{ - Message: "failed getting server by country code from discovery", - Err: &GetServerByCountryCodeNotFoundError{CountryCode: code, Type: _type}, - } + return nil, types.NewWrappedError( + "failed getting server by country code from discovery", + &GetServerByCountryCodeNotFoundError{CountryCode: code, Type: _type}, + ) } func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganization, error) { @@ -118,10 +118,10 @@ func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganizati return &organization, nil } } - return nil, &types.WrappedErrorMessage{ - Message: "failed getting Secure Internet Home URL from discovery", - Err: &GetOrgByIDNotFoundError{ID: orgID}, - } + return nil, types.NewWrappedError( + "failed getting Secure Internet Home URL from discovery", + &GetOrgByIDNotFoundError{ID: orgID}, + ) } func (discovery *Discovery) GetSecureHomeArgs( @@ -131,7 +131,7 @@ func (discovery *Discovery) GetSecureHomeArgs( org, orgErr := discovery.getOrgByID(orgID) if orgErr != nil { - return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: orgErr} + return nil, nil, types.NewWrappedError(errorMessage, orgErr) } // Get a server with the base url @@ -140,7 +140,7 @@ func (discovery *Discovery) GetSecureHomeArgs( server, serverErr := discovery.GetServerByURL(url, "secure_internet") if serverErr != nil { - return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + return nil, nil, types.NewWrappedError(errorMessage, serverErr) } return org, server, nil } @@ -168,10 +168,10 @@ func (discovery *Discovery) GetOrganizationsList() (*types.DiscoveryOrganization bodyErr := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations) if bodyErr != nil { // Return previous with an error - return &discovery.Organizations, &types.WrappedErrorMessage{ - Message: "failed getting organizations in Discovery", - Err: bodyErr, - } + return &discovery.Organizations, types.NewWrappedError( + "failed getting organizations in Discovery", + bodyErr, + ) } discovery.Organizations.Timestamp = util.GetCurrentTime() return &discovery.Organizations, nil @@ -186,10 +186,10 @@ func (discovery *Discovery) GetServersList() (*types.DiscoveryServers, error) { bodyErr := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers) if bodyErr != nil { // Return previous with an error - return &discovery.Servers, &types.WrappedErrorMessage{ - Message: "failed getting servers in Discovery", - Err: bodyErr, - } + return &discovery.Servers, types.NewWrappedError( + "failed getting servers in Discovery", + bodyErr, + ) } // Update servers timestamp discovery.Servers.Timestamp = util.GetCurrentTime() diff --git a/internal/http/http.go b/internal/http/http.go index 02d83a6..6ff853b 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -25,14 +25,14 @@ type HTTPOptionalParams struct { func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error) { url, parseErr := url.Parse(baseURL) if parseErr != nil { - return "", &types.WrappedErrorMessage{ - Message: fmt.Sprintf( + return "", types.NewWrappedError( + fmt.Sprintf( "failed to construct url: %s including parameters: %v", url, parameters, ), - Err: parseErr, - } + parseErr, + ) } q := url.Query() @@ -66,10 +66,10 @@ func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) { url, urlErr := HTTPConstructURL(url, opts.URLParameters) if urlErr != nil { - return url, &types.WrappedErrorMessage{ - Message: fmt.Sprintf("failed to create HTTP request with url: %s", url), - Err: urlErr, - } + return url, types.NewWrappedError( + fmt.Sprintf("failed to create HTTP request with url: %s", url), + urlErr, + ) } return url, nil } @@ -121,7 +121,7 @@ func HTTPMethodWithOpts( // Create request object with the body reader generated from the optional arguments req, reqErr := http.NewRequest(method, url, httpOptionalBodyReader(opts)) if reqErr != nil { - return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: reqErr} + return nil, nil, types.NewWrappedError(errorMessage, reqErr) } // See https://stackoverflow.com/questions/17714494/golang-http-request-results-in-eof-errors-when-making-multiple-requests-successi @@ -133,7 +133,7 @@ func HTTPMethodWithOpts( // Do request resp, respErr := client.Do(req) if respErr != nil { - return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: respErr} + return nil, nil, types.NewWrappedError(errorMessage, respErr) } // Request successful, make sure body is closed at the end @@ -142,13 +142,13 @@ func HTTPMethodWithOpts( // Return a string body, readErr := ioutil.ReadAll(resp.Body) if readErr != nil { - return resp.Header, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: readErr} + return resp.Header, nil, types.NewWrappedError(errorMessage, readErr) } if resp.StatusCode < 200 || resp.StatusCode > 299 { // We make this a custom error because we want to extract the status code later statusErr := &HTTPStatusError{URL: url, Body: string(body), Status: resp.StatusCode} - return resp.Header, body, &types.WrappedErrorMessage{Message: errorMessage, Err: statusErr} + return resp.Header, body, types.NewWrappedError(errorMessage, statusErr) } // Return the body in bytes and signal the status error if there was one diff --git a/internal/log/log.go b/internal/log/log.go index 970480f..eabecb9 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -52,7 +52,7 @@ func (logger *FileLogger) Init(level LogLevel, name string, directory string) er configDirErr := util.EnsureDirectory(directory) if configDirErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: configDirErr} + return types.NewWrappedError(errorMessage, configDirErr) } logFile, logOpenErr := os.OpenFile( logger.getFilename(directory, name), @@ -60,7 +60,7 @@ func (logger *FileLogger) Init(level LogLevel, name string, directory string) er 0o666, ) if logOpenErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: logOpenErr} + return types.NewWrappedError(errorMessage, logOpenErr) } log.SetOutput(logFile) logger.File = logFile @@ -68,9 +68,10 @@ func (logger *FileLogger) Init(level LogLevel, name string, directory string) er return nil } -func (logger *FileLogger) Inherit(err error, msg string) { +func (logger *FileLogger) Inherit(label string, err error) { level := types.GetErrorLevel(err) + msg := fmt.Sprintf("%s with err: %s", label, types.GetErrorTraceback(err)) switch level { case types.ERR_INFO: logger.Info(msg) diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 6ac773c..df29a9c 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -28,7 +28,7 @@ import ( func genState() (string, error) { randomBytes, err := util.MakeRandomByteSlice(32) if err != nil { - return "", &types.WrappedErrorMessage{Message: "failed generating an OAuth state", Err: err} + return "", types.NewWrappedError("failed generating an OAuth state", err) } // For consistency we also use raw url encoding here @@ -61,10 +61,10 @@ func genChallengeS256(verifier string) string { func genVerifier() (string, error) { randomBytes, err := util.MakeRandomByteSlice(32) if err != nil { - return "", &types.WrappedErrorMessage{ - Message: "failed generating an OAuth verifier", - Err: err, - } + return "", types.NewWrappedError( + "failed generating an OAuth verifier", + err, + ) } return base64.RawURLEncoding.EncodeToString(randomBytes), nil @@ -111,7 +111,7 @@ func (oauth *OAuth) setupListener() error { // create a listener listener, listenerErr := net.Listen("tcp", ":0") if listenerErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: listenerErr} + return types.NewWrappedError(errorMessage, listenerErr) } oauth.Session.Listener = listener return nil @@ -120,7 +120,7 @@ func (oauth *OAuth) setupListener() error { func (oauth *OAuth) getTokensWithCallback() error { errorMessage := "failed getting tokens with callback" if oauth.Session.Listener == nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No listener")} + return types.NewWrappedError(errorMessage, errors.New("No listener")) } mux := http.NewServeMux() // server /callback over the listener address @@ -130,7 +130,7 @@ func (oauth *OAuth) getTokensWithCallback() error { mux.HandleFunc("/callback", oauth.Callback) if err := oauth.Session.Server.Serve(oauth.Session.Listener); err != http.ErrServerClosed { - return &types.WrappedErrorMessage{Message: errorMessage, Err: err} + return types.NewWrappedError(errorMessage, err) } return oauth.Session.CallbackError } @@ -146,7 +146,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { port, portErr := oauth.GetListenerPort() if portErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: portErr} + return types.NewWrappedError(errorMessage, portErr) } data := url.Values{ @@ -163,7 +163,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { current_time := util.GetCurrentTime() _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} + return types.NewWrappedError(errorMessage, bodyErr) } tokenStructure := OAuthToken{} @@ -171,10 +171,10 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { jsonErr := json.Unmarshal(body, &tokenStructure) if jsonErr != nil { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}, - } + return types.NewWrappedError( + errorMessage, + &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}, + ) } tokenStructure.ExpiredTimestamp = current_time.Add( @@ -207,17 +207,17 @@ func (oauth *OAuth) getTokensWithRefresh() error { current_time := util.GetCurrentTime() _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} + return types.NewWrappedError(errorMessage, bodyErr) } tokenStructure := OAuthToken{} jsonErr := json.Unmarshal(body, &tokenStructure) if jsonErr != nil { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}, - } + return types.NewWrappedError( + errorMessage, + &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}, + ) } tokenStructure.ExpiredTimestamp = current_time.Add( @@ -273,7 +273,7 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro errorMessage := "failed writing response HTML" template, templateErr := template.New("oauth-response").Parse(responseTemplate) if templateErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: templateErr} + return types.NewWrappedError(errorMessage, templateErr) } executeErr := template.Execute(w, oauthResponseHTML{ @@ -281,7 +281,7 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro Message: message, }) if executeErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: executeErr} + return types.NewWrappedError(errorMessage, executeErr) } return nil } @@ -310,10 +310,10 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { extractedISS := urlQuery.Get("iss") if extractedISS != "" { if oauth.Session.ISS != extractedISS { - oauth.Session.CallbackError = &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.Session.ISS}, - } + oauth.Session.CallbackError = types.NewWrappedError( + errorMessage, + &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.Session.ISS}, + ) return } @@ -323,31 +323,31 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 extractedState := urlQuery.Get("state") if extractedState == "" { - oauth.Session.CallbackError = &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}, - } + oauth.Session.CallbackError = types.NewWrappedError( + errorMessage, + &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}, + ) return } // The state is the first entry if extractedState != oauth.Session.State { - oauth.Session.CallbackError = &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &OAuthCallbackStateMatchError{ + oauth.Session.CallbackError = types.NewWrappedError( + errorMessage, + &OAuthCallbackStateMatchError{ State: extractedState, ExpectedState: oauth.Session.State, }, - } + ) return } // No authorization code extractedCode := urlQuery.Get("code") if extractedCode == "" { - oauth.Session.CallbackError = &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}, - } + oauth.Session.CallbackError = types.NewWrappedError( + errorMessage, + &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}, + ) return } @@ -355,10 +355,10 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // Obtaining the access and refresh tokens getTokensErr := oauth.getTokensWithAuthCode(extractedCode) if getTokensErr != nil { - oauth.Session.CallbackError = &types.WrappedErrorMessage{ - Message: errorMessage, - Err: getTokensErr, - } + oauth.Session.CallbackError = types.NewWrappedError( + errorMessage, + getTokensErr, + ) return } } @@ -372,7 +372,7 @@ func (oauth OAuth) GetListenerPort() (int, error) { errorMessage := "failed to get listener port" if oauth.Session.Listener == nil { - return 0, &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No OAuth listener")} + return 0, types.NewWrappedError(errorMessage, errors.New("No OAuth listener")) } return oauth.Session.Listener.Addr().(*net.TCPAddr).Port, nil } @@ -384,14 +384,14 @@ func (oauth *OAuth) GetAuthURL(name string, iss string, postProcessAuth func(str // Generate the verifier and challenge verifier, verifierErr := genVerifier() if verifierErr != nil { - return "", &types.WrappedErrorMessage{Message: errorMessage, Err: verifierErr} + return "", types.NewWrappedError(errorMessage, verifierErr) } challenge := genChallengeS256(verifier) // Generate the state state, stateErr := genState() if stateErr != nil { - return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr} + return "", types.NewWrappedError(errorMessage, stateErr) } // Fill the struct with the necessary fields filled for the next call to getting the HTTP client @@ -401,13 +401,13 @@ func (oauth *OAuth) GetAuthURL(name string, iss string, postProcessAuth func(str // set up the listener to get the redirect URI listenerErr := oauth.setupListener() if listenerErr != nil { - return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr} + return "", types.NewWrappedError(errorMessage, stateErr) } // Get the listener port port, portErr := oauth.GetListenerPort() if portErr != nil { - return "", &types.WrappedErrorMessage{Message: errorMessage, Err: portErr} + return "", types.NewWrappedError(errorMessage, portErr) } parameters := map[string]string{ @@ -423,7 +423,7 @@ func (oauth *OAuth) GetAuthURL(name string, iss string, postProcessAuth func(str authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters) if urlErr != nil { - return "", &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} + return "", types.NewWrappedError(errorMessage, urlErr) } // Return the url processed @@ -435,16 +435,17 @@ func (oauth *OAuth) Exchange() error { tokenErr := oauth.getTokensWithCallback() if tokenErr != nil { - return &types.WrappedErrorMessage{Message: "failed finishing OAuth", Err: tokenErr} + return types.NewWrappedError("failed finishing OAuth", tokenErr) } return nil } func (oauth *OAuth) Cancel() { - oauth.Session.CallbackError = &types.WrappedErrorMessage{ - Message: "cancelled OAuth", - Err: &OAuthCancelledCallbackError{}, - } + oauth.Session.CallbackError = types.NewWrappedErrorLevel( + types.ERR_INFO, + "cancelled OAuth", + &OAuthCancelledCallbackError{}, + ) if oauth.Session.Server != nil { oauth.Session.Server.Shutdown(oauth.Session.Context) //nolint:errcheck } @@ -454,10 +455,10 @@ func (oauth *OAuth) EnsureTokens() error { errorMessage := "failed ensuring OAuth tokens" // Access Token or Refresh Tokens empty, we can not ensure the tokens if oauth.Token.Access == "" && oauth.Token.Refresh == "" { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &OAuthTokensInvalidError{Cause: "tokens are empty"}, - } + return types.NewWrappedError( + errorMessage, + &OAuthTokensInvalidError{Cause: "tokens are empty"}, + ) } // We have tokens... @@ -472,12 +473,12 @@ func (oauth *OAuth) EnsureTokens() error { // We have obtained new tokens with refresh if refreshErr != nil { // We have failed to ensure the tokens due to refresh not working - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &OAuthTokensInvalidError{ + return types.NewWrappedError( + errorMessage, + &OAuthTokensInvalidError{ Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr), }, - } + ) } return nil diff --git a/internal/server/api.go b/internal/server/api.go index 05d2528..be7281c 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -17,21 +17,21 @@ func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) { errorMessage := "failed getting server endpoints" url, urlErr := url.Parse(baseURL) if urlErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} + return nil, types.NewWrappedError(errorMessage, urlErr) } url.Path = path.Join(url.Path, WellKnownPath) _, body, bodyErr := httpw.HTTPGet(url.String()) if bodyErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} + return nil, types.NewWrappedError(errorMessage, bodyErr) } endpoints := &ServerEndpoints{} jsonErr := json.Unmarshal(body, endpoints) if jsonErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} + return nil, types.NewWrappedError(errorMessage, jsonErr) } return endpoints, nil @@ -51,20 +51,20 @@ func apiAuthorized( base, baseErr := server.GetBase() if baseErr != nil { - return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return nil, nil, types.NewWrappedError(errorMessage, baseErr) } // Join the paths url, urlErr := url.Parse(base.Endpoints.API.V3.API) if urlErr != nil { - return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} + return nil, nil, types.NewWrappedError(errorMessage, urlErr) } url.Path = path.Join(url.Path, endpoint) // Make sure the tokens are valid, this will return an error if re-login is needed oauthErr := EnsureTokens(server) if oauthErr != nil { - return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr} + return nil, nil, types.NewWrappedError(errorMessage, oauthErr) } headerKey := "Authorization" @@ -95,11 +95,11 @@ func apiAuthorizedRetry( MarkTokenExpired(server) retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts) if retryErr != nil { - return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: retryErr} + return nil, nil, types.NewWrappedError(errorMessage, retryErr) } return retryHeader, retryBody, nil } - return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} + return nil, nil, types.NewWrappedError(errorMessage, bodyErr) } return header, body, nil } @@ -108,19 +108,19 @@ func APIInfo(server Server) error { errorMessage := "failed API /info" _, body, bodyErr := apiAuthorizedRetry(server, http.MethodGet, "/info", nil) if bodyErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} + return types.NewWrappedError(errorMessage, bodyErr) } structure := ServerProfileInfo{} jsonErr := json.Unmarshal(body, &structure) if jsonErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} + return types.NewWrappedError(errorMessage, jsonErr) } base, baseErr := server.GetBase() if baseErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return types.NewWrappedError(errorMessage, baseErr) } // Store the profiles and make sure that the current profile is not overwritten @@ -169,17 +169,17 @@ func APIConnectWireguard( &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm}, ) if connectErr != nil { - return "", "", time.Time{}, &types.WrappedErrorMessage{ - Message: errorMessage, - Err: connectErr, - } + return "", "", time.Time{}, types.NewWrappedError( + errorMessage, + connectErr, + ) } expires := header.Get("expires") pTime, pTimeErr := http.ParseTime(expires) if pTimeErr != nil { - return "", "", time.Time{}, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr} + return "", "", time.Time{}, types.NewWrappedError(errorMessage, pTimeErr) } contentType := header.Get("content-type") @@ -210,13 +210,13 @@ func APIConnectOpenVPN(server Server, profile_id string, preferTCP bool) (string &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm}, ) if connectErr != nil { - return "", time.Time{}, &types.WrappedErrorMessage{Message: errorMessage, Err: connectErr} + return "", time.Time{}, types.NewWrappedError(errorMessage, connectErr) } expires := header.Get("expires") pTime, pTimeErr := http.ParseTime(expires) if pTimeErr != nil { - return "", time.Time{}, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr} + return "", time.Time{}, types.NewWrappedError(errorMessage, pTimeErr) } return string(connectBody), pTime, nil } diff --git a/internal/server/common.go b/internal/server/common.go index bf72bc6..bf6f4ca 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -97,10 +97,10 @@ func (servers *Servers) GetCurrentServer() (Server, error) { errorMessage := "failed getting current server" if servers.IsType == SecureInternetServerType { if !servers.HasSecureLocation() { - return nil, &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &ServerGetCurrentNotFoundError{}, - } + return nil, types.NewWrappedError( + errorMessage, + &ServerGetCurrentNotFoundError{}, + ) } return &servers.SecureInternetHomeServer, nil } @@ -113,18 +113,18 @@ func (servers *Servers) GetCurrentServer() (Server, error) { currentServerURL := serversStruct.CurrentURL bases := serversStruct.Map if bases == nil { - return nil, &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &ServerGetCurrentNoMapError{}, - } + return nil, types.NewWrappedError( + errorMessage, + &ServerGetCurrentNoMapError{}, + ) } server, exists := bases[currentServerURL] if !exists || server == nil { - return nil, &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &ServerGetCurrentNotFoundError{}, - } + return nil, types.NewWrappedError( + errorMessage, + &ServerGetCurrentNotFoundError{}, + ) } return server, nil } @@ -161,7 +161,7 @@ func (servers *Servers) addInstituteAndCustom( discoServer.SupportContact, ) if instituteInitErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: instituteInitErr} + return nil, types.NewWrappedError(errorMessage, instituteInitErr) } toAddServers.Map[url] = server servers.IsType = serverType @@ -192,7 +192,7 @@ func (servers *Servers) SetSecureLocation( _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer) if addLocationErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: addLocationErr} + return types.NewWrappedError(errorMessage, addLocationErr) } servers.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode @@ -209,7 +209,7 @@ func (servers *Servers) AddSecureInternet( initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer) if initErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr} + return nil, types.NewWrappedError(errorMessage, initErr) } servers.IsType = SecureInternetServerType @@ -255,7 +255,7 @@ func ShouldRenewButton(server Server) bool { func GetISS(server Server) (string, error) { base, baseErr := server.GetBase() if baseErr != nil { - return "", &types.WrappedErrorMessage{Message: "failed getting server ISS", Err: baseErr} + return "", types.NewWrappedError("failed getting server ISS", baseErr) } // We have already ensured that the base URL ends with a / return base.URL, nil @@ -288,7 +288,7 @@ func MarkTokensForRenew(server Server) { func EnsureTokens(server Server) error { ensureErr := server.GetOAuth().EnsureTokens() if ensureErr != nil { - return &types.WrappedErrorMessage{Message: "failed ensuring server tokens", Err: ensureErr} + return types.NewWrappedError("failed ensuring server tokens", ensureErr) } return nil } @@ -323,7 +323,7 @@ func getCurrentProfile(server Server) (*ServerProfile, error) { base, baseErr := server.GetBase() if baseErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return nil, types.NewWrappedError(errorMessage, baseErr) } profileID := base.Profiles.Current for _, profile := range base.Profiles.Info.ProfileList { @@ -332,10 +332,10 @@ func getCurrentProfile(server Server) (*ServerProfile, error) { } } - return nil, &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &ServerGetCurrentProfileNotFoundError{ProfileID: profileID}, - } + return nil, types.NewWrappedError( + errorMessage, + &ServerGetCurrentProfileNotFoundError{ProfileID: profileID}, + ) } func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (string, string, error) { @@ -343,14 +343,14 @@ func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (st base, baseErr := server.GetBase() if baseErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return "", "", types.NewWrappedError(errorMessage, baseErr) } profile_id := base.Profiles.Current wireguardKey, wireguardErr := wireguard.GenerateKey() if wireguardErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: wireguardErr} + return "", "", types.NewWrappedError(errorMessage, wireguardErr) } wireguardPublicKey := wireguardKey.PublicKey().String() @@ -363,7 +363,7 @@ func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (st ) if configErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} + return "", "", types.NewWrappedError(errorMessage, configErr) } // Store start and end time @@ -386,7 +386,7 @@ func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) { base, baseErr := server.GetBase() if baseErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return "", "", types.NewWrappedError(errorMessage, baseErr) } profile_id := base.Profiles.Current configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profile_id, preferTCP) @@ -396,7 +396,7 @@ func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) { base.EndTime = expires if configErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} + return "", "", types.NewWrappedError(errorMessage, configErr) } return configOpenVPN, "openvpn", nil @@ -409,12 +409,12 @@ func HasValidProfile(server Server) (bool, error) { // This does not override the current profile infoErr := APIInfo(server) if infoErr != nil { - return false, &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr} + return false, types.NewWrappedError(errorMessage, infoErr) } base, baseErr := server.GetBase() if baseErr != nil { - return false, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return false, types.NewWrappedError(errorMessage, baseErr) } // If there was a profile chosen and it doesn't exist anymore, reset it @@ -442,7 +442,7 @@ func GetConfig(server Server, preferTCP bool) (string, string, error) { profile, profileErr := getCurrentProfile(server) if profileErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr} + return "", "", types.NewWrappedError(errorMessage, profileErr) } supportsOpenVPN := profile.supportsOpenVPN() @@ -461,7 +461,7 @@ func GetConfig(server Server, preferTCP bool) (string, string, error) { } if configErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} + return "", "", types.NewWrappedError(errorMessage, configErr) } return config, configType, nil diff --git a/internal/server/custom.go b/internal/server/custom.go index feda1f3..6ba6503 100644 --- a/internal/server/custom.go +++ b/internal/server/custom.go @@ -11,18 +11,18 @@ func (servers *Servers) SetCustomServer(server Server) error { errorMessage := "failed setting custom server" base, baseErr := server.GetBase() if baseErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return types.NewWrappedError(errorMessage, baseErr) } if base.Type != "custom_server" { - return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("Not a custom server")} + return types.NewWrappedError(errorMessage, errors.New("Not a custom server")) } if _, ok := servers.CustomServers.Map[base.URL]; ok { servers.CustomServers.CurrentURL = base.URL servers.IsType = CustomServerType } else { - return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("Not a custom server")} + return types.NewWrappedError(errorMessage, errors.New("Not a custom server")) } return nil } @@ -31,7 +31,7 @@ func (servers *Servers) GetCustomServer(url string) (*InstituteAccessServer, err if server, ok := servers.CustomServers.Map[url]; ok { return server, nil } - return nil, &types.WrappedErrorMessage{Message: "failed to get institute access server", Err: fmt.Errorf("No custom server with URL: %s", url)} + return nil, types.NewWrappedError("failed to get institute access server", fmt.Errorf("No custom server with URL: %s", url)) } func (servers *Servers) RemoveCustomServer(url string) { diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index bf0e2bc..0f097b0 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -26,18 +26,18 @@ func (servers *Servers) SetInstituteAccess(server Server) error { errorMessage := "failed setting institute access server" base, baseErr := server.GetBase() if baseErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return types.NewWrappedError(errorMessage, baseErr) } if base.Type != "institute_access" { - return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("Not an institute access server")} + return types.NewWrappedError(errorMessage, errors.New("Not an institute access server")) } if _, ok := servers.InstituteServers.Map[base.URL]; ok { servers.InstituteServers.CurrentURL = base.URL servers.IsType = InstituteAccessServerType } else { - return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No such institute access server")} + return types.NewWrappedError(errorMessage, errors.New("No such institute access server")) } return nil } @@ -46,7 +46,7 @@ func (servers *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, if server, ok := servers.InstituteServers.Map[url]; ok { return server, nil } - return nil, &types.WrappedErrorMessage{Message: "failed to get institute access server", Err: fmt.Errorf("No institute access server with URL: %s", url)} + return nil, types.NewWrappedError("failed to get institute access server", fmt.Errorf("No institute access server with URL: %s", url)) } func (servers *Servers) RemoveInstituteAccess(url string) { @@ -91,7 +91,7 @@ func (institute *InstituteAccessServer) init( institute.Base.Type = serverType endpoints, endpointsErr := APIGetEndpoints(url) if endpointsErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} + return types.NewWrappedError(errorMessage, endpointsErr) } institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token) institute.Base.Endpoints = *endpoints diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 27d48a5..93e83cf 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -35,11 +35,11 @@ func (servers *Servers) SetSecureInternet(server Server) error { errorMessage := "failed setting secure internet server" base, baseErr := server.GetBase() if baseErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return types.NewWrappedError(errorMessage, baseErr) } if base.Type != "secure_internet" { - return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("Not a secure internet server")} + return types.NewWrappedError(errorMessage, errors.New("Not a secure internet server")) } // The location should already be configured @@ -71,19 +71,19 @@ func (secure *SecureInternetHomeServer) GetTemplateAuth() func(string) string { func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) { errorMessage := "failed getting current secure internet home base" if server.BaseMap == nil { - return nil, &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &ServerSecureInternetMapNotFoundError{}, - } + return nil, types.NewWrappedError( + errorMessage, + &ServerSecureInternetMapNotFoundError{}, + ) } base, exists := server.BaseMap[server.CurrentLocation] if !exists { - return nil, &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentLocation}, - } + return nil, types.NewWrappedError( + errorMessage, + &ServerSecureInternetBaseNotFoundError{Current: server.CurrentLocation}, + ) } return base, nil } @@ -113,7 +113,7 @@ func (secure *SecureInternetHomeServer) addLocation( base.Type = "secure_internet" endpoints, endpointsErr := APIGetEndpoints(locationServer.BaseURL) if endpointsErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} + return nil, types.NewWrappedError(errorMessage, endpointsErr) } base.Endpoints = *endpoints } @@ -145,7 +145,7 @@ func (secure *SecureInternetHomeServer) init( base, baseErr := secure.addLocation(homeLocation) if baseErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} + return types.NewWrappedError(errorMessage, baseErr) } // Make sure oauth contains our endpoints diff --git a/internal/util/util.go b/internal/util/util.go index a500e15..ef52ce2 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -15,10 +15,10 @@ import ( func EnsureValidURL(s string) (string, error) { parsedURL, parseErr := url.Parse(s) if parseErr != nil { - return "", &types.WrappedErrorMessage{ - Message: fmt.Sprintf("failed parsing url: %s", s), - Err: parseErr, - } + return "", types.NewWrappedError( + fmt.Sprintf("failed parsing url: %s", s), + parseErr, + ) } if parsedURL.Scheme == "" { @@ -44,7 +44,7 @@ func MakeRandomByteSlice(size int) ([]byte, error) { byteSlice := make([]byte, size) _, err := rand.Read(byteSlice) if err != nil { - return nil, &types.WrappedErrorMessage{Message: "failed reading random", Err: err} + return nil, types.NewWrappedError("failed reading random", err) } return byteSlice, nil } @@ -57,10 +57,10 @@ func EnsureDirectory(directory string) error { // Create with 700 permissions, read, write, execute only for the owner mkdirErr := os.MkdirAll(directory, 0o700) if mkdirErr != nil { - return &types.WrappedErrorMessage{ - Message: fmt.Sprintf("failed to create directory %s", directory), - Err: mkdirErr, - } + return types.NewWrappedError( + fmt.Sprintf("failed to create directory %s", directory), + mkdirErr, + ) } return nil } diff --git a/internal/verify/verify.go b/internal/verify/verify.go index 43b6c74..2dd0472 100644 --- a/internal/verify/verify.go +++ b/internal/verify/verify.go @@ -39,7 +39,7 @@ func Verify( forcePrehash, ) if err != nil { - return valid, &types.WrappedErrorMessage{Message: "failed signature verify", Err: err} + return valid, types.NewWrappedError("failed signature verify", err) } return valid, nil } diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 3d3ae8e..0a1ba5f 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -12,10 +12,10 @@ func GenerateKey() (wgtypes.Key, error) { key, keyErr := wgtypes.GeneratePrivateKey() if keyErr != nil { - return key, &types.WrappedErrorMessage{ - Message: "failed generating WireGuard key", - Err: keyErr, - } + return key, types.NewWrappedError( + "failed generating WireGuard key", + keyErr, + ) } return key, nil } |
