Skip to content

Commit 3b79f1f

Browse files
committed
add test to ensure only calling async when supported for the push mode
1 parent 974399b commit 3b79f1f

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

sdk/python/feast/feature_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ async def push(body=Depends(get_body)):
213213
and to in [PushMode.ONLINE, PushMode.ONLINE_AND_OFFLINE]
214214
)
215215
if should_push_async:
216+
print("im in async?")
216217
await store.push_async(**push_params)
217218
else:
218219
store.push(**push_params)

sdk/python/tests/foo_provider.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
from feast.infra.offline_stores.offline_store import RetrievalJob
2323
from feast.infra.provider import Provider
2424
from feast.infra.registry.base_registry import BaseRegistry
25+
from feast.infra.supported_async_methods import (
26+
ProviderAsyncMethods,
27+
SupportedAsyncMethods,
28+
)
2529
from feast.online_response import OnlineResponse
2630
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
2731
from feast.protos.feast.types.Value_pb2 import RepeatedValue
@@ -30,6 +34,20 @@
3034

3135

3236
class FooProvider(Provider):
37+
@staticmethod
38+
def with_async_support(online_read=False, online_write=False):
39+
class _FooProvider(FooProvider):
40+
@property
41+
def async_supported(self):
42+
return ProviderAsyncMethods(
43+
online=SupportedAsyncMethods(
44+
read=online_read,
45+
write=online_write,
46+
)
47+
)
48+
49+
return _FooProvider(None)
50+
3351
def __init__(self, config: RepoConfig):
3452
pass
3553

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import json
2+
from unittest.mock import AsyncMock, MagicMock, patch
3+
4+
import pytest
5+
from fastapi.testclient import TestClient
6+
7+
from feast import FeatureStore
8+
from feast.data_source import PushMode
9+
from feast.feature_server import get_app
10+
from feast.utils import _utc_now
11+
from tests.foo_provider import FooProvider
12+
13+
14+
@pytest.mark.parametrize(
15+
"online_write,push_mode,async_count",
16+
[
17+
(True, PushMode.ONLINE_AND_OFFLINE, 1),
18+
(True, PushMode.OFFLINE, 0),
19+
(True, PushMode.ONLINE, 1),
20+
(False, PushMode.ONLINE_AND_OFFLINE, 0),
21+
(False, PushMode.OFFLINE, 0),
22+
(False, PushMode.ONLINE, 0),
23+
],
24+
)
25+
def test_push_online_async_supported(online_write, push_mode, async_count, environment):
26+
push_payload = json.dumps(
27+
{
28+
"push_source_name": "location_stats_push_source",
29+
"df": {
30+
"location_id": [1],
31+
"temperature": [100],
32+
"event_timestamp": [str(_utc_now())],
33+
"created": [str(_utc_now())],
34+
},
35+
"to": push_mode.name.lower(),
36+
}
37+
)
38+
39+
provider = FooProvider.with_async_support(online_write=online_write)
40+
print(provider.async_supported.online.write)
41+
with patch.object(FeatureStore, "_get_provider", return_value=provider):
42+
fs = environment.feature_store
43+
fs.push = MagicMock()
44+
fs.push_async = AsyncMock()
45+
client = TestClient(get_app(fs))
46+
client.post("/push", data=push_payload)
47+
assert fs.push.call_count == 1 - async_count
48+
assert fs.push_async.await_count == async_count

0 commit comments

Comments
 (0)