Skip to content

Commit 51e0e4e

Browse files
authored
[BEAM-13605] Modify groupby.apply implementation in preparation for pandas 1.4.0 (#16706)
* Modify groupby.apply implementation in preparation for pandas 1.4.0 * fixup! Modify groupby.apply implementation in preparation for pandas 1.4.0 * Address review comments
1 parent 0c587e3 commit 51e0e4e

1 file changed

Lines changed: 94 additions & 23 deletions

File tree

sdks/python/apache_beam/dataframe/frames.py

Lines changed: 94 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2674,11 +2674,9 @@ def duplicated(self, keep, subset):
26742674

26752675
by = subset or list(self.columns)
26762676

2677-
# Workaround a bug where groupby.apply() that returns a single-element
2678-
# Series moves index label to column
26792677
return self.groupby(by).apply(
26802678
lambda df: pd.DataFrame(df.duplicated(keep=keep, subset=subset),
2681-
columns=[None]))[None]
2679+
columns=[None]))[None].droplevel(by)
26822680

26832681
@frame_base.with_docs_from(pd.DataFrame)
26842682
@frame_base.args_to_kwargs(pd.DataFrame)
@@ -3975,7 +3973,19 @@ def apply(self, func, *args, **kwargs):
39753973
object of the same type as what will be returned when the pipeline is
39763974
processing actual data. If the result is a pandas object it should have the
39773975
same type and name (for a Series) or column types and names (for
3978-
a DataFrame) as the actual results."""
3976+
a DataFrame) as the actual results.
3977+
3978+
Note that in pandas, ``apply`` attempts to detect if the index is unmodified
3979+
in ``func`` (indicating ``func`` is a transform) and drops the duplicate
3980+
index in the output. To determine this, pandas tests the indexes for
3981+
equality. However, Beam cannot do this since it is sensitive to the input
3982+
data; instead this implementation tests if the indexes are equivalent
3983+
with ``is``. See the `pandas 1.4.0 release notes
3984+
<https://pandas.pydata.org/docs/dev/whatsnew/v1.4.0.html#groupby-apply-consistent-transform-detection>`_
3985+
for a good explanation of the distinction between these approaches. In
3986+
practice, this just means that in some cases the Beam result will have
3987+
a duplicate index, whereas pandas would have dropped it."""
3988+
39793989
project = _maybe_project_func(self._projection)
39803990
grouping_indexes = self._grouping_indexes
39813991
grouping_columns = self._grouping_columns
@@ -3986,29 +3996,82 @@ def apply(self, func, *args, **kwargs):
39863996
fn_input = project(self._ungrouped_with_index.proxy().reset_index(
39873997
grouping_columns, drop=True))
39883998
result = func(fn_input)
3989-
if isinstance(result, pd.core.generic.NDFrame):
3990-
if result.index is fn_input.index:
3991-
proxy = result
3999+
def index_to_arrays(index):
4000+
return [index.get_level_values(level)
4001+
for level in range(index.nlevels)]
4002+
4003+
4004+
# By default do_apply will just call pandas apply()
4005+
# We override it below if necessary
4006+
do_apply = lambda gb: gb.apply(func, *args, **kwargs)
4007+
4008+
if (isinstance(result, pd.core.generic.NDFrame) and
4009+
result.index is fn_input.index):
4010+
# Special case where apply fn is a transform
4011+
# Note we trust that if the user fn produces a proxy with the identical
4012+
# index, it will produce results with identical indexes at execution
4013+
# time too
4014+
proxy = result
4015+
elif isinstance(result, pd.DataFrame):
4016+
# apply fn is not a transform, we need to make sure the original index
4017+
# values are prepended to the result's index
4018+
proxy = result[:0]
4019+
4020+
# First adjust proxy
4021+
proxy.index = pd.MultiIndex.from_arrays(
4022+
index_to_arrays(self._ungrouped.proxy().index) +
4023+
index_to_arrays(proxy.index),
4024+
names=self._ungrouped.proxy().index.names + proxy.index.names)
4025+
4026+
# Then override do_apply function
4027+
new_index_names = self._ungrouped.proxy().index.names
4028+
if len(new_index_names) > 1:
4029+
def add_key_index(key, df):
4030+
# df is a dataframe or Series representing the result of func for
4031+
# a single key
4032+
# key is a tuple with the MultiIndex values for this key
4033+
df.index = pd.MultiIndex.from_arrays(
4034+
[[key[i]] * len(df) for i in range(len(new_index_names))] +
4035+
index_to_arrays(df.index),
4036+
names=new_index_names + df.index.names)
4037+
return df
39924038
else:
3993-
proxy = result[:0]
3994-
3995-
def index_to_arrays(index):
3996-
return [index.get_level_values(level)
3997-
for level in range(index.nlevels)]
3998-
3999-
# The final result will have the grouped indexes + the indexes from the
4000-
# result
4001-
proxy.index = pd.MultiIndex.from_arrays(
4002-
index_to_arrays(self._ungrouped.proxy().index) +
4003-
index_to_arrays(proxy.index),
4004-
names=self._ungrouped.proxy().index.names + proxy.index.names)
4039+
def add_key_index(key, df):
4040+
# df is a dataframe or Series representing the result of func for
4041+
# a single key
4042+
df.index = pd.MultiIndex.from_arrays(
4043+
[[key] * len(df)] + index_to_arrays(df.index),
4044+
names=new_index_names + df.index.names)
4045+
return df
4046+
4047+
4048+
do_apply = lambda gb: pd.concat([
4049+
add_key_index(k, func(gb.get_group(k), *args, **kwargs))
4050+
for k in gb.groups.keys()])
4051+
elif isinstance(result, pd.Series):
4052+
if isinstance(fn_input, pd.DataFrame):
4053+
# DataFrameGroupBy
4054+
# In this case pandas transposes the Series result, s.t. the Series
4055+
# index values are the columns, and the grouping keys are the new index
4056+
# values.
4057+
dtype = pd.Series([result]).dtype
4058+
proxy = pd.DataFrame(columns=result.index,
4059+
dtype=result.dtype,
4060+
index=self._ungrouped.proxy().index)
4061+
elif isinstance(fn_input, pd.Series):
4062+
# SeriesGroupBy
4063+
# In this case the output is still a Series, but with an additional
4064+
# index with the grouping keys.
4065+
proxy = pd.Series(dtype=result.dtype,
4066+
name=result.name,
4067+
index=index_to_arrays(self._ungrouped.proxy().index) +
4068+
index_to_arrays(result[:0].index))
40054069
else:
40064070
# The user fn returns some non-pandas type. The expected result is a
40074071
# Series where each element is the result of one user fn call.
40084072
dtype = pd.Series([result]).dtype
40094073
proxy = pd.Series([], dtype=dtype, index=self._ungrouped.proxy().index)
40104074

4011-
40124075
def do_partition_apply(df):
40134076
# Remove columns from index, we only needed them there for partitioning
40144077
df = df.reset_index(grouping_columns, drop=True)
@@ -4017,7 +4080,8 @@ def do_partition_apply(df):
40174080
by=grouping_columns or None)
40184081

40194082
gb = project(gb)
4020-
return gb.apply(func, *args, **kwargs)
4083+
4084+
return do_apply(gb)
40214085

40224086
return DeferredDataFrame(
40234087
expressions.ComputedExpression(
@@ -4117,8 +4181,15 @@ def apply_fn(df):
41174181
@property # type: ignore
41184182
@frame_base.with_docs_from(DataFrameGroupBy)
41194183
def dtypes(self):
4120-
grouping_columns = self._grouping_columns
4121-
return self.apply(lambda df: df.drop(grouping_columns, axis=1).dtypes)
4184+
return frame_base.DeferredFrame.wrap(
4185+
expressions.ComputedExpression(
4186+
'dtypes',
4187+
lambda gb: gb.dtypes,
4188+
[self._expr],
4189+
requires_partition_by=partitionings.Arbitrary(),
4190+
preserves_partition_by=partitionings.Arbitrary()
4191+
)
4192+
)
41224193

41234194
fillna = frame_base.wont_implement_method(
41244195
DataFrameGroupBy, 'fillna', explanation=(

0 commit comments

Comments
 (0)