Skip to content

Commit ce5a130

Browse files
Agent007achals
andauthored
Respect specified ValueTypes for features during materialization (feast-dev#1906)
* assert float feature is still float from online store Signed-off-by: Jeff <jeffxl@apple.com> * ensure float features retain float type from online store Floats were converted to doubles when materialized to the online store. There is a broader bug trend around type conversions and this particular conversion utility function looks like it could use some cleanup. This commit is a quick fix. Signed-off-by: Jeff <jeffxl@apple.com> * make fix more general Signed-off-by: Achal Shah <achals@gmail.com> * Use assertAlmostEquals Signed-off-by: Achal Shah <achals@gmail.com> * format Signed-off-by: Achal Shah <achals@gmail.com> * Support pandas timestamps correctly Signed-off-by: Achal Shah <achals@gmail.com> * Support pandas timestamps correctly Signed-off-by: Achal Shah <achals@gmail.com> * Correct import Signed-off-by: Achal Shah <achals@gmail.com> Co-authored-by: Achal Shah <achals@gmail.com>
1 parent 6faf3a2 commit ce5a130

3 files changed

Lines changed: 31 additions & 14 deletions

File tree

sdk/python/feast/type_map.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,13 @@ def _type_err(item, dtype):
245245
ValueType, Tuple[str, Any, Optional[Set[Type]]]
246246
] = {
247247
ValueType.INT32: ("int32_val", lambda x: int(x), None),
248-
ValueType.INT64: ("int64_val", lambda x: int(x), None),
248+
ValueType.INT64: (
249+
"int64_val",
250+
lambda x: int(x.timestamp())
251+
if isinstance(x, pd._libs.tslibs.timestamps.Timestamp)
252+
else int(x),
253+
None,
254+
),
249255
ValueType.FLOAT: ("float_val", lambda x: float(x), None),
250256
ValueType.DOUBLE: ("double_val", lambda x: x, {float, np.float64}),
251257
ValueType.STRING: ("string_val", lambda x: str(x), None),
@@ -317,7 +323,7 @@ def python_value_to_proto_value(
317323
value: Any, feature_type: ValueType = ValueType.UNKNOWN
318324
) -> ProtoValue:
319325
value_type = feature_type
320-
if value is not None:
326+
if value is not None and feature_type == ValueType.UNKNOWN:
321327
if isinstance(value, (list, np.ndarray)):
322328
value_type = (
323329
feature_type

sdk/python/tests/integration/online_store/test_e2e_local.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def _assert_online_features(
2727
):
2828
"""Assert that features in online store are up to date with `max_date` date."""
2929
# Read features back
30-
result = store.get_online_features(
30+
response = store.get_online_features(
3131
features=[
3232
"driver_hourly_stats:conv_rate",
3333
"driver_hourly_stats:avg_daily_trips",
@@ -36,8 +36,14 @@ def _assert_online_features(
3636
],
3737
entity_rows=[{"driver_id": 1001}],
3838
full_feature_names=True,
39-
).to_dict()
39+
)
40+
41+
# Float features should still be floats from the online store...
42+
assert (
43+
response.field_values[0].fields["driver_hourly_stats__conv_rate"].float_val > 0
44+
)
4045

46+
result = response.to_dict()
4147
assert len(result) == 5
4248
assert "driver_hourly_stats__avg_daily_trips" in result
4349
assert "driver_hourly_stats__conv_rate" in result

sdk/python/tests/integration/online_store/test_universal_online.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,24 +110,27 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name
110110

111111
assert df_features["customer_id"] == online_features_dict["customer_id"][i]
112112
assert df_features["driver_id"] == online_features_dict["driver_id"][i]
113-
assert (
113+
tc.assertAlmostEqual(
114114
online_features_dict[
115115
response_feature_name("conv_rate_plus_100", full_feature_names)
116-
][i]
117-
== df_features["conv_rate"] + 100
116+
][i],
117+
df_features["conv_rate"] + 100,
118+
delta=0.0001,
118119
)
119-
assert (
120+
tc.assertAlmostEqual(
120121
online_features_dict[
121122
response_feature_name("conv_rate_plus_val_to_add", full_feature_names)
122-
][i]
123-
== df_features["conv_rate"] + df_features["val_to_add"]
123+
][i],
124+
df_features["conv_rate"] + df_features["val_to_add"],
125+
delta=0.0001,
124126
)
125127
for unprefixed_feature_ref in unprefixed_feature_refs:
126-
tc.assertEqual(
128+
tc.assertAlmostEqual(
127129
df_features[unprefixed_feature_ref],
128130
online_features_dict[
129131
response_feature_name(unprefixed_feature_ref, full_feature_names)
130132
][i],
133+
delta=0.0001,
131134
)
132135

133136
# Check what happens for missing values
@@ -254,13 +257,15 @@ def assert_feature_service_correctness(
254257
+ 3
255258
) # Add two for the driver id and the customer id entity keys and val_to_add request data
256259

260+
tc = unittest.TestCase()
257261
for i, entity_row in enumerate(entity_rows):
258262
df_features = get_latest_feature_values_from_dataframes(
259263
drivers_df, customers_df, orders_df, global_df, entity_row
260264
)
261-
assert (
265+
tc.assertAlmostEqual(
262266
feature_service_online_features_dict[
263267
response_feature_name("conv_rate_plus_100", full_feature_names)
264-
][i]
265-
== df_features["conv_rate"] + 100
268+
][i],
269+
df_features["conv_rate"] + 100,
270+
delta=0.0001,
266271
)

0 commit comments

Comments
 (0)