@@ -107,7 +107,7 @@ private NDArray[] _run(object fetches, FeedItem[] feed_dict = null)
107107 foreach ( var subfeed in feed_dict )
108108 {
109109 var subfeed_t = _graph . as_graph_element ( subfeed . Key , allow_tensor : true , allow_operation : false ) ;
110- //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype (); // subfeed_dtype was never used
110+ //var target_dtype = subfeed_t.dtype.as_numpy_typecode (); // subfeed_dtype was never used
111111 feed_dict_tensor [ subfeed_t ] = subfeed . Value ;
112112 //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
113113 }
@@ -150,58 +150,64 @@ private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list,
150150 int i = 0 ;
151151 foreach ( var x in feed_dict )
152152 {
153- if ( x . Key is Tensor tensor )
153+ if ( x . Key is Tensor key )
154154 {
155155 switch ( x . Value )
156156 {
157157 case Tensor v :
158- feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , v ) ;
158+ if ( v . dtype != key . dtype )
159+ throw new ValueError ( $ "Tensor { v } does not match the expected dtype { key . dtype } , actual dtype: { v . dtype } ") ;
160+ feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , v ) ;
159161 break ;
160162 case NDArray v :
161- feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v , tensor . dtype ) ) ;
163+ feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ;
162164 break ;
163165 case IntPtr v :
164- feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
166+ var tensor = new Tensor ( v ) ;
167+ if ( tensor . dtype != key . dtype )
168+ throw new ValueError ( $ "Tensor { v } does not match the expected dtype { key . dtype } , actual dtype: { tensor . dtype } ") ;
169+
170+ feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , tensor ) ;
165171 break ;
166172#if _REGEN
167173 // @formatter:off — disable formatter after this line
168- % types = [ "sbyte" , "byte" , "short" , "ushort" , "int" , "uint" , "long" , "ulong" , "float" , "double" , "Complex" ]
169- % foreach types%
170- case #1 v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
171- case #1 [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
172- %
174+ % types = [ "bool" , "sbyte" , "byte" , "short" , "ushort" , "int" , "uint" , "long" , "ulong" , "float" , "double" , "Complex" ]
175+ % foreach types%
176+ case #1 v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
177+ case #1 [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
178+ %
173179 // @formatter:on — enable formatter after this line
174180#else
175181 // @formatter:off — disable formatter after this line
176- case sbyte v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
177- case sbyte [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
178- case byte v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
179- case byte [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
180- case short v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
181- case short [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
182- case ushort v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
183- case ushort [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
184- case int v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
185- case int [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
186- case uint v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
187- case uint [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
188- case long v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
189- case long [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
190- case ulong v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
191- case ulong [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
192- case float v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
193- case float [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
194- case double v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
195- case double [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
196- case Complex v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
197- case Complex [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ; break ;
182+ case bool v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
183+ case bool [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
184+ case sbyte v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
185+ case sbyte [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
186+ case byte v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
187+ case byte [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
188+ case short v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
189+ case short [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
190+ case ushort v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
191+ case ushort [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
192+ case int v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
193+ case int [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
194+ case uint v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
195+ case uint [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
196+ case long v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
197+ case long [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
198+ case ulong v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
199+ case ulong [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
200+ case float v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
201+ case float [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
202+ case double v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
203+ case double [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
204+ case Complex v: feeds[ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
205+ case Complex [ ] v : feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ; break ;
198206 // @formatter:on — enable formatter after this line
199207#endif
200- case bool v :
201- feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( ( byte ) ( v ? 1 : 0 ) , TF_DataType . TF_BOOL ) ) ;
202- break ;
208+
203209 case string v :
204- feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( v ) ) ;
210+ feeds [ i ++ ] = new KeyValuePair < TF_Output , Tensor > ( key . _as_tf_output ( ) , TensorConverter . ToTensor ( v , key . dtype ) ) ;
205211 break ;
206212 default :
207213 throw new NotImplementedException ( $ "feed_dict data type { x . Value ? . GetType ( ) . Name ?? "<null>" } ") ;
@@ -214,6 +220,7 @@ private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list,
214220 return _call_tf_sessionrun ( feeds , fetches , target_list ) ;
215221 }
216222
223+
217224 private unsafe NDArray [ ] _call_tf_sessionrun ( KeyValuePair < TF_Output , Tensor > [ ] feed_dict , TF_Output [ ] fetch_list , List < Operation > target_list )
218225 {
219226 // Ensure any changes to the graph are reflected in the runtime.
0 commit comments