...
 
Commits (4)
This diff is collapsed.
......@@ -27,6 +27,7 @@ import numpy as np
import argparse
import kaldi_io
import utils
import random
from sklearn.manifold import TSNE
......@@ -88,17 +89,110 @@ def visualize_stats(feat_filename, max_feats, abs_feats=True, reverse_sort=True)
plt.matshow([finalsum_sorted])
plt.show()
def visualize_classes_tsne(feat_filename, utt_2_class):
feats, utt_ids = kaldi_io.readArk(feat_filename)
def visualize_classes_tsne(feat_filename, utt_2_class_filename, half_index=-1, normalize=True, class_mean_vector=False):
feats, utt_ids = kaldi_io.readArk(feat_filename, limit=25000)
feats_len=len(feats)
print("Loaded:" + str(feats_len) + "feats")
assert(len(utt_ids)==len(feats))
print("Loaded:" + str(feats_len) + " feats.")
feats = [feat.mean(0) for feat in feats]
if half_index != -1:
print('Cutting vectors at ', half_index, 'and normalize to unit length' if normalize else '')
feats = [feat[:half_index]/(np.linalg.norm(feat[:half_index]) if normalize else 1.0) for feat in feats]
else:
if normalize:
print('Normalize to unit length.')
feats = [feat/np.linalg.norm(feat) for feat in feats]
utt_2_class = utils.loadUtt2Spk(utt_2_class_filename)
ground_truth_utt_2_class = [utt_2_class[utt_id] for utt_id in utt_ids if utt_id in utt_2_class]
utt_ids_filtered = [utt_id for utt_id in utt_ids if utt_id in utt_2_class]
#feats_filtered = [feat for feat,utt_id in zip(feats, utt_ids) if utt_id in utt_2_class]
assert(len(ground_truth_utt_2_class) == len(utt_ids_filtered))
#assert(len(utt_ids_filtered) == len(feats_filtered) )
dataset = {}
for feat,utt in zip(feats, utt_ids):
if utt in utt_2_class:
dataset[utt] = feat
myclass_2_utt = {}
myclass_2_samples = {}
for myclass in set(ground_truth_utt_2_class):
my_class_filtered_utts = [utt_id for utt_id, gd_class in zip(utt_ids_filtered, ground_truth_utt_2_class) if gd_class == myclass]
if len(my_class_filtered_utts) > 100:
myclass_2_utt[myclass] = my_class_filtered_utts
myclass_2_samples[myclass] = random.sample(myclass_2_utt[myclass], min(1000,len(myclass_2_utt[myclass])))
feats_samples = []
feats_samples_classes = []
if class_mean_vector:
for myclass in myclass_2_samples:
feats_samples += [np.vstack(dataset[utt] for utt in myclass_2_samples[myclass]).mean(0)]
feats_samples_classes += [myclass]
else:
for myclass in myclass_2_samples:
feats_samples += [dataset[utt] for utt in myclass_2_samples[myclass]]
feats_samples_classes += [myclass]*len(myclass_2_samples[myclass])
print('Added',len(myclass_2_samples[myclass]),'entries for',myclass)
print([utt.replace('train-sample','train/sample') + '.mp3' for utt in myclass_2_samples[myclass]])
class_2_num = dict([(a,b) for b,a in enumerate(list(myclass_2_samples.keys()))])
print(class_2_num)
feats_samples_classes_num = [class_2_num[myclass] for myclass in feats_samples_classes]
#print(feats_samples_classes_num)
num_classes = max(feats_samples_classes_num)
print('Num classes=', num_classes)
print(feats_samples)
print('shape:',feats_samples[0].shape)
print('Calculating TSNE:')
model = TSNE(n_components=2, random_state=0, metric='euclidean')
tsne_data = model.fit_transform(np.vstack(feats_samples))
#model = TSNE(n_components=2, random_state=0, metric='cosine')
#tsne_data = model.fit_transform([feat[100:] for feat in feats])
colormap = plt.cm.gist_ncar #nipy_spectral, Set1,Paired
colorst = colormap(np.linspace(0, 0.9, num_classes+1)) #[colormap(i) for i in np.linspace(0, 0.9, num_speakers)]
cs = [colorst[feats_samples_classes_num[i]] for i in range(len(feats_samples_classes_num))]
#print(tsne_data[:,0])
#print(tsne_data[:,1])
plt.scatter(tsne_data[:,0], tsne_data[:,1], color=cs)
#for i,elem in enumerate(tsne_data):
# print(cs[0])
# print(ground_truth_utt_2_spk[0])
# plt.scatter(elem[0], elem[1], color=cs[i], label=ground_truth_utt_2_spk[i])
plt.legend()
# for i in range(tsne_data.shape[0]):
# plt.text(tsne_data[i,0], tsne_data[i,1], uttids[i], fontsize=8, color=cs[i])
print('Now showing tsne plot:')
plt.show()
def visualize_kaldi_bin_feats(feat_filename, max_frames, num_feat=0, phn_file='', phn_offset=5, wav_file='', do_tsne=False):
feats, utt_ids = kaldi_io.readArk(feat_filename , limit=10)
feats, utt_ids = kaldi_io.readArk(feat_filename , limit=10000)
print([feat.shape for feat in feats], utt_ids)
......@@ -123,8 +217,7 @@ def visualize_kaldi_bin_feats(feat_filename, max_frames, num_feat=0, phn_file=''
for xc in xpositions:
plt.axvline(x=xc, color='k', linestyle='--')
plt.show()
if do_tsne:
plt.figure(1)
......@@ -144,13 +237,14 @@ if __name__ == '__main__':
#'/Users/milde/inspect/feats_transVgg16big_win50_neg_samples4_lcontexts2_rcontexts2_flts40_embsize100_fc_size1024_dropout_keep0.9_batchnorm_bndecay0.95_l2_reg0.0005_dot_combine/dev/feats.ark')
#'/Users/milde/inspect/kaldi_train/feats.normalized.ark')
#'/Users/milde/inspect/feats_transVgg16big_win50_neg_samples4_lcontexts2_rcontexts2_flts40_embsize100_fc_size1024_unit_norm_var_dropout_keep0.9_batchnorm_bndecay0.999_l2_reg0.0005_dot_combine/dev/feats.ark')
'feats/tedlium_ivectors/tedlium_ivector_online_test.ark')
'feats/feats_sp_transVgg16big_nsampling_rnd_win64_neg_samples4_lcontexts2_rcontexts2_flts40_embsize100_fc_size512_unit_norm_var_dropout_keep0.9_l2_reg0.0005_featinput_unnormalized.feats.ark_dot_combine_tied_embs/dev/feats.ark')
#'feats_vgg.ark')
parser.add_argument('-f', '--format', dest='format', help='Format of the feature file (raw,kaldi_ark)', type=str, default = 'kaldi_ark')
parser.add_argument('-m', '--max_frames', dest='max_frames', help='Maximum frames', type=int, default = 10000)
parser.add_argument('-n', '--num_feat', dest='num_feat', help='feat file to visualize', type=int, default = 2)
parser.add_argument('-p', '--phn_file', dest='phn_file', help='Phoneme annotation file', type=str, default = '')
parser.add_argument('-u', '--utt_2_class', dest='utt_2_class', help='File with meta classes, e.g. male / female / age etc.', type=str, default = '')
parser.add_argument('--mode', dest='mode', help='(featshow|stats|classes_tsne)', type=str, default = 'featshow')
......@@ -161,7 +255,7 @@ if __name__ == '__main__':
elif args.mode=='featshow':
visualize_kaldi_bin_feats(args.featfile, args.max_frames, phn_file= args.phn_file, num_feat=args.num_feat)
elif args.mode=='classes_tsne':
visualize_classes_tsne(args.featfile, utt_2_class)
visualize_classes_tsne(args.featfile, utt_2_class_filename=args.utt_2_class, half_index=100)
else:
print("mode not supported.")
#visualize_stats(args.featfile, args.max_frames)
......