From 7339e77c6eda5b96874dfc099d5c58da8ed53629 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Mon, 28 Nov 2022 11:52:04 +0100 Subject: Refactor: Remove most get prefixes for receiver functions --- client/client.go | 4 +-- client/client_test.go | 8 +++--- client/fsm.go | 2 +- client/server.go | 16 +++++------ cmd/cli/main.go | 4 +-- exports/exports.go | 6 ++--- exports/servers.go | 10 +++---- internal/discovery/discovery.go | 8 +++--- internal/fsm/fsm.go | 8 +++--- internal/log/log.go | 4 +-- internal/oauth/oauth.go | 30 ++++++++++----------- internal/server/api.go | 6 ++--- internal/server/common.go | 55 +++++++++++++++++++------------------- internal/server/custom.go | 2 +- internal/server/instituteaccess.go | 35 ++++++++++++------------ internal/server/secureinternet.go | 18 ++++++------- types/error.go | 16 +++++------ 17 files changed, 115 insertions(+), 117 deletions(-) diff --git a/client/client.go b/client/client.go index 34981db..958dd25 100644 --- a/client/client.go +++ b/client/client.go @@ -147,7 +147,7 @@ func (client *Client) Deregister() { // Save the config saveErr := client.Config.Save(&client) if saveErr != nil { - client.Logger.Info("failed saving configuration, error: %s", types.GetErrorTraceback(saveErr)) + client.Logger.Info("failed saving configuration, error: %s", types.ErrorTraceback(saveErr)) } // Empty out the state @@ -157,7 +157,7 @@ func (client *Client) Deregister() { // askProfile asks the user for a profile by moving the FSM to the ASK_PROFILE state. func (client *Client) askProfile(chosenServer server.Server) error { errorMessage := "failed asking for profiles" - profiles, profilesErr := server.GetValidProfiles(chosenServer, client.SupportsWireguard) + profiles, profilesErr := server.ValidProfiles(chosenServer, client.SupportsWireguard) if profilesErr != nil { return types.NewWrappedError(errorMessage, profilesErr) } diff --git a/client/client_test.go b/client/client_test.go index 2a240bd..a125e7e 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -120,7 +120,7 @@ func testConnectOAuthParameter( if serverErr != nil { t.Fatalf("No server with error: %v", serverErr) } - port, portErr := server.GetOAuth().GetListenerPort() + port, portErr := server.OAuth().ListenerPort() if portErr != nil { _ = state.CancelOAuth() t.Fatalf("No port with error: %v", portErr) @@ -247,7 +247,7 @@ func TestTokenExpired(t *testing.T) { t.Fatalf("No server found") } - serverOAuth := currentServer.GetOAuth() + serverOAuth := currentServer.OAuth() accessToken := serverOAuth.Token.Access refreshToken := serverOAuth.Token.Refresh @@ -310,7 +310,7 @@ func TestTokenInvalid(t *testing.T) { t.Fatalf("No server found") } - serverOAuth := currentServer.GetOAuth() + serverOAuth := currentServer.OAuth() // Override tokens with invalid values serverOAuth.Token.Access = dummyValue @@ -366,7 +366,7 @@ func TestInvalidProfileCorrected(t *testing.T) { t.Fatalf("No server found") } - base, baseErr := currentServer.GetBase() + base, baseErr := currentServer.Base() if baseErr != nil { t.Fatalf("No base found") } diff --git a/client/fsm.go b/client/fsm.go index 159464a..f4bfe21 100644 --- a/client/fsm.go +++ b/client/fsm.go @@ -393,7 +393,7 @@ func (client *Client) goBackInternal() { client.Logger.Info( fmt.Sprintf( "Failed going back, error: %s", - types.GetErrorTraceback(goBackErr), + types.ErrorTraceback(goBackErr), ), ) } diff --git a/client/server.go b/client/server.go index 5b1a32b..5fed292 100644 --- a/client/server.go +++ b/client/server.go @@ -36,7 +36,7 @@ func (client *Client) getConfigAuth( } // We return the error otherwise we wrap it too much - return server.GetConfig(chosenServer, client.SupportsWireguard, preferTCP) + return server.Config(chosenServer, client.SupportsWireguard, preferTCP) } // retryConfigAuth retries the getConfigAuth function if the tokens are invalid. @@ -104,7 +104,7 @@ func (client *Client) getConfig( if saveErr != nil { client.Logger.Info( "Failed saving configuration after getting a server: %s", - types.GetErrorTraceback(saveErr), + types.ErrorTraceback(saveErr), ) } @@ -153,7 +153,7 @@ func (client *Client) RemoveSecureInternet() error { if saveErr != nil { client.Logger.Info( "Failed saving configuration after removing a secure internet server: %s", - types.GetErrorTraceback(saveErr), + types.ErrorTraceback(saveErr), ) } return nil @@ -177,7 +177,7 @@ func (client *Client) RemoveInstituteAccess(url string) error { if saveErr != nil { client.Logger.Info( "Failed saving configuration after removing an institute access server: %s", - types.GetErrorTraceback(saveErr), + types.ErrorTraceback(saveErr), ) } return nil @@ -201,7 +201,7 @@ func (client *Client) RemoveCustomServer(url string) error { if saveErr != nil { client.Logger.Info( "Failed saving configuration after removing a custom server: %s", - types.GetErrorTraceback(saveErr), + types.ErrorTraceback(saveErr), ) } return nil @@ -566,7 +566,7 @@ func (client *Client) ShouldRenewButton() bool { if currentServerErr != nil { client.Logger.Info( "No server found to renew with err: %s", - types.GetErrorTraceback(currentServerErr), + types.ErrorTraceback(currentServerErr), ) return false } @@ -581,7 +581,7 @@ func (client *Client) ensureLogin(chosenServer server.Server) error { // Relogin with oauth // This moves the state to authorized if server.NeedsRelogin(chosenServer) { - url, urlErr := server.GetOAuthURL(chosenServer, client.Name) + url, urlErr := server.OAuthURL(chosenServer, client.Name) goTransitionErr := client.FSM.GoTransitionRequired(StateOAuthStarted, url) if goTransitionErr != nil { @@ -615,7 +615,7 @@ func (client *Client) SetProfileID(profileID string) error { return client.handleError(errorMessage, serverErr) } - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { client.goBackInternal() return client.handleError(errorMessage, baseErr) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index ec78754..f4f17e4 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -142,8 +142,8 @@ func printConfig(url string, serverType ServerTypes) { if configErr != nil { // Show the usage of tracebacks and causes - fmt.Println("Error getting config:", types.GetErrorTraceback(configErr)) - fmt.Println("Error getting config, cause:", types.GetErrorCause(configErr)) + fmt.Println("Error getting config:", types.ErrorTraceback(configErr)) + fmt.Println("Error getting config, cause:", types.ErrorCause(configErr)) return } diff --git a/exports/exports.go b/exports/exports.go index fb972fc..124a3a5 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -141,9 +141,9 @@ func getError(err error) *C.error { errorStruct := (*C.error)( C.malloc(C.size_t(unsafe.Sizeof(C.error{}))), ) - errorStruct.level = C.errorLevel(types.GetErrorLevel(err)) - errorStruct.traceback = C.CString(types.GetErrorTraceback(err)) - errorStruct.cause = C.CString(types.GetErrorCause(err).Error()) + errorStruct.level = C.errorLevel(types.ErrorLevel(err)) + errorStruct.traceback = C.CString(types.ErrorTraceback(err)) + errorStruct.cause = C.CString(types.ErrorCause(err).Error()) return errorStruct } diff --git a/exports/servers.go b/exports/servers.go index 0c3cc5f..d31fa44 100644 --- a/exports/servers.go +++ b/exports/servers.go @@ -189,7 +189,7 @@ func getCPtrServer(state *client.Client, base *client.ServerBase) *C.server { locationsStruct.total_locations, locationsStruct.locations = getCPtrListStrings(locations) cServer.locations = locationsStruct - profiles := base.GetValidProfiles(state.SupportsWireguard) + profiles := base.ValidProfiles(state.SupportsWireguard) cServer.profiles = getCPtrProfiles(&profiles) // No endtime is given if we get servers when it has been partially initialised if base.EndTime.IsZero() { @@ -232,7 +232,7 @@ func getCPtrServers( servers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(serversPtr))[:totalServers:totalServers] index := 0 for _, currentServer := range serverMap { - cServer := getCPtrServer(state, ¤tServer.Base) + cServer := getCPtrServer(state, ¤tServer.Basic) servers[index] = cServer index += 1 } @@ -282,7 +282,7 @@ func getSavedServersWithOptions(state *client.Client, servers *server.Servers) * totalCustom, customPtr := getCPtrServers(state, servers.CustomServers.Map) totalInstitute, institutePtr := getCPtrServers(state, servers.InstituteServers.Map) var secureServerPtr *C.server = nil - secureInternetBase, secureInternetBaseErr := servers.SecureInternetHomeServer.GetBase() + secureInternetBase, secureInternetBaseErr := servers.SecureInternetHomeServer.Base() if secureInternetBaseErr == nil && secureInternetBase != nil { // FIXME: log error? secureServerPtr = getCPtrServer(state, secureInternetBase) @@ -328,7 +328,7 @@ func GetCurrentServer(name *C.char) (*C.server, *C.error) { if serverErr != nil { return nil, getError(serverErr) } - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return nil, getError(baseErr) } @@ -369,7 +369,7 @@ func getTransitionProfiles(data interface{}) *C.serverProfiles { func getTransitionServer(state *client.Client, data interface{}) *C.server { if server, ok := data.(server.Server); ok { - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { // TODO: LOG return nil diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 7df209c..d7fb273 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -21,9 +21,9 @@ type Discovery struct { servers types.DiscoveryServers } -// getDiscoFile is a helper function that gets a disco json and fills the structure with it +// discoFile is a helper function that gets a disco JSON and fills the structure with it // If it was unsuccessful it returns an error -func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) error { +func discoFile(jsonFile string, previousVersion uint64, structure interface{}) error { errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile) // Get json data discoURL := "https://disco.eduvpn.org/v2/" @@ -185,7 +185,7 @@ func (discovery *Discovery) Organizations() (*types.DiscoveryOrganizations, erro return &discovery.organizations, nil } file := "organization_list.json" - bodyErr := getDiscoFile(file, discovery.organizations.Version, &discovery.organizations) + bodyErr := discoFile(file, discovery.organizations.Version, &discovery.organizations) if bodyErr != nil { // Return previous with an error return &discovery.organizations, types.NewWrappedError( @@ -204,7 +204,7 @@ func (discovery *Discovery) Servers() (*types.DiscoveryServers, error) { return &discovery.servers, nil } file := "server_list.json" - bodyErr := getDiscoFile(file, discovery.servers.Version, &discovery.servers) + bodyErr := discoFile(file, discovery.servers.Version, &discovery.servers) if bodyErr != nil { // Return previous with an error return &discovery.servers, types.NewWrappedError( diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go index b8fd644..c3c7efa 100644 --- a/internal/fsm/fsm.go +++ b/internal/fsm/fsm.go @@ -106,8 +106,8 @@ func (fsm *FSM) HasTransition(check FSMStateID) bool { return false } -// getGraphFilename gets the full path to the graph filename including the .graph extension -func (fsm *FSM) getGraphFilename(extension string) string { +// graphFilename gets the full path to the graph filename including the .graph extension +func (fsm *FSM) graphFilename(extension string) string { debugPath := path.Join(fsm.Directory, "graph") return fmt.Sprintf("%s%s", debugPath, extension) } @@ -115,8 +115,8 @@ func (fsm *FSM) getGraphFilename(extension string) string { // writeGraph writes the state machine to a .graph file func (fsm *FSM) writeGraph() { graph := fsm.GenerateGraph() - graphFile := fsm.getGraphFilename(".graph") - graphImgFile := fsm.getGraphFilename(".png") + graphFile := fsm.graphFilename(".graph") + graphImgFile := fsm.graphFilename(".png") f, err := os.Create(graphFile) if err != nil { return diff --git a/internal/log/log.go b/internal/log/log.go index 7d032e9..eaedc28 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -90,9 +90,9 @@ func (logger *FileLogger) Init(level LogLevel, directory string) error { // Inherit logs an error with a label using the error level of the error func (logger *FileLogger) Inherit(label string, err error) { - level := types.GetErrorLevel(err) + level := types.ErrorLevel(err) - msg := fmt.Sprintf("%s with err: %s", label, types.GetErrorTraceback(err)) + msg := fmt.Sprintf("%s with err: %s", label, types.ErrorTraceback(err)) switch level { case types.ErrInfo: logger.Info(msg) diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 232f68c..fe78cd3 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -153,9 +153,9 @@ func (oauth *OAuth) setupListener() error { return nil } -// getTokensWithCallback gets the OAuth tokens using a local web server +// tokensWithCallback gets the OAuth tokens using a local web server // If it was unsuccessful it returns an error -func (oauth *OAuth) getTokensWithCallback() error { +func (oauth *OAuth) tokensWithCallback() error { errorMessage := "failed getting tokens with callback" if oauth.session.Listener == nil { return types.NewWrappedError(errorMessage, errors.New("no listener")) @@ -173,17 +173,17 @@ func (oauth *OAuth) getTokensWithCallback() error { return oauth.session.CallbackError } -// getTokensWithAuthCode gets the access and refresh tokens using the authorization code +// tokensWithAuthCode gets the access and refresh tokens using the authorization code // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 // If it was unsuccessful it returns an error -func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { +func (oauth *OAuth) tokensWithAuthCode(authCode string) error { errorMessage := "failed getting tokens with the authorization code" // Make sure the verifier is set as the parameter // so that the server can verify that we are the actual owner of the authorization code reqURL := oauth.TokenURL - port, portErr := oauth.GetListenerPort() + port, portErr := oauth.ListenerPort() if portErr != nil { return types.NewWrappedError(errorMessage, portErr) } @@ -230,11 +230,11 @@ func (oauth *OAuth) isTokensExpired() bool { return !currentTime.Before(expiredTime) } -// getTokensWithRefresh gets the access and refresh tokens with a previously received refresh token +// tokensWithRefresh gets the access and refresh tokens with a previously received refresh token // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 // If it was unsuccessful it returns an error -func (oauth *OAuth) getTokensWithRefresh() error { +func (oauth *OAuth) tokensWithRefresh() error { errorMessage := "failed getting tokens with the refresh token" reqURL := oauth.TokenURL data := url.Values{ @@ -398,7 +398,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // Now that we have obtained the authorization code, we can move to the next step: // Obtaining the access and refresh tokens - getTokensErr := oauth.getTokensWithAuthCode(extractedCode) + getTokensErr := oauth.tokensWithAuthCode(extractedCode) if getTokensErr != nil { oauth.session.CallbackError = types.NewWrappedError( errorMessage, @@ -418,9 +418,9 @@ func (oauth *OAuth) Init(iss string, baseAuthorizationURL string, tokenURL strin oauth.TokenURL = tokenURL } -// GetListenerPort gets the listener for the OAuth web server +// ListenerPort gets the listener for the OAuth web server // It returns the port as an integer and an error if there is any -func (oauth OAuth) GetListenerPort() (int, error) { +func (oauth OAuth) ListenerPort() (int, error) { errorMessage := "failed to get listener port" if oauth.session.Listener == nil { @@ -429,8 +429,8 @@ func (oauth OAuth) GetListenerPort() (int, error) { return oauth.session.Listener.Addr().(*net.TCPAddr).Port, nil } -// GetAuthURL gets the authorization url to start the OAuth procedure -func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) (string, error) { +// AuthURL gets the authorization url to start the OAuth procedure +func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (string, error) { errorMessage := "failed starting OAuth exchange" // Generate the verifier and challenge @@ -457,7 +457,7 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) } // Get the listener port - port, portErr := oauth.GetListenerPort() + port, portErr := oauth.ListenerPort() if portErr != nil { return "", types.NewWrappedError(errorMessage, portErr) } @@ -485,7 +485,7 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) // Exchange starts the OAuth exchange by getting the tokens with the redirect callback // If it was unsuccessful it returns an error func (oauth *OAuth) Exchange() error { - tokenErr := oauth.getTokensWithCallback() + tokenErr := oauth.tokensWithCallback() if tokenErr != nil { return types.NewWrappedError("failed finishing OAuth", tokenErr) @@ -526,7 +526,7 @@ func (oauth *OAuth) EnsureTokens() error { } // Otherwise try to refresh them and return if successful - refreshErr := oauth.getTokensWithRefresh() + refreshErr := oauth.tokensWithRefresh() // We have obtained new tokens with refresh if refreshErr != nil { // We have failed to ensure the tokens due to refresh not working diff --git a/internal/server/api.go b/internal/server/api.go index d315ada..559b787 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -50,7 +50,7 @@ func apiAuthorized( if opts == nil { opts = &httpw.HTTPOptionalParams{} } - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return nil, nil, types.NewWrappedError(errorMessage, baseErr) @@ -70,7 +70,7 @@ func apiAuthorized( } headerKey := "Authorization" - headerValue := fmt.Sprintf("Bearer %s", GetHeaderToken(server)) + headerValue := fmt.Sprintf("Bearer %s", HeaderToken(server)) if opts.Headers != nil { opts.Headers.Add(headerKey, headerValue) } else { @@ -119,7 +119,7 @@ func APIInfo(server Server) error { return types.NewWrappedError(errorMessage, jsonErr) } - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return types.NewWrappedError(errorMessage, baseErr) diff --git a/internal/server/common.go b/internal/server/common.go index 8f4eabc..16208eb 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -39,14 +39,13 @@ type Servers struct { } type Server interface { - // Gets the current OAuth object - GetOAuth() *oauth.OAuth + OAuth() *oauth.OAuth // Get the authorization URL template function - GetTemplateAuth() func(string) string + TemplateAuth() func(string) string // Gets the server base - GetBase() (*ServerBase, error) + Base() (*ServerBase, error) } type ServerProfile struct { @@ -216,7 +215,7 @@ func (servers *Servers) AddSecureInternet( } func ShouldRenewButton(server Server) bool { - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { // FIXME: Log error here? @@ -251,28 +250,28 @@ func ShouldRenewButton(server Server) bool { return true } -func GetOAuthURL(server Server, name string) (string, error) { - return server.GetOAuth().GetAuthURL(name, server.GetTemplateAuth()) +func OAuthURL(server Server, name string) (string, error) { + return server.OAuth().AuthURL(name, server.TemplateAuth()) } func OAuthExchange(server Server) error { - return server.GetOAuth().Exchange() + return server.OAuth().Exchange() } -func GetHeaderToken(server Server) string { - return server.GetOAuth().Token.Access +func HeaderToken(server Server) string { + return server.OAuth().Token.Access } func MarkTokenExpired(server Server) { - server.GetOAuth().Token.ExpiredTimestamp = time.Now() + server.OAuth().Token.ExpiredTimestamp = time.Now() } func MarkTokensForRenew(server Server) { - server.GetOAuth().Token = oauth.OAuthToken{} + server.OAuth().Token = oauth.OAuthToken{} } func EnsureTokens(server Server) error { - ensureErr := server.GetOAuth().EnsureTokens() + ensureErr := server.OAuth().EnsureTokens() if ensureErr != nil { return types.NewWrappedError("failed ensuring server tokens", ensureErr) } @@ -284,7 +283,7 @@ func NeedsRelogin(server Server) bool { } func CancelOAuth(server Server) { - server.GetOAuth().Cancel() + server.OAuth().Cancel() } func (profile *ServerProfile) supportsProtocol(protocol string) bool { @@ -304,9 +303,9 @@ func (profile *ServerProfile) supportsOpenVPN() bool { return profile.supportsProtocol("openvpn") } -func getCurrentProfile(server Server) (*ServerProfile, error) { +func Profile(server Server) (*ServerProfile, error) { errorMessage := "failed getting current profile" - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return nil, types.NewWrappedError(errorMessage, baseErr) @@ -334,7 +333,7 @@ func (base *ServerBase) InitializeEndpoints() error { return nil } -func (base *ServerBase) GetValidProfiles(clientSupportsWireguard bool) ServerProfileInfo { +func (base *ServerBase) ValidProfiles(clientSupportsWireguard bool) ServerProfileInfo { var validProfiles []ServerProfile for _, profile := range base.Profiles.Info.ProfileList { // Not a valid profile because it does not support openvpn @@ -347,14 +346,14 @@ func (base *ServerBase) GetValidProfiles(clientSupportsWireguard bool) ServerPro return ServerProfileInfo{Current: base.Profiles.Current, Info: ServerProfileListInfo{ProfileList: validProfiles}} } -func GetValidProfiles(server Server, clientSupportsWireguard bool) (*ServerProfileInfo, error) { +func ValidProfiles(server Server, clientSupportsWireguard bool) (*ServerProfileInfo, error) { errorMessage := "failed to get valid profiles" // No error wrapping here otherwise we wrap it too much - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return nil, types.NewWrappedError(errorMessage, baseErr) } - profiles := base.GetValidProfiles(clientSupportsWireguard) + profiles := base.ValidProfiles(clientSupportsWireguard) if len(profiles.Info.ProfileList) == 0 { return nil, types.NewWrappedError(errorMessage, errors.New("no profiles found with supported protocols")) } @@ -363,7 +362,7 @@ func GetValidProfiles(server Server, clientSupportsWireguard bool) (*ServerProfi func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (string, string, error) { errorMessage := "failed getting server WireGuard configuration" - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return "", "", types.NewWrappedError(errorMessage, baseErr) @@ -406,7 +405,7 @@ func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (st func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) { errorMessage := "failed getting server OpenVPN configuration" - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return "", "", types.NewWrappedError(errorMessage, baseErr) @@ -435,14 +434,14 @@ func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) return false, types.NewWrappedError(errorMessage, infoErr) } - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return false, types.NewWrappedError(errorMessage, baseErr) } // If there was a profile chosen and it doesn't exist anymore, reset it if base.Profiles.Current != "" { - _, existsProfileErr := getCurrentProfile(server) + _, existsProfileErr := Profile(server) if existsProfileErr != nil { base.Profiles.Current = "" } @@ -454,7 +453,7 @@ func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) if base.Profiles.Current == "" { base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID } - profile, profileErr := getCurrentProfile(server) + profile, profileErr := Profile(server) // shouldn't happen if profileErr != nil { return false, types.NewWrappedError(errorMessage, profileErr) @@ -474,7 +473,7 @@ func RefreshEndpoints(server Server) error { // Re-initialize the endpoints // TODO: Make this a warning instead? - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return types.NewWrappedError(errorMessage, baseErr) } @@ -487,10 +486,10 @@ func RefreshEndpoints(server Server) error { return nil } -func GetConfig(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) { +func Config(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) { errorMessage := "failed getting an OpenVPN/WireGuard configuration" - profile, profileErr := getCurrentProfile(server) + profile, profileErr := Profile(server) if profileErr != nil { return "", "", types.NewWrappedError(errorMessage, profileErr) } diff --git a/internal/server/custom.go b/internal/server/custom.go index 8bde848..f8899b3 100644 --- a/internal/server/custom.go +++ b/internal/server/custom.go @@ -9,7 +9,7 @@ import ( func (servers *Servers) SetCustomServer(server Server) error { errorMessage := "failed setting custom server" - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return types.NewWrappedError(errorMessage, baseErr) } diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index 33d8b52..ca37dcd 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -11,10 +11,10 @@ import ( // An instute access server type InstituteAccessServer struct { // An instute access server has its own OAuth - OAuth oauth.OAuth `json:"oauth"` + Auth oauth.OAuth `json:"oauth"` // Embed the server base - Base ServerBase `json:"base"` + Basic ServerBase `json:"base"` } type InstituteAccessServers struct { @@ -24,7 +24,7 @@ type InstituteAccessServers struct { func (servers *Servers) SetInstituteAccess(server Server) error { errorMessage := "failed setting institute access server" - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return types.NewWrappedError(errorMessage, baseErr) } @@ -63,19 +63,18 @@ func (servers *InstituteAccessServers) Remove(url string) { delete(servers.Map, url) } -// For an institute, we can simply get the OAuth -func (institute *InstituteAccessServer) GetOAuth() *oauth.OAuth { - return &institute.OAuth -} - -func (institute *InstituteAccessServer) GetTemplateAuth() func(string) string { +func (institute *InstituteAccessServer) TemplateAuth() func(string) string { return func(authURL string) string { return authURL } } -func (institute *InstituteAccessServer) GetBase() (*ServerBase, error) { - return &institute.Base, nil +func (institute *InstituteAccessServer) Base() (*ServerBase, error) { + return &institute.Basic, nil +} + +func (institute *InstituteAccessServer) OAuth() *oauth.OAuth { + return &institute.Auth } func (institute *InstituteAccessServer) init( @@ -85,15 +84,15 @@ func (institute *InstituteAccessServer) init( supportContact []string, ) error { errorMessage := fmt.Sprintf("failed initializing server %s", url) - institute.Base.URL = url - institute.Base.DisplayName = displayName - institute.Base.SupportContact = supportContact - institute.Base.Type = serverType - endpointsErr := institute.Base.InitializeEndpoints() + institute.Basic.URL = url + institute.Basic.DisplayName = displayName + institute.Basic.SupportContact = supportContact + institute.Basic.Type = serverType + endpointsErr := institute.Basic.InitializeEndpoints() if endpointsErr != nil { return types.NewWrappedError(errorMessage, endpointsErr) } - API := institute.Base.Endpoints.API.V3 - institute.OAuth.Init(url, API.Authorization, API.Token) + API := institute.Basic.Endpoints.API.V3 + institute.Auth.Init(url, API.Authorization, API.Token) return nil } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index f0b308f..0dc9ef1 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -12,8 +12,8 @@ import ( // A secure internet server which has its own OAuth tokens // It specifies the current location url it is connected to type SecureInternetHomeServer struct { + Auth oauth.OAuth `json:"oauth"` DisplayName map[string]string `json:"display_name"` - OAuth oauth.OAuth `json:"oauth"` // The home server has a list of info for each configured server location BaseMap map[string]*ServerBase `json:"base_map"` @@ -33,7 +33,7 @@ func (servers *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer func (servers *Servers) SetSecureInternet(server Server) error { errorMessage := "failed setting secure internet server" - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return types.NewWrappedError(errorMessage, baseErr) } @@ -58,17 +58,13 @@ func (servers *Servers) RemoveSecureInternet() { } } -func (server *SecureInternetHomeServer) GetOAuth() *oauth.OAuth { - return &server.OAuth -} - -func (server *SecureInternetHomeServer) GetTemplateAuth() func(string) string { +func (server *SecureInternetHomeServer) TemplateAuth() func(string) string { return func(authURL string) string { return util.ReplaceWAYF(server.AuthorizationTemplate, authURL, server.HomeOrganizationID) } } -func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) { +func (server *SecureInternetHomeServer) Base() (*ServerBase, error) { errorMessage := "failed getting current secure internet home base" if server.BaseMap == nil { return nil, types.NewWrappedError( @@ -88,6 +84,10 @@ func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) { return base, nil } +func (server *SecureInternetHomeServer) OAuth() *oauth.OAuth { + return &server.Auth +} + func (servers *Servers) HasSecureLocation() bool { return servers.SecureInternetHomeServer.CurrentLocation != "" } @@ -148,7 +148,7 @@ func (server *SecureInternetHomeServer) init( } // Make sure oauth contains our endpoints - server.OAuth.Init(base.URL, base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token) + server.Auth.Init(base.URL, base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token) return nil } diff --git a/types/error.go b/types/error.go index fd56f3d..a7f70b0 100644 --- a/types/error.go +++ b/types/error.go @@ -5,11 +5,11 @@ import ( "fmt" ) -type ErrorLevel int8 +type ErrLevel int8 const ( // All other errors, default - ErrOther ErrorLevel = iota + ErrOther ErrLevel = iota // The erorr is just here as additional info ErrInfo @@ -22,18 +22,18 @@ const ( ) type WrappedErrorMessage struct { - Level ErrorLevel + Level ErrLevel Message string Err error } // NewWrappedError returns a WrappedErrorMessage and uses the error level from the parent func NewWrappedError(message string, err error) *WrappedErrorMessage { - return &WrappedErrorMessage{Level: GetErrorLevel(err), Message: message, Err: err} + return &WrappedErrorMessage{Level: ErrorLevel(err), Message: message, Err: err} } // NewWrappedError returns a WrappedErrorMessage and uses the given error level from the parent -func NewWrappedErrorLevel(level ErrorLevel, message string, err error) *WrappedErrorMessage { +func NewWrappedErrorLevel(level ErrLevel, message string, err error) *WrappedErrorMessage { return &WrappedErrorMessage{Level: level, Message: message, Err: err} } @@ -70,7 +70,7 @@ func (e *WrappedErrorMessage) Error() string { return fmt.Sprintf("Got error: %s, with cause: %s", e.Message, e.Err) } -func GetErrorTraceback(err error) string { +func ErrorTraceback(err error) string { var wrappedErr *WrappedErrorMessage if errors.As(err, &wrappedErr) { @@ -79,7 +79,7 @@ func GetErrorTraceback(err error) string { return err.Error() } -func GetErrorCause(err error) error { +func ErrorCause(err error) error { var wrappedErr *WrappedErrorMessage if errors.As(err, &wrappedErr) { @@ -88,7 +88,7 @@ func GetErrorCause(err error) error { return err } -func GetErrorLevel(err error) ErrorLevel { +func ErrorLevel(err error) ErrLevel { var wrappedErr *WrappedErrorMessage if errors.As(err, &wrappedErr) { -- cgit v1.2.3