diff options
Diffstat (limited to 'internal/server/server.go')
| -rw-r--r-- | internal/server/server.go | 130 |
1 files changed, 74 insertions, 56 deletions
diff --git a/internal/server/server.go b/internal/server/server.go index 4bd8766..e7229c5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,23 +1,21 @@ package server import ( + "context" "os" "time" "github.com/eduvpn/eduvpn-common/internal/discovery" "github.com/eduvpn/eduvpn-common/internal/oauth" + "github.com/eduvpn/eduvpn-common/internal/server/api" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/server/profile" "github.com/eduvpn/eduvpn-common/internal/wireguard" + "github.com/eduvpn/eduvpn-common/types/protocol" + srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" ) -type Type int8 - -const ( - CustomServerType Type = iota - InstituteAccessServerType - SecureInternetServerType -) - type Server interface { OAuth() *oauth.OAuth @@ -25,27 +23,13 @@ type Server interface { TemplateAuth() func(string) string // Base returns the server base - Base() (*Base, error) + Base() (*base.Base, error) - // RefreshEndpoints - RefreshEndpoints(*discovery.Discovery) error -} + // NeedsLocation checks if the server needs a secure internet location + NeedsLocation() bool -type EndpointList struct { - API string `json:"api_endpoint"` - Authorization string `json:"authorization_endpoint"` - Token string `json:"token_endpoint"` -} - -type EndpointsVersions struct { - V2 EndpointList `json:"http://eduvpn.org/api#2"` - V3 EndpointList `json:"http://eduvpn.org/api#3"` -} - -// Endpoints defines the json format for /.well-known/vpn-user-portal". -type Endpoints struct { - API EndpointsVersions `json:"api"` - V string `json:"v"` + // Public returns the representation that will be passed over the CGO barrier + Public() (interface{}, error) } func UpdateTokens(srv Server, t oauth.Token) { @@ -56,12 +40,12 @@ func OAuthURL(srv Server, name string) (string, error) { return srv.OAuth().AuthURL(name, srv.TemplateAuth()) } -func OAuthExchange(srv Server) error { - return srv.OAuth().Exchange() +func OAuthExchange(ctx context.Context, srv Server) error { + return srv.OAuth().Exchange(ctx) } -func HeaderToken(srv Server) (string, error) { - return srv.OAuth().AccessToken() +func HeaderToken(ctx context.Context, srv Server) (string, error) { + return srv.OAuth().AccessToken(ctx) } func MarkTokenExpired(srv Server) { @@ -72,16 +56,13 @@ func MarkTokensForRenew(srv Server) { srv.OAuth().SetTokenRenew() } -func NeedsRelogin(srv Server) bool { - _, err := HeaderToken(srv) +func NeedsRelogin(ctx context.Context, srv Server) bool { + // TODO: this error can be a context cancel + _, err := HeaderToken(ctx, srv) return err != nil } -func CancelOAuth(srv Server) { - srv.OAuth().Cancel() -} - -func CurrentProfile(srv Server) (*Profile, error) { +func CurrentProfile(srv Server) (*profile.Profile, error) { b, err := srv.Base() if err != nil { return nil, err @@ -96,19 +77,31 @@ func CurrentProfile(srv Server) (*Profile, error) { return nil, errors.Errorf("profile not found: " + pID) } -func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) { +func ValidProfiles(srv Server, wireguardSupport bool) (*[]profile.Profile, error) { // No error wrapping here otherwise we wrap it too much b, err := srv.Base() if err != nil { return nil, err } - ps := b.ValidProfiles(wireguardSupport) - if len(ps.Info.ProfileList) == 0 { + ps := b.Profiles.Supported(wireguardSupport) + if len(ps) == 0 { return nil, errors.Errorf("no profiles found with supported protocols") } return &ps, nil } +func Profile(srv Server, id string) error { + b, err := srv.Base() + if err != nil { + return err + } + if !b.Profiles.Has(id) { + return errors.Errorf("no profile available with id: %s", id) + } + b.Profiles.Current = id + return nil +} + type ConfigData struct { // The configuration Config string @@ -120,7 +113,18 @@ type ConfigData struct { Tokens oauth.Token } -func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) { +// Public gets the public data from the types package +// dg specifies if this config is default gateway +func (c *ConfigData) Public(dg bool) srvtypes.Configuration { + return srvtypes.Configuration{ + VPNConfig: c.Config, + Protocol: protocol.New(c.Type), + DefaultGateway: dg, + Tokens: c.Tokens.Public(), + } +} + +func wireguardGetConfig(ctx context.Context, srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) { b, err := srv.Base() if err != nil { return nil, err @@ -133,7 +137,7 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*Confi } pub := key.PublicKey().String() - cfg, proto, exp, err := APIConnectWireguard(srv, pID, pub, preferTCP, openVPNSupport) + cfg, proto, exp, err := api.ConnectWireguard(ctx, b, srv.OAuth(), pID, pub, preferTCP, openVPNSupport) if err != nil { return nil, err } @@ -159,13 +163,13 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*Confi return &ConfigData{Config: cfg, Type: proto, Tokens: t}, nil } -func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) { +func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigData, error) { b, err := srv.Base() if err != nil { return nil, err } pid := b.Profiles.Current - cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP) + cfg, exp, err := api.ConnectOpenVPN(ctx, b, srv.OAuth(), pid, preferTCP) if err != nil { return nil, err } @@ -184,15 +188,14 @@ func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) { return &ConfigData{Config: cfg, Type: "openvpn", Tokens: t}, nil } -func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { - // Get new profiles using the info call - // This does not override the current profile - err := APIInfo(srv) +func HasValidProfile(ctx context.Context, srv Server, wireguardSupport bool) (bool, error) { + b, err := srv.Base() if err != nil { return false, err } - - b, err := srv.Base() + // Get new profiles using the info call + // This does not override the current profile + err = api.Info(ctx, b, srv.OAuth()) if err != nil { return false, err } @@ -225,7 +228,18 @@ func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { return true, nil } -func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) { +func RefreshEndpoints(ctx context.Context, srv Server) error { + // Re-initialize the endpoints + // TODO: Make this a warning instead? + b, err := srv.Base() + if err != nil { + return err + } + + return api.Endpoints(ctx, b) +} + +func Config(ctx context.Context, server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) { p, err := CurrentProfile(server) if err != nil { return nil, err @@ -250,10 +264,10 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, case wg: // A wireguard connect call needs to generate a wireguard key and add it to the config // Also the server could send back an OpenVPN config if it supports OpenVPN - cfg, err = wireguardGetConfig(server, preferTCP, ovpn) + cfg, err = wireguardGetConfig(ctx, server, preferTCP, ovpn) // The config only supports OpenVPN case ovpn: - cfg, err = openVPNGetConfig(server, preferTCP) + cfg, err = openVPNGetConfig(ctx, server, preferTCP) // The config supports no available protocol because the profile only supports WireGuard but the client doesn't default: return nil, errors.New("no supported protocol found") @@ -267,6 +281,10 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, return cfg, err } -func Disconnect(server Server) error { - return APIDisconnect(server) +func Disconnect(ctx context.Context, server Server) error { + b, err := server.Base() + if err != nil { + return err + } + return api.Disconnect(ctx, b, server.OAuth()) } |
