|
9 | 9 | from patsy import PatsyError |
10 | 10 | from patsy.state import stateful_transform |
11 | 11 | from patsy.util import (SortAnythingKey, |
12 | | - have_pandas, asarray_or_pandas, |
13 | | - pandas_friendly_reshape) |
| 12 | + have_pandas, asarray_or_pandas, |
| 13 | + pandas_friendly_reshape) |
14 | 14 |
|
15 | 15 | if have_pandas: |
16 | 16 | import pandas |
@@ -39,6 +39,11 @@ def __init__(self, int_array, levels, contrast=None): |
39 | 39 | self.levels = tuple(levels) |
40 | 40 | self.contrast = contrast |
41 | 41 |
|
| 42 | + @classmethod |
| 43 | + def from_pandas_categorical(cls, pandas_categorical): |
| 44 | + return Categorical(pandas_categorical.labels, |
| 45 | + pandas_categorical.levels) |
| 46 | + |
42 | 47 | @classmethod |
43 | 48 | def from_sequence(cls, sequence, levels=None, **kwargs): |
44 | 49 | """from_sequence(sequence, levels=None, contrast=None) |
@@ -166,8 +171,8 @@ def test_Categorical(): |
166 | 171 | class CategoricalTransform(object): |
167 | 172 | """C(data, contrast=None, levels=None) |
168 | 173 |
|
169 | | - Converts some `data` into :class:`Categorical` form. (It is also used |
170 | | - called implicitly any time a formula contains a bare categorical factor.) |
| 174 | + Converts some `data` into :class:`Categorical` form. (It is also called |
| 175 | + implicitly any time a formula contains a bare categorical factor.) |
171 | 176 |
|
172 | 177 | This is used in two cases: |
173 | 178 |
|
@@ -209,6 +214,9 @@ def memorize_finish(self): |
209 | 214 |
|
210 | 215 | def transform(self, data, contrast=None, levels=None): |
211 | 216 | kwargs = {"contrast": contrast} |
| 217 | + if have_pandas and isinstance(data, pandas.Categorical): |
| 218 | + data = Categorical.from_pandas_categorical(data) |
| 219 | + # fall through to the next 'if': |
212 | 220 | if isinstance(data, Categorical): |
213 | 221 | if levels is not None and data.levels != levels: |
214 | 222 | raise PatsyError("changing levels of categorical data " |
@@ -286,6 +294,13 @@ def test_C_pandas(): |
286 | 294 | assert np.array_equal(cat3.int_array.index, [10, 20, 30]) |
287 | 295 | assert cat3.contrast == "asdf" |
288 | 296 |
|
| 297 | +def test_categorical_from_pandas_categorical(): |
| 298 | + if have_pandas: |
| 299 | + pandas_categorical = pandas.Categorical.from_array(["a", "b", "a"]) |
| 300 | + c = Categorical.from_pandas_categorical(pandas_categorical) |
| 301 | + assert np.array_equal(c.int_array, [0, 1, 0]) |
| 302 | + assert c.levels == ("a", "b") |
| 303 | + |
289 | 304 | def test_categorical_non_strings(): |
290 | 305 | cat = C([1, "foo", ("a", "b")]) |
291 | 306 | assert set(cat.levels) == set([1, "foo", ("a", "b")]) |
|
0 commit comments