From 3ac1d35257b56cca92ad0eb7f4d18abb366cf105 Mon Sep 17 00:00:00 2001 From: Aleksandar Pesic Date: Sun, 4 Dec 2022 21:48:20 +0100 Subject: simplify error handling fixes #6 Signed-off-by: Aleksandar Pesic --- internal/server/server.go | 318 +++++++++++++++++----------------------------- 1 file changed, 116 insertions(+), 202 deletions(-) (limited to 'internal/server/server.go') diff --git a/internal/server/server.go b/internal/server/server.go index 95244d5..de0fa9a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,13 +1,11 @@ package server import ( - "errors" - "fmt" "time" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/wireguard" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) type Type int8 @@ -21,10 +19,10 @@ const ( type Server interface { OAuth() *oauth.OAuth - // Get the authorization URL template function + // TemplateAuth returns the authorization URL template function TemplateAuth() func(string) string - // Gets the server base + // Base returns the server base Base() (*Base, error) } @@ -34,7 +32,7 @@ type EndpointList struct { Token string `json:"token_endpoint"` } -// Struct that defines the json format for /.well-known/vpn-user-portal". +// Endpoints defines the json format for /.well-known/vpn-user-portal". type Endpoints struct { API struct { V2 EndpointList `json:"http://eduvpn.org/api#2"` @@ -43,310 +41,226 @@ type Endpoints struct { V string `json:"v"` } -func ShouldRenewButton(server Server) bool { - base, baseErr := server.Base() - - if baseErr != nil { +func ShouldRenewButton(srv Server) bool { + b, err := srv.Base() + if err != nil { // FIXME: Log error here? return false } // Get current time - current := time.Now() + now := time.Now() // Session is expired - if !current.Before(base.EndTime) { + if !now.Before(b.EndTime) { return true } // 30 minutes have not passed - if !current.After(base.StartTime.Add(30 * time.Minute)) { + if !now.After(b.StartTime.Add(30 * time.Minute)) { return false } // Session will not expire today - if !current.Add(24 * time.Hour).After(base.EndTime) { + if !now.Add(24 * time.Hour).After(b.EndTime) { return false } // Session duration is less than 24 hours but not 75% has passed - duration := base.EndTime.Sub(base.StartTime) - percentTime := base.StartTime.Add((duration / 4) * 3) - if duration < time.Duration(24*time.Hour) && !current.After(percentTime) { + d := b.EndTime.Sub(b.StartTime) + pct := b.StartTime.Add((d / 4) * 3) + if d < 24*time.Hour && !now.After(pct) { return false } return true } -func OAuthURL(server Server, name string) (string, error) { - return server.OAuth().AuthURL(name, server.TemplateAuth()) +func OAuthURL(srv Server, name string) (string, error) { + return srv.OAuth().AuthURL(name, srv.TemplateAuth()) } -func OAuthExchange(server Server) error { - return server.OAuth().Exchange() +func OAuthExchange(srv Server) error { + return srv.OAuth().Exchange() } -func HeaderToken(server Server) (string, error) { - token, tokenErr := server.OAuth().AccessToken() - if tokenErr != nil { - return "", types.NewWrappedError("failed getting server token for HTTP Header", tokenErr) - } - return token, nil +func HeaderToken(srv Server) (string, error) { + return srv.OAuth().AccessToken() } -func MarkTokenExpired(server Server) { - server.OAuth().SetTokenExpired() +func MarkTokenExpired(srv Server) { + srv.OAuth().SetTokenExpired() } -func MarkTokensForRenew(server Server) { - server.OAuth().SetTokenRenew() +func MarkTokensForRenew(srv Server) { + srv.OAuth().SetTokenRenew() } -func NeedsRelogin(server Server) bool { - _, tokenErr := HeaderToken(server) - return tokenErr != nil +func NeedsRelogin(srv Server) bool { + _, err := HeaderToken(srv) + return err != nil } -func CancelOAuth(server Server) { - server.OAuth().Cancel() +func CancelOAuth(srv Server) { + srv.OAuth().Cancel() } -func CurrentProfile(server Server) (*Profile, error) { - errorMessage := "failed getting current profile" - base, baseErr := server.Base() - - if baseErr != nil { - return nil, types.NewWrappedError(errorMessage, baseErr) +func CurrentProfile(srv Server) (*Profile, error) { + b, err := srv.Base() + if err != nil { + return nil, err } - profileID := base.Profiles.Current - for _, profile := range base.Profiles.Info.ProfileList { - if profile.ID == profileID { + pid := b.Profiles.Current + for _, profile := range b.Profiles.Info.ProfileList { + if profile.ID == pid { return &profile, nil } } - return nil, types.NewWrappedError( - errorMessage, - &CurrentProfileNotFoundError{ProfileID: profileID}, - ) + return nil, errors.Errorf("profile not found: " + pid) } -func ValidProfiles(server Server, clientSupportsWireguard bool) (*ProfileInfo, error) { - errorMessage := "failed to get valid profiles" +func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) { // No error wrapping here otherwise we wrap it too much - base, baseErr := server.Base() - if baseErr != nil { - return nil, types.NewWrappedError(errorMessage, baseErr) + b, err := srv.Base() + if err != nil { + return nil, err } - profiles := base.ValidProfiles(clientSupportsWireguard) - if len(profiles.Info.ProfileList) == 0 { - return nil, types.NewWrappedError( - errorMessage, - errors.New("no profiles found with supported protocols"), - ) + ps := b.ValidProfiles(wireguardSupport) + if len(ps.Info.ProfileList) == 0 { + return nil, errors.Errorf("no profiles found with supported protocols") } - return &profiles, nil + return &ps, nil } -func wireguardGetConfig( - server Server, - preferTCP bool, - supportsOpenVPN bool, -) (string, string, error) { - errorMessage := "failed getting server WireGuard configuration" - base, baseErr := server.Base() - - if baseErr != nil { - return "", "", types.NewWrappedError(errorMessage, baseErr) +func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string, string, error) { + b, err := srv.Base() + if err != nil { + return "", "", err } - profileID := base.Profiles.Current - wireguardKey, wireguardErr := wireguard.GenerateKey() - - if wireguardErr != nil { - return "", "", types.NewWrappedError(errorMessage, wireguardErr) + pid := b.Profiles.Current + key, err := wireguard.GenerateKey() + if err != nil { + return "", "", err } - wireguardPublicKey := wireguardKey.PublicKey().String() - config, content, expires, configErr := APIConnectWireguard( - server, - profileID, - wireguardPublicKey, - preferTCP, - supportsOpenVPN, - ) - - if configErr != nil { - return "", "", types.NewWrappedError(errorMessage, configErr) + pub := key.PublicKey().String() + cfg, ct, exp, err := APIConnectWireguard(srv, pid, pub, preferTCP, openVPNSupport) + if err != nil { + return "", "", err } // Store start and end time - base.StartTime = time.Now() - base.EndTime = expires + b.StartTime = time.Now() + b.EndTime = exp - if content == "wireguard" { + if ct == "wireguard" { // This needs the go code a way to identify a connection // Use the uuid of the connection e.g. on Linux // This needs the client code to call the go code - config = wireguard.ConfigAddKey(config, wireguardKey) + cfg = wireguard.ConfigAddKey(cfg, key) } - return config, content, nil + return cfg, ct, nil } -func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) { - errorMessage := "failed getting server OpenVPN configuration" - base, baseErr := server.Base() - - if baseErr != nil { - return "", "", types.NewWrappedError(errorMessage, baseErr) +func openVPNGetConfig(srv Server, preferTCP bool) (string, string, error) { + b, err := srv.Base() + if err != nil { + return "", "", err } - profileID := base.Profiles.Current - configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profileID, preferTCP) + pid := b.Profiles.Current + cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP) // Store start and end time - base.StartTime = time.Now() - base.EndTime = expires + b.StartTime = time.Now() + b.EndTime = exp - if configErr != nil { - return "", "", types.NewWrappedError(errorMessage, configErr) + if err != nil { + return "", "", err } - return configOpenVPN, "openvpn", nil + return cfg, "openvpn", nil } -func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) { - errorMessage := "failed has valid profile check" - +func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { // Get new profiles using the info call // This does not override the current profile - infoErr := APIInfo(server) - if infoErr != nil { - return false, types.NewWrappedError(errorMessage, infoErr) + err := APIInfo(srv) + if err != nil { + return false, err } - base, baseErr := server.Base() - if baseErr != nil { - return false, types.NewWrappedError(errorMessage, baseErr) + b, err := srv.Base() + if err != nil { + return false, err } // If there was a profile chosen and it doesn't exist anymore, reset it - if base.Profiles.Current != "" { - _, existsProfileErr := CurrentProfile(server) - if existsProfileErr != nil { - base.Profiles.Current = "" + if b.Profiles.Current != "" { + if _, err = CurrentProfile(srv); err != nil { + b.Profiles.Current = "" } } - // Set the current profile if there is only one profile or profile is already selected - if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" { - // Set the first profile if none is selected - if base.Profiles.Current == "" { - base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID - } - profile, profileErr := CurrentProfile(server) - // shouldn't happen - if profileErr != nil { - return false, types.NewWrappedError(errorMessage, profileErr) - } - // Profile does not support OpenVPN but the client also doesn't support WireGuard - if !profile.supportsOpenVPN() && !clientSupportsWireguard { - return false, nil - } - return true, nil + if len(b.Profiles.Info.ProfileList) != 1 && b.Profiles.Current == "" { + return false, nil } - return false, nil + // Set the current profile if there is only one profile or profile is already selected + // Set the first profile if none is selected + if b.Profiles.Current == "" { + b.Profiles.Current = b.Profiles.Info.ProfileList[0].ID + } + p, err := CurrentProfile(srv) + // shouldn't happen + if err != nil { + return false, err + } + // Profile does not support OpenVPN but the client also doesn't support WireGuard + if !p.supportsOpenVPN() && !wireguardSupport { + return false, nil + } + return true, nil } -func RefreshEndpoints(server Server) error { - errorMessage := "failed to refresh server endpoints" - +func RefreshEndpoints(srv Server) error { // Re-initialize the endpoints // TODO: Make this a warning instead? - base, baseErr := server.Base() - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) - } - - endpointsErr := base.InitializeEndpoints() - if endpointsErr != nil { - return types.NewWrappedError(errorMessage, endpointsErr) + b, err := srv.Base() + if err != nil { + return err } - return nil + return b.InitializeEndpoints() } -func Config(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) { - errorMessage := "failed getting an OpenVPN/WireGuard configuration" - - profile, profileErr := CurrentProfile(server) - if profileErr != nil { - return "", "", types.NewWrappedError(errorMessage, profileErr) +func Config(server Server, wireguardSupport bool, preferTCP bool) (string, string, error) { + p, err := CurrentProfile(server) + if err != nil { + return "", "", err } - supportsOpenVPN := profile.supportsOpenVPN() - supportsWireguard := profile.supportsWireguard() && clientSupportsWireguard - - var config string - var configType string - var configErr error + ovpn := p.supportsOpenVPN() + wg := p.supportsWireguard() && wireguardSupport switch { // The config supports wireguard and optionally openvpn - case supportsWireguard: + case wg: // A wireguard connect call needs to generate a wireguard key and add it to the config // Also the server could send back an OpenVPN config if it supports OpenVPN - config, configType, configErr = wireguardGetConfig(server, preferTCP, supportsOpenVPN) + return wireguardGetConfig(server, preferTCP, ovpn) // The config only supports OpenVPN - case supportsOpenVPN: - config, configType, configErr = openVPNGetConfig(server, preferTCP) + case ovpn: + return openVPNGetConfig(server, preferTCP) // The config supports no available protocol because the profile only supports WireGuard but the client doesn't default: - return "", "", types.NewWrappedError(errorMessage, errors.New("no supported protocol found")) + return "", "", errors.Errorf("no supported protocol found") } - - if configErr != nil { - return "", "", types.NewWrappedError(errorMessage, configErr) - } - - return config, configType, nil } func Disconnect(server Server) { APIDisconnect(server) } - -type CurrentProfileNotFoundError struct { - ProfileID string -} - -func (e *CurrentProfileNotFoundError) Error() string { - return fmt.Sprintf("failed to get current profile, profile with ID: %s not found", e.ProfileID) -} - -type ConfigPreferTCPError struct{} - -func (e *ConfigPreferTCPError) Error() string { - return "failed to get config, prefer TCP is on but the server does not support OpenVPN" -} - -type EmptyURLError struct{} - -func (e *EmptyURLError) Error() string { - return "failed ensuring server, empty url provided" -} - -type CurrentNoMapError struct{} - -func (e *CurrentNoMapError) Error() string { - return "failed getting current server, no servers available" -} - -type CurrentNotFoundError struct{} - -func (e *CurrentNotFoundError) Error() string { - return "failed getting current server, not found" -} -- cgit v1.2.3