Commit 84dbd2b1 authored by graham's avatar graham

fastText for blog post

parents
fastText/
data/pretrained/
\ No newline at end of file
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"
[packages]
"94cd631" = {path = "./fastText", editable = true}
pandas = "*"
numpy = "*"
[dev-packages]
[requires]
python_version = "3.6"
{
"_meta": {
"hash": {
"sha256": "06f4016621f25eb6faa1e3a9b66f08e8228e15477bf817b20d7c2cb697c8483b"
},
"pipfile-spec": 6,
"requires": {
"python_version": "3.6"
},
"sources": [
{
"name": "pypi",
"url": "https://pypi.org/simple",
"verify_ssl": true
}
]
},
"default": {
"94cd631": {
"editable": true,
"path": "./fastText"
},
"numpy": {
"hashes": [
"sha256:07379fe0b450f6fd6e5934a9bc015025bb4ce1c8fbed3ca8bef29328b1bc9570",
"sha256:085afac75bbc97a096744fcfc97a4b321c5a87220286811e85089ae04885acdd",
"sha256:2d6481c6bdab1c75affc0fc71eb1bd4b3ecef620d06f2f60c3f00521d54be04f",
"sha256:2df854df882d322d5c23087a4959e145b953dfff2abe1774fec4f639ac2f3160",
"sha256:381ad13c30cd1d0b2f3da8a0c1a4aa697487e8bb0e9e0cbeb7439776bcb645f8",
"sha256:385f1ce46e08676505b692bfde918c1e0b350963a15ef52d77691c2cf0f5dbf6",
"sha256:4d278c2261be6423c5e63d8f0ceb1b0c6db3ff83f2906f4b860db6ae99ca1bb5",
"sha256:51c5dcb51cf88b34b7d04c15f600b07c6ccbb73a089a38af2ab83c02862318da",
"sha256:589336ba5199c8061239cf446ee2f2f1fcc0c68e8531ee1382b6fc0c66b2d388",
"sha256:5edf1acc827ed139086af95ce4449b7b664f57a8c29eb755411a634be280d9f2",
"sha256:6b82b81c6b3b70ed40bc6d0b71222ebfcd6b6c04a6e7945a936e514b9113d5a3",
"sha256:6c57f973218b776195d0356e556ec932698f3a563e2f640cfca7020086383f50",
"sha256:758d1091a501fd2d75034e55e7e98bfd1370dc089160845c242db1c760d944d9",
"sha256:8622db292b766719810e0cb0f62ef6141e15fe32b04e4eb2959888319e59336b",
"sha256:8b8dcfcd630f1981f0f1e3846fae883376762a0c1b472baa35b145b911683b7b",
"sha256:97fa8f1dceffab782069b291e38c4c2227f255cdac5f1e3346666931df87373e",
"sha256:9d69967673ab7b028c2df09cae05ba56bf4e39e3cb04ebe452b6035c3b49848e",
"sha256:9e1f53afae865cc32459ad211493cf9e2a3651a7295b7a38654ef3d123808996",
"sha256:a4a433b3a264dbc9aa9c7c241e87c0358a503ea6394f8737df1683c7c9a102ac",
"sha256:baadc5f770917ada556afb7651a68176559f4dca5f4b2d0947cd15b9fb84fb51",
"sha256:c725d11990a9243e6ceffe0ab25a07c46c1cc2c5dc55e305717b5afe856c9608",
"sha256:d696a8c87315a83983fc59dd27efe034292b9e8ad667aeae51a68b4be14690d9",
"sha256:e1864a4e9f93ddb2dc6b62ccc2ec1f8250ff4ac0d3d7a15c8985dd4e1fbd6418"
],
"index": "pypi",
"version": "==1.14.5"
},
"pandas": {
"hashes": [
"sha256:211cfdb9f72f26d2ede21c751d27e08fed4434d47fb9bb82ebc8ff753888b8b6",
"sha256:28fd087514616549a0e3259cd68ac88d7eaed6bd3062017a7f312e27941266bd",
"sha256:2fb7c63138bd5ead296b18b2cb6abd3a394f7581e5ae052b02b27df8244b03ca",
"sha256:372435456c349a8d39ff001967b161f6bd29d4c3de145a4cf9b366648defbb1f",
"sha256:3790a3348ab0f416e58061d21693cb662fbb2f638001b94bf2b2199fedc1b1c2",
"sha256:437a6e906a6717a9ed2627cf6e7895b63dfaa0172567cbd75a553f55cf78cc17",
"sha256:50b52af2af2e15f4aeb2fe196da073a8c131fa02e433e105d95ce40016df5690",
"sha256:720daad75b5d35dd1b446842210c4f3fd447464c9c0884972f3f12b213a9edd1",
"sha256:b4fb71acbc2709b8f5993cb4b5445d8182864f11c39787e317aae39f21206270",
"sha256:b704fd73022342cce612996de495a16954311e0c0cf077c1b83d5cf0b9656a60",
"sha256:cbbecca0c7af6a2160b2d6ba30becc286824a98c61dcc6a41fada664f226424c",
"sha256:d2a071de755cc8ee7784e1b4c7b9b643d951d35c8adea7d64fe7c57cff9c47a7",
"sha256:d8154c5c68713a82461aba735832f0b4692be8a45a0a340a303bf90d6f80f36f",
"sha256:e1b86f7c55467ce1f6c12715f2fd1817f4a909b5c8c39bd4b5d2415ef2b04bd8",
"sha256:fcc63e8134516e93e16eb4ceac9afaa51f4adc5bf58efddae7cbc562f5b77dd0"
],
"index": "pypi",
"version": "==0.23.1"
},
"pybind11": {
"hashes": [
"sha256:7f2847016313068f6fc24e8996b30345b1b8ceb74de7ea45eb2c0fa9f8fa639d",
"sha256:87ff3ae777d9326349af5272974581270b2a0909b2392dc0cc57eb28ce23bcc3"
],
"version": "==2.2.3"
},
"python-dateutil": {
"hashes": [
"sha256:1adb80e7a782c12e52ef9a8182bebeb73f1d7e24e374397af06fb4956c8dc5c0",
"sha256:e27001de32f627c22380a688bcc43ce83504a7bc5da472209b4c70f02829f0b8"
],
"version": "==2.7.3"
},
"pytz": {
"hashes": [
"sha256:65ae0c8101309c45772196b21b74c46b2e5d11b6275c45d251b150d5da334555",
"sha256:c06425302f2cf668f1bba7a0a03f3c1d34d4ebeef2c72003da308b3947c7f749"
],
"version": "==2018.4"
},
"six": {
"hashes": [
"sha256:70e8a77beed4562e7f14fe23a786b54f6296e34344c23bc42f07b15018ff98e9",
"sha256:832dc0e10feb1aa2c68dcc57dbb658f1c7e65b9b61af69048abc87a2db00a0eb"
],
"version": "==1.11.0"
}
},
"develop": {}
}
import pandas as pd
import numpy as np
import csv
def read_data(file='./data/data.csv'):
df = pd.read_csv(file, sep='|')
return df
def add_label(s):
# very basic cleaning of labels to make it more legible
if ':' in s:
s = s.split(':')[0]
if '(' in s:
s = s.split('(')[0]
return '__label__' + s.lower().replace(' ', '-')
def get_labels(df, column='category_label', none_label='__label__none'):
'''
set the data to be read
'''
# replace all empty with label
df = df.where((pd.notnull(df)), none_label)
for idx, row in df[column].iteritems():
if row != none_label:
if '*' in row:
row = ' '.join([add_label(r) for r in row.split('*')])
else:
row = add_label(row)
df.loc[idx, column] = row
return df
def save_df(df, columns=['category_label', 'content']):
'''
couldnt figure out how to do this in just pandas
'''
df = df[columns].values
with open('data/labeled_data.txt', 'w') as f:
for line in df:
f.write(f"{line[0]} {line[1].lower()}\n")
if __name__=='__main__':
df = read_data()
df = get_labels(df)
save_df(df)
This diff is collapsed.
import fastText
from fastText import train_supervised
class Model:
def __init__(self, train_data, pretrained=None):
if pretrained:
pass
self.train_data = train_data
def train(self):
self.model = train_supervised(
self.train_data,
epoch=100,
wordNgrams=2,
loss="softmax",
pretrainedVectors="./data/pretrained/crawl-300d-2M.vec")
def create_ft_matrix(self, ft_matrix=None):
self.ft_words = self.model.get_words()
self.word_frequencies = dict(zip(*self.model.get_words(include_freq=True)))
self.ft_matrix = ft_matrix
if self.ft_matrix is None:
self.ft_matrix = np.empty((len(self.ft_words), self.model.get_dimension()))
for i, word in enumerate(self.ft_words):
self.ft_matrix[i, :] = self.model.get_word_vector(word)
def find_nearest_neighbor(self, query, vectors, n=10, cossims=None):
if cossims is None:
cossims = np.matmul(vectors, query, out=cossims)
norms = np.sqrt((query**2).sum() * (vectors**2).sum(axis=1))
cossims = cossims/norms
result_i = np.argpartition(-cossims, range(n+1))[1:n+1]
return list(zip(result_i, cossims[result_i]))
def nearest_words(self, word, n=10, word_freq=None):
result = self.find_nearest_neighbor(
self.model.get_word_vector(word), self.ft_matrix, n=n)
if word_freq:
return [(self.ft_words[r[0]], r[1]) for r in result if self.word_frequencies[self.ft_words[r[0]]] >= word_freq]
else:
return [(self.ft_words[r[0]], r[1]) for r in result]
pass
def predict_sentence(self, sent, n=5):
return self.model.predict(sent, k = n)
if __name__ == '__main__':
model = Model('data/labeled_data.txt')
model.train()
print(model.predict_sentence('im super upset this is bullshit'))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment