Skip to content

Commit b5a1c4e

Browse files
saran-tcopybara-github
authored andcommitted
Move implementation of transform_{vec,xmat}_to_egocentric_frame from base.Walker to base.Entity.
New methods in `base.Entity` are renamed to `global_{vector,xmat}_to_local_frame`. PiperOrigin-RevId: 296900770 Change-Id: I99538d7c4e72661d179ddcd4a84bf8464357c99b
1 parent c85a70b commit b5a1c4e

4 files changed

Lines changed: 127 additions & 26 deletions

File tree

dm_control/composer/entity.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,84 @@ def parent(self):
335335
def attachment_site(self):
336336
return self.mjcf_model
337337

338+
@property
339+
def root_body(self):
340+
if self.parent:
341+
return mjcf.get_attachment_frame(self.mjcf_model)
342+
else:
343+
return self.mjcf_model.worldbody
344+
345+
def global_vector_to_local_frame(self, physics, vec_in_world_frame):
346+
"""Linearly transforms a world-frame vector into entity's local frame.
347+
348+
Note that this function does not perform an affine transformation of the
349+
vector. In other words, the input vector is assumed to be specified with
350+
respect to the same origin as this entity's local frame. This function
351+
can also be applied to matrices whose innermost dimensions are either 2 or
352+
3. In this case, a matrix with the same leading dimensions is returned
353+
where the innermost vectors are replaced by their values computed in the
354+
local frame.
355+
356+
Args:
357+
physics: An `mjcf.Physics` instance.
358+
vec_in_world_frame: A NumPy array with last dimension of shape (2,) or
359+
(3,) that represents a vector quantity in the world frame.
360+
361+
Returns:
362+
The same quantity as `vec_in_world_frame` but reexpressed in this
363+
entity's local frame. The returned np.array has the same shape as
364+
np.asarray(vec_in_world_frame).
365+
366+
Raises:
367+
ValueError: if `vec_in_world_frame` does not have shape ending with (2,)
368+
or (3,).
369+
"""
370+
vec_in_world_frame = np.asarray(vec_in_world_frame)
371+
372+
xmat = np.reshape(physics.bind(self.root_body).xmat, (3, 3))
373+
# The ordering of the np.dot is such that the transformation holds for any
374+
# matrix whose final dimensions are (2,) or (3,).
375+
if vec_in_world_frame.shape[-1] == 2:
376+
return np.dot(vec_in_world_frame, xmat[:2, :2])
377+
elif vec_in_world_frame.shape[-1] == 3:
378+
return np.dot(vec_in_world_frame, xmat)
379+
else:
380+
raise ValueError('`vec_in_world_frame` should have shape with final '
381+
'dimension 2 or 3: got {}'.format(
382+
vec_in_world_frame.shape))
383+
384+
def global_xmat_to_local_frame(self, physics, xmat):
385+
"""Transforms another entity's `xmat` into this entity's local frame.
386+
387+
This function takes another entity's (E) xmat, which is an SO(3) matrix
388+
from E's frame to the world frame, and turns it to a matrix that transforms
389+
from E's frame into this entity's local frame.
390+
391+
Args:
392+
physics: An `mjcf.Physics` instance.
393+
xmat: A NumPy array of shape (3, 3) or (9,) that represents another
394+
entity's xmat.
395+
396+
Returns:
397+
The `xmat` reexpressed in this entity's local frame. The returned
398+
np.array has the same shape as np.asarray(xmat).
399+
400+
Raises:
401+
ValueError: if `xmat` does not have shape (3, 3) or (9,).
402+
"""
403+
xmat = np.asarray(xmat)
404+
405+
input_shape = xmat.shape
406+
if xmat.shape == (9,):
407+
xmat = np.reshape(xmat, (3, 3))
408+
409+
self_xmat = np.reshape(physics.bind(self.root_body).xmat, (3, 3))
410+
if xmat.shape == (3, 3):
411+
return np.reshape(np.dot(self_xmat.T, xmat), input_shape)
412+
else:
413+
raise ValueError('`xmat` should have shape (3, 3) or (9,): got {}'.format(
414+
xmat.shape))
415+
338416
def get_pose(self, physics):
339417
"""Get the position and orientation of this entity relative to its parent.
340418

dm_control/composer/entity_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,51 @@ def testIterEntitiesExcludeSelf(self):
292292
self.assertEqual(
293293
list(entities[0].iter_entities(exclude_self=True)), entities[1:])
294294

295+
def testGlobalVectorToLocalFrame(self):
296+
parent = TestEntity()
297+
parent.mjcf_model.worldbody.add(
298+
'site', xyaxes=[0, 1, 0, -1, 0, 0]).attach(self.entity.mjcf_model)
299+
physics = mjcf.Physics.from_mjcf_model(parent.mjcf_model)
300+
301+
# 3D vectors
302+
np.testing.assert_allclose(
303+
self.entity.global_vector_to_local_frame(physics, [0, 1, 0]),
304+
[1, 0, 0], atol=1e-10)
305+
np.testing.assert_allclose(
306+
self.entity.global_vector_to_local_frame(physics, [-1, 0, 0]),
307+
[0, 1, 0], atol=1e-10)
308+
np.testing.assert_allclose(
309+
self.entity.global_vector_to_local_frame(physics, [0, 0, 1]),
310+
[0, 0, 1], atol=1e-10)
311+
312+
# 2D vectors; z-component is ignored
313+
np.testing.assert_allclose(
314+
self.entity.global_vector_to_local_frame(physics, [0, 1]),
315+
[1, 0], atol=1e-10)
316+
np.testing.assert_allclose(
317+
self.entity.global_vector_to_local_frame(physics, [-1, 0]),
318+
[0, 1], atol=1e-10)
319+
320+
def testGlobalMatrixToLocalFrame(self):
321+
parent = TestEntity()
322+
parent.mjcf_model.worldbody.add(
323+
'site', xyaxes=[0, 1, 0, -1, 0, 0]).attach(self.entity.mjcf_model)
324+
physics = mjcf.Physics.from_mjcf_model(parent.mjcf_model)
325+
326+
rotation_atob = np.array([[0, 1, 0], [0, 0, -1], [-1, 0, 0]])
327+
ego_rotation_atob = np.array([[0, 0, -1], [0, -1, 0], [-1, 0, 0]])
328+
329+
np.testing.assert_allclose(
330+
self.entity.global_xmat_to_local_frame(physics, rotation_atob),
331+
ego_rotation_atob, atol=1e-10)
332+
333+
flat_rotation_atob = np.reshape(rotation_atob, -1)
334+
flat_rotation_ego_atob = np.reshape(ego_rotation_atob, -1)
335+
np.testing.assert_allclose(
336+
self.entity.global_xmat_to_local_frame(
337+
physics, flat_rotation_atob),
338+
flat_rotation_ego_atob, atol=1e-10)
339+
295340
@parameterized.parameters(*_param_product(
296341
position=[None, [1., 0., -1.]],
297342
quaternion=[None, _FORTYFIVE_DEGREES_ABOUT_X, _NINETY_DEGREES_ABOUT_Z],

dm_control/locomotion/walkers/base.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,8 @@ def transform_vec_to_egocentric_frame(self, physics, vec_in_world_frame):
105105
ValueError: if `vec_in_world_frame` does not have shape ending with (2,)
106106
or (3,).
107107
"""
108-
vec_in_world_frame = np.asarray(vec_in_world_frame)
109-
110-
xmat = np.reshape(physics.bind(self.root_body).xmat, (3, 3))
111-
# The ordering of the np.dot is such that the transformation holds for any
112-
# matrix whose final dimensions are (2,) or (3,).
113-
if vec_in_world_frame.shape[-1] == 2:
114-
return np.dot(vec_in_world_frame, xmat[:2, :2])
115-
elif vec_in_world_frame.shape[-1] == 3:
116-
return np.dot(vec_in_world_frame, xmat)
117-
else:
118-
raise ValueError('`vec_in_world_frame` should have shape with final '
119-
'dimension 2 or 3: got {}'.format(
120-
vec_in_world_frame.shape))
108+
return super(Walker, self).global_vector_to_local_frame(
109+
physics, vec_in_world_frame)
121110

122111
def transform_xmat_to_egocentric_frame(self, physics, xmat):
123112
"""Transforms another entity's `xmat` into this walker's egocentric frame.
@@ -138,18 +127,7 @@ def transform_xmat_to_egocentric_frame(self, physics, xmat):
138127
Raises:
139128
ValueError: if `xmat` does not have shape (3, 3) or (9,).
140129
"""
141-
xmat = np.asarray(xmat)
142-
143-
input_shape = xmat.shape
144-
if xmat.shape == (9,):
145-
xmat = np.reshape(xmat, (3, 3))
146-
147-
self_xmat = np.reshape(physics.bind(self.root_body).xmat, (3, 3))
148-
if xmat.shape == (3, 3):
149-
return np.reshape(np.dot(self_xmat.T, xmat), input_shape)
150-
else:
151-
raise ValueError('`xmat` should have shape (3, 3) or (9,): got {}'.format(
152-
xmat.shape))
130+
return super(Walker, self).global_xmat_to_local_frame(physics, xmat)
153131

154132
@abc.abstractproperty
155133
def root_body(self):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def find_data_files(package_dir, patterns):
166166

167167
setup(
168168
name='dm_control',
169-
version='0.0.296467089',
169+
version='0.0.296900770',
170170
description='Continuous control environments and MuJoCo Python bindings.',
171171
author='DeepMind',
172172
license='Apache License, Version 2.0',

0 commit comments

Comments
 (0)