Skip to content

Commit 54c6080

Browse files
chore: Adding more tests for On Demand Feature Views (feast-dev#4069)
* checking in progress...trying to fix tests Signed-off-by: Francisco Javier Arceo <franciscojavierarceo@users.noreply.github.com> * testing more... Signed-off-by: Francisco Javier Arceo <franciscojavierarceo@users.noreply.github.com> * fixed Signed-off-by: Francisco Javier Arceo <franciscojavierarceo@users.noreply.github.com> * fixed some tests Signed-off-by: Francisco Javier Arceo <franciscojavierarceo@users.noreply.github.com> * fixed test and serialization Signed-off-by: Francisco Javier Arceo <franciscojavierarceo@users.noreply.github.com> * removed commented out code Signed-off-by: Francisco Javier Arceo <franciscojavierarceo@users.noreply.github.com> * lint Signed-off-by: Francisco Javier Arceo <franciscojavierarceo@users.noreply.github.com> * added a test to make it explicit that feature calculation must happen on a list Signed-off-by: Francisco Javier Arceo <franciscojavierarceo@users.noreply.github.com> * linter Signed-off-by: Francisco Javier Arceo <franciscojavierarceo@users.noreply.github.com> --------- Signed-off-by: Francisco Javier Arceo <franciscojavierarceo@users.noreply.github.com>
1 parent c06dda8 commit 54c6080

File tree

5 files changed

+303
-0
lines changed

5 files changed

+303
-0
lines changed

sdk/python/feast/on_demand_feature_view.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,23 @@ def from_proto(
300300
== "user_defined_function"
301301
and on_demand_feature_view_proto.spec.feature_transformation.user_defined_function.body_text
302302
!= ""
303+
and on_demand_feature_view_proto.spec.mode == "pandas"
303304
):
304305
transformation = PandasTransformation.from_proto(
305306
on_demand_feature_view_proto.spec.feature_transformation.user_defined_function
306307
)
308+
elif (
309+
on_demand_feature_view_proto.spec.feature_transformation.WhichOneof(
310+
"transformation"
311+
)
312+
== "user_defined_function"
313+
and on_demand_feature_view_proto.spec.feature_transformation.user_defined_function.body_text
314+
!= ""
315+
and on_demand_feature_view_proto.spec.mode == "python"
316+
):
317+
transformation = PythonTransformation.from_proto(
318+
on_demand_feature_view_proto.spec.feature_transformation.user_defined_function
319+
)
307320
elif (
308321
on_demand_feature_view_proto.spec.feature_transformation.WhichOneof(
309322
"transformation"

sdk/python/tests/example_repos/example_feature_repo_1.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from datetime import timedelta
22

3+
import pandas as pd
4+
35
from feast import Entity, FeatureService, FeatureView, Field, FileSource, PushSource
6+
from feast.on_demand_feature_view import on_demand_feature_view
47
from feast.types import Float32, Int64, String
58

69
# Note that file source paths are not validated, so there doesn't actually need to be any data
@@ -99,6 +102,17 @@
99102
)
100103

101104

105+
@on_demand_feature_view(
106+
sources=[customer_profile],
107+
schema=[Field(name="on_demand_age", dtype=Int64)],
108+
mode="pandas",
109+
)
110+
def customer_profile_pandas_odfv(inputs: pd.DataFrame) -> pd.DataFrame:
111+
outputs = pd.DataFrame()
112+
outputs["on_demand_age"] = inputs["age"] + 1
113+
return outputs
114+
115+
102116
all_drivers_feature_service = FeatureService(
103117
name="driver_locations_service",
104118
features=[driver_locations],

sdk/python/tests/unit/online_store/test_online_retrieval.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,17 @@ def test_online() -> None:
124124

125125
assert "trips" in result
126126

127+
result = store.get_online_features(
128+
features=["customer_profile_pandas_odfv:on_demand_age"],
129+
entity_rows=[{"driver_id": 1, "customer_id": "5"}],
130+
full_feature_names=False,
131+
).to_dict()
132+
133+
assert "on_demand_age" in result
134+
assert result["driver_id"] == [1]
135+
assert result["customer_id"] == ["5"]
136+
assert result["on_demand_age"] == [4]
137+
127138
# invalid table reference
128139
with pytest.raises(FeatureViewNotFoundException):
129140
store.get_online_features(
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import os
2+
import tempfile
3+
from datetime import datetime, timedelta
4+
5+
import pandas as pd
6+
7+
from feast import Entity, FeatureStore, FeatureView, FileSource, RepoConfig
8+
from feast.driver_test_data import create_driver_hourly_stats_df
9+
from feast.field import Field
10+
from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig
11+
from feast.on_demand_feature_view import on_demand_feature_view
12+
from feast.types import Float32, Float64, Int64
13+
14+
15+
def test_pandas_transformation():
16+
with tempfile.TemporaryDirectory() as data_dir:
17+
store = FeatureStore(
18+
config=RepoConfig(
19+
project="test_on_demand_python_transformation",
20+
registry=os.path.join(data_dir, "registry.db"),
21+
provider="local",
22+
entity_key_serialization_version=2,
23+
online_store=SqliteOnlineStoreConfig(
24+
path=os.path.join(data_dir, "online.db")
25+
),
26+
)
27+
)
28+
29+
# Generate test data.
30+
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
31+
start_date = end_date - timedelta(days=15)
32+
33+
driver_entities = [1001, 1002, 1003, 1004, 1005]
34+
driver_df = create_driver_hourly_stats_df(driver_entities, start_date, end_date)
35+
driver_stats_path = os.path.join(data_dir, "driver_stats.parquet")
36+
driver_df.to_parquet(path=driver_stats_path, allow_truncated_timestamps=True)
37+
38+
driver = Entity(name="driver", join_keys=["driver_id"])
39+
40+
driver_stats_source = FileSource(
41+
name="driver_hourly_stats_source",
42+
path=driver_stats_path,
43+
timestamp_field="event_timestamp",
44+
created_timestamp_column="created",
45+
)
46+
47+
driver_stats_fv = FeatureView(
48+
name="driver_hourly_stats",
49+
entities=[driver],
50+
ttl=timedelta(days=0),
51+
schema=[
52+
Field(name="conv_rate", dtype=Float32),
53+
Field(name="acc_rate", dtype=Float32),
54+
Field(name="avg_daily_trips", dtype=Int64),
55+
],
56+
online=True,
57+
source=driver_stats_source,
58+
)
59+
60+
@on_demand_feature_view(
61+
sources=[driver_stats_fv],
62+
schema=[Field(name="conv_rate_plus_acc", dtype=Float64)],
63+
mode="pandas",
64+
)
65+
def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
66+
df = pd.DataFrame()
67+
df["conv_rate_plus_acc"] = inputs["conv_rate"] + inputs["acc_rate"]
68+
return df
69+
70+
store.apply([driver, driver_stats_source, driver_stats_fv, pandas_view])
71+
72+
entity_rows = [
73+
{
74+
"driver_id": 1001,
75+
}
76+
]
77+
store.write_to_online_store(
78+
feature_view_name="driver_hourly_stats", df=driver_df
79+
)
80+
81+
online_response = store.get_online_features(
82+
entity_rows=entity_rows,
83+
features=[
84+
"driver_hourly_stats:conv_rate",
85+
"driver_hourly_stats:acc_rate",
86+
"driver_hourly_stats:avg_daily_trips",
87+
"pandas_view:conv_rate_plus_acc",
88+
],
89+
).to_df()
90+
91+
assert online_response["conv_rate_plus_acc"].equals(
92+
online_response["conv_rate"] + online_response["acc_rate"]
93+
)
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import os
2+
import tempfile
3+
import unittest
4+
from datetime import datetime, timedelta
5+
from typing import Any, Dict
6+
7+
import pandas as pd
8+
import pytest
9+
10+
from feast import Entity, FeatureStore, FeatureView, FileSource, RepoConfig
11+
from feast.driver_test_data import create_driver_hourly_stats_df
12+
from feast.field import Field
13+
from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig
14+
from feast.on_demand_feature_view import on_demand_feature_view
15+
from feast.types import Float32, Float64, Int64
16+
17+
18+
class TestOnDemandPythonTransformation(unittest.TestCase):
19+
def setUp(self):
20+
with tempfile.TemporaryDirectory() as data_dir:
21+
self.store = FeatureStore(
22+
config=RepoConfig(
23+
project="test_on_demand_python_transformation",
24+
registry=os.path.join(data_dir, "registry.db"),
25+
provider="local",
26+
entity_key_serialization_version=2,
27+
online_store=SqliteOnlineStoreConfig(
28+
path=os.path.join(data_dir, "online.db")
29+
),
30+
)
31+
)
32+
33+
# Generate test data.
34+
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
35+
start_date = end_date - timedelta(days=15)
36+
37+
driver_entities = [1001, 1002, 1003, 1004, 1005]
38+
driver_df = create_driver_hourly_stats_df(
39+
driver_entities, start_date, end_date
40+
)
41+
driver_stats_path = os.path.join(data_dir, "driver_stats.parquet")
42+
driver_df.to_parquet(
43+
path=driver_stats_path, allow_truncated_timestamps=True
44+
)
45+
46+
driver = Entity(name="driver", join_keys=["driver_id"])
47+
48+
driver_stats_source = FileSource(
49+
name="driver_hourly_stats_source",
50+
path=driver_stats_path,
51+
timestamp_field="event_timestamp",
52+
created_timestamp_column="created",
53+
)
54+
55+
driver_stats_fv = FeatureView(
56+
name="driver_hourly_stats",
57+
entities=[driver],
58+
ttl=timedelta(days=0),
59+
schema=[
60+
Field(name="conv_rate", dtype=Float32),
61+
Field(name="acc_rate", dtype=Float32),
62+
Field(name="avg_daily_trips", dtype=Int64),
63+
],
64+
online=True,
65+
source=driver_stats_source,
66+
)
67+
68+
@on_demand_feature_view(
69+
sources=[driver_stats_fv],
70+
schema=[Field(name="conv_rate_plus_acc_pandas", dtype=Float64)],
71+
mode="pandas",
72+
)
73+
def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
74+
df = pd.DataFrame()
75+
df["conv_rate_plus_acc_pandas"] = (
76+
inputs["conv_rate"] + inputs["acc_rate"]
77+
)
78+
return df
79+
80+
@on_demand_feature_view(
81+
sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
82+
schema=[Field(name="conv_rate_plus_acc_python", dtype=Float64)],
83+
mode="python",
84+
)
85+
def python_view(inputs: Dict[str, Any]) -> Dict[str, Any]:
86+
output: Dict[str, Any] = {
87+
"conv_rate_plus_acc_python": [
88+
conv_rate + acc_rate
89+
for conv_rate, acc_rate in zip(
90+
inputs["conv_rate"], inputs["acc_rate"]
91+
)
92+
]
93+
}
94+
return output
95+
96+
@on_demand_feature_view(
97+
sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
98+
schema=[
99+
Field(name="conv_rate_plus_acc_python_singleton", dtype=Float64)
100+
],
101+
mode="python",
102+
)
103+
def python_singleton_view(inputs: Dict[str, Any]) -> Dict[str, Any]:
104+
output: Dict[str, Any] = dict(conv_rate_plus_acc_python=float("-inf"))
105+
output["conv_rate_plus_acc_python_singleton"] = (
106+
inputs["conv_rate"] + inputs["acc_rate"]
107+
)
108+
return output
109+
110+
with pytest.raises(TypeError):
111+
# Note the singleton view will fail as the type is
112+
# expected to be a List which can be confirmed in _infer_features_dict
113+
self.store.apply(
114+
[
115+
driver,
116+
driver_stats_source,
117+
driver_stats_fv,
118+
pandas_view,
119+
python_view,
120+
python_singleton_view,
121+
]
122+
)
123+
124+
self.store.apply(
125+
[driver, driver_stats_source, driver_stats_fv, pandas_view, python_view]
126+
)
127+
self.store.write_to_online_store(
128+
feature_view_name="driver_hourly_stats", df=driver_df
129+
)
130+
131+
def test_python_pandas_parity(self):
132+
entity_rows = [
133+
{
134+
"driver_id": 1001,
135+
}
136+
]
137+
138+
online_python_response = self.store.get_online_features(
139+
entity_rows=entity_rows,
140+
features=[
141+
"driver_hourly_stats:conv_rate",
142+
"driver_hourly_stats:acc_rate",
143+
"python_view:conv_rate_plus_acc_python",
144+
],
145+
).to_dict()
146+
147+
online_pandas_response = self.store.get_online_features(
148+
entity_rows=entity_rows,
149+
features=[
150+
"driver_hourly_stats:conv_rate",
151+
"driver_hourly_stats:acc_rate",
152+
"pandas_view:conv_rate_plus_acc_pandas",
153+
],
154+
).to_df()
155+
156+
assert len(online_python_response) == 4
157+
assert all(
158+
key in online_python_response.keys()
159+
for key in [
160+
"driver_id",
161+
"acc_rate",
162+
"conv_rate",
163+
"conv_rate_plus_acc_python",
164+
]
165+
)
166+
assert len(online_python_response["conv_rate_plus_acc_python"]) == 1
167+
assert (
168+
online_python_response["conv_rate_plus_acc_python"][0]
169+
== online_pandas_response["conv_rate_plus_acc_pandas"][0]
170+
== online_python_response["conv_rate"][0]
171+
+ online_python_response["acc_rate"][0]
172+
)

0 commit comments

Comments
 (0)