Skip to content

Commit 7de8ff7

Browse files
yuvaltassaalimuldal
authored andcommitted
Expose mj_contactForce() as a method of Physics.data.
PiperOrigin-RevId: 311731186 Change-Id: Ia8028e12d51d4b14021044ec6048f847f1420259
1 parent 2bb45ec commit 7de8ff7

3 files changed

Lines changed: 97 additions & 1 deletion

File tree

dm_control/mujoco/wrapper/core.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
_INVALID_FONT_SCALE = ("`font_scale` must be one of {}, got {{}}."
5353
.format(enums.mjtFontScale))
5454

55+
_CONTACT_ID_OUT_OF_RANGE = (
56+
"`contact_id` must be between 0 and {max_valid} (inclusive), got: {actual}."
57+
)
58+
5559
# Global cache used to store finalizers for freeing ctypes pointers.
5660
# Contains {pointer_address: weakref_object} pairs.
5761
_FINALIZERS = {}
@@ -714,6 +718,27 @@ def object_velocity(self, object_id, object_type, local_frame=False):
714718
# MuJoCo returns velocities in (angular, linear) order, which we flip here.
715719
return velocity.reshape(2, 3)[::-1]
716720

721+
def contact_force(self, contact_id):
722+
"""Returns the wrench of a contact as a 2 x 3 array of (forces, torques).
723+
724+
Args:
725+
contact_id: Integer, the index of the contact within the contact buffer
726+
(`self.contact`).
727+
728+
Returns:
729+
2x3 array with stacked (force, torque). Note that the order of dimensions
730+
is (normal, tangent, tangent), in the contact's frame.
731+
732+
Raises:
733+
ValueError: If `contact_id` is negative or bigger than ncon-1.
734+
"""
735+
if not 0 <= contact_id < self.ncon:
736+
raise ValueError(_CONTACT_ID_OUT_OF_RANGE
737+
.format(max_valid=self.ncon-1, actual=contact_id))
738+
wrench = np.empty(6, dtype=np.float64)
739+
mjlib.mj_contactForce(self.model.ptr, self.ptr, contact_id, wrench)
740+
return wrench.reshape(2, 3)
741+
717742
@property
718743
def model(self):
719744
"""The parent MjModel for this MjData instance."""

dm_control/mujoco/wrapper/core_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,77 @@ def testObjectVelocity(
465465
np.testing.assert_array_almost_equal(linvel, expected_linvel)
466466
np.testing.assert_array_almost_equal(angvel, expected_angvel)
467467

468+
def testContactForce(self):
469+
box_on_floor = """
470+
<mujoco>
471+
<worldbody>
472+
<geom name='floor' type='plane' size='1 1 1'/>
473+
<body name='box' pos='0 0 .1'>
474+
<freejoint/>
475+
<geom name='box' type='box' size='.1 .1 .1'/>
476+
</body>
477+
</worldbody>
478+
</mujoco>
479+
"""
480+
model = core.MjModel.from_xml_string(box_on_floor)
481+
data = core.MjData(model)
482+
# Settle for 500 timesteps (1 second):
483+
for _ in range(500):
484+
mjlib.mj_step(model.ptr, data.ptr)
485+
normal_force = 0.
486+
for contact_id in range(data.ncon):
487+
force = data.contact_force(contact_id)
488+
normal_force += force[0, 0]
489+
box_id = 1
490+
box_weight = -model.opt.gravity[2]*model.body_mass[box_id]
491+
self.assertAlmostEqual(normal_force, box_weight)
492+
# Test raising of out-of-range errors:
493+
bad_ids = [-1, data.ncon]
494+
for bad_id in bad_ids:
495+
with self.assertRaisesWithLiteralMatch(
496+
ValueError,
497+
core._CONTACT_ID_OUT_OF_RANGE.format(
498+
max_valid=data.ncon - 1, actual=bad_id)):
499+
data.contact_force(bad_id)
500+
501+
@parameterized.parameters(
502+
dict(
503+
condim=3, # Only sliding friction.
504+
expected_torques=[False, False, False], # No torques.
505+
),
506+
dict(
507+
condim=4, # Sliding and torsional friction.
508+
expected_torques=[True, False, False], # Only torsional torque.
509+
),
510+
dict(
511+
condim=6, # Sliding, torsional and rolling.
512+
expected_torques=[True, True, True], # All torques are nonzero.
513+
),
514+
)
515+
def testContactTorque(self, condim, expected_torques):
516+
ball_on_floor = """
517+
<mujoco>
518+
<worldbody>
519+
<geom name='floor' type='plane' size='1 1 1'/>
520+
<body name='ball' pos='0 0 .1'>
521+
<freejoint/>
522+
<geom name='ball' size='.1' friction='1 .1 .1'/>
523+
</body>
524+
</worldbody>
525+
</mujoco>
526+
"""
527+
model = core.MjModel.from_xml_string(ball_on_floor)
528+
data = core.MjData(model)
529+
model.geom_condim[:] = condim
530+
data.qvel[3:] = np.array((1., 1., 1.))
531+
# Settle for 10 timesteps (20 milliseconds):
532+
for _ in range(10):
533+
mjlib.mj_step(model.ptr, data.ptr)
534+
contact_id = 0 # This model has only one contact.
535+
_, torque = data.contact_force(contact_id)
536+
nonzero_torques = torque != 0
537+
np.testing.assert_array_equal(nonzero_torques, np.array((expected_torques)))
538+
468539
def testFreeMjrContext(self):
469540
for _ in range(5):
470541
renderer = _render.Renderer(640, 480)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def find_data_files(package_dir, patterns):
166166

167167
setup(
168168
name='dm_control',
169-
version='0.0.311497959',
169+
version='0.0.311731186',
170170
description='Continuous control environments and MuJoCo Python bindings.',
171171
author='DeepMind',
172172
license='Apache License, Version 2.0',

0 commit comments

Comments
 (0)