Skip to content

Commit ecc55a3

Browse files
Validate the hidden size of GRU and LSTM operators (#644)
* Validate the maximum hidden size of GRU and LSTM operators For gru()/gruCell() and lstm()/lstmCell(), a hiddenSize parameter is passed, a multiple of which defines a dimension of the output. - This is sometimes implicitly validated by the presence of an option with the same dimension - but not always. - Some underlying platforms operate on a single bias tensor, rather than the two bias/recurrentBias options present in the WebNN API. So the combined size needs to also be a valid dimension. Introduce validation for all cases, validate the combined size, and add an explanation inline since this is subtle. Fixes #625 * fix merge residue
1 parent 1c9948d commit ecc55a3

1 file changed

Lines changed: 20 additions & 0 deletions

File tree

index.bs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3060,6 +3060,11 @@ partial interface MLGraphBuilder {
30603060
1. If |input|'s [=MLOperand/dataType=] is not {{MLOperandDataType/"float32"}} or {{MLOperandDataType/"float16"}}, then [=exception/throw=] a {{TypeError}}.
30613061
1. If the [=MLOperand/rank=] of any of |input|, |weight| or |recurrentWeight| is not 3, then [=exception/throw=] a {{TypeError}}.
30623062
1. If the [=MLOperand/dataType=] of either |weight| or |recurrentWeight| is not equal to |input|'s [=MLOperand/dataType=], then [=exception/throw=] a {{TypeError}}.
3063+
1. If |hiddenSize| * 6 is not a [=valid dimension=], then [=exception/throw=] a {{TypeError}}.
3064+
<details class=note>
3065+
<summary>Why |hiddenSize| * 6 ?</summary>
3066+
Some underlying platforms operate on a single bias tensor which is a concatenation of {{MLGruOptions/bias}} and {{MLGruOptions/recurrentBias}}. Therefore, 3 * |hiddenSize| + 3 * |hiddenSize| must also be a [=valid dimension=].
3067+
</details>
30633068
1. If |options|.{{MLGruOptions/bias}} [=map/exists=]:
30643069
1. If its [=MLOperand/dataType=] is not equal to |input|'s [=MLOperand/dataType=], then [=exception/throw=] a {{TypeError}}.
30653070
1. If its [=MLOperand/shape=][1] is not equal to 3 * |hiddenSize|, then [=exception/throw=] a {{TypeError}}.
@@ -3240,6 +3245,11 @@ partial interface MLGraphBuilder {
32403245
1. If |input|'s [=MLOperand/dataType=] is not {{MLOperandDataType/"float32"}} or {{MLOperandDataType/"float16"}}, then [=exception/throw=] a {{TypeError}}.
32413246
1. If the [=MLOperand/rank=] of any of |input|, |weight|, |recurrentWeight| or |hiddenState| is not 2, then [=exception/throw=] a {{TypeError}}.
32423247
1. If the [=MLOperand/dataType=] of any of |weight|, |recurrentWeight|, or |hiddenState| is not equal to |input|'s [=MLOperand/dataType=], then [=exception/throw=] a {{TypeError}}.
3248+
1. If |hiddenSize| * 6 is not a [=valid dimension=], then [=exception/throw=] a {{TypeError}}.
3249+
<details class=note>
3250+
<summary>Why |hiddenSize| * 6 ?</summary>
3251+
Some underlying platforms operate on a single bias tensor which is a concatenation of {{MLGruCellOptions/bias}} and {{MLGruCellOptions/recurrentBias}}. Therefore, 3 * |hiddenSize| + 3 * |hiddenSize| must also be a [=valid dimension=].
3252+
</details>
32433253
1. If |weight|'s [=MLOperand/shape=][0] is not equal to 3 * |hiddenSize|, then [=exception/throw=] a {{TypeError}}.
32443254
1. If |recurrentWeight|'s [=MLOperand/shape=][0] is not equal to 3 * |hiddenSize|, then [=exception/throw=] a {{TypeError}}.
32453255
1. If |options|.{{MLGruOptions/bias}} [=map/exists=]:
@@ -3978,6 +3988,11 @@ partial interface MLGraphBuilder {
39783988
1. If the [=MLOperand/rank=] of any of |input|, |weight| or |recurrentWeight| is not 3, then [=exception/throw=] a {{TypeError}}.
39793989
1. If |input|'s [=MLOperand/shape=][0] is not equal to |steps|, then [=exception/throw=] a {{TypeError}}.
39803990
1. If the [=MLOperand/dataType=] of either |weight| or |recurrentWeight| is not equal to |input|'s [=MLOperand/dataType=], then [=exception/throw=] a {{TypeError}}.
3991+
1. If |hiddenSize| * 8 is not a [=valid dimension=], then [=exception/throw=] a {{TypeError}}.
3992+
<details class=note>
3993+
<summary>Why |hiddenSize| * 8 ?</summary>
3994+
Some underlying platforms operate on a single bias tensor which is a concatenation of {{MLLstmOptions/bias}} and {{MLLstmOptions/recurrentBias}}. Therefore, 4 * |hiddenSize| + 4 * |hiddenSize| must also be a [=valid dimension=].
3995+
</details>
39813996
1. Let |batchSize| be |input|'s [=MLOperand/shape=][1].
39823997
1. If |options|.{{MLLstmOptions/bias}} [=map/exists=]:
39833998
1. If its [=MLOperand/dataType=] is not equal to |input|'s [=MLOperand/dataType=], then [=exception/throw=] a {{TypeError}}.
@@ -4195,6 +4210,11 @@ partial interface MLGraphBuilder {
41954210
1. If |input|'s [=MLOperand/dataType=] is not {{MLOperandDataType/"float32"}} or {{MLOperandDataType/"float16"}}, then [=exception/throw=] a {{TypeError}}.
41964211
1. If the [=MLOperand/rank=] of any of |input|, |weight|, |recurrentWeight|, |hiddenState| or |cellState| is not 2, then [=exception/throw=] a {{TypeError}}.
41974212
1. If the [=MLOperand/dataType=] of any of |weight|, |recurrentWeight|, |hiddenState| or |cellState| is not equal to |input|'s [=MLOperand/dataType=], then [=exception/throw=] a {{TypeError}}.
4213+
1. If |hiddenSize| * 8 is not a [=valid dimension=], then [=exception/throw=] a {{TypeError}}.
4214+
<details class=note>
4215+
<summary>Why |hiddenSize| * 8 ?</summary>
4216+
Some underlying platforms operate on a single bias tensor which is a concatenation of {{MLLstmCellOptions/bias}} and {{MLLstmCellOptions/recurrentBias}}. Therefore, 4 * |hiddenSize| + 4 * |hiddenSize| must also be a [=valid dimension=].
4217+
</details>
41984218
1. Let |batchSize| be |input|'s [=MLOperand/shape=][0].
41994219
1. If |options|.{{MLLstmCellOptions/bias}} [=map/exists=]:
42004220
1. If its [=MLOperand/dataType=] is not equal to |input|'s [=MLOperand/dataType=], then [=exception/throw=] a {{TypeError}}.

0 commit comments

Comments
 (0)