Skip to content

Commit 1bfb22e

Browse files
pavankyumar456
authored andcommitted
Changingg expandAs, reduceAs to tileAs, sumAs
1 parent 4ad39d8 commit 1bfb22e

File tree

5 files changed

+25
-25
lines changed

5 files changed

+25
-25
lines changed

examples/autograd.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,11 @@ void test_tanh()
128128
VERIFY(dx.array() - (1 + af::tanh(x.array())) * (1 - af::tanh(x.array())));
129129
}
130130

131-
void test_expand()
131+
void test_tile()
132132
{
133133
auto x = Variable(af::randu(5), true);
134134
auto y = Variable(af::randu(5, 2), true);
135-
auto z = y * expandAs(x, y);
135+
auto z = y * tileAs(x, y);
136136
auto dz = Variable(af::constant(1.0, 5, 2), false);
137137
z.backward(dz);
138138
auto dy = y.grad();
@@ -141,11 +141,11 @@ void test_expand()
141141
VERIFY(dx.array() - af::sum(y.array(), 1));
142142
}
143143

144-
void test_reduce()
144+
void test_sum()
145145
{
146146
auto x = Variable(af::randu(5), true);
147147
auto y = Variable(af::randu(5, 2), true);
148-
auto z = x * reduceAs(y, x);
148+
auto z = x * sumAs(y, x);
149149
auto dz = Variable(af::constant(1.0, 5), false);
150150
z.backward(dz);
151151
auto dy = y.grad();
@@ -166,7 +166,7 @@ int main()
166166
test_exp();
167167
test_sigmoid();
168168
test_tanh();
169-
test_expand();
170-
test_reduce();
169+
test_tile();
170+
test_sum();
171171
return 0;
172172
}

include/af/autograd/Functions.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
* http://arrayfire.com/licenses/BSD-3-Clause
88
********************************************************/
99
#pragma once
10+
#include <arrayfire.h>
1011

1112
namespace af {
1213
namespace autograd {
@@ -50,18 +51,18 @@ namespace af {
5051
Variable cos(const Variable &input);
5152
Variable tanh(const Variable &input);
5253
Variable sigmoid(const Variable &input);
53-
54+
5455
Variable max(const Variable &lhs, const Variable &rhs);
5556
Variable max(const Variable &lhs, const double &rhs);
5657
Variable max(const double &lhs, const Variable &rhs);
57-
58+
5859
Variable min(const Variable &lhs, const Variable &rhs);
5960
Variable min(const Variable &lhs, const double &rhs);
6061
Variable min(const double &lhs, const Variable &rhs);
61-
62+
6263
Variable transpose(const Variable &input);
63-
Variable expandAs(const Variable &input, const Variable &reference);
64-
Variable reduceAs(const Variable &input, const Variable &reference);
64+
Variable tileAs(const Variable &input, const Variable &reference);
65+
Variable sumAs(const Variable &input, const Variable &reference);
6566

6667
Variable matmul(const Variable &lhs, const Variable &rhs);
6768
Variable matmulTN(const Variable &lhs, const Variable &rhs);

src/autograd/Functions.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ namespace af {
5454
};
5555
return Variable(result, {lhs, rhs}, grad_func);
5656
}
57-
57+
5858
Variable operator >(const Variable &lhs, const Variable &rhs)
5959
{
6060
auto result = lhs.array() > rhs.array();
@@ -116,7 +116,7 @@ namespace af {
116116
auto result = !input.array();
117117
return Variable(result, false);
118118
}
119-
119+
120120
Variable max(const Variable &lhs, const Variable &rhs)
121121
{
122122
auto mask = lhs > rhs;
@@ -165,7 +165,7 @@ namespace af {
165165
INSTANTIATE_FUNCTION(min);
166166

167167
#undef INSTANTIATE_FUNCTION
168-
168+
169169
Variable negate(const Variable &input)
170170
{
171171
auto result = 0.0 - input.array();
@@ -241,31 +241,31 @@ namespace af {
241241
return Variable(result, {input}, grad_func);
242242
}
243243

244-
Variable expandAs(const Variable &input, const Variable &reference)
244+
Variable tileAs(const Variable &input, const Variable &reference)
245245
{
246246
dim4 dims(1,1,1,1);
247-
dim4 idims = input.array().dims();
248-
dim4 rdims = reference.array().dims();
247+
dim4 rdims = reference.dims();
248+
dim4 idims = input.dims();
249249
for (int i = 0; i < 4; i++) {
250250
dims[i] = rdims[i] / idims[i];
251251
}
252252
auto result = tile(input.array(), dims);
253253
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
254-
inputs[0].addGrad(reduceAs(grad_output, inputs[0]));
254+
inputs[0].addGrad(sumAs(grad_output, inputs[0]));
255255
};
256256
return Variable(result, {input}, grad_func);
257257
}
258258

259-
Variable reduceAs(const Variable &input, const Variable &reference)
259+
Variable sumAs(const Variable &input, const Variable &reference)
260260
{
261-
dim4 idims = input.array().dims();
262-
dim4 rdims = reference.array().dims();
261+
dim4 rdims = reference.dims();
262+
dim4 idims = input.dims();
263263
auto result = input.array();
264264
for (int i = 0; i < 4; i++) {
265265
if (idims[i] != rdims[i]) result = sum(result, i);
266266
}
267267
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
268-
inputs[0].addGrad(expandAs(grad_output, inputs[0]));
268+
inputs[0].addGrad(tileAs(grad_output, inputs[0]));
269269
};
270270
return Variable(result, {input}, grad_func);
271271
}

src/nn/Modules/Activations.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ namespace af
6161
Variable PReLU::forward(const Variable &input)
6262
{
6363
auto mask = input >= 0.0;
64-
return (input * mask) + (input * !mask * expandAs(m_parameters[0],input));
64+
return (input * mask) + (input * !mask * tileAs(m_parameters[0], input));
6565
}
6666

6767
ELU::ELU(double alpha) :
@@ -85,6 +85,5 @@ namespace af
8585
auto mask = input >= m_threshold;
8686
return input * mask;
8787
}
88-
8988
}
9089
}

src/nn/Modules/Linear.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ namespace af
5151
{
5252
auto res = matmul(m_parameters[0], input);
5353
if (m_bias) {
54-
res = res + expandAs(m_parameters[1], res);
54+
res = res + tileAs(m_parameters[1], res);
5555
}
5656
return res;
5757
}

0 commit comments

Comments
 (0)