diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-11-28 11:18:14 +0100 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-11-28 11:18:42 +0100 |
| commit | e9f8db8ee8fccf60e58deb1d72766f94a053bb16 (patch) | |
| tree | ffa5a9be67717ecc8ff7bdc03d5f96028facb0e3 /internal | |
| parent | b4ff890ec2b459148d893499a34a6d2954530369 (diff) | |
Document: Add comments for most functions and packages
Errors and test files still need to be done. Also some getters are
changed by removing the 'get' prefix
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/config/config.go | 18 | ||||
| -rw-r--r-- | internal/discovery/discovery.go | 102 | ||||
| -rw-r--r-- | internal/fsm/fsm.go | 71 | ||||
| -rw-r--r-- | internal/http/http.go | 22 | ||||
| -rw-r--r-- | internal/log/log.go | 81 | ||||
| -rw-r--r-- | internal/oauth/oauth.go | 141 | ||||
| -rw-r--r-- | internal/util/util.go | 20 | ||||
| -rw-r--r-- | internal/util/util_test.go | 12 | ||||
| -rw-r--r-- | internal/verify/verify.go | 1 | ||||
| -rw-r--r-- | internal/wireguard/wireguard.go | 4 |
10 files changed, 318 insertions, 154 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index 180b881..fa3045b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,3 +1,5 @@ +// Package config implements functions for saving a struct to a file +// It then provides functions to later read it such that we can restore the same struct package config import ( @@ -10,21 +12,29 @@ import ( "github.com/eduvpn/eduvpn-common/types" ) +// Config represents a configuration that saves the client's struct as JSON type Config struct { + // Directory represents the path to where the data is saved Directory string + + // Name defines the name of file excluding the .json extension Name string } +// Init initializes the configuration using the provided directory and name func (config *Config) Init(directory string, name string) { config.Directory = directory config.Name = name } -func (config *Config) GetFilename() string { +// filename returns the filename of the configuration as a full path +func (config *Config) filename() string { pathString := path.Join(config.Directory, config.Name) return fmt.Sprintf("%s.json", pathString) } +// Save saves a structure 'readStruct' to the configuration +// If it was unusuccessful, an an error is returned func (config *Config) Save(readStruct interface{}) error { errorMessage := "failed saving configuration" configDirErr := util.EnsureDirectory(config.Directory) @@ -35,11 +45,13 @@ func (config *Config) Save(readStruct interface{}) error { if marshalErr != nil { return types.NewWrappedError(errorMessage, marshalErr) } - return ioutil.WriteFile(config.GetFilename(), jsonString, 0o600) + return ioutil.WriteFile(config.filename(), jsonString, 0o600) } +// Load loads the configuration and writes the structure to 'writeStruct' +// If it was unsuccessful, an error is returned. func (config *Config) Load(writeStruct interface{}) error { - bytes, readErr := ioutil.ReadFile(config.GetFilename()) + bytes, readErr := ioutil.ReadFile(config.filename()) if readErr != nil { return types.NewWrappedError("failed loading configuration", readErr) } diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 9102593..7df209c 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -1,3 +1,4 @@ +// package discovery implements the server discovery by contacting disco.eduvpn.org and returning the data as a Go structure package discovery import ( @@ -10,12 +11,18 @@ import ( "github.com/eduvpn/eduvpn-common/types" ) + +// Discovery is the main structure used for this package type Discovery struct { - Organizations types.DiscoveryOrganizations - Servers types.DiscoveryServers + // organizations represents the organizations that are returned by the discovery server + organizations types.DiscoveryOrganizations + + // servers represents the servers that are returned by the discovery server + servers types.DiscoveryServers } -// Helper function that gets a disco json and fills the structure with it +// getDiscoFile is a helper function that gets a disco json and fills the structure with it +// If it was unsuccessful it returns an error func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) error { errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile) // Get json data @@ -61,6 +68,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} return nil } +// DetermineOrganizationsUpdate returns a boolean indicating whether or not the discovery organizations should be updated // FIXME: Implement based on // https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md // - [IMPLEMENTED] on "first launch" when offering the search for "Institute Access" and "Organizations"; @@ -68,12 +76,13 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} // - [TODO] when the authorization for the server associated with an already chosen organization is triggered, e.g. after expiry or revocation. // - [IMPLEMENTED using a custom error message] NOTE: when the org_id that the user chose previously is no longer available in organization_list.json the application should ask the user to choose their organization (again). This can occur for example when the organization replaced their identity provider, uses a different domain after rebranding or simply ceased to exist. func (discovery *Discovery) DetermineOrganizationsUpdate() bool { - return discovery.Organizations.Timestamp.IsZero() + return discovery.organizations.Timestamp.IsZero() } -func (discovery *Discovery) GetSecureLocationList() []string { +// SecureLocationList returns a slice of all the available locations +func (discovery *Discovery) SecureLocationList() []string { var locations []string - for _, currentServer := range discovery.Servers.List { + for _, currentServer := range discovery.servers.List { if currentServer.Type == "secure_internet" { locations = append(locations, currentServer.CountryCode) } @@ -81,38 +90,44 @@ func (discovery *Discovery) GetSecureLocationList() []string { return locations } -func (discovery *Discovery) GetServerByURL( - url string, - _type string, +// ServerByURL returns the discovery server by the base URL and the according type ("secure_internet", "institute_access") +// An error is returned if and only if nil is returned for the server +func (discovery *Discovery) ServerByURL( + baseURL string, + serverType string, ) (*types.DiscoveryServer, error) { - for _, currentServer := range discovery.Servers.List { - if currentServer.BaseURL == url && currentServer.Type == _type { + for _, currentServer := range discovery.servers.List { + if currentServer.BaseURL == baseURL && currentServer.Type == serverType { return ¤tServer, nil } } return nil, types.NewWrappedError( "failed getting server by URL from discovery", - &GetServerByURLNotFoundError{URL: url, Type: _type}, + &GetServerByURLNotFoundError{URL: baseURL, Type: serverType}, ) } -func (discovery *Discovery) GetServerByCountryCode( - code string, - _type string, +// ServerByCountryCode returns the discovery server by the country code and the according type ("secure_internet", "institute_access") +// An error is returned if and only if nil is returned for the server +func (discovery *Discovery) ServerByCountryCode( + countryCode string, + serverType string, ) (*types.DiscoveryServer, error) { - for _, currentServer := range discovery.Servers.List { - if currentServer.CountryCode == code && currentServer.Type == _type { + for _, currentServer := range discovery.servers.List { + if currentServer.CountryCode == countryCode && currentServer.Type == serverType { return ¤tServer, nil } } return nil, types.NewWrappedError( - "failed getting server by country code from discovery", - &GetServerByCountryCodeNotFoundError{CountryCode: code, Type: _type}, + "failed getting server by country countryCode from discovery", + &GetServerByCountryCodeNotFoundError{CountryCode: countryCode, Type: serverType}, ) } -func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganization, error) { - for _, organization := range discovery.Organizations.List { +// orgByID returns the discovery organization by the organization ID +// An error is returned if and only if nil is returned for the organization +func (discovery *Discovery) orgByID(orgID string) (*types.DiscoveryOrganization, error) { + for _, organization := range discovery.organizations.List { if organization.OrgID == orgID { return &organization, nil } @@ -123,11 +138,15 @@ func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganizati ) } -func (discovery *Discovery) GetSecureHomeArgs( +// SecureHomeArgs returns the secure internet home server arguments: +// - The organization it belongs to +// - The secure internet server itself +// An error is returned if and only if nil is returned for the organization +func (discovery *Discovery) SecureHomeArgs( orgID string, ) (*types.DiscoveryOrganization, *types.DiscoveryServer, error) { errorMessage := "failed getting Secure Internet Home arguments from discovery" - org, orgErr := discovery.getOrgByID(orgID) + org, orgErr := discovery.orgByID(orgID) if orgErr != nil { return nil, nil, types.NewWrappedError(errorMessage, orgErr) @@ -136,7 +155,7 @@ func (discovery *Discovery) GetSecureHomeArgs( // Get a server with the base url url := org.SecureInternetHome - currentServer, serverErr := discovery.GetServerByURL(url, "secure_internet") + currentServer, serverErr := discovery.ServerByURL(url, "secure_internet") if serverErr != nil { return nil, nil, types.NewWrappedError(errorMessage, serverErr) @@ -144,55 +163,58 @@ func (discovery *Discovery) GetSecureHomeArgs( return org, currentServer, nil } +// DetermineServersUpdate returns whether or not the discovery servers should be updated by contacting the discovery server // https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md // - [Implemented] The application MUST always fetch the server_list.json at application start. // - The application MAY refresh the server_list.json periodically, e.g. once every hour. func (discovery *Discovery) DetermineServersUpdate() bool { // No servers, we should update - if discovery.Servers.Timestamp.IsZero() { + if discovery.servers.Timestamp.IsZero() { return true } // 1 hour from the last update - shouldUpdateTime := discovery.Servers.Timestamp.Add(1 * time.Hour) + shouldUpdateTime := discovery.servers.Timestamp.Add(1 * time.Hour) now := time.Now() return !now.Before(shouldUpdateTime) } -// Get the organization list -func (discovery *Discovery) GetOrganizationsList() (*types.DiscoveryOrganizations, error) { +// Organizations returns the discovery organizations +// If there was an error, a cached copy is returned if available +func (discovery *Discovery) Organizations() (*types.DiscoveryOrganizations, error) { if !discovery.DetermineOrganizationsUpdate() { - return &discovery.Organizations, nil + return &discovery.organizations, nil } file := "organization_list.json" - bodyErr := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations) + bodyErr := getDiscoFile(file, discovery.organizations.Version, &discovery.organizations) if bodyErr != nil { // Return previous with an error - return &discovery.Organizations, types.NewWrappedError( + return &discovery.organizations, types.NewWrappedError( "failed getting organizations in Discovery", bodyErr, ) } - discovery.Organizations.Timestamp = time.Now() - return &discovery.Organizations, nil + discovery.organizations.Timestamp = time.Now() + return &discovery.organizations, nil } -// Get the server list -func (discovery *Discovery) GetServersList() (*types.DiscoveryServers, error) { +// Servers returns the discovery servers +// If there was an error, a cached copy is returned if available +func (discovery *Discovery) Servers() (*types.DiscoveryServers, error) { if !discovery.DetermineServersUpdate() { - return &discovery.Servers, nil + return &discovery.servers, nil } file := "server_list.json" - bodyErr := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers) + bodyErr := getDiscoFile(file, discovery.servers.Version, &discovery.servers) if bodyErr != nil { // Return previous with an error - return &discovery.Servers, types.NewWrappedError( + return &discovery.servers, types.NewWrappedError( "failed getting servers in Discovery", bodyErr, ) } // Update servers timestamp - discovery.Servers.Timestamp = time.Now() - return &discovery.Servers, nil + discovery.servers.Timestamp = time.Now() + return &discovery.servers, nil } type GetOrgByIDNotFoundError struct { diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go index b51a5c9..b8fd644 100644 --- a/internal/fsm/fsm.go +++ b/internal/fsm/fsm.go @@ -1,3 +1,5 @@ +// Package fsm defines a finite state machine and has the ability to save this state machine to a graph file +// This graph file can be visualized using mermaid.js package fsm import ( @@ -10,7 +12,9 @@ import ( ) type ( + //StateID represents the Identifier of the state FSMStateID int8 + //StateIDSlice represents the list of state identifiers FSMStateIDSlice []FSMStateID ) @@ -26,8 +30,11 @@ func (v FSMStateIDSlice) Swap(i, j int) { v[i], v[j] = v[j], v[i] } +// Transition indicates an arrow in the state graph type FSMTransition struct { + // To represents the to-be-new state To FSMStateID + // Description is what type of message the arrow gets in the graph Description string } @@ -35,45 +42,60 @@ type ( FSMStates map[FSMStateID]FSMState ) +// State represents a single node in the graph type FSMState struct { + // Transitions indicates which out arrows this node has Transitions []FSMTransition - - // Which state to go back to on a back transition - BackState FSMStateID } +// FSM represents the total graph type FSM struct { + // States is the map from state ID to states States FSMStates + + // Current is the current state represented by the identifier Current FSMStateID - // Info to be passed from the parent state + // Name represents the descriptive name of this state machine Name string + + // StateCallback is the function ran when a transition occurs + // It takes the old state, the new state and the data and returns if this is handled by the client StateCallback func(FSMStateID, FSMStateID, interface{}) bool + + // Directory represents the path where the state graph is stored Directory string - Debug bool - GetName func(FSMStateID) string + + // Generate represents whether we want to generate the graph + Generate bool + + // GetStateName gets the name of a state as a string + GetStateName func(FSMStateID) string } +// Init initializes the state machine and sets it to the given current state func (fsm *FSM) Init( current FSMStateID, states map[FSMStateID]FSMState, callback func(FSMStateID, FSMStateID, interface{}) bool, directory string, nameGen func(FSMStateID) string, - debug bool, + generate bool, ) { fsm.States = states fsm.Current = current fsm.StateCallback = callback fsm.Directory = directory - fsm.GetName = nameGen - fsm.Debug = debug + fsm.GetStateName = nameGen + fsm.Generate = generate } +// InState returns whether or not the state machine is in the given 'check' state func (fsm *FSM) InState(check FSMStateID) bool { return check == fsm.Current } +// HasTransition checks whether or not the state machine has a transition to the given 'check' state func (fsm *FSM) HasTransition(check FSMStateID) bool { for _, transitionState := range fsm.States[fsm.Current].Transitions { if transitionState.To == check { @@ -84,11 +106,13 @@ func (fsm *FSM) HasTransition(check FSMStateID) bool { return false } +// getGraphFilename gets the full path to the graph filename including the .graph extension func (fsm *FSM) getGraphFilename(extension string) string { debugPath := path.Join(fsm.Directory, "graph") return fmt.Sprintf("%s%s", debugPath, extension) } +// writeGraph writes the state machine to a .graph file func (fsm *FSM) writeGraph() { graph := fsm.GenerateGraph() graphFile := fsm.getGraphFilename(".graph") @@ -107,18 +131,18 @@ func (fsm *FSM) writeGraph() { } } -func (fsm *FSM) GoBack() { - fsm.GoTransition(fsm.States[fsm.Current].BackState) -} - +// GoTransitionRequired transitions the state machine to a new state with associated state data 'data' +// If this transition is not handled by the client, it returns an error func (fsm *FSM) GoTransitionRequired(newState FSMStateID, data interface{}) error { oldState := fsm.Current if !fsm.GoTransitionWithData(newState, data) { - return types.NewWrappedError("failed required transition", fmt.Errorf("required transition not handled, from: %s -> to: %s", fsm.GetName(oldState), fsm.GetName(newState))) + return types.NewWrappedError("failed required transition", fmt.Errorf("required transition not handled, from: %s -> to: %s", fsm.GetStateName(oldState), fsm.GetStateName(newState))) } return nil } +// GoTransitionWithData is a helper that transitions the state machine toward the 'newState' with associated state data 'data' +// It returns whether or not the transition is handled by the client func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool { ok := fsm.HasTransition(newState) @@ -126,7 +150,7 @@ func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool if ok { oldState := fsm.Current fsm.Current = newState - if fsm.Debug { + if fsm.Generate { fsm.writeGraph() } @@ -136,11 +160,14 @@ func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool return handled } +// GoTransition is an alias to call GoTransitionWithData but have an empty string as data func (fsm *FSM) GoTransition(newState FSMStateID) bool { // No data means the callback is never required return fsm.GoTransitionWithData(newState, "") } +// generateMermaidGraph generates a graph suitable to be converted by the mermaid.js tool +// it returns the graph as a string func (fsm *FSM) generateMermaidGraph() string { graph := "graph TD\n" sortedFSM := make(FSMStateIDSlice, 0, len(fsm.States)) @@ -152,15 +179,15 @@ func (fsm *FSM) generateMermaidGraph() string { transitions := fsm.States[state].Transitions for _, transition := range transitions { if state == fsm.Current { - graph += "\nstyle " + fsm.GetName(state) + " fill:cyan\n" + graph += "\nstyle " + fsm.GetStateName(state) + " fill:cyan\n" } else { - graph += "\nstyle " + fsm.GetName(state) + " fill:white\n" + graph += "\nstyle " + fsm.GetStateName(state) + " fill:white\n" } - graph += fsm.GetName( + graph += fsm.GetStateName( state, - ) + "(" + fsm.GetName( + ) + "(" + fsm.GetStateName( state, - ) + ") " + "-->|" + transition.Description + "| " + fsm.GetName( + ) + ") " + "-->|" + transition.Description + "| " + fsm.GetStateName( transition.To, ) + "\n" } @@ -168,8 +195,10 @@ func (fsm *FSM) generateMermaidGraph() string { return graph } +// GenerateGraph generates a mermaid graph if the state machine is initialized +// If the graph cannot be generated, it returns the empty string func (fsm *FSM) GenerateGraph() string { - if fsm.GetName != nil { + if fsm.GetStateName != nil { return fsm.generateMermaidGraph() } diff --git a/internal/http/http.go b/internal/http/http.go index a9d3ea2..b21b901 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -1,3 +1,4 @@ +// Package http defines higher level helpers for the net/http package package http import ( @@ -12,8 +13,10 @@ import ( "github.com/eduvpn/eduvpn-common/types" ) +// The URLParemeters as the name suggests is a type used for the parameters in the URL type URLParameters map[string]string +// OptionalParams is a structure that defines the optional parameters that are given when making a HTTP call type HTTPOptionalParams struct { Headers http.Header URLParameters URLParameters @@ -21,7 +24,7 @@ type HTTPOptionalParams struct { Timeout time.Duration } -// Construct an URL including on parameters +// ConstructURL creates a URL with the included parameters func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error) { url, parseErr := url.Parse(baseURL) if parseErr != nil { @@ -44,23 +47,28 @@ func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error) return url.String(), nil } -// Convenience functions +// Get creates a Get request and returns the headers, body and an error func HTTPGet(url string) (http.Header, []byte, error) { return HTTPMethodWithOpts(http.MethodGet, url, nil) } +// Post creates a Post request and returns the headers, body and an error func HTTPPost(url string, body url.Values) (http.Header, []byte, error) { return HTTPMethodWithOpts(http.MethodGet, url, &HTTPOptionalParams{Body: body}) } +// GetWithOpts creates a Get request with optional parameters and returns the headers, body and an error func HTTPGetWithOpts(url string, opts *HTTPOptionalParams) (http.Header, []byte, error) { return HTTPMethodWithOpts(http.MethodGet, url, opts) } +// PostWithOpts creates a Post request with optional parameters and returns the headers, body and an error func HTTPPostWithOpts(url string, opts *HTTPOptionalParams) (http.Header, []byte, error) { return HTTPMethodWithOpts(http.MethodPost, url, opts) } +// optionalURL ensures that the URL contains the optional parameters +// it returns the url (with parameters if success) and an error indicating success func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) { if opts != nil { url, urlErr := HTTPConstructURL(url, opts.URLParameters) @@ -76,6 +84,7 @@ func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) { return url, nil } +// optionalHeaders ensures that the HTTP request uses the optional headers if defined func httpOptionalHeaders(req *http.Request, opts *HTTPOptionalParams) { // Add headers if opts != nil && req != nil && opts.Headers != nil { @@ -85,6 +94,7 @@ func httpOptionalHeaders(req *http.Request, opts *HTTPOptionalParams) { } } +// optionalBodyReader returns a HTTP body reader if there is a body, otherwise nil func httpOptionalBodyReader(opts *HTTPOptionalParams) io.Reader { if opts != nil && opts.Body != nil { return strings.NewReader(opts.Body.Encode()) @@ -92,6 +102,8 @@ func httpOptionalBodyReader(opts *HTTPOptionalParams) io.Reader { return nil } +// MethodWithOpts creates a HTTP request using a method (e.g. GET, POST), an url and optional parameters +// It returns the HTTP headers, the body and an error if there is one func HTTPMethodWithOpts( method string, url string, @@ -155,12 +167,14 @@ func HTTPMethodWithOpts( return resp.Header, body, nil } +// StatusError indicates that we have received a HTTP status error type HTTPStatusError struct { URL string Body string Status int } +// Error returns the StatusError as an error string func (e *HTTPStatusError) Error() string { return fmt.Sprintf( "failed obtaining HTTP resource: %s as it gave an unsuccessful status code: %d. Body: %s", @@ -170,12 +184,16 @@ func (e *HTTPStatusError) Error() string { ) } +// ParseJSONError indicates that the HTTP error is because of failed JSON parsing +// It has the URL and the Body as context. +// The underlying JSON parsing Err itself is also wrapped here. type HTTPParseJSONError struct { URL string Body string Err error } +// Error returns the ParseJSONError as an error string func (e *HTTPParseJSONError) Error() string { return fmt.Sprintf( "failed parsing json %s for HTTP resource: %s with error: %v", diff --git a/internal/log/log.go b/internal/log/log.go index 2ab6549..7d032e9 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -1,3 +1,4 @@ +// Package log implements a basic level based logger package log import ( @@ -11,47 +12,60 @@ import ( "github.com/eduvpn/eduvpn-common/types" ) +// FileLogger defines the type of logger that this package implements +// As the name suggests, it saves the log to a file type FileLogger struct { + // Level indicates which maximum level this logger actually forwards to the file Level LogLevel - File *os.File + + // file represents a pointer to the open log file + file *os.File } type LogLevel int8 const ( - // No level set, not allowed - LogNotSet LogLevel = iota - // Log debug, this message is not an error but is there for debugging - LogDebug - // Log info, this message is not an error but is there for additional information - LogInfo - // Log only to provide a warning, the app still functions - LogWarning - // Log to provide a generic error, the app still functions but some functionality might not work - LogError - // Log to provide a fatal error, the app cannot function correctly when such an error occurs - LogFatal + // LevelNotSet indicates level not set, not allowed + LevelNotSet LogLevel = iota + + // LevelDebug indicates that the message is not an error but is there for debugging + LevelDebug + + // LevelInfo indicates that the message is not an error but is there for additional information + LevelInfo + + // LevelWarning indicates only a warning, the app still functions + LevelWarning + + // LevelError indicates a generic error, the app still functions but some functionality might not work + LevelError + + // LevelFatal indicates a fatal error, the app cannot function correctly when such an error occurs + LevelFatal ) +// String returns the string of each level func (e LogLevel) String() string { switch e { - case LogNotSet: + case LevelNotSet: return "NOTSET" - case LogDebug: + case LevelDebug: return "DEBUG" - case LogInfo: + case LevelInfo: return "INFO" - case LogWarning: + case LevelWarning: return "WARNING" - case LogError: + case LevelError: return "ERROR" - case LogFatal: + case LevelFatal: return "FATAL" default: return "UNKNOWN" } } +// Init initializes the logger by forwarding a max level 'level' and a directory 'directory' where the log should be stored +// If the logger cannot be initialized, for example an error in opening the log file, an error is returned func (logger *FileLogger) Init(level LogLevel, directory string) error { errorMessage := "failed creating log" @@ -60,7 +74,7 @@ func (logger *FileLogger) Init(level LogLevel, directory string) error { return types.NewWrappedError(errorMessage, configDirErr) } logFile, logOpenErr := os.OpenFile( - logger.getFilename(directory), + logger.filename(directory), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o666, ) @@ -69,11 +83,12 @@ func (logger *FileLogger) Init(level LogLevel, directory string) error { } multi := io.MultiWriter(os.Stdout, logFile) log.SetOutput(multi) - logger.File = logFile + logger.file = logFile logger.Level = level return nil } +// Inherit logs an error with a label using the error level of the error func (logger *FileLogger) Inherit(label string, err error) { level := types.GetErrorLevel(err) @@ -90,36 +105,44 @@ func (logger *FileLogger) Inherit(label string, err error) { } } +// Debug logs a message with parameters as level LevelDebug func (logger *FileLogger) Debug(msg string, params ...interface{}) { - logger.log(LogDebug, msg, params...) + logger.log(LevelDebug, msg, params...) } +// Debug logs a message with parameters as level LevelInfo func (logger *FileLogger) Info(msg string, params ...interface{}) { - logger.log(LogInfo, msg, params...) + logger.log(LevelInfo, msg, params...) } +// Debug logs a message with parameters as level LevelWarning func (logger *FileLogger) Warning(msg string, params ...interface{}) { - logger.log(LogWarning, msg, params...) + logger.log(LevelWarning, msg, params...) } +// Debug logs a message with parameters as level LevelError func (logger *FileLogger) Error(msg string, params ...interface{}) { - logger.log(LogError, msg, params...) + logger.log(LevelError, msg, params...) } +// Debug logs a message with parameters as level LevelFatal func (logger *FileLogger) Fatal(msg string, params ...interface{}) { - logger.log(LogFatal, msg, params...) + logger.log(LevelFatal, msg, params...) } +// Close closes the logger by closing the internal file func (logger *FileLogger) Close() { - logger.File.Close() + logger.file.Close() } -func (logger *FileLogger) getFilename(directory string) string { +// filename returns the filename of the logger by returning the full path as a string +func (logger *FileLogger) filename(directory string) string { return path.Join(directory, "log") } +// log logs as level 'level' a message 'msg' with parameters 'params' func (logger *FileLogger) log(level LogLevel, msg string, params ...interface{}) { - if level >= logger.Level && logger.Level != LogNotSet { + if level >= logger.Level && logger.Level != LevelNotSet { formattedMsg := fmt.Sprintf(msg, params...) format := fmt.Sprintf("- Go - %s - %s", level.String(), formattedMsg) // To log file diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 1f3b719..232f68c 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -1,3 +1,8 @@ +// package oauth implement an oauth client defined in e.g. rfc 6749 +// However, we try to follow some recommendations from the v2.1 oauth draft RFC +// Some specific things we implement here: +// - PKCE (RFC 7636) +// - ISS (RFC 9207) package oauth import ( @@ -18,7 +23,7 @@ import ( "github.com/eduvpn/eduvpn-common/types" ) -// Generates a random base64 string to be used for state +// genState generates a random base64 string to be used for state // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1 // "state": OPTIONAL. An opaque value used by the client to maintain // state between the request and callback. The authorization server @@ -35,7 +40,7 @@ func genState() (string, error) { return base64.RawURLEncoding.EncodeToString(randomBytes), nil } -// Generates a sha256 base64 challenge from a verifier +// genChallengeS256 generates a sha256 base64 challenge from a verifier // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.8 func genChallengeS256(verifier string) string { hash := sha256.Sum256([]byte(verifier)) @@ -44,7 +49,7 @@ func genChallengeS256(verifier string) string { return base64.RawURLEncoding.EncodeToString(hash[:]) } -// Generates a verifier +// genVerifier generates a verifier // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1 // The code_verifier is a unique high-entropy cryptographically random // string generated for each authorization request, using the unreserved @@ -70,75 +75,108 @@ func genVerifier() (string, error) { return base64.RawURLEncoding.EncodeToString(randomBytes), nil } +// OAuth defines the main structure for this package type OAuth struct { + // ISS indicates the issuer indentifier of the authorization server as defined in RFC 9207 ISS string `json:"iss"` - Session OAuthExchangeSession `json:"-"` + + // Token is where the access and refresh tokens are stored along with the timestamps Token OAuthToken `json:"token"` + + // BaseAuthorizationURL is the URL where authorization should take place BaseAuthorizationURL string `json:"base_authorization_url"` + + // TokenURL is the URL where tokens should be obtained TokenURL string `json:"token_url"` + + // session is the internal in progress OAuth session + session OAuthExchangeSession `json:"-"` } -// This structure gets passed to the callback for easy access to the current state +// OAuthExchangeSession is a structure that gets passed to the callback for easy access to the current state type OAuthExchangeSession struct { - // returned from the callback + // CallbackError indicates an error returned by the server CallbackError error - // filled in in initialize + // ClientID is the ID of the OAuth client ClientID string + + // ISS indicates the issuer inditifer ISS string + + // State is the expected URL state paremeter State string + + // Verifier is the preimage of the challenge Verifier string - // filled in when constructing the callback + // Context is the context used for cancellation Context context.Context + + // Server is the server of the session Server *http.Server + + // Listener is the listener where the servers 'listens' on Listener net.Listener } -// Struct that defines the json format for /.well-known/vpn-user-portal" +// OAuthToken is a structure that defines the json format for /.well-known/vpn-user-portal" type OAuthToken struct { + // Access is the access token returned by the server Access string `json:"access_token"` + + // Refresh token is the refresh token returned by the server Refresh string `json:"refresh_token"` + + // Type indicates which type of tokens we have Type string `json:"token_type"` + + // Expires is the expires time returned by the server Expires int64 `json:"expires_in"` + + // ExpiredTimestamp is the Expires field but converted to a Go timestamp ExpiredTimestamp time.Time `json:"expires_in_timestamp"` } -// Sets up a listener +// setupListener sets up an OAuth listener +// If it was unsuccessful it returns an error func (oauth *OAuth) setupListener() error { errorMessage := "failed setting up listener" - oauth.Session.Context = context.Background() + oauth.session.Context = context.Background() // create a listener listener, listenerErr := net.Listen("tcp", ":0") if listenerErr != nil { return types.NewWrappedError(errorMessage, listenerErr) } - oauth.Session.Listener = listener + oauth.session.Listener = listener return nil } +// getTokensWithCallback gets the OAuth tokens using a local web server +// If it was unsuccessful it returns an error func (oauth *OAuth) getTokensWithCallback() error { errorMessage := "failed getting tokens with callback" - if oauth.Session.Listener == nil { + if oauth.session.Listener == nil { return types.NewWrappedError(errorMessage, errors.New("no listener")) } mux := http.NewServeMux() // server /callback over the listener address - oauth.Session.Server = &http.Server{ + oauth.session.Server = &http.Server{ Handler: mux, } mux.HandleFunc("/callback", oauth.Callback) - if err := oauth.Session.Server.Serve(oauth.Session.Listener); err != http.ErrServerClosed { + if err := oauth.session.Server.Serve(oauth.session.Listener); err != http.ErrServerClosed { return types.NewWrappedError(errorMessage, err) } - return oauth.Session.CallbackError + return oauth.session.CallbackError } -// Get the access and refresh tokens +// getTokensWithAuthCode gets the access and refresh tokens using the authorization code // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 +// If it was unsuccessful it returns an error func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { errorMessage := "failed getting tokens with the authorization code" // Make sure the verifier is set as the parameter @@ -151,9 +189,9 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { } data := url.Values{ - "client_id": {oauth.Session.ClientID}, + "client_id": {oauth.session.ClientID}, "code": {authCode}, - "code_verifier": {oauth.Session.Verifier}, + "code_verifier": {oauth.session.Verifier}, "grant_type": {"authorization_code"}, "redirect_uri": {fmt.Sprintf("http://127.0.0.1:%d/callback", port)}, } @@ -185,15 +223,17 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { return nil } +// isTokensExpired returns if the OAuth tokens are expired using the expired timestamp func (oauth *OAuth) isTokensExpired() bool { expiredTime := oauth.Token.ExpiredTimestamp currentTime := time.Now() return !currentTime.Before(expiredTime) } -// Get the access and refresh tokens with a previously received refresh token +// getTokensWithRefresh gets the access and refresh tokens with a previously received refresh token // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 +// If it was unsuccessful it returns an error func (oauth *OAuth) getTokensWithRefresh() error { errorMessage := "failed getting tokens with the refresh token" reqURL := oauth.TokenURL @@ -228,7 +268,8 @@ func (oauth *OAuth) getTokensWithRefresh() error { return nil } -// Adapted from: https://github.com/eduvpn/apple/blob/5b18f834be7aebfed00570ae0c2f7bcbaf1c69cc/EduVPN/Helpers/Mac/OAuthRedirectHTTPHandler.m#L25 +// responseTemplate is the HTML template for the OAuth authorized response +// this template was dapted from: https://github.com/eduvpn/apple/blob/5b18f834be7aebfed00570ae0c2f7bcbaf1c69cc/EduVPN/Helpers/Mac/OAuthRedirectHTTPHandler.m#L25 const responseTemplate string = ` <!DOCTYPE html> <html dir="ltr" xmlns="http://www.w3.org/1999/xhtml" lang="en"><head> @@ -265,11 +306,14 @@ main { </html> ` +// oauthResponseHTML is a structure that is used to give back the OAuth response type oauthResponseHTML struct { Title string Message string } +// writeResponseHTML writes the OAuth response using a response writer and the title + message +// If it was unsuccessful it returns an error func writeResponseHTML(w http.ResponseWriter, title string, message string) error { errorMessage := "failed writing response HTML" template, templateErr := template.New("oauth-response").Parse(responseTemplate) @@ -287,21 +331,21 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro return nil } -// -//// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 +// Callback is the public function used to get the OAuth tokens using an authorization code callback +// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { errorMessage := "failed callback to retrieve the authorization code" // Shutdown after we're done defer func() { // writing the html is best effort - if oauth.Session.CallbackError != nil { + if oauth.session.CallbackError != nil { _ = writeResponseHTML(w, "Authorization Failed", "The authorization has failed. See the log file for more information.") } else { _ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.") } - if oauth.Session.Server != nil { - go oauth.Session.Server.Shutdown(oauth.Session.Context) //nolint:errcheck + if oauth.session.Server != nil { + go oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck } }() @@ -310,10 +354,10 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { urlQuery := req.URL.Query() extractedISS := urlQuery.Get("iss") if extractedISS != "" { - if oauth.Session.ISS != extractedISS { - oauth.Session.CallbackError = types.NewWrappedError( + if oauth.session.ISS != extractedISS { + oauth.session.CallbackError = types.NewWrappedError( errorMessage, - &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.Session.ISS}, + &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.session.ISS}, ) return } @@ -324,19 +368,19 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 extractedState := urlQuery.Get("state") if extractedState == "" { - oauth.Session.CallbackError = types.NewWrappedError( + oauth.session.CallbackError = types.NewWrappedError( errorMessage, &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}, ) return } // The state is the first entry - if extractedState != oauth.Session.State { - oauth.Session.CallbackError = types.NewWrappedError( + if extractedState != oauth.session.State { + oauth.session.CallbackError = types.NewWrappedError( errorMessage, &OAuthCallbackStateMatchError{ State: extractedState, - ExpectedState: oauth.Session.State, + ExpectedState: oauth.session.State, }, ) return @@ -345,7 +389,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // No authorization code extractedCode := urlQuery.Get("code") if extractedCode == "" { - oauth.Session.CallbackError = types.NewWrappedError( + oauth.session.CallbackError = types.NewWrappedError( errorMessage, &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}, ) @@ -356,7 +400,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // Obtaining the access and refresh tokens getTokensErr := oauth.getTokensWithAuthCode(extractedCode) if getTokensErr != nil { - oauth.Session.CallbackError = types.NewWrappedError( + oauth.session.CallbackError = types.NewWrappedError( errorMessage, getTokensErr, ) @@ -364,22 +408,28 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { } } +// Init initializes OAuth with the following parameters: +// - OAuth server issuer identification +// - The URL used for authorization +// - The URL to obtain new tokens func (oauth *OAuth) Init(iss string, baseAuthorizationURL string, tokenURL string) { oauth.ISS = iss oauth.BaseAuthorizationURL = baseAuthorizationURL oauth.TokenURL = tokenURL } +// GetListenerPort gets the listener for the OAuth web server +// It returns the port as an integer and an error if there is any func (oauth OAuth) GetListenerPort() (int, error) { errorMessage := "failed to get listener port" - if oauth.Session.Listener == nil { + if oauth.session.Listener == nil { return 0, types.NewWrappedError(errorMessage, errors.New("no OAuth listener")) } - return oauth.Session.Listener.Addr().(*net.TCPAddr).Port, nil + return oauth.session.Listener.Addr().(*net.TCPAddr).Port, nil } -// Starts the OAuth exchange for eduvpn. +// GetAuthURL gets the authorization url to start the OAuth procedure func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) (string, error) { errorMessage := "failed starting OAuth exchange" @@ -398,7 +448,7 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) // Fill the struct with the necessary fields filled for the next call to getting the HTTP client oauthSession := OAuthExchangeSession{ClientID: name, ISS: oauth.ISS, State: state, Verifier: verifier} - oauth.Session = oauthSession + oauth.session = oauthSession // set up the listener to get the redirect URI listenerErr := oauth.setupListener() @@ -432,7 +482,8 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) return postProcessAuth(authURL), nil } -// Error definitions +// Exchange starts the OAuth exchange by getting the tokens with the redirect callback +// If it was unsuccessful it returns an error func (oauth *OAuth) Exchange() error { tokenErr := oauth.getTokensWithCallback() @@ -442,17 +493,21 @@ func (oauth *OAuth) Exchange() error { return nil } +// Cancel cancels the existing OAuth +// TODO: Use context for this func (oauth *OAuth) Cancel() { - oauth.Session.CallbackError = types.NewWrappedErrorLevel( + oauth.session.CallbackError = types.NewWrappedErrorLevel( types.ErrInfo, "cancelled OAuth", &OAuthCancelledCallbackError{}, ) - if oauth.Session.Server != nil { - oauth.Session.Server.Shutdown(oauth.Session.Context) //nolint:errcheck + if oauth.session.Server != nil { + oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck } } +// EnsureTokens makes sure the OAuth tokens are still valid +// if this cannot be guaranteed, it returns an error func (oauth *OAuth) EnsureTokens() error { errorMessage := "failed ensuring OAuth tokens" // Access Token or Refresh Tokens empty, we can not ensure the tokens diff --git a/internal/util/util.go b/internal/util/util.go index cbe9c1b..8b39b9f 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -1,3 +1,4 @@ +// package util implements several utility functions that are used across the codebase package util import ( @@ -11,6 +12,12 @@ import ( "github.com/eduvpn/eduvpn-common/types" ) +// EnsureValidURL ensures that the input URL is valid to be used internally +// It does the following +// - Sets the scheme to https if none is given +// - It 'cleans' up the path using path.Clean +// - It makes sure that the URL ends with a / +// It returns an error if the URL cannot be parsed func EnsureValidURL(s string) (string, error) { parsedURL, parseErr := url.Parse(s) if parseErr != nil { @@ -38,7 +45,8 @@ func EnsureValidURL(s string) (string, error) { return returnedURL, nil } -// Creates a random byteslice of `size` +// MakeRandomByteSlice creates a cryptographically random byteslice of `size` +// It returns the byte slice (or nil if error) and an error if it could not be generated. func MakeRandomByteSlice(size int) ([]byte, error) { byteSlice := make([]byte, size) _, err := rand.Read(byteSlice) @@ -48,6 +56,7 @@ func MakeRandomByteSlice(size int) ([]byte, error) { return byteSlice, nil } +// EnsureDirectory creates a directory with permission 700 func EnsureDirectory(directory string) error { // Create with 700 permissions, read, write, execute only for the owner mkdirErr := os.MkdirAll(directory, 0o700) @@ -60,6 +69,7 @@ func EnsureDirectory(directory string) error { return nil } +// WAYFEncode an input URL using 'skip Where Are You From' encoding // See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md // URL encode for skipping where are you from (WAYF). Note that this right now is basically an alias to QueryEscape func WAYFEncode(input string) string { @@ -68,6 +78,7 @@ func WAYFEncode(input string) string { return url.QueryEscape(input) } +// ReplaceWAYF replaces an authorization template containing of @RETURN_TO@ and @ORG_ID@ with the authorization URL and the organization ID // See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md func ReplaceWAYF(authTemplate string, authURL string, orgID string) string { // We just return the authURL in the cases where the template is not given or is invalid @@ -81,18 +92,19 @@ func ReplaceWAYF(authTemplate string, authURL string, orgID string) string { return authURL } // Replace authURL - authTemplate = strings.Replace(authTemplate, "@RETURN_TO@", WAYFEncode(authURL), 1) + authTemplate = strings.Replace(authTemplate, "@RETURN_TO@", url.QueryEscape(authURL), 1) // If now there is no more ORG_ID, return as there weren't enough @ symbols if !strings.Contains(authTemplate, "@ORG_ID@") { return authURL } // Replace ORG ID - authTemplate = strings.Replace(authTemplate, "@ORG_ID@", WAYFEncode(orgID), 1) + authTemplate = strings.Replace(authTemplate, "@ORG_ID@", url.QueryEscape(orgID), 1) return authTemplate } -// https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY.md#language-matching +// GetLanguageMatched uses a map from language tags to strings to extract the right language given the tag +// It implements it according to https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY.md#language-matching func GetLanguageMatched(languageMap map[string]string, languageTag string) string { // If no map is given, return the empty string if len(languageMap) == 0 { diff --git a/internal/util/util_test.go b/internal/util/util_test.go index bb76752..be19b11 100644 --- a/internal/util/util_test.go +++ b/internal/util/util_test.go @@ -57,18 +57,6 @@ func TestMakeRandomByteSlice(t *testing.T) { } } -func TestWAYFEncode(t *testing.T) { - // AuthTemplate - returnTo := "127.0.0.1:8000/test123bla/#wow " - - // URL encoding but with spaces replace as + instead of %20 - wantReturnTo := "127.0.0.1%3A8000%2Ftest123bla%2F%23wow+++" - encode := WAYFEncode(returnTo) - if encode != wantReturnTo { - t.Fatalf("Got: %s, want: %s", encode, wantReturnTo) - } -} - func TestReplaceWAYF(t *testing.T) { // We expect url encoding but the spaces to be correctly replace with a + instead of a %20 // And we expect that the return to and org_id are correctly replaced diff --git a/internal/verify/verify.go b/internal/verify/verify.go index 6432619..98a9c67 100644 --- a/internal/verify/verify.go +++ b/internal/verify/verify.go @@ -1,3 +1,4 @@ +// package verify implement signature verification using minisign package verify import ( diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index bbb22e4..d9bd974 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -1,3 +1,4 @@ +// package wireguard implements a few helpers for the WireGuard protocol package wireguard import ( @@ -8,6 +9,8 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +// GenerateKey generates a WireGuard private key using wgctrl +// It returns an error if key generation failed func GenerateKey() (wgtypes.Key, error) { key, keyErr := wgtypes.GeneratePrivateKey() @@ -20,6 +23,7 @@ func GenerateKey() (wgtypes.Key, error) { return key, nil } +// ConfigAddKey takes the WireGuard configuration and adds the PrivateKey to the right section // FIXME: Instead of doing a regex replace, decide if we should use a parser func ConfigAddKey(config string, key wgtypes.Key) string { interfaceSection := "[Interface]" |
