Skip to content

Commit f7e61b0

Browse files
sharwellOceania2018
authored andcommitted
Implement SafeImportGraphDefResultsHandle as a wrapper for TF_ImportGraphDefResults
1 parent 09600a8 commit f7e61b0

7 files changed

Lines changed: 103 additions & 47 deletions

File tree

src/TensorFlowNET.Core/Framework/importer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public static ITensorOrOperation[] import_graph_def(GraphDef graph_def,
6262
{
6363
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
6464
// need to create a class ImportGraphDefWithResults with IDisposal
65-
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle);
65+
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle));
6666
status.Check(true);
6767
}
6868

src/TensorFlowNET.Core/Graphs/Graph.Operation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public OperationDescription NewOperation(string opType, string opName)
3333
return c_api.TF_NewOperation(_handle, opType, opName);
3434
}
3535

36-
public Operation[] ReturnOperations(IntPtr results)
36+
public Operation[] ReturnOperations(SafeImportGraphDefResultsHandle results)
3737
{
3838
TF_Operation return_oper_handle = new TF_Operation();
3939
int num_return_opers = 0;

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ public string unique_name(string name, bool mark_as_used = true)
413413
return name;
414414
}
415415

416-
public TF_Output[] ReturnOutputs(IntPtr results)
416+
public TF_Output[] ReturnOutputs(SafeImportGraphDefResultsHandle results)
417417
{
418418
IntPtr return_output_handle = IntPtr.Zero;
419419
int num_return_outputs = 0;
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
18+
using Tensorflow.Util;
19+
20+
namespace Tensorflow
21+
{
22+
public sealed class SafeImportGraphDefResultsHandle : SafeTensorflowHandle
23+
{
24+
private SafeImportGraphDefResultsHandle()
25+
{
26+
}
27+
28+
public SafeImportGraphDefResultsHandle(IntPtr handle)
29+
: base(handle)
30+
{
31+
}
32+
33+
protected override bool ReleaseHandle()
34+
{
35+
c_api.TF_DeleteImportGraphDefResults(handle);
36+
SetHandle(IntPtr.Zero);
37+
return true;
38+
}
39+
}
40+
}

src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,35 @@
1-
using System;
2-
using System.Runtime.InteropServices;
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
318

419
namespace Tensorflow
520
{
6-
public class TF_ImportGraphDefResults : DisposableObject
21+
public sealed class TF_ImportGraphDefResults : IDisposable
722
{
823
/*public IntPtr return_nodes;
924
public IntPtr missing_unused_key_names;
1025
public IntPtr missing_unused_key_indexes;
1126
public IntPtr missing_unused_key_names_data;*/
1227

13-
public TF_ImportGraphDefResults(IntPtr handle)
28+
private SafeImportGraphDefResultsHandle Handle { get; }
29+
30+
public TF_ImportGraphDefResults(SafeImportGraphDefResultsHandle handle)
1431
{
15-
_handle = handle;
32+
Handle = handle;
1633
}
1734

1835
public TF_Output[] return_tensors
@@ -21,7 +38,7 @@ public TF_Output[] return_tensors
2138
{
2239
IntPtr return_output_handle = IntPtr.Zero;
2340
int num_outputs = -1;
24-
c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle);
41+
c_api.TF_ImportGraphDefResultsReturnOutputs(Handle, ref num_outputs, ref return_output_handle);
2542
TF_Output[] return_outputs = new TF_Output[num_outputs];
2643
unsafe
2744
{
@@ -52,13 +69,7 @@ public TF_Operation[] return_opers
5269
}
5370
}
5471

55-
public static implicit operator TF_ImportGraphDefResults(IntPtr handle)
56-
=> new TF_ImportGraphDefResults(handle);
57-
58-
public static implicit operator IntPtr(TF_ImportGraphDefResults results)
59-
=> results._handle;
60-
61-
protected override void DisposeUnmanagedResources(IntPtr handle)
62-
=> c_api.TF_DeleteImportGraphDefResults(handle);
72+
public void Dispose()
73+
=> Handle.Dispose();
6374
}
6475
}

src/TensorFlowNET.Core/Graphs/c_api.graph.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ public partial class c_api
9292
/// <param name="status">TF_Status*</param>
9393
/// <returns>TF_ImportGraphDefResults*</returns>
9494
[DllImport(TensorFlowLibName)]
95-
public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
95+
public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
9696

9797
/// <summary>
9898
/// Import the graph serialized in `graph_def` into `graph`.
@@ -258,7 +258,7 @@ public partial class c_api
258258
/// <param name="num_opers">int*</param>
259259
/// <param name="opers">TF_Operation***</param>
260260
[DllImport(TensorFlowLibName)]
261-
public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref TF_Operation opers);
261+
public static extern void TF_ImportGraphDefResultsReturnOperations(SafeImportGraphDefResultsHandle results, ref int num_opers, ref TF_Operation opers);
262262

263263
/// <summary>
264264
/// Fetches the return outputs requested via
@@ -270,7 +270,7 @@ public partial class c_api
270270
/// <param name="num_outputs">int*</param>
271271
/// <param name="outputs">TF_Output**</param>
272272
[DllImport(TensorFlowLibName)]
273-
public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs);
273+
public static extern void TF_ImportGraphDefResultsReturnOutputs(SafeImportGraphDefResultsHandle results, ref int num_outputs, ref IntPtr outputs);
274274

275275
/// <summary>
276276
/// This function creates a new TF_Session (which is created on success) using

test/TensorFlowNET.UnitTest/GraphTest.cs

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -258,44 +258,49 @@ public void ImportGraphDef()
258258
EXPECT_EQ(0, neg.NumControlOutputs);
259259
EXPECT_EQ(0, neg.GetControlOutputs().Length);
260260

261-
// Import it again, with an input mapping, return outputs, and a return
262-
// operation, into the same graph.
263-
IntPtr results;
264-
using (var opts = c_api.TF_NewImportGraphDefOptions())
261+
static SafeImportGraphDefResultsHandle ImportGraph(Status s, Graph graph, Buffer graph_def, Operation scalar)
265262
{
263+
using var opts = c_api.TF_NewImportGraphDefOptions();
266264
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
267265
c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0));
268266
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
269267
c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
270268
EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts));
271269
c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
272270
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
273-
results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle);
271+
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle);
274272
EXPECT_EQ(TF_Code.TF_OK, s.Code);
275-
}
276-
277-
Operation scalar2 = graph.OperationByName("imported2/scalar");
278-
Operation feed2 = graph.OperationByName("imported2/feed");
279-
Operation neg2 = graph.OperationByName("imported2/neg");
280-
281-
// Check input mapping
282-
neg_input = neg.Input(0);
283-
EXPECT_EQ(scalar, neg_input.oper);
284-
EXPECT_EQ(0, neg_input.index);
285273

286-
// Check return outputs
287-
var return_outputs = graph.ReturnOutputs(results);
288-
ASSERT_EQ(2, return_outputs.Length);
289-
EXPECT_EQ(feed2, return_outputs[0].oper);
290-
EXPECT_EQ(0, return_outputs[0].index);
291-
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
292-
EXPECT_EQ(0, return_outputs[1].index);
274+
return results;
275+
}
293276

294-
// Check return operation
295-
var return_opers = graph.ReturnOperations(results);
296-
ASSERT_EQ(1, return_opers.Length);
297-
EXPECT_EQ(scalar2, return_opers[0]); // not remapped
298-
c_api.TF_DeleteImportGraphDefResults(results);
277+
// Import it again, with an input mapping, return outputs, and a return
278+
// operation, into the same graph.
279+
Operation feed2;
280+
using (SafeImportGraphDefResultsHandle results = ImportGraph(s, graph, graph_def, scalar))
281+
{
282+
Operation scalar2 = graph.OperationByName("imported2/scalar");
283+
feed2 = graph.OperationByName("imported2/feed");
284+
Operation neg2 = graph.OperationByName("imported2/neg");
285+
286+
// Check input mapping
287+
neg_input = neg.Input(0);
288+
EXPECT_EQ(scalar, neg_input.oper);
289+
EXPECT_EQ(0, neg_input.index);
290+
291+
// Check return outputs
292+
var return_outputs = graph.ReturnOutputs(results);
293+
ASSERT_EQ(2, return_outputs.Length);
294+
EXPECT_EQ(feed2, return_outputs[0].oper);
295+
EXPECT_EQ(0, return_outputs[0].index);
296+
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
297+
EXPECT_EQ(0, return_outputs[1].index);
298+
299+
// Check return operation
300+
var return_opers = graph.ReturnOperations(results);
301+
ASSERT_EQ(1, return_opers.Length);
302+
EXPECT_EQ(scalar2, return_opers[0]); // not remapped
303+
}
299304

300305
// Import again, with control dependencies, into the same graph.
301306
using (var opts = c_api.TF_NewImportGraphDefOptions())

0 commit comments

Comments
 (0)