Skip to content

Commit 66038c7

Browse files
authored
chore: Implement to_remote_storage for supported offline stores (#2918)
* chore: Implement to_remote_storage for RedshiftRetrievalJob Signed-off-by: Achal Shah <achals@gmail.com> * Implement to_remote_storage for BigQuery Signed-off-by: Achal Shah <achals@gmail.com> * add for snowflake as well Signed-off-by: Achal Shah <achals@gmail.com> * fully fleshed for snowflake Signed-off-by: Achal Shah <achals@gmail.com> * better test config : Signed-off-by: Achal Shah <achals@gmail.com> * fix tests: Signed-off-by: Achal Shah <achals@gmail.com> * fix tests Signed-off-by: Achal Shah <achals@gmail.com> * more fixes Signed-off-by: Achal Shah <achals@gmail.com> * more fixes Signed-off-by: Achal Shah <achals@gmail.com> * more fixes Signed-off-by: Achal Shah <achals@gmail.com> * fix bigquery Signed-off-by: Achal Shah <achals@gmail.com> * use temp table for entity df table Signed-off-by: Achal Shah <achals@gmail.com> * simplify condition Signed-off-by: Achal Shah <achals@gmail.com> * remove temp Signed-off-by: Achal Shah <achals@gmail.com>
1 parent 109ee9c commit 66038c7

File tree

8 files changed

+131
-5
lines changed

8 files changed

+131
-5
lines changed

sdk/python/feast/infra/offline_stores/bigquery.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from google.cloud import bigquery
5858
from google.cloud.bigquery import Client, SchemaField, Table
5959
from google.cloud.bigquery._pandas_helpers import ARROW_SCALAR_IDS_TO_BQ
60+
from google.cloud.storage import Client as StorageClient
6061

6162
except ImportError as e:
6263
from feast.errors import FeastExtrasDependencyImportError
@@ -83,6 +84,9 @@ class BigQueryOfflineStoreConfig(FeastConfigBaseModel):
8384
For more information on BigQuery data locations see: https://cloud.google.com/bigquery/docs/locations
8485
"""
8586

87+
gcs_staging_location: Optional[str] = None
88+
""" (optional) GCS location used for offloading BigQuery results as parquet files."""
89+
8690

8791
class BigQueryOfflineStore(OfflineStore):
8892
@staticmethod
@@ -386,6 +390,14 @@ def query_generator() -> Iterator[str]:
386390
on_demand_feature_views if on_demand_feature_views else []
387391
)
388392
self._metadata = metadata
393+
if self.config.offline_store.gcs_staging_location:
394+
self._gcs_path = (
395+
self.config.offline_store.gcs_staging_location
396+
+ f"/{self.config.project}/export/"
397+
+ str(uuid.uuid4())
398+
)
399+
else:
400+
self._gcs_path = None
389401

390402
@property
391403
def full_feature_names(self) -> bool:
@@ -478,6 +490,43 @@ def persist(self, storage: SavedDatasetStorage):
478490
def metadata(self) -> Optional[RetrievalMetadata]:
479491
return self._metadata
480492

493+
def supports_remote_storage_export(self) -> bool:
494+
return self._gcs_path is not None
495+
496+
def to_remote_storage(self) -> List[str]:
497+
if not self._gcs_path:
498+
raise ValueError(
499+
"gcs_staging_location needs to be specified for the big query "
500+
"offline store when executing `to_remote_storage()`"
501+
)
502+
503+
table = self.to_bigquery()
504+
505+
job_config = bigquery.job.ExtractJobConfig()
506+
job_config.destination_format = "PARQUET"
507+
508+
extract_job = self.client.extract_table(
509+
table,
510+
destination_uris=[f"{self._gcs_path}/*.parquet"],
511+
location=self.config.offline_store.location,
512+
job_config=job_config,
513+
)
514+
extract_job.result()
515+
516+
bucket: str
517+
prefix: str
518+
storage_client = StorageClient(project=self.client.project)
519+
bucket, prefix = self._gcs_path[len("gs://") :].split("/", 1)
520+
prefix = prefix.rsplit("/", 1)[0]
521+
if prefix.startswith("/"):
522+
prefix = prefix[1:]
523+
524+
blobs = storage_client.list_blobs(bucket, prefix=prefix)
525+
results = []
526+
for b in blobs:
527+
results.append(f"gs://{b.bucket.name}/{b.name}")
528+
return results
529+
481530

482531
def block_until_done(
483532
client: Client,

sdk/python/feast/infra/offline_stores/file.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def persist(self, storage: SavedDatasetStorage):
105105
def metadata(self) -> Optional[RetrievalMetadata]:
106106
return self._metadata
107107

108+
def supports_remote_storage_export(self) -> bool:
109+
return False
110+
108111

109112
class FileOfflineStore(OfflineStore):
110113
@staticmethod

sdk/python/feast/infra/offline_stores/redshift.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,13 @@ def persist(self, storage: SavedDatasetStorage):
490490
def metadata(self) -> Optional[RetrievalMetadata]:
491491
return self._metadata
492492

493+
def supports_remote_storage_export(self) -> bool:
494+
return True
495+
496+
def to_remote_storage(self) -> List[str]:
497+
path = self.to_s3()
498+
return aws_utils.list_s3_files(self._config.offline_store.region, path)
499+
493500

494501
def _upload_entity_df(
495502
entity_df: Union[pd.DataFrame, str],

sdk/python/feast/infra/offline_stores/snowflake.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import os
3+
import uuid
34
from datetime import datetime
45
from pathlib import Path
56
from typing import (
@@ -90,6 +91,12 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel):
9091
schema_: Optional[str] = Field(None, alias="schema")
9192
""" Snowflake schema name """
9293

94+
storage_integration_name: Optional[str] = None
95+
""" Storage integration name in snowflake """
96+
97+
blob_export_location: Optional[str] = None
98+
""" Location (in S3, Google storage or Azure storage) where data is offloaded """
99+
93100
class Config:
94101
allow_population_by_field_name = True
95102

@@ -378,6 +385,11 @@ def query_generator() -> Iterator[str]:
378385
on_demand_feature_views if on_demand_feature_views else []
379386
)
380387
self._metadata = metadata
388+
self.export_path: Optional[str]
389+
if self.config.offline_store.blob_export_location:
390+
self.export_path = f"{self.config.offline_store.blob_export_location}/{self.config.project}/{uuid.uuid4()}"
391+
else:
392+
self.export_path = None
381393

382394
@property
383395
def full_feature_names(self) -> bool:
@@ -413,7 +425,7 @@ def _to_arrow_internal(self) -> pa.Table:
413425
pd.DataFrame(columns=[md.name for md in empty_result.description])
414426
)
415427

416-
def to_snowflake(self, table_name: str) -> None:
428+
def to_snowflake(self, table_name: str, temporary=False) -> None:
417429
"""Save dataset as a new Snowflake table"""
418430
if self.on_demand_feature_views is not None:
419431
transformed_df = self.to_df()
@@ -425,7 +437,7 @@ def to_snowflake(self, table_name: str) -> None:
425437
return None
426438

427439
with self._query_generator() as query:
428-
query = f'CREATE TABLE IF NOT EXISTS "{table_name}" AS ({query});\n'
440+
query = f'CREATE {"TEMPORARY" if temporary else ""} TABLE IF NOT EXISTS "{table_name}" AS ({query});\n'
429441

430442
execute_snowflake_statement(self.snowflake_conn, query)
431443

@@ -453,6 +465,41 @@ def persist(self, storage: SavedDatasetStorage):
453465
def metadata(self) -> Optional[RetrievalMetadata]:
454466
return self._metadata
455467

468+
def supports_remote_storage_export(self) -> bool:
469+
return (
470+
self.config.offline_store.storage_integration_name
471+
and self.config.offline_store.blob_export_location
472+
)
473+
474+
def to_remote_storage(self) -> List[str]:
475+
if not self.export_path:
476+
raise ValueError(
477+
"to_remote_storage() requires `blob_export_location` to be specified in config"
478+
)
479+
if not self.config.offline_store.storage_integration_name:
480+
raise ValueError(
481+
"to_remote_storage() requires `storage_integration_name` to be specified in config"
482+
)
483+
484+
table = f"temporary_{uuid.uuid4().hex}"
485+
self.to_snowflake(table)
486+
487+
copy_into_query = f"""copy into '{self.config.offline_store.blob_export_location}/{table}' from "{self.config.offline_store.database}"."{self.config.offline_store.schema_}"."{table}"\n
488+
storage_integration = {self.config.offline_store.storage_integration_name}\n
489+
file_format = (TYPE = PARQUET)\n
490+
DETAILED_OUTPUT = TRUE\n
491+
HEADER = TRUE;\n
492+
"""
493+
494+
cursor = execute_snowflake_statement(self.snowflake_conn, copy_into_query)
495+
all_rows = (
496+
cursor.fetchall()
497+
) # This may be need pagination at some point in the future.
498+
file_name_column_index = [
499+
idx for idx, rm in enumerate(cursor.description) if rm.name == "FILE_NAME"
500+
][0]
501+
return [f"{self.export_path}/{row[file_name_column_index]}" for row in all_rows]
502+
456503

457504
def _get_entity_schema(
458505
entity_df: Union[pd.DataFrame, str],

sdk/python/feast/infra/utils/aws_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import tempfile
44
import uuid
55
from pathlib import Path
6-
from typing import Any, Dict, Iterator, Optional, Tuple, Union
6+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
77

88
import pandas as pd
99
import pyarrow
@@ -473,7 +473,7 @@ def execute_redshift_query_and_unload_to_s3(
473473
# Run the query, unload the results to S3
474474
unique_table_name = "_" + str(uuid.uuid4()).replace("-", "")
475475
query = f"CREATE TEMPORARY TABLE {unique_table_name} AS ({query});\n"
476-
query += f"UNLOAD ('SELECT * FROM {unique_table_name}') TO '{s3_path}/' IAM_ROLE '{iam_role}' PARQUET"
476+
query += f"UNLOAD ('SELECT * FROM {unique_table_name}') TO '{s3_path}/' IAM_ROLE '{iam_role}' FORMAT AS PARQUET"
477477
execute_redshift_statement(redshift_data_client, cluster_id, database, user, query)
478478

479479

@@ -632,3 +632,14 @@ def delete_api_gateway(api_gateway_client, api_gateway_id: str) -> Dict:
632632
def get_account_id() -> str:
633633
"""Get AWS Account ID"""
634634
return boto3.client("sts").get_caller_identity().get("Account")
635+
636+
637+
def list_s3_files(aws_region: str, path: str) -> List[str]:
638+
s3 = boto3.client("s3", config=Config(region_name=aws_region))
639+
if path.startswith("s3://"):
640+
path = path[len("s3://") :]
641+
bucket, prefix = path.split("/", 1)
642+
objects = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
643+
contents = objects["Contents"]
644+
files = [f"s3://{bucket}/{content['Key']}" for content in contents]
645+
return files

sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def teardown(self):
5151
self.dataset = None
5252

5353
def create_offline_store_config(self):
54-
return BigQueryOfflineStoreConfig()
54+
return BigQueryOfflineStoreConfig(
55+
location="US", gcs_staging_location="gs://feast-export/"
56+
)
5557

5658
def create_data_source(
5759
self,

sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def __init__(self, project_name: str, *args, **kwargs):
3434
warehouse=os.environ["SNOWFLAKE_CI_WAREHOUSE"],
3535
database="FEAST",
3636
schema="OFFLINE",
37+
storage_integration_name="FEAST_S3",
38+
blob_export_location="s3://feast-snowflake-offload/export",
3739
)
3840

3941
def create_data_source(

sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,11 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
370370
full_feature_names=full_feature_names,
371371
)
372372

373+
if job_from_df.supports_remote_storage_export():
374+
files = job_from_df.to_remote_storage()
375+
print(files)
376+
assert len(files) > 0 # This test should be way more detailed
377+
373378
start_time = datetime.utcnow()
374379
actual_df_from_df_entities = job_from_df.to_df()
375380

0 commit comments

Comments
 (0)