Skip to content

Commit 0f05dbc

Browse files
axingingannxingyuan
authored andcommitted
Add depthwiseConv2D op for webgpu (tensorflow#1843)
FEATURE Benchamrks: CPU: 'Mean time: 1.041 ms' 'Min time: 0.760 ms' WebGL 1 'Mean time: 12.942 ms' 'Min time: 9.250 ms' WebGL 2 'Mean time: 10.862 ms' 'Min time: 5.965 ms' WebGPU 'Mean time: 4.423 ms' 'Min time: 3.950 ms' Test device: Chrome 78.0.3887/Macbook Pro/Intel Core I5/Mojave.
1 parent 4858434 commit 0f05dbc

7 files changed

Lines changed: 128 additions & 11 deletions

File tree

tfjs-backend-webgpu/src/backend_webgpu.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import {BinaryOpProgram} from './kernels/binary_op_webgpu';
2929
import {ConcatProgram} from './kernels/concat_webgpu';
3030
import {Conv2DMMProgram} from './kernels/conv2d_mm_webgpu';
3131
import {Conv2DNaiveProgram} from './kernels/conv2d_naive_webgpu';
32+
import {DepthwiseConv2DProgram} from './kernels/depthwise_conv2d_webgpu';
3233
import {MatMulPackedProgram} from './kernels/matmul_packed_webgpu';
3334
import {MatMulProgram} from './kernels/matmul_webgpu';
3435
import {MaxPoolProgram} from './kernels/maxpool_webgpu';
@@ -558,6 +559,13 @@ export class WebGPUBackend extends KernelBackend {
558559
Tensor4D;
559560
}
560561

562+
depthwiseConv2D(
563+
x: Tensor4D, filter: Tensor4D,
564+
convInfo: backend_util.Conv2DInfo): Tensor4D {
565+
const program = new DepthwiseConv2DProgram(convInfo);
566+
return this.compileAndRun(program, [x, filter]);
567+
}
568+
561569
private argMinMaxReduce(x: Tensor, axis: number, reduceType: 'min'|'max'):
562570
Tensor {
563571
const program = new ArgMinMaxProgram(x.shape, axis, reduceType);

tfjs-backend-webgpu/src/benchmark_ops_test.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,14 @@ describeWebGPU('Ops benchmarks', () => {
125125

126126
await time(() => tf.conv2d(a, b, 1, 'same'));
127127
});
128+
129+
it('depthwiseconv2d', async () => {
130+
const x = tf.randomNormal<tf.Rank.R4>([1, 128, 128, 1]);
131+
const w = tf.tensor4d(
132+
[0.303873, 0.229223, 0.144333, 0.803373],
133+
[2, 2, 1, 1],
134+
);
135+
136+
await time(() => tf.depthwiseConv2d(x, w, 1, 'valid'));
137+
});
128138
});

tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,6 @@ export class Conv2DMMProgram implements WebGPUProgram {
6767
this.userCode = `
6868
${matMulSource}
6969
70-
bool coordIsValid(ivec4 coord, ivec4 shape) {
71-
return all(greaterThanEqual(coord, ivec4(0))) &&
72-
all(lessThan(coord, shape));
73-
}
74-
7570
int batch;
7671
7772
float mm_readA(uint row, uint col) {

tfjs-backend-webgpu/src/kernels/conv2d_naive_webgpu.ts

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ export class Conv2DNaiveProgram implements WebGPUProgram {
4444
() => 'TODO: Dilation is unimplemented');
4545

4646
this.userCode = `
47-
bool coordIsValid(ivec4 coord, ivec4 shape) {
48-
return all(greaterThanEqual(coord, ivec4(0))) &&
49-
all(lessThan(coord, shape));
50-
}
51-
5247
float readInp(uint batch, uint row, uint col, uint chan) {
5348
ivec4 coord = ivec4(batch, row, col, chan);
5449
return coordIsValid(coord, xShape) ? getX(coord) : 0;
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {backend_util, util} from '@tensorflow/tfjs-core';
19+
import {computeDispatch} from '../webgpu_util';
20+
import {WebGPUProgram} from './webgpu_program';
21+
22+
export class DepthwiseConv2DProgram implements WebGPUProgram {
23+
outputShape: number[];
24+
userCode: string;
25+
dispatchLayout: {x: number[], y: number[], z: number[]};
26+
dispatch: [number, number, number];
27+
variableNames = ['x', 'W'];
28+
uniforms = 'ivec2 filterDims, pad, stride;';
29+
workGroupSize: [number, number, number] = [4, 8, 1];
30+
31+
constructor(convInfo: backend_util.Conv2DInfo) {
32+
this.outputShape = convInfo.outShape;
33+
this.dispatchLayout = {x: [2], y: [1], z: [0, 3]};
34+
this.dispatch = computeDispatch(
35+
this.dispatchLayout, this.outputShape, this.workGroupSize);
36+
const xNumRows = convInfo.inHeight;
37+
const xNumCols = convInfo.inWidth;
38+
const padTop = convInfo.padInfo.top;
39+
const padLeft = convInfo.padInfo.left;
40+
const strideHeight = convInfo.strideHeight;
41+
const strideWidth = convInfo.strideWidth;
42+
const dilationHeight = convInfo.dilationHeight;
43+
const dilationWidth = convInfo.dilationWidth;
44+
const filterHeight = convInfo.filterHeight;
45+
const filterWidth = convInfo.filterWidth;
46+
const channelMul = convInfo.outChannels / convInfo.inChannels;
47+
48+
util.assert(
49+
convInfo.dataFormat === 'channelsLast',
50+
() => 'TODO: NCHW is unimplemented');
51+
util.assert(
52+
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1,
53+
() => 'TODO: Dilation is unimplemented');
54+
55+
this.userCode = `
56+
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
57+
const ivec2 pads = ivec2(${padTop}, ${padLeft});
58+
59+
void writeResult(int batch, int row, int col, int chan, float value) {
60+
ivec4 coord = ivec4(batch, row, col, chan);
61+
if (coordIsValid(coord, outShape)) {
62+
setOutput(batch, row, col, chan, value);
63+
}
64+
}
65+
66+
void main() {
67+
ivec4 coords = getOutputCoords();
68+
int batch = coords[0];
69+
ivec2 xRCCorner = coords.yz * strides - pads;
70+
int d2 = coords[3];
71+
int d1 = d2 / ${channelMul};
72+
int q = d2 - d1 * ${channelMul};
73+
74+
int xRCorner = xRCCorner.x;
75+
int xCCorner = xRCCorner.y;
76+
77+
// Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).
78+
// ? = to be determined. : = across all values in that axis.
79+
float dotProd = 0.0;
80+
// TODO(xing.xu): Flatten the two for loops and vec4 the operations.
81+
for (int wR = 0; wR < ${filterHeight}; wR++) {
82+
int xR = xRCorner + wR * ${dilationHeight};
83+
84+
if (xR < 0 || xR >= ${xNumRows}) {
85+
continue;
86+
}
87+
88+
for (int wC = 0; wC < ${filterWidth}; wC++) {
89+
int xC = xCCorner + wC * ${dilationWidth};
90+
91+
if (xC < 0 || xC >= ${xNumCols}) {
92+
continue;
93+
}
94+
95+
float xVal = getX(batch, xR, xC, d1);
96+
float wVal = getW(wR, wC, d1, q);
97+
dotProd += xVal * wVal;
98+
}
99+
}
100+
writeResult(batch, coords[1], coords[2], d2, dotProd);
101+
}
102+
`;
103+
}
104+
}

tfjs-backend-webgpu/src/setup_test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ const env = jasmine.getEnv();
3030
const INCLUDE_LIST: string[] = [
3131
'matmul', 'add ', 'subtract ', 'mul ', 'conv2d', 'pad', 'pool', 'maxPool',
3232
'floor divide ', 'resizeBilinear', 'relu', 'transpose', 'concat', 'argmax',
33-
'fromPixels'
33+
'fromPixels', 'depthwise'
3434
];
3535
/** Tests that have these substrings in their name will be excluded. */
3636
const EXCLUDE_LIST: string[] = [

tfjs-backend-webgpu/src/shader_preprocessor.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ const SHADER_PREFIX = `#version 450
136136
}
137137
return res;
138138
}
139+
140+
bool coordIsValid(ivec4 coord, ivec4 shape) {
141+
return all(greaterThanEqual(coord, ivec4(0))) &&
142+
all(lessThan(coord, shape));
143+
}
139144
`;
140145

141146
const SAMPLING_SNIPPETS = `

0 commit comments

Comments
 (0)