Commit 0a7e3da7 authored by Ben Milde's avatar Ben Milde

added static LSTM transform function

parent 5d6e7b21
......@@ -104,7 +104,8 @@ flags.DEFINE_integer("right_contexts", 2, "How many right context windows")
flags.DEFINE_integer("embedding_size", 100 , "Fully connected size at the end of the network.")
flags.DEFINE_integer("fc_size", 512 , "Fully connected size at the end of the network.")
flags.DEFINE_integer("fc_size", 1024 , "Fully connected size at the end of the network.")
flags.DEFINE_integer("rnn_hidden_cells", 1024 , "Size of hidden cells for recurrent neural networks, e.g. if LSTMs or GRUs are used in the embedding transform.")
flags.DEFINE_boolean("first_layer_tanh", True, "Whether tanh should be used for the output conv1d filters in end-to-end networks.")
flags.DEFINE_boolean("first_layer_log1p", True, "Whether log1p should be applied to the output of the conv1d filters.")
......@@ -680,6 +681,17 @@ class UnsupSeech(object):
print('pool shape after inception_resnet_v2 block:', pooled.get_shape())
print('is_training: ', is_training)
if FLAGS.embedding_transformation == "Static_LSTM":
cell = tf.contrib.rnn.LSTMCell(FLAGS.rnn_hidden_cells)
outputs, state = tf.nn.static_rnn(cell, pooled, dtype=tf.float32) #, sequence_length=[seq_len])
pooled = outputs[-1]
needs_flattening = False
# if FLAGS.embedding_transformation == "Static_GRU":
# needs_flattening = False
# add summaries for res net and moving variance / moving averages of batch norm.
if FLAGS.embedding_transformation.startswith("Resnet") and FLAGS.log_tensorboard:
for var in list(self.end_points.values()) + [var for var in tf.global_variables() if 'moving' in var.name]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment