@@ -185,8 +185,50 @@ public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, stri
185185 ops . colocate_with ( data , ignore_existing : true ) ;
186186
187187 return @switch ( data , pred , name : name ) ;
188- }
189-
188+ }
189+
190+ /// <summary>
191+ /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
192+ ///
193+ /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
194+ /// `false_fn` must have the same non-zero number and type of outputs.
195+ ///
196+ /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and
197+ /// `false_fn` will be executed regardless of which branch is selected at runtime.
198+ ///
199+ /// Although this behavior is consistent with the dataflow model of TensorFlow,
200+ /// it has frequently surprised users who expected a lazier semantics.
201+ /// Consider the following simple program:
202+ ///
203+ /// z = tf.multiply(a, b)
204+ /// result = tf.cond(x < y, ()=> tf.add(x, z), ()=> tf.square(y))
205+ ///
206+ /// If `x<y`, the `tf.add` operation will be executed and `tf.square`
207+ /// operation will not be executed.Since `z` is needed for at least one
208+ /// branch of the `cond`, the `tf.multiply` operation is always executed,
209+ /// unconditionally.
210+ ///
211+ /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
212+ /// call to `cond`, and not at all during `Session.run()`). `cond`
213+ /// stitches together the graph fragments created during the `true_fn` and
214+ /// `false_fn` calls with some additional graph nodes to ensure that the right
215+ /// branch gets executed depending on the value of `pred`.
216+ ///
217+ /// `tf.cond` supports nested structures as implemented in
218+ /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
219+ /// same(possibly nested) value structure of lists, tuples, and/or named tuples.
220+ /// Singleton lists and tuples form the only exceptions to this: when returned by
221+ /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
222+ /// This behavior is disabled by passing `strict= True`.
223+ /// </summary>
224+ /// <param name="pred"> A scalar determining whether to return the result of `true_fn` or
225+ /// `false_fn`.</param>
226+ /// <param name="true_fn">The callable to be performed if pred is true.</param>
227+ /// <param name="false_fn">The callable to be performed if pred is false.</param>
228+ /// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param>
229+ /// <param name="name">Optional name prefix for the returned tensors.</param>
230+ /// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the
231+ /// callables return a singleton list, the element is extracted from the list.</returns>
190232 public static Tensor cond ( Tensor pred ,
191233 Func < ITensorOrOperation > true_fn = null ,
192234 Func < ITensorOrOperation > false_fn = null ,
@@ -195,6 +237,37 @@ public static Tensor cond(Tensor pred,
195237 {
196238 return with ( ops . name_scope ( name , "cond" , new { pred } ) , delegate
197239 {
240+ // TODO: here a chunk of original code is missing
241+ /*
242+ if fn1 is not None:
243+ if true_fn is not None:
244+ raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
245+ true_fn = fn1
246+ elif true_fn is None:
247+ raise TypeError("cond(): true_fn argument required")
248+ if fn2 is not None:
249+ if false_fn is not None:
250+ raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
251+ false_fn = fn2
252+ elif false_fn is None:
253+ raise TypeError("cond(): false_fn argument required")
254+
255+ if not callable(true_fn):
256+ raise TypeError("true_fn must be callable.")
257+ if not callable(false_fn):
258+ raise TypeError("false_fn must be callable.")
259+
260+ with ops.name_scope(name, "cond", [pred]):
261+ if context.executing_eagerly():
262+ if pred:
263+ return _UnpackIfSingleton(true_fn())
264+ return _UnpackIfSingleton(false_fn())
265+
266+ # Add the Switch to the graph.
267+ if isinstance(pred, bool):
268+ raise TypeError("pred must not be a Python bool")
269+ */
270+
198271 // Add the Switch to the graph.
199272 var ( p_2 , p_1 ) = @switch ( pred , pred ) ;
200273 var pivot_1 = array_ops . identity ( p_1 , name : "switch_t" ) ;
@@ -207,30 +280,63 @@ public static Tensor cond(Tensor pred,
207280
208281 // Build the graph for the true branch in a new context.
209282 var context_t = new CondContext ( pred , pivot_1 , branch : 1 ) ;
210- context_t . Enter ( ) ;
211- var ( orig_res_t , res_t ) = context_t . BuildCondBranch ( true_fn ) ;
212- context_t . Exit ( ) ;
213-
283+ ITensorOrOperation orig_res_t ;
284+ Tensor res_t ;
285+ try
286+ {
287+ context_t . Enter ( ) ;
288+ ( orig_res_t , res_t ) = context_t . BuildCondBranch ( true_fn ) ;
289+ }
290+ finally
291+ {
292+ context_t . Exit ( ) ;
293+ }
214294 // Build the graph for the false branch in a new context.
215295 var context_f = new CondContext ( pred , pivot_2 , branch : 0 ) ;
216- context_f . Enter ( ) ;
217- var ( orig_res_f , res_f ) = context_f . BuildCondBranch ( false_fn ) ;
218- context_f . Exit ( ) ;
296+ ITensorOrOperation orig_res_f ;
297+ Tensor res_f ;
298+ try
299+ {
300+ context_f . Enter ( ) ;
301+ ( orig_res_f , res_f ) = context_f . BuildCondBranch ( false_fn ) ;
302+ }
303+ finally
304+ {
305+ context_f . Exit ( ) ;
306+ }
219307
220- var res_t_flat = res_t ;
221- var res_f_flat = res_f ;
308+ //TODO: missing original code
309+ //if not strict:
310+ // orig_res_t = _UnpackIfSingleton(orig_res_t)
311+ // orig_res_f = _UnpackIfSingleton(orig_res_f)
312+ /*
313+ # Check that the return values of the two branches have the same structure.
314+ try:
315+ nest.assert_same_structure(orig_res_t, orig_res_f)
316+ except TypeError as e:
317+ raise TypeError(
318+ "Incompatible return types of true_fn and false_fn: {}".format(e))
319+ except ValueError as e:
320+ raise ValueError(
321+ "Incompatible return values of true_fn and false_fn: {}".format(e))
322+
323+ # Add the final merge to the graph.
324+ if not res_t:
325+ raise ValueError("true_fn and false_fn must return at least one result.
326+ */
327+ var res_t_flat = new [ ] { res_t } ;
328+ var res_f_flat = new [ ] { res_f } ;
222329
223- return new Tensor ( IntPtr . Zero ) ;
224- /*var merges = zip(res_f_flat, res_t_flat)
330+ var merges = zip ( res_f_flat , res_t_flat )
225331 . Select ( pair => merge ( new Tensor [ ] { pair . Item1 , pair . Item2 } ) )
226332 . ToArray ( ) ;
227333
228- merges = _convert_flows_to_tensorarrays(orig_res_t, merges);
229-
334+ merges = _convert_flows_to_tensorarrays ( new [ ] { orig_res_t } , merges ) ;
335+
230336 ops . add_to_collection ( ops . GraphKeys . COND_CONTEXT , context_t ) ;
231337 ops . add_to_collection ( ops . GraphKeys . COND_CONTEXT , context_f ) ;
232338
233- return merges;*/
339+ return merges [ 0 ] ;
234340 } ) ;
235341 }
236342
0 commit comments