From 7339e77c6eda5b96874dfc099d5c58da8ed53629 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Mon, 28 Nov 2022 11:52:04 +0100 Subject: Refactor: Remove most get prefixes for receiver functions --- internal/server/api.go | 6 ++--- internal/server/common.go | 55 +++++++++++++++++++------------------- internal/server/custom.go | 2 +- internal/server/instituteaccess.go | 35 ++++++++++++------------ internal/server/secureinternet.go | 18 ++++++------- 5 files changed, 57 insertions(+), 59 deletions(-) (limited to 'internal/server') diff --git a/internal/server/api.go b/internal/server/api.go index d315ada..559b787 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -50,7 +50,7 @@ func apiAuthorized( if opts == nil { opts = &httpw.HTTPOptionalParams{} } - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return nil, nil, types.NewWrappedError(errorMessage, baseErr) @@ -70,7 +70,7 @@ func apiAuthorized( } headerKey := "Authorization" - headerValue := fmt.Sprintf("Bearer %s", GetHeaderToken(server)) + headerValue := fmt.Sprintf("Bearer %s", HeaderToken(server)) if opts.Headers != nil { opts.Headers.Add(headerKey, headerValue) } else { @@ -119,7 +119,7 @@ func APIInfo(server Server) error { return types.NewWrappedError(errorMessage, jsonErr) } - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return types.NewWrappedError(errorMessage, baseErr) diff --git a/internal/server/common.go b/internal/server/common.go index 8f4eabc..16208eb 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -39,14 +39,13 @@ type Servers struct { } type Server interface { - // Gets the current OAuth object - GetOAuth() *oauth.OAuth + OAuth() *oauth.OAuth // Get the authorization URL template function - GetTemplateAuth() func(string) string + TemplateAuth() func(string) string // Gets the server base - GetBase() (*ServerBase, error) + Base() (*ServerBase, error) } type ServerProfile struct { @@ -216,7 +215,7 @@ func (servers *Servers) AddSecureInternet( } func ShouldRenewButton(server Server) bool { - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { // FIXME: Log error here? @@ -251,28 +250,28 @@ func ShouldRenewButton(server Server) bool { return true } -func GetOAuthURL(server Server, name string) (string, error) { - return server.GetOAuth().GetAuthURL(name, server.GetTemplateAuth()) +func OAuthURL(server Server, name string) (string, error) { + return server.OAuth().AuthURL(name, server.TemplateAuth()) } func OAuthExchange(server Server) error { - return server.GetOAuth().Exchange() + return server.OAuth().Exchange() } -func GetHeaderToken(server Server) string { - return server.GetOAuth().Token.Access +func HeaderToken(server Server) string { + return server.OAuth().Token.Access } func MarkTokenExpired(server Server) { - server.GetOAuth().Token.ExpiredTimestamp = time.Now() + server.OAuth().Token.ExpiredTimestamp = time.Now() } func MarkTokensForRenew(server Server) { - server.GetOAuth().Token = oauth.OAuthToken{} + server.OAuth().Token = oauth.OAuthToken{} } func EnsureTokens(server Server) error { - ensureErr := server.GetOAuth().EnsureTokens() + ensureErr := server.OAuth().EnsureTokens() if ensureErr != nil { return types.NewWrappedError("failed ensuring server tokens", ensureErr) } @@ -284,7 +283,7 @@ func NeedsRelogin(server Server) bool { } func CancelOAuth(server Server) { - server.GetOAuth().Cancel() + server.OAuth().Cancel() } func (profile *ServerProfile) supportsProtocol(protocol string) bool { @@ -304,9 +303,9 @@ func (profile *ServerProfile) supportsOpenVPN() bool { return profile.supportsProtocol("openvpn") } -func getCurrentProfile(server Server) (*ServerProfile, error) { +func Profile(server Server) (*ServerProfile, error) { errorMessage := "failed getting current profile" - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return nil, types.NewWrappedError(errorMessage, baseErr) @@ -334,7 +333,7 @@ func (base *ServerBase) InitializeEndpoints() error { return nil } -func (base *ServerBase) GetValidProfiles(clientSupportsWireguard bool) ServerProfileInfo { +func (base *ServerBase) ValidProfiles(clientSupportsWireguard bool) ServerProfileInfo { var validProfiles []ServerProfile for _, profile := range base.Profiles.Info.ProfileList { // Not a valid profile because it does not support openvpn @@ -347,14 +346,14 @@ func (base *ServerBase) GetValidProfiles(clientSupportsWireguard bool) ServerPro return ServerProfileInfo{Current: base.Profiles.Current, Info: ServerProfileListInfo{ProfileList: validProfiles}} } -func GetValidProfiles(server Server, clientSupportsWireguard bool) (*ServerProfileInfo, error) { +func ValidProfiles(server Server, clientSupportsWireguard bool) (*ServerProfileInfo, error) { errorMessage := "failed to get valid profiles" // No error wrapping here otherwise we wrap it too much - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return nil, types.NewWrappedError(errorMessage, baseErr) } - profiles := base.GetValidProfiles(clientSupportsWireguard) + profiles := base.ValidProfiles(clientSupportsWireguard) if len(profiles.Info.ProfileList) == 0 { return nil, types.NewWrappedError(errorMessage, errors.New("no profiles found with supported protocols")) } @@ -363,7 +362,7 @@ func GetValidProfiles(server Server, clientSupportsWireguard bool) (*ServerProfi func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (string, string, error) { errorMessage := "failed getting server WireGuard configuration" - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return "", "", types.NewWrappedError(errorMessage, baseErr) @@ -406,7 +405,7 @@ func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (st func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) { errorMessage := "failed getting server OpenVPN configuration" - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return "", "", types.NewWrappedError(errorMessage, baseErr) @@ -435,14 +434,14 @@ func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) return false, types.NewWrappedError(errorMessage, infoErr) } - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return false, types.NewWrappedError(errorMessage, baseErr) } // If there was a profile chosen and it doesn't exist anymore, reset it if base.Profiles.Current != "" { - _, existsProfileErr := getCurrentProfile(server) + _, existsProfileErr := Profile(server) if existsProfileErr != nil { base.Profiles.Current = "" } @@ -454,7 +453,7 @@ func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) if base.Profiles.Current == "" { base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID } - profile, profileErr := getCurrentProfile(server) + profile, profileErr := Profile(server) // shouldn't happen if profileErr != nil { return false, types.NewWrappedError(errorMessage, profileErr) @@ -474,7 +473,7 @@ func RefreshEndpoints(server Server) error { // Re-initialize the endpoints // TODO: Make this a warning instead? - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return types.NewWrappedError(errorMessage, baseErr) } @@ -487,10 +486,10 @@ func RefreshEndpoints(server Server) error { return nil } -func GetConfig(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) { +func Config(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) { errorMessage := "failed getting an OpenVPN/WireGuard configuration" - profile, profileErr := getCurrentProfile(server) + profile, profileErr := Profile(server) if profileErr != nil { return "", "", types.NewWrappedError(errorMessage, profileErr) } diff --git a/internal/server/custom.go b/internal/server/custom.go index 8bde848..f8899b3 100644 --- a/internal/server/custom.go +++ b/internal/server/custom.go @@ -9,7 +9,7 @@ import ( func (servers *Servers) SetCustomServer(server Server) error { errorMessage := "failed setting custom server" - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return types.NewWrappedError(errorMessage, baseErr) } diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index 33d8b52..ca37dcd 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -11,10 +11,10 @@ import ( // An instute access server type InstituteAccessServer struct { // An instute access server has its own OAuth - OAuth oauth.OAuth `json:"oauth"` + Auth oauth.OAuth `json:"oauth"` // Embed the server base - Base ServerBase `json:"base"` + Basic ServerBase `json:"base"` } type InstituteAccessServers struct { @@ -24,7 +24,7 @@ type InstituteAccessServers struct { func (servers *Servers) SetInstituteAccess(server Server) error { errorMessage := "failed setting institute access server" - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return types.NewWrappedError(errorMessage, baseErr) } @@ -63,19 +63,18 @@ func (servers *InstituteAccessServers) Remove(url string) { delete(servers.Map, url) } -// For an institute, we can simply get the OAuth -func (institute *InstituteAccessServer) GetOAuth() *oauth.OAuth { - return &institute.OAuth -} - -func (institute *InstituteAccessServer) GetTemplateAuth() func(string) string { +func (institute *InstituteAccessServer) TemplateAuth() func(string) string { return func(authURL string) string { return authURL } } -func (institute *InstituteAccessServer) GetBase() (*ServerBase, error) { - return &institute.Base, nil +func (institute *InstituteAccessServer) Base() (*ServerBase, error) { + return &institute.Basic, nil +} + +func (institute *InstituteAccessServer) OAuth() *oauth.OAuth { + return &institute.Auth } func (institute *InstituteAccessServer) init( @@ -85,15 +84,15 @@ func (institute *InstituteAccessServer) init( supportContact []string, ) error { errorMessage := fmt.Sprintf("failed initializing server %s", url) - institute.Base.URL = url - institute.Base.DisplayName = displayName - institute.Base.SupportContact = supportContact - institute.Base.Type = serverType - endpointsErr := institute.Base.InitializeEndpoints() + institute.Basic.URL = url + institute.Basic.DisplayName = displayName + institute.Basic.SupportContact = supportContact + institute.Basic.Type = serverType + endpointsErr := institute.Basic.InitializeEndpoints() if endpointsErr != nil { return types.NewWrappedError(errorMessage, endpointsErr) } - API := institute.Base.Endpoints.API.V3 - institute.OAuth.Init(url, API.Authorization, API.Token) + API := institute.Basic.Endpoints.API.V3 + institute.Auth.Init(url, API.Authorization, API.Token) return nil } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index f0b308f..0dc9ef1 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -12,8 +12,8 @@ import ( // A secure internet server which has its own OAuth tokens // It specifies the current location url it is connected to type SecureInternetHomeServer struct { + Auth oauth.OAuth `json:"oauth"` DisplayName map[string]string `json:"display_name"` - OAuth oauth.OAuth `json:"oauth"` // The home server has a list of info for each configured server location BaseMap map[string]*ServerBase `json:"base_map"` @@ -33,7 +33,7 @@ func (servers *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer func (servers *Servers) SetSecureInternet(server Server) error { errorMessage := "failed setting secure internet server" - base, baseErr := server.GetBase() + base, baseErr := server.Base() if baseErr != nil { return types.NewWrappedError(errorMessage, baseErr) } @@ -58,17 +58,13 @@ func (servers *Servers) RemoveSecureInternet() { } } -func (server *SecureInternetHomeServer) GetOAuth() *oauth.OAuth { - return &server.OAuth -} - -func (server *SecureInternetHomeServer) GetTemplateAuth() func(string) string { +func (server *SecureInternetHomeServer) TemplateAuth() func(string) string { return func(authURL string) string { return util.ReplaceWAYF(server.AuthorizationTemplate, authURL, server.HomeOrganizationID) } } -func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) { +func (server *SecureInternetHomeServer) Base() (*ServerBase, error) { errorMessage := "failed getting current secure internet home base" if server.BaseMap == nil { return nil, types.NewWrappedError( @@ -88,6 +84,10 @@ func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) { return base, nil } +func (server *SecureInternetHomeServer) OAuth() *oauth.OAuth { + return &server.Auth +} + func (servers *Servers) HasSecureLocation() bool { return servers.SecureInternetHomeServer.CurrentLocation != "" } @@ -148,7 +148,7 @@ func (server *SecureInternetHomeServer) init( } // Make sure oauth contains our endpoints - server.OAuth.Init(base.URL, base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token) + server.Auth.Init(base.URL, base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token) return nil } -- cgit v1.2.3