-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathtest_feature_server.py
More file actions
209 lines (175 loc) · 7.33 KB
/
test_feature_server.py
File metadata and controls
209 lines (175 loc) · 7.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi.testclient import TestClient
from feast.data_source import PushMode
from feast.errors import PushSourceNotFoundException
from feast.feature_server import get_app
from feast.online_response import OnlineResponse
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesResponse
from feast.utils import _utc_now
from tests.foo_provider import FooProvider
from tests.utils.cli_repo_creator import CliRunner, get_example_repo
@pytest.fixture
def mock_fs_factory():
def builder(**async_support):
provider = FooProvider.with_async_support(**async_support)
fs = MagicMock()
fs._get_provider.return_value = provider
empty_response = OnlineResponse(GetOnlineFeaturesResponse(results=[]))
fs.get_online_features = MagicMock(return_value=empty_response)
fs.push = MagicMock()
fs.get_online_features_async = AsyncMock(return_value=empty_response)
fs.push_async = AsyncMock()
return fs
return builder
@pytest.fixture
def test_client():
runner = CliRunner()
with runner.local_repo(
get_example_repo("example_feature_repo_1.py"), "file"
) as store:
yield TestClient(get_app(store))
def get_online_features_body():
return {
"features": [
"pushed_driver_locations:driver_lat",
"pushed_driver_locations:driver_long",
],
"entities": {"driver_id": [123]},
}
def push_body(push_mode=PushMode.ONLINE, lat=42.0):
return {
"push_source_name": "driver_locations_push",
"df": {
"driver_lat": [lat],
"driver_long": ["42.0"],
"driver_id": [123],
"event_timestamp": [str(_utc_now())],
"created_timestamp": [str(_utc_now())],
},
"to": push_mode.name.lower(),
}
@pytest.mark.parametrize("async_online_read", [True, False])
def test_get_online_features_async_supported(async_online_read, mock_fs_factory):
fs = mock_fs_factory(online_read=async_online_read)
client = TestClient(get_app(fs))
client.post("/get-online-features", json=get_online_features_body())
assert fs.get_online_features.call_count == int(not async_online_read)
assert fs.get_online_features_async.await_count == int(async_online_read)
@pytest.mark.parametrize(
"online_write,push_mode,async_count",
[
(True, PushMode.ONLINE_AND_OFFLINE, 1),
(True, PushMode.OFFLINE, 0),
(True, PushMode.ONLINE, 1),
(False, PushMode.ONLINE_AND_OFFLINE, 0),
(False, PushMode.OFFLINE, 0),
(False, PushMode.ONLINE, 0),
],
)
def test_push_online_async_supported(
online_write, push_mode, async_count, mock_fs_factory
):
fs = mock_fs_factory(online_write=online_write)
client = TestClient(get_app(fs))
client.post("/push", json=push_body(push_mode))
assert fs.push.call_count == 1 - async_count
assert fs.push_async.await_count == async_count
async def test_push_and_get(test_client):
driver_lat = 55.1
push_payload = push_body(lat=driver_lat)
response = test_client.post("/push", json=push_payload)
assert response.status_code == 200
# Check new pushed temperature is fetched
request_payload = get_online_features_body()
actual_resp = test_client.post("/get-online-features", json=request_payload)
actual = json.loads(actual_resp.text)
ix = actual["metadata"]["feature_names"].index("driver_lat")
assert actual["results"][ix]["values"][0] == pytest.approx(driver_lat, 0.0001)
assert_get_online_features_response_format(
actual, request_payload["entities"]["driver_id"][0]
)
def assert_get_online_features_response_format(parsed_response, expected_entity_id):
assert "metadata" in parsed_response
metadata = parsed_response["metadata"]
expected_features = ["driver_id", "driver_lat", "driver_long"]
response_feature_names = metadata["feature_names"]
assert len(response_feature_names) == len(expected_features)
for expected_feature in expected_features:
assert expected_feature in response_feature_names
assert "results" in parsed_response
results = parsed_response["results"]
for result in results:
# Same order as in metadata
assert len(result["statuses"]) == 1 # Requested one entity
for status in result["statuses"]:
assert status == "PRESENT"
results_driver_id_index = response_feature_names.index("driver_id")
assert results[results_driver_id_index]["values"][0] == expected_entity_id
def test_push_source_does_not_exist(test_client):
with pytest.raises(
PushSourceNotFoundException,
match="Unable to find push source 'push_source_does_not_exist'",
):
test_client.post(
"/push",
json={
"push_source_name": "push_source_does_not_exist",
"df": {
"any_data": [1],
"event_timestamp": [str(_utc_now())],
},
},
)
def test_materialize_endpoint_logic():
"""Test the materialization endpoint logic without HTTP requests"""
from datetime import datetime
from feast.feature_server import MaterializeRequest
# Test 1: Standard request with timestamps
request = MaterializeRequest(
start_ts="2021-01-01T00:00:00",
end_ts="2021-01-02T00:00:00",
feature_views=["test_view"],
)
assert request.disable_event_timestamp is False
assert request.start_ts is not None
assert request.end_ts is not None
# Test 2: Request with disable_event_timestamp
request_no_ts = MaterializeRequest(
feature_views=["test_view"], disable_event_timestamp=True
)
assert request_no_ts.disable_event_timestamp is True
assert request_no_ts.start_ts is None
assert request_no_ts.end_ts is None
# Test 3: Validation logic (this is what our endpoint does)
# Simulate the endpoint's validation logic
if request_no_ts.disable_event_timestamp:
# Should use epoch to now
now = datetime.now()
start_date = datetime(1970, 1, 1)
end_date = now
# Should not raise an error
assert start_date < end_date
else:
# Should require timestamps
if not request_no_ts.start_ts or not request_no_ts.end_ts:
# This should trigger our validation error
pass
def test_materialize_request_model():
"""Test MaterializeRequest model validation"""
from feast.feature_server import MaterializeRequest
# Test with disable_event_timestamp=True (no timestamps needed)
req1 = MaterializeRequest(feature_views=["test"], disable_event_timestamp=True)
assert req1.disable_event_timestamp is True
assert req1.start_ts is None
assert req1.end_ts is None
# Test with disable_event_timestamp=False (timestamps provided)
req2 = MaterializeRequest(
start_ts="2021-01-01T00:00:00",
end_ts="2021-01-02T00:00:00",
feature_views=["test"],
)
assert req2.disable_event_timestamp is False
assert req2.start_ts == "2021-01-01T00:00:00"
assert req2.end_ts == "2021-01-02T00:00:00"