Skip to content

Commit f02a96a

Browse files
Merge pull request tensorforce#63 from Islandman93/ale
Merge ALE environment - pre environment restructuring
2 parents c877a7c + 9cc8737 commit f02a96a

2 files changed

Lines changed: 260 additions & 0 deletions

File tree

examples/ale.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2017 reinforce.io. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""
17+
Arcade Learning Environment execution
18+
"""
19+
20+
from __future__ import absolute_import
21+
from __future__ import division
22+
from __future__ import print_function
23+
24+
import argparse
25+
import logging
26+
import os
27+
import sys
28+
29+
from tensorforce import Configuration, TensorForceError
30+
from tensorforce.core.networks import from_json
31+
from tensorforce.agents import agents
32+
from tensorforce.environments.ale import ALE
33+
from tensorforce.execution import Runner
34+
35+
36+
def main():
37+
parser = argparse.ArgumentParser()
38+
39+
parser.add_argument('rom', help="File path of the rom")
40+
parser.add_argument('-a', '--agent', help='Agent')
41+
parser.add_argument('-c', '--agent-config', help="Agent configuration file")
42+
parser.add_argument('-n', '--network-config', help="Network configuration file")
43+
parser.add_argument('-fs', '--frame-skip', help="Number of frames to repeat action", type=int, default=1)
44+
parser.add_argument('-rc', '--reward-clipping', help="Reward clipping. EX: -1 1", nargs="+", type=float)
45+
parser.add_argument('-rap', '--repeat-action-probability', help="Repeat action probability", type=float, default=0.0)
46+
parser.add_argument('-lolt', '--loss-of-life-termination', help="Loss of life counts as terminal state", action='store_true')
47+
parser.add_argument('-lolr', '--loss-of-life-reward', help="Loss of life reward/penalty. EX: -1 to penalize", type=float, default=0.0)
48+
parser.add_argument('-ds', '--display-screen', action='store_true', default=False, help="Display emulator screen")
49+
parser.add_argument('-e', '--episodes', type=int, default=50000, help="Number of episodes")
50+
parser.add_argument('-t', '--max-timesteps', type=int, default=2000, help="Maximum number of timesteps per episode")
51+
parser.add_argument('-s', '--save', help="Save agent to this dir")
52+
parser.add_argument('-se', '--save-episodes', type=int, default=100, help="Save agent every x episodes")
53+
parser.add_argument('-l', '--load', help="Load agent from this dir")
54+
parser.add_argument('-D', '--debug', action='store_true', default=False, help="Show debug outputs")
55+
56+
args = parser.parse_args()
57+
58+
logger = logging.getLogger(__name__)
59+
logger.setLevel(logging.DEBUG) # configurable!!!
60+
logger.addHandler(logging.StreamHandler(sys.stdout))
61+
62+
environment = ALE(args.rom, frame_skip=args.frame_skip, reward_clipping=args.reward_clipping,
63+
repeat_action_probability=args.repeat_action_probability,
64+
loss_of_life_termination=args.loss_of_life_termination,
65+
loss_of_life_reward=args.loss_of_life_reward,
66+
display_screen=args.display_screen)
67+
68+
if args.agent_config:
69+
agent_config = Configuration.from_json(args.agent_config)
70+
else:
71+
agent_config = Configuration()
72+
logger.info("No agent configuration provided.")
73+
if args.network_config:
74+
network = from_json(args.network_config)
75+
else:
76+
network = None
77+
logger.info("No network configuration provided.")
78+
agent_config.default(dict(states=environment.states, actions=environment.actions, network=network))
79+
agent = agents[args.agent](config=agent_config)
80+
81+
if args.load:
82+
load_dir = os.path.dirname(args.load)
83+
if not os.path.isdir(load_dir):
84+
raise OSError("Could not load agent from {}: No such directory.".format(load_dir))
85+
agent.load_model(args.load)
86+
87+
if args.debug:
88+
logger.info("-" * 16)
89+
logger.info("Configuration:")
90+
logger.info(agent_config)
91+
92+
if args.save:
93+
save_dir = os.path.dirname(args.save)
94+
if not os.path.isdir(save_dir):
95+
try:
96+
os.mkdir(save_dir, 0o755)
97+
except OSError:
98+
raise OSError("Cannot save agent to dir {} ()".format(save_dir))
99+
100+
runner = Runner(
101+
agent=agent,
102+
environment=environment,
103+
repeat_actions=1,
104+
save_path=args.save,
105+
save_episodes=args.save_episodes
106+
)
107+
108+
report_episodes = args.episodes // 1000
109+
if args.debug:
110+
report_episodes = 1
111+
112+
def episode_finished(r):
113+
if r.episode % report_episodes == 0:
114+
logger.info("Finished episode {ep} after {ts} timesteps".format(ep=r.episode, ts=r.timestep))
115+
logger.info("Episode reward: {}".format(r.episode_rewards[-1]))
116+
logger.info("Average of last 500 rewards: {}".format(sum(r.episode_rewards[-500:]) / 500))
117+
logger.info("Average of last 100 rewards: {}".format(sum(r.episode_rewards[-100:]) / 100))
118+
return True
119+
120+
logger.info("Starting {agent} for Environment '{env}'".format(agent=agent, env=environment))
121+
runner.run(args.episodes, args.max_timesteps, episode_finished=episode_finished)
122+
logger.info("Learning finished. Total episodes: {ep}".format(ep=runner.episode))
123+
124+
environment.close()
125+
126+
127+
if __name__ == '__main__':
128+
main()

tensorforce/environments/ale.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright 2017 reinforce.io. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""
17+
Arcade Learning Environment (ALE). https://github.com/mgbellemare/Arcade-Learning-Environment
18+
"""
19+
20+
from __future__ import absolute_import
21+
from __future__ import print_function
22+
from __future__ import division
23+
24+
import numpy as np
25+
from ale_python_interface import ALEInterface
26+
27+
from tensorforce import TensorForceError
28+
from tensorforce.environments import Environment
29+
30+
31+
class ALE(Environment):
32+
33+
def __init__(self, rom, frame_skip=1, reward_clipping=None, repeat_action_probability=0.0,
34+
loss_of_life_termination=False, loss_of_life_reward=0, display_screen=False,
35+
seed=np.random.RandomState()):
36+
"""
37+
Initialize ALE.
38+
39+
Args:
40+
rom: Rom filename and directory.
41+
frame_skip: Repeat action for n frames. Default 1.
42+
reward_clipping: Clip rewards between (low, high). Can be None. Default None.
43+
repeat_action_probability: Repeats last action with given probability. Default 0.
44+
loss_of_life_termination: Signals a terminal state on loss of life. Default False.
45+
loss_of_life_reward: Reward/Penalty on loss of life (negative values are a penalty). Default 0.
46+
display_screen: Displays the emulator screen. Default False.
47+
seed: Random seed
48+
"""
49+
50+
self.ale = ALEInterface()
51+
self.rom = rom
52+
53+
self.ale.setBool(b'display_screen', display_screen)
54+
self.ale.setInt(b'random_seed', seed.randint(0, 9999))
55+
self.ale.setFloat(b'repeat_action_probability', repeat_action_probability)
56+
self.ale.setBool(b'color_averaging', False)
57+
self.ale.setInt(b'frame_skip', frame_skip)
58+
59+
# all set commands must be done before loading the ROM
60+
self.ale.loadROM(rom.encode())
61+
62+
# setup gamescreen object
63+
width, height = self.ale.getScreenDims()
64+
self.gamescreen = np.empty((height, width, 3), dtype=np.uint8)
65+
66+
self.frame_skip = frame_skip
67+
68+
# setup action converter
69+
# ALE returns legal action indexes, convert these to just numbers
70+
self.action_inds = self.ale.getMinimalActionSet()
71+
72+
# setup lives
73+
self.loss_of_life_reward = loss_of_life_reward
74+
self.cur_lives = self.ale.lives()
75+
self.loss_of_life_termination = loss_of_life_termination
76+
self.life_lost = False
77+
78+
# reward clipping
79+
self.reward_clipping = reward_clipping
80+
81+
def __str__(self):
82+
return 'ALE({})'.format(self.rom)
83+
84+
def close(self):
85+
self.ale = None
86+
87+
def reset(self):
88+
self.ale.reset_game()
89+
self.cur_lives = self.ale.lives()
90+
self.life_lost = False
91+
# clear gamescreen
92+
self.gamescreen = np.empty(self.gamescreen.shape, dtype=np.uint8)
93+
return self.current_state
94+
95+
def execute(self, action):
96+
# convert action to ale action
97+
ale_action = self.action_inds[action]
98+
99+
# get reward and process terminal & next state
100+
rew = self.ale.act(ale_action)
101+
if self.loss_of_life_termination or self.loss_of_life_reward != 0:
102+
new_lives = self.ale.lives()
103+
if new_lives < self.cur_lives:
104+
self.cur_lives = new_lives
105+
self.life_lost = True
106+
rew += self.loss_of_life_reward
107+
108+
if self.reward_clipping is not None:
109+
rew = np.clip(rew, self.reward_clipping[0], self.reward_clipping[1])
110+
terminal = self.is_terminal
111+
state_tp1 = self.current_state
112+
return state_tp1, rew, terminal
113+
114+
@property
115+
def states(self):
116+
return dict(shape=self.gamescreen.shape, type=float)
117+
118+
@property
119+
def actions(self):
120+
return dict(continuous=False, num_actions=len(self.action_inds))
121+
122+
@property
123+
def current_state(self):
124+
self.gamescreen = self.ale.getScreenRGB(self.gamescreen)
125+
return self.gamescreen
126+
127+
@property
128+
def is_terminal(self):
129+
if self.loss_of_life_termination and self.life_lost:
130+
return True
131+
else:
132+
return self.ale.game_over()

0 commit comments

Comments
 (0)