@@ -184,5 +184,68 @@ namespace af {
184184 };
185185 return Variable (result, {input}, grad_func);
186186 }
187+
188+ Variable matmul (const Variable &lhs, const Variable &rhs)
189+ {
190+ // lhs:Input[0] -- [M, N]
191+ // rhs:Input[1] -- [N, K]
192+ // matmul(lhs, rhs)
193+ // -- matmul([M, N], [N, K]) -- [M, K]
194+ // result:grad_output -- [M, K]
195+ auto result = matmul (lhs.array (), rhs.array ());
196+ auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
197+ // matmulNT(grad_output, inputs[1])
198+ // -- matmulNT([M, K], [N, K])
199+ // -- matmul([M, K], [K, N]) -- [M, K]
200+ inputs[0 ].addGrad (matmulNT (grad_output, inputs[1 ]));
201+ // matmulTN(inputs[0], grad_output)
202+ // -- matmulTN([M, N], [M, K])
203+ // -- matmul([N, M], [M, K]) -- [N, K]
204+ inputs[1 ].addGrad (matmulTN (inputs[0 ], grad_output));
205+ };
206+ return Variable (result, {lhs, rhs}, grad_func);
207+ }
208+
209+ Variable matmulTN (const Variable &lhs, const Variable &rhs)
210+ {
211+ // lhs:Input[0] -- [N, M]
212+ // rhs:Input[1] -- [N, K]
213+ // matmulTN(lhs, rhs)
214+ // -- matmulTN([N, M], [N, K])
215+ // -- matmul([M, N], [N, K]) -- [M, K]
216+ // result:grad_output -- [M, K]
217+ auto result = matmulTN (lhs.array (), rhs.array ());
218+ auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
219+ // matmulNT(inputs[1], grad_output)
220+ // -- matmulNT([N, K], [M, K])
221+ // -- matmul([N, K], [K, M]) -- [N, M]
222+ inputs[0 ].addGrad (matmulNT (inputs[1 ], grad_output));
223+ // matmul(inputs[0], grad_output)
224+ // -- matmulNT([N, M], [M, K]) -- [N, K]
225+ inputs[1 ].addGrad (matmul (inputs[0 ], grad_output));
226+ };
227+ return Variable (result, {lhs, rhs}, grad_func);
228+ }
229+
230+ Variable matmulNT (const Variable &lhs, const Variable &rhs)
231+ {
232+ // lhs:Input[0] -- [M, N]
233+ // rhs:Input[1] -- [K, N]
234+ // matmulNT(lhs, rhs)
235+ // -- matmulNT([M, N], [K, N])
236+ // -- matmul([M, N], [N, K]) -- [M, K]
237+ // result:grad_output -- [M, K]
238+ auto result = matmulNT (lhs.array (), rhs.array ());
239+ auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
240+ // matmul(grad_output, inputs[1])
241+ // -- matmul([M, K], [K, N]) -- [M, N]
242+ inputs[0 ].addGrad (matmul (grad_output, inputs[1 ]));
243+ // matmulTN(grad_output, inputs[0])
244+ // -- matmulTN([M, K], [M, N])
245+ // -- matmul([K, M], [M, N]) -- [K, N]
246+ inputs[1 ].addGrad (matmulTN (grad_output, inputs[0 ]));
247+ };
248+ return Variable (result, {lhs, rhs}, grad_func);
249+ }
187250 }
188251}
0 commit comments