@@ -15,8 +15,12 @@ limitations under the License.
1515******************************************************************************/
1616
1717using System ;
18+ using System . Collections . Generic ;
19+ using System . Linq ;
1820using Tensorflow . Operations . ControlFlows ;
21+ using Tensorflow . Util ;
1922using static Tensorflow . Python ;
23+ using static Tensorflow . control_flow_ops ;
2024
2125namespace Tensorflow . Operations
2226{
@@ -32,10 +36,14 @@ public class WhileContext : ControlFlowContext
3236 bool _swap_memory ;
3337 Tensor _pivot_for_pred ;
3438 Tensor _pivot_for_body ;
35- Tensor [ ] _loop_exits ;
36- Tensor [ ] _loop_enters ;
39+ List < Tensor > _loop_exits ;
40+ List < Tensor > _loop_enters ;
41+ Graph _graph ;
42+ public override GradLoopState grad_state => _grad_state ;
43+ public override bool back_prop => _back_prop ;
3744
38- public WhileContext ( int parallel_iterations = 10 ,
45+ public WhileContext ( int ? maximum_iterations = null ,
46+ int parallel_iterations = 10 ,
3947 bool back_prop = true ,
4048 bool swap_memory = false ,
4149 string name = "while_context" ,
@@ -49,12 +57,27 @@ public WhileContext(int parallel_iterations = 10,
4957 }
5058 else
5159 {
52-
60+ __init__ ( ) ;
61+ _init_from_args ( maximum_iterations , parallel_iterations , back_prop , swap_memory , name ) ;
5362 }
5463
5564 _grad_state = grad_state ;
5665 }
5766
67+ private void _init_from_args ( int ? maximum_iterations ,
68+ int parallel_iterations ,
69+ bool back_prop ,
70+ bool swap_memory ,
71+ string name )
72+ {
73+ _name = ops . get_default_graph ( ) . unique_name ( name ) ;
74+ _back_prop = back_prop ;
75+ _swap_memory = swap_memory ;
76+ _loop_exits = new List < Tensor > ( ) ;
77+ _loop_enters = new List < Tensor > ( ) ;
78+ _graph = ops . get_default_graph ( ) ;
79+ }
80+
5881 private void _init_from_proto ( WhileContextDef context_def , string import_scope = null )
5982 {
6083 var g = ops . get_default_graph ( ) ;
@@ -70,26 +93,156 @@ private void _init_from_proto(WhileContextDef context_def, string import_scope =
7093 // The boolean tensor for loop termination condition.
7194 _pivot = g . as_graph_element ( ops . prepend_name_scope ( context_def . PivotName , import_scope ) ) as Tensor ;
7295 // The list of exit tensors for loop variables.
73- _loop_exits = new Tensor [ context_def . LoopExitNames . Count ] ;
96+ _loop_exits = new List < Tensor > ( ) ;
7497 foreach ( var ( i , exit_name ) in enumerate ( context_def . LoopExitNames ) )
75- _loop_exits [ i ] = g . as_graph_element ( ops . prepend_name_scope ( exit_name , import_scope ) ) as Tensor ;
98+ _loop_exits . Add ( g . as_graph_element ( ops . prepend_name_scope ( exit_name , import_scope ) ) as Tensor ) ;
7699 // The list of enter tensors for loop variables.
77- _loop_enters = new Tensor [ context_def . LoopEnterNames . Count ] ;
100+ _loop_enters = new List < Tensor > ( ) ;
78101 foreach ( var ( i , enter_name ) in enumerate ( context_def . LoopEnterNames ) )
79- _loop_enters [ i ] = g . as_graph_element ( ops . prepend_name_scope ( enter_name , import_scope ) ) as Tensor ;
102+ _loop_enters . Add ( g . as_graph_element ( ops . prepend_name_scope ( enter_name , import_scope ) ) as Tensor ) ;
80103
81104 __init__ ( values_def : context_def . ValuesDef , import_scope : import_scope ) ;
82105 }
83106
84- public override WhileContext GetWhileContext ( )
107+ /// <summary>
108+ /// Add the loop termination condition and body to the graph.
109+ /// </summary>
110+ public Tensor [ ] BuildLoop ( Func < Tensor , Tensor > pred ,
111+ Func < Tensor , Tensor > body ,
112+ Tensor [ ] loop_vars ,
113+ TensorShape shape_invariants ,
114+ bool return_same_structure )
85115 {
86- return this ;
116+ // Keep original_loop_vars to identify which are TensorArrays
117+ var original_loop_vars = loop_vars ;
118+ // Convert TensorArrays to their flow variables
119+ Enter ( ) ;
120+ var ( original_body_result , exit_vars ) = _BuildLoop (
121+ pred , body , original_loop_vars , loop_vars , shape_invariants ) ;
122+ Exit ( ) ;
123+
124+ var flat_result = original_body_result ;
125+
126+ var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays ( flat_result , exit_vars ) ;
127+ var packed_exit_vars = nest . pack_sequence_as (
128+ structure : original_body_result ,
129+ flat_sequence : exit_vars_with_tensor_arrays ) ;
130+
131+ return packed_exit_vars as Tensor [ ] ;
87132 }
88133
134+ private ( Tensor [ ] , Tensor [ ] ) _BuildLoop ( Func < Tensor , Tensor > pred ,
135+ Func < Tensor , Tensor > body ,
136+ Tensor [ ] original_loop_vars ,
137+ Tensor [ ] loop_vars ,
138+ TensorShape shape_invariants )
139+ {
140+ var flat_loop_vars = original_loop_vars ;
89141
90- public override GradLoopState grad_state => _grad_state ;
142+ // Let the context know the loop variables so the loop variables
143+ // would be added in the outer contexts properly.
144+ _InitializeValues ( loop_vars ) ;
145+ var real_vars = loop_vars ;
146+ Tensor [ ] enter_vars = null ;
147+ tf_with ( ops . control_dependencies ( null ) , delegate
148+ {
149+ enter_vars = real_vars . Select ( x => _Enter ( x ,
150+ _name ,
151+ is_constant : false ,
152+ parallel_iterations : _parallel_iterations ,
153+ use_input_shape : shape_invariants == null ) )
154+ . ToArray ( ) ;
91155
92- public override bool back_prop => _back_prop ;
156+ foreach ( var x in enter_vars )
157+ {
158+ x . graph . prevent_feeding ( x ) ;
159+ if ( _outer_context != null )
160+ _outer_context . AddInnerOp ( x . op ) ;
161+ }
162+ } ) ;
163+
164+ // Finds the closest enclosing non-None control pivot.
165+ var outer_context = _outer_context ;
166+ while ( outer_context != null )
167+ {
168+
169+ }
170+
171+ _SetShapeInvariants ( real_vars , enter_vars , shape_invariants ) ;
172+
173+ // Fix the control inputs and control flow context of these enter ops.
174+ _FixControlInputsAndContext ( enter_vars ) ;
175+ _InitializeValues ( enter_vars ) ;
176+ _loop_enters = enter_vars . ToList ( ) ;
177+
178+ var merge_vars = enter_vars
179+ . Select ( x => merge ( new [ ] { x , x } ) )
180+ . ToArray ( ) ;
181+
182+ _pivot_for_pred = merge_vars [ 0 ] ;
183+
184+ // Build the graph for pred.
185+ var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays ( flat_loop_vars , merge_vars ) ;
186+ // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
187+ var c = ops . convert_to_tensor ( pred ( merge_vars_with_tensor_arrays [ 0 ] ) ) ;
188+ _pivot = gen_control_flow_ops . loop_cond ( c , name : "LoopCond" ) ;
189+ var switch_vars = merge_vars . Select ( x => _SwitchRefOrTensor ( x , _pivot ) )
190+ . ToArray ( ) ;
191+
192+ // Build the graph for body.
193+ var vars_for_body = switch_vars . Select ( x => _Identity ( x [ 1 ] ) ) . ToArray ( ) ;
194+ // Convert TensorArray flow variables inside the context back into
195+ // their associated TensorArrays for calling the body.
196+ var packed_vars_for_body = _convert_flows_to_tensorarrays ( flat_loop_vars , vars_for_body ) ;
197+ var body_result = body ( packed_vars_for_body [ 0 ] ) ;
198+ var post_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
199+
200+ // Store body_result to keep track of TensorArrays returned by body
201+ var original_body_result = new [ ] { body_result } ;
202+ // Convert TensorArrays returned by body into their flow variables
203+ var result = new [ ] { body_result } ;
204+
205+ var next_vars = new List < Tensor > ( ) ;
206+ foreach ( var ( m , v ) in zip ( merge_vars , result ) )
207+ next_vars . Add ( _AddNextAndBackEdge ( m , v ) ) ;
208+
209+ // Add the exit ops.
210+ var exit_vars = switch_vars . Select ( x => exit ( x [ 0 ] ) ) . ToList ( ) ;
211+ _loop_exits = exit_vars ;
212+
213+ // Exit the loop.
214+ // ExitResult(exit_vars);
215+ return ( original_body_result , exit_vars . ToArray ( ) ) ;
216+ }
217+
218+ private void _FixControlInputsAndContext ( Tensor [ ] enters )
219+ {
220+ var graph = ops . get_default_graph ( ) ;
221+ foreach ( var e in enters )
222+ {
223+ var inp_op = e . op . inputs [ 0 ] . op ;
224+ var control_inputs = graph . _control_dependencies_for_inputs ( new [ ] { inp_op } ) ;
225+ // op for op in control_inputs if self._IsInOuterContext(op)
226+ var outer_control_inputs = control_inputs . Where ( x => _IsInOuterContext ( x . op ) )
227+ . Select ( x => x . op )
228+ . ToArray ( ) ;
229+ e . op . _set_control_flow_context ( this ) ;
230+ e . op . _add_control_inputs ( outer_control_inputs ) ;
231+ graph . _record_op_seen_by_control_dependencies ( e . op ) ;
232+ }
233+ }
234+
235+ private void _InitializeValues ( Tensor [ ] values )
236+ {
237+ _values = new HashSet < string > ( ) ;
238+ foreach ( var x in values )
239+ _values . Add ( x . name ) ;
240+ }
241+
242+ public override WhileContext GetWhileContext ( )
243+ {
244+ return this ;
245+ }
93246
94247 public WhileContext from_proto ( WhileContextDef proto , string import_scope )
95248 {
0 commit comments