forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAdamW.cs
More file actions
64 lines (57 loc) · 2.37 KB
/
AdamW.cs
File metadata and controls
64 lines (57 loc) · 2.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
namespace Tensorflow.Keras.Optimizers
{
public class AdamW : Adam
{
string name;
float weight_decay;
DeviceDType deType;
List<string> no_decay_params = null;
public AdamW(float learning_rate= 0.001f,
float weight_decay= 0.004f,
float beta_1= 0.9f,
float beta_2= 0.999f,
float epsilon= 1e-7f,
bool amsgrad = false,
List<string> no_decay_params = null,
string name= "AdamW") : base(learning_rate, beta_1, beta_2, epsilon, amsgrad)
{
this.name = name;
this.weight_decay = weight_decay;
this.no_decay_params = no_decay_params;
}
protected Operation _decay_weights_op(IVariableV1 var, float learning_rate, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{
bool do_decay = _do_use_weight_decay(var.Name);
if (do_decay) return var.assign_add(
-learning_rate * var.AsTensor() * apply_state[deType]["weight_decay"]);
return tf.no_op();
}
protected bool _do_use_weight_decay(string param_name)
{
// Whether to use L2 weight decay for `param_name`.
if (this.weight_decay == 0)
return false;
if (this.no_decay_params != null)
{
foreach (var name in no_decay_params)
{
if (param_name.Contains(name)) return false;
}
}
return true;
}
protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{
var decay = _decay_weights_op(var, _hyper["learning_rate"], apply_state);
tf.control_dependencies(new[] { decay });
return base._resource_apply_dense(var, grad, apply_state);
}
protected override void _prepare_local(DeviceDType device_dtype, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{
this.deType = device_dtype;
base._prepare_local(device_dtype, apply_state);
apply_state[device_dtype]["weight_decay"] = tf.constant(
weight_decay, name: "adam_weight_decay_rate");
}
}
}