summaryrefslogtreecommitdiff
path: root/internal/server/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/server.go')
-rw-r--r--internal/server/server.go130
1 files changed, 74 insertions, 56 deletions
diff --git a/internal/server/server.go b/internal/server/server.go
index 4bd8766..e7229c5 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -1,23 +1,21 @@
package server
import (
+ "context"
"os"
"time"
"github.com/eduvpn/eduvpn-common/internal/discovery"
"github.com/eduvpn/eduvpn-common/internal/oauth"
+ "github.com/eduvpn/eduvpn-common/internal/server/api"
+ "github.com/eduvpn/eduvpn-common/internal/server/base"
+ "github.com/eduvpn/eduvpn-common/internal/server/profile"
"github.com/eduvpn/eduvpn-common/internal/wireguard"
+ "github.com/eduvpn/eduvpn-common/types/protocol"
+ srvtypes "github.com/eduvpn/eduvpn-common/types/server"
"github.com/go-errors/errors"
)
-type Type int8
-
-const (
- CustomServerType Type = iota
- InstituteAccessServerType
- SecureInternetServerType
-)
-
type Server interface {
OAuth() *oauth.OAuth
@@ -25,27 +23,13 @@ type Server interface {
TemplateAuth() func(string) string
// Base returns the server base
- Base() (*Base, error)
+ Base() (*base.Base, error)
- // RefreshEndpoints
- RefreshEndpoints(*discovery.Discovery) error
-}
+ // NeedsLocation checks if the server needs a secure internet location
+ NeedsLocation() bool
-type EndpointList struct {
- API string `json:"api_endpoint"`
- Authorization string `json:"authorization_endpoint"`
- Token string `json:"token_endpoint"`
-}
-
-type EndpointsVersions struct {
- V2 EndpointList `json:"http://eduvpn.org/api#2"`
- V3 EndpointList `json:"http://eduvpn.org/api#3"`
-}
-
-// Endpoints defines the json format for /.well-known/vpn-user-portal".
-type Endpoints struct {
- API EndpointsVersions `json:"api"`
- V string `json:"v"`
+ // Public returns the representation that will be passed over the CGO barrier
+ Public() (interface{}, error)
}
func UpdateTokens(srv Server, t oauth.Token) {
@@ -56,12 +40,12 @@ func OAuthURL(srv Server, name string) (string, error) {
return srv.OAuth().AuthURL(name, srv.TemplateAuth())
}
-func OAuthExchange(srv Server) error {
- return srv.OAuth().Exchange()
+func OAuthExchange(ctx context.Context, srv Server) error {
+ return srv.OAuth().Exchange(ctx)
}
-func HeaderToken(srv Server) (string, error) {
- return srv.OAuth().AccessToken()
+func HeaderToken(ctx context.Context, srv Server) (string, error) {
+ return srv.OAuth().AccessToken(ctx)
}
func MarkTokenExpired(srv Server) {
@@ -72,16 +56,13 @@ func MarkTokensForRenew(srv Server) {
srv.OAuth().SetTokenRenew()
}
-func NeedsRelogin(srv Server) bool {
- _, err := HeaderToken(srv)
+func NeedsRelogin(ctx context.Context, srv Server) bool {
+ // TODO: this error can be a context cancel
+ _, err := HeaderToken(ctx, srv)
return err != nil
}
-func CancelOAuth(srv Server) {
- srv.OAuth().Cancel()
-}
-
-func CurrentProfile(srv Server) (*Profile, error) {
+func CurrentProfile(srv Server) (*profile.Profile, error) {
b, err := srv.Base()
if err != nil {
return nil, err
@@ -96,19 +77,31 @@ func CurrentProfile(srv Server) (*Profile, error) {
return nil, errors.Errorf("profile not found: " + pID)
}
-func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) {
+func ValidProfiles(srv Server, wireguardSupport bool) (*[]profile.Profile, error) {
// No error wrapping here otherwise we wrap it too much
b, err := srv.Base()
if err != nil {
return nil, err
}
- ps := b.ValidProfiles(wireguardSupport)
- if len(ps.Info.ProfileList) == 0 {
+ ps := b.Profiles.Supported(wireguardSupport)
+ if len(ps) == 0 {
return nil, errors.Errorf("no profiles found with supported protocols")
}
return &ps, nil
}
+func Profile(srv Server, id string) error {
+ b, err := srv.Base()
+ if err != nil {
+ return err
+ }
+ if !b.Profiles.Has(id) {
+ return errors.Errorf("no profile available with id: %s", id)
+ }
+ b.Profiles.Current = id
+ return nil
+}
+
type ConfigData struct {
// The configuration
Config string
@@ -120,7 +113,18 @@ type ConfigData struct {
Tokens oauth.Token
}
-func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) {
+// Public gets the public data from the types package
+// dg specifies if this config is default gateway
+func (c *ConfigData) Public(dg bool) srvtypes.Configuration {
+ return srvtypes.Configuration{
+ VPNConfig: c.Config,
+ Protocol: protocol.New(c.Type),
+ DefaultGateway: dg,
+ Tokens: c.Tokens.Public(),
+ }
+}
+
+func wireguardGetConfig(ctx context.Context, srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) {
b, err := srv.Base()
if err != nil {
return nil, err
@@ -133,7 +137,7 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*Confi
}
pub := key.PublicKey().String()
- cfg, proto, exp, err := APIConnectWireguard(srv, pID, pub, preferTCP, openVPNSupport)
+ cfg, proto, exp, err := api.ConnectWireguard(ctx, b, srv.OAuth(), pID, pub, preferTCP, openVPNSupport)
if err != nil {
return nil, err
}
@@ -159,13 +163,13 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*Confi
return &ConfigData{Config: cfg, Type: proto, Tokens: t}, nil
}
-func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) {
+func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigData, error) {
b, err := srv.Base()
if err != nil {
return nil, err
}
pid := b.Profiles.Current
- cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP)
+ cfg, exp, err := api.ConnectOpenVPN(ctx, b, srv.OAuth(), pid, preferTCP)
if err != nil {
return nil, err
}
@@ -184,15 +188,14 @@ func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) {
return &ConfigData{Config: cfg, Type: "openvpn", Tokens: t}, nil
}
-func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) {
- // Get new profiles using the info call
- // This does not override the current profile
- err := APIInfo(srv)
+func HasValidProfile(ctx context.Context, srv Server, wireguardSupport bool) (bool, error) {
+ b, err := srv.Base()
if err != nil {
return false, err
}
-
- b, err := srv.Base()
+ // Get new profiles using the info call
+ // This does not override the current profile
+ err = api.Info(ctx, b, srv.OAuth())
if err != nil {
return false, err
}
@@ -225,7 +228,18 @@ func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) {
return true, nil
}
-func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) {
+func RefreshEndpoints(ctx context.Context, srv Server) error {
+ // Re-initialize the endpoints
+ // TODO: Make this a warning instead?
+ b, err := srv.Base()
+ if err != nil {
+ return err
+ }
+
+ return api.Endpoints(ctx, b)
+}
+
+func Config(ctx context.Context, server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) {
p, err := CurrentProfile(server)
if err != nil {
return nil, err
@@ -250,10 +264,10 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData,
case wg:
// A wireguard connect call needs to generate a wireguard key and add it to the config
// Also the server could send back an OpenVPN config if it supports OpenVPN
- cfg, err = wireguardGetConfig(server, preferTCP, ovpn)
+ cfg, err = wireguardGetConfig(ctx, server, preferTCP, ovpn)
// The config only supports OpenVPN
case ovpn:
- cfg, err = openVPNGetConfig(server, preferTCP)
+ cfg, err = openVPNGetConfig(ctx, server, preferTCP)
// The config supports no available protocol because the profile only supports WireGuard but the client doesn't
default:
return nil, errors.New("no supported protocol found")
@@ -267,6 +281,10 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData,
return cfg, err
}
-func Disconnect(server Server) error {
- return APIDisconnect(server)
+func Disconnect(ctx context.Context, server Server) error {
+ b, err := server.Base()
+ if err != nil {
+ return err
+ }
+ return api.Disconnect(ctx, b, server.OAuth())
}