summaryrefslogtreecommitdiff
path: root/internal/server/servers.go
blob: 43716a484510df85cd896c0d753c0f58cbcb29c1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package server

import (
	"context"
	"errors"
	"fmt"

	"codeberg.org/eduVPN/eduvpn-common/internal/config/v2"
	"codeberg.org/eduVPN/eduvpn-common/internal/discovery"
	"codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi"
	srvtypes "codeberg.org/eduVPN/eduvpn-common/types/server"
	"codeberg.org/jwijenbergh/eduoauth-go/v2"
)

// Callbacks defines the interface for doing certain callback operations
type Callbacks interface {
	// api.Callbacks is the API callback interface
	eduvpnapi.Callbacks
	// GettingConfig is called when the config is obtained
	GettingConfig() error
	// InvalidProfile is called when an invalid profile is found
	InvalidProfile(context.Context, *Server) (string, error)
}

// Servers is the main struct that contains information for configuring the servers
type Servers struct {
	clientID string
	cb       Callbacks
	config   *v2.V2
}

// Remove removes a server with id `identifier` and type `t`
func (s *Servers) Remove(identifier string, t srvtypes.Type) error {
	return s.config.RemoveServer(identifier, t)
}

// NewServers creates a new servers struct
func NewServers(name string, cb Callbacks, cfg *v2.V2) Servers {
	return Servers{
		clientID: name,
		cb:       cb,
		config:   cfg,
	}
}

// CurrentServer contains the information for the current active server
type CurrentServer struct {
	// it embeds the state file server
	*v2.Server
	// Key is the server key
	Key v2.ServerKey
	// srvs refers to the original servers manager
	srvs *Servers
}

// ServerWithCallbacks gets the current server as a server struct and triggers callbacks as needed
func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, disco *discovery.Discovery, tokens *eduoauth.Token, disableAuth bool) (*Server, error) {
	switch cs.Key.T {
	case srvtypes.TypeInstituteAccess:
		return cs.srvs.GetInstitute(ctx, cs.Key.ID, disco, tokens, disableAuth)
	case srvtypes.TypeSecureInternet:
		return cs.srvs.GetSecure(ctx, cs.Key.ID, disco, tokens, disableAuth)
	case srvtypes.TypeCustom:
		return cs.srvs.GetCustom(ctx, cs.Key.ID, tokens, disableAuth)
	default:
		return nil, fmt.Errorf("no such server type: %d", cs.Key.T)
	}
}

// GetServer gets a server from the state file
func (s *Servers) GetServer(id string, t srvtypes.Type) (*v2.Server, error) {
	if s.config == nil {
		return nil, errors.New("no configuration available")
	}
	return s.config.GetServer(id, t)
}

// CurrentServer gets the current server from the state file and wraps it into a neat type
func (s *Servers) CurrentServer() (*CurrentServer, error) {
	curr, k, err := s.config.CurrentServer()
	if err != nil {
		return nil, err
	}
	return &CurrentServer{
		Server: curr,
		Key:    *k,
		srvs:   s,
	}, nil
}

// PublicCurrent gets the current server into a type that we can return to the client
func (s *Servers) PublicCurrent(disco *discovery.Discovery) (*srvtypes.Current, error) {
	return s.config.PublicCurrent(disco)
}

// ConnectWithCallbacks handles the /connect flow
// It calls callbacks as needed
func (s *Servers) ConnectWithCallbacks(ctx context.Context, srv *Server, pTCP bool) (*srvtypes.Configuration, error) {
	err := srv.SetCurrent()
	if err != nil {
		return nil, err
	}
	err = s.cb.GettingConfig()
	if err != nil {
		return nil, err
	}
	cfg, err := srv.connect(ctx, pTCP)
	if err == nil {
		return cfg, nil
	}
	if !errors.Is(err, ErrInvalidProfile) {
		return cfg, err
	}
	err = s.cb.GettingConfig()
	if err != nil {
		return nil, err
	}
	// Get a new profile from the callback
	pr, err := s.cb.InvalidProfile(ctx, srv)
	if err != nil {
		return cfg, err
	}
	err = srv.SetProfileID(pr)
	if err != nil {
		return nil, err
	}
	err = s.cb.GettingConfig()
	if err != nil {
		return nil, err
	}
	return srv.connect(ctx, pTCP)
}