Commit 7e017801 authored by EC2 Default User's avatar EC2 Default User
Browse files

Initial version

parent 431554a3
%% Cell type:markdown id: tags:
### Training a Keras CNN on Fashion-MNIST
%% Cell type:markdown id: tags:
Fashion-MNIST is a Zalando dataset consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. It's a drop-in replacement for MNIST.
https://github.com/zalandoresearch/fashion-mnist/
In this notebook, we'll train a simple CNN built with Keras, using the built-in Tensorflow and Apache MXNet containers provided by Amazon SageMaker.
%% Cell type:code id: tags:
``` python
from IPython.display import Image
Image("fashion-mnist-sprite.png")
```
%% Cell type:code id: tags:
``` python
import sagemaker
sess = sagemaker.Session()
role = sagemaker.get_execution_role()
```
%% Cell type:markdown id: tags:
## Download the Fashion-MNIST dataset
%% Cell type:code id: tags:
``` python
import os
import keras
import numpy as np
from keras.datasets import fashion_mnist
(x_train, y_train), (x_val, y_val) = fashion_mnist.load_data()
os.makedirs("./data", exist_ok = True)
np.savez('./data/training', image=x_train, label=y_train)
np.savez('./data/validation', image=x_val, label=y_val)
```
%% Cell type:markdown id: tags:
## Upload Fashion-MNIST data to S3
%% Cell type:code id: tags:
``` python
prefix = 'keras-fashion-mnist'
training_input_path = sess.upload_data('data/training.npz', key_prefix=prefix+'/training')
validation_input_path = sess.upload_data('data/validation.npz', key_prefix=prefix+'/validation')
print(training_input_path)
print(validation_input_path)
```
%% Cell type:markdown id: tags:
## Train with Tensorflow on the notebook instance (aka 'local mode')
%% Cell type:code id: tags:
``` python
!pygmentize mnist_keras_tf.py
```
%% Cell type:code id: tags:
``` python
from sagemaker.tensorflow import TensorFlow
tf_estimator = TensorFlow(entry_point='mnist_keras_tf.py',
role=role,
train_instance_count=1,
train_instance_type='local',
framework_version='1.12',
py_version='py3',
script_mode=True,
hyperparameters={'epochs': 1}
)
```
%% Cell type:code id: tags:
``` python
tf_estimator.fit({'training': training_input_path, 'validation': validation_input_path})
```
%% Cell type:markdown id: tags:
## Train with Tensorflow on a GPU instance
%% Cell type:code id: tags:
``` python
tf_estimator = TensorFlow(entry_point='mnist_keras_tf.py',
role=role,
train_instance_count=1,
train_instance_type='ml.p3.2xlarge',
framework_version='1.12',
py_version='py3',
script_mode=True,
hyperparameters={
'epochs': 20,
'batch-size': 256,
'learning-rate': 0.01}
)
```
%% Cell type:code id: tags:
``` python
tf_estimator.fit({'training': training_input_path, 'validation': validation_input_path})
```
%% Cell type:markdown id: tags:
## Deploy
%% Cell type:code id: tags:
``` python
import time
tf_endpoint_name = 'keras-tf-fmnist-'+time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
#tf_predictor = tf_estimator.deploy(initial_instance_count=1,
# instance_type='ml.p2.xlarge') # $1.361/hour in eu-west-1
tf_predictor = tf_estimator.deploy(initial_instance_count=1,
instance_type='ml.c5.large', # $0.134/hour in eu-west-1
accelerator_type='ml.eia1.medium', # + $0.140/hour in eu-west-1
endpoint_name=tf_endpoint_name) # = 80% discount!
```
%% Cell type:markdown id: tags:
## Predict
%% Cell type:code id: tags:
``` python
%matplotlib inline
import random
import matplotlib.pyplot as plt
num_samples = 5
indices = random.sample(range(x_val.shape[0] - 1), num_samples)
images = x_val[indices]/255
labels = y_val[indices]
for i in range(num_samples):
plt.subplot(1,num_samples,i+1)
plt.imshow(images[i].reshape(28, 28), cmap='gray')
plt.title(labels[i])
plt.axis('off')
prediction = tf_predictor.predict(images.reshape(num_samples, 28, 28, 1))['predictions']
prediction = np.array(prediction)
predicted_label = prediction.argmax(axis=1)
print('Predicted labels are: {}'.format(predicted_label))
```
%% Cell type:markdown id: tags:
## Train with Apache MXNet on a GPU instance
%% Cell type:code id: tags:
``` python
!pygmentize mnist_keras_mxnet.py
```
%% Cell type:code id: tags:
``` python
from sagemaker.mxnet import MXNet
mxnet_estimator = MXNet(entry_point='mnist_keras_mxnet.py',
role=role,
train_instance_count=1,
#train_instance_type='local_gpu',
train_instance_type='ml.p3.2xlarge',