Skip to content

Commit 314ce15

Browse files
DeepMindalimuldal
authored andcommitted
Add option to shift_pose to rotate velocities as well.
PiperOrigin-RevId: 230716764
1 parent 92f9913 commit 314ce15

2 files changed

Lines changed: 60 additions & 7 deletions

File tree

dm_control/composer/entity.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from absl import logging
2828
from dm_control import mjcf
2929
from dm_control.composer import define
30+
from dm_control.mujoco.wrapper import mjbindings
3031
import numpy as np
3132
import six
3233

@@ -38,12 +39,16 @@
3839

3940
# The component order differs from that used by the open-source `tf` package.
4041
def _multiply_quaternions(quat1, quat2):
41-
"""Multiplies two quaternions, expressed as [w, i, j, k]."""
42-
return np.array([
43-
[+quat1[0], -quat1[1], -quat1[2], -quat1[3]],
44-
[+quat1[1], +quat1[0], -quat1[3], +quat1[2]],
45-
[+quat1[2], +quat1[3], +quat1[0], -quat1[1]],
46-
[+quat1[3], -quat1[2], +quat1[1], +quat1[0]]]).dot(quat2)
42+
result = np.empty_like(quat1)
43+
mjbindings.mjlib.mju_mulQuat(result, quat1, quat2)
44+
return result
45+
46+
47+
def _rotate_vector(vec, quat):
48+
"""Rotates a vector by the given quaternion."""
49+
result = np.empty_like(vec)
50+
mjbindings.mjlib.mju_rotVecQuat(result, vec, quat)
51+
return result
4752

4853

4954
class _ObservableKeys(object):
@@ -407,7 +412,11 @@ def set_pose(self, physics, position=None, quaternion=None):
407412
normalised_quaternion = quaternion / np.linalg.norm(quaternion)
408413
physics.bind(attachment_frame).quat = normalised_quaternion
409414

410-
def shift_pose(self, physics, position=None, quaternion=None):
415+
def shift_pose(self,
416+
physics,
417+
position=None,
418+
quaternion=None,
419+
rotate_velocity=False):
411420
"""Shifts the position and/or orientation from its current configuration.
412421
413422
This is a convenience function that performs the same operation as
@@ -419,6 +428,11 @@ def shift_pose(self, physics, position=None, quaternion=None):
419428
physics: An instance of `mjcf.Physics`.
420429
position: (optional) A NumPy array of size 3.
421430
quaternion: (optional) A NumPy array of size 4.
431+
rotate_velocity: (optional) A bool, whether to shift the current linear
432+
velocity along with the pose. This will rotate the current linear
433+
velocity, which is expressed relative to the world frame. The angular
434+
velocity, which is expressed relative to the local frame is left
435+
unchanged.
422436
423437
Raises:
424438
RuntimeError: If the entity is not attached.
@@ -428,7 +442,17 @@ def shift_pose(self, physics, position=None, quaternion=None):
428442
if position is not None:
429443
new_position = current_position + position
430444
if quaternion is not None:
445+
quaternion = np.array(quaternion, copy=False)
431446
new_quaternion = _multiply_quaternions(quaternion, current_quaternion)
447+
root_joint = mjcf.get_frame_freejoint(self.mjcf_model)
448+
if root_joint and rotate_velocity:
449+
# Rotate the linear velocity. The angular velocity (qvel[3:)
450+
# is left unchanged, as it is expressed in the local frame.
451+
# When rotatating the body frame the angular velocity already
452+
# tracks the rotation but the linear velocity does not.
453+
velocity = physics.bind(root_joint).qvel[:3]
454+
rotated_velocity = _rotate_vector(velocity, quaternion)
455+
self.set_velocity(physics, rotated_velocity)
432456
self.set_pose(physics, new_position, new_quaternion)
433457

434458
def set_velocity(self, physics, velocity=None, angular_velocity=None):

dm_control/composer/entity_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,35 @@ def testShiftPose(self, original_position, position, original_quaternion,
363363
np.testing.assert_array_almost_equal(updated_quat, expected_quaternion,
364364
1e-4)
365365

366+
@parameterized.parameters(False, True)
367+
def testShiftPoseWithVelocity(self, rotate_velocity):
368+
# Setup entity.
369+
test_arena = arena.Arena()
370+
subentity = TestEntity(name='subentity')
371+
frame = test_arena.attach(subentity)
372+
frame.add('freejoint')
373+
374+
physics = mjcf.Physics.from_mjcf_model(test_arena.mjcf_model)
375+
376+
# Set the original position
377+
subentity.set_pose(physics, position=[0., 0., 0.])
378+
379+
# Set velocity in y dim.
380+
subentity.set_velocity(physics, [0., 1., 0.])
381+
382+
# Rotate the entity around the z axis.
383+
subentity.shift_pose(
384+
physics, quaternion=[0., 0., 0., 1.], rotate_velocity=rotate_velocity)
385+
386+
physics.forward()
387+
updated_position, _ = subentity.get_pose(physics)
388+
if rotate_velocity:
389+
# Should not have moved in the y dim.
390+
np.testing.assert_array_almost_equal(updated_position[1], 0.)
391+
else:
392+
# Should not have moved in the x dim.
393+
np.testing.assert_array_almost_equal(updated_position[0], 0.)
394+
366395

367396
if __name__ == '__main__':
368397
absltest.main()

0 commit comments

Comments
 (0)