Skip to content

Commit 93e738d

Browse files
pavankyumar456
authored andcommitted
Adding matmul, matmulTN, and matmulNT functions
1 parent 546a87a commit 93e738d

2 files changed

Lines changed: 67 additions & 0 deletions

File tree

include/af/autograd/Functions.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,9 @@ namespace af {
4040
Variable transpose(const Variable &input);
4141
Variable expandAs(const Variable &input, const Variable &reference);
4242
Variable reduceAs(const Variable &input, const Variable &reference);
43+
44+
Variable matmul(const Variable &lhs, const Variable &rhs);
45+
Variable matmulTN(const Variable &lhs, const Variable &rhs);
46+
Variable matmulNT(const Variable &lhs, const Variable &rhs);
4347
}
4448
}

src/autograd/Functions.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,5 +184,68 @@ namespace af {
184184
};
185185
return Variable(result, {input}, grad_func);
186186
}
187+
188+
Variable matmul(const Variable &lhs, const Variable &rhs)
189+
{
190+
// lhs:Input[0] -- [M, N]
191+
// rhs:Input[1] -- [N, K]
192+
//matmul(lhs, rhs)
193+
// -- matmul([M, N], [N, K]) -- [M, K]
194+
// result:grad_output -- [M, K]
195+
auto result = matmul(lhs.array(), rhs.array());
196+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
197+
// matmulNT(grad_output, inputs[1])
198+
// -- matmulNT([M, K], [N, K])
199+
// -- matmul([M, K], [K, N]) -- [M, K]
200+
inputs[0].addGrad(matmulNT(grad_output, inputs[1]));
201+
// matmulTN(inputs[0], grad_output)
202+
// -- matmulTN([M, N], [M, K])
203+
// -- matmul([N, M], [M, K]) -- [N, K]
204+
inputs[1].addGrad(matmulTN(inputs[0], grad_output));
205+
};
206+
return Variable(result, {lhs, rhs}, grad_func);
207+
}
208+
209+
Variable matmulTN(const Variable &lhs, const Variable &rhs)
210+
{
211+
// lhs:Input[0] -- [N, M]
212+
// rhs:Input[1] -- [N, K]
213+
// matmulTN(lhs, rhs)
214+
// -- matmulTN([N, M], [N, K])
215+
// -- matmul([M, N], [N, K]) -- [M, K]
216+
// result:grad_output -- [M, K]
217+
auto result = matmulTN(lhs.array(), rhs.array());
218+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
219+
// matmulNT(inputs[1], grad_output)
220+
// -- matmulNT([N, K], [M, K])
221+
// -- matmul([N, K], [K, M]) -- [N, M]
222+
inputs[0].addGrad(matmulNT(inputs[1], grad_output));
223+
// matmul(inputs[0], grad_output)
224+
// -- matmulNT([N, M], [M, K]) -- [N, K]
225+
inputs[1].addGrad(matmul(inputs[0], grad_output));
226+
};
227+
return Variable(result, {lhs, rhs}, grad_func);
228+
}
229+
230+
Variable matmulNT(const Variable &lhs, const Variable &rhs)
231+
{
232+
// lhs:Input[0] -- [M, N]
233+
// rhs:Input[1] -- [K, N]
234+
// matmulNT(lhs, rhs)
235+
// -- matmulNT([M, N], [K, N])
236+
// -- matmul([M, N], [N, K]) -- [M, K]
237+
// result:grad_output -- [M, K]
238+
auto result = matmulNT(lhs.array(), rhs.array());
239+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
240+
// matmul(grad_output, inputs[1])
241+
// -- matmul([M, K], [K, N]) -- [M, N]
242+
inputs[0].addGrad(matmul(grad_output, inputs[1]));
243+
// matmulTN(grad_output, inputs[0])
244+
// -- matmulTN([M, K], [M, N])
245+
// -- matmul([K, M], [M, N]) -- [K, N]
246+
inputs[1].addGrad(matmulTN(grad_output, inputs[0]));
247+
};
248+
return Variable(result, {lhs, rhs}, grad_func);
249+
}
187250
}
188251
}

0 commit comments

Comments
 (0)