Skip to content

Commit 3615dba

Browse files
authored
Merge pull request SciSharp#90 from Esther2013/master
add FinishOperation to OperationDescription
2 parents ccce438 + 119f0c5 commit 3615dba

7 files changed

Lines changed: 37 additions & 47 deletions

File tree

src/TensorFlowNET.Core/Graphs/Graph.Operation.cs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,7 @@ public OpDef GetOpDef(string type)
2020

2121
public OperationDescription NewOperation(string opType, string opName)
2222
{
23-
OperationDescription desc = c_api.TF_NewOperation(_handle, opType, opName);
24-
return desc;
25-
26-
/*c_api.TF_SetAttrTensor(desc, "value", tensor, Status);
27-
28-
Status.Check();
29-
30-
c_api.TF_SetAttrType(desc, "dtype", tensor.dtype);
31-
32-
var op = c_api.TF_FinishOperation(desc, Status);
33-
Status.Check();
34-
35-
return op;*/
23+
return c_api.TF_NewOperation(_handle, opType, opName);
3624
}
3725
}
3826
}

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,17 @@ public object get_attr(string name)
145145
return ret;
146146
}
147147

148+
public NodeDef GetNodeDef()
149+
{
150+
using (var s = new Status())
151+
using (var buffer = new Buffer())
152+
{
153+
c_api.TF_OperationToNodeDef(_handle, buffer, s);
154+
s.Check();
155+
return NodeDef.Parser.ParseFrom(buffer);
156+
}
157+
}
158+
148159
public static implicit operator Operation(IntPtr handle) => new Operation(handle);
149160
public static implicit operator IntPtr(Operation op) => op._handle;
150161

src/TensorFlowNET.Core/Operations/OperationDescription.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ public void AddInputList(params TF_Output[] inputs)
1818
c_api.TF_AddInputList(_handle, inputs, inputs.Length);
1919
}
2020

21+
public Operation FinishOperation(Status status)
22+
{
23+
return c_api.TF_FinishOperation(_handle, status);
24+
}
25+
2126
public static implicit operator OperationDescription(IntPtr handle)
2227
{
2328
return new OperationDescription(handle);

src/TensorFlowNET.Core/Operations/c_api.ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ public static partial class c_api
232232
/// <param name="lengths"></param>
233233
/// <param name="num_values"></param>
234234
[DllImport(TensorFlowLibName)]
235-
public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, string[] values, uint[] lengths, int num_values);
235+
public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, IntPtr[] values, uint[] lengths, int num_values);
236236

237237
[DllImport(TensorFlowLibName)]
238238
public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status);

test/TensorFlowNET.UnitTest/CApiColocationTest.cs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,34 +30,32 @@ public void SetUp()
3030
s_.Check();
3131
constant_ = c_test_util.ScalarConst(10, graph_, s_);
3232
s_.Check();
33-
desc_ = c_api.TF_NewOperation(graph_, "AddN", "add");
34-
s_.Check();
3533

34+
desc_ = graph_.NewOperation("AddN", "add");
3635
TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) };
3736
desc_.AddInputList(inputs);
38-
s_.Check();
3937
}
4038

4139
private void SetViaStringList(OperationDescription desc, string[] list)
4240
{
43-
string[] list_ptrs = new string[list.Length];
44-
uint[] list_lens = new uint[list.Length];
41+
var list_ptrs = new IntPtr[list.Length];
42+
var list_lens = new uint[list.Length];
4543
StringVectorToArrays(list, list_ptrs, list_lens);
4644
c_api.TF_SetAttrStringList(desc, "_class", list_ptrs, list_lens, list.Length);
4745
}
4846

49-
private void StringVectorToArrays(string[] v, string[] ptrs, uint[] lens)
47+
private void StringVectorToArrays(string[] v, IntPtr[] ptrs, uint[] lens)
5048
{
5149
for (int i = 0; i < v.Length; ++i)
5250
{
53-
ptrs[i] = v[i];// Marshal.StringToHGlobalAnsi(v[i]);
51+
ptrs[i] = Marshal.StringToHGlobalAnsi(v[i]);
5452
lens[i] = (uint)v[i].Length;
5553
}
5654
}
5755

5856
private void FinishAndVerify(OperationDescription desc, string[] expected)
5957
{
60-
Operation op = c_api.TF_FinishOperation(desc_, s_);
58+
var op = desc_.FinishOperation(s_);
6159
ASSERT_EQ(TF_Code.TF_OK, s_.Code);
6260
VerifyCollocation(op, expected);
6361
}

test/TensorFlowNET.UnitTest/GraphTest.cs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ public void Graph()
130130
EXPECT_EQ(TF_Code.TF_OK, s.Code);
131131

132132
// Serialize to NodeDef.
133-
var node_def = c_test_util.GetNodeDef(neg);
133+
var node_def = neg.GetNodeDef();
134134

135135
// Validate NodeDef is what we expect.
136136
ASSERT_TRUE(c_test_util.IsNeg(node_def, "add"));
@@ -145,13 +145,13 @@ public void Graph()
145145
// Look up some nodes by name.
146146
Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg");
147147
EXPECT_EQ(neg, neg2);
148-
var node_def2 = c_test_util.GetNodeDef(neg2);
148+
var node_def2 = neg2.GetNodeDef();
149149
EXPECT_EQ(node_def.ToString(), node_def2.ToString());
150150

151151
Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed");
152152
EXPECT_EQ(feed, feed2);
153-
node_def = c_test_util.GetNodeDef(feed);
154-
node_def2 = c_test_util.GetNodeDef(feed2);
153+
node_def = feed.GetNodeDef();
154+
node_def2 = feed2.GetNodeDef();
155155
EXPECT_EQ(node_def.ToString(), node_def2.ToString());
156156

157157
// Test iterating through the nodes of a graph.
@@ -162,7 +162,7 @@ public void Graph()
162162
uint pos = 0;
163163
Operation oper;
164164

165-
while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
165+
while ((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
166166
{
167167
if (oper.Equals(feed))
168168
{
@@ -186,7 +186,7 @@ public void Graph()
186186
}
187187
else
188188
{
189-
node_def = c_test_util.GetNodeDef(oper);
189+
node_def = oper.GetNodeDef();
190190
Assert.Fail($"Unexpected Node: {node_def.ToString()}");
191191
}
192192
}
@@ -256,7 +256,7 @@ public void ImportGraphDef()
256256
EXPECT_EQ(0, neg.GetControlInputs().Length);
257257
EXPECT_EQ(0, neg.NumControlOutputs);
258258
EXPECT_EQ(0, neg.GetControlOutputs().Length);
259-
259+
260260
// Import it again, with an input mapping, return outputs, and a return
261261
// operation, into the same graph.
262262
c_api.TF_DeleteImportGraphDefOptions(opts);
@@ -270,7 +270,7 @@ public void ImportGraphDef()
270270
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
271271
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
272272
EXPECT_EQ(TF_Code.TF_OK, s.Code);
273-
273+
274274
Operation scalar2 = graph.OperationByName("imported2/scalar");
275275
Operation feed2 = graph.OperationByName("imported2/feed");
276276
Operation neg2 = graph.OperationByName("imported2/neg");
@@ -287,7 +287,7 @@ public void ImportGraphDef()
287287
EXPECT_EQ(0, return_outputs[0].index);
288288
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
289289
EXPECT_EQ(0, return_outputs[1].index);
290-
290+
291291
// Check return operation
292292
var return_opers = graph.ReturnOperations(results);
293293
ASSERT_EQ(1, return_opers.Length);
@@ -302,26 +302,26 @@ public void ImportGraphDef()
302302
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
303303
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
304304
EXPECT_EQ(TF_Code.TF_OK, s.Code);
305-
305+
306306
var scalar3 = graph.OperationByName("imported3/scalar");
307307
var feed3 = graph.OperationByName("imported3/feed");
308308
var neg3 = graph.OperationByName("imported3/neg");
309309
ASSERT_TRUE(scalar3 != IntPtr.Zero);
310310
ASSERT_TRUE(feed3 != IntPtr.Zero);
311311
ASSERT_TRUE(neg3 != IntPtr.Zero);
312-
312+
313313
// Check that newly-imported scalar and feed have control deps (neg3 will
314314
// inherit them from input)
315315
var control_inputs = scalar3.GetControlInputs();
316316
ASSERT_EQ(2, scalar3.NumControlInputs);
317317
EXPECT_EQ(feed, control_inputs[0]);
318318
EXPECT_EQ(feed2, control_inputs[1]);
319-
319+
320320
control_inputs = feed3.GetControlInputs();
321321
ASSERT_EQ(2, feed3.NumControlInputs);
322322
EXPECT_EQ(feed, control_inputs[0]);
323323
EXPECT_EQ(feed2, control_inputs[1]);
324-
324+
325325
// Export to a graph def so we can import a graph with control dependencies
326326
graph_def.Dispose();
327327
graph_def = new Buffer();

test/TensorFlowNET.UnitTest/c_test_util.cs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,6 @@ public static GraphDef GetGraphDef(Graph graph)
5151
return def;
5252
}
5353

54-
public static NodeDef GetNodeDef(Operation oper)
55-
{
56-
var s = new Status();
57-
var buffer = new Buffer();
58-
c_api.TF_OperationToNodeDef(oper, buffer, s);
59-
s.Check();
60-
var ret = NodeDef.Parser.ParseFrom(buffer);
61-
buffer.Dispose();
62-
s.Dispose();
63-
return ret;
64-
}
65-
6654
public static bool IsAddN(NodeDef node_def, int n)
6755
{
6856
if (node_def.Op != "AddN" || node_def.Name != "add" ||

0 commit comments

Comments
 (0)