Skip to content

Commit 3614218

Browse files
feat: Improve local data validation (#1598)
1 parent a4145a8 commit 3614218

File tree

7 files changed

+114
-71
lines changed

7 files changed

+114
-71
lines changed

packages/bigframes/bigframes/core/array_value.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,19 @@ class ArrayValue:
5858

5959
@classmethod
6060
def from_pyarrow(cls, arrow_table: pa.Table, session: Session):
61-
adapted_table = local_data.adapt_pa_table(arrow_table)
62-
schema = local_data.arrow_schema_to_bigframes(adapted_table.schema)
61+
data_source = local_data.ManagedArrowTable.from_pyarrow(arrow_table)
62+
return cls.from_managed(source=data_source, session=session)
6363

64+
@classmethod
65+
def from_managed(cls, source: local_data.ManagedArrowTable, session: Session):
6466
scan_list = nodes.ScanList(
6567
tuple(
6668
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
67-
for item in schema.items
69+
for item in source.schema.items
6870
)
6971
)
70-
data_source = local_data.ManagedArrowTable(adapted_table, schema)
7172
node = nodes.ReadLocalNode(
72-
data_source,
73+
source,
7374
session=session,
7475
scan_list=scan_list,
7576
)

packages/bigframes/bigframes/core/blocks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from bigframes import session
5454
from bigframes._config import sampling_options
5555
import bigframes.constants
56+
from bigframes.core import local_data
5657
import bigframes.core as core
5758
import bigframes.core.compile.googlesql as googlesql
5859
import bigframes.core.expression as ex
@@ -187,8 +188,8 @@ def from_local(
187188

188189
pd_data = pd_data.set_axis(column_ids, axis=1)
189190
pd_data = pd_data.reset_index(names=index_ids)
190-
as_pyarrow = pa.Table.from_pandas(pd_data, preserve_index=False)
191-
array_value = core.ArrayValue.from_pyarrow(as_pyarrow, session=session)
191+
managed_data = local_data.ManagedArrowTable.from_pandas(pd_data)
192+
array_value = core.ArrayValue.from_managed(managed_data, session=session)
192193
block = cls(
193194
array_value,
194195
column_labels=column_labels,

packages/bigframes/bigframes/core/local_data.py

Lines changed: 101 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818

1919
import dataclasses
2020
import functools
21+
from typing import cast, Union
2122
import uuid
2223

24+
import geopandas # type: ignore
25+
import numpy as np
26+
import pandas
2327
import pyarrow as pa
2428

2529
import bigframes.core.schema as schemata
@@ -32,51 +36,113 @@ class LocalTableMetadata:
3236
row_count: int
3337

3438
@classmethod
35-
def from_arrow(cls, table: pa.Table):
39+
def from_arrow(cls, table: pa.Table) -> LocalTableMetadata:
3640
return cls(total_bytes=table.nbytes, row_count=table.num_rows)
3741

3842

43+
_MANAGED_STORAGE_TYPES_OVERRIDES: dict[bigframes.dtypes.Dtype, pa.DataType] = {
44+
# wkt to be precise
45+
bigframes.dtypes.GEO_DTYPE: pa.string()
46+
}
47+
48+
3949
@dataclasses.dataclass(frozen=True)
4050
class ManagedArrowTable:
4151
data: pa.Table = dataclasses.field(hash=False)
4252
schema: schemata.ArraySchema = dataclasses.field(hash=False)
4353
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
4454

55+
def __post_init__(self):
56+
self.validate()
57+
4558
@functools.cached_property
46-
def metadata(self):
59+
def metadata(self) -> LocalTableMetadata:
4760
return LocalTableMetadata.from_arrow(self.data)
4861

49-
50-
def arrow_schema_to_bigframes(arrow_schema: pa.Schema) -> schemata.ArraySchema:
51-
"""Infer the corresponding bigframes schema given a pyarrow schema."""
52-
schema_items = tuple(
53-
schemata.SchemaItem(
54-
field.name,
55-
bigframes_type_for_arrow_type(field.type),
62+
@classmethod
63+
def from_pandas(cls, dataframe: pandas.DataFrame) -> ManagedArrowTable:
64+
"""Creates managed table from pandas. Ignores index, col names must be unique strings"""
65+
columns: list[pa.ChunkedArray] = []
66+
fields: list[schemata.SchemaItem] = []
67+
column_names = list(dataframe.columns)
68+
assert len(column_names) == len(set(column_names))
69+
70+
for name, col in dataframe.items():
71+
new_arr, bf_type = _adapt_pandas_series(col)
72+
columns.append(new_arr)
73+
fields.append(schemata.SchemaItem(str(name), bf_type))
74+
75+
return ManagedArrowTable(
76+
pa.table(columns, names=column_names), schemata.ArraySchema(tuple(fields))
5677
)
57-
for field in arrow_schema
58-
)
59-
return schemata.ArraySchema(schema_items)
6078

79+
@classmethod
80+
def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable:
81+
columns: list[pa.ChunkedArray] = []
82+
fields: list[schemata.SchemaItem] = []
83+
for name, arr in zip(table.column_names, table.columns):
84+
new_arr, bf_type = _adapt_arrow_array(arr)
85+
columns.append(new_arr)
86+
fields.append(schemata.SchemaItem(name, bf_type))
87+
88+
return ManagedArrowTable(
89+
pa.table(columns, names=table.column_names),
90+
schemata.ArraySchema(tuple(fields)),
91+
)
6192

62-
def adapt_pa_table(arrow_table: pa.Table) -> pa.Table:
63-
"""Adapt a pyarrow table to one that can be handled by bigframes. Converts tz to UTC and unit to us for temporal types."""
64-
new_schema = pa.schema(
65-
[
66-
pa.field(field.name, arrow_type_replacements(field.type))
67-
for field in arrow_table.schema
68-
]
69-
)
70-
return arrow_table.cast(new_schema)
93+
def validate(self):
94+
# TODO: Content-based validation for some datatypes (eg json, wkt, list) where logical domain is smaller than pyarrow type
95+
for bf_field, arrow_field in zip(self.schema.items, self.data.schema):
96+
expected_arrow_type = _get_managed_storage_type(bf_field.dtype)
97+
arrow_type = arrow_field.type
98+
if expected_arrow_type != arrow_type:
99+
raise TypeError(
100+
f"Field {bf_field} has arrow array type: {arrow_type}, expected type: {expected_arrow_type}"
101+
)
71102

72103

73-
def bigframes_type_for_arrow_type(pa_type: pa.DataType) -> bigframes.dtypes.Dtype:
74-
return bigframes.dtypes.arrow_dtype_to_bigframes_dtype(
75-
arrow_type_replacements(pa_type)
76-
)
104+
def _get_managed_storage_type(dtype: bigframes.dtypes.Dtype) -> pa.DataType:
105+
if dtype in _MANAGED_STORAGE_TYPES_OVERRIDES.keys():
106+
return _MANAGED_STORAGE_TYPES_OVERRIDES[dtype]
107+
else:
108+
return bigframes.dtypes.bigframes_dtype_to_arrow_dtype(dtype)
109+
110+
111+
def _adapt_pandas_series(
112+
series: pandas.Series,
113+
) -> tuple[Union[pa.ChunkedArray, pa.Array], bigframes.dtypes.Dtype]:
114+
# Mostly rely on pyarrow conversions, but have to convert geo without its help.
115+
if series.dtype == bigframes.dtypes.GEO_DTYPE:
116+
series = geopandas.GeoSeries(series).to_wkt(rounding_precision=-1)
117+
return pa.array(series, type=pa.string()), bigframes.dtypes.GEO_DTYPE
118+
try:
119+
return _adapt_arrow_array(pa.array(series))
120+
except Exception as e:
121+
if series.dtype == np.dtype("O"):
122+
try:
123+
series = series.astype(bigframes.dtypes.GEO_DTYPE)
124+
except TypeError:
125+
pass
126+
raise e
127+
128+
129+
def _adapt_arrow_array(
130+
array: Union[pa.ChunkedArray, pa.Array]
131+
) -> tuple[Union[pa.ChunkedArray, pa.Array], bigframes.dtypes.Dtype]:
132+
target_type = _arrow_type_replacements(array.type)
133+
if target_type != array.type:
134+
# TODO: Maybe warn if lossy conversion?
135+
array = array.cast(target_type)
136+
bf_type = bigframes.dtypes.arrow_dtype_to_bigframes_dtype(target_type)
137+
storage_type = _get_managed_storage_type(bf_type)
138+
if storage_type != array.type:
139+
raise TypeError(
140+
f"Expected {bf_type} to use arrow {storage_type}, instead got {array.type}"
141+
)
142+
return array, bf_type
77143

78144

79-
def arrow_type_replacements(type: pa.DataType) -> pa.DataType:
145+
def _arrow_type_replacements(type: pa.DataType) -> pa.DataType:
80146
if pa.types.is_timestamp(type):
81147
# This is potentially lossy, but BigFrames doesn't support ns
82148
new_tz = "UTC" if (type.tz is not None) else None
@@ -91,18 +157,24 @@ def arrow_type_replacements(type: pa.DataType) -> pa.DataType:
91157
return pa.decimal128(38, 9)
92158
if pa.types.is_decimal256(type):
93159
return pa.decimal256(76, 38)
94-
if pa.types.is_dictionary(type):
95-
return arrow_type_replacements(type.value_type)
96160
if pa.types.is_large_string(type):
97161
# simple string type can handle the largest strings needed
98162
return pa.string()
99163
if pa.types.is_null(type):
100164
# null as a type not allowed, default type is float64 for bigframes
101165
return pa.float64()
102166
if pa.types.is_list(type):
103-
new_field_t = arrow_type_replacements(type.value_type)
167+
new_field_t = _arrow_type_replacements(type.value_type)
104168
if new_field_t != type.value_type:
105169
return pa.list_(new_field_t)
106170
return type
171+
if pa.types.is_struct(type):
172+
struct_type = cast(pa.StructType, type)
173+
new_fields: list[pa.Field] = []
174+
for i in range(struct_type.num_fields):
175+
field = struct_type.field(i)
176+
field.with_type(_arrow_type_replacements(field.type))
177+
new_fields.append(field.with_type(_arrow_type_replacements(field.type)))
178+
return pa.struct(new_fields)
107179
else:
108180
return type

packages/bigframes/bigframes/dtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,8 @@ def bigframes_dtype_to_arrow_dtype(
456456
if bigframes_dtype in _BIGFRAMES_TO_ARROW:
457457
return _BIGFRAMES_TO_ARROW[bigframes_dtype]
458458
if isinstance(bigframes_dtype, pd.ArrowDtype):
459+
if pa.types.is_duration(bigframes_dtype.pyarrow_dtype):
460+
return bigframes_dtype.pyarrow_dtype
459461
if pa.types.is_list(bigframes_dtype.pyarrow_dtype):
460462
return bigframes_dtype.pyarrow_dtype
461463
if pa.types.is_struct(bigframes_dtype.pyarrow_dtype):

packages/bigframes/bigframes/session/__init__.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,6 @@
108108

109109
logger = logging.getLogger(__name__)
110110

111-
NON_INLINABLE_DTYPES: Sequence[bigframes.dtypes.Dtype] = (
112-
# Currently excluded as doesn't have arrow type
113-
bigframes.dtypes.GEO_DTYPE,
114-
)
115-
116111

117112
class Session(
118113
third_party_pandas_gbq.GBQIOMixin,
@@ -838,17 +833,6 @@ def _read_pandas_inline(
838833
f"Could not convert with a BigQuery type: `{exc}`. "
839834
) from exc
840835

841-
# Make sure all types are inlinable to avoid escaping errors.
842-
inline_types = inline_df._block.expr.schema.dtypes
843-
noninlinable_types = [
844-
dtype for dtype in inline_types if dtype in NON_INLINABLE_DTYPES
845-
]
846-
if len(noninlinable_types) != 0:
847-
raise ValueError(
848-
f"Could not inline with a BigQuery type: `{noninlinable_types}`. "
849-
f"{constants.FEEDBACK_LINK}"
850-
)
851-
852836
return inline_df
853837

854838
def _read_pandas_load_job(

packages/bigframes/tests/unit/session/test_io_pandas.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import pandas.testing
2525
import pyarrow # type: ignore
2626
import pytest
27-
import shapely # type: ignore
2827

2928
import bigframes.core.schema
3029
import bigframes.features
@@ -504,17 +503,3 @@ def test_read_pandas_with_bigframes_dataframe():
504503
ValueError, match=re.escape("read_pandas() expects a pandas.DataFrame")
505504
):
506505
session.read_pandas(df)
507-
508-
509-
def test_read_pandas_inline_w_noninlineable_type_raises_error():
510-
session = resources.create_bigquery_session()
511-
data = [
512-
shapely.Point(1, 1),
513-
shapely.Point(2, 1),
514-
shapely.Point(1, 2),
515-
]
516-
s = pandas.Series(data, dtype=geopandas.array.GeometryDtype())
517-
with pytest.raises(
518-
ValueError, match="Could not (convert|inline) with a BigQuery type:"
519-
):
520-
session.read_pandas(s, write_engine="bigquery_inline")

packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -813,10 +813,8 @@ def visit_DefaultLiteral(self, op, *, value, dtype):
813813
elif dtype.is_json():
814814
return sge.ParseJSON(this=sge.convert(str(value)))
815815
elif dtype.is_geospatial():
816-
args = [value.wkt]
817-
if (srid := dtype.srid) is not None:
818-
args.append(srid)
819-
return self.f.st_geomfromtext(*args)
816+
wkt = value if isinstance(value, str) else value.wkt
817+
return self.f.st_geogfromtext(wkt)
820818

821819
raise NotImplementedError(f"Unsupported type: {dtype!r}")
822820

0 commit comments

Comments
 (0)