From 3ac1d35257b56cca92ad0eb7f4d18abb366cf105 Mon Sep 17 00:00:00 2001 From: Aleksandar Pesic Date: Sun, 4 Dec 2022 21:48:20 +0100 Subject: simplify error handling fixes #6 Signed-off-by: Aleksandar Pesic --- internal/config/config.go | 50 +++--- internal/discovery/discovery.go | 181 ++++++--------------- internal/fsm/fsm.go | 91 +++++------ internal/http/http.go | 128 +++++---------- internal/log/log.go | 100 ++++++++---- internal/oauth/oauth.go | 307 ++++++++++++----------------------- internal/server/api.go | 214 +++++++++++-------------- internal/server/base.go | 29 ++-- internal/server/custom.go | 41 ++--- internal/server/instituteaccess.go | 85 +++++----- internal/server/secureinternet.go | 151 +++++++----------- internal/server/server.go | 318 ++++++++++++++----------------------- internal/server/servers.go | 117 ++++++-------- internal/util/util.go | 112 ++++++------- internal/verify/verify.go | 144 ++--------------- internal/verify/verify_test.go | 123 +++++++------- internal/wireguard/wireguard.go | 13 +- 17 files changed, 850 insertions(+), 1354 deletions(-) (limited to 'internal') diff --git a/internal/config/config.go b/internal/config/config.go index 6761d62..ae023ff 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,12 +4,11 @@ package config import ( "encoding/json" - "fmt" - "io/ioutil" + "os" "path" "github.com/eduvpn/eduvpn-common/internal/util" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) // Config represents a configuration that saves the client's struct as JSON. @@ -22,38 +21,41 @@ type Config struct { } // Init initializes the configuration using the provided directory and name. -func (config *Config) Init(directory string, name string) { - config.Directory = directory - config.Name = name +func (c *Config) Init(directory string, name string) { + c.Directory = directory + c.Name = name } // filename returns the filename of the configuration as a full path. -func (config *Config) filename() string { - pathString := path.Join(config.Directory, config.Name) - return fmt.Sprintf("%s.json", pathString) +func (c *Config) filename() string { + return path.Join(c.Directory, c.Name) + ".json" } // Save saves a structure 'readStruct' to the configuration -// If it was unusuccessful, an error is returned. -func (config *Config) Save(readStruct interface{}) error { - errorMessage := "failed saving configuration" - configDirErr := util.EnsureDirectory(config.Directory) - if configDirErr != nil { - return types.NewWrappedError(errorMessage, configDirErr) +// If it was unsuccessful, an error is returned. +func (c *Config) Save(readStruct interface{}) error { + if err := util.EnsureDirectory(c.Directory); err != nil { + return err + } + cfg, err := json.Marshal(readStruct) + if err != nil { + return errors.WrapPrefix(err, "json.Marshal failed", 0) } - jsonString, marshalErr := json.Marshal(readStruct) - if marshalErr != nil { - return types.NewWrappedError(errorMessage, marshalErr) + if err = os.WriteFile(c.filename(), cfg, 0o600); err != nil { + return errors.WrapPrefix(err, "os.WriteFile failed", 0) } - return ioutil.WriteFile(config.filename(), jsonString, 0o600) + return nil } // Load loads the configuration and writes the structure to 'writeStruct' // If it was unsuccessful, an error is returned. -func (config *Config) Load(writeStruct interface{}) error { - bytes, readErr := ioutil.ReadFile(config.filename()) - if readErr != nil { - return types.NewWrappedError("failed loading configuration", readErr) +func (c *Config) Load(writeStruct interface{}) error { + bts, err := os.ReadFile(c.filename()) + if err != nil { + return errors.WrapPrefix(err, "failed loading configuration", 0) + } + if err = json.Unmarshal(bts, writeStruct); err != nil { + return errors.WrapPrefix(err, "json.Unmarshal failed", 0) } - return json.Unmarshal(bytes, writeStruct) + return nil } diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 35c2689..32bec66 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -1,4 +1,4 @@ -// package discovery implements the server discovery by contacting disco.eduvpn.org and returning the data as a Go structure +// Package discovery implements the server discovery by contacting disco.eduvpn.org and returning the data as a Go structure package discovery import ( @@ -9,6 +9,7 @@ import ( "github.com/eduvpn/eduvpn-common/internal/http" "github.com/eduvpn/eduvpn-common/internal/verify" "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) // Discovery is the main structure used for this package. @@ -23,45 +24,41 @@ type Discovery struct { // discoFile is a helper function that gets a disco JSON and fills the structure with it // If it was unsuccessful it returns an error. func discoFile(jsonFile string, previousVersion uint64, structure interface{}) error { - errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile) // Get json data - discoURL := "https://disco.eduvpn.org/v2/" - fileURL := discoURL + jsonFile - _, fileBody, fileErr := http.Get(fileURL) - - if fileErr != nil { - return types.NewWrappedError(errorMessage, fileErr) + du := "https://disco.eduvpn.org/v2/" + fu := du + jsonFile + _, body, err := http.Get(fu) + if err != nil { + return err } // Get signature sigFile := jsonFile + ".minisig" - sigURL := discoURL + sigFile - _, sigBody, sigFileErr := http.Get(sigURL) - - if sigFileErr != nil { - return types.NewWrappedError(errorMessage, sigFileErr) + sigURL := du + sigFile + _, sigBody, err := http.Get(sigURL) + if err != nil { + return err } // Verify signature // Set this to true when we want to force prehash - forcePrehash := false - verifySuccess, verifyErr := verify.Verify( + const forcePrehash = false + ok, err := verify.Verify( string(sigBody), - fileBody, + body, jsonFile, previousVersion, forcePrehash, ) - if !verifySuccess || verifyErr != nil { - return types.NewWrappedError(errorMessage, verifyErr) + if !ok || err != nil { + return err } // Parse JSON to extract version and list - jsonErr := json.Unmarshal(fileBody, structure) - - if jsonErr != nil { - return types.NewWrappedError(errorMessage, jsonErr) + if err = json.Unmarshal(body, structure); err != nil { + return errors.WrapPrefix(err, + fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile), 0) } return nil @@ -80,86 +77,67 @@ func (discovery *Discovery) DetermineOrganizationsUpdate() bool { // SecureLocationList returns a slice of all the available locations. func (discovery *Discovery) SecureLocationList() []string { - var locations []string - for _, currentServer := range discovery.servers.List { - if currentServer.Type == "secure_internet" { - locations = append(locations, currentServer.CountryCode) + var loc []string + for _, srv := range discovery.servers.List { + if srv.Type == "secure_internet" { + loc = append(loc, srv.CountryCode) } } - return locations + return loc } // ServerByURL returns the discovery server by the base URL and the according type ("secure_internet", "institute_access") // An error is returned if and only if nil is returned for the server. func (discovery *Discovery) ServerByURL( baseURL string, - serverType string, + srvType string, ) (*types.DiscoveryServer, error) { for _, currentServer := range discovery.servers.List { - if currentServer.BaseURL == baseURL && currentServer.Type == serverType { + if currentServer.BaseURL == baseURL && currentServer.Type == srvType { return ¤tServer, nil } } - return nil, types.NewWrappedError( - "failed getting server by URL from discovery", - &GetServerByURLNotFoundError{URL: baseURL, Type: serverType}, - ) + return nil, errors.Errorf("no server of type '%s' at URL '%s'", srvType, baseURL) } // ServerByCountryCode returns the discovery server by the country code and the according type ("secure_internet", "institute_access") // An error is returned if and only if nil is returned for the server. -func (discovery *Discovery) ServerByCountryCode( - countryCode string, - serverType string, -) (*types.DiscoveryServer, error) { - for _, currentServer := range discovery.servers.List { - if currentServer.CountryCode == countryCode && currentServer.Type == serverType { - return ¤tServer, nil +func (discovery *Discovery) ServerByCountryCode(countryCode string, srvType string) (*types.DiscoveryServer, error) { + for _, srv := range discovery.servers.List { + if srv.CountryCode == countryCode && srv.Type == srvType { + return &srv, nil } } - return nil, types.NewWrappedError( - "failed getting server by country countryCode from discovery", - &GetServerByCountryCodeNotFoundError{CountryCode: countryCode, Type: serverType}, - ) + return nil, errors.Errorf("no server of type '%s' with country code '%s'", srvType, countryCode) } // orgByID returns the discovery organization by the organization ID // An error is returned if and only if nil is returned for the organization. func (discovery *Discovery) orgByID(orgID string) (*types.DiscoveryOrganization, error) { - for _, organization := range discovery.organizations.List { - if organization.OrgID == orgID { - return &organization, nil + for _, org := range discovery.organizations.List { + if org.OrgID == orgID { + return &org, nil } } - return nil, types.NewWrappedError( - "failed getting Secure Internet Home URL from discovery", - &GetOrgByIDNotFoundError{ID: orgID}, - ) + return nil, errors.Errorf("no secure internet home found in organization '%s'", orgID) } // SecureHomeArgs returns the secure internet home server arguments: // - The organization it belongs to // - The secure internet server itself // An error is returned if and only if nil is returned for the organization. -func (discovery *Discovery) SecureHomeArgs( - orgID string, -) (*types.DiscoveryOrganization, *types.DiscoveryServer, error) { - errorMessage := "failed getting Secure Internet Home arguments from discovery" - org, orgErr := discovery.orgByID(orgID) - - if orgErr != nil { - return nil, nil, types.NewWrappedError(errorMessage, orgErr) +func (discovery *Discovery) SecureHomeArgs(orgID string) (*types.DiscoveryOrganization, *types.DiscoveryServer, error) { + org, err := discovery.orgByID(orgID) + if err != nil { + return nil, nil, err } // Get a server with the base url - url := org.SecureInternetHome - - currentServer, serverErr := discovery.ServerByURL(url, "secure_internet") - - if serverErr != nil { - return nil, nil, types.NewWrappedError(errorMessage, serverErr) + srv, err := discovery.ServerByURL(org.SecureInternetHome, "secure_internet") + if err != nil { + return nil, nil, err } - return org, currentServer, nil + return org, srv, nil } // DetermineServersUpdate returns whether or not the discovery servers should be updated by contacting the discovery server @@ -172,9 +150,8 @@ func (discovery *Discovery) DetermineServersUpdate() bool { return true } // 1 hour from the last update - shouldUpdateTime := discovery.servers.Timestamp.Add(1 * time.Hour) - now := time.Now() - return !now.Before(shouldUpdateTime) + upd := discovery.servers.Timestamp.Add(1 * time.Hour) + return !time.Now().Before(upd) } // Organizations returns the discovery organizations @@ -184,13 +161,10 @@ func (discovery *Discovery) Organizations() (*types.DiscoveryOrganizations, erro return &discovery.organizations, nil } file := "organization_list.json" - bodyErr := discoFile(file, discovery.organizations.Version, &discovery.organizations) - if bodyErr != nil { + err := discoFile(file, discovery.organizations.Version, &discovery.organizations) + if err != nil { // Return previous with an error - return &discovery.organizations, types.NewWrappedError( - "failed getting organizations in Discovery", - bodyErr, - ) + return &discovery.organizations, err } discovery.organizations.Timestamp = time.Now() return &discovery.organizations, nil @@ -203,63 +177,12 @@ func (discovery *Discovery) Servers() (*types.DiscoveryServers, error) { return &discovery.servers, nil } file := "server_list.json" - bodyErr := discoFile(file, discovery.servers.Version, &discovery.servers) - if bodyErr != nil { + err := discoFile(file, discovery.servers.Version, &discovery.servers) + if err != nil { // Return previous with an error - return &discovery.servers, types.NewWrappedError( - "failed getting servers in Discovery", - bodyErr, - ) + return &discovery.servers, err } // Update servers timestamp discovery.servers.Timestamp = time.Now() return &discovery.servers, nil } - -type GetOrgByIDNotFoundError struct { - ID string -} - -func (e GetOrgByIDNotFoundError) Error() string { - return fmt.Sprintf( - "No Secure Internet Home found in organizations with ID %s. Please choose your server again", - e.ID, - ) -} - -type GetServerByURLNotFoundError struct { - URL string - Type string -} - -func (e GetServerByURLNotFoundError) Error() string { - return fmt.Sprintf( - "No institute access server found in organizations with URL %s and type %s. Please choose your server again", - e.URL, - e.Type, - ) -} - -type GetServerByCountryCodeNotFoundError struct { - CountryCode string - Type string -} - -func (e GetServerByCountryCodeNotFoundError) Error() string { - return fmt.Sprintf( - "No institute access server found in organizations with country code %s and type %s", - e.CountryCode, - e.Type, - ) -} - -type GetSecureHomeArgsNotFoundError struct { - URL string -} - -func (e GetSecureHomeArgsNotFoundError) Error() string { - return fmt.Sprintf( - "No Secure Internet Home found with URL: %s. Please choose your server again", - e.URL, - ) -} diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go index 5bfa712..6c8923b 100644 --- a/internal/fsm/fsm.go +++ b/internal/fsm/fsm.go @@ -9,7 +9,7 @@ import ( "path" "sort" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) type ( @@ -97,36 +97,47 @@ func (fsm *FSM) InState(check StateID) bool { } // HasTransition checks whether or not the state machine has a transition to the given 'check' state. -func (fsm *FSM) HasTransition(check StateID) bool { - for _, transitionState := range fsm.States[fsm.Current].Transitions { - if transitionState.To == check { - return true +//func (fsm *FSM) HasTransition(check StateID) bool { +// for _, transitionState := range fsm.States[fsm.Current].Transitions { +// if transitionState.To == check { +// return true +// } +// } +// +// return false +//} + +func (fsm *FSM) CheckTransition(desired StateID) error { + for _, ts := range fsm.States[fsm.Current].Transitions { + if ts.To == desired { + return nil } } - - return false + return errors.Errorf("fsm invalid transition attempt from '%v' to '%v'", fsm.Current, desired) } // graphFilename gets the full path to the graph filename including the .graph extension. func (fsm *FSM) graphFilename(extension string) string { - debugPath := path.Join(fsm.Directory, "graph") - return fmt.Sprintf("%s%s", debugPath, extension) + pth := path.Join(fsm.Directory, "graph") + return fmt.Sprintf("%s%s", pth, extension) } // writeGraph writes the state machine to a .graph file. func (fsm *FSM) writeGraph() { - graph := fsm.GenerateGraph() - graphFile := fsm.graphFilename(".graph") - graphImgFile := fsm.graphFilename(".png") - f, err := os.Create(graphFile) + gph := fsm.GenerateGraph() + gf := fsm.graphFilename(".graph") + gif := fsm.graphFilename(".png") + f, err := os.Create(gf) if err != nil { return } + defer func() { + _ = f.Close() + }() - _, writeErr := f.WriteString(graph) - f.Close() - if writeErr != nil { - cmd := exec.Command("mmdc", "-i", graphFile, "-o", graphImgFile, "--scale", "4") + _, err = f.WriteString(gph) + if err != nil { + cmd := exec.Command("mmdc", "-i", gf, "-o", gif, "--scale", "4") // Generating is best effort _ = cmd.Start() } @@ -137,14 +148,7 @@ func (fsm *FSM) writeGraph() { func (fsm *FSM) GoTransitionRequired(newState StateID, data interface{}) error { oldState := fsm.Current if !fsm.GoTransitionWithData(newState, data) { - return types.NewWrappedError( - "failed required transition", - fmt.Errorf( - "required transition not handled, from: %s -> to: %s", - fsm.GetStateName(oldState), - fsm.GetStateName(newState), - ), - ) + return errors.Errorf("fsm failed transition from '%v' to '%v'", oldState, newState) } return nil } @@ -152,20 +156,17 @@ func (fsm *FSM) GoTransitionRequired(newState StateID, data interface{}) error { // GoTransitionWithData is a helper that transitions the state machine toward the 'newState' with associated state data 'data' // It returns whether or not the transition is handled by the client. func (fsm *FSM) GoTransitionWithData(newState StateID, data interface{}) bool { - ok := fsm.HasTransition(newState) - - handled := false - if ok { - oldState := fsm.Current - fsm.Current = newState - if fsm.Generate { - fsm.writeGraph() - } + if fsm.CheckTransition(newState) != nil { + return false + } - handled = fsm.StateCallback(oldState, newState, data) + prev := fsm.Current + fsm.Current = newState + if fsm.Generate { + fsm.writeGraph() } - return handled + return fsm.StateCallback(prev, newState, data) } // GoTransition is an alias to call GoTransitionWithData but have an empty string as data. @@ -177,21 +178,21 @@ func (fsm *FSM) GoTransition(newState StateID) bool { // generateMermaidGraph generates a graph suitable to be converted by the mermaid.js tool // it returns the graph as a string. func (fsm *FSM) generateMermaidGraph() string { - graph := "graph TD\n" - sortedFSM := make(StateIDSlice, 0, len(fsm.States)) + gph := "graph TD\n" + sf := make(StateIDSlice, 0, len(fsm.States)) for stateID := range fsm.States { - sortedFSM = append(sortedFSM, stateID) + sf = append(sf, stateID) } - sort.Sort(sortedFSM) - for _, state := range sortedFSM { + sort.Sort(sf) + for _, state := range sf { transitions := fsm.States[state].Transitions for _, transition := range transitions { if state == fsm.Current { - graph += "\nstyle " + fsm.GetStateName(state) + " fill:cyan\n" + gph += "\nstyle " + fsm.GetStateName(state) + " fill:cyan\n" } else { - graph += "\nstyle " + fsm.GetStateName(state) + " fill:white\n" + gph += "\nstyle " + fsm.GetStateName(state) + " fill:white\n" } - graph += fsm.GetStateName( + gph += fsm.GetStateName( state, ) + "(" + fsm.GetStateName( state, @@ -200,7 +201,7 @@ func (fsm *FSM) generateMermaidGraph() string { ) + "\n" } } - return graph + return gph } // GenerateGraph generates a mermaid graph if the state machine is initialized diff --git a/internal/http/http.go b/internal/http/http.go index 1d0ec45..e1f9bdf 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -4,16 +4,15 @@ package http import ( "fmt" "io" - "io/ioutil" "net/http" "net/url" "strings" "time" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) -// The URLParemeters as the name suggests is a type used for the parameters in the URL. +// URLParameters is a type used for the parameters in the URL. type URLParameters map[string]string // OptionalParams is a structure that defines the optional parameters that are given when making a HTTP call. @@ -25,26 +24,20 @@ type OptionalParams struct { } // ConstructURL creates a URL with the included parameters. -func ConstructURL(baseURL string, parameters URLParameters) (string, error) { - url, parseErr := url.Parse(baseURL) - if parseErr != nil { - return "", types.NewWrappedError( - fmt.Sprintf( - "failed to construct url: %s including parameters: %v", - url, - parameters, - ), - parseErr, - ) +func ConstructURL(baseURL string, params URLParameters) (string, error) { + u, err := url.Parse(baseURL) + if err != nil { + return "", errors.WrapPrefix(err, + fmt.Sprintf("failed to construct url '%s' with parameters: %v", u, params), 0) } - q := url.Query() + q := u.Query() - for parameter, value := range parameters { - q.Set(parameter, value) + for p, value := range params { + q.Set(p, value) } - url.RawQuery = q.Encode() - return url.String(), nil + u.RawQuery = q.Encode() + return u.String(), nil } // Get creates a Get request and returns the headers, body and an error. @@ -52,16 +45,6 @@ func Get(url string) (http.Header, []byte, error) { return MethodWithOpts(http.MethodGet, url, nil) } -// Post creates a Post request and returns the headers, body and an error. -func Post(url string, body url.Values) (http.Header, []byte, error) { - return MethodWithOpts(http.MethodGet, url, &OptionalParams{Body: body}) -} - -// GetWithOpts creates a Get request with optional parameters and returns the headers, body and an error. -func GetWithOpts(url string, opts *OptionalParams) (http.Header, []byte, error) { - return MethodWithOpts(http.MethodGet, url, opts) -} - // PostWithOpts creates a Post request with optional parameters and returns the headers, body and an error. func PostWithOpts(url string, opts *OptionalParams) (http.Header, []byte, error) { return MethodWithOpts(http.MethodPost, url, opts) @@ -69,19 +52,12 @@ func PostWithOpts(url string, opts *OptionalParams) (http.Header, []byte, error) // optionalURL ensures that the URL contains the optional parameters // it returns the url (with parameters if success) and an error indicating success. -func optionalURL(url string, opts *OptionalParams) (string, error) { - if opts != nil { - url, urlErr := ConstructURL(url, opts.URLParameters) - - if urlErr != nil { - return url, types.NewWrappedError( - fmt.Sprintf("failed to create HTTP request with url: %s", url), - urlErr, - ) - } - return url, nil +func optionalURL(urlStr string, opts *OptionalParams) (string, error) { + if opts == nil { + return urlStr, nil } - return url, nil + + return ConstructURL(urlStr, opts.URLParameters) } // optionalHeaders ensures that the HTTP request uses the optional headers if defined. @@ -106,16 +82,16 @@ func optionalBodyReader(opts *OptionalParams) io.Reader { // It returns the HTTP headers, the body and an error if there is one. func MethodWithOpts( method string, - url string, + urlStr string, opts *OptionalParams, ) (http.Header, []byte, error) { // Make sure the url contains all the parameters // This can return an error, - // it already has the right error so we don't wrap it further - url, urlErr := optionalURL(url, opts) - if urlErr != nil { + // it already has the right error, so we don't wrap it further + urlStr, err := optionalURL(urlStr, opts) + if err != nil { // No further type wrapping is needed here - return nil, nil, urlErr + return nil, nil, err } // Default timeout is 5 seconds @@ -125,15 +101,11 @@ func MethodWithOpts( timeout = opts.Timeout } - // Create a client - client := &http.Client{Timeout: timeout * time.Second} - - errorMessage := fmt.Sprintf("failed HTTP request with method %s and url %s", method, url) - // Create request object with the body reader generated from the optional arguments - req, reqErr := http.NewRequest(method, url, optionalBodyReader(opts)) - if reqErr != nil { - return nil, nil, types.NewWrappedError(errorMessage, reqErr) + req, err := http.NewRequest(method, urlStr, optionalBodyReader(opts)) + if err != nil { + return nil, nil, errors.WrapPrefix(err, + fmt.Sprintf("failed HTTP request with method %s and url %s", method, urlStr), 0) } // See https://stackoverflow.com/questions/17714494/golang-http-request-results-in-eof-errors-when-making-multiple-requests-successi @@ -142,29 +114,34 @@ func MethodWithOpts( // Make sure the headers contain all the parameters optionalHeaders(req, opts) + // Create a client + c := &http.Client{Timeout: timeout * time.Second} + // Do request - resp, respErr := client.Do(req) - if respErr != nil { - return nil, nil, types.NewWrappedError(errorMessage, respErr) + res, err := c.Do(req) + if err != nil { + return nil, nil, errors.WrapPrefix(err, + fmt.Sprintf("failed HTTP request with method %s and url %s", method, urlStr), 0) } // Request successful, make sure body is closed at the end - defer resp.Body.Close() + defer func() { + _ = res.Body.Close() + }() // Return a string - body, readErr := ioutil.ReadAll(resp.Body) - if readErr != nil { - return resp.Header, nil, types.NewWrappedError(errorMessage, readErr) + body, err := io.ReadAll(res.Body) + if err != nil { + return res.Header, nil, errors.WrapPrefix(err, + fmt.Sprintf("failed HTTP request with method %s and url %s", method, urlStr), 0) } - if resp.StatusCode < 200 || resp.StatusCode > 299 { - // We make this a custom error because we want to extract the status code later - statusErr := &StatusError{URL: url, Body: string(body), Status: resp.StatusCode} - return resp.Header, body, types.NewWrappedError(errorMessage, statusErr) + if res.StatusCode < 200 || res.StatusCode > 299 { + return res.Header, body, errors.Wrap(&StatusError{URL: urlStr, Body: string(body), Status: res.StatusCode}, 0) } // Return the body in bytes and signal the status error if there was one - return resp.Header, body, nil + return res.Header, body, nil } // StatusError indicates that we have received a HTTP status error. @@ -183,22 +160,3 @@ func (e *StatusError) Error() string { e.Body, ) } - -// ParseJSONError indicates that the HTTP error is because of failed JSON parsing -// It has the URL and the Body as context. -// The underlying JSON parsing Err itself is also wrapped here. -type ParseJSONError struct { - URL string - Body string - Err error -} - -// Error returns the ParseJSONError as an error string. -func (e *ParseJSONError) Error() string { - return fmt.Sprintf( - "failed parsing json %s for HTTP resource: %s with error: %v", - e.Body, - e.URL, - e.Err, - ) -} diff --git a/internal/log/log.go b/internal/log/log.go index 3c3218c..99d9f79 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -3,15 +3,49 @@ package log import ( "fmt" + "github.com/eduvpn/eduvpn-common/internal/oauth" "io" "log" "os" "path" "github.com/eduvpn/eduvpn-common/internal/util" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) +type ErrLevel int8 + +const ( + ErrOther ErrLevel = iota + ErrInfo + ErrWarning + ErrFatal +) + +func GetErrorLevel(err error) ErrLevel { + if err == nil { + return ErrOther + } + + getLevel := func(e error) ErrLevel { + if e == nil { + return ErrOther + } + + switch e.(type) { + case *oauth.CancelledCallbackError: + return ErrInfo + default: + return ErrOther + } + } + + if err1, ok := err.(*errors.Error); ok { + return getLevel(err1.Err) + } + return getLevel(err) +} + // FileLogger defines the type of logger that this package implements // As the name suggests, it saves the log to a file. type FileLogger struct { @@ -66,42 +100,40 @@ func (e Level) String() string { // Init initializes the logger by forwarding a max level 'level' and a directory 'directory' where the log should be stored // If the logger cannot be initialized, for example an error in opening the log file, an error is returned. -func (logger *FileLogger) Init(level Level, directory string) error { - errorMessage := "failed creating log" - - configDirErr := util.EnsureDirectory(directory) - if configDirErr != nil { - return types.NewWrappedError(errorMessage, configDirErr) +func (logger *FileLogger) Init(lvl Level, dir string) error { + err := util.EnsureDirectory(dir) + if err != nil { + return err } - logFile, logOpenErr := os.OpenFile( - logger.filename(directory), + f, err := os.OpenFile( + logger.filename(dir), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o666, ) - if logOpenErr != nil { - return types.NewWrappedError(errorMessage, logOpenErr) + if err != nil { + return errors.WrapPrefix(err, "failed creating log", 0) } - multi := io.MultiWriter(os.Stdout, logFile) + multi := io.MultiWriter(os.Stdout, f) log.SetOutput(multi) - logger.file = logFile - logger.Level = level + logger.file = f + logger.Level = lvl return nil } // Inherit logs an error with a label using the error level of the error. -func (logger *FileLogger) Inherit(label string, err error) { - level := types.ErrorLevel(err) - - msg := fmt.Sprintf("%s with err: %s", label, types.ErrorTraceback(err)) - switch level { - case types.ErrInfo: - logger.Infof(msg) - case types.ErrWarning: - logger.Warningf(msg) - case types.ErrOther: - logger.Errorf(msg) - case types.ErrFatal: - logger.Fatalf(msg) +func (logger *FileLogger) Inherit(err error) { + if err == nil { + return + } + switch GetErrorLevel(err) { + case ErrInfo: + logger.Infof(err.Error()) + case ErrWarning: + logger.Warningf(err.Error()) + case ErrOther: + logger.Errorf(err.Error()) + case ErrFatal: + logger.Fatalf(err.Error()) } } @@ -131,8 +163,8 @@ func (logger *FileLogger) Fatalf(msg string, params ...interface{}) { } // Close closes the logger by closing the internal file. -func (logger *FileLogger) Close() { - logger.file.Close() +func (logger *FileLogger) Close() error { + return logger.file.Close() } // filename returns the filename of the logger by returning the full path as a string. @@ -141,11 +173,11 @@ func (logger *FileLogger) filename(directory string) string { } // log logs as level 'level' a message 'msg' with parameters 'params'. -func (logger *FileLogger) log(level Level, msg string, params ...interface{}) { - if level >= logger.Level && logger.Level != LevelNotSet { - formattedMsg := fmt.Sprintf(msg, params...) - format := fmt.Sprintf("- Go - %s - %s", level.String(), formattedMsg) +func (logger *FileLogger) log(lvl Level, msg string, params ...interface{}) { + if lvl >= logger.Level && logger.Level != LevelNotSet { + fMsg := fmt.Sprintf(msg, params...) + f := fmt.Sprintf("- Go - %s - %s", lvl.String(), fMsg) // To log file - log.Println(format) + log.Println(f) } } diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 802295d..3dcd3d3 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -1,4 +1,4 @@ -// package oauth implement an oauth client defined in e.g. rfc 6749 +// Package oauth implement an oauth client defined in e.g. rfc 6749 // However, we try to follow some recommendations from the v2.1 oauth draft RFC // Some specific things we implement here: // - PKCE (RFC 7636) @@ -10,7 +10,6 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" - "errors" "fmt" "html/template" "net" @@ -20,7 +19,7 @@ import ( httpw "github.com/eduvpn/eduvpn-common/internal/http" "github.com/eduvpn/eduvpn-common/internal/util" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) // genState generates a random base64 string to be used for state @@ -31,13 +30,13 @@ import ( // client. // We implement it similarly to the verifier. func genState() (string, error) { - randomBytes, err := util.MakeRandomByteSlice(32) + bts, err := util.MakeRandomByteSlice(32) if err != nil { - return "", types.NewWrappedError("failed generating an OAuth state", err) + return "", err } - // For consistency we also use raw url encoding here - return base64.RawURLEncoding.EncodeToString(randomBytes), nil + // For consistency, we also use raw url encoding here + return base64.RawURLEncoding.EncodeToString(bts), nil } // genChallengeS256 generates a sha256 base64 challenge from a verifier @@ -68,10 +67,7 @@ func genChallengeS256(verifier string) string { func genVerifier() (string, error) { randomBytes, err := util.MakeRandomByteSlice(32) if err != nil { - return "", types.NewWrappedError( - "failed generating an OAuth verifier", - err, - ) + return "", err } return base64.RawURLEncoding.EncodeToString(randomBytes), nil @@ -89,10 +85,10 @@ type OAuth struct { TokenURL string `json:"token_url"` // session is the internal in progress OAuth session - session ExchangeSession `json:"-"` + session ExchangeSession // Token is where the access and refresh tokens are stored along with the timestamps - token Token `json:"-"` + token Token } // ExchangeSession is a structure that gets passed to the callback for easy access to the current state. @@ -126,39 +122,31 @@ type ExchangeSession struct { // It returns the access token as a string, possibly obtained fresh using the Refresh Token // If the token cannot be obtained, an error is returned and the token is an empty string. func (oauth *OAuth) AccessToken() (string, error) { - errorMessage := "failed getting access token" - tokens := oauth.token + ts := oauth.token // We have tokens... // The tokens are not expired yet - // So they should be valid, re-authorization not needed - if !tokens.Expired() { - return tokens.access, nil + // So they should be valid, re-login not needed + if !ts.Expired() { + return ts.access, nil } // Check if refresh is even possible by doing a simple check if the refresh token is empty // This is not needed but reduces API calls to the server - if tokens.refresh == "" { - return "", types.NewWrappedError( - errorMessage, - &TokensInvalidError{Cause: "no refresh token is present"}, - ) + if ts.refresh == "" { + return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0) } // Otherwise refresh and then later return the access token if we are successful - refreshErr := oauth.tokensWithRefresh() - if refreshErr != nil { + err := oauth.tokensWithRefresh() + if err != nil { // We have failed to ensure the tokens due to refresh not working - return "", types.NewWrappedError( - errorMessage, - &TokensInvalidError{ - Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr), - }, - ) + return "", errors.Wrap( + &TokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", err)}, 0) } // We have obtained new tokens with refresh - return tokens.access, nil + return ts.access, nil } // setupListener sets up an OAuth listener @@ -166,24 +154,22 @@ func (oauth *OAuth) AccessToken() (string, error) { // @see https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-07.html#section-8.4.2 // "Loopback Interface Redirection". func (oauth *OAuth) setupListener() error { - errorMessage := "failed setting up listener" oauth.session.Context = context.Background() // create a listener - listener, listenerErr := net.Listen("tcp", "127.0.0.1:0") - if listenerErr != nil { - return types.NewWrappedError(errorMessage, listenerErr) + lst, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return errors.WrapPrefix(err, "net.Listen failed", 0) } - oauth.session.Listener = listener + oauth.session.Listener = lst return nil } // tokensWithCallback gets the OAuth tokens using a local web server // If it was unsuccessful it returns an error. func (oauth *OAuth) tokensWithCallback() error { - errorMessage := "failed getting tokens with callback" if oauth.session.Listener == nil { - return types.NewWrappedError(errorMessage, errors.New("no listener")) + return errors.Errorf("failed getting tokens with callback: no listener") } mux := http.NewServeMux() // server /callback over the listener address @@ -196,7 +182,7 @@ func (oauth *OAuth) tokensWithCallback() error { mux.HandleFunc("/callback", oauth.Callback) if err := oauth.session.Server.Serve(oauth.session.Listener); err != http.ErrServerClosed { - return types.NewWrappedError(errorMessage, err) + return errors.WrapPrefix(err, "failed getting tokens with callback", 0) } return oauth.session.CallbackError } @@ -205,23 +191,18 @@ func (oauth *OAuth) tokensWithCallback() error { // It calculates the expired timestamp by having a 'startTime' passed to it // The URL that is input here is used for additional context. func (oauth *OAuth) fillToken(response []byte, startTime time.Time, url string) error { - responseStructure := TokenResponse{} - - jsonErr := json.Unmarshal(response, &responseStructure) - if jsonErr != nil { - return types.NewWrappedError( - "failed filling OAuth tokens", - &httpw.ParseJSONError{URL: url, Body: string(response), Err: jsonErr}, - ) - } - - internalStructure := Token{} - internalStructure.expiredTimestamp = startTime.Add( - time.Second * time.Duration(responseStructure.Expires), - ) - internalStructure.access = responseStructure.Access - internalStructure.refresh = responseStructure.Refresh - oauth.token = internalStructure + res := TokenResponse{} + + err := json.Unmarshal(response, &res) + if err != nil { + return errors.WrapPrefix(err, "failed filling OAuth tokens from "+url, 0) + } + + oauth.token = Token{ + access: res.Access, + refresh: res.Refresh, + expiredTimestamp: startTime.Add(time.Second * time.Duration(res.Expires)), + } return nil } @@ -240,14 +221,13 @@ func (oauth *OAuth) SetTokenRenew() { // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 // If it was unsuccessful it returns an error. func (oauth *OAuth) tokensWithAuthCode(authCode string) error { - errorMessage := "failed getting tokens with the authorization code" // Make sure the verifier is set as the parameter // so that the server can verify that we are the actual owner of the authorization code - reqURL := oauth.TokenURL + u := oauth.TokenURL - port, portErr := oauth.ListenerPort() - if portErr != nil { - return types.NewWrappedError(errorMessage, portErr) + port, err := oauth.ListenerPort() + if err != nil { + return err } data := url.Values{ @@ -257,21 +237,17 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error { "grant_type": {"authorization_code"}, "redirect_uri": {fmt.Sprintf("http://127.0.0.1:%d/callback", port)}, } - headers := http.Header{ + h := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, } - opts := &httpw.OptionalParams{Headers: headers, Body: data} - currentTime := time.Now() - _, body, bodyErr := httpw.PostWithOpts(reqURL, opts) - if bodyErr != nil { - return types.NewWrappedError(errorMessage, bodyErr) + opts := &httpw.OptionalParams{Headers: h, Body: data} + now := time.Now() + _, body, err := httpw.PostWithOpts(u, opts) + if err != nil { + return err } - fillErr := oauth.fillToken(body, currentTime, reqURL) - if fillErr != nil { - return types.NewWrappedError(errorMessage, fillErr) - } - return nil + return oauth.fillToken(body, now, u) } // tokensWithRefresh gets the access and refresh tokens with a previously received refresh token @@ -279,27 +255,22 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error { // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 // If it was unsuccessful it returns an error. func (oauth *OAuth) tokensWithRefresh() error { - errorMessage := "failed getting tokens with the refresh token" - reqURL := oauth.TokenURL + u := oauth.TokenURL data := url.Values{ "refresh_token": {oauth.token.refresh}, "grant_type": {"refresh_token"}, } - headers := http.Header{ + h := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, } - opts := &httpw.OptionalParams{Headers: headers, Body: data} - currentTime := time.Now() - _, body, bodyErr := httpw.PostWithOpts(reqURL, opts) - if bodyErr != nil { - return types.NewWrappedError(errorMessage, bodyErr) + opts := &httpw.OptionalParams{Headers: h, Body: data} + now := time.Now() + _, body, err := httpw.PostWithOpts(u, opts) + if err != nil { + return err } - fillErr := oauth.fillToken(body, currentTime, reqURL) - if fillErr != nil { - return types.NewWrappedError(errorMessage, fillErr) - } - return nil + return oauth.fillToken(body, now, u) } // responseTemplate is the HTML template for the OAuth authorized response @@ -349,27 +320,17 @@ type oauthResponseHTML struct { // writeResponseHTML writes the OAuth response using a response writer and the title + message // If it was unsuccessful it returns an error. func writeResponseHTML(w http.ResponseWriter, title string, message string) error { - errorMessage := "failed writing response HTML" - template, templateErr := template.New("oauth-response").Parse(responseTemplate) - if templateErr != nil { - return types.NewWrappedError(errorMessage, templateErr) + t, err := template.New("oauth-response").Parse(responseTemplate) + if err != nil { + return errors.WrapPrefix(err, "failed writing response HTML", 0) } - executeErr := template.Execute(w, oauthResponseHTML{ - Title: title, - Message: message, - }) - if executeErr != nil { - return types.NewWrappedError(errorMessage, executeErr) - } - return nil + return t.Execute(w, oauthResponseHTML{Title: title, Message: message}) } // Callback is the public function used to get the OAuth tokens using an authorization code callback // The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { - errorMessage := "failed callback to retrieve the authorization code" - // Shutdown after we're done defer func() { // writing the html is best effort @@ -383,64 +344,49 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { _ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.") } if oauth.session.Server != nil { - go oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck + go func() { + _ = oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck + }() } }() // ISS: https://www.rfc-editor.org/rfc/rfc9207.html // TODO: Make this a required parameter in the future - urlQuery := req.URL.Query() - extractedISS := urlQuery.Get("iss") - if extractedISS != "" { - if oauth.session.ISS != extractedISS { - oauth.session.CallbackError = types.NewWrappedError( - errorMessage, - &CallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.session.ISS}, - ) + q := req.URL.Query() + iss := q.Get("iss") + if iss != "" { + if oauth.session.ISS != iss { + oauth.session.CallbackError = errors.Errorf("failed matching ISS; expected '%s' got '%s'", + oauth.session.ISS, iss) return } } // Make sure the state is present and matches to protect against cross-site request forgeries // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 - extractedState := urlQuery.Get("state") - if extractedState == "" { - oauth.session.CallbackError = types.NewWrappedError( - errorMessage, - &CallbackParameterError{Parameter: "state", URL: req.URL.String()}, - ) + state := q.Get("state") + if state == "" { + oauth.session.CallbackError = errors.Errorf("failed retrieving parameter 'state' from '%s'", req.URL) return } // The state is the first entry - if extractedState != oauth.session.State { - oauth.session.CallbackError = types.NewWrappedError( - errorMessage, - &CallbackStateMatchError{ - State: extractedState, - ExpectedState: oauth.session.State, - }, - ) + if state != oauth.session.State { + oauth.session.CallbackError = errors.Errorf("failed matching state; expected '%s' got '%s'", + oauth.session.State, state) return } // No authorization code - extractedCode := urlQuery.Get("code") - if extractedCode == "" { - oauth.session.CallbackError = types.NewWrappedError( - errorMessage, - &CallbackParameterError{Parameter: "code", URL: req.URL.String()}, - ) + code := q.Get("code") + if code == "" { + oauth.session.CallbackError = errors.Errorf("failed retrieving parameter 'code' from '%s'", req.URL) return } // Now that we have obtained the authorization code, we can move to the next step: // Obtaining the access and refresh tokens - getTokensErr := oauth.tokensWithAuthCode(extractedCode) - if getTokensErr != nil { - oauth.session.CallbackError = types.NewWrappedError( - errorMessage, - getTokensErr, - ) + if err := oauth.tokensWithAuthCode(code); err != nil { + oauth.session.CallbackError = errors.WrapPrefix(err, "failed callback to retrieve the authorization code", 0) return } } @@ -457,94 +403,78 @@ func (oauth *OAuth) Init(iss string, baseAuthorizationURL string, tokenURL strin // ListenerPort gets the listener for the OAuth web server // It returns the port as an integer and an error if there is any. -func (oauth OAuth) ListenerPort() (int, error) { - errorMessage := "failed to get listener port" - +func (oauth *OAuth) ListenerPort() (int, error) { if oauth.session.Listener == nil { - return 0, types.NewWrappedError(errorMessage, errors.New("no OAuth listener")) + return 0, errors.Errorf("failed to get listener port") } return oauth.session.Listener.Addr().(*net.TCPAddr).Port, nil } // AuthURL gets the authorization url to start the OAuth procedure. func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (string, error) { - errorMessage := "failed starting OAuth exchange" - // Generate the verifier and challenge - verifier, verifierErr := genVerifier() - if verifierErr != nil { - return "", types.NewWrappedError(errorMessage, verifierErr) + v, err := genVerifier() + if err != nil { + return "", errors.WrapPrefix(err, "genVerifier error", 0) } - challenge := genChallengeS256(verifier) // Generate the state - state, stateErr := genState() - if stateErr != nil { - return "", types.NewWrappedError(errorMessage, stateErr) + state, err := genState() + if err != nil { + return "", errors.WrapPrefix(err, "genState error", 0) } // Fill the struct with the necessary fields filled for the next call to getting the HTTP client - oauthSession := ExchangeSession{ + oauth.session = ExchangeSession{ ClientID: name, ISS: oauth.ISS, State: state, - Verifier: verifier, + Verifier: v, } - oauth.session = oauthSession // set up the listener to get the redirect URI - listenerErr := oauth.setupListener() - if listenerErr != nil { - return "", types.NewWrappedError(errorMessage, stateErr) + if err = oauth.setupListener(); err != nil { + return "", errors.WrapPrefix(err, "oauth.setupListener error", 0) } // Get the listener port - port, portErr := oauth.ListenerPort() - if portErr != nil { - return "", types.NewWrappedError(errorMessage, portErr) + port, err := oauth.ListenerPort() + if err != nil { + return "", errors.WrapPrefix(err, "oauth.ListenerPort error", 0) } - parameters := map[string]string{ + params := map[string]string{ "client_id": name, "code_challenge_method": "S256", - "code_challenge": challenge, + "code_challenge": genChallengeS256(v), "response_type": "code", "scope": "config", "state": state, "redirect_uri": fmt.Sprintf("http://127.0.0.1:%d/callback", port), } - authURL, urlErr := httpw.ConstructURL(oauth.BaseAuthorizationURL, parameters) + u, err := httpw.ConstructURL(oauth.BaseAuthorizationURL, params) - if urlErr != nil { - return "", types.NewWrappedError(errorMessage, urlErr) + if err != nil { + return "", errors.WrapPrefix(err, "httpw.ConstructURL error", 0) } // Return the url processed - return postProcessAuth(authURL), nil + return postProcessAuth(u), nil } // Exchange starts the OAuth exchange by getting the tokens with the redirect callback // If it was unsuccessful it returns an error. func (oauth *OAuth) Exchange() error { - tokenErr := oauth.tokensWithCallback() - - if tokenErr != nil { - return types.NewWrappedError("failed finishing OAuth", tokenErr) - } - return nil + return oauth.tokensWithCallback() } // Cancel cancels the existing OAuth // TODO: Use context for this. func (oauth *OAuth) Cancel() { - oauth.session.CallbackError = types.NewWrappedErrorLevel( - types.ErrInfo, - "cancelled OAuth", - &CancelledCallbackError{}, - ) + oauth.session.CallbackError = errors.Wrap(&CancelledCallbackError{}, 0) if oauth.session.Server != nil { - oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck + _ = oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck } } @@ -554,33 +484,6 @@ func (e *CancelledCallbackError) Error() string { return "client cancelled OAuth" } -type CallbackParameterError struct { - Parameter string - URL string -} - -func (e *CallbackParameterError) Error() string { - return fmt.Sprintf("failed retrieving parameter: %s in url: %s", e.Parameter, e.URL) -} - -type CallbackStateMatchError struct { - State string - ExpectedState string -} - -func (e *CallbackStateMatchError) Error() string { - return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState) -} - -type CallbackISSMatchError struct { - ISS string - ExpectedISS string -} - -func (e *CallbackISSMatchError) Error() string { - return fmt.Sprintf("failed matching ISS, got: %s, want: %s", e.ISS, e.ExpectedISS) -} - type TokensInvalidError struct { Cause string } diff --git a/internal/server/api.go b/internal/server/api.go index 21ba6f4..dfa8e14 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -2,7 +2,6 @@ package server import ( "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -10,130 +9,114 @@ import ( "time" httpw "github.com/eduvpn/eduvpn-common/internal/http" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) func APIGetEndpoints(baseURL string) (*Endpoints, error) { - errorMessage := "failed getting server endpoints" - url, urlErr := url.Parse(baseURL) - if urlErr != nil { - return nil, types.NewWrappedError(errorMessage, urlErr) + u, err := url.Parse(baseURL) + if err != nil { + return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) } - wellKnownPath := "/.well-known/vpn-user-portal" + wk := "/.well-known/vpn-user-portal" - url.Path = path.Join(url.Path, wellKnownPath) - _, body, bodyErr := httpw.Get(url.String()) - - if bodyErr != nil { - return nil, types.NewWrappedError(errorMessage, bodyErr) + u.Path = path.Join(u.Path, wk) + _, body, err := httpw.Get(u.String()) + if err != nil { + return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) } - endpoints := &Endpoints{} - jsonErr := json.Unmarshal(body, endpoints) - - if jsonErr != nil { - return nil, types.NewWrappedError(errorMessage, jsonErr) + ep := &Endpoints{} + if err = json.Unmarshal(body, ep); err != nil { + return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) } - return endpoints, nil + return ep, nil } func apiAuthorized( - server Server, + srv Server, method string, endpoint string, opts *httpw.OptionalParams, ) (http.Header, []byte, error) { - errorMessage := "failed API authorized" // Ensure optional is not nil as we will fill it with headers if opts == nil { opts = &httpw.OptionalParams{} } - base, baseErr := server.Base() - - if baseErr != nil { - return nil, nil, types.NewWrappedError(errorMessage, baseErr) + b, err := srv.Base() + if err != nil { + return nil, nil, errors.WrapPrefix(err, "failed API authorized", 0) } // Join the paths - url, urlErr := url.Parse(base.Endpoints.API.V3.API) - if urlErr != nil { - return nil, nil, types.NewWrappedError(errorMessage, urlErr) + u, err := url.Parse(b.Endpoints.API.V3.API) + if err != nil { + return nil, nil, errors.WrapPrefix(err, "failed API authorized", 0) } - url.Path = path.Join(url.Path, endpoint) + u.Path = path.Join(u.Path, endpoint) // Make sure the tokens are valid, this will return an error if re-login is needed - token, tokenErr := HeaderToken(server) - if tokenErr != nil { - return nil, nil, types.NewWrappedError(errorMessage, tokenErr) + t, err := HeaderToken(srv) + if err != nil { + return nil, nil, errors.WrapPrefix(err, "failed API authorized", 0) } - headerKey := "Authorization" - headerValue := fmt.Sprintf("Bearer %s", token) + key := "Authorization" + val := fmt.Sprintf("Bearer %s", t) if opts.Headers != nil { - opts.Headers.Add(headerKey, headerValue) + opts.Headers.Add(key, val) } else { - opts.Headers = http.Header{headerKey: {headerValue}} + opts.Headers = http.Header{key: {val}} } - return httpw.MethodWithOpts(method, url.String(), opts) + return httpw.MethodWithOpts(method, u.String(), opts) } func apiAuthorizedRetry( - server Server, + srv Server, method string, endpoint string, opts *httpw.OptionalParams, ) (http.Header, []byte, error) { - errorMessage := "failed authorized API retry" - header, body, bodyErr := apiAuthorized(server, method, endpoint, opts) - - if bodyErr != nil { - var error *httpw.StatusError - - // Only retry authorized if we get a HTTP 401 - if errors.As(bodyErr, &error) && error.Status == 401 { - // Mark the token as expired and retry so we trigger the refresh flow - MarkTokenExpired(server) - retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts) - if retryErr != nil { - return nil, nil, types.NewWrappedError(errorMessage, retryErr) - } - return retryHeader, retryBody, nil - } - return nil, nil, types.NewWrappedError(errorMessage, bodyErr) - } - return header, body, nil -} - -func APIInfo(server Server) error { - errorMessage := "failed API /info" - _, body, bodyErr := apiAuthorizedRetry(server, http.MethodGet, "/info", nil) - if bodyErr != nil { - return types.NewWrappedError(errorMessage, bodyErr) + h, body, err := apiAuthorized(srv, method, endpoint, opts) + if err == nil { + return h, body, nil } - structure := ProfileInfo{} - jsonErr := json.Unmarshal(body, &structure) - if jsonErr != nil { - return types.NewWrappedError(errorMessage, jsonErr) + statErr := &httpw.StatusError{} + // Only retry authorized if we get an HTTP 401 + if errors.As(err, &statErr) && statErr.Status == 401 { + // Mark the token as expired and retry, so we trigger the refresh flow + MarkTokenExpired(srv) + h, body, err = apiAuthorized(srv, method, endpoint, opts) } + return h, body, err +} - base, baseErr := server.Base() +func APIInfo(srv Server) error { + _, body, err := apiAuthorizedRetry(srv, http.MethodGet, "/info", nil) + if err != nil { + return err + } + pi := ProfileInfo{} + if err = json.Unmarshal(body, &pi); err != nil { + return errors.WrapPrefix(err, "failed API /info", 0) + } - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) + b, err := srv.Base() + if err != nil { + return err } // Store the profiles and make sure that the current profile is not overwritten - previousProfile := base.Profiles.Current - base.Profiles = structure - base.Profiles.Current = previousProfile + prev := b.Profiles.Current + b.Profiles = pi + b.Profiles.Current = prev return nil } // see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1 -func GetPreferTCPString(preferTCP bool) string { +func boolToYesNo(preferTCP bool) string { if preferTCP { return "yes" } @@ -141,88 +124,77 @@ func GetPreferTCPString(preferTCP bool) string { } func APIConnectWireguard( - server Server, + srv Server, profileID string, pubkey string, preferTCP bool, - supportsOpenVPN bool, + openVPNSupport bool, ) (string, string, time.Time, error) { - errorMessage := "failed obtaining a WireGuard configuration" - headers := http.Header{ + hdrs := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-wireguard-profile"}, } // This profile also supports OpenVPN // Indicate that we also accept OpenVPN profiles - if supportsOpenVPN { - headers.Add("accept", "application/x-openvpn-profile") + if openVPNSupport { + hdrs.Add("accept", "application/x-openvpn-profile") } - urlForm := url.Values{ + vals := url.Values{ "profile_id": {profileID}, "public_key": {pubkey}, - "prefer_tcp": {GetPreferTCPString(preferTCP)}, + "prefer_tcp": {boolToYesNo(preferTCP)}, } - header, connectBody, connectErr := apiAuthorizedRetry( - server, - http.MethodPost, - "/connect", - &httpw.OptionalParams{Headers: headers, Body: urlForm}, - ) - if connectErr != nil { - return "", "", time.Time{}, types.NewWrappedError( - errorMessage, - connectErr, - ) + h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect", + &httpw.OptionalParams{Headers: hdrs, Body: vals}) + if err != nil { + return "", "", time.Time{}, err } - expires := header.Get("expires") + exp := h.Get("expires") - pTime, pTimeErr := http.ParseTime(expires) - if pTimeErr != nil { - return "", "", time.Time{}, types.NewWrappedError(errorMessage, pTimeErr) + ptm, err := http.ParseTime(exp) + if err != nil { + return "", "", time.Time{}, errors.WrapPrefix(err, "failed obtaining a WireGuard configuration", 0) } - contentType := header.Get("content-type") - - content := "openvpn" - if contentType == "application/x-wireguard-profile" { - content = "wireguard" + ct := h.Get("content-type") + c := "openvpn" + if ct == "application/x-wireguard-profile" { + c = "wireguard" } - return string(connectBody), content, pTime, nil + + return string(body), c, ptm, nil } -func APIConnectOpenVPN(server Server, profileID string, preferTCP bool) (string, time.Time, error) { - errorMessage := "failed obtaining an OpenVPN configuration" - headers := http.Header{ +func APIConnectOpenVPN(srv Server, profileID string, preferTCP bool) (string, time.Time, error) { + hdrs := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-openvpn-profile"}, } - urlForm := url.Values{ + vals := url.Values{ "profile_id": {profileID}, - "prefer_tcp": {GetPreferTCPString(preferTCP)}, + "prefer_tcp": {boolToYesNo(preferTCP)}, } - header, connectBody, connectErr := apiAuthorizedRetry( - server, - http.MethodPost, - "/connect", - &httpw.OptionalParams{Headers: headers, Body: urlForm}, - ) - if connectErr != nil { - return "", time.Time{}, types.NewWrappedError(errorMessage, connectErr) + h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect", + &httpw.OptionalParams{Headers: hdrs, Body: vals}) + if err != nil { + return "", time.Time{}, err } - expires := header.Get("expires") - pTime, pTimeErr := http.ParseTime(expires) - if pTimeErr != nil { - return "", time.Time{}, types.NewWrappedError(errorMessage, pTimeErr) + exp := h.Get("expires") + ptm, err := http.ParseTime(exp) + if err != nil { + return "", time.Time{}, errors.WrapPrefix(err, "failed obtaining an OpenVPN configuration", 0) } - return string(connectBody), pTime, nil + + return string(body), ptm, nil } +// APIDisconnect disconnects from the API. // This needs no further return value as it's best effort. func APIDisconnect(server Server) { _, _, _ = apiAuthorized(server, http.MethodPost, "/disconnect", nil) diff --git a/internal/server/base.go b/internal/server/base.go index bb88eb3..81049cf 100644 --- a/internal/server/base.go +++ b/internal/server/base.go @@ -2,11 +2,9 @@ package server import ( "time" - - "github.com/eduvpn/eduvpn-common/types" ) -// The base type for servers. +// Base is the base type for servers. type Base struct { URL string `json:"base_url"` DisplayName map[string]string `json:"display_name"` @@ -18,28 +16,27 @@ type Base struct { Type string `json:"server_type"` } -func (base *Base) InitializeEndpoints() error { - errorMessage := "failed initializing endpoints" - endpoints, endpointsErr := APIGetEndpoints(base.URL) - if endpointsErr != nil { - return types.NewWrappedError(errorMessage, endpointsErr) +func (b *Base) InitializeEndpoints() error { + ep, err := APIGetEndpoints(b.URL) + if err != nil { + return err } - base.Endpoints = *endpoints + b.Endpoints = *ep return nil } -func (base *Base) ValidProfiles(clientSupportsWireguard bool) ProfileInfo { - var validProfiles []Profile - for _, profile := range base.Profiles.Info.ProfileList { +func (b *Base) ValidProfiles(wireguardSupport bool) ProfileInfo { + var vps []Profile + for _, p := range b.Profiles.Info.ProfileList { // Not a valid profile because it does not support openvpn // Also the client does not support wireguard - if !profile.supportsOpenVPN() && !clientSupportsWireguard { + if !p.supportsOpenVPN() && !wireguardSupport { continue } - validProfiles = append(validProfiles, profile) + vps = append(vps, p) } return ProfileInfo{ - Current: base.Profiles.Current, - Info: ProfileListInfo{ProfileList: validProfiles}, + Current: b.Profiles.Current, + Info: ProfileListInfo{ProfileList: vps}, } } diff --git a/internal/server/custom.go b/internal/server/custom.go index d376727..bf0b230 100644 --- a/internal/server/custom.go +++ b/internal/server/custom.go @@ -1,42 +1,35 @@ package server import ( - "errors" - "fmt" - - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) -func (servers *Servers) SetCustomServer(server Server) error { - errorMessage := "failed setting custom server" - base, baseErr := server.Base() - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) +func (ss *Servers) SetCustomServer(server Server) error { + b, err := server.Base() + if err != nil { + return err } - if base.Type != "custom_server" { - return types.NewWrappedError(errorMessage, errors.New("not a custom server")) + if b.Type != "custom_server" { + return errors.WrapPrefix(err, "not a custom server", 0) } - if _, ok := servers.CustomServers.Map[base.URL]; ok { - servers.CustomServers.CurrentURL = base.URL - servers.IsType = CustomServerType + if _, ok := ss.CustomServers.Map[b.URL]; ok { + ss.CustomServers.CurrentURL = b.URL + ss.IsType = CustomServerType } else { - return types.NewWrappedError(errorMessage, errors.New("not a custom server")) + return errors.Errorf("not a custom server") } return nil } -func (servers *Servers) GetCustomServer(url string) (*InstituteAccessServer, error) { - if server, ok := servers.CustomServers.Map[url]; ok { - return server, nil +func (ss *Servers) GetCustomServer(url string) (*InstituteAccessServer, error) { + if srv, ok := ss.CustomServers.Map[url]; ok { + return srv, nil } - return nil, types.NewWrappedError( - "failed to get institute access server", - fmt.Errorf("no custom server with URL: %s", url), - ) + return nil, errors.Errorf("failed to get institute access server - no custom server with URL '%s'", url) } -func (servers *Servers) RemoveCustomServer(url string) { - servers.CustomServers.Remove(url) +func (ss *Servers) RemoveCustomServer(url string) { + ss.CustomServers.Remove(url) } diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index 9b6f735..56ed1cf 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -1,14 +1,10 @@ package server import ( - "errors" - "fmt" - "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) -// An instute access server. type InstituteAccessServer struct { // An instute access server has its own OAuth Auth oauth.OAuth `json:"oauth"` @@ -22,80 +18,75 @@ type InstituteAccessServers struct { CurrentURL string `json:"current_url"` } -func (servers *Servers) SetInstituteAccess(server Server) error { - errorMessage := "failed setting institute access server" - base, baseErr := server.Base() - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) +func (ss *Servers) SetInstituteAccess(srv Server) error { + b, err := srv.Base() + if err != nil { + return err } - if base.Type != "institute_access" { - return types.NewWrappedError(errorMessage, errors.New("not an institute access server")) + if b.Type != "institute_access" { + return errors.Errorf("not an institute access server") } - if _, ok := servers.InstituteServers.Map[base.URL]; ok { - servers.InstituteServers.CurrentURL = base.URL - servers.IsType = InstituteAccessServerType + if _, ok := ss.InstituteServers.Map[b.URL]; ok { + ss.InstituteServers.CurrentURL = b.URL + ss.IsType = InstituteAccessServerType } else { - return types.NewWrappedError(errorMessage, errors.New("no such institute access server")) + return errors.Errorf("no such institute access server") } return nil } -func (servers *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, error) { - if server, ok := servers.InstituteServers.Map[url]; ok { - return server, nil +func (ss *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, error) { + if srv, ok := ss.InstituteServers.Map[url]; ok { + return srv, nil } - return nil, types.NewWrappedError( - "failed to get institute access server", - fmt.Errorf("no institute access server with URL: %s", url), - ) + return nil, errors.Errorf("no institute access server with URL: %s", url) } -func (servers *Servers) RemoveInstituteAccess(url string) { - servers.InstituteServers.Remove(url) +func (ss *Servers) RemoveInstituteAccess(url string) { + ss.InstituteServers.Remove(url) } -func (servers *InstituteAccessServers) Remove(url string) { +func (iass *InstituteAccessServers) Remove(url string) { // Reset the current url - if servers.CurrentURL == url { - servers.CurrentURL = "" + if iass.CurrentURL == url { + iass.CurrentURL = "" } // Delete the url from the map - delete(servers.Map, url) + delete(iass.Map, url) } -func (institute *InstituteAccessServer) TemplateAuth() func(string) string { +func (ias *InstituteAccessServer) TemplateAuth() func(string) string { return func(authURL string) string { return authURL } } -func (institute *InstituteAccessServer) Base() (*Base, error) { - return &institute.Basic, nil +func (ias *InstituteAccessServer) Base() (*Base, error) { + return &ias.Basic, nil } -func (institute *InstituteAccessServer) OAuth() *oauth.OAuth { - return &institute.Auth +func (ias *InstituteAccessServer) OAuth() *oauth.OAuth { + return &ias.Auth } -func (institute *InstituteAccessServer) init( +func (ias *InstituteAccessServer) init( url string, - displayName map[string]string, - serverType string, + name map[string]string, + srvType string, supportContact []string, ) error { - errorMessage := fmt.Sprintf("failed initializing server %s", url) - institute.Basic.URL = url - institute.Basic.DisplayName = displayName - institute.Basic.SupportContact = supportContact - institute.Basic.Type = serverType - endpointsErr := institute.Basic.InitializeEndpoints() - if endpointsErr != nil { - return types.NewWrappedError(errorMessage, endpointsErr) + ias.Basic.URL = url + ias.Basic.DisplayName = name + ias.Basic.SupportContact = supportContact + ias.Basic.Type = srvType + err := ias.Basic.InitializeEndpoints() + if err != nil { + return err } - API := institute.Basic.Endpoints.API.V3 - institute.Auth.Init(url, API.Authorization, API.Token) + API := ias.Basic.Endpoints.API.V3 + ias.Auth.Init(url, API.Authorization, API.Token) return nil } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 998390d..12263a6 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -1,15 +1,13 @@ package server import ( - "errors" - "fmt" - "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/util" "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) -// A secure internet server which has its own OAuth tokens +// SecureInternetHomeServer secure internet server which has its own OAuth tokens // It specifies the current location url it is connected to. type SecureInternetHomeServer struct { Auth oauth.OAuth `json:"oauth"` @@ -24,150 +22,111 @@ type SecureInternetHomeServer struct { CurrentLocation string `json:"current_location"` } -func (servers *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer, error) { - if !servers.HasSecureLocation() { - return nil, errors.New("no secure internet home server") +func (ss *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer, error) { + if !ss.HasSecureLocation() { + return nil, errors.Errorf("no secure internet home server") } - return &servers.SecureInternetHomeServer, nil + return &ss.SecureInternetHomeServer, nil } -func (servers *Servers) SetSecureInternet(server Server) error { - errorMessage := "failed setting secure internet server" - base, baseErr := server.Base() - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) +func (ss *Servers) SetSecureInternet(server Server) error { + b, err := server.Base() + if err != nil { + return err } - if base.Type != "secure_internet" { - return types.NewWrappedError(errorMessage, errors.New("not a secure internet server")) + if b.Type != "secure_internet" { + return errors.Errorf("not a secure internet server") } // The location should already be configured // TODO: check for location? - servers.IsType = SecureInternetServerType + ss.IsType = SecureInternetServerType return nil } -func (servers *Servers) RemoveSecureInternet() { +func (ss *Servers) RemoveSecureInternet() { // Empty out the struct - servers.SecureInternetHomeServer = SecureInternetHomeServer{} + ss.SecureInternetHomeServer = SecureInternetHomeServer{} // If the current server is secure internet, default to custom server - if servers.IsType == SecureInternetServerType { - servers.IsType = CustomServerType + if ss.IsType == SecureInternetServerType { + ss.IsType = CustomServerType } } -func (server *SecureInternetHomeServer) TemplateAuth() func(string) string { +func (s *SecureInternetHomeServer) TemplateAuth() func(string) string { return func(authURL string) string { - return util.ReplaceWAYF(server.AuthorizationTemplate, authURL, server.HomeOrganizationID) + return util.ReplaceWAYF(s.AuthorizationTemplate, authURL, s.HomeOrganizationID) } } -func (server *SecureInternetHomeServer) Base() (*Base, error) { - errorMessage := "failed getting current secure internet home base" - if server.BaseMap == nil { - return nil, types.NewWrappedError( - errorMessage, - &SecureInternetMapNotFoundError{}, - ) +func (s *SecureInternetHomeServer) Base() (*Base, error) { + if s.BaseMap == nil { + return nil, errors.Errorf("secure internet map not found") } - base, exists := server.BaseMap[server.CurrentLocation] - - if !exists { - return nil, types.NewWrappedError( - errorMessage, - &SecureInternetBaseNotFoundError{Current: server.CurrentLocation}, - ) + b, ok := s.BaseMap[s.CurrentLocation] + if !ok { + return nil, errors.Errorf("secure internet base with location '%s' not found", s.CurrentLocation) } - return base, nil + return b, nil } -func (server *SecureInternetHomeServer) OAuth() *oauth.OAuth { - return &server.Auth +func (s *SecureInternetHomeServer) OAuth() *oauth.OAuth { + return &s.Auth } -func (servers *Servers) HasSecureLocation() bool { - return servers.SecureInternetHomeServer.CurrentLocation != "" +func (ss *Servers) HasSecureLocation() bool { + return ss.SecureInternetHomeServer.CurrentLocation != "" } -func (server *SecureInternetHomeServer) addLocation( - locationServer *types.DiscoveryServer, -) (*Base, error) { - errorMessage := "failed adding a location" +func (s *SecureInternetHomeServer) addLocation(locSrv *types.DiscoveryServer) (*Base, error) { // Initialize the base map if it is non-nil - if server.BaseMap == nil { - server.BaseMap = make(map[string]*Base) + if s.BaseMap == nil { + s.BaseMap = make(map[string]*Base) } // Add the location to the base map - base, exists := server.BaseMap[locationServer.CountryCode] - - if !exists || base == nil { + b, ok := s.BaseMap[locSrv.CountryCode] + if !ok || b == nil { // Create the base to be added to the map - base = &Base{} - base.URL = locationServer.BaseURL - base.DisplayName = server.DisplayName - base.SupportContact = locationServer.SupportContact - base.Type = "secure_internet" - endpointsErr := base.InitializeEndpoints() - if endpointsErr != nil { - return nil, types.NewWrappedError(errorMessage, endpointsErr) + b = &Base{} + b.URL = locSrv.BaseURL + b.DisplayName = s.DisplayName + b.SupportContact = locSrv.SupportContact + b.Type = "secure_internet" + if err := b.InitializeEndpoints(); err != nil { + return nil, err } } // Ensure it is in the map - server.BaseMap[locationServer.CountryCode] = base - return base, nil + s.BaseMap[locSrv.CountryCode] = b + return b, nil } // Initializes the home server and adds its own location. -func (server *SecureInternetHomeServer) init( - homeOrg *types.DiscoveryOrganization, - homeLocation *types.DiscoveryServer, -) error { - errorMessage := "failed initializing secure internet home server" - - if server.HomeOrganizationID != homeOrg.OrgID { +func (s *SecureInternetHomeServer) init( + homeOrg *types.DiscoveryOrganization, homeLoc *types.DiscoveryServer) error { + if s.HomeOrganizationID != homeOrg.OrgID { // New home organisation, clear everything - *server = SecureInternetHomeServer{} + *s = SecureInternetHomeServer{} } // Make sure to set the organization ID - server.HomeOrganizationID = homeOrg.OrgID - server.DisplayName = homeOrg.DisplayName + s.HomeOrganizationID = homeOrg.OrgID + s.DisplayName = homeOrg.DisplayName // Make sure to set the authorization URL template - server.AuthorizationTemplate = homeLocation.AuthenticationURLTemplate + s.AuthorizationTemplate = homeLoc.AuthenticationURLTemplate - base, baseErr := server.addLocation(homeLocation) - - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) + b, err := s.addLocation(homeLoc) + if err != nil { + return err } // Make sure oauth contains our endpoints - server.Auth.Init(base.URL, base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token) + s.Auth.Init(b.URL, b.Endpoints.API.V3.Authorization, b.Endpoints.API.V3.Token) return nil } - -type SecureInternetHomeNotFoundError struct{} - -func (e *SecureInternetHomeNotFoundError) Error() string { - return "failed to get secure internet home server, not found" -} - -type SecureInternetMapNotFoundError struct{} - -func (e *SecureInternetMapNotFoundError) Error() string { - return "secure internet map not found" -} - -type SecureInternetBaseNotFoundError struct { - Current string -} - -func (e *SecureInternetBaseNotFoundError) Error() string { - return fmt.Sprintf("secure internet base not found with current location: %s", e.Current) -} diff --git a/internal/server/server.go b/internal/server/server.go index 95244d5..de0fa9a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,13 +1,11 @@ package server import ( - "errors" - "fmt" "time" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/wireguard" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) type Type int8 @@ -21,10 +19,10 @@ const ( type Server interface { OAuth() *oauth.OAuth - // Get the authorization URL template function + // TemplateAuth returns the authorization URL template function TemplateAuth() func(string) string - // Gets the server base + // Base returns the server base Base() (*Base, error) } @@ -34,7 +32,7 @@ type EndpointList struct { Token string `json:"token_endpoint"` } -// Struct that defines the json format for /.well-known/vpn-user-portal". +// Endpoints defines the json format for /.well-known/vpn-user-portal". type Endpoints struct { API struct { V2 EndpointList `json:"http://eduvpn.org/api#2"` @@ -43,310 +41,226 @@ type Endpoints struct { V string `json:"v"` } -func ShouldRenewButton(server Server) bool { - base, baseErr := server.Base() - - if baseErr != nil { +func ShouldRenewButton(srv Server) bool { + b, err := srv.Base() + if err != nil { // FIXME: Log error here? return false } // Get current time - current := time.Now() + now := time.Now() // Session is expired - if !current.Before(base.EndTime) { + if !now.Before(b.EndTime) { return true } // 30 minutes have not passed - if !current.After(base.StartTime.Add(30 * time.Minute)) { + if !now.After(b.StartTime.Add(30 * time.Minute)) { return false } // Session will not expire today - if !current.Add(24 * time.Hour).After(base.EndTime) { + if !now.Add(24 * time.Hour).After(b.EndTime) { return false } // Session duration is less than 24 hours but not 75% has passed - duration := base.EndTime.Sub(base.StartTime) - percentTime := base.StartTime.Add((duration / 4) * 3) - if duration < time.Duration(24*time.Hour) && !current.After(percentTime) { + d := b.EndTime.Sub(b.StartTime) + pct := b.StartTime.Add((d / 4) * 3) + if d < 24*time.Hour && !now.After(pct) { return false } return true } -func OAuthURL(server Server, name string) (string, error) { - return server.OAuth().AuthURL(name, server.TemplateAuth()) +func OAuthURL(srv Server, name string) (string, error) { + return srv.OAuth().AuthURL(name, srv.TemplateAuth()) } -func OAuthExchange(server Server) error { - return server.OAuth().Exchange() +func OAuthExchange(srv Server) error { + return srv.OAuth().Exchange() } -func HeaderToken(server Server) (string, error) { - token, tokenErr := server.OAuth().AccessToken() - if tokenErr != nil { - return "", types.NewWrappedError("failed getting server token for HTTP Header", tokenErr) - } - return token, nil +func HeaderToken(srv Server) (string, error) { + return srv.OAuth().AccessToken() } -func MarkTokenExpired(server Server) { - server.OAuth().SetTokenExpired() +func MarkTokenExpired(srv Server) { + srv.OAuth().SetTokenExpired() } -func MarkTokensForRenew(server Server) { - server.OAuth().SetTokenRenew() +func MarkTokensForRenew(srv Server) { + srv.OAuth().SetTokenRenew() } -func NeedsRelogin(server Server) bool { - _, tokenErr := HeaderToken(server) - return tokenErr != nil +func NeedsRelogin(srv Server) bool { + _, err := HeaderToken(srv) + return err != nil } -func CancelOAuth(server Server) { - server.OAuth().Cancel() +func CancelOAuth(srv Server) { + srv.OAuth().Cancel() } -func CurrentProfile(server Server) (*Profile, error) { - errorMessage := "failed getting current profile" - base, baseErr := server.Base() - - if baseErr != nil { - return nil, types.NewWrappedError(errorMessage, baseErr) +func CurrentProfile(srv Server) (*Profile, error) { + b, err := srv.Base() + if err != nil { + return nil, err } - profileID := base.Profiles.Current - for _, profile := range base.Profiles.Info.ProfileList { - if profile.ID == profileID { + pid := b.Profiles.Current + for _, profile := range b.Profiles.Info.ProfileList { + if profile.ID == pid { return &profile, nil } } - return nil, types.NewWrappedError( - errorMessage, - &CurrentProfileNotFoundError{ProfileID: profileID}, - ) + return nil, errors.Errorf("profile not found: " + pid) } -func ValidProfiles(server Server, clientSupportsWireguard bool) (*ProfileInfo, error) { - errorMessage := "failed to get valid profiles" +func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) { // No error wrapping here otherwise we wrap it too much - base, baseErr := server.Base() - if baseErr != nil { - return nil, types.NewWrappedError(errorMessage, baseErr) + b, err := srv.Base() + if err != nil { + return nil, err } - profiles := base.ValidProfiles(clientSupportsWireguard) - if len(profiles.Info.ProfileList) == 0 { - return nil, types.NewWrappedError( - errorMessage, - errors.New("no profiles found with supported protocols"), - ) + ps := b.ValidProfiles(wireguardSupport) + if len(ps.Info.ProfileList) == 0 { + return nil, errors.Errorf("no profiles found with supported protocols") } - return &profiles, nil + return &ps, nil } -func wireguardGetConfig( - server Server, - preferTCP bool, - supportsOpenVPN bool, -) (string, string, error) { - errorMessage := "failed getting server WireGuard configuration" - base, baseErr := server.Base() - - if baseErr != nil { - return "", "", types.NewWrappedError(errorMessage, baseErr) +func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string, string, error) { + b, err := srv.Base() + if err != nil { + return "", "", err } - profileID := base.Profiles.Current - wireguardKey, wireguardErr := wireguard.GenerateKey() - - if wireguardErr != nil { - return "", "", types.NewWrappedError(errorMessage, wireguardErr) + pid := b.Profiles.Current + key, err := wireguard.GenerateKey() + if err != nil { + return "", "", err } - wireguardPublicKey := wireguardKey.PublicKey().String() - config, content, expires, configErr := APIConnectWireguard( - server, - profileID, - wireguardPublicKey, - preferTCP, - supportsOpenVPN, - ) - - if configErr != nil { - return "", "", types.NewWrappedError(errorMessage, configErr) + pub := key.PublicKey().String() + cfg, ct, exp, err := APIConnectWireguard(srv, pid, pub, preferTCP, openVPNSupport) + if err != nil { + return "", "", err } // Store start and end time - base.StartTime = time.Now() - base.EndTime = expires + b.StartTime = time.Now() + b.EndTime = exp - if content == "wireguard" { + if ct == "wireguard" { // This needs the go code a way to identify a connection // Use the uuid of the connection e.g. on Linux // This needs the client code to call the go code - config = wireguard.ConfigAddKey(config, wireguardKey) + cfg = wireguard.ConfigAddKey(cfg, key) } - return config, content, nil + return cfg, ct, nil } -func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) { - errorMessage := "failed getting server OpenVPN configuration" - base, baseErr := server.Base() - - if baseErr != nil { - return "", "", types.NewWrappedError(errorMessage, baseErr) +func openVPNGetConfig(srv Server, preferTCP bool) (string, string, error) { + b, err := srv.Base() + if err != nil { + return "", "", err } - profileID := base.Profiles.Current - configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profileID, preferTCP) + pid := b.Profiles.Current + cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP) // Store start and end time - base.StartTime = time.Now() - base.EndTime = expires + b.StartTime = time.Now() + b.EndTime = exp - if configErr != nil { - return "", "", types.NewWrappedError(errorMessage, configErr) + if err != nil { + return "", "", err } - return configOpenVPN, "openvpn", nil + return cfg, "openvpn", nil } -func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) { - errorMessage := "failed has valid profile check" - +func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { // Get new profiles using the info call // This does not override the current profile - infoErr := APIInfo(server) - if infoErr != nil { - return false, types.NewWrappedError(errorMessage, infoErr) + err := APIInfo(srv) + if err != nil { + return false, err } - base, baseErr := server.Base() - if baseErr != nil { - return false, types.NewWrappedError(errorMessage, baseErr) + b, err := srv.Base() + if err != nil { + return false, err } // If there was a profile chosen and it doesn't exist anymore, reset it - if base.Profiles.Current != "" { - _, existsProfileErr := CurrentProfile(server) - if existsProfileErr != nil { - base.Profiles.Current = "" + if b.Profiles.Current != "" { + if _, err = CurrentProfile(srv); err != nil { + b.Profiles.Current = "" } } - // Set the current profile if there is only one profile or profile is already selected - if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" { - // Set the first profile if none is selected - if base.Profiles.Current == "" { - base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID - } - profile, profileErr := CurrentProfile(server) - // shouldn't happen - if profileErr != nil { - return false, types.NewWrappedError(errorMessage, profileErr) - } - // Profile does not support OpenVPN but the client also doesn't support WireGuard - if !profile.supportsOpenVPN() && !clientSupportsWireguard { - return false, nil - } - return true, nil + if len(b.Profiles.Info.ProfileList) != 1 && b.Profiles.Current == "" { + return false, nil } - return false, nil + // Set the current profile if there is only one profile or profile is already selected + // Set the first profile if none is selected + if b.Profiles.Current == "" { + b.Profiles.Current = b.Profiles.Info.ProfileList[0].ID + } + p, err := CurrentProfile(srv) + // shouldn't happen + if err != nil { + return false, err + } + // Profile does not support OpenVPN but the client also doesn't support WireGuard + if !p.supportsOpenVPN() && !wireguardSupport { + return false, nil + } + return true, nil } -func RefreshEndpoints(server Server) error { - errorMessage := "failed to refresh server endpoints" - +func RefreshEndpoints(srv Server) error { // Re-initialize the endpoints // TODO: Make this a warning instead? - base, baseErr := server.Base() - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) - } - - endpointsErr := base.InitializeEndpoints() - if endpointsErr != nil { - return types.NewWrappedError(errorMessage, endpointsErr) + b, err := srv.Base() + if err != nil { + return err } - return nil + return b.InitializeEndpoints() } -func Config(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) { - errorMessage := "failed getting an OpenVPN/WireGuard configuration" - - profile, profileErr := CurrentProfile(server) - if profileErr != nil { - return "", "", types.NewWrappedError(errorMessage, profileErr) +func Config(server Server, wireguardSupport bool, preferTCP bool) (string, string, error) { + p, err := CurrentProfile(server) + if err != nil { + return "", "", err } - supportsOpenVPN := profile.supportsOpenVPN() - supportsWireguard := profile.supportsWireguard() && clientSupportsWireguard - - var config string - var configType string - var configErr error + ovpn := p.supportsOpenVPN() + wg := p.supportsWireguard() && wireguardSupport switch { // The config supports wireguard and optionally openvpn - case supportsWireguard: + case wg: // A wireguard connect call needs to generate a wireguard key and add it to the config // Also the server could send back an OpenVPN config if it supports OpenVPN - config, configType, configErr = wireguardGetConfig(server, preferTCP, supportsOpenVPN) + return wireguardGetConfig(server, preferTCP, ovpn) // The config only supports OpenVPN - case supportsOpenVPN: - config, configType, configErr = openVPNGetConfig(server, preferTCP) + case ovpn: + return openVPNGetConfig(server, preferTCP) // The config supports no available protocol because the profile only supports WireGuard but the client doesn't default: - return "", "", types.NewWrappedError(errorMessage, errors.New("no supported protocol found")) + return "", "", errors.Errorf("no supported protocol found") } - - if configErr != nil { - return "", "", types.NewWrappedError(errorMessage, configErr) - } - - return config, configType, nil } func Disconnect(server Server) { APIDisconnect(server) } - -type CurrentProfileNotFoundError struct { - ProfileID string -} - -func (e *CurrentProfileNotFoundError) Error() string { - return fmt.Sprintf("failed to get current profile, profile with ID: %s not found", e.ProfileID) -} - -type ConfigPreferTCPError struct{} - -func (e *ConfigPreferTCPError) Error() string { - return "failed to get config, prefer TCP is on but the server does not support OpenVPN" -} - -type EmptyURLError struct{} - -func (e *EmptyURLError) Error() string { - return "failed ensuring server, empty url provided" -} - -type CurrentNoMapError struct{} - -func (e *CurrentNoMapError) Error() string { - return "failed getting current server, no servers available" -} - -type CurrentNotFoundError struct{} - -func (e *CurrentNotFoundError) Error() string { - return "failed getting current server, not found" -} diff --git a/internal/server/servers.go b/internal/server/servers.go index a076770..b34dcff 100644 --- a/internal/server/servers.go +++ b/internal/server/servers.go @@ -1,9 +1,8 @@ package server import ( - "fmt" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) type Servers struct { @@ -14,125 +13,105 @@ type Servers struct { IsType Type `json:"is_secure_internet"` } -func (servers *Servers) AddSecureInternet( +func (ss *Servers) AddSecureInternet( secureOrg *types.DiscoveryOrganization, secureServer *types.DiscoveryServer, ) (Server, error) { - errorMessage := "failed adding secure internet server" // If we have specified an organization ID // We also need to get an authorization template - initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer) + err := ss.SecureInternetHomeServer.init(secureOrg, secureServer) - if initErr != nil { - return nil, types.NewWrappedError(errorMessage, initErr) + if err != nil { + return nil, err } - servers.IsType = SecureInternetServerType - return &servers.SecureInternetHomeServer, nil + ss.IsType = SecureInternetServerType + return &ss.SecureInternetHomeServer, nil } -func (servers *Servers) GetCurrentServer() (Server, error) { - errorMessage := "failed getting current server" - if servers.IsType == SecureInternetServerType { - if !servers.HasSecureLocation() { - return nil, types.NewWrappedError( - errorMessage, - &CurrentNotFoundError{}, - ) +func (ss *Servers) GetCurrentServer() (Server, error) { + //TODO(jwijenbergh): Almost certainly the return type should be pointer (*Server) + if ss.IsType == SecureInternetServerType { + if !ss.HasSecureLocation() { + return nil, errors.Errorf("ss.IsType = %v; ss.HasSecureLocation() = false", ss.IsType) } - return &servers.SecureInternetHomeServer, nil + return &ss.SecureInternetHomeServer, nil } - serversStruct := &servers.InstituteServers + srvs := &ss.InstituteServers - if servers.IsType == CustomServerType { - serversStruct = &servers.CustomServers + if ss.IsType == CustomServerType { + srvs = &ss.CustomServers } - currentServerURL := serversStruct.CurrentURL - bases := serversStruct.Map - if bases == nil { - return nil, types.NewWrappedError( - errorMessage, - &CurrentNoMapError{}, - ) + bs := srvs.Map + if bs == nil { + return nil, errors.Errorf("srvs.Map is nil") } - server, exists := bases[currentServerURL] - if !exists || server == nil { - return nil, types.NewWrappedError( - errorMessage, - &CurrentNotFoundError{}, - ) + if srv, ok := bs[srvs.CurrentURL]; !ok || srv == nil { + return nil, errors.Errorf("server not found") + } else { + return srv, nil } - return server, nil } -func (servers *Servers) addInstituteAndCustom( +func (ss *Servers) addInstituteAndCustom( discoServer *types.DiscoveryServer, isCustom bool, ) (Server, error) { url := discoServer.BaseURL - errorMessage := fmt.Sprintf("failed adding institute access server: %s", url) - toAddServers := &servers.InstituteServers - serverType := InstituteAccessServerType + srvs := &ss.InstituteServers + srvType := InstituteAccessServerType if isCustom { - toAddServers = &servers.CustomServers - serverType = CustomServerType + srvs = &ss.CustomServers + srvType = CustomServerType } - if toAddServers.Map == nil { - toAddServers.Map = make(map[string]*InstituteAccessServer) + if srvs.Map == nil { + srvs.Map = make(map[string]*InstituteAccessServer) } - server, exists := toAddServers.Map[url] + srv, ok := srvs.Map[url] // initialize the server if it doesn't exist yet - if !exists { - server = &InstituteAccessServer{} + if !ok { + srv = &InstituteAccessServer{} } - instituteInitErr := server.init( - url, - discoServer.DisplayName, - discoServer.Type, - discoServer.SupportContact, - ) - if instituteInitErr != nil { - return nil, types.NewWrappedError(errorMessage, instituteInitErr) + if err := srv.init(url, discoServer.DisplayName, discoServer.Type, discoServer.SupportContact); err != nil { + return nil, err } - toAddServers.Map[url] = server - servers.IsType = serverType - return server, nil + srvs.Map[url] = srv + ss.IsType = srvType + return srv, nil } -func (servers *Servers) AddInstituteAccessServer( +func (ss *Servers) AddInstituteAccessServer( instituteServer *types.DiscoveryServer, ) (Server, error) { - return servers.addInstituteAndCustom(instituteServer, false) + return ss.addInstituteAndCustom(instituteServer, false) } -func (servers *Servers) AddCustomServer( +func (ss *Servers) AddCustomServer( customServer *types.DiscoveryServer, ) (Server, error) { - return servers.addInstituteAndCustom(customServer, true) + return ss.addInstituteAndCustom(customServer, true) } -func (servers *Servers) GetSecureLocation() string { - return servers.SecureInternetHomeServer.CurrentLocation +func (ss *Servers) GetSecureLocation() string { + return ss.SecureInternetHomeServer.CurrentLocation } -func (servers *Servers) SetSecureLocation( +func (ss *Servers) SetSecureLocation( chosenLocationServer *types.DiscoveryServer, ) error { - errorMessage := "failed to set secure location" // Make sure to add the current location - _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer) - if addLocationErr != nil { - return types.NewWrappedError(errorMessage, addLocationErr) + if _, err := ss.SecureInternetHomeServer.addLocation(chosenLocationServer); err != nil { + return err } - servers.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode + ss.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode return nil } diff --git a/internal/util/util.go b/internal/util/util.go index ddd165d..558dba7 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -1,4 +1,4 @@ -// package util implements several utility functions that are used across the codebase +// Package util implements several utility functions that are used across the codebase package util import ( @@ -9,7 +9,7 @@ import ( "path" "strings" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) // EnsureValidURL ensures that the input URL is valid to be used internally @@ -19,135 +19,119 @@ import ( // - It makes sure that the URL ends with a / // It returns an error if the URL cannot be parsed. func EnsureValidURL(s string) (string, error) { - parsedURL, parseErr := url.Parse(s) - if parseErr != nil { - return "", types.NewWrappedError( - fmt.Sprintf("failed parsing url: %s", s), - parseErr, - ) + u, err := url.Parse(s) + if err != nil { + return "", errors.WrapPrefix(err, "failed parsing url", 0) } - if parsedURL.Scheme == "" { - parsedURL.Scheme = "https" + if u.Scheme == "" { + u.Scheme = "https" } - if parsedURL.Path != "" { + if u.Path != "" { // Clean the path // https://pkg.go.dev/path#Clean - parsedURL.Path = path.Clean(parsedURL.Path) + u.Path = path.Clean(u.Path) } - returnedURL := parsedURL.String() + str := u.String() // Make sure the URL ends with a / - if returnedURL[len(returnedURL)-1:] != "/" { - returnedURL += "/" + if str[len(str)-1:] != "/" { + str += "/" } - return returnedURL, nil + return str, nil } -// MakeRandomByteSlice creates a cryptographically random byteslice of `size` +// MakeRandomByteSlice creates a cryptographically random bytes slice of `size` // It returns the byte slice (or nil if error) and an error if it could not be generated. func MakeRandomByteSlice(size int) ([]byte, error) { - byteSlice := make([]byte, size) - _, err := rand.Read(byteSlice) - if err != nil { - return nil, types.NewWrappedError("failed reading random", err) + bs := make([]byte, size) + if _, err := rand.Read(bs); err != nil { + return nil, errors.WrapPrefix(err, "failed reading random", 0) } - return byteSlice, nil + return bs, nil } // EnsureDirectory creates a directory with permission 700. -func EnsureDirectory(directory string) error { +func EnsureDirectory(dir string) error { // Create with 700 permissions, read, write, execute only for the owner - mkdirErr := os.MkdirAll(directory, 0o700) - if mkdirErr != nil { - return types.NewWrappedError( - fmt.Sprintf("failed to create directory %s", directory), - mkdirErr, - ) + err := os.MkdirAll(dir, 0o700) + if err != nil { + return errors.WrapPrefix(err, fmt.Sprintf("failed to create directory '%s'", dir), 0) } return nil } -// WAYFEncode an input URL using 'skip Where Are You From' encoding -// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md -// URL encode for skipping where are you from (WAYF). Note that this right now is basically an alias to QueryEscape. -func WAYFEncode(input string) string { - // QueryReplace already replaces a space with a + - // see https://go.dev/play/p/pOfrn-Wsq5 - return url.QueryEscape(input) -} - // ReplaceWAYF replaces an authorization template containing of @RETURN_TO@ and @ORG_ID@ with the authorization URL and the organization ID // See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md -func ReplaceWAYF(authTemplate string, authURL string, orgID string) string { +func ReplaceWAYF(authTplt string, authURL string, orgID string) string { // We just return the authURL in the cases where the template is not given or is invalid - if authTemplate == "" { + if authTplt == "" { return authURL } - if !strings.Contains(authTemplate, "@RETURN_TO@") { + if !strings.Contains(authTplt, "@RETURN_TO@") { return authURL } - if !strings.Contains(authTemplate, "@ORG_ID@") { + if !strings.Contains(authTplt, "@ORG_ID@") { return authURL } // Replace authURL - authTemplate = strings.Replace(authTemplate, "@RETURN_TO@", url.QueryEscape(authURL), 1) + authTplt = strings.Replace(authTplt, "@RETURN_TO@", url.QueryEscape(authURL), 1) // If now there is no more ORG_ID, return as there weren't enough @ symbols - if !strings.Contains(authTemplate, "@ORG_ID@") { + if !strings.Contains(authTplt, "@ORG_ID@") { return authURL } // Replace ORG ID - authTemplate = strings.Replace(authTemplate, "@ORG_ID@", url.QueryEscape(orgID), 1) - return authTemplate + authTplt = strings.Replace(authTplt, "@ORG_ID@", url.QueryEscape(orgID), 1) + return authTplt } // GetLanguageMatched uses a map from language tags to strings to extract the right language given the tag // It implements it according to https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY.md#language-matching -func GetLanguageMatched(languageMap map[string]string, languageTag string) string { +func GetLanguageMatched(langMap map[string]string, langTag string) string { // If no map is given, return the empty string - if len(languageMap) == 0 { + if len(langMap) == 0 { return "" } // Try to find the exact match - if val, ok := languageMap[languageTag]; ok { + if val, ok := langMap[langTag]; ok { return val } // Try to find a key that starts with the OS language setting - for k := range languageMap { - if strings.HasPrefix(k, languageTag) { - return languageMap[k] + for k := range langMap { + if strings.HasPrefix(k, langTag) { + return langMap[k] } } // Try to find a key that starts with the first part of the OS language (e.g. de-) - splitted := strings.Split(languageTag, "-") + pts := strings.Split(langTag, "-") // We have a "-" - if len(splitted) > 1 { - for k := range languageMap { - if strings.HasPrefix(k, splitted[0]+"-") { - return languageMap[k] + if len(pts) > 1 { + for k := range langMap { + if strings.HasPrefix(k, pts[0]+"-") { + return langMap[k] } } } // search for just the language (e.g. de) - for k := range languageMap { - if k == splitted[0] { - return languageMap[k] + for k := range langMap { + if k == pts[0] { + return langMap[k] } } // Pick one that is deemed best, e.g. en-US or en, but note that not all languages are always available! // We force an entry that is english exactly or with an english prefix - for k := range languageMap { + for k := range langMap { if k == "en" || strings.HasPrefix(k, "en-") { - return languageMap[k] + return langMap[k] } } // Otherwise just return one - for k := range languageMap { - return languageMap[k] + for k := range langMap { + return langMap[k] } return "" diff --git a/internal/verify/verify.go b/internal/verify/verify.go index 55b82b6..cd74a2b 100644 --- a/internal/verify/verify.go +++ b/internal/verify/verify.go @@ -1,10 +1,10 @@ -// package verify implement signature verification using minisign +// Package verify implement signature verification using minisign package verify import ( "fmt" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" "github.com/jedisct1/go-minisign" ) @@ -31,7 +31,7 @@ func Verify( "RWRtBSX1alxyGX+Xn3LuZnWUT0w//B6EmTJvgaAxBMYzlQeI+jdrO6KF", // fkooman@tuxed.net, kolla@uninett.no "RWQKqtqvd0R7rUDp0rWzbtYPA3towPWcLDCl7eY9pBMMI/ohCmrS0WiM", // RoSp } - valid, err := verifyWithKeys( + return verifyWithKeys( signatureFileContent, signedJSON, expectedFileName, @@ -39,10 +39,6 @@ func Verify( keyStrs, forcePrehash, ) - if err != nil { - return valid, types.NewWrappedError("failed signature verify", err) - } - return valid, nil } // verifyWithKeys verifies the Minisign signature in signatureFileContent (minisig file format) over the server_list/organization_list JSON in signedJSON. @@ -67,23 +63,21 @@ func verifyWithKeys( case "server_list.json", "organization_list.json": break default: - return false, &UnknownExpectedFilenameError{ - Filename: filename, - Expected: "server_list.json or organization_list.json", - } + return false, errors.Errorf( + "invalid filename '%s'; expected 'server_list.json' or 'organization_list.json'", + filename) } sig, err := minisign.DecodeSignature(signatureFileContent) if err != nil { - return false, &InvalidSignatureFormatError{Err: err} + return false, errors.WrapPrefix(err, "invalid signature format", 0) } // Check if signature is prehashed, see https://jedisct1.github.io/minisign/#signature-format if forcePrehash && sig.SignatureAlgorithm != [2]byte{'E', 'D'} { - return false, &InvalidSignatureAlgorithmError{ - Algorithm: string(sig.SignatureAlgorithm[:]), - WantedAlgorithm: "ED (BLAKE2b-prehashed EdDSA)", - } + return false, errors.Errorf( + "invalid signature algorithm '%s'; expected `ED (BLAKE2b-prehashed EdDSA)`", + sig.SignatureAlgorithm[:]) } // Find allowed key used for signature @@ -91,7 +85,7 @@ func verifyWithKeys( key, err := minisign.NewPublicKey(keyStr) if err != nil { // Should only happen if Verify is wrong or extraKey is invalid - return false, &CreatePublicKeyError{PublicKey: keyStr, Err: err} + return false, errors.WrapPrefix(err, fmt.Sprintf("failed to create public key '%s'", keyStr), 0) } if sig.KeyId != key.KeyId { @@ -100,7 +94,7 @@ func verifyWithKeys( valid, err := key.Verify(signedJSON, sig) if !valid { - return false, &InvalidSignatureError{Err: err} + return false, errors.WrapPrefix(err, "invalid signature", 0) } // Parse trusted comment @@ -114,125 +108,21 @@ func verifyWithKeys( &sigFileName, ) if err != nil { - return false, &InvalidTrustedCommentError{ - TrustedComment: sig.TrustedComment, - Err: err, - } + return false, errors.WrapPrefix(err, fmt.Sprintf("invalid trusted comment '%s'", sig.TrustedComment), 0) } if sigFileName != filename { - return false, &WrongSigFilenameError{Filename: filename, SigFilename: sigFileName} + return false, errors.Errorf("wrong filename '%s'; expected filename '%s' for signature", + filename, sigFileName) } if signTime < minSignTime { - return false, &SigTimeEarlierError{SigTime: signTime, MinSigTime: minSignTime} + return false, errors.Errorf("sign time %d is before sign tim: %d", signTime, minSignTime) } return true, nil } // No matching allowed key found - return false, &UnknownKeyError{Filename: filename} -} - -type UnknownExpectedFilenameError struct { - Filename string - Expected string -} - -func (e *UnknownExpectedFilenameError) Error() string { - return fmt.Sprintf("invalid filename: %s, expected: %s", e.Filename, e.Expected) -} - -type InvalidSignatureFormatError struct { - Err error -} - -func (e *InvalidSignatureFormatError) Error() string { - return fmt.Sprintf("invalid signature format with error: %v", e.Err) -} - -func (e *InvalidSignatureFormatError) Unwrap() error { - return e.Err -} - -type InvalidSignatureAlgorithmError struct { - Algorithm string - WantedAlgorithm string -} - -func (e *InvalidSignatureAlgorithmError) Error() string { - return fmt.Sprintf( - "invalid signature algorithm: %s, wanted: %s", - e.Algorithm, - e.WantedAlgorithm, - ) -} - -type CreatePublicKeyError struct { - PublicKey string - Err error -} - -func (e *CreatePublicKeyError) Error() string { - return fmt.Sprintf("failed to create public key: %s with error: %v", e.PublicKey, e.Err) -} - -func (e *CreatePublicKeyError) Unwrap() error { - return e.Err -} - -type InvalidSignatureError struct { - Err error -} - -func (e *InvalidSignatureError) Error() string { - return fmt.Sprintf("invalid signature with error: %v", e.Err) -} - -func (e *InvalidSignatureError) Unwrap() error { - return e.Err -} - -type InvalidTrustedCommentError struct { - TrustedComment string - Err error -} - -func (e *InvalidTrustedCommentError) Error() string { - return fmt.Sprintf("invalid trusted comment: %s with error: %v", e.TrustedComment, e.Err) -} - -func (e *InvalidTrustedCommentError) Unwrap() error { - return e.Err -} - -type WrongSigFilenameError struct { - Filename string - SigFilename string -} - -func (e *WrongSigFilenameError) Error() string { - return fmt.Sprintf( - "wrong filename: %s, expected filename: %s for signature", - e.Filename, - e.SigFilename, - ) -} - -type SigTimeEarlierError struct { - SigTime uint64 - MinSigTime uint64 -} - -func (e *SigTimeEarlierError) Error() string { - return fmt.Sprintf("Sign time: %d is earlier than sign time: %d", e.SigTime, e.MinSigTime) -} - -type UnknownKeyError struct { - Filename string -} - -func (e *UnknownKeyError) Error() string { - return fmt.Sprintf("signature for filename: %s was created with an unknown key", e.Filename) + return false, errors.Errorf("signature for filename '%s' was created with an unknown key", filename) } diff --git a/internal/verify/verify_test.go b/internal/verify/verify_test.go index 8ebed4c..a80cbfc 100644 --- a/internal/verify/verify_test.go +++ b/internal/verify/verify_test.go @@ -2,10 +2,10 @@ package verify import ( "bufio" - "errors" "fmt" "io/ioutil" "os" + "strings" "testing" ) @@ -28,29 +28,17 @@ func Test_verifyWithKeys(t *testing.T) { pk = []string{scanner.Text()} } - var ( - verifyCreatePublicKeyError *CreatePublicKeyError - verifyInvalidSignatureAlgorithmError *InvalidSignatureAlgorithmError - verifyWrongSigFilenameError *WrongSigFilenameError - verifyInvalidTrustedCommentError *InvalidTrustedCommentError - verifyInvalidSignatureFormatError *InvalidSignatureFormatError - verifyInvalidSignatureError *InvalidSignatureError - verifySigTimeEarlierError *SigTimeEarlierError - verifyUnknownExpectedFilenameError *UnknownExpectedFilenameError - verifyUnknownKeyError *UnknownKeyError - ) - tests := []struct { - expectedErr interface{} - testName string - signatureFile string - jsonFile string - expectedFileName string - minSignTime uint64 - allowedPks []string + expectedErrPrefix string + testName string + signatureFile string + jsonFile string + expectedFileName string + minSignTime uint64 + allowedPks []string }{ { - &verifyInvalidSignatureAlgorithmError, + "invalid signature algorithm '", "pure", "server_list.json.pure.minisig", "server_list.json", @@ -58,9 +46,8 @@ func Test_verifyWithKeys(t *testing.T) { 10, pk, }, - { - nil, + "", "valid server_list", "server_list.json.minisig", "server_list.json", @@ -69,7 +56,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - nil, + "", "TC no hashed", "server_list.json.tc_nohashed.minisig", "server_list.json", @@ -78,7 +65,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - nil, + "", "TC later time", "server_list.json.tc_latertime.minisig", "server_list.json", @@ -87,7 +74,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyWrongSigFilenameError, + "wrong filename '", "server_list TC file:organization_list", "server_list.json.tc_orglist.minisig", "server_list.json", @@ -96,7 +83,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyWrongSigFilenameError, + "wrong filename '", "organization_list as server_list", "organization_list.json.minisig", "organization_list.json", @@ -105,7 +92,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyWrongSigFilenameError, + "wrong filename '", "TC file:otherfile", "server_list.json.tc_otherfile.minisig", "server_list.json", @@ -114,7 +101,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyInvalidTrustedCommentError, + "invalid trusted comment '", "TC no file", "server_list.json.tc_nofile.minisig", "server_list.json", @@ -123,7 +110,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyInvalidTrustedCommentError, + "invalid trusted comment '", "TC no time", "server_list.json.tc_notime.minisig", "server_list.json", @@ -132,7 +119,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyInvalidTrustedCommentError, + "invalid trusted comment '", "TC empty time", "server_list.json.tc_emptytime.minisig", "server_list.json", @@ -141,7 +128,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyWrongSigFilenameError, + "wrong filename '", "TC empty file", "server_list.json.tc_emptyfile.minisig", "server_list.json", @@ -150,7 +137,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyInvalidTrustedCommentError, + "invalid trusted comment '", "TC random", "server_list.json.tc_random.minisig", "server_list.json", @@ -159,7 +146,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - nil, + "", "large time", "server_list.json.large_time.minisig", "server_list.json", @@ -168,7 +155,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - nil, + "", "lower min time", "server_list.json.minisig", "server_list.json", @@ -177,7 +164,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifySigTimeEarlierError, + "sign time", "higher min time", "server_list.json.minisig", "server_list.json", @@ -185,9 +172,8 @@ func Test_verifyWithKeys(t *testing.T) { 11, pk, }, - { - nil, + "", "valid organization_list", "organization_list.json.minisig", "organization_list.json", @@ -196,7 +182,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyWrongSigFilenameError, + "wrong filename '", "organization_list TC file:server_list", "organization_list.json.tc_servlist.minisig", "organization_list.json", @@ -205,7 +191,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyWrongSigFilenameError, + "wrong filename '", "server_list as organization_list", "server_list.json.minisig", "server_list.json", @@ -215,7 +201,7 @@ func Test_verifyWithKeys(t *testing.T) { }, { - &verifyUnknownExpectedFilenameError, + "invalid filename '", "valid other_list", "other_list.json.minisig", "other_list.json", @@ -224,7 +210,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyWrongSigFilenameError, + "wrong filename '", "other_list as server_list", "other_list.json.minisig", "other_list.json", @@ -232,9 +218,8 @@ func Test_verifyWithKeys(t *testing.T) { 10, pk, }, - { - &verifyInvalidSignatureFormatError, + "invalid signature format", "invalid signature file", "random.txt", "server_list.json", @@ -243,7 +228,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyInvalidSignatureFormatError, + "invalid signature format", "empty signature file", "empty", "server_list.json", @@ -253,7 +238,7 @@ func Test_verifyWithKeys(t *testing.T) { }, { - &verifyUnknownKeyError, + "signature for filename '", "wrong key", "server_list.json.wrong_key.minisig", "server_list.json", @@ -263,7 +248,7 @@ func Test_verifyWithKeys(t *testing.T) { }, { - &verifyInvalidSignatureAlgorithmError, + "invalid signature algorithm '", "forged pure signature", "server_list.json.forged_pure.minisig", "server_list.json.blake2b", @@ -272,7 +257,7 @@ func Test_verifyWithKeys(t *testing.T) { pk, }, { - &verifyInvalidSignatureError, + "invalid signature", "forged key ID", "server_list.json.forged_keyid.minisig", "server_list.json", @@ -282,7 +267,7 @@ func Test_verifyWithKeys(t *testing.T) { }, { - &verifyUnknownKeyError, + "signature for filename '", "no allowed keys", "server_list.json.minisig", "server_list.json", @@ -291,7 +276,7 @@ func Test_verifyWithKeys(t *testing.T) { []string{}, }, { - nil, + "", "multiple allowed keys 1", "server_list.json.minisig", "server_list.json", @@ -302,7 +287,7 @@ func Test_verifyWithKeys(t *testing.T) { }, }, { - nil, + "", "multiple allowed keys 2", "server_list.json.minisig", "server_list.json", @@ -313,7 +298,7 @@ func Test_verifyWithKeys(t *testing.T) { }, }, { - &verifyCreatePublicKeyError, + "failed to create public key '", "invalid allowed key", "server_list.json.minisig", "server_list.json", @@ -345,7 +330,7 @@ func Test_verifyWithKeys(t *testing.T) { t.Run(tt.testName, func(t *testing.T) { valid, err := verifyWithKeys(string(files[tt.signatureFile]), files[tt.jsonFile], tt.expectedFileName, tt.minSignTime, tt.allowedPks, forcePrehash) - compareResults(t, valid, err, tt.expectedErr, func() string { + compareResults(t, valid, err, tt.expectedErrPrefix, func() string { return fmt.Sprintf( "verifyWithKeys(%q, %q, %q, %v, %v, %t)", tt.signatureFile, @@ -366,17 +351,33 @@ func compareResults( t *testing.T, ret bool, err error, - expectedErr interface{}, + expectedErrPrefix string, callStr func() string, ) { - // different error returned - if expectedErr != nil && !errors.As(err, expectedErr) { - t.Errorf("%v\nerror %T = %v, wantErr %T", callStr(), err, err, expectedErr) + if expectedErrPrefix == "" { + // we don't expect any error + if err != nil { + t.Errorf("error not expected but returned '%s'", err.Error()) + } + if !ret { + t.Errorf("error is nil and result is false") + } return } - // different boolean returned - expectedBool := expectedErr == nil - if ret != expectedBool { - t.Errorf("%v\n= %v, want %v", callStr(), ret, expectedBool) + + if err == nil { + // we expect an error but received nil + t.Errorf("expected error prefix '%s' but received nil", expectedErrPrefix) + return + } + + if !strings.HasPrefix(err.Error(), expectedErrPrefix) { + // wrong error + t.Errorf("expected error prefix '%s' for error '%s'", expectedErrPrefix, err.Error()) + return + } + + if ret { + t.Errorf("error is not nil and result is true") } } diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 7da2623..0419ff6 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -1,24 +1,21 @@ -// package wireguard implements a few helpers for the WireGuard protocol +// Package wireguard implements a few helpers for the WireGuard protocol package wireguard import ( "fmt" "regexp" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) // GenerateKey generates a WireGuard private key using wgctrl // It returns an error if key generation failed. func GenerateKey() (wgtypes.Key, error) { - key, keyErr := wgtypes.GeneratePrivateKey() + key, err := wgtypes.GeneratePrivateKey() - if keyErr != nil { - return key, types.NewWrappedError( - "failed generating WireGuard key", - keyErr, - ) + if err != nil { + return key, errors.WrapPrefix(err, "failed generating WireGuard key", 0) } return key, nil } -- cgit v1.2.3