Skip to content

Commit 9b14b46

Browse files
committed
fixed vd_cnn.meta
1 parent 0a82e58 commit 9b14b46

2 files changed

Lines changed: 16 additions & 10 deletions

File tree

graph/vd_cnn.meta

3.04 KB
Binary file not shown.

test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,38 +52,37 @@ public bool Run()
5252

5353
protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
5454
{
55+
var stopwatch = Stopwatch.StartNew();
5556
Console.WriteLine("Building dataset...");
5657
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit=null);
57-
Console.WriteLine("\tDONE");
58+
Console.WriteLine("\tDONE ");
5859

5960
var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
6061

6162
Console.WriteLine("Import graph...");
6263
var meta_file = model_name + ".meta";
6364
tf.train.import_meta_graph(Path.Join("graph", meta_file));
64-
Console.WriteLine("\tDONE");
65-
// definitely necessary, otherwize will get the exception of "use uninitialized variable"
65+
Console.WriteLine("\tDONE " + stopwatch.Elapsed);
66+
6667
sess.run(tf.global_variables_initializer());
6768

6869
var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
69-
var num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1
70+
var num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1;
7071
double max_accuracy = 0;
7172

7273
Tensor is_training = graph.get_operation_by_name("is_training");
7374
Tensor model_x = graph.get_operation_by_name("x");
7475
Tensor model_y = graph.get_operation_by_name("y");
75-
Tensor loss = graph.get_operation_by_name("loss/loss");
76+
Tensor loss = graph.get_operation_by_name("loss/value");
7677
//var optimizer_nodes = graph._nodes_by_name.Keys.Where(key => key.Contains("optimizer")).ToArray();
7778
Tensor optimizer = graph.get_operation_by_name("loss/optimizer");
7879
Tensor global_step = graph.get_operation_by_name("global_step");
79-
Tensor accuracy = graph.get_operation_by_name("accuracy/accuracy");
80-
var stopwatch = Stopwatch.StartNew();
80+
Tensor accuracy = graph.get_operation_by_name("accuracy/value");
81+
stopwatch = Stopwatch.StartNew();
8182
int i = 0;
8283
foreach (var (x_batch, y_batch, total) in train_batches)
8384
{
8485
i++;
85-
var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total);
86-
Console.WriteLine($"Training on batch {i}/{total}. Estimated training time: {estimate}");
8786
var train_feed_dict = new Hashtable
8887
{
8988
[model_x] = x_batch,
@@ -94,9 +93,14 @@ protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
9493
//_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict)
9594
var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict);
9695
//loss_value = result[2];
97-
var step = result[1];
96+
var step = result[1];
9897
if (step % 10 == 0)
98+
{
99+
var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total);
100+
Console.WriteLine($"Training on batch {i}/{total}. Estimated training time: {estimate}");
99101
Console.WriteLine($"Step {step} loss: {result[2]}");
102+
}
103+
100104
if (step % 100 == 0)
101105
{
102106
continue;
@@ -198,6 +202,8 @@ public void PrepareData()
198202
{
199203
// download graph meta data
200204
var meta_file = model_name + ".meta";
205+
if (File.GetLastWriteTime(meta_file) < new DateTime(2019,05,11)) // delete old cached file which contains errors
206+
File.Delete(meta_file);
201207
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
202208
Web.Download(url, "graph", meta_file);
203209
}

0 commit comments

Comments
 (0)