summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2024-05-23 15:33:11 +0200
committerJeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com>2024-05-29 14:36:10 +0200
commitbe74ef76b30e7ad9fc74294c8e94ff1f40e87b4e (patch)
treeb92150ebaed0d37f4baf822adc7979c846ba5091
parentef7f44e4bb7713b18e6c0ab1b6e3510075b6623b (diff)
Discovery: Improve search using levenshtein distance and sorting
-rw-r--r--client/discovery.go29
-rw-r--r--internal/discovery/discovery.go33
-rw-r--r--internal/levenshtein/levenshtein.go117
-rw-r--r--types/discovery/discovery.go4
4 files changed, 153 insertions, 30 deletions
diff --git a/client/discovery.go b/client/discovery.go
index 49df8b2..72f0ad5 100644
--- a/client/discovery.go
+++ b/client/discovery.go
@@ -1,6 +1,7 @@
package client
import (
+ "sort"
"strings"
"github.com/eduvpn/eduvpn-common/i18nerr"
@@ -34,11 +35,23 @@ func (c *Client) DiscoOrganizations(ck *cookie.Cookie, search string) (*discotyp
// convert to public subset
var retOrgs []discotypes.Organization
for _, v := range orgs.List {
- if !v.Matches(search) {
+ if search == "" {
+ retOrgs = append(retOrgs, v.Organization)
+ continue
+ }
+ score := v.Score(search)
+ if score < 0 {
continue
}
+ v.Organization.Score = score
retOrgs = append(retOrgs, v.Organization)
}
+ if search != "" {
+ sort.Slice(retOrgs, func(i, j int) bool {
+ // lower score is better
+ return retOrgs[i].Score < retOrgs[j].Score
+ })
+ }
return &discotypes.Organizations{
List: retOrgs,
}, err
@@ -65,11 +78,23 @@ func (c *Client) DiscoServers(ck *cookie.Cookie, search string) (*discotypes.Ser
// convert to public subset
var retServs []discotypes.Server
for _, v := range servs.List {
- if !v.Matches(search) {
+ if search == "" {
+ retServs = append(retServs, v.Server)
+ continue
+ }
+ score := v.Score(search)
+ if score < 0 {
continue
}
+ v.Server.Score = score
retServs = append(retServs, v.Server)
}
+ if search != "" {
+ sort.Slice(retServs, func(i, j int) bool {
+ // lower score is better
+ return retServs[i].Score < retServs[j].Score
+ })
+ }
return &discotypes.Servers{
List: retServs,
}, err
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index 783febf..383a696 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -5,10 +5,10 @@ import (
"context"
"encoding/json"
"fmt"
- "strings"
"time"
"github.com/eduvpn/eduvpn-common/internal/http"
+ "github.com/eduvpn/eduvpn-common/internal/levenshtein"
"github.com/eduvpn/eduvpn-common/internal/log"
"github.com/eduvpn/eduvpn-common/internal/verify"
discotypes "github.com/eduvpn/eduvpn-common/types/discovery"
@@ -40,20 +40,8 @@ type Organization struct {
KeywordList discotypes.MapOrString `json:"keyword_list,omitempty"`
}
-// Matches returns if the search query `str` matches with this organization
-func (s *Organization) Matches(str string) bool {
- var catalog strings.Builder
- for _, v := range s.DisplayName {
- // length and nil error is returned
- _, _ = catalog.WriteString(strings.ToLower(v))
- _, _ = catalog.WriteString(" ")
- }
- for _, v := range s.KeywordList {
- // length and nil error is returned
- _, _ = catalog.WriteString(strings.ToLower(v))
- _, _ = catalog.WriteString(" ")
- }
- return strings.Contains(catalog.String(), strings.ToLower(str))
+func (o *Organization) Score(search string) int {
+ return levenshtein.DiscoveryScore(search, o.DisplayName, o.KeywordList)
}
// Servers are the list of servers from https://disco.eduvpn.org/v2/server_list.json
@@ -82,19 +70,8 @@ type Server struct {
}
// Matches returns if the search query `str` matches with this server
-func (s *Server) Matches(str string) bool {
- var catalog strings.Builder
- for _, v := range s.DisplayName {
- // length and nil error is returned
- _, _ = catalog.WriteString(strings.ToLower(v))
- _, _ = catalog.WriteString(" ")
- }
- for _, v := range s.KeywordList {
- // length and nil error is returned
- _, _ = catalog.WriteString(strings.ToLower(v))
- _, _ = catalog.WriteString(" ")
- }
- return strings.Contains(catalog.String(), strings.ToLower(str))
+func (s *Server) Score(search string) int {
+ return levenshtein.DiscoveryScore(search, s.DisplayName, s.KeywordList)
}
// Discovery is the main structure used for this package.
diff --git a/internal/levenshtein/levenshtein.go b/internal/levenshtein/levenshtein.go
new file mode 100644
index 0000000..0a2689a
--- /dev/null
+++ b/internal/levenshtein/levenshtein.go
@@ -0,0 +1,117 @@
+package levenshtein
+
+import (
+ "unicode/utf8"
+ "unicode"
+ "strings"
+ "golang.org/x/text/runes"
+ "golang.org/x/text/transform"
+ "golang.org/x/text/unicode/norm"
+)
+
+// min returns the min of a and b
+func min(a, b int) int {
+ if a < b {
+ return a
+ }
+ return b
+}
+
+// levenshtein is an algorithm that returns the "distance" between two strings
+// the distance for hello and helloxd is 2 because it takes two inserts to go from hello to helloxd
+// the distance between hello and hello is 0 because the strings are equal
+// apart from insertions, the levenshtein algorithm also takes substitutions and deletions into account
+// levenshtein implementation from https://en.wikipedia.org/wiki/Levenshtein_distance#Iterative_with_two_matrix_rows
+func levenshtein(os, ot string) int {
+ n := utf8.RuneCountInString(os)
+ m := utf8.RuneCountInString(ot)
+ s := []rune(os)
+ t := []rune(ot)
+ v0 := make([]int, m+1)
+ v1 := make([]int, m+1)
+ for i := 0; i <= m; i++ {
+ v0[i] = i
+ }
+
+ for i := 0; i < n; i++ {
+ v1[0] = i + 1
+ for j := 0; j < m; j++ {
+ dc := v0[j+1] + 1
+ ic := v1[j] + 1
+ var sc int
+ if s[i] == t[j] {
+ sc = v0[j]
+ } else {
+ sc = v0[j] + 1
+ }
+ v1[j+1] = min(min(dc, ic), sc)
+ }
+ v0, v1 = v1, v0
+ }
+ return v0[m]
+}
+
+// adjusted creates and adjusted version of the levenshtein algorithm
+// where it filters entries where one of the words in the substr is not contained in `full`
+// for these a score of -1 returned
+// for all others it is the normal levenshtein distance
+func adjusted(substr, full string) int {
+ substr = normalize(substr)
+ full = normalize(full)
+ s_sub := strings.Split(substr, " ")
+ for _, v_sub := range s_sub {
+ if !strings.Contains(full, v_sub) {
+ return -1
+ }
+ }
+ return levenshtein(substr, full)
+}
+
+// KeywordPenalty is the penalty for matching on keywords instead of display names
+const KeywordPenalty = 2
+
+// DiscoveryScore computes the score of a discovery entry with the given search query
+// a negative score means exclude the entry from the results
+func DiscoveryScore(search string, displays map[string]string, keywords map[string]string) int {
+ var catalogDN strings.Builder
+ for _, v := range displays {
+ // length and nil error is returned
+ _, _ = catalogDN.WriteString(v)
+ }
+ scoreDN := adjusted(search, catalogDN.String())
+ var catalogKW strings.Builder
+ for _, v := range keywords {
+ // length and nil error is returned
+ _, _ = catalogKW.WriteString(v)
+ }
+ scoreKW := 3*adjusted(search, catalogKW.String())
+
+ // if both scores are positive, return the min
+ if scoreDN >= 0 && scoreKW >= 0 {
+ return min(scoreDN, scoreKW)
+ }
+
+ // scoreKW is negative, return scoreDN
+ if scoreDN >= 0 {
+ return scoreDN
+ }
+ // scoreDN is negative, return scoreKW
+ return scoreKW
+}
+
+// removeDiacritics removes "diacritics" :^)
+// diacritics are special characters, e.g. GÉANT, becomes GEANT
+func removeDiacritics(text string) (string, error) {
+ t := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
+ result, _, err := transform.String(t, text)
+ if err != nil {
+ return text, err
+ }
+ return result, nil
+}
+
+// normalize removes diacritics and converts to lower case
+func normalize(text string) string {
+ dt, _ := removeDiacritics(text)
+ return strings.ToLower(dt)
+}
diff --git a/types/discovery/discovery.go b/types/discovery/discovery.go
index 4d92766..f1f8ed6 100644
--- a/types/discovery/discovery.go
+++ b/types/discovery/discovery.go
@@ -17,6 +17,8 @@ type Organization struct {
DisplayName MapOrString `json:"display_name,omitempty"`
// OrgID is the organization ID for the server
OrgID string `json:"org_id"`
+ // score is the score internally used for sorting
+ Score int `json:"-"`
}
// Servers is the type that defines the upstream discovery format for the list of servers
@@ -36,6 +38,8 @@ type Server struct {
Type string `json:"server_type"`
// CountryCode is the country code of the server if Type is "secure_internet", e.g. nl
CountryCode string `json:"country_code"`
+ // score is the score internally used for sorting
+ Score int `json:"-"`
}
// MapOrString is a custom type as the upstream discovery format is a map or a value.