Skip to content

Commit cabe817

Browse files
committed
Use a step counter rather than physics.time() to enforce episode time limits
This circumvents a problem where the step count sometimes differs from `time_limit / control_timestep` due to accumulation of rounding error in `mjData->time`. PiperOrigin-RevId: 192746383
1 parent b7c85e5 commit cabe817

2 files changed

Lines changed: 33 additions & 13 deletions

File tree

dm_control/rl/control.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
import collections
2424
import contextlib
2525

26-
# Internal dependencies.
27-
2826
import numpy as np
2927
import six
3028
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -64,7 +62,6 @@ def __init__(self,
6462
"""
6563
self._task = task
6664
self._physics = physics
67-
self._time_limit = time_limit
6865
self._flat_observation = flat_observation
6966

7067
if n_sub_steps is not None and control_timestep is not None:
@@ -77,11 +74,18 @@ def __init__(self,
7774
else:
7875
self._n_sub_steps = 1
7976

77+
if time_limit == float('inf'):
78+
self._step_limit = float('inf')
79+
else:
80+
self._step_limit = time_limit / (
81+
self._physics.timestep() * self._n_sub_steps)
82+
self._step_count = 0
8083
self._reset_next_step = True
8184

8285
def reset(self):
8386
"""Starts a new episode and returns the first `TimeStep`."""
8487
self._reset_next_step = False
88+
self._step_count = 0
8589
with self._physics.reset_context():
8690
self._task.initialize_episode(self._physics)
8791

@@ -111,18 +115,21 @@ def step(self, action):
111115
if self._flat_observation:
112116
observation = flatten_observation(observation)
113117

114-
if self.physics.time() >= self._time_limit:
118+
self._step_count += 1
119+
if self._step_count >= self._step_limit:
115120
discount = 1.0
116121
else:
117122
discount = self._task.get_termination(self._physics)
118123

119-
if discount is None:
120-
return environment.TimeStep(
121-
environment.StepType.MID, reward, 1.0, observation)
122-
else:
124+
episode_over = discount is not None
125+
126+
if episode_over:
123127
self._reset_next_step = True
124128
return environment.TimeStep(
125129
environment.StepType.LAST, reward, discount, observation)
130+
else:
131+
return environment.TimeStep(
132+
environment.StepType.MID, reward, 1.0, observation)
126133

127134
def action_spec(self):
128135
"""Returns the action specification for this environment."""

dm_control/rl/control_test.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,25 @@ def test_environment_calls(self):
7474

7575
self.assertEquals(_CONSTANT_REWARD_VALUE, time_step.reward)
7676

77-
def test_timeout(self):
78-
self._physics.time = mock.Mock(return_value=2.)
77+
@parameterized.parameters(
78+
{'physics_timestep': .01, 'control_timestep': None,
79+
'expected_steps': 1000},
80+
{'physics_timestep': .01, 'control_timestep': .05,
81+
'expected_steps': 5000})
82+
def test_timeout(self, expected_steps, physics_timestep, control_timestep):
83+
self._physics.timestep.return_value = physics_timestep
84+
time_limit = expected_steps * (control_timestep or physics_timestep)
7985
env = control.Environment(
80-
physics=self._physics, task=self._task, time_limit=1.)
81-
env.reset()
82-
time_step = env.step([1])
86+
physics=self._physics, task=self._task, time_limit=time_limit,
87+
control_timestep=control_timestep)
88+
89+
time_step = env.reset()
90+
steps = 0
91+
while not time_step.last():
92+
time_step = env.step([1])
93+
steps += 1
94+
95+
self.assertEqual(steps, expected_steps)
8396
self.assertTrue(time_step.last())
8497

8598
time_step = env.step([1])

0 commit comments

Comments
 (0)