Skip to content

Commit 12dbbea

Browse files
MattDelacTsotne Tabidze
authored andcommitted
Add to_table() to RetrievalJob object (feast-dev#1663)
* Add notion of OfflineJob Signed-off-by: Matt Delacour <matt.delacour@shopify.com> * Use RetrievalJob instead of creating a new OfflineJob object Signed-off-by: Matt Delacour <matt.delacour@shopify.com> * Add to_table() in integration tests Signed-off-by: Matt Delacour <matt.delacour@shopify.com> Co-authored-by: Tsotne Tabidze <tsotne@tecton.ai> Signed-off-by: Mwad22 <51929507+Mwad22@users.noreply.github.com>
1 parent c02b9eb commit 12dbbea

10 files changed

Lines changed: 102 additions & 71 deletions

File tree

sdk/python/feast/feature_store.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -397,12 +397,13 @@ def tqdm_builder(length):
397397
end_date = utils.make_tzaware(end_date)
398398

399399
provider.materialize_single_feature_view(
400-
feature_view,
401-
start_date,
402-
end_date,
403-
self._registry,
404-
self.project,
405-
tqdm_builder,
400+
config=self.config,
401+
feature_view=feature_view,
402+
start_date=start_date,
403+
end_date=end_date,
404+
registry=self._registry,
405+
project=self.project,
406+
tqdm_builder=tqdm_builder,
406407
)
407408

408409
self._registry.apply_materialization(
@@ -475,12 +476,13 @@ def tqdm_builder(length):
475476
end_date = utils.make_tzaware(end_date)
476477

477478
provider.materialize_single_feature_view(
478-
feature_view,
479-
start_date,
480-
end_date,
481-
self._registry,
482-
self.project,
483-
tqdm_builder,
479+
config=self.config,
480+
feature_view=feature_view,
481+
start_date=start_date,
482+
end_date=end_date,
483+
registry=self._registry,
484+
project=self.project,
485+
tqdm_builder=tqdm_builder,
484486
)
485487

486488
self._registry.apply_materialization(

sdk/python/feast/infra/gcp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def online_read(
8181

8282
def materialize_single_feature_view(
8383
self,
84+
config: RepoConfig,
8485
feature_view: FeatureView,
8586
start_date: datetime,
8687
end_date: datetime,
@@ -99,7 +100,8 @@ def materialize_single_feature_view(
99100
created_timestamp_column,
100101
) = _get_column_names(feature_view, entities)
101102

102-
table = self.offline_store.pull_latest_from_table_or_query(
103+
offline_job = self.offline_store.pull_latest_from_table_or_query(
104+
config=config,
103105
data_source=feature_view.input,
104106
join_key_columns=join_key_columns,
105107
feature_name_columns=feature_name_columns,
@@ -108,6 +110,7 @@ def materialize_single_feature_view(
108110
start_date=start_date,
109111
end_date=end_date,
110112
)
113+
table = offline_job.to_table()
111114

112115
if feature_view.input.field_mapping is not None:
113116
table = _run_field_mapping(table, feature_view.input.field_mapping)

sdk/python/feast/infra/local.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def online_read(
8080

8181
def materialize_single_feature_view(
8282
self,
83+
config: RepoConfig,
8384
feature_view: FeatureView,
8485
start_date: datetime,
8586
end_date: datetime,
@@ -98,15 +99,17 @@ def materialize_single_feature_view(
9899
created_timestamp_column,
99100
) = _get_column_names(feature_view, entities)
100101

101-
table = self.offline_store.pull_latest_from_table_or_query(
102+
offline_job = self.offline_store.pull_latest_from_table_or_query(
102103
data_source=feature_view.input,
103104
join_key_columns=join_key_columns,
104105
feature_name_columns=feature_name_columns,
105106
event_timestamp_column=event_timestamp_column,
106107
created_timestamp_column=created_timestamp_column,
107108
start_date=start_date,
108109
end_date=end_date,
110+
config=config,
109111
)
112+
table = offline_job.to_table()
110113

111114
if feature_view.input.field_mapping is not None:
112115
table = _run_field_mapping(table, feature_view.input.field_mapping)

sdk/python/feast/infra/offline_stores/bigquery.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
from feast.data_source import BigQuerySource, DataSource
1616
from feast.errors import FeastProviderLoginError
1717
from feast.feature_view import FeatureView
18-
from feast.infra.offline_stores.offline_store import OfflineStore
18+
from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob
1919
from feast.infra.provider import (
2020
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
21-
RetrievalJob,
2221
_get_requested_feature_views_to_features_dict,
2322
)
2423
from feast.registry import Registry
@@ -52,14 +51,15 @@ class BigQueryOfflineStoreConfig(FeastConfigBaseModel):
5251
class BigQueryOfflineStore(OfflineStore):
5352
@staticmethod
5453
def pull_latest_from_table_or_query(
54+
config: RepoConfig,
5555
data_source: DataSource,
5656
join_key_columns: List[str],
5757
feature_name_columns: List[str],
5858
event_timestamp_column: str,
5959
created_timestamp_column: Optional[str],
6060
start_date: datetime,
6161
end_date: datetime,
62-
) -> pyarrow.Table:
62+
) -> RetrievalJob:
6363
assert isinstance(data_source, BigQuerySource)
6464
from_expression = data_source.get_table_query_string()
6565

@@ -74,6 +74,7 @@ def pull_latest_from_table_or_query(
7474
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
7575
field_string = ", ".join(join_key_columns + feature_name_columns + timestamps)
7676

77+
client = _get_bigquery_client(project=config.offline_store.project_id)
7778
query = f"""
7879
SELECT {field_string}
7980
FROM (
@@ -84,14 +85,7 @@ def pull_latest_from_table_or_query(
8485
)
8586
WHERE _feast_row = 1
8687
"""
87-
88-
return BigQueryOfflineStore._pull_query(query)
89-
90-
@staticmethod
91-
def _pull_query(query: str) -> pyarrow.Table:
92-
client = _get_bigquery_client()
93-
query_job = client.query(query)
94-
return query_job.to_arrow()
88+
return BigQueryRetrievalJob(query=query, client=client, config=config)
9589

9690
@staticmethod
9791
def get_historical_features(
@@ -104,19 +98,18 @@ def get_historical_features(
10498
full_feature_names: bool = False,
10599
) -> RetrievalJob:
106100
# TODO: Add entity_df validation in order to fail before interacting with BigQuery
101+
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
107102

108-
client = _get_bigquery_client()
109-
103+
client = _get_bigquery_client(project=config.offline_store.project_id)
110104
expected_join_keys = _get_join_keys(project, feature_views, registry)
111105

112106
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
113-
dataset_project = config.offline_store.project_id or client.project
114107

115108
table = _upload_entity_df_into_bigquery(
116109
client=client,
117110
project=config.project,
118111
dataset_name=config.offline_store.dataset,
119-
dataset_project=dataset_project,
112+
dataset_project=client.project,
120113
entity_df=entity_df,
121114
)
122115

@@ -265,10 +258,7 @@ def _block_until_done():
265258
if not job_config:
266259
today = date.today().strftime("%Y%m%d")
267260
rand_id = str(uuid.uuid4())[:7]
268-
dataset_project = (
269-
self.config.offline_store.project_id or self.client.project
270-
)
271-
path = f"{dataset_project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
261+
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
272262
job_config = bigquery.QueryJobConfig(destination=path)
273263

274264
bq_job = self.client.query(self.query, job_config=job_config)
@@ -287,6 +277,9 @@ def _block_until_done():
287277
print(f"Done writing to '{job_config.destination}'.")
288278
return str(job_config.destination)
289279

280+
def to_table(self) -> pyarrow.Table:
281+
return self.client.query(self.query).to_arrow()
282+
290283

291284
@dataclass(frozen=True)
292285
class FeatureViewQueryContext:
@@ -451,9 +444,9 @@ def build_point_in_time_query(
451444
return query
452445

453446

454-
def _get_bigquery_client():
447+
def _get_bigquery_client(project: Optional[str] = None):
455448
try:
456-
client = bigquery.Client()
449+
client = bigquery.Client(project=project)
457450
except DefaultCredentialsError as e:
458451
raise FeastProviderLoginError(
459452
str(e)

sdk/python/feast/infra/offline_stores/file.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def to_df(self):
3838
df = self.evaluation_function()
3939
return df
4040

41+
def to_table(self):
42+
# Only execute the evaluation function to build the final historical retrieval dataframe at the last moment.
43+
df = self.evaluation_function()
44+
return pyarrow.Table.from_pandas(df)
45+
4146

4247
class FileOfflineStore(OfflineStore):
4348
@staticmethod
@@ -49,7 +54,7 @@ def get_historical_features(
4954
registry: Registry,
5055
project: str,
5156
full_feature_names: bool = False,
52-
) -> FileRetrievalJob:
57+
) -> RetrievalJob:
5358
if not isinstance(entity_df, pd.DataFrame):
5459
raise ValueError(
5560
f"Please provide an entity_df of type {type(pd.DataFrame)} instead of type {type(entity_df)}"
@@ -207,49 +212,56 @@ def evaluate_historical_retrieval():
207212

208213
@staticmethod
209214
def pull_latest_from_table_or_query(
215+
config: RepoConfig,
210216
data_source: DataSource,
211217
join_key_columns: List[str],
212218
feature_name_columns: List[str],
213219
event_timestamp_column: str,
214220
created_timestamp_column: Optional[str],
215221
start_date: datetime,
216222
end_date: datetime,
217-
) -> pyarrow.Table:
223+
) -> RetrievalJob:
218224
assert isinstance(data_source, FileSource)
219225

220-
source_df = pd.read_parquet(data_source.path)
221-
# Make sure all timestamp fields are tz-aware. We default tz-naive fields to UTC
222-
source_df[event_timestamp_column] = source_df[event_timestamp_column].apply(
223-
lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc)
224-
)
225-
if created_timestamp_column:
226-
source_df[created_timestamp_column] = source_df[
227-
created_timestamp_column
228-
].apply(lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc))
229-
230-
source_columns = set(source_df.columns)
231-
if not set(join_key_columns).issubset(source_columns):
232-
raise FeastJoinKeysDuringMaterialization(
233-
data_source.path, set(join_key_columns), source_columns
226+
# Create lazy function that is only called from the RetrievalJob object
227+
def evaluate_offline_job():
228+
source_df = pd.read_parquet(data_source.path)
229+
# Make sure all timestamp fields are tz-aware. We default tz-naive fields to UTC
230+
source_df[event_timestamp_column] = source_df[event_timestamp_column].apply(
231+
lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc)
234232
)
233+
if created_timestamp_column:
234+
source_df[created_timestamp_column] = source_df[
235+
created_timestamp_column
236+
].apply(
237+
lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc)
238+
)
235239

236-
ts_columns = (
237-
[event_timestamp_column, created_timestamp_column]
238-
if created_timestamp_column
239-
else [event_timestamp_column]
240-
)
240+
source_columns = set(source_df.columns)
241+
if not set(join_key_columns).issubset(source_columns):
242+
raise FeastJoinKeysDuringMaterialization(
243+
data_source.path, set(join_key_columns), source_columns
244+
)
241245

242-
source_df.sort_values(by=ts_columns, inplace=True)
246+
ts_columns = (
247+
[event_timestamp_column, created_timestamp_column]
248+
if created_timestamp_column
249+
else [event_timestamp_column]
250+
)
243251

244-
filtered_df = source_df[
245-
(source_df[event_timestamp_column] >= start_date)
246-
& (source_df[event_timestamp_column] < end_date)
247-
]
248-
last_values_df = filtered_df.drop_duplicates(
249-
join_key_columns, keep="last", ignore_index=True
250-
)
252+
source_df.sort_values(by=ts_columns, inplace=True)
251253

252-
columns_to_extract = set(join_key_columns + feature_name_columns + ts_columns)
253-
table = pyarrow.Table.from_pandas(last_values_df[columns_to_extract])
254+
filtered_df = source_df[
255+
(source_df[event_timestamp_column] >= start_date)
256+
& (source_df[event_timestamp_column] < end_date)
257+
]
258+
last_values_df = filtered_df.drop_duplicates(
259+
join_key_columns, keep="last", ignore_index=True
260+
)
261+
262+
columns_to_extract = set(
263+
join_key_columns + feature_name_columns + ts_columns
264+
)
265+
return last_values_df[columns_to_extract]
254266

255-
return table
267+
return FileRetrievalJob(evaluation_function=evaluate_offline_job)

sdk/python/feast/infra/offline_stores/offline_store.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,15 @@ class RetrievalJob(ABC):
2828
"""RetrievalJob is used to manage the execution of a historical feature retrieval"""
2929

3030
@abstractmethod
31-
def to_df(self):
31+
def to_df(self) -> pd.DataFrame:
3232
"""Return dataset as Pandas DataFrame synchronously"""
3333
pass
3434

35+
@abstractmethod
36+
def to_table(self) -> pyarrow.Table:
37+
"""Return dataset as pyarrow Table synchronously"""
38+
pass
39+
3540

3641
class OfflineStore(ABC):
3742
"""
@@ -42,14 +47,15 @@ class OfflineStore(ABC):
4247
@staticmethod
4348
@abstractmethod
4449
def pull_latest_from_table_or_query(
50+
config: RepoConfig,
4551
data_source: DataSource,
4652
join_key_columns: List[str],
4753
feature_name_columns: List[str],
4854
event_timestamp_column: str,
4955
created_timestamp_column: Optional[str],
5056
start_date: datetime,
5157
end_date: datetime,
52-
) -> pyarrow.Table:
58+
) -> RetrievalJob:
5359
"""
5460
Note that join_key_columns, feature_name_columns, event_timestamp_column, and created_timestamp_column
5561
have all already been mapped to column names of the source table and those column names are the values passed

sdk/python/feast/infra/offline_stores/redshift.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import List, Optional, Union
33

44
import pandas as pd
5-
import pyarrow
65
from pydantic import StrictStr
76
from pydantic.typing import Literal
87

@@ -38,14 +37,15 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel):
3837
class RedshiftOfflineStore(OfflineStore):
3938
@staticmethod
4039
def pull_latest_from_table_or_query(
40+
config: RepoConfig,
4141
data_source: DataSource,
4242
join_key_columns: List[str],
4343
feature_name_columns: List[str],
4444
event_timestamp_column: str,
4545
created_timestamp_column: Optional[str],
4646
start_date: datetime,
4747
end_date: datetime,
48-
) -> pyarrow.Table:
48+
) -> RetrievalJob:
4949
pass
5050

5151
@staticmethod

sdk/python/feast/infra/provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def online_write_batch(
9797
@abc.abstractmethod
9898
def materialize_single_feature_view(
9999
self,
100+
config: RepoConfig,
100101
feature_view: FeatureView,
101102
start_date: datetime,
102103
end_date: datetime,

sdk/python/tests/foo_provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def online_write_batch(
4545

4646
def materialize_single_feature_view(
4747
self,
48+
config: RepoConfig,
4849
feature_view: FeatureView,
4950
start_date: datetime,
5051
end_date: datetime,

0 commit comments

Comments
 (0)