forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNearestNeighbor.cs
More file actions
74 lines (65 loc) · 2.92 KB
/
NearestNeighbor.cs
File metadata and controls
74 lines (65 loc) · 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
namespace TensorFlowNET.Examples
{
/// <summary>
/// A nearest neighbor learning algorithm example
/// This example is using the MNIST database of handwritten digits
/// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/nearest_neighbor.py
/// </summary>
public class NearestNeighbor : Python, IExample
{
public int Priority => 5;
public bool Enabled { get; set; } = true;
public string Name => "Nearest Neighbor";
Datasets mnist;
NDArray Xtr, Ytr, Xte, Yte;
public int? TrainSize = null;
public int ValidationSize = 5000;
public int? TestSize = null;
public bool Run()
{
// tf Graph Input
var xtr = tf.placeholder(tf.float32, new TensorShape(-1, 784));
var xte = tf.placeholder(tf.float32, new TensorShape(784));
// Nearest Neighbor calculation using L1 Distance
// Calculate L1 Distance
var distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices: 1);
// Prediction: Get min distance index (Nearest neighbor)
var pred = tf.arg_min(distance, 0);
float accuracy = 0f;
// Initialize the variables (i.e. assign their default value)
var init = tf.global_variables_initializer();
with(tf.Session(), sess =>
{
// Run the initializer
sess.run(init);
PrepareData();
foreach(int i in range(Xte.shape[0]))
{
// Get nearest neighbor
long nn_index = sess.run(pred, new FeedItem(xtr, Xtr), new FeedItem(xte, Xte[i]));
// Get nearest neighbor class label and compare it to its true label
int index = (int)nn_index;
print($"Test {i} Prediction: {np.argmax(Ytr[(NDArray)index])} True Class: {np.argmax(Yte[i] as NDArray)}");
// Calculate accuracy
if (np.argmax(Ytr[(NDArray)index]) == np.argmax(Yte[i] as NDArray))
accuracy += 1f/ Xte.shape[0];
}
print($"Accuracy: {accuracy}");
});
return accuracy > 0.9;
}
public void PrepareData()
{
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize);
// In this example, we limit mnist data
(Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing
}
}
}