summaryrefslogtreecommitdiff
path: root/internal/levenshtein/levenshtein.go
blob: cf28607cd77f25da0e733c9da892470ebb683295 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package levenshtein

import (
	"strings"
	"unicode"
	"unicode/utf8"

	"golang.org/x/text/runes"
	"golang.org/x/text/transform"
	"golang.org/x/text/unicode/norm"
)

// min returns the min of a and b
func min(a, b int) int {
	if a < b {
		return a
	}
	return b
}

// levenshtein is an algorithm that returns the "distance" between two strings
// the distance for hello and helloxd is 2 because it takes two inserts to go from hello to helloxd
// the distance between hello and hello is 0 because the strings are equal
// apart from insertions, the levenshtein algorithm also takes substitutions and deletions into account
// levenshtein implementation from https://en.wikipedia.org/wiki/Levenshtein_distance#Iterative_with_two_matrix_rows
func levenshtein(os, ot string) int {
	n := utf8.RuneCountInString(os)
	m := utf8.RuneCountInString(ot)
	s := []rune(os)
	t := []rune(ot)
	v0 := make([]int, m+1)
	v1 := make([]int, m+1)
	for i := 0; i <= m; i++ {
		v0[i] = i
	}

	// loop through every word in the first string
	for i := 0; i < n; i++ {
		v1[0] = i + 1
		for j := 0; j < m; j++ {
			// calculate deletion cost,
			// insertion cost and
			// substitution cost to get from the string
			// to the target
			dc := v0[j+1] + 1
			ic := v1[j] + 1
			var sc int
			if s[i] == t[j] {
				sc = v0[j]
			} else {
				sc = v0[j] + 1
			}
			// take the min of all the costs
			v1[j+1] = min(min(dc, ic), sc)
		}
		v0, v1 = v1, v0
	}
	return v0[m]
}

// adjusted creates and adjusted version of the levenshtein algorithm
// where it filters entries where one of the words in the substr is not contained in `full`
// for these a score of -1 returned
// for all others it is the normal levenshtein distance
func adjusted(substr, full string) int {
	sSub := strings.Split(substr, " ")
	for _, vSub := range sSub {
		if !strings.Contains(full, vSub) {
			return -1
		}
	}
	return levenshtein(substr, full)
}

// KeywordPenalty is the penalty for matching on keywords instead of display names
// We have a penalty for matching on keywords because currently there are a lot of generic keywords
const KeywordPenalty = 2

// DiscoveryScore computes the score of a discovery entry with the given search query
// a negative score means exclude the entry from the results
func DiscoveryScore(search string, displays map[string]string, keywords map[string]string) int {
	search = normalize(search)
	scoreDN := -1
	for _, v := range displays {
		score := adjusted(search, normalize(v))
		// set the smallest non-zero score
		if (score >= 0 && score < scoreDN) || scoreDN == -1 {
			scoreDN = score
		}
	}
	scoreKW := -1
	for _, v := range keywords {
		score := KeywordPenalty * adjusted(search, normalize(v))
		if score == 0 {
			score = KeywordPenalty
		}
		// set the smallest non-zero score
		if (score >= 0 && score < scoreKW) || scoreKW == -1 {
			scoreKW = score
		}
	}

	// if both scores are positive, return the min
	if scoreDN >= 0 && scoreKW >= 0 {
		return min(scoreDN, scoreKW)
	}

	// scoreKW is negative, return scoreDN
	if scoreDN >= 0 {
		return scoreDN
	}
	// scoreDN is negative, return scoreKW
	return scoreKW
}

// removeDiacritics removes "diacritics" :^)
// diacritics are special characters, e.g. GÉANT, becomes GEANT
func removeDiacritics(text string) (string, error) {
	t := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
	result, _, err := transform.String(t, text)
	if err != nil {
		return text, err
	}
	return result, nil
}

// normalize removes diacritics and converts to lower case
func normalize(text string) string {
	dt, _ := removeDiacritics(text)
	return strings.ToLower(dt)
}