Skip to content

Commit c3a102f

Browse files
authored
feat: Incorporate substrait ODFVs into ibis-based offline store queries (feast-dev#4102)
1 parent f2b4eb9 commit c3a102f

File tree

10 files changed

+132
-22
lines changed

10 files changed

+132
-22
lines changed

protos/feast/core/Transformation.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ message FeatureTransformationV2 {
2929

3030
message SubstraitTransformationV2 {
3131
bytes substrait_plan = 1;
32+
bytes ibis_function = 2;
3233
}

sdk/python/feast/infra/offline_stores/ibis.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,15 @@ def read_fv(
193193
event_timestamp_col=event_timestamp_col,
194194
)
195195

196+
odfvs = OnDemandFeatureView.get_requested_odfvs(feature_refs, project, registry)
197+
198+
substrait_odfvs = [fv for fv in odfvs if fv.mode == "substrait"]
199+
for odfv in substrait_odfvs:
200+
res = odfv.transform_ibis(res, full_feature_names)
201+
196202
return IbisRetrievalJob(
197203
res,
198-
OnDemandFeatureView.get_requested_odfvs(feature_refs, project, registry),
204+
[fv for fv in odfvs if fv.mode != "substrait"],
199205
full_feature_names,
200206
metadata=RetrievalMetadata(
201207
features=feature_refs,

sdk/python/feast/on_demand_feature_view.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,53 @@ def get_request_data_schema(self) -> Dict[str, ValueType]:
392392
def _get_projected_feature_name(self, feature: str) -> str:
393393
return f"{self.projection.name_to_use()}__{feature}"
394394

395+
def transform_ibis(
396+
self,
397+
ibis_table,
398+
full_feature_names: bool = False,
399+
):
400+
from ibis.expr.types import Table
401+
402+
if not isinstance(ibis_table, Table):
403+
raise TypeError("transform_ibis only accepts ibis.expr.types.Table")
404+
405+
assert type(self.feature_transformation) == SubstraitTransformation
406+
407+
columns_to_cleanup = []
408+
for source_fv_projection in self.source_feature_view_projections.values():
409+
for feature in source_fv_projection.features:
410+
full_feature_ref = f"{source_fv_projection.name}__{feature.name}"
411+
if full_feature_ref in ibis_table.columns:
412+
# Make sure the partial feature name is always present
413+
ibis_table = ibis_table.mutate(
414+
**{feature.name: ibis_table[full_feature_ref]}
415+
)
416+
columns_to_cleanup.append(feature.name)
417+
elif feature.name in ibis_table.columns:
418+
ibis_table = ibis_table.mutate(
419+
**{full_feature_ref: ibis_table[feature.name]}
420+
)
421+
columns_to_cleanup.append(full_feature_ref)
422+
423+
transformed_table = self.feature_transformation.transform_ibis(ibis_table)
424+
425+
transformed_table = transformed_table.drop(*columns_to_cleanup)
426+
427+
rename_columns: Dict[str, str] = {}
428+
for feature in self.features:
429+
short_name = feature.name
430+
long_name = self._get_projected_feature_name(feature.name)
431+
if short_name in transformed_table.columns and full_feature_names:
432+
rename_columns[short_name] = long_name
433+
elif not full_feature_names:
434+
rename_columns[long_name] = short_name
435+
436+
for rename_from, rename_to in rename_columns.items():
437+
if rename_from in transformed_table.columns:
438+
transformed_table = transformed_table.rename(**{rename_to: rename_from})
439+
440+
return transformed_table
441+
395442
def transform_arrow(
396443
self,
397444
pa_table: pyarrow.Table,
@@ -419,7 +466,7 @@ def transform_arrow(
419466
columns_to_cleanup.append(full_feature_ref)
420467

421468
df_with_transformed_features: pyarrow.Table = (
422-
self.feature_transformation.transform_arrow(pa_table)
469+
self.feature_transformation.transform_arrow(pa_table, self.features)
423470
)
424471

425472
# Work out whether the correct columns names are used.
@@ -438,7 +485,7 @@ def transform_arrow(
438485
# Cleanup extra columns used for transformation
439486
for col in columns_to_cleanup:
440487
if col in df_with_transformed_features.column_names:
441-
df_with_transformed_features = df_with_transformed_features.dtop(col)
488+
df_with_transformed_features = df_with_transformed_features.drop(col)
442489
return df_with_transformed_features.rename_columns(
443490
[
444491
rename_columns.get(c, c)
@@ -487,7 +534,9 @@ def get_transformed_features_df(
487534
rename_columns[long_name] = short_name
488535

489536
# Cleanup extra columns used for transformation
490-
df_with_features.drop(columns=columns_to_cleanup, inplace=True)
537+
df_with_transformed_features = df_with_transformed_features[
538+
[f.name for f in self.features]
539+
]
491540
return df_with_transformed_features.rename(columns=rename_columns)
492541

493542
def get_transformed_features_dict(

sdk/python/feast/transformation/pandas_transformation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ def __init__(self, udf: FunctionType, udf_string: str = ""):
2727
self.udf = udf
2828
self.udf_string = udf_string
2929

30-
def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table:
30+
def transform_arrow(
31+
self, pa_table: pyarrow.Table, features: List[Field]
32+
) -> pyarrow.Table:
3133
if not isinstance(pa_table, pyarrow.Table):
3234
raise TypeError(
3335
f"pa_table should be type pyarrow.Table but got {type(pa_table).__name__}"

sdk/python/feast/transformation/python_transformation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def __init__(self, udf: FunctionType, udf_string: str = ""):
2525
self.udf = udf
2626
self.udf_string = udf_string
2727

28-
def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table:
28+
def transform_arrow(
29+
self, pa_table: pyarrow.Table, features: List[Field]
30+
) -> pyarrow.Table:
2931
raise Exception(
3032
'OnDemandFeatureView mode "python" not supported for offline processing.'
3133
)

sdk/python/feast/transformation/substrait_transformation.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from types import FunctionType
12
from typing import Any, Dict, List
23

4+
import dill
35
import pandas as pd
46
import pyarrow
57
import pyarrow.substrait as substrait # type: ignore # noqa
@@ -16,14 +18,16 @@
1618

1719

1820
class SubstraitTransformation:
19-
def __init__(self, substrait_plan: bytes):
21+
def __init__(self, substrait_plan: bytes, ibis_function: FunctionType):
2022
"""
2123
Creates an SubstraitTransformation object.
2224
2325
Args:
2426
substrait_plan: The user-provided substrait plan.
27+
ibis_function: The user-provided ibis function.
2528
"""
2629
self.substrait_plan = substrait_plan
30+
self.ibis_function = ibis_function
2731

2832
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
2933
def table_provider(names, schema: pyarrow.Schema):
@@ -34,13 +38,22 @@ def table_provider(names, schema: pyarrow.Schema):
3438
).read_all()
3539
return table.to_pandas()
3640

37-
def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table:
41+
def transform_ibis(self, table):
42+
return self.ibis_function(table)
43+
44+
def transform_arrow(
45+
self, pa_table: pyarrow.Table, features: List[Field] = []
46+
) -> pyarrow.Table:
3847
def table_provider(names, schema: pyarrow.Schema):
3948
return pa_table.select(schema.names)
4049

4150
table: pyarrow.Table = pyarrow.substrait.run_query(
4251
self.substrait_plan, table_provider=table_provider
4352
).read_all()
53+
54+
if features:
55+
table = table.select([f.name for f in features])
56+
4457
return table
4558

4659
def infer_features(self, random_input: Dict[str, List[Any]]) -> List[Field]:
@@ -55,6 +68,7 @@ def infer_features(self, random_input: Dict[str, List[Any]]) -> List[Field]:
5568
),
5669
)
5770
for f, dt in zip(output_df.columns, output_df.dtypes)
71+
if f not in random_input
5872
]
5973

6074
def __eq__(self, other):
@@ -66,18 +80,26 @@ def __eq__(self, other):
6680
if not super().__eq__(other):
6781
return False
6882

69-
return self.substrait_plan == other.substrait_plan
83+
return (
84+
self.substrait_plan == other.substrait_plan
85+
and self.ibis_function.__code__.co_code
86+
== other.ibis_function.__code__.co_code
87+
)
7088

7189
def to_proto(self) -> SubstraitTransformationProto:
72-
return SubstraitTransformationProto(substrait_plan=self.substrait_plan)
90+
return SubstraitTransformationProto(
91+
substrait_plan=self.substrait_plan,
92+
ibis_function=dill.dumps(self.ibis_function, recurse=True),
93+
)
7394

7495
@classmethod
7596
def from_proto(
7697
cls,
7798
substrait_transformation_proto: SubstraitTransformationProto,
7899
):
79100
return SubstraitTransformation(
80-
substrait_plan=substrait_transformation_proto.substrait_plan
101+
substrait_plan=substrait_transformation_proto.substrait_plan,
102+
ibis_function=dill.loads(substrait_transformation_proto.ibis_function),
81103
)
82104

83105
@classmethod
@@ -91,7 +113,7 @@ def from_ibis(cls, user_function, sources):
91113
input_fields = []
92114

93115
for s in sources:
94-
fields = s.projection.features if isinstance(s, FeatureView) else s.features
116+
fields = s.projection.features if isinstance(s, FeatureView) else s.schema
95117

96118
input_fields.extend(
97119
[
@@ -108,5 +130,6 @@ def from_ibis(cls, user_function, sources):
108130
expr = user_function(ibis.table(input_fields, "t"))
109131

110132
return SubstraitTransformation(
111-
substrait_plan=compiler.compile(expr).SerializeToString()
133+
substrait_plan=compiler.compile(expr).SerializeToString(),
134+
ibis_function=user_function,
112135
)

sdk/python/tests/integration/feature_repos/repo_configuration.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,17 +356,23 @@ def values(self):
356356
def construct_universal_feature_views(
357357
data_sources: UniversalDataSources,
358358
with_odfv: bool = True,
359+
use_substrait_odfv: bool = False,
359360
) -> UniversalFeatureViews:
360361
driver_hourly_stats = create_driver_hourly_stats_feature_view(data_sources.driver)
361362
driver_hourly_stats_base_feature_view = (
362363
create_driver_hourly_stats_batch_feature_view(data_sources.driver)
363364
)
365+
364366
return UniversalFeatureViews(
365367
customer=create_customer_daily_profile_feature_view(data_sources.customer),
366368
global_fv=create_global_stats_feature_view(data_sources.global_ds),
367369
driver=driver_hourly_stats,
368370
driver_odfv=conv_rate_plus_100_feature_view(
369-
[driver_hourly_stats_base_feature_view, create_conv_rate_request_source()]
371+
[
372+
driver_hourly_stats_base_feature_view[["conv_rate"]],
373+
create_conv_rate_request_source(),
374+
],
375+
use_substrait_odfv=use_substrait_odfv,
370376
)
371377
if with_odfv
372378
else None,

sdk/python/tests/integration/feature_repos/universal/feature_views.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import pandas as pd
6+
from ibis.expr.types.relations import Table
67

78
from feast import (
89
BatchFeatureView,
@@ -15,7 +16,7 @@
1516
)
1617
from feast.data_source import DataSource, RequestSource
1718
from feast.feature_view_projection import FeatureViewProjection
18-
from feast.on_demand_feature_view import PandasTransformation
19+
from feast.on_demand_feature_view import PandasTransformation, SubstraitTransformation
1920
from feast.types import Array, FeastType, Float32, Float64, Int32, Int64
2021
from tests.integration.feature_repos.universal.entities import (
2122
customer,
@@ -56,10 +57,22 @@ def conv_rate_plus_100(features_df: pd.DataFrame) -> pd.DataFrame:
5657
return df
5758

5859

60+
def conv_rate_plus_100_ibis(features_table: Table) -> Table:
61+
return features_table.mutate(
62+
conv_rate_plus_100=features_table["conv_rate"] + 100,
63+
conv_rate_plus_val_to_add=features_table["conv_rate"]
64+
+ features_table["val_to_add"],
65+
conv_rate_plus_100_rounded=(features_table["conv_rate"] + 100)
66+
.round(digits=0)
67+
.cast("int32"),
68+
)
69+
70+
5971
def conv_rate_plus_100_feature_view(
6072
sources: List[Union[FeatureView, RequestSource, FeatureViewProjection]],
6173
infer_features: bool = False,
6274
features: Optional[List[Field]] = None,
75+
use_substrait_odfv: bool = False,
6376
) -> OnDemandFeatureView:
6477
# Test that positional arguments and Features still work for ODFVs.
6578
_features = features or [
@@ -73,7 +86,10 @@ def conv_rate_plus_100_feature_view(
7386
sources=sources,
7487
feature_transformation=PandasTransformation(
7588
udf=conv_rate_plus_100, udf_string="raw udf source"
76-
),
89+
)
90+
if not use_substrait_odfv
91+
else SubstraitTransformation.from_ibis(conv_rate_plus_100_ibis, sources),
92+
mode="pandas" if not use_substrait_odfv else "substrait",
7793
)
7894

7995

sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,19 @@
4141
@pytest.mark.integration
4242
@pytest.mark.universal_offline_stores
4343
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: f"full:{v}")
44-
def test_historical_features(environment, universal_data_sources, full_feature_names):
44+
@pytest.mark.parametrize(
45+
"use_substrait_odfv", [True, False], ids=lambda v: f"substrait:{v}"
46+
)
47+
def test_historical_features(
48+
environment, universal_data_sources, full_feature_names, use_substrait_odfv
49+
):
4550
store = environment.feature_store
4651

4752
(entities, datasets, data_sources) = universal_data_sources
4853

49-
feature_views = construct_universal_feature_views(data_sources)
54+
feature_views = construct_universal_feature_views(
55+
data_sources, use_substrait_odfv=use_substrait_odfv
56+
)
5057

5158
entity_df_with_request_data = datasets.entity_df.copy(deep=True)
5259
entity_df_with_request_data["val_to_add"] = [

sdk/python/tests/unit/test_substrait_transformation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,8 @@ def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
7575
mode="substrait",
7676
)
7777
def substrait_view(inputs: Table) -> Table:
78-
return inputs.select(
79-
(inputs["conv_rate"] + inputs["acc_rate"]).name(
80-
"conv_rate_plus_acc_substrait"
81-
)
78+
return inputs.mutate(
79+
conv_rate_plus_acc_substrait=inputs["conv_rate"] + inputs["acc_rate"]
8280
)
8381

8482
store.apply(

0 commit comments

Comments
 (0)