summaryrefslogtreecommitdiff
path: root/client/client.go
blob: e8fc02ce16eb3c6b9fcd48b05ae1b6362923dc4f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
//go:generate go run golang.org/x/text/cmd/gotext -srclang=en update -out=zgotext.go -lang=da,de,en,es,fr,it,nl,pt,sl,ukr

// Package client implements the public interface for creating eduVPN/Let's Connect! clients
package client

import (
	"context"
	"errors"
	"sync"
	"time"

	"github.com/eduvpn/eduvpn-common/i18nerr"
	"github.com/eduvpn/eduvpn-common/internal/api"
	"github.com/eduvpn/eduvpn-common/internal/config"
	"github.com/eduvpn/eduvpn-common/internal/discovery"
	"github.com/eduvpn/eduvpn-common/internal/failover"
	"github.com/eduvpn/eduvpn-common/internal/fsm"
	"github.com/eduvpn/eduvpn-common/internal/http"
	"github.com/eduvpn/eduvpn-common/internal/log"
	"github.com/eduvpn/eduvpn-common/internal/server"
	"github.com/eduvpn/eduvpn-common/types/cookie"
	srvtypes "github.com/eduvpn/eduvpn-common/types/server"
	"github.com/jwijenbergh/eduoauth-go"
)

// Client is the main struct for the VPN client.
type Client struct {
	// The name of the client
	Name string

	// The servers
	Servers server.Servers

	// The fsm
	FSM fsm.FSM

	// Whether to enable debugging
	Debug bool

	// TokenSetter sets the tokens in the client
	TokenSetter func(sid string, stype srvtypes.Type, tok srvtypes.Tokens)

	// TokenGetter gets the tokens from the client
	TokenGetter func(sid string, stype srvtypes.Type) *srvtypes.Tokens

	// tokenCacher
	tokCacher TokenCacher

	// cfg is the config
	cfg *config.Config

	mu sync.Mutex

	discoMan *discovery.Manager
}

// GettingConfig is defined here to satisfy the server.Callbacks interface
// It is called when internally we are getting a config
// We go to the GettingConfig state
func (c *Client) GettingConfig() error {
	if c.FSM.InState(StateGettingConfig) {
		return nil
	}
	_, err := c.FSM.GoTransition(StateGettingConfig)
	return err
}

// InvalidProfile is defined here to satisfy the server.Callbacks interface
// It is called when a profile is invalid
// Here we call the AskProfile transition
func (c *Client) InvalidProfile(ctx context.Context, srv *server.Server) (string, error) {
	ck := cookie.NewWithContext(ctx)
	prfs, err := srv.Profiles()
	if err != nil {
		return "", err
	}
	// we are guaranteed to have profiles > 0 (even after filtering)
	// because internally this callback is only triggered if there is a choice to make

	errChan := make(chan error)
	go func() {
		err := c.FSM.GoTransitionRequired(StateAskProfile, &srvtypes.RequiredAskTransition{
			C:    ck,
			Data: prfs,
		})
		if err != nil {
			errChan <- err
		}
	}()
	pID, err := ck.Receive(errChan)
	if err != nil {
		return "", err
	}

	return pID, nil
}

func (c *Client) goTransition(id fsm.StateID) error {
	handled, err := c.FSM.GoTransition(id)
	if err != nil {
		return i18nerr.WrapInternal(err, "state transition error")
	}
	if !handled {
		log.Logger.Debugf("transition not handled by the client to internal state: '%s'", GetStateName(id))
	}
	return nil
}

// New creates a new client with the following parameters:
//   - name: the name of the client
//   - directory: the directory where the config files are stored. Absolute or relative
//   - stateCallback: the callback function for the FSM that takes two states (old and new) and the data as an interface
//   - debug: whether or not we want to enable debugging
//
// It returns an error if initialization failed, for example when discovery cannot be obtained and when there are no servers.
func New(name string, version string, directory string, stateCallback func(FSMStateID, FSMStateID, interface{}) bool, debug bool) (c *Client, err error) {
	// We create the client by filling fields one by one
	c = &Client{}

	if !isAllowedClientID(name) {
		return nil, i18nerr.NewInternalf("The client registered with an invalid client ID: '%v'", name)
	}

	if len([]rune(version)) > 20 {
		return nil, i18nerr.NewInternalf("The client registered with an invalid version: '%v'", version)
	}

	// Initialize the logger
	lvl := log.LevelInfo
	if debug {
		lvl = log.LevelDebug
	}

	if err = log.Logger.Init(lvl, directory); err != nil {
		return nil, i18nerr.WrapInternalf(err, "The log file with directory: '%s' failed to initialize", directory)
	}

	// set client name
	c.Name = name

	// register HTTP agent
	http.RegisterAgent(userAgentName(name), version)

	// Initialize the FSM
	c.FSM = newFSM(stateCallback)

	// Debug only if given
	c.Debug = debug

	c.cfg = config.NewFromDirectory(directory)

	// set the servers
	c.Servers = server.NewServers(c.Name, c, c.cfg.V2)

	c.discoMan = discovery.NewManager(c.cfg.Discovery())

	if !c.hasDiscovery() {
		return c, nil
	}

	disco, release := c.discoMan.Discovery(true)
	defer release()
	disco.MarkServersExpired()
	if !c.cfg.HasSecureInternet() {
		disco.MarkOrganizationsExpired()
	}

	return c, nil
}

// TriggerAuth is called when authorization is triggered
// This function satisfies the server.Callbacks interface
func (c *Client) TriggerAuth(ctx context.Context, url string, wait bool) (string, error) {
	// Get a reply from the client
	if wait {
		ck := cookie.NewWithContext(ctx)
		errChan := make(chan error)
		go func() {
			err := c.FSM.GoTransitionRequired(StateOAuthStarted, &srvtypes.RequiredAskTransition{
				C:    ck,
				Data: url,
			})
			if err != nil {
				errChan <- err
			}
		}()
		g, err := ck.Receive(errChan)
		if err != nil {
			return "", err
		}
		return g, nil
	}
	// Otherwise do normal authorization (desktop clients)
	err := c.FSM.GoTransitionRequired(StateOAuthStarted, url)
	if err != nil {
		return "", err
	}
	return "", nil
}

// AuthDone is called when authorization is done
// This is defined to satisfy the server.Callbacks interface
func (c *Client) AuthDone(id string, t srvtypes.Type) {
	srv, err := c.Servers.GetServer(id, t)
	if err == nil {
		srv.LastAuthorizeTime = time.Now()
	}
	_, err = c.FSM.GoTransition(StateMain)
	if err != nil {
		log.Logger.Debugf("unhandled auth done main transition: %v", err)
	}
	c.TrySave()
}

// TokensUpdated is called when tokens are updated
// It updates the cache map and the client tokens
// This is defined to satisfy the server.Callbacks interface
func (c *Client) TokensUpdated(id string, t srvtypes.Type, tok eduoauth.Token) {
	if tok.Access == "" {
		return
	}
	// Set the memory
	err := c.tokCacher.Set(id, t, tok)
	if err != nil {
		log.Logger.Warningf("failed to set tokens into cache with error: %v", err)
	}

	if c.TokenSetter == nil {
		return
	}
	// Update the client
	c.TokenSetter(id, t, srvtypes.Tokens{
		Access:  tok.Access,
		Refresh: tok.Refresh,
		Expires: tok.ExpiredTimestamp.Unix(),
	})
}

// Register means updating the FSM to get to the initial state correctly
func (c *Client) Register() error {
	err := c.goTransition(StateMain)
	if err != nil {
		return err
	}
	return nil
}

// Deregister 'deregisters' the client, meaning saving the log file and the config and emptying out the client struct.
func (c *Client) Deregister() {
	c.discoMan.Cancel()

	_, release := c.discoMan.Discovery(false)
	// save the config
	c.TrySave()

	// Move the state machine back
	_, err := c.FSM.GoTransition(StateDeregistered)
	if err != nil {
		log.Logger.Debugf("failed deregistered transition: %v", err)
	}

	// Close the log file
	_ = log.Logger.Close()
	release()

	// Empty out the state
	*c = Client{}
}

// ExpiryTimes returns the different Unix timestamps regarding expiry
// - The time starting at which the renew button should be shown, after 30 minutes and less than 24 hours
// - The time starting at which the countdown button should be shown, less than 24 hours
// - The list of times where notifications should be shown
// These times are reset when the VPN gets disconnected
func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) {
	srv, err := c.Servers.CurrentServer()
	if err != nil {
		return nil, i18nerr.WrapInternal(err, "The current server was not found when getting the VPN expiration date")
	}
	return &srvtypes.Expiry{
		StartTime:         srv.LastAuthorizeTime.Unix(),
		EndTime:           srv.ExpireTime.Unix(),
		ButtonTime:        server.RenewButtonTime(srv.LastAuthorizeTime, srv.ExpireTime),
		CountdownTime:     server.CountdownTime(srv.LastAuthorizeTime, srv.ExpireTime),
		NotificationTimes: server.NotificationTimes(srv.LastAuthorizeTime, srv.ExpireTime),
	}, nil
}

func (c *Client) locationCallback(ck *cookie.Cookie, orgID string) error {
	disco, release := c.discoMan.Discovery(false)
	locs := disco.SecureLocationList()
	release()
	errChan := make(chan error)
	go func() {
		err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{
			C:    ck,
			Data: locs,
		})
		if err != nil {
			errChan <- err
		}
	}()
	loc, err := ck.Receive(errChan)
	if err != nil {
		return err
	}
	srv, err := c.Servers.GetServer(orgID, srvtypes.TypeSecureInternet)
	if err != nil {
		return err
	}
	srv.CountryCode = loc
	c.TrySave()
	return nil
}

// TrySave tries to save the internal state file
// If an error occurs it logs it
func (c *Client) TrySave() {
	log.Logger.Debugf("saving state file")
	if c.cfg == nil {
		log.Logger.Warningf("no state file to save")
		return
	}
	err := c.cfg.Save()
	if err != nil {
		log.Logger.Warningf("failed to save state file: %v", err)
	}
}

// AddServer adds a server with identifier and type
func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.Type, ot *int64) (err error) {
	c.mu.Lock()
	defer c.mu.Unlock()

	if !c.hasDiscovery() && _type != srvtypes.TypeCustom {
		return i18nerr.NewInternalf("Adding a non-custom server when the client does not use discovery is not supported, identifier: %s, type: %v", identifier, _type)
	}
	// we are non-interactive if oauth time is non-nil
	ni := ot != nil
	// If we have failed to add the server, we remove it again
	// We add the server because we can then obtain it in other callback functions
	previousState := c.FSM.Current

	defer func() {
		if err == nil {
			c.TrySave()
		}
		// If we must run callbacks, go to the previous state if we're not in it
		if !ni && !c.FSM.InState(previousState) {
			c.FSM.GoTransition(previousState) //nolint:errcheck
		}
	}()

	if !ni {
		err = c.goTransition(StateAddingServer)
		// this is already wrapped in an UI error
		if err != nil {
			return err
		}
	}
	if _type != srvtypes.TypeSecureInternet {
		// Convert to an identifier
		identifier, err = http.EnsureValidURL(identifier, true)
		if err != nil {
			return i18nerr.WrapInternalf(err, "failed to convert identifier: %v", identifier)
		}
	}

	switch _type {
	case srvtypes.TypeInstituteAccess:
		err = c.Servers.AddInstitute(ck.Context(), c.discoMan, identifier, ot)
		if err != nil {
			return i18nerr.Wrapf(err, "Failed to add an institute access server with URL: '%s'", identifier)
		}
	case srvtypes.TypeSecureInternet:
		err = c.Servers.AddSecure(ck.Context(), c.discoMan, identifier, ot)
		if err != nil {
			return i18nerr.Wrapf(err, "Failed to add a secure internet server with organisation ID: '%s'", identifier)
		}
	case srvtypes.TypeCustom:
		err = c.Servers.AddCustom(ck.Context(), identifier, ot)
		if err != nil {
			return i18nerr.Wrapf(err, "Failed to add a server with URL: '%s'", identifier)
		}
	default:
		return i18nerr.NewInternalf("Failed to add server type: '%v'", _type)
	}
	return nil
}

func (c *Client) convertIdentifier(identifier string, t srvtypes.Type) (string, error) {
	// assume secure internet identifiers are always valid as we can't really assume they are valid urls (+ always https)
	if t == srvtypes.TypeSecureInternet {
		return identifier, nil
	}
	// Convert to an identifier, this also converts the scheme to HTTPS
	identifier, err := http.EnsureValidURL(identifier, true)
	if err != nil {
		return "", i18nerr.Wrapf(err, "The input: '%s' is not a valid URL", identifier)
	}
	return identifier, nil
}

// GetConfig gets a VPN configuration
func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool, startup bool) (cfg *srvtypes.Configuration, err error) {
	c.mu.Lock()
	defer c.mu.Unlock()
	previousState := c.FSM.Current

	if !c.hasDiscovery() && _type != srvtypes.TypeCustom {
		return nil, i18nerr.NewInternalf("Getting a non-custom server when the client does not use discovery is not supported, identifier: %s, type: %d", identifier, _type)
	}

	defer func() {
		c.TrySave()
		if err == nil {
			// it could be that we are not in getting config yet if we have just done authorization
			c.FSM.GoTransition(StateGettingConfig) //nolint:errcheck
			c.FSM.GoTransition(StateGotConfig)     //nolint:errcheck
		} else if !c.FSM.InState(previousState) {
			// go back to the previous state if an error occurred
			c.FSM.GoTransition(previousState) //nolint:errcheck
		}
	}()

	identifier, err = c.convertIdentifier(identifier, _type)
	if err != nil {
		return nil, err
	}
	err = c.GettingConfig()
	if err != nil {
		log.Logger.Debugf("failed getting config transition: %v", err)
	}

	tok, err := c.retrieveTokens(identifier, _type)
	if err != nil {
		log.Logger.Debugf("no tokens found for server: '%s', with error: '%v'", identifier, err)
	}

	ctx := ck.Context()
	if _type != srvtypes.TypeCustom {
		disco, release := c.discoMan.Discovery(true)
		// make sure the servers are fetched fresh
		_, _, dserverr := disco.Servers(ctx)
		if dserverr != nil {
			log.Logger.Warningf("failed to fetch server discovery when getting config: %v", dserverr)
		}
		release()
	}

	var srv *server.Server
	switch _type {
	case srvtypes.TypeInstituteAccess:
		srv, err = c.Servers.GetInstitute(ctx, identifier, c.discoMan, tok, startup)
	case srvtypes.TypeSecureInternet:
		disco, release := c.discoMan.Discovery(true)
		// make sure the organizations are fetched if they need an update
		_, _, dorgerr := disco.Organizations(ctx)
		if dorgerr != nil {
			log.Logger.Warningf("failed to fetch organization discovery when getting config: %v", dorgerr)
		}
		release()
		srv, err = c.Servers.GetSecure(ctx, identifier, c.discoMan, tok, startup)

		var cErr *discovery.ErrCountryNotFound
		if errors.As(err, &cErr) {
			err = c.locationCallback(ck, identifier)
			if err == nil {
				srv, err = c.Servers.GetSecure(ctx, identifier, c.discoMan, tok, startup)
			}
		}
	case srvtypes.TypeCustom:
		srv, err = c.Servers.GetCustom(ctx, identifier, tok, startup)
	default:
		err = i18nerr.NewInternalf("Server type: '%v' is not valid to get a config for", _type)
	}
	if err != nil {
		if startup {
			if errors.Is(err, api.ErrAuthorizeDisabled) {
				return nil, i18nerr.Newf("The client tried to autoconnect to the VPN server: '%s', but you need to authorizate again. Please manually connect again.", identifier)
			}
			return nil, i18nerr.Wrapf(err, "The client tried to autoconnect to the VPN server: '%s', but the operation failed to complete", identifier)
		}
		return nil, i18nerr.Wrapf(err, "Failed to connect to server: '%s'", identifier)
	}

	cfg, err = c.Servers.ConnectWithCallbacks(ck.Context(), srv, pTCP)
	if err != nil {
		return nil, i18nerr.Wrapf(err, "Failed to obtain a VPN configuration for server: '%s'", identifier)
	}
	return cfg, nil
}

// RemoveServer removes a server
func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error) {
	identifier, err = c.convertIdentifier(identifier, _type)
	if err != nil {
		return err
	}
	err = c.Servers.Remove(identifier, _type)
	if err != nil {
		return i18nerr.WrapInternalf(err, "Failed to remove server: '%s'", identifier)
	}
	disco, release := c.discoMan.Discovery(true)
	defer release()
	if _type == srvtypes.TypeSecureInternet {
		disco.MarkOrganizationsExpired()
	}
	c.TrySave()
	return nil
}

// CurrentServer gets the current server that is configured
func (c *Client) CurrentServer() (*srvtypes.Current, error) {
	curr, err := c.Servers.PublicCurrent(c.discoMan)
	if err != nil {
		return nil, i18nerr.WrapInternal(err, "The current server could not be retrieved")
	}
	return curr, nil
}

// SetProfileID set the profile ID `pID` for the current server
func (c *Client) SetProfileID(pID string) error {
	srv, err := c.Servers.CurrentServer()
	if err != nil {
		return i18nerr.WrapInternalf(err, "Failed to set the profile ID: '%s'", pID)
	}
	srv.Profiles.Current = pID
	c.TrySave()
	return nil
}

func (c *Client) retrieveTokens(sid string, t srvtypes.Type) (*eduoauth.Token, error) {
	// get from memory
	tok, err := c.tokCacher.Get(sid, t)
	if err == nil {
		return tok, nil
	}
	if c.TokenGetter == nil {
		return tok, err
	}
	// get from client
	gtok := c.TokenGetter(sid, t)
	if gtok == nil {
		return nil, errors.New("client returned nil tokens")
	}
	return &eduoauth.Token{
		Access:           gtok.Access,
		Refresh:          gtok.Refresh,
		ExpiredTimestamp: time.Unix(gtok.Expires, 0),
	}, nil
}

// Cleanup cleans up the VPN connection by sending a /disconnect
func (c *Client) Cleanup(ck *cookie.Cookie) error {
	defer c.TrySave()
	srv, err := c.Servers.CurrentServer()
	if err != nil {
		return i18nerr.WrapInternal(err, "The current server was not found when cleaning up the connection")
	}
	tok, err := c.retrieveTokens(srv.Key.ID, srv.Key.T)
	if err != nil {
		return i18nerr.WrapInternal(err, "No OAuth tokens were found when cleaning up the connection")
	}
	auth, err := srv.ServerWithCallbacks(ck.Context(), c.discoMan, tok, true)
	if err != nil {
		return i18nerr.WrapInternal(err, "The server was unable to be retrieved when cleaning up the connection")
	}
	err = auth.Disconnect(ck.Context())
	if err != nil {
		return i18nerr.WrapInternal(err, "Failed to cleanup the VPN connection")
	}
	return nil
}

// SetSecureLocation sets a secure internet location for
// organization ID `orgID` with country code `countryCode`
func (c *Client) SetSecureLocation(orgID string, countryCode string) error {
	// not supported with Let's Connect! & govVPN
	if !c.hasDiscovery() {
		return i18nerr.NewInternal("Setting a secure internet location with this client ID is not supported")
	}
	srv, err := c.Servers.GetServer(orgID, srvtypes.TypeSecureInternet)
	if err != nil {
		return i18nerr.WrapInternalf(err, "Failed to get the secure internet server with id: '%s' for setting a location", orgID)
	}
	srv.CountryCode = countryCode
	defer c.TrySave()

	// no cached location profiles
	if srv.LocationProfiles == nil {
		return nil
	}

	// restore profile from the location
	if v, ok := srv.LocationProfiles[srv.CountryCode]; ok {
		srv.Profiles.Current = v
	}
	return nil
}

// RenewSession is called when the user clicks on the renew session button
// It re-authorized the server by getting a server without passing tokens
func (c *Client) RenewSession(ck *cookie.Cookie) error {
	// getting the current serving with nil tokens means re-authorize
	srv, err := c.Servers.CurrentServer()
	if err != nil {
		return i18nerr.WrapInternal(err, "The current server could not be retrieved when renewing the session")
	}

	c.mu.Lock()
	defer c.mu.Unlock()
	previousState := c.FSM.Current

	// getting a server with no tokens means re-authorize
	_, err = srv.ServerWithCallbacks(ck.Context(), c.discoMan, nil, false)
	if err != nil {
		c.FSM.GoTransition(previousState) //nolint:errcheck
		return i18nerr.WrapInternal(err, "The server was unable to be retrieved when renewing the session")
	}
	return nil
}

// StartFailover starts the failover procedure
func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, mtu int, readRxBytes func() (int64, error)) (bool, error) {
	f := failover.New(readRxBytes)

	// get current profile
	d, err := f.Start(ck.Context(), gateway, mtu)
	if err != nil {
		return d, i18nerr.WrapInternalf(err, "Failover failed to complete with gateway: '%s' and MTU: '%d'", gateway, mtu)
	}
	return d, nil
}

// ServerList gets the list of servers
func (c *Client) ServerList() (*srvtypes.List, error) {
	disco, release := c.discoMan.Discovery(false)
	defer release()
	g := c.cfg.V2.PublicList(disco)
	return g, nil
}