Skip to content

Commit e152782

Browse files
committed
Finished kmean model.
1 parent 8503a52 commit e152782

7 files changed

Lines changed: 55 additions & 18 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ public static Tensor embedding_lookup(RefVariable @params,
2626
partition_strategy: partition_strategy,
2727
name: name);
2828

29+
public static Tensor embedding_lookup(Tensor @params,
30+
Tensor ids,
31+
string partition_strategy = "mod",
32+
string name = null) => embedding_ops._embedding_lookup_and_transform(new Tensor[] { @params },
33+
ids,
34+
partition_strategy: partition_strategy,
35+
name: name);
36+
2937
public static IActivation relu() => new relu();
3038

3139
public static Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name);

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,15 @@ private static Tensor _ReductionDims(Tensor x, int[] axis)
333333
else
334334
{
335335
var rank = common_shapes.rank(x);
336-
if (rank != null)
336+
337+
// we rely on Range and Rank to do the right thing at run-time.
338+
if (rank == -1) return range(0, array_ops.rank(x));
339+
340+
if (rank.HasValue && rank.Value > -1)
337341
{
338342
return constant_op.constant(np.arange(rank.Value), TF_DataType.TF_INT32);
339343
}
344+
340345
return range(0, rank, 1);
341346
}
342347
}

test/TensorFlowNET.Examples/KMeansClustering.cs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public class KMeansClustering : Python, IExample
2828

2929
Datasets mnist;
3030
NDArray full_data_x;
31-
int num_steps = 10; // Total steps to train
31+
int num_steps = 20; // Total steps to train
3232
int k = 25; // The number of clusters
3333
int num_classes = 10; // The 10 digits
3434
int num_features = 784; // Each image is 28x28 pixels
@@ -42,9 +42,9 @@ public bool Run()
4242
tf.train.import_meta_graph("graph/kmeans.meta");
4343

4444
// Input images
45-
var X = graph.get_operation_by_name("Placeholder").output; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features));
45+
Tensor X = graph.get_operation_by_name("Placeholder"); // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features));
4646
// Labels (for assigning a label to a centroid and testing)
47-
var Y = graph.get_operation_by_name("Placeholder_1").output; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes));
47+
Tensor Y = graph.get_operation_by_name("Placeholder_1"); // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes));
4848

4949
// K-Means Parameters
5050
//var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true);
@@ -57,26 +57,24 @@ public bool Run()
5757
var train_op = graph.get_operation_by_name("group_deps");
5858
Tensor avg_distance = graph.get_operation_by_name("Mean");
5959
Tensor cluster_idx = graph.get_operation_by_name("Squeeze_1");
60+
NDArray result = null;
6061

6162
with(tf.Session(graph), sess =>
6263
{
6364
sess.run(init_vars, new FeedItem(X, full_data_x));
6465
sess.run(init_op, new FeedItem(X, full_data_x));
6566

6667
// Training
67-
NDArray result = null;
6868
var sw = new Stopwatch();
6969

7070
foreach (var i in range(1, num_steps + 1))
7171
{
72-
sw.Start();
72+
sw.Restart();
7373
result = sess.run(new ITensorOrOperation[] { train_op, avg_distance, cluster_idx }, new FeedItem(X, full_data_x));
7474
sw.Stop();
7575

76-
if (i % 5 == 0 || i == 1)
76+
if (i % 4 == 0 || i == 1)
7777
print($"Step {i}, Avg Distance: {result[1]} Elapse: {sw.ElapsedMilliseconds}ms");
78-
79-
sw.Reset();
8078
}
8179

8280
var idx = result[2].Data<int>();
@@ -102,9 +100,20 @@ public bool Run()
102100

103101
// Evaluation ops
104102
// Lookup: centroid_id -> label
103+
var cluster_label = tf.nn.embedding_lookup(labels_map, cluster_idx);
104+
105+
// Compute accuracy
106+
var correct_prediction = tf.equal(cluster_label, tf.cast(tf.argmax(Y, 1), tf.int32));
107+
var cast = tf.cast(correct_prediction, tf.float32);
108+
var accuracy_op = tf.reduce_mean(cast);
109+
110+
// Test Model
111+
var (test_x, test_y) = (mnist.test.images, mnist.test.labels);
112+
result = sess.run(accuracy_op, new FeedItem(X, test_x), new FeedItem(Y, test_y));
113+
print($"Test Accuracy: {result}");
105114
});
106115

107-
return false;
116+
return (float)result > 0.70;
108117
}
109118

110119
public void PrepareData()

test/TensorFlowNET.Examples/NearestNeighbor.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ public bool Run()
5151
long nn_index = sess.run(pred, new FeedItem(xtr, Xtr), new FeedItem(xte, Xte[i]));
5252
// Get nearest neighbor class label and compare it to its true label
5353
int index = (int)nn_index;
54-
print($"Test {i} Prediction: {np.argmax(Ytr[index])} True Class: {np.argmax(Yte[i])}");
54+
55+
if (i % 10 == 0 || i == 0)
56+
print($"Test {i} Prediction: {np.argmax(Ytr[index])} True Class: {np.argmax(Yte[i])}");
57+
5558
// Calculate accuracy
5659
if ((int)np.argmax(Ytr[index]) == (int)np.argmax(Yte[i]))
5760
accuracy += 1f/ Xte.shape[0];
@@ -60,7 +63,7 @@ public bool Run()
6063
print($"Accuracy: {accuracy}");
6164
});
6265

63-
return accuracy > 0.9;
66+
return accuracy > 0.8;
6467
}
6568

6669
public void PrepareData()

test/TensorFlowNET.Examples/ObjectDetection.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace TensorFlowNET.Examples
1616
public class ObjectDetection : Python, IExample
1717
{
1818
public int Priority => 11;
19-
public bool Enabled { get; set; } = true;
19+
public bool Enabled { get; set; } = false;
2020
public string Name => "Object Detection";
2121
public float MIN_SCORE = 0.5f;
2222

test/TensorFlowNET.Examples/Program.cs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Diagnostics;
34
using System.Drawing;
45
using System.Linq;
56
using System.Reflection;
@@ -21,28 +22,39 @@ static void Main(string[] args)
2122
.OrderBy(x => x.Priority)
2223
.ToArray();
2324

25+
var sw = new Stopwatch();
2426
foreach (IExample example in examples)
2527
{
2628
if (args.Length > 0 && !args.Contains(example.Name))
2729
continue;
2830

2931
Console.WriteLine($"{DateTime.UtcNow} Starting {example.Name}", Color.White);
3032

33+
3134
try
3235
{
3336
if (example.Enabled)
34-
if (example.Run())
35-
success.Add($"Example {example.Priority}: {example.Name}");
37+
{
38+
sw.Restart();
39+
bool isSuccess = example.Run();
40+
sw.Stop();
41+
42+
if (isSuccess)
43+
success.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s");
3644
else
37-
errors.Add($"Example {example.Priority}: {example.Name}");
45+
errors.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s");
46+
}
3847
else
39-
disabled.Add($"Example {example.Priority}: {example.Name}");
48+
{
49+
disabled.Add($"Example {example.Priority}: {example.Name} in {sw.ElapsedMilliseconds}ms");
50+
}
4051
}
4152
catch (Exception ex)
4253
{
4354
errors.Add($"Example {example.Priority}: {example.Name}");
4455
Console.WriteLine(ex);
4556
}
57+
4658

4759
Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White);
4860
}

test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public void InceptionArchGoogLeNet()
4545
public void KMeansClustering()
4646
{
4747
tf.Graph().as_default();
48-
new KMeansClustering() { Enabled = false, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run();
48+
new KMeansClustering() { Enabled = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run();
4949
}
5050

5151
[TestMethod]

0 commit comments

Comments
 (0)