forked from google-deepmind/dm_control
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathelement.py
More file actions
1404 lines (1200 loc) · 54 KB
/
element.py
File metadata and controls
1404 lines (1200 loc) · 54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Classes to represent MJCF elements in the object model."""
import collections
import copy
import os
import sys
from dm_control.mjcf import attribute as attribute_types
from dm_control.mjcf import base
from dm_control.mjcf import constants
from dm_control.mjcf import copier
from dm_control.mjcf import debugging
from dm_control.mjcf import namescope
from dm_control.mjcf import schema
from dm_control.mujoco.wrapper import util
from lxml import etree
import numpy as np
_raw_property = property # pylint: disable=invalid-name
_UNITS = ('K', 'M', 'G', 'T', 'P', 'E')
def _to_bytes(value_str):
"""Converts a `str` value representing a size in bytes to `int`.
Args:
value_str: `str` value to be converted.
Returns:
`int` corresponding size in bytes.
Raises:
ValueError: if the `str` value passed has an unsupported unit.
"""
if value_str.isdigit():
value_int = int(value_str)
else:
value_int = int(value_str[:-1])
unit = value_str[-1].upper()
if unit not in _UNITS:
raise ValueError(
f'unit of `size.memory` should be one of [{", ".join(_UNITS)}], got'
f' {unit}')
power = 10 * (_UNITS.index(unit) + 1)
value_int *= 2**power
return value_int
def _max_bytes(first, second):
return str(max(_to_bytes(first), _to_bytes(second)))
_CONFLICT_BEHAVIOR_FUNC = {'min': min, 'max': max, 'max_bytes': _max_bytes}
def property(method): # pylint: disable=redefined-builtin
"""Modifies `@property` to keep track of any `AttributeError` raised.
Our `Element` implementations overrides the `__getattr__` method. This does
not interact well with `@property`: if a `property`'s code is buggy so as to
raise an `AttributeError`, then Python would silently discard it and redirect
to our `__getattr__` instead, leading to an uninformative stack trace. This
makes it very difficult to debug issues that involve properties.
To remedy this, we modify `@property` within this module to store any
`AttributeError` raised within the respective `Element` object. Then, in our
`__getattr__` logic, we could re-raise it to preserve the original stack
trace.
The reason that this is not implemented as a different decorator is that we
could accidentally use @property on a new method. This would work fine until
someone triggers a subtle bug. This is when a proper trace would be most
useful, but we would still end up with a strange undebuggable stack trace
anyway.
Note that at the end of this module, we have a `del property` to prevent this
override from being broadcasted externally.
Args:
method: The method that is being decorated.
Returns:
A `property` corresponding to the decorated method.
"""
def _mjcf_property(self):
try:
return method(self)
except:
_, err, tb = sys.exc_info()
err_with_next_tb = err.with_traceback(tb.tb_next)
if isinstance(err, AttributeError):
self._last_attribute_error = err_with_next_tb # pylint: disable=protected-access
raise err_with_next_tb # pylint: disable=raise-missing-from
return _raw_property(_mjcf_property)
def _make_element(spec, parent, attributes=None):
"""Helper function to generate the right kind of Element given a spec."""
if (spec.name == constants.WORLDBODY
or (spec.name == constants.SITE
and (parent.tag == constants.BODY
or parent.tag == constants.WORLDBODY))):
return _AttachableElement(spec, parent, attributes)
elif isinstance(parent, _AttachmentFrame):
return _AttachmentFrameChild(spec, parent, attributes)
elif spec.name == constants.DEFAULT:
return _DefaultElement(spec, parent, attributes)
elif spec.name == constants.ACTUATOR:
return _ActuatorElement(spec, parent, attributes)
else:
return _ElementImpl(spec, parent, attributes)
_DEFAULT_NAME_FROM_FILENAME = frozenset(['mesh', 'hfield', 'texture'])
class _ElementImpl(base.Element):
"""Actual implementation of a generic MJCF element object."""
__slots__ = ['__weakref__', '_spec', '_parent', '_attributes', '_children',
'_own_attributes', '_attachments', '_is_removed', '_init_stack',
'_is_worldbody', '_cached_namescope', '_cached_root',
'_cached_full_identifier', '_cached_revision',
'_last_attribute_error']
def __init__(self, spec, parent, attributes=None):
attributes = attributes or {}
# For certain `asset` elements the `name` attribute can be omitted, in which
# case the name will be the filename without the leading path and extension.
# See http://www.mujoco.org/book/XMLreference.html#asset.
if ('name' not in attributes
and 'file' in attributes
and spec.name in _DEFAULT_NAME_FROM_FILENAME):
_, filename = os.path.split(attributes['file'])
basename, _ = os.path.splitext(filename)
attributes['name'] = basename
self._spec = spec
self._parent = parent
self._attributes = collections.OrderedDict()
self._own_attributes = None
self._children = []
self._attachments = collections.OrderedDict()
self._is_removed = False
self._is_worldbody = (self.tag == 'worldbody')
if self._parent:
self._cached_namescope = self._parent.namescope
self._cached_root = self._parent.root
self._cached_full_identifier = ''
self._cached_revision = -1
self._last_attribute_error = None
if debugging.debug_mode():
self._init_stack = debugging.get_current_stack_trace()
with debugging.freeze_current_stack_trace():
for child_spec in self._spec.children.values():
if not (child_spec.repeated or child_spec.on_demand):
self._children.append(_make_element(spec=child_spec, parent=self))
if constants.DCLASS in attributes:
attributes[constants.CLASS] = attributes[constants.DCLASS]
del attributes[constants.DCLASS]
for attribute_name in attributes.keys():
self._check_valid_attribute(attribute_name)
for attribute_spec in self._spec.attributes.values():
value = None
# Some Reference attributes refer to a namespace that is specified
# via another attribute. We therefore have to set things up for
# the additional indirection.
if attribute_spec.type is attribute_types.Reference:
reference_namespace = (
attribute_spec.other_kwargs['reference_namespace'])
if reference_namespace.startswith(
constants.INDIRECT_REFERENCE_NAMESPACE_PREFIX):
attribute_spec = copy.deepcopy(attribute_spec)
namespace_attrib_name = reference_namespace[
len(constants.INDIRECT_REFERENCE_NAMESPACE_PREFIX):]
attribute_spec.other_kwargs['reference_namespace'] = (
self._attributes[namespace_attrib_name])
if attribute_spec.name in attributes:
value = attributes[attribute_spec.name]
try:
self._attributes[attribute_spec.name] = attribute_spec.type(
name=attribute_spec.name,
required=attribute_spec.required,
conflict_allowed=attribute_spec.conflict_allowed,
conflict_behavior=attribute_spec.conflict_behavior,
parent=self, value=value, **attribute_spec.other_kwargs)
except:
# On failure, clear attributes already created
for attribute_obj in self._attributes.values():
attribute_obj._force_clear() # pylint: disable=protected-access
# Then raise a meaningful error
err_type, err, tb = sys.exc_info()
raise err_type( # pylint: disable=raise-missing-from
f'during initialization of attribute {attribute_spec.name!r} of '
f'element <{self._spec.name}>: {err}').with_traceback(tb)
def get_init_stack(self):
"""Gets the stack trace where this element was first initialized."""
if debugging.debug_mode():
return self._init_stack
def get_last_modified_stacks_for_all_attributes(self):
"""Gets a dict of stack traces where each attribute was last modified."""
return collections.OrderedDict(
[(name, self._attributes[name].last_modified_stack)
for name in self._spec.attributes])
def is_same_as(self, other):
"""Checks whether another element is semantically equivalent to this one.
Two elements are considered equivalent if they have the same
specification (i.e. same tag appearing in the same context), the same
attribute values, and all of their children are equivalent. The ordering
of non-repeated children is not important for this comparison, while
the ordering of repeated children are important only amongst the same
type* of children. In other words, for two bodies to be considered
equivalent, their child sites must appear in the same order, and their
child geoms must appear in the same order, but permutations between sites
and geoms are disregarded. (The only exception is in tendon definition,
where strict ordering of all children is necessary for equivalence.)
*Note that the notion of "same type" in this function is very loose:
for example different actuator element subtypes are treated as separate
types when children ordering is considered. Therefore, two <actuator>
elements might be considered equivalent even though they result in different
orderings of `mjData.ctrl` when compiled. As it stands, this function
is designed primarily as a testing aid and should not be used to guarantee
that models are actually identical.
Args:
other: An `mjcf.Element`
Returns:
`True` if `other` element is semantically equivalent to this one.
"""
if other is None or other.spec != self._spec:
return False
for attribute_name in self._spec.attributes.keys():
attribute = self._attributes[attribute_name]
other_attribute = getattr(other, attribute_name)
if isinstance(attribute.value, base.Element):
if attribute.value.full_identifier != other_attribute.full_identifier:
return False
elif not np.all(attribute.value == other_attribute):
return False
if (self._parent and
self._parent.tag == constants.TENDON and
self._parent.parent == self.root):
return self._tendon_has_same_children_as(other)
else:
return self._has_same_children_as(other)
def _has_same_children_as(self, other):
"""Helper function to check whether another element has the same children.
See docstring for `is_same_as` for explanation about the treatment of
children ordering.
Args:
other: An `mjcf.Element`
Returns:
A boolean
"""
for child_name, child_spec in self._spec.children.items():
child = self.get_children(child_name)
other_child = getattr(other, child_name)
if not child_spec.repeated:
if ((child is None and other_child is not None) or
(child is not None and not child.is_same_as(other_child))):
return False
else:
if len(child) != len(other_child):
return False
else:
for grandchild, other_grandchild in zip(child, other_child):
if not grandchild.is_same_as(other_grandchild):
return False
return True
def _tendon_has_same_children_as(self, other):
return all(child.is_same_as(other_child)
for child, other_child
in zip(self.all_children(), other.all_children()))
def _alias_attributes_dict(self, other):
if self._own_attributes is None:
self._own_attributes = self._attributes
self._attributes = other
def _restore_attributes_dict(self):
if self._own_attributes is not None:
for attribute_name, attribute in self._attributes.items():
self._own_attributes[attribute_name].value = attribute.value
self._attributes = self._own_attributes
self._own_attributes = None
@property
def tag(self):
return self._spec.name
@property
def spec(self):
return self._spec
@property
def parent(self):
return self._parent
@property
def namescope(self):
return self._cached_namescope
@property
def root(self):
return self._cached_root
def prefixed_identifier(self, prefix_root):
if not self._spec.identifier and not self._is_worldbody:
return None
elif self._is_worldbody:
prefix = self.namescope.full_prefix(prefix_root=prefix_root)
return prefix or 'world'
else:
full_identifier = (
self._attributes[self._spec.identifier].to_xml_string(
prefix_root=prefix_root))
if full_identifier:
return full_identifier
else:
prefix = self.namescope.full_prefix(prefix_root=prefix_root)
prefix = prefix or constants.PREFIX_SEPARATOR
return prefix + self._default_identifier
@property
def full_identifier(self):
"""Fully-qualified identifier used for this element in the generated XML."""
if self.namescope.revision > self._cached_revision:
self._cached_full_identifier = self.prefixed_identifier(
prefix_root=self.namescope.root)
self._cached_revision = self.namescope.revision
return self._cached_full_identifier
@property
def _default_identifier(self):
"""The default identifier used if this element is not named by the user."""
if not self._spec.identifier:
return None
else:
siblings = self.root.find_all(self._spec.namespace,
exclude_attachments=True)
return '{separator}unnamed_{namespace}_{index}'.format(
separator=constants.PREFIX_SEPARATOR,
namespace=self._spec.namespace,
index=siblings.index(self))
def __dir__(self):
out_dir = set()
classes = (type(self),)
while classes:
super_classes = set()
for klass in classes:
out_dir.update(klass.__dict__)
super_classes.update(klass.__bases__)
classes = super_classes
out_dir.update(self._spec.children)
out_dir.update(self._spec.attributes)
if constants.CLASS in out_dir:
out_dir.remove(constants.CLASS)
out_dir.add(constants.DCLASS)
return sorted(out_dir)
def find(self, namespace, identifier):
"""Finds an element with a particular identifier.
This function allows the direct access to an arbitrarily deeply nested
child element by name, without the need to manually traverse through the
object tree. The `namespace` argument specifies the kind of element to
find. In most cases, this corresponds to the element's XML tag name.
However, if an element has multiple specialized tags, then the namespace
corresponds to the tag name of the most general element of that kind.
For example, `namespace='joint'` would search for `<joint>` and
`<freejoint>`, while `namespace='actuator'` would search for `<general>`,
`<motor>`, `<position>`, `<velocity>`, and `<cylinder>`.
Args:
namespace: A string specifying the namespace being searched. See the
docstring above for explanation.
identifier: The identifier string of the desired element.
Returns:
An `mjcf.Element` object, or `None` if an element with the specified
identifier is not found.
Raises:
ValueError: if either `namespace` or `identifier` is not a string, or if
`namespace` is not a valid namespace.
"""
if not isinstance(namespace, str):
raise ValueError(
'`namespace` should be a string: got {!r}'.format(namespace))
if not isinstance(identifier, str):
raise ValueError(
'`identifier` should be a string: got {!r}'.format(identifier))
if namespace not in schema.FINDABLE_NAMESPACES:
raise ValueError('{!r} is not a valid namespace. Available: {}.'.format(
namespace, schema.FINDABLE_NAMESPACES))
if constants.PREFIX_SEPARATOR in identifier:
scope_name = identifier.split(constants.PREFIX_SEPARATOR)[0]
try:
attachment = self.namescope.get('attached_model', scope_name)
found_element = attachment.find(
namespace, identifier[(len(scope_name) + 1):])
except (KeyError, ValueError):
found_element = None
else:
try:
found_element = self.namescope.get(namespace, identifier)
except KeyError:
found_element = None
if found_element and self._parent:
next_parent = found_element.parent
while next_parent and next_parent != self:
next_parent = next_parent.parent
if not next_parent:
found_element = None
return found_element
def find_all(self, namespace,
immediate_children_only=False, exclude_attachments=False):
"""Finds all elements of a particular kind.
The `namespace` argument specifies the kind of element to
find. In most cases, this corresponds to the element's XML tag name.
However, if an element has multiple specialized tags, then the namespace
corresponds to the tag name of the most general element of that kind.
For example, `namespace='joint'` would search for `<joint>` and
`<freejoint>`, while `namespace='actuator'` would search for `<general>`,
`<motor>`, `<position>`, `<velocity>`, and `<cylinder>`.
Args:
namespace: A string specifying the namespace being searched. See the
docstring above for explanation.
immediate_children_only: (optional) A boolean, if `True` then only
the immediate children of this element are returned.
exclude_attachments: (optional) A boolean, if `True` then elements
belonging to attached models are excluded.
Returns:
A list of `mjcf.Element`.
Raises:
ValueError: if `namespace` is not a valid namespace.
"""
if namespace not in schema.FINDABLE_NAMESPACES:
raise ValueError('{!r} is not a valid namespace. Available: {}'.format(
namespace, schema.FINDABLE_NAMESPACES))
out = []
children = self._children if exclude_attachments else self.all_children()
for child in children:
if (namespace == child.spec.namespace or
# Direct children of attachment frames have custom namespaces of the
# form "joint@attachment_frame_<id>".
child.spec.namespace and child.spec.namespace.startswith(
namespace + constants.NAMESPACE_SEPARATOR) or
# Attachment frames are considered part of the "body" namespace.
namespace == constants.BODY and isinstance(child, _AttachmentFrame)):
out.append(child)
if not immediate_children_only:
out.extend(child.find_all(namespace,
exclude_attachments=exclude_attachments))
return out
def enter_scope(self, scope_identifier):
"""Finds the root element of the given scope and returns it.
This function allows the access to a nested scope that is a child of this
element. The `scope_identifier` argument specifies the path to the child
scope element.
Args:
scope_identifier: The path of the desired scope element.
Returns:
An `mjcf.Element` object, or `None` if a scope element with the
specified path is not found.
"""
if constants.PREFIX_SEPARATOR in scope_identifier:
scope_name = scope_identifier.split(constants.PREFIX_SEPARATOR)[0]
try:
attachment = self.namescope.get('attached_model', scope_name)
except KeyError:
return None
scope_suffix = scope_identifier[(len(scope_name) + 1):]
if scope_suffix:
return attachment.enter_scope(scope_suffix)
else:
return attachment
else:
try:
return self.namescope.get('attached_model', scope_identifier)
except KeyError:
return None
def _check_valid_attribute(self, attribute_name):
if attribute_name not in self._spec.attributes:
raise AttributeError(
'{!r} is not a valid attribute for <{}>'.format(
attribute_name, self._spec.name))
def _get_attribute(self, attribute_name):
self._check_valid_attribute(attribute_name)
return self._attributes[attribute_name].value
def get_attribute_xml_string(self,
attribute_name,
prefix_root=None,
*,
precision=constants.XML_DEFAULT_PRECISION,
zero_threshold=0):
self._check_valid_attribute(attribute_name)
return self._attributes[attribute_name].to_xml_string(
prefix_root, precision=precision, zero_threshold=zero_threshold)
def get_attributes(self):
fix_attribute_name = (
lambda name: constants.DCLASS if name == constants.CLASS else name)
return collections.OrderedDict(
[(fix_attribute_name(name), self._get_attribute(name))
for name in self._spec.attributes.keys()
if self._get_attribute(name) is not None])
def _set_attribute(self, attribute_name, value):
self._check_valid_attribute(attribute_name)
self._attributes[attribute_name].value = value
self.namescope.increment_revision()
def set_attributes(self, **kwargs):
if constants.DCLASS in kwargs:
kwargs[constants.CLASS] = kwargs[constants.DCLASS]
del kwargs[constants.DCLASS]
old_values = []
with debugging.freeze_current_stack_trace():
for attribute_name, new_value in kwargs.items():
old_value = self._get_attribute(attribute_name)
try:
self._set_attribute(attribute_name, new_value)
old_values.append((attribute_name, old_value))
except:
# On failure, restore old attribute values for those already set.
for name, old_value in old_values:
self._set_attribute(name, old_value)
# Then raise a meaningful error.
err_type, err, tb = sys.exc_info()
raise err_type( # pylint: disable=raise-missing-from
f'during assignment to attribute {attribute_name!r} of '
f'element <{self._spec.name}>: {err}').with_traceback(tb)
def _remove_attribute(self, attribute_name):
self._check_valid_attribute(attribute_name)
self._attributes[attribute_name].clear()
self.namescope.increment_revision()
def _check_valid_child(self, element_name):
try:
return self._spec.children[element_name]
except KeyError:
raise AttributeError( # pylint: disable=raise-missing-from
'<{}> is not a valid child of <{}>'
.format(element_name, self._spec.name))
def get_children(self, element_name):
child_spec = self._check_valid_child(element_name)
if child_spec.repeated:
return _ElementListView(spec=child_spec, parent=self)
else:
for child in self._children:
if child.tag == element_name:
return child
if child_spec.on_demand:
return None
else:
raise RuntimeError(
'Cannot find the non-repeated child <{}> of <{}>. '
'This should never happen, as we pre-create these in __init__. '
'Please file an bug report. Thank you.'
.format(element_name, self._spec.name))
def add(self, element_name, **kwargs):
"""Add a new child element to this element.
Args:
element_name: The tag of the element to add.
**kwargs: Attributes of the new element being created.
Raises:
ValueError: If the 'element_name' is not a valid child, or if an invalid
attribute is specified in `kwargs`.
Returns:
An `mjcf.Element` corresponding to the newly created child element.
"""
return self.insert(element_name, position=None, **kwargs)
def insert(self, element_name, position, **kwargs):
"""Add a new child element to this element.
Args:
element_name: The tag of the element to add.
position: Where to insert the new element.
**kwargs: Attributes of the new element being created.
Raises:
ValueError: If the 'element_name' is not a valid child, or if an invalid
attribute is specified in `kwargs`.
Returns:
An `mjcf.Element` corresponding to the newly created child element.
"""
child_spec = self._check_valid_child(element_name)
if child_spec.on_demand:
need_new_on_demand = self.get_children(element_name) is None
else:
need_new_on_demand = False
if not (child_spec.repeated or need_new_on_demand):
raise ValueError('A <{}> child already exists, please access it directly.'
.format(element_name))
new_element = _make_element(child_spec, self, attributes=kwargs)
if position is not None:
self._children.insert(position, new_element)
else:
self._children.append(new_element)
self.namescope.increment_revision()
return new_element
def __getattr__(self, name):
if self._last_attribute_error:
# This means that we got here through a @property raising AttributeError.
# We therefore just re-raise the last AttributeError back to the user.
# Note that self._last_attribute_error was set by our specially
# instrumented @property decorator.
exc = self._last_attribute_error
self._last_attribute_error = None
raise exc # pylint: disable=raising-bad-type
elif name in self._spec.children:
return self.get_children(name)
elif name in self._spec.attributes:
return self._get_attribute(name)
elif name == constants.DCLASS and constants.CLASS in self._spec.attributes:
return self._get_attribute(constants.CLASS)
else:
raise AttributeError('object has no attribute: {}'.format(name))
def __setattr__(self, name, value):
# If this name corresponds to a descriptor for a slotted attribute or
# settable property then try to invoke the descriptor to set the attribute
# and return if successful.
klass_attr = getattr(type(self), name, None)
if klass_attr is not None:
try:
return klass_attr.__set__(self, value)
except AttributeError:
pass
# If we did not find a settable descriptor then we look in the attribute
# spec to see if there is a MuJoCo attribute matching this name.
attribute_name = name if name != constants.DCLASS else constants.CLASS
if attribute_name in self._spec.attributes:
self._set_attribute(attribute_name, value)
else:
raise AttributeError('can\'t set attribute: {}'.format(name))
def __delattr__(self, name):
if name in self._spec.children:
if self._spec.children[name].repeated:
raise AttributeError(
'`{0}` is a collection of child elements, '
'which cannot be deleted. Did you mean to call `{0}.clear()`?'
.format(name))
else:
return self.get_children(name).remove()
elif name in self._spec.attributes:
return self._remove_attribute(name)
else:
raise AttributeError('object has no attribute: {}'.format(name))
def _check_attachments_on_remove(self, affect_attachments):
if not affect_attachments and self._attachments:
raise ValueError(
'please use remove(affect_attachments=True) as this will affect some '
'attributes and/or children belonging to an attached model')
for child in self._children:
child._check_attachments_on_remove(affect_attachments) # pylint: disable=protected-access
def remove(self, affect_attachments=False):
"""Removes this element from the model."""
self._check_attachments_on_remove(affect_attachments)
if affect_attachments:
for attachment in self._attachments.values():
attachment.remove(affect_attachments=True)
for child in list(self._children):
child.remove(affect_attachments)
if self._spec.repeated or self._spec.on_demand:
self._parent._children.remove(self) # pylint: disable=protected-access
for attribute in self._attributes.values():
attribute._force_clear() # pylint: disable=protected-access
self._parent = None
self._is_removed = True
else:
for attribute in self._attributes.values():
attribute._force_clear() # pylint: disable=protected-access
self.namescope.increment_revision()
@property
def is_removed(self):
return self._is_removed
def all_children(self):
all_children = [child for child in self._children]
for attachment in self._attachments.values():
all_children += [child for child in attachment.all_children()
if child.spec.repeated]
return all_children
def to_xml(self, prefix_root=None, debug_context=None,
*,
precision=constants.XML_DEFAULT_PRECISION,
zero_threshold=0):
"""Generates an etree._Element corresponding to this MJCF element.
Args:
prefix_root: (optional) A `NameScope` object to be treated as root
for the purpose of calculating the prefix.
If `None` then no prefix is included.
debug_context: (optional) A `debugging.DebugContext` object to which
the debugging information associated with the generated XML is written.
This is intended for internal use within PyMJCF; users should never need
manually pass this argument.
precision: (optional) Number of digits to output for floating point
quantities.
zero_threshold: (optional) When outputting XML, floating point quantities
whose absolute value falls below this threshold will be treated as zero.
Returns:
An etree._Element object.
"""
prefix_root = prefix_root or self.namescope
xml_element = etree.Element(self._spec.name)
self._attributes_to_xml(xml_element, prefix_root, debug_context,
precision=precision, zero_threshold=zero_threshold)
self._children_to_xml(xml_element, prefix_root, debug_context,
precision=precision, zero_threshold=zero_threshold)
return xml_element
def _attributes_to_xml(self, xml_element, prefix_root, debug_context=None,
*, precision, zero_threshold):
del debug_context # Unused.
for attribute_name, attribute in self._attributes.items():
attribute_value = attribute.to_xml_string(prefix_root,
precision=precision,
zero_threshold=zero_threshold)
if attribute_name == self._spec.identifier and attribute_value is None:
xml_element.set(attribute_name, self.full_identifier)
elif attribute_value is None:
continue
else:
xml_element.set(attribute_name, attribute_value)
def _children_to_xml(self, xml_element, prefix_root, debug_context=None,
*, precision, zero_threshold):
for child in self.all_children():
child_xml = child.to_xml(prefix_root, debug_context,
precision=precision,
zero_threshold=zero_threshold)
if (child_xml.attrib or len(child_xml) # pylint: disable=g-explicit-length-test
or child.spec.repeated or child.spec.on_demand):
xml_element.append(child_xml)
if debugging.debug_mode() and debug_context:
debug_comment = debug_context.register_element_for_debugging(child)
xml_element.append(debug_comment)
if len(child_xml) > 0: # pylint: disable=g-explicit-length-test
child_xml.insert(0, copy.deepcopy(debug_comment))
def to_xml_string(self, prefix_root=None,
self_only=False, pretty_print=True, debug_context=None,
*,
precision=constants.XML_DEFAULT_PRECISION,
zero_threshold=0):
"""Generates an XML string corresponding to this MJCF element.
Args:
prefix_root: (optional) A `NameScope` object to be treated as root
for the purpose of calculating the prefix.
If `None` then no prefix is included.
self_only: (optional) A boolean, whether to generate an XML corresponding
only to this element without any children.
pretty_print: (optional) A boolean, whether to the XML string should be
properly indented.
debug_context: (optional) A `debugging.DebugContext` object to which
the debugging information associated with the generated XML is written.
This is intended for internal use within PyMJCF; users should never need
manually pass this argument.
precision: (optional) Number of digits to output for floating point
quantities.
zero_threshold: (optional) When outputting XML, floating point quantities
whose absolute value falls below this threshold will be treated as zero.
Returns:
A string.
"""
xml_element = self.to_xml(prefix_root, debug_context,
precision=precision,
zero_threshold=zero_threshold)
if self_only and len(xml_element) > 0: # pylint: disable=g-explicit-length-test
etree.strip_elements(xml_element, '*')
xml_element.text = '...'
if (self_only and self._spec.identifier and
not self._attributes[self._spec.identifier].to_xml_string(
prefix_root, precision=precision, zero_threshold=zero_threshold)):
del xml_element.attrib[self._spec.identifier]
xml_string = util.to_native_string(
etree.tostring(xml_element, pretty_print=pretty_print))
if pretty_print and debug_context:
return debug_context.commit_xml_string(xml_string)
else:
return xml_string
def __str__(self):
return self.to_xml_string(self_only=True, pretty_print=False)
def __repr__(self):
return 'MJCF Element: ' + str(self)
def _check_valid_attachment(self, other):
self_spec = self._spec
if self_spec.name == constants.WORLDBODY:
self_spec = self._spec.children[constants.BODY]
other_spec = other.spec
if other_spec.name == constants.WORLDBODY:
other_spec = other_spec.children[constants.BODY]
if other_spec != self_spec:
raise ValueError(
'The attachment must have the same spec as this element.')
def _attach(self, other, exclude_worldbody=False, dry_run=False):
"""Attaches another element of the same spec to this element.
All children of `other` will be treated as children of this element.
All XML attributes which are defined in `other` but not defined in this
element will be copied over, and any conflicting XML attribute value causes
an error. After the attachment, any XML attribute modified in this element
will also affect `other` and vice versa.
Children of this element which are not a repeated elements will also be
attached by the corresponding children of `other`.
Args:
other: Another Element with the same spec.
exclude_worldbody: (optional) A boolean. If `True`, then don't do anything
if `other` is a worldbody.
dry_run: (optional) A boolean, if `True` only verify that the operation
is valid without actually committing any change.
Raises:
ValueError: If `other` has a different spec, or if there are conflicting
XML attribute values.
"""
self._check_valid_attachment(other)
if exclude_worldbody and other.tag == constants.WORLDBODY:
return
if dry_run:
self._check_conflicting_attributes(other, copying=False)
else:
self._attachments[other.namescope] = other
self._sync_attributes(other, copying=False)
self._attach_children(other, exclude_worldbody, dry_run)
if other.tag != constants.WORLDBODY and not dry_run:
other._alias_attributes_dict(self._attributes) # pylint: disable=protected-access
def _detach(self, other_namescope):
"""Detaches a model with the specified namescope."""
attached_element = self._attachments.get(other_namescope)
if attached_element:
attached_element._restore_attributes_dict() # pylint: disable=protected-access
del self._attachments[other_namescope]
for child in self._children:
child._detach(other_namescope) # pylint: disable=protected-access
def _check_conflicting_attributes(self, other, copying):
for attribute_name, other_attribute in other.get_attributes().items():
if attribute_name == constants.DCLASS:
attribute_name = constants.CLASS
if ((not self._attributes[attribute_name].conflict_allowed)
and self._attributes[attribute_name].value is not None
and other_attribute is not None
and np.asarray(
self._attributes[attribute_name].value != other_attribute).any()):
raise ValueError(
'Conflicting values for attribute `{}`: {} vs {}'
.format(attribute_name,
self._attributes[attribute_name].value,
other_attribute))
def _sync_attributes(self, other, copying):
self._check_conflicting_attributes(other, copying)
for attribute_name, other_attribute in other.get_attributes().items():
if attribute_name == constants.DCLASS:
attribute_name = constants.CLASS
self_attribute = self._attributes[attribute_name]
if other_attribute is not None:
if self_attribute.conflict_behavior in _CONFLICT_BEHAVIOR_FUNC:
if self_attribute.value is not None:
self_attribute.value = (
_CONFLICT_BEHAVIOR_FUNC[self_attribute.conflict_behavior](
self_attribute.value, other_attribute))
else:
self_attribute.value = other_attribute
elif copying or not self_attribute.conflict_allowed:
self_attribute.value = other_attribute
def _attach_children(self, other, exclude_worldbody, dry_run=False):
for other_child in other.all_children():
if not other_child.spec.repeated:
self_child = self.get_children(other_child.spec.name)
self_child._attach(other_child, exclude_worldbody, dry_run) # pylint: disable=protected-access
def resolve_references(self):
for attribute in self._attributes.values():
if isinstance(attribute, attribute_types.Reference):
if attribute.value and isinstance(attribute.value, str):
referred = self.root.find(
attribute.reference_namespace, attribute.value)
if referred:
attribute.value = referred
for child in self.all_children():
child.resolve_references()
def _update_references(self, reference_dict):
for attribute in self._attributes.values():
if isinstance(attribute, attribute_types.Reference):
if attribute.value in reference_dict:
attribute.value = reference_dict[attribute.value]
for child in self.all_children():
child._update_references(reference_dict) # pylint: disable=protected-access
class _AttachableElement(_ElementImpl):
"""Specialized object representing a <site> or <worldbody> element.
This element defines a frame to which another MJCF model can be attached.
"""
__slots__ = []
def attach(self, attachment):
"""Attaches another MJCF model at this site.
An empty <body> will be created as an attachment frame. All children of
`attachment`'s <worldbody> will be treated as children of this frame.
Furthermore, all other elements in `attachment` are merged into the root
of the MJCF model to which this element belongs.
Args:
attachment: An MJCF `RootElement`
Returns:
An `mjcf.Element` corresponding to the attachment frame. A joint can be
added directly to this frame to give degrees of freedom to the attachment.
Raises:
ValueError: If `other` is not a valid attachment to this element.
"""