Skip to content

Commit e488c67

Browse files
committed
TF_ImportGraphDefResults.return_tensors
1 parent 9d5bb8f commit e488c67

6 files changed

Lines changed: 117 additions & 13 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.io.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public partial class tensorflow
2424
public GFile gfile = new GFile();
2525
public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name);
2626

27-
public void import_graph_def(GraphDef graph_def,
27+
public ITensorOrOperation[] import_graph_def(GraphDef graph_def,
2828
Dictionary<string, Tensor> input_map = null,
2929
string[] return_elements = null,
3030
string name = null,

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ public static int len(object a)
9595
throw new NotImplementedException("len() not implemented for type: " + a.GetType());
9696
}
9797

98+
public static float min(float a, float b)
99+
=> Math.Min(a, b);
100+
98101
public static T[] list<T>(IEnumerable<T> list)
99102
=> list.ToArray();
100103

src/TensorFlowNET.Core/Framework/importer.py.cs renamed to src/TensorFlowNET.Core/Framework/importer.cs

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,51 @@ public static ITensorOrOperation[] import_graph_def(GraphDef graph_def,
5454
input_map = _ConvertInputMapValues(name, input_map);
5555
});
5656

57+
TF_ImportGraphDefResults results = null;
5758
var bytes = graph_def.ToByteString().ToArray();
5859
using (var buffer = c_api_util.tf_buffer(bytes))
5960
using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions())
6061
using (var status = new Status())
6162
{
6263
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
6364
// need to create a class ImportGraphDefWithResults with IDisposal
64-
var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status);
65+
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status);
6566
status.Check(true);
66-
c_api.TF_DeleteImportGraphDefResults(results);
6767
}
6868

6969
_ProcessNewOps(graph);
7070

7171
if (return_elements == null)
7272
return null;
7373
else
74-
throw new NotImplementedException("import_graph_def return_elements");
74+
return _GatherReturnElements(return_elements, graph, results);
75+
}
76+
77+
private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements,
78+
Graph graph,
79+
TF_ImportGraphDefResults results)
80+
{
81+
var return_outputs = results.return_tensors;
82+
var return_opers = results.return_opers;
83+
84+
var combined_return_elements = new List<ITensorOrOperation>();
85+
int outputs_idx = 0;
86+
int opers_idx = 0;
87+
foreach(var name in requested_return_elements)
88+
{
89+
if (name.Contains(":"))
90+
{
91+
combined_return_elements.append(graph.get_tensor_by_tf_output(return_outputs[outputs_idx]));
92+
outputs_idx += 1;
93+
}
94+
else
95+
{
96+
throw new NotImplementedException("_GatherReturnElements");
97+
// combined_return_elements.append(graph._get_operation_by_tf_operation(return_opers[opers_idx]));
98+
}
99+
}
100+
101+
return combined_return_elements.ToArray();
75102
}
76103

77104
private static void _ProcessNewOps(Graph graph)
@@ -100,8 +127,29 @@ public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions option
100127

101128
foreach (var name in return_elements)
102129
{
103-
throw new NotImplementedException("_PopulateTFImportGraphDefOptions");
130+
if(name.Contains(":"))
131+
{
132+
var (op_name, index) = _ParseTensorName(name);
133+
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index);
134+
}
135+
else
136+
{
137+
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name);
138+
}
104139
}
140+
141+
// c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints);
142+
}
143+
144+
private static (string, int) _ParseTensorName(string tensor_name)
145+
{
146+
var components = tensor_name.Split(':');
147+
if (components.Length == 2)
148+
return (components[0], int.Parse(components[1]));
149+
else if (components.Length == 1)
150+
return (components[0], 0);
151+
else
152+
throw new ValueError($"Cannot convert {tensor_name} to a tensor name.");
105153
}
106154

107155
public static Dictionary<string, Tensor> _ConvertInputMapValues(string name, Dictionary<string, Tensor> input_map)

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,12 @@ protected override void DisposeUnmanagedResources(IntPtr handle)
494494
c_api.TF_DeleteGraph(handle);
495495
}
496496

497+
public Tensor get_tensor_by_tf_output(TF_Output tf_output)
498+
{
499+
var op = _get_operation_by_tf_operation(tf_output.oper);
500+
return op.outputs[tf_output.index];
501+
}
502+
497503
/// <summary>
498504
/// Returns the <see cref="Tensor"/> with the given <paramref name="name"/>.
499505
/// This method may be called concurrently from multiple threads.

src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,62 @@
33

44
namespace Tensorflow
55
{
6-
[StructLayout(LayoutKind.Sequential)]
7-
public struct TF_ImportGraphDefResults
6+
public class TF_ImportGraphDefResults : DisposableObject
87
{
9-
public IntPtr return_tensors;
10-
public IntPtr return_nodes;
8+
/*public IntPtr return_nodes;
119
public IntPtr missing_unused_key_names;
1210
public IntPtr missing_unused_key_indexes;
13-
public IntPtr missing_unused_key_names_data;
11+
public IntPtr missing_unused_key_names_data;*/
12+
13+
public TF_ImportGraphDefResults(IntPtr handle)
14+
{
15+
_handle = handle;
16+
}
17+
18+
public TF_Output[] return_tensors
19+
{
20+
get
21+
{
22+
IntPtr return_output_handle = IntPtr.Zero;
23+
int num_outputs = -1;
24+
c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle);
25+
TF_Output[] return_outputs = new TF_Output[num_outputs];
26+
unsafe
27+
{
28+
var tf_output_ptr = (TF_Output*)return_output_handle;
29+
for (int i = 0; i < num_outputs; i++)
30+
return_outputs[i] = *(tf_output_ptr + i);
31+
return return_outputs;
32+
}
33+
}
34+
}
35+
36+
public TF_Operation[] return_opers
37+
{
38+
get
39+
{
40+
return new TF_Operation[0];
41+
/*TF_Operation return_output_handle = new TF_Operation();
42+
int num_outputs = -1;
43+
c_api.TF_ImportGraphDefResultsReturnOperations(_handle, ref num_outputs, ref return_output_handle);
44+
TF_Operation[] return_outputs = new TF_Operation[num_outputs];
45+
unsafe
46+
{
47+
var tf_output_ptr = (TF_Operation*)return_output_handle;
48+
for (int i = 0; i < num_outputs; i++)
49+
return_outputs[i] = *(tf_output_ptr + i);
50+
return return_outputs;
51+
}*/
52+
}
53+
}
54+
55+
public static implicit operator TF_ImportGraphDefResults(IntPtr handle)
56+
=> new TF_ImportGraphDefResults(handle);
57+
58+
public static implicit operator IntPtr(TF_ImportGraphDefResults results)
59+
=> results._handle;
60+
61+
protected override void DisposeUnmanagedResources(IntPtr handle)
62+
=> c_api.TF_DeleteImportGraphDefResults(handle);
1463
}
1564
}

src/TensorFlowNET.Core/Status/Status.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ public void Check(bool throwException = false)
6565
}
6666

6767
public static implicit operator IntPtr(Status status)
68-
{
69-
return status._handle;
70-
}
68+
=> status._handle;
7169

7270
protected override void DisposeUnmanagedResources(IntPtr handle)
7371
=> TF_DeleteStatus(handle);

0 commit comments

Comments
 (0)