Skip to content

Commit ffdebe2

Browse files
committed
Finished MNIST CNN example.
1 parent 2b630fb commit ffdebe2

3 files changed

Lines changed: 87 additions & 65 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ Example runner will download all the required files like training data and model
149149
* [Object Detection](test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs)
150150
* [Text Classification](test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs)
151151
* [CNN Text Classification](test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs)
152+
* [MNIST CNN](test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs)
152153
* [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcess/NER)
153154
* [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs)
154155

test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs

Lines changed: 85 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ public class DigitRecognitionCNN : IExample
7373
float accuracy_test = 0f;
7474
float loss_test = 1f;
7575

76-
NDArray x_train;
76+
NDArray x_train, y_train;
77+
NDArray x_valid, y_valid;
78+
NDArray x_test, y_test;
7779

7880
public bool Run()
7981
{
@@ -135,6 +137,62 @@ public Graph BuildGraph()
135137
return graph;
136138
}
137139

140+
public void Train(Session sess)
141+
{
142+
// Number of training iterations in each epoch
143+
var num_tr_iter = y_train.len / batch_size;
144+
145+
var init = tf.global_variables_initializer();
146+
sess.run(init);
147+
148+
float loss_val = 100.0f;
149+
float accuracy_val = 0f;
150+
151+
foreach (var epoch in range(epochs))
152+
{
153+
print($"Training epoch: {epoch + 1}");
154+
// Randomly shuffle the training data at the beginning of each epoch
155+
(x_train, y_train) = mnist.Randomize(x_train, y_train);
156+
157+
foreach (var iteration in range(num_tr_iter))
158+
{
159+
var start = iteration * batch_size;
160+
var end = (iteration + 1) * batch_size;
161+
var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);
162+
163+
// Run optimization op (backprop)
164+
sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
165+
166+
if (iteration % display_freq == 0)
167+
{
168+
// Calculate and display the batch loss and accuracy
169+
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
170+
loss_val = result[0];
171+
accuracy_val = result[1];
172+
print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
173+
}
174+
}
175+
176+
// Run validation after every epoch
177+
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_valid), new FeedItem(y, y_valid));
178+
loss_val = results1[0];
179+
accuracy_val = results1[1];
180+
print("---------------------------------------------------------");
181+
print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
182+
print("---------------------------------------------------------");
183+
}
184+
}
185+
186+
public void Test(Session sess)
187+
{
188+
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_test), new FeedItem(y, y_test));
189+
loss_test = result[0];
190+
accuracy_test = result[1];
191+
print("---------------------------------------------------------");
192+
print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
193+
print("---------------------------------------------------------");
194+
}
195+
138196
/// <summary>
139197
/// Create a 2D convolution layer
140198
/// </summary>
@@ -219,6 +277,14 @@ private RefVariable bias_variable(string name, int[] shape)
219277
initializer: initial);
220278
}
221279

280+
/// <summary>
281+
/// Create a fully-connected layer
282+
/// </summary>
283+
/// <param name="x">input from previous layer</param>
284+
/// <param name="num_units">number of hidden units in the fully-connected layer</param>
285+
/// <param name="name">layer name</param>
286+
/// <param name="use_relu">boolean to add ReLU non-linearity (or not)</param>
287+
/// <returns>The output array</returns>
222288
private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
223289
{
224290
return with(tf.variable_scope(name), delegate
@@ -235,81 +301,36 @@ private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = tr
235301
return layer;
236302
});
237303
}
238-
239-
public Graph ImportGraph() => throw new NotImplementedException();
240-
241-
public void Predict(Session sess) => throw new NotImplementedException();
242304

243305
public void PrepareData()
244306
{
245307
mnist = MNIST.read_data_sets("mnist", one_hot: true);
246-
x_train = Reformat(mnist.train.data, mnist.train.labels);
308+
(x_train, y_train) = Reformat(mnist.train.data, mnist.train.labels);
309+
(x_valid, y_valid) = Reformat(mnist.validation.data, mnist.validation.labels);
310+
(x_test, y_test) = Reformat(mnist.test.data, mnist.test.labels);
311+
247312
print("Size of:");
248313
print($"- Training-set:\t\t{len(mnist.train.data)}");
249314
print($"- Validation-set:\t{len(mnist.validation.data)}");
250315
}
251316

252-
private NDArray Reformat(NDArray x, NDArray y)
317+
/// <summary>
318+
/// Reformats the data to the format acceptable for convolutional layers
319+
/// </summary>
320+
/// <param name="x"></param>
321+
/// <param name="y"></param>
322+
/// <returns></returns>
323+
private (NDArray, NDArray) Reformat(NDArray x, NDArray y)
253324
{
254-
var (img_size, num_ch, num_class) = (np.sqrt(x.shape[1]), 1, np.unique<int>(np.argmax(y, 1)));
255-
256-
return x;
325+
var (img_size, num_ch, num_class) = (np.sqrt(x.shape[1]), 1, len(np.unique<int>(np.argmax(y, 1))));
326+
var dataset = x.reshape(x.shape[0], img_size, img_size, num_ch).astype(np.float32);
327+
//y[0] = np.arange(num_class) == y[0];
328+
//var labels = (np.arange(num_class) == y.reshape(y.shape[0], 1, y.shape[1])).astype(np.float32);
329+
return (dataset, y);
257330
}
258331

259-
public void Train(Session sess)
260-
{
261-
// Number of training iterations in each epoch
262-
var num_tr_iter = mnist.train.labels.len / batch_size;
263-
264-
var init = tf.global_variables_initializer();
265-
sess.run(init);
266-
267-
float loss_val = 100.0f;
268-
float accuracy_val = 0f;
269-
270-
foreach (var epoch in range(epochs))
271-
{
272-
print($"Training epoch: {epoch + 1}");
273-
// Randomly shuffle the training data at the beginning of each epoch
274-
var (x_train, y_train) = mnist.Randomize(mnist.train.data, mnist.train.labels);
275-
276-
foreach (var iteration in range(num_tr_iter))
277-
{
278-
var start = iteration * batch_size;
279-
var end = (iteration + 1) * batch_size;
280-
var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);
281-
282-
// Run optimization op (backprop)
283-
sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
284-
285-
if (iteration % display_freq == 0)
286-
{
287-
// Calculate and display the batch loss and accuracy
288-
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
289-
loss_val = result[0];
290-
accuracy_val = result[1];
291-
print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
292-
}
293-
}
294-
295-
// Run validation after every epoch
296-
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels));
297-
loss_val = results1[0];
298-
accuracy_val = results1[1];
299-
print("---------------------------------------------------------");
300-
print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
301-
print("---------------------------------------------------------");
302-
}
303-
}
332+
public Graph ImportGraph() => throw new NotImplementedException();
304333

305-
public void Test(Session sess)
306-
{
307-
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels));
308-
loss_test = result[0];
309-
accuracy_test = result[1];
310-
print("---------------------------------------------------------");
311-
print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
312-
print("---------------------------------------------------------");
313-
}
334+
public void Predict(Session sess) => throw new NotImplementedException();
314335
}
315336
}

test/TensorFlowNET.Examples/Utility/Datasets.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public Datasets(T train, T validation, T test)
2828
var perm = np.random.permutation(y.shape[0]);
2929

3030
np.random.shuffle(perm);
31-
return (train.data[perm], train.labels[perm]);
31+
return (x[perm], y[perm]);
3232
}
3333

3434
/// <summary>

0 commit comments

Comments
 (0)