@@ -52,38 +52,37 @@ public bool Run()
5252
5353 protected virtual bool RunWithImportedGraph ( Session sess , Graph graph )
5454 {
55+ var stopwatch = Stopwatch . StartNew ( ) ;
5556 Console . WriteLine ( "Building dataset..." ) ;
5657 var ( x , y , alphabet_size ) = DataHelpers . build_char_dataset ( "train" , model_name , CHAR_MAX_LEN , DataLimit = null ) ;
57- Console . WriteLine ( "\t DONE" ) ;
58+ Console . WriteLine ( "\t DONE " ) ;
5859
5960 var ( train_x , valid_x , train_y , valid_y ) = train_test_split ( x , y , test_size : 0.15f ) ;
6061
6162 Console . WriteLine ( "Import graph..." ) ;
6263 var meta_file = model_name + ".meta" ;
6364 tf . train . import_meta_graph ( Path . Join ( "graph" , meta_file ) ) ;
64- Console . WriteLine ( "\t DONE" ) ;
65- // definitely necessary, otherwize will get the exception of "use uninitialized variable"
65+ Console . WriteLine ( "\t DONE " + stopwatch . Elapsed ) ;
66+
6667 sess . run ( tf . global_variables_initializer ( ) ) ;
6768
6869 var train_batches = batch_iter ( train_x , train_y , BATCH_SIZE , NUM_EPOCHS ) ;
69- var num_batches_per_epoch = ( len ( train_x ) - 1 ) ; // BATCH_SIZE + 1
70+ var num_batches_per_epoch = ( len ( train_x ) - 1 ) / BATCH_SIZE + 1 ;
7071 double max_accuracy = 0 ;
7172
7273 Tensor is_training = graph . get_operation_by_name ( "is_training" ) ;
7374 Tensor model_x = graph . get_operation_by_name ( "x" ) ;
7475 Tensor model_y = graph . get_operation_by_name ( "y" ) ;
75- Tensor loss = graph . get_operation_by_name ( "loss/loss " ) ;
76+ Tensor loss = graph . get_operation_by_name ( "loss/value " ) ;
7677 //var optimizer_nodes = graph._nodes_by_name.Keys.Where(key => key.Contains("optimizer")).ToArray();
7778 Tensor optimizer = graph . get_operation_by_name ( "loss/optimizer" ) ;
7879 Tensor global_step = graph . get_operation_by_name ( "global_step" ) ;
79- Tensor accuracy = graph . get_operation_by_name ( "accuracy/accuracy " ) ;
80- var stopwatch = Stopwatch . StartNew ( ) ;
80+ Tensor accuracy = graph . get_operation_by_name ( "accuracy/value " ) ;
81+ stopwatch = Stopwatch . StartNew ( ) ;
8182 int i = 0 ;
8283 foreach ( var ( x_batch , y_batch , total ) in train_batches )
8384 {
8485 i ++ ;
85- var estimate = TimeSpan . FromSeconds ( ( stopwatch . Elapsed . TotalSeconds / i ) * total ) ;
86- Console . WriteLine ( $ "Training on batch { i } /{ total } . Estimated training time: { estimate } ") ;
8786 var train_feed_dict = new Hashtable
8887 {
8988 [ model_x ] = x_batch ,
@@ -94,9 +93,14 @@ protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
9493 //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict)
9594 var result = sess . run ( new ITensorOrOperation [ ] { optimizer , global_step , loss } , train_feed_dict ) ;
9695 //loss_value = result[2];
97- var step = result [ 1 ] ;
96+ var step = result [ 1 ] ;
9897 if ( step % 10 == 0 )
98+ {
99+ var estimate = TimeSpan . FromSeconds ( ( stopwatch . Elapsed . TotalSeconds / i ) * total ) ;
100+ Console . WriteLine ( $ "Training on batch { i } /{ total } . Estimated training time: { estimate } ") ;
99101 Console . WriteLine ( $ "Step { step } loss: { result [ 2 ] } ") ;
102+ }
103+
100104 if ( step % 100 == 0 )
101105 {
102106 continue ;
@@ -198,6 +202,8 @@ public void PrepareData()
198202 {
199203 // download graph meta data
200204 var meta_file = model_name + ".meta" ;
205+ if ( File . GetLastWriteTime ( meta_file ) < new DateTime ( 2019 , 05 , 11 ) ) // delete old cached file which contains errors
206+ File . Delete ( meta_file ) ;
201207 url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file ;
202208 Web . Download ( url , "graph" , meta_file ) ;
203209 }
0 commit comments