diff options
| -rw-r--r-- | client/client.go | 59 | ||||
| -rw-r--r-- | client/client_test.go | 67 | ||||
| -rw-r--r-- | client/server.go | 20 | ||||
| -rw-r--r-- | cmd/cli/main.go | 6 | ||||
| -rw-r--r-- | docs/src/api/overview/README.md | 10 | ||||
| -rw-r--r-- | exports/exports.go | 42 | ||||
| -rw-r--r-- | types/discovery/discovery.go | 34 | ||||
| -rw-r--r-- | types/server/server.go | 44 |
8 files changed, 155 insertions, 127 deletions
diff --git a/client/client.go b/client/client.go index 6df078d..70adb71 100644 --- a/client/client.go +++ b/client/client.go @@ -14,9 +14,9 @@ import ( "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" - srvtypes "github.com/eduvpn/eduvpn-common/types/server" discotypes "github.com/eduvpn/eduvpn-common/types/discovery" "github.com/eduvpn/eduvpn-common/types/protocol" + srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" ) @@ -121,36 +121,23 @@ type Client struct { profileWg sync.WaitGroup } -// Register initializes the clientwith the following parameters: +// New creates a new client with the following parameters: // - name: the name of the client // - directory: the directory where the config files are stored. Absolute or relative // - stateCallback: the callback function for the FSM that takes two states (old and new) and the data as an interface // - debug: whether or not we want to enable debugging // // It returns an error if initialization failed, for example when discovery cannot be obtained and when there are no servers. -func (c *Client) Register( - name string, - version string, - directory string, - stateCallback func(FSMStateID, FSMStateID, interface{}) bool, - debug bool, -) (err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - if !c.InFSMState(StateDeregistered) { - return errors.Errorf("fsm attempt to register while in '%v'", c.FSM.Current) - } +func New(name string, version string, directory string, stateCallback func(FSMStateID, FSMStateID, interface{}) bool, debug bool) (c *Client, err error) { + // We create the client by filling fields one by one + c = &Client{} if !isAllowedClientID(name) { - return errors.Errorf("client ID is not allowed: '%v', see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php for a list of allowed IDs", name) + return nil, errors.Errorf("client ID is not allowed: '%v', see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php for a list of allowed IDs", name) } if len([]rune(version)) > 20 { - return errors.Errorf("version is not allowed: '%s', must be max 20 characters", version) + return nil, errors.Errorf("version is not allowed: '%s', must be max 20 characters", version) } // Initialize the logger @@ -160,7 +147,7 @@ func (c *Client) Register( } if err = log.Logger.Init(lvl, directory); err != nil { - return err + return nil, err } // set client name @@ -187,23 +174,15 @@ func (c *Client) Register( log.Logger.Infof("Previous configuration not found") } - // Go to the No Server state after we're done - defer c.FSM.GoTransition(StateNoServer) - - // Let's Connect! doesn't care about discovery - if c.isLetsConnect() { - return nil - } - - // Check if we are able to fetch discovery, and log if something went wrong - if _, err := c.DiscoServers(); err != nil { - log.Logger.Warningf("Failed to get discovery servers: %v", err) - } + return c, nil +} - if _, err := c.DiscoOrganizations(); err != nil { - log.Logger.Warningf("Failed to get discovery organizations: %v", err) +// Registering means updating the FSM to get to the initial state correctly +func (c *Client) Register() error { + if !c.InFSMState(StateDeregistered) { + return errors.Errorf("fsm attempt to register while in '%v'", c.FSM.Current) } - + c.FSM.GoTransition(StateNoServer) return nil } @@ -386,8 +365,8 @@ func (c *Client) ServerList() (*srvtypes.List, error) { } cc := c.Servers.SecureInternetHomeServer.CurrentLocation secureInternet = &srvtypes.SecureInternet{ - Server: generic, - CountryCode: cc, + Server: generic, + CountryCode: cc, // TODO: delisted Delisted: false, } @@ -449,8 +428,8 @@ func (c *Client) CurrentServer() (*srvtypes.Current, error) { cc := c.Servers.SecureInternetHomeServer.CurrentLocation return &srvtypes.Current{ SecureInternet: &srvtypes.SecureInternet{ - Server: generic, - CountryCode: cc, + Server: generic, + CountryCode: cc, // TODO: delisted Delisted: false, }, diff --git a/client/client_test.go b/client/client_test.go index 772be4b..56c38ff 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -77,7 +77,7 @@ func TestServer(t *testing.T) { serverURI := getServerURI(t) state := &Client{} - registerErr := state.Register( + state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configstest", @@ -87,10 +87,15 @@ func TestServer(t *testing.T) { }, false, ) - if registerErr != nil { - t.Fatalf("Register error: %v", registerErr) + if err != nil { + t.Fatalf("Creating client error: %v", err) + } + err = state.Register() + if err != nil { + t.Fatalf("Registering error: %v", err) } + addErr := state.AddCustomServer(serverURI) if addErr != nil { t.Fatalf("Add error: %v", addErr) @@ -110,7 +115,7 @@ func testConnectOAuthParameter( state := &Client{} configDirectory := "test_oauth_parameters" - registerErr := state.Register( + state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", configDirectory, @@ -152,11 +157,15 @@ func testConnectOAuthParameter( }, false, ) - if registerErr != nil { - t.Fatalf("Register error: %v", registerErr) + if err != nil { + t.Fatalf("Creating client error: %v", err) + } + err = state.Register() + if err != nil { + t.Fatalf("Registering error: %v", err) } - err := state.AddCustomServer(serverURI) + err = state.AddCustomServer(serverURI) if errPrefix == "" { if err != nil { @@ -240,7 +249,7 @@ func TestTokenExpired(t *testing.T) { // Get a vpn state state := &Client{} - registerErr := state.Register( + state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configsexpired", @@ -250,8 +259,12 @@ func TestTokenExpired(t *testing.T) { }, false, ) - if registerErr != nil { - t.Fatalf("Register error: %v", registerErr) + if err != nil { + t.Fatalf("Creating client error: %v", err) + } + err = state.Register() + if err != nil { + t.Fatalf("Registering error: %v", err) } addErr := state.AddCustomServer(serverURI) @@ -302,7 +315,7 @@ func TestInvalidProfileCorrected(t *testing.T) { serverURI := getServerURI(t) state := &Client{} - registerErr := state.Register( + state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configscancelprofile", @@ -312,8 +325,12 @@ func TestInvalidProfileCorrected(t *testing.T) { }, false, ) - if registerErr != nil { - t.Fatalf("Register error: %v", registerErr) + if err != nil { + t.Fatalf("Creating client error: %v", err) + } + err = state.Register() + if err != nil { + t.Fatalf("Registering error: %v", err) } addErr := state.AddCustomServer(serverURI) @@ -359,7 +376,7 @@ func TestPreferTCP(t *testing.T) { serverURI := getServerURI(t) state := &Client{} - registerErr := state.Register( + state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configsprefertcp", @@ -369,8 +386,12 @@ func TestPreferTCP(t *testing.T) { }, false, ) - if registerErr != nil { - t.Fatalf("Register error: %v", registerErr) + if err != nil { + t.Fatalf("Creating client error: %v", err) + } + err = state.Register() + if err != nil { + t.Fatalf("Registering error: %v", err) } addErr := state.AddCustomServer(serverURI) @@ -419,28 +440,26 @@ func TestInvalidClientID(t *testing.T) { } for k, v := range tests { - state := &Client{} - registerErr := state.Register( + _, err := New( k, "0.1.0-test", "configsclientid", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) return true }, false, ) if v { - if registerErr != nil { - t.Fatalf("expected valid register with clientID: %v, got error: %v", k, registerErr) + if err != nil { + t.Fatalf("expected valid register with clientID: %v, got error: %v", k, err) } continue } - if registerErr == nil { + if err == nil { t.Fatalf("expected invalid register with clientID: %v, but got no error", k) } - if !strings.HasPrefix(registerErr.Error(), "client ID is not allowed") { - t.Fatalf("register error has invalid prefix: %v", registerErr.Error()) + if !strings.HasPrefix(err.Error(), "client ID is not allowed") { + t.Fatalf("register error has invalid prefix: %v", err.Error()) } } } diff --git a/client/server.go b/client/server.go index 6c0b4d2..7a14d00 100644 --- a/client/server.go +++ b/client/server.go @@ -9,8 +9,8 @@ import ( "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" discotypes "github.com/eduvpn/eduvpn-common/types/discovery" - srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/eduvpn/eduvpn-common/types/protocol" + srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" ) @@ -264,6 +264,15 @@ func (c *Client) AddInstituteServer(url string) (err error) { // Indicate that we're loading the server c.FSM.GoTransition(StateLoadingServer) + // Check if we are able to fetch discovery, and log if something went wrong + if _, err := c.DiscoServers(); err != nil { + log.Logger.Warningf("Failed to get discovery servers: %v", err) + } + + if _, err := c.DiscoOrganizations(); err != nil { + log.Logger.Warningf("Failed to get discovery organizations: %v", err) + } + // FIXME: Do nothing with discovery here as the client already has it // So pass a server as the parameter var dSrv *discotypes.Server @@ -317,6 +326,15 @@ func (c *Client) AddSecureInternetHomeServer(orgID string) (err error) { // Indicate that we're loading the server c.FSM.GoTransition(StateLoadingServer) + // Check if we are able to fetch discovery, and log if something went wrong + if _, err := c.DiscoServers(); err != nil { + log.Logger.Warningf("Failed to get discovery servers: %v", err) + } + + if _, err := c.DiscoOrganizations(); err != nil { + log.Logger.Warningf("Failed to get discovery organizations: %v", err) + } + // Get the secure internet URL from discovery org, dSrv, err := c.Discovery.SecureHomeArgs(orgID) if err != nil { diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 9ff880d..ff99a15 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -173,9 +173,8 @@ func getConfig(state *client.Client, url string, srvType ServerTypes) (*srvtypes // Get a config for a single server, Institute Access or Secure Internet. func printConfig(url string, srvType ServerTypes) { - c := &client.Client{} - - err := c.Register( + var c *client.Client + c, err := client.New( "org.eduvpn.app.linux", "1.1.2-cli", "configs", @@ -189,6 +188,7 @@ func printConfig(url string, srvType ServerTypes) { fmt.Printf("Register error: %v", err) return } + _ = c.Register() defer c.Deregister() diff --git a/docs/src/api/overview/README.md b/docs/src/api/overview/README.md index 73ea215..32e0893 100644 --- a/docs/src/api/overview/README.md +++ b/docs/src/api/overview/README.md @@ -141,8 +141,14 @@ func stateCallback(oldState int, newState int, data interface{}) { // do something } -c := client.Client{} -c.Register("org.eduvpn.app.linux", "1.0.0", "/home/eduvpn/.config/eduvpn", stateCallback, true) +c, err := client.New("org.eduvpn.app.linux", "1.0.0", "/home/eduvpn/.config/eduvpn", stateCallback, true) +if err != nil { + // handle error +} +err := c.Register() +if err != nil { + // handle error +} ``` </details> diff --git a/exports/exports.go b/exports/exports.go index 0779104..b0c13e0 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -30,10 +30,7 @@ import ( srvtypes "github.com/eduvpn/eduvpn-common/types/server" ) -var ( - PStateCallback C.StateCB - VPNState *client.Client -) +var VPNState *client.Client func getTokens(tokens *C.char) (t srvtypes.Tokens, err error) { err = json.Unmarshal([]byte(C.GoString(tokens)), &t) @@ -62,13 +59,11 @@ func getReturnData(data interface{}) (string, error) { } func StateCallback( + stateCallback C.StateCB, oldState client.FSMStateID, newState client.FSMStateID, data interface{}, ) bool { - if PStateCallback == nil { - return false - } oldStateC := C.int(oldState) newStateC := C.int(newState) d, err := getReturnData(data) @@ -76,7 +71,7 @@ func StateCallback( return false } dataC := C.CString(d) - handled := C.call_callback(PStateCallback, oldStateC, newStateC, unsafe.Pointer(dataC)) + handled := C.call_callback(stateCallback, oldStateC, newStateC, unsafe.Pointer(dataC)) FreeString(dataC) return handled != C.int(0) } @@ -88,6 +83,8 @@ func getVPNState() (*client.Client, error) { return VPNState, nil } +// Register creates a new client and also registers the FSM to go to the initial state +// //export Register func Register( name *C.char, @@ -100,22 +97,31 @@ func Register( if stateErr == nil { return getCError(errors.New("failed to register, a VPN state is already present")) } - state := &client.Client{} - registerErr := state.Register( + c, err := client.New( C.GoString(name), C.GoString(version), C.GoString(configDirectory), - StateCallback, + func(old client.FSMStateID, new client.FSMStateID, data interface{}) bool { + return StateCallback(stateCallback, old, new, data) + }, debug != 0, ) - // Only update the VPN state if we get no error when registering - if registerErr == nil { - VPNState = state - PStateCallback = stateCallback - return nil + // Only update the state if we get no error + if err == nil { + // Update the global client such that other functions can retrieve it + // TODO: Use a sync.Once or return a CGO handler instead of a global state? + VPNState = c + // finally register the newly created client + err = c.Register() + if err != nil { + // Note: Registering can only fail for non-newly created clients + // We have obtained a fresh copy here + // This error is only there for the Go API where you can call register multiple times on an already client + panic(err) + } } - return getCError(registerErr) + return getCError(err) } //export ExpiryTimes @@ -189,7 +195,7 @@ func AddServer(_type C.int, id *C.char) *C.char { var err error switch t { case int(srvtypes.TypeInstituteAccess): - err = state.AddInstituteServer(C.GoString(id)) + err = state.AddInstituteServer(C.GoString(id)) case int(srvtypes.TypeSecureInternet): err = state.AddSecureInternetHomeServer(C.GoString(id)) case int(srvtypes.TypeCustom): diff --git a/types/discovery/discovery.go b/types/discovery/discovery.go index 5f54721..0d3495a 100644 --- a/types/discovery/discovery.go +++ b/types/discovery/discovery.go @@ -11,36 +11,36 @@ import ( // Defined in URL: "https://disco.eduvpn.org/v2/organization_list.json" type Organizations struct { // Version is the version field. The Go library internally already checks for rollbacks, you can use this for logging - Version uint64 `json:"v"` + Version uint64 `json:"v"` // List is the list/slice of organizations. Omitted if none are there - List []Organization `json:"organization_list,omitempty"` + List []Organization `json:"organization_list,omitempty"` // Timestamp is a timestamp that is internally used by the Go library to keep track of when the organizations was last updated // You can also use this for logging - Timestamp time.Time `json:"go_timestamp"` + Timestamp time.Time `json:"go_timestamp"` } // Organization is the type that defines the upstream discovery format for a single organization type Organization struct { // DisplayName is the map of strings from language tags to display names // Omitted if none is defined - DisplayName MapOrString `json:"display_name,omitempty"` + DisplayName MapOrString `json:"display_name,omitempty"` // OrgID is the organization ID for the server - OrgID string `json:"org_id"` + OrgID string `json:"org_id"` // SecureInternetHome is the secure internet home server that belongs to this organization // Omitted if none is defined - SecureInternetHome string `json:"secure_internet_home,omitempty"` + SecureInternetHome string `json:"secure_internet_home,omitempty"` // KeywordList is the list of keywords // Omitted if none is defined - KeywordList MapOrString `json:"keyword_list,omitempty"` + KeywordList MapOrString `json:"keyword_list,omitempty"` } // Servers is the type that defines the upstream discovery format for the list of servers // url: "https://disco.eduvpn.org/v2/server_list.json" type Servers struct { // Version is the version field in discovery. The Go library already checks for rollbacks, use this for logging - Version uint64 `json:"v"` + Version uint64 `json:"v"` // List is the actual list of servers, omitted from the JSON if empty - List []Server `json:"server_list,omitempty"` + List []Server `json:"server_list,omitempty"` // Timestamp is a timestamp that is internally used by the Go library to keep track of when the organizations was last updated // You can also use this for logging Timestamp time.Time `json:"go_timestamp"` @@ -49,21 +49,21 @@ type Servers struct { // Server is a signle discovery server type Server struct { // AuthenticationURLTemplate is the template to be used for authentication to skip WAYF - AuthenticationURLTemplate string `json:"authentication_url_template"` + AuthenticationURLTemplate string `json:"authentication_url_template"` // BaseURL is the base URL of the server which is used as an identifier for the server by the Go library - BaseURL string `json:"base_url"` + BaseURL string `json:"base_url"` // CountryCode is the country code for the server in case of secure internet, e.g. NL - CountryCode string `json:"country_code"` + CountryCode string `json:"country_code"` // DisplayName is the display name of the server, omitted if empty - DisplayName MapOrString `json:"display_name,omitempty"` + DisplayName MapOrString `json:"display_name,omitempty"` // DisplayName are the keywords of the server, omitted if empty - KeywordList MapOrString `json:"keyword_list,omitempty"` + KeywordList MapOrString `json:"keyword_list,omitempty"` // PublicKeyList are the public keys of the server. Currently not used in this lib but returned by the upstream discovery server - PublicKeyList []string `json:"public_key_list"` + PublicKeyList []string `json:"public_key_list"` // Type is the type of the server, "secure_internet" or "institute_access" - Type string `json:"server_type"` + Type string `json:"server_type"` // SupportContact is the list/slice of support contacts - SupportContact []string `json:"support_contact"` + SupportContact []string `json:"support_contact"` } // MapOrString is a custom type as the upstream discovery format is a map or a value. diff --git a/types/server/server.go b/types/server/server.go index ae73f45..82730ab 100644 --- a/types/server/server.go +++ b/types/server/server.go @@ -20,14 +20,14 @@ const ( // Expiry is the struct that gives the time at which certain expiry elements should be shown type Expiry struct { // StartTime is the start time of the VPN in Unix - StartTime int64 `json:"start_time"` + StartTime int64 `json:"start_time"` // EndTime is the end time of the VPN in Unix. - EndTime int64 `json:"end_time"` + EndTime int64 `json:"end_time"` // ButtonTime is the Unix time at which to start showing the renew button in the UI - ButtonTime int64 `json:"button_time"` + ButtonTime int64 `json:"button_time"` // CountdownTime is the Unix time at which to start showing more detailed countdown timer. // E.g. first start with days (7 days left), and when the current time is after this time, show e.g. 9 minutes and 59 seconds - CountdownTime int64 `json:"countdown_time"` + CountdownTime int64 `json:"countdown_time"` // NotificationTimes is the slice/list of times at which to show a notification that the VPN is about to expire NotificationTimes []int64 `json:"notification_times"` } @@ -38,28 +38,28 @@ type Profile struct { // It is a map where country codes are mapped to names, this is to be consistent with the format of other display names // E.g. {"en": "Default Profile"} // If this is empty, the field is omitted from the JSON - DisplayName map[string]string `json:"display_name,omitempty"` + DisplayName map[string]string `json:"display_name,omitempty"` // Protocols is the list of protocols that this profile supports - Protocols []protocol.Protocol `json:"supported_protocols"` + Protocols []protocol.Protocol `json:"supported_protocols"` } // Profiles is the map of profiles with the current defined type Profiles struct { // Map, the map of profiles from profile ID to the profile contents // If this is empty, the field is omitted from the JSON - Map map[string]Profile `json:"map,omitempty"` + Map map[string]Profile `json:"map,omitempty"` // Current is the current profile ID that is defined - Current string `json:"current"` + Current string `json:"current"` } // Tokens are the OAuth tokens for the server type Tokens struct { // Access is the access token - Access string `json:"access_token"` + Access string `json:"access_token"` // Refresh is the refresh token Refresh string `json:"refresh_token"` // Expires is the Unix timestamp when the token expires - Expires int64 `json:"expires_in"` + Expires int64 `json:"expires_in"` } // Server is the basic type for a server. This is the base for secure internet and institute access. Custom servers are equal to this type @@ -68,10 +68,10 @@ type Server struct { DisplayName map[string]string `json:"display_name,omitempty"` // Identifier is the Base URL for Institute Access and Custom Server. For Secure Internet this is the organization ID // This identifier should be passed to the Go library for e.g. getting a config - Identifier string `json:"identifier"` + Identifier string `json:"identifier"` // Profiles is the profiles that this server has defined // It could be that this is empty if the library has not discovered the profiles just yet - Profiles Profiles `json:"profiles"` + Profiles Profiles `json:"profiles"` } // Institute defines an institute access server @@ -91,17 +91,17 @@ type SecureInternet struct { CountryCode string `json:"country_code"` // Delisted is a boolean that indicates whether or not this server is delisted from discovery // If it is, the UI should show a warning symbol or move the server to a new category, which is up to the client - Delisted bool `json:"delisted"` + Delisted bool `json:"delisted"` } // List is the list of servers type List struct { // Institutes is the list/slice of institute access servers. If none are defined, this is omitted in the JSON - Institutes []Institute `json:"institute_access_servers,omitempty"` + Institutes []Institute `json:"institute_access_servers,omitempty"` // Secure Internet is the secure internet server if any. If none is there, it is omitted in the JSON SecureInternet *SecureInternet `json:"secure_internet_server,omitempty"` // Custom is the list/slice of custom servers. If none are defined, this is omitted in the JSON - Custom []Server `json:"custom_servers,omitempty"` + Custom []Server `json:"custom_servers,omitempty"` } // Configuration is the configuration that you get back when you call the get config function @@ -109,14 +109,14 @@ type Configuration struct { // VPNConfig is the VPN Configuration, a WireGuard or OpenVPN Configuration // In case of OpenVPN, we append "script-security 0" to disable scripts from being run by default. // A client may override this, e.g. for, very trusted, pre-provisioned VPNs - VPNConfig string `json:"config"` + VPNConfig string `json:"config"` // Protocol defines which protocol the configuration is for, OpenVPN or WireGuard - Protocol protocol.Protocol `json:"protocol"` + Protocol protocol.Protocol `json:"protocol"` // DefaultGateway is a boolean that indicates whether or not this configuration should be configured as a default gateway - DefaultGateway bool `json:"default_gateway"` + DefaultGateway bool `json:"default_gateway"` // Tokens is the updated tokens that we get back from the VPN configuration // They should be used by the client to save them in e.g. the keyring - Tokens Tokens `json:"tokens"` + Tokens Tokens `json:"tokens"` } // Current is the struct that defines the current server @@ -125,11 +125,11 @@ type Current struct { // The following three are mutually exclusive // Institute is the institute access server if any, if none is there this field is omitted in the JSON - Institute *Institute `json:"institute_access_server,omitempty"` + Institute *Institute `json:"institute_access_server,omitempty"` // Secure Internet is the secure internet server if any, if none is there this field is omitted in the JSON SecureInternet *SecureInternet `json:"secure_internet_server,omitempty"` // Custom is the custom server if any, if none is there this field is omitted in the JSON - Custom *Server `json:"custom_server,omitempty"` + Custom *Server `json:"custom_server,omitempty"` // Type is the type of server that is there to check which of the three types should be non-nil - Type Type `json:"server_type"` + Type Type `json:"server_type"` } |
