Skip to content

Commit fa00e25

Browse files
karikeradsmilkov
authored andcommitted
fix padding_test.ts & 'for .. in' to 'for .. of' (tensorflow#2388)
DEV Fix for .. in for Array.
1 parent 0f31b6b commit fa00e25

3 files changed

Lines changed: 12 additions & 11 deletions

File tree

tfjs-core/src/engine.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -935,8 +935,8 @@ export class Engine implements TensorTracker, DataMover {
935935
// This means that we are not computing higher-order gradients
936936
// and can clean up the tape.
937937
this.state.activeTape.forEach(node => {
938-
for (const key in node.saved) {
939-
node.saved[key].dispose();
938+
for (const tensor of node.saved) {
939+
tensor.dispose();
940940
}
941941
});
942942
this.state.activeTape = null;

tfjs-layers/src/engine/training_dataset.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ function standardizeDataIteratorOutput(
238238
`provides ${flattenedYs.length} outputs. (Expected output keys: ` +
239239
`${JSON.stringify(model.outputNames)})`);
240240

241-
for (const xIndex in flattenedXs) {
241+
for (let xIndex = 0; xIndex < flattenedXs.length; xIndex++) {
242242
tfc.util.assert(
243243
flattenedXs[xIndex].shape[0] === batchSize,
244244
() => `Batch size mismatch: input ` +
@@ -247,7 +247,7 @@ function standardizeDataIteratorOutput(
247247
`expected ${batchSize} based on input ${model.inputNames[0]}.`);
248248
}
249249

250-
for (const yIndex in flattenedYs) {
250+
for (let yIndex = 0; yIndex < flattenedYs.length; yIndex++) {
251251
tfc.util.assert(
252252
flattenedYs[yIndex].shape[0] === batchSize,
253253
() => `Batch size mismatch: output ` +

tfjs-layers/src/layers/padding_test.ts

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ describeMathCPU('ZeroPadding2D: Symbolic', () => {
7878
const dataFormats: DataFormat[] =
7979
[undefined, 'channelsFirst', 'channelsLast'];
8080

81-
for (const dataFormat in dataFormats) {
81+
for (const dataFormat of dataFormats) {
8282
it('Default padding 1-1-1-1, dataFormat=' + dataFormat, () => {
8383
const x = new SymbolicTensor('float32', [1, 2, 3, 4], null, [], null);
84-
const layer = tfl.layers.zeroPadding2d();
84+
const layer = tfl.layers.zeroPadding2d({dataFormat});
8585
const y = layer.apply(x) as SymbolicTensor;
8686
expect(y.dtype).toEqual('float32');
8787
if (dataFormat === 'channelsFirst') {
@@ -93,19 +93,19 @@ describeMathCPU('ZeroPadding2D: Symbolic', () => {
9393

9494
it('All symmetric padding 2, dataFormat=' + dataFormat, () => {
9595
const x = new SymbolicTensor('float32', [1, 2, 3, 4], null, [], null);
96-
const layer = tfl.layers.zeroPadding2d({padding: 2});
96+
const layer = tfl.layers.zeroPadding2d({dataFormat, padding: 2});
9797
const y = layer.apply(x) as SymbolicTensor;
9898
expect(y.dtype).toEqual('float32');
9999
if (dataFormat === 'channelsFirst') {
100-
expect(y.shape).toEqual([1, 6, 7, 8]);
100+
expect(y.shape).toEqual([1, 2, 7, 8]);
101101
} else {
102102
expect(y.shape).toEqual([1, 6, 7, 4]);
103103
}
104104
});
105105

106106
it('Symmetric padding 2-3, dataFormat=' + dataFormat, () => {
107107
const x = new SymbolicTensor('float32', [1, 2, 3, 4], null, [], null);
108-
const layer = tfl.layers.zeroPadding2d({padding: [2, 3]});
108+
const layer = tfl.layers.zeroPadding2d({dataFormat, padding: [2, 3]});
109109
const y = layer.apply(x) as SymbolicTensor;
110110
expect(y.dtype).toEqual('float32');
111111
if (dataFormat === 'channelsFirst') {
@@ -117,11 +117,12 @@ describeMathCPU('ZeroPadding2D: Symbolic', () => {
117117

118118
it('Asymmetric padding 2-3-4-5, dataFormat=' + dataFormat, () => {
119119
const x = new SymbolicTensor('float32', [1, 2, 3, 4], null, [], null);
120-
const layer = tfl.layers.zeroPadding2d({padding: [[2, 3], [4, 5]]});
120+
const layer =
121+
tfl.layers.zeroPadding2d({dataFormat, padding: [[2, 3], [4, 5]]});
121122
const y = layer.apply(x) as SymbolicTensor;
122123
expect(y.dtype).toEqual('float32');
123124
if (dataFormat === 'channelsFirst') {
124-
expect(y.shape).toEqual([1, 2, 7, 13]);
125+
expect(y.shape).toEqual([1, 2, 8, 13]);
125126
} else {
126127
expect(y.shape).toEqual([1, 7, 12, 4]);
127128
}

0 commit comments

Comments
 (0)