Skip to content

Commit ecc4923

Browse files
committed
minor changes in examples
1 parent 1ff31db commit ecc4923

4 files changed

Lines changed: 21 additions & 19 deletions

File tree

test/TensorFlowNET.Examples/KMeansClustering.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@ public class KMeansClustering : Python, IExample
1818
public int Priority => 8;
1919
public bool Enabled { get; set; } = true;
2020
public string Name => "K-means Clustering";
21-
public int DataSize = 5000;
22-
public int TestSize = 5000;
23-
public int BatchSize = 100;
21+
22+
public int? train_size = null;
23+
public int validation_size = 5000;
24+
public int? test_size = null;
25+
public int batch_size = 1024; // The number of samples per batch
2426

2527
Datasets mnist;
2628
NDArray full_data_x;
2729
int num_steps = 50; // Total steps to train
28-
int batch_size = 1024; // The number of samples per batch
2930
int k = 25; // The number of clusters
3031
int num_classes = 10; // The 10 digits
3132
int num_features = 784; // Each image is 28x28 pixels
@@ -48,7 +49,7 @@ public bool Run()
4849

4950
public void PrepareData()
5051
{
51-
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size:TestSize);
52+
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size:validation_size, test_size:test_size);
5253
full_data_x = mnist.train.images;
5354
}
5455
}

test/TensorFlowNET.Examples/LinearRegression.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ public class LinearRegression : Python, IExample
1616
public bool Enabled { get; set; } = true;
1717
public string Name => "Linear Regression";
1818

19-
NumPyRandom rng = np.random;
19+
public int training_epochs = 1000;
2020

2121
// Parameters
2222
float learning_rate = 0.01f;
23-
public int TrainingEpochs = 1000;
2423
int display_step = 50;
2524

25+
NumPyRandom rng = np.random;
2626
NDArray train_X, train_Y;
2727
int n_samples;
2828

@@ -62,7 +62,7 @@ public bool Run()
6262
sess.run(init);
6363

6464
// Fit all training data
65-
for (int epoch = 0; epoch < TrainingEpochs; epoch++)
65+
for (int epoch = 0; epoch < training_epochs; epoch++)
6666
{
6767
foreach (var (x, y) in zip<float>(train_X, train_Y))
6868
{

test/TensorFlowNET.Examples/LogisticRegression.cs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ public class LogisticRegression : Python, IExample
2020
public bool Enabled { get; set; } = true;
2121
public string Name => "Logistic Regression";
2222

23+
public int training_epochs = 10;
24+
public int? train_size = null;
25+
public int validation_size = 5000;
26+
public int? test_size = null;
27+
public int batch_size = 100;
28+
2329
private float learning_rate = 0.01f;
24-
public int TrainingEpochs = 10;
25-
public int? TrainSize = null;
26-
public int ValidationSize = 5000;
27-
public int? TestSize = null;
28-
public int BatchSize = 100;
2930
private int display_step = 1;
3031

3132
Datasets mnist;
@@ -60,14 +61,14 @@ public bool Run()
6061
sess.run(init);
6162

6263
// Training cycle
63-
foreach (var epoch in range(TrainingEpochs))
64+
foreach (var epoch in range(training_epochs))
6465
{
6566
var avg_cost = 0.0f;
66-
var total_batch = mnist.train.num_examples / BatchSize;
67+
var total_batch = mnist.train.num_examples / batch_size;
6768
// Loop over all batches
6869
foreach (var i in range(total_batch))
6970
{
70-
var (batch_xs, batch_ys) = mnist.train.next_batch(BatchSize);
71+
var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size);
7172
// Run optimization op (backprop) and cost op (to get loss value)
7273
var result = sess.run(new object[] { optimizer, cost },
7374
new FeedItem(x, batch_xs),
@@ -99,7 +100,7 @@ public bool Run()
99100

100101
public void PrepareData()
101102
{
102-
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size: ValidationSize, test_size: TestSize);
103+
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size);
103104
}
104105

105106
public void SaveModel(Session sess)

test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public void InceptionArchGoogLeNet()
3939
[TestMethod]
4040
public void KMeansClustering()
4141
{
42-
new KMeansClustering() { Enabled = true }.Run();
42+
new KMeansClustering() { Enabled = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run();
4343
}
4444

4545
[TestMethod]
@@ -51,7 +51,7 @@ public void LinearRegression()
5151
[TestMethod]
5252
public void LogisticRegression()
5353
{
54-
new LogisticRegression() { Enabled = true, TrainingEpochs=10, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
54+
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Run();
5555
}
5656

5757
[Ignore]

0 commit comments

Comments
 (0)