@@ -43,7 +43,7 @@ public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null)
4343
4444 private NDArray _run < T > ( T fetches , FeedItem [ ] feed_dict = null )
4545 {
46- var feed_dict_tensor = new Dictionary < Tensor , NDArray > ( ) ;
46+ var feed_dict_tensor = new Dictionary < object , object > ( ) ;
4747
4848 if ( feed_dict != null )
4949 feed_dict . ToList ( ) . ForEach ( x => feed_dict_tensor . Add ( x . Key , x . Value ) ) ;
@@ -79,9 +79,30 @@ private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null)
7979 /// name of an operation, the first Tensor output of that operation
8080 /// will be returned for that element.
8181 /// </returns>
82- private NDArray [ ] _do_run ( List < Operation > target_list , List < Tensor > fetch_list , Dictionary < Tensor , NDArray > feed_dict )
82+ private NDArray [ ] _do_run ( List < Operation > target_list , List < Tensor > fetch_list , Dictionary < object , object > feed_dict )
8383 {
84- var feeds = feed_dict . Select ( x => new KeyValuePair < TF_Output , Tensor > ( x . Key . _as_tf_output ( ) , new Tensor ( x . Value ) ) ) . ToArray ( ) ;
84+ var feeds = feed_dict . Select ( x =>
85+ {
86+ if ( x . Key is Tensor tensor )
87+ {
88+ switch ( x . Value )
89+ {
90+ case Tensor t1 :
91+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , t1 ) ;
92+ case NDArray nd :
93+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( nd ) ) ;
94+ case int intVal :
95+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( intVal ) ) ;
96+ case float floatVal :
97+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( floatVal ) ) ;
98+ case double doubleVal :
99+ return new KeyValuePair < TF_Output , Tensor > ( tensor . _as_tf_output ( ) , new Tensor ( doubleVal ) ) ;
100+ default :
101+ break ;
102+ }
103+ }
104+ throw new NotImplementedException ( "_do_run.feed_dict" ) ;
105+ } ) . ToArray ( ) ;
85106 var fetches = fetch_list . Select ( x => x . _as_tf_output ( ) ) . ToArray ( ) ;
86107 var targets = target_list ;
87108
0 commit comments