diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index c577159884d..b5175bbf2f2 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -53,11 +53,13 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel): type: Literal["dynamodb"] = "dynamodb" """Online store type selector""" - batch_size: int = 40 - """Number of items to retrieve in a DynamoDB BatchGetItem call.""" + batch_size: int = 100 + """Number of items to retrieve in a DynamoDB BatchGetItem call. + DynamoDB supports up to 100 items per BatchGetItem request.""" endpoint_url: Union[str, None] = None - """DynamoDB local development endpoint Url, i.e. http://localhost:8000""" + """DynamoDB endpoint URL. Use for local development (e.g., http://localhost:8000) + or VPC endpoints for improved latency.""" region: StrictStr """AWS Region Name""" @@ -74,30 +76,33 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel): session_based_auth: bool = False """AWS session based client authentication""" - max_pool_connections: int = 10 - """Max number of connections for async Dynamodb operations""" + max_pool_connections: int = 50 + """Max number of connections for async Dynamodb operations. + Increase for high-throughput workloads.""" - keepalive_timeout: float = 12.0 - """Keep-alive timeout in seconds for async Dynamodb connections.""" + keepalive_timeout: float = 30.0 + """Keep-alive timeout in seconds for async Dynamodb connections. + Higher values help reuse connections under sustained load.""" - connect_timeout: Union[int, float] = 60 + connect_timeout: Union[int, float] = 5 """The time in seconds until a timeout exception is thrown when attempting to make - an async connection.""" + an async connection. Lower values enable faster failure detection.""" - read_timeout: Union[int, float] = 60 + read_timeout: Union[int, float] = 10 """The time in seconds until a timeout exception is thrown when attempting to read - from an async connection.""" + from an async connection. Lower values enable faster failure detection.""" - total_max_retry_attempts: Union[int, None] = None + total_max_retry_attempts: Union[int, None] = 3 """Maximum number of total attempts that will be made on a single request. Maps to `retries.total_max_attempts` in botocore.config.Config. """ - retry_mode: Union[Literal["legacy", "standard", "adaptive"], None] = None + retry_mode: Union[Literal["legacy", "standard", "adaptive"], None] = "adaptive" """The type of retry mode (aio)botocore should use. Maps to `retries.mode` in botocore.config.Config. + 'adaptive' mode provides intelligent retry with client-side rate limiting. """ @@ -111,16 +116,22 @@ class DynamoDBOnlineStore(OnlineStore): _aioboto_session: Async boto session. _aioboto_client: Async boto client. _aioboto_context_stack: Async context stack. + _type_deserializer: Cached TypeDeserializer instance for performance. """ _dynamodb_client = None _dynamodb_resource = None + # Class-level cached TypeDeserializer to avoid per-request instantiation + _type_deserializer: Optional[TypeDeserializer] = None def __init__(self): super().__init__() self._aioboto_session = None self._aioboto_client = None self._aioboto_context_stack = None + # Initialize cached TypeDeserializer if not already done + if DynamoDBOnlineStore._type_deserializer is None: + DynamoDBOnlineStore._type_deserializer = TypeDeserializer() async def initialize(self, config: RepoConfig): online_config = config.online_store @@ -133,6 +144,7 @@ async def initialize(self, config: RepoConfig): online_config.read_timeout, online_config.total_max_retry_attempts, online_config.retry_mode, + online_config.endpoint_url, ) async def close(self): @@ -153,6 +165,7 @@ async def _get_aiodynamodb_client( read_timeout: Union[int, float], total_max_retry_attempts: Union[int, None], retry_mode: Union[Literal["legacy", "standard", "adaptive"], None], + endpoint_url: Optional[str] = None, ): if self._aioboto_client is None: logger.debug("initializing the aiobotocore dynamodb client") @@ -163,16 +176,23 @@ async def _get_aiodynamodb_client( if retry_mode is not None: retries["mode"] = retry_mode - client_context = self._get_aioboto_session().create_client( - "dynamodb", - region_name=region, - config=AioConfig( + # Build client kwargs, including endpoint_url for VPC endpoints or local testing + client_kwargs: Dict[str, Any] = { + "region_name": region, + "config": AioConfig( max_pool_connections=max_pool_connections, connect_timeout=connect_timeout, read_timeout=read_timeout, retries=retries if retries else None, connector_args={"keepalive_timeout": keepalive_timeout}, ), + } + if endpoint_url: + client_kwargs["endpoint_url"] = endpoint_url + + client_context = self._get_aioboto_session().create_client( + "dynamodb", + **client_kwargs, ) self._aioboto_context_stack = contextlib.AsyncExitStack() self._aioboto_client = ( @@ -431,6 +451,7 @@ async def online_write_batch_async( online_config.read_timeout, online_config.total_max_retry_attempts, online_config.retry_mode, + online_config.endpoint_url, ) await dynamo_write_items_async(client, table_name, items) @@ -448,6 +469,7 @@ def online_read( config: The RepoConfig for the current FeatureStore. table: Feast FeatureView. entity_keys: a list of entity keys that should be read from the FeatureStore. + requested_features: Optional list of feature names to retrieve. """ online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) @@ -479,7 +501,9 @@ def online_read( RequestItems=batch_entity_ids, ) batch_result = self._process_batch_get_response( - table_instance.name, response, entity_ids, batch + table_instance.name, + response, + batch, ) result.extend(batch_result) return result @@ -513,7 +537,10 @@ async def online_read_async( entity_ids_iter = iter(entity_ids) table_name = _get_table_name(online_config, config, table) - deserialize = TypeDeserializer().deserialize + # Use cached TypeDeserializer for better performance + if self._type_deserializer is None: + self._type_deserializer = TypeDeserializer() + deserialize = self._type_deserializer.deserialize def to_tbl_resp(raw_client_response): return { @@ -542,6 +569,7 @@ def to_tbl_resp(raw_client_response): online_config.read_timeout, online_config.total_max_retry_attempts, online_config.retry_mode, + online_config.endpoint_url, ) response_batches = await asyncio.gather( *[ @@ -557,7 +585,6 @@ def to_tbl_resp(raw_client_response): result_batch = self._process_batch_get_response( table_name, response, - entity_ids, batch, to_tbl_response=to_tbl_resp, ) @@ -589,26 +616,6 @@ def _get_dynamodb_resource( ) return self._dynamodb_resource - def _sort_dynamodb_response( - self, - responses: list, - order: list, - to_tbl_response: Callable = lambda raw_dict: raw_dict, - ) -> Any: - """DynamoDB Batch Get Item doesn't return items in a particular order.""" - # Assign an index to order - order_with_index = {value: idx for idx, value in enumerate(order)} - # Sort table responses by index - table_responses_ordered: Any = [ - (order_with_index[tbl_res["entity_id"]], tbl_res) - for tbl_res in map(to_tbl_response, responses) - ] - table_responses_ordered = sorted( - table_responses_ordered, key=lambda tup: tup[0] - ) - _, table_responses_ordered = zip(*table_responses_ordered) - return table_responses_ordered - def _write_batch_non_duplicates( self, table_instance, @@ -630,37 +637,77 @@ def _write_batch_non_duplicates( progress(1) def _process_batch_get_response( - self, table_name, response, entity_ids, batch, **sort_kwargs - ): - response = response.get("Responses") - table_responses = response.get(table_name) + self, + table_name: str, + response: Dict[str, Any], + batch: List[str], + to_tbl_response: Callable = lambda raw_dict: raw_dict, + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + """Process batch get response using O(1) dictionary lookup. - batch_result = [] - if table_responses: - table_responses = self._sort_dynamodb_response( - table_responses, entity_ids, **sort_kwargs - ) - entity_idx = 0 - for tbl_res in table_responses: - entity_id = tbl_res["entity_id"] - while entity_id != batch[entity_idx]: - batch_result.append((None, None)) - entity_idx += 1 - res = {} - for feature_name, value_bin in tbl_res["values"].items(): + DynamoDB BatchGetItem doesn't return items in a particular order, + so we use a dictionary for O(1) lookup instead of O(n log n) sorting. + + This method: + - Uses dictionary lookup instead of sorting for response ordering + - Pre-allocates the result list with None values + - Minimizes object creation in the hot path + + Args: + table_name: Name of the DynamoDB table + response: Raw response from DynamoDB batch_get_item + batch: List of entity_ids in the order they should be returned + to_tbl_response: Function to transform raw DynamoDB response items + (used for async client responses that need deserialization) + + Returns: + List of (timestamp, features) tuples in the same order as batch + """ + responses_data = response.get("Responses") + if not responses_data: + # No responses at all, return all None tuples + return [(None, None)] * len(batch) + + table_responses = responses_data.get(table_name) + if not table_responses: + # No responses for this table, return all None tuples + return [(None, None)] * len(batch) + + # Build a dictionary for O(1) lookup instead of O(n log n) sorting + response_dict: Dict[str, Any] = { + tbl_res["entity_id"]: tbl_res + for tbl_res in map(to_tbl_response, table_responses) + } + + # Pre-allocate result list with None tuples (faster than appending) + batch_size = len(batch) + result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [ + (None, None) + ] * batch_size + + # Process each entity in batch order using O(1) dict lookup + for idx, entity_id in enumerate(batch): + tbl_res = response_dict.get(entity_id) + if tbl_res is not None: + # Parse feature values + features: Dict[str, ValueProto] = {} + values_data = tbl_res["values"] + for feature_name, value_bin in values_data.items(): val = ValueProto() val.ParseFromString(value_bin.value) - res[feature_name] = val - batch_result.append((datetime.fromisoformat(tbl_res["event_ts"]), res)) - entity_idx += 1 - # Not all entities in a batch may have responses - # Pad with remaining values in batch that were not found - batch_size_nones = ((None, None),) * (len(batch) - len(batch_result)) - batch_result.extend(batch_size_nones) - return batch_result + features[feature_name] = val + + # Parse timestamp and set result + result[idx] = ( + datetime.fromisoformat(tbl_res["event_ts"]), + features, + ) + + return result @staticmethod def _to_entity_ids(config: RepoConfig, entity_keys: List[EntityKeyProto]): + """Convert entity keys to entity IDs.""" return [ compute_entity_id( entity_key, diff --git a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py index 4127f699810..7c99a07d7aa 100644 --- a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py +++ b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py @@ -64,10 +64,17 @@ def test_dynamodb_online_store_config_default(): aws_region = "us-west-2" dynamodb_store_config = DynamoDBOnlineStoreConfig(region=aws_region) assert dynamodb_store_config.type == "dynamodb" - assert dynamodb_store_config.batch_size == 40 + assert dynamodb_store_config.batch_size == 100 assert dynamodb_store_config.endpoint_url is None assert dynamodb_store_config.region == aws_region assert dynamodb_store_config.table_name_template == "{project}.{table_name}" + # Verify other optimized defaults + assert dynamodb_store_config.max_pool_connections == 50 + assert dynamodb_store_config.keepalive_timeout == 30.0 + assert dynamodb_store_config.connect_timeout == 5 + assert dynamodb_store_config.read_timeout == 10 + assert dynamodb_store_config.total_max_retry_attempts == 3 + assert dynamodb_store_config.retry_mode == "adaptive" def test_dynamodb_online_store_config_custom_params():