Skip to content

Commit f9e8b20

Browse files
committed
handle pandas.Categorical-valued factors as categorical data
1 parent bda9937 commit f9e8b20

3 files changed

Lines changed: 29 additions & 6 deletions

File tree

patsy/build.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,9 @@ def _examine_factor_types(factors, factor_states, data_iter_maker):
441441
break
442442
for factor in list(examine_needed):
443443
value = factor.eval(factor_states[factor], data)
444+
if have_pandas and isinstance(value, pandas.Categorical):
445+
value = Categorical.from_pandas_categorical(value)
446+
# fall through into the next 'if':
444447
if isinstance(value, Categorical):
445448
postprocessor = CategoricalTransform(levels=value.levels)
446449
prefinished_postprocessors[factor] = postprocessor

patsy/categorical.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from patsy import PatsyError
1010
from patsy.state import stateful_transform
1111
from patsy.util import (SortAnythingKey,
12-
have_pandas, asarray_or_pandas,
13-
pandas_friendly_reshape)
12+
have_pandas, asarray_or_pandas,
13+
pandas_friendly_reshape)
1414

1515
if have_pandas:
1616
import pandas
@@ -39,6 +39,11 @@ def __init__(self, int_array, levels, contrast=None):
3939
self.levels = tuple(levels)
4040
self.contrast = contrast
4141

42+
@classmethod
43+
def from_pandas_categorical(cls, pandas_categorical):
44+
return Categorical(pandas_categorical.labels,
45+
pandas_categorical.levels)
46+
4247
@classmethod
4348
def from_sequence(cls, sequence, levels=None, **kwargs):
4449
"""from_sequence(sequence, levels=None, contrast=None)
@@ -166,8 +171,8 @@ def test_Categorical():
166171
class CategoricalTransform(object):
167172
"""C(data, contrast=None, levels=None)
168173
169-
Converts some `data` into :class:`Categorical` form. (It is also used
170-
called implicitly any time a formula contains a bare categorical factor.)
174+
Converts some `data` into :class:`Categorical` form. (It is also called
175+
implicitly any time a formula contains a bare categorical factor.)
171176
172177
This is used in two cases:
173178
@@ -209,6 +214,9 @@ def memorize_finish(self):
209214

210215
def transform(self, data, contrast=None, levels=None):
211216
kwargs = {"contrast": contrast}
217+
if have_pandas and isinstance(data, pandas.Categorical):
218+
data = Categorical.from_pandas_categorical(data)
219+
# fall through to the next 'if':
212220
if isinstance(data, Categorical):
213221
if levels is not None and data.levels != levels:
214222
raise PatsyError("changing levels of categorical data "
@@ -286,6 +294,13 @@ def test_C_pandas():
286294
assert np.array_equal(cat3.int_array.index, [10, 20, 30])
287295
assert cat3.contrast == "asdf"
288296

297+
def test_categorical_from_pandas_categorical():
298+
if have_pandas:
299+
pandas_categorical = pandas.Categorical.from_array(["a", "b", "a"])
300+
c = Categorical.from_pandas_categorical(pandas_categorical)
301+
assert np.array_equal(c.int_array, [0, 1, 0])
302+
assert c.levels == ("a", "b")
303+
289304
def test_categorical_non_strings():
290305
cat = C([1, "foo", ("a", "b")])
291306
assert set(cat.levels) == set([1, "foo", ("a", "b")])

patsy/test_build.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,14 +436,19 @@ def iter_maker():
436436
def test_categorical():
437437
data_strings = {"a": ["a1", "a2", "a1"]}
438438
data_categ = {"a": C(["a2", "a1", "a2"])}
439+
datas = [data_strings, data_categ]
440+
if have_pandas:
441+
data_pandas = {"a": pandas.Categorical.from_array(["a1", "a2", "a2"])}
442+
datas.append(data_pandas)
439443
def t(data1, data2):
440444
def iter_maker():
441445
yield data1
442446
builders = design_matrix_builders([make_termlist(["a"])],
443447
iter_maker)
444448
build_design_matrices(builders, data2)
445-
t(data_strings, data_categ)
446-
t(data_categ, data_strings)
449+
for data1 in datas:
450+
for data2 in datas:
451+
t(data1, data2)
447452

448453
def test_contrast():
449454
from patsy.contrasts import ContrastMatrix, Sum

0 commit comments

Comments
 (0)