Skip to content

Commit e6e9994

Browse files
committed
use drop_null_keys, some pandas fastpaths
1 parent d3a28c0 commit e6e9994

File tree

4 files changed

+24
-12
lines changed

4 files changed

+24
-12
lines changed

packages/python/plotly/_plotly_utils/basevalidators.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
import sys
1010
import warnings
11+
import narwhals.stable.v1 as nw
1112

1213
from _plotly_utils.optional_imports import get_module
1314

@@ -93,8 +94,19 @@ def copy_to_readonly_numpy_array(v, kind=None, force_numeric=False):
9394
"O": "object",
9495
}
9596

96-
# Handle pandas Series and Index objects
97+
if isinstance(v, nw.Series):
98+
if nw.dependencies.is_pandas_like_series(v_native := v.to_native()):
99+
v = v_native
100+
else:
101+
v = v.to_numpy()
102+
elif isinstance(v, nw.DataFrame):
103+
if nw.dependencies.is_pandas_like_dataframe(v_native := v.to_native()):
104+
v = v_native
105+
else:
106+
v = v.to_numpy()
107+
97108
if pd and isinstance(v, (pd.Series, pd.Index)):
109+
# Handle pandas Series and Index objects
98110
if v.dtype.kind in numeric_kinds:
99111
# Get the numeric numpy array so we use fast path below
100112
v = v.values
@@ -189,10 +201,12 @@ def is_homogeneous_array(v):
189201
"""
190202
np = get_module("numpy", should_load=False)
191203
pd = get_module("pandas", should_load=False)
204+
import narwhals as nw
192205
if (
193206
np
194207
and isinstance(v, np.ndarray)
195208
or (pd and isinstance(v, (pd.Series, pd.Index)))
209+
or (isinstance(v, nw.Series))
196210
):
197211
return True
198212
if is_numpy_convertable(v):

packages/python/plotly/plotly/express/_core.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ def invert_label(args, column):
156156

157157

158158
def _is_continuous(df: nw.DataFrame, col_name: str) -> bool:
159+
if nw.dependencies.is_pandas_like_dataframe(df_native := df.to_native()):
160+
return df_native[col_name].dtype.kind in 'ifc'
159161
return df.get_column(col_name).dtype.is_numeric()
160162

161163

@@ -1114,15 +1116,12 @@ def to_unindexed_series(x, name=None, native_namespace=None):
11141116
itx index reset if pandas-like). Stripping the index from existing pd.Series is
11151117
required to get things to match up right in the new DataFrame we're building.
11161118
"""
1117-
x_native = nw.to_native(x, strict=False)
1118-
if nw.dependencies.is_pandas_like_series(x_native):
1119-
return nw.from_native(
1120-
x_native.__class__(x_native, name=name).reset_index(drop=True),
1121-
series_only=True,
1122-
)
11231119
x = nw.from_native(x, series_only=True, strict=False)
11241120
if isinstance(x, nw.Series):
1125-
return x.rename(name)
1121+
if name == x.name:
1122+
# Avoid potentially creating a copy in pre-copy-on-write pandas
1123+
return nw.maybe_reset_index(x)
1124+
return nw.maybe_reset_index(x).rename(name)
11261125
elif native_namespace is not None:
11271126
return nw.new_series(name=name, values=x, native_namespace=native_namespace)
11281127
else:
@@ -1907,7 +1906,7 @@ def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFra
19071906
for i, level in enumerate(path):
19081907

19091908
dfg = (
1910-
df.group_by(path[i:])
1909+
df.group_by(path[i:], drop_null_keys=True)
19111910
.agg(**agg_f)
19121911
.pipe(post_agg, continuous_aggs, discrete_aggs)
19131912
)
@@ -2307,7 +2306,7 @@ def get_groups_and_orders(args, grouper):
23072306
groups = {tuple(single_group_name): df}
23082307
else:
23092308
required_grouper = list(orders.keys())
2310-
grouped = dict(df.group_by(required_grouper).__iter__())
2309+
grouped = dict(df.group_by(required_grouper, drop_null_keys=True).__iter__())
23112310
sorted_group_names = list(grouped.keys())
23122311

23132312
for i, col in reversed(list(enumerate(required_grouper))):

packages/python/plotly/plotly/express/_imshow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,6 @@ def imshow(
326326
if binary_string:
327327
raise ValueError("Binary strings cannot be used with pandas arrays")
328328
is_dataframe = True
329-
img = img.to_numpy()
330329
else:
331330
is_dataframe = False
332331

packages/python/plotly/plotly/figure_factory/_hexbin_mapbox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def create_hexbin_mapbox(
407407
center = dict(lat=lat_range.mean(), lon=lon_range.mean())
408408

409409
if args["animation_frame"] is not None:
410-
groups = dict(args["data_frame"].group_by(args["animation_frame"]).__iter__())
410+
groups = dict(args["data_frame"].group_by(args["animation_frame"], drop_null_keys=True).__iter__())
411411
else:
412412
groups = {(0,): args["data_frame"]}
413413

0 commit comments

Comments
 (0)