1+ #include " NN_RLSolverBase.h"
2+
3+ using namespace ReinforcementLearning ;
4+ using namespace NeuralNet ;
5+ using namespace DeepLearning ;
6+ NN_RLSolverBase::NN_RLSolverBase (std::shared_ptr<BaseModel> m,
7+ std::shared_ptr<Net> net0,
8+ std::shared_ptr<Trainer> trainer0,
9+ int Dim, DeepLearning::QLearningSolverParameter para):
10+ model(m), net(net0), trainer(trainer0), stateDim(Dim),
11+ trainingPara(para){
12+ netInputDim = stateDim + 1 ;
13+ randChoice = std::make_shared<RandomStream>(0 , model->getNumActions ()-1 );
14+ }
15+ void NN_RLSolverBase::train (){
16+ std::shared_ptr<arma::mat> trainingSampleX (new arma::mat);
17+ std::shared_ptr<arma::mat> trainingSampleY (new arma::mat);
18+ int maxIter = trainingPara.numtrainingepisodes ();
19+ for (int iter = 0 ; iter < maxIter; iter++){
20+ std::cout << " RLsolver iteration: " << iter << std::endl;
21+ this ->generateExperience ();
22+ if (iter > 20 ) {
23+ this ->generateTrainingSample (trainingSampleX, trainingSampleY);
24+ // trainingSampleX->print("X");
25+ // trainingSampleY->print("Y");
26+ trainer->setTrainingSamples (trainingSampleX, trainingSampleY);
27+ trainer->train ();
28+ }
29+ }
30+ }
31+
32+ void NN_RLSolverBase::getMaxQ (const State& S, double * Q, int * action) {
33+ double maxQ;
34+ int a = 0 ;
35+ std::shared_ptr<arma::mat> inputTemp (new arma::mat (netInputDim, 1 ));
36+ maxQ = -std::numeric_limits<double >::max ();
37+ for (int k = 0 ; k < this ->stateDim ; k++)
38+ inputTemp->at (k) = S[k] / this ->state_norm [k];
39+ for (int j = 0 ; j < model->getNumActions (); j++) {
40+ inputTemp->at (stateDim) = j / state_norm[stateDim];
41+ net->setTrainingSamples (inputTemp, nullptr );
42+ net->forward ();
43+ double tempQ = arma::as_scalar (*(net->netOutput ()));
44+ if (maxQ < tempQ) {
45+ maxQ = tempQ;
46+ a = j;
47+ }
48+ }
49+ *Q = maxQ;
50+ *action = a;
51+ return ;
52+ }
53+
54+ void NN_RLSolverBase::generateTrainingSample (std::shared_ptr<arma::mat> trainingX, std::shared_ptr<arma::mat> trainingY){
55+ trainingX->set_size (netInputDim, experienceSet.size ());
56+ trainingY->set_size (1 , experienceSet.size ());
57+ double maxQ;
58+ int action;
59+ std::shared_ptr<arma::mat> inputTemp (new arma::mat (netInputDim, 1 ));
60+ for (int i = 0 ; i < this ->experienceSet .size (); i++) {
61+ this ->getMaxQ (experienceSet[i].newState ,&maxQ,&action);
62+ double targetQ = experienceSet[i].reward + trainingPara.discount ()*maxQ;;
63+ for ( int k = 0 ; k < this ->stateDim ; k++)
64+ inputTemp->at (k) = experienceSet[i].oldState [k] / this ->state_norm [k];
65+ inputTemp->at (stateDim) = experienceSet[i].action / state_norm[stateDim];
66+
67+ trainingX->col (i) = *inputTemp;
68+ trainingY->at (i) = targetQ;
69+ }
70+ }
71+
72+ void NN_RLSolverBase::generateExperience (){
73+ double maxQ;
74+ int action;
75+ double epi = trainingPara.epsilon ();
76+ arma::mat outputTemp (1 ,1 );
77+ std::shared_ptr<arma::mat> inputTemp (new arma::mat (netInputDim, 1 ));
78+ model->createInitialState ();
79+ int i;
80+ for (i = 0 ; i < trainingPara.episodelength (); i++){
81+ if ( this ->terminate (model->getCurrState ()) ) {
82+ break ;
83+ }
84+ State oldState = model->getCurrState ();
85+ if (randChoice->nextDou ()< epi){
86+ this ->getMaxQ (oldState, &maxQ, &action);
87+ } else {
88+ action = randChoice->nextInt ();
89+ }
90+ model->run (action);
91+ State currState = model->getCurrState ();
92+ double r = this ->getRewards (currState);
93+ oldState.shrink_to_fit ();
94+ currState.shrink_to_fit ();
95+ this ->experienceSet .push_back (Experience (oldState,currState, action, r));
96+ }
97+ }
98+
99+ double NN_RLSolverBase::getRewards (const State &newS) const {
100+ return 0.0 ;
101+ }
102+ bool NN_RLSolverBase::terminate (const State& S) const {
103+ return false ;
104+ }
0 commit comments