Skip to content

Commit 1bf5f2b

Browse files
committed
DQN 추가
1 parent 7b228b1 commit 1bf5f2b

27 files changed

Lines changed: 504 additions & 0 deletions

06 - Game Agent (DQN)/README.md

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Deep Q-network
2+
3+
- 구글의 딥마인드에서 개발한 Deep Q-network (DQN)을 이용하여 Q-learning 을 구현해봅니다.
4+
- 딥마인드의 논문에서는 신경망 모델을 CNN 모델을 사용하지만, 여기서는 간단히 기본적인 다중 신경망 모델을 사용합니다.
5+
- 게임은 간단한 장애물 피하기 게임이며 화면 출력은 matplotlib 으로 구현하였습니다.
6+
7+
### 파일 설명
8+
9+
- agent.py
10+
- 게임을 진행하거나 학습시키는 에이전트입니다.
11+
- game.py
12+
- 게임을 구현해 놓은 파일입니다. 게임의 상태를 화면의 픽셀로 가져오지 않고, 좌표값을 이용하여 계산량을 줄이도록 하였습니다.
13+
- model.py
14+
- DQN을 구현해 놓은 파일입니다.
15+
- 논문에서는 CNN 모델을 사용하였지만, 구현을 간단히 하고 성능을 빠르게 하기 위해 기본적인 신경망 모델을 사용합니다.
16+
17+
### 핵심 코드
18+
19+
게임 구현을 위한 다양한 내용들이 들어있어 코드분량이 꽤 많지만, 핵심 내용은 딱 다음과 같습니다.
20+
21+
1. Q_value 를 이용해 얻어온 액션을 수행하고, 해당 액션에 의한 게임의 상태와 리워드를 획득한 뒤, 이것을 메모리에 순차적으로 쌓아둡니다.
22+
2. 일정 수준 이상의 메모리가 쌓이면, 메모리에 저장된 것들 중 샘플링을 하여 논문의 다음 수식을 이용해 최적화를 수행합니다.
23+
24+
```
25+
Set y_j =
26+
if episode is terminates at step j+1 then r_j
27+
otherwise r_j + γ*max_a'Q(ð_(j+1),a';θ')
28+
With respect to the network parameters θ
29+
Perform a gradient descent step on (y_j-Q(ð_j,a_j;θ))^2
30+
Every C steps reset Q^ = Q
31+
```
32+
33+
위 내용을 구현한 코드는 model.py 파일의 아래의 내용과 같습니다.
34+
35+
```python
36+
# model.py
37+
38+
def build_model(self):
39+
L1 = tf.nn.relu(tf.matmul(state, W1) + b1)
40+
Q_value = tf.matmul(L2, W3) + b3
41+
42+
Q_action = tf.reduce_sum(tf.mul(Q_value, self.input_action), axis=1)
43+
cost = tf.reduce_mean(tf.square(self.input_Y - Q_action))
44+
train_op = tf.train.AdamOptimizer(1e-6).minimize(cost, global_step=self.global_step)
45+
46+
def train(self):
47+
Q_value = self.Q_value.eval(feed_dict={self.input_state: next_state})
48+
49+
for i in range(0, self.BATCH_SIZE):
50+
if minibatch[i][4]: # if episode is terminates
51+
Y.append(reward[i])
52+
else:
53+
Y.append(reward[i] + self.GAMMA * np.max(Q_value[i]))
54+
55+
self.train_op.run(feed_dict={
56+
self.input_Y: Y,
57+
self.input_action: action,
58+
self.input_state: state
59+
})
60+
```
61+
62+
### 결과물
63+
64+
- 상상력을 발휘해주세요. 검정색 배경은 도로, 사각형을 자동차들로 그리고 녹색 사각형을 자율 주행차라고 상상하고 즐겨주시면 감사하겠습니다. :-D
65+
- 100만번 정도의 학습 후 최고의 성능을 내기 시작했으며, 2012 맥북프로 CPU 버전으로 최고 성능을 내는데까지 약 1시간 정도 걸렸습니다.
66+
67+
![게임](screenshot_game.gif)
68+
69+
![텐서보드](screenshot_tensorboard.png)
70+
71+
### 사용법
72+
73+
자가 학습시키기
74+
75+
```
76+
python agent.py -train
77+
```
78+
79+
얼마나 잘 하는지 확인해보기
80+
81+
```
82+
python agent.py
83+
```
84+
85+
텐서보드로 평균 보상값 확인해보기
86+
87+
```
88+
tensorboard --logdir=./logs
89+
```

06 - Game Agent (DQN)/agent.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# -*- coding: utf-8 -*-
2+
# 게임 구현과 DQN 모델을 이용해 게임을 실행하고 학습을 진행합니다.
3+
4+
import tensorflow as tf
5+
import numpy as np
6+
import time
7+
8+
from game import Game
9+
from model import DQN
10+
11+
12+
tf.app.flags.DEFINE_boolean("train", False, "학습모드. 게임을 화면에 보여주지 않습니다.")
13+
FLAGS = tf.app.flags.FLAGS
14+
15+
# action: 0: 좌, 1: 유지, 2: 우
16+
n_action = 3
17+
screen_width = 6
18+
screen_height = 10
19+
20+
21+
def main(_):
22+
game = Game(screen_width, screen_height, show_game=not FLAGS.train)
23+
state = game.get_state()
24+
brain = DQN(n_action, screen_width, screen_height, state)
25+
26+
while 1:
27+
game.reset()
28+
gameover = FLAGS.train
29+
30+
print " Avg. Reward: %d, Total Game: %d" % (
31+
game.total_reward / game.total_game, game.total_game)
32+
33+
while not gameover:
34+
# DQN 모델을 이용해 실행할 액션을 결정합니다.
35+
action = brain.get_action(FLAGS.train)
36+
37+
# 결정한 액션을 이용해 게임을 진행하고, 보상과 게임의 종료 여부를 받아옵니다.
38+
reward, gameover = game.proceed(np.argmax(action))
39+
40+
# 위에서 결정한 액션에 따른 현재 상태를 가져옵니다.
41+
# 상태는 screen_width x screen_height 크기의 화면 구성입니다.
42+
state = game.get_state()
43+
44+
# DQN 으로 학습을 진행합니다.
45+
brain.step(state, action, reward, gameover)
46+
47+
# 학습모드가 아닌 경우, 게임 진행을 인간이 인지할 수 있는 속도로^^; 보여줍니다.
48+
if not FLAGS.train:
49+
time.sleep(0.3)
50+
51+
if __name__ == '__main__':
52+
tf.app.run()

06 - Game Agent (DQN)/game.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# -*- coding: utf-8 -*-
2+
# 장애물 회피 게임 즉, 자율주행차:-D 게임을 구현합니다.
3+
4+
import numpy as np
5+
import random
6+
7+
import matplotlib.pyplot as plt
8+
import matplotlib.patches as patches
9+
10+
11+
class Game:
12+
def __init__(self, screen_width, screen_height, show_game=True):
13+
self.screen_width = screen_width
14+
self.screen_height = screen_height
15+
# 도로의 크기는 스크린의 반으로 정하며, 도로의 좌측 우측의 여백을 계산해둡니다.
16+
self.road_width = (screen_width / 2)
17+
self.road_left = self.road_width / 2 + 1
18+
self.road_right = self.road_left + self.road_width - 1
19+
20+
# 자동차와 장애물의 초기 위치와, 장애물 각각의 속도를 정합니다.
21+
self.car = {"col": 0, "row": 2}
22+
self.block = [
23+
{"col": 0, "row": 0, "speed": 1},
24+
{"col": 0, "row": 0, "speed": 2},
25+
]
26+
27+
self.total_reward = 0.
28+
self.current_reward = 0.
29+
self.total_game = 0
30+
self.show_game = show_game
31+
32+
if show_game:
33+
self.fig, self.axis = self.prepare_display()
34+
35+
def prepare_display(self):
36+
"""게임을 화면에 보여주기 위해 matplotlib 으로 출력할 화면을 설정합니다."""
37+
fig, axis = plt.subplots(figsize=(4, 6))
38+
fig.set_size_inches(4, 6)
39+
# 화면을 닫으면 프로그램을 종료합니다.
40+
fig.canvas.mpl_connect('close_event', exit)
41+
plt.axis((0, self.screen_width, 0, self.screen_height))
42+
plt.tick_params(top='off', right='off',
43+
left='off', labelleft='off',
44+
bottom='off', labelbottom='off')
45+
46+
plt.draw()
47+
# 게임을 진행하며 화면을 업데이트 할 수 있도록 interactive 모드로 설정합니다.
48+
plt.ion()
49+
plt.show()
50+
51+
return fig, axis
52+
53+
def get_state(self):
54+
"""게임의 상태를 가져옵니다.
55+
56+
게임의 상태는 screen_width x screen_height 크기로 각 위치에 대한 상태값을 가지고 있으며,
57+
빈 공간인 경우에는 0, 사물이 있는 경우에는 1이 들어있는 1차원 배열입니다.
58+
계산의 편의성을 위해 2차원 -> 1차원으로 변환하여 사용합니다.
59+
"""
60+
state = np.zeros((self.screen_width, self.screen_height))
61+
62+
state[self.car["col"], self.car["row"]] = 1
63+
64+
if self.block[0]["row"] < self.screen_height:
65+
state[self.block[0]["col"], self.block[0]["row"]] = 1
66+
67+
if self.block[1]["row"] < self.screen_height:
68+
state[self.block[1]["col"], self.block[1]["row"]] = 1
69+
70+
return state.reshape((-1, self.screen_width * self.screen_height))
71+
72+
def draw_screen(self):
73+
title = " Avg. Reward: %d Reward: %d Total Game: %d" % (
74+
self.total_reward / self.total_game,
75+
self.current_reward,
76+
self.total_game)
77+
78+
self.axis.clear()
79+
self.axis.set_title(title, fontsize=12)
80+
81+
road = patches.Rectangle((self.road_left - 1, 0), self.road_width + 1, self.screen_height, linewidth=0, facecolor="#333333")
82+
# 자동차, 장애물들을 1x1 크기의 정사각형으로 그리도록하며, 좌표를 기준으로 중앙에 위치시킵니다.
83+
# 자동차의 경우에는 장애물과 충돌시 확인이 가능하도록 0.5만큼 아래쪽으로 이동하여 그립니다.
84+
car = patches.Rectangle((self.car["col"] - 0.5, self.car["row"] - 0.5), 1, 1, linewidth=0, facecolor="#00FF00")
85+
block1 = patches.Rectangle((self.block[0]["col"] - 0.5, self.block[0]["row"]), 1, 1, linewidth=0, facecolor="#0000FF")
86+
block2 = patches.Rectangle((self.block[1]["col"] - 0.5, self.block[1]["row"]), 1, 1, linewidth=0, facecolor="#FF0000")
87+
88+
self.axis.add_patch(road)
89+
self.axis.add_patch(car)
90+
self.axis.add_patch(block1)
91+
self.axis.add_patch(block2)
92+
93+
self.fig.canvas.draw()
94+
# 게임의 다음 단계 진행을 위해 matplot 의 이벤트 루프를 잠시 멈춥니다.
95+
plt.pause(0.0001)
96+
97+
def reset(self):
98+
"""자동차, 장애물의 위치와 보상값들을 초기화합니다."""
99+
self.current_reward = 0
100+
self.total_game += 1
101+
102+
self.car["col"] = int(self.screen_width / 2)
103+
104+
self.block[0]["col"] = random.randrange(self.road_left, self.road_right + 1)
105+
self.block[0]["row"] = 0
106+
self.block[1]["col"] = random.randrange(self.road_left, self.road_right + 1)
107+
self.block[1]["row"] = 0
108+
109+
self.update_block()
110+
111+
def update_car(self, move):
112+
"""액션에 따라 자동차를 이동시킵니다.
113+
114+
자동차 위치 제한을 도로가 아니라 화면의 좌우측 끝으로 하고,
115+
도로를 넘어가면 패널티를 주도록 학습해서 도로를 넘지 않게 만들면 더욱 좋을 것 같습니다.
116+
"""
117+
118+
# 자동차의 위치가 도로의 좌측을 넘지 않도록 합니다: max(0, move) > 0
119+
self.car["col"] = max(self.road_left, self.car["col"] + move)
120+
# 자동차의 위치가 도로의 우측을 넘지 않도록 합니다.: min(max, screen_width) < screen_width
121+
self.car["col"] = min(self.car["col"], self.road_right)
122+
123+
def update_block(self):
124+
"""장애물을 이동시킵니다.
125+
126+
장애물이 화면 내에 있는 경우는 각각의 속도에 따라 위치 변경을,
127+
화면을 벗어난 경우에는 다시 방해를 시작하도록 재설정을 합니다.
128+
"""
129+
reward = 0
130+
131+
if self.block[0]["row"] > 0:
132+
self.block[0]["row"] -= self.block[0]["speed"]
133+
else:
134+
self.block[0]["col"] = random.randrange(self.road_left, self.road_right + 1)
135+
self.block[0]["row"] = self.screen_height
136+
reward += 1
137+
138+
if self.block[1]["row"] > 0:
139+
self.block[1]["row"] -= self.block[1]["speed"]
140+
else:
141+
self.block[1]["col"] = random.randrange(self.road_left, self.road_right + 1)
142+
self.block[1]["row"] = self.screen_height
143+
reward += 1
144+
145+
return reward
146+
147+
def is_gameover(self):
148+
# 장애물과 자동차가 충돌했는지를 파악합니다.
149+
# 사각형 박스의 충돌을 체크하는 것이 아니라 좌표를 체크하는 것이어서 화면에는 약간 다르게 보일 수 있습니다.
150+
if ((self.car["col"] == self.block[0]["col"] and
151+
self.car["row"] == self.block[0]["row"]) or
152+
(self.car["col"] == self.block[1]["col"] and
153+
self.car["row"] == self.block[1]["row"])):
154+
155+
self.total_reward += self.current_reward
156+
157+
return True
158+
else:
159+
return False
160+
161+
def proceed(self, action):
162+
# action: 0: 좌, 1: 유지, 2: 우
163+
# action - 1 을 하여, 좌표를 액션이 0 일 경우 -1 만큼, 2 일 경우 1 만큼 옮깁니다.
164+
self.update_car(action - 1)
165+
# 장애물을 이동시킵니다. 장애물이 자동차에 충돌하지 않고 화면을 모두 지나가면 보상을 얻습니다.
166+
escape_reward = self.update_block()
167+
# 움직임이 적을 경우에도 보상을 줘서 안정적으로 이동하는 것 처럼 보이게 만듭니다.
168+
stable_reward = 1. / self.screen_height if action == 1 else 0
169+
# 게임이 종료됐는지를 판단합니다. 자동차와 장애물이 충돌했는지를 파악합니다.
170+
gameover = self.is_gameover()
171+
172+
if gameover:
173+
# 장애물에 충돌한 경우 -2점을 보상으로 줍니다. 장애물이 두 개이기 때문입니다.
174+
# 장애물을 회피했을 때 보상을 주지 않고, 충돌한 경우에만 -1점을 주어도 됩니다.
175+
reward = -2
176+
else:
177+
reward = escape_reward + stable_reward
178+
self.current_reward += reward
179+
180+
if self.show_game:
181+
self.draw_screen()
182+
183+
return reward, gameover
2.34 MB
Binary file not shown.
99.2 KB
Binary file not shown.
99.1 KB
Binary file not shown.
99.1 KB
Binary file not shown.
99.1 KB
Binary file not shown.

0 commit comments

Comments
 (0)