@@ -118,110 +118,10 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
118118 if ( values == null )
119119 throw new ValueError ( "None values not supported." ) ;
120120
121- if ( np_dt == null )
122- {
123- switch ( values )
124- {
125- case bool boolVal :
126- nparray = boolVal ;
127- break ;
128- case int intVal :
129- nparray = intVal ;
130- break ;
131- case int [ ] intVals :
132- nparray = np . array ( intVals ) ;
133- break ;
134- case int [ , ] intVals :
135- nparray = np . array ( intVals ) ;
136- break ;
137- case long intVal :
138- nparray = intVal ;
139- break ;
140- case long [ ] intVals :
141- nparray = np . array ( intVals ) ;
142- break ;
143- case long [ , ] intVals :
144- nparray = np . array ( intVals ) ;
145- break ;
146- case float floatVal :
147- nparray = floatVal ;
148- break ;
149- case float [ ] floatVals :
150- nparray = floatVals ;
151- break ;
152- case float [ , ] floatVals :
153- nparray = np . array ( floatVals ) ;
154- break ;
155- case double doubleVal :
156- nparray = doubleVal ;
157- break ;
158- case double [ ] doubleVals :
159- nparray = np . array ( doubleVals ) ;
160- break ;
161- case double [ , ] doubleVals :
162- nparray = np . array ( doubleVals ) ;
163- break ;
164- case string strVal :
165- nparray = strVal ;
166- break ;
167- case string [ ] strVals :
168- nparray = strVals ;
169- break ;
170- case byte [ ] byteValues :
171- nparray = byteValues ;
172- break ;
173- case byte [ , ] byteValues :
174- nparray = np . array ( byteValues ) ;
175- break ;
176- default :
177- throw new NotImplementedException ( $ "make_tensor_proto: Support for type { values . GetType ( ) } Not Implemented") ;
178- }
179- }
180- else
181- {
182- // convert data type
183- switch ( np_dt . Name )
184- {
185- case "Int32" :
186- if ( values . GetType ( ) . IsArray )
187- nparray = np . array ( ( int [ ] ) values , np_dt ) ;
188- else
189- nparray = Converts . ToInt32 ( values ) ;
190- break ;
191- case "Int64" :
192- if ( values . GetType ( ) . IsArray )
193- nparray = np . array ( ( int [ ] ) values , np_dt ) ;
194- else
195- nparray = Converts . ToInt64 ( values ) ;
196- break ;
197- case "Single" :
198- if ( values . GetType ( ) . IsArray )
199- nparray = np . array ( ( float [ ] ) values , np_dt ) ;
200- else
201- nparray = Converts . ToSingle ( values ) ;
202- break ;
203- case "Double" :
204- if ( values . GetType ( ) . IsArray )
205- nparray = np . array ( ( double [ ] ) values , np_dt ) ;
206- else
207- nparray = Converts . ToDouble ( values ) ;
208- break ;
209- case "String" :
210- if ( values . GetType ( ) . IsArray )
211- nparray = np . array ( ( string [ ] ) values , np_dt ) ;
212- else
213- nparray = NDArray . FromString ( Converts . ToString ( values ) ) ;
214- break ;
215- case "Boolean" :
216- if ( values . GetType ( ) . IsArray )
217- nparray = np . array ( ( bool [ ] ) values , np_dt ) ;
218- else
219- nparray = Converts . ToBoolean ( values ) ;
220- break ;
221- default :
222- throw new NotImplementedException ( $ "make_tensor_proto: Support for type { np_dt . Name } Not Implemented") ;
223- }
224- }
121+ nparray = convert_to_numpy_ndarray ( values ) ;
122+
123+ if ( np_dt != null && np_dt != typeof ( string ) )
124+ nparray = nparray . astype ( np_dt ) ;
225125 }
226126
227127 var numpy_dtype = nparray . dtype . as_dtype ( dtype : dtype ) ;
@@ -316,23 +216,59 @@ public static NDArray convert_to_numpy_ndarray(object values)
316216 case NDArray val :
317217 nd = val ;
318218 break ;
319- case int val :
320- nd = np . asarray ( val ) ;
219+ case bool boolVal :
220+ nd = boolVal ;
221+ break ;
222+ case int intVal :
223+ nd = intVal ;
224+ break ;
225+ case int [ ] intVals :
226+ nd = np . array ( intVals ) ;
227+ break ;
228+ case int [ , ] intVals :
229+ nd = np . array ( intVals ) ;
230+ break ;
231+ case long intVal :
232+ nd = intVal ;
233+ break ;
234+ case long [ ] intVals :
235+ nd = np . array ( intVals ) ;
236+ break ;
237+ case long [ , ] intVals :
238+ nd = np . array ( intVals ) ;
239+ break ;
240+ case float floatVal :
241+ nd = floatVal ;
242+ break ;
243+ case float [ ] floatVals :
244+ nd = floatVals ;
245+ break ;
246+ case float [ , ] floatVals :
247+ nd = np . array ( floatVals ) ;
248+ break ;
249+ case double doubleVal :
250+ nd = doubleVal ;
251+ break ;
252+ case double [ ] doubleVals :
253+ nd = np . array ( doubleVals ) ;
254+ break ;
255+ case double [ , ] doubleVals :
256+ nd = np . array ( doubleVals ) ;
321257 break ;
322- case int [ ] val :
323- nd = np . array ( val ) ;
258+ case string strVal :
259+ nd = NDArray . FromString ( strVal ) ;
324260 break ;
325- case float val :
326- nd = np . asarray ( val ) ;
261+ case string [ ] strVals :
262+ nd = strVals ;
327263 break ;
328- case double val :
329- nd = np . asarray ( val ) ;
264+ case byte [ ] byteValues :
265+ nd = byteValues ;
330266 break ;
331- case string val :
332- nd = np . asarray ( val ) ;
267+ case byte [ , ] byteValues :
268+ nd = np . array ( byteValues ) ;
333269 break ;
334270 default :
335- throw new Exception ( " Not Implemented") ;
271+ throw new NotImplementedException ( $ "convert_to_numpy_ndarray: Support for type { values . GetType ( ) } Not Implemented") ;
336272 }
337273
338274 return nd ;
0 commit comments