forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathregularizers_test.ts
More file actions
100 lines (92 loc) · 3.6 KB
/
regularizers_test.ts
File metadata and controls
100 lines (92 loc) · 3.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/* Unit tests for constraints */
import {scalar, serialization, Tensor, tensor1d} from '@tensorflow/tfjs-core';
import * as tfl from './index';
import {deserializeRegularizer, getRegularizer, serializeRegularizer} from './regularizers';
import {describeMathCPU, expectTensorsClose} from './utils/test_utils';
describeMathCPU('Built-in Regularizers', () => {
it('l1_l2', () => {
const x = tensor1d([1, -2, 3, -4]);
const regularizer = tfl.regularizers.l1l2();
const score = regularizer.apply(x);
expectTensorsClose(
score, scalar(0.01 * (1 + 2 + 3 + 4) + 0.01 * (1 + 4 + 9 + 16)));
});
it('l1', () => {
const x = tensor1d([1, -2, 3, -4]);
const regularizer = tfl.regularizers.l1();
const score = regularizer.apply(x);
expectTensorsClose(score, scalar(0.01 * (1 + 2 + 3 + 4)));
});
it('l2', () => {
const x = tensor1d([1, -2, 3, -4]);
const regularizer = tfl.regularizers.l2();
const score = regularizer.apply(x);
expectTensorsClose(score, scalar(0.01 * (1 + 4 + 9 + 16)));
});
it('l1_l2 non default', () => {
const x = tensor1d([1, -2, 3, -4]);
const regularizer = tfl.regularizers.l1l2({l1: 1, l2: 2});
const score = regularizer.apply(x);
expectTensorsClose(
score, scalar(1 * (1 + 2 + 3 + 4) + 2 * (1 + 4 + 9 + 16)));
});
it('Using number arg for constructor leads to error', () => {
// tslint:disable-next-line:no-any
expect(() => tfl.regularizers.l1(0.001 as any))
.toThrowError(/expected.*object.*received.*0\.001/);
// tslint:disable-next-line:no-any
expect(() => tfl.regularizers.l2(0.001 as any))
.toThrowError(/expected.*object.*received.*0\.001/);
// tslint:disable-next-line:no-any
expect(() => tfl.regularizers.l1l2(0.001 as any))
.toThrowError(/expected.*object.*received.*0\.001/);
});
});
describeMathCPU('regularizers.get', () => {
let x: Tensor;
beforeEach(() => {
x = tensor1d([1, -2, 3, -4]);
});
it('by string - lower camel', () => {
const regularizer = getRegularizer('l1l2');
expectTensorsClose(regularizer.apply(x), tfl.regularizers.l1l2().apply(x));
});
it('by string - upper camel', () => {
const regularizer = getRegularizer('L1L2');
expectTensorsClose(regularizer.apply(x), tfl.regularizers.l1l2().apply(x));
});
it('by existing object', () => {
const origReg = tfl.regularizers.l1l2({l1: 1, l2: 2});
const regularizer = getRegularizer(origReg);
expect(regularizer).toEqual(origReg);
});
it('by config dict', () => {
const origReg = tfl.regularizers.l1l2({l1: 1, l2: 2});
const regularizer = getRegularizer(
serializeRegularizer(origReg) as serialization.ConfigDict);
expectTensorsClose(regularizer.apply(x), origReg.apply(x));
});
});
describeMathCPU('Regularizer Serialization', () => {
it('Built-ins', () => {
const regularizer = tfl.regularizers.l1l2({l1: 1, l2: 2});
const config =
serializeRegularizer(regularizer) as serialization.ConfigDict;
const reconstituted = deserializeRegularizer(config);
const roundTripConfig =
serializeRegularizer(reconstituted) as serialization.ConfigDict;
expect(roundTripConfig.className).toEqual('L1L2');
const nestedConfig = roundTripConfig.config as serialization.ConfigDict;
expect(nestedConfig.l1).toEqual(1);
expect(nestedConfig.l2).toEqual(2);
});
});