-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathadam_solver.cpp
More file actions
94 lines (81 loc) · 2.87 KB
/
adam_solver.cpp
File metadata and controls
94 lines (81 loc) · 2.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <vector>
#include "caffe/sgd_solvers.hpp"
namespace caffe {
template <typename Dtype>
void AdamSolver<Dtype>::AdamPreSolve() {
// Add the extra history entries for Adam after those from
// SGDSolver::PreSolve
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
for (int i = 0; i < net_params.size(); ++i) {
const vector<int>& shape = net_params[i]->shape();
this->history_.push_back(
shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
}
}
#ifndef CPU_ONLY
template <typename Dtype>
void adam_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Dtype beta1,
Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate);
#endif
template <typename Dtype>
void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype local_rate = rate * net_params_lr[param_id];
const Dtype beta1 = this->param_.momentum();
const Dtype beta2 = this->param_.momentum2();
// we create aliases for convenience
size_t update_history_offset = net_params.size();
Blob<Dtype>* val_m = this->history_[param_id].get();
Blob<Dtype>* val_v = this->history_[param_id + update_history_offset].get();
Blob<Dtype>* val_t = this->temp_[param_id].get();
const int t = this->iter_ + 1;
const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) /
(Dtype(1.) - pow(beta1, t));
const int N = net_params[param_id]->count();
const Dtype eps_hat = this->param_.delta();
switch (Caffe::mode()) {
case Caffe::CPU: {
// update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
caffe_cpu_axpby(N, Dtype(1)-beta1,
net_params[param_id]->cpu_diff(), beta1,
val_m->mutable_cpu_data());
// update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
caffe_mul(N,
net_params[param_id]->cpu_diff(),
net_params[param_id]->cpu_diff(),
val_t->mutable_cpu_data());
caffe_cpu_axpby(N, Dtype(1)-beta2,
val_t->cpu_data(), beta2,
val_v->mutable_cpu_data());
// set update
caffe_powx(N,
val_v->cpu_data(), Dtype(0.5),
val_t->mutable_cpu_data());
caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data());
caffe_div(N,
val_m->cpu_data(),
val_t->cpu_data(),
val_t->mutable_cpu_data());
caffe_cpu_scale(N, local_rate*correction,
val_t->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
adam_update_gpu(N, net_params[param_id]->mutable_gpu_diff(),
val_m->mutable_gpu_data(), val_v->mutable_gpu_data(), beta1, beta2,
eps_hat, local_rate*correction);
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
INSTANTIATE_CLASS(AdamSolver);
REGISTER_SOLVER_CLASS(Adam);
} // namespace caffe