-
Notifications
You must be signed in to change notification settings - Fork 745
Expand file tree
/
Copy pathcopier_test.py
More file actions
83 lines (69 loc) · 3.31 KB
/
copier_test.py
File metadata and controls
83 lines (69 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright 2018 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for `dm_control.mjcf.copier`."""
import os
from absl.testing import absltest
from dm_control import mjcf
from dm_control.mjcf import parser
import numpy as np
_ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'test_assets')
_TEST_MODEL_XML = os.path.join(_ASSETS_DIR, 'test_model.xml')
_MODEL_WITH_ASSETS_XML = os.path.join(_ASSETS_DIR, 'model_with_assets.xml')
class CopierTest(absltest.TestCase):
def testSimpleCopy(self):
mjcf_model = parser.from_path(_TEST_MODEL_XML)
mixin = mjcf.RootElement(model='test_mixin')
mixin.compiler.boundmass = 1
mjcf_model.include_copy(mixin)
self.assertEqual(mjcf_model.model, 'test') # Model name should not change
self.assertEqual(mjcf_model.compiler.boundmass, mixin.compiler.boundmass)
mixin.compiler.boundinertia = 2
mjcf_model.include_copy(mixin)
self.assertEqual(mjcf_model.compiler.boundinertia,
mixin.compiler.boundinertia)
mixin.compiler.boundinertia = 1
with self.assertRaisesRegex(ValueError, 'Conflicting values'):
mjcf_model.include_copy(mixin)
mixin.worldbody.add('body', name='b_0', pos=[0, 1, 2])
mjcf_model.include_copy(mixin, override_attributes=True)
self.assertEqual(mjcf_model.compiler.boundmass, mixin.compiler.boundmass)
self.assertEqual(mjcf_model.compiler.boundinertia,
mixin.compiler.boundinertia)
np.testing.assert_array_equal(mjcf_model.worldbody.body['b_0'].pos,
[0, 1, 2])
def testCopyingWithReference(self):
sensor_mixin = mjcf.RootElement('sensor_mixin')
touch_site = sensor_mixin.worldbody.add('site', name='touch_site')
sensor_mixin.sensor.add('touch', name='touch_sensor', site=touch_site)
mjcf_model = mjcf.RootElement('model')
mjcf_model.include_copy(sensor_mixin)
# Copied reference should be updated to the copied site.
self.assertIs(mjcf_model.find('sensor', 'touch_sensor').site,
mjcf_model.find('site', 'touch_site'))
def testCopyingWithAssets(self):
mjcf_model = parser.from_path(_MODEL_WITH_ASSETS_XML)
copied = mjcf.RootElement()
copied.include_copy(mjcf_model)
original_assets = (mjcf_model.find_all('mesh')
+ mjcf_model.find_all('texture')
+ mjcf_model.find_all('hfield'))
copied_assets = (copied.find_all('mesh')
+ copied.find_all('texture')
+ copied.find_all('hfield'))
self.assertLen(copied_assets, len(original_assets))
for original_asset, copied_asset in zip(original_assets, copied_assets):
self.assertIs(copied_asset.file, original_asset.file)
if __name__ == '__main__':
absltest.main()