@@ -5554,28 +5554,38 @@ partial dictionary MLOpSupportLimits {
55545554 builder, input, weight, recurrentWeight, steps, hiddenSize, options) {
55555555 const batchSize = input.shape[1] ;
55565556 const inputSize = input.shape[2] ;
5557- const numDirections = (options.direction == 'both' ? 2 : 1);
5557+ const direction = options.direction || 'forward' ;
5558+ const numDirections = (direction == 'both' ? 2 : 1);
55585559 let hiddenState = options.initialHiddenState;
55595560 let cellState = options.initialCellState;
55605561
55615562 if (!hiddenState) {
5562- const desc = {dataType: 'float32' , shape: [numDirections, 1, hiddenSize] };
5563- const totalSize = numDirections * hiddenSize;
5563+ const desc = {
5564+ dataType: 'float32' ,
5565+ shape: [numDirections, batchSize, hiddenSize]
5566+ };
5567+ const totalSize = numDirections * batchSize * hiddenSize;
55645568 hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
55655569 }
55665570
55675571 if (!cellState) {
5568- const desc = {dataType: 'float32' , shape: [numDirections, 1, hiddenSize] };
5569- const totalSize = numDirections * hiddenSize;
5572+ const desc = {
5573+ dataType: 'float32' ,
5574+ shape: [numDirections, batchSize, hiddenSize]
5575+ };
5576+ const totalSize = numDirections * batchSize * hiddenSize;
55705577 cellState = builder.constant(desc, new Float32Array(totalSize).fill(0));
55715578 }
55725579
5573- let sequence = null;
55745580 let currentWeight = [];
55755581 let currentRecurrentWeight = [];
55765582 let currentBias = [];
55775583 let currentRecurrentBias = [];
55785584 let currentPeepholeWeight = [];
5585+ let forwardSequence = null;
5586+ let backwardSequence = null;
5587+ let outputHidden = null;
5588+ let outputCell = null;
55795589
55805590 for (let dir = 0; dir < numDirections; ++dir) {
55815591 currentWeight.push(squeeze(
@@ -5605,36 +5615,26 @@ partial dictionary MLOpSupportLimits {
56055615 builder.slice(
56065616 options.peepholeWeight, [dir, 0] , [1, 3 * hiddenSize] ))) :
56075617 null);
5608- }
5609-
5610- for (let step = 0; step < steps; ++step) {
5611- let currentHidden = [];
5612- let currentCell = [];
5613- let nextHidden = null;
5614- let nextCell = null;
56155618
5616- for (let dir = 0; dir < numDirections; ++dir) {
5617- currentHidden.push(squeeze(
5618- builder,
5619- builder.slice(hiddenState, [dir, 0, 0] , [1, batchSize, hiddenSize] )));
5620- currentCell.push(squeeze(
5621- builder,
5622- builder.slice(cellState, [dir, 0, 0] , [1, batchSize, hiddenSize] )));
5623- }
5619+ let currentHidden = squeeze(
5620+ builder,
5621+ builder.slice(hiddenState, [dir, 0, 0] , [1, batchSize, hiddenSize] ));
5622+ let currentCell = squeeze(
5623+ builder,
5624+ builder.slice(cellState, [dir, 0, 0] , [1, batchSize, hiddenSize] ));
56245625
5625- for (let dir = 0; dir < numDirections; ++dir) {
5626- let slice =
5627- (dir == 1 || options.direction == 'backward' ? steps - step - 1 : step);
5628- let currentInput = squeeze(
5626+ for (let step = 0; step < steps; ++step) {
5627+ const slice = (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
5628+ const currentInput = squeeze(
56295629 builder,
56305630 builder.slice(input, [slice, 0, 0] , [1, batchSize, inputSize] ));
56315631
5632- let results = builder.lstmCell(
5632+ [currentHidden, currentCell] = builder.lstmCell(
56335633 currentInput,
56345634 currentWeight[dir] ,
56355635 currentRecurrentWeight[dir] ,
5636- currentHidden[dir] ,
5637- currentCell[dir] ,
5636+ currentHidden,
5637+ currentCell,
56385638 hiddenSize,
56395639 {
56405640 bias: currentBias[dir] ,
@@ -5644,27 +5644,56 @@ partial dictionary MLOpSupportLimits {
56445644 activations: options.activations
56455645 });
56465646
5647- let output = builder.reshape(results[0] , [1, batchSize, hiddenSize] );
5648- let cell = builder.reshape(results[1] , [1, batchSize, hiddenSize] );
5649-
5650- nextHidden =
5651- (nextHidden ? builder.concat([nextHidden, output] , 0) : output);
5652- nextCell = (nextCell ? builder.concat([nextCell, cell] , 0) : cell);
5647+ if (options.returnSequence) {
5648+ // Expand currentHidden of 2D([batchSize, hiddenSize] )
5649+ // to 4D([steps, numDirections, batchSize, hiddenSize] )
5650+ const expandedHiddenAs4D = builder.reshape(
5651+ currentHidden, [1, 1, batchSize, hiddenSize] );
5652+
5653+ if (direction == 'forward' || (dir == 0 && direction == 'both' )) {
5654+ forwardSequence = forwardSequence ?
5655+ builder.concat([forwardSequence, expandedHiddenAs4D] , 0) :
5656+ expandedHiddenAs4D;
5657+ } else if (direction == 'backward' || (dir == 1 && direction == 'both' )) {
5658+ backwardSequence = backwardSequence ?
5659+ builder.concat([expandedHiddenAs4D, backwardSequence] , 0) :
5660+ expandedHiddenAs4D;
5661+ }
5662+ }
56535663 }
56545664
5655- hiddenState = nextHidden;
5656- cellState = nextCell;
5665+ // Expand currentHidden of 2D([batchSize, hiddenSize] )
5666+ // to 3D([numDirections, batchSize, hiddenSize] )
5667+ const expandedHiddenAs3D = builder.reshape(
5668+ currentHidden, [1, batchSize, hiddenSize] );
5669+ outputHidden = outputHidden ?
5670+ builder.concat([outputHidden, expandedHiddenAs3D] , 0) :
5671+ expandedHiddenAs3D;
5672+
5673+ // Expand currentCell of 2D([batchSize, hiddenSize] )
5674+ // to 3D([numDirections, batchSize, hiddenSize] )
5675+ const expandedCellAs3D = builder.reshape(
5676+ currentCell, [1, batchSize, hiddenSize] );
5677+ outputCell = outputCell ?
5678+ builder.concat([outputCell, expandedCellAs3D] , 0) : expandedCellAs3D;
5679+ }
56575680
5658- if (options.returnSequence) {
5659- nextHidden =
5660- builder.reshape(nextHidden, [1, numDirections, batchSize, hiddenSize] );
5661- sequence =
5662- (sequence ? builder.concat([sequence, nextHidden] , 0) : nextHidden);
5681+ if (options.returnSequence) {
5682+ let outputSequence = null;
5683+
5684+ if (direction == 'forward' ) {
5685+ outputSequence = forwardSequence;
5686+ } else if (direction == 'backward' ) {
5687+ outputSequence = backwardSequence;
5688+ } else if (direction == 'both' ) {
5689+ // Concat along axis 1 (numDirections dimension)
5690+ outputSequence = builder.concat([forwardSequence, backwardSequence] , 1);
56635691 }
5664- }
56655692
5666- return (
5667- sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState] );
5693+ return [outputHidden, outputCell, outputSequence] ;
5694+ } else {
5695+ return [outputHidden, outputCell] ;
5696+ }
56685697 }
56695698 </pre>
56705699</details>
0 commit comments