@@ -40,13 +40,15 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
4040 }
4141
4242 var attrs = new Dictionary < string , object > ( ) ;
43- var inferred_from = new Dictionary < string , object > ( ) ;
4443 var inputs = new List < Tensor > ( ) ;
4544 var input_types = new List < TF_DataType > ( ) ;
46- var base_types = new List < TF_DataType > ( ) ;
47-
45+
4846 return Python . with < ops . name_scope , Operation > ( new ops . name_scope ( name ) , scope =>
4947 {
48+ var inferred_from = new Dictionary < string , object > ( ) ;
49+ var base_types = new List < TF_DataType > ( ) ;
50+ var types = new List < TF_DataType > ( ) ;
51+
5052 // Perform input type inference
5153 foreach ( var input_arg in op_def . InputArg )
5254 {
@@ -72,20 +74,14 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
7274 if ( ! _IsListValue ( values ) )
7375 throw new TypeError ( $ "Expected list for '{ input_name } ' argument to '{ op_type_name } ' Op, not { values } .") ;
7476 if ( input_arg . Type != DataType . DtInvalid )
75- {
7677 dtype = input_arg . Type ;
77- }
7878 else if ( ! String . IsNullOrEmpty ( input_arg . NumberAttr ) )
7979 {
8080 if ( attrs . ContainsKey ( input_arg . TypeAttr ) )
81- {
8281 dtype = ( DataType ) attrs [ input_arg . TypeAttr ] ;
83- }
8482 else
85- {
8683 if ( values is Tensor [ ] values1 )
8784 dtype = values1 [ 0 ] . dtype . as_datatype_enum ( ) ;
88- }
8985
9086 if ( dtype == DataType . DtInvalid && default_type_attr_map . ContainsKey ( input_arg . TypeAttr ) )
9187 default_dtype = ( DataType ) default_type_attr_map [ input_arg . TypeAttr ] ;
@@ -94,86 +90,48 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
9490 if ( input_arg . IsRef && dtype != DataType . DtInvalid )
9591 dtype = dtype . as_base_dtype ( ) ;
9692
97- values = ops . internal_convert_n_to_tensor ( values , name : input_arg . Name , dtype : dtype , preferred_dtype : default_dtype , as_ref : input_arg . IsRef ) ;
93+ values = ops . internal_convert_n_to_tensor ( values ,
94+ name : input_arg . Name ,
95+ dtype : dtype ,
96+ preferred_dtype : default_dtype ,
97+ as_ref : input_arg . IsRef ) ;
9898 }
9999 else
100100 {
101- if ( default_type_attr_map . ContainsKey ( input_arg . TypeAttr ) )
101+ if ( input_arg . Type != DataType . DtInvalid )
102+ dtype = input_arg . Type ;
103+ else if ( attrs . ContainsKey ( input_arg . TypeAttr ) )
104+ dtype = ( DataType ) attrs [ input_arg . TypeAttr ] ;
105+ else if ( default_type_attr_map . ContainsKey ( input_arg . TypeAttr ) )
102106 default_dtype = ( DataType ) default_type_attr_map [ input_arg . TypeAttr ] ;
103107
104- if ( keywords [ input_name ] is Tensor )
105- {
106- }
107- else
108- {
109- keywords [ input_name ] = ops . internal_convert_to_tensor ( values , name : input_name , as_ref : input_arg . IsRef ) ;
110- }
111-
112- if ( ! String . IsNullOrEmpty ( input_arg . TypeAttr ) )
113- {
114- attrs [ input_arg . TypeAttr ] = ( keywords [ input_name ] as Tensor ) . dtype ;
115- }
116- values = new Tensor [ ] { keywords [ input_name ] as Tensor } ;
117- }
118-
119- inputs . AddRange ( values as Tensor [ ] ) ;
120- base_types . AddRange ( ( values as Tensor [ ] ) . Select ( x => x . dtype . as_base_dtype ( ) ) ) ;
121- input_types . AddRange ( base_types ) ;
122-
123- if ( ! string . IsNullOrEmpty ( input_arg . NumberAttr ) )
124- {
125- if ( attrs . ContainsKey ( input_arg . NumberAttr ) )
126- {
127-
128- }
129- else
130- {
131- attrs [ input_arg . NumberAttr ] = ( values as Tensor [ ] ) . Length ;
132- inferred_from [ input_arg . NumberAttr ] = input_name ;
133- var num_attr = op_def . Attr . First ( x => x . Name == input_arg . NumberAttr ) ;
134- if ( num_attr . HasMinimum && ( values as Tensor [ ] ) . Length < num_attr . Minimum )
135- throw new ValueError ( $ "List argument '{ input_name } ' to '{ op_type_name } ' Op with length { ( values as Tensor [ ] ) . Length } shorter " +
136- $ "than minimum length { num_attr . Minimum } ") ;
137- }
108+ values = ops . internal_convert_to_tensor ( values ,
109+ name : input_name ,
110+ as_ref : input_arg . IsRef ) ;
138111
139- // All tensors must have the same base type.
140- if ( input_arg . Type != DataType . DtInvalid )
141- {
112+ //if (!String.IsNullOrEmpty(input_arg.TypeAttr))
113+ //attrs[input_arg.TypeAttr] = values.dtype;
142114
143- }
144- else
145- {
146- attrs [ input_arg . TypeAttr ] = base_types [ 0 ] ;
147- inferred_from [ input_arg . TypeAttr ] = input_name ;
148- var type_attr = op_def . Attr . First ( x => x . Name == input_arg . TypeAttr ) ;
149- }
115+ values = new Tensor [ ] { values } ;
150116 }
151- else if ( ! string . IsNullOrEmpty ( input_arg . TypeAttr ) )
152- {
153- var attr_value = base_types [ 0 ] ;
154- if ( attrs . ContainsKey ( input_arg . TypeAttr ) )
155- {
156117
157- }
158- else
159- {
160- attrs [ input_arg . TypeAttr ] = attr_value ;
161- inferred_from [ input_arg . TypeAttr ] = input_name ;
162- }
163- }
164- else if ( ! string . IsNullOrEmpty ( input_arg . TypeListAttr ) )
118+ if ( values is Tensor [ ] values2 )
165119 {
166- var attr_value = base_types ;
167- if ( attrs . ContainsKey ( input_arg . TypeListAttr ) )
168- {
169-
170- }
171- else
172- {
173- attrs [ input_arg . TypeListAttr ] = attr_value ;
174- inferred_from [ input_arg . TypeListAttr ] = input_name ;
175- }
120+ types = values2 . Select ( x => x . dtype ) . ToList ( ) ;
121+ inputs . AddRange ( values2 ) ;
122+ base_types = values2 . Select ( x => x . dtype . as_base_dtype ( ) ) . ToList ( ) ;
176123 }
124+ else throw new NotImplementedException ( "_IsListParameter" ) ;
125+
126+ SetAttrs ( op_type_name ,
127+ input_arg ,
128+ op_def ,
129+ attrs ,
130+ inferred_from ,
131+ types ,
132+ base_types ,
133+ input_types ,
134+ values ) ;
177135 }
178136
179137 // Process remaining attrs
@@ -190,22 +148,26 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
190148 foreach ( var attr_def in op_def . Attr )
191149 {
192150 var key = attr_def . Name ;
151+ var value = attrs [ key ] ;
152+
193153 if ( ! attrs . ContainsKey ( key ) )
194154 Console . WriteLine ( $ "_apply_op_helper: key '{ key } ' is not found in '{ op_def . Name } ' operation's attr_def.") ;
195155
196- attr_protos [ key ] = SetAttrValue ( op_def , attr_def , attrs [ key ] ) ;
156+ attr_protos [ key ] = SetAttrValue ( op_def , attr_def , value ) ;
197157 }
198158
159+ attrs . Clear ( ) ;
160+
199161 // Determine output types (possibly using attrs)
200162 var output_types = new List < TF_DataType > ( ) ;
201163
202164 foreach ( var arg in op_def . OutputArg )
203165 {
204- if ( ! String . IsNullOrEmpty ( arg . NumberAttr ) )
166+ if ( ! string . IsNullOrEmpty ( arg . NumberAttr ) )
205167 {
206168
207169 }
208- else if ( ! String . IsNullOrEmpty ( arg . TypeAttr ) )
170+ else if ( ! string . IsNullOrEmpty ( arg . TypeAttr ) )
209171 {
210172 output_types . Add ( ( TF_DataType ) attr_protos [ arg . TypeAttr ] . Type ) ;
211173 }
@@ -222,6 +184,79 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
222184 } ) ;
223185 }
224186
187+ private void SetAttrs ( string op_type_name ,
188+ ArgDef input_arg ,
189+ OpDef op_def ,
190+ Dictionary < string , object > attrs ,
191+ Dictionary < string , object > inferred_from ,
192+ List < TF_DataType > types ,
193+ List < TF_DataType > base_types ,
194+ List < TF_DataType > input_types ,
195+ dynamic values )
196+ {
197+ var input_name = input_arg . Name ;
198+
199+ if ( ! string . IsNullOrEmpty ( input_arg . NumberAttr ) )
200+ {
201+ if ( attrs . ContainsKey ( input_arg . NumberAttr ) )
202+ {
203+
204+ }
205+ else
206+ {
207+ attrs [ input_arg . NumberAttr ] = ( values as Tensor [ ] ) . Length ;
208+ inferred_from [ input_arg . NumberAttr ] = input_name ;
209+ var num_attr = op_def . Attr . First ( x => x . Name == input_arg . NumberAttr ) ;
210+ if ( num_attr . HasMinimum && ( values as Tensor [ ] ) . Length < num_attr . Minimum )
211+ throw new ValueError ( $ "List argument '{ input_name } ' to '{ op_type_name } ' Op with length { ( values as Tensor [ ] ) . Length } shorter " +
212+ $ "than minimum length { num_attr . Minimum } ") ;
213+ }
214+
215+ // All tensors must have the same base type.
216+ if ( input_arg . Type != DataType . DtInvalid )
217+ {
218+
219+ }
220+ else
221+ {
222+ attrs [ input_arg . TypeAttr ] = base_types [ 0 ] ;
223+ inferred_from [ input_arg . TypeAttr ] = input_name ;
224+ var type_attr = op_def . Attr . First ( x => x . Name == input_arg . TypeAttr ) ;
225+ }
226+ }
227+ else if ( ! string . IsNullOrEmpty ( input_arg . TypeAttr ) )
228+ {
229+ var attr_value = base_types [ 0 ] ;
230+ if ( attrs . ContainsKey ( input_arg . TypeAttr ) )
231+ {
232+
233+ }
234+ else
235+ {
236+ attrs [ input_arg . TypeAttr ] = attr_value ;
237+ inferred_from [ input_arg . TypeAttr ] = input_name ;
238+ }
239+ }
240+ else if ( ! string . IsNullOrEmpty ( input_arg . TypeListAttr ) )
241+ {
242+ var attr_value = base_types ;
243+ if ( attrs . ContainsKey ( input_arg . TypeListAttr ) )
244+ {
245+
246+ }
247+ else
248+ {
249+ attrs [ input_arg . TypeListAttr ] = attr_value ;
250+ inferred_from [ input_arg . TypeListAttr ] = input_name ;
251+ }
252+ }
253+
254+ if ( input_arg . IsRef )
255+ input_types . AddRange ( types ) ;
256+ else
257+ input_types . AddRange ( base_types ) ;
258+ }
259+
225260 public DataType _MakeType ( TF_DataType v , AttrDef attr_def )
226261 {
227262 return v . as_base_dtype ( ) . as_datatype_enum ( ) ;
@@ -231,6 +266,13 @@ private AttrValue SetAttrValue(OpDef op_def, AttrDef attr_def, object value)
231266 {
232267 var attr_value = new AttrValue ( ) ;
233268
269+ if ( attr_def . Type . StartsWith ( "list(" ) )
270+ {
271+ if ( attr_def . HasMinimum )
272+ ;
273+ attr_value . List = new AttrValue . Types . ListValue ( ) ;
274+ }
275+
234276 switch ( attr_def . Type )
235277 {
236278 case "string" :
@@ -240,8 +282,6 @@ private AttrValue SetAttrValue(OpDef op_def, AttrDef attr_def, object value)
240282 attr_value . Type = _MakeType ( ( TF_DataType ) value , attr_def ) ;
241283 break ;
242284 case "list(type)" :
243- if ( attr_value . List == null )
244- attr_value . List = new AttrValue . Types . ListValue ( ) ;
245285 attr_value . List . Type . AddRange ( ( value as IList < TF_DataType > ) . Select ( x => _MakeType ( x , attr_def ) ) ) ;
246286 break ;
247287 case "bool" :
0 commit comments