Skip to content

Commit d240395

Browse files
Updating test configs for multi-pass.
1 parent b85f96b commit d240395

10 files changed

Lines changed: 72 additions & 15 deletions

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Features
5151
--------
5252

5353
TensorForce currently integrates with the OpenAI Gym API, OpenAI
54-
Universe, DeepMind lab, ALE and Maze explorer. The following algorithms are available (all
54+
Universe, the Unreal Engine (game engine), DeepMind lab, ALE and Maze explorer. The following algorithms are available (all
5555
policy methods both continuous/discrete and using a Beta distribution for bounded actions).
5656

5757
- A3C using distributed TensorFlow or a multithreaded runner - now as part of our generic Model

tensorforce/meta_parameter_recorder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def convert_dictionary_to_string(self, data, indent=0, format_type=0, separator=
104104
if separator is None:
105105
separator = ", "
106106

107-
#This should not ever occur but here as a catch
107+
# This should not ever occur but here as a catch
108108
if type(data) is not dict:
109109
raise TensorForceError(
110110
"Error: MetaParameterRecorder Dictionary conversion was passed a type {}"
@@ -140,7 +140,7 @@ def convert_list_to_string(self, data, indent=0, format_type=0, eol=None, count=
140140
if eol is None:
141141
eol = os.linesep
142142

143-
#This should not ever occur but here as a catch
143+
# This should not ever occur but here as a catch
144144
if type(data) is not list:
145145
raise TensorForceError(
146146
"Error: MetaParameterRecorder List conversion was passed a type {}"
@@ -171,7 +171,7 @@ def convert_ndarray_to_md(self, data, format_type=0, eol=None):
171171
if eol is None:
172172
eol = os.linesep
173173

174-
#This should not ever occur but here as a catch
174+
# This should not ever occur but here as a catch
175175
if type(data) is not np.ndarray:
176176
raise TensorForceError(
177177
"Error: MetaParameterRecorder ndarray conversion was passed"
@@ -254,7 +254,7 @@ def build_metagraph_list(self):
254254

255255
self.ignore_unknown_dtypes = True
256256
for key in sorted(self.meta_params):
257-
value=self.convert_data_to_string(self.meta_params[key])
257+
value = self.convert_data_to_string(self.meta_params[key])
258258

259259
if len(value) == 0:
260260
continue

tensorforce/tests/base_agent_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def tf_apply(self, x, internals, update, return_internals=False):
193193
name='multi',
194194
environment=environment,
195195
network_spec=CustomNetwork,
196-
**self.__class__.kwargs
196+
**self.__class__.multi_kwargs
197197
)
198198

199199
def test_lstm(self):

tensorforce/tests/test_dqfd_agent.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,20 @@ def pre_run(self, agent, environment):
107107

108108
agent.import_demonstrations(demonstrations)
109109
agent.pretrain(steps=1000)
110+
111+
multi_kwargs = dict(
112+
memory=dict(
113+
type='replay',
114+
capacity=1000
115+
),
116+
optimizer=dict(
117+
type="adam",
118+
learning_rate=0.01
119+
),
120+
repeat_update=1,
121+
batch_size=16,
122+
first_update=16,
123+
target_sync_frequency=10,
124+
demo_memory_capacity=100,
125+
demo_sampling_ratio=0.2
126+
)

tensorforce/tests/test_dqn_agent.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class TestDQNAgent(BaseAgentTest, unittest.TestCase):
3838
learning_rate=0.002
3939
),
4040
# Comment in to test exploration types
41-
# exploration=dict(
41+
# explorations_spec=dict(
4242
# type="epsilon_decay",
4343
# initial_epsilon=1.0,
4444
# final_epsilon=0.1,
@@ -58,3 +58,18 @@ class TestDQNAgent(BaseAgentTest, unittest.TestCase):
5858

5959
exclude_float = True
6060
exclude_bounded = True
61+
62+
multi_kwargs = dict(
63+
memory=dict(
64+
type='replay',
65+
capacity=1000
66+
),
67+
optimizer=dict(
68+
type="adam",
69+
learning_rate=0.01
70+
),
71+
repeat_update=1,
72+
batch_size=16,
73+
first_update=16,
74+
target_sync_frequency=10
75+
)

tensorforce/tests/test_dqn_memories.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def test_replay(self):
4040
type='replay',
4141
capacity=1000
4242
),
43-
batch_size=8,
44-
first_update=10,
43+
repeat_update=4,
44+
batch_size=32,
45+
first_update=64,
4546
target_sync_frequency=10
4647
)
4748

@@ -64,8 +65,9 @@ def test_prioritized_replay(self):
6465
type='prioritized_replay',
6566
capacity=1000
6667
),
67-
batch_size=8,
68-
first_update=10,
68+
repeat_update=4,
69+
batch_size=32,
70+
first_update=64,
6971
target_sync_frequency=10
7072
)
7173

@@ -87,8 +89,9 @@ def test_naive_prioritized_replay(self):
8789
type='naive_prioritized_replay',
8890
capacity=1000
8991
),
90-
batch_size=8,
91-
first_update=10,
92+
repeat_update=4,
93+
batch_size=32,
94+
first_update=64,
9295
target_sync_frequency=10
9396
)
9497

tensorforce/tests/test_dqn_nstep_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from tensorforce.agents import DQNNstepAgent
2424

2525

26-
2726
class TestDQNNstepAgent(BaseAgentTest, unittest.TestCase):
2827

2928
agent = DQNNstepAgent
@@ -33,9 +32,10 @@ class TestDQNNstepAgent(BaseAgentTest, unittest.TestCase):
3332
batch_size=8,
3433
optimizer=dict(
3534
type='adam',
36-
learning_rate=1e-2
35+
learning_rate=0.01
3736
)
3837
)
3938

4039
exclude_float = True
4140
exclude_bounded = True
41+
exclude_multi = True

tensorforce/tests/test_ppo_agent.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,11 @@ class TestPPOAgent(BaseAgentTest, unittest.TestCase):
3030
kwargs = dict(
3131
batch_size=8
3232
)
33+
34+
multi_kwargs = dict(
35+
batch_size=32,
36+
step_optimizer=dict(
37+
type='adam',
38+
learning_rate=0.001
39+
)
40+
)

tensorforce/tests/test_trpo_agent.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,8 @@ class TestTRPOAgent(BaseAgentTest, unittest.TestCase):
3131
kwargs = dict(
3232
batch_size=8
3333
)
34+
35+
multi_kwargs = dict(
36+
batch_size=64,
37+
learning_rate=0.1
38+
)

tensorforce/tests/test_vpg_agent.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,12 @@ class TestVPGAgent(BaseAgentTest, unittest.TestCase):
3030
kwargs = dict(
3131
batch_size=8
3232
)
33+
34+
multi_kwargs = dict(
35+
batch_size=64,
36+
optimizer=dict(
37+
type='adam',
38+
learning_rate=0.01
39+
)
40+
)
41+

0 commit comments

Comments
 (0)