Commit 4f46b7b4 authored by Jack Poulson's avatar Jack Poulson
Browse files

Adding ability to query all nearest neighbors.

parent 3c2f3a80
......@@ -676,15 +676,25 @@ def nearest_neighbors(embeddings, key, key_index_bijection, num_neighbors=20):
'''
query_index = key_index_bijection.key_to_index[key]
cosines = embeddings @ embeddings[query_index, :].transpose()
neighbor_indices = (np.argsort(cosines))[::-1]
neighbor_indices = neighbor_indices[:num_neighbors]
indices = (np.argsort(cosines))[::-1]
# Leave space for the potential later removal of the key itself.
indices = indices[:num_neighbors + 1]
neighbor_indices = []
neighbor_keys = []
neighbor_cosines = []
for neighbor_index in neighbor_indices:
neighbor_key = key_index_bijection.index_to_key[neighbor_index]
for index in indices:
neighbor_key = key_index_bijection.index_to_key[index]
if neighbor_key == key:
continue
neighbor_indices.append(index)
neighbor_keys.append(neighbor_key)
neighbor_cosines.append(cosines[neighbor_index])
neighbor_cosines.append(cosines[index])
neighbor_indices = neighbor_indices[:num_neighbors]
neighbor_keys = neighbor_keys[:num_neighbors]
neighbor_cosines = neighbor_cosines[:num_neighbors]
return (neighbor_keys, neighbor_indices, neighbor_cosines)
......
......@@ -914,6 +914,7 @@ def retrieve_fpds(start_date='2019/03/17',
url += ' DEPARTMENT_NAME:{}'.format(agency)
url += '&start={}'.format(0)
print('url: {}'.format(url))
data = xmltodict.parse(requests.get(url).text)
# Extract the offset range from the URL.
......@@ -1698,6 +1699,36 @@ def nearest_neighbors(embeddings,
return (keys, indices, cosines)
def get_all_nearest_neighbors(embeddings,
key_index_bijection,
num_neighbors=20,
only_vendors=True):
'''Returns a dictionary from every vendor to its nearest neighbors.
Args:
embeddings: The tall-skinny matrix of embeddings.
key_index_bijection: The KeyIndexBijection between keys and indices.
num_neighbors: The maximum number of neighbors to return.
only_vendors: If true, only vendors are preserved in the results.
Returns:
The tuple of the list of keys and indices of the nearest neighbors and
their cosine similarity to the query embedding. If only_vendors is True,
the vendor names are returned, rather than their corresponding (prefixed)
keys.
'''
neighbors = {}
for key in key_index_bijection.key_to_index:
vendor = vendor_from_key(key)
if vendor is None:
continue
keys, indices, cosines = nearest_neighbors(
embeddings, key, key_index_bijection, num_neighbors, only_vendors)
neighbors[key] = keys
return neighbors
def append_duns_to_vendor(vendor, duns):
'''Appends the DUNS number to the vendor string.'''
return '{} DUNS={}'.format(vendor, duns)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment