diff options
Diffstat (limited to 'client')
| -rw-r--r-- | client/client.go | 59 | ||||
| -rw-r--r-- | client/client_test.go | 67 | ||||
| -rw-r--r-- | client/server.go | 20 |
3 files changed, 81 insertions, 65 deletions
diff --git a/client/client.go b/client/client.go index 6df078d..70adb71 100644 --- a/client/client.go +++ b/client/client.go @@ -14,9 +14,9 @@ import ( "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" - srvtypes "github.com/eduvpn/eduvpn-common/types/server" discotypes "github.com/eduvpn/eduvpn-common/types/discovery" "github.com/eduvpn/eduvpn-common/types/protocol" + srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" ) @@ -121,36 +121,23 @@ type Client struct { profileWg sync.WaitGroup } -// Register initializes the clientwith the following parameters: +// New creates a new client with the following parameters: // - name: the name of the client // - directory: the directory where the config files are stored. Absolute or relative // - stateCallback: the callback function for the FSM that takes two states (old and new) and the data as an interface // - debug: whether or not we want to enable debugging // // It returns an error if initialization failed, for example when discovery cannot be obtained and when there are no servers. -func (c *Client) Register( - name string, - version string, - directory string, - stateCallback func(FSMStateID, FSMStateID, interface{}) bool, - debug bool, -) (err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - if !c.InFSMState(StateDeregistered) { - return errors.Errorf("fsm attempt to register while in '%v'", c.FSM.Current) - } +func New(name string, version string, directory string, stateCallback func(FSMStateID, FSMStateID, interface{}) bool, debug bool) (c *Client, err error) { + // We create the client by filling fields one by one + c = &Client{} if !isAllowedClientID(name) { - return errors.Errorf("client ID is not allowed: '%v', see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php for a list of allowed IDs", name) + return nil, errors.Errorf("client ID is not allowed: '%v', see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php for a list of allowed IDs", name) } if len([]rune(version)) > 20 { - return errors.Errorf("version is not allowed: '%s', must be max 20 characters", version) + return nil, errors.Errorf("version is not allowed: '%s', must be max 20 characters", version) } // Initialize the logger @@ -160,7 +147,7 @@ func (c *Client) Register( } if err = log.Logger.Init(lvl, directory); err != nil { - return err + return nil, err } // set client name @@ -187,23 +174,15 @@ func (c *Client) Register( log.Logger.Infof("Previous configuration not found") } - // Go to the No Server state after we're done - defer c.FSM.GoTransition(StateNoServer) - - // Let's Connect! doesn't care about discovery - if c.isLetsConnect() { - return nil - } - - // Check if we are able to fetch discovery, and log if something went wrong - if _, err := c.DiscoServers(); err != nil { - log.Logger.Warningf("Failed to get discovery servers: %v", err) - } + return c, nil +} - if _, err := c.DiscoOrganizations(); err != nil { - log.Logger.Warningf("Failed to get discovery organizations: %v", err) +// Registering means updating the FSM to get to the initial state correctly +func (c *Client) Register() error { + if !c.InFSMState(StateDeregistered) { + return errors.Errorf("fsm attempt to register while in '%v'", c.FSM.Current) } - + c.FSM.GoTransition(StateNoServer) return nil } @@ -386,8 +365,8 @@ func (c *Client) ServerList() (*srvtypes.List, error) { } cc := c.Servers.SecureInternetHomeServer.CurrentLocation secureInternet = &srvtypes.SecureInternet{ - Server: generic, - CountryCode: cc, + Server: generic, + CountryCode: cc, // TODO: delisted Delisted: false, } @@ -449,8 +428,8 @@ func (c *Client) CurrentServer() (*srvtypes.Current, error) { cc := c.Servers.SecureInternetHomeServer.CurrentLocation return &srvtypes.Current{ SecureInternet: &srvtypes.SecureInternet{ - Server: generic, - CountryCode: cc, + Server: generic, + CountryCode: cc, // TODO: delisted Delisted: false, }, diff --git a/client/client_test.go b/client/client_test.go index 772be4b..56c38ff 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -77,7 +77,7 @@ func TestServer(t *testing.T) { serverURI := getServerURI(t) state := &Client{} - registerErr := state.Register( + state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configstest", @@ -87,10 +87,15 @@ func TestServer(t *testing.T) { }, false, ) - if registerErr != nil { - t.Fatalf("Register error: %v", registerErr) + if err != nil { + t.Fatalf("Creating client error: %v", err) + } + err = state.Register() + if err != nil { + t.Fatalf("Registering error: %v", err) } + addErr := state.AddCustomServer(serverURI) if addErr != nil { t.Fatalf("Add error: %v", addErr) @@ -110,7 +115,7 @@ func testConnectOAuthParameter( state := &Client{} configDirectory := "test_oauth_parameters" - registerErr := state.Register( + state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", configDirectory, @@ -152,11 +157,15 @@ func testConnectOAuthParameter( }, false, ) - if registerErr != nil { - t.Fatalf("Register error: %v", registerErr) + if err != nil { + t.Fatalf("Creating client error: %v", err) + } + err = state.Register() + if err != nil { + t.Fatalf("Registering error: %v", err) } - err := state.AddCustomServer(serverURI) + err = state.AddCustomServer(serverURI) if errPrefix == "" { if err != nil { @@ -240,7 +249,7 @@ func TestTokenExpired(t *testing.T) { // Get a vpn state state := &Client{} - registerErr := state.Register( + state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configsexpired", @@ -250,8 +259,12 @@ func TestTokenExpired(t *testing.T) { }, false, ) - if registerErr != nil { - t.Fatalf("Register error: %v", registerErr) + if err != nil { + t.Fatalf("Creating client error: %v", err) + } + err = state.Register() + if err != nil { + t.Fatalf("Registering error: %v", err) } addErr := state.AddCustomServer(serverURI) @@ -302,7 +315,7 @@ func TestInvalidProfileCorrected(t *testing.T) { serverURI := getServerURI(t) state := &Client{} - registerErr := state.Register( + state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configscancelprofile", @@ -312,8 +325,12 @@ func TestInvalidProfileCorrected(t *testing.T) { }, false, ) - if registerErr != nil { - t.Fatalf("Register error: %v", registerErr) + if err != nil { + t.Fatalf("Creating client error: %v", err) + } + err = state.Register() + if err != nil { + t.Fatalf("Registering error: %v", err) } addErr := state.AddCustomServer(serverURI) @@ -359,7 +376,7 @@ func TestPreferTCP(t *testing.T) { serverURI := getServerURI(t) state := &Client{} - registerErr := state.Register( + state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configsprefertcp", @@ -369,8 +386,12 @@ func TestPreferTCP(t *testing.T) { }, false, ) - if registerErr != nil { - t.Fatalf("Register error: %v", registerErr) + if err != nil { + t.Fatalf("Creating client error: %v", err) + } + err = state.Register() + if err != nil { + t.Fatalf("Registering error: %v", err) } addErr := state.AddCustomServer(serverURI) @@ -419,28 +440,26 @@ func TestInvalidClientID(t *testing.T) { } for k, v := range tests { - state := &Client{} - registerErr := state.Register( + _, err := New( k, "0.1.0-test", "configsclientid", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) return true }, false, ) if v { - if registerErr != nil { - t.Fatalf("expected valid register with clientID: %v, got error: %v", k, registerErr) + if err != nil { + t.Fatalf("expected valid register with clientID: %v, got error: %v", k, err) } continue } - if registerErr == nil { + if err == nil { t.Fatalf("expected invalid register with clientID: %v, but got no error", k) } - if !strings.HasPrefix(registerErr.Error(), "client ID is not allowed") { - t.Fatalf("register error has invalid prefix: %v", registerErr.Error()) + if !strings.HasPrefix(err.Error(), "client ID is not allowed") { + t.Fatalf("register error has invalid prefix: %v", err.Error()) } } } diff --git a/client/server.go b/client/server.go index 6c0b4d2..7a14d00 100644 --- a/client/server.go +++ b/client/server.go @@ -9,8 +9,8 @@ import ( "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" discotypes "github.com/eduvpn/eduvpn-common/types/discovery" - srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/eduvpn/eduvpn-common/types/protocol" + srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" ) @@ -264,6 +264,15 @@ func (c *Client) AddInstituteServer(url string) (err error) { // Indicate that we're loading the server c.FSM.GoTransition(StateLoadingServer) + // Check if we are able to fetch discovery, and log if something went wrong + if _, err := c.DiscoServers(); err != nil { + log.Logger.Warningf("Failed to get discovery servers: %v", err) + } + + if _, err := c.DiscoOrganizations(); err != nil { + log.Logger.Warningf("Failed to get discovery organizations: %v", err) + } + // FIXME: Do nothing with discovery here as the client already has it // So pass a server as the parameter var dSrv *discotypes.Server @@ -317,6 +326,15 @@ func (c *Client) AddSecureInternetHomeServer(orgID string) (err error) { // Indicate that we're loading the server c.FSM.GoTransition(StateLoadingServer) + // Check if we are able to fetch discovery, and log if something went wrong + if _, err := c.DiscoServers(); err != nil { + log.Logger.Warningf("Failed to get discovery servers: %v", err) + } + + if _, err := c.DiscoOrganizations(); err != nil { + log.Logger.Warningf("Failed to get discovery organizations: %v", err) + } + // Get the secure internet URL from discovery org, dSrv, err := c.Discovery.SecureHomeArgs(orgID) if err != nil { |
