summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-11-28 11:18:14 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-11-28 11:18:42 +0100
commite9f8db8ee8fccf60e58deb1d72766f94a053bb16 (patch)
treeffa5a9be67717ecc8ff7bdc03d5f96028facb0e3 /internal
parentb4ff890ec2b459148d893499a34a6d2954530369 (diff)
Document: Add comments for most functions and packages
Errors and test files still need to be done. Also some getters are changed by removing the 'get' prefix
Diffstat (limited to 'internal')
-rw-r--r--internal/config/config.go18
-rw-r--r--internal/discovery/discovery.go102
-rw-r--r--internal/fsm/fsm.go71
-rw-r--r--internal/http/http.go22
-rw-r--r--internal/log/log.go81
-rw-r--r--internal/oauth/oauth.go141
-rw-r--r--internal/util/util.go20
-rw-r--r--internal/util/util_test.go12
-rw-r--r--internal/verify/verify.go1
-rw-r--r--internal/wireguard/wireguard.go4
10 files changed, 318 insertions, 154 deletions
diff --git a/internal/config/config.go b/internal/config/config.go
index 180b881..fa3045b 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -1,3 +1,5 @@
+// Package config implements functions for saving a struct to a file
+// It then provides functions to later read it such that we can restore the same struct
package config
import (
@@ -10,21 +12,29 @@ import (
"github.com/eduvpn/eduvpn-common/types"
)
+// Config represents a configuration that saves the client's struct as JSON
type Config struct {
+ // Directory represents the path to where the data is saved
Directory string
+
+ // Name defines the name of file excluding the .json extension
Name string
}
+// Init initializes the configuration using the provided directory and name
func (config *Config) Init(directory string, name string) {
config.Directory = directory
config.Name = name
}
-func (config *Config) GetFilename() string {
+// filename returns the filename of the configuration as a full path
+func (config *Config) filename() string {
pathString := path.Join(config.Directory, config.Name)
return fmt.Sprintf("%s.json", pathString)
}
+// Save saves a structure 'readStruct' to the configuration
+// If it was unusuccessful, an an error is returned
func (config *Config) Save(readStruct interface{}) error {
errorMessage := "failed saving configuration"
configDirErr := util.EnsureDirectory(config.Directory)
@@ -35,11 +45,13 @@ func (config *Config) Save(readStruct interface{}) error {
if marshalErr != nil {
return types.NewWrappedError(errorMessage, marshalErr)
}
- return ioutil.WriteFile(config.GetFilename(), jsonString, 0o600)
+ return ioutil.WriteFile(config.filename(), jsonString, 0o600)
}
+// Load loads the configuration and writes the structure to 'writeStruct'
+// If it was unsuccessful, an error is returned.
func (config *Config) Load(writeStruct interface{}) error {
- bytes, readErr := ioutil.ReadFile(config.GetFilename())
+ bytes, readErr := ioutil.ReadFile(config.filename())
if readErr != nil {
return types.NewWrappedError("failed loading configuration", readErr)
}
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index 9102593..7df209c 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -1,3 +1,4 @@
+// package discovery implements the server discovery by contacting disco.eduvpn.org and returning the data as a Go structure
package discovery
import (
@@ -10,12 +11,18 @@ import (
"github.com/eduvpn/eduvpn-common/types"
)
+
+// Discovery is the main structure used for this package
type Discovery struct {
- Organizations types.DiscoveryOrganizations
- Servers types.DiscoveryServers
+ // organizations represents the organizations that are returned by the discovery server
+ organizations types.DiscoveryOrganizations
+
+ // servers represents the servers that are returned by the discovery server
+ servers types.DiscoveryServers
}
-// Helper function that gets a disco json and fills the structure with it
+// getDiscoFile is a helper function that gets a disco json and fills the structure with it
+// If it was unsuccessful it returns an error
func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) error {
errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile)
// Get json data
@@ -61,6 +68,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
return nil
}
+// DetermineOrganizationsUpdate returns a boolean indicating whether or not the discovery organizations should be updated
// FIXME: Implement based on
// https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md
// - [IMPLEMENTED] on "first launch" when offering the search for "Institute Access" and "Organizations";
@@ -68,12 +76,13 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
// - [TODO] when the authorization for the server associated with an already chosen organization is triggered, e.g. after expiry or revocation.
// - [IMPLEMENTED using a custom error message] NOTE: when the org_id that the user chose previously is no longer available in organization_list.json the application should ask the user to choose their organization (again). This can occur for example when the organization replaced their identity provider, uses a different domain after rebranding or simply ceased to exist.
func (discovery *Discovery) DetermineOrganizationsUpdate() bool {
- return discovery.Organizations.Timestamp.IsZero()
+ return discovery.organizations.Timestamp.IsZero()
}
-func (discovery *Discovery) GetSecureLocationList() []string {
+// SecureLocationList returns a slice of all the available locations
+func (discovery *Discovery) SecureLocationList() []string {
var locations []string
- for _, currentServer := range discovery.Servers.List {
+ for _, currentServer := range discovery.servers.List {
if currentServer.Type == "secure_internet" {
locations = append(locations, currentServer.CountryCode)
}
@@ -81,38 +90,44 @@ func (discovery *Discovery) GetSecureLocationList() []string {
return locations
}
-func (discovery *Discovery) GetServerByURL(
- url string,
- _type string,
+// ServerByURL returns the discovery server by the base URL and the according type ("secure_internet", "institute_access")
+// An error is returned if and only if nil is returned for the server
+func (discovery *Discovery) ServerByURL(
+ baseURL string,
+ serverType string,
) (*types.DiscoveryServer, error) {
- for _, currentServer := range discovery.Servers.List {
- if currentServer.BaseURL == url && currentServer.Type == _type {
+ for _, currentServer := range discovery.servers.List {
+ if currentServer.BaseURL == baseURL && currentServer.Type == serverType {
return &currentServer, nil
}
}
return nil, types.NewWrappedError(
"failed getting server by URL from discovery",
- &GetServerByURLNotFoundError{URL: url, Type: _type},
+ &GetServerByURLNotFoundError{URL: baseURL, Type: serverType},
)
}
-func (discovery *Discovery) GetServerByCountryCode(
- code string,
- _type string,
+// ServerByCountryCode returns the discovery server by the country code and the according type ("secure_internet", "institute_access")
+// An error is returned if and only if nil is returned for the server
+func (discovery *Discovery) ServerByCountryCode(
+ countryCode string,
+ serverType string,
) (*types.DiscoveryServer, error) {
- for _, currentServer := range discovery.Servers.List {
- if currentServer.CountryCode == code && currentServer.Type == _type {
+ for _, currentServer := range discovery.servers.List {
+ if currentServer.CountryCode == countryCode && currentServer.Type == serverType {
return &currentServer, nil
}
}
return nil, types.NewWrappedError(
- "failed getting server by country code from discovery",
- &GetServerByCountryCodeNotFoundError{CountryCode: code, Type: _type},
+ "failed getting server by country countryCode from discovery",
+ &GetServerByCountryCodeNotFoundError{CountryCode: countryCode, Type: serverType},
)
}
-func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganization, error) {
- for _, organization := range discovery.Organizations.List {
+// orgByID returns the discovery organization by the organization ID
+// An error is returned if and only if nil is returned for the organization
+func (discovery *Discovery) orgByID(orgID string) (*types.DiscoveryOrganization, error) {
+ for _, organization := range discovery.organizations.List {
if organization.OrgID == orgID {
return &organization, nil
}
@@ -123,11 +138,15 @@ func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganizati
)
}
-func (discovery *Discovery) GetSecureHomeArgs(
+// SecureHomeArgs returns the secure internet home server arguments:
+// - The organization it belongs to
+// - The secure internet server itself
+// An error is returned if and only if nil is returned for the organization
+func (discovery *Discovery) SecureHomeArgs(
orgID string,
) (*types.DiscoveryOrganization, *types.DiscoveryServer, error) {
errorMessage := "failed getting Secure Internet Home arguments from discovery"
- org, orgErr := discovery.getOrgByID(orgID)
+ org, orgErr := discovery.orgByID(orgID)
if orgErr != nil {
return nil, nil, types.NewWrappedError(errorMessage, orgErr)
@@ -136,7 +155,7 @@ func (discovery *Discovery) GetSecureHomeArgs(
// Get a server with the base url
url := org.SecureInternetHome
- currentServer, serverErr := discovery.GetServerByURL(url, "secure_internet")
+ currentServer, serverErr := discovery.ServerByURL(url, "secure_internet")
if serverErr != nil {
return nil, nil, types.NewWrappedError(errorMessage, serverErr)
@@ -144,55 +163,58 @@ func (discovery *Discovery) GetSecureHomeArgs(
return org, currentServer, nil
}
+// DetermineServersUpdate returns whether or not the discovery servers should be updated by contacting the discovery server
// https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md
// - [Implemented] The application MUST always fetch the server_list.json at application start.
// - The application MAY refresh the server_list.json periodically, e.g. once every hour.
func (discovery *Discovery) DetermineServersUpdate() bool {
// No servers, we should update
- if discovery.Servers.Timestamp.IsZero() {
+ if discovery.servers.Timestamp.IsZero() {
return true
}
// 1 hour from the last update
- shouldUpdateTime := discovery.Servers.Timestamp.Add(1 * time.Hour)
+ shouldUpdateTime := discovery.servers.Timestamp.Add(1 * time.Hour)
now := time.Now()
return !now.Before(shouldUpdateTime)
}
-// Get the organization list
-func (discovery *Discovery) GetOrganizationsList() (*types.DiscoveryOrganizations, error) {
+// Organizations returns the discovery organizations
+// If there was an error, a cached copy is returned if available
+func (discovery *Discovery) Organizations() (*types.DiscoveryOrganizations, error) {
if !discovery.DetermineOrganizationsUpdate() {
- return &discovery.Organizations, nil
+ return &discovery.organizations, nil
}
file := "organization_list.json"
- bodyErr := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations)
+ bodyErr := getDiscoFile(file, discovery.organizations.Version, &discovery.organizations)
if bodyErr != nil {
// Return previous with an error
- return &discovery.Organizations, types.NewWrappedError(
+ return &discovery.organizations, types.NewWrappedError(
"failed getting organizations in Discovery",
bodyErr,
)
}
- discovery.Organizations.Timestamp = time.Now()
- return &discovery.Organizations, nil
+ discovery.organizations.Timestamp = time.Now()
+ return &discovery.organizations, nil
}
-// Get the server list
-func (discovery *Discovery) GetServersList() (*types.DiscoveryServers, error) {
+// Servers returns the discovery servers
+// If there was an error, a cached copy is returned if available
+func (discovery *Discovery) Servers() (*types.DiscoveryServers, error) {
if !discovery.DetermineServersUpdate() {
- return &discovery.Servers, nil
+ return &discovery.servers, nil
}
file := "server_list.json"
- bodyErr := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers)
+ bodyErr := getDiscoFile(file, discovery.servers.Version, &discovery.servers)
if bodyErr != nil {
// Return previous with an error
- return &discovery.Servers, types.NewWrappedError(
+ return &discovery.servers, types.NewWrappedError(
"failed getting servers in Discovery",
bodyErr,
)
}
// Update servers timestamp
- discovery.Servers.Timestamp = time.Now()
- return &discovery.Servers, nil
+ discovery.servers.Timestamp = time.Now()
+ return &discovery.servers, nil
}
type GetOrgByIDNotFoundError struct {
diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go
index b51a5c9..b8fd644 100644
--- a/internal/fsm/fsm.go
+++ b/internal/fsm/fsm.go
@@ -1,3 +1,5 @@
+// Package fsm defines a finite state machine and has the ability to save this state machine to a graph file
+// This graph file can be visualized using mermaid.js
package fsm
import (
@@ -10,7 +12,9 @@ import (
)
type (
+ //StateID represents the Identifier of the state
FSMStateID int8
+ //StateIDSlice represents the list of state identifiers
FSMStateIDSlice []FSMStateID
)
@@ -26,8 +30,11 @@ func (v FSMStateIDSlice) Swap(i, j int) {
v[i], v[j] = v[j], v[i]
}
+// Transition indicates an arrow in the state graph
type FSMTransition struct {
+ // To represents the to-be-new state
To FSMStateID
+ // Description is what type of message the arrow gets in the graph
Description string
}
@@ -35,45 +42,60 @@ type (
FSMStates map[FSMStateID]FSMState
)
+// State represents a single node in the graph
type FSMState struct {
+ // Transitions indicates which out arrows this node has
Transitions []FSMTransition
-
- // Which state to go back to on a back transition
- BackState FSMStateID
}
+// FSM represents the total graph
type FSM struct {
+ // States is the map from state ID to states
States FSMStates
+
+ // Current is the current state represented by the identifier
Current FSMStateID
- // Info to be passed from the parent state
+ // Name represents the descriptive name of this state machine
Name string
+
+ // StateCallback is the function ran when a transition occurs
+ // It takes the old state, the new state and the data and returns if this is handled by the client
StateCallback func(FSMStateID, FSMStateID, interface{}) bool
+
+ // Directory represents the path where the state graph is stored
Directory string
- Debug bool
- GetName func(FSMStateID) string
+
+ // Generate represents whether we want to generate the graph
+ Generate bool
+
+ // GetStateName gets the name of a state as a string
+ GetStateName func(FSMStateID) string
}
+// Init initializes the state machine and sets it to the given current state
func (fsm *FSM) Init(
current FSMStateID,
states map[FSMStateID]FSMState,
callback func(FSMStateID, FSMStateID, interface{}) bool,
directory string,
nameGen func(FSMStateID) string,
- debug bool,
+ generate bool,
) {
fsm.States = states
fsm.Current = current
fsm.StateCallback = callback
fsm.Directory = directory
- fsm.GetName = nameGen
- fsm.Debug = debug
+ fsm.GetStateName = nameGen
+ fsm.Generate = generate
}
+// InState returns whether or not the state machine is in the given 'check' state
func (fsm *FSM) InState(check FSMStateID) bool {
return check == fsm.Current
}
+// HasTransition checks whether or not the state machine has a transition to the given 'check' state
func (fsm *FSM) HasTransition(check FSMStateID) bool {
for _, transitionState := range fsm.States[fsm.Current].Transitions {
if transitionState.To == check {
@@ -84,11 +106,13 @@ func (fsm *FSM) HasTransition(check FSMStateID) bool {
return false
}
+// getGraphFilename gets the full path to the graph filename including the .graph extension
func (fsm *FSM) getGraphFilename(extension string) string {
debugPath := path.Join(fsm.Directory, "graph")
return fmt.Sprintf("%s%s", debugPath, extension)
}
+// writeGraph writes the state machine to a .graph file
func (fsm *FSM) writeGraph() {
graph := fsm.GenerateGraph()
graphFile := fsm.getGraphFilename(".graph")
@@ -107,18 +131,18 @@ func (fsm *FSM) writeGraph() {
}
}
-func (fsm *FSM) GoBack() {
- fsm.GoTransition(fsm.States[fsm.Current].BackState)
-}
-
+// GoTransitionRequired transitions the state machine to a new state with associated state data 'data'
+// If this transition is not handled by the client, it returns an error
func (fsm *FSM) GoTransitionRequired(newState FSMStateID, data interface{}) error {
oldState := fsm.Current
if !fsm.GoTransitionWithData(newState, data) {
- return types.NewWrappedError("failed required transition", fmt.Errorf("required transition not handled, from: %s -> to: %s", fsm.GetName(oldState), fsm.GetName(newState)))
+ return types.NewWrappedError("failed required transition", fmt.Errorf("required transition not handled, from: %s -> to: %s", fsm.GetStateName(oldState), fsm.GetStateName(newState)))
}
return nil
}
+// GoTransitionWithData is a helper that transitions the state machine toward the 'newState' with associated state data 'data'
+// It returns whether or not the transition is handled by the client
func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool {
ok := fsm.HasTransition(newState)
@@ -126,7 +150,7 @@ func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool
if ok {
oldState := fsm.Current
fsm.Current = newState
- if fsm.Debug {
+ if fsm.Generate {
fsm.writeGraph()
}
@@ -136,11 +160,14 @@ func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool
return handled
}
+// GoTransition is an alias to call GoTransitionWithData but have an empty string as data
func (fsm *FSM) GoTransition(newState FSMStateID) bool {
// No data means the callback is never required
return fsm.GoTransitionWithData(newState, "")
}
+// generateMermaidGraph generates a graph suitable to be converted by the mermaid.js tool
+// it returns the graph as a string
func (fsm *FSM) generateMermaidGraph() string {
graph := "graph TD\n"
sortedFSM := make(FSMStateIDSlice, 0, len(fsm.States))
@@ -152,15 +179,15 @@ func (fsm *FSM) generateMermaidGraph() string {
transitions := fsm.States[state].Transitions
for _, transition := range transitions {
if state == fsm.Current {
- graph += "\nstyle " + fsm.GetName(state) + " fill:cyan\n"
+ graph += "\nstyle " + fsm.GetStateName(state) + " fill:cyan\n"
} else {
- graph += "\nstyle " + fsm.GetName(state) + " fill:white\n"
+ graph += "\nstyle " + fsm.GetStateName(state) + " fill:white\n"
}
- graph += fsm.GetName(
+ graph += fsm.GetStateName(
state,
- ) + "(" + fsm.GetName(
+ ) + "(" + fsm.GetStateName(
state,
- ) + ") " + "-->|" + transition.Description + "| " + fsm.GetName(
+ ) + ") " + "-->|" + transition.Description + "| " + fsm.GetStateName(
transition.To,
) + "\n"
}
@@ -168,8 +195,10 @@ func (fsm *FSM) generateMermaidGraph() string {
return graph
}
+// GenerateGraph generates a mermaid graph if the state machine is initialized
+// If the graph cannot be generated, it returns the empty string
func (fsm *FSM) GenerateGraph() string {
- if fsm.GetName != nil {
+ if fsm.GetStateName != nil {
return fsm.generateMermaidGraph()
}
diff --git a/internal/http/http.go b/internal/http/http.go
index a9d3ea2..b21b901 100644
--- a/internal/http/http.go
+++ b/internal/http/http.go
@@ -1,3 +1,4 @@
+// Package http defines higher level helpers for the net/http package
package http
import (
@@ -12,8 +13,10 @@ import (
"github.com/eduvpn/eduvpn-common/types"
)
+// The URLParemeters as the name suggests is a type used for the parameters in the URL
type URLParameters map[string]string
+// OptionalParams is a structure that defines the optional parameters that are given when making a HTTP call
type HTTPOptionalParams struct {
Headers http.Header
URLParameters URLParameters
@@ -21,7 +24,7 @@ type HTTPOptionalParams struct {
Timeout time.Duration
}
-// Construct an URL including on parameters
+// ConstructURL creates a URL with the included parameters
func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error) {
url, parseErr := url.Parse(baseURL)
if parseErr != nil {
@@ -44,23 +47,28 @@ func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error)
return url.String(), nil
}
-// Convenience functions
+// Get creates a Get request and returns the headers, body and an error
func HTTPGet(url string) (http.Header, []byte, error) {
return HTTPMethodWithOpts(http.MethodGet, url, nil)
}
+// Post creates a Post request and returns the headers, body and an error
func HTTPPost(url string, body url.Values) (http.Header, []byte, error) {
return HTTPMethodWithOpts(http.MethodGet, url, &HTTPOptionalParams{Body: body})
}
+// GetWithOpts creates a Get request with optional parameters and returns the headers, body and an error
func HTTPGetWithOpts(url string, opts *HTTPOptionalParams) (http.Header, []byte, error) {
return HTTPMethodWithOpts(http.MethodGet, url, opts)
}
+// PostWithOpts creates a Post request with optional parameters and returns the headers, body and an error
func HTTPPostWithOpts(url string, opts *HTTPOptionalParams) (http.Header, []byte, error) {
return HTTPMethodWithOpts(http.MethodPost, url, opts)
}
+// optionalURL ensures that the URL contains the optional parameters
+// it returns the url (with parameters if success) and an error indicating success
func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) {
if opts != nil {
url, urlErr := HTTPConstructURL(url, opts.URLParameters)
@@ -76,6 +84,7 @@ func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) {
return url, nil
}
+// optionalHeaders ensures that the HTTP request uses the optional headers if defined
func httpOptionalHeaders(req *http.Request, opts *HTTPOptionalParams) {
// Add headers
if opts != nil && req != nil && opts.Headers != nil {
@@ -85,6 +94,7 @@ func httpOptionalHeaders(req *http.Request, opts *HTTPOptionalParams) {
}
}
+// optionalBodyReader returns a HTTP body reader if there is a body, otherwise nil
func httpOptionalBodyReader(opts *HTTPOptionalParams) io.Reader {
if opts != nil && opts.Body != nil {
return strings.NewReader(opts.Body.Encode())
@@ -92,6 +102,8 @@ func httpOptionalBodyReader(opts *HTTPOptionalParams) io.Reader {
return nil
}
+// MethodWithOpts creates a HTTP request using a method (e.g. GET, POST), an url and optional parameters
+// It returns the HTTP headers, the body and an error if there is one
func HTTPMethodWithOpts(
method string,
url string,
@@ -155,12 +167,14 @@ func HTTPMethodWithOpts(
return resp.Header, body, nil
}
+// StatusError indicates that we have received a HTTP status error
type HTTPStatusError struct {
URL string
Body string
Status int
}
+// Error returns the StatusError as an error string
func (e *HTTPStatusError) Error() string {
return fmt.Sprintf(
"failed obtaining HTTP resource: %s as it gave an unsuccessful status code: %d. Body: %s",
@@ -170,12 +184,16 @@ func (e *HTTPStatusError) Error() string {
)
}
+// ParseJSONError indicates that the HTTP error is because of failed JSON parsing
+// It has the URL and the Body as context.
+// The underlying JSON parsing Err itself is also wrapped here.
type HTTPParseJSONError struct {
URL string
Body string
Err error
}
+// Error returns the ParseJSONError as an error string
func (e *HTTPParseJSONError) Error() string {
return fmt.Sprintf(
"failed parsing json %s for HTTP resource: %s with error: %v",
diff --git a/internal/log/log.go b/internal/log/log.go
index 2ab6549..7d032e9 100644
--- a/internal/log/log.go
+++ b/internal/log/log.go
@@ -1,3 +1,4 @@
+// Package log implements a basic level based logger
package log
import (
@@ -11,47 +12,60 @@ import (
"github.com/eduvpn/eduvpn-common/types"
)
+// FileLogger defines the type of logger that this package implements
+// As the name suggests, it saves the log to a file
type FileLogger struct {
+ // Level indicates which maximum level this logger actually forwards to the file
Level LogLevel
- File *os.File
+
+ // file represents a pointer to the open log file
+ file *os.File
}
type LogLevel int8
const (
- // No level set, not allowed
- LogNotSet LogLevel = iota
- // Log debug, this message is not an error but is there for debugging
- LogDebug
- // Log info, this message is not an error but is there for additional information
- LogInfo
- // Log only to provide a warning, the app still functions
- LogWarning
- // Log to provide a generic error, the app still functions but some functionality might not work
- LogError
- // Log to provide a fatal error, the app cannot function correctly when such an error occurs
- LogFatal
+ // LevelNotSet indicates level not set, not allowed
+ LevelNotSet LogLevel = iota
+
+ // LevelDebug indicates that the message is not an error but is there for debugging
+ LevelDebug
+
+ // LevelInfo indicates that the message is not an error but is there for additional information
+ LevelInfo
+
+ // LevelWarning indicates only a warning, the app still functions
+ LevelWarning
+
+ // LevelError indicates a generic error, the app still functions but some functionality might not work
+ LevelError
+
+ // LevelFatal indicates a fatal error, the app cannot function correctly when such an error occurs
+ LevelFatal
)
+// String returns the string of each level
func (e LogLevel) String() string {
switch e {
- case LogNotSet:
+ case LevelNotSet:
return "NOTSET"
- case LogDebug:
+ case LevelDebug:
return "DEBUG"
- case LogInfo:
+ case LevelInfo:
return "INFO"
- case LogWarning:
+ case LevelWarning:
return "WARNING"
- case LogError:
+ case LevelError:
return "ERROR"
- case LogFatal:
+ case LevelFatal:
return "FATAL"
default:
return "UNKNOWN"
}
}
+// Init initializes the logger by forwarding a max level 'level' and a directory 'directory' where the log should be stored
+// If the logger cannot be initialized, for example an error in opening the log file, an error is returned
func (logger *FileLogger) Init(level LogLevel, directory string) error {
errorMessage := "failed creating log"
@@ -60,7 +74,7 @@ func (logger *FileLogger) Init(level LogLevel, directory string) error {
return types.NewWrappedError(errorMessage, configDirErr)
}
logFile, logOpenErr := os.OpenFile(
- logger.getFilename(directory),
+ logger.filename(directory),
os.O_RDWR|os.O_CREATE|os.O_APPEND,
0o666,
)
@@ -69,11 +83,12 @@ func (logger *FileLogger) Init(level LogLevel, directory string) error {
}
multi := io.MultiWriter(os.Stdout, logFile)
log.SetOutput(multi)
- logger.File = logFile
+ logger.file = logFile
logger.Level = level
return nil
}
+// Inherit logs an error with a label using the error level of the error
func (logger *FileLogger) Inherit(label string, err error) {
level := types.GetErrorLevel(err)
@@ -90,36 +105,44 @@ func (logger *FileLogger) Inherit(label string, err error) {
}
}
+// Debug logs a message with parameters as level LevelDebug
func (logger *FileLogger) Debug(msg string, params ...interface{}) {
- logger.log(LogDebug, msg, params...)
+ logger.log(LevelDebug, msg, params...)
}
+// Debug logs a message with parameters as level LevelInfo
func (logger *FileLogger) Info(msg string, params ...interface{}) {
- logger.log(LogInfo, msg, params...)
+ logger.log(LevelInfo, msg, params...)
}
+// Debug logs a message with parameters as level LevelWarning
func (logger *FileLogger) Warning(msg string, params ...interface{}) {
- logger.log(LogWarning, msg, params...)
+ logger.log(LevelWarning, msg, params...)
}
+// Debug logs a message with parameters as level LevelError
func (logger *FileLogger) Error(msg string, params ...interface{}) {
- logger.log(LogError, msg, params...)
+ logger.log(LevelError, msg, params...)
}
+// Debug logs a message with parameters as level LevelFatal
func (logger *FileLogger) Fatal(msg string, params ...interface{}) {
- logger.log(LogFatal, msg, params...)
+ logger.log(LevelFatal, msg, params...)
}
+// Close closes the logger by closing the internal file
func (logger *FileLogger) Close() {
- logger.File.Close()
+ logger.file.Close()
}
-func (logger *FileLogger) getFilename(directory string) string {
+// filename returns the filename of the logger by returning the full path as a string
+func (logger *FileLogger) filename(directory string) string {
return path.Join(directory, "log")
}
+// log logs as level 'level' a message 'msg' with parameters 'params'
func (logger *FileLogger) log(level LogLevel, msg string, params ...interface{}) {
- if level >= logger.Level && logger.Level != LogNotSet {
+ if level >= logger.Level && logger.Level != LevelNotSet {
formattedMsg := fmt.Sprintf(msg, params...)
format := fmt.Sprintf("- Go - %s - %s", level.String(), formattedMsg)
// To log file
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 1f3b719..232f68c 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -1,3 +1,8 @@
+// package oauth implement an oauth client defined in e.g. rfc 6749
+// However, we try to follow some recommendations from the v2.1 oauth draft RFC
+// Some specific things we implement here:
+// - PKCE (RFC 7636)
+// - ISS (RFC 9207)
package oauth
import (
@@ -18,7 +23,7 @@ import (
"github.com/eduvpn/eduvpn-common/types"
)
-// Generates a random base64 string to be used for state
+// genState generates a random base64 string to be used for state
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1
// "state": OPTIONAL. An opaque value used by the client to maintain
// state between the request and callback. The authorization server
@@ -35,7 +40,7 @@ func genState() (string, error) {
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
}
-// Generates a sha256 base64 challenge from a verifier
+// genChallengeS256 generates a sha256 base64 challenge from a verifier
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.8
func genChallengeS256(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
@@ -44,7 +49,7 @@ func genChallengeS256(verifier string) string {
return base64.RawURLEncoding.EncodeToString(hash[:])
}
-// Generates a verifier
+// genVerifier generates a verifier
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1
// The code_verifier is a unique high-entropy cryptographically random
// string generated for each authorization request, using the unreserved
@@ -70,75 +75,108 @@ func genVerifier() (string, error) {
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
}
+// OAuth defines the main structure for this package
type OAuth struct {
+ // ISS indicates the issuer indentifier of the authorization server as defined in RFC 9207
ISS string `json:"iss"`
- Session OAuthExchangeSession `json:"-"`
+
+ // Token is where the access and refresh tokens are stored along with the timestamps
Token OAuthToken `json:"token"`
+
+ // BaseAuthorizationURL is the URL where authorization should take place
BaseAuthorizationURL string `json:"base_authorization_url"`
+
+ // TokenURL is the URL where tokens should be obtained
TokenURL string `json:"token_url"`
+
+ // session is the internal in progress OAuth session
+ session OAuthExchangeSession `json:"-"`
}
-// This structure gets passed to the callback for easy access to the current state
+// OAuthExchangeSession is a structure that gets passed to the callback for easy access to the current state
type OAuthExchangeSession struct {
- // returned from the callback
+ // CallbackError indicates an error returned by the server
CallbackError error
- // filled in in initialize
+ // ClientID is the ID of the OAuth client
ClientID string
+
+ // ISS indicates the issuer inditifer
ISS string
+
+ // State is the expected URL state paremeter
State string
+
+ // Verifier is the preimage of the challenge
Verifier string
- // filled in when constructing the callback
+ // Context is the context used for cancellation
Context context.Context
+
+ // Server is the server of the session
Server *http.Server
+
+ // Listener is the listener where the servers 'listens' on
Listener net.Listener
}
-// Struct that defines the json format for /.well-known/vpn-user-portal"
+// OAuthToken is a structure that defines the json format for /.well-known/vpn-user-portal"
type OAuthToken struct {
+ // Access is the access token returned by the server
Access string `json:"access_token"`
+
+ // Refresh token is the refresh token returned by the server
Refresh string `json:"refresh_token"`
+
+ // Type indicates which type of tokens we have
Type string `json:"token_type"`
+
+ // Expires is the expires time returned by the server
Expires int64 `json:"expires_in"`
+
+ // ExpiredTimestamp is the Expires field but converted to a Go timestamp
ExpiredTimestamp time.Time `json:"expires_in_timestamp"`
}
-// Sets up a listener
+// setupListener sets up an OAuth listener
+// If it was unsuccessful it returns an error
func (oauth *OAuth) setupListener() error {
errorMessage := "failed setting up listener"
- oauth.Session.Context = context.Background()
+ oauth.session.Context = context.Background()
// create a listener
listener, listenerErr := net.Listen("tcp", ":0")
if listenerErr != nil {
return types.NewWrappedError(errorMessage, listenerErr)
}
- oauth.Session.Listener = listener
+ oauth.session.Listener = listener
return nil
}
+// getTokensWithCallback gets the OAuth tokens using a local web server
+// If it was unsuccessful it returns an error
func (oauth *OAuth) getTokensWithCallback() error {
errorMessage := "failed getting tokens with callback"
- if oauth.Session.Listener == nil {
+ if oauth.session.Listener == nil {
return types.NewWrappedError(errorMessage, errors.New("no listener"))
}
mux := http.NewServeMux()
// server /callback over the listener address
- oauth.Session.Server = &http.Server{
+ oauth.session.Server = &http.Server{
Handler: mux,
}
mux.HandleFunc("/callback", oauth.Callback)
- if err := oauth.Session.Server.Serve(oauth.Session.Listener); err != http.ErrServerClosed {
+ if err := oauth.session.Server.Serve(oauth.session.Listener); err != http.ErrServerClosed {
return types.NewWrappedError(errorMessage, err)
}
- return oauth.Session.CallbackError
+ return oauth.session.CallbackError
}
-// Get the access and refresh tokens
+// getTokensWithAuthCode gets the access and refresh tokens using the authorization code
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
+// If it was unsuccessful it returns an error
func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
errorMessage := "failed getting tokens with the authorization code"
// Make sure the verifier is set as the parameter
@@ -151,9 +189,9 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
}
data := url.Values{
- "client_id": {oauth.Session.ClientID},
+ "client_id": {oauth.session.ClientID},
"code": {authCode},
- "code_verifier": {oauth.Session.Verifier},
+ "code_verifier": {oauth.session.Verifier},
"grant_type": {"authorization_code"},
"redirect_uri": {fmt.Sprintf("http://127.0.0.1:%d/callback", port)},
}
@@ -185,15 +223,17 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
return nil
}
+// isTokensExpired returns if the OAuth tokens are expired using the expired timestamp
func (oauth *OAuth) isTokensExpired() bool {
expiredTime := oauth.Token.ExpiredTimestamp
currentTime := time.Now()
return !currentTime.Before(expiredTime)
}
-// Get the access and refresh tokens with a previously received refresh token
+// getTokensWithRefresh gets the access and refresh tokens with a previously received refresh token
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
+// If it was unsuccessful it returns an error
func (oauth *OAuth) getTokensWithRefresh() error {
errorMessage := "failed getting tokens with the refresh token"
reqURL := oauth.TokenURL
@@ -228,7 +268,8 @@ func (oauth *OAuth) getTokensWithRefresh() error {
return nil
}
-// Adapted from: https://github.com/eduvpn/apple/blob/5b18f834be7aebfed00570ae0c2f7bcbaf1c69cc/EduVPN/Helpers/Mac/OAuthRedirectHTTPHandler.m#L25
+// responseTemplate is the HTML template for the OAuth authorized response
+// this template was dapted from: https://github.com/eduvpn/apple/blob/5b18f834be7aebfed00570ae0c2f7bcbaf1c69cc/EduVPN/Helpers/Mac/OAuthRedirectHTTPHandler.m#L25
const responseTemplate string = `
<!DOCTYPE html>
<html dir="ltr" xmlns="http://www.w3.org/1999/xhtml" lang="en"><head>
@@ -265,11 +306,14 @@ main {
</html>
`
+// oauthResponseHTML is a structure that is used to give back the OAuth response
type oauthResponseHTML struct {
Title string
Message string
}
+// writeResponseHTML writes the OAuth response using a response writer and the title + message
+// If it was unsuccessful it returns an error
func writeResponseHTML(w http.ResponseWriter, title string, message string) error {
errorMessage := "failed writing response HTML"
template, templateErr := template.New("oauth-response").Parse(responseTemplate)
@@ -287,21 +331,21 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro
return nil
}
-//
-//// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
+// Callback is the public function used to get the OAuth tokens using an authorization code callback
+// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
errorMessage := "failed callback to retrieve the authorization code"
// Shutdown after we're done
defer func() {
// writing the html is best effort
- if oauth.Session.CallbackError != nil {
+ if oauth.session.CallbackError != nil {
_ = writeResponseHTML(w, "Authorization Failed", "The authorization has failed. See the log file for more information.")
} else {
_ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.")
}
- if oauth.Session.Server != nil {
- go oauth.Session.Server.Shutdown(oauth.Session.Context) //nolint:errcheck
+ if oauth.session.Server != nil {
+ go oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
}
}()
@@ -310,10 +354,10 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
urlQuery := req.URL.Query()
extractedISS := urlQuery.Get("iss")
if extractedISS != "" {
- if oauth.Session.ISS != extractedISS {
- oauth.Session.CallbackError = types.NewWrappedError(
+ if oauth.session.ISS != extractedISS {
+ oauth.session.CallbackError = types.NewWrappedError(
errorMessage,
- &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.Session.ISS},
+ &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.session.ISS},
)
return
}
@@ -324,19 +368,19 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
extractedState := urlQuery.Get("state")
if extractedState == "" {
- oauth.Session.CallbackError = types.NewWrappedError(
+ oauth.session.CallbackError = types.NewWrappedError(
errorMessage,
&OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()},
)
return
}
// The state is the first entry
- if extractedState != oauth.Session.State {
- oauth.Session.CallbackError = types.NewWrappedError(
+ if extractedState != oauth.session.State {
+ oauth.session.CallbackError = types.NewWrappedError(
errorMessage,
&OAuthCallbackStateMatchError{
State: extractedState,
- ExpectedState: oauth.Session.State,
+ ExpectedState: oauth.session.State,
},
)
return
@@ -345,7 +389,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// No authorization code
extractedCode := urlQuery.Get("code")
if extractedCode == "" {
- oauth.Session.CallbackError = types.NewWrappedError(
+ oauth.session.CallbackError = types.NewWrappedError(
errorMessage,
&OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()},
)
@@ -356,7 +400,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// Obtaining the access and refresh tokens
getTokensErr := oauth.getTokensWithAuthCode(extractedCode)
if getTokensErr != nil {
- oauth.Session.CallbackError = types.NewWrappedError(
+ oauth.session.CallbackError = types.NewWrappedError(
errorMessage,
getTokensErr,
)
@@ -364,22 +408,28 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
}
}
+// Init initializes OAuth with the following parameters:
+// - OAuth server issuer identification
+// - The URL used for authorization
+// - The URL to obtain new tokens
func (oauth *OAuth) Init(iss string, baseAuthorizationURL string, tokenURL string) {
oauth.ISS = iss
oauth.BaseAuthorizationURL = baseAuthorizationURL
oauth.TokenURL = tokenURL
}
+// GetListenerPort gets the listener for the OAuth web server
+// It returns the port as an integer and an error if there is any
func (oauth OAuth) GetListenerPort() (int, error) {
errorMessage := "failed to get listener port"
- if oauth.Session.Listener == nil {
+ if oauth.session.Listener == nil {
return 0, types.NewWrappedError(errorMessage, errors.New("no OAuth listener"))
}
- return oauth.Session.Listener.Addr().(*net.TCPAddr).Port, nil
+ return oauth.session.Listener.Addr().(*net.TCPAddr).Port, nil
}
-// Starts the OAuth exchange for eduvpn.
+// GetAuthURL gets the authorization url to start the OAuth procedure
func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) (string, error) {
errorMessage := "failed starting OAuth exchange"
@@ -398,7 +448,7 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string)
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
oauthSession := OAuthExchangeSession{ClientID: name, ISS: oauth.ISS, State: state, Verifier: verifier}
- oauth.Session = oauthSession
+ oauth.session = oauthSession
// set up the listener to get the redirect URI
listenerErr := oauth.setupListener()
@@ -432,7 +482,8 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string)
return postProcessAuth(authURL), nil
}
-// Error definitions
+// Exchange starts the OAuth exchange by getting the tokens with the redirect callback
+// If it was unsuccessful it returns an error
func (oauth *OAuth) Exchange() error {
tokenErr := oauth.getTokensWithCallback()
@@ -442,17 +493,21 @@ func (oauth *OAuth) Exchange() error {
return nil
}
+// Cancel cancels the existing OAuth
+// TODO: Use context for this
func (oauth *OAuth) Cancel() {
- oauth.Session.CallbackError = types.NewWrappedErrorLevel(
+ oauth.session.CallbackError = types.NewWrappedErrorLevel(
types.ErrInfo,
"cancelled OAuth",
&OAuthCancelledCallbackError{},
)
- if oauth.Session.Server != nil {
- oauth.Session.Server.Shutdown(oauth.Session.Context) //nolint:errcheck
+ if oauth.session.Server != nil {
+ oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
}
}
+// EnsureTokens makes sure the OAuth tokens are still valid
+// if this cannot be guaranteed, it returns an error
func (oauth *OAuth) EnsureTokens() error {
errorMessage := "failed ensuring OAuth tokens"
// Access Token or Refresh Tokens empty, we can not ensure the tokens
diff --git a/internal/util/util.go b/internal/util/util.go
index cbe9c1b..8b39b9f 100644
--- a/internal/util/util.go
+++ b/internal/util/util.go
@@ -1,3 +1,4 @@
+// package util implements several utility functions that are used across the codebase
package util
import (
@@ -11,6 +12,12 @@ import (
"github.com/eduvpn/eduvpn-common/types"
)
+// EnsureValidURL ensures that the input URL is valid to be used internally
+// It does the following
+// - Sets the scheme to https if none is given
+// - It 'cleans' up the path using path.Clean
+// - It makes sure that the URL ends with a /
+// It returns an error if the URL cannot be parsed
func EnsureValidURL(s string) (string, error) {
parsedURL, parseErr := url.Parse(s)
if parseErr != nil {
@@ -38,7 +45,8 @@ func EnsureValidURL(s string) (string, error) {
return returnedURL, nil
}
-// Creates a random byteslice of `size`
+// MakeRandomByteSlice creates a cryptographically random byteslice of `size`
+// It returns the byte slice (or nil if error) and an error if it could not be generated.
func MakeRandomByteSlice(size int) ([]byte, error) {
byteSlice := make([]byte, size)
_, err := rand.Read(byteSlice)
@@ -48,6 +56,7 @@ func MakeRandomByteSlice(size int) ([]byte, error) {
return byteSlice, nil
}
+// EnsureDirectory creates a directory with permission 700
func EnsureDirectory(directory string) error {
// Create with 700 permissions, read, write, execute only for the owner
mkdirErr := os.MkdirAll(directory, 0o700)
@@ -60,6 +69,7 @@ func EnsureDirectory(directory string) error {
return nil
}
+// WAYFEncode an input URL using 'skip Where Are You From' encoding
// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md
// URL encode for skipping where are you from (WAYF). Note that this right now is basically an alias to QueryEscape
func WAYFEncode(input string) string {
@@ -68,6 +78,7 @@ func WAYFEncode(input string) string {
return url.QueryEscape(input)
}
+// ReplaceWAYF replaces an authorization template containing of @RETURN_TO@ and @ORG_ID@ with the authorization URL and the organization ID
// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md
func ReplaceWAYF(authTemplate string, authURL string, orgID string) string {
// We just return the authURL in the cases where the template is not given or is invalid
@@ -81,18 +92,19 @@ func ReplaceWAYF(authTemplate string, authURL string, orgID string) string {
return authURL
}
// Replace authURL
- authTemplate = strings.Replace(authTemplate, "@RETURN_TO@", WAYFEncode(authURL), 1)
+ authTemplate = strings.Replace(authTemplate, "@RETURN_TO@", url.QueryEscape(authURL), 1)
// If now there is no more ORG_ID, return as there weren't enough @ symbols
if !strings.Contains(authTemplate, "@ORG_ID@") {
return authURL
}
// Replace ORG ID
- authTemplate = strings.Replace(authTemplate, "@ORG_ID@", WAYFEncode(orgID), 1)
+ authTemplate = strings.Replace(authTemplate, "@ORG_ID@", url.QueryEscape(orgID), 1)
return authTemplate
}
-// https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY.md#language-matching
+// GetLanguageMatched uses a map from language tags to strings to extract the right language given the tag
+// It implements it according to https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY.md#language-matching
func GetLanguageMatched(languageMap map[string]string, languageTag string) string {
// If no map is given, return the empty string
if len(languageMap) == 0 {
diff --git a/internal/util/util_test.go b/internal/util/util_test.go
index bb76752..be19b11 100644
--- a/internal/util/util_test.go
+++ b/internal/util/util_test.go
@@ -57,18 +57,6 @@ func TestMakeRandomByteSlice(t *testing.T) {
}
}
-func TestWAYFEncode(t *testing.T) {
- // AuthTemplate
- returnTo := "127.0.0.1:8000/test123bla/#wow "
-
- // URL encoding but with spaces replace as + instead of %20
- wantReturnTo := "127.0.0.1%3A8000%2Ftest123bla%2F%23wow+++"
- encode := WAYFEncode(returnTo)
- if encode != wantReturnTo {
- t.Fatalf("Got: %s, want: %s", encode, wantReturnTo)
- }
-}
-
func TestReplaceWAYF(t *testing.T) {
// We expect url encoding but the spaces to be correctly replace with a + instead of a %20
// And we expect that the return to and org_id are correctly replaced
diff --git a/internal/verify/verify.go b/internal/verify/verify.go
index 6432619..98a9c67 100644
--- a/internal/verify/verify.go
+++ b/internal/verify/verify.go
@@ -1,3 +1,4 @@
+// package verify implement signature verification using minisign
package verify
import (
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index bbb22e4..d9bd974 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -1,3 +1,4 @@
+// package wireguard implements a few helpers for the WireGuard protocol
package wireguard
import (
@@ -8,6 +9,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
+// GenerateKey generates a WireGuard private key using wgctrl
+// It returns an error if key generation failed
func GenerateKey() (wgtypes.Key, error) {
key, keyErr := wgtypes.GeneratePrivateKey()
@@ -20,6 +23,7 @@ func GenerateKey() (wgtypes.Key, error) {
return key, nil
}
+// ConfigAddKey takes the WireGuard configuration and adds the PrivateKey to the right section
// FIXME: Instead of doing a regex replace, decide if we should use a parser
func ConfigAddKey(config string, key wgtypes.Key) string {
interfaceSection := "[Interface]"