@@ -6,83 +6,70 @@ namespace Tensorflow
66{
77 public partial class Tensor
88 {
9- public static Tensor operator + ( Tensor x , Tensor y )
10- {
11- return Python . with < ops . name_scope , Tensor > ( new ops . name_scope ( "" , "add" , new Tensor [ ] { x , y } ) , scope =>
12- {
13- return gen_math_ops . add ( x , y , scope ) ;
14- } ) ;
15- }
16-
17- public static Tensor operator + ( Tensor x , int y )
18- {
19- return Python . with < ops . name_scope , Tensor > ( new ops . name_scope ( "" , "add" , new object [ ] { x , y } ) , scope =>
20- {
21- var y1 = ops . convert_to_tensor ( y , x . dtype . as_base_dtype ( ) , name : "y" ) ;
22- return gen_math_ops . add ( x , y1 , scope ) ;
23- } ) ;
24- }
9+ public static Tensor operator + ( Tensor x , Tensor y ) => BinaryOpWrapper ( "add" , x , y ) ;
10+ public static Tensor operator + ( Tensor x , int y ) => BinaryOpWrapper ( "add" , x , y ) ;
2511
2612 public static Tensor operator - ( Tensor t1 ) => gen_math_ops . neg ( t1 ) ;
27- public static Tensor operator - ( Tensor t1 , Tensor t2 ) => gen_math_ops . sub ( t1 , t2 ) ;
28- public static Tensor operator - ( Tensor t1 , int t2 ) => gen_math_ops . sub ( t1 , t2 ) ;
29- public static Tensor operator - ( Tensor t1 , double t2 ) => gen_math_ops . sub ( t1 , t2 ) ;
3013
31- public static Tensor operator * ( double x , Tensor y )
32- {
33- return Python . with < ops . name_scope , Tensor > ( new ops . name_scope ( "" , "mul" , new { x , y } ) ,
34- scope =>
35- {
36- var x1 = ops . convert_to_tensor ( x , y . dtype . as_base_dtype ( ) , name : "x" ) ;
37- return gen_math_ops . mul ( x1 , y , name : scope ) ;
38- } ) ;
39- }
14+ public static Tensor operator - ( Tensor x , Tensor y ) => BinaryOpWrapper ( "sub" , x , y ) ;
15+ public static Tensor operator - ( Tensor x , int y ) => BinaryOpWrapper ( "sub" , x , y ) ;
16+ public static Tensor operator - ( Tensor x , double y ) => BinaryOpWrapper ( "sub" , x , y ) ;
4017
41- public static Tensor operator * ( Tensor x , Tensor y )
42- {
43- return Python . with < ops . name_scope , Tensor > ( new ops . name_scope ( "" , "mul" , new Tensor [ ] { x , y } ) , scope =>
44- {
45- return gen_math_ops . mul ( x , y , name : scope ) ;
46- } ) ;
47- }
18+ public static Tensor operator * ( float x , Tensor y ) => BinaryOpWrapper ( "mul" , x , y ) ;
19+ public static Tensor operator * ( double x , Tensor y ) => BinaryOpWrapper ( "mul" , x , y ) ;
20+ public static Tensor operator * ( Tensor x , Tensor y ) => BinaryOpWrapper ( "mul" , x , y ) ;
21+ public static Tensor operator * ( Tensor x , int y ) => BinaryOpWrapper ( "mul" , x , y ) ;
4822
49- public static Tensor operator * ( Tensor x , int y )
50- {
51- return Python . with < ops . name_scope , Tensor > ( new ops . name_scope ( "" , "mul" , new object [ ] { x , y } ) , scope =>
52- {
53- var y1 = ops . convert_to_tensor ( y , x . dtype . as_base_dtype ( ) , name : "y" ) ;
54- return gen_math_ops . mul ( x , y1 , name : scope ) ;
55- } ) ;
56- }
23+ public static Tensor operator / ( Tensor x , Tensor y ) => BinaryOpWrapper ( "truediv" , x , y ) ;
24+ public static Tensor operator / ( Tensor x , float y ) => BinaryOpWrapper ( "truediv" , x , y ) ;
25+ public static Tensor operator / ( Tensor x , double y ) => BinaryOpWrapper ( "truediv" , x , y ) ;
5726
58- public static Tensor operator / ( Tensor x , Tensor y )
59- {
60- return Python . with < ops . name_scope , Tensor > ( new ops . name_scope ( "truediv/" , "truediv" , new Tensor [ ] { x , y } ) , scope =>
61- {
62- return gen_math_ops . real_div ( x , y , scope ) ;
63- } ) ;
64- }
27+ public static Tensor operator % ( Tensor x , Tensor y ) => BinaryOpWrapper ( "mod" , x , y ) ;
6528
66- public static Tensor operator / ( Tensor x , double y )
67- {
68- return Python . with < ops . name_scope , Tensor > ( new ops . name_scope ( "truediv/" , "truediv" , new object [ ] { x , y } ) , scope =>
69- {
70- var y1 = ops . convert_to_tensor ( y , dtype : x . dtype . as_base_dtype ( ) , name : "y" ) ;
71- return gen_math_ops . real_div ( x , y1 , scope ) ;
72- } ) ;
73- }
29+ public static Tensor operator > ( Tensor x , int y ) => gen_array_ops . greater ( x , y ) ;
30+ public static Tensor operator > ( Tensor x , double y ) => gen_array_ops . greater ( x , y ) ;
31+ public static Tensor operator < ( Tensor x , int y ) => gen_array_ops . less ( x , y ) ;
32+ public static Tensor operator < ( Tensor x , double y ) => gen_array_ops . less ( x , y ) ;
7433
75- public static Tensor operator % ( Tensor x , Tensor y )
34+ private static Tensor BinaryOpWrapper < Tx , Ty > ( string name , Tx x , Ty y )
7635 {
77- return Python . with < ops . name_scope , Tensor > ( new ops . name_scope ( "" , "mod" , new object [ ] { x , y } ) , scope =>
36+ TF_DataType dtype = TF_DataType . DtInvalid ;
37+ if ( x is Tensor tl )
38+ dtype = tl . dtype . as_base_dtype ( ) ;
39+ if ( y is Tensor tr )
40+ dtype = tr . dtype . as_base_dtype ( ) ;
41+
42+ var namescope = new ops . name_scope ( "" , name , new { x , y } ) ;
43+ return Python . with < ops . name_scope , Tensor > ( namescope , scope =>
7844 {
79- return gen_math_ops . floor_mod ( x , y , scope ) ;
45+ Tensor result = null ;
46+ var x1 = ops . convert_to_tensor ( x , dtype : dtype , name : "x" ) ;
47+ var y1 = ops . convert_to_tensor ( y , dtype : dtype , name : "y" ) ;
48+
49+ switch ( name )
50+ {
51+ case "add" :
52+ result = gen_math_ops . add ( x1 , y1 , name : scope ) ;
53+ break ;
54+ case "truediv" :
55+ result = gen_math_ops . real_div ( x1 , y1 , name : scope ) ;
56+ break ;
57+ case "mul" :
58+ result = gen_math_ops . mul ( x1 , y1 , name : scope ) ;
59+ break ;
60+ case "sub" :
61+ result = gen_math_ops . sub ( x1 , y1 , name : scope ) ;
62+ break ;
63+ case "mod" :
64+ result = gen_math_ops . floor_mod ( x1 , y1 , name : scope ) ;
65+ break ;
66+ default :
67+ throw new NotImplementedException ( $ "BinaryOpWrapper: { name } - { typeof ( Tx ) . Name } , { typeof ( Ty ) } ") ;
68+ }
69+
70+ return result ;
8071 } ) ;
72+
8173 }
82-
83- public static Tensor operator > ( Tensor x , int y ) => gen_array_ops . greater ( x , y ) ;
84- public static Tensor operator > ( Tensor x , double y ) => gen_array_ops . greater ( x , y ) ;
85- public static Tensor operator < ( Tensor x , int y ) => gen_array_ops . less ( x , y ) ;
86- public static Tensor operator < ( Tensor x , double y ) => gen_array_ops . less ( x , y ) ;
8774 }
8875}
0 commit comments