@@ -28,7 +28,7 @@ public class KMeansClustering : Python, IExample
2828
2929 Datasets mnist ;
3030 NDArray full_data_x ;
31- int num_steps = 10 ; // Total steps to train
31+ int num_steps = 20 ; // Total steps to train
3232 int k = 25 ; // The number of clusters
3333 int num_classes = 10 ; // The 10 digits
3434 int num_features = 784 ; // Each image is 28x28 pixels
@@ -42,9 +42,9 @@ public bool Run()
4242 tf . train . import_meta_graph ( "graph/kmeans.meta" ) ;
4343
4444 // Input images
45- var X = graph . get_operation_by_name ( "Placeholder" ) . output ; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features));
45+ Tensor X = graph . get_operation_by_name ( "Placeholder" ) ; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features));
4646 // Labels (for assigning a label to a centroid and testing)
47- var Y = graph . get_operation_by_name ( "Placeholder_1" ) . output ; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes));
47+ Tensor Y = graph . get_operation_by_name ( "Placeholder_1" ) ; // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes));
4848
4949 // K-Means Parameters
5050 //var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true);
@@ -57,26 +57,24 @@ public bool Run()
5757 var train_op = graph . get_operation_by_name ( "group_deps" ) ;
5858 Tensor avg_distance = graph . get_operation_by_name ( "Mean" ) ;
5959 Tensor cluster_idx = graph . get_operation_by_name ( "Squeeze_1" ) ;
60+ NDArray result = null ;
6061
6162 with ( tf . Session ( graph ) , sess =>
6263 {
6364 sess . run ( init_vars , new FeedItem ( X , full_data_x ) ) ;
6465 sess . run ( init_op , new FeedItem ( X , full_data_x ) ) ;
6566
6667 // Training
67- NDArray result = null ;
6868 var sw = new Stopwatch ( ) ;
6969
7070 foreach ( var i in range ( 1 , num_steps + 1 ) )
7171 {
72- sw . Start ( ) ;
72+ sw . Restart ( ) ;
7373 result = sess . run ( new ITensorOrOperation [ ] { train_op , avg_distance , cluster_idx } , new FeedItem ( X , full_data_x ) ) ;
7474 sw . Stop ( ) ;
7575
76- if ( i % 5 == 0 || i == 1 )
76+ if ( i % 4 == 0 || i == 1 )
7777 print ( $ "Step { i } , Avg Distance: { result [ 1 ] } Elapse: { sw . ElapsedMilliseconds } ms") ;
78-
79- sw . Reset ( ) ;
8078 }
8179
8280 var idx = result [ 2 ] . Data < int > ( ) ;
@@ -102,9 +100,20 @@ public bool Run()
102100
103101 // Evaluation ops
104102 // Lookup: centroid_id -> label
103+ var cluster_label = tf . nn . embedding_lookup ( labels_map , cluster_idx ) ;
104+
105+ // Compute accuracy
106+ var correct_prediction = tf . equal ( cluster_label , tf . cast ( tf . argmax ( Y , 1 ) , tf . int32 ) ) ;
107+ var cast = tf . cast ( correct_prediction , tf . float32 ) ;
108+ var accuracy_op = tf . reduce_mean ( cast ) ;
109+
110+ // Test Model
111+ var ( test_x , test_y ) = ( mnist . test . images , mnist . test . labels ) ;
112+ result = sess . run ( accuracy_op , new FeedItem ( X , test_x ) , new FeedItem ( Y , test_y ) ) ;
113+ print ( $ "Test Accuracy: { result } ") ;
105114 } ) ;
106115
107- return false ;
116+ return ( float ) result > 0.70 ;
108117 }
109118
110119 public void PrepareData ( )
0 commit comments