summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-11-28 11:52:04 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-11-28 12:03:16 +0100
commit7339e77c6eda5b96874dfc099d5c58da8ed53629 (patch)
treeb602159b0c397cbaa4f8983aea987274163fe357 /internal/server
parente9f8db8ee8fccf60e58deb1d72766f94a053bb16 (diff)
Refactor: Remove most get prefixes for receiver functions
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/api.go6
-rw-r--r--internal/server/common.go55
-rw-r--r--internal/server/custom.go2
-rw-r--r--internal/server/instituteaccess.go35
-rw-r--r--internal/server/secureinternet.go18
5 files changed, 57 insertions, 59 deletions
diff --git a/internal/server/api.go b/internal/server/api.go
index d315ada..559b787 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -50,7 +50,7 @@ func apiAuthorized(
if opts == nil {
opts = &httpw.HTTPOptionalParams{}
}
- base, baseErr := server.GetBase()
+ base, baseErr := server.Base()
if baseErr != nil {
return nil, nil, types.NewWrappedError(errorMessage, baseErr)
@@ -70,7 +70,7 @@ func apiAuthorized(
}
headerKey := "Authorization"
- headerValue := fmt.Sprintf("Bearer %s", GetHeaderToken(server))
+ headerValue := fmt.Sprintf("Bearer %s", HeaderToken(server))
if opts.Headers != nil {
opts.Headers.Add(headerKey, headerValue)
} else {
@@ -119,7 +119,7 @@ func APIInfo(server Server) error {
return types.NewWrappedError(errorMessage, jsonErr)
}
- base, baseErr := server.GetBase()
+ base, baseErr := server.Base()
if baseErr != nil {
return types.NewWrappedError(errorMessage, baseErr)
diff --git a/internal/server/common.go b/internal/server/common.go
index 8f4eabc..16208eb 100644
--- a/internal/server/common.go
+++ b/internal/server/common.go
@@ -39,14 +39,13 @@ type Servers struct {
}
type Server interface {
- // Gets the current OAuth object
- GetOAuth() *oauth.OAuth
+ OAuth() *oauth.OAuth
// Get the authorization URL template function
- GetTemplateAuth() func(string) string
+ TemplateAuth() func(string) string
// Gets the server base
- GetBase() (*ServerBase, error)
+ Base() (*ServerBase, error)
}
type ServerProfile struct {
@@ -216,7 +215,7 @@ func (servers *Servers) AddSecureInternet(
}
func ShouldRenewButton(server Server) bool {
- base, baseErr := server.GetBase()
+ base, baseErr := server.Base()
if baseErr != nil {
// FIXME: Log error here?
@@ -251,28 +250,28 @@ func ShouldRenewButton(server Server) bool {
return true
}
-func GetOAuthURL(server Server, name string) (string, error) {
- return server.GetOAuth().GetAuthURL(name, server.GetTemplateAuth())
+func OAuthURL(server Server, name string) (string, error) {
+ return server.OAuth().AuthURL(name, server.TemplateAuth())
}
func OAuthExchange(server Server) error {
- return server.GetOAuth().Exchange()
+ return server.OAuth().Exchange()
}
-func GetHeaderToken(server Server) string {
- return server.GetOAuth().Token.Access
+func HeaderToken(server Server) string {
+ return server.OAuth().Token.Access
}
func MarkTokenExpired(server Server) {
- server.GetOAuth().Token.ExpiredTimestamp = time.Now()
+ server.OAuth().Token.ExpiredTimestamp = time.Now()
}
func MarkTokensForRenew(server Server) {
- server.GetOAuth().Token = oauth.OAuthToken{}
+ server.OAuth().Token = oauth.OAuthToken{}
}
func EnsureTokens(server Server) error {
- ensureErr := server.GetOAuth().EnsureTokens()
+ ensureErr := server.OAuth().EnsureTokens()
if ensureErr != nil {
return types.NewWrappedError("failed ensuring server tokens", ensureErr)
}
@@ -284,7 +283,7 @@ func NeedsRelogin(server Server) bool {
}
func CancelOAuth(server Server) {
- server.GetOAuth().Cancel()
+ server.OAuth().Cancel()
}
func (profile *ServerProfile) supportsProtocol(protocol string) bool {
@@ -304,9 +303,9 @@ func (profile *ServerProfile) supportsOpenVPN() bool {
return profile.supportsProtocol("openvpn")
}
-func getCurrentProfile(server Server) (*ServerProfile, error) {
+func Profile(server Server) (*ServerProfile, error) {
errorMessage := "failed getting current profile"
- base, baseErr := server.GetBase()
+ base, baseErr := server.Base()
if baseErr != nil {
return nil, types.NewWrappedError(errorMessage, baseErr)
@@ -334,7 +333,7 @@ func (base *ServerBase) InitializeEndpoints() error {
return nil
}
-func (base *ServerBase) GetValidProfiles(clientSupportsWireguard bool) ServerProfileInfo {
+func (base *ServerBase) ValidProfiles(clientSupportsWireguard bool) ServerProfileInfo {
var validProfiles []ServerProfile
for _, profile := range base.Profiles.Info.ProfileList {
// Not a valid profile because it does not support openvpn
@@ -347,14 +346,14 @@ func (base *ServerBase) GetValidProfiles(clientSupportsWireguard bool) ServerPro
return ServerProfileInfo{Current: base.Profiles.Current, Info: ServerProfileListInfo{ProfileList: validProfiles}}
}
-func GetValidProfiles(server Server, clientSupportsWireguard bool) (*ServerProfileInfo, error) {
+func ValidProfiles(server Server, clientSupportsWireguard bool) (*ServerProfileInfo, error) {
errorMessage := "failed to get valid profiles"
// No error wrapping here otherwise we wrap it too much
- base, baseErr := server.GetBase()
+ base, baseErr := server.Base()
if baseErr != nil {
return nil, types.NewWrappedError(errorMessage, baseErr)
}
- profiles := base.GetValidProfiles(clientSupportsWireguard)
+ profiles := base.ValidProfiles(clientSupportsWireguard)
if len(profiles.Info.ProfileList) == 0 {
return nil, types.NewWrappedError(errorMessage, errors.New("no profiles found with supported protocols"))
}
@@ -363,7 +362,7 @@ func GetValidProfiles(server Server, clientSupportsWireguard bool) (*ServerProfi
func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (string, string, error) {
errorMessage := "failed getting server WireGuard configuration"
- base, baseErr := server.GetBase()
+ base, baseErr := server.Base()
if baseErr != nil {
return "", "", types.NewWrappedError(errorMessage, baseErr)
@@ -406,7 +405,7 @@ func wireguardGetConfig(server Server, preferTCP bool, supportsOpenVPN bool) (st
func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) {
errorMessage := "failed getting server OpenVPN configuration"
- base, baseErr := server.GetBase()
+ base, baseErr := server.Base()
if baseErr != nil {
return "", "", types.NewWrappedError(errorMessage, baseErr)
@@ -435,14 +434,14 @@ func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error)
return false, types.NewWrappedError(errorMessage, infoErr)
}
- base, baseErr := server.GetBase()
+ base, baseErr := server.Base()
if baseErr != nil {
return false, types.NewWrappedError(errorMessage, baseErr)
}
// If there was a profile chosen and it doesn't exist anymore, reset it
if base.Profiles.Current != "" {
- _, existsProfileErr := getCurrentProfile(server)
+ _, existsProfileErr := Profile(server)
if existsProfileErr != nil {
base.Profiles.Current = ""
}
@@ -454,7 +453,7 @@ func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error)
if base.Profiles.Current == "" {
base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID
}
- profile, profileErr := getCurrentProfile(server)
+ profile, profileErr := Profile(server)
// shouldn't happen
if profileErr != nil {
return false, types.NewWrappedError(errorMessage, profileErr)
@@ -474,7 +473,7 @@ func RefreshEndpoints(server Server) error {
// Re-initialize the endpoints
// TODO: Make this a warning instead?
- base, baseErr := server.GetBase()
+ base, baseErr := server.Base()
if baseErr != nil {
return types.NewWrappedError(errorMessage, baseErr)
}
@@ -487,10 +486,10 @@ func RefreshEndpoints(server Server) error {
return nil
}
-func GetConfig(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) {
+func Config(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) {
errorMessage := "failed getting an OpenVPN/WireGuard configuration"
- profile, profileErr := getCurrentProfile(server)
+ profile, profileErr := Profile(server)
if profileErr != nil {
return "", "", types.NewWrappedError(errorMessage, profileErr)
}
diff --git a/internal/server/custom.go b/internal/server/custom.go
index 8bde848..f8899b3 100644
--- a/internal/server/custom.go
+++ b/internal/server/custom.go
@@ -9,7 +9,7 @@ import (
func (servers *Servers) SetCustomServer(server Server) error {
errorMessage := "failed setting custom server"
- base, baseErr := server.GetBase()
+ base, baseErr := server.Base()
if baseErr != nil {
return types.NewWrappedError(errorMessage, baseErr)
}
diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go
index 33d8b52..ca37dcd 100644
--- a/internal/server/instituteaccess.go
+++ b/internal/server/instituteaccess.go
@@ -11,10 +11,10 @@ import (
// An instute access server
type InstituteAccessServer struct {
// An instute access server has its own OAuth
- OAuth oauth.OAuth `json:"oauth"`
+ Auth oauth.OAuth `json:"oauth"`
// Embed the server base
- Base ServerBase `json:"base"`
+ Basic ServerBase `json:"base"`
}
type InstituteAccessServers struct {
@@ -24,7 +24,7 @@ type InstituteAccessServers struct {
func (servers *Servers) SetInstituteAccess(server Server) error {
errorMessage := "failed setting institute access server"
- base, baseErr := server.GetBase()
+ base, baseErr := server.Base()
if baseErr != nil {
return types.NewWrappedError(errorMessage, baseErr)
}
@@ -63,19 +63,18 @@ func (servers *InstituteAccessServers) Remove(url string) {
delete(servers.Map, url)
}
-// For an institute, we can simply get the OAuth
-func (institute *InstituteAccessServer) GetOAuth() *oauth.OAuth {
- return &institute.OAuth
-}
-
-func (institute *InstituteAccessServer) GetTemplateAuth() func(string) string {
+func (institute *InstituteAccessServer) TemplateAuth() func(string) string {
return func(authURL string) string {
return authURL
}
}
-func (institute *InstituteAccessServer) GetBase() (*ServerBase, error) {
- return &institute.Base, nil
+func (institute *InstituteAccessServer) Base() (*ServerBase, error) {
+ return &institute.Basic, nil
+}
+
+func (institute *InstituteAccessServer) OAuth() *oauth.OAuth {
+ return &institute.Auth
}
func (institute *InstituteAccessServer) init(
@@ -85,15 +84,15 @@ func (institute *InstituteAccessServer) init(
supportContact []string,
) error {
errorMessage := fmt.Sprintf("failed initializing server %s", url)
- institute.Base.URL = url
- institute.Base.DisplayName = displayName
- institute.Base.SupportContact = supportContact
- institute.Base.Type = serverType
- endpointsErr := institute.Base.InitializeEndpoints()
+ institute.Basic.URL = url
+ institute.Basic.DisplayName = displayName
+ institute.Basic.SupportContact = supportContact
+ institute.Basic.Type = serverType
+ endpointsErr := institute.Basic.InitializeEndpoints()
if endpointsErr != nil {
return types.NewWrappedError(errorMessage, endpointsErr)
}
- API := institute.Base.Endpoints.API.V3
- institute.OAuth.Init(url, API.Authorization, API.Token)
+ API := institute.Basic.Endpoints.API.V3
+ institute.Auth.Init(url, API.Authorization, API.Token)
return nil
}
diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go
index f0b308f..0dc9ef1 100644
--- a/internal/server/secureinternet.go
+++ b/internal/server/secureinternet.go
@@ -12,8 +12,8 @@ import (
// A secure internet server which has its own OAuth tokens
// It specifies the current location url it is connected to
type SecureInternetHomeServer struct {
+ Auth oauth.OAuth `json:"oauth"`
DisplayName map[string]string `json:"display_name"`
- OAuth oauth.OAuth `json:"oauth"`
// The home server has a list of info for each configured server location
BaseMap map[string]*ServerBase `json:"base_map"`
@@ -33,7 +33,7 @@ func (servers *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer
func (servers *Servers) SetSecureInternet(server Server) error {
errorMessage := "failed setting secure internet server"
- base, baseErr := server.GetBase()
+ base, baseErr := server.Base()
if baseErr != nil {
return types.NewWrappedError(errorMessage, baseErr)
}
@@ -58,17 +58,13 @@ func (servers *Servers) RemoveSecureInternet() {
}
}
-func (server *SecureInternetHomeServer) GetOAuth() *oauth.OAuth {
- return &server.OAuth
-}
-
-func (server *SecureInternetHomeServer) GetTemplateAuth() func(string) string {
+func (server *SecureInternetHomeServer) TemplateAuth() func(string) string {
return func(authURL string) string {
return util.ReplaceWAYF(server.AuthorizationTemplate, authURL, server.HomeOrganizationID)
}
}
-func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) {
+func (server *SecureInternetHomeServer) Base() (*ServerBase, error) {
errorMessage := "failed getting current secure internet home base"
if server.BaseMap == nil {
return nil, types.NewWrappedError(
@@ -88,6 +84,10 @@ func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) {
return base, nil
}
+func (server *SecureInternetHomeServer) OAuth() *oauth.OAuth {
+ return &server.Auth
+}
+
func (servers *Servers) HasSecureLocation() bool {
return servers.SecureInternetHomeServer.CurrentLocation != ""
}
@@ -148,7 +148,7 @@ func (server *SecureInternetHomeServer) init(
}
// Make sure oauth contains our endpoints
- server.OAuth.Init(base.URL, base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token)
+ server.Auth.Init(base.URL, base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token)
return nil
}