Skip to content

Commit 9594997

Browse files
author
Alex Mirrington
committed
Fix OnDemandFeatureView type inference for array types
Signed-off-by: Alex Mirrington <alex.mirrington@rokt.com>
1 parent de5b0eb commit 9594997

File tree

5 files changed

+217
-5
lines changed

5 files changed

+217
-5
lines changed

sdk/python/feast/transformation/pandas_transformation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def infer_features(self, random_input: dict[str, list[Any]]) -> list[Field]:
4444
Field(
4545
name=f,
4646
dtype=from_value_type(
47-
python_type_to_feast_value_type(f, type_name=str(dt))
47+
python_type_to_feast_value_type(
48+
f, value=output_df[f].tolist()[0], type_name=str(dt)
49+
)
4850
),
4951
)
5052
for f, dt in zip(output_df.columns, output_df.dtypes)

sdk/python/feast/transformation/python_transformation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def infer_features(self, random_input: dict[str, list[Any]]) -> list[Field]:
4444
Field(
4545
name=f,
4646
dtype=from_value_type(
47-
python_type_to_feast_value_type(f, type_name=type(dt[0]).__name__)
47+
python_type_to_feast_value_type(
48+
f, value=dt[0], type_name=type(dt[0]).__name__
49+
)
4850
),
4951
)
5052
for f, dt in output_dict.items()

sdk/python/feast/transformation/substrait_transformation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def infer_features(self, random_input: dict[str, list[Any]]) -> list[Field]:
6464
Field(
6565
name=f,
6666
dtype=from_value_type(
67-
python_type_to_feast_value_type(f, type_name=str(dt))
67+
python_type_to_feast_value_type(
68+
f, value=output_df[f].tolist()[0], type_name=str(dt)
69+
)
6870
),
6971
)
7072
for f, dt in zip(output_df.columns, output_df.dtypes)

sdk/python/feast/type_map.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def python_type_to_feast_value_type(
155155
"uint16": ValueType.INT32,
156156
"uint8": ValueType.INT32,
157157
"int8": ValueType.INT32,
158+
"bool_": ValueType.BOOL,
158159
"bool": ValueType.BOOL,
159160
"boolean": ValueType.BOOL,
160161
"timedelta": ValueType.UNIX_TIMESTAMP,

sdk/python/tests/unit/test_on_demand_pandas_transformation.py

Lines changed: 207 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,26 @@
44

55
import pandas as pd
66

7-
from feast import Entity, FeatureStore, FeatureView, FileSource, RepoConfig
7+
from feast import (
8+
Entity,
9+
FeatureStore,
10+
FeatureView,
11+
FileSource,
12+
RepoConfig,
13+
RequestSource,
14+
)
815
from feast.driver_test_data import create_driver_hourly_stats_df
916
from feast.field import Field
1017
from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig
1118
from feast.on_demand_feature_view import on_demand_feature_view
12-
from feast.types import Float32, Float64, Int64
19+
from feast.types import (
20+
Array,
21+
Bool,
22+
Float32,
23+
Float64,
24+
Int64,
25+
String,
26+
)
1327

1428

1529
def test_pandas_transformation():
@@ -91,3 +105,194 @@ def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
91105
assert online_response["conv_rate_plus_acc"].equals(
92106
online_response["conv_rate"] + online_response["acc_rate"]
93107
)
108+
109+
110+
def test_pandas_transformation_returning_all_data_types():
111+
with tempfile.TemporaryDirectory() as data_dir:
112+
store = FeatureStore(
113+
config=RepoConfig(
114+
project="test_on_demand_python_transformation",
115+
registry=os.path.join(data_dir, "registry.db"),
116+
provider="local",
117+
entity_key_serialization_version=2,
118+
online_store=SqliteOnlineStoreConfig(
119+
path=os.path.join(data_dir, "online.db")
120+
),
121+
)
122+
)
123+
124+
# Generate test data.
125+
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
126+
start_date = end_date - timedelta(days=15)
127+
128+
driver_entities = [1001, 1002, 1003, 1004, 1005]
129+
driver_df = create_driver_hourly_stats_df(driver_entities, start_date, end_date)
130+
driver_stats_path = os.path.join(data_dir, "driver_stats.parquet")
131+
driver_df.to_parquet(path=driver_stats_path, allow_truncated_timestamps=True)
132+
133+
driver = Entity(name="driver", join_keys=["driver_id"])
134+
135+
driver_stats_source = FileSource(
136+
name="driver_hourly_stats_source",
137+
path=driver_stats_path,
138+
timestamp_field="event_timestamp",
139+
created_timestamp_column="created",
140+
)
141+
142+
driver_stats_fv = FeatureView(
143+
name="driver_hourly_stats",
144+
entities=[driver],
145+
ttl=timedelta(days=0),
146+
schema=[
147+
Field(name="conv_rate", dtype=Float32),
148+
Field(name="acc_rate", dtype=Float32),
149+
Field(name="avg_daily_trips", dtype=Int64),
150+
],
151+
online=True,
152+
source=driver_stats_source,
153+
)
154+
155+
request_source = RequestSource(
156+
name="request_source",
157+
schema=[
158+
Field(name="avg_daily_trip_rank_thresholds", dtype=Array(Int64)),
159+
Field(name="avg_daily_trip_rank_names", dtype=Array(String)),
160+
],
161+
)
162+
163+
@on_demand_feature_view(
164+
sources=[request_source, driver_stats_fv],
165+
schema=[
166+
Field(name="highest_achieved_rank", dtype=String),
167+
Field(name="avg_daily_trips_plus_one", dtype=Int64),
168+
Field(name="conv_rate_plus_acc", dtype=Float64),
169+
Field(name="is_highest_rank", dtype=Bool),
170+
Field(name="achieved_ranks", dtype=Array(String)),
171+
Field(name="trips_until_next_rank_int", dtype=Array(Int64)),
172+
Field(name="trips_until_next_rank_float", dtype=Array(Float64)),
173+
Field(name="achieved_ranks_mask", dtype=Array(Bool)),
174+
],
175+
mode="pandas",
176+
)
177+
def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
178+
df = pd.DataFrame()
179+
df["conv_rate_plus_acc"] = inputs["conv_rate"] + inputs["acc_rate"]
180+
df["avg_daily_trips_plus_one"] = inputs["avg_daily_trips"] + 1
181+
182+
df["trips_until_next_rank_int"] = inputs[
183+
["avg_daily_trips", "avg_daily_trip_rank_thresholds"]
184+
].apply(
185+
lambda x: [max(threshold - x.iloc[0], 0) for threshold in x.iloc[1]],
186+
axis=1,
187+
)
188+
df["trips_until_next_rank_float"] = df["trips_until_next_rank_int"].map(
189+
lambda values: [float(value) for value in values]
190+
)
191+
df["achieved_ranks_mask"] = df["trips_until_next_rank_int"].map(
192+
lambda values: [value <= 0 for value in values]
193+
)
194+
195+
temp = pd.concat(
196+
[df[["achieved_ranks_mask"]], inputs[["avg_daily_trip_rank_names"]]],
197+
axis=1,
198+
)
199+
df["achieved_ranks"] = temp.apply(
200+
lambda x: [
201+
rank if achieved else "Locked"
202+
for achieved, rank in zip(x.iloc[0], x.iloc[1])
203+
],
204+
axis=1,
205+
)
206+
df["highest_achieved_rank"] = (
207+
df["achieved_ranks"]
208+
.map(
209+
lambda ranks: str([rank for rank in ranks if rank != "Locked"][-1])
210+
)
211+
.astype("string")
212+
)
213+
df["is_highest_rank"] = df["achieved_ranks"].map(
214+
lambda ranks: ranks[-1] != "Locked"
215+
)
216+
return df
217+
218+
store.apply([driver, driver_stats_source, driver_stats_fv, pandas_view])
219+
220+
entity_rows = [
221+
{
222+
"driver_id": 1001,
223+
"avg_daily_trip_rank_thresholds": [100, 250, 500, 1000],
224+
"avg_daily_trip_rank_names": ["Bronze", "Silver", "Gold", "Platinum"],
225+
}
226+
]
227+
store.write_to_online_store(
228+
feature_view_name="driver_hourly_stats", df=driver_df
229+
)
230+
231+
online_response = store.get_online_features(
232+
entity_rows=entity_rows,
233+
features=[
234+
"driver_hourly_stats:conv_rate",
235+
"driver_hourly_stats:acc_rate",
236+
"driver_hourly_stats:avg_daily_trips",
237+
"pandas_view:avg_daily_trips_plus_one",
238+
"pandas_view:conv_rate_plus_acc",
239+
"pandas_view:trips_until_next_rank_int",
240+
"pandas_view:trips_until_next_rank_float",
241+
"pandas_view:achieved_ranks_mask",
242+
"pandas_view:achieved_ranks",
243+
"pandas_view:highest_achieved_rank",
244+
"pandas_view:is_highest_rank",
245+
],
246+
).to_df()
247+
# We use to_df here to ensure we use the pandas backend, but convert to a dict for comparisons
248+
result = online_response.to_dict(orient="records")[0]
249+
print(result)
250+
251+
# Type assertions
252+
# Materialized view
253+
assert type(result["conv_rate"]) == float
254+
assert type(result["acc_rate"]) == float
255+
assert type(result["avg_daily_trips"]) == int
256+
# On-demand view
257+
assert type(result["avg_daily_trips_plus_one"]) == int
258+
assert type(result["conv_rate_plus_acc"]) == float
259+
assert type(result["highest_achieved_rank"]) == str
260+
assert type(result["is_highest_rank"]) == bool
261+
262+
assert type(result["trips_until_next_rank_int"]) == list
263+
assert all([type(e) == int for e in result["trips_until_next_rank_int"]])
264+
265+
assert type(result["trips_until_next_rank_float"]) == list
266+
assert all([type(e) == float for e in result["trips_until_next_rank_float"]])
267+
268+
assert type(result["achieved_ranks"]) == list
269+
assert all([type(e) == str for e in result["achieved_ranks"]])
270+
271+
assert type(result["achieved_ranks_mask"]) == list
272+
assert all([type(e) == bool for e in result["achieved_ranks_mask"]])
273+
274+
# Value assertions
275+
expected_trips_until_next_rank = [
276+
max(threshold - result["avg_daily_trips"], 0)
277+
for threshold in entity_rows[0]["avg_daily_trip_rank_thresholds"]
278+
]
279+
expected_mask = [value <= 0 for value in expected_trips_until_next_rank]
280+
expected_ranks = [
281+
rank if achieved else "Locked"
282+
for achieved, rank in zip(
283+
expected_mask, entity_rows[0]["avg_daily_trip_rank_names"]
284+
)
285+
]
286+
highest_rank = [rank for rank in expected_ranks if rank != "Locked"][-1]
287+
288+
assert result["conv_rate_plus_acc"] == result["conv_rate"] + result["acc_rate"]
289+
assert result["avg_daily_trips_plus_one"] == result["avg_daily_trips"] + 1
290+
assert result["highest_achieved_rank"] == highest_rank
291+
assert result["is_highest_rank"] == (expected_ranks[-1] != "Locked")
292+
293+
assert result["trips_until_next_rank_int"] == expected_trips_until_next_rank
294+
assert result["trips_until_next_rank_float"] == [
295+
float(value) for value in expected_trips_until_next_rank
296+
]
297+
assert result["achieved_ranks_mask"] == expected_mask
298+
assert result["achieved_ranks"] == expected_ranks

0 commit comments

Comments
 (0)