Skip to content

Commit a56b451

Browse files
committed
Add a utility module for performing inverse kinematics on MuJoCo models
PiperOrigin-RevId: 207116592
1 parent 67ed879 commit a56b451

4 files changed

Lines changed: 594 additions & 0 deletions

File tree

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
<mujoco model='jaco'>
2+
<compiler coordinate='local' angle='radian' eulerseq='yxz'/>
3+
4+
<contact>
5+
<exclude body1='b_base' body2='b_1'/>
6+
<exclude body1='b_finger_1' body2='b_finger_2'/>
7+
<exclude body1='b_finger_1' body2='b_finger_3'/>
8+
<exclude body1='b_finger_2' body2='b_finger_3'/>
9+
</contact>
10+
11+
<default>
12+
<geom contype='1' conaffinity='1' condim='3' friction='.1 .1' solimp='.95 .98 .0005' solref='0.02 1.1' density='1'/>
13+
<joint type='hinge' armature='0' damping='0'
14+
solimpfriction='.95 .95 0' solreffriction='.02 1' solimplimit='0 .99 .01' solreflimit='.02 1'/>
15+
<default class='finger'>
16+
<joint frictionloss='0.1' armature='.1' axis='0 0 1' limited='true' range='-0.15 1.2' damping='.1'/>
17+
<velocity kv='0.5' ctrlrange='-1 1' gear='1' forcerange='-1.5 1.5'/>
18+
<geom type='capsule' size="0.0108231 0.0145116" pos="0.0190913 -0.0112192 1.10387e-06" quat="0.475811 -0.475815 0.523071 -0.523068"/>
19+
</default>
20+
<default class='fingertip'>
21+
<geom type='box' pos='0.065 -0.022 0' size='0.02 .005 .011' rgba='1 0 0 1'
22+
quat='1 0 0 -0.235' condim='4' />
23+
<site type='ellipsoid' pos='0.065 -0.022 0' size='0.02 .005 .011'
24+
quat='1 0 0 -0.235'/>
25+
</default>
26+
</default>
27+
28+
<worldbody>
29+
<geom name='ground' type='plane' pos='0 0 0.06' size='1 1 1' rgba='0.78 0.93 0.92 1.0' margin='0.1' gap='0.1' />
30+
<body name='b_base' pos='0 0 0.005'>
31+
<inertial pos='-0.000132925 -0.000516418 0.0820911' quat='0.6951 -0.00714556 0.00738537 0.71884' mass='0.00718278' diaginertia='0.00157301 0.00157269 0.00058831' />
32+
<geom name='base' rgba='0 0.4470 0.7410 1' type='capsule' size="0.0350512 0.0555543" pos="-0.000132925 -0.000516418 0.0820911" quat="0.6951 -0.00714564 0.00738544 0.71884" />
33+
<body name='b_1' pos='0 0 0.1535' quat='0 0 1 0'>
34+
<inertial pos='-3.49946e-005 0.010651 -0.0670012' quat='0.704277 0.0627825 -0.0628 0.70435' mass='0.00644357' diaginertia='0.00172311 0.00163259 0.00042883' />
35+
<joint name='joint_1' axis='0 0 -1' />
36+
<geom name='link_1' rgba='0.8500 0.3250 0.0980 1' type='capsule' size="0.031417 0.0668352" pos="-3.49946e-05 0.010651 -0.0670012" quat="0.704277 0.0627825 -0.0628 0.70435" />
37+
<body name='b_2' pos='0 0 -0.1185' quat='0 0 1 1'>
38+
<inertial pos='0.205004 7.5835e-005 -0.0202812' quat='2.24027e-006 0.707108 5.30493e-006 0.707106' mass='0.011691' diaginertia='0.0253426 0.0250512 0.000536357' />
39+
<joint name='joint_2' axis='0 0 1' range='-3.4906585 0.34906585' limited='true' />
40+
<geom name='link_2' rgba='0.9290 0.6940 0.1250 1' type='capsule' size="0.0251592 0.240341" pos="0.205004 7.5835e-05 -0.0202812" quat="2.23957e-06 0.707108 5.30416e-06 0.707106" />
41+
<body name='b_3' pos='0.41 0 0' quat='0 0.707107 0.707107 0'>
42+
<inertial pos='0.0843942 7.48909e-006 -0.0177123' quat='-4.64469e-005 0.692702 -4.11798e-005 0.721224' mass='0.00673417' diaginertia='0.00420255 0.00409328 0.000303584' />
43+
<joint name='joint_3' axis='0 0 -1' range='-4.01425728 0.872664626' limited='true' />
44+
<geom name='link_3' rgba='0.4940 0.1840 0.5560 1' type='capsule' size="0.0255647 0.120643" pos="0.0843942 7.48907e-06 -0.0177123" quat="-4.64475e-05 0.692702 -4.11814e-05 0.721224" />
45+
<body name='b_4' pos='0.207 0 -0.01125' quat='0 0.707107 0 -0.707107'>
46+
<inertial pos='0.0101659 -4.64261e-005 -0.0369867' quat='0.965925 1.91737e-005 -0.258823 -1.40769e-005' mass='0.00221423' diaginertia='0.0001816 0.000174586 9.79319e-005' />
47+
<joint name='joint_4' axis='0 0 -1' />
48+
<geom name='link_4' rgba='0.4660 0.6740 0.1880 1' type='capsule' size="0.0257405 0.0289568" pos="0.0101659 -4.64261e-05 -0.0369867" quat="0.965925 1.43712e-05 -0.258823 3.85082e-06" />
49+
<body name='b_5' pos='0.037 0 -0.06408' quat='0.866025 0 -0.5 0'>
50+
<inertial pos='0.0101659 -4.64261e-005 -0.0369867' quat='0.965925 1.91737e-005 -0.258823 -1.40769e-005' mass='0.00221423' diaginertia='0.0001816 0.000174586 9.79319e-005' />
51+
<joint name='joint_5' axis='0 0 -1' />
52+
<geom name='link_5' rgba='0.3010 0.7450 0.9330 1' type='capsule' size="0.0257405 0.0289568" pos="0.0101659 -4.64261e-05 -0.0369867" quat="0.965925 1.43712e-05 -0.258823 3.85082e-06" />
53+
<body name='b_hand' pos='0.037 0 -0.06408' quat='0.612372 -0.353553 -0.353553 0.612372'>
54+
<site type='sphere' size='.01' name='gripsite' pos='0 0 -.16' rgba='.5 .5 .5 .3' />
55+
<site type='sphere' size='.01' name='pinchsite' pos='0 0.015 -0.195' rgba='.5 .5 .5 .3' />
56+
<inertial pos='0.00628384 -2.92087e-005 -0.0608681' quat='0.708562 -0.0338601 -0.0358744 0.703923' mass='0.00547172' diaginertia='0.000759818 0.000676099 0.0004995' />
57+
<joint name='joint_6' axis='0 0 -1' range='-6.28319 6.28319' limited='false'/>
58+
<geom name='link_6' rgba='0.6350 0.0780 0.1840 1' type='capsule' size="0.0368731 0.0322296" pos="0.00628384 -2.92087e-05 -0.0608681" quat="0.708562 -0.0338601 -0.0358744 0.703923"/>
59+
<!-- optional collision geom to be used if issues arise with the mesh
60+
<geom type='ellipsoid' rgba='0.6350 0.0780 0.1840 1' size='.035 .025 .01' pos='.005 0 -.117' group='1'/> -->
61+
<body name='b_finger_1' childclass='finger' pos='-0.029 .003 -0.1145' quat='-0.414818 -0.329751 -0.663854 0.52772'>
62+
<inertial pos='0.0485761 -0.000715511 0' quat='0.507589 0.507348 0.492543 0.492294' mass='0.000379077' diaginertia='4.00708e-005 4.00527e-005 2.156e-006' />
63+
<joint name='joint_finger_1'/>
64+
<geom name='finger_knuckle_1' rgba='0 0.4470 0.7410 1'/>
65+
<geom class='fingertip' name='finger_tip_1' rgba='0 0.4470 0.7410 1'/>
66+
<site class='fingertip' name='fingertip1'/>
67+
</body>
68+
<body name='b_finger_2' childclass='finger' pos='0.0295 0.0216 -0.115' quat='0.561254 -0.620653 0.321748 0.443014'>
69+
<inertial pos='0.0485761 -0.000715511 0' quat='0.507589 0.507348 0.492543 0.492294' mass='0.000379077' diaginertia='4.00708e-005 4.00527e-005 2.156e-006' />
70+
<joint name='joint_finger_2'/>
71+
<geom name='finger_knuckle_2' rgba='0.9290 0.6940 0.1250 1'/>
72+
<geom class='fingertip' name='finger_tip_2' rgba='0.9290 0.6940 0.1250 1'/>
73+
<site class='fingertip' name='fingertip2'/>
74+
</body>
75+
<body name='b_finger_3' childclass='finger' pos='0.0295 -0.0216 -0.1145' quat='0.625248 -0.567602 0.434845 0.312735'>
76+
<inertial pos='0.0485761 -0.000715511 0' quat='0.507589 0.507348 0.492543 0.492294' mass='0.000379077' diaginertia='4.00708e-005 4.00527e-005 2.156e-006' />
77+
<joint name='joint_finger_3'/>
78+
<geom name='finger_knuckle_3' rgba='0.8500 0.3250 0.0980 1'/>
79+
<geom class='fingertip' name='finger_tip_3' rgba='0.8500 0.3250 0.0980 1'/>
80+
<site class='fingertip' name='fingertip3'/>
81+
</body>
82+
</body>
83+
</body>
84+
</body>
85+
</body>
86+
</body>
87+
</body>
88+
</body>
89+
</worldbody>
90+
91+
</mujoco>
92+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
<mujoco>
2+
<default>
3+
<geom type='capsule' size='0.01'/>
4+
<site type='sphere' size='0.03'/>
5+
<joint type='ball' damping='0.005'/>
6+
</default>
7+
<worldbody>
8+
<body>
9+
<geom fromto='0 0 0 0 0 0.1'/>
10+
<body pos='0 0 0.1'>
11+
<geom fromto='0 0 0 0 0.1 0.1'/>
12+
<joint name='joint_1'/>
13+
<body pos='0 0.1 0.1'>
14+
<geom fromto='0 0 0 0 0.1 0.0'/>
15+
<joint name='joint_2'/>
16+
<site name='gripsite' pos='0 0.1 0'/>
17+
</body>
18+
</body>
19+
</body>
20+
</worldbody>
21+
</mujoco>
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# Copyright 2017-2018 The dm_control Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
16+
"""Functions for computing inverse kinematics on MuJoCo models."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import collections
23+
24+
from absl import logging
25+
from dm_control.mujoco.wrapper.mjbindings import mjlib
26+
import numpy as np
27+
from six.moves import range
28+
29+
30+
_INVALID_JOINT_NAMES_TYPE = (
31+
'`joint_names` must be either None, a list, a tuple, or a numpy array; '
32+
'got {}.')
33+
_REQUIRE_TARGET_POS_OR_QUAT = (
34+
'At least one of `target_pos` or `target_quat` must be specified.')
35+
36+
IKResult = collections.namedtuple(
37+
'IKResult', ['qpos', 'err_norm', 'steps', 'success'])
38+
39+
40+
def qpos_from_site_pose(physics,
41+
site_name,
42+
target_pos=None,
43+
target_quat=None,
44+
joint_names=None,
45+
tol=1e-14,
46+
rot_weight=1.0,
47+
regularization_threshold=0.1,
48+
regularization_strength=3e-2,
49+
max_update_norm=2.0,
50+
progress_thresh=20.0,
51+
max_steps=100,
52+
inplace=False):
53+
"""Find joint positions that satisfy a target site position and/or rotation.
54+
55+
Args:
56+
physics: A `mujoco.Physics` instance.
57+
site_name: A string specifying the name of the target site.
58+
target_pos: A (3,) numpy array specifying the desired Cartesian position of
59+
the site, or None if the position should be unconstrained (default).
60+
One or both of `target_pos` or `target_quat` must be specified.
61+
target_quat: A (4,) numpy array specifying the desired orientation of the
62+
site as a quarternion, or None if the orientation should be unconstrained
63+
(default). One or both of `target_pos` or `target_quat` must be specified.
64+
joint_names: (optional) A list, tuple or numpy array specifying the names of
65+
one or more joints that can be manipulated in order to achieve the target
66+
site pose. If None (default), all joints may be manipulated.
67+
tol: (optional) Precision goal for `qpos` (the maximum value of `err_norm`
68+
in the stopping criterion).
69+
rot_weight: (optional) Determines the weight given to rotational error
70+
relative to translational error.
71+
regularization_threshold: (optional) L2 regularization will be used when
72+
inverting the Jacobian whilst `err_norm` is greater than this value.
73+
regularization_strength: (optional) Coefficient of the quadratic penalty
74+
on joint movements.
75+
max_update_norm: (optional) The maximum L2 norm of the update applied to
76+
the joint positions on each iteration. The update vector will be scaled
77+
such that its magnitude never exceeds this value.
78+
progress_thresh: (optional) If `err_norm` divided by the magnitude of the
79+
joint position update is greater than this value then the optimization
80+
will terminate prematurely. This is a useful heuristic to avoid getting
81+
stuck in local minima.
82+
max_steps: (optional) The maximum number of iterations to perform.
83+
inplace: (optional) If True, `physics.data` will be modified in place.
84+
Default value is False, i.e. a copy of `physics.data` will be made.
85+
86+
Returns:
87+
An `IKResult` namedtuple with the following fields:
88+
qpos: An (nq,) numpy array of joint positions.
89+
err_norm: A float, the weighted sum of L2 norms for the residual
90+
translational and rotational errors.
91+
steps: An int, the number of iterations that were performed.
92+
success: Boolean, True if we converged on a solution within `max_steps`,
93+
False otherwise.
94+
95+
Raises:
96+
ValueError: If both `target_pos` and `target_quat` are None, or if
97+
`joint_names` has an invalid type.
98+
"""
99+
100+
dtype = physics.data.qpos.dtype
101+
102+
if target_pos is not None and target_quat is not None:
103+
jac = np.empty((6, physics.model.nv), dtype=dtype)
104+
err = np.empty(6, dtype=dtype)
105+
jac_pos, jac_rot = jac[:3], jac[3:]
106+
err_pos, err_rot = err[:3], err[3:]
107+
else:
108+
jac = np.empty((3, physics.model.nv), dtype=dtype)
109+
err = np.empty(3, dtype=dtype)
110+
if target_pos is not None:
111+
jac_pos, jac_rot = jac, None
112+
err_pos, err_rot = err, None
113+
elif target_quat is not None:
114+
jac_pos, jac_rot = None, jac
115+
err_pos, err_rot = None, err
116+
else:
117+
raise ValueError(_REQUIRE_TARGET_POS_OR_QUAT)
118+
119+
update_nv = np.zeros(physics.model.nv, dtype=dtype)
120+
121+
if target_quat is not None:
122+
site_xquat = np.empty(4, dtype=dtype)
123+
neg_site_xquat = np.empty(4, dtype=dtype)
124+
err_rot_quat = np.empty(4, dtype=dtype)
125+
126+
if not inplace:
127+
physics = physics.copy(share_model=True)
128+
129+
# Ensure that the Cartesian position of the site is up to date.
130+
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
131+
132+
# Convert site name to index.
133+
site_id = physics.model.name2id(site_name, 'site')
134+
135+
# These are views onto the underlying MuJoCo buffers. mj_fwdPosition will
136+
# update them in place, so we can avoid indexing overhead in the main loop.
137+
site_xpos = physics.named.data.site_xpos[site_name]
138+
site_xmat = physics.named.data.site_xmat[site_name]
139+
140+
# This is an index into the rows of `update` and the columns of `jac`
141+
# that selects DOFs associated with joints that we are allowed to manipulate.
142+
if joint_names is None:
143+
dof_indices = slice(None) # Update all DOFs.
144+
elif isinstance(joint_names, (list, np.ndarray, tuple)):
145+
if isinstance(joint_names, tuple):
146+
joint_names = list(joint_names)
147+
# Find the indices of the DOFs belonging to each named joint. Note that
148+
# these are not necessarily the same as the joint IDs, since a single joint
149+
# may have >1 DOF (e.g. ball joints).
150+
indexer = physics.named.model.dof_jntid.axes.row
151+
# `dof_jntid` is an `(nv,)` array indexed by joint name. We use its row
152+
# indexer to map each joint name to the indices of its corresponding DOFs.
153+
dof_indices = indexer.convert_key_item(joint_names)
154+
else:
155+
raise ValueError(_INVALID_JOINT_NAMES_TYPE.format(type(joint_names)))
156+
157+
steps = 0
158+
success = False
159+
160+
for steps in range(max_steps):
161+
162+
err_norm = 0.0
163+
164+
if target_pos is not None:
165+
# Translational error.
166+
err_pos[:] = target_pos - site_xpos
167+
err_norm += np.linalg.norm(err_pos)
168+
if target_quat is not None:
169+
# Rotational error.
170+
mjlib.mju_mat2Quat(site_xquat, site_xmat)
171+
mjlib.mju_negQuat(neg_site_xquat, site_xquat)
172+
mjlib.mju_mulQuat(err_rot_quat, target_quat, neg_site_xquat)
173+
mjlib.mju_quat2Vel(err_rot, err_rot_quat, 1)
174+
err_norm += np.linalg.norm(err_rot) * rot_weight
175+
176+
if err_norm < tol:
177+
logging.debug('Converged after %i steps: err_norm=%3g', steps, err_norm)
178+
success = True
179+
break
180+
else:
181+
# TODO(b/112141670): Generalize this to other entities besides sites.
182+
mjlib.mj_jacSite(
183+
physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id)
184+
jac_joints = jac[:, dof_indices]
185+
186+
# TODO(b/112141592): This does not take joint limits into consideration.
187+
reg_strength = (
188+
regularization_strength if err_norm > regularization_threshold
189+
else 0.0)
190+
update_joints = nullspace_method(
191+
jac_joints, err, regularization_strength=reg_strength)
192+
193+
update_norm = np.linalg.norm(update_joints)
194+
195+
# Check whether we are still making enough progress, and halt if not.
196+
progress_criterion = err_norm / update_norm
197+
if progress_criterion > progress_thresh:
198+
logging.debug('Step %2i: err_norm / update_norm (%3g) > '
199+
'tolerance (%3g). Halting due to insufficient progress',
200+
steps, progress_criterion, progress_thresh)
201+
break
202+
203+
if update_norm > max_update_norm:
204+
update_joints *= max_update_norm / update_norm
205+
206+
# Write the entries for the specified joints into the full `update_nv`
207+
# vector.
208+
update_nv[dof_indices] = update_joints
209+
210+
# Update `physics.qpos`, taking quaternions into account.
211+
mjlib.mj_integratePos(physics.model.ptr, physics.data.qpos, update_nv, 1)
212+
213+
# Compute the new Cartesian position of the site.
214+
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
215+
216+
logging.debug('Step %2i: err_norm=%-10.3g update_norm=%-10.3g',
217+
steps, err_norm, update_norm)
218+
219+
if not success and steps == max_steps - 1:
220+
logging.warning('Failed to converge after %i steps: err_norm=%3g',
221+
steps, err_norm)
222+
223+
if not inplace:
224+
# Our temporary copy of physics.data is about to go out of scope, and when
225+
# it does the underlying mjData pointer will be freed and physics.data.qpos
226+
# will be a view onto a block of deallocated memory. We therefore need to
227+
# make a copy of physics.data.qpos while physics.data is still alive.
228+
qpos = physics.data.qpos.copy()
229+
else:
230+
# If we're modifying physics.data in place then it's fine to return a view.
231+
qpos = physics.data.qpos
232+
233+
return IKResult(qpos=qpos, err_norm=err_norm, steps=steps, success=success)
234+
235+
236+
def nullspace_method(jac_joints, delta, regularization_strength=0.0):
237+
"""Calculates the joint velocities to achieve a specified end effector delta.
238+
239+
Args:
240+
jac_joints: The Jacobian of the end effector with respect to the joints. A
241+
numpy array of shape `(ndelta, nv)`, where `ndelta` is the size of `delta`
242+
and `nv` is the number of degrees of freedom.
243+
delta: The desired end-effector delta. A numpy array of shape `(3,)` or
244+
`(6,)` containing either position deltas, rotation deltas, or both.
245+
regularization_strength: (optional) Coefficient of the quadratic penalty
246+
on joint movements. Default is zero, i.e. no regularization.
247+
248+
Returns:
249+
An `(nv,)` numpy array of joint velocities.
250+
251+
Reference:
252+
Buss, S. R. S. (2004). Introduction to inverse kinematics with jacobian
253+
transpose, pseudoinverse and damped least squares methods.
254+
https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf
255+
"""
256+
hess_approx = jac_joints.T.dot(jac_joints)
257+
joint_delta = jac_joints.T.dot(delta)
258+
if regularization_strength > 0:
259+
# L2 regularization
260+
hess_approx += np.eye(hess_approx.shape[0]) * regularization_strength
261+
return np.linalg.solve(hess_approx, joint_delta)
262+
else:
263+
return np.linalg.lstsq(hess_approx, joint_delta)[0]

0 commit comments

Comments
 (0)