From 4c71ed77151c8b248794eea2873580e7d99aab24 Mon Sep 17 00:00:00 2001 From: Miguel Trejo Date: Fri, 8 Apr 2022 22:18:29 -0500 Subject: [PATCH] fix: dynamodb batch request overwrite partition keys Signed-off-by: Miguel Trejo --- .../feast/infra/online_stores/dynamodb.py | 42 ++++++++++++------- .../test_dynamodb_online_store.py | 25 ++++++++++- sdk/python/tests/utils/online_store_utils.py | 6 ++- 3 files changed, 55 insertions(+), 18 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index 61334be1a92..01562ad900c 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -191,21 +191,7 @@ def online_write_batch( table_instance = dynamodb_resource.Table( _get_table_name(online_config, config, table) ) - with table_instance.batch_writer() as batch: - for entity_key, features, timestamp, created_ts in data: - entity_id = compute_entity_id(entity_key) - batch.put_item( - Item={ - "entity_id": entity_id, # PartitionKey - "event_ts": str(utils.make_tzaware(timestamp)), - "values": { - k: v.SerializeToString() - for k, v in features.items() # Serialized Features - }, - } - ) - if progress: - progress(1) + self._write_batch_non_duplicates(table_instance, data, progress) @log_exceptions_and_usage(online_store="dynamodb") def online_read( @@ -299,6 +285,32 @@ def _sort_dynamodb_response(self, responses: list, order: list): _, table_responses_ordered = zip(*table_responses_ordered) return table_responses_ordered + @log_exceptions_and_usage(online_store="dynamodb") + def _write_batch_non_duplicates( + self, + table_instance, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ): + """Deduplicate write batch request items on ``entity_id`` primary key.""" + with table_instance.batch_writer(overwrite_by_pkeys=["entity_id"]) as batch: + for entity_key, features, timestamp, created_ts in data: + entity_id = compute_entity_id(entity_key) + batch.put_item( + Item={ + "entity_id": entity_id, # PartitionKey + "event_ts": str(utils.make_tzaware(timestamp)), + "values": { + k: v.SerializeToString() + for k, v in features.items() # Serialized Features + }, + } + ) + if progress: + progress(1) + def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None): return boto3.client("dynamodb", region_name=region, endpoint_url=endpoint_url) diff --git a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py index 7b0c5a4a619..7d6da0dc06d 100644 --- a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py +++ b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py @@ -1,5 +1,7 @@ +from copy import deepcopy from dataclasses import dataclass +import boto3 import pytest from moto import mock_dynamodb2 @@ -162,7 +164,7 @@ def test_online_read(repo_config, n_samples): data = _create_n_customer_test_samples(n=n_samples) _insert_data_test_table(data, PROJECT, f"{TABLE_NAME}_{n_samples}", REGION) - entity_keys, features = zip(*data) + entity_keys, features, *rest = zip(*data) dynamodb_store = DynamoDBOnlineStore() returned_items = dynamodb_store.online_read( config=repo_config, @@ -171,3 +173,24 @@ def test_online_read(repo_config, n_samples): ) assert len(returned_items) == len(data) assert [item[1] for item in returned_items] == list(features) + + +@mock_dynamodb2 +def test_write_batch_non_duplicates(repo_config): + """Test DynamoDBOnline Store deduplicate write batch request items.""" + dynamodb_tbl = f"{TABLE_NAME}_batch_non_duplicates" + _create_test_table(PROJECT, dynamodb_tbl, REGION) + data = _create_n_customer_test_samples() + data_duplicate = deepcopy(data) + dynamodb_resource = boto3.resource("dynamodb", region_name=REGION) + table_instance = dynamodb_resource.Table(f"{PROJECT}.{dynamodb_tbl}") + dynamodb_store = DynamoDBOnlineStore() + # Insert duplicate data + dynamodb_store._write_batch_non_duplicates( + table_instance, data + data_duplicate, progress=None + ) + # Request more items than inserted + response = table_instance.scan(Limit=20) + returned_items = response.get("Items", None) + assert returned_items is not None + assert len(returned_items) == len(data) diff --git a/sdk/python/tests/utils/online_store_utils.py b/sdk/python/tests/utils/online_store_utils.py index ee90c2a5427..f72b4d5a2a3 100644 --- a/sdk/python/tests/utils/online_store_utils.py +++ b/sdk/python/tests/utils/online_store_utils.py @@ -19,6 +19,8 @@ def _create_n_customer_test_samples(n=10): "name": ValueProto(string_val="John"), "age": ValueProto(int64_val=3), }, + datetime.utcnow(), + None, ) for i in range(n) ] @@ -42,13 +44,13 @@ def _delete_test_table(project, tbl_name, region): def _insert_data_test_table(data, project, tbl_name, region): dynamodb_resource = boto3.resource("dynamodb", region_name=region) table_instance = dynamodb_resource.Table(f"{project}.{tbl_name}") - for entity_key, features in data: + for entity_key, features, timestamp, created_ts in data: entity_id = compute_entity_id(entity_key) with table_instance.batch_writer() as batch: batch.put_item( Item={ "entity_id": entity_id, - "event_ts": str(utils.make_tzaware(datetime.utcnow())), + "event_ts": str(utils.make_tzaware(timestamp)), "values": {k: v.SerializeToString() for k, v in features.items()}, } )