1414#include < memory>
1515#include < vector>
1616#include < unordered_map>
17+ #include < stdexcept>
1718
1819#include < arrayfire.h>
1920
2021namespace af {
2122 namespace autograd {
23+
24+ // Forward declare the function
25+ class Variable ;
26+ Variable operator +(const Variable lhs, const Variable rhs);
27+
2228 class Variable
2329 {
2430 public:
@@ -31,25 +37,22 @@ namespace af {
3137 public:
3238 Shared () :
3339 m_data (),
34- m_grad (),
3540 m_inputs (),
36- m_grad_parts (),
41+ m_grads (),
3742 m_backward (nullptr )
3843 {}
3944
4045 Shared (af::array data) :
4146 m_data (data),
42- m_grad (af::constant(0 , data.dims(), data.type())),
4347 m_inputs (),
44- m_grad_parts (),
48+ m_grads (),
4549 m_backward (nullptr )
4650 {}
4751
4852 Shared (af::array data, std::vector<Variable> inputs, BackwardFunc_t backward) :
4953 m_data (data),
50- m_grad (af::constant(0 , data.dims(), data.type())),
5154 m_inputs (inputs.begin(), inputs.end()),
52- m_grad_parts (),
55+ m_grads (),
5356 m_backward (backward)
5457 {}
5558
@@ -58,19 +61,17 @@ namespace af {
5861 return m_data;
5962 }
6063
61- af::array getGrad () const
64+ Variable getGrad () const
6265 {
63- return m_grad;
66+ if (m_grads.size () == 0 ) {
67+ throw std::runtime_error (" Gradient hasn't been calculated" );
68+ }
69+ return m_grads[0 ];
6470 }
6571
6672 void addGrad (Variable grad)
6773 {
68- m_grad_parts.push_back (grad);
69- }
70-
71- std::vector<Variable> getGradParts ()
72- {
73- return m_grad_parts;
74+ m_grads.push_back (grad);
7475 }
7576
7677 std::vector<Variable> getInputs ()
@@ -80,24 +81,26 @@ namespace af {
8081
8182 void evalGrad ()
8283 {
83- m_grad = m_grad_parts[0 ].getData ();
84- for (int i = 1 ; i < (int )m_grad_parts.size (); i++) {
85- m_grad += m_grad_parts[i].getData ();
84+ if (m_grads.size () == 1 ) return ;
85+ Variable grad = m_grads[0 ];
86+ for (int i = 1 ; i < (int )m_grads.size (); i++) {
87+ grad = grad + m_grads[i];
8688 }
87- af::eval (m_grad);
89+ grad.getData ().eval ();
90+ m_grads.clear ();
91+ m_grads.push_back (grad);
8892 }
8993
9094 void backward ()
9195 {
9296 this ->evalGrad ();
93- if (m_backward) m_backward (m_inputs, m_grad );
97+ if (m_backward) m_backward (m_inputs, m_grads[ 0 ] );
9498 }
9599
96100 private:
97101 af::array m_data;
98- af::array m_grad;
99102 std::vector<Variable> m_inputs;
100- std::vector<Variable> m_grad_parts ;
103+ std::vector<Variable> m_grads ;
101104 BackwardFunc_t m_backward;
102105 };
103106
@@ -123,7 +126,7 @@ namespace af {
123126 return m_shared->getData ();
124127 }
125128
126- af::array getGrad () const
129+ Variable getGrad () const
127130 {
128131 return m_shared->getGrad ();
129132 }
0 commit comments