Skip to content

Commit 40af0c5

Browse files
committed
Graph.unique_name fixed + test case
1 parent fe4a06f commit 40af0c5

4 files changed

Lines changed: 106 additions & 85 deletions

File tree

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ private Tensor _as_graph_element(object obj)
7373
return var._as_graph_element();
7474

7575
return null;
76-
}
77-
76+
}
77+
7878
private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
7979
{
8080
string types_str = "";
@@ -99,15 +99,15 @@ private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tenso
9999
// If obj appears to be a name...
100100
if (obj is string name)
101101
{
102-
if(name.Contains(":") && allow_tensor)
102+
if (name.Contains(":") && allow_tensor)
103103
{
104104
string op_name = name.Split(':')[0];
105105
int out_n = int.Parse(name.Split(':')[1]);
106106

107107
if (_nodes_by_name.ContainsKey(op_name))
108108
return _nodes_by_name[op_name].outputs[out_n];
109109
}
110-
else if(!name.Contains(":") & allow_operation)
110+
else if (!name.Contains(":") & allow_operation)
111111
{
112112
if (!_nodes_by_name.ContainsKey(name))
113113
throw new KeyError($"The name {name} refers to an Operation not in the graph.");
@@ -166,8 +166,8 @@ private void _check_not_finalized()
166166
throw new RuntimeError("Graph is finalized and cannot be modified.");
167167
}
168168

169-
public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
170-
TF_DataType[] input_types = null, string name = null,
169+
public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
170+
TF_DataType[] input_types = null, string name = null,
171171
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
172172
{
173173
if (inputs == null)
@@ -188,7 +188,7 @@ public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[]
188188
var input_ops = inputs.Select(x => x.op).ToArray();
189189
var control_inputs = _control_dependencies_for_inputs(input_ops);
190190

191-
var op = new Operation(node_def,
191+
var op = new Operation(node_def,
192192
this,
193193
inputs: inputs,
194194
output_types: dtypes,
@@ -259,54 +259,61 @@ public string name_scope(string name)
259259
_name_stack = new_stack;
260260

261261
return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/";
262-
}
263-
262+
}
263+
264+
/// <summary>
265+
/// Return a unique operation name for `name`.
266+
///
267+
/// Note: You rarely need to call `unique_name()` directly.Most of
268+
/// the time you just need to create `with g.name_scope()` blocks to
269+
/// generate structured names.
270+
///
271+
/// `unique_name` is used to generate structured names, separated by
272+
/// `"/"`, to help identify operations when debugging a graph.
273+
/// Operation names are displayed in error messages reported by the
274+
/// TensorFlow runtime, and in various visualization tools such as
275+
/// TensorBoard.
276+
///
277+
/// If `mark_as_used` is set to `True`, which is the default, a new
278+
/// unique name is created and marked as in use.If it's set to `False`,
279+
/// the unique name is returned without actually being marked as used.
280+
/// This is useful when the caller simply wants to know what the name
281+
/// to be created will be.
282+
/// </summary>
283+
/// <param name="name">The name for an operation.</param>
284+
/// <param name="mark_as_used"> Whether to mark this name as being used.</param>
285+
/// <returns>A string to be passed to `create_op()` that will be used
286+
/// to name the operation being created.</returns>
264287
public string unique_name(string name, bool mark_as_used = true)
265288
{
266289
if (!String.IsNullOrEmpty(_name_stack))
267-
{
268290
name = _name_stack + "/" + name;
269-
}
270-
291+
// For the sake of checking for names in use, we treat names as case
292+
// insensitive (e.g. foo = Foo).
271293
var name_key = name.ToLower();
272294
int i = 0;
273295
if (_names_in_use.ContainsKey(name_key))
274-
{
275-
foreach (var item in _names_in_use)
276-
{
277-
if (item.Key == name_key)
278-
{
279-
i = _names_in_use[name_key];
280-
break;
281-
}
282-
283-
i++;
284-
}
285-
}
286-
296+
i = _names_in_use[name_key];
297+
// Increment the number for "name_key".
287298
if (mark_as_used)
288-
if (_names_in_use.ContainsKey(name_key))
289-
_names_in_use[name_key]++;
290-
else
291-
_names_in_use[name_key] = i + 1;
292-
299+
_names_in_use[name_key] = i + 1;
293300
if (i > 0)
294301
{
295-
var base_name_key = name_key;
296-
297302
// Make sure the composed name key is not already used.
298-
if (_names_in_use.ContainsKey(name_key))
303+
var base_name_key = name_key;
304+
while (_names_in_use.ContainsKey(name_key))
299305
{
300306
name_key = $"{base_name_key}_{i}";
301307
i += 1;
302308
}
303-
309+
// Mark the composed name_key as used in case someone wants
310+
// to call unique_name("name_1").
304311
if (mark_as_used)
305312
_names_in_use[name_key] = 1;
306313

307-
name = $"{name}_{i - 1}";
314+
// Return the new name with the original capitalization of the given name.
315+
name = $"{name}_{i-1}";
308316
}
309-
310317
return name;
311318
}
312319

@@ -375,8 +382,8 @@ public void prevent_feeding(Tensor tensor)
375382
public void prevent_fetching(Operation op)
376383
{
377384
_unfetchable_ops.Add(op);
378-
}
379-
385+
}
386+
380387
public void Dispose()
381388
{
382389
c_api.TF_DeleteGraph(_handle);
@@ -387,8 +394,8 @@ public void __enter__()
387394
}
388395

389396
public void __exit__()
390-
{
391-
397+
{
398+
392399
}
393400

394401
public static implicit operator IntPtr(Graph graph)

test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ public void TestNested()
157157
});
158158
});
159159
});
160-
AssertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs);
161-
AssertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
160+
assertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs);
161+
assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
162162
}
163163

164164
[TestMethod]
@@ -200,12 +200,12 @@ public void TestClear()
200200
b_none2 = constant_op.constant(12.0);
201201
});
202202
});
203-
AssertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs);
204-
AssertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs);
205-
AssertItemsEqual(new object[0], b_none.op.control_inputs);
206-
AssertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs);
207-
AssertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs);
208-
AssertItemsEqual(new object[0], b_none2.op.control_inputs);
203+
assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs);
204+
assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs);
205+
assertItemsEqual(new object[0], b_none.op.control_inputs);
206+
assertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs);
207+
assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs);
208+
assertItemsEqual(new object[0], b_none2.op.control_inputs);
209209
}
210210

211211
[TestMethod]
@@ -256,25 +256,25 @@ public void TestComplex()
256256
});
257257
});
258258

259-
AssertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
260-
AssertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs);
261-
AssertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs);
262-
AssertItemsEqual(new[] {a_1.op, a_2.op}, b_4.op.control_inputs);
259+
assertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
260+
assertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs);
261+
assertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs);
262+
assertItemsEqual(new[] {a_1.op, a_2.op}, b_4.op.control_inputs);
263263

264-
AssertItemsEqual(new object[0], c_1.op.control_inputs);
265-
AssertItemsEqual(new[] {a_2.op}, c_2.op.control_inputs);
266-
AssertItemsEqual(new[] {a_2.op, a_3.op}, c_3.op.control_inputs);
267-
AssertItemsEqual(new[] {a_2.op, a_3.op, a_4.op}, c_4.op.control_inputs);
264+
assertItemsEqual(new object[0], c_1.op.control_inputs);
265+
assertItemsEqual(new[] {a_2.op}, c_2.op.control_inputs);
266+
assertItemsEqual(new[] {a_2.op, a_3.op}, c_3.op.control_inputs);
267+
assertItemsEqual(new[] {a_2.op, a_3.op, a_4.op}, c_4.op.control_inputs);
268268

269-
AssertItemsEqual(new object[0], d_1.op.control_inputs);
270-
AssertItemsEqual(new object[0], d_2.op.control_inputs);
271-
AssertItemsEqual(new object[0], d_3.op.control_inputs);
272-
AssertItemsEqual(new object[0], d_4.op.control_inputs);
269+
assertItemsEqual(new object[0], d_1.op.control_inputs);
270+
assertItemsEqual(new object[0], d_2.op.control_inputs);
271+
assertItemsEqual(new object[0], d_3.op.control_inputs);
272+
assertItemsEqual(new object[0], d_4.op.control_inputs);
273273

274-
AssertItemsEqual(new[] {a_1.op}, e_1.op.control_inputs);
275-
AssertItemsEqual(new[] {a_2.op}, e_2.op.control_inputs);
276-
AssertItemsEqual(new[] {a_3.op}, e_3.op.control_inputs);
277-
AssertItemsEqual(new[] {a_4.op}, e_4.op.control_inputs);
274+
assertItemsEqual(new[] {a_1.op}, e_1.op.control_inputs);
275+
assertItemsEqual(new[] {a_2.op}, e_2.op.control_inputs);
276+
assertItemsEqual(new[] {a_3.op}, e_3.op.control_inputs);
277+
assertItemsEqual(new[] {a_4.op}, e_4.op.control_inputs);
278278
}
279279

280280
[Ignore("Don't know how to create an operation with two outputs")]

test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,35 @@ public void TestShape()
3333
Assert.AreEqual("myop", op.name);
3434
Assert.AreEqual("Identity", op.type);
3535
Assert.AreEqual(1, len(op.outputs));
36-
AssertItemsEqual(new []{2, 3}, op.outputs[0].shape);
36+
assertItemsEqual(new []{2, 3}, op.outputs[0].shape);
3737
});
3838
}
3939

40-
/*def testUniqueName(self):
41-
g = ops.Graph()
42-
with g.as_default():
43-
c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
44-
c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
45-
op = g._create_op_from_tf_operation(c_op)
46-
op2 = g._create_op_from_tf_operation(c_op2)
47-
48-
# Create ops with same names as op1 and op2. We expect the new names to be
49-
# uniquified.
50-
op3 = test_ops.int_output(name="myop").op
51-
op4 = test_ops.int_output(name="myop_1").op
52-
53-
self.assertEqual(op.name, "myop")
54-
self.assertEqual(op2.name, "myop_1")
55-
self.assertEqual(op3.name, "myop_2")
56-
self.assertEqual(op4.name, "myop_1_1")
57-
40+
[TestMethod]
41+
public void TestUniqueName()
42+
{
43+
var graph = tf.Graph().as_default();
44+
with<Graph>(graph, g =>
45+
{
46+
//var (c_op,op_desc) = ops._create_c_op(g, ops._NodeDef("Const", "myop"), new Tensor[0], new Operation[0]);
47+
//var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]);
48+
//var op = g._create_op_from_tf_operation(c_op);
49+
//var op2 = g._create_op_from_tf_operation(c_op2);
50+
var op = constant_op.constant(0, name:"myop").op;
51+
var op2 = constant_op.constant(0, name: "myop_1").op;
52+
53+
// Create ops with same names as op1 and op2. We expect the new names to be
54+
// uniquified.
55+
var op3 = constant_op.constant(0, name: "myop").op;
56+
var op4 = constant_op.constant(0, name: "myop_1").op;
57+
58+
self.assertEqual(op.name, "myop");
59+
self.assertEqual(op2.name, "myop_1");
60+
self.assertEqual(op3.name, "myop_2");
61+
self.assertEqual(op4.name, "myop_1_1");
62+
});
63+
}
64+
/*
5865
@test_util.run_v1_only("b/120545219")
5966
def testCond(self):
6067
g = ops.Graph()
@@ -164,5 +171,5 @@ with ops.control_dependencies([c]):
164171
165172
166173
*/
167-
}
174+
}
168175
}

test/TensorFlowNET.UnitTest/PythonTest.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest
1313
/// </summary>
1414
public class PythonTest : Python
1515
{
16-
public void AssertItemsEqual(ICollection expected, ICollection given)
16+
public void assertItemsEqual(ICollection expected, ICollection given)
1717
{
1818
Assert.IsNotNull(expected);
1919
Assert.IsNotNull(given);
@@ -23,5 +23,12 @@ public void AssertItemsEqual(ICollection expected, ICollection given)
2323
for(int i=0; i<e.Length; i++)
2424
Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}");
2525
}
26+
27+
public void assertEqual(object given, object expected)
28+
{
29+
Assert.AreEqual(expected, given);
30+
}
31+
32+
protected PythonTest self { get => this; }
2633
}
2734
}

0 commit comments

Comments
 (0)