Skip to content

Commit d21efb0

Browse files
committed
add modularized gradient for squared difference
1 parent 1fca6c3 commit d21efb0

6 files changed

Lines changed: 69 additions & 4 deletions

File tree

tfjs-core/src/backends/cpu/all_kernels.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import './non_max_suppression_v5';
2222
import './square';
2323

24-
// TODO this will be a dist import from types to avoid a circular dependency
24+
// TODO this will be a dist import from types
2525
import {KernelConfig} from '../../kernel_registry';
2626

2727
// Import Kernel Configs here.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/**
2+
* @license
3+
* Copyright 2010 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {SquaredDifference} from '../kernel_names';
19+
import {GradConfig} from '../kernel_registry';
20+
import {mul, sub} from '../ops/binary_ops';
21+
import {scalar} from '../ops/tensor_ops';
22+
import {Tensor} from '../tensor';
23+
24+
export const squaredDifferenceGrad_: GradConfig = {
25+
kernelName: SquaredDifference,
26+
gradFunc: (dy: Tensor, saved: Tensor[]) => {
27+
const [$a, $b] = saved;
28+
const two = scalar(2);
29+
const derA = () => mul(dy, mul(two, sub($a, $b)));
30+
const derB = () => mul(dy, mul(two, sub($b, $a)));
31+
return {$a: derA, $b: derB};
32+
}
33+
};

tfjs-core/src/index.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ import './backends/cpu/backend_cpu';
3737
import './backends/cpu/register_all_kernels';
3838
// Import all kernels from webgl.
3939
import './backends/webgl/register_all_kernels';
40+
41+
// Register all the gradients
42+
import './register_all_gradients';
43+
4044
import './platforms/platform_browser';
4145
import './platforms/platform_node';
4246

tfjs-core/src/ops/all_gradients.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,12 @@
2020
// the contents of this file and import only the gradients that are needed.
2121

2222
import './square_grad';
23+
24+
// Import Grad Configs here.
25+
import {squaredDifferenceGrad_} from '../gradients/SquaredDifference_grad';
26+
import {GradConfig} from '../kernel_registry';
27+
28+
// Export all kernel configs here so that the package can auto register them
29+
export const gradConfigs: GradConfig[] = [
30+
squaredDifferenceGrad_,
31+
];

tfjs-core/src/ops/ops.ts

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
* =============================================================================
1616
*/
1717

18-
// Importing this file registers gradients in the global registry.
19-
import './all_gradients';
20-
2118
// Modularized ops.
2219
export {square} from './square';
2320
export {squaredDifference} from './squared_difference';
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/**
2+
* @license
3+
* Copyright 2020 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
import {registerGradient} from './kernel_registry';
18+
import {gradConfigs} from './ops/all_gradients';
19+
20+
for (const gradientConfig of gradConfigs) {
21+
registerGradient(gradientConfig);
22+
}

0 commit comments

Comments
 (0)