Skip to content

Commit d28e79b

Browse files
committed
Fix deprecation warnings on numpy 1.10
Apparently we were doing array == string and expecting a scalar to be returned. That's not a great plan, so don't do that.
1 parent 870d680 commit d28e79b

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

patsy/mgcv_cubic_splines.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy as np
1111

1212
from patsy.util import (have_pandas, atleast_2d_column_default,
13-
no_pickling, assert_no_pickling)
13+
no_pickling, assert_no_pickling, safe_string_eq)
1414
from patsy.state import stateful_transform
1515

1616
if have_pandas:
@@ -623,7 +623,7 @@ def memorize_finish(self):
623623
constraints = args["constraints"]
624624
n_constraints = 0
625625
if constraints is not None:
626-
if constraints == "center":
626+
if safe_string_eq(constraints, "center"):
627627
# Here we collect only number of constraints,
628628
# actual centering constraint will be computed after all_knots
629629
n_constraints = 1
@@ -651,7 +651,7 @@ def memorize_finish(self):
651651
lower_bound=args["lower_bound"],
652652
upper_bound=args["upper_bound"])
653653
if constraints is not None:
654-
if constraints == "center":
654+
if safe_string_eq(constraints, "center"):
655655
# Now we can compute centering constraints
656656
constraints = _get_centering_constraint_from_dmatrix(
657657
_get_free_crs_dmatrix(x, self._all_knots, cyclic=self._cyclic)
@@ -895,7 +895,7 @@ def __init__(self):
895895
def memorize_chunk(self, *args, **kwargs):
896896
constraints = self._tmp.setdefault("constraints",
897897
kwargs.get("constraints"))
898-
if constraints == "center":
898+
if safe_string_eq(constraints, "center"):
899899
args_2d = []
900900
for arg in args:
901901
arg = atleast_2d_column_default(arg)
@@ -919,7 +919,7 @@ def memorize_finish(self):
919919
del self._tmp
920920

921921
if constraints is not None:
922-
if constraints == "center":
922+
if safe_string_eq(constraints, "center"):
923923
constraints = np.atleast_2d(tmp["sum"] / tmp["count"])
924924
else:
925925
constraints = np.atleast_2d(constraints)

patsy/util.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"safe_issubdtype",
2121
"no_pickling",
2222
"assert_no_pickling",
23+
"safe_string_eq",
2324
]
2425

2526
import sys
@@ -699,3 +700,23 @@ def assert_no_pickling(obj):
699700
import pickle
700701
from nose.tools import assert_raises
701702
assert_raises(NotImplementedError, pickle.dumps, obj)
703+
704+
# Use like:
705+
# if safe_string_eq(constraints, "center"):
706+
# ...
707+
# where 'constraints' might be a string or an array. (If it's an array, then
708+
# we can't use == becaues it might broadcast and ugh.)
709+
def safe_string_eq(obj, value):
710+
if isinstance(obj, six.string_types):
711+
return obj == value
712+
else:
713+
return False
714+
715+
def test_safe_string_eq():
716+
assert safe_string_eq("foo", "foo")
717+
assert not safe_string_eq("foo", "bar")
718+
719+
if not six.PY3:
720+
assert safe_string_eq(unicode("foo"), "foo")
721+
722+
assert not safe_string_eq(np.empty((2, 2)), "foo")

0 commit comments

Comments
 (0)