@@ -45,7 +45,24 @@ public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
4545 switch ( op_ctxt )
4646 {
4747 case WhileContext cwhile :
48- throw new NotImplementedException ( "_SwitchGrad WhileContext" ) ;
48+ {
49+ var merge_grad = grad_ctxt . grad_state . switch_map . get ( op ) ;
50+ if ( merge_grad != null )
51+ {
52+ if ( grads [ 1 ] != null )
53+ control_flow_ops . _AddNextAndBackEdge ( merge_grad , grads [ 1 ] ,
54+ enforce_shape_invariant : false ) ;
55+ return new Tensor [ ] { null , null } ;
56+ }
57+ else if ( grads [ 0 ] != null )
58+ {
59+ merge_grad = merge ( new [ ] { grads [ 0 ] , grads [ 0 ] } , name : "b_switch" ) [ 0 ] ;
60+ grad_ctxt . grad_state . switch_map [ op ] = merge_grad ;
61+ return new Tensor [ ] { merge_grad , null } ;
62+ }
63+ else
64+ return new Tensor [ ] { null , null } ;
65+ }
4966 case CondContext ccond :
5067 {
5168 var zero_grad = grads [ 1 - op_ctxt . branch ] ;
@@ -74,7 +91,7 @@ public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
7491 /// <param name="inputs"></param>
7592 /// <param name="name"></param>
7693 /// <returns></returns>
77- internal static Tensor [ ] merge ( Tensor [ ] inputs , string name = null )
94+ internal static MergeOutput merge ( Tensor [ ] inputs , string name = null )
7895 {
7996 return tf_with ( ops . name_scope ( name , "Merge" , inputs ) , scope =>
8097 {
@@ -146,7 +163,7 @@ public static Tensor[] _MergeGrad(Operation op, Tensor[] grads)
146163 }
147164
148165 [ RegisterGradient ( "RefMerge" ) ]
149- public Tensor [ ] _RefMergeGrad ( Operation op , Tensor [ ] grads )
166+ public static Tensor [ ] _RefMergeGrad ( Operation op , Tensor [ ] grads )
150167 {
151168 return _MergeGrad ( op , grads ) ;
152169 }
@@ -155,43 +172,32 @@ public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads)
155172 /// Gradients for an exit op are calculated using an Enter op.
156173 /// </summary>
157174 [ RegisterGradient ( "Exit" ) ]
158- public Tensor [ ] _ExitGrad ( Operation op , Tensor [ ] grads )
175+ public static Tensor [ ] _ExitGrad ( Operation op , Tensor [ ] grads )
159176 {
160- throw new NotImplementedException ( "_ExitGrad" ) ;
161- // graph = ops.get_default_graph()
162- //# pylint: disable=protected-access
163- // op_ctxt = op._get_control_flow_context()
164- // grad_ctxt = graph._get_control_flow_context()
165- // # pylint: enable=protected-access
166- // if not grad_ctxt.back_prop:
167- // # The flag `back_prop` is set by users to suppress gradient
168- // # computation for this loop. If the attribute `back_prop` is false,
169- // # no gradient computation.
170- // return None
177+ var grad = grads [ 0 ] ;
178+ var graph = ops . get_default_graph ( ) ;
179+ var op_ctxt = op . _get_control_flow_context ( ) ;
180+ var grad_ctxt = graph . _get_control_flow_context ( ) as WhileContext ;
181+ // The flag `back_prop` is set by users to suppress gradient
182+ // computation for this loop. If the attribute `back_prop` is false,
183+ // no gradient computation.
184+ if ( ! grad_ctxt . back_prop )
185+ return null ;
171186
172- // if op_ctxt.grad_state:
173- // raise TypeError("Second-order gradient for while loops not supported.")
187+ if ( op_ctxt . grad_state != null )
188+ throw new TypeError ( "Second-order gradient for while loops not supported." ) ;
174189
175- // if isinstance(grad, ops.Tensor) :
176- // grad_ctxt.AddName(grad.name)
177- // else:
178- // if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
179- // raise TypeError("Type %s not supported" % type(grad))
180- // grad_ctxt.AddName(grad.values.name)
181- // grad_ctxt.AddName(grad.indices.name)
182- // dense_shape = grad.dense_shape
183- // if dense_shape is not None:
184- // grad_ctxt.AddName(dense_shape.name)
185- // grad_ctxt.Enter()
186- // # pylint: disable=protected-access
187- // result = control_flow_ops._Enter(
188- // grad, grad_ctxt.name, is_constant=False,
189- // parallel_iterations=grad_ctxt.parallel_iterations,
190- // name="b_exit")
191- // # pylint: enable=protected-access
192- // grad_ctxt.loop_enters.append(result)
193- // grad_ctxt.Exit()
194- // return result
190+ grad_ctxt . AddName ( grad . name ) ;
191+
192+ grad_ctxt . Enter ( ) ;
193+ var result = control_flow_ops . _Enter (
194+ grad , grad_ctxt . name , is_constant : false ,
195+ parallel_iterations : grad_ctxt . parallel_iterations ,
196+ name : "b_exit" ) ;
197+
198+ grad_ctxt . loop_enters . append ( result ) ;
199+ grad_ctxt . Exit ( ) ;
200+ return new [ ] { result } ;
195201 }
196202
197203 /// <summary>
@@ -200,15 +206,15 @@ public Tensor[] _ExitGrad(Operation op, Tensor[] grads)
200206 /// Note that the backprop next_iteration is added in switch grad.
201207 /// </summary>
202208 [ RegisterGradient ( "NextIteration" ) ]
203- public Tensor [ ] _NextIterationGrad ( object _ , Tensor [ ] grad )
209+ public static Tensor [ ] _NextIterationGrad ( Operation op , Tensor [ ] grads )
204210 {
205- return grad ;
211+ return grads ;
206212 }
207213
208214 [ RegisterGradient ( "RefNextIteration" ) ]
209- public Tensor [ ] _RefNextIterationGrad ( object _ , Tensor [ ] grad )
215+ public static Tensor [ ] _RefNextIterationGrad ( Operation op , Tensor [ ] grads )
210216 {
211- return grad ;
217+ return grads ;
212218 }
213219
214220 /// <summary>
@@ -218,33 +224,31 @@ public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad)
218224 /// For loop invariants, we need to add an accumulator loop.
219225 /// </summary>
220226 [ RegisterGradient ( "Enter" ) ]
221- public Tensor [ ] _EnterGrad ( Tensor op , Tensor [ ] grad )
227+ public static Tensor [ ] _EnterGrad ( Operation op , Tensor [ ] grads )
222228 {
223- throw new NotImplementedException ( "_EnterGrad" ) ;
224- // graph = ops.get_default_graph()
225- //# pylint: disable=protected-access
226- // grad_ctxt = graph._get_control_flow_context()
227- // # pylint: enable=protected-access
228- // if not grad_ctxt.back_prop:
229- // # Skip gradient computation, if the attribute `back_prop` is false.
230- // return grad
231- // if grad_ctxt.grad_state is None:
232- // # Pass the gradient through if we are not in a gradient while context.
233- // return grad
234- // if op.get_attr("is_constant"):
235- // # Add a gradient accumulator for each loop invariant.
236- // if isinstance(grad, ops.Tensor) :
237- // result = grad_ctxt.AddBackpropAccumulator(op, grad)
238- // elif isinstance(grad, ops.IndexedSlices) :
239- // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
240- // else:
241- // # TODO(yuanbyu, lukasr): Add support for SparseTensor.
242- // raise TypeError("Type %s not supported" % type(grad))
243- // else:
244- // result = exit(grad)
245- // grad_ctxt.loop_exits.append(result)
246- // grad_ctxt.ExitResult([result])
247- // return result
229+ Tensor result = null ;
230+ var grad = grads [ 0 ] ;
231+ var graph = ops . get_default_graph ( ) ;
232+ var grad_ctxt = graph . _get_control_flow_context ( ) as WhileContext ;
233+ if ( ! grad_ctxt . back_prop )
234+ // Skip gradient computation, if the attribute `back_prop` is false.
235+ return grads ;
236+ if ( grad_ctxt . grad_state == null )
237+ // Pass the gradient through if we are not in a gradient while context.
238+ return grads ;
239+ if ( op . get_attr < bool > ( "is_constant" ) )
240+ {
241+ // Add a gradient accumulator for each loop invariant.
242+ result = grad_ctxt . AddBackpropAccumulator ( op , grad ) ;
243+ }
244+ else
245+ {
246+ result = control_flow_ops . exit ( grad ) ;
247+ grad_ctxt . loop_exits . append ( result ) ;
248+ grad_ctxt . ExitResult ( new [ ] { result } ) ;
249+ }
250+
251+ return new Tensor [ ] { result } ;
248252 }
249253
250254
0 commit comments