Skip to content

Commit 462afda

Browse files
committed
Fix formatting of very long arrays in generated XML
`np.array2string` truncates very long arrays (such as mesh vertices) by inserting '...' to represent the skipped elements. We now use `np.savetxt` instead, with a more precise string representation for floats. PiperOrigin-RevId: 220796989
1 parent c4e60f4 commit 462afda

4 files changed

Lines changed: 38 additions & 20 deletions

File tree

dm_control/mjcf/attribute.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323
import collections
2424
import hashlib
2525
import os
26-
import re
2726

2827
from dm_control.mjcf import base
2928
from dm_control.mjcf import constants
3029
from dm_control.mjcf import debugging
3130
from dm_control.mujoco.wrapper import util
3231
import numpy as np
3332
import six
33+
3434
from dm_control.utils import io as resources
3535

3636

@@ -190,10 +190,12 @@ def to_xml_string(self, prefix_root=None): # pylint: disable=unused-argument
190190
if self._value is None:
191191
return None
192192
else:
193-
return re.sub(r'\s+', r' ',
194-
np.array2string(
195-
self._value, precision=15,
196-
suppress_small=False, max_line_width=999)[1:-1].strip())
193+
out = six.BytesIO()
194+
# 17 decimal digits is sufficient to represent a double float without loss
195+
# of precision.
196+
# https://en.wikipedia.org/wiki/IEEE_754#Character_representation
197+
np.savetxt(out, self._value, fmt='%.17g', newline=' ')
198+
return util.to_native_string(out.getvalue())[:-1] # Strip trailing space.
197199

198200
def _check_shape(self, array):
199201
actual_length = array.shape[0]

dm_control/mjcf/attribute_test.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,22 @@ def testFloatArray(self):
191191
# XML string should not be affected by global print options
192192
np.set_printoptions(precision=3, suppress=True)
193193
mujoco.optional.float_array = [np.pi, 2, 1e-16]
194-
self.assertXMLStringEqual(
195-
mujoco.optional, 'float_array',
196-
'3.141592653589793e+00 2.000000000000000e+00 1.000000000000000e-16')
194+
self.assertXMLStringEqual(mujoco.optional, 'float_array',
195+
'3.1415926535897931 2 9.9999999999999998e-17')
197196
self.assertCanBeCleared(mujoco.optional, 'float_array')
198197

198+
def testFormatVeryLargeArray(self):
199+
mujoco = self._mujoco
200+
array = np.arange(2000, dtype=np.double)
201+
mujoco.optional.huge_float_array = array
202+
xml_string = mujoco.optional.get_attribute_xml_string('huge_float_array')
203+
self.assertNotIn('...', xml_string)
204+
# Check that array <--> string conversion is a round trip.
205+
mujoco.optional.huge_float_array = None
206+
self.assertIsNone(mujoco.optional.huge_float_array)
207+
mujoco.optional.huge_float_array = xml_string
208+
np.testing.assert_array_equal(mujoco.optional.huge_float_array, array)
209+
199210
def testIntArray(self):
200211
mujoco = self._mujoco
201212
mujoco.optional.int_array = [2, 2]
@@ -393,7 +404,7 @@ def testFileNameTrimming(self):
393404
asset = attribute.Asset(
394405
contents='', extension=extension, prefix=original_filename)
395406
vfs_filename = asset.get_vfs_filename()
396-
self.assertEqual(len(vfs_filename), constants.MAX_VFS_FILENAME_LENGTH)
407+
self.assertLen(vfs_filename, constants.MAX_VFS_FILENAME_LENGTH)
397408

398409
vfs = types.MJVFS()
399410
mjlib.mj_defaultVFS(vfs)

dm_control/mjcf/element_test.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,10 @@ def testAttach(self):
340340
parent=attachment_site.parent, root=submujoco)
341341
self.assertEqual(
342342
subsubmodel_frame.to_xml_string().split('\n')[0],
343-
'<body pos="0.1 0.1 0.1" quat="0. 1. 0. 0." name="subsubmodel/">')
343+
'<body '
344+
'pos="0.10000000000000001 0.10000000000000001 0.10000000000000001" '
345+
'quat="0 1 0 0" '
346+
'name="subsubmodel/">')
344347
self.assertEqual(subsubmodel_frame.all_children(),
345348
subsubmujoco.worldbody.all_children())
346349

@@ -485,9 +488,12 @@ def testAttachmentFrames(self):
485488
subsubmujoco_frame = submujoco.find('attachment_frame', 'subsubmodel')
486489
subsubmujoco_frame_xml = subsubmujoco_frame.to_xml_string(
487490
pretty_print=False, prefix_root=mujoco.namescope)
488-
self.assertStartsWith(subsubmujoco_frame_xml,
489-
'<body pos="0.1 0.1 0.1" quat="0. 1. 0. 0." '
490-
'name="submodel/subsubmodel/">')
491+
self.assertStartsWith(
492+
subsubmujoco_frame_xml,
493+
'<body '
494+
'pos="0.10000000000000001 0.10000000000000001 0.10000000000000001" '
495+
'quat="0 1 0 0" '
496+
'name="submodel/subsubmodel/">')
491497
self.assertEqual(subsubmujoco_frame.full_identifier,
492498
'submodel/subsubmodel/')
493499
with self.assertRaisesRegexp(AttributeError, 'not a valid child'):
@@ -497,7 +503,7 @@ def testAttachmentFrames(self):
497503
pretty_print=False, prefix_root=mujoco.namescope)
498504
self.assertEqual(
499505
hinge_joint_xml,
500-
'<joint class="submodel/" type="hinge" axis="1. 2. 3." '
506+
'<joint class="submodel/" type="hinge" axis="1 2 3" '
501507
'name="submodel/subsubmodel/"/>')
502508
self.assertEqual(hinge_joint.full_identifier, 'submodel/subsubmodel/')
503509

@@ -737,7 +743,7 @@ def testFindAll(self):
737743
mujoco.attach(submujoco)
738744

739745
geoms = mujoco.find_all('geom')
740-
self.assertEqual(len(geoms), 6)
746+
self.assertLen(geoms, 6)
741747
self.assertEqual(geoms[0].root, mujoco)
742748
self.assertEqual(geoms[1].root, mujoco)
743749
self.assertEqual(geoms[2].root, submujoco)
@@ -746,11 +752,9 @@ def testFindAll(self):
746752
self.assertEqual(geoms[5].root, submujoco)
747753

748754
b_0 = submujoco.find('body', 'b_0')
749-
self.assertEqual(len(b_0.find_all('joint')), 6)
750-
self.assertEqual(
751-
len(b_0.find_all('joint', immediate_children_only=True)), 1)
752-
self.assertEqual(
753-
len(b_0.find_all('joint', exclude_attachments=True)), 2)
755+
self.assertLen(b_0.find_all('joint'), 6)
756+
self.assertLen(b_0.find_all('joint', immediate_children_only=True), 1)
757+
self.assertLen(b_0.find_all('joint', exclude_attachments=True), 2)
754758

755759
def testFindAllFrameJoints(self):
756760
root_model = parser.from_path(_TEST_MODEL_XML)

dm_control/mjcf/test_assets/attribute_test_schema.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
<attribute name="int" type="int"/>
1010
<attribute name="string" type="string"/>
1111
<attribute name="float_array" type="array" array_type="float" array_size="3"/>
12+
<attribute name="huge_float_array" type="array" array_type="float"/>
1213
<attribute name="int_array" type="array" array_type="int" array_size="2"/>
1314
<attribute name="keyword" type="keyword" valid_values="Alpha Beta Gamma"/>
1415
<attribute name="reftype" type="keyword" valid_values="entity optional"/>

0 commit comments

Comments
 (0)