@@ -31,11 +31,19 @@ public class RetrainImageClassifier : IExample
3131 string summaries_dir = Path . Join ( data_dir , "retrain_logs" ) ;
3232 string image_dir = Path . Join ( data_dir , "flower_photos" ) ;
3333 string bottleneck_dir = Path . Join ( data_dir , "bottleneck" ) ;
34+ // The location where variable checkpoints will be stored.
35+ string CHECKPOINT_NAME = Path . Join ( data_dir , "_retrain_checkpoint" ) ;
3436 string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3" ;
3537 float testing_percentage = 0.1f ;
3638 float validation_percentage = 0.1f ;
3739 Tensor resized_image_tensor ;
3840 Dictionary < string , Dictionary < string , string [ ] > > image_lists ;
41+ int how_many_training_steps = 200 ;
42+ int eval_step_interval = 10 ;
43+ int train_batch_size = 100 ;
44+ int validation_batch_size = 100 ;
45+ int intermediate_store_frequency = 0 ;
46+ const int MAX_NUM_IMAGES_PER_CLASS = 134217727 ;
3947
4048 public bool Run ( )
4149 {
@@ -47,6 +55,9 @@ public bool Run()
4755 Tensor resized_image_tensor = graph . OperationByName ( "Placeholder" ) ;
4856 Tensor final_tensor = graph . OperationByName ( "final_result" ) ;
4957 Tensor ground_truth_input = graph . OperationByName ( "input/GroundTruthInput" ) ;
58+ Operation train_step = graph . OperationByName ( "train/GradientDescent" ) ;
59+ Tensor bottleneck_input = graph . OperationByName ( "input/BottleneckInputPlaceholder" ) ;
60+ Tensor cross_entropy = graph . OperationByName ( "cross_entropy/sparse_softmax_cross_entropy_loss/value" ) ;
5061
5162 var sw = new Stopwatch ( ) ;
5263
@@ -72,11 +83,104 @@ public bool Run()
7283 // Merge all the summaries and write them out to the summaries_dir
7384 var merged = tf . summary . merge_all ( ) ;
7485 var train_writer = tf . summary . FileWriter ( summaries_dir + "/train" , sess . graph ) ;
86+ var validation_writer = tf . summary . FileWriter ( summaries_dir + "/validation" , sess . graph ) ;
87+
88+ // Create a train saver that is used to restore values into an eval graph
89+ // when exporting models.
90+ var train_saver = tf . train . Saver ( ) ;
91+
92+ for ( int i = 0 ; i < how_many_training_steps ; i ++ )
93+ {
94+ var ( train_bottlenecks , train_ground_truth , _) = get_random_cached_bottlenecks (
95+ sess , image_lists , train_batch_size , "training" ,
96+ bottleneck_dir , image_dir , jpeg_data_tensor ,
97+ decoded_image_tensor , resized_image_tensor , bottleneck_tensor ,
98+ tfhub_module ) ;
99+
100+ // Feed the bottlenecks and ground truth into the graph, and run a training
101+ // step. Capture training summaries for TensorBoard with the `merged` op.
102+ var results = sess . run (
103+ new ITensorOrOperation [ ] { merged , train_step } ,
104+ new FeedItem ( bottleneck_input , train_bottlenecks ) ,
105+ new FeedItem ( ground_truth_input , train_ground_truth ) ) ;
106+ var train_summary = results [ 0 ] ;
107+
108+ // TODO
109+ train_writer . add_summary ( train_summary , i ) ;
110+
111+ // Every so often, print out how well the graph is training.
112+ bool is_last_step = ( i + 1 == how_many_training_steps ) ;
113+ if ( ( i % eval_step_interval ) == 0 || is_last_step )
114+ {
115+ results = sess . run (
116+ new Tensor [ ] { evaluation_step , cross_entropy } ,
117+ new FeedItem ( bottleneck_input , train_bottlenecks ) ,
118+ new FeedItem ( ground_truth_input , train_ground_truth ) ) ;
119+ ( float train_accuracy , float cross_entropy_value ) = ( results [ 0 ] , results [ 1 ] ) ;
120+ print ( $ "{ DateTime . Now } : Step { i } : Train accuracy = { train_accuracy * 100 } %") ;
121+ print ( $ "{ DateTime . Now } : Step { i } : Cross entropy = { cross_entropy_value } ") ;
122+
123+ var ( validation_bottlenecks , validation_ground_truth , _) = get_random_cached_bottlenecks (
124+ sess , image_lists , validation_batch_size , "validation" ,
125+ bottleneck_dir , image_dir , jpeg_data_tensor ,
126+ decoded_image_tensor , resized_image_tensor , bottleneck_tensor ,
127+ tfhub_module ) ;
128+
129+ // Run a validation step and capture training summaries for TensorBoard
130+ // with the `merged` op.
131+ results = sess . run ( new Tensor [ ] { merged , evaluation_step } ,
132+ new FeedItem ( bottleneck_input , validation_bottlenecks ) ,
133+ new FeedItem ( ground_truth_input , validation_ground_truth ) ) ;
134+
135+ ( string validation_summary , float validation_accuracy ) = ( results [ 0 ] , results [ 1 ] ) ;
136+
137+ validation_writer . add_summary ( validation_summary , i ) ;
138+ print ( $ "{ DateTime . Now } : Step { i } : Validation accuracy = { validation_accuracy * 100 } % (N={ len ( validation_bottlenecks ) } )") ;
139+ }
140+
141+ // Store intermediate results
142+ int intermediate_frequency = intermediate_store_frequency ;
143+ if ( intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0 )
144+ {
145+
146+ }
147+ }
148+
149+ // After training is complete, force one last save of the train checkpoint.
150+ train_saver . save ( sess , CHECKPOINT_NAME ) ;
75151 } ) ;
76152
77153 return false ;
78154 }
79155
156+ private ( NDArray , long [ ] , string [ ] ) get_random_cached_bottlenecks ( Session sess , Dictionary < string , Dictionary < string , string [ ] > > image_lists ,
157+ int how_many , string category , string bottleneck_dir , string image_dir ,
158+ Tensor jpeg_data_tensor , Tensor decoded_image_tensor , Tensor resized_input_tensor ,
159+ Tensor bottleneck_tensor , string module_name )
160+ {
161+ var bottlenecks = new List < float [ ] > ( ) ;
162+ var ground_truths = new List < long > ( ) ;
163+ var filenames = new List < string > ( ) ;
164+ int class_count = image_lists . Keys . Count ;
165+ foreach ( var unused_i in range ( how_many ) )
166+ {
167+ int label_index = new Random ( ) . Next ( class_count ) ;
168+ string label_name = image_lists . Keys . ToArray ( ) [ label_index ] ;
169+ int image_index = new Random ( ) . Next ( MAX_NUM_IMAGES_PER_CLASS ) ;
170+ string image_name = get_image_path ( image_lists , label_name , image_index ,
171+ image_dir , category ) ;
172+ var bottleneck = get_or_create_bottleneck (
173+ sess , image_lists , label_name , image_index , image_dir , category ,
174+ bottleneck_dir , jpeg_data_tensor , decoded_image_tensor ,
175+ resized_input_tensor , bottleneck_tensor , module_name ) ;
176+ bottlenecks . Add ( bottleneck ) ;
177+ ground_truths . Add ( label_index ) ;
178+ filenames . Add ( image_name ) ;
179+ }
180+
181+ return ( bottlenecks . ToArray ( ) , ground_truths . ToArray ( ) , filenames . ToArray ( ) ) ;
182+ }
183+
80184 /// <summary>
81185 /// Inserts the operations we need to evaluate the accuracy of our results.
82186 /// </summary>
0 commit comments