-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathfold.py
More file actions
409 lines (327 loc) · 14.5 KB
/
fold.py
File metadata and controls
409 lines (327 loc) · 14.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
# -*- coding: utf-8 -*-
"""Folds, scans (lazy partial folds), and unfold.
For more batteries for itertools, see also the ``unpythonic.it`` module.
Racket-like multi-input ``foldl`` and ``foldr`` based on
https://docs.racket-lang.org/reference/pairs.html
``scanl` and ``scanr`` inspired by ``itertools.accumulate``, Haskell,
and (stream-scan) in SRFI-41.
https://srfi.schemers.org/srfi-41/srfi-41.html
"""
__all__ = ["scanl", "scanr", "scanl1", "scanr1",
"foldl", "foldr", "reducel", "reducer",
"rscanl", "rscanl1", "rfoldl", "rreducel", # reverse each input, then left-scan/fold
"unfold", "unfold1",
"prod",
"running_minmax", "minmax"]
from functools import partial
from itertools import zip_longest
from operator import mul
#from collections import deque
from .funutil import Values
#from .it import first, last, rev
from .it import last, rev
# Require at least one iterable to make this work seamlessly with curry. We take
# this approach with any new function families the standard library doesn't provide.
def scanl(proc, init, iterable0, *iterables, longest=False, fillvalue=None):
"""Scan (a.k.a. accumulate).
Like ``itertools.accumulate``, but supports multiple input iterables.
At least one iterable (``iterable0``) is required.
Initial value is mandatory; there is no sane default for the case with
multiple inputs.
By default, terminate when the shortest input runs out. To terminate on
longest input, use ``longest=True`` and optionally provide a ``fillvalue``.
If the inputs are iterators, this is essentially a lazy ``foldl`` that
yields the intermediate result at each step. Hence, useful for partially
folding infinite sequences (in the mathematical sense of "sequence").
Returns a generator, which (roughly)::
acc = init
yield acc
for elts in zip(iterable0, *iterables): # or zip_longest as appropriate
acc = proc(*elts, acc) # if this was legal syntax
yield acc
Example - partial sums and products::
from operator import add, mul
psums = composer(tail, curry(scanl, add, 0)) # tail to drop the init value
pprods = composer(tail, curry(scanl, mul, 1))
data = range(1, 5)
assert tuple(psums(data)) == (1, 3, 6, 10)
assert tuple(pprods(data)) == (1, 2, 6, 24)
"""
z = zip if not longest else partial(zip_longest, fillvalue=fillvalue)
acc = init
yield acc
for xs in z(iterable0, *iterables):
acc = proc(*(xs + (acc,)))
yield acc
def scanr(proc, init, iterable0, *iterables, longest=False, fillvalue=None):
"""Dual of scanl; scan from the right.
Example::
from operator import add
assert tuple(scanl(add, 0, range(1, 5))) == (0, 1, 3, 6, 10)
assert tuple(scanr(add, 0, range(1, 5))) == (0, 4, 7, 9, 10)
**CAUTION**: The ordering of the output is different from Haskell's ``scanr``;
we yield the results in the order they are computed (via a linear process).
For multiple input iterables, the notion of *corresponding elements*
is based on syncing the **left** ends.
Note difference between *l, *r and r*l where * = fold, scan::
def append_tuple(a, b, acc):
return acc + ((a, b),)
# foldl: left-fold
assert foldl(append_tuple, (), (1, 2, 3), (4, 5)) == ((1, 4), (2, 5))
# foldr: right-fold
assert foldr(append_tuple, (), (1, 2, 3), (4, 5)) == ((2, 5), (1, 4))
# rfoldl: reverse each input, then left-fold
assert rfoldl(append_tuple, (), (1, 2, 3), (4, 5)) == ((3, 5), (2, 4))
"""
# Linear process: sync left ends; reverse; scanl.
# (Flat is better than nested, also for the call stack.)
#
# The implicit tuple(...) in rev(...) may seem inelegant, but it doesn't
# really matter whether we keep the data in stack frames (like in the
# recursive-process variant) or read it into a tuple (like here); we must
# read and store all elements of the input before the actual scanning can begin.
init_from_lastx = init is _uselast and not iterables
z = zip if not longest else partial(zip_longest, fillvalue=fillvalue)
xss = rev(z(iterable0, *iterables))
if init_from_lastx:
try:
init = next(xss)[0]
except StopIteration:
return
# # left-append into a deque to get same output order as in Haskell
# acc = init
# que = deque()
# que.appendleft(acc)
# for xs in xss:
# acc = proc(*(xs + (acc,)))
# que.appendleft(acc)
# yield from que
# to be more rackety/pythonic: yield results in the order they're computed
acc = init
yield acc
for xs in xss:
acc = proc(*(xs + (acc,)))
yield acc
# Equivalent recursive process:
#def scanr(proc, init, iterable0, *iterables, longest=False, fillvalue=None):
# z = zip if not longest else partial(zip_longest, fillvalue=fillvalue)
# xss = z(iterable0, *iterables)
# pending_init_from_lastx = init is _uselast and not iterables
# def _scanr_recurser():
# try:
# xs = next(xss)
# except StopIteration:
# yield init # base case for recursion
# return
# subgen = _scanr_recurser()
# acc = next(subgen) # final result of previous step
#
# # The other base case: one iterable, no init given.
# # If pending_init_from_lastx is still True, we are the second-to-last subgen.
# nonlocal pending_init_from_lastx
# if pending_init_from_lastx:
# pending_init_from_lastx = False
# yield xs[0] # init value = last element from iterable0
# return
#
# # In case of all but the outermost generator, their final result has already
# # been read by the next(subgen), so they have only the last two yields remaining.
# yield proc(*(xs + (acc,))) # final result
# yield acc # previous result
# yield from subgen # sustain the chain
# return _scanr_recurser()
def scanl1(proc, iterable, init=None):
"""scanl for a single iterable, with optional init.
If ``init is None``, use the first element from the iterable.
If the iterable is empty, return ``None``.
Example - partial sums and products::
from operator import add, mul
psums = curry(scanl1, add)
pprods = curry(scanl1, mul)
data = range(1, 5)
assert tuple(psums(data)) == (1, 3, 6, 10)
assert tuple(pprods(data)) == (1, 2, 6, 24)
"""
it = iter(iterable)
if init is None:
try:
init = next(it)
except StopIteration:
def empty_iterable():
yield from ()
return empty_iterable()
return scanl(proc, init, it)
_uselast = object() # sentinel
def scanr1(proc, iterable, init=None):
"""Dual of scanl1.
If ``init is None``, use the last element from the iterable.
"""
return scanr(proc, _uselast if init is None else init, iterable)
def foldl(proc, init, iterable0, *iterables, longest=False, fillvalue=None):
"""Racket-like foldl that supports multiple input iterables.
At least one iterable (``iterable0``) is required. More are optional.
Initial value is mandatory; there is no sane default for the case with
multiple inputs.
By default, terminate when the shortest input runs out. To terminate on
longest input, use ``longest=True`` and optionally provide a ``fillvalue``.
Note order: ``proc(elt, acc)``, which is the opposite order of arguments
compared to ``functools.reduce``. General case ``proc(e1, ..., en, acc)``.
"""
return last(scanl(proc, init, iterable0, *iterables,
longest=longest, fillvalue=fillvalue))
def foldr(proc, init, iterable0, *iterables, longest=False, fillvalue=None):
"""Dual of foldl; fold from the right."""
# if using the haskelly result ordering in scanr, then first(...);
# if ordering results as they are computed, then last(...)
return last(scanr(proc, init, iterable0, *iterables,
longest=longest, fillvalue=fillvalue))
def reducel(proc, iterable, init=None):
"""Foldl for a single iterable, with optional init.
If ``init is None``, use the first element from the iterable.
Like ``functools.reduce``, but uses ``proc(elt, acc)`` like Racket."""
return last(scanl1(proc, iterable, init))
def reducer(proc, iterable, init=None):
"""Dual of reducel.
If ``init is None``, use the last element from the iterable.
"""
# if using the haskelly result ordering in scanr, then first(...);
# if ordering results as they are computed, then last(...)
return last(scanr1(proc, iterable, init))
def rscanl(proc, init, iterable0, *iterables, longest=False, fillvalue=None):
"""Reverse each input, then scanl.
For multiple input iterables, the notion of *corresponding elements*
is based on syncing the **right** ends.
``rev`` is applied to the inputs. Note this forces any generators.
"""
return scanl(proc, init, rev(iterable0), *(rev(s) for s in iterables),
longest=longest, fillvalue=fillvalue)
def rscanl1(proc, iterable, init=None):
"""Reverse the input, then scanl1."""
return scanl1(proc, rev(iterable), init)
def rfoldl(proc, init, iterable0, *iterables, longest=False, fillvalue=None):
"""Reverse each input, then foldl.
For multiple input iterables, the notion of *corresponding elements*
is based on syncing the **right** ends.
``rev`` is applied to the inputs. Note this forces any generators.
"""
return foldl(proc, init, rev(iterable0), *(rev(s) for s in iterables),
longest=longest, fillvalue=fillvalue)
def rreducel(proc, iterable, init=None):
"""Reverse the input, then reducel."""
return reducel(proc, rev(iterable), init)
def unfold1(proc, init):
"""Generate a sequence corecursively. The counterpart of foldl.
Returns a generator.
State starts from the value ``init``.
``proc`` must accept one argument, the state. If you have a complex,
multi-component state and would like to unpack it automatically,
see ``unfold``.
``proc`` must return either ``(value, newstate)``, or ``None`` to signify
that the sequence ends (if the sequence is finite).
("Sequence" is here meant in the mathematical sense; in the Python sense,
the output is an iterable.)
Example::
def step2(k): # x0, x0 + 2, x0 + 4, ...
return (k, k + 2)
assert (tuple(take(10, unfold1(step2, 10))) ==
(10, 12, 14, 16, 18, 20, 22, 24, 26, 28))
"""
state = init
while True:
result = proc(state)
if result is None:
break
value, state = result
yield value
def unfold(proc, *inits, **kwinits):
"""Like unfold1, but for n-in-(1+n)-out proc.
The current state is unpacked to the argument list of ``proc``.
It must return either a ``Values`` object where the first positional
return value is the ``value`` to be yielded at this iteration, and
anything else is state to be unpacked to the args/kwargs of ``proc``
at the next iteration; or a bare ``None`` to signify that the sequence ends.
If your state is something simple such as one number, see ``unfold1``.
Example::
def fibo(a, b):
return Values(a, a=b, b=a + b)
assert (tuple(take(10, unfold(fibo, 1, 1))) ==
(1, 1, 2, 3, 5, 8, 13, 21, 34, 55))
"""
state = Values(*inits, **kwinits)
while True:
result = proc(*state.rets, **state.kwrets)
if result is None:
break
if not isinstance(result, Values):
raise TypeError(f"Expected `None` (to terminate) or a `Values` (to continue), got {type(result)} with value {repr(result)}")
value, *rets = result.rets # unpack the first positional return value, keep the rest
state = Values(*rets, **result.kwrets)
yield value
# This is **not** how to make a right map; the result is exactly the same
# as for the ordinary (left) map and zip, but unnecessarily using a
# recursive process for something that can be done using a linear one.
# For documentation only. For working mapr, zipr, see unpythonic.it.
# The trick is in the order in which the recurser yields its results.
#
# def testme():
# squaretwo = lambda a, b: (a**2, b**2)
# print(tuple(mapr(squaretwo, (1, 2, 3), (4, 5))))
# print(tuple(map(squaretwo, (1, 2, 3), (4, 5))))
#
# def mapr(proc, *iterables):
# """Like map, but starting from the right. Recursive process.
#
# See ``rmap`` for the linear process that works by reversing each input.
# """
# def scanproc(*args):
# *elts, _ = args # discard acc
# return proc(*elts)
# # discard the init value with butlast
# return butlast(scanr(scanproc, None, *iterables))
#
# def zipr(*iterables):
# """Like zip, but starting from the right. Recursive process.
#
# See ``rzip`` for the linear process that works by reversing each input.
# """
# def identity(*args): # unpythonic.fun.identity, but dependency loop
# return args
# return mapr(identity, *iterables)
def prod(iterable, start=1):
"""Like the builtin sum, but compute the product.
This is a fold operation.
"""
return reducel(mul, iterable, init=start)
def running_minmax(iterable):
"""Return a generator extracting a running `(min, max)` from `iterable`.
The iterable is iterated just once.
If `iterable` is empty, an empty iterator is returned.
We assume iterable contains no NaNs, and that all elements in `iterable`
are comparable using `<` and `>`. Suggest filtering accordingly before
calling this.
This is a scan operation.
"""
it = iter(iterable)
try:
first = next(it)
except StopIteration: # behave like `unpack` and `window` on empty input
def empty_iterable():
yield from ()
return empty_iterable()
def mm(elt, acc):
a, b = acc
if elt < a:
a = elt
if elt > b:
b = elt
return a, b
return scanl(mm, (first, first), it)
def minmax(iterable):
"""Extract `(min, max)` from `iterable`, iterating it just once.
If `iterable` is empty, return `(None, None)`.
We assume iterable contains no NaNs, and that all elements in `iterable`
are comparable using `<` and `>`. Suggest filtering accordingly before
calling this.
This is a fold operation.
"""
return last(running_minmax(iterable), default=(None, None))