Commit 9c0b5b7a authored by Rob Tomsick's avatar Rob Tomsick

Make substantial improvements to scoring/sorting

parent 6d8fd0ea
...@@ -53,7 +53,7 @@ import java.util.stream.Collectors; ...@@ -53,7 +53,7 @@ import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair;
import org.jooq.Condition; import org.jooq.Condition;
import org.jooq.DSLContext; import org.jooq.DSLContext;
import org.jooq.Field; import org.jooq.Field;
...@@ -196,7 +196,14 @@ implements DictionaryService ...@@ -196,7 +196,14 @@ implements DictionaryService
} }
return drugs.stream() return drugs.stream()
.sorted(comparatorFor(query, d -> d.canonicalName())) .sorted((a, b) -> {
double score =
biScore(query.toUpperCase(), b.canonicalName().toUpperCase())
-
biScore(query.toUpperCase(), a.canonicalName().toUpperCase());
return (int) Math.round(score * 1000);
})
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
...@@ -241,12 +248,6 @@ implements DictionaryService ...@@ -241,12 +248,6 @@ implements DictionaryService
.fetch(field(name("id"), UUID.class))); .fetch(field(name("id"), UUID.class)));
/* search NPNs */ /* search NPNs */
cond = DSL.falseCondition();
for (String hash : hashGroup)
{
cond = cond.or(hashField.like(hash));
}
ids.addAll(this.ctx ids.addAll(this.ctx
.select(field(name("id__entries"), UUID.class)) .select(field(name("id__entries"), UUID.class))
...@@ -258,37 +259,14 @@ implements DictionaryService ...@@ -258,37 +259,14 @@ implements DictionaryService
List<NDCProduct> products = this.loadProducts(new ArrayList<>(ids)); List<NDCProduct> products = this.loadProducts(new ArrayList<>(ids));
final String ucQuery = query.toUpperCase(); return products
.parallelStream()
products.sort(comparatorFor(query, p -> { .map(r -> Pair.of(r, scoreProduct(r, query)))
.filter(r -> r.getRight() >= 0.5d)
if (0.2d < StringUtils.getJaroWinklerDistance(ucQuery, .sorted((a, b) -> (int) Math.round((b.getRight() - a.getRight()) * 1000))
p.name().toUpperCase()) .limit(limit)
|| p.nonProprietaryNames().isEmpty()) .map(p -> p.getLeft())
{ .collect(Collectors.toList());
return p.name().toUpperCase();
}
/* use npns if the product name is totally dissimilar to the product name */
final List<String> names =
p.nonProprietaryNames()
.stream()
.map(s -> s.toUpperCase())
.collect(Collectors.toList());
names.sort((a, b) ->
(int) ((StringUtils.getJaroWinklerDistance(ucQuery, b)
-
StringUtils.getJaroWinklerDistance(ucQuery, a)) * 1000d));
return names.get(0);
}));
int max = (products.size() < limit ? products.size() : limit);
return new ArrayList<Product>(products.subList(0, max));
} }
@Override @Override
...@@ -395,25 +373,40 @@ implements DictionaryService ...@@ -395,25 +373,40 @@ implements DictionaryService
}); });
} }
private static final <T> Comparator<T> private static final double
comparatorFor(String reference, Function<T, String> accessor) scoreProduct(NDCProduct p, String query)
{ {
final String hash = PhoneticHash.hash(reference); double score =
biScore(p.name().toUpperCase(), query.toUpperCase());
/* NPNs */
for (String npn : p.nonProprietaryNames())
{
double ns = biScore(npn, query);
if (ns > score)
{
score = ns;
}
}
/* component drugs */
double cs =
p.components()
.stream()
.map(c -> biScore(c.drug().canonicalName().toUpperCase(), query.toUpperCase()))
.max(Double :: compare)
.orElse(0.0d);
score = Math.max(score, cs);
return (a, b) -> return score;
{ }
final String ah = PhoneticHash.hash(accessor.apply(a));
final String bh = PhoneticHash.hash(accessor.apply(b)); private static final double
biScore(String reference, String variation)
return (StringUtils.getLevenshteinDistance(reference.toUpperCase(), accessor.apply(a).toUpperCase()) {
+ return score(reference, variation)
StringUtils.getLevenshteinDistance(hash, ah)) + score(PhoneticHash.hash(reference),
- PhoneticHash.hash(variation));
(StringUtils.getLevenshteinDistance(reference.toUpperCase(), accessor.apply(b).toUpperCase())
+
StringUtils.getLevenshteinDistance(hash, bh));
};
} }
private static final Stream<String> private static final Stream<String>
...@@ -471,4 +464,36 @@ implements DictionaryService ...@@ -471,4 +464,36 @@ implements DictionaryService
} }
private static final double
score(String a, String b)
{
/* TODO a.length < n-gram size */
List<String> ang = shingle(a, 2);
List<String> bng = shingle(b, 2);
/* jaccard is size of intersection / size of union */
Set<String> intersection = new HashSet<>(ang);
intersection.retainAll(bng);
Set<String> union = new HashSet<>();
union.addAll(ang);
union.addAll(bng);
return ((double) intersection.size() / (double) union.size());
}
private static final List<String>
shingle(String a, int n)
{
List<String> ngrams = new ArrayList<>();
for (int i = 0; i < a.length() - n + 1; i++)
{
ngrams.add(a.substring(i, i + n));
}
return ngrams;
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment