@@ -49,28 +49,45 @@ class _NoSuchType:
4949from math import log as math_log , copysign , trunc , floor , ceil
5050try :
5151 from mpmath import mpf , almosteq as mpf_almosteq
52- except ImportError :
52+ except ImportError : # pragma: no cover, optional at runtime, but installed at development time.
5353 # Can't use a gensym here since `mpf` must be a unique *type*.
5454 mpf = _NoSuchType
5555 mpf_almosteq = None
5656
5757def _numsign (x ):
58+ """The sign function, for numeric inputs."""
5859 if x == 0 :
5960 return 0
6061 return int (copysign (1.0 , x ))
6162
6263try :
6364 from sympy import log as _symlog , Expr as _symExpr , sign as _symsign
64- def log (x , b ):
65+ def log (x , b = None ):
66+ """The logarithm function.
67+
68+ Works for both numeric and symbolic (`SymPy.Expr`) inputs.
69+
70+ Default base `b=None` means `e`, i.e. take the natural logarithm.
71+ """
6572 if isinstance (x , _symExpr ):
6673 # https://stackoverflow.com/questions/46129259/how-to-simplify-logarithm-of-exponent-in-sympy
67- return _symlog (x , b ).expand (force = True )
68- return math_log (x , b )
74+ if b is not None :
75+ return _symlog (x , b ).expand (force = True )
76+ else :
77+ return _symlog (x ).expand (force = True )
78+ if b is not None :
79+ return math_log (x , b )
80+ else :
81+ return math_log (x )
6982 def sign (x ):
83+ """The sign function.
84+
85+ Works for both numeric and symbolic (`SymPy.Expr`) inputs.
86+ """
7087 if isinstance (x , _symExpr ):
7188 return _symsign (x )
7289 return _numsign (x )
73- except ImportError :
90+ except ImportError : # pragma: no cover, optional at runtime, but installed at development time.
7491 log = math_log
7592 sign = _numsign
7693 _symExpr = _NoSuchType
@@ -98,12 +115,12 @@ def almosteq(a, b, tol=1e-8):
98115 if isinstance (a , mpf ) and isinstance (b , mpf ):
99116 return mpf_almosteq (a , b , tol )
100117 # compare as native float if only one is an mpf
101- elif isinstance (a , mpf ) and isinstance (b , float ):
118+ elif isinstance (a , mpf ) and isinstance (b , ( float , int ) ):
102119 a = float (a )
103- elif isinstance (a , float ) and isinstance (b , mpf ):
120+ elif isinstance (a , ( float , int ) ) and isinstance (b , mpf ):
104121 b = float (b )
105122
106- if not all (isinstance (x , float ) for x in (a , b )):
123+ if not all (isinstance (x , ( float , int ) ) for x in (a , b )):
107124 return False # non-float type, already determined that a != b
108125 min_normal = float_info .min
109126 max_float = float_info .max
0 commit comments