Skip to content

Commit f1551f4

Browse files
committed
memoize: make thread-safe
1 parent eb52f09 commit f1551f4

File tree

3 files changed

+62
-7
lines changed

3 files changed

+62
-7
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ The same applies if you need the macro parts of `unpythonic` (i.e. import anythi
204204

205205
- Fix bug: `fup`/`fupdate`/`ShadowedSequence` now actually accept an infinite-length iterable as a replacement sequence (under the obvious usage limitations), as the documentation has always claimed.
206206

207+
- Fix bug: `memoize` is now thread-safe.
208+
207209

208210
---
209211

unpythonic/fun.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from collections import namedtuple
2121
from functools import wraps, partial as functools_partial
2222
from inspect import signature
23+
from threading import RLock
2324
from typing import get_type_hints
2425

2526
from .arity import (_resolve_bindings, tuplify_bindings, _bind)
@@ -61,18 +62,35 @@ def memoize(f):
6162
6263
**CAUTION**: ``f`` must be pure (no side effects, no internal state
6364
preserved between invocations) for this to make any sense.
65+
66+
Beginning with v0.15.0, `memoize` is thread-safe even when the same memoized
67+
function instance is called concurrently from multiple threads. Exactly one
68+
thread will compute the result. If `f` is recursive, the thread that acquired
69+
the lock is the one that is allowed to recurse into the memoized `f`.
6470
"""
71+
# One lock per use site of `memoize`. We use an `RLock` to allow recursive calls
72+
# to the memoized `f` in the thread that acquired the lock.
73+
lock = RLock()
6574
memo = {}
6675
@wraps(f)
6776
def memoized(*args, **kwargs):
6877
k = tuplify_bindings(_resolve_bindings(f, args, kwargs, _partial=False))
69-
if k not in memo:
70-
try:
71-
result = (_success, maybe_force_args(f, *args, **kwargs))
72-
except BaseException as err:
73-
result = (_fail, err)
74-
memo[k] = result # should yell separately if k is not a valid key
75-
kind, value = memo[k]
78+
try: # EAFP to eliminate TOCTTOU.
79+
kind, value = memo[k]
80+
except KeyError:
81+
# But we still need to be careful to avoid race conditions.
82+
with lock:
83+
if k not in memo:
84+
# We were the first thread to acquire the lock.
85+
try:
86+
result = (_success, maybe_force_args(f, *args, **kwargs))
87+
except BaseException as err:
88+
result = (_fail, err)
89+
memo[k] = result # should yell separately if k is not a valid key
90+
else:
91+
# Some other thread acquired the lock before us.
92+
pass
93+
kind, value = memo[k]
7694
if kind is _fail:
7795
raise value
7896
return value

unpythonic/tests/test_fun.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
from collections import Counter
77
import sys
8+
from queue import Queue
9+
import threading
10+
from time import sleep
811

912
from ..dispatch import generic
1013
from ..fun import (memoize, partial, curry, apply,
@@ -16,6 +19,8 @@
1619
to1st, to2nd, tokth, tolast, to,
1720
withself)
1821
from ..funutil import Values
22+
from ..it import allsame
23+
from ..misc import slurp
1924

2025
from ..dynassign import dyn
2126

@@ -135,6 +140,36 @@ def t():
135140
fail["memoize should not prevent exception propagation."] # pragma: no cover
136141
test[evaluations == 1]
137142

143+
with testset("@memoize thread-safety"):
144+
def threadtest():
145+
@memoize
146+
def f(x):
147+
# Sleep a "long" time to make actual concurrent operation more likely.
148+
sleep(0.001)
149+
150+
# The trick here is that because only one thread will acquire the lock
151+
# for the memo, then for the same `x`, all the results should be the same.
152+
return (id(threading.current_thread()), x)
153+
154+
comm = Queue()
155+
def worker(que):
156+
# The value of `x` doesn't matter, as long as it's the same in all workers.
157+
r = f(42)
158+
que.put(r)
159+
160+
n = 1000
161+
threads = [threading.Thread(target=worker, args=(comm,), kwargs={}) for _ in range(n)]
162+
for t in threads:
163+
t.start()
164+
for t in threads:
165+
t.join()
166+
167+
# Test that all threads finished, and that the results from each thread are the same.
168+
results = slurp(comm)
169+
test[the[len(results)] == the[n]]
170+
test[allsame(results)]
171+
threadtest()
172+
138173
with testset("partial (type-checking wrapper)"):
139174
def nottypedfunc(x):
140175
return "ok"

0 commit comments

Comments
 (0)