diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/discovery/discovery.go | 4 | ||||
| -rw-r--r-- | internal/fsm/fsm.go | 52 | ||||
| -rw-r--r-- | internal/http/http.go | 48 | ||||
| -rw-r--r-- | internal/log/log.go | 12 | ||||
| -rw-r--r-- | internal/oauth/oauth.go | 62 | ||||
| -rw-r--r-- | internal/oauth/token.go | 10 | ||||
| -rw-r--r-- | internal/server/api.go | 22 | ||||
| -rw-r--r-- | internal/server/common.go | 88 | ||||
| -rw-r--r-- | internal/server/instituteaccess.go | 4 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 26 | ||||
| -rw-r--r-- | internal/verify/verify.go | 62 | ||||
| -rw-r--r-- | internal/verify/verify_test.go | 18 |
12 files changed, 204 insertions, 204 deletions
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 17d5b08..40fa165 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -28,7 +28,7 @@ func discoFile(jsonFile string, previousVersion uint64, structure interface{}) e // Get json data discoURL := "https://disco.eduvpn.org/v2/" fileURL := discoURL + jsonFile - _, fileBody, fileErr := http.HTTPGet(fileURL) + _, fileBody, fileErr := http.Get(fileURL) if fileErr != nil { return types.NewWrappedError(errorMessage, fileErr) @@ -37,7 +37,7 @@ func discoFile(jsonFile string, previousVersion uint64, structure interface{}) e // Get signature sigFile := jsonFile + ".minisig" sigURL := discoURL + sigFile - _, sigBody, sigFileErr := http.HTTPGet(sigURL) + _, sigBody, sigFileErr := http.Get(sigURL) if sigFileErr != nil { return types.NewWrappedError(errorMessage, sigFileErr) diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go index e6f3f3a..4114a32 100644 --- a/internal/fsm/fsm.go +++ b/internal/fsm/fsm.go @@ -12,56 +12,56 @@ import ( ) type ( - //StateID represents the Identifier of the state. - FSMStateID int8 - //StateIDSlice represents the list of state identifiers. - FSMStateIDSlice []FSMStateID + // StateID represents the Identifier of the state. + StateID int8 + // StateIDSlice represents the list of state identifiers. + StateIDSlice []StateID ) -func (v FSMStateIDSlice) Len() int { +func (v StateIDSlice) Len() int { return len(v) } -func (v FSMStateIDSlice) Less(i, j int) bool { +func (v StateIDSlice) Less(i, j int) bool { return v[i] < v[j] } -func (v FSMStateIDSlice) Swap(i, j int) { +func (v StateIDSlice) Swap(i, j int) { v[i], v[j] = v[j], v[i] } // Transition indicates an arrow in the state graph. -type FSMTransition struct { +type Transition struct { // To represents the to-be-new state - To FSMStateID + To StateID // Description is what type of message the arrow gets in the graph Description string } type ( - FSMStates map[FSMStateID]FSMState + States map[StateID]State ) // State represents a single node in the graph. -type FSMState struct { +type State struct { // Transitions indicates which out arrows this node has - Transitions []FSMTransition + Transitions []Transition } // FSM represents the total graph. type FSM struct { // States is the map from state ID to states - States FSMStates + States States // Current is the current state represented by the identifier - Current FSMStateID + Current StateID // Name represents the descriptive name of this state machine Name string // StateCallback is the function ran when a transition occurs // It takes the old state, the new state and the data and returns if this is handled by the client - StateCallback func(FSMStateID, FSMStateID, interface{}) bool + StateCallback func(StateID, StateID, interface{}) bool // Directory represents the path where the state graph is stored Directory string @@ -70,16 +70,16 @@ type FSM struct { Generate bool // GetStateName gets the name of a state as a string - GetStateName func(FSMStateID) string + GetStateName func(StateID) string } // Init initializes the state machine and sets it to the given current state. func (fsm *FSM) Init( - current FSMStateID, - states map[FSMStateID]FSMState, - callback func(FSMStateID, FSMStateID, interface{}) bool, + current StateID, + states States, + callback func(StateID, StateID, interface{}) bool, directory string, - nameGen func(FSMStateID) string, + nameGen func(StateID) string, generate bool, ) { fsm.States = states @@ -91,12 +91,12 @@ func (fsm *FSM) Init( } // InState returns whether or not the state machine is in the given 'check' state. -func (fsm *FSM) InState(check FSMStateID) bool { +func (fsm *FSM) InState(check StateID) bool { return check == fsm.Current } // HasTransition checks whether or not the state machine has a transition to the given 'check' state. -func (fsm *FSM) HasTransition(check FSMStateID) bool { +func (fsm *FSM) HasTransition(check StateID) bool { for _, transitionState := range fsm.States[fsm.Current].Transitions { if transitionState.To == check { return true @@ -133,7 +133,7 @@ func (fsm *FSM) writeGraph() { // GoTransitionRequired transitions the state machine to a new state with associated state data 'data' // If this transition is not handled by the client, it returns an error. -func (fsm *FSM) GoTransitionRequired(newState FSMStateID, data interface{}) error { +func (fsm *FSM) GoTransitionRequired(newState StateID, data interface{}) error { oldState := fsm.Current if !fsm.GoTransitionWithData(newState, data) { return types.NewWrappedError("failed required transition", fmt.Errorf("required transition not handled, from: %s -> to: %s", fsm.GetStateName(oldState), fsm.GetStateName(newState))) @@ -143,7 +143,7 @@ func (fsm *FSM) GoTransitionRequired(newState FSMStateID, data interface{}) erro // GoTransitionWithData is a helper that transitions the state machine toward the 'newState' with associated state data 'data' // It returns whether or not the transition is handled by the client. -func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool { +func (fsm *FSM) GoTransitionWithData(newState StateID, data interface{}) bool { ok := fsm.HasTransition(newState) handled := false @@ -161,7 +161,7 @@ func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool } // GoTransition is an alias to call GoTransitionWithData but have an empty string as data. -func (fsm *FSM) GoTransition(newState FSMStateID) bool { +func (fsm *FSM) GoTransition(newState StateID) bool { // No data means the callback is never required return fsm.GoTransitionWithData(newState, "") } @@ -170,7 +170,7 @@ func (fsm *FSM) GoTransition(newState FSMStateID) bool { // it returns the graph as a string. func (fsm *FSM) generateMermaidGraph() string { graph := "graph TD\n" - sortedFSM := make(FSMStateIDSlice, 0, len(fsm.States)) + sortedFSM := make(StateIDSlice, 0, len(fsm.States)) for stateID := range fsm.States { sortedFSM = append(sortedFSM, stateID) } diff --git a/internal/http/http.go b/internal/http/http.go index 7e87e3c..5f3d783 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -17,7 +17,7 @@ import ( type URLParameters map[string]string // OptionalParams is a structure that defines the optional parameters that are given when making a HTTP call. -type HTTPOptionalParams struct { +type OptionalParams struct { Headers http.Header URLParameters URLParameters Body url.Values @@ -25,7 +25,7 @@ type HTTPOptionalParams struct { } // ConstructURL creates a URL with the included parameters. -func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error) { +func ConstructURL(baseURL string, parameters URLParameters) (string, error) { url, parseErr := url.Parse(baseURL) if parseErr != nil { return "", types.NewWrappedError( @@ -48,30 +48,30 @@ func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error) } // Get creates a Get request and returns the headers, body and an error. -func HTTPGet(url string) (http.Header, []byte, error) { - return HTTPMethodWithOpts(http.MethodGet, url, nil) +func Get(url string) (http.Header, []byte, error) { + return MethodWithOpts(http.MethodGet, url, nil) } // Post creates a Post request and returns the headers, body and an error. -func HTTPPost(url string, body url.Values) (http.Header, []byte, error) { - return HTTPMethodWithOpts(http.MethodGet, url, &HTTPOptionalParams{Body: body}) +func Post(url string, body url.Values) (http.Header, []byte, error) { + return MethodWithOpts(http.MethodGet, url, &OptionalParams{Body: body}) } // GetWithOpts creates a Get request with optional parameters and returns the headers, body and an error. -func HTTPGetWithOpts(url string, opts *HTTPOptionalParams) (http.Header, []byte, error) { - return HTTPMethodWithOpts(http.MethodGet, url, opts) +func GetWithOpts(url string, opts *OptionalParams) (http.Header, []byte, error) { + return MethodWithOpts(http.MethodGet, url, opts) } // PostWithOpts creates a Post request with optional parameters and returns the headers, body and an error. -func HTTPPostWithOpts(url string, opts *HTTPOptionalParams) (http.Header, []byte, error) { - return HTTPMethodWithOpts(http.MethodPost, url, opts) +func PostWithOpts(url string, opts *OptionalParams) (http.Header, []byte, error) { + return MethodWithOpts(http.MethodPost, url, opts) } // optionalURL ensures that the URL contains the optional parameters // it returns the url (with parameters if success) and an error indicating success. -func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) { +func optionalURL(url string, opts *OptionalParams) (string, error) { if opts != nil { - url, urlErr := HTTPConstructURL(url, opts.URLParameters) + url, urlErr := ConstructURL(url, opts.URLParameters) if urlErr != nil { return url, types.NewWrappedError( @@ -85,7 +85,7 @@ func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) { } // optionalHeaders ensures that the HTTP request uses the optional headers if defined. -func httpOptionalHeaders(req *http.Request, opts *HTTPOptionalParams) { +func optionalHeaders(req *http.Request, opts *OptionalParams) { // Add headers if opts != nil && req != nil && opts.Headers != nil { for k, v := range opts.Headers { @@ -95,7 +95,7 @@ func httpOptionalHeaders(req *http.Request, opts *HTTPOptionalParams) { } // optionalBodyReader returns a HTTP body reader if there is a body, otherwise nil. -func httpOptionalBodyReader(opts *HTTPOptionalParams) io.Reader { +func optionalBodyReader(opts *OptionalParams) io.Reader { if opts != nil && opts.Body != nil { return strings.NewReader(opts.Body.Encode()) } @@ -104,15 +104,15 @@ func httpOptionalBodyReader(opts *HTTPOptionalParams) io.Reader { // MethodWithOpts creates a HTTP request using a method (e.g. GET, POST), an url and optional parameters // It returns the HTTP headers, the body and an error if there is one. -func HTTPMethodWithOpts( +func MethodWithOpts( method string, url string, - opts *HTTPOptionalParams, + opts *OptionalParams, ) (http.Header, []byte, error) { // Make sure the url contains all the parameters // This can return an error, // it already has the right error so so we don't wrap it further - url, urlErr := httpOptionalURL(url, opts) + url, urlErr := optionalURL(url, opts) if urlErr != nil { // No further type wrapping is needed here return nil, nil, urlErr @@ -131,7 +131,7 @@ func HTTPMethodWithOpts( errorMessage := fmt.Sprintf("failed HTTP request with method %s and url %s", method, url) // Create request object with the body reader generated from the optional arguments - req, reqErr := http.NewRequest(method, url, httpOptionalBodyReader(opts)) + req, reqErr := http.NewRequest(method, url, optionalBodyReader(opts)) if reqErr != nil { return nil, nil, types.NewWrappedError(errorMessage, reqErr) } @@ -140,7 +140,7 @@ func HTTPMethodWithOpts( req.Close = true // Make sure the headers contain all the parameters - httpOptionalHeaders(req, opts) + optionalHeaders(req, opts) // Do request resp, respErr := client.Do(req) @@ -159,7 +159,7 @@ func HTTPMethodWithOpts( if resp.StatusCode < 200 || resp.StatusCode > 299 { // We make this a custom error because we want to extract the status code later - statusErr := &HTTPStatusError{URL: url, Body: string(body), Status: resp.StatusCode} + statusErr := &StatusError{URL: url, Body: string(body), Status: resp.StatusCode} return resp.Header, body, types.NewWrappedError(errorMessage, statusErr) } @@ -168,14 +168,14 @@ func HTTPMethodWithOpts( } // StatusError indicates that we have received a HTTP status error. -type HTTPStatusError struct { +type StatusError struct { URL string Body string Status int } // Error returns the StatusError as an error string. -func (e *HTTPStatusError) Error() string { +func (e *StatusError) Error() string { return fmt.Sprintf( "failed obtaining HTTP resource: %s as it gave an unsuccessful status code: %d. Body: %s", e.URL, @@ -187,14 +187,14 @@ func (e *HTTPStatusError) Error() string { // ParseJSONError indicates that the HTTP error is because of failed JSON parsing // It has the URL and the Body as context. // The underlying JSON parsing Err itself is also wrapped here. -type HTTPParseJSONError struct { +type ParseJSONError struct { URL string Body string Err error } // Error returns the ParseJSONError as an error string. -func (e *HTTPParseJSONError) Error() string { +func (e *ParseJSONError) Error() string { return fmt.Sprintf( "failed parsing json %s for HTTP resource: %s with error: %v", e.Body, diff --git a/internal/log/log.go b/internal/log/log.go index 3aa0f0b..43bc737 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -16,17 +16,17 @@ import ( // As the name suggests, it saves the log to a file. type FileLogger struct { // Level indicates which maximum level this logger actually forwards to the file - Level LogLevel + Level Level // file represents a pointer to the open log file file *os.File } -type LogLevel int8 +type Level int8 const ( // LevelNotSet indicates level not set, not allowed. - LevelNotSet LogLevel = iota + LevelNotSet Level = iota // LevelDebug indicates that the message is not an error but is there for debugging. LevelDebug @@ -45,7 +45,7 @@ const ( ) // String returns the string of each level. -func (e LogLevel) String() string { +func (e Level) String() string { switch e { case LevelNotSet: return "NOTSET" @@ -66,7 +66,7 @@ func (e LogLevel) String() string { // Init initializes the logger by forwarding a max level 'level' and a directory 'directory' where the log should be stored // If the logger cannot be initialized, for example an error in opening the log file, an error is returned. -func (logger *FileLogger) Init(level LogLevel, directory string) error { +func (logger *FileLogger) Init(level Level, directory string) error { errorMessage := "failed creating log" configDirErr := util.EnsureDirectory(directory) @@ -141,7 +141,7 @@ func (logger *FileLogger) filename(directory string) string { } // log logs as level 'level' a message 'msg' with parameters 'params'. -func (logger *FileLogger) log(level LogLevel, msg string, params ...interface{}) { +func (logger *FileLogger) log(level Level, msg string, params ...interface{}) { if level >= logger.Level && logger.Level != LevelNotSet { formattedMsg := fmt.Sprintf(msg, params...) format := fmt.Sprintf("- Go - %s - %s", level.String(), formattedMsg) diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 84ecdc4..3c1e5d6 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -87,14 +87,14 @@ type OAuth struct { TokenURL string `json:"token_url"` // session is the internal in progress OAuth session - session OAuthExchangeSession `json:"-"` + session ExchangeSession `json:"-"` // Token is where the access and refresh tokens are stored along with the timestamps - token OAuthToken `json:"-"` + token Token `json:"-"` } -// OAuthExchangeSession is a structure that gets passed to the callback for easy access to the current state. -type OAuthExchangeSession struct { +// ExchangeSession is a structure that gets passed to the callback for easy access to the current state. +type ExchangeSession struct { // CallbackError indicates an error returned by the server CallbackError error @@ -137,7 +137,7 @@ func (oauth *OAuth) AccessToken() (string, error) { // Check if refresh is even possible by doing a simple check if the refresh token is empty // This is not needed but reduces API calls to the server if tokens.refresh == "" { - return "", types.NewWrappedError(errorMessage, &OAuthTokensInvalidError{Cause: "no refresh token is present"}) + return "", types.NewWrappedError(errorMessage, &TokensInvalidError{Cause: "no refresh token is present"}) } // Otherwise refresh and then later return the access token if we are successful @@ -146,7 +146,7 @@ func (oauth *OAuth) AccessToken() (string, error) { // We have failed to ensure the tokens due to refresh not working return "", types.NewWrappedError( errorMessage, - &OAuthTokensInvalidError{ + &TokensInvalidError{ Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr), }, ) @@ -195,17 +195,17 @@ func (oauth *OAuth) tokensWithCallback() error { // It calculates the expired timestamp by having a 'startTime' passed to it // The URL that is input here is used for additional context. func (oauth *OAuth) fillToken(response []byte, startTime time.Time, url string) error { - responseStructure := OAuthTokenResponse{} + responseStructure := TokenResponse{} jsonErr := json.Unmarshal(response, &responseStructure) if jsonErr != nil { return types.NewWrappedError( "failed filling OAuth tokens", - &httpw.HTTPParseJSONError{URL: url, Body: string(response), Err: jsonErr}, + &httpw.ParseJSONError{URL: url, Body: string(response), Err: jsonErr}, ) } - internalStructure := OAuthToken{} + internalStructure := Token{} internalStructure.expiredTimestamp = startTime.Add( time.Second * time.Duration(responseStructure.Expires), ) @@ -222,7 +222,7 @@ func (oauth *OAuth) SetTokenExpired() { // SetTokenRenew sets the tokens for renewal by completely clearing the structure. func (oauth *OAuth) SetTokenRenew() { - oauth.token = OAuthToken{} + oauth.token = Token{} } // tokensWithAuthCode gets the access and refresh tokens using the authorization code @@ -250,9 +250,9 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error { headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, } - opts := &httpw.HTTPOptionalParams{Headers: headers, Body: data} + opts := &httpw.OptionalParams{Headers: headers, Body: data} currentTime := time.Now() - _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) + _, body, bodyErr := httpw.PostWithOpts(reqURL, opts) if bodyErr != nil { return types.NewWrappedError(errorMessage, bodyErr) } @@ -278,9 +278,9 @@ func (oauth *OAuth) tokensWithRefresh() error { headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, } - opts := &httpw.HTTPOptionalParams{Headers: headers, Body: data} + opts := &httpw.OptionalParams{Headers: headers, Body: data} currentTime := time.Now() - _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) + _, body, bodyErr := httpw.PostWithOpts(reqURL, opts) if bodyErr != nil { return types.NewWrappedError(errorMessage, bodyErr) } @@ -381,7 +381,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { if oauth.session.ISS != extractedISS { oauth.session.CallbackError = types.NewWrappedError( errorMessage, - &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.session.ISS}, + &CallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.session.ISS}, ) return } @@ -394,7 +394,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { if extractedState == "" { oauth.session.CallbackError = types.NewWrappedError( errorMessage, - &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}, + &CallbackParameterError{Parameter: "state", URL: req.URL.String()}, ) return } @@ -402,7 +402,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { if extractedState != oauth.session.State { oauth.session.CallbackError = types.NewWrappedError( errorMessage, - &OAuthCallbackStateMatchError{ + &CallbackStateMatchError{ State: extractedState, ExpectedState: oauth.session.State, }, @@ -415,7 +415,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { if extractedCode == "" { oauth.session.CallbackError = types.NewWrappedError( errorMessage, - &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}, + &CallbackParameterError{Parameter: "code", URL: req.URL.String()}, ) return } @@ -471,7 +471,7 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s } // Fill the struct with the necessary fields filled for the next call to getting the HTTP client - oauthSession := OAuthExchangeSession{ClientID: name, ISS: oauth.ISS, State: state, Verifier: verifier} + oauthSession := ExchangeSession{ClientID: name, ISS: oauth.ISS, State: state, Verifier: verifier} oauth.session = oauthSession // set up the listener to get the redirect URI @@ -496,7 +496,7 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s "redirect_uri": fmt.Sprintf("http://127.0.0.1:%d/callback", port), } - authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters) + authURL, urlErr := httpw.ConstructURL(oauth.BaseAuthorizationURL, parameters) if urlErr != nil { return "", types.NewWrappedError(errorMessage, urlErr) @@ -523,50 +523,50 @@ func (oauth *OAuth) Cancel() { oauth.session.CallbackError = types.NewWrappedErrorLevel( types.ErrInfo, "cancelled OAuth", - &OAuthCancelledCallbackError{}, + &CancelledCallbackError{}, ) if oauth.session.Server != nil { oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck } } -type OAuthCancelledCallbackError struct{} +type CancelledCallbackError struct{} -func (e *OAuthCancelledCallbackError) Error() string { +func (e *CancelledCallbackError) Error() string { return "client cancelled OAuth" } -type OAuthCallbackParameterError struct { +type CallbackParameterError struct { Parameter string URL string } -func (e *OAuthCallbackParameterError) Error() string { +func (e *CallbackParameterError) Error() string { return fmt.Sprintf("failed retrieving parameter: %s in url: %s", e.Parameter, e.URL) } -type OAuthCallbackStateMatchError struct { +type CallbackStateMatchError struct { State string ExpectedState string } -func (e *OAuthCallbackStateMatchError) Error() string { +func (e *CallbackStateMatchError) Error() string { return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState) } -type OAuthCallbackISSMatchError struct { +type CallbackISSMatchError struct { ISS string ExpectedISS string } -func (e *OAuthCallbackISSMatchError) Error() string { +func (e *CallbackISSMatchError) Error() string { return fmt.Sprintf("failed matching ISS, got: %s, want: %s", e.ISS, e.ExpectedISS) } -type OAuthTokensInvalidError struct { +type TokensInvalidError struct { Cause string } -func (e *OAuthTokensInvalidError) Error() string { +func (e *TokensInvalidError) Error() string { return fmt.Sprintf("tokens are invalid due to: %s", e.Cause) } diff --git a/internal/oauth/token.go b/internal/oauth/token.go index 8ceb9a8..eb79357 100644 --- a/internal/oauth/token.go +++ b/internal/oauth/token.go @@ -2,8 +2,8 @@ package oauth import "time" -// OAuthTokenResponse defines the OAuth response from the server that includes the tokens. -type OAuthTokenResponse struct { +// TokenResponse defines the OAuth response from the server that includes the tokens. +type TokenResponse struct { // Access is the access token returned by the server Access string `json:"access_token"` @@ -18,8 +18,8 @@ type OAuthTokenResponse struct { } -// OAuthToken is a structure that contains our access and refresh tokens and a timestamp when they expire. -type OAuthToken struct { +// Token is a structure that contains our access and refresh tokens and a timestamp when they expire. +type Token struct { // Access is the access token returned by the server access string @@ -31,7 +31,7 @@ type OAuthToken struct { } // Expired checks if the access token is expired. -func (tokens *OAuthToken) Expired() bool { +func (tokens *Token) Expired() bool { currentTime := time.Now() return !currentTime.Before(tokens.expiredTimestamp) } diff --git a/internal/server/api.go b/internal/server/api.go index 65aadca..21ba6f4 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -13,7 +13,7 @@ import ( "github.com/eduvpn/eduvpn-common/types" ) -func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) { +func APIGetEndpoints(baseURL string) (*Endpoints, error) { errorMessage := "failed getting server endpoints" url, urlErr := url.Parse(baseURL) if urlErr != nil { @@ -23,13 +23,13 @@ func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) { wellKnownPath := "/.well-known/vpn-user-portal" url.Path = path.Join(url.Path, wellKnownPath) - _, body, bodyErr := httpw.HTTPGet(url.String()) + _, body, bodyErr := httpw.Get(url.String()) if bodyErr != nil { return nil, types.NewWrappedError(errorMessage, bodyErr) } - endpoints := &ServerEndpoints{} + endpoints := &Endpoints{} jsonErr := json.Unmarshal(body, endpoints) if jsonErr != nil { @@ -43,12 +43,12 @@ func apiAuthorized( server Server, method string, endpoint string, - opts *httpw.HTTPOptionalParams, + opts *httpw.OptionalParams, ) (http.Header, []byte, error) { errorMessage := "failed API authorized" // Ensure optional is not nil as we will fill it with headers if opts == nil { - opts = &httpw.HTTPOptionalParams{} + opts = &httpw.OptionalParams{} } base, baseErr := server.Base() @@ -76,20 +76,20 @@ func apiAuthorized( } else { opts.Headers = http.Header{headerKey: {headerValue}} } - return httpw.HTTPMethodWithOpts(method, url.String(), opts) + return httpw.MethodWithOpts(method, url.String(), opts) } func apiAuthorizedRetry( server Server, method string, endpoint string, - opts *httpw.HTTPOptionalParams, + opts *httpw.OptionalParams, ) (http.Header, []byte, error) { errorMessage := "failed authorized API retry" header, body, bodyErr := apiAuthorized(server, method, endpoint, opts) if bodyErr != nil { - var error *httpw.HTTPStatusError + var error *httpw.StatusError // Only retry authorized if we get a HTTP 401 if errors.As(bodyErr, &error) && error.Status == 401 { @@ -112,7 +112,7 @@ func APIInfo(server Server) error { if bodyErr != nil { return types.NewWrappedError(errorMessage, bodyErr) } - structure := ServerProfileInfo{} + structure := ProfileInfo{} jsonErr := json.Unmarshal(body, &structure) if jsonErr != nil { @@ -168,7 +168,7 @@ func APIConnectWireguard( server, http.MethodPost, "/connect", - &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm}, + &httpw.OptionalParams{Headers: headers, Body: urlForm}, ) if connectErr != nil { return "", "", time.Time{}, types.NewWrappedError( @@ -209,7 +209,7 @@ func APIConnectOpenVPN(server Server, profileID string, preferTCP bool) (string, server, http.MethodPost, "/connect", - &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm}, + &httpw.OptionalParams{Headers: headers, Body: urlForm}, ) if connectErr != nil { return "", time.Time{}, types.NewWrappedError(errorMessage, connectErr) diff --git a/internal/server/common.go b/internal/server/common.go index 7f6599a..e8c8e51 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -11,21 +11,21 @@ import ( ) // The base type for servers. -type ServerBase struct { +type Base struct { URL string `json:"base_url"` DisplayName map[string]string `json:"display_name"` SupportContact []string `json:"support_contact"` - Endpoints ServerEndpoints `json:"endpoints"` - Profiles ServerProfileInfo `json:"profiles"` + Endpoints Endpoints `json:"endpoints"` + Profiles ProfileInfo `json:"profiles"` StartTime time.Time `json:"start_time"` EndTime time.Time `json:"expire_time"` Type string `json:"server_type"` } -type ServerType int8 +type Type int8 const ( - CustomServerType ServerType = iota + CustomServerType Type = iota InstituteAccessServerType SecureInternetServerType ) @@ -35,7 +35,7 @@ type Servers struct { CustomServers InstituteAccessServers `json:"custom_servers"` InstituteServers InstituteAccessServers `json:"institute_servers"` SecureInternetHomeServer SecureInternetHomeServer `json:"secure_internet_home"` - IsType ServerType `json:"is_secure_internet"` + IsType Type `json:"is_secure_internet"` } type Server interface { @@ -45,48 +45,48 @@ type Server interface { TemplateAuth() func(string) string // Gets the server base - Base() (*ServerBase, error) + Base() (*Base, error) } -type ServerProfile struct { +type Profile struct { ID string `json:"profile_id"` DisplayName string `json:"display_name"` VPNProtoList []string `json:"vpn_proto_list"` DefaultGateway bool `json:"default_gateway"` } -type ServerProfileListInfo struct { - ProfileList []ServerProfile `json:"profile_list"` +type ProfileListInfo struct { + ProfileList []Profile `json:"profile_list"` } -type ServerProfileInfo struct { +type ProfileInfo struct { Current string `json:"current_profile"` - Info ServerProfileListInfo `json:"info"` + Info ProfileListInfo `json:"info"` } -func (info ServerProfileInfo) GetCurrentProfileIndex() int { +func (info ProfileInfo) GetCurrentProfileIndex() int { index := 0 for _, profile := range info.Info.ProfileList { if profile.ID == info.Current { return index } - index += 1 + index++ } // Default is 'first' profile return 0 } -type ServerEndpointList struct { +type EndpointList struct { API string `json:"api_endpoint"` Authorization string `json:"authorization_endpoint"` Token string `json:"token_endpoint"` } // Struct that defines the json format for /.well-known/vpn-user-portal". -type ServerEndpoints struct { +type Endpoints struct { API struct { - V2 ServerEndpointList `json:"http://eduvpn.org/api#2"` - V3 ServerEndpointList `json:"http://eduvpn.org/api#3"` + V2 EndpointList `json:"http://eduvpn.org/api#2"` + V3 EndpointList `json:"http://eduvpn.org/api#3"` } `json:"api"` V string `json:"v"` } @@ -97,7 +97,7 @@ func (servers *Servers) GetCurrentServer() (Server, error) { if !servers.HasSecureLocation() { return nil, types.NewWrappedError( errorMessage, - &ServerGetCurrentNotFoundError{}, + &CurrentNotFoundError{}, ) } return &servers.SecureInternetHomeServer, nil @@ -113,7 +113,7 @@ func (servers *Servers) GetCurrentServer() (Server, error) { if bases == nil { return nil, types.NewWrappedError( errorMessage, - &ServerGetCurrentNoMapError{}, + &CurrentNoMapError{}, ) } server, exists := bases[currentServerURL] @@ -121,7 +121,7 @@ func (servers *Servers) GetCurrentServer() (Server, error) { if !exists || server == nil { return nil, types.NewWrappedError( errorMessage, - &ServerGetCurrentNotFoundError{}, + &CurrentNotFoundError{}, ) } return server, nil @@ -283,7 +283,7 @@ func CancelOAuth(server Server) { server.OAuth().Cancel() } -func (profile *ServerProfile) supportsProtocol(protocol string) bool { +func (profile *Profile) supportsProtocol(protocol string) bool { for _, proto := range profile.VPNProtoList { if proto == protocol { return true @@ -292,15 +292,15 @@ func (profile *ServerProfile) supportsProtocol(protocol string) bool { return false } -func (profile *ServerProfile) supportsWireguard() bool { +func (profile *Profile) supportsWireguard() bool { return profile.supportsProtocol("wireguard") } -func (profile *ServerProfile) supportsOpenVPN() bool { +func (profile *Profile) supportsOpenVPN() bool { return profile.supportsProtocol("openvpn") } -func Profile(server Server) (*ServerProfile, error) { +func CurrentProfile(server Server) (*Profile, error) { errorMessage := "failed getting current profile" base, baseErr := server.Base() @@ -316,11 +316,11 @@ func Profile(server Server) (*ServerProfile, error) { return nil, types.NewWrappedError( errorMessage, - &ServerGetCurrentProfileNotFoundError{ProfileID: profileID}, + &CurrentProfileNotFoundError{ProfileID: profileID}, ) } -func (base *ServerBase) InitializeEndpoints() error { +func (base *Base) InitializeEndpoints() error { errorMessage := "failed initializing endpoints" endpoints, endpointsErr := APIGetEndpoints(base.URL) if endpointsErr != nil { @@ -330,8 +330,8 @@ func (base *ServerBase) InitializeEndpoints() error { return nil } -func (base *ServerBase) ValidProfiles(clientSupportsWireguard bool) ServerProfileInfo { - var validProfiles []ServerProfile +func (base *Base) ValidProfiles(clientSupportsWireguard bool) ProfileInfo { + var validProfiles []Profile for _, profile := range base.Profiles.Info.ProfileList { // Not a valid profile because it does not support openvpn // Also the client does not support wireguard @@ -340,10 +340,10 @@ func (base *ServerBase) ValidProfiles(clientSupportsWireguard bool) ServerProfil } validProfiles = append(validProfiles, profile) } - return ServerProfileInfo{Current: base.Profiles.Current, Info: ServerProfileListInfo{ProfileList: validProfiles}} + return ProfileInfo{Current: base.Profiles.Current, Info: ProfileListInfo{ProfileList: validProfiles}} } -func ValidProfiles(server Server, clientSupportsWireguard bool) (*ServerProfileInfo, error) { +func ValidProfiles(server Server, clientSupportsWireguard bool) (*ProfileInfo, error) { errorMessage := "failed to get valid profiles" // No error wrapping here otherwise we wrap it too much base, baseErr := server.Base() @@ -438,7 +438,7 @@ func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) // If there was a profile chosen and it doesn't exist anymore, reset it if base.Profiles.Current != "" { - _, existsProfileErr := Profile(server) + _, existsProfileErr := CurrentProfile(server) if existsProfileErr != nil { base.Profiles.Current = "" } @@ -450,7 +450,7 @@ func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) if base.Profiles.Current == "" { base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID } - profile, profileErr := Profile(server) + profile, profileErr := CurrentProfile(server) // shouldn't happen if profileErr != nil { return false, types.NewWrappedError(errorMessage, profileErr) @@ -486,7 +486,7 @@ func RefreshEndpoints(server Server) error { func Config(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) { errorMessage := "failed getting an OpenVPN/WireGuard configuration" - profile, profileErr := Profile(server) + profile, profileErr := CurrentProfile(server) if profileErr != nil { return "", "", types.NewWrappedError(errorMessage, profileErr) } @@ -522,34 +522,34 @@ func Disconnect(server Server) { APIDisconnect(server) } -type ServerGetCurrentProfileNotFoundError struct { +type CurrentProfileNotFoundError struct { ProfileID string } -func (e *ServerGetCurrentProfileNotFoundError) Error() string { +func (e *CurrentProfileNotFoundError) Error() string { return fmt.Sprintf("failed to get current profile, profile with ID: %s not found", e.ProfileID) } -type ServerGetConfigForceTCPError struct{} +type ConfigPreferTCPError struct{} -func (e *ServerGetConfigForceTCPError) Error() string { +func (e *ConfigPreferTCPError) Error() string { return "failed to get config, prefer TCP is on but the server does not support OpenVPN" } -type ServerEnsureServerEmptyURLError struct{} +type EmptyURLError struct{} -func (e *ServerEnsureServerEmptyURLError) Error() string { +func (e *EmptyURLError) Error() string { return "failed ensuring server, empty url provided" } -type ServerGetCurrentNoMapError struct{} +type CurrentNoMapError struct{} -func (e *ServerGetCurrentNoMapError) Error() string { +func (e *CurrentNoMapError) Error() string { return "failed getting current server, no servers available" } -type ServerGetCurrentNotFoundError struct{} +type CurrentNotFoundError struct{} -func (e *ServerGetCurrentNotFoundError) Error() string { +func (e *CurrentNotFoundError) Error() string { return "failed getting current server, not found" } diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index c0594a7..f76323c 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -14,7 +14,7 @@ type InstituteAccessServer struct { Auth oauth.OAuth `json:"oauth"` // Embed the server base - Basic ServerBase `json:"base"` + Basic Base `json:"base"` } type InstituteAccessServers struct { @@ -69,7 +69,7 @@ func (institute *InstituteAccessServer) TemplateAuth() func(string) string { } } -func (institute *InstituteAccessServer) Base() (*ServerBase, error) { +func (institute *InstituteAccessServer) Base() (*Base, error) { return &institute.Basic, nil } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index c6a353b..fa4c9c9 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -16,7 +16,7 @@ type SecureInternetHomeServer struct { DisplayName map[string]string `json:"display_name"` // The home server has a list of info for each configured server location - BaseMap map[string]*ServerBase `json:"base_map"` + BaseMap map[string]*Base `json:"base_map"` // We have the authorization URL template, the home organization ID and the current location AuthorizationTemplate string `json:"authorization_template"` @@ -64,12 +64,12 @@ func (server *SecureInternetHomeServer) TemplateAuth() func(string) string { } } -func (server *SecureInternetHomeServer) Base() (*ServerBase, error) { +func (server *SecureInternetHomeServer) Base() (*Base, error) { errorMessage := "failed getting current secure internet home base" if server.BaseMap == nil { return nil, types.NewWrappedError( errorMessage, - &ServerSecureInternetMapNotFoundError{}, + &SecureInternetMapNotFoundError{}, ) } @@ -78,7 +78,7 @@ func (server *SecureInternetHomeServer) Base() (*ServerBase, error) { if !exists { return nil, types.NewWrappedError( errorMessage, - &ServerSecureInternetBaseNotFoundError{Current: server.CurrentLocation}, + &SecureInternetBaseNotFoundError{Current: server.CurrentLocation}, ) } return base, nil @@ -94,11 +94,11 @@ func (servers *Servers) HasSecureLocation() bool { func (server *SecureInternetHomeServer) addLocation( locationServer *types.DiscoveryServer, -) (*ServerBase, error) { +) (*Base, error) { errorMessage := "failed adding a location" // Initialize the base map if it is non-nil if server.BaseMap == nil { - server.BaseMap = make(map[string]*ServerBase) + server.BaseMap = make(map[string]*Base) } // Add the location to the base map @@ -106,7 +106,7 @@ func (server *SecureInternetHomeServer) addLocation( if !exists || base == nil { // Create the base to be added to the map - base = &ServerBase{} + base = &Base{} base.URL = locationServer.BaseURL base.DisplayName = server.DisplayName base.SupportContact = locationServer.SupportContact @@ -152,22 +152,22 @@ func (server *SecureInternetHomeServer) init( return nil } -type ServerGetSecureInternetHomeError struct{} +type SecureInternetHomeNotFoundError struct{} -func (e *ServerGetSecureInternetHomeError) Error() string { +func (e *SecureInternetHomeNotFoundError) Error() string { return "failed to get secure internet home server, not found" } -type ServerSecureInternetMapNotFoundError struct{} +type SecureInternetMapNotFoundError struct{} -func (e *ServerSecureInternetMapNotFoundError) Error() string { +func (e *SecureInternetMapNotFoundError) Error() string { return "secure internet map not found" } -type ServerSecureInternetBaseNotFoundError struct { +type SecureInternetBaseNotFoundError struct { Current string } -func (e *ServerSecureInternetBaseNotFoundError) Error() string { +func (e *SecureInternetBaseNotFoundError) Error() string { return fmt.Sprintf("secure internet base not found with current location: %s", e.Current) } diff --git a/internal/verify/verify.go b/internal/verify/verify.go index 83765a0..55b82b6 100644 --- a/internal/verify/verify.go +++ b/internal/verify/verify.go @@ -67,7 +67,7 @@ func verifyWithKeys( case "server_list.json", "organization_list.json": break default: - return false, &VerifyUnknownExpectedFilenameError{ + return false, &UnknownExpectedFilenameError{ Filename: filename, Expected: "server_list.json or organization_list.json", } @@ -75,12 +75,12 @@ func verifyWithKeys( sig, err := minisign.DecodeSignature(signatureFileContent) if err != nil { - return false, &VerifyInvalidSignatureFormatError{Err: err} + return false, &InvalidSignatureFormatError{Err: err} } // Check if signature is prehashed, see https://jedisct1.github.io/minisign/#signature-format if forcePrehash && sig.SignatureAlgorithm != [2]byte{'E', 'D'} { - return false, &VerifyInvalidSignatureAlgorithmError{ + return false, &InvalidSignatureAlgorithmError{ Algorithm: string(sig.SignatureAlgorithm[:]), WantedAlgorithm: "ED (BLAKE2b-prehashed EdDSA)", } @@ -91,7 +91,7 @@ func verifyWithKeys( key, err := minisign.NewPublicKey(keyStr) if err != nil { // Should only happen if Verify is wrong or extraKey is invalid - return false, &VerifyCreatePublicKeyError{PublicKey: keyStr, Err: err} + return false, &CreatePublicKeyError{PublicKey: keyStr, Err: err} } if sig.KeyId != key.KeyId { @@ -100,7 +100,7 @@ func verifyWithKeys( valid, err := key.Verify(signedJSON, sig) if !valid { - return false, &VerifyInvalidSignatureError{Err: err} + return false, &InvalidSignatureError{Err: err} } // Parse trusted comment @@ -114,54 +114,54 @@ func verifyWithKeys( &sigFileName, ) if err != nil { - return false, &VerifyInvalidTrustedCommentError{ + return false, &InvalidTrustedCommentError{ TrustedComment: sig.TrustedComment, Err: err, } } if sigFileName != filename { - return false, &VerifyWrongSigFilenameError{Filename: filename, SigFilename: sigFileName} + return false, &WrongSigFilenameError{Filename: filename, SigFilename: sigFileName} } if signTime < minSignTime { - return false, &VerifySigTimeEarlierError{SigTime: signTime, MinSigTime: minSignTime} + return false, &SigTimeEarlierError{SigTime: signTime, MinSigTime: minSignTime} } return true, nil } // No matching allowed key found - return false, &VerifyUnknownKeyError{Filename: filename} + return false, &UnknownKeyError{Filename: filename} } -type VerifyUnknownExpectedFilenameError struct { +type UnknownExpectedFilenameError struct { Filename string Expected string } -func (e *VerifyUnknownExpectedFilenameError) Error() string { +func (e *UnknownExpectedFilenameError) Error() string { return fmt.Sprintf("invalid filename: %s, expected: %s", e.Filename, e.Expected) } -type VerifyInvalidSignatureFormatError struct { +type InvalidSignatureFormatError struct { Err error } -func (e *VerifyInvalidSignatureFormatError) Error() string { +func (e *InvalidSignatureFormatError) Error() string { return fmt.Sprintf("invalid signature format with error: %v", e.Err) } -func (e *VerifyInvalidSignatureFormatError) Unwrap() error { +func (e *InvalidSignatureFormatError) Unwrap() error { return e.Err } -type VerifyInvalidSignatureAlgorithmError struct { +type InvalidSignatureAlgorithmError struct { Algorithm string WantedAlgorithm string } -func (e *VerifyInvalidSignatureAlgorithmError) Error() string { +func (e *InvalidSignatureAlgorithmError) Error() string { return fmt.Sprintf( "invalid signature algorithm: %s, wanted: %s", e.Algorithm, @@ -169,50 +169,50 @@ func (e *VerifyInvalidSignatureAlgorithmError) Error() string { ) } -type VerifyCreatePublicKeyError struct { +type CreatePublicKeyError struct { PublicKey string Err error } -func (e *VerifyCreatePublicKeyError) Error() string { +func (e *CreatePublicKeyError) Error() string { return fmt.Sprintf("failed to create public key: %s with error: %v", e.PublicKey, e.Err) } -func (e *VerifyCreatePublicKeyError) Unwrap() error { +func (e *CreatePublicKeyError) Unwrap() error { return e.Err } -type VerifyInvalidSignatureError struct { +type InvalidSignatureError struct { Err error } -func (e *VerifyInvalidSignatureError) Error() string { +func (e *InvalidSignatureError) Error() string { return fmt.Sprintf("invalid signature with error: %v", e.Err) } -func (e *VerifyInvalidSignatureError) Unwrap() error { +func (e *InvalidSignatureError) Unwrap() error { return e.Err } -type VerifyInvalidTrustedCommentError struct { +type InvalidTrustedCommentError struct { TrustedComment string Err error } -func (e *VerifyInvalidTrustedCommentError) Error() string { +func (e *InvalidTrustedCommentError) Error() string { return fmt.Sprintf("invalid trusted comment: %s with error: %v", e.TrustedComment, e.Err) } -func (e *VerifyInvalidTrustedCommentError) Unwrap() error { +func (e *InvalidTrustedCommentError) Unwrap() error { return e.Err } -type VerifyWrongSigFilenameError struct { +type WrongSigFilenameError struct { Filename string SigFilename string } -func (e *VerifyWrongSigFilenameError) Error() string { +func (e *WrongSigFilenameError) Error() string { return fmt.Sprintf( "wrong filename: %s, expected filename: %s for signature", e.Filename, @@ -220,19 +220,19 @@ func (e *VerifyWrongSigFilenameError) Error() string { ) } -type VerifySigTimeEarlierError struct { +type SigTimeEarlierError struct { SigTime uint64 MinSigTime uint64 } -func (e *VerifySigTimeEarlierError) Error() string { +func (e *SigTimeEarlierError) Error() string { return fmt.Sprintf("Sign time: %d is earlier than sign time: %d", e.SigTime, e.MinSigTime) } -type VerifyUnknownKeyError struct { +type UnknownKeyError struct { Filename string } -func (e *VerifyUnknownKeyError) Error() string { +func (e *UnknownKeyError) Error() string { return fmt.Sprintf("signature for filename: %s was created with an unknown key", e.Filename) } diff --git a/internal/verify/verify_test.go b/internal/verify/verify_test.go index e250ee7..8ebed4c 100644 --- a/internal/verify/verify_test.go +++ b/internal/verify/verify_test.go @@ -29,15 +29,15 @@ func Test_verifyWithKeys(t *testing.T) { } var ( - verifyCreatePublicKeyError *VerifyCreatePublicKeyError - verifyInvalidSignatureAlgorithmError *VerifyInvalidSignatureAlgorithmError - verifyWrongSigFilenameError *VerifyWrongSigFilenameError - verifyInvalidTrustedCommentError *VerifyInvalidTrustedCommentError - verifyInvalidSignatureFormatError *VerifyInvalidSignatureFormatError - verifyInvalidSignatureError *VerifyInvalidSignatureError - verifySigTimeEarlierError *VerifySigTimeEarlierError - verifyUnknownExpectedFilenameError *VerifyUnknownExpectedFilenameError - verifyUnknownKeyError *VerifyUnknownKeyError + verifyCreatePublicKeyError *CreatePublicKeyError + verifyInvalidSignatureAlgorithmError *InvalidSignatureAlgorithmError + verifyWrongSigFilenameError *WrongSigFilenameError + verifyInvalidTrustedCommentError *InvalidTrustedCommentError + verifyInvalidSignatureFormatError *InvalidSignatureFormatError + verifyInvalidSignatureError *InvalidSignatureError + verifySigTimeEarlierError *SigTimeEarlierError + verifyUnknownExpectedFilenameError *UnknownExpectedFilenameError + verifyUnknownKeyError *UnknownKeyError ) tests := []struct { |
