Skip to content

Commit 94d41fe

Browse files
authored
tfjs-node: Update fusedMatMul interface to take config, exclude… (tensorflow#1890)
FEATURE
1 parent b42d122 commit 94d41fe

5 files changed

Lines changed: 194 additions & 177 deletions

File tree

tfjs-core/src/ops/diag_test.ts

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,4 @@ describeWithFlags('diag', ALL_ENVS, () => {
101101
expect(diag.dtype).toBe('bool');
102102
expectArraysEqual(await diag.data(), [1, 0, 0, 1]);
103103
});
104-
it('complex', () => {
105-
const real = tf.tensor1d([2.25]);
106-
const imag = tf.tensor1d([4.75]);
107-
const m = tf.complex(real, imag);
108-
expect(() => tf.diag(m)).toThrowError();
109-
});
110104
});

tfjs-node/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"yalc": "~1.0.0-pre.21"
5050
},
5151
"dependencies": {
52-
"@tensorflow/tfjs": "1.2.7",
52+
"@tensorflow/tfjs": "1.2.8",
5353
"adm-zip": "^0.4.11",
5454
"https-proxy-agent": "^2.2.1",
5555
"node-pre-gyp": "0.13.0",

tfjs-node/src/nodejs_kernel_backend.ts

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import {BackendTimingInfo, DataMover, DataType, fill, KernelBackend, ones, Rank, rsqrt, Scalar, scalar, ShapeMap, Tensor, Tensor1D, tensor1d, Tensor2D, tensor2d, Tensor3D, tensor3d, Tensor4D, tidy, util} from '@tensorflow/tfjs-core';
2020
import {EPSILON_FLOAT32} from '@tensorflow/tfjs-core/dist/backends/backend';
2121
import {Conv2DInfo, Conv3DInfo} from '@tensorflow/tfjs-core/dist/ops/conv_util';
22-
import {Activation} from '@tensorflow/tfjs-core/dist/ops/fused_util';
22+
// tslint:disable-next-line:max-line-length
23+
import {Activation, FusedBatchMatMulConfig} from '@tensorflow/tfjs-core/dist/ops/fused_util';
2324
import {Tensor5D} from '@tensorflow/tfjs-core/dist/tensor';
2425
import {BackendValues, upcastType} from '@tensorflow/tfjs-core/dist/types';
2526
import {isNullOrUndefined} from 'util';
@@ -369,9 +370,8 @@ export class NodeJSKernelBackend extends KernelBackend {
369370
}
370371

371372
fusedBatchMatMul(
372-
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
373-
bias?: Tensor, activation?: Activation,
374-
preluActivationWeights?: Tensor): Tensor3D {
373+
{a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:
374+
FusedBatchMatMulConfig): Tensor3D {
375375
// Core TensorFlow does not have a fused BatchMatMul op. Combine calls to
376376
// achieve the same results:
377377
let result = this.batchMatMul(a, b, transposeA, transposeB);
@@ -431,6 +431,10 @@ export class NodeJSKernelBackend extends KernelBackend {
431431
return this.executeSingleInput('Neg', a) as T;
432432
}
433433

434+
diag(x: Tensor): Tensor {
435+
return this.executeSingleInput('Diag', x);
436+
}
437+
434438
add(a: Tensor, b: Tensor): Tensor {
435439
const opAttrs = [createTypeOpAttr('T', upcastType(a.dtype, b.dtype))];
436440
return this.executeSingleOutput('Add', opAttrs, [a, b]);
@@ -1162,13 +1166,15 @@ export class NodeJSKernelBackend extends KernelBackend {
11621166
`TF Backend supports only 'valid' and 'same' padding ` +
11631167
`while padding was ${convInfo.padInfo.type}`);
11641168
}
1165-
const ksize = [1, convInfo.filterDepth, convInfo.filterHeight,
1166-
convInfo.filterWidth, 1];
1167-
const strides = [1, convInfo.strideDepth, convInfo.strideHeight,
1168-
convInfo.strideWidth, 1];
1169+
const ksize = [
1170+
1, convInfo.filterDepth, convInfo.filterHeight, convInfo.filterWidth, 1
1171+
];
1172+
const strides = [
1173+
1, convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth, 1
1174+
];
11691175
const padding = convInfo.padInfo.type;
1170-
const dataFormat = convInfo.dataFormat === 'channelsLast' ?
1171-
'NDHWC' : 'NCDHW';
1176+
const dataFormat =
1177+
convInfo.dataFormat === 'channelsLast' ? 'NDHWC' : 'NCDHW';
11721178
const opAttrs = [
11731179
createTypeOpAttr('T', x.dtype),
11741180
{name: 'ksize', type: this.binding.TF_ATTR_INT, value: ksize},
@@ -1183,20 +1189,21 @@ export class NodeJSKernelBackend extends KernelBackend {
11831189
return this.executeSingleOutput('AvgPool3D', opAttrs, [x]) as Tensor5D;
11841190
}
11851191

1186-
avgPool3dBackprop(
1187-
dy: Tensor5D, x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
1192+
avgPool3dBackprop(dy: Tensor5D, x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
11881193
if (convInfo.padInfo.type !== 'VALID' && convInfo.padInfo.type !== 'SAME') {
11891194
throw new Error(
11901195
`TF Backend supports only 'valid' and 'same' padding ` +
11911196
`while padding type was ${convInfo.padInfo.type}`);
11921197
}
1193-
const ksize = [1, convInfo.filterDepth, convInfo.filterHeight,
1194-
convInfo.filterWidth, 1];
1195-
const strides = [1, convInfo.strideDepth, convInfo.strideHeight,
1196-
convInfo.strideWidth, 1];
1198+
const ksize = [
1199+
1, convInfo.filterDepth, convInfo.filterHeight, convInfo.filterWidth, 1
1200+
];
1201+
const strides = [
1202+
1, convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth, 1
1203+
];
11971204
const padding = convInfo.padInfo.type;
1198-
const dataFormat = convInfo.dataFormat === 'channelsLast' ?
1199-
'NDHWC' : 'NCDHW';
1205+
const dataFormat =
1206+
convInfo.dataFormat === 'channelsLast' ? 'NDHWC' : 'NCDHW';
12001207
const opAttrs = [
12011208
createTypeOpAttr('T', x.dtype),
12021209
{name: 'ksize', type: this.binding.TF_ATTR_INT, value: ksize},
@@ -1210,7 +1217,7 @@ export class NodeJSKernelBackend extends KernelBackend {
12101217
];
12111218
const origInputShape = tensor1d(x.shape, 'int32');
12121219
return this.executeSingleOutput(
1213-
'AvgPool3DGrad', opAttrs, [origInputShape, dy]) as Tensor5D;
1220+
'AvgPool3DGrad', opAttrs, [origInputShape, dy]) as Tensor5D;
12141221
}
12151222

12161223
maxPool3d(x: Tensor5D, convInfo: Conv3DInfo): Tensor5D {
@@ -1219,13 +1226,15 @@ export class NodeJSKernelBackend extends KernelBackend {
12191226
`TF Backend supports only 'valid' and 'same' padding ` +
12201227
`while padding was ${convInfo.padInfo.type}`);
12211228
}
1222-
const ksize = [1, convInfo.filterDepth, convInfo.filterHeight,
1223-
convInfo.filterWidth, 1];
1224-
const strides = [1, convInfo.strideDepth, convInfo.strideHeight,
1225-
convInfo.strideWidth, 1];
1229+
const ksize = [
1230+
1, convInfo.filterDepth, convInfo.filterHeight, convInfo.filterWidth, 1
1231+
];
1232+
const strides = [
1233+
1, convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth, 1
1234+
];
12261235
const padding = convInfo.padInfo.type;
1227-
const dataFormat = convInfo.dataFormat === 'channelsLast' ?
1228-
'NDHWC' : 'NCDHW';
1236+
const dataFormat =
1237+
convInfo.dataFormat === 'channelsLast' ? 'NDHWC' : 'NCDHW';
12291238
const opAttrs = [
12301239
createTypeOpAttr('T', x.dtype),
12311240
{name: 'ksize', type: this.binding.TF_ATTR_INT, value: ksize},
@@ -1246,13 +1255,15 @@ export class NodeJSKernelBackend extends KernelBackend {
12461255
`TF Backend supports only 'valid' and 'same' padding ` +
12471256
`while padding type was ${convInfo.padInfo.type}`);
12481257
}
1249-
const ksize = [1, convInfo.filterDepth, convInfo.filterHeight,
1250-
convInfo.filterWidth, 1];
1251-
const strides = [1, convInfo.strideDepth, convInfo.strideHeight,
1252-
convInfo.strideWidth, 1];
1258+
const ksize = [
1259+
1, convInfo.filterDepth, convInfo.filterHeight, convInfo.filterWidth, 1
1260+
];
1261+
const strides = [
1262+
1, convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth, 1
1263+
];
12531264
const padding = convInfo.padInfo.type;
1254-
const dataFormat = convInfo.dataFormat === 'channelsLast' ?
1255-
'NDHWC' : 'NCDHW';
1265+
const dataFormat =
1266+
convInfo.dataFormat === 'channelsLast' ? 'NDHWC' : 'NCDHW';
12561267
const opAttrs = [
12571268
createTypeOpAttr('T', x.dtype),
12581269
{name: 'ksize', type: this.binding.TF_ATTR_INT, value: ksize},

tfjs-node/src/run_tests.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,20 @@ const IGNORE_LIST: string[] = [
5555
// https://github.com/tensorflow/tfjs/issues/1077
5656
'maxPool test-tensorflow {} x=[2,2,3] f=[1,1] s=2 p=1 dimRoundingMode=floor',
5757
'avgPool test-tensorflow {} x=[2,2,3] f=[1,1] s=2 p=1 dimRoundingMode=floor',
58+
// tslint:disable-next-line:max-line-length
59+
'avgPool3d test-tensorflow {} x=[1,2,2,2,1] f=[2,2,2] s=1 p=1 roundingMode=floor',
60+
// tslint:disable-next-line:max-line-length
61+
'maxPool3d test-tensorflow {} x=[1,2,2,2,1] f=[2,2,2] s=1 p=1 roundingMode=floor',
5862
// libtensorflow doesn't support 6D ArgMax yet.
59-
'Reduction: argmax test-tensorflow {} 6D, axis=0'
63+
'Reduction: argmax test-tensorflow {} 6D, axis=0',
64+
'diag test-tensorflow {} complex', 'diag test-tensorflow {} bool',
65+
// See https://github.com/tensorflow/tfjs/issues/1891
66+
'conv2d test-tensorflow {} x=[2,1,2,2] f=[1,1,1,1] s=1 d=1 p=0 NCHW',
67+
'conv2d test-tensorflow {} x=[1,2,2] f=[2,2,1,1] s=1 d=1 p=same NCHW',
68+
'conv2d test-tensorflow {} x=[2,2,2] f=[2,2,2,1] s=1 d=1 p=same NCHW',
69+
'conv2d test-tensorflow {} x=[2,1,2,2] f=[2,2,1,1] s=1 d=1 p=same NCHW',
70+
'conv2d test-tensorflow {} gradient x=[1,1,3,3] f=[2,2,1,1] s=1 p=0 NCHW',
71+
'conv2d test-tensorflow {} gradient x=[2,1,3,3] f=[2,2,1,1] s=1 p=0 NCHW'
6072
];
6173

6274
if (process.platform === 'win32') {

0 commit comments

Comments
 (0)