1919import { BackendTimingInfo , DataMover , DataType , fill , KernelBackend , ones , Rank , rsqrt , Scalar , scalar , ShapeMap , Tensor , Tensor1D , tensor1d , Tensor2D , tensor2d , Tensor3D , tensor3d , Tensor4D , tidy , util } from '@tensorflow/tfjs-core' ;
2020import { EPSILON_FLOAT32 } from '@tensorflow/tfjs-core/dist/backends/backend' ;
2121import { Conv2DInfo , Conv3DInfo } from '@tensorflow/tfjs-core/dist/ops/conv_util' ;
22- import { Activation } from '@tensorflow/tfjs-core/dist/ops/fused_util' ;
22+ // tslint:disable-next-line:max-line-length
23+ import { Activation , FusedBatchMatMulConfig } from '@tensorflow/tfjs-core/dist/ops/fused_util' ;
2324import { Tensor5D } from '@tensorflow/tfjs-core/dist/tensor' ;
2425import { BackendValues , upcastType } from '@tensorflow/tfjs-core/dist/types' ;
2526import { isNullOrUndefined } from 'util' ;
@@ -369,9 +370,8 @@ export class NodeJSKernelBackend extends KernelBackend {
369370 }
370371
371372 fusedBatchMatMul (
372- a : Tensor3D , b : Tensor3D , transposeA : boolean , transposeB : boolean ,
373- bias ?: Tensor , activation ?: Activation ,
374- preluActivationWeights ?: Tensor ) : Tensor3D {
373+ { a, b, transposeA, transposeB, bias, activation, preluActivationWeights} :
374+ FusedBatchMatMulConfig ) : Tensor3D {
375375 // Core TensorFlow does not have a fused BatchMatMul op. Combine calls to
376376 // achieve the same results:
377377 let result = this . batchMatMul ( a , b , transposeA , transposeB ) ;
@@ -431,6 +431,10 @@ export class NodeJSKernelBackend extends KernelBackend {
431431 return this . executeSingleInput ( 'Neg' , a ) as T ;
432432 }
433433
434+ diag ( x : Tensor ) : Tensor {
435+ return this . executeSingleInput ( 'Diag' , x ) ;
436+ }
437+
434438 add ( a : Tensor , b : Tensor ) : Tensor {
435439 const opAttrs = [ createTypeOpAttr ( 'T' , upcastType ( a . dtype , b . dtype ) ) ] ;
436440 return this . executeSingleOutput ( 'Add' , opAttrs , [ a , b ] ) ;
@@ -1162,13 +1166,15 @@ export class NodeJSKernelBackend extends KernelBackend {
11621166 `TF Backend supports only 'valid' and 'same' padding ` +
11631167 `while padding was ${ convInfo . padInfo . type } ` ) ;
11641168 }
1165- const ksize = [ 1 , convInfo . filterDepth , convInfo . filterHeight ,
1166- convInfo . filterWidth , 1 ] ;
1167- const strides = [ 1 , convInfo . strideDepth , convInfo . strideHeight ,
1168- convInfo . strideWidth , 1 ] ;
1169+ const ksize = [
1170+ 1 , convInfo . filterDepth , convInfo . filterHeight , convInfo . filterWidth , 1
1171+ ] ;
1172+ const strides = [
1173+ 1 , convInfo . strideDepth , convInfo . strideHeight , convInfo . strideWidth , 1
1174+ ] ;
11691175 const padding = convInfo . padInfo . type ;
1170- const dataFormat = convInfo . dataFormat === 'channelsLast' ?
1171- 'NDHWC' : 'NCDHW' ;
1176+ const dataFormat =
1177+ convInfo . dataFormat === 'channelsLast' ? 'NDHWC' : 'NCDHW' ;
11721178 const opAttrs = [
11731179 createTypeOpAttr ( 'T' , x . dtype ) ,
11741180 { name : 'ksize' , type : this . binding . TF_ATTR_INT , value : ksize } ,
@@ -1183,20 +1189,21 @@ export class NodeJSKernelBackend extends KernelBackend {
11831189 return this . executeSingleOutput ( 'AvgPool3D' , opAttrs , [ x ] ) as Tensor5D ;
11841190 }
11851191
1186- avgPool3dBackprop (
1187- dy : Tensor5D , x : Tensor5D , convInfo : Conv3DInfo ) : Tensor5D {
1192+ avgPool3dBackprop ( dy : Tensor5D , x : Tensor5D , convInfo : Conv3DInfo ) : Tensor5D {
11881193 if ( convInfo . padInfo . type !== 'VALID' && convInfo . padInfo . type !== 'SAME' ) {
11891194 throw new Error (
11901195 `TF Backend supports only 'valid' and 'same' padding ` +
11911196 `while padding type was ${ convInfo . padInfo . type } ` ) ;
11921197 }
1193- const ksize = [ 1 , convInfo . filterDepth , convInfo . filterHeight ,
1194- convInfo . filterWidth , 1 ] ;
1195- const strides = [ 1 , convInfo . strideDepth , convInfo . strideHeight ,
1196- convInfo . strideWidth , 1 ] ;
1198+ const ksize = [
1199+ 1 , convInfo . filterDepth , convInfo . filterHeight , convInfo . filterWidth , 1
1200+ ] ;
1201+ const strides = [
1202+ 1 , convInfo . strideDepth , convInfo . strideHeight , convInfo . strideWidth , 1
1203+ ] ;
11971204 const padding = convInfo . padInfo . type ;
1198- const dataFormat = convInfo . dataFormat === 'channelsLast' ?
1199- 'NDHWC' : 'NCDHW' ;
1205+ const dataFormat =
1206+ convInfo . dataFormat === 'channelsLast' ? 'NDHWC' : 'NCDHW' ;
12001207 const opAttrs = [
12011208 createTypeOpAttr ( 'T' , x . dtype ) ,
12021209 { name : 'ksize' , type : this . binding . TF_ATTR_INT , value : ksize } ,
@@ -1210,7 +1217,7 @@ export class NodeJSKernelBackend extends KernelBackend {
12101217 ] ;
12111218 const origInputShape = tensor1d ( x . shape , 'int32' ) ;
12121219 return this . executeSingleOutput (
1213- 'AvgPool3DGrad' , opAttrs , [ origInputShape , dy ] ) as Tensor5D ;
1220+ 'AvgPool3DGrad' , opAttrs , [ origInputShape , dy ] ) as Tensor5D ;
12141221 }
12151222
12161223 maxPool3d ( x : Tensor5D , convInfo : Conv3DInfo ) : Tensor5D {
@@ -1219,13 +1226,15 @@ export class NodeJSKernelBackend extends KernelBackend {
12191226 `TF Backend supports only 'valid' and 'same' padding ` +
12201227 `while padding was ${ convInfo . padInfo . type } ` ) ;
12211228 }
1222- const ksize = [ 1 , convInfo . filterDepth , convInfo . filterHeight ,
1223- convInfo . filterWidth , 1 ] ;
1224- const strides = [ 1 , convInfo . strideDepth , convInfo . strideHeight ,
1225- convInfo . strideWidth , 1 ] ;
1229+ const ksize = [
1230+ 1 , convInfo . filterDepth , convInfo . filterHeight , convInfo . filterWidth , 1
1231+ ] ;
1232+ const strides = [
1233+ 1 , convInfo . strideDepth , convInfo . strideHeight , convInfo . strideWidth , 1
1234+ ] ;
12261235 const padding = convInfo . padInfo . type ;
1227- const dataFormat = convInfo . dataFormat === 'channelsLast' ?
1228- 'NDHWC' : 'NCDHW' ;
1236+ const dataFormat =
1237+ convInfo . dataFormat === 'channelsLast' ? 'NDHWC' : 'NCDHW' ;
12291238 const opAttrs = [
12301239 createTypeOpAttr ( 'T' , x . dtype ) ,
12311240 { name : 'ksize' , type : this . binding . TF_ATTR_INT , value : ksize } ,
@@ -1246,13 +1255,15 @@ export class NodeJSKernelBackend extends KernelBackend {
12461255 `TF Backend supports only 'valid' and 'same' padding ` +
12471256 `while padding type was ${ convInfo . padInfo . type } ` ) ;
12481257 }
1249- const ksize = [ 1 , convInfo . filterDepth , convInfo . filterHeight ,
1250- convInfo . filterWidth , 1 ] ;
1251- const strides = [ 1 , convInfo . strideDepth , convInfo . strideHeight ,
1252- convInfo . strideWidth , 1 ] ;
1258+ const ksize = [
1259+ 1 , convInfo . filterDepth , convInfo . filterHeight , convInfo . filterWidth , 1
1260+ ] ;
1261+ const strides = [
1262+ 1 , convInfo . strideDepth , convInfo . strideHeight , convInfo . strideWidth , 1
1263+ ] ;
12531264 const padding = convInfo . padInfo . type ;
1254- const dataFormat = convInfo . dataFormat === 'channelsLast' ?
1255- 'NDHWC' : 'NCDHW' ;
1265+ const dataFormat =
1266+ convInfo . dataFormat === 'channelsLast' ? 'NDHWC' : 'NCDHW' ;
12561267 const opAttrs = [
12571268 createTypeOpAttr ( 'T' , x . dtype ) ,
12581269 { name : 'ksize' , type : this . binding . TF_ATTR_INT , value : ksize } ,
0 commit comments