Skip to content

Commit 0484aab

Browse files
committed
optimize image dataset loading
1 parent 8cf46d2 commit 0484aab

4 files changed

Lines changed: 14 additions & 5 deletions

File tree

test/TensorFlowNET.Examples/LogisticRegression.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using NumSharp;
22
using System;
33
using System.Collections.Generic;
4+
using System.Diagnostics;
45
using System.IO;
56
using System.Linq;
67
using System.Text;
@@ -55,6 +56,8 @@ public bool Run()
5556
// Initialize the variables (i.e. assign their default value)
5657
var init = tf.global_variables_initializer();
5758

59+
var sw = new Stopwatch();
60+
5861
return with(tf.Session(), sess =>
5962
{
6063
// Run the initializer
@@ -63,6 +66,8 @@ public bool Run()
6366
// Training cycle
6467
foreach (var epoch in range(training_epochs))
6568
{
69+
sw.Start();
70+
6671
var avg_cost = 0.0f;
6772
var total_batch = mnist.train.num_examples / batch_size;
6873
// Loop over all batches
@@ -79,9 +84,13 @@ public bool Run()
7984
avg_cost += c / total_batch;
8085
}
8186

87+
sw.Stop();
88+
8289
// Display logs per epoch step
8390
if ((epoch + 1) % display_step == 0)
84-
print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")}");
91+
print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")} elapse= {sw.ElapsedMilliseconds}ms");
92+
93+
sw.Reset();
8594
}
8695

8796
print("Optimization Finished!");

test/TensorFlowNET.Examples/NeuralNetXor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace TensorFlowNET.Examples
1212
/// </summary>
1313
public class NeuralNetXor : Python, IExample
1414
{
15-
public int Priority => 2;
15+
public int Priority => 10;
1616
public bool Enabled { get; set; } = true;
1717
public string Name => "NN XOR";
1818

test/TensorFlowNET.Examples/Utility/DataSet.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
5454

5555
// Get the rest examples in this epoch
5656
var rest_num_examples = _num_examples - start;
57-
var images_rest_part = _images[np.arange(start, _num_examples)];
58-
var labels_rest_part = _labels[np.arange(start, _num_examples)];
57+
//var images_rest_part = _images[np.arange(start, _num_examples)];
58+
//var labels_rest_part = _labels[np.arange(start, _num_examples)];
5959
// Shuffle the data
6060
if (shuffle)
6161
{

test/TensorFlowNET.Examples/Utility/MnistDataSet.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ private static NDArray dense_to_one_hot(NDArray labels_dense, int num_classes)
102102
for(int row = 0; row < num_labels; row++)
103103
{
104104
var col = labels_dense.Data<byte>(row);
105-
labels_one_hot.SetData(1, row, col);
105+
labels_one_hot.SetData(1.0, row, col);
106106
}
107107

108108
return labels_one_hot;

0 commit comments

Comments
 (0)