Skip to content

Commit f01b691

Browse files
merging
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent 3c94614 commit f01b691

File tree

1 file changed

+54
-1
lines changed

1 file changed

+54
-1
lines changed

sdk/python/tests/unit/test_on_demand_python_transformation.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from feast.field import Field
2222
from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig
2323
from feast.on_demand_feature_view import on_demand_feature_view
24+
<<<<<<< HEAD
2425
from feast.types import (
2526
Array,
2627
Bool,
@@ -33,6 +34,9 @@
3334
_utc_now,
3435
from_value_type,
3536
)
37+
=======
38+
from feast.types import Array, Bool, Float32, Float64, Int64, String, UnixTimestamp
39+
>>>>>>> b95d2a21 (updated test case)
3640

3741

3842
class TestOnDemandPythonTransformation(unittest.TestCase):
@@ -167,6 +171,29 @@ def python_singleton_view(inputs: dict[str, Any]) -> dict[str, Any]:
167171
)
168172
return output
169173

174+
@on_demand_feature_view(
175+
sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
176+
schema=[
177+
Field(name="conv_rate_plus_acc", dtype=Float64),
178+
Field(name="current_datetime", dtype=UnixTimestamp),
179+
],
180+
mode="python",
181+
write_to_online_store=True,
182+
)
183+
def python_stored_writes_feature_view(
184+
inputs: dict[str, Any],
185+
) -> dict[str, Any]:
186+
output: dict[str, Any] = {
187+
"conv_rate_plus_acc": [
188+
conv_rate + acc_rate
189+
for conv_rate, acc_rate in zip(
190+
inputs["conv_rate"], inputs["acc_rate"]
191+
)
192+
],
193+
"current_datetime": [datetime.now() for _ in inputs["conv_rate"]],
194+
}
195+
return output
196+
170197
with pytest.raises(TypeError):
171198
# Note the singleton view will fail as the type is
172199
# expected to be a list which can be confirmed in _infer_features_dict
@@ -191,6 +218,7 @@ def python_singleton_view(inputs: dict[str, Any]) -> dict[str, Any]:
191218
python_view,
192219
python_demo_view,
193220
driver_stats_entity_less_fv,
221+
python_stored_writes_feature_view,
194222
]
195223
)
196224
self.store.write_to_online_store(
@@ -203,7 +231,7 @@ def python_singleton_view(inputs: dict[str, Any]) -> dict[str, Any]:
203231

204232
assert len(self.store.list_all_feature_views()) == 5
205233
assert len(self.store.list_feature_views()) == 2
206-
assert len(self.store.list_on_demand_feature_views()) == 3
234+
assert len(self.store.list_on_demand_feature_views()) == 4
207235
assert len(self.store.list_stream_feature_views()) == 0
208236

209237
def test_python_pandas_parity(self):
@@ -291,6 +319,31 @@ def test_python_docs_demo(self):
291319
== online_python_response["conv_rate_plus_val2_python"][0]
292320
)
293321

322+
def test_stored_writes(self):
323+
entity_rows = [
324+
{
325+
"driver_id": 1001,
326+
}
327+
]
328+
329+
online_python_response = self.store.get_online_features(
330+
entity_rows=entity_rows,
331+
features=[
332+
"python_stored_writes_feature_view:conv_rate_plus_acc",
333+
"python_stored_writes_feature_view:current_datetime",
334+
],
335+
).to_dict()
336+
337+
assert sorted(list(online_python_response.keys())) == sorted(
338+
[
339+
"driver_id",
340+
"conv_rate_plus_acc",
341+
"current_datetime",
342+
]
343+
)
344+
print(online_python_response)
345+
# Now this is where we need to test the stored writes, this should return the same output as the previous
346+
294347

295348
class TestOnDemandPythonTransformationAllDataTypes(unittest.TestCase):
296349
def setUp(self):

0 commit comments

Comments
 (0)