Skip to content

Commit 4cdf6b9

Browse files
committed
implementing NB classifier
1 parent f8b618c commit 4cdf6b9

6 files changed

Lines changed: 100 additions & 17 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.tile.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System;
1+
using NumSharp.Core;
2+
using System;
23
using System.Collections.Generic;
34
using System.Text;
45

@@ -9,6 +10,9 @@ public static partial class tf
910
public static Tensor tile(Tensor input,
1011
Tensor multiples,
1112
string name = null) => gen_array_ops.tile(input, multiples, name);
13+
public static Tensor tile(NDArray input,
14+
int[] multiples,
15+
string name = null) => gen_array_ops.tile(input, multiples, name);
1216

1317
}
1418
}

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,10 @@ public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.D
224224

225225
/// <summary>
226226
/// When building ops to compute gradients, this op prevents the contribution of
227-
/// its inputs to be taken into account.Normally, the gradient generator adds ops /// to a graph to compute the derivatives of a specified 'loss' by recursively /// finding out inputs that contributed to its computation.If you insert this op /// in the graph it inputs are masked from the gradient generator. They are not
227+
/// its inputs to be taken into account.Normally, the gradient generator adds ops
228+
/// to a graph to compute the derivatives of a specified 'loss' by recursively
229+
/// finding out inputs that contributed to its computation.If you insert this op
230+
/// in the graph it inputs are masked from the gradient generator. They are not
228231
/// taken into account for computing gradients.
229232
/// </summary>
230233
/// <param name="input"></param>

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System;
1+
using NumSharp.Core;
2+
using System;
23
using System.Collections.Generic;
34
using System.IO;
45
using System.Text;
@@ -156,6 +157,11 @@ public static Tensor tile(Tensor input, Tensor multiples, string name = null)
156157
var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples });
157158
return _op.outputs[0];
158159
}
160+
public static Tensor tile(NDArray input, int[] multiples, string name = null)
161+
{
162+
var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples });
163+
return _op.outputs[0];
164+
}
159165

160166
public static Tensor transpose(Tensor x, int[] perm, string name = null)
161167
{

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public static TF_DataType as_dtype(Type type)
4747
dtype = TF_DataType.TF_STRING;
4848
break;
4949
default:
50-
throw new Exception("Not Implemented");
50+
throw new Exception("as_dtype Not Implemented");
5151
}
5252

5353
return dtype;

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,8 @@ public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype
441441

442442
switch (value)
443443
{
444+
case NDArray nd:
445+
return constant_op.constant(nd, dtype: dtype, name: name);
444446
case Tensor tensor:
445447
return tensor;
446448
case string str:

test/TensorFlowNET.Examples/NaiveBayesClassifier.cs

Lines changed: 81 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,42 +15,107 @@ public class NaiveBayesClassifier : Python, IExample
1515
public Normal dist { get; set; }
1616
public void Run()
1717
{
18-
np.array<float>(1.0f, 1.0f);
19-
var X = np.array<float>(new float[][] { new float[] { 1.0f, 1.0f }, new float[] { 2.0f, 2.0f }, new float[] { -1.0f, -1.0f }, new float[] { -2.0f, -2.0f }, new float[] { 1.0f, -1.0f }, new float[] { 2.0f, -2.0f }, });
20-
var y = np.array<int>(0,0,1,1,2,2);
18+
var X = np.array<double>(new double[][] { new double[] { 5.1, 3.5},new double[] { 4.9, 3.0 },new double[] { 4.7, 3.2 },
19+
new double[] { 4.6, 3.1 },new double[] { 5.0, 3.6 },new double[] { 5.4, 3.9 },
20+
new double[] { 4.6, 3.4 },new double[] { 5.0, 3.4 },new double[] { 4.4, 2.9 },
21+
new double[] { 4.9, 3.1 },new double[] { 5.4, 3.7 },new double[] {4.8, 3.4 },
22+
new double[] {4.8, 3.0 },new double[] {4.3, 3.0 },new double[] {5.8, 4.0 },
23+
new double[] {5.7, 4.4 },new double[] {5.4, 3.9 },new double[] {5.1, 3.5 },
24+
new double[] {5.7, 3.8 },new double[] {5.1, 3.8 },new double[] {5.4, 3.4 },
25+
new double[] {5.1, 3.7 },new double[] {5.1, 3.3 },new double[] {4.8, 3.4 },
26+
new double[] {5.0 , 3.0 },new double[] {5.0 , 3.4 },new double[] {5.2, 3.5 },
27+
new double[] {5.2, 3.4 },new double[] {4.7, 3.2 },new double[] {4.8, 3.1 },
28+
new double[] {5.4, 3.4 },new double[] {5.2, 4.1},new double[] {5.5, 4.2 },
29+
new double[] {4.9, 3.1 },new double[] {5.0 , 3.2 },new double[] {5.5, 3.5 },
30+
new double[] {4.9, 3.6 },new double[] {4.4, 3.0 },new double[] {5.1, 3.4 },
31+
new double[] {5.0 , 3.5 },new double[] {4.5, 2.3 },new double[] {4.4, 3.2 },
32+
new double[] {5.0 , 3.5 },new double[] {5.1, 3.8 },new double[] {4.8, 3.0},
33+
new double[] {5.1, 3.8 },new double[] {4.6, 3.2 },new double[] { 5.3, 3.7 },
34+
new double[] {5.0 , 3.3 },new double[] {7.0 , 3.2 },new double[] {6.4, 3.2 },
35+
new double[] {6.9, 3.1 },new double[] {5.5, 2.3 },new double[] {6.5, 2.8 },
36+
new double[] {5.7, 2.8 },new double[] {6.3, 3.3 },new double[] {4.9, 2.4 },
37+
new double[] {6.6, 2.9 },new double[] {5.2, 2.7 },new double[] {5.0 , 2.0 },
38+
new double[] {5.9, 3.0 },new double[] {6.0 , 2.2 },new double[] {6.1, 2.9 },
39+
new double[] {5.6, 2.9 },new double[] {6.7, 3.1 },new double[] {5.6, 3.0 },
40+
new double[] {5.8, 2.7 },new double[] {6.2, 2.2 },new double[] {5.6, 2.5 },
41+
new double[] {5.9, 3.0},new double[] {6.1, 2.8},new double[] {6.3, 2.5},
42+
new double[] {6.1, 2.8},new double[] {6.4, 2.9},new double[] {6.6, 3.0 },
43+
new double[] {6.8, 2.8},new double[] {6.7, 3.0 },new double[] {6.0 , 2.9},
44+
new double[] {5.7, 2.6},new double[] {5.5, 2.4},new double[] {5.5, 2.4},
45+
new double[] {5.8, 2.7},new double[] {6.0 , 2.7},new double[] {5.4, 3.0 },
46+
new double[] {6.0 , 3.4},new double[] {6.7, 3.1},new double[] {6.3, 2.3},
47+
new double[] {5.6, 3.0 },new double[] {5.5, 2.5},new double[] {5.5, 2.6},
48+
new double[] {6.1, 3.0 },new double[] {5.8, 2.6},new double[] {5.0 , 2.3},
49+
new double[] {5.6, 2.7},new double[] {5.7, 3.0 },new double[] {5.7, 2.9},
50+
new double[] {6.2, 2.9},new double[] {5.1, 2.5},new double[] {5.7, 2.8},
51+
new double[] {6.3, 3.3},new double[] {5.8, 2.7},new double[] {7.1, 3.0 },
52+
new double[] {6.3, 2.9},new double[] {6.5, 3.0 },new double[] {7.6, 3.0 },
53+
new double[] {4.9, 2.5},new double[] {7.3, 2.9},new double[] {6.7, 2.5},
54+
new double[] {7.2, 3.6},new double[] {6.5, 3.2},new double[] {6.4, 2.7},
55+
new double[] {6.8, 3.00 },new double[] {5.7, 2.5},new double[] {5.8, 2.8},
56+
new double[] {6.4, 3.2},new double[] {6.5, 3.0 },new double[] {7.7, 3.8},
57+
new double[] {7.7, 2.6},new double[] {6.0 , 2.2},new double[] {6.9, 3.2},
58+
new double[] {5.6, 2.8},new double[] {7.7, 2.8},new double[] {6.3, 2.7},
59+
new double[] {6.7, 3.3},new double[] {7.2, 3.2},new double[] {6.2, 2.8},
60+
new double[] {6.1, 3.0 },new double[] {6.4, 2.8},new double[] {7.2, 3.0 },
61+
new double[] {7.4, 2.8},new double[] {7.9, 3.8},new double[] {6.4, 2.8},
62+
new double[] {6.3, 2.8},new double[] {6.1, 2.6},new double[] {7.7, 3.0 },
63+
new double[] {6.3, 3.4},new double[] {6.4, 3.1},new double[] {6.0, 3.0},
64+
new double[] {6.9, 3.1},new double[] {6.7, 3.1},new double[] {6.9, 3.1},
65+
new double[] {5.8, 2.7},new double[] {6.8, 3.2},new double[] {6.7, 3.3},
66+
new double[] {6.7, 3.0 },new double[] {6.3, 2.5},new double[] {6.5, 3.0 },
67+
new double[] {6.2, 3.4},new double[] {5.9, 3.0 }, new double[] {5.8, 3.0 }});
68+
69+
var y = np.array<int>(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
70+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
71+
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
72+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
73+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
74+
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
75+
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2);
2176
fit(X, y);
2277
// Create a regular grid and classify each point
78+
double x_min = (double) X.amin(0)[0] - 0.5;
79+
double y_min = (double) X.amin(0)[1] - 0.5;
80+
double x_max = (double) X.amax(0)[0] + 0.5;
81+
double y_max = (double) X.amax(0)[1] + 0.5;
82+
83+
var (xx, yy) = np.meshgrid(np.linspace(x_min, x_max, 30), np.linspace(y_min, y_max, 30));
84+
var s = tf.Session();
85+
var samples = np.vstack(xx.ravel(), yy.ravel());
86+
var Z = s.run(predict(samples));
87+
2388
}
2489

2590
public void fit(NDArray X, NDArray y)
2691
{
2792
NDArray unique_y = y.unique<long>();
2893

29-
Dictionary<long, List<List<float>>> dic = new Dictionary<long, List<List<float>>>();
94+
Dictionary<long, List<List<double>>> dic = new Dictionary<long, List<List<double>>>();
3095
// Init uy in dic
3196
foreach (int uy in unique_y.Data<int>())
3297
{
33-
dic.Add(uy, new List<List<float>>());
98+
dic.Add(uy, new List<List<double>>());
3499
}
35100
// Separate training points by class
36101
// Shape : nb_classes * nb_samples * nb_features
37102
int maxCount = 0;
38103
for (int i = 0; i < y.size; i++)
39104
{
40105
long curClass = (long)y[i];
41-
List<List<float>> l = dic[curClass];
42-
List<float> pair = new List<float>();
43-
pair.Add((float)X[i,0]);
44-
pair.Add((float)X[i, 1]);
106+
List<List<double>> l = dic[curClass];
107+
List<double> pair = new List<double>();
108+
pair.Add((double)X[i,0]);
109+
pair.Add((double)X[i, 1]);
45110
l.Add(pair);
46111
if (l.Count > maxCount)
47112
{
48113
maxCount = l.Count;
49114
}
50115
dic[curClass] = l;
51116
}
52-
float[,,] points = new float[dic.Count, maxCount, X.shape[1]];
53-
foreach (KeyValuePair<long, List<List<float>>> kv in dic)
117+
double[,,] points = new double[dic.Count, maxCount, X.shape[1]];
118+
foreach (KeyValuePair<long, List<List<double>>> kv in dic)
54119
{
55120
int j = (int) kv.Key;
56121
for (int i = 0; i < maxCount; i++)
@@ -62,7 +127,7 @@ public void fit(NDArray X, NDArray y)
62127
}
63128

64129
}
65-
NDArray points_by_class = np.array<float>(points);
130+
NDArray points_by_class = np.array<double>(points);
66131
// estimate mean and variance for each class / feature
67132
// shape : nb_classes * nb_features
68133
var cons = tf.constant(points_by_class);
@@ -87,7 +152,10 @@ public Tensor predict (NDArray X)
87152

88153
// Conditional probabilities log P(x|c) with shape
89154
// (nb_samples, nb_classes)
90-
Tensor tile = tf.tile(new Tensor(X), new Tensor(new int[] { -1, nb_classes, nb_features }));
155+
var t1= ops.convert_to_tensor(X, TF_DataType.TF_DOUBLE);
156+
//var t2 = ops.convert_to_tensor(new int[] { 1, nb_classes });
157+
//Tensor tile = tf.tile(t1, t2);
158+
Tensor tile = tf.tile(X, new int[] { 1, nb_classes });
91159
Tensor r = tf.reshape(tile, new Tensor(new int[] { -1, nb_classes, nb_features }));
92160
var cond_probs = tf.reduce_sum(dist.log_prob(r));
93161
// uniform priors

0 commit comments

Comments
 (0)