diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2024-05-23 15:33:11 +0200 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2024-05-29 14:36:10 +0200 |
| commit | be74ef76b30e7ad9fc74294c8e94ff1f40e87b4e (patch) | |
| tree | b92150ebaed0d37f4baf822adc7979c846ba5091 | |
| parent | ef7f44e4bb7713b18e6c0ab1b6e3510075b6623b (diff) | |
Discovery: Improve search using levenshtein distance and sorting
| -rw-r--r-- | client/discovery.go | 29 | ||||
| -rw-r--r-- | internal/discovery/discovery.go | 33 | ||||
| -rw-r--r-- | internal/levenshtein/levenshtein.go | 117 | ||||
| -rw-r--r-- | types/discovery/discovery.go | 4 |
4 files changed, 153 insertions, 30 deletions
diff --git a/client/discovery.go b/client/discovery.go index 49df8b2..72f0ad5 100644 --- a/client/discovery.go +++ b/client/discovery.go @@ -1,6 +1,7 @@ package client import ( + "sort" "strings" "github.com/eduvpn/eduvpn-common/i18nerr" @@ -34,11 +35,23 @@ func (c *Client) DiscoOrganizations(ck *cookie.Cookie, search string) (*discotyp // convert to public subset var retOrgs []discotypes.Organization for _, v := range orgs.List { - if !v.Matches(search) { + if search == "" { + retOrgs = append(retOrgs, v.Organization) + continue + } + score := v.Score(search) + if score < 0 { continue } + v.Organization.Score = score retOrgs = append(retOrgs, v.Organization) } + if search != "" { + sort.Slice(retOrgs, func(i, j int) bool { + // lower score is better + return retOrgs[i].Score < retOrgs[j].Score + }) + } return &discotypes.Organizations{ List: retOrgs, }, err @@ -65,11 +78,23 @@ func (c *Client) DiscoServers(ck *cookie.Cookie, search string) (*discotypes.Ser // convert to public subset var retServs []discotypes.Server for _, v := range servs.List { - if !v.Matches(search) { + if search == "" { + retServs = append(retServs, v.Server) + continue + } + score := v.Score(search) + if score < 0 { continue } + v.Server.Score = score retServs = append(retServs, v.Server) } + if search != "" { + sort.Slice(retServs, func(i, j int) bool { + // lower score is better + return retServs[i].Score < retServs[j].Score + }) + } return &discotypes.Servers{ List: retServs, }, err diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 783febf..383a696 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -5,10 +5,10 @@ import ( "context" "encoding/json" "fmt" - "strings" "time" "github.com/eduvpn/eduvpn-common/internal/http" + "github.com/eduvpn/eduvpn-common/internal/levenshtein" "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/internal/verify" discotypes "github.com/eduvpn/eduvpn-common/types/discovery" @@ -40,20 +40,8 @@ type Organization struct { KeywordList discotypes.MapOrString `json:"keyword_list,omitempty"` } -// Matches returns if the search query `str` matches with this organization -func (s *Organization) Matches(str string) bool { - var catalog strings.Builder - for _, v := range s.DisplayName { - // length and nil error is returned - _, _ = catalog.WriteString(strings.ToLower(v)) - _, _ = catalog.WriteString(" ") - } - for _, v := range s.KeywordList { - // length and nil error is returned - _, _ = catalog.WriteString(strings.ToLower(v)) - _, _ = catalog.WriteString(" ") - } - return strings.Contains(catalog.String(), strings.ToLower(str)) +func (o *Organization) Score(search string) int { + return levenshtein.DiscoveryScore(search, o.DisplayName, o.KeywordList) } // Servers are the list of servers from https://disco.eduvpn.org/v2/server_list.json @@ -82,19 +70,8 @@ type Server struct { } // Matches returns if the search query `str` matches with this server -func (s *Server) Matches(str string) bool { - var catalog strings.Builder - for _, v := range s.DisplayName { - // length and nil error is returned - _, _ = catalog.WriteString(strings.ToLower(v)) - _, _ = catalog.WriteString(" ") - } - for _, v := range s.KeywordList { - // length and nil error is returned - _, _ = catalog.WriteString(strings.ToLower(v)) - _, _ = catalog.WriteString(" ") - } - return strings.Contains(catalog.String(), strings.ToLower(str)) +func (s *Server) Score(search string) int { + return levenshtein.DiscoveryScore(search, s.DisplayName, s.KeywordList) } // Discovery is the main structure used for this package. diff --git a/internal/levenshtein/levenshtein.go b/internal/levenshtein/levenshtein.go new file mode 100644 index 0000000..0a2689a --- /dev/null +++ b/internal/levenshtein/levenshtein.go @@ -0,0 +1,117 @@ +package levenshtein + +import ( + "unicode/utf8" + "unicode" + "strings" + "golang.org/x/text/runes" + "golang.org/x/text/transform" + "golang.org/x/text/unicode/norm" +) + +// min returns the min of a and b +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// levenshtein is an algorithm that returns the "distance" between two strings +// the distance for hello and helloxd is 2 because it takes two inserts to go from hello to helloxd +// the distance between hello and hello is 0 because the strings are equal +// apart from insertions, the levenshtein algorithm also takes substitutions and deletions into account +// levenshtein implementation from https://en.wikipedia.org/wiki/Levenshtein_distance#Iterative_with_two_matrix_rows +func levenshtein(os, ot string) int { + n := utf8.RuneCountInString(os) + m := utf8.RuneCountInString(ot) + s := []rune(os) + t := []rune(ot) + v0 := make([]int, m+1) + v1 := make([]int, m+1) + for i := 0; i <= m; i++ { + v0[i] = i + } + + for i := 0; i < n; i++ { + v1[0] = i + 1 + for j := 0; j < m; j++ { + dc := v0[j+1] + 1 + ic := v1[j] + 1 + var sc int + if s[i] == t[j] { + sc = v0[j] + } else { + sc = v0[j] + 1 + } + v1[j+1] = min(min(dc, ic), sc) + } + v0, v1 = v1, v0 + } + return v0[m] +} + +// adjusted creates and adjusted version of the levenshtein algorithm +// where it filters entries where one of the words in the substr is not contained in `full` +// for these a score of -1 returned +// for all others it is the normal levenshtein distance +func adjusted(substr, full string) int { + substr = normalize(substr) + full = normalize(full) + s_sub := strings.Split(substr, " ") + for _, v_sub := range s_sub { + if !strings.Contains(full, v_sub) { + return -1 + } + } + return levenshtein(substr, full) +} + +// KeywordPenalty is the penalty for matching on keywords instead of display names +const KeywordPenalty = 2 + +// DiscoveryScore computes the score of a discovery entry with the given search query +// a negative score means exclude the entry from the results +func DiscoveryScore(search string, displays map[string]string, keywords map[string]string) int { + var catalogDN strings.Builder + for _, v := range displays { + // length and nil error is returned + _, _ = catalogDN.WriteString(v) + } + scoreDN := adjusted(search, catalogDN.String()) + var catalogKW strings.Builder + for _, v := range keywords { + // length and nil error is returned + _, _ = catalogKW.WriteString(v) + } + scoreKW := 3*adjusted(search, catalogKW.String()) + + // if both scores are positive, return the min + if scoreDN >= 0 && scoreKW >= 0 { + return min(scoreDN, scoreKW) + } + + // scoreKW is negative, return scoreDN + if scoreDN >= 0 { + return scoreDN + } + // scoreDN is negative, return scoreKW + return scoreKW +} + +// removeDiacritics removes "diacritics" :^) +// diacritics are special characters, e.g. GÉANT, becomes GEANT +func removeDiacritics(text string) (string, error) { + t := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC) + result, _, err := transform.String(t, text) + if err != nil { + return text, err + } + return result, nil +} + +// normalize removes diacritics and converts to lower case +func normalize(text string) string { + dt, _ := removeDiacritics(text) + return strings.ToLower(dt) +} diff --git a/types/discovery/discovery.go b/types/discovery/discovery.go index 4d92766..f1f8ed6 100644 --- a/types/discovery/discovery.go +++ b/types/discovery/discovery.go @@ -17,6 +17,8 @@ type Organization struct { DisplayName MapOrString `json:"display_name,omitempty"` // OrgID is the organization ID for the server OrgID string `json:"org_id"` + // score is the score internally used for sorting + Score int `json:"-"` } // Servers is the type that defines the upstream discovery format for the list of servers @@ -36,6 +38,8 @@ type Server struct { Type string `json:"server_type"` // CountryCode is the country code of the server if Type is "secure_internet", e.g. nl CountryCode string `json:"country_code"` + // score is the score internally used for sorting + Score int `json:"-"` } // MapOrString is a custom type as the upstream discovery format is a map or a value. |
