summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorAleksandar Pesic <peske.nis@gmail.com>2022-12-04 21:48:20 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-12-12 13:26:51 +0100
commit3ac1d35257b56cca92ad0eb7f4d18abb366cf105 (patch)
tree432db14d1f92a252518f371be420fa0d3ef044c8 /internal
parent37bca013bd4405548b274ac473acf959ad661ee6 (diff)
simplify error handling
fixes #6 Signed-off-by: Aleksandar Pesic <peske.nis@gmail.com>
Diffstat (limited to 'internal')
-rw-r--r--internal/config/config.go50
-rw-r--r--internal/discovery/discovery.go181
-rw-r--r--internal/fsm/fsm.go91
-rw-r--r--internal/http/http.go128
-rw-r--r--internal/log/log.go100
-rw-r--r--internal/oauth/oauth.go307
-rw-r--r--internal/server/api.go214
-rw-r--r--internal/server/base.go29
-rw-r--r--internal/server/custom.go41
-rw-r--r--internal/server/instituteaccess.go85
-rw-r--r--internal/server/secureinternet.go151
-rw-r--r--internal/server/server.go318
-rw-r--r--internal/server/servers.go117
-rw-r--r--internal/util/util.go112
-rw-r--r--internal/verify/verify.go144
-rw-r--r--internal/verify/verify_test.go123
-rw-r--r--internal/wireguard/wireguard.go13
17 files changed, 850 insertions, 1354 deletions
diff --git a/internal/config/config.go b/internal/config/config.go
index 6761d62..ae023ff 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -4,12 +4,11 @@ package config
import (
"encoding/json"
- "fmt"
- "io/ioutil"
+ "os"
"path"
"github.com/eduvpn/eduvpn-common/internal/util"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
// Config represents a configuration that saves the client's struct as JSON.
@@ -22,38 +21,41 @@ type Config struct {
}
// Init initializes the configuration using the provided directory and name.
-func (config *Config) Init(directory string, name string) {
- config.Directory = directory
- config.Name = name
+func (c *Config) Init(directory string, name string) {
+ c.Directory = directory
+ c.Name = name
}
// filename returns the filename of the configuration as a full path.
-func (config *Config) filename() string {
- pathString := path.Join(config.Directory, config.Name)
- return fmt.Sprintf("%s.json", pathString)
+func (c *Config) filename() string {
+ return path.Join(c.Directory, c.Name) + ".json"
}
// Save saves a structure 'readStruct' to the configuration
-// If it was unusuccessful, an error is returned.
-func (config *Config) Save(readStruct interface{}) error {
- errorMessage := "failed saving configuration"
- configDirErr := util.EnsureDirectory(config.Directory)
- if configDirErr != nil {
- return types.NewWrappedError(errorMessage, configDirErr)
+// If it was unsuccessful, an error is returned.
+func (c *Config) Save(readStruct interface{}) error {
+ if err := util.EnsureDirectory(c.Directory); err != nil {
+ return err
+ }
+ cfg, err := json.Marshal(readStruct)
+ if err != nil {
+ return errors.WrapPrefix(err, "json.Marshal failed", 0)
}
- jsonString, marshalErr := json.Marshal(readStruct)
- if marshalErr != nil {
- return types.NewWrappedError(errorMessage, marshalErr)
+ if err = os.WriteFile(c.filename(), cfg, 0o600); err != nil {
+ return errors.WrapPrefix(err, "os.WriteFile failed", 0)
}
- return ioutil.WriteFile(config.filename(), jsonString, 0o600)
+ return nil
}
// Load loads the configuration and writes the structure to 'writeStruct'
// If it was unsuccessful, an error is returned.
-func (config *Config) Load(writeStruct interface{}) error {
- bytes, readErr := ioutil.ReadFile(config.filename())
- if readErr != nil {
- return types.NewWrappedError("failed loading configuration", readErr)
+func (c *Config) Load(writeStruct interface{}) error {
+ bts, err := os.ReadFile(c.filename())
+ if err != nil {
+ return errors.WrapPrefix(err, "failed loading configuration", 0)
+ }
+ if err = json.Unmarshal(bts, writeStruct); err != nil {
+ return errors.WrapPrefix(err, "json.Unmarshal failed", 0)
}
- return json.Unmarshal(bytes, writeStruct)
+ return nil
}
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index 35c2689..32bec66 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -1,4 +1,4 @@
-// package discovery implements the server discovery by contacting disco.eduvpn.org and returning the data as a Go structure
+// Package discovery implements the server discovery by contacting disco.eduvpn.org and returning the data as a Go structure
package discovery
import (
@@ -9,6 +9,7 @@ import (
"github.com/eduvpn/eduvpn-common/internal/http"
"github.com/eduvpn/eduvpn-common/internal/verify"
"github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
// Discovery is the main structure used for this package.
@@ -23,45 +24,41 @@ type Discovery struct {
// discoFile is a helper function that gets a disco JSON and fills the structure with it
// If it was unsuccessful it returns an error.
func discoFile(jsonFile string, previousVersion uint64, structure interface{}) error {
- errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile)
// Get json data
- discoURL := "https://disco.eduvpn.org/v2/"
- fileURL := discoURL + jsonFile
- _, fileBody, fileErr := http.Get(fileURL)
-
- if fileErr != nil {
- return types.NewWrappedError(errorMessage, fileErr)
+ du := "https://disco.eduvpn.org/v2/"
+ fu := du + jsonFile
+ _, body, err := http.Get(fu)
+ if err != nil {
+ return err
}
// Get signature
sigFile := jsonFile + ".minisig"
- sigURL := discoURL + sigFile
- _, sigBody, sigFileErr := http.Get(sigURL)
-
- if sigFileErr != nil {
- return types.NewWrappedError(errorMessage, sigFileErr)
+ sigURL := du + sigFile
+ _, sigBody, err := http.Get(sigURL)
+ if err != nil {
+ return err
}
// Verify signature
// Set this to true when we want to force prehash
- forcePrehash := false
- verifySuccess, verifyErr := verify.Verify(
+ const forcePrehash = false
+ ok, err := verify.Verify(
string(sigBody),
- fileBody,
+ body,
jsonFile,
previousVersion,
forcePrehash,
)
- if !verifySuccess || verifyErr != nil {
- return types.NewWrappedError(errorMessage, verifyErr)
+ if !ok || err != nil {
+ return err
}
// Parse JSON to extract version and list
- jsonErr := json.Unmarshal(fileBody, structure)
-
- if jsonErr != nil {
- return types.NewWrappedError(errorMessage, jsonErr)
+ if err = json.Unmarshal(body, structure); err != nil {
+ return errors.WrapPrefix(err,
+ fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile), 0)
}
return nil
@@ -80,86 +77,67 @@ func (discovery *Discovery) DetermineOrganizationsUpdate() bool {
// SecureLocationList returns a slice of all the available locations.
func (discovery *Discovery) SecureLocationList() []string {
- var locations []string
- for _, currentServer := range discovery.servers.List {
- if currentServer.Type == "secure_internet" {
- locations = append(locations, currentServer.CountryCode)
+ var loc []string
+ for _, srv := range discovery.servers.List {
+ if srv.Type == "secure_internet" {
+ loc = append(loc, srv.CountryCode)
}
}
- return locations
+ return loc
}
// ServerByURL returns the discovery server by the base URL and the according type ("secure_internet", "institute_access")
// An error is returned if and only if nil is returned for the server.
func (discovery *Discovery) ServerByURL(
baseURL string,
- serverType string,
+ srvType string,
) (*types.DiscoveryServer, error) {
for _, currentServer := range discovery.servers.List {
- if currentServer.BaseURL == baseURL && currentServer.Type == serverType {
+ if currentServer.BaseURL == baseURL && currentServer.Type == srvType {
return &currentServer, nil
}
}
- return nil, types.NewWrappedError(
- "failed getting server by URL from discovery",
- &GetServerByURLNotFoundError{URL: baseURL, Type: serverType},
- )
+ return nil, errors.Errorf("no server of type '%s' at URL '%s'", srvType, baseURL)
}
// ServerByCountryCode returns the discovery server by the country code and the according type ("secure_internet", "institute_access")
// An error is returned if and only if nil is returned for the server.
-func (discovery *Discovery) ServerByCountryCode(
- countryCode string,
- serverType string,
-) (*types.DiscoveryServer, error) {
- for _, currentServer := range discovery.servers.List {
- if currentServer.CountryCode == countryCode && currentServer.Type == serverType {
- return &currentServer, nil
+func (discovery *Discovery) ServerByCountryCode(countryCode string, srvType string) (*types.DiscoveryServer, error) {
+ for _, srv := range discovery.servers.List {
+ if srv.CountryCode == countryCode && srv.Type == srvType {
+ return &srv, nil
}
}
- return nil, types.NewWrappedError(
- "failed getting server by country countryCode from discovery",
- &GetServerByCountryCodeNotFoundError{CountryCode: countryCode, Type: serverType},
- )
+ return nil, errors.Errorf("no server of type '%s' with country code '%s'", srvType, countryCode)
}
// orgByID returns the discovery organization by the organization ID
// An error is returned if and only if nil is returned for the organization.
func (discovery *Discovery) orgByID(orgID string) (*types.DiscoveryOrganization, error) {
- for _, organization := range discovery.organizations.List {
- if organization.OrgID == orgID {
- return &organization, nil
+ for _, org := range discovery.organizations.List {
+ if org.OrgID == orgID {
+ return &org, nil
}
}
- return nil, types.NewWrappedError(
- "failed getting Secure Internet Home URL from discovery",
- &GetOrgByIDNotFoundError{ID: orgID},
- )
+ return nil, errors.Errorf("no secure internet home found in organization '%s'", orgID)
}
// SecureHomeArgs returns the secure internet home server arguments:
// - The organization it belongs to
// - The secure internet server itself
// An error is returned if and only if nil is returned for the organization.
-func (discovery *Discovery) SecureHomeArgs(
- orgID string,
-) (*types.DiscoveryOrganization, *types.DiscoveryServer, error) {
- errorMessage := "failed getting Secure Internet Home arguments from discovery"
- org, orgErr := discovery.orgByID(orgID)
-
- if orgErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, orgErr)
+func (discovery *Discovery) SecureHomeArgs(orgID string) (*types.DiscoveryOrganization, *types.DiscoveryServer, error) {
+ org, err := discovery.orgByID(orgID)
+ if err != nil {
+ return nil, nil, err
}
// Get a server with the base url
- url := org.SecureInternetHome
-
- currentServer, serverErr := discovery.ServerByURL(url, "secure_internet")
-
- if serverErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, serverErr)
+ srv, err := discovery.ServerByURL(org.SecureInternetHome, "secure_internet")
+ if err != nil {
+ return nil, nil, err
}
- return org, currentServer, nil
+ return org, srv, nil
}
// DetermineServersUpdate returns whether or not the discovery servers should be updated by contacting the discovery server
@@ -172,9 +150,8 @@ func (discovery *Discovery) DetermineServersUpdate() bool {
return true
}
// 1 hour from the last update
- shouldUpdateTime := discovery.servers.Timestamp.Add(1 * time.Hour)
- now := time.Now()
- return !now.Before(shouldUpdateTime)
+ upd := discovery.servers.Timestamp.Add(1 * time.Hour)
+ return !time.Now().Before(upd)
}
// Organizations returns the discovery organizations
@@ -184,13 +161,10 @@ func (discovery *Discovery) Organizations() (*types.DiscoveryOrganizations, erro
return &discovery.organizations, nil
}
file := "organization_list.json"
- bodyErr := discoFile(file, discovery.organizations.Version, &discovery.organizations)
- if bodyErr != nil {
+ err := discoFile(file, discovery.organizations.Version, &discovery.organizations)
+ if err != nil {
// Return previous with an error
- return &discovery.organizations, types.NewWrappedError(
- "failed getting organizations in Discovery",
- bodyErr,
- )
+ return &discovery.organizations, err
}
discovery.organizations.Timestamp = time.Now()
return &discovery.organizations, nil
@@ -203,63 +177,12 @@ func (discovery *Discovery) Servers() (*types.DiscoveryServers, error) {
return &discovery.servers, nil
}
file := "server_list.json"
- bodyErr := discoFile(file, discovery.servers.Version, &discovery.servers)
- if bodyErr != nil {
+ err := discoFile(file, discovery.servers.Version, &discovery.servers)
+ if err != nil {
// Return previous with an error
- return &discovery.servers, types.NewWrappedError(
- "failed getting servers in Discovery",
- bodyErr,
- )
+ return &discovery.servers, err
}
// Update servers timestamp
discovery.servers.Timestamp = time.Now()
return &discovery.servers, nil
}
-
-type GetOrgByIDNotFoundError struct {
- ID string
-}
-
-func (e GetOrgByIDNotFoundError) Error() string {
- return fmt.Sprintf(
- "No Secure Internet Home found in organizations with ID %s. Please choose your server again",
- e.ID,
- )
-}
-
-type GetServerByURLNotFoundError struct {
- URL string
- Type string
-}
-
-func (e GetServerByURLNotFoundError) Error() string {
- return fmt.Sprintf(
- "No institute access server found in organizations with URL %s and type %s. Please choose your server again",
- e.URL,
- e.Type,
- )
-}
-
-type GetServerByCountryCodeNotFoundError struct {
- CountryCode string
- Type string
-}
-
-func (e GetServerByCountryCodeNotFoundError) Error() string {
- return fmt.Sprintf(
- "No institute access server found in organizations with country code %s and type %s",
- e.CountryCode,
- e.Type,
- )
-}
-
-type GetSecureHomeArgsNotFoundError struct {
- URL string
-}
-
-func (e GetSecureHomeArgsNotFoundError) Error() string {
- return fmt.Sprintf(
- "No Secure Internet Home found with URL: %s. Please choose your server again",
- e.URL,
- )
-}
diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go
index 5bfa712..6c8923b 100644
--- a/internal/fsm/fsm.go
+++ b/internal/fsm/fsm.go
@@ -9,7 +9,7 @@ import (
"path"
"sort"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
type (
@@ -97,36 +97,47 @@ func (fsm *FSM) InState(check StateID) bool {
}
// HasTransition checks whether or not the state machine has a transition to the given 'check' state.
-func (fsm *FSM) HasTransition(check StateID) bool {
- for _, transitionState := range fsm.States[fsm.Current].Transitions {
- if transitionState.To == check {
- return true
+//func (fsm *FSM) HasTransition(check StateID) bool {
+// for _, transitionState := range fsm.States[fsm.Current].Transitions {
+// if transitionState.To == check {
+// return true
+// }
+// }
+//
+// return false
+//}
+
+func (fsm *FSM) CheckTransition(desired StateID) error {
+ for _, ts := range fsm.States[fsm.Current].Transitions {
+ if ts.To == desired {
+ return nil
}
}
-
- return false
+ return errors.Errorf("fsm invalid transition attempt from '%v' to '%v'", fsm.Current, desired)
}
// graphFilename gets the full path to the graph filename including the .graph extension.
func (fsm *FSM) graphFilename(extension string) string {
- debugPath := path.Join(fsm.Directory, "graph")
- return fmt.Sprintf("%s%s", debugPath, extension)
+ pth := path.Join(fsm.Directory, "graph")
+ return fmt.Sprintf("%s%s", pth, extension)
}
// writeGraph writes the state machine to a .graph file.
func (fsm *FSM) writeGraph() {
- graph := fsm.GenerateGraph()
- graphFile := fsm.graphFilename(".graph")
- graphImgFile := fsm.graphFilename(".png")
- f, err := os.Create(graphFile)
+ gph := fsm.GenerateGraph()
+ gf := fsm.graphFilename(".graph")
+ gif := fsm.graphFilename(".png")
+ f, err := os.Create(gf)
if err != nil {
return
}
+ defer func() {
+ _ = f.Close()
+ }()
- _, writeErr := f.WriteString(graph)
- f.Close()
- if writeErr != nil {
- cmd := exec.Command("mmdc", "-i", graphFile, "-o", graphImgFile, "--scale", "4")
+ _, err = f.WriteString(gph)
+ if err != nil {
+ cmd := exec.Command("mmdc", "-i", gf, "-o", gif, "--scale", "4")
// Generating is best effort
_ = cmd.Start()
}
@@ -137,14 +148,7 @@ func (fsm *FSM) writeGraph() {
func (fsm *FSM) GoTransitionRequired(newState StateID, data interface{}) error {
oldState := fsm.Current
if !fsm.GoTransitionWithData(newState, data) {
- return types.NewWrappedError(
- "failed required transition",
- fmt.Errorf(
- "required transition not handled, from: %s -> to: %s",
- fsm.GetStateName(oldState),
- fsm.GetStateName(newState),
- ),
- )
+ return errors.Errorf("fsm failed transition from '%v' to '%v'", oldState, newState)
}
return nil
}
@@ -152,20 +156,17 @@ func (fsm *FSM) GoTransitionRequired(newState StateID, data interface{}) error {
// GoTransitionWithData is a helper that transitions the state machine toward the 'newState' with associated state data 'data'
// It returns whether or not the transition is handled by the client.
func (fsm *FSM) GoTransitionWithData(newState StateID, data interface{}) bool {
- ok := fsm.HasTransition(newState)
-
- handled := false
- if ok {
- oldState := fsm.Current
- fsm.Current = newState
- if fsm.Generate {
- fsm.writeGraph()
- }
+ if fsm.CheckTransition(newState) != nil {
+ return false
+ }
- handled = fsm.StateCallback(oldState, newState, data)
+ prev := fsm.Current
+ fsm.Current = newState
+ if fsm.Generate {
+ fsm.writeGraph()
}
- return handled
+ return fsm.StateCallback(prev, newState, data)
}
// GoTransition is an alias to call GoTransitionWithData but have an empty string as data.
@@ -177,21 +178,21 @@ func (fsm *FSM) GoTransition(newState StateID) bool {
// generateMermaidGraph generates a graph suitable to be converted by the mermaid.js tool
// it returns the graph as a string.
func (fsm *FSM) generateMermaidGraph() string {
- graph := "graph TD\n"
- sortedFSM := make(StateIDSlice, 0, len(fsm.States))
+ gph := "graph TD\n"
+ sf := make(StateIDSlice, 0, len(fsm.States))
for stateID := range fsm.States {
- sortedFSM = append(sortedFSM, stateID)
+ sf = append(sf, stateID)
}
- sort.Sort(sortedFSM)
- for _, state := range sortedFSM {
+ sort.Sort(sf)
+ for _, state := range sf {
transitions := fsm.States[state].Transitions
for _, transition := range transitions {
if state == fsm.Current {
- graph += "\nstyle " + fsm.GetStateName(state) + " fill:cyan\n"
+ gph += "\nstyle " + fsm.GetStateName(state) + " fill:cyan\n"
} else {
- graph += "\nstyle " + fsm.GetStateName(state) + " fill:white\n"
+ gph += "\nstyle " + fsm.GetStateName(state) + " fill:white\n"
}
- graph += fsm.GetStateName(
+ gph += fsm.GetStateName(
state,
) + "(" + fsm.GetStateName(
state,
@@ -200,7 +201,7 @@ func (fsm *FSM) generateMermaidGraph() string {
) + "\n"
}
}
- return graph
+ return gph
}
// GenerateGraph generates a mermaid graph if the state machine is initialized
diff --git a/internal/http/http.go b/internal/http/http.go
index 1d0ec45..e1f9bdf 100644
--- a/internal/http/http.go
+++ b/internal/http/http.go
@@ -4,16 +4,15 @@ package http
import (
"fmt"
"io"
- "io/ioutil"
"net/http"
"net/url"
"strings"
"time"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
-// The URLParemeters as the name suggests is a type used for the parameters in the URL.
+// URLParameters is a type used for the parameters in the URL.
type URLParameters map[string]string
// OptionalParams is a structure that defines the optional parameters that are given when making a HTTP call.
@@ -25,26 +24,20 @@ type OptionalParams struct {
}
// ConstructURL creates a URL with the included parameters.
-func ConstructURL(baseURL string, parameters URLParameters) (string, error) {
- url, parseErr := url.Parse(baseURL)
- if parseErr != nil {
- return "", types.NewWrappedError(
- fmt.Sprintf(
- "failed to construct url: %s including parameters: %v",
- url,
- parameters,
- ),
- parseErr,
- )
+func ConstructURL(baseURL string, params URLParameters) (string, error) {
+ u, err := url.Parse(baseURL)
+ if err != nil {
+ return "", errors.WrapPrefix(err,
+ fmt.Sprintf("failed to construct url '%s' with parameters: %v", u, params), 0)
}
- q := url.Query()
+ q := u.Query()
- for parameter, value := range parameters {
- q.Set(parameter, value)
+ for p, value := range params {
+ q.Set(p, value)
}
- url.RawQuery = q.Encode()
- return url.String(), nil
+ u.RawQuery = q.Encode()
+ return u.String(), nil
}
// Get creates a Get request and returns the headers, body and an error.
@@ -52,16 +45,6 @@ func Get(url string) (http.Header, []byte, error) {
return MethodWithOpts(http.MethodGet, url, nil)
}
-// Post creates a Post request and returns the headers, body and an error.
-func Post(url string, body url.Values) (http.Header, []byte, error) {
- return MethodWithOpts(http.MethodGet, url, &OptionalParams{Body: body})
-}
-
-// GetWithOpts creates a Get request with optional parameters and returns the headers, body and an error.
-func GetWithOpts(url string, opts *OptionalParams) (http.Header, []byte, error) {
- return MethodWithOpts(http.MethodGet, url, opts)
-}
-
// PostWithOpts creates a Post request with optional parameters and returns the headers, body and an error.
func PostWithOpts(url string, opts *OptionalParams) (http.Header, []byte, error) {
return MethodWithOpts(http.MethodPost, url, opts)
@@ -69,19 +52,12 @@ func PostWithOpts(url string, opts *OptionalParams) (http.Header, []byte, error)
// optionalURL ensures that the URL contains the optional parameters
// it returns the url (with parameters if success) and an error indicating success.
-func optionalURL(url string, opts *OptionalParams) (string, error) {
- if opts != nil {
- url, urlErr := ConstructURL(url, opts.URLParameters)
-
- if urlErr != nil {
- return url, types.NewWrappedError(
- fmt.Sprintf("failed to create HTTP request with url: %s", url),
- urlErr,
- )
- }
- return url, nil
+func optionalURL(urlStr string, opts *OptionalParams) (string, error) {
+ if opts == nil {
+ return urlStr, nil
}
- return url, nil
+
+ return ConstructURL(urlStr, opts.URLParameters)
}
// optionalHeaders ensures that the HTTP request uses the optional headers if defined.
@@ -106,16 +82,16 @@ func optionalBodyReader(opts *OptionalParams) io.Reader {
// It returns the HTTP headers, the body and an error if there is one.
func MethodWithOpts(
method string,
- url string,
+ urlStr string,
opts *OptionalParams,
) (http.Header, []byte, error) {
// Make sure the url contains all the parameters
// This can return an error,
- // it already has the right error so we don't wrap it further
- url, urlErr := optionalURL(url, opts)
- if urlErr != nil {
+ // it already has the right error, so we don't wrap it further
+ urlStr, err := optionalURL(urlStr, opts)
+ if err != nil {
// No further type wrapping is needed here
- return nil, nil, urlErr
+ return nil, nil, err
}
// Default timeout is 5 seconds
@@ -125,15 +101,11 @@ func MethodWithOpts(
timeout = opts.Timeout
}
- // Create a client
- client := &http.Client{Timeout: timeout * time.Second}
-
- errorMessage := fmt.Sprintf("failed HTTP request with method %s and url %s", method, url)
-
// Create request object with the body reader generated from the optional arguments
- req, reqErr := http.NewRequest(method, url, optionalBodyReader(opts))
- if reqErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, reqErr)
+ req, err := http.NewRequest(method, urlStr, optionalBodyReader(opts))
+ if err != nil {
+ return nil, nil, errors.WrapPrefix(err,
+ fmt.Sprintf("failed HTTP request with method %s and url %s", method, urlStr), 0)
}
// See https://stackoverflow.com/questions/17714494/golang-http-request-results-in-eof-errors-when-making-multiple-requests-successi
@@ -142,29 +114,34 @@ func MethodWithOpts(
// Make sure the headers contain all the parameters
optionalHeaders(req, opts)
+ // Create a client
+ c := &http.Client{Timeout: timeout * time.Second}
+
// Do request
- resp, respErr := client.Do(req)
- if respErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, respErr)
+ res, err := c.Do(req)
+ if err != nil {
+ return nil, nil, errors.WrapPrefix(err,
+ fmt.Sprintf("failed HTTP request with method %s and url %s", method, urlStr), 0)
}
// Request successful, make sure body is closed at the end
- defer resp.Body.Close()
+ defer func() {
+ _ = res.Body.Close()
+ }()
// Return a string
- body, readErr := ioutil.ReadAll(resp.Body)
- if readErr != nil {
- return resp.Header, nil, types.NewWrappedError(errorMessage, readErr)
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ return res.Header, nil, errors.WrapPrefix(err,
+ fmt.Sprintf("failed HTTP request with method %s and url %s", method, urlStr), 0)
}
- if resp.StatusCode < 200 || resp.StatusCode > 299 {
- // We make this a custom error because we want to extract the status code later
- statusErr := &StatusError{URL: url, Body: string(body), Status: resp.StatusCode}
- return resp.Header, body, types.NewWrappedError(errorMessage, statusErr)
+ if res.StatusCode < 200 || res.StatusCode > 299 {
+ return res.Header, body, errors.Wrap(&StatusError{URL: urlStr, Body: string(body), Status: res.StatusCode}, 0)
}
// Return the body in bytes and signal the status error if there was one
- return resp.Header, body, nil
+ return res.Header, body, nil
}
// StatusError indicates that we have received a HTTP status error.
@@ -183,22 +160,3 @@ func (e *StatusError) Error() string {
e.Body,
)
}
-
-// ParseJSONError indicates that the HTTP error is because of failed JSON parsing
-// It has the URL and the Body as context.
-// The underlying JSON parsing Err itself is also wrapped here.
-type ParseJSONError struct {
- URL string
- Body string
- Err error
-}
-
-// Error returns the ParseJSONError as an error string.
-func (e *ParseJSONError) Error() string {
- return fmt.Sprintf(
- "failed parsing json %s for HTTP resource: %s with error: %v",
- e.Body,
- e.URL,
- e.Err,
- )
-}
diff --git a/internal/log/log.go b/internal/log/log.go
index 3c3218c..99d9f79 100644
--- a/internal/log/log.go
+++ b/internal/log/log.go
@@ -3,15 +3,49 @@ package log
import (
"fmt"
+ "github.com/eduvpn/eduvpn-common/internal/oauth"
"io"
"log"
"os"
"path"
"github.com/eduvpn/eduvpn-common/internal/util"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
+type ErrLevel int8
+
+const (
+ ErrOther ErrLevel = iota
+ ErrInfo
+ ErrWarning
+ ErrFatal
+)
+
+func GetErrorLevel(err error) ErrLevel {
+ if err == nil {
+ return ErrOther
+ }
+
+ getLevel := func(e error) ErrLevel {
+ if e == nil {
+ return ErrOther
+ }
+
+ switch e.(type) {
+ case *oauth.CancelledCallbackError:
+ return ErrInfo
+ default:
+ return ErrOther
+ }
+ }
+
+ if err1, ok := err.(*errors.Error); ok {
+ return getLevel(err1.Err)
+ }
+ return getLevel(err)
+}
+
// FileLogger defines the type of logger that this package implements
// As the name suggests, it saves the log to a file.
type FileLogger struct {
@@ -66,42 +100,40 @@ func (e Level) String() string {
// Init initializes the logger by forwarding a max level 'level' and a directory 'directory' where the log should be stored
// If the logger cannot be initialized, for example an error in opening the log file, an error is returned.
-func (logger *FileLogger) Init(level Level, directory string) error {
- errorMessage := "failed creating log"
-
- configDirErr := util.EnsureDirectory(directory)
- if configDirErr != nil {
- return types.NewWrappedError(errorMessage, configDirErr)
+func (logger *FileLogger) Init(lvl Level, dir string) error {
+ err := util.EnsureDirectory(dir)
+ if err != nil {
+ return err
}
- logFile, logOpenErr := os.OpenFile(
- logger.filename(directory),
+ f, err := os.OpenFile(
+ logger.filename(dir),
os.O_RDWR|os.O_CREATE|os.O_APPEND,
0o666,
)
- if logOpenErr != nil {
- return types.NewWrappedError(errorMessage, logOpenErr)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed creating log", 0)
}
- multi := io.MultiWriter(os.Stdout, logFile)
+ multi := io.MultiWriter(os.Stdout, f)
log.SetOutput(multi)
- logger.file = logFile
- logger.Level = level
+ logger.file = f
+ logger.Level = lvl
return nil
}
// Inherit logs an error with a label using the error level of the error.
-func (logger *FileLogger) Inherit(label string, err error) {
- level := types.ErrorLevel(err)
-
- msg := fmt.Sprintf("%s with err: %s", label, types.ErrorTraceback(err))
- switch level {
- case types.ErrInfo:
- logger.Infof(msg)
- case types.ErrWarning:
- logger.Warningf(msg)
- case types.ErrOther:
- logger.Errorf(msg)
- case types.ErrFatal:
- logger.Fatalf(msg)
+func (logger *FileLogger) Inherit(err error) {
+ if err == nil {
+ return
+ }
+ switch GetErrorLevel(err) {
+ case ErrInfo:
+ logger.Infof(err.Error())
+ case ErrWarning:
+ logger.Warningf(err.Error())
+ case ErrOther:
+ logger.Errorf(err.Error())
+ case ErrFatal:
+ logger.Fatalf(err.Error())
}
}
@@ -131,8 +163,8 @@ func (logger *FileLogger) Fatalf(msg string, params ...interface{}) {
}
// Close closes the logger by closing the internal file.
-func (logger *FileLogger) Close() {
- logger.file.Close()
+func (logger *FileLogger) Close() error {
+ return logger.file.Close()
}
// filename returns the filename of the logger by returning the full path as a string.
@@ -141,11 +173,11 @@ func (logger *FileLogger) filename(directory string) string {
}
// log logs as level 'level' a message 'msg' with parameters 'params'.
-func (logger *FileLogger) log(level Level, msg string, params ...interface{}) {
- if level >= logger.Level && logger.Level != LevelNotSet {
- formattedMsg := fmt.Sprintf(msg, params...)
- format := fmt.Sprintf("- Go - %s - %s", level.String(), formattedMsg)
+func (logger *FileLogger) log(lvl Level, msg string, params ...interface{}) {
+ if lvl >= logger.Level && logger.Level != LevelNotSet {
+ fMsg := fmt.Sprintf(msg, params...)
+ f := fmt.Sprintf("- Go - %s - %s", lvl.String(), fMsg)
// To log file
- log.Println(format)
+ log.Println(f)
}
}
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 802295d..3dcd3d3 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -1,4 +1,4 @@
-// package oauth implement an oauth client defined in e.g. rfc 6749
+// Package oauth implement an oauth client defined in e.g. rfc 6749
// However, we try to follow some recommendations from the v2.1 oauth draft RFC
// Some specific things we implement here:
// - PKCE (RFC 7636)
@@ -10,7 +10,6 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
- "errors"
"fmt"
"html/template"
"net"
@@ -20,7 +19,7 @@ import (
httpw "github.com/eduvpn/eduvpn-common/internal/http"
"github.com/eduvpn/eduvpn-common/internal/util"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
// genState generates a random base64 string to be used for state
@@ -31,13 +30,13 @@ import (
// client.
// We implement it similarly to the verifier.
func genState() (string, error) {
- randomBytes, err := util.MakeRandomByteSlice(32)
+ bts, err := util.MakeRandomByteSlice(32)
if err != nil {
- return "", types.NewWrappedError("failed generating an OAuth state", err)
+ return "", err
}
- // For consistency we also use raw url encoding here
- return base64.RawURLEncoding.EncodeToString(randomBytes), nil
+ // For consistency, we also use raw url encoding here
+ return base64.RawURLEncoding.EncodeToString(bts), nil
}
// genChallengeS256 generates a sha256 base64 challenge from a verifier
@@ -68,10 +67,7 @@ func genChallengeS256(verifier string) string {
func genVerifier() (string, error) {
randomBytes, err := util.MakeRandomByteSlice(32)
if err != nil {
- return "", types.NewWrappedError(
- "failed generating an OAuth verifier",
- err,
- )
+ return "", err
}
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
@@ -89,10 +85,10 @@ type OAuth struct {
TokenURL string `json:"token_url"`
// session is the internal in progress OAuth session
- session ExchangeSession `json:"-"`
+ session ExchangeSession
// Token is where the access and refresh tokens are stored along with the timestamps
- token Token `json:"-"`
+ token Token
}
// ExchangeSession is a structure that gets passed to the callback for easy access to the current state.
@@ -126,39 +122,31 @@ type ExchangeSession struct {
// It returns the access token as a string, possibly obtained fresh using the Refresh Token
// If the token cannot be obtained, an error is returned and the token is an empty string.
func (oauth *OAuth) AccessToken() (string, error) {
- errorMessage := "failed getting access token"
- tokens := oauth.token
+ ts := oauth.token
// We have tokens...
// The tokens are not expired yet
- // So they should be valid, re-authorization not needed
- if !tokens.Expired() {
- return tokens.access, nil
+ // So they should be valid, re-login not needed
+ if !ts.Expired() {
+ return ts.access, nil
}
// Check if refresh is even possible by doing a simple check if the refresh token is empty
// This is not needed but reduces API calls to the server
- if tokens.refresh == "" {
- return "", types.NewWrappedError(
- errorMessage,
- &TokensInvalidError{Cause: "no refresh token is present"},
- )
+ if ts.refresh == "" {
+ return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0)
}
// Otherwise refresh and then later return the access token if we are successful
- refreshErr := oauth.tokensWithRefresh()
- if refreshErr != nil {
+ err := oauth.tokensWithRefresh()
+ if err != nil {
// We have failed to ensure the tokens due to refresh not working
- return "", types.NewWrappedError(
- errorMessage,
- &TokensInvalidError{
- Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr),
- },
- )
+ return "", errors.Wrap(
+ &TokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", err)}, 0)
}
// We have obtained new tokens with refresh
- return tokens.access, nil
+ return ts.access, nil
}
// setupListener sets up an OAuth listener
@@ -166,24 +154,22 @@ func (oauth *OAuth) AccessToken() (string, error) {
// @see https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-07.html#section-8.4.2
// "Loopback Interface Redirection".
func (oauth *OAuth) setupListener() error {
- errorMessage := "failed setting up listener"
oauth.session.Context = context.Background()
// create a listener
- listener, listenerErr := net.Listen("tcp", "127.0.0.1:0")
- if listenerErr != nil {
- return types.NewWrappedError(errorMessage, listenerErr)
+ lst, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return errors.WrapPrefix(err, "net.Listen failed", 0)
}
- oauth.session.Listener = listener
+ oauth.session.Listener = lst
return nil
}
// tokensWithCallback gets the OAuth tokens using a local web server
// If it was unsuccessful it returns an error.
func (oauth *OAuth) tokensWithCallback() error {
- errorMessage := "failed getting tokens with callback"
if oauth.session.Listener == nil {
- return types.NewWrappedError(errorMessage, errors.New("no listener"))
+ return errors.Errorf("failed getting tokens with callback: no listener")
}
mux := http.NewServeMux()
// server /callback over the listener address
@@ -196,7 +182,7 @@ func (oauth *OAuth) tokensWithCallback() error {
mux.HandleFunc("/callback", oauth.Callback)
if err := oauth.session.Server.Serve(oauth.session.Listener); err != http.ErrServerClosed {
- return types.NewWrappedError(errorMessage, err)
+ return errors.WrapPrefix(err, "failed getting tokens with callback", 0)
}
return oauth.session.CallbackError
}
@@ -205,23 +191,18 @@ func (oauth *OAuth) tokensWithCallback() error {
// It calculates the expired timestamp by having a 'startTime' passed to it
// The URL that is input here is used for additional context.
func (oauth *OAuth) fillToken(response []byte, startTime time.Time, url string) error {
- responseStructure := TokenResponse{}
-
- jsonErr := json.Unmarshal(response, &responseStructure)
- if jsonErr != nil {
- return types.NewWrappedError(
- "failed filling OAuth tokens",
- &httpw.ParseJSONError{URL: url, Body: string(response), Err: jsonErr},
- )
- }
-
- internalStructure := Token{}
- internalStructure.expiredTimestamp = startTime.Add(
- time.Second * time.Duration(responseStructure.Expires),
- )
- internalStructure.access = responseStructure.Access
- internalStructure.refresh = responseStructure.Refresh
- oauth.token = internalStructure
+ res := TokenResponse{}
+
+ err := json.Unmarshal(response, &res)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed filling OAuth tokens from "+url, 0)
+ }
+
+ oauth.token = Token{
+ access: res.Access,
+ refresh: res.Refresh,
+ expiredTimestamp: startTime.Add(time.Second * time.Duration(res.Expires)),
+ }
return nil
}
@@ -240,14 +221,13 @@ func (oauth *OAuth) SetTokenRenew() {
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
// If it was unsuccessful it returns an error.
func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
- errorMessage := "failed getting tokens with the authorization code"
// Make sure the verifier is set as the parameter
// so that the server can verify that we are the actual owner of the authorization code
- reqURL := oauth.TokenURL
+ u := oauth.TokenURL
- port, portErr := oauth.ListenerPort()
- if portErr != nil {
- return types.NewWrappedError(errorMessage, portErr)
+ port, err := oauth.ListenerPort()
+ if err != nil {
+ return err
}
data := url.Values{
@@ -257,21 +237,17 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
"grant_type": {"authorization_code"},
"redirect_uri": {fmt.Sprintf("http://127.0.0.1:%d/callback", port)},
}
- headers := http.Header{
+ h := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
}
- opts := &httpw.OptionalParams{Headers: headers, Body: data}
- currentTime := time.Now()
- _, body, bodyErr := httpw.PostWithOpts(reqURL, opts)
- if bodyErr != nil {
- return types.NewWrappedError(errorMessage, bodyErr)
+ opts := &httpw.OptionalParams{Headers: h, Body: data}
+ now := time.Now()
+ _, body, err := httpw.PostWithOpts(u, opts)
+ if err != nil {
+ return err
}
- fillErr := oauth.fillToken(body, currentTime, reqURL)
- if fillErr != nil {
- return types.NewWrappedError(errorMessage, fillErr)
- }
- return nil
+ return oauth.fillToken(body, now, u)
}
// tokensWithRefresh gets the access and refresh tokens with a previously received refresh token
@@ -279,27 +255,22 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
// If it was unsuccessful it returns an error.
func (oauth *OAuth) tokensWithRefresh() error {
- errorMessage := "failed getting tokens with the refresh token"
- reqURL := oauth.TokenURL
+ u := oauth.TokenURL
data := url.Values{
"refresh_token": {oauth.token.refresh},
"grant_type": {"refresh_token"},
}
- headers := http.Header{
+ h := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
}
- opts := &httpw.OptionalParams{Headers: headers, Body: data}
- currentTime := time.Now()
- _, body, bodyErr := httpw.PostWithOpts(reqURL, opts)
- if bodyErr != nil {
- return types.NewWrappedError(errorMessage, bodyErr)
+ opts := &httpw.OptionalParams{Headers: h, Body: data}
+ now := time.Now()
+ _, body, err := httpw.PostWithOpts(u, opts)
+ if err != nil {
+ return err
}
- fillErr := oauth.fillToken(body, currentTime, reqURL)
- if fillErr != nil {
- return types.NewWrappedError(errorMessage, fillErr)
- }
- return nil
+ return oauth.fillToken(body, now, u)
}
// responseTemplate is the HTML template for the OAuth authorized response
@@ -349,27 +320,17 @@ type oauthResponseHTML struct {
// writeResponseHTML writes the OAuth response using a response writer and the title + message
// If it was unsuccessful it returns an error.
func writeResponseHTML(w http.ResponseWriter, title string, message string) error {
- errorMessage := "failed writing response HTML"
- template, templateErr := template.New("oauth-response").Parse(responseTemplate)
- if templateErr != nil {
- return types.NewWrappedError(errorMessage, templateErr)
+ t, err := template.New("oauth-response").Parse(responseTemplate)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed writing response HTML", 0)
}
- executeErr := template.Execute(w, oauthResponseHTML{
- Title: title,
- Message: message,
- })
- if executeErr != nil {
- return types.NewWrappedError(errorMessage, executeErr)
- }
- return nil
+ return t.Execute(w, oauthResponseHTML{Title: title, Message: message})
}
// Callback is the public function used to get the OAuth tokens using an authorization code callback
// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
- errorMessage := "failed callback to retrieve the authorization code"
-
// Shutdown after we're done
defer func() {
// writing the html is best effort
@@ -383,64 +344,49 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
_ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.")
}
if oauth.session.Server != nil {
- go oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
+ go func() {
+ _ = oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
+ }()
}
}()
// ISS: https://www.rfc-editor.org/rfc/rfc9207.html
// TODO: Make this a required parameter in the future
- urlQuery := req.URL.Query()
- extractedISS := urlQuery.Get("iss")
- if extractedISS != "" {
- if oauth.session.ISS != extractedISS {
- oauth.session.CallbackError = types.NewWrappedError(
- errorMessage,
- &CallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.session.ISS},
- )
+ q := req.URL.Query()
+ iss := q.Get("iss")
+ if iss != "" {
+ if oauth.session.ISS != iss {
+ oauth.session.CallbackError = errors.Errorf("failed matching ISS; expected '%s' got '%s'",
+ oauth.session.ISS, iss)
return
}
}
// Make sure the state is present and matches to protect against cross-site request forgeries
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
- extractedState := urlQuery.Get("state")
- if extractedState == "" {
- oauth.session.CallbackError = types.NewWrappedError(
- errorMessage,
- &CallbackParameterError{Parameter: "state", URL: req.URL.String()},
- )
+ state := q.Get("state")
+ if state == "" {
+ oauth.session.CallbackError = errors.Errorf("failed retrieving parameter 'state' from '%s'", req.URL)
return
}
// The state is the first entry
- if extractedState != oauth.session.State {
- oauth.session.CallbackError = types.NewWrappedError(
- errorMessage,
- &CallbackStateMatchError{
- State: extractedState,
- ExpectedState: oauth.session.State,
- },
- )
+ if state != oauth.session.State {
+ oauth.session.CallbackError = errors.Errorf("failed matching state; expected '%s' got '%s'",
+ oauth.session.State, state)
return
}
// No authorization code
- extractedCode := urlQuery.Get("code")
- if extractedCode == "" {
- oauth.session.CallbackError = types.NewWrappedError(
- errorMessage,
- &CallbackParameterError{Parameter: "code", URL: req.URL.String()},
- )
+ code := q.Get("code")
+ if code == "" {
+ oauth.session.CallbackError = errors.Errorf("failed retrieving parameter 'code' from '%s'", req.URL)
return
}
// Now that we have obtained the authorization code, we can move to the next step:
// Obtaining the access and refresh tokens
- getTokensErr := oauth.tokensWithAuthCode(extractedCode)
- if getTokensErr != nil {
- oauth.session.CallbackError = types.NewWrappedError(
- errorMessage,
- getTokensErr,
- )
+ if err := oauth.tokensWithAuthCode(code); err != nil {
+ oauth.session.CallbackError = errors.WrapPrefix(err, "failed callback to retrieve the authorization code", 0)
return
}
}
@@ -457,94 +403,78 @@ func (oauth *OAuth) Init(iss string, baseAuthorizationURL string, tokenURL strin
// ListenerPort gets the listener for the OAuth web server
// It returns the port as an integer and an error if there is any.
-func (oauth OAuth) ListenerPort() (int, error) {
- errorMessage := "failed to get listener port"
-
+func (oauth *OAuth) ListenerPort() (int, error) {
if oauth.session.Listener == nil {
- return 0, types.NewWrappedError(errorMessage, errors.New("no OAuth listener"))
+ return 0, errors.Errorf("failed to get listener port")
}
return oauth.session.Listener.Addr().(*net.TCPAddr).Port, nil
}
// AuthURL gets the authorization url to start the OAuth procedure.
func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (string, error) {
- errorMessage := "failed starting OAuth exchange"
-
// Generate the verifier and challenge
- verifier, verifierErr := genVerifier()
- if verifierErr != nil {
- return "", types.NewWrappedError(errorMessage, verifierErr)
+ v, err := genVerifier()
+ if err != nil {
+ return "", errors.WrapPrefix(err, "genVerifier error", 0)
}
- challenge := genChallengeS256(verifier)
// Generate the state
- state, stateErr := genState()
- if stateErr != nil {
- return "", types.NewWrappedError(errorMessage, stateErr)
+ state, err := genState()
+ if err != nil {
+ return "", errors.WrapPrefix(err, "genState error", 0)
}
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
- oauthSession := ExchangeSession{
+ oauth.session = ExchangeSession{
ClientID: name,
ISS: oauth.ISS,
State: state,
- Verifier: verifier,
+ Verifier: v,
}
- oauth.session = oauthSession
// set up the listener to get the redirect URI
- listenerErr := oauth.setupListener()
- if listenerErr != nil {
- return "", types.NewWrappedError(errorMessage, stateErr)
+ if err = oauth.setupListener(); err != nil {
+ return "", errors.WrapPrefix(err, "oauth.setupListener error", 0)
}
// Get the listener port
- port, portErr := oauth.ListenerPort()
- if portErr != nil {
- return "", types.NewWrappedError(errorMessage, portErr)
+ port, err := oauth.ListenerPort()
+ if err != nil {
+ return "", errors.WrapPrefix(err, "oauth.ListenerPort error", 0)
}
- parameters := map[string]string{
+ params := map[string]string{
"client_id": name,
"code_challenge_method": "S256",
- "code_challenge": challenge,
+ "code_challenge": genChallengeS256(v),
"response_type": "code",
"scope": "config",
"state": state,
"redirect_uri": fmt.Sprintf("http://127.0.0.1:%d/callback", port),
}
- authURL, urlErr := httpw.ConstructURL(oauth.BaseAuthorizationURL, parameters)
+ u, err := httpw.ConstructURL(oauth.BaseAuthorizationURL, params)
- if urlErr != nil {
- return "", types.NewWrappedError(errorMessage, urlErr)
+ if err != nil {
+ return "", errors.WrapPrefix(err, "httpw.ConstructURL error", 0)
}
// Return the url processed
- return postProcessAuth(authURL), nil
+ return postProcessAuth(u), nil
}
// Exchange starts the OAuth exchange by getting the tokens with the redirect callback
// If it was unsuccessful it returns an error.
func (oauth *OAuth) Exchange() error {
- tokenErr := oauth.tokensWithCallback()
-
- if tokenErr != nil {
- return types.NewWrappedError("failed finishing OAuth", tokenErr)
- }
- return nil
+ return oauth.tokensWithCallback()
}
// Cancel cancels the existing OAuth
// TODO: Use context for this.
func (oauth *OAuth) Cancel() {
- oauth.session.CallbackError = types.NewWrappedErrorLevel(
- types.ErrInfo,
- "cancelled OAuth",
- &CancelledCallbackError{},
- )
+ oauth.session.CallbackError = errors.Wrap(&CancelledCallbackError{}, 0)
if oauth.session.Server != nil {
- oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
+ _ = oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
}
}
@@ -554,33 +484,6 @@ func (e *CancelledCallbackError) Error() string {
return "client cancelled OAuth"
}
-type CallbackParameterError struct {
- Parameter string
- URL string
-}
-
-func (e *CallbackParameterError) Error() string {
- return fmt.Sprintf("failed retrieving parameter: %s in url: %s", e.Parameter, e.URL)
-}
-
-type CallbackStateMatchError struct {
- State string
- ExpectedState string
-}
-
-func (e *CallbackStateMatchError) Error() string {
- return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState)
-}
-
-type CallbackISSMatchError struct {
- ISS string
- ExpectedISS string
-}
-
-func (e *CallbackISSMatchError) Error() string {
- return fmt.Sprintf("failed matching ISS, got: %s, want: %s", e.ISS, e.ExpectedISS)
-}
-
type TokensInvalidError struct {
Cause string
}
diff --git a/internal/server/api.go b/internal/server/api.go
index 21ba6f4..dfa8e14 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -2,7 +2,6 @@ package server
import (
"encoding/json"
- "errors"
"fmt"
"net/http"
"net/url"
@@ -10,130 +9,114 @@ import (
"time"
httpw "github.com/eduvpn/eduvpn-common/internal/http"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
func APIGetEndpoints(baseURL string) (*Endpoints, error) {
- errorMessage := "failed getting server endpoints"
- url, urlErr := url.Parse(baseURL)
- if urlErr != nil {
- return nil, types.NewWrappedError(errorMessage, urlErr)
+ u, err := url.Parse(baseURL)
+ if err != nil {
+ return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
- wellKnownPath := "/.well-known/vpn-user-portal"
+ wk := "/.well-known/vpn-user-portal"
- url.Path = path.Join(url.Path, wellKnownPath)
- _, body, bodyErr := httpw.Get(url.String())
-
- if bodyErr != nil {
- return nil, types.NewWrappedError(errorMessage, bodyErr)
+ u.Path = path.Join(u.Path, wk)
+ _, body, err := httpw.Get(u.String())
+ if err != nil {
+ return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
- endpoints := &Endpoints{}
- jsonErr := json.Unmarshal(body, endpoints)
-
- if jsonErr != nil {
- return nil, types.NewWrappedError(errorMessage, jsonErr)
+ ep := &Endpoints{}
+ if err = json.Unmarshal(body, ep); err != nil {
+ return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
- return endpoints, nil
+ return ep, nil
}
func apiAuthorized(
- server Server,
+ srv Server,
method string,
endpoint string,
opts *httpw.OptionalParams,
) (http.Header, []byte, error) {
- errorMessage := "failed API authorized"
// Ensure optional is not nil as we will fill it with headers
if opts == nil {
opts = &httpw.OptionalParams{}
}
- base, baseErr := server.Base()
-
- if baseErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, baseErr)
+ b, err := srv.Base()
+ if err != nil {
+ return nil, nil, errors.WrapPrefix(err, "failed API authorized", 0)
}
// Join the paths
- url, urlErr := url.Parse(base.Endpoints.API.V3.API)
- if urlErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, urlErr)
+ u, err := url.Parse(b.Endpoints.API.V3.API)
+ if err != nil {
+ return nil, nil, errors.WrapPrefix(err, "failed API authorized", 0)
}
- url.Path = path.Join(url.Path, endpoint)
+ u.Path = path.Join(u.Path, endpoint)
// Make sure the tokens are valid, this will return an error if re-login is needed
- token, tokenErr := HeaderToken(server)
- if tokenErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, tokenErr)
+ t, err := HeaderToken(srv)
+ if err != nil {
+ return nil, nil, errors.WrapPrefix(err, "failed API authorized", 0)
}
- headerKey := "Authorization"
- headerValue := fmt.Sprintf("Bearer %s", token)
+ key := "Authorization"
+ val := fmt.Sprintf("Bearer %s", t)
if opts.Headers != nil {
- opts.Headers.Add(headerKey, headerValue)
+ opts.Headers.Add(key, val)
} else {
- opts.Headers = http.Header{headerKey: {headerValue}}
+ opts.Headers = http.Header{key: {val}}
}
- return httpw.MethodWithOpts(method, url.String(), opts)
+ return httpw.MethodWithOpts(method, u.String(), opts)
}
func apiAuthorizedRetry(
- server Server,
+ srv Server,
method string,
endpoint string,
opts *httpw.OptionalParams,
) (http.Header, []byte, error) {
- errorMessage := "failed authorized API retry"
- header, body, bodyErr := apiAuthorized(server, method, endpoint, opts)
-
- if bodyErr != nil {
- var error *httpw.StatusError
-
- // Only retry authorized if we get a HTTP 401
- if errors.As(bodyErr, &error) && error.Status == 401 {
- // Mark the token as expired and retry so we trigger the refresh flow
- MarkTokenExpired(server)
- retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts)
- if retryErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, retryErr)
- }
- return retryHeader, retryBody, nil
- }
- return nil, nil, types.NewWrappedError(errorMessage, bodyErr)
- }
- return header, body, nil
-}
-
-func APIInfo(server Server) error {
- errorMessage := "failed API /info"
- _, body, bodyErr := apiAuthorizedRetry(server, http.MethodGet, "/info", nil)
- if bodyErr != nil {
- return types.NewWrappedError(errorMessage, bodyErr)
+ h, body, err := apiAuthorized(srv, method, endpoint, opts)
+ if err == nil {
+ return h, body, nil
}
- structure := ProfileInfo{}
- jsonErr := json.Unmarshal(body, &structure)
- if jsonErr != nil {
- return types.NewWrappedError(errorMessage, jsonErr)
+ statErr := &httpw.StatusError{}
+ // Only retry authorized if we get an HTTP 401
+ if errors.As(err, &statErr) && statErr.Status == 401 {
+ // Mark the token as expired and retry, so we trigger the refresh flow
+ MarkTokenExpired(srv)
+ h, body, err = apiAuthorized(srv, method, endpoint, opts)
}
+ return h, body, err
+}
- base, baseErr := server.Base()
+func APIInfo(srv Server) error {
+ _, body, err := apiAuthorizedRetry(srv, http.MethodGet, "/info", nil)
+ if err != nil {
+ return err
+ }
+ pi := ProfileInfo{}
+ if err = json.Unmarshal(body, &pi); err != nil {
+ return errors.WrapPrefix(err, "failed API /info", 0)
+ }
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
+ b, err := srv.Base()
+ if err != nil {
+ return err
}
// Store the profiles and make sure that the current profile is not overwritten
- previousProfile := base.Profiles.Current
- base.Profiles = structure
- base.Profiles.Current = previousProfile
+ prev := b.Profiles.Current
+ b.Profiles = pi
+ b.Profiles.Current = prev
return nil
}
// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1
-func GetPreferTCPString(preferTCP bool) string {
+func boolToYesNo(preferTCP bool) string {
if preferTCP {
return "yes"
}
@@ -141,88 +124,77 @@ func GetPreferTCPString(preferTCP bool) string {
}
func APIConnectWireguard(
- server Server,
+ srv Server,
profileID string,
pubkey string,
preferTCP bool,
- supportsOpenVPN bool,
+ openVPNSupport bool,
) (string, string, time.Time, error) {
- errorMessage := "failed obtaining a WireGuard configuration"
- headers := http.Header{
+ hdrs := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
"accept": {"application/x-wireguard-profile"},
}
// This profile also supports OpenVPN
// Indicate that we also accept OpenVPN profiles
- if supportsOpenVPN {
- headers.Add("accept", "application/x-openvpn-profile")
+ if openVPNSupport {
+ hdrs.Add("accept", "application/x-openvpn-profile")
}
- urlForm := url.Values{
+ vals := url.Values{
"profile_id": {profileID},
"public_key": {pubkey},
- "prefer_tcp": {GetPreferTCPString(preferTCP)},
+ "prefer_tcp": {boolToYesNo(preferTCP)},
}
- header, connectBody, connectErr := apiAuthorizedRetry(
- server,
- http.MethodPost,
- "/connect",
- &httpw.OptionalParams{Headers: headers, Body: urlForm},
- )
- if connectErr != nil {
- return "", "", time.Time{}, types.NewWrappedError(
- errorMessage,
- connectErr,
- )
+ h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect",
+ &httpw.OptionalParams{Headers: hdrs, Body: vals})
+ if err != nil {
+ return "", "", time.Time{}, err
}
- expires := header.Get("expires")
+ exp := h.Get("expires")
- pTime, pTimeErr := http.ParseTime(expires)
- if pTimeErr != nil {
- return "", "", time.Time{}, types.NewWrappedError(errorMessage, pTimeErr)
+ ptm, err := http.ParseTime(exp)
+ if err != nil {
+ return "", "", time.Time{}, errors.WrapPrefix(err, "failed obtaining a WireGuard configuration", 0)
}
- contentType := header.Get("content-type")
-
- content := "openvpn"
- if contentType == "application/x-wireguard-profile" {
- content = "wireguard"
+ ct := h.Get("content-type")
+ c := "openvpn"
+ if ct == "application/x-wireguard-profile" {
+ c = "wireguard"
}
- return string(connectBody), content, pTime, nil
+
+ return string(body), c, ptm, nil
}
-func APIConnectOpenVPN(server Server, profileID string, preferTCP bool) (string, time.Time, error) {
- errorMessage := "failed obtaining an OpenVPN configuration"
- headers := http.Header{
+func APIConnectOpenVPN(srv Server, profileID string, preferTCP bool) (string, time.Time, error) {
+ hdrs := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
"accept": {"application/x-openvpn-profile"},
}
- urlForm := url.Values{
+ vals := url.Values{
"profile_id": {profileID},
- "prefer_tcp": {GetPreferTCPString(preferTCP)},
+ "prefer_tcp": {boolToYesNo(preferTCP)},
}
- header, connectBody, connectErr := apiAuthorizedRetry(
- server,
- http.MethodPost,
- "/connect",
- &httpw.OptionalParams{Headers: headers, Body: urlForm},
- )
- if connectErr != nil {
- return "", time.Time{}, types.NewWrappedError(errorMessage, connectErr)
+ h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect",
+ &httpw.OptionalParams{Headers: hdrs, Body: vals})
+ if err != nil {
+ return "", time.Time{}, err
}
- expires := header.Get("expires")
- pTime, pTimeErr := http.ParseTime(expires)
- if pTimeErr != nil {
- return "", time.Time{}, types.NewWrappedError(errorMessage, pTimeErr)
+ exp := h.Get("expires")
+ ptm, err := http.ParseTime(exp)
+ if err != nil {
+ return "", time.Time{}, errors.WrapPrefix(err, "failed obtaining an OpenVPN configuration", 0)
}
- return string(connectBody), pTime, nil
+
+ return string(body), ptm, nil
}
+// APIDisconnect disconnects from the API.
// This needs no further return value as it's best effort.
func APIDisconnect(server Server) {
_, _, _ = apiAuthorized(server, http.MethodPost, "/disconnect", nil)
diff --git a/internal/server/base.go b/internal/server/base.go
index bb88eb3..81049cf 100644
--- a/internal/server/base.go
+++ b/internal/server/base.go
@@ -2,11 +2,9 @@ package server
import (
"time"
-
- "github.com/eduvpn/eduvpn-common/types"
)
-// The base type for servers.
+// Base is the base type for servers.
type Base struct {
URL string `json:"base_url"`
DisplayName map[string]string `json:"display_name"`
@@ -18,28 +16,27 @@ type Base struct {
Type string `json:"server_type"`
}
-func (base *Base) InitializeEndpoints() error {
- errorMessage := "failed initializing endpoints"
- endpoints, endpointsErr := APIGetEndpoints(base.URL)
- if endpointsErr != nil {
- return types.NewWrappedError(errorMessage, endpointsErr)
+func (b *Base) InitializeEndpoints() error {
+ ep, err := APIGetEndpoints(b.URL)
+ if err != nil {
+ return err
}
- base.Endpoints = *endpoints
+ b.Endpoints = *ep
return nil
}
-func (base *Base) ValidProfiles(clientSupportsWireguard bool) ProfileInfo {
- var validProfiles []Profile
- for _, profile := range base.Profiles.Info.ProfileList {
+func (b *Base) ValidProfiles(wireguardSupport bool) ProfileInfo {
+ var vps []Profile
+ for _, p := range b.Profiles.Info.ProfileList {
// Not a valid profile because it does not support openvpn
// Also the client does not support wireguard
- if !profile.supportsOpenVPN() && !clientSupportsWireguard {
+ if !p.supportsOpenVPN() && !wireguardSupport {
continue
}
- validProfiles = append(validProfiles, profile)
+ vps = append(vps, p)
}
return ProfileInfo{
- Current: base.Profiles.Current,
- Info: ProfileListInfo{ProfileList: validProfiles},
+ Current: b.Profiles.Current,
+ Info: ProfileListInfo{ProfileList: vps},
}
}
diff --git a/internal/server/custom.go b/internal/server/custom.go
index d376727..bf0b230 100644
--- a/internal/server/custom.go
+++ b/internal/server/custom.go
@@ -1,42 +1,35 @@
package server
import (
- "errors"
- "fmt"
-
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
-func (servers *Servers) SetCustomServer(server Server) error {
- errorMessage := "failed setting custom server"
- base, baseErr := server.Base()
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
+func (ss *Servers) SetCustomServer(server Server) error {
+ b, err := server.Base()
+ if err != nil {
+ return err
}
- if base.Type != "custom_server" {
- return types.NewWrappedError(errorMessage, errors.New("not a custom server"))
+ if b.Type != "custom_server" {
+ return errors.WrapPrefix(err, "not a custom server", 0)
}
- if _, ok := servers.CustomServers.Map[base.URL]; ok {
- servers.CustomServers.CurrentURL = base.URL
- servers.IsType = CustomServerType
+ if _, ok := ss.CustomServers.Map[b.URL]; ok {
+ ss.CustomServers.CurrentURL = b.URL
+ ss.IsType = CustomServerType
} else {
- return types.NewWrappedError(errorMessage, errors.New("not a custom server"))
+ return errors.Errorf("not a custom server")
}
return nil
}
-func (servers *Servers) GetCustomServer(url string) (*InstituteAccessServer, error) {
- if server, ok := servers.CustomServers.Map[url]; ok {
- return server, nil
+func (ss *Servers) GetCustomServer(url string) (*InstituteAccessServer, error) {
+ if srv, ok := ss.CustomServers.Map[url]; ok {
+ return srv, nil
}
- return nil, types.NewWrappedError(
- "failed to get institute access server",
- fmt.Errorf("no custom server with URL: %s", url),
- )
+ return nil, errors.Errorf("failed to get institute access server - no custom server with URL '%s'", url)
}
-func (servers *Servers) RemoveCustomServer(url string) {
- servers.CustomServers.Remove(url)
+func (ss *Servers) RemoveCustomServer(url string) {
+ ss.CustomServers.Remove(url)
}
diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go
index 9b6f735..56ed1cf 100644
--- a/internal/server/instituteaccess.go
+++ b/internal/server/instituteaccess.go
@@ -1,14 +1,10 @@
package server
import (
- "errors"
- "fmt"
-
"github.com/eduvpn/eduvpn-common/internal/oauth"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
-// An instute access server.
type InstituteAccessServer struct {
// An instute access server has its own OAuth
Auth oauth.OAuth `json:"oauth"`
@@ -22,80 +18,75 @@ type InstituteAccessServers struct {
CurrentURL string `json:"current_url"`
}
-func (servers *Servers) SetInstituteAccess(server Server) error {
- errorMessage := "failed setting institute access server"
- base, baseErr := server.Base()
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
+func (ss *Servers) SetInstituteAccess(srv Server) error {
+ b, err := srv.Base()
+ if err != nil {
+ return err
}
- if base.Type != "institute_access" {
- return types.NewWrappedError(errorMessage, errors.New("not an institute access server"))
+ if b.Type != "institute_access" {
+ return errors.Errorf("not an institute access server")
}
- if _, ok := servers.InstituteServers.Map[base.URL]; ok {
- servers.InstituteServers.CurrentURL = base.URL
- servers.IsType = InstituteAccessServerType
+ if _, ok := ss.InstituteServers.Map[b.URL]; ok {
+ ss.InstituteServers.CurrentURL = b.URL
+ ss.IsType = InstituteAccessServerType
} else {
- return types.NewWrappedError(errorMessage, errors.New("no such institute access server"))
+ return errors.Errorf("no such institute access server")
}
return nil
}
-func (servers *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, error) {
- if server, ok := servers.InstituteServers.Map[url]; ok {
- return server, nil
+func (ss *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, error) {
+ if srv, ok := ss.InstituteServers.Map[url]; ok {
+ return srv, nil
}
- return nil, types.NewWrappedError(
- "failed to get institute access server",
- fmt.Errorf("no institute access server with URL: %s", url),
- )
+ return nil, errors.Errorf("no institute access server with URL: %s", url)
}
-func (servers *Servers) RemoveInstituteAccess(url string) {
- servers.InstituteServers.Remove(url)
+func (ss *Servers) RemoveInstituteAccess(url string) {
+ ss.InstituteServers.Remove(url)
}
-func (servers *InstituteAccessServers) Remove(url string) {
+func (iass *InstituteAccessServers) Remove(url string) {
// Reset the current url
- if servers.CurrentURL == url {
- servers.CurrentURL = ""
+ if iass.CurrentURL == url {
+ iass.CurrentURL = ""
}
// Delete the url from the map
- delete(servers.Map, url)
+ delete(iass.Map, url)
}
-func (institute *InstituteAccessServer) TemplateAuth() func(string) string {
+func (ias *InstituteAccessServer) TemplateAuth() func(string) string {
return func(authURL string) string {
return authURL
}
}
-func (institute *InstituteAccessServer) Base() (*Base, error) {
- return &institute.Basic, nil
+func (ias *InstituteAccessServer) Base() (*Base, error) {
+ return &ias.Basic, nil
}
-func (institute *InstituteAccessServer) OAuth() *oauth.OAuth {
- return &institute.Auth
+func (ias *InstituteAccessServer) OAuth() *oauth.OAuth {
+ return &ias.Auth
}
-func (institute *InstituteAccessServer) init(
+func (ias *InstituteAccessServer) init(
url string,
- displayName map[string]string,
- serverType string,
+ name map[string]string,
+ srvType string,
supportContact []string,
) error {
- errorMessage := fmt.Sprintf("failed initializing server %s", url)
- institute.Basic.URL = url
- institute.Basic.DisplayName = displayName
- institute.Basic.SupportContact = supportContact
- institute.Basic.Type = serverType
- endpointsErr := institute.Basic.InitializeEndpoints()
- if endpointsErr != nil {
- return types.NewWrappedError(errorMessage, endpointsErr)
+ ias.Basic.URL = url
+ ias.Basic.DisplayName = name
+ ias.Basic.SupportContact = supportContact
+ ias.Basic.Type = srvType
+ err := ias.Basic.InitializeEndpoints()
+ if err != nil {
+ return err
}
- API := institute.Basic.Endpoints.API.V3
- institute.Auth.Init(url, API.Authorization, API.Token)
+ API := ias.Basic.Endpoints.API.V3
+ ias.Auth.Init(url, API.Authorization, API.Token)
return nil
}
diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go
index 998390d..12263a6 100644
--- a/internal/server/secureinternet.go
+++ b/internal/server/secureinternet.go
@@ -1,15 +1,13 @@
package server
import (
- "errors"
- "fmt"
-
"github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/eduvpn/eduvpn-common/internal/util"
"github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
-// A secure internet server which has its own OAuth tokens
+// SecureInternetHomeServer secure internet server which has its own OAuth tokens
// It specifies the current location url it is connected to.
type SecureInternetHomeServer struct {
Auth oauth.OAuth `json:"oauth"`
@@ -24,150 +22,111 @@ type SecureInternetHomeServer struct {
CurrentLocation string `json:"current_location"`
}
-func (servers *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer, error) {
- if !servers.HasSecureLocation() {
- return nil, errors.New("no secure internet home server")
+func (ss *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer, error) {
+ if !ss.HasSecureLocation() {
+ return nil, errors.Errorf("no secure internet home server")
}
- return &servers.SecureInternetHomeServer, nil
+ return &ss.SecureInternetHomeServer, nil
}
-func (servers *Servers) SetSecureInternet(server Server) error {
- errorMessage := "failed setting secure internet server"
- base, baseErr := server.Base()
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
+func (ss *Servers) SetSecureInternet(server Server) error {
+ b, err := server.Base()
+ if err != nil {
+ return err
}
- if base.Type != "secure_internet" {
- return types.NewWrappedError(errorMessage, errors.New("not a secure internet server"))
+ if b.Type != "secure_internet" {
+ return errors.Errorf("not a secure internet server")
}
// The location should already be configured
// TODO: check for location?
- servers.IsType = SecureInternetServerType
+ ss.IsType = SecureInternetServerType
return nil
}
-func (servers *Servers) RemoveSecureInternet() {
+func (ss *Servers) RemoveSecureInternet() {
// Empty out the struct
- servers.SecureInternetHomeServer = SecureInternetHomeServer{}
+ ss.SecureInternetHomeServer = SecureInternetHomeServer{}
// If the current server is secure internet, default to custom server
- if servers.IsType == SecureInternetServerType {
- servers.IsType = CustomServerType
+ if ss.IsType == SecureInternetServerType {
+ ss.IsType = CustomServerType
}
}
-func (server *SecureInternetHomeServer) TemplateAuth() func(string) string {
+func (s *SecureInternetHomeServer) TemplateAuth() func(string) string {
return func(authURL string) string {
- return util.ReplaceWAYF(server.AuthorizationTemplate, authURL, server.HomeOrganizationID)
+ return util.ReplaceWAYF(s.AuthorizationTemplate, authURL, s.HomeOrganizationID)
}
}
-func (server *SecureInternetHomeServer) Base() (*Base, error) {
- errorMessage := "failed getting current secure internet home base"
- if server.BaseMap == nil {
- return nil, types.NewWrappedError(
- errorMessage,
- &SecureInternetMapNotFoundError{},
- )
+func (s *SecureInternetHomeServer) Base() (*Base, error) {
+ if s.BaseMap == nil {
+ return nil, errors.Errorf("secure internet map not found")
}
- base, exists := server.BaseMap[server.CurrentLocation]
-
- if !exists {
- return nil, types.NewWrappedError(
- errorMessage,
- &SecureInternetBaseNotFoundError{Current: server.CurrentLocation},
- )
+ b, ok := s.BaseMap[s.CurrentLocation]
+ if !ok {
+ return nil, errors.Errorf("secure internet base with location '%s' not found", s.CurrentLocation)
}
- return base, nil
+ return b, nil
}
-func (server *SecureInternetHomeServer) OAuth() *oauth.OAuth {
- return &server.Auth
+func (s *SecureInternetHomeServer) OAuth() *oauth.OAuth {
+ return &s.Auth
}
-func (servers *Servers) HasSecureLocation() bool {
- return servers.SecureInternetHomeServer.CurrentLocation != ""
+func (ss *Servers) HasSecureLocation() bool {
+ return ss.SecureInternetHomeServer.CurrentLocation != ""
}
-func (server *SecureInternetHomeServer) addLocation(
- locationServer *types.DiscoveryServer,
-) (*Base, error) {
- errorMessage := "failed adding a location"
+func (s *SecureInternetHomeServer) addLocation(locSrv *types.DiscoveryServer) (*Base, error) {
// Initialize the base map if it is non-nil
- if server.BaseMap == nil {
- server.BaseMap = make(map[string]*Base)
+ if s.BaseMap == nil {
+ s.BaseMap = make(map[string]*Base)
}
// Add the location to the base map
- base, exists := server.BaseMap[locationServer.CountryCode]
-
- if !exists || base == nil {
+ b, ok := s.BaseMap[locSrv.CountryCode]
+ if !ok || b == nil {
// Create the base to be added to the map
- base = &Base{}
- base.URL = locationServer.BaseURL
- base.DisplayName = server.DisplayName
- base.SupportContact = locationServer.SupportContact
- base.Type = "secure_internet"
- endpointsErr := base.InitializeEndpoints()
- if endpointsErr != nil {
- return nil, types.NewWrappedError(errorMessage, endpointsErr)
+ b = &Base{}
+ b.URL = locSrv.BaseURL
+ b.DisplayName = s.DisplayName
+ b.SupportContact = locSrv.SupportContact
+ b.Type = "secure_internet"
+ if err := b.InitializeEndpoints(); err != nil {
+ return nil, err
}
}
// Ensure it is in the map
- server.BaseMap[locationServer.CountryCode] = base
- return base, nil
+ s.BaseMap[locSrv.CountryCode] = b
+ return b, nil
}
// Initializes the home server and adds its own location.
-func (server *SecureInternetHomeServer) init(
- homeOrg *types.DiscoveryOrganization,
- homeLocation *types.DiscoveryServer,
-) error {
- errorMessage := "failed initializing secure internet home server"
-
- if server.HomeOrganizationID != homeOrg.OrgID {
+func (s *SecureInternetHomeServer) init(
+ homeOrg *types.DiscoveryOrganization, homeLoc *types.DiscoveryServer) error {
+ if s.HomeOrganizationID != homeOrg.OrgID {
// New home organisation, clear everything
- *server = SecureInternetHomeServer{}
+ *s = SecureInternetHomeServer{}
}
// Make sure to set the organization ID
- server.HomeOrganizationID = homeOrg.OrgID
- server.DisplayName = homeOrg.DisplayName
+ s.HomeOrganizationID = homeOrg.OrgID
+ s.DisplayName = homeOrg.DisplayName
// Make sure to set the authorization URL template
- server.AuthorizationTemplate = homeLocation.AuthenticationURLTemplate
+ s.AuthorizationTemplate = homeLoc.AuthenticationURLTemplate
- base, baseErr := server.addLocation(homeLocation)
-
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
+ b, err := s.addLocation(homeLoc)
+ if err != nil {
+ return err
}
// Make sure oauth contains our endpoints
- server.Auth.Init(base.URL, base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token)
+ s.Auth.Init(b.URL, b.Endpoints.API.V3.Authorization, b.Endpoints.API.V3.Token)
return nil
}
-
-type SecureInternetHomeNotFoundError struct{}
-
-func (e *SecureInternetHomeNotFoundError) Error() string {
- return "failed to get secure internet home server, not found"
-}
-
-type SecureInternetMapNotFoundError struct{}
-
-func (e *SecureInternetMapNotFoundError) Error() string {
- return "secure internet map not found"
-}
-
-type SecureInternetBaseNotFoundError struct {
- Current string
-}
-
-func (e *SecureInternetBaseNotFoundError) Error() string {
- return fmt.Sprintf("secure internet base not found with current location: %s", e.Current)
-}
diff --git a/internal/server/server.go b/internal/server/server.go
index 95244d5..de0fa9a 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -1,13 +1,11 @@
package server
import (
- "errors"
- "fmt"
"time"
"github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/eduvpn/eduvpn-common/internal/wireguard"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
type Type int8
@@ -21,10 +19,10 @@ const (
type Server interface {
OAuth() *oauth.OAuth
- // Get the authorization URL template function
+ // TemplateAuth returns the authorization URL template function
TemplateAuth() func(string) string
- // Gets the server base
+ // Base returns the server base
Base() (*Base, error)
}
@@ -34,7 +32,7 @@ type EndpointList struct {
Token string `json:"token_endpoint"`
}
-// Struct that defines the json format for /.well-known/vpn-user-portal".
+// Endpoints defines the json format for /.well-known/vpn-user-portal".
type Endpoints struct {
API struct {
V2 EndpointList `json:"http://eduvpn.org/api#2"`
@@ -43,310 +41,226 @@ type Endpoints struct {
V string `json:"v"`
}
-func ShouldRenewButton(server Server) bool {
- base, baseErr := server.Base()
-
- if baseErr != nil {
+func ShouldRenewButton(srv Server) bool {
+ b, err := srv.Base()
+ if err != nil {
// FIXME: Log error here?
return false
}
// Get current time
- current := time.Now()
+ now := time.Now()
// Session is expired
- if !current.Before(base.EndTime) {
+ if !now.Before(b.EndTime) {
return true
}
// 30 minutes have not passed
- if !current.After(base.StartTime.Add(30 * time.Minute)) {
+ if !now.After(b.StartTime.Add(30 * time.Minute)) {
return false
}
// Session will not expire today
- if !current.Add(24 * time.Hour).After(base.EndTime) {
+ if !now.Add(24 * time.Hour).After(b.EndTime) {
return false
}
// Session duration is less than 24 hours but not 75% has passed
- duration := base.EndTime.Sub(base.StartTime)
- percentTime := base.StartTime.Add((duration / 4) * 3)
- if duration < time.Duration(24*time.Hour) && !current.After(percentTime) {
+ d := b.EndTime.Sub(b.StartTime)
+ pct := b.StartTime.Add((d / 4) * 3)
+ if d < 24*time.Hour && !now.After(pct) {
return false
}
return true
}
-func OAuthURL(server Server, name string) (string, error) {
- return server.OAuth().AuthURL(name, server.TemplateAuth())
+func OAuthURL(srv Server, name string) (string, error) {
+ return srv.OAuth().AuthURL(name, srv.TemplateAuth())
}
-func OAuthExchange(server Server) error {
- return server.OAuth().Exchange()
+func OAuthExchange(srv Server) error {
+ return srv.OAuth().Exchange()
}
-func HeaderToken(server Server) (string, error) {
- token, tokenErr := server.OAuth().AccessToken()
- if tokenErr != nil {
- return "", types.NewWrappedError("failed getting server token for HTTP Header", tokenErr)
- }
- return token, nil
+func HeaderToken(srv Server) (string, error) {
+ return srv.OAuth().AccessToken()
}
-func MarkTokenExpired(server Server) {
- server.OAuth().SetTokenExpired()
+func MarkTokenExpired(srv Server) {
+ srv.OAuth().SetTokenExpired()
}
-func MarkTokensForRenew(server Server) {
- server.OAuth().SetTokenRenew()
+func MarkTokensForRenew(srv Server) {
+ srv.OAuth().SetTokenRenew()
}
-func NeedsRelogin(server Server) bool {
- _, tokenErr := HeaderToken(server)
- return tokenErr != nil
+func NeedsRelogin(srv Server) bool {
+ _, err := HeaderToken(srv)
+ return err != nil
}
-func CancelOAuth(server Server) {
- server.OAuth().Cancel()
+func CancelOAuth(srv Server) {
+ srv.OAuth().Cancel()
}
-func CurrentProfile(server Server) (*Profile, error) {
- errorMessage := "failed getting current profile"
- base, baseErr := server.Base()
-
- if baseErr != nil {
- return nil, types.NewWrappedError(errorMessage, baseErr)
+func CurrentProfile(srv Server) (*Profile, error) {
+ b, err := srv.Base()
+ if err != nil {
+ return nil, err
}
- profileID := base.Profiles.Current
- for _, profile := range base.Profiles.Info.ProfileList {
- if profile.ID == profileID {
+ pid := b.Profiles.Current
+ for _, profile := range b.Profiles.Info.ProfileList {
+ if profile.ID == pid {
return &profile, nil
}
}
- return nil, types.NewWrappedError(
- errorMessage,
- &CurrentProfileNotFoundError{ProfileID: profileID},
- )
+ return nil, errors.Errorf("profile not found: " + pid)
}
-func ValidProfiles(server Server, clientSupportsWireguard bool) (*ProfileInfo, error) {
- errorMessage := "failed to get valid profiles"
+func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) {
// No error wrapping here otherwise we wrap it too much
- base, baseErr := server.Base()
- if baseErr != nil {
- return nil, types.NewWrappedError(errorMessage, baseErr)
+ b, err := srv.Base()
+ if err != nil {
+ return nil, err
}
- profiles := base.ValidProfiles(clientSupportsWireguard)
- if len(profiles.Info.ProfileList) == 0 {
- return nil, types.NewWrappedError(
- errorMessage,
- errors.New("no profiles found with supported protocols"),
- )
+ ps := b.ValidProfiles(wireguardSupport)
+ if len(ps.Info.ProfileList) == 0 {
+ return nil, errors.Errorf("no profiles found with supported protocols")
}
- return &profiles, nil
+ return &ps, nil
}
-func wireguardGetConfig(
- server Server,
- preferTCP bool,
- supportsOpenVPN bool,
-) (string, string, error) {
- errorMessage := "failed getting server WireGuard configuration"
- base, baseErr := server.Base()
-
- if baseErr != nil {
- return "", "", types.NewWrappedError(errorMessage, baseErr)
+func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string, string, error) {
+ b, err := srv.Base()
+ if err != nil {
+ return "", "", err
}
- profileID := base.Profiles.Current
- wireguardKey, wireguardErr := wireguard.GenerateKey()
-
- if wireguardErr != nil {
- return "", "", types.NewWrappedError(errorMessage, wireguardErr)
+ pid := b.Profiles.Current
+ key, err := wireguard.GenerateKey()
+ if err != nil {
+ return "", "", err
}
- wireguardPublicKey := wireguardKey.PublicKey().String()
- config, content, expires, configErr := APIConnectWireguard(
- server,
- profileID,
- wireguardPublicKey,
- preferTCP,
- supportsOpenVPN,
- )
-
- if configErr != nil {
- return "", "", types.NewWrappedError(errorMessage, configErr)
+ pub := key.PublicKey().String()
+ cfg, ct, exp, err := APIConnectWireguard(srv, pid, pub, preferTCP, openVPNSupport)
+ if err != nil {
+ return "", "", err
}
// Store start and end time
- base.StartTime = time.Now()
- base.EndTime = expires
+ b.StartTime = time.Now()
+ b.EndTime = exp
- if content == "wireguard" {
+ if ct == "wireguard" {
// This needs the go code a way to identify a connection
// Use the uuid of the connection e.g. on Linux
// This needs the client code to call the go code
- config = wireguard.ConfigAddKey(config, wireguardKey)
+ cfg = wireguard.ConfigAddKey(cfg, key)
}
- return config, content, nil
+ return cfg, ct, nil
}
-func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) {
- errorMessage := "failed getting server OpenVPN configuration"
- base, baseErr := server.Base()
-
- if baseErr != nil {
- return "", "", types.NewWrappedError(errorMessage, baseErr)
+func openVPNGetConfig(srv Server, preferTCP bool) (string, string, error) {
+ b, err := srv.Base()
+ if err != nil {
+ return "", "", err
}
- profileID := base.Profiles.Current
- configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profileID, preferTCP)
+ pid := b.Profiles.Current
+ cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP)
// Store start and end time
- base.StartTime = time.Now()
- base.EndTime = expires
+ b.StartTime = time.Now()
+ b.EndTime = exp
- if configErr != nil {
- return "", "", types.NewWrappedError(errorMessage, configErr)
+ if err != nil {
+ return "", "", err
}
- return configOpenVPN, "openvpn", nil
+ return cfg, "openvpn", nil
}
-func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) {
- errorMessage := "failed has valid profile check"
-
+func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) {
// Get new profiles using the info call
// This does not override the current profile
- infoErr := APIInfo(server)
- if infoErr != nil {
- return false, types.NewWrappedError(errorMessage, infoErr)
+ err := APIInfo(srv)
+ if err != nil {
+ return false, err
}
- base, baseErr := server.Base()
- if baseErr != nil {
- return false, types.NewWrappedError(errorMessage, baseErr)
+ b, err := srv.Base()
+ if err != nil {
+ return false, err
}
// If there was a profile chosen and it doesn't exist anymore, reset it
- if base.Profiles.Current != "" {
- _, existsProfileErr := CurrentProfile(server)
- if existsProfileErr != nil {
- base.Profiles.Current = ""
+ if b.Profiles.Current != "" {
+ if _, err = CurrentProfile(srv); err != nil {
+ b.Profiles.Current = ""
}
}
- // Set the current profile if there is only one profile or profile is already selected
- if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" {
- // Set the first profile if none is selected
- if base.Profiles.Current == "" {
- base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID
- }
- profile, profileErr := CurrentProfile(server)
- // shouldn't happen
- if profileErr != nil {
- return false, types.NewWrappedError(errorMessage, profileErr)
- }
- // Profile does not support OpenVPN but the client also doesn't support WireGuard
- if !profile.supportsOpenVPN() && !clientSupportsWireguard {
- return false, nil
- }
- return true, nil
+ if len(b.Profiles.Info.ProfileList) != 1 && b.Profiles.Current == "" {
+ return false, nil
}
- return false, nil
+ // Set the current profile if there is only one profile or profile is already selected
+ // Set the first profile if none is selected
+ if b.Profiles.Current == "" {
+ b.Profiles.Current = b.Profiles.Info.ProfileList[0].ID
+ }
+ p, err := CurrentProfile(srv)
+ // shouldn't happen
+ if err != nil {
+ return false, err
+ }
+ // Profile does not support OpenVPN but the client also doesn't support WireGuard
+ if !p.supportsOpenVPN() && !wireguardSupport {
+ return false, nil
+ }
+ return true, nil
}
-func RefreshEndpoints(server Server) error {
- errorMessage := "failed to refresh server endpoints"
-
+func RefreshEndpoints(srv Server) error {
// Re-initialize the endpoints
// TODO: Make this a warning instead?
- base, baseErr := server.Base()
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
- }
-
- endpointsErr := base.InitializeEndpoints()
- if endpointsErr != nil {
- return types.NewWrappedError(errorMessage, endpointsErr)
+ b, err := srv.Base()
+ if err != nil {
+ return err
}
- return nil
+ return b.InitializeEndpoints()
}
-func Config(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) {
- errorMessage := "failed getting an OpenVPN/WireGuard configuration"
-
- profile, profileErr := CurrentProfile(server)
- if profileErr != nil {
- return "", "", types.NewWrappedError(errorMessage, profileErr)
+func Config(server Server, wireguardSupport bool, preferTCP bool) (string, string, error) {
+ p, err := CurrentProfile(server)
+ if err != nil {
+ return "", "", err
}
- supportsOpenVPN := profile.supportsOpenVPN()
- supportsWireguard := profile.supportsWireguard() && clientSupportsWireguard
-
- var config string
- var configType string
- var configErr error
+ ovpn := p.supportsOpenVPN()
+ wg := p.supportsWireguard() && wireguardSupport
switch {
// The config supports wireguard and optionally openvpn
- case supportsWireguard:
+ case wg:
// A wireguard connect call needs to generate a wireguard key and add it to the config
// Also the server could send back an OpenVPN config if it supports OpenVPN
- config, configType, configErr = wireguardGetConfig(server, preferTCP, supportsOpenVPN)
+ return wireguardGetConfig(server, preferTCP, ovpn)
// The config only supports OpenVPN
- case supportsOpenVPN:
- config, configType, configErr = openVPNGetConfig(server, preferTCP)
+ case ovpn:
+ return openVPNGetConfig(server, preferTCP)
// The config supports no available protocol because the profile only supports WireGuard but the client doesn't
default:
- return "", "", types.NewWrappedError(errorMessage, errors.New("no supported protocol found"))
+ return "", "", errors.Errorf("no supported protocol found")
}
-
- if configErr != nil {
- return "", "", types.NewWrappedError(errorMessage, configErr)
- }
-
- return config, configType, nil
}
func Disconnect(server Server) {
APIDisconnect(server)
}
-
-type CurrentProfileNotFoundError struct {
- ProfileID string
-}
-
-func (e *CurrentProfileNotFoundError) Error() string {
- return fmt.Sprintf("failed to get current profile, profile with ID: %s not found", e.ProfileID)
-}
-
-type ConfigPreferTCPError struct{}
-
-func (e *ConfigPreferTCPError) Error() string {
- return "failed to get config, prefer TCP is on but the server does not support OpenVPN"
-}
-
-type EmptyURLError struct{}
-
-func (e *EmptyURLError) Error() string {
- return "failed ensuring server, empty url provided"
-}
-
-type CurrentNoMapError struct{}
-
-func (e *CurrentNoMapError) Error() string {
- return "failed getting current server, no servers available"
-}
-
-type CurrentNotFoundError struct{}
-
-func (e *CurrentNotFoundError) Error() string {
- return "failed getting current server, not found"
-}
diff --git a/internal/server/servers.go b/internal/server/servers.go
index a076770..b34dcff 100644
--- a/internal/server/servers.go
+++ b/internal/server/servers.go
@@ -1,9 +1,8 @@
package server
import (
- "fmt"
-
"github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
type Servers struct {
@@ -14,125 +13,105 @@ type Servers struct {
IsType Type `json:"is_secure_internet"`
}
-func (servers *Servers) AddSecureInternet(
+func (ss *Servers) AddSecureInternet(
secureOrg *types.DiscoveryOrganization,
secureServer *types.DiscoveryServer,
) (Server, error) {
- errorMessage := "failed adding secure internet server"
// If we have specified an organization ID
// We also need to get an authorization template
- initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer)
+ err := ss.SecureInternetHomeServer.init(secureOrg, secureServer)
- if initErr != nil {
- return nil, types.NewWrappedError(errorMessage, initErr)
+ if err != nil {
+ return nil, err
}
- servers.IsType = SecureInternetServerType
- return &servers.SecureInternetHomeServer, nil
+ ss.IsType = SecureInternetServerType
+ return &ss.SecureInternetHomeServer, nil
}
-func (servers *Servers) GetCurrentServer() (Server, error) {
- errorMessage := "failed getting current server"
- if servers.IsType == SecureInternetServerType {
- if !servers.HasSecureLocation() {
- return nil, types.NewWrappedError(
- errorMessage,
- &CurrentNotFoundError{},
- )
+func (ss *Servers) GetCurrentServer() (Server, error) {
+ //TODO(jwijenbergh): Almost certainly the return type should be pointer (*Server)
+ if ss.IsType == SecureInternetServerType {
+ if !ss.HasSecureLocation() {
+ return nil, errors.Errorf("ss.IsType = %v; ss.HasSecureLocation() = false", ss.IsType)
}
- return &servers.SecureInternetHomeServer, nil
+ return &ss.SecureInternetHomeServer, nil
}
- serversStruct := &servers.InstituteServers
+ srvs := &ss.InstituteServers
- if servers.IsType == CustomServerType {
- serversStruct = &servers.CustomServers
+ if ss.IsType == CustomServerType {
+ srvs = &ss.CustomServers
}
- currentServerURL := serversStruct.CurrentURL
- bases := serversStruct.Map
- if bases == nil {
- return nil, types.NewWrappedError(
- errorMessage,
- &CurrentNoMapError{},
- )
+ bs := srvs.Map
+ if bs == nil {
+ return nil, errors.Errorf("srvs.Map is nil")
}
- server, exists := bases[currentServerURL]
- if !exists || server == nil {
- return nil, types.NewWrappedError(
- errorMessage,
- &CurrentNotFoundError{},
- )
+ if srv, ok := bs[srvs.CurrentURL]; !ok || srv == nil {
+ return nil, errors.Errorf("server not found")
+ } else {
+ return srv, nil
}
- return server, nil
}
-func (servers *Servers) addInstituteAndCustom(
+func (ss *Servers) addInstituteAndCustom(
discoServer *types.DiscoveryServer,
isCustom bool,
) (Server, error) {
url := discoServer.BaseURL
- errorMessage := fmt.Sprintf("failed adding institute access server: %s", url)
- toAddServers := &servers.InstituteServers
- serverType := InstituteAccessServerType
+ srvs := &ss.InstituteServers
+ srvType := InstituteAccessServerType
if isCustom {
- toAddServers = &servers.CustomServers
- serverType = CustomServerType
+ srvs = &ss.CustomServers
+ srvType = CustomServerType
}
- if toAddServers.Map == nil {
- toAddServers.Map = make(map[string]*InstituteAccessServer)
+ if srvs.Map == nil {
+ srvs.Map = make(map[string]*InstituteAccessServer)
}
- server, exists := toAddServers.Map[url]
+ srv, ok := srvs.Map[url]
// initialize the server if it doesn't exist yet
- if !exists {
- server = &InstituteAccessServer{}
+ if !ok {
+ srv = &InstituteAccessServer{}
}
- instituteInitErr := server.init(
- url,
- discoServer.DisplayName,
- discoServer.Type,
- discoServer.SupportContact,
- )
- if instituteInitErr != nil {
- return nil, types.NewWrappedError(errorMessage, instituteInitErr)
+ if err := srv.init(url, discoServer.DisplayName, discoServer.Type, discoServer.SupportContact); err != nil {
+ return nil, err
}
- toAddServers.Map[url] = server
- servers.IsType = serverType
- return server, nil
+ srvs.Map[url] = srv
+ ss.IsType = srvType
+ return srv, nil
}
-func (servers *Servers) AddInstituteAccessServer(
+func (ss *Servers) AddInstituteAccessServer(
instituteServer *types.DiscoveryServer,
) (Server, error) {
- return servers.addInstituteAndCustom(instituteServer, false)
+ return ss.addInstituteAndCustom(instituteServer, false)
}
-func (servers *Servers) AddCustomServer(
+func (ss *Servers) AddCustomServer(
customServer *types.DiscoveryServer,
) (Server, error) {
- return servers.addInstituteAndCustom(customServer, true)
+ return ss.addInstituteAndCustom(customServer, true)
}
-func (servers *Servers) GetSecureLocation() string {
- return servers.SecureInternetHomeServer.CurrentLocation
+func (ss *Servers) GetSecureLocation() string {
+ return ss.SecureInternetHomeServer.CurrentLocation
}
-func (servers *Servers) SetSecureLocation(
+func (ss *Servers) SetSecureLocation(
chosenLocationServer *types.DiscoveryServer,
) error {
- errorMessage := "failed to set secure location"
// Make sure to add the current location
- _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer)
- if addLocationErr != nil {
- return types.NewWrappedError(errorMessage, addLocationErr)
+ if _, err := ss.SecureInternetHomeServer.addLocation(chosenLocationServer); err != nil {
+ return err
}
- servers.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode
+ ss.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode
return nil
}
diff --git a/internal/util/util.go b/internal/util/util.go
index ddd165d..558dba7 100644
--- a/internal/util/util.go
+++ b/internal/util/util.go
@@ -1,4 +1,4 @@
-// package util implements several utility functions that are used across the codebase
+// Package util implements several utility functions that are used across the codebase
package util
import (
@@ -9,7 +9,7 @@ import (
"path"
"strings"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
// EnsureValidURL ensures that the input URL is valid to be used internally
@@ -19,135 +19,119 @@ import (
// - It makes sure that the URL ends with a /
// It returns an error if the URL cannot be parsed.
func EnsureValidURL(s string) (string, error) {
- parsedURL, parseErr := url.Parse(s)
- if parseErr != nil {
- return "", types.NewWrappedError(
- fmt.Sprintf("failed parsing url: %s", s),
- parseErr,
- )
+ u, err := url.Parse(s)
+ if err != nil {
+ return "", errors.WrapPrefix(err, "failed parsing url", 0)
}
- if parsedURL.Scheme == "" {
- parsedURL.Scheme = "https"
+ if u.Scheme == "" {
+ u.Scheme = "https"
}
- if parsedURL.Path != "" {
+ if u.Path != "" {
// Clean the path
// https://pkg.go.dev/path#Clean
- parsedURL.Path = path.Clean(parsedURL.Path)
+ u.Path = path.Clean(u.Path)
}
- returnedURL := parsedURL.String()
+ str := u.String()
// Make sure the URL ends with a /
- if returnedURL[len(returnedURL)-1:] != "/" {
- returnedURL += "/"
+ if str[len(str)-1:] != "/" {
+ str += "/"
}
- return returnedURL, nil
+ return str, nil
}
-// MakeRandomByteSlice creates a cryptographically random byteslice of `size`
+// MakeRandomByteSlice creates a cryptographically random bytes slice of `size`
// It returns the byte slice (or nil if error) and an error if it could not be generated.
func MakeRandomByteSlice(size int) ([]byte, error) {
- byteSlice := make([]byte, size)
- _, err := rand.Read(byteSlice)
- if err != nil {
- return nil, types.NewWrappedError("failed reading random", err)
+ bs := make([]byte, size)
+ if _, err := rand.Read(bs); err != nil {
+ return nil, errors.WrapPrefix(err, "failed reading random", 0)
}
- return byteSlice, nil
+ return bs, nil
}
// EnsureDirectory creates a directory with permission 700.
-func EnsureDirectory(directory string) error {
+func EnsureDirectory(dir string) error {
// Create with 700 permissions, read, write, execute only for the owner
- mkdirErr := os.MkdirAll(directory, 0o700)
- if mkdirErr != nil {
- return types.NewWrappedError(
- fmt.Sprintf("failed to create directory %s", directory),
- mkdirErr,
- )
+ err := os.MkdirAll(dir, 0o700)
+ if err != nil {
+ return errors.WrapPrefix(err, fmt.Sprintf("failed to create directory '%s'", dir), 0)
}
return nil
}
-// WAYFEncode an input URL using 'skip Where Are You From' encoding
-// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md
-// URL encode for skipping where are you from (WAYF). Note that this right now is basically an alias to QueryEscape.
-func WAYFEncode(input string) string {
- // QueryReplace already replaces a space with a +
- // see https://go.dev/play/p/pOfrn-Wsq5
- return url.QueryEscape(input)
-}
-
// ReplaceWAYF replaces an authorization template containing of @RETURN_TO@ and @ORG_ID@ with the authorization URL and the organization ID
// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md
-func ReplaceWAYF(authTemplate string, authURL string, orgID string) string {
+func ReplaceWAYF(authTplt string, authURL string, orgID string) string {
// We just return the authURL in the cases where the template is not given or is invalid
- if authTemplate == "" {
+ if authTplt == "" {
return authURL
}
- if !strings.Contains(authTemplate, "@RETURN_TO@") {
+ if !strings.Contains(authTplt, "@RETURN_TO@") {
return authURL
}
- if !strings.Contains(authTemplate, "@ORG_ID@") {
+ if !strings.Contains(authTplt, "@ORG_ID@") {
return authURL
}
// Replace authURL
- authTemplate = strings.Replace(authTemplate, "@RETURN_TO@", url.QueryEscape(authURL), 1)
+ authTplt = strings.Replace(authTplt, "@RETURN_TO@", url.QueryEscape(authURL), 1)
// If now there is no more ORG_ID, return as there weren't enough @ symbols
- if !strings.Contains(authTemplate, "@ORG_ID@") {
+ if !strings.Contains(authTplt, "@ORG_ID@") {
return authURL
}
// Replace ORG ID
- authTemplate = strings.Replace(authTemplate, "@ORG_ID@", url.QueryEscape(orgID), 1)
- return authTemplate
+ authTplt = strings.Replace(authTplt, "@ORG_ID@", url.QueryEscape(orgID), 1)
+ return authTplt
}
// GetLanguageMatched uses a map from language tags to strings to extract the right language given the tag
// It implements it according to https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY.md#language-matching
-func GetLanguageMatched(languageMap map[string]string, languageTag string) string {
+func GetLanguageMatched(langMap map[string]string, langTag string) string {
// If no map is given, return the empty string
- if len(languageMap) == 0 {
+ if len(langMap) == 0 {
return ""
}
// Try to find the exact match
- if val, ok := languageMap[languageTag]; ok {
+ if val, ok := langMap[langTag]; ok {
return val
}
// Try to find a key that starts with the OS language setting
- for k := range languageMap {
- if strings.HasPrefix(k, languageTag) {
- return languageMap[k]
+ for k := range langMap {
+ if strings.HasPrefix(k, langTag) {
+ return langMap[k]
}
}
// Try to find a key that starts with the first part of the OS language (e.g. de-)
- splitted := strings.Split(languageTag, "-")
+ pts := strings.Split(langTag, "-")
// We have a "-"
- if len(splitted) > 1 {
- for k := range languageMap {
- if strings.HasPrefix(k, splitted[0]+"-") {
- return languageMap[k]
+ if len(pts) > 1 {
+ for k := range langMap {
+ if strings.HasPrefix(k, pts[0]+"-") {
+ return langMap[k]
}
}
}
// search for just the language (e.g. de)
- for k := range languageMap {
- if k == splitted[0] {
- return languageMap[k]
+ for k := range langMap {
+ if k == pts[0] {
+ return langMap[k]
}
}
// Pick one that is deemed best, e.g. en-US or en, but note that not all languages are always available!
// We force an entry that is english exactly or with an english prefix
- for k := range languageMap {
+ for k := range langMap {
if k == "en" || strings.HasPrefix(k, "en-") {
- return languageMap[k]
+ return langMap[k]
}
}
// Otherwise just return one
- for k := range languageMap {
- return languageMap[k]
+ for k := range langMap {
+ return langMap[k]
}
return ""
diff --git a/internal/verify/verify.go b/internal/verify/verify.go
index 55b82b6..cd74a2b 100644
--- a/internal/verify/verify.go
+++ b/internal/verify/verify.go
@@ -1,10 +1,10 @@
-// package verify implement signature verification using minisign
+// Package verify implement signature verification using minisign
package verify
import (
"fmt"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
"github.com/jedisct1/go-minisign"
)
@@ -31,7 +31,7 @@ func Verify(
"RWRtBSX1alxyGX+Xn3LuZnWUT0w//B6EmTJvgaAxBMYzlQeI+jdrO6KF", // fkooman@tuxed.net, kolla@uninett.no
"RWQKqtqvd0R7rUDp0rWzbtYPA3towPWcLDCl7eY9pBMMI/ohCmrS0WiM", // RoSp
}
- valid, err := verifyWithKeys(
+ return verifyWithKeys(
signatureFileContent,
signedJSON,
expectedFileName,
@@ -39,10 +39,6 @@ func Verify(
keyStrs,
forcePrehash,
)
- if err != nil {
- return valid, types.NewWrappedError("failed signature verify", err)
- }
- return valid, nil
}
// verifyWithKeys verifies the Minisign signature in signatureFileContent (minisig file format) over the server_list/organization_list JSON in signedJSON.
@@ -67,23 +63,21 @@ func verifyWithKeys(
case "server_list.json", "organization_list.json":
break
default:
- return false, &UnknownExpectedFilenameError{
- Filename: filename,
- Expected: "server_list.json or organization_list.json",
- }
+ return false, errors.Errorf(
+ "invalid filename '%s'; expected 'server_list.json' or 'organization_list.json'",
+ filename)
}
sig, err := minisign.DecodeSignature(signatureFileContent)
if err != nil {
- return false, &InvalidSignatureFormatError{Err: err}
+ return false, errors.WrapPrefix(err, "invalid signature format", 0)
}
// Check if signature is prehashed, see https://jedisct1.github.io/minisign/#signature-format
if forcePrehash && sig.SignatureAlgorithm != [2]byte{'E', 'D'} {
- return false, &InvalidSignatureAlgorithmError{
- Algorithm: string(sig.SignatureAlgorithm[:]),
- WantedAlgorithm: "ED (BLAKE2b-prehashed EdDSA)",
- }
+ return false, errors.Errorf(
+ "invalid signature algorithm '%s'; expected `ED (BLAKE2b-prehashed EdDSA)`",
+ sig.SignatureAlgorithm[:])
}
// Find allowed key used for signature
@@ -91,7 +85,7 @@ func verifyWithKeys(
key, err := minisign.NewPublicKey(keyStr)
if err != nil {
// Should only happen if Verify is wrong or extraKey is invalid
- return false, &CreatePublicKeyError{PublicKey: keyStr, Err: err}
+ return false, errors.WrapPrefix(err, fmt.Sprintf("failed to create public key '%s'", keyStr), 0)
}
if sig.KeyId != key.KeyId {
@@ -100,7 +94,7 @@ func verifyWithKeys(
valid, err := key.Verify(signedJSON, sig)
if !valid {
- return false, &InvalidSignatureError{Err: err}
+ return false, errors.WrapPrefix(err, "invalid signature", 0)
}
// Parse trusted comment
@@ -114,125 +108,21 @@ func verifyWithKeys(
&sigFileName,
)
if err != nil {
- return false, &InvalidTrustedCommentError{
- TrustedComment: sig.TrustedComment,
- Err: err,
- }
+ return false, errors.WrapPrefix(err, fmt.Sprintf("invalid trusted comment '%s'", sig.TrustedComment), 0)
}
if sigFileName != filename {
- return false, &WrongSigFilenameError{Filename: filename, SigFilename: sigFileName}
+ return false, errors.Errorf("wrong filename '%s'; expected filename '%s' for signature",
+ filename, sigFileName)
}
if signTime < minSignTime {
- return false, &SigTimeEarlierError{SigTime: signTime, MinSigTime: minSignTime}
+ return false, errors.Errorf("sign time %d is before sign tim: %d", signTime, minSignTime)
}
return true, nil
}
// No matching allowed key found
- return false, &UnknownKeyError{Filename: filename}
-}
-
-type UnknownExpectedFilenameError struct {
- Filename string
- Expected string
-}
-
-func (e *UnknownExpectedFilenameError) Error() string {
- return fmt.Sprintf("invalid filename: %s, expected: %s", e.Filename, e.Expected)
-}
-
-type InvalidSignatureFormatError struct {
- Err error
-}
-
-func (e *InvalidSignatureFormatError) Error() string {
- return fmt.Sprintf("invalid signature format with error: %v", e.Err)
-}
-
-func (e *InvalidSignatureFormatError) Unwrap() error {
- return e.Err
-}
-
-type InvalidSignatureAlgorithmError struct {
- Algorithm string
- WantedAlgorithm string
-}
-
-func (e *InvalidSignatureAlgorithmError) Error() string {
- return fmt.Sprintf(
- "invalid signature algorithm: %s, wanted: %s",
- e.Algorithm,
- e.WantedAlgorithm,
- )
-}
-
-type CreatePublicKeyError struct {
- PublicKey string
- Err error
-}
-
-func (e *CreatePublicKeyError) Error() string {
- return fmt.Sprintf("failed to create public key: %s with error: %v", e.PublicKey, e.Err)
-}
-
-func (e *CreatePublicKeyError) Unwrap() error {
- return e.Err
-}
-
-type InvalidSignatureError struct {
- Err error
-}
-
-func (e *InvalidSignatureError) Error() string {
- return fmt.Sprintf("invalid signature with error: %v", e.Err)
-}
-
-func (e *InvalidSignatureError) Unwrap() error {
- return e.Err
-}
-
-type InvalidTrustedCommentError struct {
- TrustedComment string
- Err error
-}
-
-func (e *InvalidTrustedCommentError) Error() string {
- return fmt.Sprintf("invalid trusted comment: %s with error: %v", e.TrustedComment, e.Err)
-}
-
-func (e *InvalidTrustedCommentError) Unwrap() error {
- return e.Err
-}
-
-type WrongSigFilenameError struct {
- Filename string
- SigFilename string
-}
-
-func (e *WrongSigFilenameError) Error() string {
- return fmt.Sprintf(
- "wrong filename: %s, expected filename: %s for signature",
- e.Filename,
- e.SigFilename,
- )
-}
-
-type SigTimeEarlierError struct {
- SigTime uint64
- MinSigTime uint64
-}
-
-func (e *SigTimeEarlierError) Error() string {
- return fmt.Sprintf("Sign time: %d is earlier than sign time: %d", e.SigTime, e.MinSigTime)
-}
-
-type UnknownKeyError struct {
- Filename string
-}
-
-func (e *UnknownKeyError) Error() string {
- return fmt.Sprintf("signature for filename: %s was created with an unknown key", e.Filename)
+ return false, errors.Errorf("signature for filename '%s' was created with an unknown key", filename)
}
diff --git a/internal/verify/verify_test.go b/internal/verify/verify_test.go
index 8ebed4c..a80cbfc 100644
--- a/internal/verify/verify_test.go
+++ b/internal/verify/verify_test.go
@@ -2,10 +2,10 @@ package verify
import (
"bufio"
- "errors"
"fmt"
"io/ioutil"
"os"
+ "strings"
"testing"
)
@@ -28,29 +28,17 @@ func Test_verifyWithKeys(t *testing.T) {
pk = []string{scanner.Text()}
}
- var (
- verifyCreatePublicKeyError *CreatePublicKeyError
- verifyInvalidSignatureAlgorithmError *InvalidSignatureAlgorithmError
- verifyWrongSigFilenameError *WrongSigFilenameError
- verifyInvalidTrustedCommentError *InvalidTrustedCommentError
- verifyInvalidSignatureFormatError *InvalidSignatureFormatError
- verifyInvalidSignatureError *InvalidSignatureError
- verifySigTimeEarlierError *SigTimeEarlierError
- verifyUnknownExpectedFilenameError *UnknownExpectedFilenameError
- verifyUnknownKeyError *UnknownKeyError
- )
-
tests := []struct {
- expectedErr interface{}
- testName string
- signatureFile string
- jsonFile string
- expectedFileName string
- minSignTime uint64
- allowedPks []string
+ expectedErrPrefix string
+ testName string
+ signatureFile string
+ jsonFile string
+ expectedFileName string
+ minSignTime uint64
+ allowedPks []string
}{
{
- &verifyInvalidSignatureAlgorithmError,
+ "invalid signature algorithm '",
"pure",
"server_list.json.pure.minisig",
"server_list.json",
@@ -58,9 +46,8 @@ func Test_verifyWithKeys(t *testing.T) {
10,
pk,
},
-
{
- nil,
+ "",
"valid server_list",
"server_list.json.minisig",
"server_list.json",
@@ -69,7 +56,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- nil,
+ "",
"TC no hashed",
"server_list.json.tc_nohashed.minisig",
"server_list.json",
@@ -78,7 +65,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- nil,
+ "",
"TC later time",
"server_list.json.tc_latertime.minisig",
"server_list.json",
@@ -87,7 +74,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyWrongSigFilenameError,
+ "wrong filename '",
"server_list TC file:organization_list",
"server_list.json.tc_orglist.minisig",
"server_list.json",
@@ -96,7 +83,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyWrongSigFilenameError,
+ "wrong filename '",
"organization_list as server_list",
"organization_list.json.minisig",
"organization_list.json",
@@ -105,7 +92,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyWrongSigFilenameError,
+ "wrong filename '",
"TC file:otherfile",
"server_list.json.tc_otherfile.minisig",
"server_list.json",
@@ -114,7 +101,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyInvalidTrustedCommentError,
+ "invalid trusted comment '",
"TC no file",
"server_list.json.tc_nofile.minisig",
"server_list.json",
@@ -123,7 +110,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyInvalidTrustedCommentError,
+ "invalid trusted comment '",
"TC no time",
"server_list.json.tc_notime.minisig",
"server_list.json",
@@ -132,7 +119,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyInvalidTrustedCommentError,
+ "invalid trusted comment '",
"TC empty time",
"server_list.json.tc_emptytime.minisig",
"server_list.json",
@@ -141,7 +128,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyWrongSigFilenameError,
+ "wrong filename '",
"TC empty file",
"server_list.json.tc_emptyfile.minisig",
"server_list.json",
@@ -150,7 +137,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyInvalidTrustedCommentError,
+ "invalid trusted comment '",
"TC random",
"server_list.json.tc_random.minisig",
"server_list.json",
@@ -159,7 +146,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- nil,
+ "",
"large time",
"server_list.json.large_time.minisig",
"server_list.json",
@@ -168,7 +155,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- nil,
+ "",
"lower min time",
"server_list.json.minisig",
"server_list.json",
@@ -177,7 +164,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifySigTimeEarlierError,
+ "sign time",
"higher min time",
"server_list.json.minisig",
"server_list.json",
@@ -185,9 +172,8 @@ func Test_verifyWithKeys(t *testing.T) {
11,
pk,
},
-
{
- nil,
+ "",
"valid organization_list",
"organization_list.json.minisig",
"organization_list.json",
@@ -196,7 +182,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyWrongSigFilenameError,
+ "wrong filename '",
"organization_list TC file:server_list",
"organization_list.json.tc_servlist.minisig",
"organization_list.json",
@@ -205,7 +191,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyWrongSigFilenameError,
+ "wrong filename '",
"server_list as organization_list",
"server_list.json.minisig",
"server_list.json",
@@ -215,7 +201,7 @@ func Test_verifyWithKeys(t *testing.T) {
},
{
- &verifyUnknownExpectedFilenameError,
+ "invalid filename '",
"valid other_list",
"other_list.json.minisig",
"other_list.json",
@@ -224,7 +210,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyWrongSigFilenameError,
+ "wrong filename '",
"other_list as server_list",
"other_list.json.minisig",
"other_list.json",
@@ -232,9 +218,8 @@ func Test_verifyWithKeys(t *testing.T) {
10,
pk,
},
-
{
- &verifyInvalidSignatureFormatError,
+ "invalid signature format",
"invalid signature file",
"random.txt",
"server_list.json",
@@ -243,7 +228,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyInvalidSignatureFormatError,
+ "invalid signature format",
"empty signature file",
"empty",
"server_list.json",
@@ -253,7 +238,7 @@ func Test_verifyWithKeys(t *testing.T) {
},
{
- &verifyUnknownKeyError,
+ "signature for filename '",
"wrong key",
"server_list.json.wrong_key.minisig",
"server_list.json",
@@ -263,7 +248,7 @@ func Test_verifyWithKeys(t *testing.T) {
},
{
- &verifyInvalidSignatureAlgorithmError,
+ "invalid signature algorithm '",
"forged pure signature",
"server_list.json.forged_pure.minisig",
"server_list.json.blake2b",
@@ -272,7 +257,7 @@ func Test_verifyWithKeys(t *testing.T) {
pk,
},
{
- &verifyInvalidSignatureError,
+ "invalid signature",
"forged key ID",
"server_list.json.forged_keyid.minisig",
"server_list.json",
@@ -282,7 +267,7 @@ func Test_verifyWithKeys(t *testing.T) {
},
{
- &verifyUnknownKeyError,
+ "signature for filename '",
"no allowed keys",
"server_list.json.minisig",
"server_list.json",
@@ -291,7 +276,7 @@ func Test_verifyWithKeys(t *testing.T) {
[]string{},
},
{
- nil,
+ "",
"multiple allowed keys 1",
"server_list.json.minisig",
"server_list.json",
@@ -302,7 +287,7 @@ func Test_verifyWithKeys(t *testing.T) {
},
},
{
- nil,
+ "",
"multiple allowed keys 2",
"server_list.json.minisig",
"server_list.json",
@@ -313,7 +298,7 @@ func Test_verifyWithKeys(t *testing.T) {
},
},
{
- &verifyCreatePublicKeyError,
+ "failed to create public key '",
"invalid allowed key",
"server_list.json.minisig",
"server_list.json",
@@ -345,7 +330,7 @@ func Test_verifyWithKeys(t *testing.T) {
t.Run(tt.testName, func(t *testing.T) {
valid, err := verifyWithKeys(string(files[tt.signatureFile]), files[tt.jsonFile],
tt.expectedFileName, tt.minSignTime, tt.allowedPks, forcePrehash)
- compareResults(t, valid, err, tt.expectedErr, func() string {
+ compareResults(t, valid, err, tt.expectedErrPrefix, func() string {
return fmt.Sprintf(
"verifyWithKeys(%q, %q, %q, %v, %v, %t)",
tt.signatureFile,
@@ -366,17 +351,33 @@ func compareResults(
t *testing.T,
ret bool,
err error,
- expectedErr interface{},
+ expectedErrPrefix string,
callStr func() string,
) {
- // different error returned
- if expectedErr != nil && !errors.As(err, expectedErr) {
- t.Errorf("%v\nerror %T = %v, wantErr %T", callStr(), err, err, expectedErr)
+ if expectedErrPrefix == "" {
+ // we don't expect any error
+ if err != nil {
+ t.Errorf("error not expected but returned '%s'", err.Error())
+ }
+ if !ret {
+ t.Errorf("error is nil and result is false")
+ }
return
}
- // different boolean returned
- expectedBool := expectedErr == nil
- if ret != expectedBool {
- t.Errorf("%v\n= %v, want %v", callStr(), ret, expectedBool)
+
+ if err == nil {
+ // we expect an error but received nil
+ t.Errorf("expected error prefix '%s' but received nil", expectedErrPrefix)
+ return
+ }
+
+ if !strings.HasPrefix(err.Error(), expectedErrPrefix) {
+ // wrong error
+ t.Errorf("expected error prefix '%s' for error '%s'", expectedErrPrefix, err.Error())
+ return
+ }
+
+ if ret {
+ t.Errorf("error is not nil and result is true")
}
}
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index 7da2623..0419ff6 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -1,24 +1,21 @@
-// package wireguard implements a few helpers for the WireGuard protocol
+// Package wireguard implements a few helpers for the WireGuard protocol
package wireguard
import (
"fmt"
"regexp"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// GenerateKey generates a WireGuard private key using wgctrl
// It returns an error if key generation failed.
func GenerateKey() (wgtypes.Key, error) {
- key, keyErr := wgtypes.GeneratePrivateKey()
+ key, err := wgtypes.GeneratePrivateKey()
- if keyErr != nil {
- return key, types.NewWrappedError(
- "failed generating WireGuard key",
- keyErr,
- )
+ if err != nil {
+ return key, errors.WrapPrefix(err, "failed generating WireGuard key", 0)
}
return key, nil
}