Skip to content

Commit e9d5d8f

Browse files
authored
fixing the tensor leak by add tidy block for sync executors and dispose intermediate tensors fro async executors (tensorflow#2756)
BUG
1 parent fe39e34 commit e9d5d8f

6 files changed

Lines changed: 70 additions & 29 deletions

File tree

tfjs-converter/src/operations/executors/control_executor.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ import {scalar} from '@tensorflow/tfjs-core';
2121
import {NamedTensorsMap} from '../../data/types';
2222
import {ExecutionContext} from '../../executor/execution_context';
2323
import {TensorArray} from '../../executor/tensor_array';
24-
import {Node} from '../types';
24+
import {InternalOpAsyncExecutor, Node} from '../types';
2525

2626
import {getParamValue, getTensor} from './utils';
2727

28-
export async function executeOp(
28+
export const executeOp: InternalOpAsyncExecutor = async(
2929
node: Node, tensorMap: NamedTensorsMap,
30-
context: ExecutionContext): Promise<tfc.Tensor[]> {
30+
context: ExecutionContext): Promise<tfc.Tensor[]> => {
3131
switch (node.op) {
3232
case 'LoopCond':
3333
return [
@@ -161,6 +161,6 @@ export async function executeOp(
161161
default:
162162
throw TypeError(`Node type ${node.op} is not implemented`);
163163
}
164-
}
164+
};
165165

166166
export const CATEGORY = 'control';

tfjs-converter/src/operations/executors/dynamic_executor.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ import * as tfc from '@tensorflow/tfjs-core';
1919

2020
import {NamedTensorsMap} from '../../data/types';
2121
import {ExecutionContext} from '../../executor/execution_context';
22-
import {Node} from '../types';
22+
import {InternalOpAsyncExecutor, Node} from '../types';
23+
2324
import {getParamValue} from './utils';
2425

25-
export async function executeOp(
26+
export const executeOp: InternalOpAsyncExecutor = async(
2627
node: Node, tensorMap: NamedTensorsMap,
27-
context: ExecutionContext): Promise<tfc.Tensor[]> {
28+
context: ExecutionContext): Promise<tfc.Tensor[]> => {
2829
switch (node.op) {
2930
case 'NonMaxSuppressionV5':
3031
case 'NonMaxSuppressionV3':
@@ -56,9 +57,12 @@ export async function executeOp(
5657
iouThreshold, scoreThreshold)];
5758
}
5859
case 'Where': {
59-
return [await tfc.whereAsync(
60+
const condition =
6061
(getParamValue('condition', node, tensorMap, context) as tfc.Tensor)
61-
.asType('bool'))];
62+
.asType('bool');
63+
const result = [await tfc.whereAsync(condition)];
64+
condition.dispose();
65+
return result;
6266
}
6367
case 'ListDiff': {
6468
return tfc.setdiff1dAsync(
@@ -68,6 +72,6 @@ export async function executeOp(
6872
default:
6973
throw TypeError(`Node type ${node.op} is not implemented`);
7074
}
71-
}
75+
};
7276

7377
export const CATEGORY = 'dynamic';

tfjs-converter/src/operations/executors/dynamic_executor_test.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,17 @@ describe('dynamic', () => {
171171

172172
expect(validateParam(node, dynamic.json)).toBeTruthy();
173173
});
174+
it('should not have memory leak', async () => {
175+
node.op = 'Where';
176+
node.inputParams = {'condition': createTensorAttr(0)};
177+
const input1 = [tfc.scalar(1)];
178+
spyOn(tfc, 'whereAsync').and.callThrough();
179+
180+
const prevCount = tfc.memory().numTensors;
181+
await executeOp(node, {input1}, context);
182+
const afterCount = tfc.memory().numTensors;
183+
expect(afterCount).toEqual(prevCount + 1);
184+
});
174185
});
175186

176187
describe('ListDiff', () => {

tfjs-converter/src/operations/operation_executor.ts

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,37 +52,45 @@ export function executeOp(
5252
((node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext) => {
5353
switch (node.category) {
5454
case 'arithmetic':
55-
return arithmetic.executeOp(node, tensorMap, context);
55+
return tfc.tidy(
56+
() => arithmetic.executeOp(node, tensorMap, context));
5657
case 'basic_math':
57-
return basicMath.executeOp(node, tensorMap, context);
58+
return tfc.tidy(
59+
() => basicMath.executeOp(node, tensorMap, context));
5860
case 'control':
5961
return control.executeOp(node, tensorMap, context);
6062
case 'convolution':
61-
return convolution.executeOp(node, tensorMap, context);
63+
return tfc.tidy(
64+
() => convolution.executeOp(node, tensorMap, context));
6265
case 'creation':
63-
return creation.executeOp(node, tensorMap, context);
66+
return tfc.tidy(() => creation.executeOp(node, tensorMap, context));
6467
case 'dynamic':
6568
return dynamic.executeOp(node, tensorMap, context);
6669
case 'evaluation':
67-
return evaluation.executeOp(node, tensorMap, context);
70+
return tfc.tidy(
71+
() => evaluation.executeOp(node, tensorMap, context));
6872
case 'image':
69-
return image.executeOp(node, tensorMap, context);
73+
return tfc.tidy(() => image.executeOp(node, tensorMap, context));
7074
case 'graph':
71-
return graph.executeOp(node, tensorMap, context);
75+
return tfc.tidy(() => graph.executeOp(node, tensorMap, context));
7276
case 'logical':
73-
return logical.executeOp(node, tensorMap, context);
77+
return tfc.tidy(() => logical.executeOp(node, tensorMap, context));
7478
case 'matrices':
75-
return matrices.executeOp(node, tensorMap, context);
79+
return tfc.tidy(() => matrices.executeOp(node, tensorMap, context));
7680
case 'normalization':
77-
return normalization.executeOp(node, tensorMap, context);
81+
return tfc.tidy(
82+
() => normalization.executeOp(node, tensorMap, context));
7883
case 'reduction':
79-
return reduction.executeOp(node, tensorMap, context);
84+
return tfc.tidy(
85+
() => reduction.executeOp(node, tensorMap, context));
8086
case 'slice_join':
81-
return sliceJoin.executeOp(node, tensorMap, context);
87+
return tfc.tidy(
88+
() => sliceJoin.executeOp(node, tensorMap, context));
8289
case 'spectral':
83-
return spectral.executeOp(node, tensorMap, context);
90+
return tfc.tidy(() => spectral.executeOp(node, tensorMap, context));
8491
case 'transformation':
85-
return transformation.executeOp(node, tensorMap, context);
92+
return tfc.tidy(
93+
() => transformation.executeOp(node, tensorMap, context));
8694
case 'custom':
8795
const opMapper = getRegisteredOp(node.op);
8896
if (opMapper && opMapper.customExecutor) {

tfjs-converter/src/operations/operation_executor_test.ts

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
* =============================================================================
1616
*/
1717

18+
import * as tfc from '@tensorflow/tfjs-core';
1819
import {add, mul, scalar, Tensor, test_util} from '@tensorflow/tfjs-core';
1920

2021
import {ExecutionContext} from '../executor/execution_context';
2122

2223
import {deregisterOp, registerOp} from './custom_op/register';
2324
import * as arithmetic from './executors/arithmetic_executor';
2425
import * as basic_math from './executors/basic_math_executor';
26+
import * as control from './executors/control_executor';
2527
import * as convolution from './executors/convolution_executor';
2628
import * as creation from './executors/creation_executor';
2729
import * as dynamic from './executors/dynamic_executor';
@@ -56,9 +58,9 @@ describe('OperationExecutor', () => {
5658
});
5759

5860
describe('executeOp', () => {
59-
[arithmetic, basic_math, convolution, creation, dynamic, evaluation, image,
60-
graph, logical, matrices, normalization, reduction, slice_join, spectral,
61-
transformation]
61+
[arithmetic, basic_math, convolution, control, creation, dynamic,
62+
evaluation, image, graph, logical, matrices, normalization, reduction,
63+
slice_join, spectral, transformation]
6264
.forEach(category => {
6365
it('should call ' + category.CATEGORY + ' executor', () => {
6466
spyOn(category, 'executeOp');
@@ -67,6 +69,17 @@ describe('OperationExecutor', () => {
6769
expect(category.executeOp).toHaveBeenCalledWith(node, {}, context);
6870
});
6971
});
72+
[arithmetic, basic_math, convolution, creation, evaluation, image, graph,
73+
logical, matrices, normalization, reduction, slice_join, spectral,
74+
transformation]
75+
.forEach(category => {
76+
it('should call tidy around executor', () => {
77+
spyOn(tfc, 'tidy');
78+
node.category = category.CATEGORY;
79+
executeOp(node, {}, context);
80+
expect(tfc.tidy).toHaveBeenCalled();
81+
});
82+
});
7083
});
7184

7285
describe('custom op executeOp', () => {

tfjs-converter/src/operations/types.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,17 @@ export declare interface AttrParamMapper extends ParamMapper {
7575

7676
export interface InternalOpExecutor {
7777
(node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext): Tensor
78-
|Tensor[]|Promise<Tensor|Tensor[]>;
78+
|Tensor[];
79+
}
80+
81+
export interface InternalOpAsyncExecutor {
82+
(node: Node, tensorMap: NamedTensorsMap,
83+
context: ExecutionContext): Promise<Tensor[]>;
7984
}
8085

8186
export declare interface OpMapper {
8287
tfOpName: string;
83-
category: Category;
88+
category?: Category;
8489
inputs?: InputParamMapper[];
8590
attrs?: AttrParamMapper[];
8691
customExecutor?: OpExecutor;

0 commit comments

Comments
 (0)