Skip to content

Commit 0277e43

Browse files
committed
Ensure that consecutive observations are not views onto the same memory addresses
This means that observation arrays will not be mutated when the environment steps (google-deepmind#34). PiperOrigin-RevId: 216901616
1 parent fe1a361 commit 0277e43

8 files changed

Lines changed: 29 additions & 17 deletions

File tree

dm_control/mujoco/engine.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -442,31 +442,31 @@ def _physics_state_items(self):
442442
# Named views of simulation data.
443443

444444
def control(self):
445-
"""Returns MuJoCo actuation vector as defined in the model."""
446-
return self.data.ctrl[:]
445+
"""Returns a copy of the control signals for the actuators."""
446+
return self.data.ctrl[:].copy()
447447

448448
def activation(self):
449-
"""Returns the internal states of 'filter' or 'integrator' actuators.
449+
"""Returns a copy of the internal states of actuators.
450450
451451
For details, please refer to
452452
http://www.mujoco.org/book/computation.html#geActuation
453453
454454
Returns:
455455
Activations in a numpy array.
456456
"""
457-
return self.data.act[:]
457+
return self.data.act[:].copy()
458458

459459
def state(self):
460460
"""Returns the full physics state. Alias for `get_physics_state`."""
461461
return np.concatenate(self._physics_state_items())
462462

463463
def position(self):
464-
"""Returns generalized positions (system configuration)."""
465-
return self.data.qpos[:]
464+
"""Returns a copy of the generalized positions (system configuration)."""
465+
return self.data.qpos[:].copy()
466466

467467
def velocity(self):
468-
"""Returns generalized velocities."""
469-
return self.data.qvel[:]
468+
"""Returns a copy of the generalized velocities."""
469+
return self.data.qvel[:].copy()
470470

471471
def timestep(self):
472472
"""Returns the simulation timestep."""

dm_control/suite/cheetah.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_observation(self, physics):
9191
"""Returns an observation of the state, ignoring horizontal position."""
9292
obs = collections.OrderedDict()
9393
# Ignores horizontal position to maintain translational invariance.
94-
obs['position'] = physics.data.qpos[1:]
94+
obs['position'] = physics.data.qpos[1:].copy()
9595
obs['velocity'] = physics.velocity()
9696
return obs
9797

dm_control/suite/hopper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def get_observation(self, physics):
117117
"""Returns an observation of positions, velocities and touch sensors."""
118118
obs = collections.OrderedDict()
119119
# Ignores horizontal position to maintain translational invariance:
120-
obs['position'] = physics.data.qpos[1:]
120+
obs['position'] = physics.data.qpos[1:].copy()
121121
obs['velocity'] = physics.velocity()
122122
obs['touch'] = physics.touch()
123123
return obs

dm_control/suite/humanoid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,19 +110,19 @@ def head_height(self):
110110

111111
def center_of_mass_position(self):
112112
"""Returns position of the center-of-mass."""
113-
return self.named.data.subtree_com['torso']
113+
return self.named.data.subtree_com['torso'].copy()
114114

115115
def center_of_mass_velocity(self):
116116
"""Returns the velocity of the center-of-mass."""
117-
return self.named.data.sensordata['torso_subtreelinvel']
117+
return self.named.data.sensordata['torso_subtreelinvel'].copy()
118118

119119
def torso_vertical_orientation(self):
120120
"""Returns the z-projection of the torso orientation matrix."""
121121
return self.named.data.xmat['torso', ['zx', 'zy', 'zz']]
122122

123123
def joint_angles(self):
124124
"""Returns the state without global orientation or position."""
125-
return self.data.qpos[7:] # Skip the 7 DoFs of the free root joint.
125+
return self.data.qpos[7:].copy() # Skip the 7 DoFs of the free root joint.
126126

127127
def extremities(self):
128128
"""Returns end effector positions in egocentric frame."""

dm_control/suite/humanoid_CMU.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,15 @@ def center_of_mass_position(self):
9090

9191
def center_of_mass_velocity(self):
9292
"""Returns the velocity of the center-of-mass."""
93-
return self.named.data.sensordata['thorax_subtreelinvel']
93+
return self.named.data.sensordata['thorax_subtreelinvel'].copy()
9494

9595
def torso_vertical_orientation(self):
9696
"""Returns the z-projection of the thorax orientation matrix."""
9797
return self.named.data.xmat['thorax', ['zx', 'zy', 'zz']]
9898

9999
def joint_angles(self):
100100
"""Returns the state without global orientation or position."""
101-
return self.data.qpos[7:] # Skip the 7 DoFs of the free root joint.
101+
return self.data.qpos[7:].copy() # Skip the 7 DoFs of the free root joint.
102102

103103
def extremities(self):
104104
"""Returns end effector positions in egocentric frame."""

dm_control/suite/pendulum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def pole_vertical(self):
6464

6565
def angular_velocity(self):
6666
"""Returns the angular velocity of the pole."""
67-
return self.named.data.qvel['hinge']
67+
return self.named.data.qvel['hinge'].copy()
6868

6969
def pole_orientation(self):
7070
"""Returns both horizontal and vertical components of pole frame."""

dm_control/suite/swimmer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def body_velocities(self):
163163

164164
def joints(self):
165165
"""Returns all internal joint angles (excluding root joints)."""
166-
return self.data.qpos[3:]
166+
return self.data.qpos[3:].copy()
167167

168168

169169
class Swimmer(base.Task):

dm_control/suite/tests/domains_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,5 +204,17 @@ def test_task_supports_environment_kwargs(self, domain, task):
204204
self.assertSetEqual(set(env.observation_spec()),
205205
{control.FLAT_OBSERVATION_KEY})
206206

207+
@parameterized.parameters(*suite.ALL_TASKS)
208+
def test_observation_arrays_dont_share_memory(self, domain, task):
209+
env = suite.load(domain, task)
210+
first_timestep = env.reset()
211+
action = np.zeros(env.action_spec().shape)
212+
second_timestep = env.step(action)
213+
for name, first_array in six.iteritems(first_timestep.observation):
214+
second_array = second_timestep.observation[name]
215+
self.assertFalse(
216+
np.may_share_memory(first_array, second_array),
217+
msg='Consecutive observations of {!r} may share memory.'.format(name))
218+
207219
if __name__ == '__main__':
208220
absltest.main()

0 commit comments

Comments
 (0)