diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2024-02-06 16:26:59 +0100 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2024-02-19 14:15:07 +0100 |
| commit | 3152078aec8334357a61171838f664eb03299211 (patch) | |
| tree | 57da9dd39a70e44f05f104adc442b0166053b85c /internal/config/v2 | |
| parent | 819d7f9914cbb34abb76b932c05b030a34986ec2 (diff) | |
Config: New state file
Caches less. Also convert the V1 state file
Diffstat (limited to 'internal/config/v2')
| -rw-r--r-- | internal/config/v2/convert.go | 88 | ||||
| -rw-r--r-- | internal/config/v2/v2.go | 245 | ||||
| -rw-r--r-- | internal/config/v2/v2_test.go | 139 |
3 files changed, 472 insertions, 0 deletions
diff --git a/internal/config/v2/convert.go b/internal/config/v2/convert.go new file mode 100644 index 0000000..0212749 --- /dev/null +++ b/internal/config/v2/convert.go @@ -0,0 +1,88 @@ +package v2 + +import ( + "time" + + "github.com/eduvpn/eduvpn-common/internal/config/v1" + "github.com/eduvpn/eduvpn-common/types/server" +) + +func v1AuthTime(st time.Time, ost time.Time) time.Time { + // OAuth start time can be zero + if ost.IsZero() { + return st + } + return ost +} + +func convertV1Server(list v1.InstituteServers, iscurrent bool, t server.Type) (map[ServerType]*Server, *ServerType) { + ret := make(map[ServerType]*Server) + var lc *ServerType + for k, v := range list.Map { + key := ServerType{ + T: t, + ID: k, + } + if iscurrent && k == list.CurrentURL { + lc = &key + } + prfs := v.Profiles.Public() + prfs.Current = v.Profiles.Current + ret[key] = &Server{ + Profiles: prfs, + LastAuthorizeTime: v1AuthTime(v.Base.StartTime, v.Base.StartTimeOAuth), + ExpireTime: v.Base.ExpireTime, + } + } + return ret, lc +} + +func FromV1(ver1 *v1.V1) *V2 { + gsrvs := ver1.Servers + + var lc *ServerType + cust, glc := convertV1Server(gsrvs.Custom, gsrvs.IsType == server.TypeCustom, server.TypeCustom) + if lc == nil { + lc = glc + } + res, glc := convertV1Server(gsrvs.Institute, gsrvs.IsType == server.TypeInstituteAccess, server.TypeInstituteAccess) + if lc == nil { + lc = glc + } + + for k, v := range cust { + res[k] = v + } + sec := gsrvs.SecureInternetHome + // if the home organization ID is filled we have secure internet present + if sec.HomeOrganizationID == "" { + return &V2{ + Discovery: ver1.Discovery, + List: res, + LastChosen: lc, + } + } + v, ok := sec.BaseMap[sec.CurrentLocation] + if v != nil && ok { + t := ServerType{ + T: server.TypeSecureInternet, + ID: sec.HomeOrganizationID, + } + if gsrvs.IsType == server.TypeSecureInternet { + lc = &t + } + prfs := v.Profiles.Public() + prfs.Current = v.Profiles.Current + res[t] = &Server{ + CountryCode: sec.CurrentLocation, + Profiles: prfs, + LastAuthorizeTime: v1AuthTime(v.StartTime, v.StartTimeOAuth), + ExpireTime: v.ExpireTime, + } + } + return &V2{ + Discovery: ver1.Discovery, + List: res, + LastChosen: lc, + } +} diff --git a/internal/config/v2/v2.go b/internal/config/v2/v2.go new file mode 100644 index 0000000..9608d54 --- /dev/null +++ b/internal/config/v2/v2.go @@ -0,0 +1,245 @@ +package v2 + +import ( + "errors" + "fmt" + "net/url" + "time" + + "github.com/eduvpn/eduvpn-common/internal/discovery" + "github.com/eduvpn/eduvpn-common/types/server" +) + +type Server struct { + Profiles server.Profiles `json:"profiles"` + LastAuthorizeTime time.Time `json:"last_authorize_time,omitempty"` + ExpireTime time.Time `json:"expire_time,omitempty"` + + // In case of secure internet: + CountryCode string `json:"country_code"` +} + +type ServerType struct { + T server.Type + ID string +} + +const keyFormat = "%d,%s" + +func newServerType(key string) (*ServerType, error) { + var t server.Type + var id string + if _, err := fmt.Sscanf(key, keyFormat, &t, &id); err != nil { + return nil, err + } + + return &ServerType{ + T: t, + ID: id, + }, nil +} + +func (st ServerType) MarshalText() ([]byte, error) { + k := fmt.Sprintf(keyFormat, st.T, st.ID) + return []byte(k), nil +} + +func (st *ServerType) UnmarshalText(text []byte) error { + k := string(text) + g, err := newServerType(k) + if err != nil { + return err + } + *st = *g + return nil +} + +type V2 struct { + List map[ServerType]*Server `json:"server_list,omitempty"` + LastChosen *ServerType `json:"last_chosen_id,omitempty"` + Discovery discovery.Discovery `json:"discovery"` +} + +func (cfg *V2) RemoveServer(id string, t server.Type) error { + k := ServerType{ + ID: id, + T: t, + } + + if _, ok := cfg.List[k]; ok { + delete(cfg.List, k) + + // reset the last chosen + if cfg.LastChosen != nil && *cfg.LastChosen == k { + cfg.LastChosen = nil + } + return nil + } + return errors.New("server does not exist") +} + +func (cfg *V2) getServerWithKey(k ServerType) (*Server, error) { + if v, ok := cfg.List[k]; ok { + return v, nil + } + return nil, errors.New("server does not exist") +} + +func (cfg *V2) GetServer(id string, t server.Type) (*Server, error) { + k := ServerType{ + ID: id, + T: t, + } + return cfg.getServerWithKey(k) +} + +func (cfg *V2) CurrentServer() (*Server, *ServerType, error) { + if cfg.LastChosen == nil { + return nil, nil, errors.New("no server chosen before") + } + srv, err := cfg.getServerWithKey(*cfg.LastChosen) + if err != nil { + return nil, nil, err + } + return srv, cfg.LastChosen, nil +} + +func (cfg *V2) HasSecureInternet() bool { + for k := range cfg.List { + if k.T == server.TypeSecureInternet { + return true + } + } + return false +} + +func (cfg *V2) AddServer(id string, t server.Type, srv Server) error { + if cfg.HasSecureInternet() && t == server.TypeSecureInternet { + return errors.New("a secure internet server already exists, remove the other secure internet server first") + } + k := ServerType{ + ID: id, + T: t, + } + if cfg.List == nil { + cfg.List = make(map[ServerType]*Server) + } + cfg.List[k] = &srv + return nil +} + +func (cfg *V2) PublicCurrent(disco *discovery.Discovery) (*server.Current, error) { + curr, _, err := cfg.CurrentServer() + if err != nil { + return nil, err + } + rcurr := &server.Current{} + // SAFETY: LastChosen is guaranteed to be non-nil here + switch cfg.LastChosen.T { + case server.TypeInstituteAccess: + g, err := convertInstitute(cfg.LastChosen.ID, disco) + if err != nil { + return nil, err + } + g.Profiles = curr.Profiles + rcurr.Institute = g + case server.TypeSecureInternet: + g, err := convertSecure(cfg.LastChosen.ID, curr.CountryCode, disco) + if err != nil { + return nil, err + } + g.Profiles = curr.Profiles + rcurr.SecureInternet = g + case server.TypeCustom: + g, err := convertCustom(cfg.LastChosen.ID) + if err != nil { + return nil, err + } + g.Profiles = curr.Profiles + rcurr.Custom = g + default: + return nil, fmt.Errorf("unknown connected type: %d", cfg.LastChosen.T) + } + rcurr.Type = cfg.LastChosen.T + return rcurr, nil +} + +func convertInstitute(url string, disco *discovery.Discovery) (*server.Institute, error) { + dsrv, err := disco.ServerByURL(url, "institute_access") + if err != nil { + return nil, err + } + + return &server.Institute{ + Server: server.Server{ + DisplayName: dsrv.DisplayName, + Identifier: url, + }, + SupportContacts: dsrv.SupportContact, + }, nil +} + +func convertCustom(u string) (*server.Server, error) { + pu, err := url.Parse(u) + if err != nil { + return nil, err + } + return &server.Server{ + DisplayName: map[string]string{ + "en": pu.Hostname(), + }, + Identifier: u, + }, nil +} + +func convertSecure(orgID string, countryCode string, disco *discovery.Discovery) (*server.SecureInternet, error) { + dorg, _, err := disco.SecureHomeArgs(orgID) + if err != nil { + return nil, err + } + return &server.SecureInternet{ + Server: server.Server{ + DisplayName: dorg.DisplayName, + Identifier: dorg.OrgID, + }, + CountryCode: countryCode, + Locations: disco.SecureLocationList(), + }, nil +} + +func (cfg *V2) PublicList(disco *discovery.Discovery) *server.List { + ret := &server.List{} + // TODO: profile information? + for k, v := range cfg.List { + switch k.T { + case server.TypeInstituteAccess: + g, err := convertInstitute(k.ID, disco) + if err != nil || g == nil { + // TODO: log/delisted? + continue + } + g.Profiles = v.Profiles + ret.Institutes = append(ret.Institutes, *g) + case server.TypeSecureInternet: + g, err := convertSecure(k.ID, v.CountryCode, disco) + if err != nil || g == nil { + // TODO: log/delisted? + continue + } + g.Profiles = v.Profiles + ret.SecureInternet = g + case server.TypeCustom: + g, err := convertCustom(k.ID) + if err != nil || g == nil { + // TODO: log/delisted? + continue + } + g.Profiles = v.Profiles + ret.Custom = append(ret.Custom, *g) + default: + // TODO: log + continue + } + } + return ret +} diff --git a/internal/config/v2/v2_test.go b/internal/config/v2/v2_test.go new file mode 100644 index 0000000..5a4c2ea --- /dev/null +++ b/internal/config/v2/v2_test.go @@ -0,0 +1,139 @@ +package v2 + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/eduvpn/eduvpn-common/internal/test" + "github.com/eduvpn/eduvpn-common/types/server" +) + +func TestLoad(t *testing.T) { + cases := []struct { + json string + want *V2 + wantErr string + }{ + // normal v2 config + { + json: ` +{ + "server_list": { + "1,a": { + "profiles": { + "current": "a", + "map": { + "a": { + "display_name": { + "en": "a" + } + } + } + } + } + } +} +`, + want: &V2{ + List: map[ServerType]*Server{ + {ID: "a", T: server.TypeInstituteAccess}: { + Profiles: server.Profiles{ + Map: map[string]server.Profile{ + "a": {DisplayName: map[string]string{"en": "a"}}, + }, + Current: "a", + }, + }, + }, + }, + wantErr: "", + }, + { + json: ` +{ + "server_list": { + "a,1": { + "profiles": { + "current": "a", + "map": { + "a": { + "display_name": { + "en": "a" + } + } + } + } + } + } +} +`, + want: nil, + wantErr: "expected integer", + }, + { + json: ` +{ + "server_list": { + "1,a": { + "profiles": { + "current": "a", + "map": { + "a": { + "display_name": { + "en": "a" + } + } + } + } + }, + "2,a": { + "profiles": { + "current": "a", + "map": { + "a": { + "display_name": { + "en": "a" + } + } + } + } + } + } +} +`, + want: &V2{ + List: map[ServerType]*Server{ + {ID: "a", T: server.TypeInstituteAccess}: { + Profiles: server.Profiles{ + Map: map[string]server.Profile{ + "a": {DisplayName: map[string]string{"en": "a"}}, + }, + Current: "a", + }, + }, + {ID: "a", T: server.TypeSecureInternet}: { + Profiles: server.Profiles{ + Map: map[string]server.Profile{ + "a": {DisplayName: map[string]string{"en": "a"}}, + }, + Current: "a", + }, + }, + }, + }, + wantErr: "", + }, + } + + for _, v := range cases { + var g *V2 + err := json.Unmarshal([]byte(v.json), &g) + test.AssertError(t, err, v.wantErr) + if err == nil { + if !reflect.DeepEqual(g, v.want) { + t.Fatalf("structs not equal") + } + } + } +} |
