|
5 | 5 | 'use strict'; |
6 | 6 |
|
7 | 7 | validateTwoInputsFromMultipleBuilders('prelu'); |
| 8 | + |
| 9 | +const tests = [ |
| 10 | + { |
| 11 | + name: |
| 12 | + '[prelu] Test slope\'s shape = [3, 2, 5] which is the same as input\'s shape.', |
| 13 | + input: {dataType: 'float32', dimensions: [3, 2, 5]}, |
| 14 | + slope: {dataType: 'float32', dimensions: [3, 2, 5]}, |
| 15 | + output: {dataType: 'float32', dimensions: [3, 2, 5]}, |
| 16 | + }, |
| 17 | + { |
| 18 | + name: |
| 19 | + '[prelu] Test slope\'s shape = [5] which is unidirectionally broadcastable to input\'s shape.', |
| 20 | + input: {dataType: 'float32', dimensions: [3, 2, 5]}, |
| 21 | + slope: {dataType: 'float32', dimensions: [5]}, |
| 22 | + output: {dataType: 'float32', dimensions: [3, 2, 5]}, |
| 23 | + }, |
| 24 | + { |
| 25 | + name: |
| 26 | + '[prelu] Test slope\'s shape = [] which is unidirectionally broadcastable to input\'s shape.', |
| 27 | + input: {dataType: 'float32', dimensions: [3, 2, 5]}, |
| 28 | + slope: {dataType: 'float32', dimensions: []}, |
| 29 | + output: {dataType: 'float32', dimensions: [3, 2, 5]}, |
| 30 | + }, |
| 31 | + { |
| 32 | + name: |
| 33 | + '[prelu] Test slope\'s shape = [2, 5] which is unidirectionally broadcastable to input\'s shape.', |
| 34 | + input: {dataType: 'float32', dimensions: [3, 2, 5]}, |
| 35 | + slope: {dataType: 'float32', dimensions: [2, 5]}, |
| 36 | + output: {dataType: 'float32', dimensions: [3, 2, 5]}, |
| 37 | + }, |
| 38 | + { |
| 39 | + name: |
| 40 | + '[prelu] Test with input\'s dataType = int32 and slope\'s dataType = int32.', |
| 41 | + input: {dataType: 'int32', dimensions: [3, 2, 5]}, |
| 42 | + slope: {dataType: 'int32', dimensions: [2, 5]}, |
| 43 | + output: {dataType: 'int32', dimensions: [3, 2, 5]}, |
| 44 | + }, |
| 45 | + { |
| 46 | + name: |
| 47 | + '[prelu] Test with input\'s dataType = int8 and slope\'s dataType = int8.', |
| 48 | + input: {dataType: 'int8', dimensions: [3, 2, 5]}, |
| 49 | + slope: {dataType: 'int8', dimensions: [2, 5]}, |
| 50 | + output: {dataType: 'int8', dimensions: [3, 2, 5]}, |
| 51 | + }, |
| 52 | + { |
| 53 | + name: |
| 54 | + '[prelu] Throw if the shape of slope is not broadcastable to the shape of input.', |
| 55 | + input: {dataType: 'float32', dimensions: [3, 2, 5]}, |
| 56 | + slope: {dataType: 'float32', dimensions: [2]}, |
| 57 | + }, |
| 58 | + { |
| 59 | + name: |
| 60 | + '[prelu] Throw if the data type of slope does not match the data type of input.', |
| 61 | + input: {dataType: 'float32', dimensions: [3, 2, 5]}, |
| 62 | + slope: {dataType: 'int32', dimensions: [3, 2, 5]}, |
| 63 | + }, |
| 64 | + { |
| 65 | + name: '[prelu] Throw if the data type of input is int64.', |
| 66 | + input: {dataType: 'int64', dimensions: [3, 2, 5]}, |
| 67 | + slope: {dataType: 'int64', dimensions: [3, 2, 5]}, |
| 68 | + }, |
| 69 | + { |
| 70 | + name: '[prelu] Throw if the data type of input is uint32.', |
| 71 | + input: {dataType: 'uint32', dimensions: [3, 2, 5]}, |
| 72 | + slope: {dataType: 'uint32', dimensions: [3, 2, 5]}, |
| 73 | + }, |
| 74 | +]; |
| 75 | + |
| 76 | +tests.forEach( |
| 77 | + test => promise_test(async t => { |
| 78 | + const input = builder.input( |
| 79 | + 'input', |
| 80 | + {dataType: test.input.dataType, dimensions: test.input.dimensions}); |
| 81 | + const slope = builder.input( |
| 82 | + 'input', |
| 83 | + {dataType: test.slope.dataType, dimensions: test.slope.dimensions}); |
| 84 | + if (test.output) { |
| 85 | + const output = builder.prelu(input, slope); |
| 86 | + assert_equals(output.dataType(), test.output.dataType); |
| 87 | + assert_array_equals(output.shape(), test.output.dimensions); |
| 88 | + } else { |
| 89 | + assert_throws_js(TypeError, () => builder.prelu(input, slope)); |
| 90 | + } |
| 91 | + }, test.name)); |
0 commit comments