Skip to content

Commit 03cd3c5

Browse files
shiyi9801chromium-wpt-export-bot
authored andcommitted
webnn: Enforce input data type constraints for gemm and matmul
As specified in webmachinelearning/webnn#646 Bug: 328567884 Change-Id: Ia55a214e7ad281ec3c8911e9116f388fac209d05
1 parent 78709a8 commit 03cd3c5

2 files changed

Lines changed: 24 additions & 12 deletions

File tree

webnn/validation_tests/gemm.https.any.js

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ const tests = [
3434
b: {dataType: 'float32', dimensions: [2, 4]},
3535
},
3636
{
37-
name: 'Test building gemm with aTranspose=true.',
37+
name: '[gemm] Test building gemm with aTranspose=true.',
3838
a: {dataType: 'float32', dimensions: [2, 3]},
3939
b: {dataType: 'float32', dimensions: [2, 4]},
4040
options: {
@@ -44,15 +44,15 @@ const tests = [
4444
},
4545
{
4646
name:
47-
'Throw if inputShapeA[0] is not equal to inputShapeB[0] with aTranspose=true.',
47+
'[gemm] Throw if inputShapeA[0] is not equal to inputShapeB[0] with aTranspose=true.',
4848
a: {dataType: 'float32', dimensions: [2, 3]},
4949
b: {dataType: 'float32', dimensions: [3, 4]},
5050
options: {
5151
aTranspose: true,
5252
},
5353
},
5454
{
55-
name: 'Test building gemm with bTranspose=true.',
55+
name: '[gemm] Test building gemm with bTranspose=true.',
5656
a: {dataType: 'float32', dimensions: [2, 3]},
5757
b: {dataType: 'float32', dimensions: [4, 3]},
5858
options: {
@@ -62,30 +62,30 @@ const tests = [
6262
},
6363
{
6464
name:
65-
'Throw if inputShapeA[0] is not equal to inputShapeB[0] with bTranspose=true.',
65+
'[gemm] Throw if inputShapeA[0] is not equal to inputShapeB[0] with bTranspose=true.',
6666
a: {dataType: 'float32', dimensions: [2, 3]},
6767
b: {dataType: 'float32', dimensions: [3, 4]},
6868
options: {
6969
bTranspose: true,
7070
},
7171
},
7272
{
73-
name: 'Throw if the rank of inputA is not 2.',
73+
name: '[gemm] Throw if the rank of inputA is not 2.',
7474
a: {dataType: 'float32', dimensions: [2, 3, 1]},
7575
b: {dataType: 'float32', dimensions: [2, 4]},
7676
},
7777
{
78-
name: 'Throw if the rank of inputB is not 2.',
78+
name: '[gemm] Throw if the rank of inputB is not 2.',
7979
a: {dataType: 'float32', dimensions: [2, 4]},
8080
b: {dataType: 'float32', dimensions: [2, 3, 1]},
8181
},
8282
{
83-
name: 'Throw if data types of two inputs do not match.',
83+
name: '[gemm] Throw if data types of two inputs do not match.',
8484
a: {dataType: 'float32', dimensions: [2, 3]},
8585
b: {dataType: 'int32', dimensions: [3, 4]},
8686
},
8787
{
88-
name: 'Test building gemm with inputC.',
88+
name: '[gemm] Test building gemm with inputC.',
8989
a: {dataType: 'float32', dimensions: [2, 3]},
9090
b: {dataType: 'float32', dimensions: [3, 4]},
9191
options: {
@@ -94,7 +94,7 @@ const tests = [
9494
output: {dataType: 'float32', dimensions: [2, 4]}
9595
},
9696
{
97-
name: 'Test building gemm with scalar inputC.',
97+
name: '[gemm] Test building gemm with scalar inputC.',
9898
a: {dataType: 'float32', dimensions: [2, 3]},
9999
b: {dataType: 'float32', dimensions: [3, 4]},
100100
options: {
@@ -104,16 +104,21 @@ const tests = [
104104
},
105105
{
106106
name:
107-
'Throw if inputShapeC is not unidirectionally broadcastable to the output shape [inputShapeA[0], inputShapeB[1]].',
107+
'[gemm] Throw if inputShapeC is not unidirectionally broadcastable to the output shape [inputShapeA[0], inputShapeB[1]].',
108108
a: {dataType: 'float32', dimensions: [2, 3]},
109109
b: {dataType: 'float32', dimensions: [3, 4]},
110110
options: {
111111
c: {dataType: 'float32', dimensions: [2, 3]},
112112
},
113113
},
114+
{
115+
name: '[gemm] Throw if the input data type is not floating point.',
116+
a: {dataType: 'int32', dimensions: [2, 3]},
117+
b: {dataType: 'int32', dimensions: [3, 4]}
118+
},
114119
{
115120
name:
116-
'Throw if data type of inputC does not match ones of inputA and inputB.',
121+
'[gemm] Throw if data type of inputC does not match ones of inputA and inputB.',
117122
a: {dataType: 'float32', dimensions: [3, 2]},
118123
b: {dataType: 'float32', dimensions: [4, 3]},
119124
options: {
@@ -123,7 +128,7 @@ const tests = [
123128
},
124129
},
125130
{
126-
name: 'Throw if the rank of inputC is 3.',
131+
name: '[gemm] Throw if the rank of inputC is 3.',
127132
a: {dataType: 'float32', dimensions: [3, 2]},
128133
b: {dataType: 'float32', dimensions: [4, 3]},
129134
options: {

webnn/validation_tests/matmul.https.any.js

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ const tests = [
6767
},
6868
output: {dataType: 'float32', dimensions: [2, 3, 5]}
6969
},
70+
{
71+
name: '[matmul] Throw if the input data type is not floating point',
72+
inputs: {
73+
a: {dataType: 'uint32', dimensions: [2, 3, 4]},
74+
b: {dataType: 'uint32', dimensions: [2, 4, 5]}
75+
}
76+
},
7077
{
7178
name: '[matmul] Throw if data type of two inputs don\'t match',
7279
inputs: {

0 commit comments

Comments
 (0)