Skip to content

Commit 7991146

Browse files
committed
add async writer for dynamo
Signed-off-by: Rob Howley <howley.robert@gmail.com>
1 parent 7fdc291 commit 7991146

File tree

3 files changed

+138
-21
lines changed

3 files changed

+138
-21
lines changed

sdk/python/feast/infra/online_stores/dynamodb.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from feast.infra.online_stores.helpers import compute_entity_id
2727
from feast.infra.online_stores.online_store import OnlineStore
2828
from feast.infra.supported_async_methods import SupportedAsyncMethods
29+
from feast.infra.utils.aws_utils import dynamo_write_items_async
2930
from feast.protos.feast.core.DynamoDBTable_pb2 import (
3031
DynamoDBTable as DynamoDBTableProto,
3132
)
@@ -103,7 +104,7 @@ async def close(self):
103104

104105
@property
105106
def async_supported(self) -> SupportedAsyncMethods:
106-
return SupportedAsyncMethods(read=True)
107+
return SupportedAsyncMethods(read=True, write=True)
107108

108109
def update(
109110
self,
@@ -238,6 +239,42 @@ def online_write_batch(
238239
)
239240
self._write_batch_non_duplicates(table_instance, data, progress, config)
240241

242+
async def online_write_batch_async(
243+
self,
244+
config: RepoConfig,
245+
table: FeatureView,
246+
data: List[
247+
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
248+
],
249+
progress: Optional[Callable[[int], Any]],
250+
) -> None:
251+
"""
252+
Writes a batch of feature rows to the online store asynchronously.
253+
254+
If a tz-naive timestamp is passed to this method, it is assumed to be UTC.
255+
256+
Args:
257+
config: The config for the current feature store.
258+
table: Feature view to which these feature rows correspond.
259+
data: A list of quadruplets containing feature data. Each quadruplet contains an entity
260+
key, a dict containing feature values, an event timestamp for the row, and the created
261+
timestamp for the row if it exists.
262+
progress: Function to be called once a batch of rows is written to the online store, used
263+
to show progress.
264+
"""
265+
online_config = config.online_store
266+
assert isinstance(online_config, DynamoDBOnlineStoreConfig)
267+
268+
table_name = _get_table_name(online_config, config, table)
269+
items = [
270+
_to_write_item(config, entity_key, features, timestamp)
271+
for entity_key, features, timestamp, _ in data
272+
]
273+
client = _get_aiodynamodb_client(
274+
online_config.region, config.online_store.max_pool_connections
275+
)
276+
await dynamo_write_items_async(client, table_name, items)
277+
241278
def online_read(
242279
self,
243280
config: RepoConfig,
@@ -419,19 +456,8 @@ def _write_batch_non_duplicates(
419456
"""Deduplicate write batch request items on ``entity_id`` primary key."""
420457
with table_instance.batch_writer(overwrite_by_pkeys=["entity_id"]) as batch:
421458
for entity_key, features, timestamp, created_ts in data:
422-
entity_id = compute_entity_id(
423-
entity_key,
424-
entity_key_serialization_version=config.entity_key_serialization_version,
425-
)
426459
batch.put_item(
427-
Item={
428-
"entity_id": entity_id, # PartitionKey
429-
"event_ts": str(utils.make_tzaware(timestamp)),
430-
"values": {
431-
k: v.SerializeToString()
432-
for k, v in features.items() # Serialized Features
433-
},
434-
}
460+
Item=_to_write_item(config, entity_key, features, timestamp)
435461
)
436462
if progress:
437463
progress(1)
@@ -675,3 +701,18 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None
675701
region, endpoint_url
676702
)
677703
return self._dynamodb_resource
704+
705+
706+
def _to_write_item(config, entity_key, features, timestamp):
707+
entity_id = compute_entity_id(
708+
entity_key,
709+
entity_key_serialization_version=config.entity_key_serialization_version,
710+
)
711+
return {
712+
"entity_id": entity_id, # PartitionKey
713+
"event_ts": str(utils.make_tzaware(timestamp)),
714+
"values": {
715+
k: v.SerializeToString()
716+
for k, v in features.items() # Serialized Features
717+
},
718+
}

sdk/python/feast/infra/utils/aws_utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import asyncio
12
import contextlib
3+
import itertools
24
import os
35
import tempfile
46
import uuid
@@ -10,6 +12,7 @@
1012
import pyarrow as pa
1113
import pyarrow.parquet as pq
1214
from tenacity import (
15+
AsyncRetrying,
1316
retry,
1417
retry_if_exception_type,
1518
stop_after_attempt,
@@ -1076,3 +1079,54 @@ def upload_arrow_table_to_athena(
10761079
# Clean up S3 temporary data
10771080
# for file_path in uploaded_files:
10781081
# s3_resource.Object(bucket, file_path).delete()
1082+
1083+
1084+
class DynamoUnprocessedWriteItems(Exception):
1085+
pass
1086+
1087+
1088+
async def dynamo_write_items_async(
1089+
dynamo_client, table_name: str, items: list[Any]
1090+
) -> None:
1091+
DYNAMO_MAX_WRITE_BATCH_SIZE = 25
1092+
1093+
async def _do_write(items):
1094+
item_iter = iter(items)
1095+
item_batches = []
1096+
while True:
1097+
item_batch = [
1098+
item
1099+
for item in itertools.islice(item_iter, DYNAMO_MAX_WRITE_BATCH_SIZE)
1100+
]
1101+
if not item_batch:
1102+
break
1103+
1104+
item_batches.append(item_batch)
1105+
1106+
return await asyncio.gather(
1107+
*[
1108+
dynamo_client.batch_write_item(
1109+
RequestItems={table_name: item_batch},
1110+
)
1111+
for item_batch in item_batches
1112+
]
1113+
)
1114+
1115+
put_items = [{"PutRequest": {"Item": item}} for item in items]
1116+
1117+
retries = AsyncRetrying(
1118+
retry=retry_if_exception_type(DynamoUnprocessedWriteItems),
1119+
wait=wait_exponential(multiplier=1, max=4),
1120+
reraise=True,
1121+
)
1122+
1123+
for attempt in retries:
1124+
with attempt:
1125+
response_batches = await _do_write(put_items)
1126+
1127+
put_items = []
1128+
for response in response_batches:
1129+
put_items.extend(response["UnprocessedItems"])
1130+
1131+
if put_items:
1132+
raise DynamoUnprocessedWriteItems()

sdk/python/tests/integration/online_store/test_push_features_to_online_store.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,51 @@
88
from tests.integration.feature_repos.universal.entities import location
99

1010

11-
@pytest.mark.integration
12-
@pytest.mark.universal_online_stores
13-
def test_push_features_and_read(environment, universal_data_sources):
11+
@pytest.fixture
12+
def store(environment, universal_data_sources):
1413
store = environment.feature_store
1514
_, _, data_sources = universal_data_sources
1615
feature_views = construct_universal_feature_views(data_sources)
1716
location_fv = feature_views.pushed_locations
1817
store.apply([location(), location_fv])
18+
return store
19+
1920

21+
def _ingest_df():
2022
data = {
2123
"location_id": [1],
2224
"temperature": [4],
2325
"event_timestamp": [pd.Timestamp(_utc_now()).round("ms")],
2426
"created": [pd.Timestamp(_utc_now()).round("ms")],
2527
}
26-
df_ingest = pd.DataFrame(data)
28+
return pd.DataFrame(data)
2729

28-
store.push("location_stats_push_source", df_ingest)
30+
31+
def assert_response(online_resp):
32+
online_resp_dict = online_resp.to_dict()
33+
assert online_resp_dict["location_id"] == [1]
34+
assert online_resp_dict["temperature"] == [4]
35+
36+
37+
@pytest.mark.integration
38+
@pytest.mark.universal_online_stores
39+
def test_push_features_and_read(store):
40+
store.push("location_stats_push_source", _ingest_df())
2941

3042
online_resp = store.get_online_features(
3143
features=["pushable_location_stats:temperature"],
3244
entity_rows=[{"location_id": 1}],
3345
)
34-
online_resp_dict = online_resp.to_dict()
35-
assert online_resp_dict["location_id"] == [1]
36-
assert online_resp_dict["temperature"] == [4]
46+
assert_response(online_resp)
47+
48+
49+
@pytest.mark.integration
50+
@pytest.mark.universal_online_stores(only=["dynamodb"])
51+
async def test_push_features_and_read_async(store):
52+
await store.push_async("location_stats_push_source", _ingest_df())
53+
54+
online_resp = await store.get_online_features_async(
55+
features=["pushable_location_stats:temperature"],
56+
entity_rows=[{"location_id": 1}],
57+
)
58+
assert_response(online_resp)

0 commit comments

Comments
 (0)