2727from absl import logging
2828from dm_control import mjcf
2929from dm_control .composer import define
30+ from dm_control .mujoco .wrapper import mjbindings
3031import numpy as np
3132import six
3233
3839
3940# The component order differs from that used by the open-source `tf` package.
4041def _multiply_quaternions (quat1 , quat2 ):
41- """Multiplies two quaternions, expressed as [w, i, j, k]."""
42- return np .array ([
43- [+ quat1 [0 ], - quat1 [1 ], - quat1 [2 ], - quat1 [3 ]],
44- [+ quat1 [1 ], + quat1 [0 ], - quat1 [3 ], + quat1 [2 ]],
45- [+ quat1 [2 ], + quat1 [3 ], + quat1 [0 ], - quat1 [1 ]],
46- [+ quat1 [3 ], - quat1 [2 ], + quat1 [1 ], + quat1 [0 ]]]).dot (quat2 )
42+ result = np .empty_like (quat1 )
43+ mjbindings .mjlib .mju_mulQuat (result , quat1 , quat2 )
44+ return result
45+
46+
47+ def _rotate_vector (vec , quat ):
48+ """Rotates a vector by the given quaternion."""
49+ result = np .empty_like (vec )
50+ mjbindings .mjlib .mju_rotVecQuat (result , vec , quat )
51+ return result
4752
4853
4954class _ObservableKeys (object ):
@@ -407,7 +412,11 @@ def set_pose(self, physics, position=None, quaternion=None):
407412 normalised_quaternion = quaternion / np .linalg .norm (quaternion )
408413 physics .bind (attachment_frame ).quat = normalised_quaternion
409414
410- def shift_pose (self , physics , position = None , quaternion = None ):
415+ def shift_pose (self ,
416+ physics ,
417+ position = None ,
418+ quaternion = None ,
419+ rotate_velocity = False ):
411420 """Shifts the position and/or orientation from its current configuration.
412421
413422 This is a convenience function that performs the same operation as
@@ -419,6 +428,11 @@ def shift_pose(self, physics, position=None, quaternion=None):
419428 physics: An instance of `mjcf.Physics`.
420429 position: (optional) A NumPy array of size 3.
421430 quaternion: (optional) A NumPy array of size 4.
431+ rotate_velocity: (optional) A bool, whether to shift the current linear
432+ velocity along with the pose. This will rotate the current linear
433+ velocity, which is expressed relative to the world frame. The angular
434+ velocity, which is expressed relative to the local frame is left
435+ unchanged.
422436
423437 Raises:
424438 RuntimeError: If the entity is not attached.
@@ -428,7 +442,17 @@ def shift_pose(self, physics, position=None, quaternion=None):
428442 if position is not None :
429443 new_position = current_position + position
430444 if quaternion is not None :
445+ quaternion = np .array (quaternion , copy = False )
431446 new_quaternion = _multiply_quaternions (quaternion , current_quaternion )
447+ root_joint = mjcf .get_frame_freejoint (self .mjcf_model )
448+ if root_joint and rotate_velocity :
449+ # Rotate the linear velocity. The angular velocity (qvel[3:)
450+ # is left unchanged, as it is expressed in the local frame.
451+ # When rotatating the body frame the angular velocity already
452+ # tracks the rotation but the linear velocity does not.
453+ velocity = physics .bind (root_joint ).qvel [:3 ]
454+ rotated_velocity = _rotate_vector (velocity , quaternion )
455+ self .set_velocity (physics , rotated_velocity )
432456 self .set_pose (physics , new_position , new_quaternion )
433457
434458 def set_velocity (self , physics , velocity = None , angular_velocity = None ):
0 commit comments