Skip to content

Commit 88a35ff

Browse files
committed
Fix emulation error of lstm by 'backward' and 'both' direction options
1 parent df555e1 commit 88a35ff

1 file changed

Lines changed: 73 additions & 44 deletions

File tree

index.bs

Lines changed: 73 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5554,28 +5554,38 @@ partial dictionary MLOpSupportLimits {
55545554
builder, input, weight, recurrentWeight, steps, hiddenSize, options) {
55555555
const batchSize = input.shape[1];
55565556
const inputSize = input.shape[2];
5557-
const numDirections = (options.direction == 'both' ? 2 : 1);
5557+
const direction = options.direction || 'forward';
5558+
const numDirections = (direction == 'both' ? 2 : 1);
55585559
let hiddenState = options.initialHiddenState;
55595560
let cellState = options.initialCellState;
55605561

55615562
if (!hiddenState) {
5562-
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
5563-
const totalSize = numDirections * hiddenSize;
5563+
const desc = {
5564+
dataType: 'float32',
5565+
shape: [numDirections, batchSize, hiddenSize]
5566+
};
5567+
const totalSize = numDirections * batchSize * hiddenSize;
55645568
hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
55655569
}
55665570

55675571
if (!cellState) {
5568-
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
5569-
const totalSize = numDirections * hiddenSize;
5572+
const desc = {
5573+
dataType: 'float32',
5574+
shape: [numDirections, batchSize, hiddenSize]
5575+
};
5576+
const totalSize = numDirections * batchSize * hiddenSize;
55705577
cellState = builder.constant(desc, new Float32Array(totalSize).fill(0));
55715578
}
55725579

5573-
let sequence = null;
55745580
let currentWeight = [];
55755581
let currentRecurrentWeight = [];
55765582
let currentBias = [];
55775583
let currentRecurrentBias = [];
55785584
let currentPeepholeWeight = [];
5585+
let forwardSequence = null;
5586+
let backwardSequence = null;
5587+
let outputHidden = null;
5588+
let outputCell = null;
55795589

55805590
for (let dir = 0; dir < numDirections; ++dir) {
55815591
currentWeight.push(squeeze(
@@ -5605,36 +5615,26 @@ partial dictionary MLOpSupportLimits {
56055615
builder.slice(
56065616
options.peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) :
56075617
null);
5608-
}
5609-
5610-
for (let step = 0; step < steps; ++step) {
5611-
let currentHidden = [];
5612-
let currentCell = [];
5613-
let nextHidden = null;
5614-
let nextCell = null;
56155618

5616-
for (let dir = 0; dir < numDirections; ++dir) {
5617-
currentHidden.push(squeeze(
5618-
builder,
5619-
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])));
5620-
currentCell.push(squeeze(
5621-
builder,
5622-
builder.slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize])));
5623-
}
5619+
let currentHidden = squeeze(
5620+
builder,
5621+
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]));
5622+
let currentCell = squeeze(
5623+
builder,
5624+
builder.slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize]));
56245625

5625-
for (let dir = 0; dir < numDirections; ++dir) {
5626-
let slice =
5627-
(dir == 1 || options.direction == 'backward' ? steps - step - 1 : step);
5628-
let currentInput = squeeze(
5626+
for (let step = 0; step < steps; ++step) {
5627+
const slice = (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
5628+
const currentInput = squeeze(
56295629
builder,
56305630
builder.slice(input, [slice, 0, 0], [1, batchSize, inputSize]));
56315631

5632-
let results = builder.lstmCell(
5632+
[currentHidden, currentCell] = builder.lstmCell(
56335633
currentInput,
56345634
currentWeight[dir],
56355635
currentRecurrentWeight[dir],
5636-
currentHidden[dir],
5637-
currentCell[dir],
5636+
currentHidden,
5637+
currentCell,
56385638
hiddenSize,
56395639
{
56405640
bias: currentBias[dir],
@@ -5644,27 +5644,56 @@ partial dictionary MLOpSupportLimits {
56445644
activations: options.activations
56455645
});
56465646

5647-
let output = builder.reshape(results[0], [1, batchSize, hiddenSize]);
5648-
let cell = builder.reshape(results[1], [1, batchSize, hiddenSize]);
5649-
5650-
nextHidden =
5651-
(nextHidden ? builder.concat([nextHidden, output], 0) : output);
5652-
nextCell = (nextCell ? builder.concat([nextCell, cell], 0) : cell);
5647+
if (options.returnSequence) {
5648+
// Expand currentHidden of 2D([batchSize, hiddenSize])
5649+
// to 4D([steps, numDirections, batchSize, hiddenSize])
5650+
const expandedHiddenAs4D = builder.reshape(
5651+
currentHidden, [1, 1, batchSize, hiddenSize]);
5652+
5653+
if (direction == 'forward' || (dir == 0 && direction == 'both')) {
5654+
forwardSequence = forwardSequence ?
5655+
builder.concat([forwardSequence, expandedHiddenAs4D], 0) :
5656+
expandedHiddenAs4D;
5657+
} else if (direction == 'backward' || (dir == 1 && direction == 'both')) {
5658+
backwardSequence = backwardSequence ?
5659+
builder.concat([expandedHiddenAs4D, backwardSequence], 0) :
5660+
expandedHiddenAs4D;
5661+
}
5662+
}
56535663
}
56545664

5655-
hiddenState = nextHidden;
5656-
cellState = nextCell;
5665+
// Expand currentHidden of 2D([batchSize, hiddenSize])
5666+
// to 3D([numDirections, batchSize, hiddenSize])
5667+
const expandedHiddenAs3D = builder.reshape(
5668+
currentHidden, [1, batchSize, hiddenSize]);
5669+
outputHidden = outputHidden ?
5670+
builder.concat([outputHidden, expandedHiddenAs3D], 0) :
5671+
expandedHiddenAs3D;
5672+
5673+
// Expand currentCell of 2D([batchSize, hiddenSize])
5674+
// to 3D([numDirections, batchSize, hiddenSize])
5675+
const expandedCellAs3D = builder.reshape(
5676+
currentCell, [1, batchSize, hiddenSize]);
5677+
outputCell = outputCell ?
5678+
builder.concat([outputCell, expandedCellAs3D], 0) : expandedCellAs3D;
5679+
}
56575680

5658-
if (options.returnSequence) {
5659-
nextHidden =
5660-
builder.reshape(nextHidden, [1, numDirections, batchSize, hiddenSize]);
5661-
sequence =
5662-
(sequence ? builder.concat([sequence, nextHidden], 0) : nextHidden);
5681+
if (options.returnSequence) {
5682+
let outputSequence = null;
5683+
5684+
if (direction == 'forward') {
5685+
outputSequence = forwardSequence;
5686+
} else if (direction == 'backward') {
5687+
outputSequence = backwardSequence;
5688+
} else if (direction == 'both') {
5689+
// Concat along axis 1 (numDirections dimension)
5690+
outputSequence = builder.concat([forwardSequence, backwardSequence], 1);
56635691
}
5664-
}
56655692

5666-
return (
5667-
sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]);
5693+
return [outputHidden, outputCell, outputSequence];
5694+
} else {
5695+
return [outputHidden, outputCell];
5696+
}
56685697
}
56695698
</pre>
56705699
</details>

0 commit comments

Comments
 (0)