Skip to content

Commit 9ed0a09

Browse files
authored
feat: Make arrow primary interchange for offline ODFV execution (#4083)
1 parent a05cdbc commit 9ed0a09

File tree

6 files changed

+104
-46
lines changed

6 files changed

+104
-46
lines changed

sdk/python/feast/infra/offline_stores/offline_store.py

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -76,36 +76,11 @@ def to_df(
7676
validation_reference (optional): The validation to apply against the retrieved dataframe.
7777
timeout (optional): The query timeout if applicable.
7878
"""
79-
features_df = self._to_df_internal(timeout=timeout)
80-
81-
if self.on_demand_feature_views:
82-
# TODO(adchia): Fix requirement to specify dependent feature views in feature_refs
83-
for odfv in self.on_demand_feature_views:
84-
if odfv.mode not in {"pandas", "substrait"}:
85-
raise Exception(
86-
f'OnDemandFeatureView mode "{odfv.mode}" not supported for offline processing.'
87-
)
88-
features_df = features_df.join(
89-
odfv.get_transformed_features_df(
90-
features_df,
91-
self.full_feature_names,
92-
)
93-
)
94-
95-
if validation_reference:
96-
if not flags_helper.is_test():
97-
warnings.warn(
98-
"Dataset validation is an experimental feature. "
99-
"This API is unstable and it could and most probably will be changed in the future. "
100-
"We do not guarantee that future changes will maintain backward compatibility.",
101-
RuntimeWarning,
102-
)
103-
104-
validation_result = validation_reference.profile.validate(features_df)
105-
if not validation_result.is_success:
106-
raise ValidationFailed(validation_result)
107-
108-
return features_df
79+
return (
80+
self.to_arrow(validation_reference=validation_reference, timeout=timeout)
81+
.to_pandas()
82+
.reset_index(drop=True)
83+
)
10984

11085
def to_arrow(
11186
self,
@@ -122,23 +97,20 @@ def to_arrow(
12297
validation_reference (optional): The validation to apply against the retrieved dataframe.
12398
timeout (optional): The query timeout if applicable.
12499
"""
125-
if not self.on_demand_feature_views and not validation_reference:
126-
return self._to_arrow_internal(timeout=timeout)
127-
128-
features_df = self._to_df_internal(timeout=timeout)
100+
features_table = self._to_arrow_internal(timeout=timeout)
129101
if self.on_demand_feature_views:
130102
for odfv in self.on_demand_feature_views:
131-
if odfv.mode not in {"pandas", "substrait"}:
132-
raise Exception(
133-
f'OnDemandFeatureView mode "{odfv.mode}" not supported for offline processing.'
134-
)
135-
features_df = features_df.join(
136-
odfv.get_transformed_features_df(
137-
features_df,
138-
self.full_feature_names,
139-
)
103+
transformed_arrow = odfv.transform_arrow(
104+
features_table, self.full_feature_names
140105
)
141106

107+
for col in transformed_arrow.column_names:
108+
if col.startswith("__index"):
109+
continue
110+
features_table = features_table.append_column(
111+
col, transformed_arrow[col]
112+
)
113+
142114
if validation_reference:
143115
if not flags_helper.is_test():
144116
warnings.warn(
@@ -148,11 +120,13 @@ def to_arrow(
148120
RuntimeWarning,
149121
)
150122

151-
validation_result = validation_reference.profile.validate(features_df)
123+
validation_result = validation_reference.profile.validate(
124+
features_table.to_pandas()
125+
)
152126
if not validation_result.is_success:
153127
raise ValidationFailed(validation_result)
154128

155-
return pyarrow.Table.from_pandas(features_df)
129+
return features_table
156130

157131
def to_sql(self) -> str:
158132
"""

sdk/python/feast/on_demand_feature_view.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import dill
1010
import pandas as pd
11+
import pyarrow
1112
from typeguard import typechecked
1213

1314
from feast.base_feature_view import BaseFeatureView
@@ -391,6 +392,60 @@ def get_request_data_schema(self) -> Dict[str, ValueType]:
391392
def _get_projected_feature_name(self, feature: str) -> str:
392393
return f"{self.projection.name_to_use()}__{feature}"
393394

395+
def transform_arrow(
396+
self,
397+
pa_table: pyarrow.Table,
398+
full_feature_names: bool = False,
399+
) -> pyarrow.Table:
400+
if not isinstance(pa_table, pyarrow.Table):
401+
raise TypeError("transform_arrow only accepts pyarrow.Table")
402+
columns_to_cleanup = []
403+
for source_fv_projection in self.source_feature_view_projections.values():
404+
for feature in source_fv_projection.features:
405+
full_feature_ref = f"{source_fv_projection.name}__{feature.name}"
406+
if full_feature_ref in pa_table.column_names:
407+
# Make sure the partial feature name is always present
408+
pa_table = pa_table.append_column(
409+
feature.name, pa_table[full_feature_ref]
410+
)
411+
# pa_table[feature.name] = pa_table[full_feature_ref]
412+
columns_to_cleanup.append(feature.name)
413+
elif feature.name in pa_table.column_names:
414+
# Make sure the full feature name is always present
415+
# pa_table[full_feature_ref] = pa_table[feature.name]
416+
pa_table = pa_table.append_column(
417+
full_feature_ref, pa_table[feature.name]
418+
)
419+
columns_to_cleanup.append(full_feature_ref)
420+
421+
df_with_transformed_features: pyarrow.Table = (
422+
self.feature_transformation.transform_arrow(pa_table)
423+
)
424+
425+
# Work out whether the correct columns names are used.
426+
rename_columns: Dict[str, str] = {}
427+
for feature in self.features:
428+
short_name = feature.name
429+
long_name = self._get_projected_feature_name(feature.name)
430+
if (
431+
short_name in df_with_transformed_features.column_names
432+
and full_feature_names
433+
):
434+
rename_columns[short_name] = long_name
435+
elif not full_feature_names:
436+
rename_columns[long_name] = short_name
437+
438+
# Cleanup extra columns used for transformation
439+
for col in columns_to_cleanup:
440+
if col in df_with_transformed_features.column_names:
441+
df_with_transformed_features = df_with_transformed_features.dtop(col)
442+
return df_with_transformed_features.rename_columns(
443+
[
444+
rename_columns.get(c, c)
445+
for c in df_with_transformed_features.column_names
446+
]
447+
)
448+
394449
def get_transformed_features_df(
395450
self,
396451
df_with_features: pd.DataFrame,

sdk/python/feast/transformation/pandas_transformation.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import dill
55
import pandas as pd
6+
import pyarrow
67

78
from feast.field import Field, from_value_type
89
from feast.protos.feast.core.Transformation_pb2 import (
@@ -26,6 +27,19 @@ def __init__(self, udf: FunctionType, udf_string: str = ""):
2627
self.udf = udf
2728
self.udf_string = udf_string
2829

30+
def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table:
31+
if not isinstance(pa_table, pyarrow.Table):
32+
raise TypeError(
33+
f"pa_table should be type pyarrow.Table but got {type(pa_table).__name__}"
34+
)
35+
output_df = self.udf.__call__(pa_table.to_pandas())
36+
output_df = pyarrow.Table.from_pandas(output_df)
37+
if not isinstance(output_df, pyarrow.Table):
38+
raise TypeError(
39+
f"output_df should be type pyarrow.Table but got {type(output_df).__name__}"
40+
)
41+
return output_df
42+
2943
def transform(self, input_df: pd.DataFrame) -> pd.DataFrame:
3044
if not isinstance(input_df, pd.DataFrame):
3145
raise TypeError(

sdk/python/feast/transformation/python_transformation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Dict, List
33

44
import dill
5+
import pyarrow
56

67
from feast.field import Field, from_value_type
78
from feast.protos.feast.core.Transformation_pb2 import (
@@ -24,6 +25,11 @@ def __init__(self, udf: FunctionType, udf_string: str = ""):
2425
self.udf = udf
2526
self.udf_string = udf_string
2627

28+
def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table:
29+
raise Exception(
30+
'OnDemandFeatureView mode "python" not supported for offline processing.'
31+
)
32+
2733
def transform(self, input_dict: Dict) -> Dict:
2834
if not isinstance(input_dict, Dict):
2935
raise TypeError(

sdk/python/feast/transformation/substrait_transformation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ def table_provider(names, schema: pyarrow.Schema):
3434
).read_all()
3535
return table.to_pandas()
3636

37+
def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table:
38+
def table_provider(names, schema: pyarrow.Schema):
39+
return pa_table.select(schema.names)
40+
41+
table: pyarrow.Table = pyarrow.substrait.run_query(
42+
self.substrait_plan, table_provider=table_provider
43+
).read_all()
44+
return table
45+
3746
def infer_features(self, random_input: Dict[str, List[Any]]) -> List[Field]:
3847
df = pd.DataFrame.from_dict(random_input)
3948
output_df: pd.DataFrame = self.transform(df)

sdk/python/tests/unit/infra/offline_stores/test_offline_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def test_to_sql():
216216

217217
@pytest.mark.parametrize("timeout", (None, 30))
218218
def test_to_df_timeout(retrieval_job, timeout: Optional[int]):
219-
with patch.object(retrieval_job, "_to_df_internal") as mock_to_df_internal:
219+
with patch.object(retrieval_job, "_to_arrow_internal") as mock_to_df_internal:
220220
retrieval_job.to_df(timeout=timeout)
221221
mock_to_df_internal.assert_called_once_with(timeout=timeout)
222222

0 commit comments

Comments
 (0)