Skip to content

Commit 30dde0f

Browse files
committed
TestSuite: added all examples with very small training sets (runs through within seconds)
1 parent 4bc97cd commit 30dde0f

7 files changed

Lines changed: 40 additions & 26 deletions

File tree

test/TensorFlowNET.Examples/KMeansClustering.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ public class KMeansClustering : Python, IExample
1818
public int Priority => 8;
1919
public bool Enabled { get; set; } = true;
2020
public string Name => "K-means Clustering";
21+
public int DataSize = 5000;
22+
public int TestSize = 5000;
23+
public int BatchSize = 100;
2124

2225
Datasets mnist;
2326
NDArray full_data_x;
@@ -45,7 +48,7 @@ public bool Run()
4548

4649
public void PrepareData()
4750
{
48-
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true);
51+
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size:TestSize);
4952
full_data_x = mnist.train.images;
5053
}
5154
}

test/TensorFlowNET.Examples/LogisticRegression.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ public class LogisticRegression : Python, IExample
2121
public string Name => "Logistic Regression";
2222

2323
private float learning_rate = 0.01f;
24-
private int training_epochs = 10;
25-
private int batch_size = 100;
24+
public int TrainingEpochs = 10;
25+
public int DataSize = 5000;
26+
public int TestSize = 5000;
27+
public int BatchSize = 100;
2628
private int display_step = 1;
2729

2830
Datasets mnist;
@@ -57,14 +59,14 @@ public bool Run()
5759
sess.run(init);
5860

5961
// Training cycle
60-
foreach (var epoch in range(training_epochs))
62+
foreach (var epoch in range(TrainingEpochs))
6163
{
6264
var avg_cost = 0.0f;
63-
var total_batch = mnist.train.num_examples / batch_size;
65+
var total_batch = mnist.train.num_examples / BatchSize;
6466
// Loop over all batches
6567
foreach (var i in range(total_batch))
6668
{
67-
var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size);
69+
var (batch_xs, batch_ys) = mnist.train.next_batch(BatchSize);
6870
// Run optimization op (backprop) and cost op (to get loss value)
6971
var result = sess.run(new object[] { optimizer, cost },
7072
new FeedItem(x, batch_xs),
@@ -96,7 +98,7 @@ public bool Run()
9698

9799
public void PrepareData()
98100
{
99-
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true);
101+
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size: TestSize);
100102
}
101103

102104
public void SaveModel(Session sess)

test/TensorFlowNET.Examples/NearestNeighbor.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ public class NearestNeighbor : Python, IExample
1919
public string Name => "Nearest Neighbor";
2020
Datasets mnist;
2121
NDArray Xtr, Ytr, Xte, Yte;
22+
public int DataSize = 5000;
23+
public int TestBatchSize = 200;
2224

2325
public bool Run()
2426
{
@@ -62,10 +64,10 @@ public bool Run()
6264

6365
public void PrepareData()
6466
{
65-
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true);
67+
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize);
6668
// In this example, we limit mnist data
67-
(Xtr, Ytr) = mnist.train.next_batch(5000); // 5000 for training (nn candidates)
68-
(Xte, Yte) = mnist.test.next_batch(200); // 200 for testing
69+
(Xtr, Ytr) = mnist.train.next_batch(DataSize); // 5000 for training (nn candidates)
70+
(Xte, Yte) = mnist.test.next_batch(TestBatchSize); // 200 for testing
6971
}
7072
}
7173
}

test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public class DataHelpers
1313
private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
1414
private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";
1515

16-
public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len)
16+
public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len, int? limit=null)
1717
{
1818
string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} ";
1919
/*if (step == "train")
@@ -25,10 +25,11 @@ public static (int[][], int[], int) build_char_dataset(string step, string model
2525
char_dict[c.ToString()] = char_dict.Count;
2626

2727
var contents = File.ReadAllLines(TRAIN_PATH);
28-
29-
var x = new int[contents.Length][];
30-
var y = new int[contents.Length];
31-
for (int i = 0; i < contents.Length; i++)
28+
var size = limit == null ? contents.Length : limit.Value;
29+
30+
var x = new int[size][];
31+
var y = new int[size];
32+
for (int i = 0; i < size; i++)
3233
{
3334
string[] parts = contents[i].ToLower().Split(",\"").ToArray();
3435
string content = parts[2];

test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ public class TextClassificationTrain : Python, IExample
1717
public int Priority => 100;
1818
public bool Enabled { get; set; }= false;
1919
public string Name => "Text Classification";
20+
public int? DataLimit = null;
2021

2122
private string dataDir = "text_classification";
2223
private string dataFileName = "dbpedia_csv.tar.gz";
@@ -28,7 +29,7 @@ public bool Run()
2829
{
2930
PrepareData();
3031
Console.WriteLine("Building dataset...");
31-
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN);
32+
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN, DataLimit);
3233

3334
var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
3435

test/TensorFlowNET.Examples/Utility/MnistDataSet.cs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,26 @@ public static Datasets read_data_sets(string train_dir,
2121
TF_DataType dtype = TF_DataType.TF_FLOAT,
2222
bool reshape = true,
2323
int validation_size = 5000,
24+
int test_size = 5000,
2425
string source_url = DEFAULT_SOURCE_URL)
2526
{
27+
var train_size = validation_size * 2;
28+
2629
Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES);
2730
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir);
28-
var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]));
31+
var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]), limit: train_size);
2932

3033
Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS);
3134
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir);
32-
var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot);
35+
var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot, limit: train_size);
3336

3437
Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES);
3538
Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir);
36-
var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]));
39+
var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]), limit: test_size);
3740

3841
Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS);
3942
Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir);
40-
var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot);
43+
var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot, limit:test_size);
4144

4245
int end = train_images.shape[0];
4346
var validation_images = train_images[np.arange(validation_size)];
@@ -52,14 +55,15 @@ public static Datasets read_data_sets(string train_dir,
5255
return new Datasets(train, validation, test);
5356
}
5457

55-
public static NDArray extract_images(string file)
58+
public static NDArray extract_images(string file, int? limit=null)
5659
{
5760
using (var bytestream = new FileStream(file, FileMode.Open))
5861
{
5962
var magic = _read32(bytestream);
6063
if (magic != 2051)
6164
throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}");
62-
var num_images = _read32(bytestream);
65+
var num_images = _read32(bytestream);
66+
num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit);
6367
var rows = _read32(bytestream);
6468
var cols = _read32(bytestream);
6569
var buf = new byte[rows * cols * num_images];
@@ -70,14 +74,15 @@ public static NDArray extract_images(string file)
7074
}
7175
}
7276

73-
public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10)
77+
public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10, int? limit = null)
7478
{
7579
using (var bytestream = new FileStream(file, FileMode.Open))
7680
{
7781
var magic = _read32(bytestream);
7882
if (magic != 2049)
7983
throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}");
8084
var num_items = _read32(bytestream);
85+
num_items = limit == null ? num_items : Math.Min(num_items,(uint) limit);
8186
var buf = new byte[num_items];
8287
bytestream.Read(buf, 0, buf.Length);
8388
var labels = np.frombuffer(buf, np.uint8);

test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public void LinearRegression()
5151
[TestMethod]
5252
public void LogisticRegression()
5353
{
54-
new LogisticRegression() { Enabled = true }.Run();
54+
new LogisticRegression() { Enabled = true, TrainingEpochs=10, DataSize = 500, TestSize = 500 }.Run();
5555
}
5656

5757
[Ignore]
@@ -78,14 +78,14 @@ public void NamedEntityRecognition()
7878
[TestMethod]
7979
public void NearestNeighbor()
8080
{
81-
new NearestNeighbor() { Enabled = true }.Run();
81+
new NearestNeighbor() { Enabled = true, DataSize = 500, TestBatchSize = 100 }.Run();
8282
}
8383

8484
[Ignore]
8585
[TestMethod]
8686
public void TextClassificationTrain()
8787
{
88-
new TextClassificationTrain() { Enabled = true }.Run();
88+
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run();
8989
}
9090

9191
[Ignore]

0 commit comments

Comments
 (0)