Commit 9c0b5b7a authored by Rob Tomsick's avatar Rob Tomsick

Make substantial improvements to scoring/sorting

parent 6d8fd0ea
......@@ -53,7 +53,7 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.jooq.Condition;
import org.jooq.DSLContext;
import org.jooq.Field;
......@@ -196,7 +196,14 @@ implements DictionaryService
}
return drugs.stream()
.sorted(comparatorFor(query, d -> d.canonicalName()))
.sorted((a, b) -> {
double score =
biScore(query.toUpperCase(), b.canonicalName().toUpperCase())
-
biScore(query.toUpperCase(), a.canonicalName().toUpperCase());
return (int) Math.round(score * 1000);
})
.collect(Collectors.toList());
}
......@@ -241,12 +248,6 @@ implements DictionaryService
.fetch(field(name("id"), UUID.class)));
/* search NPNs */
cond = DSL.falseCondition();
for (String hash : hashGroup)
{
cond = cond.or(hashField.like(hash));
}
ids.addAll(this.ctx
.select(field(name("id__entries"), UUID.class))
......@@ -258,37 +259,14 @@ implements DictionaryService
List<NDCProduct> products = this.loadProducts(new ArrayList<>(ids));
final String ucQuery = query.toUpperCase();
products.sort(comparatorFor(query, p -> {
if (0.2d < StringUtils.getJaroWinklerDistance(ucQuery,
p.name().toUpperCase())
|| p.nonProprietaryNames().isEmpty())
{
return p.name().toUpperCase();
}
/* use npns if the product name is totally dissimilar to the product name */
final List<String> names =
p.nonProprietaryNames()
.stream()
.map(s -> s.toUpperCase())
return products
.parallelStream()
.map(r -> Pair.of(r, scoreProduct(r, query)))
.filter(r -> r.getRight() >= 0.5d)
.sorted((a, b) -> (int) Math.round((b.getRight() - a.getRight()) * 1000))
.limit(limit)
.map(p -> p.getLeft())
.collect(Collectors.toList());
names.sort((a, b) ->
(int) ((StringUtils.getJaroWinklerDistance(ucQuery, b)
-
StringUtils.getJaroWinklerDistance(ucQuery, a)) * 1000d));
return names.get(0);
}));
int max = (products.size() < limit ? products.size() : limit);
return new ArrayList<Product>(products.subList(0, max));
}
@Override
......@@ -395,25 +373,40 @@ implements DictionaryService
});
}
private static final <T> Comparator<T>
comparatorFor(String reference, Function<T, String> accessor)
private static final double
scoreProduct(NDCProduct p, String query)
{
final String hash = PhoneticHash.hash(reference);
double score =
biScore(p.name().toUpperCase(), query.toUpperCase());
return (a, b) ->
/* NPNs */
for (String npn : p.nonProprietaryNames())
{
double ns = biScore(npn, query);
if (ns > score)
{
final String ah = PhoneticHash.hash(accessor.apply(a));
final String bh = PhoneticHash.hash(accessor.apply(b));
score = ns;
}
}
/* component drugs */
double cs =
p.components()
.stream()
.map(c -> biScore(c.drug().canonicalName().toUpperCase(), query.toUpperCase()))
.max(Double :: compare)
.orElse(0.0d);
return (StringUtils.getLevenshteinDistance(reference.toUpperCase(), accessor.apply(a).toUpperCase())
+
StringUtils.getLevenshteinDistance(hash, ah))
-
(StringUtils.getLevenshteinDistance(reference.toUpperCase(), accessor.apply(b).toUpperCase())
+
StringUtils.getLevenshteinDistance(hash, bh));
};
score = Math.max(score, cs);
return score;
}
private static final double
biScore(String reference, String variation)
{
return score(reference, variation)
+ score(PhoneticHash.hash(reference),
PhoneticHash.hash(variation));
}
private static final Stream<String>
......@@ -471,4 +464,36 @@ implements DictionaryService
}
private static final double
score(String a, String b)
{
/* TODO a.length < n-gram size */
List<String> ang = shingle(a, 2);
List<String> bng = shingle(b, 2);
/* jaccard is size of intersection / size of union */
Set<String> intersection = new HashSet<>(ang);
intersection.retainAll(bng);
Set<String> union = new HashSet<>();
union.addAll(ang);
union.addAll(bng);
return ((double) intersection.size() / (double) union.size());
}
private static final List<String>
shingle(String a, int n)
{
List<String> ngrams = new ArrayList<>();
for (int i = 0; i < a.length() - n + 1; i++)
{
ngrams.add(a.substring(i, i + n));
}
return ngrams;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment