Skip to content

Commit 3c46a4d

Browse files
authored
Add to_bigquery() function to BigQueryRetrievalJob (feast-dev#1634)
* Add to_bigquery() function for bq retrieval job Signed-off-by: Vivian Tao <vivian.tao@shopify.com> * Using tenacity for retries Signed-off-by: Vivian Tao <vivian.tao@shopify.com> * Refactoring to_biquery function Signed-off-by: Vivian Tao <vivian.tao@shopify.com> * Adding tenacity dependency and changing temp table prefix to historical Signed-off-by: Vivian Tao <vivian.tao@shopify.com> * Use self.client instead of creating a new client Signed-off-by: Vivian Tao <vivian.tao@shopify.com> * pin tenacity to major version Signed-off-by: Vivian Tao <vivian.tao@shopify.com> * Tenacity dependency range Signed-off-by: Vivian Tao <vivian.tao@shopify.com>
1 parent d71b452 commit 3c46a4d

3 files changed

Lines changed: 48 additions & 4 deletions

File tree

sdk/python/feast/infra/offline_stores/bigquery.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import time
2+
import uuid
23
from dataclasses import asdict, dataclass
3-
from datetime import datetime, timedelta
4+
from datetime import date, datetime, timedelta
45
from typing import List, Optional, Set, Union
56

67
import pandas
78
import pyarrow
89
from jinja2 import BaseLoader, Environment
10+
from tenacity import retry, stop_after_delay, wait_fixed
911

1012
from feast import errors
1113
from feast.data_source import BigQuerySource, DataSource
@@ -118,7 +120,7 @@ def get_historical_features(
118120
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
119121
)
120122

121-
job = BigQueryRetrievalJob(query=query, client=client)
123+
job = BigQueryRetrievalJob(query=query, client=client, config=config)
122124
return job
123125

124126

@@ -206,15 +208,41 @@ def _infer_event_timestamp_from_dataframe(entity_df: pandas.DataFrame) -> str:
206208

207209

208210
class BigQueryRetrievalJob(RetrievalJob):
209-
def __init__(self, query, client):
211+
def __init__(self, query, client, config):
210212
self.query = query
211213
self.client = client
214+
self.config = config
212215

213216
def to_df(self):
214217
# TODO: Ideally only start this job when the user runs "get_historical_features", not when they run to_df()
215218
df = self.client.query(self.query).to_dataframe(create_bqstorage_client=True)
216219
return df
217220

221+
def to_bigquery(self, dry_run=False) -> Optional[str]:
222+
@retry(wait=wait_fixed(10), stop=stop_after_delay(1800), reraise=True)
223+
def _block_until_done():
224+
return self.client.get_job(bq_job.job_id).state in ["PENDING", "RUNNING"]
225+
226+
today = date.today().strftime("%Y%m%d")
227+
rand_id = str(uuid.uuid4())[:7]
228+
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
229+
job_config = bigquery.QueryJobConfig(destination=path, dry_run=dry_run)
230+
bq_job = self.client.query(self.query, job_config=job_config)
231+
232+
if dry_run:
233+
print(
234+
"This query will process {} bytes.".format(bq_job.total_bytes_processed)
235+
)
236+
return None
237+
238+
_block_until_done()
239+
240+
if bq_job.exception():
241+
raise bq_job.exception()
242+
243+
print(f"Done writing to '{path}'.")
244+
return path
245+
218246

219247
@dataclass(frozen=True)
220248
class FeatureViewQueryContext:

sdk/python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"pydantic>=1.0.0",
5555
"PyYAML==5.3.*",
5656
"tabulate==0.8.*",
57+
"tenacity>=7.*",
5758
"toml==0.10.*",
5859
"tqdm==4.*",
5960
]
@@ -92,7 +93,6 @@
9293
"pytest-mock==1.10.4",
9394
"Sphinx!=4.0.0",
9495
"sphinx-rtd-theme",
95-
"tenacity",
9696
"adlfs==0.5.9",
9797
"firebase-admin==4.5.2",
9898
"pre-commit",

sdk/python/tests/test_historical_retrieval.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,22 @@ def test_historical_features_from_bigquery_sources(
441441
],
442442
)
443443

444+
# Just a dry run, should not create table
445+
bq_dry_run = job_from_sql.to_bigquery(dry_run=True)
446+
assert bq_dry_run is None
447+
448+
bq_temp_table_path = job_from_sql.to_bigquery()
449+
assert bq_temp_table_path.split(".")[0] == gcp_project
450+
451+
if provider_type == "gcp_custom_offline_config":
452+
assert bq_temp_table_path.split(".")[1] == "foo"
453+
else:
454+
assert bq_temp_table_path.split(".")[1] == bigquery_dataset
455+
456+
# Check that this table actually exists
457+
actual_bq_temp_table = bigquery.Client().get_table(bq_temp_table_path)
458+
assert actual_bq_temp_table.table_id == bq_temp_table_path.split(".")[-1]
459+
444460
start_time = datetime.utcnow()
445461
actual_df_from_sql_entities = job_from_sql.to_df()
446462
end_time = datetime.utcnow()

0 commit comments

Comments
 (0)