Skip to content

Commit 658b18f

Browse files
authored
chore: Make feature server test a unit test (#4694)
* make feature server tests unit tests, remove integration Signed-off-by: Rob Howley <howley.robert@gmail.com> * use local example repo for test Signed-off-by: Rob Howley <howley.robert@gmail.com> * use example repo 1, switch to drive loc Signed-off-by: Rob Howley <howley.robert@gmail.com> * missing timestamp in payload Signed-off-by: Rob Howley <howley.robert@gmail.com> * combine tests Signed-off-by: Rob Howley <howley.robert@gmail.com> * simplify the request bodies Signed-off-by: Rob Howley <howley.robert@gmail.com> * use diff feature view Signed-off-by: Rob Howley <howley.robert@gmail.com> * fix feature name Signed-off-by: Rob Howley <howley.robert@gmail.com> * approx equal Signed-off-by: Rob Howley <howley.robert@gmail.com> * approx equal Signed-off-by: Rob Howley <howley.robert@gmail.com> * fix response format assertion Signed-off-by: Rob Howley <howley.robert@gmail.com> * dont need lazy fixture, simplify Signed-off-by: Rob Howley <howley.robert@gmail.com> --------- Signed-off-by: Rob Howley <howley.robert@gmail.com>
1 parent 1443da4 commit 658b18f

File tree

2 files changed

+124
-173
lines changed

2 files changed

+124
-173
lines changed

sdk/python/tests/integration/online_store/test_python_feature_server.py

Lines changed: 0 additions & 150 deletions
This file was deleted.
Lines changed: 124 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,75 @@
11
import json
2-
from unittest.mock import AsyncMock, MagicMock, patch
2+
from unittest.mock import AsyncMock, MagicMock
33

44
import pytest
55
from fastapi.testclient import TestClient
66

7-
from feast import FeatureStore
87
from feast.data_source import PushMode
8+
from feast.errors import PushSourceNotFoundException
99
from feast.feature_server import get_app
10+
from feast.online_response import OnlineResponse
11+
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesResponse
1012
from feast.utils import _utc_now
1113
from tests.foo_provider import FooProvider
14+
from tests.utils.cli_repo_creator import CliRunner, get_example_repo
15+
16+
17+
@pytest.fixture
18+
def mock_fs_factory():
19+
def builder(**async_support):
20+
provider = FooProvider.with_async_support(**async_support)
21+
fs = MagicMock()
22+
fs._get_provider.return_value = provider
23+
empty_response = OnlineResponse(GetOnlineFeaturesResponse(results=[]))
24+
fs.get_online_features = MagicMock(return_value=empty_response)
25+
fs.push = MagicMock()
26+
fs.get_online_features_async = AsyncMock(return_value=empty_response)
27+
fs.push_async = AsyncMock()
28+
return fs
29+
30+
return builder
31+
32+
33+
@pytest.fixture
34+
def test_client():
35+
runner = CliRunner()
36+
with runner.local_repo(
37+
get_example_repo("example_feature_repo_1.py"), "file"
38+
) as store:
39+
yield TestClient(get_app(store))
40+
41+
42+
def get_online_features_body():
43+
return {
44+
"features": [
45+
"pushed_driver_locations:driver_lat",
46+
"pushed_driver_locations:driver_long",
47+
],
48+
"entities": {"driver_id": [123]},
49+
}
50+
51+
52+
def push_body(push_mode=PushMode.ONLINE, lat=42.0):
53+
return {
54+
"push_source_name": "driver_locations_push",
55+
"df": {
56+
"driver_lat": [lat],
57+
"driver_long": ["42.0"],
58+
"driver_id": [123],
59+
"event_timestamp": [str(_utc_now())],
60+
"created_timestamp": [str(_utc_now())],
61+
},
62+
"to": push_mode.name.lower(),
63+
}
64+
65+
66+
@pytest.mark.parametrize("async_online_read", [True, False])
67+
def test_get_online_features_async_supported(async_online_read, mock_fs_factory):
68+
fs = mock_fs_factory(online_read=async_online_read)
69+
client = TestClient(get_app(fs))
70+
client.post("/get-online-features", json=get_online_features_body())
71+
assert fs.get_online_features.call_count == int(not async_online_read)
72+
assert fs.get_online_features_async.await_count == int(async_online_read)
1273

1374

1475
@pytest.mark.parametrize(
@@ -22,26 +83,66 @@
2283
(False, PushMode.ONLINE, 0),
2384
],
2485
)
25-
def test_push_online_async_supported(online_write, push_mode, async_count, environment):
26-
push_payload = json.dumps(
27-
{
28-
"push_source_name": "location_stats_push_source",
29-
"df": {
30-
"location_id": [1],
31-
"temperature": [100],
32-
"event_timestamp": [str(_utc_now())],
33-
"created": [str(_utc_now())],
34-
},
35-
"to": push_mode.name.lower(),
36-
}
86+
def test_push_online_async_supported(
87+
online_write, push_mode, async_count, mock_fs_factory
88+
):
89+
fs = mock_fs_factory(online_write=online_write)
90+
client = TestClient(get_app(fs))
91+
client.post("/push", json=push_body(push_mode))
92+
assert fs.push.call_count == 1 - async_count
93+
assert fs.push_async.await_count == async_count
94+
95+
96+
async def test_push_and_get(test_client):
97+
driver_lat = 55.1
98+
push_payload = push_body(lat=driver_lat)
99+
response = test_client.post("/push", json=push_payload)
100+
assert response.status_code == 200
101+
102+
# Check new pushed temperature is fetched
103+
request_payload = get_online_features_body()
104+
actual_resp = test_client.post("/get-online-features", json=request_payload)
105+
actual = json.loads(actual_resp.text)
106+
107+
ix = actual["metadata"]["feature_names"].index("driver_lat")
108+
assert actual["results"][ix]["values"][0] == pytest.approx(driver_lat, 0.0001)
109+
110+
assert_get_online_features_response_format(
111+
actual, request_payload["entities"]["driver_id"][0]
37112
)
38113

39-
provider = FooProvider.with_async_support(online_write=online_write)
40-
with patch.object(FeatureStore, "_get_provider", return_value=provider):
41-
fs = environment.feature_store
42-
fs.push = MagicMock()
43-
fs.push_async = AsyncMock()
44-
client = TestClient(get_app(fs))
45-
client.post("/push", data=push_payload)
46-
assert fs.push.call_count == 1 - async_count
47-
assert fs.push_async.await_count == async_count
114+
115+
def assert_get_online_features_response_format(parsed_response, expected_entity_id):
116+
assert "metadata" in parsed_response
117+
metadata = parsed_response["metadata"]
118+
expected_features = ["driver_id", "driver_lat", "driver_long"]
119+
response_feature_names = metadata["feature_names"]
120+
assert len(response_feature_names) == len(expected_features)
121+
for expected_feature in expected_features:
122+
assert expected_feature in response_feature_names
123+
assert "results" in parsed_response
124+
results = parsed_response["results"]
125+
for result in results:
126+
# Same order as in metadata
127+
assert len(result["statuses"]) == 1 # Requested one entity
128+
for status in result["statuses"]:
129+
assert status == "PRESENT"
130+
results_driver_id_index = response_feature_names.index("driver_id")
131+
assert results[results_driver_id_index]["values"][0] == expected_entity_id
132+
133+
134+
def test_push_source_does_not_exist(test_client):
135+
with pytest.raises(
136+
PushSourceNotFoundException,
137+
match="Unable to find push source 'push_source_does_not_exist'",
138+
):
139+
test_client.post(
140+
"/push",
141+
json={
142+
"push_source_name": "push_source_does_not_exist",
143+
"df": {
144+
"any_data": [1],
145+
"event_timestamp": [str(_utc_now())],
146+
},
147+
},
148+
)

0 commit comments

Comments
 (0)