Skip to content

Commit 09bb705

Browse files
committed
Allow subset of schema to be passed into load_table_from_dataframe.
The types of any remaining columns will be auto-detected.
1 parent a6126b7 commit 09bb705

File tree

3 files changed

+118
-20
lines changed

3 files changed

+118
-20
lines changed

bigquery/google/cloud/bigquery/_pandas_helpers.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -187,37 +187,50 @@ def bq_to_arrow_array(series, bq_field):
187187
return pyarrow.array(series, type=arrow_type)
188188

189189

190-
def dataframe_to_bq_schema(dataframe):
190+
def dataframe_to_bq_schema(dataframe, bq_schema):
191191
"""Convert a pandas DataFrame schema to a BigQuery schema.
192192
193-
TODO(GH#8140): Add bq_schema argument to allow overriding autodetected
194-
schema for a subset of columns.
195-
196193
Args:
197194
dataframe (pandas.DataFrame):
198-
DataFrame to convert to convert to Parquet file.
195+
DataFrame for which the client determines the BigQuery schema.
196+
bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
197+
A BigQuery schema. Use this argument to override the autodetected
198+
type for some or all of the DataFrame columns.
199199
200200
Returns:
201201
Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]:
202202
The automatically determined schema. Returns None if the type of
203203
any column cannot be determined.
204204
"""
205-
bq_schema = []
205+
if bq_schema:
206+
bq_schema_index = {field.name: field for field in bq_schema}
207+
else:
208+
bq_schema_index = {}
209+
210+
bq_schema_out = []
206211
for column, dtype in zip(dataframe.columns, dataframe.dtypes):
212+
# Use provided type from schema, if present.
213+
bq_field = bq_schema_index.get(column)
214+
if bq_field:
215+
bq_schema_out.append(bq_field)
216+
continue
217+
218+
# Otherwise, try to automatically determine the type based on the
219+
# pandas dtype.
207220
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
208221
if not bq_type:
209222
return None
210223
bq_field = schema.SchemaField(column, bq_type)
211-
bq_schema.append(bq_field)
212-
return tuple(bq_schema)
224+
bq_schema_out.append(bq_field)
225+
return tuple(bq_schema_out)
213226

214227

215228
def dataframe_to_arrow(dataframe, bq_schema):
216229
"""Convert pandas dataframe to Arrow table, using BigQuery schema.
217230
218231
Args:
219232
dataframe (pandas.DataFrame):
220-
DataFrame to convert to convert to Parquet file.
233+
DataFrame to convert to Arrow table.
221234
bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
222235
Desired BigQuery schema. Number of columns must match number of
223236
columns in the DataFrame.
@@ -255,7 +268,7 @@ def dataframe_to_parquet(dataframe, bq_schema, filepath, parquet_compression="SN
255268
256269
Args:
257270
dataframe (pandas.DataFrame):
258-
DataFrame to convert to convert to Parquet file.
271+
DataFrame to convert to Parquet file.
259272
bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
260273
Desired BigQuery schema. Number of columns must match number of
261274
columns in the DataFrame.

bigquery/google/cloud/bigquery/client.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,14 +1531,15 @@ def load_table_from_dataframe(
15311531
if location is None:
15321532
location = self.location
15331533

1534-
if not job_config.schema:
1535-
autodetected_schema = _pandas_helpers.dataframe_to_bq_schema(dataframe)
1536-
1537-
# Only use an explicit schema if we were able to determine one
1538-
# matching the dataframe. If not, fallback to the pandas to_parquet
1539-
# method.
1540-
if autodetected_schema:
1541-
job_config.schema = autodetected_schema
1534+
autodetected_schema = _pandas_helpers.dataframe_to_bq_schema(
1535+
dataframe, job_config.schema
1536+
)
1537+
1538+
# Only use an explicit schema if we were able to determine one
1539+
# matching the dataframe. If not, fallback to the pandas to_parquet
1540+
# method.
1541+
if autodetected_schema:
1542+
job_config.schema = autodetected_schema
15421543

15431544
tmpfd, tmppath = tempfile.mkstemp(suffix="_job_{}.parquet".format(job_id[:8]))
15441545
os.close(tmpfd)

bigquery/tests/unit/test_client.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5393,6 +5393,90 @@ def test_load_table_from_dataframe_w_automatic_schema(self):
53935393
SchemaField("ts_col", "TIMESTAMP"),
53945394
)
53955395

5396+
@unittest.skipIf(pandas is None, "Requires `pandas`")
5397+
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
5398+
def test_load_table_from_dataframe_w_partial_automatic_schema(self):
5399+
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
5400+
from google.cloud.bigquery import job
5401+
from google.cloud.bigquery.schema import SchemaField
5402+
5403+
client = self._make_client()
5404+
dt_col = pandas.Series(
5405+
[
5406+
datetime.datetime(2010, 1, 2, 3, 44, 50),
5407+
datetime.datetime(2011, 2, 3, 14, 50, 59),
5408+
datetime.datetime(2012, 3, 14, 15, 16),
5409+
],
5410+
dtype="datetime64[ns]",
5411+
)
5412+
ts_col = pandas.Series(
5413+
[
5414+
datetime.datetime(2010, 1, 2, 3, 44, 50),
5415+
datetime.datetime(2011, 2, 3, 14, 50, 59),
5416+
datetime.datetime(2012, 3, 14, 15, 16),
5417+
],
5418+
dtype="datetime64[ns]",
5419+
).dt.tz_localize(pytz.utc)
5420+
df_data = {
5421+
"int_col": [1, 2, 3],
5422+
"int_as_float_col": [1.0, float("nan"), 3.0],
5423+
"float_col": [1.0, 2.0, 3.0],
5424+
"bool_col": [True, False, True],
5425+
"dt_col": dt_col,
5426+
"ts_col": ts_col,
5427+
"string_col": ["abc", "def", "ghi"],
5428+
}
5429+
dataframe = pandas.DataFrame(
5430+
df_data,
5431+
columns=[
5432+
"int_col",
5433+
"int_as_float_col",
5434+
"float_col",
5435+
"bool_col",
5436+
"dt_col",
5437+
"ts_col",
5438+
"string_col",
5439+
],
5440+
)
5441+
load_patch = mock.patch(
5442+
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
5443+
)
5444+
5445+
schema = (
5446+
SchemaField("int_as_float_col", "INTEGER"),
5447+
SchemaField("string_col", "STRING"),
5448+
)
5449+
job_config = job.LoadJobConfig(schema=schema)
5450+
with load_patch as load_table_from_file:
5451+
client.load_table_from_dataframe(
5452+
dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION
5453+
)
5454+
5455+
load_table_from_file.assert_called_once_with(
5456+
client,
5457+
mock.ANY,
5458+
self.TABLE_REF,
5459+
num_retries=_DEFAULT_NUM_RETRIES,
5460+
rewind=True,
5461+
job_id=mock.ANY,
5462+
job_id_prefix=None,
5463+
location=self.LOCATION,
5464+
project=None,
5465+
job_config=mock.ANY,
5466+
)
5467+
5468+
sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
5469+
assert sent_config.source_format == job.SourceFormat.PARQUET
5470+
assert tuple(sent_config.schema) == (
5471+
SchemaField("int_col", "INTEGER"),
5472+
SchemaField("int_as_float_col", "INTEGER"),
5473+
SchemaField("float_col", "FLOAT"),
5474+
SchemaField("bool_col", "BOOLEAN"),
5475+
SchemaField("dt_col", "DATETIME"),
5476+
SchemaField("ts_col", "TIMESTAMP"),
5477+
SchemaField("string_col", "STRING"),
5478+
)
5479+
53965480
@unittest.skipIf(pandas is None, "Requires `pandas`")
53975481
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
53985482
def test_load_table_from_dataframe_w_schema_wo_pyarrow(self):
@@ -5402,7 +5486,7 @@ def test_load_table_from_dataframe_w_schema_wo_pyarrow(self):
54025486

54035487
client = self._make_client()
54045488
records = [{"name": "Monty", "age": 100}, {"name": "Python", "age": 60}]
5405-
dataframe = pandas.DataFrame(records)
5489+
dataframe = pandas.DataFrame(records, columns=["name", "age"])
54065490
schema = (SchemaField("name", "STRING"), SchemaField("age", "INTEGER"))
54075491
job_config = job.LoadJobConfig(schema=schema)
54085492

@@ -5514,7 +5598,7 @@ def test_load_table_from_dataframe_w_nulls(self):
55145598

55155599
client = self._make_client()
55165600
records = [{"name": None, "age": None}, {"name": None, "age": None}]
5517-
dataframe = pandas.DataFrame(records)
5601+
dataframe = pandas.DataFrame(records, columns=["name", "age"])
55185602
schema = [SchemaField("name", "STRING"), SchemaField("age", "INTEGER")]
55195603
job_config = job.LoadJobConfig(schema=schema)
55205604

0 commit comments

Comments
 (0)