forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathregularizers.ts
More file actions
149 lines (128 loc) · 4.3 KB
/
regularizers.ts
File metadata and controls
149 lines (128 loc) · 4.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/* original source: keras/regularizers.py */
import * as tfc from '@tensorflow/tfjs-core';
import {abs, add, Scalar, serialization, sum, Tensor, tidy, zeros} from '@tensorflow/tfjs-core';
import * as K from './backend/tfjs_backend';
import {deserializeKerasObject, serializeKerasObject} from './utils/generic_utils';
function assertObjectArgs(args: L1Args | L2Args | L1L2Args): void {
if (args != null && typeof args !== 'object') {
throw new Error(
`Argument to L1L2 regularizer's constructor is expected to be an ` +
`object, but received: ${args}`);
}
}
/**
* Regularizer base class.
*/
export abstract class Regularizer extends serialization.Serializable {
abstract apply(x: Tensor): Scalar;
}
export interface L1L2Args {
/** L1 regularization rate. Defaults to 0.01. */
l1?: number;
/** L2 regularization rate. Defaults to 0.01. */
l2?: number;
}
export interface L1Args {
/** L1 regularization rate. Defaults to 0.01. */
l1: number;
}
export interface L2Args {
/** L2 regularization rate. Defaults to 0.01. */
l2: number;
}
export class L1L2 extends Regularizer {
/** @nocollapse */
static className = 'L1L2';
private readonly l1: number;
private readonly l2: number;
private readonly hasL1: boolean;
private readonly hasL2: boolean;
constructor(args?: L1L2Args) {
super();
assertObjectArgs(args);
this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;
this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;
this.hasL1 = this.l1 !== 0;
this.hasL2 = this.l2 !== 0;
}
/**
* Porting note: Renamed from __call__.
* @param x Variable of which to calculate the regularization score.
*/
apply(x: Tensor): Scalar {
return tidy(() => {
let regularization: Tensor = zeros([1]);
if (this.hasL1) {
regularization = add(regularization, sum(tfc.mul(this.l1, abs(x))));
}
if (this.hasL2) {
regularization =
add(regularization, sum(tfc.mul(this.l2, K.square(x))));
}
return regularization.asScalar();
});
}
getConfig(): serialization.ConfigDict {
return {'l1': this.l1, 'l2': this.l2};
}
/** @nocollapse */
static fromConfig<T extends serialization.Serializable>(
cls: serialization.SerializableConstructor<T>,
config: serialization.ConfigDict): T {
return new cls({l1: config['l1'] as number, l2: config['l2'] as number});
}
}
serialization.registerClass(L1L2);
export function l1(args?: L1Args) {
assertObjectArgs(args);
return new L1L2({l1: args != null ? args.l1 : null, l2: 0});
}
export function l2(args: L2Args) {
assertObjectArgs(args);
return new L1L2({l2: args != null ? args.l2 : null, l1: 0});
}
/** @docinline */
export type RegularizerIdentifier = 'l1l2'|string;
// Maps the JavaScript-like identifier keys to the corresponding keras symbols.
export const REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP:
{[identifier in RegularizerIdentifier]: string} = {
'l1l2': 'L1L2'
};
export function serializeRegularizer(constraint: Regularizer):
serialization.ConfigDictValue {
return serializeKerasObject(constraint);
}
export function deserializeRegularizer(
config: serialization.ConfigDict,
customObjects: serialization.ConfigDict = {}): Regularizer {
return deserializeKerasObject(
config, serialization.SerializationMap.getMap().classNameMap,
customObjects, 'regularizer');
}
export function getRegularizer(identifier: RegularizerIdentifier|
serialization.ConfigDict|
Regularizer): Regularizer {
if (identifier == null) {
return null;
}
if (typeof identifier === 'string') {
const className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
identifier;
const config = {className, config: {}};
return deserializeRegularizer(config);
} else if (identifier instanceof Regularizer) {
return identifier;
} else {
return deserializeRegularizer(identifier);
}
}