@@ -32,31 +32,41 @@ public Tensor execute(Context ctx, string op_name, Tensor[] inputs, object[] att
3232 ctx . ensure_initialized ( ) ;
3333 using ( var status = new Status ( ) )
3434 {
35- var retVals = wrap_tfe_src . TFE_Py_Execute ( ctx , ctx . device_name , op_name , inputs , attrs , 1 , status ) ;
35+ var retVals = wrap_tfe_src . TFE_Execute ( ctx , ctx . device_name , op_name , inputs , attrs , 1 , status ) ;
3636
37- var t = c_api . TFE_TensorHandleResolve ( retVals [ 0 ] , status ) ;
38- status . Check ( true ) ;
39-
40- return new EagerTensor ( t ) ;
37+ return new EagerTensor ( retVals [ 0 ] ) ;
4138 }
4239 }
4340
44- public ( TF_DataType , Tensor ) args_to_matching_eager ( Tensor [ ] l , Context ctx , TF_DataType default_dtype = TF_DataType . DtInvalid )
41+ public ( TF_DataType , Tensor [ ] ) args_to_matching_eager ( Context ctx , TF_DataType default_dtype = TF_DataType . DtInvalid , object [ ] args = null )
4542 {
46- var dtype = default_dtype ;
47- if ( dtype == TF_DataType . DtInvalid )
48- {
49- var tensor = ops . convert_to_tensor ( l , dtype , preferred_dtype : default_dtype , ctx : ctx ) ;
43+ if ( args . Length == 0 && default_dtype != TF_DataType . DtInvalid )
44+ return ( default_dtype , null ) ;
5045
51- if ( dtype == TF_DataType . DtInvalid )
52- dtype = tensor . dtype ;
46+ if ( args . Count ( x => x is EagerTensor ) == args . Length )
47+ return ( ( args [ 0 ] as EagerTensor ) . dtype , args . Select ( x => x as EagerTensor ) . ToArray ( ) ) ;
5348
54- return ( dtype , tensor ) ;
49+ var dtype = TF_DataType . DtInvalid ;
50+ foreach ( var x in args )
51+ {
52+ if ( x is EagerTensor et )
53+ dtype = et . dtype ;
5554 }
56- else
55+
56+ if ( dtype == TF_DataType . DtInvalid )
5757 {
58- return ( dtype , l [ 0 ] ) ;
58+ var ret = new List < Tensor > ( ) ;
59+ foreach ( var t in args )
60+ {
61+ ret . Add ( ops . convert_to_tensor ( t , dtype , preferred_dtype : default_dtype , ctx : ctx ) ) ;
62+ if ( dtype == TF_DataType . DtInvalid )
63+ dtype = ret . Last ( ) . dtype ;
64+ }
65+
66+ return ( dtype , ret . ToArray ( ) ) ;
5967 }
68+ else
69+ throw new NotImplementedException ( "" ) ;
6070 }
6171
6272 public void record_gradient ( string op_name , InputList inputs , Dictionary < string , object > attrs , Tensor [ ] results , string name = null )
0 commit comments