Skip to content

Commit 5751c20

Browse files
committed
can training and cross validating, but train_saver.save failed. SciSharp#248
1 parent 5e5ec52 commit 5751c20

8 files changed

Lines changed: 202 additions & 23 deletions

File tree

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ private NDArray _run(object fetches, FeedItem[] feed_dict = null)
9090
feed_dict_tensor[subfeed_t] = (NDArray)val;
9191
break;
9292
case int val:
93+
feed_dict_tensor[subfeed_t] = (NDArray)val;
94+
break;
95+
case long val:
96+
feed_dict_tensor[subfeed_t] = (NDArray)val;
97+
break;
98+
case long[] val:
9399
feed_dict_tensor[subfeed_t] = (NDArray)val;
94100
break;
95101
case int[] val:

src/TensorFlowNET.Core/Sessions/_FetchHandler.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ public NDArray build_results(BaseSession session, NDArray[] tensor_values)
6262
case "Single":
6363
full_values.Add(float.NaN);
6464
break;
65+
case "String":
66+
full_values.Add(float.NaN);
67+
break;
6568
default:
6669
throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype.Name}");
6770
}

src/TensorFlowNET.Core/Summaries/EventFileWriter.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,35 @@
55

66
namespace Tensorflow.Summaries
77
{
8+
/// <summary>
9+
/// Creates a `EventFileWriter` and an event file to write to.
10+
/// </summary>
811
public class EventFileWriter
912
{
1013
string _logdir;
11-
Queue<int> _event_queue;
14+
// Represents a first-in, first-out collection of objects.
15+
Queue<Event> _event_queue;
16+
EventsWriter _ev_writer;
17+
int _flush_secs;
18+
Event _sentinel_event;
19+
bool _closed;
20+
EventLoggerThread _worker;
1221

1322
public EventFileWriter(string logdir, int max_queue = 10, int flush_secs= 120,
1423
string filename_suffix = null)
1524
{
1625
_logdir = logdir;
1726
Directory.CreateDirectory(_logdir);
18-
_event_queue = new Queue<int>(max_queue);
27+
_event_queue = new Queue<Event>(max_queue);
28+
_ev_writer = new EventsWriter(Path.Combine(_logdir, "events"));
29+
_flush_secs = flush_secs;
30+
_sentinel_event = new Event();
31+
if (!string.IsNullOrEmpty(filename_suffix))
32+
// self._ev_writer.InitWithSuffix(compat.as_bytes(filename_suffix)))
33+
throw new NotImplementedException("EventFileWriter filename_suffix is not null");
34+
_closed = false;
35+
_worker = new EventLoggerThread(_event_queue, _ev_writer, _flush_secs, _sentinel_event);
36+
_worker.start();
1937
}
2038
}
2139
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using Google.Protobuf;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.IO;
5+
using System.Text;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
9+
namespace Tensorflow.Summaries
10+
{
11+
/// <summary>
12+
/// Thread that logs events.
13+
/// </summary>
14+
public class EventLoggerThread
15+
{
16+
Queue<Event> _queue;
17+
bool daemon;
18+
EventsWriter _ev_writer;
19+
int _flush_secs;
20+
Event _sentinel_event;
21+
22+
public EventLoggerThread(Queue<Event> queue, EventsWriter ev_writer, int flush_secs, Event sentinel_event)
23+
{
24+
daemon = true;
25+
_queue = queue;
26+
_ev_writer = ev_writer;
27+
_flush_secs = flush_secs;
28+
_sentinel_event = sentinel_event;
29+
}
30+
31+
public void start() => run();
32+
33+
public void run()
34+
{
35+
Task.Run(delegate
36+
{
37+
while (true)
38+
{
39+
if(_queue.Count == 0)
40+
{
41+
Thread.Sleep(_flush_secs * 1000);
42+
continue;
43+
}
44+
45+
var @event = _queue.Dequeue();
46+
_ev_writer._WriteSerializedEvent(@event.ToByteArray());
47+
Thread.Sleep(1000);
48+
}
49+
});
50+
}
51+
}
52+
}
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.IO;
34
using System.Text;
45

56
namespace Tensorflow.Summaries
67
{
78
public class EventsWriter
89
{
10+
string _file_prefix;
11+
912
public EventsWriter(string file_prefix)
1013
{
11-
14+
_file_prefix = file_prefix;
1215
}
1316

1417
public void _WriteSerializedEvent(byte[] event_str)
1518
{
16-
19+
File.WriteAllBytes(_file_prefix, event_str);
1720
}
1821
}
1922
}

src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System;
1+
using Google.Protobuf;
2+
using System;
23
using System.Collections.Generic;
34
using System.Text;
45

@@ -9,5 +10,10 @@ namespace Tensorflow.Summaries
910
/// </summary>
1011
public abstract class SummaryToEventTransformer
1112
{
13+
public void add_summary(string summary, int global_step = 0)
14+
{
15+
var bytes = UTF8Encoding.Unicode.GetBytes(summary);
16+
// var summ = Tensorflow.Summary.Parser.ParseFrom(bytes);
17+
}
1218
}
1319
}

src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ private IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null)
7171
case "Int32":
7272
Marshal.Copy(nd1.Data<int>(), 0, dotHandle, nd.size);
7373
break;
74+
case "Int64":
75+
Marshal.Copy(nd1.Data<long>(), 0, dotHandle, nd.size);
76+
break;
7477
case "Single":
7578
Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size);
7679
break;
@@ -80,24 +83,8 @@ private IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null)
8083
case "Byte":
8184
Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size);
8285
break;
83-
//case "String":
84-
/*string ss = nd.Data<string>()[0];
85-
var str = Marshal.StringToHGlobalAnsi(ss);
86-
ulong dst_len = c_api.TF_StringEncodedSize((ulong)ss.Length);
87-
var dataType1 = ToTFDataType(nd.dtype);
88-
// shape
89-
var dims1 = nd.shape.Select(x => (long)x).ToArray();
90-
91-
var tfHandle1 = c_api.TF_AllocateTensor(dataType1,
92-
dims1,
93-
nd.ndim,
94-
dst_len + sizeof(Int64));
95-
96-
dotHandle = c_api.TF_TensorData(tfHandle1);
97-
Marshal.WriteInt64(dotHandle, 0);
98-
c_api.TF_StringEncode(str, (ulong)ss.Length, dotHandle + sizeof(Int64), dst_len, status);
99-
return tfHandle1;*/
100-
break;
86+
case "String":
87+
return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>(0)));
10188
default:
10289
throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}.");
10390
}

test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,19 @@ public class RetrainImageClassifier : IExample
3131
string summaries_dir = Path.Join(data_dir, "retrain_logs");
3232
string image_dir = Path.Join(data_dir, "flower_photos");
3333
string bottleneck_dir = Path.Join(data_dir, "bottleneck");
34+
// The location where variable checkpoints will be stored.
35+
string CHECKPOINT_NAME = Path.Join(data_dir, "_retrain_checkpoint");
3436
string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3";
3537
float testing_percentage = 0.1f;
3638
float validation_percentage = 0.1f;
3739
Tensor resized_image_tensor;
3840
Dictionary<string, Dictionary<string, string[]>> image_lists;
41+
int how_many_training_steps = 200;
42+
int eval_step_interval = 10;
43+
int train_batch_size = 100;
44+
int validation_batch_size = 100;
45+
int intermediate_store_frequency = 0;
46+
const int MAX_NUM_IMAGES_PER_CLASS = 134217727;
3947

4048
public bool Run()
4149
{
@@ -47,6 +55,9 @@ public bool Run()
4755
Tensor resized_image_tensor = graph.OperationByName("Placeholder");
4856
Tensor final_tensor = graph.OperationByName("final_result");
4957
Tensor ground_truth_input = graph.OperationByName("input/GroundTruthInput");
58+
Operation train_step = graph.OperationByName("train/GradientDescent");
59+
Tensor bottleneck_input = graph.OperationByName("input/BottleneckInputPlaceholder");
60+
Tensor cross_entropy = graph.OperationByName("cross_entropy/sparse_softmax_cross_entropy_loss/value");
5061

5162
var sw = new Stopwatch();
5263

@@ -72,11 +83,104 @@ public bool Run()
7283
// Merge all the summaries and write them out to the summaries_dir
7384
var merged = tf.summary.merge_all();
7485
var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph);
86+
var validation_writer = tf.summary.FileWriter(summaries_dir + "/validation", sess.graph);
87+
88+
// Create a train saver that is used to restore values into an eval graph
89+
// when exporting models.
90+
var train_saver = tf.train.Saver();
91+
92+
for (int i = 0; i < how_many_training_steps; i++)
93+
{
94+
var (train_bottlenecks, train_ground_truth, _) = get_random_cached_bottlenecks(
95+
sess, image_lists, train_batch_size, "training",
96+
bottleneck_dir, image_dir, jpeg_data_tensor,
97+
decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
98+
tfhub_module);
99+
100+
// Feed the bottlenecks and ground truth into the graph, and run a training
101+
// step. Capture training summaries for TensorBoard with the `merged` op.
102+
var results = sess.run(
103+
new ITensorOrOperation[] { merged, train_step },
104+
new FeedItem(bottleneck_input, train_bottlenecks),
105+
new FeedItem(ground_truth_input, train_ground_truth));
106+
var train_summary = results[0];
107+
108+
// TODO
109+
train_writer.add_summary(train_summary, i);
110+
111+
// Every so often, print out how well the graph is training.
112+
bool is_last_step = (i + 1 == how_many_training_steps);
113+
if ((i % eval_step_interval) == 0 || is_last_step)
114+
{
115+
results = sess.run(
116+
new Tensor[] { evaluation_step, cross_entropy },
117+
new FeedItem(bottleneck_input, train_bottlenecks),
118+
new FeedItem(ground_truth_input, train_ground_truth));
119+
(float train_accuracy, float cross_entropy_value) = (results[0], results[1]);
120+
print($"{DateTime.Now}: Step {i}: Train accuracy = {train_accuracy * 100}%");
121+
print($"{DateTime.Now}: Step {i}: Cross entropy = {cross_entropy_value}");
122+
123+
var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks(
124+
sess, image_lists, validation_batch_size, "validation",
125+
bottleneck_dir, image_dir, jpeg_data_tensor,
126+
decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
127+
tfhub_module);
128+
129+
// Run a validation step and capture training summaries for TensorBoard
130+
// with the `merged` op.
131+
results = sess.run(new Tensor[] { merged, evaluation_step },
132+
new FeedItem(bottleneck_input, validation_bottlenecks),
133+
new FeedItem(ground_truth_input, validation_ground_truth));
134+
135+
(string validation_summary, float validation_accuracy) = (results[0], results[1]);
136+
137+
validation_writer.add_summary(validation_summary, i);
138+
print($"{DateTime.Now}: Step {i}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)})");
139+
}
140+
141+
// Store intermediate results
142+
int intermediate_frequency = intermediate_store_frequency;
143+
if (intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0)
144+
{
145+
146+
}
147+
}
148+
149+
// After training is complete, force one last save of the train checkpoint.
150+
train_saver.save(sess, CHECKPOINT_NAME);
75151
});
76152

77153
return false;
78154
}
79155

156+
private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
157+
int how_many, string category, string bottleneck_dir, string image_dir,
158+
Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
159+
Tensor bottleneck_tensor, string module_name)
160+
{
161+
var bottlenecks = new List<float[]>();
162+
var ground_truths = new List<long>();
163+
var filenames = new List<string>();
164+
int class_count = image_lists.Keys.Count;
165+
foreach (var unused_i in range(how_many))
166+
{
167+
int label_index = new Random().Next(class_count);
168+
string label_name = image_lists.Keys.ToArray()[label_index];
169+
int image_index = new Random().Next(MAX_NUM_IMAGES_PER_CLASS);
170+
string image_name = get_image_path(image_lists, label_name, image_index,
171+
image_dir, category);
172+
var bottleneck = get_or_create_bottleneck(
173+
sess, image_lists, label_name, image_index, image_dir, category,
174+
bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
175+
resized_input_tensor, bottleneck_tensor, module_name);
176+
bottlenecks.Add(bottleneck);
177+
ground_truths.Add(label_index);
178+
filenames.Add(image_name);
179+
}
180+
181+
return (bottlenecks.ToArray(), ground_truths.ToArray(), filenames.ToArray());
182+
}
183+
80184
/// <summary>
81185
/// Inserts the operations we need to evaluate the accuracy of our results.
82186
/// </summary>

0 commit comments

Comments
 (0)