Skip to content

Commit 8f4cfc1

Browse files
pavankyumar456
authored andcommitted
Adding negate, reciprocal, subtract and divide
1 parent f9ec214 commit 8f4cfc1

File tree

5 files changed

+128
-48
lines changed

5 files changed

+128
-48
lines changed

examples/autograd.cpp

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,57 +9,87 @@
99

1010
#include <af/autograd.h>
1111

12+
#define VERIFY(VAL) do { \
13+
auto res = af::allTrue<bool>(af::abs(VAL) < 1E-5); \
14+
printf("%s:%d %s\n", __FUNCTION__, __LINE__, \
15+
res ? "PASS" : "FAIL"); \
16+
} while(0)
17+
1218
using af::autograd::Variable;
13-
void test1()
19+
void test_multiply()
1420
{
1521
auto x = Variable(af::randu(5), true);
16-
af_print(x.array());
1722
auto y = x * x;
18-
af_print(y.array());
1923
auto dy = Variable(af::constant(1.0, 5), false);
2024
y.backward(dy);
2125
auto dx = x.grad();
22-
af_print(dx.array() - 2 * x.array());
26+
VERIFY(dx.array() - 2 * x.array());
2327
}
2428

25-
void test2()
29+
void test_multipl_add()
2630
{
2731
auto x = Variable(af::randu(5), true);
28-
af_print(x.array());
2932
auto y = Variable(af::randu(5), true);
30-
af_print(y.array());
3133
auto z = x * x + x * y + y * y;
3234
auto dz = Variable(af::constant(1.0, 5), false);
3335
z.backward(dz);
3436
auto dx = x.grad();
3537
auto dy = y.grad();
36-
af_print(dx.array() - 2 * x.array() - y.array());
37-
af_print(dy.array() - 2 * y.array() - x.array());
38+
VERIFY(dx.array() - 2 * x.array() - y.array());
39+
VERIFY(dy.array() - 2 * y.array() - x.array());
3840
}
3941

40-
void test3()
42+
void test_no_calc_grad()
4143
{
4244
auto x = Variable(af::randu(5), false);
43-
af_print(x.array());
4445
auto y = Variable(af::randu(5), true);
45-
af_print(y.array());
4646
auto z = x * x + x * y + y * y;
4747
auto dz = Variable(af::constant(1.0, 5), false);
4848
z.backward(dz);
4949
auto dy = y.grad();
50-
af_print(dy.array() - 2 * y.array() - x.array());
50+
VERIFY(dy.array() - 2 * y.array() - x.array());
5151
try {
5252
auto dx = x.grad();
5353
} catch(af::exception &ex) {
5454
std::cout << ex.what() << std::endl;
55+
return;
5556
}
57+
printf("%s:%d No Gradient check Failed\n");
58+
}
59+
60+
void test_multiply_sub()
61+
{
62+
auto x = Variable(af::randu(5), true);
63+
auto y = Variable(af::randu(5), true);
64+
auto z = x * x - x * y;
65+
auto dz = Variable(af::constant(1.0, 5), false);
66+
z.backward(dz);
67+
auto dx = x.grad();
68+
auto dy = y.grad();
69+
VERIFY(dx.array() - (2 * x.array() - y.array()));
70+
VERIFY(dy.array() - (-x.array()));
71+
}
72+
73+
void test_divide_add()
74+
{
75+
auto x = Variable(af::randu(5), true);
76+
auto y = Variable(af::randu(5), true);
77+
auto z = x + x / y + y;
78+
auto dz = Variable(af::constant(1.0, 5), false);
79+
z.backward(dz);
80+
auto dx = x.grad();
81+
auto dy = y.grad();
82+
VERIFY(dx.array() - (1.0 + 1.0 / y.array()));
83+
VERIFY(dy.array() - (1.0 - x.array() / (y.array() * y.array())));
5684
}
5785

5886
int main()
5987
{
6088
af::info();
61-
test1();
62-
test2();
63-
test3();
89+
test_multiply();
90+
test_multipl_add();
91+
test_no_calc_grad();
92+
test_multiply_sub();
93+
test_divide_add();
6494
return 0;
6595
}

include/af/autograd/Functions.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ namespace af {
1313

1414
class Variable;
1515

16-
Variable operator +(const Variable lhs, const Variable rhs);
17-
Variable operator *(const Variable lhs, const Variable rhs);
16+
Variable operator +(const Variable &lhs, const Variable &rhs);
17+
Variable operator *(const Variable &lhs, const Variable &rhs);
18+
Variable operator -(const Variable &lhs, const Variable &rhs);
19+
Variable operator /(const Variable &lhs, const Variable &rhs);
20+
21+
Variable negate(const Variable &input);
22+
Variable reciprocal(const Variable &input);
1823
}
1924
}

include/af/autograd/Variable.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ namespace af {
2222
class Variable
2323
{
2424
public:
25-
typedef std::function<void(std::vector<Variable>, Variable)> GradFunc_t;
25+
typedef std::function<void(std::vector<Variable> &, const Variable &)> GradFunc_t;
2626
typedef std::unordered_map<std::ptrdiff_t, bool> Cache_t;
2727
typedef std::vector<Variable> DAG_t;
2828

2929
private:
3030
struct Shared {
3131
Shared();
32-
Shared(af::array data, bool calc_grad);
33-
Shared(af::array data,
34-
std::vector<Variable> inputs,
32+
Shared(const af::array &data, bool calc_grad);
33+
Shared(const af::array &data,
34+
const std::vector<Variable> &inputs,
3535
GradFunc_t grad_func,
3636
bool calc_grad);
3737

@@ -45,26 +45,26 @@ namespace af {
4545
public:
4646

4747
Variable();
48-
Variable(af::array data, bool calc_grad);
49-
Variable(af::array data,
50-
std::vector<Variable> inputs,
48+
Variable(const af::array &data, bool calc_grad);
49+
Variable(const af::array &data,
50+
const std::vector<Variable> &inputs,
5151
GradFunc_t grad_func);
5252

5353
af::array array() const;
5454

5555
Variable grad() const;
5656

57-
bool isCalcGrad();
57+
bool isCalcGrad() const;
5858

5959
void setCalcGrad(bool calc_grad);
6060

61-
void addGrad(Variable child_grad);
61+
void addGrad(const Variable &child_grad);
6262

6363
void evalGrad();
6464

6565
void calcGradInputs();
6666

67-
void backward(Variable grad);
67+
void backward(const Variable &grad);
6868

6969
DAG_t build();
7070

src/autograd/Functions.cpp

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,66 @@
1313
namespace af {
1414
namespace autograd {
1515

16-
Variable operator +(const Variable lhs, const Variable rhs)
16+
Variable operator +(const Variable &lhs, const Variable &rhs)
1717
{
1818
auto result = lhs.array() + rhs.array();
19-
auto grad_func = [](std::vector<Variable> inputs, Variable grad_output) {
19+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
2020
inputs[0].addGrad(grad_output);
2121
inputs[1].addGrad(grad_output);
2222
};
2323
return Variable(result, {lhs, rhs}, grad_func);
2424
}
2525

26-
Variable operator *(const Variable lhs, const Variable rhs)
26+
Variable operator -(const Variable &lhs, const Variable &rhs)
27+
{
28+
auto result = lhs.array() - rhs.array();
29+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
30+
inputs[0].addGrad(grad_output);
31+
inputs[1].addGrad(negate(grad_output));
32+
};
33+
return Variable(result, {lhs, rhs}, grad_func);
34+
}
35+
36+
Variable operator *(const Variable &lhs, const Variable &rhs)
2737
{
2838
auto result = lhs.array() * rhs.array();
29-
auto grad_func = [](std::vector<Variable> inputs, Variable grad_output) {
39+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
3040
inputs[0].addGrad(grad_output * inputs[1]);
3141
inputs[1].addGrad(grad_output * inputs[0]);
3242
};
3343
return Variable(result, {lhs, rhs}, grad_func);
3444
}
3545

46+
Variable operator /(const Variable &lhs, const Variable &rhs)
47+
{
48+
auto result = lhs.array() / rhs.array();
49+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
50+
auto inputs_1_rec = reciprocal(inputs[1]);
51+
auto grad_input_0 = grad_output * inputs_1_rec;
52+
inputs[0].addGrad(grad_input_0);
53+
inputs[1].addGrad(grad_input_0 * negate(inputs[0]) * inputs_1_rec);
54+
};
55+
return Variable(result, {lhs, rhs}, grad_func);
56+
}
57+
58+
Variable negate(const Variable &input)
59+
{
60+
auto result = 0.0 - input.array();
61+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
62+
inputs[0].addGrad(negate(grad_output));
63+
};
64+
return Variable(result, {input}, grad_func);
65+
}
66+
67+
Variable reciprocal(const Variable &input)
68+
{
69+
auto result = 1.0 / input.array();
70+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
71+
auto res = reciprocal(inputs[0]);
72+
inputs[0].addGrad(negate(grad_output) * res * res);
73+
};
74+
return Variable(result, {input}, grad_func);
75+
}
76+
3677
}
3778
}

src/autograd/Variable.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ namespace af {
2121
m_grad_func(nullptr)
2222
{}
2323

24-
Variable::Shared::Shared(af::array data, bool calc_grad) :
24+
Variable::Shared::Shared(const af::array &data, bool calc_grad) :
2525
m_calc_grad(calc_grad),
2626
m_data(data),
2727
m_inputs(),
2828
m_grads(),
2929
m_grad_func(nullptr)
3030
{}
3131

32-
Variable::Shared::Shared(af::array data,
33-
std::vector<Variable> inputs,
32+
Variable::Shared::Shared(const af::array &data,
33+
const std::vector<Variable> &inputs,
3434
GradFunc_t grad_func,
3535
bool calc_grad) :
3636
m_calc_grad(calc_grad),
@@ -45,13 +45,13 @@ namespace af {
4545
{
4646
}
4747

48-
Variable::Variable(af::array data, bool calc_grad) :
48+
Variable::Variable(const af::array &data, bool calc_grad) :
4949
m_shared(new Shared(data, calc_grad))
5050
{}
5151

52-
Variable::Variable(af::array data,
53-
std::vector<Variable> inputs,
54-
GradFunc_t grad_func) :
52+
Variable::Variable(const af::array &data,
53+
const std::vector<Variable> &inputs,
54+
GradFunc_t grad_func) :
5555
m_shared(nullptr)
5656
{
5757
bool calc_grad = false;
@@ -81,7 +81,7 @@ namespace af {
8181
return m_shared->m_grads[0];
8282
}
8383

84-
bool Variable::isCalcGrad()
84+
bool Variable::isCalcGrad() const
8585
{
8686
return m_shared->m_calc_grad;
8787
}
@@ -96,7 +96,7 @@ namespace af {
9696
}
9797
}
9898

99-
void Variable::addGrad(Variable child_grad)
99+
void Variable::addGrad(const Variable &child_grad)
100100
{
101101
if (m_shared->m_calc_grad) {
102102
m_shared->m_grads.push_back(child_grad);
@@ -107,13 +107,17 @@ namespace af {
107107
{
108108
// Flag asking not to calculate gradients
109109
if (!m_shared->m_calc_grad) return;
110-
Variable grad = m_shared->m_grads[0];
111-
for (unsigned i = 1; i < m_shared->m_grads.size(); i++) {
112-
grad = grad + m_shared->m_grads[i];
110+
111+
// Best not to evaluate the JIT immediately if theres only a single gradient
112+
if (m_shared->m_grads.size() > 1) {
113+
Variable grad = m_shared->m_grads[0];
114+
for (unsigned i = 1; i < m_shared->m_grads.size(); i++) {
115+
grad = grad + m_shared->m_grads[i];
116+
}
117+
grad.array().eval();
118+
m_shared->m_grads.clear();
119+
m_shared->m_grads.push_back(grad);
113120
}
114-
grad.array().eval();
115-
m_shared->m_grads.clear();
116-
m_shared->m_grads.push_back(grad);
117121
}
118122

119123
void Variable::calcGradInputs()
@@ -124,7 +128,7 @@ namespace af {
124128
}
125129
}
126130

127-
void Variable::backward(Variable grad)
131+
void Variable::backward(const Variable &grad)
128132
{
129133
this->addGrad(grad);
130134
Variable::DAG_t dag = this->build();

0 commit comments

Comments
 (0)