# Copyright 2018 The dm_control Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Tests for `mjcf.debugging`.""" import contextlib import os import re import shutil import sys from absl.testing import absltest from dm_control import mjcf from dm_control.mjcf import code_for_debugging_test as test_code from dm_control.mjcf import debugging ORIGINAL_DEBUG_MODE = debugging.debug_mode() class DebuggingTest(absltest.TestCase): def tearDown(self): super().tearDown() if ORIGINAL_DEBUG_MODE: debugging.enable_debug_mode() else: debugging.disable_debug_mode() def setup_debug_mode(self, debug_mode_enabled, full_dump_enabled=False): if debug_mode_enabled: debugging.enable_debug_mode() else: debugging.disable_debug_mode() if full_dump_enabled: base_dir = absltest.get_default_test_tmpdir() self.dump_dir = os.path.join(base_dir, 'mjcf_debugging_test') shutil.rmtree(self.dump_dir, ignore_errors=True) os.mkdir(self.dump_dir) else: self.dump_dir = '' debugging.set_full_dump_dir(self.dump_dir) def assertStackFromTestCode(self, stack, function_name, line_ref): self.assertEqual(stack[-1].function_name, function_name) self.assertStartsWith(test_code.__file__, stack[-1].filename) line_info = test_code.LINE_REF['.'.join([function_name, line_ref])] self.assertEqual(stack[-1].line_number, line_info.line_number) self.assertEqual(stack[-1].text, line_info.text) @contextlib.contextmanager def assertRaisesTestCodeRef(self, line_ref): filename, _ = os.path.splitext(test_code.__file__) expected_message = ( filename + '.py:' + str(test_code.LINE_REF[line_ref].line_number)) print(expected_message) with self.assertRaisesRegex(ValueError, expected_message): yield def test_get_current_stack_trace(self): self.setup_debug_mode(debug_mode_enabled=True) stack_trace = debugging.get_current_stack_trace() self.assertStartsWith( sys.modules[__name__].__file__, stack_trace[-1].filename) self.assertEqual(stack_trace[-1].function_name, 'test_get_current_stack_trace') self.assertEqual(stack_trace[-1].text, 'stack_trace = debugging.get_current_stack_trace()') def test_disable_debug_mode(self): self.setup_debug_mode(debug_mode_enabled=False) mjcf_model = test_code.make_valid_model() test_code.break_valid_model(mjcf_model) self.assertFalse(mjcf_model.get_init_stack()) my_actuator = mjcf_model.find('actuator', 'my_actuator') my_actuator_attrib_stacks = ( my_actuator.get_last_modified_stacks_for_all_attributes()) for stack in my_actuator_attrib_stacks.values(): self.assertFalse(stack) def test_element_and_attribute_stacks(self): self.setup_debug_mode(debug_mode_enabled=True) mjcf_model = test_code.make_valid_model() test_code.break_valid_model(mjcf_model) self.assertStackFromTestCode(mjcf_model.get_init_stack(), 'make_valid_model', 'mjcf_model') my_actuator = mjcf_model.find('actuator', 'my_actuator') self.assertStackFromTestCode(my_actuator.get_init_stack(), 'make_valid_model', 'my_actuator') my_actuator_attrib_stacks = ( my_actuator.get_last_modified_stacks_for_all_attributes()) # `name` attribute was assigned at the same time as the element was created. self.assertEqual(my_actuator_attrib_stacks['name'], my_actuator.get_init_stack()) # `joint` attribute was modified later on. self.assertStackFromTestCode(my_actuator_attrib_stacks['joint'], 'break_valid_model', 'my_actuator.joint') def test_valid_physics(self): self.setup_debug_mode(debug_mode_enabled=True) mjcf_model = test_code.make_valid_model() mjcf.Physics.from_mjcf_model(mjcf_model) # Should not raise def test_physics_error_message_outside_of_debug_mode(self): self.setup_debug_mode(debug_mode_enabled=False) mjcf_model = test_code.make_broken_model() # Make sure that we advertise debug mode if it's currently disabled. with self.assertRaisesRegex(ValueError, '--pymjcf_debug'): mjcf.Physics.from_mjcf_model(mjcf_model) def test_physics_error_message_in_debug_mode(self): self.setup_debug_mode(debug_mode_enabled=True) mjcf_model_1 = test_code.make_broken_model() with self.assertRaisesTestCodeRef('make_broken_model.my_actuator'): mjcf.Physics.from_mjcf_model(mjcf_model_1) mjcf_model_2 = test_code.make_valid_model() physics = mjcf.Physics.from_mjcf_model(mjcf_model_2) # Should not raise. test_code.break_valid_model(mjcf_model_2) with self.assertRaisesTestCodeRef('break_valid_model.my_actuator.joint'): physics.reload_from_mjcf_model(mjcf_model_2) def test_full_debug_dump(self): self.setup_debug_mode(debug_mode_enabled=True, full_dump_enabled=False) mjcf_model = test_code.make_valid_model() test_code.break_valid_model(mjcf_model) # Make sure that we advertise full dump mode if it's currently disabled. with self.assertRaisesRegex(ValueError, '--pymjcf_debug_full_dump_dir'): mjcf.Physics.from_mjcf_model(mjcf_model) self.setup_debug_mode(debug_mode_enabled=True, full_dump_enabled=True) with self.assertRaises(ValueError): mjcf.Physics.from_mjcf_model(mjcf_model) with open(os.path.join(self.dump_dir, 'model.xml')) as f: dumped_xml = f.read() dumped_xml = [line.strip() for line in dumped_xml.strip().split('\n')] xml_line_pattern = re.compile(r'^(.*)$') uninstrumented_pattern = re.compile(r'({})'.format( '|'.join([ r'', r'', r'', r'' ]))) for xml_line in dumped_xml: print(xml_line) xml_line_match = xml_line_pattern.match(xml_line) if not xml_line_match: # Only uninstrumented lines are allowed to have no metadata. self.assertIsNotNone(uninstrumented_pattern.match(xml_line)) else: xml_element = xml_line_match.group(1) debug_id = int(xml_line_match.group(2)) with open(os.path.join(self.dump_dir, str(debug_id) + '.dump')) as f: element_dump = f.read() self.assertIn(xml_element, element_dump) if __name__ == '__main__': absltest.main()