Skip to content

Commit 83c8bdd

Browse files
DeepMindalimuldal
authored andcommitted
Add tests for set_pose and shift_pose
PiperOrigin-RevId: 230677378
1 parent 86e55b3 commit 83c8bdd

1 file changed

Lines changed: 101 additions & 1 deletion

File tree

dm_control/composer/entity_test.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,45 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import itertools
2223
# Internal dependencies.
2324

2425
from absl.testing import absltest
26+
from absl.testing import parameterized
2527
from dm_control import mjcf
28+
from dm_control.composer import arena
2629
from dm_control.composer import define
2730
from dm_control.composer import entity
2831
from dm_control.composer.observation.observable import base as observable
32+
import numpy as np
2933
import six
3034
from six.moves import range
3135

36+
_NO_ROTATION = np.array([1., 0., 0., 0])
37+
_NINETY_DEGREES_ABOUT_X = np.array(
38+
[np.cos(np.pi / 4), np.sin(np.pi / 4), 0., 0.])
39+
_NINETY_DEGREES_ABOUT_Y = np.array(
40+
[np.cos(np.pi / 4), 0., np.sin(np.pi / 4), 0.])
41+
_NINETY_DEGREES_ABOUT_Z = np.array(
42+
[np.cos(np.pi / 4), 0., 0., np.sin(np.pi / 4)])
43+
_FORTYFIVE_DEGREES_ABOUT_X = np.array(
44+
[np.cos(np.pi / 8), np.sin(np.pi / 8), 0., 0.])
45+
46+
_TEST_ROTATIONS = [
47+
# Triplets of original rotation, new rotation and final rotation.
48+
(None, _NO_ROTATION, _NO_ROTATION),
49+
(_NO_ROTATION, _NINETY_DEGREES_ABOUT_Z, _NINETY_DEGREES_ABOUT_Z),
50+
(_FORTYFIVE_DEGREES_ABOUT_X, _NINETY_DEGREES_ABOUT_Y,
51+
np.array([0.65328, 0.2706, 0.65328, -0.2706])),
52+
]
53+
3254

3355
class TestEntity(entity.Entity):
3456
"""Simple test entity that does nothing but declare some observables."""
3557

3658
def _build(self, name='test_entity'):
3759
self._mjcf_root = mjcf.element.RootElement(model=name)
60+
self._mjcf_root.worldbody.add('geom', type='sphere', size=(0.1,))
3861

3962
def _build_observables(self):
4063
return TestEntityObservables(self)
@@ -56,7 +79,7 @@ def observable1(self):
5679
return observable.Generic(lambda phys: 1.0)
5780

5881

59-
class EntityTest(absltest.TestCase):
82+
class EntityTest(parameterized.TestCase):
6083

6184
def setUp(self):
6285
super(EntityTest, self).setUp()
@@ -263,6 +286,83 @@ def testIterEntitiesExcludeSelf(self):
263286
self.assertEqual(
264287
list(entities[0].iter_entities(exclude_self=True)), entities[1:])
265288

289+
@parameterized.parameters(
290+
dict(position=position, quaternion=quaternion, freejoint=freejoint)
291+
for position, quaternion, freejoint in itertools.product(
292+
[None, [1., 0., -1.]], # position
293+
[
294+
None,
295+
_FORTYFIVE_DEGREES_ABOUT_X,
296+
_NINETY_DEGREES_ABOUT_Z,
297+
], # quaternion
298+
[False, True], # freejoint
299+
))
300+
def testSetPose(self, position, quaternion, freejoint):
301+
# Setup entity.
302+
test_arena = arena.Arena()
303+
subentity = TestEntity(name='subentity')
304+
frame = test_arena.attach(subentity)
305+
if freejoint:
306+
frame.add('freejoint')
307+
308+
physics = mjcf.Physics.from_mjcf_model(test_arena.mjcf_model)
309+
310+
if quaternion is None:
311+
ground_truth_quat = _NO_ROTATION
312+
else:
313+
ground_truth_quat = quaternion
314+
315+
if position is None:
316+
ground_truth_pos = np.zeros(shape=(3,))
317+
else:
318+
ground_truth_pos = position
319+
320+
subentity.set_pose(physics, position=position, quaternion=quaternion)
321+
322+
np.testing.assert_array_equal(physics.bind(frame).xpos, ground_truth_pos)
323+
np.testing.assert_array_equal(physics.bind(frame).xquat, ground_truth_quat)
324+
325+
@parameterized.parameters(
326+
dict(
327+
original_position=original_position,
328+
position=position,
329+
original_quaternion=test_rotation[0],
330+
quaternion=test_rotation[1],
331+
expected_quaternion=test_rotation[2],
332+
freejoint=freejoint)
333+
for (original_position, position, test_rotation,
334+
freejoint) in itertools.product(
335+
[[-2, -1, -1.], [1., 0., -1.]], # original_position
336+
[None, [1., 0., -1.]], # position
337+
_TEST_ROTATIONS, # (original_quat, quat, expected_quat)
338+
[False, True], # freejoint
339+
))
340+
def testShiftPose(self, original_position, position, original_quaternion,
341+
quaternion, expected_quaternion, freejoint):
342+
# Setup entity.
343+
test_arena = arena.Arena()
344+
subentity = TestEntity(name='subentity')
345+
frame = test_arena.attach(subentity)
346+
if freejoint:
347+
frame.add('freejoint')
348+
349+
physics = mjcf.Physics.from_mjcf_model(test_arena.mjcf_model)
350+
351+
# Set the original position
352+
subentity.set_pose(
353+
physics, position=original_position, quaternion=original_quaternion)
354+
355+
if position is None:
356+
ground_truth_pos = original_position
357+
else:
358+
ground_truth_pos = original_position + np.array(position)
359+
subentity.shift_pose(physics, position=position, quaternion=quaternion)
360+
np.testing.assert_array_equal(physics.bind(frame).xpos, ground_truth_pos)
361+
362+
updated_quat = physics.bind(frame).xquat
363+
np.testing.assert_array_almost_equal(updated_quat, expected_quaternion,
364+
1e-4)
365+
266366

267367
if __name__ == '__main__':
268368
absltest.main()

0 commit comments

Comments
 (0)