Skip to content

Commit ed0cdf4

Browse files
committed
add integration test
Signed-off-by: HaoXuAI <sduxuhao@gmail.com>
1 parent 25af94e commit ed0cdf4

File tree

6 files changed

+212
-7
lines changed

6 files changed

+212
-7
lines changed

sdk/python/feast/batch_feature_view.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class BatchFeatureView(FeatureView):
6464
udf: Optional[Callable[[Any], Any]]
6565
udf_string: Optional[str]
6666
feature_transformation: Transformation
67+
batch_engine: Optional[Field]
6768

6869
def __init__(
6970
self,
@@ -82,6 +83,7 @@ def __init__(
8283
udf: Optional[Callable[[Any], Any]],
8384
udf_string: Optional[str] = "",
8485
feature_transformation: Optional[Transformation] = None,
86+
batch_engine: Optional[Field] = None,
8587
):
8688
if not flags_helper.is_test():
8789
warnings.warn(
@@ -105,6 +107,7 @@ def __init__(
105107
self.feature_transformation = (
106108
feature_transformation or self.get_feature_transformation()
107109
)
110+
self.batch_engine = batch_engine
108111

109112
super().__init__(
110113
name=name,
@@ -147,18 +150,21 @@ def batch_feature_view(
147150
source: Optional[DataSource] = None,
148151
tags: Optional[Dict[str, str]] = None,
149152
online: bool = True,
153+
offline: bool = True,
150154
description: str = "",
151155
owner: str = "",
152156
schema: Optional[List[Field]] = None,
153157
):
154158
"""
155159
Args:
156160
name:
161+
mode:
157162
entities:
158163
ttl:
159164
source:
160165
tags:
161166
online:
167+
offline:
162168
description:
163169
owner:
164170
schema:
@@ -184,6 +190,7 @@ def decorator(user_function):
184190
source=source,
185191
tags=tags,
186192
online=online,
193+
offline=offline,
187194
description=description,
188195
owner=owner,
189196
schema=schema,

sdk/python/feast/infra/compute_engines/dag/builder.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,14 @@ def build(self) -> ExecutionPlan:
6969
return ExecutionPlan(self.nodes)
7070

7171
def _should_join(self):
72-
return (
73-
self.feature_view.compute_config.join_strategy == "engine"
74-
or self.task.config.compute_engine.get("point_in_time_join") == "engine"
75-
)
72+
if hasattr(self.feature_view, "batch_engine"):
73+
return hasattr(self.feature_view.batch_engine, "join_strategy") and (
74+
self.feature_view.batch_engine.join_strategy == "engine"
75+
or self.task.config.batch_engine.get("point_in_time_join") == "engine"
76+
)
77+
if hasattr(self.feature_view, "batch_engine_config"):
78+
return hasattr(self.feature_view.stream_engine, "join_strategy") and (
79+
self.feature_view.stream_engine.join_strategy == "engine"
80+
or self.task.config.stream_engine.get("point_in_time_join") == "engine"
81+
)
82+
return False

sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@ def __init__(
2727
self.spark_session = spark_session
2828

2929
def build_source_node(self):
30-
source_path = self.feature_view.source.path
3130
if isinstance(self.task, MaterializationTask):
32-
node = SparkMaterializationReadNode("source", source_path)
31+
node = SparkMaterializationReadNode("source", self.task)
3332
else:
3433
node = SparkHistoricalRetrievalReadNode(
35-
"source", source_path, self.spark_session
34+
"source", self.task, self.spark_session
3635
)
3736
self.nodes.append(node)
3837
return node

sdk/python/feast/stream_feature_view.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class StreamFeatureView(FeatureView):
8383
udf: Optional[FunctionType]
8484
udf_string: Optional[str]
8585
feature_transformation: Optional[Transformation]
86+
stream_engine: Optional[Field]
8687

8788
def __init__(
8889
self,
@@ -103,6 +104,7 @@ def __init__(
103104
udf: Optional[FunctionType] = None,
104105
udf_string: Optional[str] = "",
105106
feature_transformation: Optional[Transformation] = None,
107+
stream_engine: Optional[Field] = None,
106108
):
107109
if not flags_helper.is_test():
108110
warnings.warn(
@@ -133,6 +135,7 @@ def __init__(
133135
self.feature_transformation = (
134136
feature_transformation or self.get_feature_transformation()
135137
)
138+
self.stream_engine = stream_engine
136139

137140
super().__init__(
138141
name=name,
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from datetime import timedelta
2+
3+
from pyspark.sql import DataFrame
4+
5+
from feast import BatchFeatureView, Entity, Field, FileSource
6+
from feast.types import Float32, Int32, Int64
7+
8+
driver_hourly_stats = FileSource(
9+
path="%PARQUET_PATH%", # placeholder to be replaced by the test
10+
timestamp_field="event_timestamp",
11+
created_timestamp_column="created",
12+
)
13+
14+
driver = Entity(
15+
name="driver_id",
16+
description="driver id",
17+
)
18+
19+
20+
def transform_feature(df: DataFrame) -> DataFrame:
21+
df = df.withColumn("conv_rate", df["conv_rate"] * 2)
22+
df = df.withColumn("acc_rate", df["acc_rate"] * 2)
23+
return df
24+
25+
26+
driver_hourly_stats_view = BatchFeatureView(
27+
name="driver_hourly_stats",
28+
entities=[driver],
29+
mode="python",
30+
udf=transform_feature,
31+
udf_string="transform_feature",
32+
ttl=timedelta(days=1),
33+
schema=[
34+
Field(name="conv_rate", dtype=Float32),
35+
Field(name="acc_rate", dtype=Float32),
36+
Field(name="avg_daily_trips", dtype=Int64),
37+
Field(name="driver_id", dtype=Int32),
38+
],
39+
online=True,
40+
offline=True,
41+
source=driver_hourly_stats,
42+
tags={},
43+
)
44+
45+
46+
global_daily_stats = FileSource(
47+
path="%PARQUET_PATH_GLOBAL%", # placeholder to be replaced by the test
48+
timestamp_field="event_timestamp",
49+
created_timestamp_column="created",
50+
)
51+
52+
53+
global_stats_feature_view = BatchFeatureView(
54+
name="global_daily_stats",
55+
entities=None,
56+
mode="python",
57+
udf=lambda x: x,
58+
ttl=timedelta(days=1),
59+
schema=[
60+
Field(name="num_rides", dtype=Int32),
61+
Field(name="avg_ride_length", dtype=Float32),
62+
],
63+
online=True,
64+
offline=True,
65+
source=global_daily_stats,
66+
tags={},
67+
)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from datetime import datetime, timedelta
2+
from typing import cast
3+
from unittest.mock import MagicMock
4+
5+
import pandas as pd
6+
import pytest
7+
8+
from feast.infra.compute_engines.base import HistoricalRetrievalTask
9+
from feast.infra.compute_engines.spark.compute import SparkComputeEngine
10+
from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob
11+
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
12+
SparkOfflineStore,
13+
)
14+
from feast.infra.offline_stores.contrib.spark_offline_store.tests.data_source import (
15+
SparkDataSourceCreator,
16+
)
17+
from tests.example_repos.example_feature_repo_with_bfvs_compute import (
18+
global_stats_feature_view,
19+
)
20+
from tests.integration.feature_repos.integration_test_repo_config import (
21+
IntegrationTestRepoConfig,
22+
)
23+
from tests.integration.feature_repos.repo_configuration import (
24+
construct_test_environment,
25+
)
26+
from tests.integration.feature_repos.universal.online_store.redis import (
27+
RedisOnlineStoreCreator,
28+
)
29+
30+
31+
@pytest.mark.integration
32+
def test_spark_compute_engine_get_historical_features():
33+
now = datetime.utcnow()
34+
35+
spark_config = IntegrationTestRepoConfig(
36+
provider="local",
37+
online_store_creator=RedisOnlineStoreCreator,
38+
offline_store_creator=SparkDataSourceCreator,
39+
batch_engine={"type": "spark.engine", "partitions": 10},
40+
)
41+
spark_environment = construct_test_environment(
42+
spark_config, None, entity_key_serialization_version=2
43+
)
44+
45+
spark_environment.setup()
46+
47+
# 👷 Prepare test parquet feature file
48+
df = pd.DataFrame(
49+
[
50+
{
51+
"driver_id": 1001,
52+
"event_timestamp": now - timedelta(days=1),
53+
"created": now - timedelta(hours=2),
54+
"conv_rate": 0.8,
55+
"acc_rate": 0.95,
56+
"avg_daily_trips": 15,
57+
},
58+
{
59+
"driver_id": 1001,
60+
"event_timestamp": now - timedelta(days=2),
61+
"created": now - timedelta(hours=3),
62+
"conv_rate": 0.75,
63+
"acc_rate": 0.9,
64+
"avg_daily_trips": 14,
65+
},
66+
{
67+
"driver_id": 1002,
68+
"event_timestamp": now - timedelta(days=1),
69+
"created": now - timedelta(hours=2),
70+
"conv_rate": 0.7,
71+
"acc_rate": 0.88,
72+
"avg_daily_trips": 12,
73+
},
74+
]
75+
)
76+
77+
ds = spark_environment.data_source_creator.create_data_source(
78+
df,
79+
spark_environment.feature_store.project,
80+
field_mapping={"ts_1": "ts"},
81+
)
82+
global_stats_feature_view.source = ds
83+
84+
# 📥 Entity DataFrame to join with
85+
entity_df = pd.DataFrame(
86+
[
87+
{"driver_id": 1001, "event_timestamp": now},
88+
{"driver_id": 1002, "event_timestamp": now},
89+
]
90+
)
91+
92+
# 🛠 Build retrieval task
93+
task = HistoricalRetrievalTask(
94+
entity_df=entity_df,
95+
feature_view=global_stats_feature_view,
96+
full_feature_name=False,
97+
registry=MagicMock(),
98+
config=spark_environment.config,
99+
start_time=now - timedelta(days=1),
100+
end_time=now,
101+
)
102+
103+
# 🧪 Run SparkComputeEngine
104+
engine = SparkComputeEngine(
105+
repo_config=task.config,
106+
offline_store=SparkOfflineStore(),
107+
online_store=MagicMock(),
108+
registry=MagicMock(),
109+
)
110+
111+
spark_dag_retrieval_job = engine.get_historical_features(task)
112+
spark_df = cast(SparkDAGRetrievalJob, spark_dag_retrieval_job).to_spark_df()
113+
df_out = spark_df.to_pandas().sort_values("driver_id").reset_index(drop=True)
114+
115+
# ✅ Assert output
116+
assert list(df_out.driver_id) == [1001, 1002]
117+
assert abs(df_out.loc[0]["conv_rate"] - 0.8) < 1e-6
118+
assert abs(df_out.loc[1]["conv_rate"] - 0.7) < 1e-6
119+
120+
121+
if __name__ == "__main__":
122+
test_spark_compute_engine_get_historical_features()

0 commit comments

Comments
 (0)