33using System . Linq ;
44using System . Text ;
55using Tensorflow . Operations . ControlFlows ;
6+ using static Tensorflow . Python ;
67
78namespace Tensorflow . Operations
89{
@@ -46,9 +47,9 @@ public CondContext(Tensor pred = null,
4647 if ( pred == null && context_def == null ) return ;
4748
4849 _name = ops . get_default_graph ( ) . unique_name ( name ) ;
49- if ( context_def != null )
50- {
51- _init_from_proto ( context_def , import_scope : import_scope ) ;
50+ if ( context_def != null )
51+ {
52+ _init_from_proto ( context_def , import_scope : import_scope ) ;
5253 }
5354 else
5455 {
@@ -66,16 +67,16 @@ public CondContext(Tensor pred = null,
6667 }
6768 }
6869
69- private void _init_from_proto ( CondContextDef context_def , string import_scope = null )
70- {
71- var g = ops . get_default_graph ( ) ;
72- _name = ops . prepend_name_scope ( context_def . ContextName , import_scope ) ;
73- var p1 = ops . prepend_name_scope ( context_def . PredName , import_scope ) ;
74- _pred = g . as_graph_element ( p1 ) as Tensor ;
75- var p2 = ops . prepend_name_scope ( context_def . PivotName , import_scope ) ;
76- _pivot = g . as_graph_element ( p2 ) as Tensor ;
77- _branch = context_def . Branch ;
78- __init__ ( values_def : context_def . ValuesDef , import_scope : import_scope ) ;
70+ private void _init_from_proto ( CondContextDef context_def , string import_scope = null )
71+ {
72+ var g = ops . get_default_graph ( ) ;
73+ _name = ops . prepend_name_scope ( context_def . ContextName , import_scope ) ;
74+ var p1 = ops . prepend_name_scope ( context_def . PredName , import_scope ) ;
75+ _pred = g . as_graph_element ( p1 ) as Tensor ;
76+ var p2 = ops . prepend_name_scope ( context_def . PivotName , import_scope ) ;
77+ _pivot = g . as_graph_element ( p2 ) as Tensor ;
78+ _branch = context_def . Branch ;
79+ __init__ ( values_def : context_def . ValuesDef , import_scope : import_scope ) ;
7980 }
8081
8182 /// <summary>
@@ -90,8 +91,8 @@ public override Tensor AddValue(Tensor val)
9091 // Use the real value if it comes from outer context. This is needed in
9192 // particular for nested conds.
9293 if ( _external_values . ContainsKey ( val . name ) )
93- result = _external_values [ val . name ] ;
94-
94+ result = _external_values [ val . name ] ;
95+
9596 result = result == null ? val : result ;
9697 }
9798 else
@@ -107,10 +108,10 @@ public override Tensor AddValue(Tensor val)
107108 }
108109
109110 with ( ops . control_dependencies ( null ) , ctrl =>
110- {
111- var results = control_flow_ops . _SwitchRefOrTensor ( result , _pred ) ;
112- result = results [ _branch ] ;
113- if ( _outer_context != null )
111+ {
112+ var results = control_flow_ops . _SwitchRefOrTensor ( result , _pred ) ;
113+ result = results [ _branch ] ;
114+ if ( _outer_context != null )
114115 _outer_context . AddInnerOp ( result . op ) ;
115116 } ) ;
116117
@@ -127,87 +128,87 @@ public override Tensor AddValue(Tensor val)
127128 }
128129 _external_values [ val . name ] = result ;
129130 }
130- return result ;
131- }
132-
131+ return result ;
132+ }
133+
133134 /// <summary>
134135 /// Add the subgraph defined by fn() to the graph.
135136 /// </summary>
136- public ( T , Tensor ) BuildCondBranch < T > ( Func < T > fn )
137- {
138- // Add the subgraph defined by fn() to the graph.
139- var pre_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
140- var original_result = fn ( ) ;
141- var post_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
142-
143- //TODO: port this chunck of missing code:
144- /*
145- if len(post_summaries) > len(pre_summaries):
146- new_summaries = post_summaries[len(pre_summaries):]
147- summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
148- summary_ref[:] = pre_summaries
149- with ops.control_dependencies(new_summaries):
150- if original_result is None:
151- return no_op(), None
152- else:
153- original_result = nest.map_structure(array_ops.identity,
154- original_result)
155- */
156- if ( original_result == null )
157- return ( original_result , null ) ;
158-
159- switch ( original_result )
160- {
161- case Tensor result :
162- return ( original_result , _BuildCondTensor ( result ) ) ;
163- case Operation op :
164- return ( original_result , _BuildCondTensor ( op ) ) ;
137+ public ( T , Tensor ) BuildCondBranch < T > ( Func < T > fn )
138+ {
139+ // Add the subgraph defined by fn() to the graph.
140+ var pre_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
141+ var original_result = fn ( ) ;
142+ var post_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
143+
144+ //TODO: port this chunck of missing code:
145+ /*
146+ if len(post_summaries) > len(pre_summaries):
147+ new_summaries = post_summaries[len(pre_summaries):]
148+ summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
149+ summary_ref[:] = pre_summaries
150+ with ops.control_dependencies(new_summaries):
151+ if original_result is None:
152+ return no_op(), None
153+ else:
154+ original_result = nest.map_structure(array_ops.identity,
155+ original_result)
156+ */
157+ if ( original_result == null )
158+ return ( original_result , null ) ;
159+
160+ switch ( original_result )
161+ {
162+ case Tensor result :
163+ return ( original_result , _BuildCondTensor ( result ) ) ;
164+ case Operation op :
165+ return ( original_result , _BuildCondTensor ( op ) ) ;
165166 case float [ ] fv :
166167 {
167168 var result = ops . convert_to_tensor ( fv [ 0 ] ) ;
168169 return ( original_result , _BuildCondTensor ( result ) ) ;
169- }
170- default :
171- return ( original_result , null ) ;
172- }
173- }
174-
175- public ( T [ ] , Tensor [ ] ) BuildCondBranch < T > ( Func < T [ ] > fn )
176- {
177- // Add the subgraph defined by fn() to the graph.
178- var pre_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
179- var original_result = fn ( ) ;
180- var post_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
181-
182- switch ( original_result )
183- {
184- case Tensor [ ] results :
185- return ( original_result , results . Select ( _BuildCondTensor ) . ToArray ( ) ) ;
186- case Operation [ ] results :
187- return ( original_result , results . Select ( _BuildCondTensor ) . ToArray ( ) ) ;
188- case float [ ] fv :
189- var result = ops . convert_to_tensor ( fv [ 0 ] ) ;
190- return ( original_result , new Tensor [ ] { result } ) ;
191- default :
192- return ( original_result , new Tensor [ 0 ] ) ;
193- }
194- }
195-
196- private Tensor _BuildCondTensor ( ITensorOrOperation v )
197- {
198- switch ( v )
199- {
200- case Operation op :
201- // Use pivot as the proxy for this op.
202- return control_flow_ops . with_dependencies ( new Operation [ ] { op } , _pivot ) ;
203- case Tensor t :
204- return _ProcessOutputTensor ( t ) ;
205- default :
206- return _ProcessOutputTensor ( ops . convert_to_tensor ( v ) ) ;
207-
208- }
209- }
210-
170+ }
171+ default :
172+ return ( original_result , null ) ;
173+ }
174+ }
175+
176+ public ( T [ ] , Tensor [ ] ) BuildCondBranch < T > ( Func < T [ ] > fn )
177+ {
178+ // Add the subgraph defined by fn() to the graph.
179+ var pre_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
180+ var original_result = fn ( ) ;
181+ var post_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
182+
183+ switch ( original_result )
184+ {
185+ case Tensor [ ] results :
186+ return ( original_result , results . Select ( _BuildCondTensor ) . ToArray ( ) ) ;
187+ case Operation [ ] results :
188+ return ( original_result , results . Select ( _BuildCondTensor ) . ToArray ( ) ) ;
189+ case float [ ] fv :
190+ var result = ops . convert_to_tensor ( fv [ 0 ] ) ;
191+ return ( original_result , new Tensor [ ] { result } ) ;
192+ default :
193+ return ( original_result , new Tensor [ 0 ] ) ;
194+ }
195+ }
196+
197+ private Tensor _BuildCondTensor ( ITensorOrOperation v )
198+ {
199+ switch ( v )
200+ {
201+ case Operation op :
202+ // Use pivot as the proxy for this op.
203+ return control_flow_ops . with_dependencies ( new Operation [ ] { op } , _pivot ) ;
204+ case Tensor t :
205+ return _ProcessOutputTensor ( t ) ;
206+ default :
207+ return _ProcessOutputTensor ( ops . convert_to_tensor ( v ) ) ;
208+
209+ }
210+ }
211+
211212 /// <summary>
212213 /// Process an output tensor of a conditional branch.
213214 /// </summary>
@@ -238,7 +239,7 @@ private Tensor _ProcessOutputTensor(Tensor val)
238239 }
239240 return real_val ;
240241 }
241-
242+
242243 protected override void _AddOpInternal ( Operation op )
243244 {
244245 if ( op . inputs . Length == 0 )
@@ -324,20 +325,20 @@ public override bool back_prop
324325 }
325326 }
326327
327- public CondContextDef to_proto ( string export_scope )
328- {
329- throw new NotImplementedException ( ) ;
330- }
331-
332- public CondContext from_proto ( CondContextDef proto , string import_scope )
333- {
334- var ret = new CondContext ( context_def : proto , import_scope : import_scope ) ;
335-
336- ret . Enter ( ) ;
337- foreach ( var nested_def in proto . NestedContexts )
338- from_control_flow_context_def ( nested_def , import_scope : import_scope ) ;
339- ret . Exit ( ) ;
340- return ret ;
341- }
342- }
328+ public CondContextDef to_proto ( string export_scope )
329+ {
330+ throw new NotImplementedException ( ) ;
331+ }
332+
333+ public CondContext from_proto ( CondContextDef proto , string import_scope )
334+ {
335+ var ret = new CondContext ( context_def : proto , import_scope : import_scope ) ;
336+
337+ ret . Enter ( ) ;
338+ foreach ( var nested_def in proto . NestedContexts )
339+ from_control_flow_context_def ( nested_def , import_scope : import_scope ) ;
340+ ret . Exit ( ) ;
341+ return ret ;
342+ }
343+ }
343344}
0 commit comments