Skip to content

Commit f80f05f

Browse files
vmallya-123adchia
andauthored
feat: Adding billing_project_id in BigQueryOfflineStoreConfig (feast-dev#3253)
* adding_billing_project_in_config Signed-off-by: “Varun <varun.mallya@tech.jago.com> * Fix lint Signed-off-by: Danny Chiao <danny@tecton.ai> Signed-off-by: “Varun <varun.mallya@tech.jago.com> Signed-off-by: Danny Chiao <danny@tecton.ai> Co-authored-by: Danny Chiao <danny@tecton.ai>
1 parent 53dc811 commit f80f05f

1 file changed

Lines changed: 44 additions & 16 deletions

File tree

sdk/python/feast/infra/offline_stores/bigquery.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pandas as pd
2020
import pyarrow
2121
import pyarrow.parquet
22-
from pydantic import StrictStr
22+
from pydantic import StrictStr, validator
2323
from pydantic.typing import Literal
2424
from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed
2525

@@ -83,7 +83,8 @@ class BigQueryOfflineStoreConfig(FeastConfigBaseModel):
8383

8484
project_id: Optional[StrictStr] = None
8585
""" (optional) GCP project name used for the BigQuery offline store """
86-
86+
billing_project_id: Optional[StrictStr] = None
87+
""" (optional) GCP project name used to run the bigquery jobs at """
8788
location: Optional[StrictStr] = None
8889
""" (optional) GCP location name used for the BigQuery offline store.
8990
Examples of location names include ``US``, ``EU``, ``us-central1``, ``us-west4``.
@@ -94,6 +95,14 @@ class BigQueryOfflineStoreConfig(FeastConfigBaseModel):
9495
gcs_staging_location: Optional[str] = None
9596
""" (optional) GCS location used for offloading BigQuery results as parquet files."""
9697

98+
@validator("billing_project_id")
99+
def project_id_exists(cls, v, values, **kwargs):
100+
if v and not values["project_id"]:
101+
raise ValueError(
102+
"please specify project_id if billing_project_id is specified"
103+
)
104+
return v
105+
97106

98107
class BigQueryOfflineStore(OfflineStore):
99108
@staticmethod
@@ -122,9 +131,11 @@ def pull_latest_from_table_or_query(
122131
timestamps.append(created_timestamp_column)
123132
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
124133
field_string = ", ".join(join_key_columns + feature_name_columns + timestamps)
125-
134+
project_id = (
135+
config.offline_store.billing_project_id or config.offline_store.project_id
136+
)
126137
client = _get_bigquery_client(
127-
project=config.offline_store.project_id,
138+
project=project_id,
128139
location=config.offline_store.location,
129140
)
130141
query = f"""
@@ -162,9 +173,11 @@ def pull_all_from_table_or_query(
162173
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
163174
assert isinstance(data_source, BigQuerySource)
164175
from_expression = data_source.get_table_query_string()
165-
176+
project_id = (
177+
config.offline_store.billing_project_id or config.offline_store.project_id
178+
)
166179
client = _get_bigquery_client(
167-
project=config.offline_store.project_id,
180+
project=project_id,
168181
location=config.offline_store.location,
169182
)
170183
field_string = ", ".join(
@@ -197,17 +210,22 @@ def get_historical_features(
197210
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
198211
for fv in feature_views:
199212
assert isinstance(fv.batch_source, BigQuerySource)
200-
213+
project_id = (
214+
config.offline_store.billing_project_id or config.offline_store.project_id
215+
)
201216
client = _get_bigquery_client(
202-
project=config.offline_store.project_id,
217+
project=project_id,
203218
location=config.offline_store.location,
204219
)
205220

206221
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
207-
222+
if config.offline_store.billing_project_id:
223+
dataset_project = str(config.offline_store.project_id)
224+
else:
225+
dataset_project = client.project
208226
table_reference = _get_table_reference_for_new_entity(
209227
client,
210-
client.project,
228+
dataset_project,
211229
config.offline_store.dataset,
212230
config.offline_store.location,
213231
)
@@ -295,9 +313,11 @@ def write_logged_features(
295313
):
296314
destination = logging_config.destination
297315
assert isinstance(destination, BigQueryLoggingDestination)
298-
316+
project_id = (
317+
config.offline_store.billing_project_id or config.offline_store.project_id
318+
)
299319
client = _get_bigquery_client(
300-
project=config.offline_store.project_id,
320+
project=project_id,
301321
location=config.offline_store.location,
302322
)
303323

@@ -353,9 +373,11 @@ def offline_write_batch(
353373

354374
if table.schema != pa_schema:
355375
table = table.cast(pa_schema)
356-
376+
project_id = (
377+
config.offline_store.billing_project_id or config.offline_store.project_id
378+
)
357379
client = _get_bigquery_client(
358-
project=config.offline_store.project_id,
380+
project=project_id,
359381
location=config.offline_store.location,
360382
)
361383

@@ -451,7 +473,10 @@ def to_bigquery(
451473
if not job_config:
452474
today = date.today().strftime("%Y%m%d")
453475
rand_id = str(uuid.uuid4())[:7]
454-
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
476+
if self.config.offline_store.billing_project_id:
477+
path = f"{self.config.offline_store.project_id}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
478+
else:
479+
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
455480
job_config = bigquery.QueryJobConfig(destination=path)
456481

457482
if not job_config.dry_run and self.on_demand_feature_views:
@@ -525,7 +550,10 @@ def to_remote_storage(self) -> List[str]:
525550

526551
bucket: str
527552
prefix: str
528-
storage_client = StorageClient(project=self.client.project)
553+
if self.config.offline_store.billing_project_id:
554+
storage_client = StorageClient(project=self.config.offline_store.project_id)
555+
else:
556+
storage_client = StorageClient(project=self.client.project)
529557
bucket, prefix = self._gcs_path[len("gs://") :].split("/", 1)
530558
prefix = prefix.rsplit("/", 1)[0]
531559
if prefix.startswith("/"):

0 commit comments

Comments
 (0)