Skip to content

Commit 2fe6984

Browse files
authored
Merge pull request rlcode#68 from jcwleo/master
Add DQN with PER
2 parents a497d71 + 560e8dd commit 2fe6984

2 files changed

Lines changed: 279 additions & 0 deletions

File tree

2-cartpole/1-dqn/SumTree.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import numpy
2+
3+
4+
class SumTree:
5+
write = 0
6+
7+
def __init__(self, capacity):
8+
self.capacity = capacity
9+
self.tree = numpy.zeros(2 * capacity - 1)
10+
self.data = numpy.zeros(capacity, dtype=object)
11+
12+
def _propagate(self, idx, change):
13+
parent = (idx - 1) // 2
14+
15+
self.tree[parent] += change
16+
17+
if parent != 0:
18+
self._propagate(parent, change)
19+
20+
def _retrieve(self, idx, s):
21+
left = 2 * idx + 1
22+
right = left + 1
23+
24+
if left >= len(self.tree):
25+
return idx
26+
27+
if s <= self.tree[left]:
28+
return self._retrieve(left, s)
29+
else:
30+
return self._retrieve(right, s - self.tree[left])
31+
32+
def total(self):
33+
return self.tree[0]
34+
35+
def add(self, p, data):
36+
idx = self.write + self.capacity - 1
37+
38+
self.data[self.write] = data
39+
self.update(idx, p)
40+
41+
self.write += 1
42+
if self.write >= self.capacity:
43+
self.write = 0
44+
45+
def update(self, idx, p):
46+
change = p - self.tree[idx]
47+
48+
self.tree[idx] = p
49+
self._propagate(idx, change)
50+
51+
def get(self, s):
52+
idx = self._retrieve(0, s)
53+
dataIdx = idx - self.capacity + 1
54+
55+
return (idx, self.tree[idx], self.data[dataIdx])
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import sys
2+
import gym
3+
import pylab
4+
import random
5+
import numpy as np
6+
from SumTree import SumTree
7+
from collections import deque
8+
from keras.layers import Dense
9+
from keras.optimizers import Adam
10+
from keras.models import Sequential
11+
12+
EPISODES = 300
13+
14+
15+
# 카트폴 예제에서의 DQN 에이전트
16+
class DQNAgent:
17+
def __init__(self, state_size, action_size):
18+
self.render = False
19+
self.load_model = False
20+
21+
# 상태와 행동의 크기 정의
22+
self.state_size = state_size
23+
self.action_size = action_size
24+
25+
# DQN 하이퍼파라미터
26+
self.discount_factor = 0.99
27+
self.learning_rate = 0.001
28+
self.epsilon = 1.0
29+
self.epsilon_decay = 0.999
30+
self.epsilon_min = 0.01
31+
self.batch_size = 64
32+
self.train_start = 2000
33+
self.memory_size = 2000
34+
35+
# 리플레이 메모리, 최대 크기 2000
36+
self.memory = Memory(self.memory_size)
37+
38+
# 모델과 타깃 모델 생성
39+
self.model = self.build_model()
40+
self.target_model = self.build_model()
41+
42+
# 타깃 모델 초기화
43+
self.update_target_model()
44+
45+
if self.load_model:
46+
self.model.load_weights("./save_model/cartpole_dqn_trained.h5")
47+
48+
# 상태가 입력, 큐함수가 출력인 인공신경망 생성
49+
def build_model(self):
50+
model = Sequential()
51+
model.add(Dense(24, input_dim=self.state_size, activation='relu',
52+
kernel_initializer='he_uniform'))
53+
model.add(Dense(24, activation='relu',
54+
kernel_initializer='he_uniform'))
55+
model.add(Dense(self.action_size, activation='linear',
56+
kernel_initializer='he_uniform'))
57+
model.summary()
58+
model.compile(loss='mse', optimizer=Adam(lr=self.learning_rate))
59+
return model
60+
61+
# 타깃 모델을 모델의 가중치로 업데이트
62+
def update_target_model(self):
63+
self.target_model.set_weights(self.model.get_weights())
64+
65+
# 입실론 탐욕 정책으로 행동 선택
66+
def get_action(self, state):
67+
if np.random.rand() <= self.epsilon:
68+
return random.randrange(self.action_size)
69+
else:
70+
q_value = self.model.predict(state)
71+
return np.argmax(q_value[0])
72+
73+
# 샘플 <s, a, r, s'>을 리플레이 메모리에 저장
74+
def append_sample(self, state, action, reward, next_state, done):
75+
if self.epsilon == 1:
76+
done = True
77+
78+
# TD-error 를 구해서 같이 메모리에 저장
79+
target = self.model.predict([state])
80+
old_val = target[0][action]
81+
target_val = self.target_model.predict([next_state])
82+
if done:
83+
target[0][action] = reward
84+
else:
85+
target[0][action] = reward + self.discount_factor * (
86+
np.amax(target_val[0]))
87+
error = abs(old_val - target[0][action])
88+
89+
self.memory.add(error, (state, action, reward, next_state, done))
90+
91+
# 리플레이 메모리에서 무작위로 추출한 배치로 모델 학습
92+
def train_model(self):
93+
if self.epsilon > self.epsilon_min:
94+
self.epsilon *= self.epsilon_decay
95+
96+
# 메모리에서 배치 크기만큼 무작위로 샘플 추출
97+
mini_batch = self.memory.sample(self.batch_size)
98+
99+
errors = np.zeros(self.batch_size)
100+
states = np.zeros((self.batch_size, self.state_size))
101+
next_states = np.zeros((self.batch_size, self.state_size))
102+
actions, rewards, dones = [], [], []
103+
104+
for i in range(self.batch_size):
105+
states[i] = mini_batch[i][1][0]
106+
actions.append(mini_batch[i][1][1])
107+
rewards.append(mini_batch[i][1][2])
108+
next_states[i] = mini_batch[i][1][3]
109+
dones.append(mini_batch[i][1][4])
110+
111+
# 현재 상태에 대한 모델의 큐함수
112+
# 다음 상태에 대한 타깃 모델의 큐함수
113+
target = self.model.predict(states)
114+
target_val = self.target_model.predict(next_states)
115+
116+
# 벨만 최적 방정식을 이용한 업데이트 타깃
117+
for i in range(self.batch_size):
118+
old_val = target[i][actions[i]]
119+
if dones[i]:
120+
target[i][actions[i]] = rewards[i]
121+
else:
122+
target[i][actions[i]] = rewards[i] + self.discount_factor * (
123+
np.amax(target_val[i]))
124+
# TD-error를 저장
125+
errors[i] = abs(old_val - target[i][actions[i]])
126+
127+
# TD-error로 priority 업데이트
128+
for i in range(self.batch_size):
129+
idx = mini_batch[i][0]
130+
self.memory.update(idx, errors[i])
131+
132+
self.model.fit(states, target, batch_size=self.batch_size,
133+
epochs=1, verbose=0)
134+
135+
136+
class Memory: # stored as ( s, a, r, s_ ) in SumTree
137+
e = 0.01
138+
a = 0.6
139+
140+
def __init__(self, capacity):
141+
self.tree = SumTree(capacity)
142+
143+
def _getPriority(self, error):
144+
return (error + self.e) ** self.a
145+
146+
def add(self, error, sample):
147+
p = self._getPriority(error)
148+
self.tree.add(p, sample)
149+
150+
def sample(self, n):
151+
batch = []
152+
segment = self.tree.total() / n
153+
154+
for i in range(n):
155+
a = segment * i
156+
b = segment * (i + 1)
157+
158+
s = random.uniform(a, b)
159+
(idx, p, data) = self.tree.get(s)
160+
batch.append((idx, data))
161+
162+
return batch
163+
164+
def update(self, idx, error):
165+
p = self._getPriority(error)
166+
self.tree.update(idx, p)
167+
168+
169+
if __name__ == "__main__":
170+
# CartPole-v1 환경, 최대 타임스텝 수가 500
171+
env = gym.make('CartPole-v1')
172+
state_size = env.observation_space.shape[0]
173+
action_size = env.action_space.n
174+
175+
# DQN 에이전트 생성
176+
agent = DQNAgent(state_size, action_size)
177+
178+
scores, episodes = [], []
179+
180+
step = 0
181+
for e in range(EPISODES):
182+
done = False
183+
score = 0
184+
# env 초기화
185+
state = env.reset()
186+
state = np.reshape(state, [1, state_size])
187+
188+
while not done:
189+
if agent.render:
190+
env.render()
191+
step += 1
192+
# 현재 상태로 행동을 선택
193+
action = agent.get_action(state)
194+
# 선택한 행동으로 환경에서 한 타임스텝 진행
195+
next_state, reward, done, info = env.step(action)
196+
next_state = np.reshape(next_state, [1, state_size])
197+
# 에피소드가 중간에 끝나면 -100 보상
198+
r = reward if not done or score+reward == 500 else -10
199+
# 리플레이 메모리에 샘플 <s, a, r, s'> 저장
200+
agent.append_sample(state, action, r, next_state, done)
201+
# 매 타임스텝마다 학습
202+
if step >= agent.train_start:
203+
agent.train_model()
204+
205+
score += reward
206+
state = next_state
207+
208+
if done:
209+
# 각 에피소드마다 타깃 모델을 모델의 가중치로 업데이트
210+
agent.update_target_model()
211+
212+
# score = score if score == 500 else score + 100
213+
# 에피소드마다 학습 결과 출력
214+
scores.append(score)
215+
episodes.append(e)
216+
pylab.plot(episodes, scores, 'b')
217+
pylab.savefig("./save_graph/cartpole_dqn.png")
218+
print("episode:", e, " score:", score, " memory length:",
219+
step if step <= agent.memory_size else agent.memory_size, " epsilon:", agent.epsilon)
220+
221+
# 이전 10개 에피소드의 점수 평균이 490보다 크면 학습 중단
222+
if np.mean(scores[-min(10, len(scores)):]) > 490:
223+
agent.model.save_weights("./save_model/cartpole_dqn.h5")
224+
sys.exit()

0 commit comments

Comments
 (0)