Skip to content

Commit 546a87a

Browse files
pavankyumar456
authored andcommitted
Adding expandAs, reduceAs, and transpose
1 parent 75ca803 commit 546a87a

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

examples/autograd.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,32 @@ void test_tanh()
128128
VERIFY(dx.array() - (1 + af::tanh(x.array())) * (1 - af::tanh(x.array())));
129129
}
130130

131+
void test_expand()
132+
{
133+
auto x = Variable(af::randu(5), true);
134+
auto y = Variable(af::randu(5, 2), true);
135+
auto z = y * expandAs(x, y);
136+
auto dz = Variable(af::constant(1.0, 5, 2), false);
137+
z.backward(dz);
138+
auto dy = y.grad();
139+
auto dx = x.grad();
140+
VERIFY(dy.array() - af::tile(x.array(), 1, 2));
141+
VERIFY(dx.array() - af::sum(y.array(), 1));
142+
}
143+
144+
void test_reduce()
145+
{
146+
auto x = Variable(af::randu(5), true);
147+
auto y = Variable(af::randu(5, 2), true);
148+
auto z = x * reduceAs(y, x);
149+
auto dz = Variable(af::constant(1.0, 5), false);
150+
z.backward(dz);
151+
auto dy = y.grad();
152+
auto dx = x.grad();
153+
VERIFY(dy.array() - af::tile(x.array(), 1, 2));
154+
VERIFY(dx.array() - af::sum(y.array(), 1));
155+
}
156+
131157
int main()
132158
{
133159
af::info();
@@ -140,5 +166,7 @@ int main()
140166
test_exp();
141167
test_sigmoid();
142168
test_tanh();
169+
test_expand();
170+
test_reduce();
143171
return 0;
144172
}

include/af/autograd/Functions.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,9 @@ namespace af {
3636
Variable cos(const Variable &input);
3737
Variable tanh(const Variable &input);
3838
Variable sigmoid(const Variable &input);
39+
40+
Variable transpose(const Variable &input);
41+
Variable expandAs(const Variable &input, const Variable &reference);
42+
Variable reduceAs(const Variable &input, const Variable &reference);
3943
}
4044
}

src/autograd/Functions.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,5 +147,42 @@ namespace af {
147147
return Variable(result, {input}, grad_func);
148148
}
149149

150+
Variable transpose(const Variable &input)
151+
{
152+
auto result = transpose(input.array());
153+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
154+
inputs[0].addGrad(transpose(grad_output));
155+
};
156+
return Variable(result, {input}, grad_func);
157+
}
158+
159+
Variable expandAs(const Variable &input, const Variable &reference)
160+
{
161+
dim4 dims(1,1,1,1);
162+
dim4 idims = input.array().dims();
163+
dim4 rdims = reference.array().dims();
164+
for (int i = 0; i < 4; i++) {
165+
dims[i] = rdims[i] / idims[i];
166+
}
167+
auto result = tile(input.array(), dims);
168+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
169+
inputs[0].addGrad(reduceAs(grad_output, inputs[0]));
170+
};
171+
return Variable(result, {input}, grad_func);
172+
}
173+
174+
Variable reduceAs(const Variable &input, const Variable &reference)
175+
{
176+
dim4 idims = input.array().dims();
177+
dim4 rdims = reference.array().dims();
178+
auto result = input.array();
179+
for (int i = 0; i < 4; i++) {
180+
if (idims[i] != rdims[i]) result = sum(result, i);
181+
}
182+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
183+
inputs[0].addGrad(expandAs(grad_output, inputs[0]));
184+
};
185+
return Variable(result, {input}, grad_func);
186+
}
150187
}
151188
}

0 commit comments

Comments
 (0)