@@ -15,42 +15,107 @@ public class NaiveBayesClassifier : Python, IExample
1515 public Normal dist { get ; set ; }
1616 public void Run ( )
1717 {
18- np . array < float > ( 1.0f , 1.0f ) ;
19- var X = np . array < float > ( new float [ ] [ ] { new float [ ] { 1.0f , 1.0f } , new float [ ] { 2.0f , 2.0f } , new float [ ] { - 1.0f , - 1.0f } , new float [ ] { - 2.0f , - 2.0f } , new float [ ] { 1.0f , - 1.0f } , new float [ ] { 2.0f , - 2.0f } , } ) ;
20- var y = np . array < int > ( 0 , 0 , 1 , 1 , 2 , 2 ) ;
18+ var X = np . array < double > ( new double [ ] [ ] { new double [ ] { 5.1 , 3.5 } , new double [ ] { 4.9 , 3.0 } , new double [ ] { 4.7 , 3.2 } ,
19+ new double [ ] { 4.6 , 3.1 } , new double [ ] { 5.0 , 3.6 } , new double [ ] { 5.4 , 3.9 } ,
20+ new double [ ] { 4.6 , 3.4 } , new double [ ] { 5.0 , 3.4 } , new double [ ] { 4.4 , 2.9 } ,
21+ new double [ ] { 4.9 , 3.1 } , new double [ ] { 5.4 , 3.7 } , new double [ ] { 4.8 , 3.4 } ,
22+ new double [ ] { 4.8 , 3.0 } , new double [ ] { 4.3 , 3.0 } , new double [ ] { 5.8 , 4.0 } ,
23+ new double [ ] { 5.7 , 4.4 } , new double [ ] { 5.4 , 3.9 } , new double [ ] { 5.1 , 3.5 } ,
24+ new double [ ] { 5.7 , 3.8 } , new double [ ] { 5.1 , 3.8 } , new double [ ] { 5.4 , 3.4 } ,
25+ new double [ ] { 5.1 , 3.7 } , new double [ ] { 5.1 , 3.3 } , new double [ ] { 4.8 , 3.4 } ,
26+ new double [ ] { 5.0 , 3.0 } , new double [ ] { 5.0 , 3.4 } , new double [ ] { 5.2 , 3.5 } ,
27+ new double [ ] { 5.2 , 3.4 } , new double [ ] { 4.7 , 3.2 } , new double [ ] { 4.8 , 3.1 } ,
28+ new double [ ] { 5.4 , 3.4 } , new double [ ] { 5.2 , 4.1 } , new double [ ] { 5.5 , 4.2 } ,
29+ new double [ ] { 4.9 , 3.1 } , new double [ ] { 5.0 , 3.2 } , new double [ ] { 5.5 , 3.5 } ,
30+ new double [ ] { 4.9 , 3.6 } , new double [ ] { 4.4 , 3.0 } , new double [ ] { 5.1 , 3.4 } ,
31+ new double [ ] { 5.0 , 3.5 } , new double [ ] { 4.5 , 2.3 } , new double [ ] { 4.4 , 3.2 } ,
32+ new double [ ] { 5.0 , 3.5 } , new double [ ] { 5.1 , 3.8 } , new double [ ] { 4.8 , 3.0 } ,
33+ new double [ ] { 5.1 , 3.8 } , new double [ ] { 4.6 , 3.2 } , new double [ ] { 5.3 , 3.7 } ,
34+ new double [ ] { 5.0 , 3.3 } , new double [ ] { 7.0 , 3.2 } , new double [ ] { 6.4 , 3.2 } ,
35+ new double [ ] { 6.9 , 3.1 } , new double [ ] { 5.5 , 2.3 } , new double [ ] { 6.5 , 2.8 } ,
36+ new double [ ] { 5.7 , 2.8 } , new double [ ] { 6.3 , 3.3 } , new double [ ] { 4.9 , 2.4 } ,
37+ new double [ ] { 6.6 , 2.9 } , new double [ ] { 5.2 , 2.7 } , new double [ ] { 5.0 , 2.0 } ,
38+ new double [ ] { 5.9 , 3.0 } , new double [ ] { 6.0 , 2.2 } , new double [ ] { 6.1 , 2.9 } ,
39+ new double [ ] { 5.6 , 2.9 } , new double [ ] { 6.7 , 3.1 } , new double [ ] { 5.6 , 3.0 } ,
40+ new double [ ] { 5.8 , 2.7 } , new double [ ] { 6.2 , 2.2 } , new double [ ] { 5.6 , 2.5 } ,
41+ new double [ ] { 5.9 , 3.0 } , new double [ ] { 6.1 , 2.8 } , new double [ ] { 6.3 , 2.5 } ,
42+ new double [ ] { 6.1 , 2.8 } , new double [ ] { 6.4 , 2.9 } , new double [ ] { 6.6 , 3.0 } ,
43+ new double [ ] { 6.8 , 2.8 } , new double [ ] { 6.7 , 3.0 } , new double [ ] { 6.0 , 2.9 } ,
44+ new double [ ] { 5.7 , 2.6 } , new double [ ] { 5.5 , 2.4 } , new double [ ] { 5.5 , 2.4 } ,
45+ new double [ ] { 5.8 , 2.7 } , new double [ ] { 6.0 , 2.7 } , new double [ ] { 5.4 , 3.0 } ,
46+ new double [ ] { 6.0 , 3.4 } , new double [ ] { 6.7 , 3.1 } , new double [ ] { 6.3 , 2.3 } ,
47+ new double [ ] { 5.6 , 3.0 } , new double [ ] { 5.5 , 2.5 } , new double [ ] { 5.5 , 2.6 } ,
48+ new double [ ] { 6.1 , 3.0 } , new double [ ] { 5.8 , 2.6 } , new double [ ] { 5.0 , 2.3 } ,
49+ new double [ ] { 5.6 , 2.7 } , new double [ ] { 5.7 , 3.0 } , new double [ ] { 5.7 , 2.9 } ,
50+ new double [ ] { 6.2 , 2.9 } , new double [ ] { 5.1 , 2.5 } , new double [ ] { 5.7 , 2.8 } ,
51+ new double [ ] { 6.3 , 3.3 } , new double [ ] { 5.8 , 2.7 } , new double [ ] { 7.1 , 3.0 } ,
52+ new double [ ] { 6.3 , 2.9 } , new double [ ] { 6.5 , 3.0 } , new double [ ] { 7.6 , 3.0 } ,
53+ new double [ ] { 4.9 , 2.5 } , new double [ ] { 7.3 , 2.9 } , new double [ ] { 6.7 , 2.5 } ,
54+ new double [ ] { 7.2 , 3.6 } , new double [ ] { 6.5 , 3.2 } , new double [ ] { 6.4 , 2.7 } ,
55+ new double [ ] { 6.8 , 3.00 } , new double [ ] { 5.7 , 2.5 } , new double [ ] { 5.8 , 2.8 } ,
56+ new double [ ] { 6.4 , 3.2 } , new double [ ] { 6.5 , 3.0 } , new double [ ] { 7.7 , 3.8 } ,
57+ new double [ ] { 7.7 , 2.6 } , new double [ ] { 6.0 , 2.2 } , new double [ ] { 6.9 , 3.2 } ,
58+ new double [ ] { 5.6 , 2.8 } , new double [ ] { 7.7 , 2.8 } , new double [ ] { 6.3 , 2.7 } ,
59+ new double [ ] { 6.7 , 3.3 } , new double [ ] { 7.2 , 3.2 } , new double [ ] { 6.2 , 2.8 } ,
60+ new double [ ] { 6.1 , 3.0 } , new double [ ] { 6.4 , 2.8 } , new double [ ] { 7.2 , 3.0 } ,
61+ new double [ ] { 7.4 , 2.8 } , new double [ ] { 7.9 , 3.8 } , new double [ ] { 6.4 , 2.8 } ,
62+ new double [ ] { 6.3 , 2.8 } , new double [ ] { 6.1 , 2.6 } , new double [ ] { 7.7 , 3.0 } ,
63+ new double [ ] { 6.3 , 3.4 } , new double [ ] { 6.4 , 3.1 } , new double [ ] { 6.0 , 3.0 } ,
64+ new double [ ] { 6.9 , 3.1 } , new double [ ] { 6.7 , 3.1 } , new double [ ] { 6.9 , 3.1 } ,
65+ new double [ ] { 5.8 , 2.7 } , new double [ ] { 6.8 , 3.2 } , new double [ ] { 6.7 , 3.3 } ,
66+ new double [ ] { 6.7 , 3.0 } , new double [ ] { 6.3 , 2.5 } , new double [ ] { 6.5 , 3.0 } ,
67+ new double [ ] { 6.2 , 3.4 } , new double [ ] { 5.9 , 3.0 } , new double [ ] { 5.8 , 3.0 } } ) ;
68+
69+ var y = np . array < int > ( 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
70+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
71+ 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
72+ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
73+ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ,
74+ 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ,
75+ 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ) ;
2176 fit ( X , y ) ;
2277 // Create a regular grid and classify each point
78+ double x_min = ( double ) X . amin ( 0 ) [ 0 ] - 0.5 ;
79+ double y_min = ( double ) X . amin ( 0 ) [ 1 ] - 0.5 ;
80+ double x_max = ( double ) X . amax ( 0 ) [ 0 ] + 0.5 ;
81+ double y_max = ( double ) X . amax ( 0 ) [ 1 ] + 0.5 ;
82+
83+ var ( xx , yy ) = np . meshgrid ( np . linspace ( x_min , x_max , 30 ) , np . linspace ( y_min , y_max , 30 ) ) ;
84+ var s = tf . Session ( ) ;
85+ var samples = np . vstack ( xx . ravel ( ) , yy . ravel ( ) ) ;
86+ var Z = s . run ( predict ( samples ) ) ;
87+
2388 }
2489
2590 public void fit ( NDArray X , NDArray y )
2691 {
2792 NDArray unique_y = y . unique < long > ( ) ;
2893
29- Dictionary < long , List < List < float > > > dic = new Dictionary < long , List < List < float > > > ( ) ;
94+ Dictionary < long , List < List < double > > > dic = new Dictionary < long , List < List < double > > > ( ) ;
3095 // Init uy in dic
3196 foreach ( int uy in unique_y . Data < int > ( ) )
3297 {
33- dic . Add ( uy , new List < List < float > > ( ) ) ;
98+ dic . Add ( uy , new List < List < double > > ( ) ) ;
3499 }
35100 // Separate training points by class
36101 // Shape : nb_classes * nb_samples * nb_features
37102 int maxCount = 0 ;
38103 for ( int i = 0 ; i < y . size ; i ++ )
39104 {
40105 long curClass = ( long ) y [ i ] ;
41- List < List < float > > l = dic [ curClass ] ;
42- List < float > pair = new List < float > ( ) ;
43- pair . Add ( ( float ) X [ i , 0 ] ) ;
44- pair . Add ( ( float ) X [ i , 1 ] ) ;
106+ List < List < double > > l = dic [ curClass ] ;
107+ List < double > pair = new List < double > ( ) ;
108+ pair . Add ( ( double ) X [ i , 0 ] ) ;
109+ pair . Add ( ( double ) X [ i , 1 ] ) ;
45110 l . Add ( pair ) ;
46111 if ( l . Count > maxCount )
47112 {
48113 maxCount = l . Count ;
49114 }
50115 dic [ curClass ] = l ;
51116 }
52- float [ , , ] points = new float [ dic . Count , maxCount , X . shape [ 1 ] ] ;
53- foreach ( KeyValuePair < long , List < List < float > > > kv in dic )
117+ double [ , , ] points = new double [ dic . Count , maxCount , X . shape [ 1 ] ] ;
118+ foreach ( KeyValuePair < long , List < List < double > > > kv in dic )
54119 {
55120 int j = ( int ) kv . Key ;
56121 for ( int i = 0 ; i < maxCount ; i ++ )
@@ -62,7 +127,7 @@ public void fit(NDArray X, NDArray y)
62127 }
63128
64129 }
65- NDArray points_by_class = np . array < float > ( points ) ;
130+ NDArray points_by_class = np . array < double > ( points ) ;
66131 // estimate mean and variance for each class / feature
67132 // shape : nb_classes * nb_features
68133 var cons = tf . constant ( points_by_class ) ;
@@ -87,7 +152,10 @@ public Tensor predict (NDArray X)
87152
88153 // Conditional probabilities log P(x|c) with shape
89154 // (nb_samples, nb_classes)
90- Tensor tile = tf . tile ( new Tensor ( X ) , new Tensor ( new int [ ] { - 1 , nb_classes , nb_features } ) ) ;
155+ var t1 = ops . convert_to_tensor ( X , TF_DataType . TF_DOUBLE ) ;
156+ //var t2 = ops.convert_to_tensor(new int[] { 1, nb_classes });
157+ //Tensor tile = tf.tile(t1, t2);
158+ Tensor tile = tf . tile ( X , new int [ ] { 1 , nb_classes } ) ;
91159 Tensor r = tf . reshape ( tile , new Tensor ( new int [ ] { - 1 , nb_classes , nb_features } ) ) ;
92160 var cond_probs = tf . reduce_sum ( dist . log_prob ( r ) ) ;
93161 // uniform priors
0 commit comments