summaryrefslogtreecommitdiff
path: root/client
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2023-03-22 12:16:54 +0100
committerJeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com>2023-09-25 09:43:37 +0200
commiteb57e36d3c017bef80277e97db1009c38893ce2d (patch)
tree11ae9fa5e75492690e3db4bde349e2accc3fa1c9 /client
parentf5fe3d75801830ab9f1d380f5b3238b9006cf48b (diff)
Exports + Client: Refactor registering a client
- Make sure the global exports state is only set on successful creating - Only call discovery when adding a server to ensure we get the most up to date args. Creating a client should have no network calls. Fixes #12 - Split creating a client in New and Register in the GO api
Diffstat (limited to 'client')
-rw-r--r--client/client.go59
-rw-r--r--client/client_test.go67
-rw-r--r--client/server.go20
3 files changed, 81 insertions, 65 deletions
diff --git a/client/client.go b/client/client.go
index 6df078d..70adb71 100644
--- a/client/client.go
+++ b/client/client.go
@@ -14,9 +14,9 @@ import (
"github.com/eduvpn/eduvpn-common/internal/log"
"github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/eduvpn/eduvpn-common/internal/server"
- srvtypes "github.com/eduvpn/eduvpn-common/types/server"
discotypes "github.com/eduvpn/eduvpn-common/types/discovery"
"github.com/eduvpn/eduvpn-common/types/protocol"
+ srvtypes "github.com/eduvpn/eduvpn-common/types/server"
"github.com/go-errors/errors"
)
@@ -121,36 +121,23 @@ type Client struct {
profileWg sync.WaitGroup
}
-// Register initializes the clientwith the following parameters:
+// New creates a new client with the following parameters:
// - name: the name of the client
// - directory: the directory where the config files are stored. Absolute or relative
// - stateCallback: the callback function for the FSM that takes two states (old and new) and the data as an interface
// - debug: whether or not we want to enable debugging
//
// It returns an error if initialization failed, for example when discovery cannot be obtained and when there are no servers.
-func (c *Client) Register(
- name string,
- version string,
- directory string,
- stateCallback func(FSMStateID, FSMStateID, interface{}) bool,
- debug bool,
-) (err error) {
- defer func() {
- if err != nil {
- c.logError(err)
- }
- }()
-
- if !c.InFSMState(StateDeregistered) {
- return errors.Errorf("fsm attempt to register while in '%v'", c.FSM.Current)
- }
+func New(name string, version string, directory string, stateCallback func(FSMStateID, FSMStateID, interface{}) bool, debug bool) (c *Client, err error) {
+ // We create the client by filling fields one by one
+ c = &Client{}
if !isAllowedClientID(name) {
- return errors.Errorf("client ID is not allowed: '%v', see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php for a list of allowed IDs", name)
+ return nil, errors.Errorf("client ID is not allowed: '%v', see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php for a list of allowed IDs", name)
}
if len([]rune(version)) > 20 {
- return errors.Errorf("version is not allowed: '%s', must be max 20 characters", version)
+ return nil, errors.Errorf("version is not allowed: '%s', must be max 20 characters", version)
}
// Initialize the logger
@@ -160,7 +147,7 @@ func (c *Client) Register(
}
if err = log.Logger.Init(lvl, directory); err != nil {
- return err
+ return nil, err
}
// set client name
@@ -187,23 +174,15 @@ func (c *Client) Register(
log.Logger.Infof("Previous configuration not found")
}
- // Go to the No Server state after we're done
- defer c.FSM.GoTransition(StateNoServer)
-
- // Let's Connect! doesn't care about discovery
- if c.isLetsConnect() {
- return nil
- }
-
- // Check if we are able to fetch discovery, and log if something went wrong
- if _, err := c.DiscoServers(); err != nil {
- log.Logger.Warningf("Failed to get discovery servers: %v", err)
- }
+ return c, nil
+}
- if _, err := c.DiscoOrganizations(); err != nil {
- log.Logger.Warningf("Failed to get discovery organizations: %v", err)
+// Registering means updating the FSM to get to the initial state correctly
+func (c *Client) Register() error {
+ if !c.InFSMState(StateDeregistered) {
+ return errors.Errorf("fsm attempt to register while in '%v'", c.FSM.Current)
}
-
+ c.FSM.GoTransition(StateNoServer)
return nil
}
@@ -386,8 +365,8 @@ func (c *Client) ServerList() (*srvtypes.List, error) {
}
cc := c.Servers.SecureInternetHomeServer.CurrentLocation
secureInternet = &srvtypes.SecureInternet{
- Server: generic,
- CountryCode: cc,
+ Server: generic,
+ CountryCode: cc,
// TODO: delisted
Delisted: false,
}
@@ -449,8 +428,8 @@ func (c *Client) CurrentServer() (*srvtypes.Current, error) {
cc := c.Servers.SecureInternetHomeServer.CurrentLocation
return &srvtypes.Current{
SecureInternet: &srvtypes.SecureInternet{
- Server: generic,
- CountryCode: cc,
+ Server: generic,
+ CountryCode: cc,
// TODO: delisted
Delisted: false,
},
diff --git a/client/client_test.go b/client/client_test.go
index 772be4b..56c38ff 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -77,7 +77,7 @@ func TestServer(t *testing.T) {
serverURI := getServerURI(t)
state := &Client{}
- registerErr := state.Register(
+ state, err := New(
"org.letsconnect-vpn.app.linux",
"0.1.0-test",
"configstest",
@@ -87,10 +87,15 @@ func TestServer(t *testing.T) {
},
false,
)
- if registerErr != nil {
- t.Fatalf("Register error: %v", registerErr)
+ if err != nil {
+ t.Fatalf("Creating client error: %v", err)
+ }
+ err = state.Register()
+ if err != nil {
+ t.Fatalf("Registering error: %v", err)
}
+
addErr := state.AddCustomServer(serverURI)
if addErr != nil {
t.Fatalf("Add error: %v", addErr)
@@ -110,7 +115,7 @@ func testConnectOAuthParameter(
state := &Client{}
configDirectory := "test_oauth_parameters"
- registerErr := state.Register(
+ state, err := New(
"org.letsconnect-vpn.app.linux",
"0.1.0-test",
configDirectory,
@@ -152,11 +157,15 @@ func testConnectOAuthParameter(
},
false,
)
- if registerErr != nil {
- t.Fatalf("Register error: %v", registerErr)
+ if err != nil {
+ t.Fatalf("Creating client error: %v", err)
+ }
+ err = state.Register()
+ if err != nil {
+ t.Fatalf("Registering error: %v", err)
}
- err := state.AddCustomServer(serverURI)
+ err = state.AddCustomServer(serverURI)
if errPrefix == "" {
if err != nil {
@@ -240,7 +249,7 @@ func TestTokenExpired(t *testing.T) {
// Get a vpn state
state := &Client{}
- registerErr := state.Register(
+ state, err := New(
"org.letsconnect-vpn.app.linux",
"0.1.0-test",
"configsexpired",
@@ -250,8 +259,12 @@ func TestTokenExpired(t *testing.T) {
},
false,
)
- if registerErr != nil {
- t.Fatalf("Register error: %v", registerErr)
+ if err != nil {
+ t.Fatalf("Creating client error: %v", err)
+ }
+ err = state.Register()
+ if err != nil {
+ t.Fatalf("Registering error: %v", err)
}
addErr := state.AddCustomServer(serverURI)
@@ -302,7 +315,7 @@ func TestInvalidProfileCorrected(t *testing.T) {
serverURI := getServerURI(t)
state := &Client{}
- registerErr := state.Register(
+ state, err := New(
"org.letsconnect-vpn.app.linux",
"0.1.0-test",
"configscancelprofile",
@@ -312,8 +325,12 @@ func TestInvalidProfileCorrected(t *testing.T) {
},
false,
)
- if registerErr != nil {
- t.Fatalf("Register error: %v", registerErr)
+ if err != nil {
+ t.Fatalf("Creating client error: %v", err)
+ }
+ err = state.Register()
+ if err != nil {
+ t.Fatalf("Registering error: %v", err)
}
addErr := state.AddCustomServer(serverURI)
@@ -359,7 +376,7 @@ func TestPreferTCP(t *testing.T) {
serverURI := getServerURI(t)
state := &Client{}
- registerErr := state.Register(
+ state, err := New(
"org.letsconnect-vpn.app.linux",
"0.1.0-test",
"configsprefertcp",
@@ -369,8 +386,12 @@ func TestPreferTCP(t *testing.T) {
},
false,
)
- if registerErr != nil {
- t.Fatalf("Register error: %v", registerErr)
+ if err != nil {
+ t.Fatalf("Creating client error: %v", err)
+ }
+ err = state.Register()
+ if err != nil {
+ t.Fatalf("Registering error: %v", err)
}
addErr := state.AddCustomServer(serverURI)
@@ -419,28 +440,26 @@ func TestInvalidClientID(t *testing.T) {
}
for k, v := range tests {
- state := &Client{}
- registerErr := state.Register(
+ _, err := New(
k,
"0.1.0-test",
"configsclientid",
func(old FSMStateID, new FSMStateID, data interface{}) bool {
- stateCallback(t, old, new, data, state)
return true
},
false,
)
if v {
- if registerErr != nil {
- t.Fatalf("expected valid register with clientID: %v, got error: %v", k, registerErr)
+ if err != nil {
+ t.Fatalf("expected valid register with clientID: %v, got error: %v", k, err)
}
continue
}
- if registerErr == nil {
+ if err == nil {
t.Fatalf("expected invalid register with clientID: %v, but got no error", k)
}
- if !strings.HasPrefix(registerErr.Error(), "client ID is not allowed") {
- t.Fatalf("register error has invalid prefix: %v", registerErr.Error())
+ if !strings.HasPrefix(err.Error(), "client ID is not allowed") {
+ t.Fatalf("register error has invalid prefix: %v", err.Error())
}
}
}
diff --git a/client/server.go b/client/server.go
index 6c0b4d2..7a14d00 100644
--- a/client/server.go
+++ b/client/server.go
@@ -9,8 +9,8 @@ import (
"github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/eduvpn/eduvpn-common/internal/server"
discotypes "github.com/eduvpn/eduvpn-common/types/discovery"
- srvtypes "github.com/eduvpn/eduvpn-common/types/server"
"github.com/eduvpn/eduvpn-common/types/protocol"
+ srvtypes "github.com/eduvpn/eduvpn-common/types/server"
"github.com/go-errors/errors"
)
@@ -264,6 +264,15 @@ func (c *Client) AddInstituteServer(url string) (err error) {
// Indicate that we're loading the server
c.FSM.GoTransition(StateLoadingServer)
+ // Check if we are able to fetch discovery, and log if something went wrong
+ if _, err := c.DiscoServers(); err != nil {
+ log.Logger.Warningf("Failed to get discovery servers: %v", err)
+ }
+
+ if _, err := c.DiscoOrganizations(); err != nil {
+ log.Logger.Warningf("Failed to get discovery organizations: %v", err)
+ }
+
// FIXME: Do nothing with discovery here as the client already has it
// So pass a server as the parameter
var dSrv *discotypes.Server
@@ -317,6 +326,15 @@ func (c *Client) AddSecureInternetHomeServer(orgID string) (err error) {
// Indicate that we're loading the server
c.FSM.GoTransition(StateLoadingServer)
+ // Check if we are able to fetch discovery, and log if something went wrong
+ if _, err := c.DiscoServers(); err != nil {
+ log.Logger.Warningf("Failed to get discovery servers: %v", err)
+ }
+
+ if _, err := c.DiscoOrganizations(); err != nil {
+ log.Logger.Warningf("Failed to get discovery organizations: %v", err)
+ }
+
// Get the secure internet URL from discovery
org, dSrv, err := c.Discovery.SecureHomeArgs(orgID)
if err != nil {