@@ -13,53 +13,81 @@ namespace Tensorflow.Graphs
1313 public sealed class AutoGraphAttribute : OnMethodBoundaryAspect
1414 {
1515 FuncGraph graph ;
16- Tensor [ ] originalInputs ;
16+ Tensors originalInputs ;
1717 string func_name ;
18- static Dictionary < string , Func < Tensor [ ] , Tensor > > functions = new Dictionary < string , Func < Tensor [ ] , Tensor > > ( ) ;
18+ static Dictionary < string , Func < Tensors , Tensors > > functions = new Dictionary < string , Func < Tensors , Tensors > > ( ) ;
1919
2020 public override void OnEntry ( MethodExecutionArgs args )
2121 {
22- if ( args . Instance is TensorFlowOpLayer op )
23- func_name = $ "autograph_{ op . OpType } .{ args . Method . Name } ";
24- else
25- func_name = $ "autograph_{ args . Instance } .{ args . Method . Name } ";
22+ func_name = $ "autograph_{ args . Instance . GetHashCode ( ) } .{ args . Method . Name } ";
2623
2724 if ( functions . ContainsKey ( func_name ) )
2825 {
29- args . ReturnValue = functions [ func_name ] ( args . Arguments . Select ( x => x as Tensor ) . ToArray ( ) ) ;
26+ if ( args . Arguments [ 0 ] is Tensors tensor_inputs )
27+ args . ReturnValue = functions [ func_name ] ( tensor_inputs . ToArray ( ) ) ;
28+ else
29+ args . ReturnValue = functions [ func_name ] ( args . Arguments . Select ( x => x as Tensor ) . ToArray ( ) ) ;
3030 args . FlowBehavior = FlowBehavior . Return ;
3131 return ;
3232 }
3333
3434 // make function as an Operation by autograph
3535 graph = new FuncGraph ( func_name ) ;
3636
37- originalInputs = new Tensor [ args . Arguments . Length ] ;
38- // convert args to placeholder
39- for ( var i = 0 ; i < args . Arguments . Length ; i ++ )
37+ // convert to Tensors
38+ if ( args . Arguments [ 0 ] is Tensors inputs )
39+ {
40+ originalInputs = inputs ;
41+ var new_inputs = inputs . Select ( x => tf . placeholder ( x . dtype , shape : x . TensorShape ) ) . ToArray ( ) ;
42+ args . Arguments [ 0 ] = new Tensors ( new_inputs ) ;
43+ }
44+ else
4045 {
41- if ( args . Arguments [ i ] is EagerTensor tensor )
46+ originalInputs = new Tensors ( args . Arguments . Length ) ;
47+ // convert args to placeholder
48+ for ( var i = 0 ; i < args . Arguments . Length ; i ++ )
4249 {
43- originalInputs [ i ] = tensor ;
44- args . Arguments [ i ] = tf . placeholder ( tensor . dtype , shape : tensor . TensorShape ) ;
50+ if ( args . Arguments [ i ] is EagerTensor tensor )
51+ {
52+ originalInputs [ i ] = tensor ;
53+ args . Arguments [ i ] = tf . placeholder ( tensor . dtype , shape : tensor . TensorShape ) ;
54+ }
4555 }
4656 }
4757 }
4858
4959 public override void OnExit ( MethodExecutionArgs args )
5060 {
51- var output = ( Tensor ) args . ReturnValue ;
52- var inputs = args . Arguments . Select ( x => x as Tensor ) . ToArray ( ) ;
5361 var opers = graph . _nodes_by_name . Values . Select ( x => x as Operation ) . ToArray ( ) ;
5462
55- graph . ToGraph ( opers ,
56- inputs . Select ( x => x . op ) . ToArray ( ) ,
57- new Operation [ ] { output . op } ,
58- null ) ;
63+ if ( args . ReturnValue is Tensors outputs )
64+ {
65+ if ( args . Arguments [ 0 ] is Tensors inputs )
66+ {
67+ graph . ToGraph ( opers ,
68+ inputs . Select ( x => x . op ) . ToArray ( ) ,
69+ outputs . Select ( x => x . op ) . ToArray ( ) ,
70+ null ) ;
71+ }
72+ else
73+ {
74+ graph . ToGraph ( opers ,
75+ args . Arguments . Select ( x => ( x as Tensor ) . op ) . ToArray ( ) ,
76+ outputs . Select ( x => x . op ) . ToArray ( ) ,
77+ null ) ;
78+ }
79+ }
80+ else
81+ {
82+ graph . ToGraph ( opers ,
83+ args . Arguments . Select ( x => ( x as Tensor ) . op ) . ToArray ( ) ,
84+ new Operation [ ] { ( args . ReturnValue as Tensor ) . op } ,
85+ null ) ;
86+ }
5987
6088 graph . Dispose ( ) ;
6189
62- Func < Tensor [ ] , Tensor > function = ( x ) =>
90+ Func < Tensors , Tensors > function = ( x ) =>
6391 {
6492 var result = tf . Runner . TFE_Execute ( tf . Context ,
6593 tf . Context . DeviceName ,
0 commit comments