# Copyright 2017 The dm_control Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Tests for the base rendering module.""" import threading from absl.testing import absltest from dm_control._render import base from dm_control._render import executor WIDTH = 1024 HEIGHT = 768 class ContextBaseTests(absltest.TestCase): class ContextMock(base.ContextBase): def _platform_init(self, max_width, max_height): self.init_thread = threading.current_thread() self.make_current_count = 0 self.max_width = max_width self.max_height = max_height self.free_thread = None def _platform_make_current(self): self.make_current_count += 1 self.make_current_thread = threading.current_thread() def _platform_free(self): self.free_thread = threading.current_thread() def setUp(self): super().setUp() self.context = ContextBaseTests.ContextMock(WIDTH, HEIGHT) def test_init(self): self.assertIs(self.context.init_thread, self.context.thread) self.assertEqual(self.context.max_width, WIDTH) self.assertEqual(self.context.max_height, HEIGHT) def test_make_current(self): self.assertEqual(self.context.make_current_count, 0) with self.context.make_current(): pass self.assertEqual(self.context.make_current_count, 1) self.assertIs(self.context.make_current_thread, self.context.thread) # Already current, shouldn't trigger a call to `_platform_make_current`. with self.context.make_current(): pass self.assertEqual(self.context.make_current_count, 1) self.assertIs(self.context.make_current_thread, self.context.thread) def test_thread_sharing(self): first_context = ContextBaseTests.ContextMock( WIDTH, HEIGHT, executor.PassthroughRenderExecutor) second_context = ContextBaseTests.ContextMock( WIDTH, HEIGHT, executor.PassthroughRenderExecutor) with first_context.make_current(): pass self.assertEqual(first_context.make_current_count, 1) with first_context.make_current(): pass self.assertEqual(first_context.make_current_count, 1) with second_context.make_current(): pass self.assertEqual(second_context.make_current_count, 1) with second_context.make_current(): pass self.assertEqual(second_context.make_current_count, 1) with first_context.make_current(): pass self.assertEqual(first_context.make_current_count, 2) with second_context.make_current(): pass self.assertEqual(second_context.make_current_count, 2) def test_free(self): with self.context.make_current(): pass thread = self.context.thread self.assertIn(id(self.context), base._CURRENT_THREAD_FOR_CONTEXT) self.assertIn(thread, base._CURRENT_CONTEXT_FOR_THREAD) self.context.free() self.assertIs(self.context.free_thread, thread) self.assertIsNone(self.context.thread) self.assertNotIn(id(self.context), base._CURRENT_THREAD_FOR_CONTEXT) self.assertNotIn(thread, base._CURRENT_CONTEXT_FOR_THREAD) def test_free_with_multiple_contexts(self): context1 = ContextBaseTests.ContextMock(WIDTH, HEIGHT, executor.PassthroughRenderExecutor) with context1.make_current(): pass context2 = ContextBaseTests.ContextMock(WIDTH, HEIGHT, executor.PassthroughRenderExecutor) with context2.make_current(): pass self.assertEqual(base._CURRENT_CONTEXT_FOR_THREAD[threading.main_thread()], id(context2)) self.assertIs(base._CURRENT_THREAD_FOR_CONTEXT[id(context2)], threading.main_thread()) context1.free() self.assertIsNone( base._CURRENT_CONTEXT_FOR_THREAD[self.context.free_thread]) self.assertIsNone(base._CURRENT_THREAD_FOR_CONTEXT[id(context2)]) def test_refcounting(self): thread = self.context.thread self.assertEqual(self.context._refcount, 0) self.context.increment_refcount() self.assertEqual(self.context._refcount, 1) # Context should not be freed yet, since its refcount is still positive. self.context.free() self.assertIsNone(self.context.free_thread) self.assertIs(self.context.thread, thread) # Decrement the refcount to zero. self.context.decrement_refcount() self.assertEqual(self.context._refcount, 0) # Now the context can be freed. self.context.free() self.assertIs(self.context.free_thread, thread) self.assertIsNone(self.context.thread) def test_del(self): self.assertIsNone(self.context.free_thread) self.context.__del__() self.assertIsNotNone(self.context.free_thread) if __name__ == '__main__': absltest.main()