Skip to content

Commit fdf8231

Browse files
committed
import_graph_def
1 parent c7cf8b6 commit fdf8231

12 files changed

Lines changed: 300 additions & 23 deletions

File tree

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class c_api_util
8+
{
9+
public static TF_Output tf_output(IntPtr c_op, int index) => new TF_Output(c_op, index);
10+
11+
public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions();
12+
13+
public static IntPtr tf_buffer(byte[] data)
14+
{
15+
if (data != null)
16+
throw new NotImplementedException("");
17+
// var buf = c_api.TF_NewBufferFromString(data);
18+
else
19+
throw new NotImplementedException("");
20+
}
21+
}
22+
}
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
using Google.Protobuf;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
6+
using static Tensorflow.OpDef.Types;
7+
8+
namespace Tensorflow
9+
{
10+
public class importer
11+
{
12+
public static ITensorOrOperation[] import_graph_def(GraphDef graph_def,
13+
Dictionary<string, Tensor> input_map = null,
14+
string[] return_elements = null,
15+
string name = "",
16+
OpList producer_op_list = null)
17+
{
18+
var op_dict = op_def_registry.get_registered_ops();
19+
20+
graph_def = _ProcessGraphDefParam(graph_def, op_dict);
21+
input_map = _ProcessInputMapParam(input_map);
22+
return_elements = _ProcessReturnElementsParam(return_elements);
23+
24+
if (producer_op_list != null)
25+
_RemoveDefaultAttrs(op_dict, producer_op_list, graph_def);
26+
27+
string prefix = "";
28+
var graph = ops.get_default_graph();
29+
Python.with<ops.name_scope>(new ops.name_scope(name, "import", input_map.Values), scope =>
30+
{
31+
/*prefix = scope;
32+
if (!string.IsNullOrEmpty(prefix))
33+
prefix = prefix.Substring(0, prefix.Length - 1);
34+
else
35+
prefix = "";*/
36+
37+
// Generate any input map tensors inside name scope
38+
input_map = _ConvertInputMapValues(name, input_map);
39+
});
40+
41+
var scoped_options = c_api_util.ScopedTFImportGraphDefOptions();
42+
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
43+
44+
var bytes = graph_def.ToByteString().ToArray();
45+
46+
var status = new Status();
47+
c_api.TF_GraphImportGraphDefWithResults(graph, IntPtr.Zero, scoped_options, status);
48+
49+
throw new NotImplementedException("importer.import_graph_def");
50+
}
51+
52+
public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options,
53+
string prefix,
54+
Dictionary<string, Tensor> input_map,
55+
string[] return_elements)
56+
{
57+
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix);
58+
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1);
59+
60+
foreach(var input in input_map)
61+
{
62+
throw new NotImplementedException("_PopulateTFImportGraphDefOptions");
63+
}
64+
65+
if (return_elements == null)
66+
return_elements = new string[0];
67+
68+
foreach (var name in return_elements)
69+
{
70+
throw new NotImplementedException("_PopulateTFImportGraphDefOptions");
71+
}
72+
}
73+
74+
public static Dictionary<string, Tensor> _ConvertInputMapValues(string name, Dictionary<string, Tensor> input_map)
75+
{
76+
return input_map;
77+
}
78+
79+
public static GraphDef _ProcessGraphDefParam(GraphDef graph_def, Dictionary<string, OpDef> op_dict)
80+
{
81+
foreach(var node in graph_def.Node)
82+
{
83+
if (!op_dict.ContainsKey(node.Op))
84+
continue;
85+
86+
var op_def = op_dict[node.Op];
87+
_SetDefaultAttrValues(node, op_def);
88+
}
89+
90+
return graph_def;
91+
}
92+
93+
private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def)
94+
{
95+
foreach(var attr_def in op_def.Attr)
96+
{
97+
var key = attr_def.Name;
98+
if(attr_def.DefaultValue != null)
99+
{
100+
var value = node_def.Attr[key];
101+
if (value == null)
102+
node_def.Attr[key] = attr_def.DefaultValue;
103+
}
104+
}
105+
}
106+
107+
private static Dictionary<string, Tensor> _ProcessInputMapParam(Dictionary<string, Tensor> input_map)
108+
{
109+
if (input_map == null)
110+
return new Dictionary<string, Tensor>();
111+
112+
return input_map;
113+
}
114+
115+
private static string[] _ProcessReturnElementsParam(string[] return_elements)
116+
{
117+
if (return_elements == null)
118+
return null;
119+
120+
return return_elements;
121+
}
122+
123+
private static void _RemoveDefaultAttrs(Dictionary<string, OpDef> op_dict, OpList producer_op_list, GraphDef graph_def)
124+
{
125+
var producer_op_dict = new Dictionary<string, OpDef>();
126+
producer_op_list.Op.Select(op =>
127+
{
128+
producer_op_dict[op.Name] = op;
129+
return op;
130+
}).ToArray();
131+
132+
foreach(var node in graph_def.Node)
133+
{
134+
// Remove any default attr values that aren't in op_def.
135+
if (producer_op_dict.ContainsKey(node.Op))
136+
{
137+
var op_def = op_dict[node.Op];
138+
var producer_op_def = producer_op_dict[node.Op];
139+
foreach(var key in node.Attr)
140+
{
141+
if(_FindAttrInOpDef(key.Key, op_def) == null)
142+
{
143+
var attr_def = _FindAttrInOpDef(key.Key, producer_op_def);
144+
if (attr_def != null && attr_def.DefaultValue != null &&
145+
node.Attr[key.Key] == attr_def.DefaultValue)
146+
node.Attr[key.Key].ClearValue();
147+
}
148+
}
149+
}
150+
}
151+
}
152+
153+
private static AttrDef _FindAttrInOpDef(string name, OpDef op_def)
154+
{
155+
return op_def.Attr.FirstOrDefault(x => x.Name == name);
156+
}
157+
}
158+
}

src/TensorFlowNET.Core/Framework/meta_graph.py.cs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.IO;
34
using System.Linq;
45
using System.Text;
56
using static Tensorflow.MetaGraphDef.Types;
@@ -8,6 +9,59 @@ namespace Tensorflow
89
{
910
public class meta_graph
1011
{
12+
public static MetaGraphDef read_meta_graph_file(string filename)
13+
{
14+
var bytes = File.ReadAllBytes(filename);
15+
var meta_graph_def = MetaGraphDef.Parser.ParseFrom(bytes);
16+
return meta_graph_def;
17+
}
18+
19+
public static void import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
20+
bool clear_devices = false,
21+
string import_scope = "",
22+
Dictionary<string, Tensor> input_map = null,
23+
string unbound_inputs_col_name = "unbound_inputs",
24+
string[] return_elements = null)
25+
{
26+
var meta_graph_def = meta_graph_or_file;
27+
28+
if (!string.IsNullOrEmpty(unbound_inputs_col_name))
29+
{
30+
foreach(var col in meta_graph_def.CollectionDef)
31+
{
32+
if(col.Key == unbound_inputs_col_name)
33+
{
34+
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
35+
}
36+
}
37+
}
38+
39+
// Sets graph to default graph if it's not passed in.
40+
var graph = ops.get_default_graph();
41+
42+
// Gathers the list of nodes we are interested in.
43+
OpList producer_op_list = null;
44+
if (meta_graph_def.MetaInfoDef.StrippedOpList != null)
45+
producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList;
46+
var input_graph_def = meta_graph_def.GraphDef;
47+
// Remove all the explicit device specifications for this node. This helps to
48+
// make the graph more portable.
49+
if (clear_devices)
50+
foreach (var node in input_graph_def.Node)
51+
node.Device = "";
52+
53+
var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false);
54+
importer.import_graph_def(input_graph_def,
55+
name: scope_to_prepend_to_names,
56+
input_map: input_map,
57+
producer_op_list: producer_op_list,
58+
return_elements: return_elements);
59+
60+
// Restores all the other collections.
61+
var variable_objects = new Dictionary<string, RefVariable>();
62+
63+
}
64+
1165
/// <summary>
1266
/// Returns `MetaGraphDef` proto. Optionally writes it to filename.
1367
/// </summary>

src/TensorFlowNET.Core/Graphs/c_api.graph.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,18 @@ public partial class c_api
218218
[DllImport(TensorFlowLibName)]
219219
public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix);
220220

221+
/// <summary>
222+
/// Set whether to uniquify imported operation names. If true, imported operation
223+
/// names will be modified if their name already exists in the graph. If false,
224+
/// conflicting names will be treated as an error. Note that this option has no
225+
/// effect if a prefix is set, since the prefix will guarantee all names are
226+
/// unique. Defaults to false.
227+
/// </summary>
228+
/// <param name="ops">TF_ImportGraphDefOptions*</param>
229+
/// <param name="uniquify_prefix">unsigned char</param>
230+
[DllImport(TensorFlowLibName)]
231+
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(IntPtr ops, char uniquify_prefix);
232+
221233
/// <summary>
222234
/// Fetches the return operations requested via
223235
/// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
<TargetFramework>netstandard2.0</TargetFramework>
55
<AssemblyName>TensorFlow.NET</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
7-
<Version>0.1.0</Version>
7+
<Version>0.2.0</Version>
88
<Authors>Haiping Chen</Authors>
99
<Company>SciSharp STACK</Company>
1010
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -16,10 +16,13 @@
1616
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags>
1717
<Description>Google's TensorFlow binding in .NET Standard.
1818
Docs: https://tensorflownet.readthedocs.io</Description>
19-
<AssemblyVersion>0.1.0.0</AssemblyVersion>
20-
<PackageReleaseNotes>Implemented the tf.Variable().
21-
TensorFlow 1.13 RC.</PackageReleaseNotes>
19+
<AssemblyVersion>0.2.0.0</AssemblyVersion>
20+
<PackageReleaseNotes>Added a bunch of APIs.
21+
Fixed String tensor creation bug.
22+
Upgraded to TensorFlow 1.13 RC-1.
23+
</PackageReleaseNotes>
2224
<LangVersion>7.2</LangVersion>
25+
<FileVersion>0.2.0.0</FileVersion>
2326
</PropertyGroup>
2427

2528
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">

src/TensorFlowNET.Core/Train/Saving/Saver.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,13 @@ public string save(Session sess,
193193
return _is_empty ? string.Empty : model_checkpoint_path;
194194
}
195195

196+
public Saver import_meta_graph(string meta_graph_or_file,
197+
bool clear_devices = false,
198+
string import_scope = "")
199+
{
200+
return saver._import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope);
201+
}
202+
196203
/// <summary>
197204
/// Writes `MetaGraphDef` to save_path/filename.
198205
/// </summary>
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class saver
8+
{
9+
public static Saver _import_meta_graph_with_return_elements(string meta_graph_or_file,
10+
bool clear_devices = false,
11+
string import_scope = "",
12+
string[] return_elements = null)
13+
{
14+
var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file);
15+
16+
meta_graph.import_scoped_meta_graph_with_return_elements(
17+
meta_graph_def,
18+
clear_devices: clear_devices,
19+
import_scope: import_scope,
20+
return_elements: return_elements);
21+
22+
return null;
23+
/*var (imported_vars, imported_return_elements) = (
24+
, false);*/
25+
}
26+
}
27+
}

src/TensorFlowNET.Core/Train/tf.optimizers.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ public static class train
1414
public static Saver Saver() => new Saver();
1515

1616
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text);
17+
18+
public static Saver import_meta_graph(string meta_graph_or_file,
19+
bool clear_devices = false,
20+
string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file,
21+
clear_devices,
22+
import_scope);
1723
}
1824
}
1925
}

src/TensorFlowNET.Core/c_api_util.cs

Lines changed: 0 additions & 14 deletions
This file was deleted.

test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
<ItemGroup>
99
<PackageReference Include="NumSharp" Version="0.7.1" />
10-
<PackageReference Include="TensorFlow.NET" Version="0.1.0" />
10+
<PackageReference Include="TensorFlow.NET" Version="0.2.0" />
1111
</ItemGroup>
1212

1313
<ItemGroup>

0 commit comments

Comments
 (0)