train.py 7.98 KB
Newer Older
Rémi Bèges's avatar
Rémi Bèges committed
1 2 3
import tensorflow as tf
from tensorflow import keras
import numpy as np
Rémi Bèges's avatar
Rémi Bèges committed
4
from autosync.dataset import *
Rémi Bèges's avatar
Rémi Bèges committed
5
from tensorflow.keras.utils import plot_model
Rémi Bèges's avatar
Rémi Bèges committed
6
from tqdm import tqdm
Rémi Bèges's avatar
Rémi Bèges committed
7 8

def train(dir='.'):
Rémi Bèges's avatar
Rémi Bèges committed
9 10 11 12 13 14 15 16 17 18 19
    (train_audio, train_subs), (test_audio, test_subs) = fetch_dataset_slices(dir)
    np.save('cache/train_audio.npy', train_audio)
    np.save('cache/train_subs.npy', train_subs)
    np.save('cache/test_audio.npy', test_audio)
    np.save('cache/test_subs.npy', test_subs)
    '''
    train_audio = np.load('cache/train_audio.npy')
    train_subs = np.load('cache/train_subs.npy')
    test_audio = np.load('cache/test_audio.npy')
    test_subs = np.load('cache/test_subs.npy')
    '''
Rémi Bèges's avatar
Rémi Bèges committed
20 21 22 23 24 25
    print('{} items in training, {} items in test'.format(
        len(train_audio),
        len(test_audio),
        )
    )

Rémi Bèges's avatar
Rémi Bèges committed
26
    dataset_balance = np.count_nonzero(train_subs) / train_subs.shape[0] * 100
Rémi Bèges's avatar
Rémi Bèges committed
27 28 29 30
    print('Training dataset balance: {}%'.format(dataset_balance))

    dataset_balance = np.count_nonzero(test_subs) / test_subs.shape[0] * 100
    print('Test dataset balance: {}%'.format(dataset_balance))
Rémi Bèges's avatar
Rémi Bèges committed
31 32 33

    epochs = 200
    batch_size = 32
Rémi Bèges's avatar
Rémi Bèges committed
34
    # m6 is just feat normalized
35
    model_id = 'm17_dataclean_300k_balanced'
Rémi Bèges's avatar
Rémi Bèges committed
36
    logsess = 'logs/' + model_id
Rémi Bèges's avatar
Rémi Bèges committed
37 38 39

    training_set = create_dataset(train_audio, train_subs)
    training_set = training_set.repeat(epochs).batch(batch_size)
Rémi Bèges's avatar
Rémi Bèges committed
40 41


Rémi Bèges's avatar
Rémi Bèges committed
42
    test_set = create_dataset(test_audio, test_subs)
Rémi Bèges's avatar
Rémi Bèges committed
43 44 45 46 47
    test_set = test_set.repeat(epochs).batch(len(test_subs))

    # Create a training set for evaluation purposes only of the same length than test dataset
    fit_set = create_dataset(train_audio[:len(test_subs)], train_subs[:len(test_subs)])
    fit_set = fit_set.repeat(epochs).batch(len(test_subs))
Rémi Bèges's avatar
Rémi Bèges committed
48 49

    input_shape = (train_audio.shape[1], 1)
Rémi Bèges's avatar
Rémi Bèges committed
50 51
    print('Building model... {}'.format(input_shape))

Rémi Bèges's avatar
Rémi Bèges committed
52
    '''
Rémi Bèges's avatar
Rémi Bèges committed
53
    model = keras.Sequential([
Rémi Bèges's avatar
Rémi Bèges committed
54 55
        keras.layers.Conv1D(filters=12, kernel_size=(3), activation='relu', input_shape=input_shape),
        keras.layers.Conv1D(filters=12, kernel_size=(3), activation='relu'),
Rémi Bèges's avatar
Rémi Bèges committed
56
        keras.layers.Conv1D(filters=12, kernel_size=(3), activation='relu'),
Rémi Bèges's avatar
Rémi Bèges committed
57
        keras.layers.Flatten(),
Rémi Bèges's avatar
Rémi Bèges committed
58 59 60 61 62
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(1, activation="tanh"),
        keras.layers.Lambda(lambda x: (x + 1.0) / 2.0),
    ])
    '''
Rémi Bèges's avatar
Rémi Bèges committed
63

Rémi Bèges's avatar
Rémi Bèges committed
64
    model = keras.Sequential([
65
        #keras.layers.Flatten(input_shape=input_shape),
66 67
        #keras.layers.Dense(28, activation="elu"),

68 69
        keras.layers.Conv1D(filters=14, kernel_size=(3), activation='relu', input_shape=input_shape),
        keras.layers.Conv1D(filters=14, kernel_size=(3), activation='relu', input_shape=input_shape),
70
        keras.layers.BatchNormalization(),
71
        keras.layers.Flatten(),
72

73
        keras.layers.Dense(56, activation="relu"),
74 75
        keras.layers.BatchNormalization(),

76
        keras.layers.Dense(6, activation="relu"),
77 78
        #keras.layers.BatchNormalization(),

Rémi Bèges's avatar
Rémi Bèges committed
79 80 81 82
        keras.layers.Dense(1, activation="tanh"),
        keras.layers.Lambda(lambda x: (x + 1.0) / 2.0),
        #keras.layers.Conv1D(filters=12, kernel_size=(3), activation='relu'),
        #keras.layers.Dense(28, activation="relu"),
Rémi Bèges's avatar
Rémi Bèges committed
83

84
        #keras.layers.Dense(56, activation="relu"),
Rémi Bèges's avatar
Rémi Bèges committed
85 86 87
        #keras.layers.Dropout(0.2),

        #keras.layers.BatchNormalization(),
Rémi Bèges's avatar
Rémi Bèges committed
88 89
    ])

Rémi Bèges's avatar
Rémi Bèges committed
90 91 92 93 94 95 96
    #plot_model(model, to_file='model.png', show_shapes=True)

    # Optimizer
    optimizer = tf.train.AdamOptimizer(
        learning_rate=0.001,
    )

Rémi Bèges's avatar
Rémi Bèges committed
97
    # Training loss
Rémi Bèges's avatar
Rémi Bèges committed
98 99 100 101 102
    iterator = training_set.make_one_shot_iterator()
    features, labels = iterator.get_next()
    predicted_labels = model(features)
    loss = tf.keras.losses.binary_crossentropy(labels, predicted_labels)
    loss = tf.reduce_mean(loss)
Rémi Bèges's avatar
Rémi Bèges committed
103
    tf.summary.scalar('loss', loss)
Rémi Bèges's avatar
Rémi Bèges committed
104
    merged = tf.summary.merge_all()
Rémi Bèges's avatar
Rémi Bèges committed
105

Rémi Bèges's avatar
Rémi Bèges committed
106 107 108 109
    # Train accuracy
    iterator = fit_set.make_one_shot_iterator()
    features, labels = iterator.get_next()
    predicted_labels = model(features)
Rémi Bèges's avatar
Rémi Bèges committed
110 111
    _pl = tf.cast(tf.round(predicted_labels), tf.int32)
    _l = tf.cast(labels, tf.int32)
Rémi Bèges's avatar
Rémi Bèges committed
112 113 114 115 116 117
    acc_train = tf.contrib.metrics.accuracy(
        _pl,
        _l
    )
    accuracy_value_train = tf.placeholder(tf.float32, shape=())
    accuracy_summary_train = tf.summary.scalar('accuracy:train', accuracy_value_train)
Rémi Bèges's avatar
Rémi Bèges committed
118

Rémi Bèges's avatar
Rémi Bèges committed
119 120 121 122 123 124
    # Test accuracy
    iterator = test_set.make_one_shot_iterator()
    features, labels = iterator.get_next()
    predicted_labels = model(features)
    _pl = tf.cast(tf.round(predicted_labels), tf.int32)
    _l = tf.cast(labels, tf.int32)
Rémi Bèges's avatar
Rémi Bèges committed
125
    acc = tf.contrib.metrics.accuracy(
Rémi Bèges's avatar
Rémi Bèges committed
126 127
        _pl,
        _l
Rémi Bèges's avatar
Rémi Bèges committed
128
    )
Rémi Bèges's avatar
Rémi Bèges committed
129 130
    accuracy_value_ = tf.placeholder(tf.float32, shape=())
    accuracy_summary = tf.summary.scalar('accuracy:test', accuracy_value_)
Rémi Bèges's avatar
Rémi Bèges committed
131 132 133 134 135 136 137

    # Trainable variables
    train_vars = model.trainable_variables
    print('Found {} vars to train'.format(len(train_vars)))

    # Train op
    train_op = optimizer.minimize(loss, var_list=train_vars)
Rémi Bèges's avatar
Rémi Bèges committed
138

Rémi Bèges's avatar
Rémi Bèges committed
139
    i = 0
140 141
    ei = 0
    epoch_count = 0
Rémi Bèges's avatar
Rémi Bèges committed
142 143 144 145 146 147 148 149 150
    with tf.Session() as sess, tqdm(total = epochs*len(train_audio)) as pbar:
        # Summaries
        writer = tf.summary.FileWriter(logsess, sess.graph)

        # Train
        sess.run(tf.global_variables_initializer())

        try:
            while True:
Rémi Bèges's avatar
Rémi Bèges committed
151
                i += batch_size
152
                ei += batch_size
Rémi Bèges's avatar
Rémi Bèges committed
153
                # Train & evaluate
Rémi Bèges's avatar
Rémi Bèges committed
154
                _, loss_value, summary = sess.run([train_op, loss, merged])
Rémi Bèges's avatar
Rémi Bèges committed
155

Rémi Bèges's avatar
Rémi Bèges committed
156
                writer.add_summary(summary, i)
Rémi Bèges's avatar
Rémi Bèges committed
157

Rémi Bèges's avatar
Rémi Bèges committed
158
                # Update progress bar
Rémi Bèges's avatar
Rémi Bèges committed
159
                pbar.set_postfix_str("loss={:.3f}".format(loss_value))
Rémi Bèges's avatar
Rémi Bèges committed
160 161
                pbar.refresh()
                pbar.update(batch_size)
Rémi Bèges's avatar
Rémi Bèges committed
162

163
                if ei < len(train_audio):
Rémi Bèges's avatar
Rémi Bèges committed
164 165
                    continue

166 167 168
                ei -= len(train_audio)
                epoch_count += 1

Rémi Bèges's avatar
Rémi Bèges committed
169 170
                # Test train accuracy after each epoch
                acc_value, = sess.run([acc_train])
171
                print('Train acc (epoch: {}): {}'.format(epoch_count, acc_value))
Rémi Bèges's avatar
Rémi Bèges committed
172 173 174 175 176
                summary = sess.run(accuracy_summary_train, feed_dict={accuracy_value_train: acc_value})
                writer.add_summary(summary, i)

                # Test accuracy after each epoch
                acc_value, label_value, predicted_value, = sess.run([acc, labels, predicted_labels])
177
                print('Test acc (epoch: {}): {}'.format(epoch_count, acc_value))
Rémi Bèges's avatar
Rémi Bèges committed
178 179
                summary = sess.run(accuracy_summary, feed_dict={accuracy_value_: acc_value})
                writer.add_summary(summary, i)
Rémi Bèges's avatar
Rémi Bèges committed
180
                '''
Rémi Bèges's avatar
Rémi Bèges committed
181 182 183 184 185 186
                pred_bin = np.zeros(predicted_value.shape) + 0.1
                pred_bin[predicted_value > 0.5] = 0.9
                plt.plot(label_value)
                plt.plot(predicted_value)
                plt.plot(pred_bin)
                plt.show()
Rémi Bèges's avatar
Rémi Bèges committed
187
                '''
Rémi Bèges's avatar
Rémi Bèges committed
188
        except tf.errors.OutOfRangeError:
Rémi Bèges's avatar
Rémi Bèges committed
189 190 191 192 193 194 195
            inputs = {
                "audio_mfcc": features
            }
            outputs = {
                "has_speech": predicted_labels
            }
            tf.saved_model.simple_save(
Rémi Bèges's avatar
Rémi Bèges committed
196
                sess, 'saved_models/' + model_id, inputs, outputs
Rémi Bèges's avatar
Rémi Bèges committed
197 198 199 200 201
            )

        # TODO: test accuracy
        # TODO: Check acc function is correct
        # TODO:
Rémi Bèges's avatar
Rémi Bèges committed
202 203 204 205 206 207 208 209
    '''
    test_acc = model.evaluate(
        test_set,
        steps=len(test_audio) // batch_size
    )

    print('Acc:', test_acc)
    '''
Rémi Bèges's avatar
Rémi Bèges committed
210
    #tf.keras.models.save_model(model, 'saved_models/v6')
Rémi Bèges's avatar
Rémi Bèges committed
211 212 213
    #saved_model_path = tf.contrib.saved_model.save_keras_model(model, "./saved_models")

    '''
Rémi Bèges's avatar
Rémi Bèges committed
214
    print('Fitting model')
Rémi Bèges's avatar
Rémi Bèges committed
215 216 217 218 219 220
    for i in range(1):
        history = model.fit(
            training_set.make_one_shot_iterator(),
            steps_per_epoch=len(train_audio) // batch_size,
            epochs=2,
            verbose=1)
Rémi Bèges's avatar
Rémi Bèges committed
221

Rémi Bèges's avatar
Rémi Bèges committed
222 223 224 225
        test_acc = model.evaluate(
            test_set,
            steps=len(test_audio) // batch_size
        )
Rémi Bèges's avatar
Rémi Bèges committed
226

Rémi Bèges's avatar
Rémi Bèges committed
227 228
        print('Acc:', test_acc)
        saved_model_path = tf.contrib.saved_model.save_keras_model(model, "./saved_models")
Rémi Bèges's avatar
Rémi Bèges committed
229
    '''
Rémi Bèges's avatar
Rémi Bèges committed
230

Rémi Bèges's avatar
Rémi Bèges committed
231
def main():
232
    dir = ["datasets/10*.avi", "datasets/Game*.mkv"]
Rémi Bèges's avatar
Rémi Bèges committed
233 234
    print('Beginning training, data dir: {}'.format(dir))
    train(dir=dir)
Rémi Bèges's avatar
Rémi Bèges committed
235 236 237

if __name__ == '__main__':
    main()