|
| 1 | +# Copyright 2017 reinforce.io. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | + |
| 16 | +""" |
| 17 | +Arcade Learning Environment execution |
| 18 | +""" |
| 19 | + |
| 20 | +from __future__ import absolute_import |
| 21 | +from __future__ import division |
| 22 | +from __future__ import print_function |
| 23 | + |
| 24 | +import argparse |
| 25 | +import logging |
| 26 | +import os |
| 27 | +import sys |
| 28 | + |
| 29 | +from tensorforce import Configuration, TensorForceError |
| 30 | +from tensorforce.core.networks import from_json |
| 31 | +from tensorforce.agents import agents |
| 32 | +from tensorforce.environments.ale import ALE |
| 33 | +from tensorforce.execution import Runner |
| 34 | + |
| 35 | + |
| 36 | +def main(): |
| 37 | + parser = argparse.ArgumentParser() |
| 38 | + |
| 39 | + parser.add_argument('rom', help="File path of the rom") |
| 40 | + parser.add_argument('-a', '--agent', help='Agent') |
| 41 | + parser.add_argument('-c', '--agent-config', help="Agent configuration file") |
| 42 | + parser.add_argument('-n', '--network-config', help="Network configuration file") |
| 43 | + parser.add_argument('-fs', '--frame-skip', help="Number of frames to repeat action", type=int, default=1) |
| 44 | + parser.add_argument('-rc', '--reward-clipping', help="Reward clipping. EX: -1 1", nargs="+", type=float) |
| 45 | + parser.add_argument('-rap', '--repeat-action-probability', help="Repeat action probability", type=float, default=0.0) |
| 46 | + parser.add_argument('-lolt', '--loss-of-life-termination', help="Loss of life counts as terminal state", action='store_true') |
| 47 | + parser.add_argument('-lolr', '--loss-of-life-reward', help="Loss of life reward/penalty. EX: -1 to penalize", type=float, default=0.0) |
| 48 | + parser.add_argument('-ds', '--display-screen', action='store_true', default=False, help="Display emulator screen") |
| 49 | + parser.add_argument('-e', '--episodes', type=int, default=50000, help="Number of episodes") |
| 50 | + parser.add_argument('-t', '--max-timesteps', type=int, default=2000, help="Maximum number of timesteps per episode") |
| 51 | + parser.add_argument('-s', '--save', help="Save agent to this dir") |
| 52 | + parser.add_argument('-se', '--save-episodes', type=int, default=100, help="Save agent every x episodes") |
| 53 | + parser.add_argument('-l', '--load', help="Load agent from this dir") |
| 54 | + parser.add_argument('-D', '--debug', action='store_true', default=False, help="Show debug outputs") |
| 55 | + |
| 56 | + args = parser.parse_args() |
| 57 | + |
| 58 | + logger = logging.getLogger(__name__) |
| 59 | + logger.setLevel(logging.DEBUG) # configurable!!! |
| 60 | + logger.addHandler(logging.StreamHandler(sys.stdout)) |
| 61 | + |
| 62 | + environment = ALE(args.rom, frame_skip=args.frame_skip, reward_clipping=args.reward_clipping, |
| 63 | + repeat_action_probability=args.repeat_action_probability, |
| 64 | + loss_of_life_termination=args.loss_of_life_termination, |
| 65 | + loss_of_life_reward=args.loss_of_life_reward, |
| 66 | + display_screen=args.display_screen) |
| 67 | + |
| 68 | + if args.agent_config: |
| 69 | + agent_config = Configuration.from_json(args.agent_config) |
| 70 | + else: |
| 71 | + agent_config = Configuration() |
| 72 | + logger.info("No agent configuration provided.") |
| 73 | + if args.network_config: |
| 74 | + network = from_json(args.network_config) |
| 75 | + else: |
| 76 | + network = None |
| 77 | + logger.info("No network configuration provided.") |
| 78 | + agent_config.default(dict(states=environment.states, actions=environment.actions, network=network)) |
| 79 | + agent = agents[args.agent](config=agent_config) |
| 80 | + |
| 81 | + if args.load: |
| 82 | + load_dir = os.path.dirname(args.load) |
| 83 | + if not os.path.isdir(load_dir): |
| 84 | + raise OSError("Could not load agent from {}: No such directory.".format(load_dir)) |
| 85 | + agent.load_model(args.load) |
| 86 | + |
| 87 | + if args.debug: |
| 88 | + logger.info("-" * 16) |
| 89 | + logger.info("Configuration:") |
| 90 | + logger.info(agent_config) |
| 91 | + |
| 92 | + if args.save: |
| 93 | + save_dir = os.path.dirname(args.save) |
| 94 | + if not os.path.isdir(save_dir): |
| 95 | + try: |
| 96 | + os.mkdir(save_dir, 0o755) |
| 97 | + except OSError: |
| 98 | + raise OSError("Cannot save agent to dir {} ()".format(save_dir)) |
| 99 | + |
| 100 | + runner = Runner( |
| 101 | + agent=agent, |
| 102 | + environment=environment, |
| 103 | + repeat_actions=1, |
| 104 | + save_path=args.save, |
| 105 | + save_episodes=args.save_episodes |
| 106 | + ) |
| 107 | + |
| 108 | + report_episodes = args.episodes // 1000 |
| 109 | + if args.debug: |
| 110 | + report_episodes = 1 |
| 111 | + |
| 112 | + def episode_finished(r): |
| 113 | + if r.episode % report_episodes == 0: |
| 114 | + logger.info("Finished episode {ep} after {ts} timesteps".format(ep=r.episode, ts=r.timestep)) |
| 115 | + logger.info("Episode reward: {}".format(r.episode_rewards[-1])) |
| 116 | + logger.info("Average of last 500 rewards: {}".format(sum(r.episode_rewards[-500:]) / 500)) |
| 117 | + logger.info("Average of last 100 rewards: {}".format(sum(r.episode_rewards[-100:]) / 100)) |
| 118 | + return True |
| 119 | + |
| 120 | + logger.info("Starting {agent} for Environment '{env}'".format(agent=agent, env=environment)) |
| 121 | + runner.run(args.episodes, args.max_timesteps, episode_finished=episode_finished) |
| 122 | + logger.info("Learning finished. Total episodes: {ep}".format(ep=runner.episode)) |
| 123 | + |
| 124 | + environment.close() |
| 125 | + |
| 126 | + |
| 127 | +if __name__ == '__main__': |
| 128 | + main() |
0 commit comments