forked from google-deepmind/dm_control
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathattribute.py
More file actions
556 lines (458 loc) · 18.2 KB
/
attribute.py
File metadata and controls
556 lines (458 loc) · 18.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
# Copyright 2018 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Classes representing various MJCF attribute data types."""
import abc
import collections
import hashlib
import io
import os
from dm_control.mjcf import base
from dm_control.mjcf import constants
from dm_control.mjcf import debugging
from dm_control.mjcf import skin
from dm_control.mujoco.wrapper import util
import numpy as np
# Copybara placeholder for internal file handling dependency.
from dm_control.utils import io as resources
_INVALID_REFERENCE_TYPE = (
'Reference should be an MJCF Element whose type is {valid_type!r}: '
'got {actual_type!r}.')
_MESH_EXTENSIONS = ('.stl', '.msh', '.obj')
# MuJoCo's compiler enforces this.
_INVALID_MESH_EXTENSION = (
'Mesh files must have one of the following extensions: {}, got {{}}.'
.format(_MESH_EXTENSIONS))
class _Attribute(metaclass=abc.ABCMeta):
"""Abstract base class for MJCF attribute data types."""
def __init__(self, name, required, parent, value,
conflict_allowed, conflict_behavior):
self._name = name
self._required = required
self._parent = parent
self._value = None
self._conflict_allowed = conflict_allowed
self._conflict_behavior = conflict_behavior
self._check_and_assign(value)
def _check_and_assign(self, new_value):
if new_value is None:
self.clear()
elif isinstance(new_value, str):
self._assign_from_string(new_value)
else:
self._assign(new_value)
if debugging.debug_mode():
self._last_modified_stack = debugging.get_current_stack_trace()
@property
def last_modified_stack(self):
if debugging.debug_mode():
return self._last_modified_stack
@property
def value(self):
return self._value
@value.setter
def value(self, new_value):
self._check_and_assign(new_value)
@abc.abstractmethod
def _assign(self, value):
raise NotImplementedError # pragma: no cover
def clear(self):
if self._required:
raise AttributeError(
'Attribute {!r} of element <{}> is required'
.format(self._name, self._parent.tag))
else:
self._force_clear()
def _force_clear(self):
self._before_clear()
self._value = None
if debugging.debug_mode():
self._last_modified_stack = debugging.get_current_stack_trace()
def _before_clear(self):
pass
def _assign_from_string(self, string):
self._assign(string)
def to_xml_string(self, prefix_root, **kwargs): # pylint: disable=unused-argument
if self._value is None:
return None
else:
return str(self._value)
@property
def conflict_allowed(self):
return self._conflict_allowed
@property
def conflict_behavior(self):
return self._conflict_behavior
class String(_Attribute):
"""A string MJCF attribute."""
def _assign(self, value):
if not isinstance(value, str):
raise ValueError('Expect a string value: got {}'.format(value))
elif not value:
self.clear()
else:
self._value = value
class Integer(_Attribute):
"""An integer MJCF attribute."""
def _assign(self, value):
try:
float_value = float(value)
int_value = int(float(value))
if float_value != int_value:
raise ValueError
except ValueError:
raise ValueError(
'Expect an integer value: got {}'.format(value)) from None
self._value = int_value
class Float(_Attribute):
"""An float MJCF attribute."""
def _assign(self, value):
try:
float_value = float(value)
except ValueError:
raise ValueError('Expect a float value: got {}'.format(value)) from None
self._value = float_value
def to_xml_string(self, prefix_root=None,
*,
precision=constants.XML_DEFAULT_PRECISION,
zero_threshold=0,
**kwargs):
if self._value is None:
return None
else:
out = io.BytesIO()
value = self._value
if abs(value) < zero_threshold:
value = 0.0
np.savetxt(out, [value], fmt=f'%.{precision:d}g', newline=' ')
return util.to_native_string(out.getvalue())[:-1] # Strip trailing space.
class Keyword(_Attribute):
"""A keyword MJCF attribute."""
def __init__(self, name, required, parent, value,
conflict_allowed, conflict_behavior, valid_values):
self._valid_values = collections.OrderedDict(
(value.lower(), value) for value in valid_values)
super().__init__(name, required, parent, value, conflict_allowed,
conflict_behavior)
def _assign(self, value):
if value is None or value == '': # pylint: disable=g-explicit-bool-comparison
self.clear()
else:
try:
self._value = self._valid_values[str(value).lower()]
except KeyError:
raise ValueError('Expect keyword to be one of {} but got: {}'.format(
list(self._valid_values.values()), value)) from None
@property
def valid_values(self):
return list(self._valid_values.keys())
class Array(_Attribute):
"""An array MJCF attribute."""
def __init__(self, name, required, parent, value,
conflict_allowed, conflict_behavior, length, dtype):
self._length = length
self._dtype = dtype
super().__init__(name, required, parent, value, conflict_allowed,
conflict_behavior)
def _assign(self, value):
self._value = self._check_shape(np.array(value, dtype=self._dtype))
def _assign_from_string(self, string):
self._assign(np.fromstring(string, dtype=self._dtype, sep=' '))
def to_xml_string(self, prefix_root=None,
*,
precision=constants.XML_DEFAULT_PRECISION,
zero_threshold=0,
**kwargs):
if self._value is None:
return None
else:
out = io.BytesIO()
value = self._value
if zero_threshold:
value = np.copy(value)
value[np.abs(value) < zero_threshold] = 0
np.savetxt(out, value, fmt=f'%.{precision:d}g', newline=' ')
return util.to_native_string(out.getvalue())[:-1] # Strip trailing space.
def _check_shape(self, array):
actual_length = array.shape[0]
if len(array.shape) > 1:
raise ValueError('Expect one-dimensional array: got {}'.format(array))
if self._length and actual_length > self._length:
raise ValueError('Expect array with no more than {} entries: got {}'
.format(self._length, array))
return array
class Identifier(_Attribute):
"""A string attribute that represents a unique identifier of an element."""
def _assign(self, value):
if not isinstance(value, str):
raise ValueError('Expect a string value: got {}'.format(value))
elif not value:
self.clear()
elif self._parent.spec.namespace == 'body' and value == 'world':
raise ValueError('A body cannot be named \'world\'. '
'The name \'world\' is used by MuJoCo to refer to the '
'<worldbody>.')
elif constants.PREFIX_SEPARATOR in value:
raise ValueError(
'An identifier cannot contain a {!r}, '
'as this is reserved for scoping purposes: got {!r}'
.format(constants.PREFIX_SEPARATOR, value))
else:
old_value = self._value
if value != old_value:
self._parent.namescope.add(
self._parent.spec.namespace, value, self._parent)
if old_value:
self._parent.namescope.remove(self._parent.spec.namespace, old_value)
self._value = value
def _before_clear(self):
if self._value:
self._parent.namescope.remove(self._parent.spec.namespace, self._value)
def _defaults_string(self, prefix_root):
prefix = self._parent.namescope.full_prefix(prefix_root, as_list=True)
prefix.append(self._value or '')
return constants.PREFIX_SEPARATOR.join(prefix) or constants.PREFIX_SEPARATOR
def to_xml_string(self, prefix_root=None, **kwargs):
if self._parent.tag == constants.DEFAULT:
return self._defaults_string(prefix_root)
elif self._value:
prefix = self._parent.namescope.full_prefix(prefix_root, as_list=True)
prefix.append(self._value)
return constants.PREFIX_SEPARATOR.join(prefix)
else:
return self._value
class Reference(_Attribute):
"""A string attribute that represents a reference to an identifier."""
def __init__(self, name, required, parent, value,
conflict_allowed, conflict_behavior, reference_namespace):
self._reference_namespace = reference_namespace
super().__init__(name, required, parent, value, conflict_allowed,
conflict_behavior)
def _check_dead_reference(self):
if isinstance(self._value, base.Element) and self._value.is_removed:
self.clear()
@property
def value(self):
self._check_dead_reference()
return super().value
@value.setter
def value(self, new_value):
super(Reference, self.__class__).value.fset(self, new_value)
@property
def reference_namespace(self):
if isinstance(self._reference_namespace, _Attribute):
return constants.INDIRECT_REFERENCE_ATTRIB.get(
self._reference_namespace.value, self._reference_namespace.value)
else:
return self._reference_namespace
def _assign(self, value):
if not isinstance(value, (base.Element, str)):
raise ValueError(
'Expect a string or `mjcf.Element` value: got {}'.format(value))
elif not value:
self.clear()
else:
if isinstance(value, base.Element):
value_namespace = (
value.spec.namespace.split(constants.NAMESPACE_SEPARATOR)[0])
if value_namespace != self.reference_namespace:
raise ValueError(_INVALID_REFERENCE_TYPE.format(
valid_type=self.reference_namespace,
actual_type=value_namespace))
self._value = value
def _before_clear(self):
if isinstance(self._value, base.Element):
if isinstance(self._reference_namespace, _Attribute):
self._reference_namespace._force_clear() # pylint: disable=protected-access
def _defaults_string(self, prefix_root):
"""Generates the XML string if this is a reference to a defaults class.
To prevent global defaults from clashing, we turn all global defaults
into a properly named defaults class. Therefore, care must be taken when
this attribute is not explicitly defined. If the parent element can be
traced up to a body with a nontrivial 'childclass' then must continue to
leave this attribute undefined.
Args:
prefix_root: A `NameScope` object to be treated as root
for the purpose of calculating the prefix.
Returns:
A string to be used in the generated XML.
"""
self._check_dead_reference()
prefix = self._parent.namescope.full_prefix(prefix_root)
if not self._value:
defaults_root = self._parent.parent
while defaults_root is not None:
if (hasattr(defaults_root, constants.CHILDCLASS)
and defaults_root.childclass):
break
defaults_root = defaults_root.parent
if defaults_root is None:
# This element doesn't belong to a childclass'd body.
global_class = self._parent.root.default.dclass or ''
out_string = (prefix + global_class) or constants.PREFIX_SEPARATOR
else:
out_string = None
else:
out_string = prefix + self._value
return out_string
def to_xml_string(self, prefix_root, **kwargs):
self._check_dead_reference()
if isinstance(self._value, base.Element):
return self._value.prefixed_identifier(prefix_root)
elif (self.reference_namespace == constants.DEFAULT
and self._name != constants.CHILDCLASS):
return self._defaults_string(prefix_root)
elif self._value:
return self._parent.namescope.full_prefix(prefix_root) + self._value
else:
return None
class BasePath(_Attribute):
"""A string attribute that represents a base path for an asset type."""
def __init__(self, name, required, parent, value,
conflict_allowed, conflict_behavior, path_namespace):
self._path_namespace = path_namespace
super().__init__(name, required, parent, value, conflict_allowed,
conflict_behavior)
def _assign(self, value):
if not isinstance(value, str):
raise ValueError('Expect a string value: got {}'.format(value))
elif not value:
self.clear()
else:
self._parent.namescope.replace(
constants.BASEPATH, self._path_namespace, value)
self._value = value
def _before_clear(self):
if self._value:
self._parent.namescope.remove(constants.BASEPATH, self._path_namespace)
def to_xml_string(self, prefix_root=None, **kwargs):
return None
class BaseAsset:
"""Base class for binary assets."""
__slots__ = ('extension', 'prefix')
def __init__(self, extension, prefix=''):
self.extension = extension
self.prefix = prefix
def __eq__(self, other):
return self.get_vfs_filename() == other.get_vfs_filename()
def get_vfs_filename(self):
"""Returns the name of the asset file as registered in MuJoCo's VFS."""
# Hash the contents of the asset to get a unique identifier.
hash_string = hashlib.sha1(util.to_binary_string(self.contents)).hexdigest()
# Prepend the prefix, if one exists.
if self.prefix:
prefix = self.prefix
raw_length = len(prefix) + len(hash_string) + len(self.extension) + 1
if raw_length > constants.MAX_VFS_FILENAME_LENGTH:
trim_amount = raw_length - constants.MAX_VFS_FILENAME_LENGTH
prefix = prefix[:-trim_amount]
filename = '-'.join([prefix, hash_string])
else:
filename = hash_string
# An extension is needed because MuJoCo's compiler looks at this when
# deciding how to load meshes and heightfields.
return filename + self.extension
class Asset(BaseAsset):
"""Class representing a binary asset."""
__slots__ = ('contents',)
def __init__(self, contents, extension, prefix=''):
"""Initializes a new `Asset`.
Args:
contents: The contents of the file as a bytestring.
extension: A string specifying the file extension (e.g. '.png', '.stl').
prefix: (optional) A prefix applied to the filename given in MuJoCo's VFS.
"""
self.contents = contents
super().__init__(extension, prefix)
class SkinAsset(BaseAsset):
"""Class representing a binary asset corresponding to a skin."""
__slots__ = ('skin', 'parent', '_cached_revision', '_cached_contents')
def __init__(self, contents, parent, extension, prefix=''):
self.skin = skin.parse(
contents, lambda body_name: parent.root.find('body', body_name))
self.parent = parent
self._cached_revision = -1
self._cached_contents = None
super().__init__(extension, prefix)
@property
def contents(self):
if self._cached_revision < self.parent.namescope.revision:
self._cached_contents = skin.serialize(self.skin)
self._cached_revision = self.parent.namescope.revision
return self._cached_contents
class File(_Attribute):
"""Attribute representing an asset file."""
def __init__(self, name, required, parent, value,
conflict_allowed, conflict_behavior, path_namespace):
self._path_namespace = path_namespace
super().__init__(name, required, parent, value, conflict_allowed,
conflict_behavior)
parent.namescope.files.add(self)
def _assign(self, value):
if not value:
self.clear()
else:
if isinstance(value, str):
asset = self._get_asset_from_path(value)
elif isinstance(value, Asset):
asset = value
else:
raise ValueError('Expect either a string or `Asset` value: got {}'
.format(value))
self._validate_extension(asset.extension)
self._value = asset
def _get_asset_from_path(self, path):
"""Constructs a `Asset` given a file path."""
_, basename = os.path.split(path)
filename, extension = os.path.splitext(basename)
# Look in the dict of pre-loaded assets before checking the filesystem.
try:
contents = self._parent.namescope.assets[path]
except KeyError:
# Construct the full path to the asset file, prefixed by the path to the
# model directory, and by `meshdir` or `texturedir` if appropriate.
path_parts = []
if self._parent.namescope.model_dir:
path_parts.append(self._parent.namescope.model_dir)
try:
base_path = self._parent.namescope.get(constants.BASEPATH,
self._path_namespace)
path_parts.append(base_path)
except KeyError:
pass
path_parts.append(path)
full_path = os.path.join(*path_parts) # pylint: disable=no-value-for-parameter
contents = resources.GetResource(full_path)
if self._parent.tag == constants.SKIN:
return SkinAsset(contents=contents, parent=self._parent,
extension=extension, prefix=filename)
else:
return Asset(contents=contents, extension=extension, prefix=filename)
def _validate_extension(self, extension):
if self._parent.tag == constants.MESH:
if extension.lower() not in _MESH_EXTENSIONS:
raise ValueError(_INVALID_MESH_EXTENSION.format(extension))
def get_contents(self):
"""Returns a bytestring representing the contents of the asset."""
if self._value is None:
raise RuntimeError('You must assign a value to this attribute before '
'querying the contents.')
return self._value.contents
def to_xml_string(self, prefix_root=None, **kwargs):
"""Returns the asset filename as it will appear in the generated XML."""
del prefix_root # Unused
if self._value is not None:
return self._value.get_vfs_filename()
else:
return None