1818
1919import dataclasses
2020import functools
21+ from typing import cast , Union
2122import uuid
2223
24+ import geopandas # type: ignore
25+ import numpy as np
26+ import pandas
2327import pyarrow as pa
2428
2529import bigframes .core .schema as schemata
@@ -32,51 +36,113 @@ class LocalTableMetadata:
3236 row_count : int
3337
3438 @classmethod
35- def from_arrow (cls , table : pa .Table ):
39+ def from_arrow (cls , table : pa .Table ) -> LocalTableMetadata :
3640 return cls (total_bytes = table .nbytes , row_count = table .num_rows )
3741
3842
43+ _MANAGED_STORAGE_TYPES_OVERRIDES : dict [bigframes .dtypes .Dtype , pa .DataType ] = {
44+ # wkt to be precise
45+ bigframes .dtypes .GEO_DTYPE : pa .string ()
46+ }
47+
48+
3949@dataclasses .dataclass (frozen = True )
4050class ManagedArrowTable :
4151 data : pa .Table = dataclasses .field (hash = False )
4252 schema : schemata .ArraySchema = dataclasses .field (hash = False )
4353 id : uuid .UUID = dataclasses .field (default_factory = uuid .uuid4 )
4454
55+ def __post_init__ (self ):
56+ self .validate ()
57+
4558 @functools .cached_property
46- def metadata (self ):
59+ def metadata (self ) -> LocalTableMetadata :
4760 return LocalTableMetadata .from_arrow (self .data )
4861
49-
50- def arrow_schema_to_bigframes (arrow_schema : pa .Schema ) -> schemata .ArraySchema :
51- """Infer the corresponding bigframes schema given a pyarrow schema."""
52- schema_items = tuple (
53- schemata .SchemaItem (
54- field .name ,
55- bigframes_type_for_arrow_type (field .type ),
62+ @classmethod
63+ def from_pandas (cls , dataframe : pandas .DataFrame ) -> ManagedArrowTable :
64+ """Creates managed table from pandas. Ignores index, col names must be unique strings"""
65+ columns : list [pa .ChunkedArray ] = []
66+ fields : list [schemata .SchemaItem ] = []
67+ column_names = list (dataframe .columns )
68+ assert len (column_names ) == len (set (column_names ))
69+
70+ for name , col in dataframe .items ():
71+ new_arr , bf_type = _adapt_pandas_series (col )
72+ columns .append (new_arr )
73+ fields .append (schemata .SchemaItem (str (name ), bf_type ))
74+
75+ return ManagedArrowTable (
76+ pa .table (columns , names = column_names ), schemata .ArraySchema (tuple (fields ))
5677 )
57- for field in arrow_schema
58- )
59- return schemata .ArraySchema (schema_items )
6078
79+ @classmethod
80+ def from_pyarrow (self , table : pa .Table ) -> ManagedArrowTable :
81+ columns : list [pa .ChunkedArray ] = []
82+ fields : list [schemata .SchemaItem ] = []
83+ for name , arr in zip (table .column_names , table .columns ):
84+ new_arr , bf_type = _adapt_arrow_array (arr )
85+ columns .append (new_arr )
86+ fields .append (schemata .SchemaItem (name , bf_type ))
87+
88+ return ManagedArrowTable (
89+ pa .table (columns , names = table .column_names ),
90+ schemata .ArraySchema (tuple (fields )),
91+ )
6192
62- def adapt_pa_table ( arrow_table : pa . Table ) -> pa . Table :
63- """Adapt a pyarrow table to one that can be handled by bigframes. Converts tz to UTC and unit to us for temporal types."""
64- new_schema = pa .schema (
65- [
66- pa . field ( field . name , arrow_type_replacements ( field .type ))
67- for field in arrow_table . schema
68- ]
69- )
70- return arrow_table . cast ( new_schema )
93+ def validate ( self ) :
94+ # TODO: Content-based validation for some datatypes (eg json, wkt, list) where logical domain is smaller than pyarrow type
95+ for bf_field , arrow_field in zip ( self .schema . items , self . data . schema ):
96+ expected_arrow_type = _get_managed_storage_type ( bf_field . dtype )
97+ arrow_type = arrow_field .type
98+ if expected_arrow_type != arrow_type :
99+ raise TypeError (
100+ f"Field { bf_field } has arrow array type: { arrow_type } , expected type: { expected_arrow_type } "
101+ )
71102
72103
73- def bigframes_type_for_arrow_type (pa_type : pa .DataType ) -> bigframes .dtypes .Dtype :
74- return bigframes .dtypes .arrow_dtype_to_bigframes_dtype (
75- arrow_type_replacements (pa_type )
76- )
104+ def _get_managed_storage_type (dtype : bigframes .dtypes .Dtype ) -> pa .DataType :
105+ if dtype in _MANAGED_STORAGE_TYPES_OVERRIDES .keys ():
106+ return _MANAGED_STORAGE_TYPES_OVERRIDES [dtype ]
107+ else :
108+ return bigframes .dtypes .bigframes_dtype_to_arrow_dtype (dtype )
109+
110+
111+ def _adapt_pandas_series (
112+ series : pandas .Series ,
113+ ) -> tuple [Union [pa .ChunkedArray , pa .Array ], bigframes .dtypes .Dtype ]:
114+ # Mostly rely on pyarrow conversions, but have to convert geo without its help.
115+ if series .dtype == bigframes .dtypes .GEO_DTYPE :
116+ series = geopandas .GeoSeries (series ).to_wkt (rounding_precision = - 1 )
117+ return pa .array (series , type = pa .string ()), bigframes .dtypes .GEO_DTYPE
118+ try :
119+ return _adapt_arrow_array (pa .array (series ))
120+ except Exception as e :
121+ if series .dtype == np .dtype ("O" ):
122+ try :
123+ series = series .astype (bigframes .dtypes .GEO_DTYPE )
124+ except TypeError :
125+ pass
126+ raise e
127+
128+
129+ def _adapt_arrow_array (
130+ array : Union [pa .ChunkedArray , pa .Array ]
131+ ) -> tuple [Union [pa .ChunkedArray , pa .Array ], bigframes .dtypes .Dtype ]:
132+ target_type = _arrow_type_replacements (array .type )
133+ if target_type != array .type :
134+ # TODO: Maybe warn if lossy conversion?
135+ array = array .cast (target_type )
136+ bf_type = bigframes .dtypes .arrow_dtype_to_bigframes_dtype (target_type )
137+ storage_type = _get_managed_storage_type (bf_type )
138+ if storage_type != array .type :
139+ raise TypeError (
140+ f"Expected { bf_type } to use arrow { storage_type } , instead got { array .type } "
141+ )
142+ return array , bf_type
77143
78144
79- def arrow_type_replacements (type : pa .DataType ) -> pa .DataType :
145+ def _arrow_type_replacements (type : pa .DataType ) -> pa .DataType :
80146 if pa .types .is_timestamp (type ):
81147 # This is potentially lossy, but BigFrames doesn't support ns
82148 new_tz = "UTC" if (type .tz is not None ) else None
@@ -91,18 +157,24 @@ def arrow_type_replacements(type: pa.DataType) -> pa.DataType:
91157 return pa .decimal128 (38 , 9 )
92158 if pa .types .is_decimal256 (type ):
93159 return pa .decimal256 (76 , 38 )
94- if pa .types .is_dictionary (type ):
95- return arrow_type_replacements (type .value_type )
96160 if pa .types .is_large_string (type ):
97161 # simple string type can handle the largest strings needed
98162 return pa .string ()
99163 if pa .types .is_null (type ):
100164 # null as a type not allowed, default type is float64 for bigframes
101165 return pa .float64 ()
102166 if pa .types .is_list (type ):
103- new_field_t = arrow_type_replacements (type .value_type )
167+ new_field_t = _arrow_type_replacements (type .value_type )
104168 if new_field_t != type .value_type :
105169 return pa .list_ (new_field_t )
106170 return type
171+ if pa .types .is_struct (type ):
172+ struct_type = cast (pa .StructType , type )
173+ new_fields : list [pa .Field ] = []
174+ for i in range (struct_type .num_fields ):
175+ field = struct_type .field (i )
176+ field .with_type (_arrow_type_replacements (field .type ))
177+ new_fields .append (field .with_type (_arrow_type_replacements (field .type )))
178+ return pa .struct (new_fields )
107179 else :
108180 return type
0 commit comments