Skip to content

Commit 6a820d8

Browse files
committed
changes in CondContext and control_flow_ops.cond
1 parent 40af0c5 commit 6a820d8

7 files changed

Lines changed: 337 additions & 140 deletions

File tree

src/TensorFlowNET.Core/Graphs/Graph.Control.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace Tensorflow
88
{
99
public partial class Graph
1010
{
11+
// Current control flow context. It could be either CondContext or WhileContext
1112
public IControlFlowContext _control_flow_context;
1213

1314
// represents the nested with(...) statements

src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,41 @@ public CondContext(Tensor pred,
6464
}
6565
}
6666

67+
/// <summary>
68+
/// Add the subgraph defined by fn() to the graph.
69+
/// </summary>
6770
public (T, Tensor) BuildCondBranch<T>(Func<T> fn)
6871
{
6972
// Add the subgraph defined by fn() to the graph.
7073
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
7174
var original_result = fn();
7275
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
7376

77+
//TODO: port this chunck of missing code:
78+
/*
79+
if len(post_summaries) > len(pre_summaries):
80+
new_summaries = post_summaries[len(pre_summaries):]
81+
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
82+
summary_ref[:] = pre_summaries
83+
with ops.control_dependencies(new_summaries):
84+
if original_result is None:
85+
return no_op(), None
86+
else:
87+
original_result = nest.map_structure(array_ops.identity,
88+
original_result)
89+
*/
90+
if (original_result == null)
91+
return (original_result, null);
92+
7493
switch (original_result)
7594
{
76-
case Operation[] results:
95+
case Operation[] results:
96+
// Python code:
97+
// result = nest.map_structure(self._BuildCondTensor, original_result)
7798
return (original_result, _BuildCondTensor(results));
99+
case Tensor t:
100+
// TODO: should this be (original_result, t) instead?
101+
return (original_result, _BuildCondTensor(new []{t.op}));
78102
case float[] fv:
79103
var result = ops.convert_to_tensor(fv[0]);
80104
return (original_result, result );

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,24 @@
33
using System.Text;
44

55
namespace Tensorflow.Operations
6-
{
6+
{
7+
/// <summary>
8+
/// The base class for control flow context.
9+
///
10+
/// The usage pattern is a sequence of(Enter, Exit) followed by a final
11+
/// ExitResult.
12+
///
13+
/// We maintain the following state for control flow contexts during graph
14+
/// construction:
15+
/// 1. graph has _control_flow_context: the current context used to
16+
/// construct new nodes.Changed by ctxt.Enter() and ctxt.Exit()
17+
/// 2. op has _control_flow_context: the context to which the op belongs.
18+
/// Set at the time the op is created.Immutable.
19+
/// 3. A ControlFlowContext has _outer_context: the context in which this
20+
/// context is created.Set at the time a context is created.Immutable.
21+
/// 4. A ControlFlowContext has _context_stack.
22+
/// Pushed and popped by ctxt.Enter() and ctxt.Exit()
23+
/// </summary>
724
public abstract class ControlFlowContext : IPython, IControlFlowContext
825
{
926
/// <summary>
@@ -17,6 +34,8 @@ public ControlFlowContext()
1734
_context_stack = new Stack<IControlFlowContext>();
1835
}
1936

37+
public string name { get; set; }
38+
2039
public void __init__()
2140
{
2241

@@ -26,13 +45,30 @@ public void __enter__()
2645
{
2746
}
2847

48+
public void __exit__()
49+
{
50+
}
51+
52+
/// <summary>
53+
/// Enter this control flow context.
54+
/// </summary>
2955
public virtual void Enter()
3056
{
3157
var graph = ops.get_default_graph();
3258
_context_stack.Push(graph._get_control_flow_context());
3359
graph._set_control_flow_context(this);
3460
}
3561

62+
/// <summary>
63+
/// Exit this control flow context.
64+
/// </summary>
65+
public virtual void Exit()
66+
{
67+
var graph = ops.get_default_graph();
68+
var last_context = _context_stack.Pop();
69+
graph._set_control_flow_context(last_context);
70+
}
71+
3672
public void AddOp(Operation op)
3773
{
3874
_AddOpInternal(op);
@@ -56,17 +92,6 @@ protected virtual void _RemoveExternalControlEdges(Operation op)
5692
var internal_control_inputs = op.control_inputs;
5793
}
5894

59-
public void Exit()
60-
{
61-
var graph = ops.get_default_graph();
62-
var last_context = _context_stack.Pop();
63-
graph._set_control_flow_context(last_context);
64-
}
65-
66-
public void __exit__()
67-
{
68-
}
69-
7095
public void Dispose()
7196
{
7297
}

src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

Lines changed: 122 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,50 @@ public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, stri
185185
ops.colocate_with(data, ignore_existing: true);
186186

187187
return @switch(data, pred, name: name);
188-
}
189-
188+
}
189+
190+
/// <summary>
191+
/// Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
192+
///
193+
/// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
194+
/// `false_fn` must have the same non-zero number and type of outputs.
195+
///
196+
/// **WARNING**: Any Tensors or Operations created outside of `true_fn` and
197+
/// `false_fn` will be executed regardless of which branch is selected at runtime.
198+
///
199+
/// Although this behavior is consistent with the dataflow model of TensorFlow,
200+
/// it has frequently surprised users who expected a lazier semantics.
201+
/// Consider the following simple program:
202+
///
203+
/// z = tf.multiply(a, b)
204+
/// result = tf.cond(x &lt; y, ()=> tf.add(x, z), ()=> tf.square(y))
205+
///
206+
/// If `x&lt;y`, the `tf.add` operation will be executed and `tf.square`
207+
/// operation will not be executed.Since `z` is needed for at least one
208+
/// branch of the `cond`, the `tf.multiply` operation is always executed,
209+
/// unconditionally.
210+
///
211+
/// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
212+
/// call to `cond`, and not at all during `Session.run()`). `cond`
213+
/// stitches together the graph fragments created during the `true_fn` and
214+
/// `false_fn` calls with some additional graph nodes to ensure that the right
215+
/// branch gets executed depending on the value of `pred`.
216+
///
217+
/// `tf.cond` supports nested structures as implemented in
218+
/// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
219+
/// same(possibly nested) value structure of lists, tuples, and/or named tuples.
220+
/// Singleton lists and tuples form the only exceptions to this: when returned by
221+
/// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
222+
/// This behavior is disabled by passing `strict= True`.
223+
/// </summary>
224+
/// <param name="pred"> A scalar determining whether to return the result of `true_fn` or
225+
/// `false_fn`.</param>
226+
/// <param name="true_fn">The callable to be performed if pred is true.</param>
227+
/// <param name="false_fn">The callable to be performed if pred is false.</param>
228+
/// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param>
229+
/// <param name="name">Optional name prefix for the returned tensors.</param>
230+
/// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the
231+
/// callables return a singleton list, the element is extracted from the list.</returns>
190232
public static Tensor cond(Tensor pred,
191233
Func<ITensorOrOperation> true_fn = null,
192234
Func<ITensorOrOperation> false_fn = null,
@@ -195,6 +237,37 @@ public static Tensor cond(Tensor pred,
195237
{
196238
return with(ops.name_scope(name, "cond", new { pred }), delegate
197239
{
240+
// TODO: here a chunk of original code is missing
241+
/*
242+
if fn1 is not None:
243+
if true_fn is not None:
244+
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
245+
true_fn = fn1
246+
elif true_fn is None:
247+
raise TypeError("cond(): true_fn argument required")
248+
if fn2 is not None:
249+
if false_fn is not None:
250+
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
251+
false_fn = fn2
252+
elif false_fn is None:
253+
raise TypeError("cond(): false_fn argument required")
254+
255+
if not callable(true_fn):
256+
raise TypeError("true_fn must be callable.")
257+
if not callable(false_fn):
258+
raise TypeError("false_fn must be callable.")
259+
260+
with ops.name_scope(name, "cond", [pred]):
261+
if context.executing_eagerly():
262+
if pred:
263+
return _UnpackIfSingleton(true_fn())
264+
return _UnpackIfSingleton(false_fn())
265+
266+
# Add the Switch to the graph.
267+
if isinstance(pred, bool):
268+
raise TypeError("pred must not be a Python bool")
269+
*/
270+
198271
// Add the Switch to the graph.
199272
var (p_2, p_1) = @switch(pred, pred);
200273
var pivot_1 = array_ops.identity(p_1, name: "switch_t");
@@ -207,30 +280,63 @@ public static Tensor cond(Tensor pred,
207280

208281
// Build the graph for the true branch in a new context.
209282
var context_t = new CondContext(pred, pivot_1, branch: 1);
210-
context_t.Enter();
211-
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
212-
context_t.Exit();
213-
283+
ITensorOrOperation orig_res_t;
284+
Tensor res_t;
285+
try
286+
{
287+
context_t.Enter();
288+
(orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
289+
}
290+
finally
291+
{
292+
context_t.Exit();
293+
}
214294
// Build the graph for the false branch in a new context.
215295
var context_f = new CondContext(pred, pivot_2, branch: 0);
216-
context_f.Enter();
217-
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
218-
context_f.Exit();
296+
ITensorOrOperation orig_res_f;
297+
Tensor res_f;
298+
try
299+
{
300+
context_f.Enter();
301+
(orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
302+
}
303+
finally
304+
{
305+
context_f.Exit();
306+
}
219307

220-
var res_t_flat = res_t;
221-
var res_f_flat = res_f;
308+
//TODO: missing original code
309+
//if not strict:
310+
// orig_res_t = _UnpackIfSingleton(orig_res_t)
311+
// orig_res_f = _UnpackIfSingleton(orig_res_f)
312+
/*
313+
# Check that the return values of the two branches have the same structure.
314+
try:
315+
nest.assert_same_structure(orig_res_t, orig_res_f)
316+
except TypeError as e:
317+
raise TypeError(
318+
"Incompatible return types of true_fn and false_fn: {}".format(e))
319+
except ValueError as e:
320+
raise ValueError(
321+
"Incompatible return values of true_fn and false_fn: {}".format(e))
322+
323+
# Add the final merge to the graph.
324+
if not res_t:
325+
raise ValueError("true_fn and false_fn must return at least one result.
326+
*/
327+
var res_t_flat = new[] { res_t };
328+
var res_f_flat = new[] { res_f };
222329

223-
return new Tensor(IntPtr.Zero);
224-
/*var merges = zip(res_f_flat, res_t_flat)
330+
var merges = zip(res_f_flat, res_t_flat)
225331
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
226332
.ToArray();
227333

228-
merges = _convert_flows_to_tensorarrays(orig_res_t, merges);
229-
334+
merges = _convert_flows_to_tensorarrays(new [] { orig_res_t}, merges);
335+
230336
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t);
231337
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f);
232338

233-
return merges;*/
339+
return merges[0];
234340
});
235341
}
236342

src/TensorFlowNET.Core/Python.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using NumSharp;
22
using System;
3+
using System.Collections;
34
using System.Collections.Generic;
45
using System.ComponentModel;
56
using System.Linq;
@@ -17,8 +18,8 @@ protected void print(object obj)
1718
Console.WriteLine(obj.ToString());
1819
}
1920

20-
protected int len(Array a)
21-
=> a.Length;
21+
protected int len<T>(IEnumerable<T> a)
22+
=> a.Count();
2223

2324
protected IEnumerable<int> range(int end)
2425
{

0 commit comments

Comments
 (0)