Skip to content

Commit d39ed99

Browse files
committed
add basesolver and model
1 parent e7cd966 commit d39ed99

3 files changed

Lines changed: 174 additions & 0 deletions

File tree

src/test/NN-RL/BaseModel.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#pragma once
2+
#include <array>
3+
namespace ReinforcementLearning {
4+
5+
typedef std::vector<double> State;
6+
7+
class BaseModel {
8+
public:
9+
virtual ~BaseModel(){}
10+
virtual void run(int actions) = 0;
11+
12+
virtual State getCurrState() {
13+
return currState;
14+
}
15+
virtual void createInitialState() = 0;
16+
virtual int getNumActions(){ return numActions;}
17+
18+
protected:
19+
State currState, prevState;
20+
int numActions;
21+
int stateDim;
22+
};
23+
24+
struct Experience{
25+
State oldState, newState;
26+
int action;
27+
double reward;
28+
Experience(State old0, State new0, int a0, double c0):
29+
oldState(old0),newState(new0), action(a0), reward(c0)
30+
{}
31+
};
32+
33+
}

src/test/NN-RL/NN_RLSolverBase.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#include "NN_RLSolverBase.h"
2+
3+
using namespace ReinforcementLearning;
4+
using namespace NeuralNet;
5+
using namespace DeepLearning;
6+
NN_RLSolverBase::NN_RLSolverBase(std::shared_ptr<BaseModel> m,
7+
std::shared_ptr<Net> net0,
8+
std::shared_ptr<Trainer> trainer0,
9+
int Dim, DeepLearning::QLearningSolverParameter para):
10+
model(m), net(net0), trainer(trainer0), stateDim(Dim),
11+
trainingPara(para){
12+
netInputDim = stateDim + 1;
13+
randChoice = std::make_shared<RandomStream>(0, model->getNumActions()-1);
14+
}
15+
void NN_RLSolverBase::train(){
16+
std::shared_ptr<arma::mat> trainingSampleX(new arma::mat);
17+
std::shared_ptr<arma::mat> trainingSampleY(new arma::mat);
18+
int maxIter = trainingPara.numtrainingepisodes();
19+
for (int iter = 0; iter < maxIter; iter++){
20+
std::cout << "RLsolver iteration: " << iter << std::endl;
21+
this->generateExperience();
22+
if (iter > 20) {
23+
this->generateTrainingSample(trainingSampleX, trainingSampleY);
24+
// trainingSampleX->print("X");
25+
// trainingSampleY->print("Y");
26+
trainer->setTrainingSamples(trainingSampleX, trainingSampleY);
27+
trainer->train();
28+
}
29+
}
30+
}
31+
32+
void NN_RLSolverBase::getMaxQ(const State& S, double* Q, int* action) {
33+
double maxQ;
34+
int a = 0;
35+
std::shared_ptr<arma::mat> inputTemp(new arma::mat(netInputDim, 1));
36+
maxQ = -std::numeric_limits<double>::max();
37+
for (int k = 0; k < this->stateDim; k++)
38+
inputTemp->at(k) = S[k] / this->state_norm[k];
39+
for (int j = 0; j < model->getNumActions(); j++) {
40+
inputTemp->at(stateDim) = j / state_norm[stateDim];
41+
net->setTrainingSamples(inputTemp, nullptr);
42+
net->forward();
43+
double tempQ = arma::as_scalar(*(net->netOutput()));
44+
if (maxQ < tempQ) {
45+
maxQ = tempQ;
46+
a = j;
47+
}
48+
}
49+
*Q = maxQ;
50+
*action = a;
51+
return;
52+
}
53+
54+
void NN_RLSolverBase::generateTrainingSample(std::shared_ptr<arma::mat> trainingX, std::shared_ptr<arma::mat> trainingY){
55+
trainingX->set_size(netInputDim, experienceSet.size());
56+
trainingY->set_size(1, experienceSet.size());
57+
double maxQ;
58+
int action;
59+
std::shared_ptr<arma::mat> inputTemp(new arma::mat(netInputDim, 1));
60+
for (int i = 0; i < this->experienceSet.size(); i++) {
61+
this->getMaxQ(experienceSet[i].newState,&maxQ,&action);
62+
double targetQ = experienceSet[i].reward + trainingPara.discount()*maxQ;;
63+
for ( int k = 0; k < this->stateDim; k++)
64+
inputTemp->at(k) = experienceSet[i].oldState[k] / this->state_norm[k];
65+
inputTemp->at(stateDim) = experienceSet[i].action / state_norm[stateDim];
66+
67+
trainingX->col(i) = *inputTemp;
68+
trainingY->at(i) = targetQ;
69+
}
70+
}
71+
72+
void NN_RLSolverBase::generateExperience(){
73+
double maxQ;
74+
int action;
75+
double epi = trainingPara.epsilon();
76+
arma::mat outputTemp(1,1);
77+
std::shared_ptr<arma::mat> inputTemp(new arma::mat(netInputDim, 1));
78+
model->createInitialState();
79+
int i;
80+
for(i = 0; i < trainingPara.episodelength(); i++){
81+
if( this->terminate(model->getCurrState()) ) {
82+
break;
83+
}
84+
State oldState = model->getCurrState();
85+
if (randChoice->nextDou()< epi){
86+
this->getMaxQ(oldState, &maxQ, &action);
87+
} else {
88+
action = randChoice->nextInt();
89+
}
90+
model->run(action);
91+
State currState = model->getCurrState();
92+
double r = this->getRewards(currState);
93+
oldState.shrink_to_fit();
94+
currState.shrink_to_fit();
95+
this->experienceSet.push_back(Experience(oldState,currState, action, r));
96+
}
97+
}
98+
99+
double NN_RLSolverBase::getRewards(const State &newS) const{
100+
return 0.0;
101+
}
102+
bool NN_RLSolverBase::terminate(const State& S) const {
103+
return false;
104+
}

src/test/NN-RL/NN_RLSolverBase.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
#include <armadillo>
3+
#include "BaseModel.h"
4+
#include "common.h"
5+
#include "Model_PoleSimple.h"
6+
#include "Net.h"
7+
#include "../Trainer/Trainer.h"
8+
9+
namespace ReinforcementLearning {
10+
11+
class NN_RLSolverBase {
12+
public:
13+
NN_RLSolverBase(std::shared_ptr<BaseModel> m,
14+
std::shared_ptr<NeuralNet::Net> net0,
15+
std::shared_ptr<NeuralNet::Trainer> trainer0, int Dim, DeepLearning::QLearningSolverParameter para);
16+
virtual ~NN_RLSolverBase(){}
17+
virtual void train();
18+
virtual void test(){}
19+
virtual void generateTrainingSample(std::shared_ptr<arma::mat> trainingSampleX, std::shared_ptr<arma::mat> trainingSampleY);
20+
virtual void generateExperience();
21+
virtual void getMaxQ(const State& S,double* Q, int* action);
22+
virtual double getRewards(const State& newS) const;
23+
virtual bool terminate(const State& S) const;
24+
virtual void setNormalizationConst(){}
25+
protected:
26+
int stateDim;
27+
int netInputDim;
28+
std::shared_ptr<BaseModel> model;
29+
std::shared_ptr<NeuralNet::Net> net;
30+
std::shared_ptr<NeuralNet::Trainer> trainer;
31+
DeepLearning::QLearningSolverParameter trainingPara;
32+
std::shared_ptr<RandomStream> randChoice;
33+
std::vector<Experience> experienceSet;
34+
State state_norm;
35+
double action_norm;
36+
};
37+
}

0 commit comments

Comments
 (0)