Skip to content

Commit 7e0fe27

Browse files
pavankyumar456
authored andcommitted
Add scalar support for operators
1 parent 8f4cfc1 commit 7e0fe27

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

examples/autograd.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@ void test_divide_add()
8383
VERIFY(dy.array() - (1.0 - x.array() / (y.array() * y.array())));
8484
}
8585

86+
void test_multiply_add_scalar()
87+
{
88+
auto x = Variable(af::randu(5), true);
89+
auto y = Variable(af::randu(5), true);
90+
auto z = 2 * x + x * y + y;
91+
auto dz = Variable(af::constant(1.0, 5), false);
92+
z.backward(dz);
93+
auto dx = x.grad();
94+
auto dy = y.grad();
95+
VERIFY(dx.array() - (2.0 + y.array()));
96+
VERIFY(dy.array() - (1.0 + x.array()));
97+
}
98+
8699
int main()
87100
{
88101
af::info();
@@ -91,5 +104,6 @@ int main()
91104
test_no_calc_grad();
92105
test_multiply_sub();
93106
test_divide_add();
107+
test_multiply_add_scalar();
94108
return 0;
95109
}

include/af/autograd/Functions.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ namespace af {
1818
Variable operator -(const Variable &lhs, const Variable &rhs);
1919
Variable operator /(const Variable &lhs, const Variable &rhs);
2020

21+
Variable operator +(const double &lhs, const Variable &rhs);
22+
Variable operator *(const double &lhs, const Variable &rhs);
23+
Variable operator -(const double &lhs, const Variable &rhs);
24+
Variable operator /(const double &lhs, const Variable &rhs);
25+
26+
Variable operator +(const Variable &lhs, const double &rhs);
27+
Variable operator *(const Variable &lhs, const double &rhs);
28+
Variable operator -(const Variable &lhs, const double &rhs);
29+
Variable operator /(const Variable &lhs, const double &rhs);
30+
2131
Variable negate(const Variable &input);
2232
Variable reciprocal(const Variable &input);
2333
}

src/autograd/Functions.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,29 @@ namespace af {
7474
return Variable(result, {input}, grad_func);
7575
}
7676

77+
78+
#define INSTANTIATE_OPERATOR(OP) \
79+
Variable operator OP(const double &lhs_val, const Variable &rhs) \
80+
{ \
81+
auto lhs = Variable( \
82+
af::constant(lhs_val, \
83+
rhs.array().dims(), \
84+
rhs.array().type()), \
85+
false); \
86+
return lhs OP rhs; \
87+
} \
88+
Variable operator OP(const Variable &lhs, const double &rhs_val) \
89+
{ \
90+
auto rhs = Variable( \
91+
af::constant(rhs_val, \
92+
lhs.array().dims(), lhs.array().type()), \
93+
false); \
94+
return lhs OP rhs; \
95+
} \
96+
97+
INSTANTIATE_OPERATOR(+)
98+
INSTANTIATE_OPERATOR(-)
99+
INSTANTIATE_OPERATOR(*)
100+
INSTANTIATE_OPERATOR(/)
77101
}
78102
}

0 commit comments

Comments
 (0)