summaryrefslogtreecommitdiff
path: root/internal/server/server.go
diff options
context:
space:
mode:
authorAleksandar Pesic <peske.nis@gmail.com>2022-12-04 21:48:20 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-12-12 13:26:51 +0100
commit3ac1d35257b56cca92ad0eb7f4d18abb366cf105 (patch)
tree432db14d1f92a252518f371be420fa0d3ef044c8 /internal/server/server.go
parent37bca013bd4405548b274ac473acf959ad661ee6 (diff)
simplify error handling
fixes #6 Signed-off-by: Aleksandar Pesic <peske.nis@gmail.com>
Diffstat (limited to 'internal/server/server.go')
-rw-r--r--internal/server/server.go318
1 files changed, 116 insertions, 202 deletions
diff --git a/internal/server/server.go b/internal/server/server.go
index 95244d5..de0fa9a 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -1,13 +1,11 @@
package server
import (
- "errors"
- "fmt"
"time"
"github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/eduvpn/eduvpn-common/internal/wireguard"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
type Type int8
@@ -21,10 +19,10 @@ const (
type Server interface {
OAuth() *oauth.OAuth
- // Get the authorization URL template function
+ // TemplateAuth returns the authorization URL template function
TemplateAuth() func(string) string
- // Gets the server base
+ // Base returns the server base
Base() (*Base, error)
}
@@ -34,7 +32,7 @@ type EndpointList struct {
Token string `json:"token_endpoint"`
}
-// Struct that defines the json format for /.well-known/vpn-user-portal".
+// Endpoints defines the json format for /.well-known/vpn-user-portal".
type Endpoints struct {
API struct {
V2 EndpointList `json:"http://eduvpn.org/api#2"`
@@ -43,310 +41,226 @@ type Endpoints struct {
V string `json:"v"`
}
-func ShouldRenewButton(server Server) bool {
- base, baseErr := server.Base()
-
- if baseErr != nil {
+func ShouldRenewButton(srv Server) bool {
+ b, err := srv.Base()
+ if err != nil {
// FIXME: Log error here?
return false
}
// Get current time
- current := time.Now()
+ now := time.Now()
// Session is expired
- if !current.Before(base.EndTime) {
+ if !now.Before(b.EndTime) {
return true
}
// 30 minutes have not passed
- if !current.After(base.StartTime.Add(30 * time.Minute)) {
+ if !now.After(b.StartTime.Add(30 * time.Minute)) {
return false
}
// Session will not expire today
- if !current.Add(24 * time.Hour).After(base.EndTime) {
+ if !now.Add(24 * time.Hour).After(b.EndTime) {
return false
}
// Session duration is less than 24 hours but not 75% has passed
- duration := base.EndTime.Sub(base.StartTime)
- percentTime := base.StartTime.Add((duration / 4) * 3)
- if duration < time.Duration(24*time.Hour) && !current.After(percentTime) {
+ d := b.EndTime.Sub(b.StartTime)
+ pct := b.StartTime.Add((d / 4) * 3)
+ if d < 24*time.Hour && !now.After(pct) {
return false
}
return true
}
-func OAuthURL(server Server, name string) (string, error) {
- return server.OAuth().AuthURL(name, server.TemplateAuth())
+func OAuthURL(srv Server, name string) (string, error) {
+ return srv.OAuth().AuthURL(name, srv.TemplateAuth())
}
-func OAuthExchange(server Server) error {
- return server.OAuth().Exchange()
+func OAuthExchange(srv Server) error {
+ return srv.OAuth().Exchange()
}
-func HeaderToken(server Server) (string, error) {
- token, tokenErr := server.OAuth().AccessToken()
- if tokenErr != nil {
- return "", types.NewWrappedError("failed getting server token for HTTP Header", tokenErr)
- }
- return token, nil
+func HeaderToken(srv Server) (string, error) {
+ return srv.OAuth().AccessToken()
}
-func MarkTokenExpired(server Server) {
- server.OAuth().SetTokenExpired()
+func MarkTokenExpired(srv Server) {
+ srv.OAuth().SetTokenExpired()
}
-func MarkTokensForRenew(server Server) {
- server.OAuth().SetTokenRenew()
+func MarkTokensForRenew(srv Server) {
+ srv.OAuth().SetTokenRenew()
}
-func NeedsRelogin(server Server) bool {
- _, tokenErr := HeaderToken(server)
- return tokenErr != nil
+func NeedsRelogin(srv Server) bool {
+ _, err := HeaderToken(srv)
+ return err != nil
}
-func CancelOAuth(server Server) {
- server.OAuth().Cancel()
+func CancelOAuth(srv Server) {
+ srv.OAuth().Cancel()
}
-func CurrentProfile(server Server) (*Profile, error) {
- errorMessage := "failed getting current profile"
- base, baseErr := server.Base()
-
- if baseErr != nil {
- return nil, types.NewWrappedError(errorMessage, baseErr)
+func CurrentProfile(srv Server) (*Profile, error) {
+ b, err := srv.Base()
+ if err != nil {
+ return nil, err
}
- profileID := base.Profiles.Current
- for _, profile := range base.Profiles.Info.ProfileList {
- if profile.ID == profileID {
+ pid := b.Profiles.Current
+ for _, profile := range b.Profiles.Info.ProfileList {
+ if profile.ID == pid {
return &profile, nil
}
}
- return nil, types.NewWrappedError(
- errorMessage,
- &CurrentProfileNotFoundError{ProfileID: profileID},
- )
+ return nil, errors.Errorf("profile not found: " + pid)
}
-func ValidProfiles(server Server, clientSupportsWireguard bool) (*ProfileInfo, error) {
- errorMessage := "failed to get valid profiles"
+func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) {
// No error wrapping here otherwise we wrap it too much
- base, baseErr := server.Base()
- if baseErr != nil {
- return nil, types.NewWrappedError(errorMessage, baseErr)
+ b, err := srv.Base()
+ if err != nil {
+ return nil, err
}
- profiles := base.ValidProfiles(clientSupportsWireguard)
- if len(profiles.Info.ProfileList) == 0 {
- return nil, types.NewWrappedError(
- errorMessage,
- errors.New("no profiles found with supported protocols"),
- )
+ ps := b.ValidProfiles(wireguardSupport)
+ if len(ps.Info.ProfileList) == 0 {
+ return nil, errors.Errorf("no profiles found with supported protocols")
}
- return &profiles, nil
+ return &ps, nil
}
-func wireguardGetConfig(
- server Server,
- preferTCP bool,
- supportsOpenVPN bool,
-) (string, string, error) {
- errorMessage := "failed getting server WireGuard configuration"
- base, baseErr := server.Base()
-
- if baseErr != nil {
- return "", "", types.NewWrappedError(errorMessage, baseErr)
+func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string, string, error) {
+ b, err := srv.Base()
+ if err != nil {
+ return "", "", err
}
- profileID := base.Profiles.Current
- wireguardKey, wireguardErr := wireguard.GenerateKey()
-
- if wireguardErr != nil {
- return "", "", types.NewWrappedError(errorMessage, wireguardErr)
+ pid := b.Profiles.Current
+ key, err := wireguard.GenerateKey()
+ if err != nil {
+ return "", "", err
}
- wireguardPublicKey := wireguardKey.PublicKey().String()
- config, content, expires, configErr := APIConnectWireguard(
- server,
- profileID,
- wireguardPublicKey,
- preferTCP,
- supportsOpenVPN,
- )
-
- if configErr != nil {
- return "", "", types.NewWrappedError(errorMessage, configErr)
+ pub := key.PublicKey().String()
+ cfg, ct, exp, err := APIConnectWireguard(srv, pid, pub, preferTCP, openVPNSupport)
+ if err != nil {
+ return "", "", err
}
// Store start and end time
- base.StartTime = time.Now()
- base.EndTime = expires
+ b.StartTime = time.Now()
+ b.EndTime = exp
- if content == "wireguard" {
+ if ct == "wireguard" {
// This needs the go code a way to identify a connection
// Use the uuid of the connection e.g. on Linux
// This needs the client code to call the go code
- config = wireguard.ConfigAddKey(config, wireguardKey)
+ cfg = wireguard.ConfigAddKey(cfg, key)
}
- return config, content, nil
+ return cfg, ct, nil
}
-func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) {
- errorMessage := "failed getting server OpenVPN configuration"
- base, baseErr := server.Base()
-
- if baseErr != nil {
- return "", "", types.NewWrappedError(errorMessage, baseErr)
+func openVPNGetConfig(srv Server, preferTCP bool) (string, string, error) {
+ b, err := srv.Base()
+ if err != nil {
+ return "", "", err
}
- profileID := base.Profiles.Current
- configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profileID, preferTCP)
+ pid := b.Profiles.Current
+ cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP)
// Store start and end time
- base.StartTime = time.Now()
- base.EndTime = expires
+ b.StartTime = time.Now()
+ b.EndTime = exp
- if configErr != nil {
- return "", "", types.NewWrappedError(errorMessage, configErr)
+ if err != nil {
+ return "", "", err
}
- return configOpenVPN, "openvpn", nil
+ return cfg, "openvpn", nil
}
-func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) {
- errorMessage := "failed has valid profile check"
-
+func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) {
// Get new profiles using the info call
// This does not override the current profile
- infoErr := APIInfo(server)
- if infoErr != nil {
- return false, types.NewWrappedError(errorMessage, infoErr)
+ err := APIInfo(srv)
+ if err != nil {
+ return false, err
}
- base, baseErr := server.Base()
- if baseErr != nil {
- return false, types.NewWrappedError(errorMessage, baseErr)
+ b, err := srv.Base()
+ if err != nil {
+ return false, err
}
// If there was a profile chosen and it doesn't exist anymore, reset it
- if base.Profiles.Current != "" {
- _, existsProfileErr := CurrentProfile(server)
- if existsProfileErr != nil {
- base.Profiles.Current = ""
+ if b.Profiles.Current != "" {
+ if _, err = CurrentProfile(srv); err != nil {
+ b.Profiles.Current = ""
}
}
- // Set the current profile if there is only one profile or profile is already selected
- if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" {
- // Set the first profile if none is selected
- if base.Profiles.Current == "" {
- base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID
- }
- profile, profileErr := CurrentProfile(server)
- // shouldn't happen
- if profileErr != nil {
- return false, types.NewWrappedError(errorMessage, profileErr)
- }
- // Profile does not support OpenVPN but the client also doesn't support WireGuard
- if !profile.supportsOpenVPN() && !clientSupportsWireguard {
- return false, nil
- }
- return true, nil
+ if len(b.Profiles.Info.ProfileList) != 1 && b.Profiles.Current == "" {
+ return false, nil
}
- return false, nil
+ // Set the current profile if there is only one profile or profile is already selected
+ // Set the first profile if none is selected
+ if b.Profiles.Current == "" {
+ b.Profiles.Current = b.Profiles.Info.ProfileList[0].ID
+ }
+ p, err := CurrentProfile(srv)
+ // shouldn't happen
+ if err != nil {
+ return false, err
+ }
+ // Profile does not support OpenVPN but the client also doesn't support WireGuard
+ if !p.supportsOpenVPN() && !wireguardSupport {
+ return false, nil
+ }
+ return true, nil
}
-func RefreshEndpoints(server Server) error {
- errorMessage := "failed to refresh server endpoints"
-
+func RefreshEndpoints(srv Server) error {
// Re-initialize the endpoints
// TODO: Make this a warning instead?
- base, baseErr := server.Base()
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
- }
-
- endpointsErr := base.InitializeEndpoints()
- if endpointsErr != nil {
- return types.NewWrappedError(errorMessage, endpointsErr)
+ b, err := srv.Base()
+ if err != nil {
+ return err
}
- return nil
+ return b.InitializeEndpoints()
}
-func Config(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) {
- errorMessage := "failed getting an OpenVPN/WireGuard configuration"
-
- profile, profileErr := CurrentProfile(server)
- if profileErr != nil {
- return "", "", types.NewWrappedError(errorMessage, profileErr)
+func Config(server Server, wireguardSupport bool, preferTCP bool) (string, string, error) {
+ p, err := CurrentProfile(server)
+ if err != nil {
+ return "", "", err
}
- supportsOpenVPN := profile.supportsOpenVPN()
- supportsWireguard := profile.supportsWireguard() && clientSupportsWireguard
-
- var config string
- var configType string
- var configErr error
+ ovpn := p.supportsOpenVPN()
+ wg := p.supportsWireguard() && wireguardSupport
switch {
// The config supports wireguard and optionally openvpn
- case supportsWireguard:
+ case wg:
// A wireguard connect call needs to generate a wireguard key and add it to the config
// Also the server could send back an OpenVPN config if it supports OpenVPN
- config, configType, configErr = wireguardGetConfig(server, preferTCP, supportsOpenVPN)
+ return wireguardGetConfig(server, preferTCP, ovpn)
// The config only supports OpenVPN
- case supportsOpenVPN:
- config, configType, configErr = openVPNGetConfig(server, preferTCP)
+ case ovpn:
+ return openVPNGetConfig(server, preferTCP)
// The config supports no available protocol because the profile only supports WireGuard but the client doesn't
default:
- return "", "", types.NewWrappedError(errorMessage, errors.New("no supported protocol found"))
+ return "", "", errors.Errorf("no supported protocol found")
}
-
- if configErr != nil {
- return "", "", types.NewWrappedError(errorMessage, configErr)
- }
-
- return config, configType, nil
}
func Disconnect(server Server) {
APIDisconnect(server)
}
-
-type CurrentProfileNotFoundError struct {
- ProfileID string
-}
-
-func (e *CurrentProfileNotFoundError) Error() string {
- return fmt.Sprintf("failed to get current profile, profile with ID: %s not found", e.ProfileID)
-}
-
-type ConfigPreferTCPError struct{}
-
-func (e *ConfigPreferTCPError) Error() string {
- return "failed to get config, prefer TCP is on but the server does not support OpenVPN"
-}
-
-type EmptyURLError struct{}
-
-func (e *EmptyURLError) Error() string {
- return "failed ensuring server, empty url provided"
-}
-
-type CurrentNoMapError struct{}
-
-func (e *CurrentNoMapError) Error() string {
- return "failed getting current server, no servers available"
-}
-
-type CurrentNotFoundError struct{}
-
-func (e *CurrentNotFoundError) Error() string {
- return "failed getting current server, not found"
-}