diff --git a/cmd/server/issue115_test.go b/cmd/server/issue115_test.go index a3c1533d..c8bcd8dc 100644 --- a/cmd/server/issue115_test.go +++ b/cmd/server/issue115_test.go @@ -8,14 +8,15 @@ import ( "testing" "github.com/moov-io/base/log" + "github.com/moov-io/watchman/internal/stringscore" "github.com/moov-io/watchman/pkg/ofac" ) func TestIssue115__TopSDNs(t *testing.T) { - score := jaroWinkler("georgehabbash", "georgebush") - eql(t, "george bush jaroWinkler", score, 0.896) + score := stringscore.JaroWinkler("georgehabbash", "georgebush") + eql(t, "george bush stringscore.JaroWinkler", score, 0.896) - score = jaroWinkler("g", "geoergebush") + score = stringscore.JaroWinkler("g", "geoergebush") eql(t, "g vs geoergebush", score, 0.070) pipe := noLogPipeliner diff --git a/cmd/server/issue326_test.go b/cmd/server/issue326_test.go index ec3187ab..0c51181d 100644 --- a/cmd/server/issue326_test.go +++ b/cmd/server/issue326_test.go @@ -7,26 +7,29 @@ package main import ( "testing" + "github.com/moov-io/watchman/internal/prepare" + "github.com/moov-io/watchman/internal/stringscore" + "github.com/stretchr/testify/assert" ) func TestIssue326(t *testing.T) { - india := precompute("Huawei Technologies India Private Limited") - investment := precompute("Huawei Technologies Investment Co. Ltd.") + india := prepare.LowerAndRemovePunctuation("Huawei Technologies India Private Limited") + investment := prepare.LowerAndRemovePunctuation("Huawei Technologies Investment Co. Ltd.") // Cuba - score := jaroWinkler(precompute("Huawei Cuba"), precompute("Huawei")) + score := stringscore.JaroWinkler(prepare.LowerAndRemovePunctuation("Huawei Cuba"), prepare.LowerAndRemovePunctuation("Huawei")) assert.Equal(t, 0.7444444444444445, score) // India - score = jaroWinkler(india, precompute("Huawei")) + score = stringscore.JaroWinkler(india, prepare.LowerAndRemovePunctuation("Huawei")) assert.Equal(t, 0.4846031746031746, score) - score = jaroWinkler(india, precompute("Huawei Technologies")) + score = stringscore.JaroWinkler(india, prepare.LowerAndRemovePunctuation("Huawei Technologies")) assert.Equal(t, 0.6084415584415584, score) // Investment - score = jaroWinkler(investment, precompute("Huawei")) + score = stringscore.JaroWinkler(investment, prepare.LowerAndRemovePunctuation("Huawei")) assert.Equal(t, 0.3788888888888889, score) - score = jaroWinkler(investment, precompute("Huawei Technologies")) + score = stringscore.JaroWinkler(investment, prepare.LowerAndRemovePunctuation("Huawei Technologies")) assert.Equal(t, 0.5419191919191919, score) } diff --git a/cmd/server/main.go b/cmd/server/main.go index e060cac0..14c7cba6 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -24,6 +24,7 @@ import ( "github.com/moov-io/base/http/bind" "github.com/moov-io/base/log" "github.com/moov-io/watchman" + "github.com/moov-io/watchman/internal/prepare" searchv2 "github.com/moov-io/watchman/internal/search" "github.com/moov-io/watchman/pkg/ofac" pubsearch "github.com/moov-io/watchman/pkg/search" @@ -129,11 +130,11 @@ func main() { }() defer adminServer.Shutdown() - var pipeline *pipeliner + var pipeline *prepare.Pipeliner if debug, err := strconv.ParseBool(os.Getenv("DEBUG_NAME_PIPELINE")); debug && err == nil { - pipeline = newPipeliner(logger, true) + pipeline = prepare.NewPipeliner(logger, true) } else { - pipeline = newPipeliner(log.NewNopLogger(), false) + pipeline = prepare.NewPipeliner(log.NewNopLogger(), false) } searchWorkers := readInt(os.Getenv("SEARCH_MAX_WORKERS"), *flagWorkers) @@ -299,3 +300,14 @@ func generalizeOFACSDNs(input []*SDN, ofacAddresses []*Address) []pubsearch.Enti func addSearchV2Routes(logger log.Logger, r *mux.Router, service searchv2.Service) { searchv2.NewController(logger, service).AppendRoutes(r) } + +func readInt(override string, value int) int { + if override != "" { + n, err := strconv.ParseInt(override, 10, 32) + if err != nil { + panic(fmt.Errorf("unable to parse %q as int", override)) //nolint:forbidigo + } + return int(n) + } + return value +} diff --git a/cmd/server/search.go b/cmd/server/search.go index 63e07d69..6c103562 100644 --- a/cmd/server/search.go +++ b/cmd/server/search.go @@ -8,24 +8,20 @@ import ( "bytes" "encoding/json" "errors" - "fmt" - "math" - "os" - "sort" "strconv" "strings" "sync" "time" "github.com/moov-io/base/log" - "github.com/moov-io/base/strx" + "github.com/moov-io/watchman/internal/prepare" + "github.com/moov-io/watchman/internal/stringscore" "github.com/moov-io/watchman/pkg/csl_eu" "github.com/moov-io/watchman/pkg/csl_uk" "github.com/moov-io/watchman/pkg/csl_us" "github.com/moov-io/watchman/pkg/dpl" "github.com/moov-io/watchman/pkg/ofac" - "github.com/xrash/smetrics" "go4.org/syncutil" ) @@ -35,7 +31,7 @@ var ( softResultsLimit, hardResultsLimit = 10, 100 ) -// searcher holds precomputed data for each object available to search against. +// searcher holds prepare.LowerAndRemovePunctuationd data for each object available to search against. // This data comes from various US and EU Federal agencies type searcher struct { // OFAC @@ -76,12 +72,12 @@ type searcher struct { sync.RWMutex // protects all above fields *syncutil.Gate // limits concurrent processing - pipe *pipeliner + pipe *prepare.Pipeliner logger log.Logger } -func newSearcher(logger log.Logger, pipeline *pipeliner, workers int) *searcher { +func newSearcher(logger log.Logger, pipeline *prepare.Pipeliner, workers int) *searcher { logger.Logf("allowing only %d workers for search", workers) return &searcher{ logger: logger.With(log.Fields{ @@ -121,7 +117,7 @@ var ( return func(add *Address) *item { return &item{ value: add, - weight: jaroWinkler(add.address, precompute(needleAddr)), + weight: stringscore.JaroWinkler(add.address, prepare.LowerAndRemovePunctuation(needleAddr)), } } } @@ -133,7 +129,7 @@ var ( return func(add *Address) *item { return &item{ value: add, - weight: jaroWinkler(add.citystate, precompute(needleCityState)), + weight: stringscore.JaroWinkler(add.citystate, prepare.LowerAndRemovePunctuation(needleCityState)), } } } @@ -143,7 +139,7 @@ var ( return func(add *Address) *item { return &item{ value: add, - weight: jaroWinkler(add.country, precompute(needleCountry)), + weight: stringscore.JaroWinkler(add.country, prepare.LowerAndRemovePunctuation(needleCountry)), } } } @@ -251,7 +247,7 @@ func (s *searcher) FindAlts(limit int, id string) []*ofac.AlternateIdentity { } func (s *searcher) TopAltNames(limit int, minMatch float64, alt string) []Alt { - alt = precompute(alt) + alt = prepare.LowerAndRemovePunctuation(alt) altTokens := strings.Fields(alt) s.RLock() @@ -273,7 +269,7 @@ func (s *searcher) TopAltNames(limit int, minMatch float64, alt string) []Alt { xs.add(&item{ matched: s.Alts[i].name, value: s.Alts[i], - weight: bestPairsJaroWinkler(altTokens, s.Alts[i].name), + weight: stringscore.BestPairsJaroWinkler(altTokens, s.Alts[i].name), }) }(i) } @@ -295,123 +291,6 @@ func (s *searcher) TopAltNames(limit int, minMatch float64, alt string) []Alt { return out } -// bestPairsJaroWinkler compares a search query to an indexed term (name, address, etc) and returns a decimal fraction -// score. -// -// The algorithm splits each string into tokens, and does a pairwise Jaro-Winkler score of all token combinations -// (outer product). The best match for each search token is chosen, such that each index token can be matched at most -// once. -// -// The pairwise scores are combined into an average in a way that corrects for character length, and the fraction of the -// indexed term that didn't match. -func bestPairsJaroWinkler(searchTokens []string, indexed string) float64 { - type Score struct { - score float64 - searchTokenIdx int - indexTokenIdx int - } - - indexedTokens := strings.Fields(indexed) - searchTokensLength := sumLength(searchTokens) - indexTokensLength := sumLength(indexedTokens) - - disablePhoneticFiltering := strx.Yes(os.Getenv("DISABLE_PHONETIC_FILTERING")) - - //Compare each search token to each indexed token. Sort the results in descending order - scoresCapacity := (len(searchTokens) + len(indexedTokens)) - if !disablePhoneticFiltering { - scoresCapacity /= 5 // reduce the capacity as many terms don't phonetically match - } - scores := make([]Score, 0, scoresCapacity) - for searchIdx, searchToken := range searchTokens { - for indexIdx, indexedToken := range indexedTokens { - // Compare the first letters phonetically and only run jaro-winkler on those which are similar - if disablePhoneticFiltering || firstCharacterSoundexMatch(indexedToken, searchToken) { - score := customJaroWinkler(indexedToken, searchToken) - scores = append(scores, Score{score, searchIdx, indexIdx}) - } - } - } - sort.Slice(scores[:], func(i, j int) bool { - return scores[i].score > scores[j].score - }) - - //Pick the highest score for each search term, where the indexed token hasn't yet been matched - matchedSearchTokens := make([]bool, len(searchTokens)) - matchedIndexTokens := make([]bool, len(indexedTokens)) - matchedIndexTokensLength := 0 - totalWeightedScores := 0.0 - for _, score := range scores { - //If neither the search token nor index token have been matched so far - if !matchedSearchTokens[score.searchTokenIdx] && !matchedIndexTokens[score.indexTokenIdx] { - //Weight the importance of this word score by its character length - searchToken := searchTokens[score.searchTokenIdx] - indexToken := indexedTokens[score.indexTokenIdx] - totalWeightedScores += score.score * float64(len(searchToken)+len(indexToken)) - - matchedSearchTokens[score.searchTokenIdx] = true - matchedIndexTokens[score.indexTokenIdx] = true - matchedIndexTokensLength += len(indexToken) - } - } - lengthWeightedAverageScore := totalWeightedScores / float64(searchTokensLength+matchedIndexTokensLength) - - //If some index tokens weren't matched by any search token, penalise this search a small amount. If this isn't done, - //a query of "John Doe" will match "John Doe" and "John Bartholomew Doe" equally well. - //Calculate the fraction of the index name that wasn't matched, apply a weighting to reduce the importance of - //unmatched portion, then scale down the final score. - matchedIndexLength := 0 - for i, str := range indexedTokens { - if matchedIndexTokens[i] { - matchedIndexLength += len(str) - } - } - matchedFraction := float64(matchedIndexLength) / float64(indexTokensLength) - return lengthWeightedAverageScore * scalingFactor(matchedFraction, unmatchedIndexPenaltyWeight) -} - -func customJaroWinkler(s1 string, s2 string) float64 { - score := smetrics.JaroWinkler(s1, s2, boostThreshold, prefixSize) - - if lengthMetric := lengthDifferenceFactor(s1, s2); lengthMetric < lengthDifferenceCutoffFactor { - //If there's a big difference in matched token lengths, punish the score. Jaro-Winkler is quite permissive about - //different lengths - score = score * scalingFactor(lengthMetric, lengthDifferencePenaltyWeight) - } - if s1[0] != s2[0] { - //Penalise words that start with a different characters. Jaro-Winkler is too lenient on this - //TODO should use a phonetic comparison here, like Soundex - score = score * differentLetterPenaltyWeight - } - return score -} - -// scalingFactor returns a float [0,1] that can be used to scale another number down, given some metric and a desired -// weight -// e.g. If a score has a 50% value according to a metric, and we want a 10% weight to the metric: -// -// scaleFactor := scalingFactor(0.5, 0.1) // 0.95 -// scaledScore := score * scaleFactor -func scalingFactor(metric float64, weight float64) float64 { - return 1.0 - (1.0-metric)*weight -} - -func sumLength(strs []string) int { - totalLength := 0 - for _, str := range strs { - totalLength += len(str) - } - return totalLength -} - -func lengthDifferenceFactor(s1 string, s2 string) float64 { - ls1 := float64(len(s1)) - ls2 := float64(len(s2)) - min := math.Min(ls1, ls2) - max := math.Max(ls1, ls2) - return min / max -} - func (s *searcher) FindSDN(entityID string) *ofac.SDN { if sdn := s.debugSDN(entityID); sdn != nil { return sdn.SDN @@ -490,7 +369,7 @@ func (s *searcher) FindSDNsByRemarksID(limit int, id string) []*SDN { } func (s *searcher) TopSDNs(limit int, minMatch float64, name string, keepSDN func(*SDN) bool) []*SDN { - name = precompute(name) + name = prepare.LowerAndRemovePunctuation(name) nameTokens := strings.Fields(name) s.RLock() @@ -516,7 +395,7 @@ func (s *searcher) TopSDNs(limit int, minMatch float64, name string, keepSDN fun xs.add(&item{ matched: s.SDNs[i].name, value: s.SDNs[i], - weight: bestPairsJaroWinkler(nameTokens, s.SDNs[i].name), + weight: stringscore.BestPairsJaroWinkler(nameTokens, s.SDNs[i].name), }) }(i) } @@ -539,7 +418,7 @@ func (s *searcher) TopSDNs(limit int, minMatch float64, name string, keepSDN fun } func (s *searcher) TopDPs(limit int, minMatch float64, name string) []DP { - name = precompute(name) + name = prepare.LowerAndRemovePunctuation(name) nameTokens := strings.Fields(name) s.RLock() @@ -561,7 +440,7 @@ func (s *searcher) TopDPs(limit int, minMatch float64, name string) []DP { xs.add(&item{ matched: s.DPs[i].name, value: s.DPs[i], - weight: bestPairsJaroWinkler(nameTokens, s.DPs[i].name), + weight: stringscore.BestPairsJaroWinkler(nameTokens, s.DPs[i].name), }) }(i) } @@ -583,7 +462,7 @@ func (s *searcher) TopDPs(limit int, minMatch float64, name string) []DP { return out } -// SDN is ofac.SDN wrapped with precomputed search metadata +// SDN is ofac.SDN wrapped with prepare.LowerAndRemovePunctuationd search metadata type SDN struct { *ofac.SDN @@ -593,7 +472,7 @@ type SDN struct { // matchedName holds the highest scoring term from the search query matchedName string - // name is precomputed for speed + // name is prepare.LowerAndRemovePunctuationd for speed name string // id is the parseed ID value from an SDN's remarks field. Often this @@ -627,13 +506,12 @@ func findAddresses(entityID string, addrs []*ofac.Address) []*ofac.Address { return out } -func precomputeSDNs(sdns []*ofac.SDN, addrs []*ofac.Address, pipe *pipeliner) []*SDN { +func precomputeSDNs(sdns []*ofac.SDN, addrs []*ofac.Address, pipe *prepare.Pipeliner) []*SDN { out := make([]*SDN, len(sdns)) for i := range sdns { - nn := sdnName(sdns[i], findAddresses(sdns[i].EntityID, addrs)) + nn := prepare.SdnName(sdns[i], findAddresses(sdns[i].EntityID, addrs)) if err := pipe.Do(nn); err != nil { - pipe.logger.Logf("pipeline", fmt.Sprintf("problem pipelining SDN: %v", err)) continue } @@ -646,13 +524,13 @@ func precomputeSDNs(sdns []*ofac.SDN, addrs []*ofac.Address, pipe *pipeliner) [] return out } -// Address is ofac.Address wrapped with precomputed search metadata +// Address is ofac.Address wrapped with prepare.LowerAndRemovePunctuationd search metadata type Address struct { Address *ofac.Address match float64 // match % - // precomputed fields for speed + // prepare.LowerAndRemovePunctuationd fields for speed address, citystate, country string } @@ -672,15 +550,15 @@ func precomputeAddresses(adds []*ofac.Address) []*Address { for i := range adds { out[i] = &Address{ Address: adds[i], - address: precompute(adds[i].Address), - citystate: precompute(adds[i].CityStateProvincePostalCode), - country: precompute(adds[i].Country), + address: prepare.LowerAndRemovePunctuation(adds[i].Address), + citystate: prepare.LowerAndRemovePunctuation(adds[i].CityStateProvincePostalCode), + country: prepare.LowerAndRemovePunctuation(adds[i].Country), } } return out } -// Alt is an ofac.AlternateIdentity wrapped with precomputed search metadata +// Alt is an ofac.AlternateIdentity wrapped with prepare.LowerAndRemovePunctuationd search metadata type Alt struct { AlternateIdentity *ofac.AlternateIdentity @@ -690,7 +568,7 @@ type Alt struct { // matchedName holds the highest scoring term from the search query matchedName string - // name is precomputed for speed + // name is prepare.LowerAndRemovePunctuationd for speed name string } @@ -707,13 +585,12 @@ func (a Alt) MarshalJSON() ([]byte, error) { }) } -func precomputeAlts(alts []*ofac.AlternateIdentity, pipe *pipeliner) []*Alt { +func precomputeAlts(alts []*ofac.AlternateIdentity, pipe *prepare.Pipeliner) []*Alt { out := make([]*Alt, len(alts)) for i := range alts { - an := altName(alts[i]) + an := prepare.AltName(alts[i]) if err := pipe.Do(an); err != nil { - pipe.logger.LogErrorf("problem pipelining SDN: %v", err) continue } @@ -725,7 +602,7 @@ func precomputeAlts(alts []*ofac.AlternateIdentity, pipe *pipeliner) []*Alt { return out } -// DP is a BIS Denied Person wrapped with precomputed search metadata +// DP is a BIS Denied Person wrapped with prepare.LowerAndRemovePunctuationd search metadata type DP struct { DeniedPerson *dpl.DPL match float64 @@ -746,12 +623,11 @@ func (d DP) MarshalJSON() ([]byte, error) { }) } -func precomputeDPs(persons []*dpl.DPL, pipe *pipeliner) []*DP { +func precomputeDPs(persons []*dpl.DPL, pipe *prepare.Pipeliner) []*DP { out := make([]*DP, len(persons)) for i := range persons { - nn := dpName(persons[i]) + nn := prepare.DPName(persons[i]) if err := pipe.Do(nn); err != nil { - pipe.logger.LogErrorf("problem pipelining DP: %v", err) continue } out[i] = &DP{ @@ -762,141 +638,6 @@ func precomputeDPs(persons []*dpl.DPL, pipe *pipeliner) []*DP { return out } -var ( - // Jaro-Winkler parameters - boostThreshold = readFloat(os.Getenv("JARO_WINKLER_BOOST_THRESHOLD"), 0.7) - prefixSize = readInt(os.Getenv("JARO_WINKLER_PREFIX_SIZE"), 4) - // Customised Jaro-Winkler parameters - lengthDifferenceCutoffFactor = readFloat(os.Getenv("LENGTH_DIFFERENCE_CUTOFF_FACTOR"), 0.9) - lengthDifferencePenaltyWeight = readFloat(os.Getenv("LENGTH_DIFFERENCE_PENALTY_WEIGHT"), 0.3) - differentLetterPenaltyWeight = readFloat(os.Getenv("DIFFERENT_LETTER_PENALTY_WEIGHT"), 0.9) - - // Watchman parameters - exactMatchFavoritism = readFloat(os.Getenv("EXACT_MATCH_FAVORITISM"), 0.0) - unmatchedIndexPenaltyWeight = readFloat(os.Getenv("UNMATCHED_INDEX_TOKEN_WEIGHT"), 0.15) -) - -func readFloat(override string, value float64) float64 { - if override != "" { - n, err := strconv.ParseFloat(override, 32) - if err != nil { - panic(fmt.Errorf("unable to parse %q as float64", override)) //nolint:forbidigo - } - return n - } - return value -} - -func readInt(override string, value int) int { - if override != "" { - n, err := strconv.ParseInt(override, 10, 32) - if err != nil { - panic(fmt.Errorf("unable to parse %q as int", override)) //nolint:forbidigo - } - return int(n) - } - return value -} - -// jaroWinkler runs the similarly named algorithm over the two input strings and averages their match percentages -// according to the second string (assumed to be the user's query) -// -// Terms are compared between a few adjacent terms and accumulate the highest near-neighbor match. -// -// For more details see https://en.wikipedia.org/wiki/Jaro%E2%80%93Winkler_distance -func jaroWinkler(s1, s2 string) float64 { - return jaroWinklerWithFavoritism(s1, s2, exactMatchFavoritism) -} - -var ( - adjacentSimilarityPositions = readInt(os.Getenv("ADJACENT_SIMILARITY_POSITIONS"), 3) -) - -func jaroWinklerWithFavoritism(indexedTerm, query string, favoritism float64) float64 { - maxMatch := func(indexedWord string, indexedWordIdx int, queryWords []string) (float64, string) { - if indexedWord == "" || len(queryWords) == 0 { - return 0.0, "" - } - - // We're only looking for the highest match close - start := indexedWordIdx - adjacentSimilarityPositions - end := indexedWordIdx + adjacentSimilarityPositions - - var max float64 - var maxTerm string - for i := start; i < end; i++ { - if i >= 0 && len(queryWords) > i { - score := smetrics.JaroWinkler(indexedWord, queryWords[i], boostThreshold, prefixSize) - if score > max { - max = score - maxTerm = queryWords[i] - } - } - } - return max, maxTerm - } - - indexedWords, queryWords := strings.Fields(indexedTerm), strings.Fields(query) - if len(indexedWords) == 0 || len(queryWords) == 0 { - return 0.0 // avoid returning NaN later on - } - - var scores []float64 - for i := range indexedWords { - max, term := maxMatch(indexedWords[i], i, queryWords) - //fmt.Printf("%s maxMatch %s %f\n", indexedWords[i], term, max) - if max >= 1.0 { - // If the query is longer than our indexed term (and EITHER are longer than most names) - // we want to reduce the maximum weight proportionally by the term difference, which - // forces more terms to match instead of one or two dominating the weight. - if (len(queryWords) > len(indexedWords)) && (len(indexedWords) > 3 || len(queryWords) > 3) { - max *= (float64(len(indexedWords)) / float64(len(queryWords))) - goto add - } - // If the indexed term is really short cap the match at 90%. - // This sill allows names to match highly with a couple different characters. - if len(indexedWords) == 1 && len(queryWords) > 1 { - max *= 0.9 - goto add - } - // Otherwise, apply Perfect match favoritism - max += favoritism - add: - scores = append(scores, max) - } else { - // If there are more terms in the user's query than what's indexed then - // adjust the max lower by the proportion of different terms. - // - // We do this to decrease the importance of a short (often common) term. - if len(queryWords) > len(indexedWords) { - scores = append(scores, max*float64(len(indexedWords))/float64(len(queryWords))) - continue - } - - // Apply an additional weight based on similarity of term lengths, - // so terms which are closer in length match higher. - s1 := float64(len(indexedWords[i])) - t := float64(len(term)) - 1 - weight := math.Min(math.Abs(s1/t), 1.0) - - scores = append(scores, max*weight) - } - } - - // average the highest N scores where N is the words in our query (query). - // Only truncate scores if there are enough words (aka more than First/Last). - sort.Float64s(scores) - if len(indexedWords) > len(queryWords) && len(queryWords) > 5 { - scores = scores[len(indexedWords)-len(queryWords):] - } - - var sum float64 - for i := range scores { - sum += scores[i] - } - return math.Min(sum/float64(len(scores)), 1.00) -} - // extractIDFromRemark attempts to parse out a National ID or similar governmental ID value // from an SDN's remarks property. // diff --git a/cmd/server/search_generic.go b/cmd/server/search_generic.go index df977fb1..4359df83 100644 --- a/cmd/server/search_generic.go +++ b/cmd/server/search_generic.go @@ -9,6 +9,9 @@ import ( "reflect" "strings" "sync" + + "github.com/moov-io/watchman/internal/prepare" + "github.com/moov-io/watchman/internal/stringscore" ) type Result[T any] struct { @@ -52,7 +55,7 @@ func topResults[T any](limit int, minMatch float64, name string, data []*Result[ return nil } - name = precompute(name) + name = prepare.LowerAndRemovePunctuation(name) nameTokens := strings.Fields(name) xs := newLargest(limit, minMatch) @@ -67,7 +70,7 @@ func topResults[T any](limit int, minMatch float64, name string, data []*Result[ it := &item{ matched: data[i].precomputedName, value: data[i], - weight: bestPairsJaroWinkler(nameTokens, data[i].precomputedName), + weight: stringscore.BestPairsJaroWinkler(nameTokens, data[i].precomputedName), } for _, alt := range data[i].precomputedAlts { @@ -75,7 +78,7 @@ func topResults[T any](limit int, minMatch float64, name string, data []*Result[ continue } - score := bestPairsJaroWinkler(nameTokens, alt) + score := stringscore.BestPairsJaroWinkler(nameTokens, alt) if score > it.weight { it.matched = alt it.weight = score diff --git a/cmd/server/search_handlers_bench_test.go b/cmd/server/search_handlers_bench_test.go index 0a22735d..5db3348b 100644 --- a/cmd/server/search_handlers_bench_test.go +++ b/cmd/server/search_handlers_bench_test.go @@ -5,20 +5,13 @@ package main import ( - "crypto/rand" "fmt" - "io" - "math/big" "net/http" "net/http/httptest" "net/url" - "os" - "path/filepath" - "strings" "testing" "github.com/moov-io/base/log" - "github.com/moov-io/watchman/pkg/ofac" "github.com/gorilla/mux" "github.com/stretchr/testify/require" @@ -56,31 +49,3 @@ func BenchmarkSearchHandler(b *testing.B) { } require.NoError(b, g.Wait()) } - -func BenchmarkJaroWinkler(b *testing.B) { - fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "sdn.csv")) - if err != nil { - b.Error(err) - } - results, err := ofac.Read(map[string]io.ReadCloser{"sdn.csv": fd}) - require.NoError(b, err) - require.Len(b, results.SDNs, 7379) - - randomIndex := func(length int) int { - n, err := rand.Int(rand.Reader, big.NewInt(1e9)) - if err != nil { - panic(err) - } - return int(n.Int64()) % length - } - - b.Run("bestPairsJaroWinkler", func(b *testing.B) { - for i := 0; i < b.N; i++ { - nameTokens := strings.Fields(fake.Person().Name()) - idx := randomIndex(len(results.SDNs)) - - score := bestPairsJaroWinkler(nameTokens, results.SDNs[idx].SDNName) - require.Greater(b, score, -0.01) - } - }) -} diff --git a/cmd/server/search_test.go b/cmd/server/search_test.go index 57892175..24f9db68 100644 --- a/cmd/server/search_test.go +++ b/cmd/server/search_test.go @@ -10,16 +10,17 @@ import ( "net/http/httptest" "net/url" "path/filepath" - "strings" "sync" "testing" "github.com/moov-io/base/log" + "github.com/moov-io/watchman/internal/prepare" "github.com/moov-io/watchman/pkg/csl_eu" "github.com/moov-io/watchman/pkg/csl_uk" "github.com/moov-io/watchman/pkg/csl_us" "github.com/moov-io/watchman/pkg/dpl" "github.com/moov-io/watchman/pkg/ofac" + "github.com/stretchr/testify/require" ) @@ -29,6 +30,8 @@ var ( testSearcherStats *DownloadStats testSearcherOnce sync.Once + noLogPipeliner = prepare.NewPipeliner(log.NewNopLogger(), false) + // Mock Searchers addressSearcher = newSearcher(log.NewNopLogger(), noLogPipeliner, 1) altSearcher = newSearcher(log.NewNopLogger(), noLogPipeliner, 1) @@ -431,152 +434,6 @@ func verifyDownloadStats(b *testing.B) { require.Equal(b, 0, testSearcherStats.UKSanctionsList) } -func TestJaroWinkler(t *testing.T) { - cases := []struct { - indexed, search string - match float64 - }{ - // examples - {"wei, zhao", "wei, Zhao", 0.875}, - {"WEI, Zhao", "WEI, Zhao", 1.0}, - {"WEI Zhao", "WEI Zhao", 1.0}, - {strings.ToLower("WEI Zhao"), precompute("WEI, Zhao"), 1.0}, - - // apply jaroWinkler in both directions - {"jane doe", "jan lahore", 0.439}, - {"jan lahore", "jane doe", 0.549}, - - // real world case - {"john doe", "paul john", 0.624}, - {"john doe", "john othername", 0.440}, - - // close match - {"jane doe", "jane doe2", 0.940}, - - // real-ish world examples - {"kalamity linden", "kala limited", 0.687}, - {"kala limited", "kalamity linden", 0.687}, - - // examples used in demos / commonly - {"nicolas", "nicolas", 1.0}, - {"nicolas moros maduro", "nicolas maduro", 0.958}, - {"nicolas maduro", "nicolas moros maduro", 0.839}, - - // customer examples - {"ian", "ian mckinley", 0.429}, - {"iap", "ian mckinley", 0.352}, - {"ian mckinley", "ian", 0.891}, - {"ian mckinley", "iap", 0.733}, - {"ian mckinley", "tian xiang 7", 0.000}, - {"bindaree food group pty", precompute("independent insurance group ltd"), 0.269}, // precompute removes ltd - {"bindaree food group pty ltd", "independent insurance group ltd", 0.401}, // only matches higher from 'ltd' - {"p.c.c. (singapore) private limited", "culver max entertainment private limited", 0.514}, - {"zincum llc", "easy verification inc.", 0.000}, - {"transpetrochart co ltd", "jx metals trading co.", 0.431}, - {"technolab", "moomoo technologies inc", 0.565}, - {"sewa security services", "sesa - safety & environmental services australia pty ltd", 0.480}, - {"bueno", "20/f rykadan capital twr135 hoi bun rd, kwun tong 135 hoi bun rd., kwun tong", 0.094}, - - // example cases - {"nicolas maduro", "nicolás maduro", 0.937}, - {"nicolas maduro", precompute("nicolás maduro"), 1.0}, - {"nic maduro", "nicolas maduro", 0.872}, - {"nick maduro", "nicolas maduro", 0.859}, - {"nicolas maduroo", "nicolas maduro", 0.966}, - {"nicolas maduro", "nicolas maduro", 1.0}, - {"maduro, nicolas", "maduro, nicolas", 1.0}, - {"maduro moros, nicolas", "maduro moros, nicolas", 1.0}, - {"maduro moros, nicolas", "nicolas maduro", 0.953}, - {"nicolas maduro moros", "maduro", 0.900}, - {"nicolas maduro moros", "nicolás maduro", 0.898}, - {"nicolas, maduro moros", "maduro", 0.897}, - {"nicolas, maduro moros", "nicolas maduro", 0.928}, - {"nicolas, maduro moros", "nicolás", 0.822}, - {"nicolas, maduro moros", "maduro", 0.897}, - {"nicolas, maduro moros", "nicolás maduro", 0.906}, - {"africada financial services bureau change", "skylight", 0.441}, - {"africada financial services bureau change", "skylight financial inc", 0.658}, - {"africada financial services bureau change", "skylight services inc", 0.599}, - {"africada financial services bureau change", "skylight financial services", 0.761}, - {"africada financial services bureau change", "skylight financial services inc", 0.730}, - - // stopwords tests - {"the group for the preservation of the holy sites", "the bridgespan group", 0.682}, - {precompute("the group for the preservation of the holy sites"), precompute("the bridgespan group"), 0.682}, - {"group preservation holy sites", "bridgespan group", 0.652}, - - {"the group for the preservation of the holy sites", "the logan group", 0.670}, - {precompute("the group for the preservation of the holy sites"), precompute("the logan group"), 0.670}, - {"group preservation holy sites", "logan group", 0.586}, - - {"the group for the preservation of the holy sites", "the anything group", 0.546}, - {precompute("the group for the preservation of the holy sites"), precompute("the anything group"), 0.546}, - {"group preservation holy sites", "anything group", 0.488}, - - {"the group for the preservation of the holy sites", "the hello world group", 0.637}, - {precompute("the group for the preservation of the holy sites"), precompute("the hello world group"), 0.637}, - {"group preservation holy sites", "hello world group", 0.577}, - - {"the group for the preservation of the holy sites", "the group", 0.880}, - {precompute("the group for the preservation of the holy sites"), precompute("the group"), 0.880}, - {"group preservation holy sites", "group", 0.879}, - - {"the group for the preservation of the holy sites", "The flibbity jibbity flobbity jobbity grobbity zobbity group", 0.345}, - { - precompute("the group for the preservation of the holy sites"), - precompute("the flibbity jibbity flobbity jobbity grobbity zobbity group"), - 0.366, - }, - {"group preservation holy sites", "flibbity jibbity flobbity jobbity grobbity zobbity group", 0.263}, - - // precompute - {"i c sogo kenkyusho", precompute("A.I.C. SOGO KENKYUSHO"), 0.858}, - {precompute("A.I.C. SOGO KENKYUSHO"), "sogo kenkyusho", 0.972}, - } - for i := range cases { - v := cases[i] - // Only need to call chomp on s1, see jaroWinkler doc - eql(t, fmt.Sprintf("#%d %s vs %s", i, v.indexed, v.search), bestPairsJaroWinkler(strings.Fields(v.search), v.indexed), v.match) - } -} - -func TestJaroWinklerWithFavoritism(t *testing.T) { - favoritism := 1.0 - delta := 0.01 - - score := jaroWinklerWithFavoritism("Vladimir Putin", "PUTIN, Vladimir Vladimirovich", favoritism) - require.InDelta(t, score, 1.00, delta) - - score = jaroWinklerWithFavoritism("nicolas, maduro moros", "nicolás maduro", 0.25) - require.InDelta(t, score, 0.96, delta) - - score = jaroWinklerWithFavoritism("Vladimir Putin", "A.I.C. SOGO KENKYUSHO", favoritism) - require.InDelta(t, score, 0.00, delta) -} - -func TestJaroWinklerErr(t *testing.T) { - v := jaroWinkler("", "hello") - eql(t, "NaN #1", v, 0.0) - - v = jaroWinkler("hello", "") - eql(t, "NaN #1", v, 0.0) -} - -func eql(t *testing.T, desc string, x, y float64) { - t.Helper() - if math.IsNaN(x) || math.IsNaN(y) { - t.Fatalf("%s: x=%.2f y=%.2f", desc, x, y) - } - if math.Abs(x-y) > 0.01 { - t.Errorf("%s: %.3f != %.3f", desc, x, y) - } -} - -func TestEql(t *testing.T) { - eql(t, "", 0.1, 0.1) - eql(t, "", 0.0001, 0.00002) -} - // TestSearch_liveData will download the real data and run searches against the corpus. // This test is designed to tweak match percents and results. func TestSearch_liveData(t *testing.T) { @@ -867,3 +724,18 @@ func TestSearch__FindSDNsByRemarksID(t *testing.T) { t.Fatalf("sdns=%#v", sdns) } } + +func eql(t *testing.T, desc string, x, y float64) { + t.Helper() + if math.IsNaN(x) || math.IsNaN(y) { + t.Fatalf("%s: x=%.2f y=%.2f", desc, x, y) + } + if math.Abs(x-y) > 0.01 { + t.Errorf("%s: %.3f != %.3f", desc, x, y) + } +} + +func TestEql(t *testing.T) { + eql(t, "", 0.1, 0.1) + eql(t, "", 0.0001, 0.00002) +} diff --git a/cmd/server/search_us_csl.go b/cmd/server/search_us_csl.go index ac640314..f6ef91fe 100644 --- a/cmd/server/search_us_csl.go +++ b/cmd/server/search_us_csl.go @@ -11,6 +11,7 @@ import ( moovhttp "github.com/moov-io/base/http" "github.com/moov-io/base/log" + "github.com/moov-io/watchman/internal/prepare" "github.com/moov-io/watchman/pkg/csl_us" ) @@ -37,16 +38,15 @@ func searchUSCSL(logger log.Logger, searcher *searcher) http.HandlerFunc { } } -func precomputeCSLEntities[T any](items []*T, pipe *pipeliner) []*Result[T] { +func precomputeCSLEntities[T any](items []*T, pipe *prepare.Pipeliner) []*Result[T] { out := make([]*Result[T], len(items)) if items == nil { return out } for i, item := range items { - name := cslName(item) + name := prepare.CSLName(item) if err := pipe.Do(name); err != nil { - pipe.logger.LogErrorf("problem pipelining %T: %v", item, err) continue } @@ -63,7 +63,7 @@ func precomputeCSLEntities[T any](items []*T, pipe *pipeliner) []*Result[T] { continue } for j := range alts { - alt := &Name{Processed: alts[j]} + alt := &prepare.Name{Processed: alts[j]} pipe.Do(alt) altNames = append(altNames, alt.Processed) } @@ -73,7 +73,7 @@ func precomputeCSLEntities[T any](items []*T, pipe *pipeliner) []*Result[T] { continue } for j := range alts { - alt := &Name{Processed: alts[j]} + alt := &prepare.Name{Processed: alts[j]} pipe.Do(alt) altNames = append(altNames, alt.Processed) } @@ -83,7 +83,7 @@ func precomputeCSLEntities[T any](items []*T, pipe *pipeliner) []*Result[T] { continue } for j := range alts { - alt := &Name{Processed: alts[j]} + alt := &prepare.Name{Processed: alts[j]} pipe.Do(alt) altNames = append(altNames, alt.Processed) } diff --git a/internal/largest/items.go b/internal/largest/items.go new file mode 100644 index 00000000..2781c768 --- /dev/null +++ b/internal/largest/items.go @@ -0,0 +1,66 @@ +// Copyright The Moov Authors +// Use of this source code is governed by an Apache License +// license that can be found in the LICENSE file. + +package largest + +import ( + "slices" + "sync" +) + +// Item represents an arbitrary value with an associated weight +type Item struct { + Value interface{} + Weight float64 +} + +// NewItems returns a structure which can be used to track items with the highest weights +func NewItems(capacity int, minMatch float64) *Items { + return &Items{ + items: make([]*Item, capacity), + capacity: capacity, + minMatch: minMatch, + } +} + +// Items keeps track of a set of items with the lowest weights. This is used to +// find the largest weighted values out of a much larger set. +type Items struct { + items []*Item + capacity int + minMatch float64 + mu sync.Mutex +} + +func (xs *Items) Add(it *Item) { + if it.Weight < xs.minMatch { + return // skip item as it's below our threshold + } + + xs.mu.Lock() + defer xs.mu.Unlock() + + for i := range xs.items { + if xs.items[i] == nil { + xs.items[i] = it // insert if we found empty slot + break + } + if xs.items[i].Weight < it.Weight { + xs.items = slices.Insert(xs.items, i, it) + break + } + } + if len(xs.items) > xs.capacity { + xs.items = xs.items[:xs.capacity] + } +} + +func (xs *Items) Items() []*Item { + xs.mu.Lock() + defer xs.mu.Unlock() + + out := make([]*Item, len(xs.items)) + copy(out, xs.items) + return out +} diff --git a/internal/largest/items_test.go b/internal/largest/items_test.go new file mode 100644 index 00000000..5f564ed1 --- /dev/null +++ b/internal/largest/items_test.go @@ -0,0 +1,139 @@ +// Copyright The Moov Authors +// Use of this source code is governed by an Apache License +// license that can be found in the LICENSE file. + +package largest + +import ( + "crypto/rand" + "fmt" + "math" + "math/big" + "testing" + + "github.com/moov-io/watchman/pkg/ofac" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func randomWeight() float64 { + n, _ := rand.Int(rand.Reader, big.NewInt(1000)) + return float64(n.Int64()) / 100.0 +} + +func TestLargest(t *testing.T) { + xs := NewItems(10, 0.0) + + min := 10000.0 + for i := 0; i < 1000; i++ { + it := &Item{ + Value: i, + Weight: randomWeight(), + } + xs.Add(it) + min = math.Min(min, it.Weight) + } + + // Check we didn't overflow + items := xs.Items() + require.Equal(t, len(items), 10) + + for i := range items { + if i+1 > len(items)-1 { + continue // don't hit index out of bounds + } + if items[i].Weight < 0.0001 { + t.Fatalf("weight of %.2f is too low", items[i].Weight) + } + if items[i].Weight < items[i+1].Weight { + t.Errorf("items[%d].Weight=%.2f < items[%d].Weight=%.2f", i, items[i].Weight, i+1, items[i+1].Weight) + } + } +} + +// TestLargest_MaxOrdering will test the ordering of 1.0 values to see +// if they hold their insert ordering. +func TestLargest_MaxOrdering(t *testing.T) { + xs := NewItems(10, 0.0) + + xs.Add(&Item{Value: "A", Weight: 0.99}) + xs.Add(&Item{Value: "B", Weight: 1.0}) + xs.Add(&Item{Value: "C", Weight: 1.0}) + xs.Add(&Item{Value: "D", Weight: 1.0}) + xs.Add(&Item{Value: "E", Weight: 0.97}) + + if n := len(xs.items); n != 10 { + t.Fatalf("found %d items: %#v", n, xs.items) + } + + if s, ok := xs.items[0].Value.(string); !ok || s != "B" { + t.Errorf("xs.items[0]=%#v", xs.items[0]) + } + if s, ok := xs.items[1].Value.(string); !ok || s != "C" { + t.Errorf("xs.items[1]=%#v", xs.items[1]) + } + if s, ok := xs.items[2].Value.(string); !ok || s != "D" { + t.Errorf("xs.items[2]=%#v", xs.items[2]) + } + if s, ok := xs.items[3].Value.(string); !ok || s != "A" { + t.Errorf("xs.items[3]=%#v", xs.items[3]) + } + if s, ok := xs.items[4].Value.(string); !ok || s != "E" { + t.Errorf("xs.items[4]=%#v", xs.items[4]) + } + for i := 5; i < 10; i++ { + if xs.items[i] != nil { + t.Errorf("#%d was non-nil: %#v", i, xs.items[i]) + } + } +} + +func TestLargest__MinMatch(t *testing.T) { + xs := NewItems(2, 0.96) + + xs.Add(&Item{Value: "A", Weight: 0.94}) + xs.Add(&Item{Value: "B", Weight: 1.0}) + xs.Add(&Item{Value: "C", Weight: 0.95}) + xs.Add(&Item{Value: "D", Weight: 0.09}) + + require.Equal(t, "B", xs.items[0].Value) + require.Nil(t, xs.items[1]) +} + +func BenchmarkLargest(b *testing.B) { + size := b.N * 500_000 + + scores := make([]float64, size) + for i := 0; i < b.N; i++ { + n, err := rand.Int(rand.Reader, big.NewInt(100)) + if err != nil { + b.Fatal(err) + } + scores[i] = float64(n.Int64()) / 100.0 + } + + limit := 20 + matches := []float64{0.1, 0.25, 0.5, 0.75, 0.9, 0.99} + for i := range matches { + b.Run(fmt.Sprintf("%.2f%%", matches[i]*100), func(b *testing.B) { + // accumulate scores + xs := NewItems(limit, matches[i]) + + g := &errgroup.Group{} + for i := range scores { + score := scores[i] + g.Go(func() error { + xs.Add(&Item{ + Value: ofac.SDN{}, + Weight: score, + }) + return nil + }) + } + require.NoError(b, g.Wait()) + require.Len(b, xs.items, limit) + require.Equal(b, limit, cap(xs.items)) + }) + } +} diff --git a/cmd/server/pipeline.go b/internal/prepare/pipeline.go similarity index 89% rename from cmd/server/pipeline.go rename to internal/prepare/pipeline.go index 6d2951a4..f85c25e7 100644 --- a/cmd/server/pipeline.go +++ b/internal/prepare/pipeline.go @@ -1,12 +1,11 @@ -// Copyright 2022 The Moov Authors +// Copyright The Moov Authors // Use of this source code is governed by an Apache License // license that can be found in the LICENSE file. -package main +package prepare import ( "errors" - "fmt" "github.com/moov-io/base/log" "github.com/moov-io/watchman/pkg/csl_eu" @@ -51,7 +50,7 @@ type Name struct { altNames []string } -func sdnName(sdn *ofac.SDN, addrs []*ofac.Address) *Name { +func SdnName(sdn *ofac.SDN, addrs []*ofac.Address) *Name { return &Name{ Original: sdn.SDNName, Processed: sdn.SDNName, @@ -60,7 +59,7 @@ func sdnName(sdn *ofac.SDN, addrs []*ofac.Address) *Name { } } -func altName(alt *ofac.AlternateIdentity) *Name { +func AltName(alt *ofac.AlternateIdentity) *Name { return &Name{ Original: alt.AlternateName, Processed: alt.AlternateName, @@ -68,7 +67,7 @@ func altName(alt *ofac.AlternateIdentity) *Name { } } -func dpName(dp *dpl.DPL) *Name { +func DPName(dp *dpl.DPL) *Name { return &Name{ Original: dp.Name, Processed: dp.Name, @@ -76,7 +75,7 @@ func dpName(dp *dpl.DPL) *Name { } } -func cslName(item interface{}) *Name { +func CSLName(item interface{}) *Name { switch v := item.(type) { case *csl_us.EL: return &Name{ @@ -212,7 +211,7 @@ func (ds *debugStep) apply(in *Name) error { return nil } -func newPipeliner(logger log.Logger, debug bool) *pipeliner { +func NewPipeliner(logger log.Logger, debug bool) *Pipeliner { steps := []step{ &reorderSDNStep{}, &companyNameCleanupStep{}, @@ -224,27 +223,27 @@ func newPipeliner(logger log.Logger, debug bool) *pipeliner { steps[i] = &debugStep{logger: logger, step: steps[i]} } } - return &pipeliner{ + return &Pipeliner{ logger: logger, steps: steps, } } -type pipeliner struct { +type Pipeliner struct { logger log.Logger steps []step } -func (p *pipeliner) Do(name *Name) error { +func (p *Pipeliner) Do(name *Name) error { if p == nil || p.steps == nil || p.logger == nil || name == nil { - return errors.New("nil pipeliner or Name") + return errors.New("nil Pipeliner or Name") } for i := range p.steps { if name == nil { - return fmt.Errorf("%T: nil Name", p.steps[i]) + return p.logger.Error().LogErrorf("%T: nil Name", p.steps[i]).Err() } if err := p.steps[i].apply(name); err != nil { - return fmt.Errorf("pipeline: %v", err) + return p.logger.Error().LogErrorf("pipeline: %v", err).Err() } } return nil diff --git a/cmd/server/pipeline_company_name_cleanup.go b/internal/prepare/pipeline_company_name_cleanup.go similarity index 95% rename from cmd/server/pipeline_company_name_cleanup.go rename to internal/prepare/pipeline_company_name_cleanup.go index b1e6069a..716ff812 100644 --- a/cmd/server/pipeline_company_name_cleanup.go +++ b/internal/prepare/pipeline_company_name_cleanup.go @@ -1,8 +1,8 @@ -// Copyright 2022 The Moov Authors +// Copyright The Moov Authors // Use of this source code is governed by an Apache License // license that can be found in the LICENSE file. -package main +package prepare import ( "strings" diff --git a/cmd/server/pipeline_company_name_cleanup_test.go b/internal/prepare/pipeline_company_name_cleanup_test.go similarity index 98% rename from cmd/server/pipeline_company_name_cleanup_test.go rename to internal/prepare/pipeline_company_name_cleanup_test.go index f8ed9d1b..374a5bfd 100644 --- a/cmd/server/pipeline_company_name_cleanup_test.go +++ b/internal/prepare/pipeline_company_name_cleanup_test.go @@ -1,8 +1,8 @@ -// Copyright 2022 The Moov Authors +// Copyright The Moov Authors // Use of this source code is governed by an Apache License // license that can be found in the LICENSE file. -package main +package prepare import ( "testing" diff --git a/cmd/server/pipeline_normalize.go b/internal/prepare/pipeline_normalize.go similarity index 86% rename from cmd/server/pipeline_normalize.go rename to internal/prepare/pipeline_normalize.go index 7495bdc5..9a2dbf7a 100644 --- a/cmd/server/pipeline_normalize.go +++ b/internal/prepare/pipeline_normalize.go @@ -1,8 +1,8 @@ -// Copyright 2022 The Moov Authors +// Copyright The Moov Authors // Use of this source code is governed by an Apache License // license that can be found in the LICENSE file. -package main +package prepare import ( "strings" @@ -21,17 +21,17 @@ var ( type normalizeStep struct{} func (s *normalizeStep) apply(in *Name) error { - in.Processed = precompute(in.Processed) + in.Processed = LowerAndRemovePunctuation(in.Processed) return nil } -// precompute will lowercase each substring and remove punctuation +// LowerAndRemovePunctuation will lowercase each substring and remove punctuation // // This function is called on every record from the flat files and all // search requests (i.e. HTTP and searcher.TopNNNs methods). // See: https://godoc.org/golang.org/x/text/unicode/norm#Form // See: https://withblue.ink/2019/03/11/why-you-need-to-normalize-unicode-strings.html -func precompute(s string) string { +func LowerAndRemovePunctuation(s string) string { trimmed := strings.TrimSpace(strings.ToLower(punctuationReplacer.Replace(s))) // UTF-8 normalization diff --git a/cmd/server/pipeline_normalize_test.go b/internal/prepare/pipeline_normalize_test.go similarity index 72% rename from cmd/server/pipeline_normalize_test.go rename to internal/prepare/pipeline_normalize_test.go index 27767ce8..677cbc7b 100644 --- a/cmd/server/pipeline_normalize_test.go +++ b/internal/prepare/pipeline_normalize_test.go @@ -1,8 +1,8 @@ -// Copyright 2022 The Moov Authors +// Copyright The Moov Authors // Use of this source code is governed by an Apache License // license that can be found in the LICENSE file. -package main +package prepare import ( "testing" @@ -21,9 +21,9 @@ func TestPipeline__normalizeStep(t *testing.T) { } } -// TestPrecompute ensures we are trimming and UTF-8 normalizing strings +// TestLowerAndRemovePunctuation ensures we are trimming and UTF-8 normalizing strings // as expected. This is needed since our datafiles are normalized for us. -func TestPrecompute(t *testing.T) { +func TestLowerAndRemovePunctuation(t *testing.T) { tests := []struct { name, input, expected string }{ @@ -36,9 +36,9 @@ func TestPrecompute(t *testing.T) { {"issue 483 #2", "11,420.2-1 CORP.", "114202 1 corp"}, } for i, tc := range tests { - guess := precompute(tc.input) + guess := LowerAndRemovePunctuation(tc.input) if guess != tc.expected { - t.Errorf("case: %d name: %s precompute(%q)=%q expected %q", i, tc.name, tc.input, guess, tc.expected) + t.Errorf("case: %d name: %s LowerAndRemovePunctuation(%q)=%q expected %q", i, tc.name, tc.input, guess, tc.expected) } } } diff --git a/cmd/server/pipeline_reorder.go b/internal/prepare/pipeline_reorder.go similarity index 96% rename from cmd/server/pipeline_reorder.go rename to internal/prepare/pipeline_reorder.go index d52160ce..88737710 100644 --- a/cmd/server/pipeline_reorder.go +++ b/internal/prepare/pipeline_reorder.go @@ -1,8 +1,8 @@ -// Copyright 2022 The Moov Authors +// Copyright The Moov Authors // Use of this source code is governed by an Apache License // license that can be found in the LICENSE file. -package main +package prepare import ( "fmt" diff --git a/cmd/server/pipeline_reorder_test.go b/internal/prepare/pipeline_reorder_test.go similarity index 97% rename from cmd/server/pipeline_reorder_test.go rename to internal/prepare/pipeline_reorder_test.go index 5d1e670a..59567449 100644 --- a/cmd/server/pipeline_reorder_test.go +++ b/internal/prepare/pipeline_reorder_test.go @@ -1,8 +1,8 @@ -// Copyright 2022 The Moov Authors +// Copyright The Moov Authors // Use of this source code is governed by an Apache License // license that can be found in the LICENSE file. -package main +package prepare import ( "testing" diff --git a/cmd/server/pipeline_stopwords.go b/internal/prepare/pipeline_stopwords.go similarity index 98% rename from cmd/server/pipeline_stopwords.go rename to internal/prepare/pipeline_stopwords.go index 68da6e48..29febd2c 100644 --- a/cmd/server/pipeline_stopwords.go +++ b/internal/prepare/pipeline_stopwords.go @@ -1,8 +1,8 @@ -// Copyright 2022 The Moov Authors +// Copyright The Moov Authors // Use of this source code is governed by an Apache License // license that can be found in the LICENSE file. -package main +package prepare import ( "os" diff --git a/cmd/server/pipeline_stopwords_test.go b/internal/prepare/pipeline_stopwords_test.go similarity index 99% rename from cmd/server/pipeline_stopwords_test.go rename to internal/prepare/pipeline_stopwords_test.go index 099957b4..ab9ec6de 100644 --- a/cmd/server/pipeline_stopwords_test.go +++ b/internal/prepare/pipeline_stopwords_test.go @@ -1,8 +1,8 @@ -// Copyright 2022 The Moov Authors +// Copyright The Moov Authors // Use of this source code is governed by an Apache License // license that can be found in the LICENSE file. -package main +package prepare import ( "testing" diff --git a/cmd/server/pipeline_test.go b/internal/prepare/pipeline_test.go similarity index 93% rename from cmd/server/pipeline_test.go rename to internal/prepare/pipeline_test.go index 71657142..dd0e7491 100644 --- a/cmd/server/pipeline_test.go +++ b/internal/prepare/pipeline_test.go @@ -1,8 +1,8 @@ -// Copyright 2022 The Moov Authors +// Copyright The Moov Authors // Use of this source code is governed by an Apache License // license that can be found in the LICENSE file. -package main +package prepare import ( "testing" @@ -12,12 +12,12 @@ import ( ) var ( - noopPipeliner = &pipeliner{ + noopPipeliner = &Pipeliner{ logger: log.NewNopLogger(), steps: []step{}, } - noLogPipeliner = newPipeliner(log.NewNopLogger(), false) + noLogPipeliner = NewPipeliner(log.NewNopLogger(), false) ) func TestPipelineNoop(t *testing.T) { diff --git a/internal/search/model_searched_entity.go b/internal/search/model_searched_entity.go new file mode 100644 index 00000000..59b3ba0d --- /dev/null +++ b/internal/search/model_searched_entity.go @@ -0,0 +1,11 @@ +package search + +import ( + "github.com/moov-io/watchman/pkg/search" +) + +type SearchedEntity[T any] struct { + search.Entity[T] + + Match float64 `json:"match"` +} diff --git a/internal/search/model_searched_entity_test.go b/internal/search/model_searched_entity_test.go new file mode 100644 index 00000000..4679cd0f --- /dev/null +++ b/internal/search/model_searched_entity_test.go @@ -0,0 +1,50 @@ +package search + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/moov-io/watchman/pkg/search" + + "github.com/stretchr/testify/require" +) + +func TestSearchedEntityJSON(t *testing.T) { + type SDN struct { + EntityID string `json:"entityID"` + } + + bs, err := json.MarshalIndent(SearchedEntity[SDN]{ + Entity: search.Entity[SDN]{ + SourceData: SDN{ + EntityID: "12345", + }, + }, + Match: 0.6401, + }, "", " ") + require.NoError(t, err) + + expected := strings.TrimSpace(`{ + "name": "", + "entityType": "", + "sourceList": "", + "sourceID": "", + "person": null, + "business": null, + "organization": null, + "aircraft": null, + "vessel": null, + "cryptoAddresses": null, + "addresses": null, + "affiliations": null, + "sanctionsInfo": null, + "historicalInfo": null, + "titles": null, + "sourceData": { + "entityID": "12345" + }, + "match": 0.6401 +}`) + require.Equal(t, expected, string(bs)) +} diff --git a/internal/search/service.go b/internal/search/service.go index 57a53841..b62cc3a0 100644 --- a/internal/search/service.go +++ b/internal/search/service.go @@ -14,6 +14,9 @@ type Service interface { } func NewService[T any](logger log.Logger, entities []search.Entity[T]) Service { + + fmt.Printf("v2search NewService(%d entities)\n", len(entities)) //nolint:forbidigo + return &service[T]{ logger: logger, entities: entities, @@ -33,4 +36,9 @@ func (s *service[T]) Search(ctx context.Context) { return } } + + // TODO(adam): use SearchedEntity + // type SearchedEntity[T any] struct { + // search.Entity[T] + // Match float64 `json:"match"` } diff --git a/internal/stringscore/jaro_winkler.go b/internal/stringscore/jaro_winkler.go new file mode 100644 index 00000000..b806d76a --- /dev/null +++ b/internal/stringscore/jaro_winkler.go @@ -0,0 +1,266 @@ +package stringscore + +import ( + "fmt" + "math" + "os" + "sort" + "strconv" + "strings" + + "github.com/moov-io/base/strx" + + "github.com/xrash/smetrics" +) + +var ( + // Jaro-Winkler parameters + boostThreshold = readFloat(os.Getenv("JARO_WINKLER_BOOST_THRESHOLD"), 0.7) + prefixSize = readInt(os.Getenv("JARO_WINKLER_PREFIX_SIZE"), 4) + // Customised Jaro-Winkler parameters + lengthDifferenceCutoffFactor = readFloat(os.Getenv("LENGTH_DIFFERENCE_CUTOFF_FACTOR"), 0.9) + lengthDifferencePenaltyWeight = readFloat(os.Getenv("LENGTH_DIFFERENCE_PENALTY_WEIGHT"), 0.3) + differentLetterPenaltyWeight = readFloat(os.Getenv("DIFFERENT_LETTER_PENALTY_WEIGHT"), 0.9) + + // Watchman parameters + exactMatchFavoritism = readFloat(os.Getenv("EXACT_MATCH_FAVORITISM"), 0.0) + unmatchedIndexPenaltyWeight = readFloat(os.Getenv("UNMATCHED_INDEX_TOKEN_WEIGHT"), 0.15) +) + +func readFloat(override string, value float64) float64 { + if override != "" { + n, err := strconv.ParseFloat(override, 32) + if err != nil { + panic(fmt.Errorf("unable to parse %q as float64", override)) //nolint:forbidigo + } + return n + } + return value +} + +func readInt(override string, value int) int { + if override != "" { + n, err := strconv.ParseInt(override, 10, 32) + if err != nil { + panic(fmt.Errorf("unable to parse %q as int", override)) //nolint:forbidigo + } + return int(n) + } + return value +} + +// BestPairsJaroWinkler compares a search query to an indexed term (name, address, etc) and returns a decimal fraction +// score. +// +// The algorithm splits each string into tokens, and does a pairwise Jaro-Winkler score of all token combinations +// (outer product). The best match for each search token is chosen, such that each index token can be matched at most +// once. +// +// The pairwise scores are combined into an average in a way that corrects for character length, and the fraction of the +// indexed term that didn't match. +func BestPairsJaroWinkler(searchTokens []string, indexed string) float64 { + type Score struct { + score float64 + searchTokenIdx int + indexTokenIdx int + } + + indexedTokens := strings.Fields(indexed) + searchTokensLength := sumLength(searchTokens) + indexTokensLength := sumLength(indexedTokens) + + disablePhoneticFiltering := strx.Yes(os.Getenv("DISABLE_PHONETIC_FILTERING")) + + //Compare each search token to each indexed token. Sort the results in descending order + scoresCapacity := (len(searchTokens) + len(indexedTokens)) + if !disablePhoneticFiltering { + scoresCapacity /= 5 // reduce the capacity as many terms don't phonetically match + } + scores := make([]Score, 0, scoresCapacity) + for searchIdx, searchToken := range searchTokens { + for indexIdx, indexedToken := range indexedTokens { + // Compare the first letters phonetically and only run jaro-winkler on those which are similar + if disablePhoneticFiltering || firstCharacterSoundexMatch(indexedToken, searchToken) { + score := customJaroWinkler(indexedToken, searchToken) + scores = append(scores, Score{score, searchIdx, indexIdx}) + } + } + } + sort.Slice(scores[:], func(i, j int) bool { + return scores[i].score > scores[j].score + }) + + //Pick the highest score for each search term, where the indexed token hasn't yet been matched + matchedSearchTokens := make([]bool, len(searchTokens)) + matchedIndexTokens := make([]bool, len(indexedTokens)) + matchedIndexTokensLength := 0 + totalWeightedScores := 0.0 + for _, score := range scores { + //If neither the search token nor index token have been matched so far + if !matchedSearchTokens[score.searchTokenIdx] && !matchedIndexTokens[score.indexTokenIdx] { + //Weight the importance of this word score by its character length + searchToken := searchTokens[score.searchTokenIdx] + indexToken := indexedTokens[score.indexTokenIdx] + totalWeightedScores += score.score * float64(len(searchToken)+len(indexToken)) + + matchedSearchTokens[score.searchTokenIdx] = true + matchedIndexTokens[score.indexTokenIdx] = true + matchedIndexTokensLength += len(indexToken) + } + } + lengthWeightedAverageScore := totalWeightedScores / float64(searchTokensLength+matchedIndexTokensLength) + + //If some index tokens weren't matched by any search token, penalise this search a small amount. If this isn't done, + //a query of "John Doe" will match "John Doe" and "John Bartholomew Doe" equally well. + //Calculate the fraction of the index name that wasn't matched, apply a weighting to reduce the importance of + //unmatched portion, then scale down the final score. + matchedIndexLength := 0 + for i, str := range indexedTokens { + if matchedIndexTokens[i] { + matchedIndexLength += len(str) + } + } + matchedFraction := float64(matchedIndexLength) / float64(indexTokensLength) + return lengthWeightedAverageScore * scalingFactor(matchedFraction, unmatchedIndexPenaltyWeight) +} + +func customJaroWinkler(s1 string, s2 string) float64 { + score := smetrics.JaroWinkler(s1, s2, boostThreshold, prefixSize) + + if lengthMetric := lengthDifferenceFactor(s1, s2); lengthMetric < lengthDifferenceCutoffFactor { + //If there's a big difference in matched token lengths, punish the score. Jaro-Winkler is quite permissive about + //different lengths + score = score * scalingFactor(lengthMetric, lengthDifferencePenaltyWeight) + } + if s1[0] != s2[0] { + //Penalise words that start with a different characters. Jaro-Winkler is too lenient on this + //TODO should use a phonetic comparison here, like Soundex + score = score * differentLetterPenaltyWeight + } + return score +} + +// scalingFactor returns a float [0,1] that can be used to scale another number down, given some metric and a desired +// weight +// e.g. If a score has a 50% value according to a metric, and we want a 10% weight to the metric: +// +// scaleFactor := scalingFactor(0.5, 0.1) // 0.95 +// scaledScore := score * scaleFactor +func scalingFactor(metric float64, weight float64) float64 { + return 1.0 - (1.0-metric)*weight +} + +func sumLength(strs []string) int { + totalLength := 0 + for _, str := range strs { + totalLength += len(str) + } + return totalLength +} + +func lengthDifferenceFactor(s1 string, s2 string) float64 { + ls1 := float64(len(s1)) + ls2 := float64(len(s2)) + min := math.Min(ls1, ls2) + max := math.Max(ls1, ls2) + return min / max +} + +// jaroWinkler runs the similarly named algorithm over the two input strings and averages their match percentages +// according to the second string (assumed to be the user's query) +// +// Terms are compared between a few adjacent terms and accumulate the highest near-neighbor match. +// +// For more details see https://en.wikipedia.org/wiki/Jaro%E2%80%93Winkler_distance +func JaroWinkler(s1, s2 string) float64 { + return JaroWinklerWithFavoritism(s1, s2, exactMatchFavoritism) +} + +var ( + adjacentSimilarityPositions = readInt(os.Getenv("ADJACENT_SIMILARITY_POSITIONS"), 3) +) + +func JaroWinklerWithFavoritism(indexedTerm, query string, favoritism float64) float64 { + maxMatch := func(indexedWord string, indexedWordIdx int, queryWords []string) (float64, string) { + if indexedWord == "" || len(queryWords) == 0 { + return 0.0, "" + } + + // We're only looking for the highest match close + start := indexedWordIdx - adjacentSimilarityPositions + end := indexedWordIdx + adjacentSimilarityPositions + + var max float64 + var maxTerm string + for i := start; i < end; i++ { + if i >= 0 && len(queryWords) > i { + score := smetrics.JaroWinkler(indexedWord, queryWords[i], boostThreshold, prefixSize) + if score > max { + max = score + maxTerm = queryWords[i] + } + } + } + return max, maxTerm + } + + indexedWords, queryWords := strings.Fields(indexedTerm), strings.Fields(query) + if len(indexedWords) == 0 || len(queryWords) == 0 { + return 0.0 // avoid returning NaN later on + } + + var scores []float64 + for i := range indexedWords { + max, term := maxMatch(indexedWords[i], i, queryWords) + //fmt.Printf("%s maxMatch %s %f\n", indexedWords[i], term, max) + if max >= 1.0 { + // If the query is longer than our indexed term (and EITHER are longer than most names) + // we want to reduce the maximum weight proportionally by the term difference, which + // forces more terms to match instead of one or two dominating the weight. + if (len(queryWords) > len(indexedWords)) && (len(indexedWords) > 3 || len(queryWords) > 3) { + max *= (float64(len(indexedWords)) / float64(len(queryWords))) + goto add + } + // If the indexed term is really short cap the match at 90%. + // This sill allows names to match highly with a couple different characters. + if len(indexedWords) == 1 && len(queryWords) > 1 { + max *= 0.9 + goto add + } + // Otherwise, apply Perfect match favoritism + max += favoritism + add: + scores = append(scores, max) + } else { + // If there are more terms in the user's query than what's indexed then + // adjust the max lower by the proportion of different terms. + // + // We do this to decrease the importance of a short (often common) term. + if len(queryWords) > len(indexedWords) { + scores = append(scores, max*float64(len(indexedWords))/float64(len(queryWords))) + continue + } + + // Apply an additional weight based on similarity of term lengths, + // so terms which are closer in length match higher. + s1 := float64(len(indexedWords[i])) + t := float64(len(term)) - 1 + weight := math.Min(math.Abs(s1/t), 1.0) + + scores = append(scores, max*weight) + } + } + + // average the highest N scores where N is the words in our query (query). + // Only truncate scores if there are enough words (aka more than First/Last). + sort.Float64s(scores) + if len(indexedWords) > len(queryWords) && len(queryWords) > 5 { + scores = scores[len(indexedWords)-len(queryWords):] + } + + var sum float64 + for i := range scores { + sum += scores[i] + } + return math.Min(sum/float64(len(scores)), 1.00) +} diff --git a/internal/stringscore/jaro_winkler_test.go b/internal/stringscore/jaro_winkler_test.go new file mode 100644 index 00000000..33aea850 --- /dev/null +++ b/internal/stringscore/jaro_winkler_test.go @@ -0,0 +1,200 @@ +// Copyright The Moov Authors +// Use of this source code is governed by an Apache License +// license that can be found in the LICENSE file. + +package stringscore_test + +import ( + "crypto/rand" + "fmt" + "io" + "math" + "math/big" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/moov-io/watchman/internal/prepare" + "github.com/moov-io/watchman/internal/stringscore" + "github.com/moov-io/watchman/pkg/ofac" + + "github.com/jaswdr/faker" + "github.com/stretchr/testify/require" +) + +func TestJaroWinkler(t *testing.T) { + cases := []struct { + indexed, search string + match float64 + }{ + // examples + {"wei, zhao", "wei, Zhao", 0.875}, + {"WEI, Zhao", "WEI, Zhao", 1.0}, + {"WEI Zhao", "WEI Zhao", 1.0}, + {strings.ToLower("WEI Zhao"), prepare.LowerAndRemovePunctuation("WEI, Zhao"), 1.0}, + + // apply jaroWinkler in both directions + {"jane doe", "jan lahore", 0.439}, + {"jan lahore", "jane doe", 0.549}, + + // real world case + {"john doe", "paul john", 0.624}, + {"john doe", "john othername", 0.440}, + + // close match + {"jane doe", "jane doe2", 0.940}, + + // real-ish world examples + {"kalamity linden", "kala limited", 0.687}, + {"kala limited", "kalamity linden", 0.687}, + + // examples used in demos / commonly + {"nicolas", "nicolas", 1.0}, + {"nicolas moros maduro", "nicolas maduro", 0.958}, + {"nicolas maduro", "nicolas moros maduro", 0.839}, + + // customer examples + {"ian", "ian mckinley", 0.429}, + {"iap", "ian mckinley", 0.352}, + {"ian mckinley", "ian", 0.891}, + {"ian mckinley", "iap", 0.733}, + {"ian mckinley", "tian xiang 7", 0.000}, + {"bindaree food group pty", prepare.LowerAndRemovePunctuation("independent insurance group ltd"), 0.269}, // removes ltd + {"bindaree food group pty ltd", "independent insurance group ltd", 0.401}, // only matches higher from 'ltd' + {"p.c.c. (singapore) private limited", "culver max entertainment private limited", 0.514}, + {"zincum llc", "easy verification inc.", 0.000}, + {"transpetrochart co ltd", "jx metals trading co.", 0.431}, + {"technolab", "moomoo technologies inc", 0.565}, + {"sewa security services", "sesa - safety & environmental services australia pty ltd", 0.480}, + {"bueno", "20/f rykadan capital twr135 hoi bun rd, kwun tong 135 hoi bun rd., kwun tong", 0.094}, + + // example cases + {"nicolas maduro", "nicolás maduro", 0.937}, + {"nicolas maduro", prepare.LowerAndRemovePunctuation("nicolás maduro"), 1.0}, + {"nic maduro", "nicolas maduro", 0.872}, + {"nick maduro", "nicolas maduro", 0.859}, + {"nicolas maduroo", "nicolas maduro", 0.966}, + {"nicolas maduro", "nicolas maduro", 1.0}, + {"maduro, nicolas", "maduro, nicolas", 1.0}, + {"maduro moros, nicolas", "maduro moros, nicolas", 1.0}, + {"maduro moros, nicolas", "nicolas maduro", 0.953}, + {"nicolas maduro moros", "maduro", 0.900}, + {"nicolas maduro moros", "nicolás maduro", 0.898}, + {"nicolas, maduro moros", "maduro", 0.897}, + {"nicolas, maduro moros", "nicolas maduro", 0.928}, + {"nicolas, maduro moros", "nicolás", 0.822}, + {"nicolas, maduro moros", "maduro", 0.897}, + {"nicolas, maduro moros", "nicolás maduro", 0.906}, + {"africada financial services bureau change", "skylight", 0.441}, + {"africada financial services bureau change", "skylight financial inc", 0.658}, + {"africada financial services bureau change", "skylight services inc", 0.599}, + {"africada financial services bureau change", "skylight financial services", 0.761}, + {"africada financial services bureau change", "skylight financial services inc", 0.730}, + + // stopwords tests + {"the group for the preservation of the holy sites", "the bridgespan group", 0.682}, + {prepare.LowerAndRemovePunctuation("the group for the preservation of the holy sites"), prepare.LowerAndRemovePunctuation("the bridgespan group"), 0.682}, + {"group preservation holy sites", "bridgespan group", 0.652}, + + {"the group for the preservation of the holy sites", "the logan group", 0.670}, + {prepare.LowerAndRemovePunctuation("the group for the preservation of the holy sites"), prepare.LowerAndRemovePunctuation("the logan group"), 0.670}, + {"group preservation holy sites", "logan group", 0.586}, + + {"the group for the preservation of the holy sites", "the anything group", 0.546}, + {prepare.LowerAndRemovePunctuation("the group for the preservation of the holy sites"), prepare.LowerAndRemovePunctuation("the anything group"), 0.546}, + {"group preservation holy sites", "anything group", 0.488}, + + {"the group for the preservation of the holy sites", "the hello world group", 0.637}, + {prepare.LowerAndRemovePunctuation("the group for the preservation of the holy sites"), prepare.LowerAndRemovePunctuation("the hello world group"), 0.637}, + {"group preservation holy sites", "hello world group", 0.577}, + + {"the group for the preservation of the holy sites", "the group", 0.880}, + {prepare.LowerAndRemovePunctuation("the group for the preservation of the holy sites"), prepare.LowerAndRemovePunctuation("the group"), 0.880}, + {"group preservation holy sites", "group", 0.879}, + + {"the group for the preservation of the holy sites", "The flibbity jibbity flobbity jobbity grobbity zobbity group", 0.345}, + { + prepare.LowerAndRemovePunctuation("the group for the preservation of the holy sites"), + prepare.LowerAndRemovePunctuation("the flibbity jibbity flobbity jobbity grobbity zobbity group"), + 0.366, + }, + {"group preservation holy sites", "flibbity jibbity flobbity jobbity grobbity zobbity group", 0.263}, + + // prepare.LowerAndRemovePunctuation + {"i c sogo kenkyusho", prepare.LowerAndRemovePunctuation("A.I.C. SOGO KENKYUSHO"), 0.858}, + {prepare.LowerAndRemovePunctuation("A.I.C. SOGO KENKYUSHO"), "sogo kenkyusho", 0.972}, + } + for i := range cases { + v := cases[i] + // Only need to call chomp on s1, see jaroWinkler doc + eql(t, fmt.Sprintf("#%d %s vs %s", i, v.indexed, v.search), stringscore.BestPairsJaroWinkler(strings.Fields(v.search), v.indexed), v.match) + } +} + +func TestJaroWinklerWithFavoritism(t *testing.T) { + favoritism := 1.0 + delta := 0.01 + + score := stringscore.JaroWinklerWithFavoritism("Vladimir Putin", "PUTIN, Vladimir Vladimirovich", favoritism) + require.InDelta(t, score, 1.00, delta) + + score = stringscore.JaroWinklerWithFavoritism("nicolas, maduro moros", "nicolás maduro", 0.25) + require.InDelta(t, score, 0.96, delta) + + score = stringscore.JaroWinklerWithFavoritism("Vladimir Putin", "A.I.C. SOGO KENKYUSHO", favoritism) + require.InDelta(t, score, 0.00, delta) +} + +func TestJaroWinklerErr(t *testing.T) { + v := stringscore.JaroWinkler("", "hello") + eql(t, "NaN #1", v, 0.0) + + v = stringscore.JaroWinkler("hello", "") + eql(t, "NaN #1", v, 0.0) +} + +func eql(t *testing.T, desc string, x, y float64) { + t.Helper() + if math.IsNaN(x) || math.IsNaN(y) { + t.Fatalf("%s: x=%.2f y=%.2f", desc, x, y) + } + if math.Abs(x-y) > 0.01 { + t.Errorf("%s: %.3f != %.3f", desc, x, y) + } +} + +func TestEql(t *testing.T) { + eql(t, "", 0.1, 0.1) + eql(t, "", 0.0001, 0.00002) +} + +func BenchmarkJaroWinkler(b *testing.B) { + fd, err := os.Open(filepath.Join("..", "..", "test", "testdata", "sdn.csv")) + if err != nil { + b.Error(err) + } + results, err := ofac.Read(map[string]io.ReadCloser{"sdn.csv": fd}) + require.NoError(b, err) + require.Len(b, results.SDNs, 7379) + + randomIndex := func(length int) int { + n, err := rand.Int(rand.Reader, big.NewInt(1e9)) + if err != nil { + panic(err) + } + return int(n.Int64()) % length + } + + fake := faker.New() + + b.Run("BestPairsJaroWinkler", func(b *testing.B) { + for i := 0; i < b.N; i++ { + nameTokens := strings.Fields(fake.Person().Name()) + idx := randomIndex(len(results.SDNs)) + + score := stringscore.BestPairsJaroWinkler(nameTokens, results.SDNs[idx].SDNName) + require.Greater(b, score, -0.01) + } + }) +} diff --git a/cmd/server/new_algorithm_test.go b/internal/stringscore/new_algorithm_test.go similarity index 92% rename from cmd/server/new_algorithm_test.go rename to internal/stringscore/new_algorithm_test.go index 2a24cc68..8c60cd7e 100644 --- a/cmd/server/new_algorithm_test.go +++ b/internal/stringscore/new_algorithm_test.go @@ -1,12 +1,14 @@ -// Copyright 2022 The Moov Authors +// Copyright The Moov Authors // Use of this source code is governed by an Apache License // license that can be found in the LICENSE file. -package main +package stringscore_test import ( "strings" "testing" + + "github.com/moov-io/watchman/internal/stringscore" ) func TestBestPairsJaroWinkler__FalsePositives(t *testing.T) { @@ -70,7 +72,7 @@ func TestBestPairsJaroWinkler__TruePositives(t *testing.T) { } func compareAlgorithms(indexedName string, query string) (float64, float64) { - oldScore := jaroWinkler(indexedName, query) - newScore := bestPairsJaroWinkler(strings.Fields(query), indexedName) + oldScore := stringscore.JaroWinkler(indexedName, query) + newScore := stringscore.BestPairsJaroWinkler(strings.Fields(query), indexedName) return oldScore, newScore } diff --git a/cmd/server/phonetics.go b/internal/stringscore/phonetics.go similarity index 98% rename from cmd/server/phonetics.go rename to internal/stringscore/phonetics.go index 18107a33..cec3d6ae 100644 --- a/cmd/server/phonetics.go +++ b/internal/stringscore/phonetics.go @@ -1,4 +1,4 @@ -package main +package stringscore import ( "unicode" diff --git a/cmd/server/phonetics_test.go b/internal/stringscore/phonetics_test.go similarity index 89% rename from cmd/server/phonetics_test.go rename to internal/stringscore/phonetics_test.go index e25ef4b0..0e17d1cd 100644 --- a/cmd/server/phonetics_test.go +++ b/internal/stringscore/phonetics_test.go @@ -1,4 +1,4 @@ -package main +package stringscore import ( "strings" @@ -24,12 +24,12 @@ func TestDisablePhoneticFiltering(t *testing.T) { indexed := "tian xiang 7" t.Setenv("DISABLE_PHONETIC_FILTERING", "no") - score := bestPairsJaroWinkler(search, indexed) + score := BestPairsJaroWinkler(search, indexed) require.InDelta(t, 0.00, score, 0.01) // Disable filtering (force the compare) t.Setenv("DISABLE_PHONETIC_FILTERING", "yes") - score = bestPairsJaroWinkler(search, indexed) + score = BestPairsJaroWinkler(search, indexed) require.InDelta(t, 0.544, score, 0.01) } diff --git a/pkg/search/similarity.go b/pkg/search/similarity.go new file mode 100644 index 00000000..c50e4e36 --- /dev/null +++ b/pkg/search/similarity.go @@ -0,0 +1,388 @@ +package search + +import ( + "math" + "strings" + "time" +) + +func Similarity[Q any, I any](query Entity[Q], index Entity[I]) float64 { + var parts []partial + + // 1) Compare top-level entity fields + score, weight := compareStringField(query.Name, index.Name, 2.0) + parts = append(parts, partial{Score: score, Weight: weight}) + + score, weight = compareStringField(string(query.Type), string(index.Type), 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + + score, weight = compareStringField(string(query.Source), string(index.Source), 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + + score, weight = compareStringField(query.SourceID, index.SourceID, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + + // Titles (slice of strings) + score, weight = compareStringSlice(query.Titles, index.Titles, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + + // Person + if query.Person != nil && index.Person != nil { + // Name + score, weight = compareStringField(query.Person.Name, index.Person.Name, 2.0) + parts = append(parts, partial{Score: score, Weight: weight}) + // Gender + score, weight = compareStringField(string(query.Person.Gender), string(index.Person.Gender), 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + // BirthDate + score, weight = compareDateField(query.Person.BirthDate, index.Person.BirthDate, 2.0) + parts = append(parts, partial{Score: score, Weight: weight}) + // DeathDate + score, weight = compareDateField(query.Person.DeathDate, index.Person.DeathDate, 1.5) + parts = append(parts, partial{Score: score, Weight: weight}) + // AltNames + score, weight = compareStringSlice(query.Person.AltNames, index.Person.AltNames, 1.5) + parts = append(parts, partial{Score: score, Weight: weight}) + // Titles + score, weight = compareStringSlice(query.Person.Titles, index.Person.Titles, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + // GovernmentIDs (naive) + if len(query.Person.GovernmentIDs) > 0 && len(index.Person.GovernmentIDs) > 0 { + var qIDs, iIDs []string + for _, gid := range query.Person.GovernmentIDs { + qIDs = append(qIDs, string(gid.Type), gid.Country, gid.Identifier) + } + for _, gid := range index.Person.GovernmentIDs { + iIDs = append(iIDs, string(gid.Type), gid.Country, gid.Identifier) + } + score, weight = compareStringSlice(qIDs, iIDs, 2.0) + parts = append(parts, partial{Score: score, Weight: weight}) + } + } + + // Business + if query.Business != nil && index.Business != nil { + score, weight = compareStringField(query.Business.Name, index.Business.Name, 2.0) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareDateField(query.Business.Created, index.Business.Created, 1.5) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareDateField(query.Business.Dissolved, index.Business.Dissolved, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + // Compare Identifiers + if len(query.Business.Identifier) > 0 && len(index.Business.Identifier) > 0 { + var qIDs, iIDs []string + for _, id := range query.Business.Identifier { + qIDs = append(qIDs, id.Name, id.Country, id.Identifier) + } + for _, id := range index.Business.Identifier { + iIDs = append(iIDs, id.Name, id.Country, id.Identifier) + } + score, weight = compareStringSlice(qIDs, iIDs, 2.0) + parts = append(parts, partial{Score: score, Weight: weight}) + } + } + + // Organization + if query.Organization != nil && index.Organization != nil { + score, weight = compareStringField(query.Organization.Name, index.Organization.Name, 2.0) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareDateField(query.Organization.Created, index.Organization.Created, 1.5) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareDateField(query.Organization.Dissolved, index.Organization.Dissolved, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + if len(query.Organization.Identifier) > 0 && len(index.Organization.Identifier) > 0 { + var qIDs, iIDs []string + for _, id := range query.Organization.Identifier { + qIDs = append(qIDs, id.Name, id.Country, id.Identifier) + } + for _, id := range index.Organization.Identifier { + iIDs = append(iIDs, id.Name, id.Country, id.Identifier) + } + score, weight = compareStringSlice(qIDs, iIDs, 2.0) + parts = append(parts, partial{Score: score, Weight: weight}) + } + } + + // Aircraft + if query.Aircraft != nil && index.Aircraft != nil { + score, weight = compareStringField(query.Aircraft.Name, index.Aircraft.Name, 2.0) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareStringField(string(query.Aircraft.Type), string(index.Aircraft.Type), 1.5) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareStringField(query.Aircraft.Flag, index.Aircraft.Flag, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareDateField(query.Aircraft.Built, index.Aircraft.Built, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareStringField(query.Aircraft.ICAOCode, index.Aircraft.ICAOCode, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareStringField(query.Aircraft.Model, index.Aircraft.Model, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareStringField(query.Aircraft.SerialNumber, index.Aircraft.SerialNumber, 1.5) + parts = append(parts, partial{Score: score, Weight: weight}) + } + + // Vessel + if query.Vessel != nil && index.Vessel != nil { + score, weight = compareStringField(query.Vessel.Name, index.Vessel.Name, 2.0) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareStringField(query.Vessel.IMONumber, index.Vessel.IMONumber, 1.5) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareStringField(string(query.Vessel.Type), string(index.Vessel.Type), 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareStringField(query.Vessel.Flag, index.Vessel.Flag, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareDateField(query.Vessel.Built, index.Vessel.Built, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareStringField(query.Vessel.Model, index.Vessel.Model, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + // Tonnage + if query.Vessel.Tonnage > 0 && index.Vessel.Tonnage > 0 { + diff := math.Abs(float64(query.Vessel.Tonnage) - float64(index.Vessel.Tonnage)) + var matchScore float64 + if diff == 0 { + matchScore = 1.0 + } else if diff < 500 { + matchScore = 0.5 + } else { + matchScore = 0 + } + parts = append(parts, partial{Score: matchScore, Weight: 1.0}) + } + score, weight = compareStringField(query.Vessel.MMSI, index.Vessel.MMSI, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + score, weight = compareStringField(query.Vessel.CallSign, index.Vessel.CallSign, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + // GrossRegisteredTonnage + if query.Vessel.GrossRegisteredTonnage > 0 && index.Vessel.GrossRegisteredTonnage > 0 { + diff := math.Abs(float64(query.Vessel.GrossRegisteredTonnage) - + float64(index.Vessel.GrossRegisteredTonnage)) + var matchScore float64 + if diff == 0 { + matchScore = 1.0 + } else if diff < 500 { + matchScore = 0.5 + } else { + matchScore = 0 + } + parts = append(parts, partial{Score: matchScore, Weight: 1.0}) + } + score, weight = compareStringField(query.Vessel.Owner, index.Vessel.Owner, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + } + + // CryptoAddresses + score, weight = compareCryptoAddresses(query.CryptoAddresses, index.CryptoAddresses, 1.5) + parts = append(parts, partial{Score: score, Weight: weight}) + + // Addresses + score, weight = compareAddresses(query.Addresses, index.Addresses, 1.5) + parts = append(parts, partial{Score: score, Weight: weight}) + + // Affiliations + score, weight = compareAffiliations(query.Affiliations, index.Affiliations, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + + // SanctionsInfo + score, weight = compareSanctionsInfo(query.SanctionsInfo, index.SanctionsInfo, 1.5) + parts = append(parts, partial{Score: score, Weight: weight}) + + // HistoricalInfo + score, weight = compareHistoricalInfo(query.HistoricalInfo, index.HistoricalInfo, 1.0) + parts = append(parts, partial{Score: score, Weight: weight}) + + // SourceData (T) fields are not included in the scoring + + return combineScores(parts) +} + +// jaroWinklerDistance is a placeholder for your advanced string comparison. +func jaroWinklerDistance(a, b string) float64 { + a = strings.TrimSpace(strings.ToLower(a)) + b = strings.TrimSpace(strings.ToLower(b)) + if a == "" || b == "" { + return 0.0 + } + if a == b { + return 1.0 + } + // Replace with real logic + return 0.5 +} + +// compareStringField returns (score, weight). +// If the query is empty, the field is skipped (no penalty, no weight). +func compareStringField(queryVal, indexVal string, weight float64) (float64, float64) { + q := strings.TrimSpace(queryVal) + i := strings.TrimSpace(indexVal) + + if q == "" { + return 0, 0 + } + if i == "" { + return 0, weight + } + dist := jaroWinklerDistance(q, i) + return dist * weight, weight +} + +// compareDateField treats nil query as "skip," nil index as mismatch, and +// otherwise uses a simplistic "within 1 year => full match, within 5 => partial" strategy. +func compareDateField(queryVal, indexVal *time.Time, weight float64) (float64, float64) { + if queryVal == nil { + return 0, 0 + } + if indexVal == nil { + return 0, weight + } + diffYears := math.Abs(queryVal.Sub(*indexVal).Hours() / 24 / 365) + switch { + case diffYears < 1: + return weight, weight + case diffYears < 5: + return 0.5 * weight, weight + default: + return 0, weight + } +} + +// compareStringSlice does a naive match by concatenating slices and comparing as one big string. +// In real systems, you might do a best-match approach or measure how many elements overlap, etc. +func compareStringSlice(queryVals, indexVals []string, weight float64) (float64, float64) { + if len(queryVals) == 0 { + return 0, 0 + } + if len(indexVals) == 0 { + return 0, weight + } + query := strings.Join(queryVals, " ") + index := strings.Join(indexVals, " ") + return compareStringField(query, index, weight) +} + +// compareAddresses is a placeholder. In reality, you'd want something more sophisticated +// (e.g., best-match across addresses, geospatial closeness, etc.). +func compareAddresses(qAddrs, iAddrs []Address, weight float64) (float64, float64) { + if len(qAddrs) == 0 { + return 0, 0 + } + if len(iAddrs) == 0 { + return 0, weight + } + // Naive approach: just compare the first address line1, line2, city, etc. as a concatenated string + var qParts, iParts []string + for _, a := range qAddrs { + qParts = append(qParts, a.Line1, a.Line2, a.City, a.State, a.PostalCode, a.Country) + } + for _, a := range iAddrs { + iParts = append(iParts, a.Line1, a.Line2, a.City, a.State, a.PostalCode, a.Country) + } + query := strings.Join(qParts, " ") + index := strings.Join(iParts, " ") + return compareStringField(query, index, weight) +} + +// compareCryptoAddresses is a placeholder that just compares them all as one big string. +func compareCryptoAddresses(qAddrs, iAddrs []CryptoAddress, weight float64) (float64, float64) { + if len(qAddrs) == 0 { + return 0, 0 + } + if len(iAddrs) == 0 { + return 0, weight + } + var qParts, iParts []string + for _, ca := range qAddrs { + qParts = append(qParts, ca.Currency, ca.Address) + } + for _, ca := range iAddrs { + iParts = append(iParts, ca.Currency, ca.Address) + } + query := strings.Join(qParts, " ") + index := strings.Join(iParts, " ") + return compareStringField(query, index, weight) +} + +// compareAffiliations is another naive approach. +// You could do more advanced logic to match entity names, types, etc. +func compareAffiliations(qAffs, iAffs []Affiliation, weight float64) (float64, float64) { + if len(qAffs) == 0 { + return 0, 0 + } + if len(iAffs) == 0 { + return 0, weight + } + // Just combine them all + var qParts, iParts []string + for _, aff := range qAffs { + qParts = append(qParts, aff.EntityName, aff.Type, aff.Details) + } + for _, aff := range iAffs { + iParts = append(iParts, aff.EntityName, aff.Type, aff.Details) + } + query := strings.Join(qParts, " ") + index := strings.Join(iParts, " ") + return compareStringField(query, index, weight) +} + +// compareSanctionsInfo is naive. Potentially you'd do fuzzy set matching of programs, etc. +func compareSanctionsInfo(qInfo, iInfo *SanctionsInfo, weight float64) (float64, float64) { + if qInfo == nil { + return 0, 0 + } + if iInfo == nil { + return 0, weight + } + // Combine programs and description + query := strings.Join(qInfo.Programs, " ") + " " + qInfo.Description + index := strings.Join(iInfo.Programs, " ") + " " + iInfo.Description + score, w := compareStringField(query, index, weight) + // If one is "secondary" and the other isn't, reduce score + if qInfo.Secondary != iInfo.Secondary && score > 0 { + score *= 0.5 + } + return score, w +} + +// compareHistoricalInfo is naive. You might want date checks, type matching, etc. +func compareHistoricalInfo(qHist, iHist []HistoricalInfo, weight float64) (float64, float64) { + if len(qHist) == 0 { + return 0, 0 + } + if len(iHist) == 0 { + return 0, weight + } + var qParts, iParts []string + for _, h := range qHist { + qParts = append(qParts, h.Type, h.Value) + // If you want to compare the date, you'd do it similarly to compareDateField + } + for _, h := range iHist { + iParts = append(iParts, h.Type, h.Value) + } + query := strings.Join(qParts, " ") + index := strings.Join(iParts, " ") + return compareStringField(query, index, weight) +} + +type partial struct { + Score float64 + Weight float64 +} + +// combineScores sums partials into a final ratio in [0..1]. +func combineScores(partials []partial) float64 { + var totalScore, totalWeight float64 + for _, p := range partials { + totalScore += p.Score + totalWeight += p.Weight + } + if totalWeight == 0 { + return 0 + } + ratio := totalScore / totalWeight + if ratio < 0 { + return 0 + } else if ratio > 1 { + return 1 + } + return ratio +} diff --git a/pkg/search/similarity_ofac_test.go b/pkg/search/similarity_ofac_test.go new file mode 100644 index 00000000..4b319507 --- /dev/null +++ b/pkg/search/similarity_ofac_test.go @@ -0,0 +1,87 @@ +package search_test + +import ( + "testing" + + "github.com/moov-io/watchman/pkg/ofac" + "github.com/moov-io/watchman/pkg/search" + + "github.com/stretchr/testify/require" +) + +func TestSimilarity_OFAC_SDN_Vessel(t *testing.T) { + fullSDN := ofac.SDN{ + EntityID: "123", + SDNName: "TANKER VESSEL", + SDNType: "vessel", + Programs: []string{"SDGT", "IRGC"}, + CallSign: "ABCD1234", + VesselType: "Cargo", + Tonnage: "10000", + GrossRegisteredTonnage: "12000", + VesselFlag: "US", + VesselOwner: "BIG SHIPPING INC.", + Remarks: "Test remarks", + } + + indexEntity := ofac.ToEntity(fullSDN, nil, nil) + + testCases := []struct { + name string + query search.Entity[any] + expected float64 + }{ + { + name: "Perfect match", + query: search.Entity[any]{ + Name: "TANKER VESSEL", + Type: search.EntityVessel, + Vessel: &search.Vessel{ + Name: "TANKER VESSEL", + Type: search.VesselType("Cargo"), + Flag: "US", + Tonnage: 10000, + CallSign: "ABCD1234", + GrossRegisteredTonnage: 12000, + Owner: "BIG SHIPPING INC.", + }, + }, + expected: 0.80, + }, + { + name: "Partial match (some fields differ)", + query: search.Entity[any]{ + Name: "Tanker Vessel", // close match (capitalization differs) + Type: search.EntityVessel, + Vessel: &search.Vessel{ + Name: "Tanker Vessel", + Type: search.VesselType("Cargo"), + Flag: "GB", // mismatch + Tonnage: 9500, // partial mismatch + CallSign: "ABCD1234", + GrossRegisteredTonnage: 12000, + Owner: "BIG SHIPPING Inc", // minor difference + }, + }, + expected: 0.75, + }, + { + name: "Mismatch (completely different)", + query: search.Entity[any]{ + Name: "Random Business", + Type: search.EntityBusiness, + Business: &search.Business{ + Name: "Random Business", + }, + }, + expected: 0.5, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + score := search.Similarity(tc.query, indexEntity) + require.InDelta(t, tc.expected, score, 0.001) + }) + } +}