Skip to content

Commit f677439

Browse files
kerryjiangOceania2018
authored andcommitted
started to use MnistModelLoader in Tensorflow.Hub (SciSharp#330)
1 parent 787c772 commit f677439

12 files changed

Lines changed: 69 additions & 356 deletions

File tree

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,7 @@ src/TensorFlowNET.Native/bazel-*
332332
src/TensorFlowNET.Native/c_api.h
333333
/.vscode
334334
test/TensorFlowNET.Examples/mnist
335+
336+
337+
# training model resources
338+
.resources

test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License.
1818
using System;
1919
using System.Diagnostics;
2020
using Tensorflow;
21-
using TensorFlowNET.Examples.Utility;
21+
using Tensorflow.Hub;
2222
using static Tensorflow.Python;
2323

2424
namespace TensorFlowNET.Examples
@@ -39,7 +39,7 @@ public class KMeansClustering : IExample
3939
public int? test_size = null;
4040
public int batch_size = 1024; // The number of samples per batch
4141

42-
Datasets<DataSetMnist> mnist;
42+
Datasets<MnistDataSet> mnist;
4343
NDArray full_data_x;
4444
int num_steps = 20; // Total steps to train
4545
int k = 25; // The number of clusters
@@ -62,19 +62,31 @@ public bool Run()
6262

6363
public void PrepareData()
6464
{
65-
mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size:validation_size, test_size:test_size);
66-
full_data_x = mnist.train.data;
65+
var loader = new MnistModelLoader();
66+
67+
var setting = new ModelLoadSetting
68+
{
69+
TrainDir = ".resources/mnist",
70+
OneHot = true,
71+
TrainSize = train_size,
72+
ValidationSize = validation_size,
73+
TestSize = test_size
74+
};
75+
76+
mnist = loader.LoadAsync(setting).Result;
77+
78+
full_data_x = mnist.Train.Data;
6779

6880
// download graph meta data
6981
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta";
70-
Web.Download(url, "graph", "kmeans.meta");
82+
loader.DownloadAsync(url, ".resources/graph", "kmeans.meta").Wait();
7183
}
7284

7385
public Graph ImportGraph()
7486
{
7587
var graph = tf.Graph().as_default();
7688

77-
tf.train.import_meta_graph("graph/kmeans.meta");
89+
tf.train.import_meta_graph(".resources/graph/kmeans.meta");
7890

7991
return graph;
8092
}
@@ -132,7 +144,7 @@ public void Train(Session sess)
132144
sw.Start();
133145
foreach (var i in range(idx.Count))
134146
{
135-
var x = mnist.train.labels[i];
147+
var x = mnist.Train.Labels[i];
136148
counts[idx[i]] += x;
137149
}
138150

@@ -153,7 +165,7 @@ public void Train(Session sess)
153165
var accuracy_op = tf.reduce_mean(cast);
154166

155167
// Test Model
156-
var (test_x, test_y) = (mnist.test.data, mnist.test.labels);
168+
var (test_x, test_y) = (mnist.Test.Data, mnist.Test.Labels);
157169
result = sess.run(accuracy_op, new FeedItem(X, test_x), new FeedItem(Y, test_y));
158170
accuray_test = result;
159171
print($"Test Accuracy: {accuray_test}");

test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ limitations under the License.
1919
using System.Diagnostics;
2020
using System.IO;
2121
using Tensorflow;
22-
using TensorFlowNET.Examples.Utility;
22+
using Tensorflow.Hub;
2323
using static Tensorflow.Python;
2424

2525
namespace TensorFlowNET.Examples
@@ -45,7 +45,7 @@ public class LogisticRegression : IExample
4545
private float learning_rate = 0.01f;
4646
private int display_step = 1;
4747

48-
Datasets<DataSetMnist> mnist;
48+
Datasets<MnistDataSet> mnist;
4949

5050
public bool Run()
5151
{
@@ -84,11 +84,11 @@ public bool Run()
8484
sw.Start();
8585

8686
var avg_cost = 0.0f;
87-
var total_batch = mnist.train.num_examples / batch_size;
87+
var total_batch = mnist.Train.NumOfExamples / batch_size;
8888
// Loop over all batches
8989
foreach (var i in range(total_batch))
9090
{
91-
var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size);
91+
var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(batch_size);
9292
// Run optimization op (backprop) and cost op (to get loss value)
9393
var result = sess.run(new object[] { optimizer, cost },
9494
new FeedItem(x, batch_xs),
@@ -115,7 +115,7 @@ public bool Run()
115115
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
116116
// Calculate accuracy
117117
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
118-
float acc = accuracy.eval(new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels));
118+
float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels));
119119
print($"Accuracy: {acc.ToString("F4")}");
120120

121121
return acc > 0.9;
@@ -124,31 +124,31 @@ public bool Run()
124124

125125
public void PrepareData()
126126
{
127-
mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size);
127+
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size).Result;
128128
}
129129

130130
public void SaveModel(Session sess)
131131
{
132132
var saver = tf.train.Saver();
133-
var save_path = saver.save(sess, "logistic_regression/model.ckpt");
134-
tf.train.write_graph(sess.graph, "logistic_regression", "model.pbtxt", as_text: true);
133+
var save_path = saver.save(sess, ".resources/logistic_regression/model.ckpt");
134+
tf.train.write_graph(sess.graph, ".resources/logistic_regression", "model.pbtxt", as_text: true);
135135

136-
FreezeGraph.freeze_graph(input_graph: "logistic_regression/model.pbtxt",
136+
FreezeGraph.freeze_graph(input_graph: ".resources/logistic_regression/model.pbtxt",
137137
input_saver: "",
138138
input_binary: false,
139-
input_checkpoint: "logistic_regression/model.ckpt",
139+
input_checkpoint: ".resources/logistic_regression/model.ckpt",
140140
output_node_names: "Softmax",
141141
restore_op_name: "save/restore_all",
142142
filename_tensor_name: "save/Const:0",
143-
output_graph: "logistic_regression/model.pb",
143+
output_graph: ".resources/logistic_regression/model.pb",
144144
clear_devices: true,
145145
initializer_nodes: "");
146146
}
147147

148148
public void Predict(Session sess)
149149
{
150150
var graph = new Graph().as_default();
151-
graph.Import(Path.Join("logistic_regression", "model.pb"));
151+
graph.Import(Path.Join(".resources/logistic_regression", "model.pb"));
152152

153153
// restoring the model
154154
// var saver = tf.train.import_meta_graph("logistic_regression/tensorflowModel.ckpt.meta");
@@ -159,7 +159,7 @@ public void Predict(Session sess)
159159
var input = x.outputs[0];
160160

161161
// predict
162-
var (batch_xs, batch_ys) = mnist.train.next_batch(10);
162+
var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(10);
163163
var results = sess.run(output, new FeedItem(input, batch_xs[np.arange(1)]));
164164

165165
if (results.argmax() == (batch_ys[0] as NDArray).argmax())

test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License.
1717
using NumSharp;
1818
using System;
1919
using Tensorflow;
20-
using TensorFlowNET.Examples.Utility;
20+
using Tensorflow.Hub;
2121
using static Tensorflow.Python;
2222

2323
namespace TensorFlowNET.Examples
@@ -31,7 +31,7 @@ public class NearestNeighbor : IExample
3131
{
3232
public bool Enabled { get; set; } = true;
3333
public string Name => "Nearest Neighbor";
34-
Datasets<DataSetMnist> mnist;
34+
Datasets<MnistDataSet> mnist;
3535
NDArray Xtr, Ytr, Xte, Yte;
3636
public int? TrainSize = null;
3737
public int ValidationSize = 5000;
@@ -84,10 +84,10 @@ public bool Run()
8484

8585
public void PrepareData()
8686
{
87-
mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize);
87+
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize).Result;
8888
// In this example, we limit mnist data
89-
(Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
90-
(Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing
89+
(Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
90+
(Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing
9191
}
9292

9393
public Graph ImportGraph()

test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License.
1818
using System;
1919
using System.Diagnostics;
2020
using Tensorflow;
21-
using TensorFlowNET.Examples.Utility;
21+
using Tensorflow.Hub;
2222
using static Tensorflow.Python;
2323

2424
namespace TensorFlowNET.Examples.ImageProcess
@@ -46,7 +46,7 @@ public class DigitRecognitionCNN : IExample
4646
int epochs = 5; // accuracy > 98%
4747
int batch_size = 100;
4848
float learning_rate = 0.001f;
49-
Datasets<DataSetMnist> mnist;
49+
Datasets<MnistDataSet> mnist;
5050

5151
// Network configuration
5252
// 1st Convolutional Layer
@@ -310,14 +310,14 @@ private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = tr
310310

311311
public void PrepareData()
312312
{
313-
mnist = MNIST.read_data_sets("mnist", one_hot: true);
314-
(x_train, y_train) = Reformat(mnist.train.data, mnist.train.labels);
315-
(x_valid, y_valid) = Reformat(mnist.validation.data, mnist.validation.labels);
316-
(x_test, y_test) = Reformat(mnist.test.data, mnist.test.labels);
313+
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result;
314+
(x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels);
315+
(x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels);
316+
(x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels);
317317

318318
print("Size of:");
319-
print($"- Training-set:\t\t{len(mnist.train.data)}");
320-
print($"- Validation-set:\t{len(mnist.validation.data)}");
319+
print($"- Training-set:\t\t{len(mnist.Train.Data)}");
320+
print($"- Validation-set:\t{len(mnist.Validation.Data)}");
321321
}
322322

323323
/// <summary>

test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License.
1717
using NumSharp;
1818
using System;
1919
using Tensorflow;
20-
using TensorFlowNET.Examples.Utility;
20+
using Tensorflow.Hub;
2121
using static Tensorflow.Python;
2222

2323
namespace TensorFlowNET.Examples.ImageProcess
@@ -44,7 +44,7 @@ public class DigitRecognitionNN : IExample
4444
int batch_size = 100;
4545
float learning_rate = 0.001f;
4646
int h1 = 200; // number of nodes in the 1st hidden layer
47-
Datasets<DataSetMnist> mnist;
47+
Datasets<MnistDataSet> mnist;
4848

4949
Tensor x, y;
5050
Tensor loss, accuracy;
@@ -121,13 +121,13 @@ private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = tr
121121

122122
public void PrepareData()
123123
{
124-
mnist = MNIST.read_data_sets("mnist", one_hot: true);
124+
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result;
125125
}
126126

127127
public void Train(Session sess)
128128
{
129129
// Number of training iterations in each epoch
130-
var num_tr_iter = mnist.train.labels.shape[0] / batch_size;
130+
var num_tr_iter = mnist.Train.Labels.shape[0] / batch_size;
131131

132132
var init = tf.global_variables_initializer();
133133
sess.run(init);
@@ -139,13 +139,13 @@ public void Train(Session sess)
139139
{
140140
print($"Training epoch: {epoch + 1}");
141141
// Randomly shuffle the training data at the beginning of each epoch
142-
var (x_train, y_train) = randomize(mnist.train.data, mnist.train.labels);
142+
var (x_train, y_train) = mnist.Randomize(mnist.Train.Data, mnist.Train.Labels);
143143

144144
foreach (var iteration in range(num_tr_iter))
145145
{
146146
var start = iteration * batch_size;
147147
var end = (iteration + 1) * batch_size;
148-
var (x_batch, y_batch) = get_next_batch(x_train, y_train, start, end);
148+
var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);
149149

150150
// Run optimization op (backprop)
151151
sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
@@ -161,7 +161,8 @@ public void Train(Session sess)
161161
}
162162

163163
// Run validation after every epoch
164-
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels));
164+
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Validation.Data), new FeedItem(y, mnist.Validation.Labels));
165+
165166
loss_val = results1[0];
166167
accuracy_val = results1[1];
167168
print("---------------------------------------------------------");
@@ -172,35 +173,12 @@ public void Train(Session sess)
172173

173174
public void Test(Session sess)
174175
{
175-
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels));
176+
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels));
176177
loss_test = result[0];
177178
accuracy_test = result[1];
178179
print("---------------------------------------------------------");
179180
print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
180181
print("---------------------------------------------------------");
181182
}
182-
183-
private (NDArray, NDArray) randomize(NDArray x, NDArray y)
184-
{
185-
var perm = np.random.permutation(y.shape[0]);
186-
187-
np.random.shuffle(perm);
188-
return (mnist.train.data[perm], mnist.train.labels[perm]);
189-
}
190-
191-
/// <summary>
192-
/// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method)
193-
/// </summary>
194-
/// <param name="x"></param>
195-
/// <param name="y"></param>
196-
/// <param name="start"></param>
197-
/// <param name="end"></param>
198-
/// <returns></returns>
199-
private (NDArray, NDArray) get_next_batch(NDArray x, NDArray y, int start, int end)
200-
{
201-
var x_batch = x[$"{start}:{end}"];
202-
var y_batch = y[$"{start}:{end}"];
203-
return (x_batch, y_batch);
204-
}
205183
}
206184
}

test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License.
1717
using NumSharp;
1818
using System;
1919
using Tensorflow;
20-
using TensorFlowNET.Examples.Utility;
20+
using Tensorflow.Hub;
2121
using static Tensorflow.Python;
2222

2323
namespace TensorFlowNET.Examples.ImageProcess
@@ -45,7 +45,7 @@ public class DigitRecognitionRNN : IExample
4545
int n_inputs = 28;
4646
int n_outputs = 10;
4747

48-
Datasets<DataSetMnist> mnist;
48+
Datasets<MnistDataSet> mnist;
4949

5050
Tensor x, y;
5151
Tensor loss, accuracy, cls_prediction;
@@ -143,15 +143,15 @@ public void Test(Session sess)
143143

144144
public void PrepareData()
145145
{
146-
mnist = MNIST.read_data_sets("mnist", one_hot: true);
147-
(x_train, y_train) = (mnist.train.data, mnist.train.labels);
148-
(x_valid, y_valid) = (mnist.validation.data, mnist.validation.labels);
149-
(x_test, y_test) = (mnist.test.data, mnist.test.labels);
146+
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result;
147+
(x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels);
148+
(x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels);
149+
(x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels);
150150

151151
print("Size of:");
152-
print($"- Training-set:\t\t{len(mnist.train.data)}");
153-
print($"- Validation-set:\t{len(mnist.validation.data)}");
154-
print($"- Test-set:\t\t{len(mnist.test.data)}");
152+
print($"- Training-set:\t\t{len(mnist.Train.Data)}");
153+
print($"- Validation-set:\t{len(mnist.Validation.Data)}");
154+
print($"- Test-set:\t\t{len(mnist.Test.Data)}");
155155
}
156156

157157
public Graph ImportGraph() => throw new NotImplementedException();

test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
<ProjectReference Include="..\..\src\KerasNET.Core\Keras.Core.csproj" />
1919
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
2020
<ProjectReference Include="..\..\src\TensorFlowText\TensorFlowText.csproj" />
21+
<ProjectReference Include="..\..\src\TensorFlowHub\TensorFlowHub.csproj" />
2122
</ItemGroup>
2223
</Project>

0 commit comments

Comments
 (0)