|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +"""Test/usage example: numerical tricks in FP.""" |
| 4 | + |
| 5 | +from operator import add, mul |
| 6 | +from itertools import repeat |
| 7 | +from math import sin, pi, log2 |
| 8 | + |
| 9 | +from unpythonic.fun import curry |
| 10 | +from unpythonic.it import unpack, drop, take, tail, first, second, last, iterate1 |
| 11 | +from unpythonic.fold import scanl, scanl1 |
| 12 | + |
| 13 | +def test(): |
| 14 | + # http://learnyouahaskell.com/higher-order-functions |
| 15 | + def collatz(n): |
| 16 | + if n < 1: |
| 17 | + raise ValueError() |
| 18 | + while True: |
| 19 | + yield n |
| 20 | + if n == 1: |
| 21 | + break |
| 22 | + n = n // 2 if n % 2 == 0 else 3 * n + 1 |
| 23 | + assert tuple(collatz(13)) == (13, 40, 20, 10, 5, 16, 8, 4, 2, 1) |
| 24 | + assert tuple(collatz(10)) == (10, 5, 16, 8, 4, 2, 1) |
| 25 | + assert tuple(collatz(30)) == (30, 15, 46, 23, 70, 35, 106, 53, 160, 80, 40, 20, 10, 5, 16, 8, 4, 2, 1) |
| 26 | + def len_gt(k, s): |
| 27 | + a, _ = unpack(1, drop(k, s)) |
| 28 | + return a # None if no item |
| 29 | + islong = curry(len_gt, 15) |
| 30 | + assert sum(1 for n in range(1, 101) if islong(collatz(n))) == 66 |
| 31 | + |
| 32 | + # Implicitly defined infinite streams, using generators. |
| 33 | + # |
| 34 | + def adds(s1, s2): |
| 35 | + """Add two infinite streams (elementwise).""" |
| 36 | + return map(add, s1, s2) |
| 37 | + def muls(s, c): |
| 38 | + """Multiply an infinite stream by a constant.""" |
| 39 | + return map(lambda x: c * x, s) |
| 40 | + |
| 41 | + # will eventually crash (stack overflow) |
| 42 | + def ones_fp(): |
| 43 | + yield 1 |
| 44 | + yield from ones_fp() |
| 45 | + def nats_fp(start=0): |
| 46 | + yield start |
| 47 | + yield from adds(nats_fp(start), ones_fp()) |
| 48 | + def fibos_fp(): |
| 49 | + yield 1 |
| 50 | + yield 1 |
| 51 | + yield from adds(fibos_fp(), tail(fibos_fp())) |
| 52 | + def powers_of_2(): |
| 53 | + yield 1 |
| 54 | + yield from muls(powers_of_2(), 2) |
| 55 | + assert tuple(take(10, ones_fp())) == (1,) * 10 |
| 56 | + assert tuple(take(10, nats_fp())) == tuple(range(10)) |
| 57 | + assert tuple(take(10, fibos_fp())) == (1, 1, 2, 3, 5, 8, 13, 21, 34, 55) |
| 58 | + assert tuple(take(10, powers_of_2())) == (1, 2, 4, 8, 16, 32, 64, 128, 256, 512) |
| 59 | + |
| 60 | + # The scanl equations are sometimes useful. The conditions |
| 61 | + # rs[0] = s0 |
| 62 | + # rs[k+1] = rs[k] + xs[k] |
| 63 | + # are equivalent with |
| 64 | + # rs = scanl(add, s0, xs) |
| 65 | + # https://www.vex.net/~trebla/haskell/scanl.xhtml |
| 66 | + def zs(): # s0 = 0, rs = [0, ...], xs = [0, ...] |
| 67 | + yield from scanl(add, 0, zs()) |
| 68 | + def os(): # s0 = 1, rs = [1, ...], xs = [0, ...] |
| 69 | + yield from scanl(add, 1, zs()) |
| 70 | + def ns(start=0): # s0 = start, rs = [start, start+1, ...], xs = [1, ...] |
| 71 | + yield from scanl(add, start, os()) |
| 72 | + def fs(): # s0 = 1, scons(1, rs) = fibos, xs = fibos |
| 73 | + yield 1 |
| 74 | + yield from scanl(add, 1, fs()) |
| 75 | + def p2s(): # s0 = 1, rs = xs = [1, 2, 4, ...] |
| 76 | + yield from scanl(add, 1, p2s()) |
| 77 | + assert tuple(take(10, zs())) == (0,) * 10 |
| 78 | + assert tuple(take(10, os())) == (1,) * 10 |
| 79 | + assert tuple(take(10, ns())) == tuple(range(10)) |
| 80 | + assert tuple(take(10, fs())) == (1, 1, 2, 3, 5, 8, 13, 21, 34, 55) |
| 81 | + assert tuple(take(10, p2s())) == (1, 2, 4, 8, 16, 32, 64, 128, 256, 512) |
| 82 | + |
| 83 | + # better Python: simple is better than complex (also no stack overflow) |
| 84 | + def ones(): |
| 85 | + return repeat(1) |
| 86 | + def nats(start=0): |
| 87 | + return scanl(add, start, ones()) |
| 88 | + def fibos(): |
| 89 | + a, b = 1, 1 |
| 90 | + while True: |
| 91 | + yield a |
| 92 | + a, b = b, a + b |
| 93 | + def pows(): |
| 94 | + return scanl(mul, 1, repeat(2)) |
| 95 | + assert tuple(take(10, ones())) == (1,) * 10 |
| 96 | + assert tuple(take(10, nats())) == tuple(range(10)) |
| 97 | + assert tuple(take(10, fibos())) == (1, 1, 2, 3, 5, 8, 13, 21, 34, 55) |
| 98 | + assert tuple(take(10, pows())) == (1, 2, 4, 8, 16, 32, 64, 128, 256, 512) |
| 99 | + |
| 100 | + # How to improve accuracy of numeric differentiation with FP tricks. |
| 101 | + # |
| 102 | + # See: |
| 103 | + # Hughes, 1984: Why Functional Programming Matters, p. 11 ff. |
| 104 | + # http://www.cse.chalmers.se/~rjmh/Papers/whyfp.html |
| 105 | + # |
| 106 | + def easydiff(f, x, h): # as well known, wildly inaccurate |
| 107 | + return (f(x + h) - f(x)) / h |
| 108 | + def halve(x): |
| 109 | + return x / 2 |
| 110 | + def differentiate(h0, f, x): |
| 111 | + return map(curry(easydiff, f, x), iterate1(halve, h0)) |
| 112 | + def within(eps, s): |
| 113 | + while True: |
| 114 | + # unpack with peek (but be careful, the rewinded tail is a tee'd copy) |
| 115 | + a, b, s = unpack(2, s, k=1) |
| 116 | + if abs(a - b) < eps: |
| 117 | + return b |
| 118 | + def differentiate_with_tol(h0, f, x, eps): |
| 119 | + return within(eps, differentiate(h0, f, x)) |
| 120 | + assert abs(differentiate_with_tol(0.1, sin, pi/2, 1e-8)) < 1e-7 |
| 121 | + |
| 122 | + def order(s): |
| 123 | + """Estimate asymptotic order of s, consuming the first three terms.""" |
| 124 | + a, b, c, _ = unpack(3, s) |
| 125 | + return round(log2(abs((a - c) / (b - c)) - 1)) |
| 126 | + def eliminate_error(n, s): |
| 127 | + """Eliminate error term of given asymptotic order n. |
| 128 | +
|
| 129 | + The stream s must be based on halving h at each step |
| 130 | + for the formula used here to work.""" |
| 131 | + while True: |
| 132 | + a, b, s = unpack(2, s, k=1) |
| 133 | + yield (b*2**n - a) / (2**(n - 1)) |
| 134 | + def improve(s): |
| 135 | + """Eliminate asymptotically dominant error term from s. |
| 136 | +
|
| 137 | + Consumes the first three terms to estimate the order. |
| 138 | + """ |
| 139 | + return eliminate_error(order(s), s) |
| 140 | + def better_differentiate_with_tol(h0, f, x, eps): |
| 141 | + return within(eps, improve(differentiate(h0, f, x))) |
| 142 | + assert abs(better_differentiate_with_tol(0.1, sin, pi/2, 1e-8)) < 1e-9 |
| 143 | + |
| 144 | + def super_improve(s): |
| 145 | + return map(second, iterate1(improve, s)) |
| 146 | + def best_differentiate_with_tol(h0, f, x, eps): |
| 147 | + return within(eps, super_improve(differentiate(h0, f, x))) |
| 148 | + assert abs(best_differentiate_with_tol(0.1, sin, pi/2, 1e-8)) < 1e-12 |
| 149 | + |
| 150 | + # pi approximation with Euler series acceleration |
| 151 | + # |
| 152 | + # See SICP, 2nd ed., sec. 3.5.3. |
| 153 | + # |
| 154 | + # This implementation originally by Jim Hoover, in Racket, from: |
| 155 | + # https://sites.ualberta.ca/~jhoover/325/CourseNotes/section/Streams.htm |
| 156 | + # |
| 157 | + partial_sums = curry(scanl1, add) |
| 158 | + def pi_summands(n): # π/4 = 1 - 1/3 + 1/5 - 1/7 + ... |
| 159 | + sign = +1 |
| 160 | + while True: |
| 161 | + yield sign / n |
| 162 | + n += 2 |
| 163 | + sign *= -1 |
| 164 | + pi_stream = muls(partial_sums(pi_summands(1)), 4) |
| 165 | + |
| 166 | + # http://mathworld.wolfram.com/EulerTransform.html |
| 167 | + # https://en.wikipedia.org/wiki/Series_acceleration#Euler%27s_transform |
| 168 | + def euler_transform(s): |
| 169 | + while True: |
| 170 | + a, b, c, s = unpack(3, s, k=1) |
| 171 | + yield c - ((c - b)**2 / (a - 2*b + c)) |
| 172 | + faster_pi_stream = euler_transform(pi_stream) |
| 173 | + |
| 174 | + def super_accelerate(transform, s): |
| 175 | + return map(first, iterate1(transform, s)) |
| 176 | + fastest_pi_stream = super_accelerate(euler_transform, pi_stream) |
| 177 | + |
| 178 | + assert abs(last(take(6, pi_stream)) - pi) < 0.2 |
| 179 | + assert abs(last(take(6, faster_pi_stream)) - pi) < 1e-3 |
| 180 | + assert abs(last(take(6, fastest_pi_stream)) - pi) < 1e-15 |
| 181 | + |
| 182 | + print("All tests PASSED") |
| 183 | + |
| 184 | +if __name__ == '__main__': |
| 185 | + test() |
0 commit comments