Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions aws_lambda_powertools/utilities/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import sys
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Tuple, Union, overload
from typing import TYPE_CHECKING, Any, Tuple, Union, cast, overload

from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.utilities.batch.exceptions import (
Expand All @@ -35,7 +35,9 @@

if TYPE_CHECKING:
from collections.abc import Callable
from types import TracebackType

from aws_lambda_powertools.logging import Logger
from aws_lambda_powertools.utilities.batch.types import (
PartialItemFailureResponse,
PartialItemFailures,
Expand Down Expand Up @@ -68,10 +70,11 @@ class BasePartialProcessor(ABC):

lambda_context: LambdaContext

def __init__(self):
def __init__(self, logger: logging.Logger | Logger | None = None):
self.success_messages: list[BatchEventTypes] = []
self.fail_messages: list[BatchEventTypes] = []
self.exceptions: list[ExceptionInfo] = []
self.logger = logger

@abstractmethod
def _prepare(self):
Expand Down Expand Up @@ -237,6 +240,22 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse:
exception_string = f"{exception[0]}:{exception[1]}"
entry = ("fail", exception_string, record)
logger.debug(f"Record processing exception: {exception_string}")

# Log with full traceback when a customer-provided logger is present
# and the exception carries a real traceback (e.g. not a synthetic FIFO circuit-breaker)
batch_logger = self.logger
if batch_logger is not None and exception[2] is not None:
# ExceptionInfo allows None on every slot, but logging.warning's exc_info
# requires a fully populated tuple. We already excluded synthetic exceptions
# (no traceback) above, so the type and value are guaranteed to be set.
assert exception[0] is not None
assert exception[1] is not None
exc_info = cast("tuple[type[BaseException], BaseException, TracebackType]", exception)
batch_logger.warning(
"Record processing exception; skipping this record",
exc_info=exc_info,
)

self.exceptions.append(exception)
self.fail_messages.append(record)
return entry
Expand All @@ -250,6 +269,7 @@ def __init__(
event_type: EventType,
model: BatchTypeModels | None = None,
raise_on_entire_batch_failure: bool = True,
logger: logging.Logger | Logger | None = None,
):
"""Process batch and partially report failed items

Expand All @@ -262,6 +282,8 @@ def __init__(
raise_on_entire_batch_failure: bool
Raise an exception when the entire batch has failed processing.
When set to False, partial failures are reported in the response
logger: logging.Logger | Logger | None
Optional Logger instance to output warnings with tracebacks for failed records.

Exceptions
----------
Expand All @@ -285,7 +307,7 @@ def __init__(
EventType.Kafka: KafkaEventRecord,
}

super().__init__()
super().__init__(logger=logger)

def response(self) -> PartialItemFailureResponse:
"""Batch items that failed processing, if any"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)

if TYPE_CHECKING:
from aws_lambda_powertools.logging import Logger
from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -66,7 +67,12 @@ def lambda_handler(event, context: LambdaContext):
None,
)

def __init__(self, model: BatchSqsTypeModel | None = None, skip_group_on_error: bool = False):
def __init__(
self,
model: BatchSqsTypeModel | None = None,
skip_group_on_error: bool = False,
logger: logging.Logger | Logger | None = None,
):
"""
Initialize the SqsFifoProcessor.

Expand All @@ -77,12 +83,14 @@ def __init__(self, model: BatchSqsTypeModel | None = None, skip_group_on_error:
skip_group_on_error: bool
Determines whether to exclusively skip messages from the MessageGroupID that encountered processing failures
Default is False.
logger: logging.Logger | Logger | None
Optional Logger instance to output warnings with tracebacks for failed records.

"""
self._skip_group_on_error: bool = skip_group_on_error
self._current_group_id = None
self._failed_group_ids: set[str] = set()
super().__init__(EventType.SQS, model)
super().__init__(EventType.SQS, model, logger=logger)

def _process_record(self, record):
self._current_group_id = record.get("attributes", {}).get("MessageGroupId")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -861,3 +861,88 @@ async def simple_async_handler(record: SQSRecord):
# THEN record is processed successfully using asyncio.run()
assert result == {"batchItemFailures": []}
assert result == {"batchItemFailures": []}


def test_batch_processor_logs_exception_with_injected_logger(sqs_event_factory, caplog):
import logging

from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, process_partial_response

fail_record = sqs_event_factory("fail")
success_record = sqs_event_factory("success")

def handler(record):
if "fail" in record["body"]:
raise ValueError("intentional failure")
return record["body"]

test_logger = logging.getLogger("test_logger")
processor = BatchProcessor(event_type=EventType.SQS, logger=test_logger)

with caplog.at_level(logging.WARNING, logger="test_logger"):
process_partial_response(
event={"Records": [fail_record, success_record]},
record_handler=handler,
processor=processor,
)

warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
assert len(warning_records) == 1, f"Expected 1 WARNING log, got {len(warning_records)}"
assert "intentional failure" in warning_records[0].getMessage() or warning_records[0].exc_info is not None
assert warning_records[0].exc_info is not None, "Expected exc_info (traceback) in log record"
assert warning_records[0].exc_info[0] is ValueError


def test_batch_processor_does_not_log_without_injected_logger(sqs_event_factory, caplog):
import logging

from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, process_partial_response

fail_record = sqs_event_factory("fail")

def handler(record):
raise ValueError("intentional failure")

processor = BatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_failure=False, logger=None)

with caplog.at_level(logging.WARNING, logger="aws_lambda_powertools.utilities.batch.base"):
process_partial_response(
event={"Records": [fail_record]},
record_handler=handler,
processor=processor,
)

warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
assert len(warning_records) == 0, "Expected no WARNING logs when logger is None"


def test_sqs_fifo_circuit_breaker_does_not_log(sqs_event_fifo_factory, caplog):
import logging

from aws_lambda_powertools.utilities.batch import SqsFifoPartialProcessor, process_partial_response

failing_record = sqs_event_fifo_factory("fail", "group-1")
short_circuited_record = sqs_event_fifo_factory("would-succeed", "group-1")

def handler(record):
if "fail" in record["body"]:
raise ValueError("first record failure")
return record["body"]

test_logger = logging.getLogger("test_logger")
processor = SqsFifoPartialProcessor(logger=test_logger)
processor.raise_on_entire_batch_failure = False

with caplog.at_level(logging.WARNING, logger="test_logger"):
process_partial_response(
event={"Records": [failing_record, short_circuited_record]},
record_handler=handler,
processor=processor,
)

warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
assert len(warning_records) == 1, (
f"Expected exactly 1 WARNING (real exception only), got {len(warning_records)}: "
+ str([r.getMessage() for r in warning_records])
)
assert warning_records[0].exc_info[0] is ValueError