Skip to content

Commit c8cfba9

Browse files
Lab requires conv2d default network
1 parent 35475d0 commit c8cfba9

2 files changed

Lines changed: 9 additions & 7 deletions

File tree

BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package(default_visibility = ["//visibility:public"])
33
tensorforce_args = [
44
"--agent VPGAgent",
55
"--agent-config /configs/vpg_agent.json",
6-
"--network-config /configs/vpg_network.json",
6+
"--network-config /configs/vpg_network_visual.json",
77
"--episodes 1000",
88
"--max-timesteps 1000"
99
]
@@ -12,7 +12,7 @@ py_library(
1212
name = "tensorforce",
1313
imports = [":tensorforce"],
1414
data = ["//tensorforce:examples/configs/vpg_agent.json",
15-
"//tensorforce:examples/configs/vpg_network.json"],
15+
"//tensorforce:examples/configs/vpg_network_visual.json"],
1616
srcs = glob(["tensorforce/**/*.py"])
1717
)
1818

tensorforce/environments/deepmind_lab.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,19 @@ class DeepMindLab(Environment):
5050
# level = deepmind_lab.Lab(level_id, ())
5151
# return level.action_spec()
5252

53-
def __init__(self, level_id, repeat_action=1, state_attributes=['RGB_INTERLACED'], settings={'width': '320', 'height': '240', 'fps': '60', 'appendCommand': ''}):
53+
def __init__(self, level_id, repeat_action=1, state_attribute='RGB_INTERLACED', settings={'width': '320', 'height': '240', 'fps': '60', 'appendCommand': ''}):
5454
"""
5555
Initialize DeepMind Lab environment.
5656
5757
:param level_id: string with id/descriptor of the level, e.g. 'seekavoid_arena_01'
5858
:param num_steps: number of frames the environment is advanced, executing the given action during every frame
59-
:param state_attributes: list of attributes which represent the state for this environment, should adhere to the specification given in DeepMindLabEnvironment.state_spec(level_id)
59+
:param state_attribute: Attributes which represents the state for this environment, should adhere to the specification given in DeepMindLabEnvironment.state_spec(level_id)
6060
:param settings: dict specifying additional settings as key-value string pairs. The following options are recognized: 'width' (horizontal resolution of the observation frames), 'height' (vertical resolution of the observation frames), 'fps' (frames per second) and 'appendCommand' (commands for the internal Quake console).
6161
"""
6262
self.level_id = level_id
63-
self.level = deepmind_lab.Lab(level=level_id, observations=state_attributes, config=settings)
63+
self.level = deepmind_lab.Lab(level=level_id, observations=[state_attribute], config=settings)
6464
self.repeat_action = repeat_action
65+
self.state_attribute = state_attribute
6566

6667
def __str__(self):
6768
return 'DeepMindLab({})'.format(self.level_id)
@@ -80,7 +81,7 @@ def reset(self):
8081
:return: initial state
8182
"""
8283
self.level.reset() # optional: episode=-1, seed=None
83-
return self.level.observations()['RGB_INTERLACED']
84+
return self.level.observations()[self.state_attribute]
8485

8586
def execute(self, action):
8687
"""
@@ -112,7 +113,8 @@ def states(self):
112113
if state_type == np.uint8:
113114
state_type = np.int32
114115

115-
states[state['name']] = dict(shape=state['shape'], type=state_type)
116+
if state['name'] == self.state_attribute:
117+
states[state['name']] = dict(shape=state['shape'], type=state_type)
116118

117119
return states
118120

0 commit comments

Comments
 (0)