Skip to content

Commit 06bcaad

Browse files
miaobinchromium-wpt-export-bot
authored andcommitted
WebNN: Support axis for softmax operator
This CL adds axis parameter into the IDL and mojo definitions of softmax operator [1]. It also updates the backends implementation to support the new axis parameter. In addition, the corresponding tests have also been updated. [1] webmachinelearning/webnn#649 Bug: 338094927 Change-Id: Ib08ecbba61c27c94256953a952357eeda80241e6 Cq-Include-Trybots: luci.chromium.try:win11-blink-rel,mac14.arm64-blink-rel,mac14-blink-rel Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5495877 Reviewed-by: ningxin hu <ningxin.hu@intel.com> Commit-Queue: Bin Miao <bin.miao@intel.com> Reviewed-by: Austin Sullivan <asully@chromium.org> Cr-Commit-Position: refs/heads/main@{#1314418}
1 parent aa5fb60 commit 06bcaad

4 files changed

Lines changed: 196 additions & 2 deletions

File tree

webnn/conformance_tests/softmax.https.any.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@
1010

1111
// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-softmax
1212

13-
runWebNNConformanceTests('softmax', buildOperationWithSingleInput);
13+
runWebNNConformanceTests('softmax', buildSoftmax);

webnn/resources/test_data/softmax.json

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,93 @@
198198
],
199199
"type": "float32"
200200
}
201+
},
202+
{
203+
"name": "softmax float32 3D constant tensor",
204+
"inputs": {
205+
"x": {
206+
"shape": [1, 3, 4],
207+
"data": [
208+
0.4301910996437073,
209+
0.5471914410591125,
210+
-1.1637765169143677,
211+
0.18390046060085297,
212+
0.583903968334198,
213+
0.17356790602207184,
214+
0.5397239923477173,
215+
-0.9535139799118042,
216+
-0.5920282602310181,
217+
-0.17344485223293304,
218+
0.14395014941692352,
219+
-0.37920907139778137
220+
],
221+
"type": "float32",
222+
"constant": true
223+
}
224+
},
225+
"axis": 1,
226+
"expected": {
227+
"name": "output",
228+
"shape": [1, 3, 4],
229+
"data": [
230+
0.39589041471481323,
231+
0.45983806252479553,
232+
0.09812675416469574,
233+
0.529077410697937,
234+
0.4616699814796448,
235+
0.31647709012031555,
236+
0.5390242338180542,
237+
0.16964708268642426,
238+
0.142439603805542,
239+
0.22368484735488892,
240+
0.36284899711608887,
241+
0.3012755215167999
242+
],
243+
"type": "float32"
244+
}
245+
},
246+
{
247+
"name": "softmax float32 4D tensor",
248+
"inputs": {
249+
"x": {
250+
"shape": [3, 4, 1, 1],
251+
"data": [
252+
0.4301910996437073,
253+
0.5471914410591125,
254+
-1.1637765169143677,
255+
0.18390046060085297,
256+
0.583903968334198,
257+
0.17356790602207184,
258+
0.5397239923477173,
259+
-0.9535139799118042,
260+
-0.5920282602310181,
261+
-0.17344485223293304,
262+
0.14395014941692352,
263+
-0.37920907139778137
264+
],
265+
"type": "float32"
266+
}
267+
},
268+
"axis": 1,
269+
"expected": {
270+
"name": "output",
271+
"shape": [3, 4, 1, 1],
272+
"data": [
273+
0.3216537833213806,
274+
0.3615773916244507,
275+
0.06533370912075043,
276+
0.25143513083457947,
277+
0.35271573066711426,
278+
0.23400123417377472,
279+
0.33747196197509766,
280+
0.07581108063459396,
281+
0.17110128700733185,
282+
0.26004093885421753,
283+
0.3571779429912567,
284+
0.2116798311471939
285+
],
286+
"type": "float32"
287+
}
201288
}
202289
]
203290
}

webnn/resources/utils.js

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,20 @@ const buildSlice = (operationName, builder, resources) => {
785785
return namedOutputOperand;
786786
};
787787

788+
const buildSoftmax = (operationName, builder, resources) => {
789+
// MLOperand softmax(MLOperand input, [EnforceRange] unsigned long axis);
790+
const namedOutputOperand = {};
791+
const inputOperand = createSingleInputOperand(builder, resources);
792+
if (resources.axis !== undefined) {
793+
// invoke builder.softmax(input, axis)
794+
namedOutputOperand[resources.expected.name] = builder[operationName](inputOperand, resources.axis);
795+
} else {
796+
// invoke builder.softmax(input)
797+
namedOutputOperand[resources.expected.name] = builder[operationName](inputOperand);
798+
}
799+
return namedOutputOperand;
800+
};
801+
788802
const buildSplit = (operationName, builder, resources) => {
789803
// sequence<MLOperand> split(MLOperand input,
790804
// (unsigned long or sequence<unsigned long>) splits,

webnn/validation_tests/softmax.https.any.js

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,97 @@
44

55
'use strict';
66

7-
validateInputFromAnotherBuilder('softmax');
7+
const tests_without_axis = [
8+
{
9+
name: '[softmax] Test building Softmax with float32 input without axis.',
10+
input: { dataType: 'float32', dimensions: [4, 3] },
11+
output: { dataType: 'float32', dimensions: [4, 3] }
12+
},
13+
{
14+
name: '[softmax] Test building Softmax with float16 input without axis.',
15+
input: { dataType: 'float16', dimensions: [3, 5] },
16+
output: { dataType: 'float16', dimensions: [3, 5] }
17+
},
18+
{
19+
name: '[softmax] Throw if the input is not a non-floating point data.',
20+
input: { dataType: 'int32', dimensions: [3, 2] }
21+
},
22+
{
23+
name: '[softmax] Throw if the input dimensions is not 2.',
24+
input: { dataType: 'float32', dimensions: [1, 4, 3] }
25+
}
26+
];
27+
28+
tests_without_axis.forEach(test =>
29+
promise_test(async t => {
30+
let input = builder.input(
31+
`input`, { dataType: test.input.dataType, dimensions: test.input.dimensions }
32+
);
33+
if (test.output) {
34+
const output = builder.softmax(input);
35+
assert_equals(output.dataType(), test.output.dataType);
36+
assert_array_equals(output.shape(), test.output.dimensions);
37+
} else {
38+
assert_throws_js(TypeError, () => builder.softmax(input));
39+
}
40+
}, test.name)
41+
);
42+
43+
multi_builder_test(async (t, builder, otherBuilder) => {
44+
const operandDescriptor = { dataType: 'float32', dimensions: [2, 3] };
45+
const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor);
46+
47+
assert_throws_js(
48+
TypeError,
49+
() => builder.softmax(inputFromOtherBuilder));
50+
}, '[softmax without axis] throw if any input is from another builder');
51+
52+
const tests = [
53+
{
54+
name: '[softmax] Test building Softmax with float32 input.',
55+
input: { dataType: 'float32', dimensions: [4, 4, 3] },
56+
axis: 1,
57+
output: { dataType: 'float32', dimensions: [4, 4, 3] }
58+
},
59+
{
60+
name: '[softmax] Test building Softmax with float16 input.',
61+
input: { dataType: 'float16', dimensions: [3, 1, 5, 2] },
62+
axis: 2,
63+
output: { dataType: 'float16', dimensions: [3, 1, 5, 2] }
64+
},
65+
{
66+
name: '[softmax] Throw if the input is not a non-floating-point data.',
67+
input: { dataType: 'int32', dimensions: [3, 1, 5, 2] },
68+
axis: 3
69+
},
70+
{
71+
name: '[softmax] Throw if the axis is greater than input rank - 1.',
72+
input: { dataType: 'float16', dimensions: [3, 1, 5, 2] },
73+
axis: 4
74+
}
75+
];
76+
77+
tests.forEach(test =>
78+
promise_test(async t => {
79+
let input = builder.input(
80+
`input`, { dataType: test.input.dataType, dimensions: test.input.dimensions }
81+
);
82+
if (test.output) {
83+
const output = builder.softmax(input, test.axis);
84+
assert_equals(output.dataType(), test.output.dataType);
85+
assert_array_equals(output.shape(), test.output.dimensions);
86+
} else {
87+
assert_throws_js(TypeError, () => builder.softmax(input, test.axis));
88+
}
89+
}, test.name)
90+
);
91+
92+
multi_builder_test(async (t, builder, otherBuilder) => {
93+
const operandDescriptor = { dataType: 'float32', dimensions: [1, 2, 3] };
94+
const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor);
95+
const axis = 1;
96+
97+
assert_throws_js(
98+
TypeError,
99+
() => builder.softmax(inputFromOtherBuilder, axis));
100+
}, '[softmax] throw if any input is from another builder');

0 commit comments

Comments
 (0)