@@ -73,7 +73,9 @@ public class DigitRecognitionCNN : IExample
7373 float accuracy_test = 0f ;
7474 float loss_test = 1f ;
7575
76- NDArray x_train ;
76+ NDArray x_train , y_train ;
77+ NDArray x_valid , y_valid ;
78+ NDArray x_test , y_test ;
7779
7880 public bool Run ( )
7981 {
@@ -135,6 +137,62 @@ public Graph BuildGraph()
135137 return graph ;
136138 }
137139
140+ public void Train ( Session sess )
141+ {
142+ // Number of training iterations in each epoch
143+ var num_tr_iter = y_train . len / batch_size ;
144+
145+ var init = tf . global_variables_initializer ( ) ;
146+ sess . run ( init ) ;
147+
148+ float loss_val = 100.0f ;
149+ float accuracy_val = 0f ;
150+
151+ foreach ( var epoch in range ( epochs ) )
152+ {
153+ print ( $ "Training epoch: { epoch + 1 } ") ;
154+ // Randomly shuffle the training data at the beginning of each epoch
155+ ( x_train , y_train ) = mnist . Randomize ( x_train , y_train ) ;
156+
157+ foreach ( var iteration in range ( num_tr_iter ) )
158+ {
159+ var start = iteration * batch_size ;
160+ var end = ( iteration + 1 ) * batch_size ;
161+ var ( x_batch , y_batch ) = mnist . GetNextBatch ( x_train , y_train , start , end ) ;
162+
163+ // Run optimization op (backprop)
164+ sess . run ( optimizer , new FeedItem ( x , x_batch ) , new FeedItem ( y , y_batch ) ) ;
165+
166+ if ( iteration % display_freq == 0 )
167+ {
168+ // Calculate and display the batch loss and accuracy
169+ var result = sess . run ( new [ ] { loss , accuracy } , new FeedItem ( x , x_batch ) , new FeedItem ( y , y_batch ) ) ;
170+ loss_val = result [ 0 ] ;
171+ accuracy_val = result [ 1 ] ;
172+ print ( $ "iter { iteration . ToString ( "000" ) } : Loss={ loss_val . ToString ( "0.0000" ) } , Training Accuracy={ accuracy_val . ToString ( "P" ) } ") ;
173+ }
174+ }
175+
176+ // Run validation after every epoch
177+ var results1 = sess . run ( new [ ] { loss , accuracy } , new FeedItem ( x , x_valid ) , new FeedItem ( y , y_valid ) ) ;
178+ loss_val = results1 [ 0 ] ;
179+ accuracy_val = results1 [ 1 ] ;
180+ print ( "---------------------------------------------------------" ) ;
181+ print ( $ "Epoch: { epoch + 1 } , validation loss: { loss_val . ToString ( "0.0000" ) } , validation accuracy: { accuracy_val . ToString ( "P" ) } ") ;
182+ print ( "---------------------------------------------------------" ) ;
183+ }
184+ }
185+
186+ public void Test ( Session sess )
187+ {
188+ var result = sess . run ( new [ ] { loss , accuracy } , new FeedItem ( x , x_test ) , new FeedItem ( y , y_test ) ) ;
189+ loss_test = result [ 0 ] ;
190+ accuracy_test = result [ 1 ] ;
191+ print ( "---------------------------------------------------------" ) ;
192+ print ( $ "Test loss: { loss_test . ToString ( "0.0000" ) } , test accuracy: { accuracy_test . ToString ( "P" ) } ") ;
193+ print ( "---------------------------------------------------------" ) ;
194+ }
195+
138196 /// <summary>
139197 /// Create a 2D convolution layer
140198 /// </summary>
@@ -219,6 +277,14 @@ private RefVariable bias_variable(string name, int[] shape)
219277 initializer : initial ) ;
220278 }
221279
280+ /// <summary>
281+ /// Create a fully-connected layer
282+ /// </summary>
283+ /// <param name="x">input from previous layer</param>
284+ /// <param name="num_units">number of hidden units in the fully-connected layer</param>
285+ /// <param name="name">layer name</param>
286+ /// <param name="use_relu">boolean to add ReLU non-linearity (or not)</param>
287+ /// <returns>The output array</returns>
222288 private Tensor fc_layer ( Tensor x , int num_units , string name , bool use_relu = true )
223289 {
224290 return with ( tf . variable_scope ( name ) , delegate
@@ -235,81 +301,36 @@ private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = tr
235301 return layer ;
236302 } ) ;
237303 }
238-
239- public Graph ImportGraph ( ) => throw new NotImplementedException ( ) ;
240-
241- public void Predict ( Session sess ) => throw new NotImplementedException ( ) ;
242304
243305 public void PrepareData ( )
244306 {
245307 mnist = MNIST . read_data_sets ( "mnist" , one_hot : true ) ;
246- x_train = Reformat ( mnist . train . data , mnist . train . labels ) ;
308+ ( x_train , y_train ) = Reformat ( mnist . train . data , mnist . train . labels ) ;
309+ ( x_valid , y_valid ) = Reformat ( mnist . validation . data , mnist . validation . labels ) ;
310+ ( x_test , y_test ) = Reformat ( mnist . test . data , mnist . test . labels ) ;
311+
247312 print ( "Size of:" ) ;
248313 print ( $ "- Training-set:\t \t { len ( mnist . train . data ) } ") ;
249314 print ( $ "- Validation-set:\t { len ( mnist . validation . data ) } ") ;
250315 }
251316
252- private NDArray Reformat ( NDArray x , NDArray y )
317+ /// <summary>
318+ /// Reformats the data to the format acceptable for convolutional layers
319+ /// </summary>
320+ /// <param name="x"></param>
321+ /// <param name="y"></param>
322+ /// <returns></returns>
323+ private ( NDArray , NDArray ) Reformat ( NDArray x , NDArray y )
253324 {
254- var ( img_size , num_ch , num_class ) = ( np . sqrt ( x . shape [ 1 ] ) , 1 , np . unique < int > ( np . argmax ( y , 1 ) ) ) ;
255-
256- return x ;
325+ var ( img_size , num_ch , num_class ) = ( np . sqrt ( x . shape [ 1 ] ) , 1 , len ( np . unique < int > ( np . argmax ( y , 1 ) ) ) ) ;
326+ var dataset = x . reshape ( x . shape [ 0 ] , img_size , img_size , num_ch ) . astype ( np . float32 ) ;
327+ //y[0] = np.arange(num_class) == y[0];
328+ //var labels = (np.arange(num_class) == y.reshape(y.shape[0], 1, y.shape[1])).astype(np.float32);
329+ return ( dataset , y ) ;
257330 }
258331
259- public void Train ( Session sess )
260- {
261- // Number of training iterations in each epoch
262- var num_tr_iter = mnist . train . labels . len / batch_size ;
263-
264- var init = tf . global_variables_initializer ( ) ;
265- sess . run ( init ) ;
266-
267- float loss_val = 100.0f ;
268- float accuracy_val = 0f ;
269-
270- foreach ( var epoch in range ( epochs ) )
271- {
272- print ( $ "Training epoch: { epoch + 1 } ") ;
273- // Randomly shuffle the training data at the beginning of each epoch
274- var ( x_train , y_train ) = mnist . Randomize ( mnist . train . data , mnist . train . labels ) ;
275-
276- foreach ( var iteration in range ( num_tr_iter ) )
277- {
278- var start = iteration * batch_size ;
279- var end = ( iteration + 1 ) * batch_size ;
280- var ( x_batch , y_batch ) = mnist . GetNextBatch ( x_train , y_train , start , end ) ;
281-
282- // Run optimization op (backprop)
283- sess . run ( optimizer , new FeedItem ( x , x_batch ) , new FeedItem ( y , y_batch ) ) ;
284-
285- if ( iteration % display_freq == 0 )
286- {
287- // Calculate and display the batch loss and accuracy
288- var result = sess . run ( new [ ] { loss , accuracy } , new FeedItem ( x , x_batch ) , new FeedItem ( y , y_batch ) ) ;
289- loss_val = result [ 0 ] ;
290- accuracy_val = result [ 1 ] ;
291- print ( $ "iter { iteration . ToString ( "000" ) } : Loss={ loss_val . ToString ( "0.0000" ) } , Training Accuracy={ accuracy_val . ToString ( "P" ) } ") ;
292- }
293- }
294-
295- // Run validation after every epoch
296- var results1 = sess . run ( new [ ] { loss , accuracy } , new FeedItem ( x , mnist . validation . data ) , new FeedItem ( y , mnist . validation . labels ) ) ;
297- loss_val = results1 [ 0 ] ;
298- accuracy_val = results1 [ 1 ] ;
299- print ( "---------------------------------------------------------" ) ;
300- print ( $ "Epoch: { epoch + 1 } , validation loss: { loss_val . ToString ( "0.0000" ) } , validation accuracy: { accuracy_val . ToString ( "P" ) } ") ;
301- print ( "---------------------------------------------------------" ) ;
302- }
303- }
332+ public Graph ImportGraph ( ) => throw new NotImplementedException ( ) ;
304333
305- public void Test ( Session sess )
306- {
307- var result = sess . run ( new [ ] { loss , accuracy } , new FeedItem ( x , mnist . test . data ) , new FeedItem ( y , mnist . test . labels ) ) ;
308- loss_test = result [ 0 ] ;
309- accuracy_test = result [ 1 ] ;
310- print ( "---------------------------------------------------------" ) ;
311- print ( $ "Test loss: { loss_test . ToString ( "0.0000" ) } , test accuracy: { accuracy_test . ToString ( "P" ) } ") ;
312- print ( "---------------------------------------------------------" ) ;
313- }
334+ public void Predict ( Session sess ) => throw new NotImplementedException ( ) ;
314335 }
315336}
0 commit comments