1919from __future__ import division
2020from __future__ import print_function
2121
22+ import itertools
2223# Internal dependencies.
2324
2425from absl .testing import absltest
26+ from absl .testing import parameterized
2527from dm_control import mjcf
28+ from dm_control .composer import arena
2629from dm_control .composer import define
2730from dm_control .composer import entity
2831from dm_control .composer .observation .observable import base as observable
32+ import numpy as np
2933import six
3034from six .moves import range
3135
36+ _NO_ROTATION = np .array ([1. , 0. , 0. , 0 ])
37+ _NINETY_DEGREES_ABOUT_X = np .array (
38+ [np .cos (np .pi / 4 ), np .sin (np .pi / 4 ), 0. , 0. ])
39+ _NINETY_DEGREES_ABOUT_Y = np .array (
40+ [np .cos (np .pi / 4 ), 0. , np .sin (np .pi / 4 ), 0. ])
41+ _NINETY_DEGREES_ABOUT_Z = np .array (
42+ [np .cos (np .pi / 4 ), 0. , 0. , np .sin (np .pi / 4 )])
43+ _FORTYFIVE_DEGREES_ABOUT_X = np .array (
44+ [np .cos (np .pi / 8 ), np .sin (np .pi / 8 ), 0. , 0. ])
45+
46+ _TEST_ROTATIONS = [
47+ # Triplets of original rotation, new rotation and final rotation.
48+ (None , _NO_ROTATION , _NO_ROTATION ),
49+ (_NO_ROTATION , _NINETY_DEGREES_ABOUT_Z , _NINETY_DEGREES_ABOUT_Z ),
50+ (_FORTYFIVE_DEGREES_ABOUT_X , _NINETY_DEGREES_ABOUT_Y ,
51+ np .array ([0.65328 , 0.2706 , 0.65328 , - 0.2706 ])),
52+ ]
53+
3254
3355class TestEntity (entity .Entity ):
3456 """Simple test entity that does nothing but declare some observables."""
3557
3658 def _build (self , name = 'test_entity' ):
3759 self ._mjcf_root = mjcf .element .RootElement (model = name )
60+ self ._mjcf_root .worldbody .add ('geom' , type = 'sphere' , size = (0.1 ,))
3861
3962 def _build_observables (self ):
4063 return TestEntityObservables (self )
@@ -56,7 +79,7 @@ def observable1(self):
5679 return observable .Generic (lambda phys : 1.0 )
5780
5881
59- class EntityTest (absltest .TestCase ):
82+ class EntityTest (parameterized .TestCase ):
6083
6184 def setUp (self ):
6285 super (EntityTest , self ).setUp ()
@@ -263,6 +286,83 @@ def testIterEntitiesExcludeSelf(self):
263286 self .assertEqual (
264287 list (entities [0 ].iter_entities (exclude_self = True )), entities [1 :])
265288
289+ @parameterized .parameters (
290+ dict (position = position , quaternion = quaternion , freejoint = freejoint )
291+ for position , quaternion , freejoint in itertools .product (
292+ [None , [1. , 0. , - 1. ]], # position
293+ [
294+ None ,
295+ _FORTYFIVE_DEGREES_ABOUT_X ,
296+ _NINETY_DEGREES_ABOUT_Z ,
297+ ], # quaternion
298+ [False , True ], # freejoint
299+ ))
300+ def testSetPose (self , position , quaternion , freejoint ):
301+ # Setup entity.
302+ test_arena = arena .Arena ()
303+ subentity = TestEntity (name = 'subentity' )
304+ frame = test_arena .attach (subentity )
305+ if freejoint :
306+ frame .add ('freejoint' )
307+
308+ physics = mjcf .Physics .from_mjcf_model (test_arena .mjcf_model )
309+
310+ if quaternion is None :
311+ ground_truth_quat = _NO_ROTATION
312+ else :
313+ ground_truth_quat = quaternion
314+
315+ if position is None :
316+ ground_truth_pos = np .zeros (shape = (3 ,))
317+ else :
318+ ground_truth_pos = position
319+
320+ subentity .set_pose (physics , position = position , quaternion = quaternion )
321+
322+ np .testing .assert_array_equal (physics .bind (frame ).xpos , ground_truth_pos )
323+ np .testing .assert_array_equal (physics .bind (frame ).xquat , ground_truth_quat )
324+
325+ @parameterized .parameters (
326+ dict (
327+ original_position = original_position ,
328+ position = position ,
329+ original_quaternion = test_rotation [0 ],
330+ quaternion = test_rotation [1 ],
331+ expected_quaternion = test_rotation [2 ],
332+ freejoint = freejoint )
333+ for (original_position , position , test_rotation ,
334+ freejoint ) in itertools .product (
335+ [[- 2 , - 1 , - 1. ], [1. , 0. , - 1. ]], # original_position
336+ [None , [1. , 0. , - 1. ]], # position
337+ _TEST_ROTATIONS , # (original_quat, quat, expected_quat)
338+ [False , True ], # freejoint
339+ ))
340+ def testShiftPose (self , original_position , position , original_quaternion ,
341+ quaternion , expected_quaternion , freejoint ):
342+ # Setup entity.
343+ test_arena = arena .Arena ()
344+ subentity = TestEntity (name = 'subentity' )
345+ frame = test_arena .attach (subentity )
346+ if freejoint :
347+ frame .add ('freejoint' )
348+
349+ physics = mjcf .Physics .from_mjcf_model (test_arena .mjcf_model )
350+
351+ # Set the original position
352+ subentity .set_pose (
353+ physics , position = original_position , quaternion = original_quaternion )
354+
355+ if position is None :
356+ ground_truth_pos = original_position
357+ else :
358+ ground_truth_pos = original_position + np .array (position )
359+ subentity .shift_pose (physics , position = position , quaternion = quaternion )
360+ np .testing .assert_array_equal (physics .bind (frame ).xpos , ground_truth_pos )
361+
362+ updated_quat = physics .bind (frame ).xquat
363+ np .testing .assert_array_almost_equal (updated_quat , expected_quaternion ,
364+ 1e-4 )
365+
266366
267367if __name__ == '__main__' :
268368 absltest .main ()
0 commit comments