-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathadagrad_solver.cpp
More file actions
69 lines (58 loc) · 2.09 KB
/
adagrad_solver.cpp
File metadata and controls
69 lines (58 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <vector>
#include "caffe/sgd_solvers.hpp"
namespace caffe {
#ifndef CPU_ONLY
template <typename Dtype>
void adagrad_update_gpu(int N, Dtype* g, Dtype* h, Dtype delta,
Dtype local_rate);
#endif
template <typename Dtype>
void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype delta = this->param_.delta();
Dtype local_rate = rate * net_params_lr[param_id];
switch (Caffe::mode()) {
case Caffe::CPU: {
// compute square of gradient in update
caffe_powx(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), Dtype(2),
this->update_[param_id]->mutable_cpu_data());
// update history
caffe_add(net_params[param_id]->count(),
this->update_[param_id]->cpu_data(),
this->history_[param_id]->cpu_data(),
this->history_[param_id]->mutable_cpu_data());
// prepare update
caffe_powx(net_params[param_id]->count(),
this->history_[param_id]->cpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_cpu_data());
caffe_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_cpu_data());
caffe_div(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(),
this->update_[param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());
// scale and copy
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->cpu_data(), Dtype(0),
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
adagrad_update_gpu(net_params[param_id]->count(),
net_params[param_id]->mutable_gpu_diff(),
this->history_[param_id]->mutable_gpu_data(), delta, local_rate);
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
INSTANTIATE_CLASS(AdaGradSolver);
REGISTER_SOLVER_CLASS(AdaGrad);
} // namespace caffe