File tree Expand file tree Collapse file tree 3 files changed +48
-0
lines changed
Expand file tree Collapse file tree 3 files changed +48
-0
lines changed Original file line number Diff line number Diff line change @@ -83,6 +83,19 @@ void test_divide_add()
8383 VERIFY (dy.array () - (1.0 - x.array () / (y.array () * y.array ())));
8484}
8585
86+ void test_multiply_add_scalar ()
87+ {
88+ auto x = Variable (af::randu (5 ), true );
89+ auto y = Variable (af::randu (5 ), true );
90+ auto z = 2 * x + x * y + y;
91+ auto dz = Variable (af::constant (1.0 , 5 ), false );
92+ z.backward (dz);
93+ auto dx = x.grad ();
94+ auto dy = y.grad ();
95+ VERIFY (dx.array () - (2.0 + y.array ()));
96+ VERIFY (dy.array () - (1.0 + x.array ()));
97+ }
98+
8699int main ()
87100{
88101 af::info ();
@@ -91,5 +104,6 @@ int main()
91104 test_no_calc_grad ();
92105 test_multiply_sub ();
93106 test_divide_add ();
107+ test_multiply_add_scalar ();
94108 return 0 ;
95109}
Original file line number Diff line number Diff line change @@ -18,6 +18,16 @@ namespace af {
1818 Variable operator -(const Variable &lhs, const Variable &rhs);
1919 Variable operator /(const Variable &lhs, const Variable &rhs);
2020
21+ Variable operator +(const double &lhs, const Variable &rhs);
22+ Variable operator *(const double &lhs, const Variable &rhs);
23+ Variable operator -(const double &lhs, const Variable &rhs);
24+ Variable operator /(const double &lhs, const Variable &rhs);
25+
26+ Variable operator +(const Variable &lhs, const double &rhs);
27+ Variable operator *(const Variable &lhs, const double &rhs);
28+ Variable operator -(const Variable &lhs, const double &rhs);
29+ Variable operator /(const Variable &lhs, const double &rhs);
30+
2131 Variable negate (const Variable &input);
2232 Variable reciprocal (const Variable &input);
2333 }
Original file line number Diff line number Diff line change @@ -74,5 +74,29 @@ namespace af {
7474 return Variable (result, {input}, grad_func);
7575 }
7676
77+
78+ #define INSTANTIATE_OPERATOR (OP ) \
79+ Variable operator OP (const double &lhs_val, const Variable &rhs) \
80+ { \
81+ auto lhs = Variable ( \
82+ af::constant (lhs_val, \
83+ rhs.array ().dims (), \
84+ rhs.array ().type ()), \
85+ false ); \
86+ return lhs OP rhs; \
87+ } \
88+ Variable operator OP (const Variable &lhs, const double &rhs_val) \
89+ { \
90+ auto rhs = Variable ( \
91+ af::constant (rhs_val, \
92+ lhs.array ().dims (), lhs.array ().type ()), \
93+ false ); \
94+ return lhs OP rhs; \
95+ } \
96+
97+ INSTANTIATE_OPERATOR (+)
98+ INSTANTIATE_OPERATOR (-)
99+ INSTANTIATE_OPERATOR (*)
100+ INSTANTIATE_OPERATOR (/)
77101 }
78102}
You can’t perform that action at this time.
0 commit comments