-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtest_gmemo.py
More file actions
350 lines (315 loc) · 14.1 KB
/
test_gmemo.py
File metadata and controls
350 lines (315 loc) · 14.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
# -*- coding: utf-8 -*-
from ..syntax import macros, test, test_raises, fail, the # noqa: F401
from ..test.fixtures import session, testset
from itertools import count, takewhile, chain
from collections import Counter
from ..gmemo import gmemoize, imemoize, fimemoize
from ..it import take, drop, last
from ..fold import prod
from ..funutil import call
from ..misc import timer
def runtests():
with testset("multiple instances, interleaved"):
total_evaluations = 0
@gmemoize
def gen():
nonlocal total_evaluations
j = 1
while True:
total_evaluations += 1
yield j
j += 1
g1 = gen()
g2 = gen()
test[next(g1) == 1]
test[next(g1) == 2]
test[next(g2) == 1]
test[next(g1) == 3]
test[next(g2) == 2]
test[next(g2) == 3]
test[next(g2) == 4]
test[next(g1) == 4]
g3 = gen()
test[next(g3) == 1]
test[next(g3) == 2]
test[next(g3) == 3]
test[next(g3) == 4]
test[total_evaluations == 4]
with testset("multiple instances, exhaust one first"):
total_evaluations = 0
@gmemoize
def gen():
nonlocal total_evaluations
for j in range(3):
total_evaluations += 1
yield j
g1 = gen()
g2 = gen()
test[total_evaluations == 0]
test[tuple(x for x in g1) == (0, 1, 2)]
test[total_evaluations == 3]
test[tuple(x for x in g2) == (0, 1, 2)]
test[total_evaluations == 3]
with testset("@gmemoize caches exceptions"):
class AllOkJustTesting(Exception):
pass
total_evaluations = 0
@gmemoize
def gen():
nonlocal total_evaluations
total_evaluations += 1
yield 1
total_evaluations += 1
raise AllOkJustTesting("ha ha only serious")
g1 = gen()
test[total_evaluations == 0]
try:
next(g1)
test[total_evaluations == 1]
next(g1)
except AllOkJustTesting as err:
exc_instance = err
else:
fail["Should have raised at the second next() call."] # pragma: no cover
test[total_evaluations == 2]
g2 = gen()
next(g2)
test[total_evaluations == 2] # still just two, it's memoized
try:
next(g2)
except AllOkJustTesting as err2:
test[the[err2 is exc_instance], "should be the same cached exception instance"]
else:
fail["Should have raised at the second next() call."] # pragma: no cover
test[total_evaluations == 2]
with testset("subscripting to get already computed items"):
@gmemoize
def gen():
yield from range(5)
g3 = gen()
# Any item that has entered the memo can be retrieved by subscripting.
# len() is the current length of the memo.
test[len(g3) == 0]
next(g3)
test[len(g3) == 1]
next(g3)
test[len(g3) == 2]
next(g3)
test[len(g3) == 3]
test[g3[0] == 0]
test[g3[1] == 1]
test[g3[2] == 2]
# Items not yet memoized cannot be retrieved from the memo.
test_raises[IndexError, g3[3]]
# Negative indices work too, counting from the current end of the memo.
test[g3[-1] == 2]
test[g3[-2] == 1]
test[g3[-3] == 0]
# Counting back past the start is an error, just like in `list`.
test_raises[IndexError, g3[-4]]
# Slicing is supported.
test[g3[0:3] == [0, 1, 2]]
test[g3[0:2] == [0, 1]]
test[g3[::-1] == [2, 1, 0]]
test[g3[0::2] == [0, 2]]
test[g3[2::-2] == [2, 0]]
# Out-of-range slices produce the empty list, like in `list`.
test[g3[3:] == []]
test[g3[-4::-1] == []]
with testset("memoizing a sequence partially"):
# To do this, build a chain of generators, then memoize only the last one:
evaluations = Counter()
def orig():
yield from range(100)
def evens():
yield from (x for x in orig() if x % 2 == 0)
@gmemoize
def some_evens(n): # drop n first terms
evaluations[n] += 1
yield from drop(n, evens())
last(some_evens(25))
last(some_evens(25))
last(some_evens(20))
test[all(v == 1 for k, v in the[evaluations].items())]
# Or use lambda for a more compact presentation:
se = gmemoize(lambda n: (yield from drop(n, evens())))
test[the[last(se(25))] == the[last(se(25))]] # iterating twice!
# Using fimemoize, we can omit the "yield from" (specifying a regular
# factory function that makes an iterable, instead of a gfunc):
se = fimemoize(lambda n: drop(n, evens()))
test[the[last(se(25))] == the[last(se(25))]] # iterating twice!
# In the nonparametric case, we can memoize the iterable directly:
se = imemoize(drop(25, evens()))
test[the[last(se())] == the[last(se())]] # iterating twice!
# DANGER: WRONG! Now we get a new instance of evens() also for the same n,
# so each call to se(n) caches separately. (This is why we have fimemoize.)
se = lambda n: call(imemoize(drop(n, evens()))) # call() invokes the gfunc
test[the[last(se(25))] == the[last(se(25))]]
test[the[last(se(20))] == the[last(se(20))]]
with testset("FP sieve of Eratosthenes"):
def primes(): # no memoization, recomputes unnecessarily, very slow
yield 2
for n in count(start=3, step=2):
if not any(n % p == 0 for p in takewhile(lambda x: x * x <= n, primes())):
yield n
@gmemoize # <-- the only change (beside the function name)
def mprimes(): # external memo for users, re-use it internally - simplest code
yield 2
for n in count(start=3, step=2):
if not any(n % p == 0 for p in takewhile(lambda x: x * x <= n, mprimes())):
yield n
def memo_primes(): # manual internal memo only - fastest, no caching for users
memo = []
def manual_mprimes():
memo.append(2)
yield 2
for n in count(start=3, step=2):
if not any(n % p == 0 for p in takewhile(lambda x: x * x <= n, memo)):
memo.append(n)
yield n
return manual_mprimes()
# external memo for users, separate manual internal memo
# doubles memory usage due to exactly one internal memo; almost as fast as memo_primes
# since the tight inner loop skips the very general gmemoize machinery
#
# This version wins in speed for moderate n (1e5) on typical architectures where
# the memory bus is a bottleneck, since the rule for generating new candidates is
# simple arithmetic. Contrast memo_primes3, which needs to keep a table that gets
# larger as n grows (so memory transfers dominate for large n). That strategy
# seems faster for n ~ 1e3, though.
@gmemoize
def memo_primes2():
memo = []
def manual_mprimes2():
memo.append(2)
yield 2
for n in count(start=3, step=2):
if not any(n % p == 0 for p in takewhile(lambda x: x * x <= n, memo)):
memo.append(n)
yield n
return manual_mprimes2()
# small refinement: skip testing 15, 25, 35, ...
# - we know that in base-10, for any prime > 10 the last digit must be 1, 3, 7 or 9;
# if it is 0, 2 or 5, the number is divisible by at least one factor of 10 (namely 2 or 5)
# - n < 10 must be checked separately; the primes are 2, 3, 5, 7
# (note the factors of 10 are there, plus some unrelated primes)
@gmemoize
def mprimes2():
yield 2
for n in chain([3, 5, 7], (d + k for d in count(10, step=10)
for k in [1, 3, 7, 9])):
if not any(n % p == 0 for p in takewhile(lambda x: x * x <= n, mprimes2())):
yield n
# generalization: let's not be limited by base-10
# base-b representation, switch b when appropriate:
# n = k*b + m
# b = 2*3, 2*3*5, 2*3*5*7, ...
# k: integer, 1, 2, ..., {next factor to account for in b} - 1
# so e.g. when b=6, we check from 6 to 29; when b=30, from 30 to 209, ...
# m: last digit in base-b representation of n, note m < b
# for a number represented in base-b to be prime, m must not be divisible by any factor of b
# Only the numbers up to b must be checked separately (and already have when we reach the next b).
#
# For the first 5e4 primes, about 20% of the integers within each range are candidates.
# If you want the details, add this just before "for n in ns:":
# print(b, ns[-1]**(1/2), len(ns), (nextp-1)*b, len(ns)/((nextp-1)*b))
@gmemoize
def mprimes3():
# minimal init takes three terms; b = 2*3 = 6 > 5, so no overlap in output of init and general loop
# (and this init yields all primes up to b = 6)
yield from (2, 3, 5)
theprimes = mprimes3()
ps = list(take(2, theprimes)) # factors of b; b is chosen s.t. each factor is a different prime
p, b, np = ps[-1], prod(ps), len(ps)
lastdigits = [1, 3, 5] # last digits in base-6 that are not divisible by 2
while True:
nextp = next(theprimes)
lastdigits = [n for n in lastdigits if n % p != 0]
ns = [k * b + m for k in range(1, nextp)
for m in lastdigits]
# in ns, we have already eliminated the first np primes as possible factors, so skip checking them
for n in ns:
if not any(n % p == 0 for p in takewhile(lambda x: x * x <= n, drop(np, mprimes3()))):
yield n
ps.append(nextp)
b *= nextp
p = nextp
np += 1
lastdigits = lastdigits + ns
test[the[tuple(take(500, mprimes3())) == tuple(take(500, mprimes2()))]] # de-spam: don't capture the LHS.
@gmemoize
def memo_primes3():
memo = []
def manual_mprimes3():
for p in (2, 3, 5):
memo.append(p)
yield p
p, b, np = 3, 6, 2
lastdigits = [1, 3, 5]
while True:
nextp = memo[np]
lastdigits = [n for n in lastdigits if n % p != 0]
ns = [k * b + m for k in range(1, nextp)
for m in lastdigits]
for n in ns:
if not any(n % p == 0 for p in takewhile(lambda x: x * x <= n, drop(np, memo))):
memo.append(n)
yield n
b *= nextp
p = nextp
np += 1
lastdigits += ns
return manual_mprimes3()
test[the[tuple(take(500, memo_primes3())) == tuple(take(500, mprimes2()))]]
@gmemoize
def memo_primes4():
memo = []
def manual_mprimes4():
for p in (2, 3, 5):
memo.append(p)
yield p
p, b, np = 3, 6, 2
lastdigits = [1, 3, 5]
maxnp = 5 # --> b = 2*3*5*7*11 = 2310; optimal setting depends on CPU cache size
while True:
nextp = memo[np]
lastdigits = [n for n in lastdigits if n % p != 0]
ns = [k * b + m for k in range(1, nextp)
for m in lastdigits]
for n in ns:
if not any(n % p == 0 for p in takewhile(lambda x: x * x <= n, drop(np, memo))):
memo.append(n)
yield n
if np == maxnp: # avoid table becoming too big (leading to memory bus dominated run time)
break
b *= nextp
p = nextp
np += 1
lastdigits += ns
# once maximum b reached, stay at that b, using the final table of lastdigits
for kb in count(nextp * b, step=b):
for n in (kb + m for m in lastdigits):
if not any(n % p == 0 for p in takewhile(lambda x: x * x <= n, drop(np, memo))):
memo.append(n)
yield n
return manual_mprimes4()
test[the[tuple(take(500, memo_primes4())) == tuple(take(500, mprimes2()))]]
test[last(take(5000, memo_primes4())) == 48611] # trigger the maxnp case
test[tuple(take(10, primes())) == (2, 3, 5, 7, 11, 13, 17, 19, 23, 29)]
test[tuple(take(10, mprimes())) == (2, 3, 5, 7, 11, 13, 17, 19, 23, 29)]
test[tuple(take(10, memo_primes())) == (2, 3, 5, 7, 11, 13, 17, 19, 23, 29)]
test[tuple(take(10, mprimes2())) == (2, 3, 5, 7, 11, 13, 17, 19, 23, 29)]
test[tuple(take(10, memo_primes2())) == (2, 3, 5, 7, 11, 13, 17, 19, 23, 29)]
test[tuple(take(10, mprimes3())) == (2, 3, 5, 7, 11, 13, 17, 19, 23, 29)]
test[tuple(take(10, memo_primes3())) == (2, 3, 5, 7, 11, 13, 17, 19, 23, 29)]
# TODO: need some kind of benchmarking tools to do this properly.
with testset("performance benchmark"):
n = 2500
print(f"Performance for first {n:d} primes:")
for g in (mprimes(), memo_primes(), mprimes2(), memo_primes2(), mprimes3(), memo_primes3(), memo_primes4()):
with timer() as tictoc:
last(take(n, g))
print(g, tictoc.dt)
if __name__ == '__main__': # pragma: no cover
with session(__file__):
runtests()