Skip to content

Commit a9330d2

Browse files
committed
Froze trained model is completed. SciSharp#248.
1 parent ed6fd47 commit a9330d2

7 files changed

Lines changed: 240 additions & 17 deletions

File tree

graph/InceptionV3.meta

324 KB
Binary file not shown.

src/TensorFlowNET.Core/Framework/graph_util_impl.cs

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
using System;
1+
using NumSharp;
2+
using System;
23
using System.Collections.Generic;
4+
using System.Linq;
35
using System.Text;
6+
using static Tensorflow.Python;
47

58
namespace Tensorflow
69
{
@@ -23,7 +26,181 @@ public GraphDef convert_variables_to_constants(Session sess,
2326
{
2427
// This graph only includes the nodes needed to evaluate the output nodes, and
2528
// removes unneeded nodes like those involved in saving and assignment.
26-
throw new NotImplementedException("");
29+
var inference_graph = extract_sub_graph(input_graph_def, output_node_names);
30+
31+
// Identify the ops in the graph.
32+
var map_name_to_node = new Dictionary<string, NodeDef>();
33+
inference_graph.Node.Select(x => map_name_to_node[x.Name] = x).ToArray();
34+
35+
// Get list of variables.
36+
var variable_names = new List<string>();
37+
var variable_dict_names = new List<string>();
38+
39+
foreach (var node in inference_graph.Node)
40+
{
41+
if(new string[] { "Variable", "VariableV2", "VarHandleOp" }.Contains(node.Op))
42+
{
43+
var variable_name = node.Name;
44+
45+
variable_dict_names.Add(variable_name);
46+
if (node.Op == "VarHandleOp")
47+
variable_names.Add(variable_name + "/Read/ReadVariableOp:0");
48+
else
49+
variable_names.Add(variable_name + ":0");
50+
}
51+
else if (new string[] { "ReadVariableOp", "ResourceGather" }.Contains(node.Op))
52+
{
53+
// There can be one or more Identity ops in between the ReadVariableOp and
54+
// VarHandleOp. Store the Identity ops with the associated dtypes.
55+
var source_op_name = get_input_name(node);
56+
while(map_name_to_node[source_op_name].Op == "Identity")
57+
{
58+
throw new NotImplementedException("map_name_to_node[source_op_name].Op");
59+
/*resource_identity_types[source_op_name] = node.attr["dtype"];
60+
source_op_name = get_input_name(map_name_to_node[source_op_name]);*/
61+
}
62+
}
63+
}
64+
65+
// Gets map of variables and the associated data.
66+
NDArray returned_variables = null;
67+
if (variable_names != null)
68+
returned_variables = sess.run(variable_names);
69+
70+
var variables_data_map = new Dictionary<string, NDArray>();
71+
foreach(var (i, name) in enumerate(variable_dict_names))
72+
variables_data_map[name] = returned_variables[i];
73+
print($"Froze {len(returned_variables)} variables.");
74+
75+
// Reconstruct the graph with constants in place of variables.
76+
var output_graph_def = new GraphDef();
77+
int how_many_converted = 0;
78+
foreach(var input_node in inference_graph.Node)
79+
{
80+
var output_node = new NodeDef();
81+
if (variables_data_map.ContainsKey(input_node.Name))
82+
{
83+
var data = variables_data_map[input_node.Name];
84+
output_node = create_const_op(input_node.Name, input_node.Attr["dtype"],
85+
data, data.shape);
86+
how_many_converted += 1;
87+
}
88+
// else if (resource_identity_types.ContainsKey(input_node.Name))
89+
else if(input_node.Op == "ReadVariableOp")
90+
{
91+
output_node.Op = "Identity";
92+
output_node.Name = input_node.Name;
93+
output_node.Input.AddRange(new[] { input_node.Input[0] });
94+
output_node.Attr["T"] = input_node.Attr["dtype"];
95+
}
96+
else if (input_node.Op == "ResourceGather")
97+
{
98+
99+
}
100+
else
101+
{
102+
output_node.MergeFrom(input_node);
103+
}
104+
105+
output_graph_def.Node.AddRange(new[] { output_node });
106+
}
107+
108+
output_graph_def.Library = inference_graph.Library;
109+
print($"Converted {how_many_converted} variables to const ops.");
110+
return output_graph_def;
111+
}
112+
113+
private NodeDef create_const_op(string node_name, AttrValue dtype, NDArray data, int[] data_shape = null)
114+
{
115+
var output_node = new NodeDef
116+
{
117+
Op = "Const",
118+
Name = node_name
119+
};
120+
output_node.Attr["dtype"] = dtype;
121+
output_node.Attr["value"] = new AttrValue()
122+
{
123+
Tensor = tensor_util.make_tensor_proto(
124+
data, dtype: dtype.Type.as_tf_dtype(), shape: data_shape)
125+
};
126+
127+
return output_node;
128+
}
129+
130+
/// <summary>
131+
/// Gets the name of the first input. Errors if suffix is not :0.
132+
/// </summary>
133+
/// <param name="node"></param>
134+
/// <returns></returns>
135+
private string get_input_name(NodeDef node)
136+
{
137+
var details = node.Input[0].Split(':');
138+
if (details.Length == 1 || int.Parse(details[1]) == 0)
139+
return details[0];
140+
// While it is valid for input tensors to have a suffix that is not :0, this
141+
// method is used to find the associated ops, not tensors, and therefore it
142+
// is not valid.
143+
throw new ValueError($"Tensor name '{node.Input[0]}' is invalid.");
144+
}
145+
146+
147+
private GraphDef extract_sub_graph(GraphDef graph_def, string[] dest_nodes)
148+
{
149+
var (name_to_input_name, name_to_node, name_to_seq_num) = _extract_graph_summary(
150+
graph_def);
151+
152+
var nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name);
153+
var nodes_to_keep_list = nodes_to_keep.OrderBy(n => name_to_seq_num[n]).ToArray();
154+
// Now construct the output GraphDef
155+
var output = new GraphDef();
156+
foreach (var n in nodes_to_keep_list)
157+
output.Node.Add(name_to_node[n]); // need deep clone?
158+
output.Library = graph_def.Library;
159+
output.Versions = graph_def.Versions;
160+
161+
return output;
162+
}
163+
164+
private string[] _bfs_for_reachable_nodes(string[] target_nodes, Dictionary<string, string[]> name_to_input_name)
165+
{
166+
var nodes_to_keep = new List<string>();
167+
var next_to_visit = target_nodes.Select(x => x).ToList();
168+
while(next_to_visit.Count > 0)
169+
{
170+
var node = next_to_visit[0];
171+
next_to_visit.RemoveAt(0);
172+
if (nodes_to_keep.Contains(node))
173+
continue;
174+
nodes_to_keep.Add(node);
175+
if (name_to_input_name.Keys.Contains(node))
176+
next_to_visit.AddRange(name_to_input_name[node]);
177+
}
178+
179+
return nodes_to_keep.ToArray();
180+
}
181+
182+
private (Dictionary<string, string[]>, Dictionary<string, NodeDef>, Dictionary<string, int>) _extract_graph_summary(GraphDef graph_def)
183+
{
184+
var name_to_input_name = new Dictionary<string, string[]>();
185+
var name_to_node = new Dictionary<string, NodeDef>();
186+
var name_to_seq_num = new Dictionary<string, int>();
187+
188+
int seq = 0;
189+
foreach (var node in graph_def.Node)
190+
{
191+
var n = _node_name(node.Name);
192+
name_to_node[n] = node;
193+
name_to_input_name[n] = node.Input.Select(x => _node_name(x)).ToArray();
194+
name_to_seq_num[n] = seq;
195+
seq++;
196+
}
197+
198+
return (name_to_input_name, name_to_node, name_to_seq_num);
199+
}
200+
201+
private string _node_name(string n)
202+
{
203+
return n.StartsWith("^") ? n.Substring(1) : n.Split(':')[0];
27204
}
28205

29206
private string get_input_name(string node)

src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@ public class _ElementFetchMapper : _FetchMapper
1515
public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn)
1616
{
1717
var g = ops.get_default_graph();
18-
ITensorOrOperation el = null;
1918

2019
foreach(var fetch in fetches)
2120
{
22-
el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true);
21+
var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true);
22+
_unique_fetches.Add(el);
2323
}
24-
25-
_unique_fetches.Add(el);
24+
2625
_contraction_fn = contraction_fn;
2726
}
2827

src/TensorFlowNET.Core/Sessions/_FetchHandler.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> fee
3333
_fetches.Add(val);
3434
_ops.Add(false);
3535
break;
36+
default:
37+
throw new NotImplementedException("_FetchHandler fetch");
3638
}
3739

3840
}

src/TensorFlowNET.Core/Sessions/_FetchMapper.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ namespace Tensorflow
88
{
99
public class _FetchMapper
1010
{
11-
protected List<object> _unique_fetches = new List<object>();
12-
11+
protected List<ITensorOrOperation> _unique_fetches = new List<ITensorOrOperation>();
12+
protected List<int[]> _value_indices = new List<int[]>();
1313
public static _FetchMapper for_fetch(object fetch)
1414
{
1515
var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch };
1616

17+
if(fetch is List<string> fetches1)
18+
return new _ListFetchMapper(fetches1.ToArray());
1719
if (fetch.GetType().IsArray)
1820
return new _ListFetchMapper(fetches);
1921
else
@@ -28,7 +30,7 @@ public virtual NDArray build_results(List<NDArray> values)
2830
return nd;
2931
}
3032

31-
public virtual List<object> unique_fetches()
33+
public virtual List<ITensorOrOperation> unique_fetches()
3234
{
3335
return _unique_fetches;
3436
}

src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,46 @@ public class _ListFetchMapper : _FetchMapper
1212
public _ListFetchMapper(object[] fetches)
1313
{
1414
_mappers = fetches.Select(fetch => _FetchMapper.for_fetch(fetch)).ToArray();
15-
_unique_fetches.AddRange(fetches);
15+
(_unique_fetches, _value_indices) = _uniquify_fetches(_mappers);
16+
}
17+
18+
private (List<ITensorOrOperation>, List<int[]>) _uniquify_fetches(_FetchMapper[] fetch_mappers)
19+
{
20+
var unique_fetches = new List<ITensorOrOperation>();
21+
var value_indices = new List<int[]>();
22+
var seen_fetches = new Dictionary<ITensorOrOperation, int>();
23+
24+
foreach (var m in fetch_mappers)
25+
{
26+
var m_value_indices = new List<int>();
27+
foreach (var uf in m.unique_fetches())
28+
{
29+
switch (uf)
30+
{
31+
case Tensor f:
32+
if (!seen_fetches.ContainsKey(f))
33+
{
34+
seen_fetches[f] = seen_fetches.Count;
35+
unique_fetches.Add(f);
36+
}
37+
m_value_indices.Add(seen_fetches.Count - 1);
38+
break;
39+
case Operation f:
40+
if (!seen_fetches.ContainsKey(f))
41+
{
42+
seen_fetches[f] = seen_fetches.Count;
43+
unique_fetches.Add(f);
44+
}
45+
m_value_indices.Add(seen_fetches.Count - 1);
46+
break;
47+
default:
48+
throw new NotImplementedException("_uniquify_fetches");
49+
}
50+
}
51+
value_indices.Add(m_value_indices.ToArray());
52+
}
53+
54+
return (unique_fetches, value_indices);
1655
}
1756
}
1857
}

test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
using System;
44
using System.Collections.Generic;
55
using System.Diagnostics;
6+
using System.Drawing;
67
using System.IO;
78
using System.Linq;
89
using System.Text;
910
using Tensorflow;
1011
using TensorFlowNET.Examples.Utility;
1112
using static Tensorflow.Python;
13+
using Console = Colorful.Console;
1214

1315
namespace TensorFlowNET.Examples.ImageProcess
1416
{
@@ -84,7 +86,7 @@ public bool Run()
8486

8587
var sw = new Stopwatch();
8688

87-
with(tf.Session(graph), sess =>
89+
return with(tf.Session(graph), sess =>
8890
{
8991
// Initialize all weights: for the module to their pretrained values,
9092
// and for the newly added retraining layer to random initial values.
@@ -111,6 +113,7 @@ public bool Run()
111113
// Create a train saver that is used to restore values into an eval graph
112114
// when exporting models.
113115
var train_saver = tf.train.Saver();
116+
sw.Restart();
114117

115118
for (int i = 0; i < how_many_training_steps; i++)
116119
{
@@ -140,8 +143,7 @@ public bool Run()
140143
new FeedItem(bottleneck_input, train_bottlenecks),
141144
new FeedItem(ground_truth_input, train_ground_truth));
142145
(float train_accuracy, float cross_entropy_value) = (results[0], results[1]);
143-
print($"{DateTime.Now}: Step {i}: Train accuracy = {train_accuracy * 100}%");
144-
print($"{DateTime.Now}: Step {i}: Cross entropy = {cross_entropy_value}");
146+
print($"{DateTime.Now}: Step {i + 1}: Train accuracy = {train_accuracy * 100}%, Cross entropy = {cross_entropy_value.ToString("G4")}");
145147

146148
var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks(
147149
sess, image_lists, validation_batch_size, "validation",
@@ -158,7 +160,8 @@ public bool Run()
158160
(string validation_summary, float validation_accuracy) = (results[0], results[1]);
159161

160162
validation_writer.add_summary(validation_summary, i);
161-
print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)})");
163+
print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)}) {sw.ElapsedMilliseconds}ms");
164+
sw.Restart();
162165
}
163166

164167
// Store intermediate results
@@ -180,12 +183,11 @@ public bool Run()
180183

181184
// Write out the trained graph and labels with the weights stored as
182185
// constants.
183-
print($"final test accuracy: {test_accuracy}");
184186
print($"Save final result to : {output_graph}");
185187
save_graph_to_file(output_graph, class_count);
188+
File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys));
189+
return test_accuracy > 0.75f;
186190
});
187-
188-
return false;
189191
}
190192

191193
/// <summary>
@@ -215,6 +217,8 @@ public bool Run()
215217
new FeedItem(bottleneck_input, test_bottlenecks),
216218
new FeedItem(ground_truth_input, test_ground_truth));
217219

220+
print($"final test accuracy: {((float)results[0] * 100).ToString("G4")}% (N={len(test_bottlenecks)})");
221+
218222
return (results[0], results[1]);
219223
}
220224

0 commit comments

Comments
 (0)