forked from SciSharp/SciSharp-Stack-Examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmnist_cnn.py
More file actions
360 lines (312 loc) · 13.5 KB
/
mnist_cnn.py
File metadata and controls
360 lines (312 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.tools.freeze_graph import freeze_graph
img_h = img_w = 28 # MNIST images are 28x28
img_size_flat = img_h * img_w # 28x28=784, the total number of pixels
n_classes = 10 # Number of classes, one class per digit
n_channels = 1
def load_data(mode='train'):
"""
Function to (download and) load the MNIST data
:param mode: train or test
:return: images and the corresponding labels
"""
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
if mode == 'train':
x_train, y_train, x_valid, y_valid = mnist.train.images, mnist.train.labels, \
mnist.validation.images, mnist.validation.labels
x_train, _ = reformat(x_train, y_train)
x_valid, _ = reformat(x_valid, y_valid)
return x_train, y_train, x_valid, y_valid
elif mode == 'test':
x_test, y_test = mnist.test.images, mnist.test.labels
x_test, _ = reformat(x_test, y_test)
return x_test, y_test
def reformat(x, y):
"""
Reformats the data to the format acceptable for convolutional layers
:param x: input array
:param y: corresponding labels
:return: reshaped input and labels
"""
img_size, num_ch, num_class = int(np.sqrt(x.shape[-1])), 1, len(np.unique(np.argmax(y, 1)))
dataset = x.reshape((-1, img_size, img_size, num_ch)).astype(np.float32)
labels = (np.arange(num_class) == y[:, None]).astype(np.float32)
return dataset, labels
def randomize(x, y):
""" Randomizes the order of data samples and their corresponding labels"""
permutation = np.random.permutation(y.shape[0])
shuffled_x = x[permutation, :, :, :]
shuffled_y = y[permutation]
return shuffled_x, shuffled_y
def get_next_batch(x, y, start, end):
x_batch = x[start:end]
y_batch = y[start:end]
return x_batch, y_batch
x_train, y_train, x_valid, y_valid = load_data(mode='train')
print("Size of:")
print("- Training-set:\t\t{}".format(len(y_train)))
print("- Validation-set:\t{}".format(len(y_valid)))
logs_path = "./logs" # path to the folder that we want to save the logs for Tensorboard
lr = 0.001 # The optimization initial learning rate
epochs = 1 # Total number of training epochs
batch_size = 100 # Training batch size
display_freq = 100 # Frequency of displaying the training results
# 1st Convolutional Layer
filter_size1 = 5 # Convolution filters are 5 x 5 pixels.
num_filters1 = 16 # There are 16 of these filters.
stride1 = 1 # The stride of the sliding window
# 2nd Convolutional Layer
filter_size2 = 5 # Convolution filters are 5 x 5 pixels.
num_filters2 = 32 # There are 32 of these filters.
stride2 = 1 # The stride of the sliding window
# Fully-connected layer.
h1 = 128 # Number of neurons in fully-connected layer.
# weight and bais wrappers
def weight_variable(shape):
"""
Create a weight variable with appropriate initialization
:param name: weight name
:param shape: weight shape
:return: initialized weight variable
"""
initer = tf.truncated_normal_initializer(stddev=0.01)
return tf.get_variable('W',
dtype=tf.float32,
shape=shape,
initializer=initer)
def bias_variable(shape):
"""
Create a bias variable with appropriate initialization
:param name: bias variable name
:param shape: bias variable shape
:return: initialized bias variable
"""
initial = tf.constant(0., shape=shape, dtype=tf.float32)
return tf.get_variable('b',
dtype=tf.float32,
initializer=initial)
def conv_layer(x, filter_size, num_filters, stride, name):
"""
Create a 2D convolution layer
:param x: input from previous layer
:param filter_size: size of each filter
:param num_filters: number of filters (or output feature maps)
:param stride: filter stride
:param name: layer name
:return: The output array
"""
with tf.variable_scope(name):
num_in_channel = x.get_shape().as_list()[-1]
shape = [filter_size, filter_size, num_in_channel, num_filters]
W = weight_variable(shape=shape)
tf.summary.histogram('weight', W)
b = bias_variable(shape=[num_filters])
tf.summary.histogram('bias', b)
layer = tf.nn.conv2d(x, W,
strides=[1, stride, stride, 1],
padding="SAME")
layer += b
return tf.nn.relu(layer)
def max_pool(x, ksize, stride, name):
"""
Create a max pooling layer
:param x: input to max-pooling layer
:param ksize: size of the max-pooling filter
:param stride: stride of the max-pooling filter
:param name: layer name
:return: The output array
"""
return tf.nn.max_pool(x,
ksize=[1, ksize, ksize, 1],
strides=[1, stride, stride, 1],
padding="SAME",
name=name)
def flatten_layer(layer):
"""
Flattens the output of the convolutional layer to be fed into fully-connected layer
:param layer: input array
:return: flattened array
"""
with tf.variable_scope('Flatten_layer'):
layer_shape = layer.get_shape()
num_features = layer_shape[1:4].num_elements()
layer_flat = tf.reshape(layer, [-1, num_features])
return layer_flat
def fc_layer(x, num_units, name, use_relu=True):
"""
Create a fully-connected layer
:param x: input from previous layer
:param num_units: number of hidden units in the fully-connected layer
:param name: layer name
:param use_relu: boolean to add ReLU non-linearity (or not)
:return: The output array
"""
with tf.variable_scope(name):
in_dim = x.get_shape()[1]
W = weight_variable(shape=[in_dim, num_units])
tf.summary.histogram('weight', W)
b = bias_variable(shape=[num_units])
tf.summary.histogram('bias', b)
layer = tf.matmul(x, W)
layer += b
if use_relu:
layer = tf.nn.relu(layer)
return layer
with tf.name_scope('Input'):
x = tf.placeholder(tf.float32, shape=[None, img_h, img_w, n_channels], name='X')
y = tf.placeholder(tf.float32, shape=[None, n_classes], name='Y')
conv1 = conv_layer(x, filter_size1, num_filters1, stride1, name='conv1')
pool1 = max_pool(conv1, ksize=2, stride=2, name='pool1')
conv2 = conv_layer(pool1, filter_size2, num_filters2, stride2, name='conv2')
pool2 = max_pool(conv2, ksize=2, stride=2, name='pool2')
layer_flat = flatten_layer(pool2)
fc1 = fc_layer(layer_flat, h1, 'FC1', use_relu=True)
output_logits = fc_layer(fc1, n_classes, 'OUT', use_relu=False)
with tf.variable_scope('Train'):
with tf.variable_scope('Loss'):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=output_logits), name='loss')
tf.summary.scalar('loss', loss)
with tf.variable_scope('Optimizer'):
optimizer = tf.train.AdamOptimizer(learning_rate=lr, name='Adam-op').minimize(loss)
with tf.variable_scope('Accuracy'):
correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name='correct_pred')
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')
tf.summary.scalar('accuracy', accuracy)
with tf.variable_scope('Prediction'):
cls_prediction = tf.argmax(output_logits, axis=1, name='predictions')
# Initialize the variables
init = tf.global_variables_initializer()
# Merge all summaries
merged = tf.summary.merge_all()
sess = tf.Session()
sess.run(init)
global_step = 0
summary_writer = tf.summary.FileWriter(logs_path, sess.graph)
# Number of training iterations in each epoch
num_tr_iter = int(len(y_train) / batch_size)
for epoch in range(epochs):
print('Training epoch: {}'.format(epoch + 1))
x_train, y_train = randomize(x_train, y_train)
for iteration in range(num_tr_iter):
global_step += 1
start = iteration * batch_size
end = (iteration + 1) * batch_size
x_batch, y_batch = get_next_batch(x_train, y_train, start, end)
# Run optimization op (backprop)
feed_dict_batch = {x: x_batch, y: y_batch}
sess.run(optimizer, feed_dict=feed_dict_batch)
if iteration % display_freq == 0:
# Calculate and display the batch loss and accuracy
loss_batch, acc_batch, summary_tr = sess.run([loss, accuracy, merged],
feed_dict=feed_dict_batch)
summary_writer.add_summary(summary_tr, global_step)
print("iter {0:3d}:\t Loss={1:.2f},\tTraining Accuracy={2:.01%}".
format(iteration, loss_batch, acc_batch))
# Run validation after every epoch
feed_dict_valid = {x: x_valid, y: y_valid}
loss_valid, acc_valid = sess.run([loss, accuracy], feed_dict=feed_dict_valid)
print('---------------------------------------------------------')
print("Epoch: {0}, validation loss: {1:.2f}, validation accuracy: {2:.01%}".
format(epoch + 1, loss_valid, acc_valid))
print('---------------------------------------------------------')
def plot_images(images, cls_true, cls_pred=None, title=None):
"""
Create figure with 3x3 sub-plots.
:param images: array of images to be plotted, (9, img_h*img_w)
:param cls_true: corresponding true labels (9,)
:param cls_pred: corresponding true labels (9,)
"""
fig, axes = plt.subplots(3, 3, figsize=(9, 9))
fig.subplots_adjust(hspace=0.3, wspace=0.3)
for i, ax in enumerate(axes.flat):
# Plot image.
ax.imshow(np.squeeze(images[i]), cmap='binary')
# Show true and predicted classes.
if cls_pred is None:
ax_title = "True: {0}".format(cls_true[i])
else:
ax_title = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])
ax.set_title(ax_title)
# Remove ticks from the plot.
ax.set_xticks([])
ax.set_yticks([])
if title:
plt.suptitle(title, size=20)
plt.show(block=False)
def plot_example_errors(images, cls_true, cls_pred, title=None):
"""
Function for plotting examples of images that have been mis-classified
:param images: array of all images, (#imgs, img_h*img_w)
:param cls_true: corresponding true labels, (#imgs,)
:param cls_pred: corresponding predicted labels, (#imgs,)
"""
# Negate the boolean array.
incorrect = np.logical_not(np.equal(cls_pred, cls_true))
# Get the images from the test-set that have been
# incorrectly classified.
incorrect_images = images[incorrect]
# Get the true and predicted classes for those images.
cls_pred = cls_pred[incorrect]
cls_true = cls_true[incorrect]
# Plot the first 9 images.
plot_images(images=incorrect_images[0:9],
cls_true=cls_true[0:9],
cls_pred=cls_pred[0:9],
title=title)
def test(sess):
x_test, y_test = load_data(mode='test')
feed_dict_test = {x: x_test, y: y_test}
loss_test, acc_test = sess.run([loss, accuracy], feed_dict=feed_dict_test)
print('---------------------------------------------------------')
print("Test loss: {0:.2f}, test accuracy: {1:.01%}".format(loss_test, acc_test))
print('---------------------------------------------------------')
# Plot some of the correct and misclassified examples
cls_pred = sess.run(cls_prediction, feed_dict=feed_dict_test)
cls_true = np.argmax(y_test, axis=1)
plot_images(x_test, cls_true, cls_pred, title='Correct Examples')
plot_example_errors(x_test, cls_true, cls_pred, title='Misclassified Examples')
plt.show()
# Test the network when training is done
test(sess)
# freeze graph
# https://medium.com/@prasadpal107/saving-freezing-optimizing-for-inference-restoring-of-tensorflow-models-b4146deb21b5
saver = tf.train.Saver()
saver.save(sess,'./tensorflowModel.ckpt')
def freeze_model(sess):
tf.train.write_graph(sess.graph.as_graph_def(), '.', 'tensorflowModel.pbtxt', as_text=True)
freeze_graph(input_graph = 'tensorflowModel.pbtxt',
input_saver = '',
input_binary = False,
input_checkpoint = './tensorflowModel.ckpt',
output_node_names = 'Train/Prediction/predictions',
restore_op_name = '',
filename_tensor_name = '',
output_graph = 'frozentensorflowModel.pb',
clear_devices = True,
initializer_nodes = '')
# freeze_model(sess)
# close the session after you are done with testing and freezing
sess.close()
# import freezing model
def load_pb(frozen_graph_filename):
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
graph = tf.Graph()
sess = tf.Session(graph = graph)
with graph.as_default():
with sess.as_default():
# restoring the model
saver = tf.train.import_meta_graph('tensorflowModel.ckpt.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# tf.import_graph_def(load_pb('frozentensorflowModel.pb'), name = '')
x = graph.get_tensor_by_name("Input/X:0")
y = graph.get_tensor_by_name("Input/Y:0")
loss = graph.get_tensor_by_name("Train/Loss/loss:0")
accuracy = graph.get_tensor_by_name("Train/Accuracy/accuracy:0")
cls_prediction = graph.get_tensor_by_name("Train/Prediction/predictions:0")
test(sess)