Skip to content

Commit 4e6426a

Browse files
saran-tcopybara-github
authored andcommitted
Fix rendering context cleanup logic.
Fixes #310. PiperOrigin-RevId: 472296006 Change-Id: I558c1fdf3753ed660f4463ae2dc054f496164c8a
1 parent 3ee9a41 commit 4e6426a

2 files changed

Lines changed: 27 additions & 3 deletions

File tree

dm_control/_render/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,13 @@ def thread(self):
8888
return self._render_executor.thread
8989

9090
def _free_on_executor_thread(self): # pylint: disable=missing-function-docstring
91-
if _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread] == id(self):
92-
del _CURRENT_THREAD_FOR_CONTEXT[id(self)]
93-
del _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread]
91+
current_ctx = _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread]
92+
if current_ctx is not None:
93+
del _CURRENT_THREAD_FOR_CONTEXT[current_ctx]
94+
del _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread]
95+
9496
self._platform_make_current()
97+
9598
try:
9699
dummy = []
97100
while self._patients:

dm_control/_render/base_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,27 @@ def test_free(self):
110110
self.assertNotIn(id(self.context), base._CURRENT_THREAD_FOR_CONTEXT)
111111
self.assertNotIn(thread, base._CURRENT_CONTEXT_FOR_THREAD)
112112

113+
def test_free_with_multiple_contexts(self):
114+
context1 = ContextBaseTests.ContextMock(WIDTH, HEIGHT,
115+
executor.PassthroughRenderExecutor)
116+
with context1.make_current():
117+
pass
118+
119+
context2 = ContextBaseTests.ContextMock(WIDTH, HEIGHT,
120+
executor.PassthroughRenderExecutor)
121+
with context2.make_current():
122+
pass
123+
124+
self.assertEqual(base._CURRENT_CONTEXT_FOR_THREAD[threading.main_thread()],
125+
id(context2))
126+
self.assertIs(base._CURRENT_THREAD_FOR_CONTEXT[id(context2)],
127+
threading.main_thread())
128+
129+
context1.free()
130+
self.assertIsNone(
131+
base._CURRENT_CONTEXT_FOR_THREAD[self.context.free_thread])
132+
self.assertIsNone(base._CURRENT_THREAD_FOR_CONTEXT[id(context2)])
133+
113134
def test_refcounting(self):
114135
thread = self.context.thread
115136

0 commit comments

Comments
 (0)