Skip to content

Commit f409c99

Browse files
lisa0314chromium-wpt-export-bot
authored andcommitted
webnn: Enforce input data type constraints for prelu
As specified in webmachinelearning/webnn#646. Besides, this CL also fixes the throwing TypeError and migrates the unit tests to WPT for prelu. Bug: 328567884, 327337526, 328026885 Change-Id: I2c6c0097b8d4fdc8c92bd66d21abd2cbda91e030 Cq-Include-Trybots: luci.chromium.try:win11-blink-rel Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5495874 Reviewed-by: Austin Sullivan <asully@chromium.org> Reviewed-by: ningxin hu <ningxin.hu@intel.com> Commit-Queue: Lisha Guo <lisha.guo@intel.com> Cr-Commit-Position: refs/heads/main@{#1297344}
1 parent b5ab7ed commit f409c99

1 file changed

Lines changed: 84 additions & 0 deletions

File tree

webnn/validation_tests/prelu.https.any.js

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,87 @@
55
'use strict';
66

77
validateTwoInputsFromMultipleBuilders('prelu');
8+
9+
const tests = [
10+
{
11+
name:
12+
'[prelu] Test slope\'s shape = [3, 2, 5] which is the same as input\'s shape.',
13+
input: {dataType: 'float32', dimensions: [3, 2, 5]},
14+
slope: {dataType: 'float32', dimensions: [3, 2, 5]},
15+
output: {dataType: 'float32', dimensions: [3, 2, 5]},
16+
},
17+
{
18+
name:
19+
'[prelu] Test slope\'s shape = [5] which is unidirectionally broadcastable to input\'s shape.',
20+
input: {dataType: 'float32', dimensions: [3, 2, 5]},
21+
slope: {dataType: 'float32', dimensions: [5]},
22+
output: {dataType: 'float32', dimensions: [3, 2, 5]},
23+
},
24+
{
25+
name:
26+
'[prelu] Test slope\'s shape = [] which is unidirectionally broadcastable to input\'s shape.',
27+
input: {dataType: 'float32', dimensions: [3, 2, 5]},
28+
slope: {dataType: 'float32', dimensions: []},
29+
output: {dataType: 'float32', dimensions: [3, 2, 5]},
30+
},
31+
{
32+
name:
33+
'[prelu] Test slope\'s shape = [2, 5] which is unidirectionally broadcastable to input\'s shape.',
34+
input: {dataType: 'float32', dimensions: [3, 2, 5]},
35+
slope: {dataType: 'float32', dimensions: [2, 5]},
36+
output: {dataType: 'float32', dimensions: [3, 2, 5]},
37+
},
38+
{
39+
name:
40+
'[prelu] Test with input\'s dataType = int32 and slope\'s dataType = int32.',
41+
input: {dataType: 'int32', dimensions: [3, 2, 5]},
42+
slope: {dataType: 'int32', dimensions: [2, 5]},
43+
output: {dataType: 'int32', dimensions: [3, 2, 5]},
44+
},
45+
{
46+
name:
47+
'[prelu] Test with input\'s dataType = int8 and slope\'s dataType = int8.',
48+
input: {dataType: 'int8', dimensions: [3, 2, 5]},
49+
slope: {dataType: 'int8', dimensions: [2, 5]},
50+
output: {dataType: 'int8', dimensions: [3, 2, 5]},
51+
},
52+
{
53+
name:
54+
'[prelu] Throw if the shape of slope is not broadcastable to the shape of input.',
55+
input: {dataType: 'float32', dimensions: [3, 2, 5]},
56+
slope: {dataType: 'float32', dimensions: [2]},
57+
},
58+
{
59+
name:
60+
'[prelu] Throw if the data type of slope does not match the data type of input.',
61+
input: {dataType: 'float32', dimensions: [3, 2, 5]},
62+
slope: {dataType: 'int32', dimensions: [3, 2, 5]},
63+
},
64+
{
65+
name: '[prelu] Throw if the data type of input is int64.',
66+
input: {dataType: 'int64', dimensions: [3, 2, 5]},
67+
slope: {dataType: 'int64', dimensions: [3, 2, 5]},
68+
},
69+
{
70+
name: '[prelu] Throw if the data type of input is uint32.',
71+
input: {dataType: 'uint32', dimensions: [3, 2, 5]},
72+
slope: {dataType: 'uint32', dimensions: [3, 2, 5]},
73+
},
74+
];
75+
76+
tests.forEach(
77+
test => promise_test(async t => {
78+
const input = builder.input(
79+
'input',
80+
{dataType: test.input.dataType, dimensions: test.input.dimensions});
81+
const slope = builder.input(
82+
'input',
83+
{dataType: test.slope.dataType, dimensions: test.slope.dimensions});
84+
if (test.output) {
85+
const output = builder.prelu(input, slope);
86+
assert_equals(output.dataType(), test.output.dataType);
87+
assert_array_equals(output.shape(), test.output.dimensions);
88+
} else {
89+
assert_throws_js(TypeError, () => builder.prelu(input, slope));
90+
}
91+
}, test.name));

0 commit comments

Comments
 (0)