Skip to content

Commit c5a7cd2

Browse files
authored
Add to_arrow with support for Arrow data format. (googleapis#8644)
* BQ Storage: Add basic arrow stream parser * BQ Storage: Add tests for to_dataframe with arrow data * Add to_arrow with BQ Storage API.
1 parent aba3216 commit c5a7cd2

5 files changed

Lines changed: 446 additions & 30 deletions

File tree

bigquery_storage/google/cloud/bigquery_storage_v1beta1/reader.py

Lines changed: 153 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,29 @@
2727
import pandas
2828
except ImportError: # pragma: NO COVER
2929
pandas = None
30+
try:
31+
import pyarrow
32+
except ImportError: # pragma: NO COVER
33+
pyarrow = None
3034
import six
3135

36+
try:
37+
import pyarrow
38+
except ImportError: # pragma: NO COVER
39+
pyarrow = None
40+
3241
from google.cloud.bigquery_storage_v1beta1 import types
3342

3443

3544
_STREAM_RESUMPTION_EXCEPTIONS = (google.api_core.exceptions.ServiceUnavailable,)
45+
3646
_FASTAVRO_REQUIRED = (
3747
"fastavro is required to parse ReadRowResponse messages with Avro bytes."
3848
)
3949
_PANDAS_REQUIRED = "pandas is required to create a DataFrame"
50+
_PYARROW_REQUIRED = (
51+
"pyarrow is required to parse ReadRowResponse messages with Arrow bytes."
52+
)
4053

4154

4255
class ReadRowsStream(object):
@@ -113,7 +126,7 @@ def __iter__(self):
113126
while True:
114127
try:
115128
for message in self._wrapped:
116-
rowcount = message.avro_rows.row_count
129+
rowcount = message.row_count
117130
self._position.offset += rowcount
118131
yield message
119132

@@ -152,11 +165,28 @@ def rows(self, read_session):
152165
Iterable[Mapping]:
153166
A sequence of rows, represented as dictionaries.
154167
"""
155-
if fastavro is None:
156-
raise ImportError(_FASTAVRO_REQUIRED)
157-
158168
return ReadRowsIterable(self, read_session)
159169

170+
def to_arrow(self, read_session):
171+
"""Create a :class:`pyarrow.Table` of all rows in the stream.
172+
173+
This method requires the pyarrow library and a stream using the Arrow
174+
format.
175+
176+
Args:
177+
read_session ( \
178+
~google.cloud.bigquery_storage_v1beta1.types.ReadSession \
179+
):
180+
The read session associated with this read rows stream. This
181+
contains the schema, which is required to parse the data
182+
messages.
183+
184+
Returns:
185+
pyarrow.Table:
186+
A table of all rows in the stream.
187+
"""
188+
return self.rows(read_session).to_arrow()
189+
160190
def to_dataframe(self, read_session, dtypes=None):
161191
"""Create a :class:`pandas.DataFrame` of all rows in the stream.
162192
@@ -186,8 +216,6 @@ def to_dataframe(self, read_session, dtypes=None):
186216
pandas.DataFrame:
187217
A data frame of all rows in the stream.
188218
"""
189-
if fastavro is None:
190-
raise ImportError(_FASTAVRO_REQUIRED)
191219
if pandas is None:
192220
raise ImportError(_PANDAS_REQUIRED)
193221

@@ -212,6 +240,7 @@ def __init__(self, reader, read_session):
212240
self._status = None
213241
self._reader = reader
214242
self._read_session = read_session
243+
self._stream_parser = _StreamParser.from_read_session(self._read_session)
215244

216245
@property
217246
def total_rows(self):
@@ -231,17 +260,31 @@ def pages(self):
231260
"""
232261
# Each page is an iterator of rows. But also has num_items, remaining,
233262
# and to_dataframe.
234-
stream_parser = _StreamParser(self._read_session)
235263
for message in self._reader:
236264
self._status = message.status
237-
yield ReadRowsPage(stream_parser, message)
265+
yield ReadRowsPage(self._stream_parser, message)
238266

239267
def __iter__(self):
240268
"""Iterator for each row in all pages."""
241269
for page in self.pages:
242270
for row in page:
243271
yield row
244272

273+
def to_arrow(self):
274+
"""Create a :class:`pyarrow.Table` of all rows in the stream.
275+
276+
This method requires the pyarrow library and a stream using the Arrow
277+
format.
278+
279+
Returns:
280+
pyarrow.Table:
281+
A table of all rows in the stream.
282+
"""
283+
record_batches = []
284+
for page in self.pages:
285+
record_batches.append(page.to_arrow())
286+
return pyarrow.Table.from_batches(record_batches)
287+
245288
def to_dataframe(self, dtypes=None):
246289
"""Create a :class:`pandas.DataFrame` of all rows in the stream.
247290
@@ -291,8 +334,8 @@ def __init__(self, stream_parser, message):
291334
self._stream_parser = stream_parser
292335
self._message = message
293336
self._iter_rows = None
294-
self._num_items = self._message.avro_rows.row_count
295-
self._remaining = self._message.avro_rows.row_count
337+
self._num_items = self._message.row_count
338+
self._remaining = self._message.row_count
296339

297340
def _parse_rows(self):
298341
"""Parse rows from the message only once."""
@@ -326,6 +369,15 @@ def next(self):
326369
# Alias needed for Python 2/3 support.
327370
__next__ = next
328371

372+
def to_arrow(self):
373+
"""Create an :class:`pyarrow.RecordBatch` of rows in the page.
374+
375+
Returns:
376+
pyarrow.RecordBatch:
377+
Rows from the message, as an Arrow record batch.
378+
"""
379+
return self._stream_parser.to_arrow(self._message)
380+
329381
def to_dataframe(self, dtypes=None):
330382
"""Create a :class:`pandas.DataFrame` of rows in the page.
331383
@@ -355,21 +407,61 @@ def to_dataframe(self, dtypes=None):
355407

356408

357409
class _StreamParser(object):
410+
def to_arrow(self, message):
411+
raise NotImplementedError("Not implemented.")
412+
413+
def to_dataframe(self, message, dtypes=None):
414+
raise NotImplementedError("Not implemented.")
415+
416+
def to_rows(self, message):
417+
raise NotImplementedError("Not implemented.")
418+
419+
@staticmethod
420+
def from_read_session(read_session):
421+
schema_type = read_session.WhichOneof("schema")
422+
if schema_type == "avro_schema":
423+
return _AvroStreamParser(read_session)
424+
elif schema_type == "arrow_schema":
425+
return _ArrowStreamParser(read_session)
426+
else:
427+
raise TypeError(
428+
"Unsupported schema type in read_session: {0}".format(schema_type)
429+
)
430+
431+
432+
class _AvroStreamParser(_StreamParser):
358433
"""Helper to parse Avro messages into useful representations."""
359434

360435
def __init__(self, read_session):
361-
"""Construct a _StreamParser.
436+
"""Construct an _AvroStreamParser.
362437
363438
Args:
364439
read_session (google.cloud.bigquery_storage_v1beta1.types.ReadSession):
365440
A read session. This is required because it contains the schema
366441
used in the stream messages.
367442
"""
443+
if fastavro is None:
444+
raise ImportError(_FASTAVRO_REQUIRED)
445+
368446
self._read_session = read_session
369447
self._avro_schema_json = None
370448
self._fastavro_schema = None
371449
self._column_names = None
372450

451+
def to_arrow(self, message):
452+
"""Create an :class:`pyarrow.RecordBatch` of rows in the page.
453+
454+
Args:
455+
message (google.cloud.bigquery_storage_v1beta1.types.ReadRowsResponse):
456+
Protocol buffer from the read rows stream, to convert into an
457+
Arrow record batch.
458+
459+
Returns:
460+
pyarrow.RecordBatch:
461+
Rows from the message, as an Arrow record batch.
462+
"""
463+
raise NotImplementedError("to_arrow not implemented for Avro streams.")
464+
373465
def to_dataframe(self, message, dtypes=None):
374466
"""Create a :class:`pandas.DataFrame` of rows in the page.
375467
@@ -447,6 +539,56 @@ def to_rows(self, message):
447539
break # Finished with message
448540

449541

542+
class _ArrowStreamParser(_StreamParser):
543+
def __init__(self, read_session):
544+
if pyarrow is None:
545+
raise ImportError(_PYARROW_REQUIRED)
546+
547+
self._read_session = read_session
548+
self._schema = None
549+
550+
def to_arrow(self, message):
551+
return self._parse_arrow_message(message)
552+
553+
def to_rows(self, message):
554+
record_batch = self._parse_arrow_message(message)
555+
556+
# Iterate through each column simultaneously, and make a dict from the
557+
# row values
558+
for row in zip(*record_batch.columns):
559+
yield dict(zip(self._column_names, row))
560+
561+
def to_dataframe(self, message, dtypes=None):
562+
record_batch = self._parse_arrow_message(message)
563+
564+
if dtypes is None:
565+
dtypes = {}
566+
567+
df = record_batch.to_pandas()
568+
569+
for column in dtypes:
570+
df[column] = pandas.Series(df[column], dtype=dtypes[column])
571+
572+
return df
573+
574+
def _parse_arrow_message(self, message):
575+
self._parse_arrow_schema()
576+
577+
return pyarrow.read_record_batch(
578+
pyarrow.py_buffer(message.arrow_record_batch.serialized_record_batch),
579+
self._schema,
580+
)
581+
582+
def _parse_arrow_schema(self):
583+
if self._schema:
584+
return
585+
586+
self._schema = pyarrow.read_schema(
587+
pyarrow.py_buffer(self._read_session.arrow_schema.serialized_schema)
588+
)
589+
self._column_names = [field.name for field in self._schema]
590+
591+
450592
def _copy_stream_position(position):
451593
"""Copy a StreamPosition.
452594

bigquery_storage/noxfile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def default(session):
3737
session.install('mock', 'pytest', 'pytest-cov')
3838
for local_dep in LOCAL_DEPS:
3939
session.install('-e', local_dep)
40-
session.install('-e', '.[pandas,fastavro]')
40+
session.install('-e', '.[pandas,fastavro,pyarrow]')
4141

4242
# Run py.test against the unit tests.
4343
session.run(
@@ -121,7 +121,7 @@ def system(session):
121121
session.install('-e', os.path.join('..', 'test_utils'))
122122
for local_dep in LOCAL_DEPS:
123123
session.install('-e', local_dep)
124-
session.install('-e', '.[pandas,fastavro]')
124+
session.install('-e', '.[fastavro,pandas,pyarrow]')
125125

126126
# Run py.test against the system tests.
127127
session.run('py.test', '--quiet', 'tests/system/')

bigquery_storage/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
extras = {
3232
'pandas': 'pandas>=0.17.1',
3333
'fastavro': 'fastavro>=0.21.2',
34+
'pyarrow': 'pyarrow>=0.13.0',
3435
}
3536

3637
package_root = os.path.abspath(os.path.dirname(__file__))

bigquery_storage/tests/system/test_system.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919

2020
import numpy
21+
import pyarrow.types
2122
import pytest
2223

2324
from google.cloud import bigquery_storage_v1beta1
@@ -67,14 +68,82 @@ def test_read_rows_full_table(client, project_id, small_table_reference):
6768
assert len(block.avro_rows.serialized_binary_rows) > 0
6869

6970

70-
def test_read_rows_to_dataframe(client, project_id):
71+
def test_read_rows_to_arrow(client, project_id):
72+
table_ref = bigquery_storage_v1beta1.types.TableReference()
73+
table_ref.project_id = "bigquery-public-data"
74+
table_ref.dataset_id = "new_york_citibike"
75+
table_ref.table_id = "citibike_stations"
76+
77+
read_options = bigquery_storage_v1beta1.types.TableReadOptions()
78+
read_options.selected_fields.append("station_id")
79+
read_options.selected_fields.append("latitude")
80+
read_options.selected_fields.append("longitude")
81+
read_options.selected_fields.append("name")
82+
session = client.create_read_session(
83+
table_ref,
84+
"projects/{}".format(project_id),
85+
format_=bigquery_storage_v1beta1.enums.DataFormat.ARROW,
86+
read_options=read_options,
87+
requested_streams=1,
88+
)
89+
stream_pos = bigquery_storage_v1beta1.types.StreamPosition(
90+
stream=session.streams[0]
91+
)
92+
93+
tbl = client.read_rows(stream_pos).to_arrow(session)
94+
95+
assert tbl.num_columns == 4
96+
schema = tbl.schema
97+
# Use field_by_name because the order doesn't currently match that of
98+
# selected_fields.
99+
assert pyarrow.types.is_int64(schema.field_by_name("station_id").type)
100+
assert pyarrow.types.is_float64(schema.field_by_name("latitude").type)
101+
assert pyarrow.types.is_float64(schema.field_by_name("longitude").type)
102+
assert pyarrow.types.is_string(schema.field_by_name("name").type)
103+
104+
105+
def test_read_rows_to_dataframe_w_avro(client, project_id):
71106
table_ref = bigquery_storage_v1beta1.types.TableReference()
72107
table_ref.project_id = "bigquery-public-data"
73108
table_ref.dataset_id = "new_york_citibike"
74109
table_ref.table_id = "citibike_stations"
75110
session = client.create_read_session(
76111
table_ref, "projects/{}".format(project_id), requested_streams=1
77112
)
113+
schema_type = session.WhichOneof("schema")
114+
assert schema_type == "avro_schema"
115+
116+
stream_pos = bigquery_storage_v1beta1.types.StreamPosition(
117+
stream=session.streams[0]
118+
)
119+
120+
frame = client.read_rows(stream_pos).to_dataframe(
121+
session, dtypes={"latitude": numpy.float16}
122+
)
123+
124+
# Station ID is a required field (no nulls), so the datatype should always
125+
# be integer.
126+
assert frame.station_id.dtype.name == "int64"
127+
assert frame.latitude.dtype.name == "float16"
128+
assert frame.longitude.dtype.name == "float64"
129+
assert frame["name"].str.startswith("Central Park").any()
130+
131+
132+
def test_read_rows_to_dataframe_w_arrow(client, project_id):
133+
table_ref = bigquery_storage_v1beta1.types.TableReference()
134+
table_ref.project_id = "bigquery-public-data"
135+
table_ref.dataset_id = "new_york_citibike"
136+
table_ref.table_id = "citibike_stations"
137+
138+
session = client.create_read_session(
139+
table_ref,
140+
"projects/{}".format(project_id),
141+
format_=bigquery_storage_v1beta1.enums.DataFormat.ARROW,
142+
requested_streams=1,
143+
)
144+
schema_type = session.WhichOneof("schema")
145+
assert schema_type == "arrow_schema"
146+
78147
stream_pos = bigquery_storage_v1beta1.types.StreamPosition(
79148
stream=session.streams[0]
80149
)

0 commit comments

Comments
 (0)