11using NumSharp ;
22using System ;
33using System . Collections . Generic ;
4+ using System . Diagnostics ;
5+ using System . Linq ;
46using System . Text ;
57using Tensorflow ;
68using Tensorflow . Clustering ;
@@ -26,7 +28,7 @@ public class KMeansClustering : Python, IExample
2628
2729 Datasets mnist ;
2830 NDArray full_data_x ;
29- int num_steps = 50 ; // Total steps to train
31+ int num_steps = 10 ; // Total steps to train
3032 int k = 25 ; // The number of clusters
3133 int num_classes = 10 ; // The 10 digits
3234 int num_features = 784 ; // Each image is 28x28 pixels
@@ -63,22 +65,43 @@ public bool Run()
6365
6466 // Training
6567 NDArray result = null ;
66- foreach ( var i in range ( 1 , num_steps + 1 ) )
68+ var sw = new Stopwatch ( ) ;
69+
70+ foreach ( var i in range ( 1 , num_steps + 1 ) )
6771 {
72+ sw . Start ( ) ;
6873 result = sess . run ( new ITensorOrOperation [ ] { train_op , avg_distance , cluster_idx } , new FeedItem ( X , full_data_x ) ) ;
69- if ( i % 2 == 0 || i == 1 )
70- print ( $ "Step { i } , Avg Distance: { result [ 1 ] } ") ;
74+ sw . Stop ( ) ;
75+
76+ if ( i % 5 == 0 || i == 1 )
77+ print ( $ "Step { i } , Avg Distance: { result [ 1 ] } Elapse: { sw . ElapsedMilliseconds } ms") ;
78+
79+ sw . Reset ( ) ;
7180 }
7281
73- var idx = result [ 2 ] ;
82+ var idx = result [ 2 ] . Data < int > ( ) ;
7483
7584 // Assign a label to each centroid
7685 // Count total number of labels per centroid, using the label of each training
7786 // sample to their closest centroid (given by 'idx')
78- var counts = np . zeros ( k , num_classes ) ;
79- foreach ( var i in range ( idx . len ) )
80- counts [ idx [ i ] ] += mnist . train . labels [ i ] ;
87+ var counts = np . zeros ( ( k , num_classes ) , np . float32 ) ;
88+
89+ sw . Start ( ) ;
90+ foreach ( var i in range ( idx . Length ) )
91+ {
92+ var x = mnist . train . labels [ i ] ;
93+ counts [ idx [ i ] ] += x ;
94+ }
95+
96+ sw . Stop ( ) ;
97+ print ( $ "Assign a label to each centroid took { sw . ElapsedMilliseconds } ms") ;
98+
99+ // Assign the most frequent label to the centroid
100+ var labels_map_array = np . argmax ( counts , 1 ) ;
101+ var labels_map = tf . convert_to_tensor ( labels_map_array ) ;
81102
103+ // Evaluation ops
104+ // Lookup: centroid_id -> label
82105 } ) ;
83106
84107 return false ;
0 commit comments