diff --git a/.gemini/config.yaml b/.gemini/config.yaml new file mode 100644 index 000000000..518d8fdf8 --- /dev/null +++ b/.gemini/config.yaml @@ -0,0 +1,3 @@ +code_review: + comment_severity_threshold: LOW +ignore_patterns: ['CHANGELOG.md'] diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 1216135c3..4d5c1d2c7 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -26,6 +26,7 @@ coc codegen coro datamodel +deepwiki drivername DSNs dunders @@ -80,5 +81,6 @@ tagwords taskupdate testuuid Tful +tiangolo typeerror vulnz diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index d00164ab2..ce8d62ab9 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -53,7 +53,7 @@ jobs: run: | echo "$HOME/.cargo/bin" >> $GITHUB_PATH - name: Install dependencies - run: uv sync --dev --extra sql --extra encryption --extra grpc --extra telemetry + run: uv sync --dev --extra all - name: Run tests and check coverage run: uv run pytest --cov=a2a --cov-report term --cov-fail-under=88 - name: Show coverage summary in log diff --git a/.ruff.toml b/.ruff.toml index 42a2340a2..3562de26e 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -32,6 +32,7 @@ ignore = [ "TRY003", "TRY201", "FIX002", + "UP038", ] select = [ diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c9062255..a73176991 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## [0.3.6](https://github.com/a2aproject/a2a-python/compare/v0.3.5...v0.3.6) (2025-09-09) + + +### Features + +* add JSON-RPC `method` to `ServerCallContext.state` ([d62df7a](https://github.com/a2aproject/a2a-python/commit/d62df7a77e556f26556fc798a55dc6dacec21ea4)) +* **gRPC:** Add proto conversion utilities ([80fc33a](https://github.com/a2aproject/a2a-python/commit/80fc33aaef647826208d9020ef70e5e6592468e3)) + ## [0.3.5](https://github.com/a2aproject/a2a-python/compare/v0.3.4...v0.3.5) (2025-09-08) diff --git a/README.md b/README.md index a3463b07a..4c8f83785 100644 --- a/README.md +++ b/README.md @@ -47,20 +47,17 @@ Install the core SDK and any desired extras using your preferred package manager | Feature | `uv` Command | `pip` Command | | ------------------------ | ------------------------------------------ | -------------------------------------------- | | **Core SDK** | `uv add a2a-sdk` | `pip install a2a-sdk` | +| **All Extras** | `uv add a2a-sdk[all]` | `pip install a2a-sdk[all]` | | **HTTP Server** | `uv add "a2a-sdk[http-server]"` | `pip install "a2a-sdk[http-server]"` | | **gRPC Support** | `uv add "a2a-sdk[grpc]"` | `pip install "a2a-sdk[grpc]"` | | **OpenTelemetry Tracing**| `uv add "a2a-sdk[telemetry]"` | `pip install "a2a-sdk[telemetry]"` | - -#### Database Support - -Install the necessary drivers for your chosen SQL database. - -| Database | `uv` Command | `pip` Command | -| ------------- | ---------------------------------- | ------------------------------------ | -| **PostgreSQL**| `uv add "a2a-sdk[postgresql]"` | `pip install "a2a-sdk[postgresql]"` | -| **MySQL** | `uv add "a2a-sdk[mysql]"` | `pip install "a2a-sdk[mysql]"` | -| **SQLite** | `uv add "a2a-sdk[sqlite]"` | `pip install "a2a-sdk[sqlite]"` | -| **All SQL Drivers** | `uv add "a2a-sdk[sql]"` | `pip install "a2a-sdk[sql]"` | +| **Encryption** | `uv add "a2a-sdk[encryption]"` | `pip install "a2a-sdk[encryption]"` | +| | | | +| **Database Drivers** | | | +| **PostgreSQL** | `uv add "a2a-sdk[postgresql]"` | `pip install "a2a-sdk[postgresql]"` | +| **MySQL** | `uv add "a2a-sdk[mysql]"` | `pip install "a2a-sdk[mysql]"` | +| **SQLite** | `uv add "a2a-sdk[sqlite]"` | `pip install "a2a-sdk[sqlite]"` | +| **All SQL Drivers** | `uv add "a2a-sdk[sql]"` | `pip install "a2a-sdk[sql]"` | ## Examples diff --git a/pyproject.toml b/pyproject.toml index 80e38bd8e..0d2cb75a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,13 +30,22 @@ classifiers = [ [project.optional-dependencies] http-server = ["fastapi>=0.115.2", "sse-starlette", "starlette"] -postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"] -mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"] -sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"] -sql = ["sqlalchemy[asyncio,postgresql-asyncpg,aiomysql,aiosqlite]>=2.0.0"] encryption = ["cryptography>=43.0.0"] grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0"] telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"] +postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"] +mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"] +sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"] + +sql = ["a2a-sdk[postgresql,mysql,sqlite]"] + +all = [ + "a2a-sdk[http-server]", + "a2a-sdk[sql]", + "a2a-sdk[encryption]", + "a2a-sdk[grpc]", + "a2a-sdk[telemetry]", +] [project.urls] homepage = "https://a2a-protocol.org/" diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index 4f158da8e..ea73ff592 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -337,6 +337,7 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911 # 3) Build call context and wrap the request for downstream handling call_context = self._context_builder.build(request) + call_context.state['method'] = method request_id = specific_request.id a2a_request = A2ARequest(root=specific_request) diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 12b0d3072..bf7953ca9 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -18,9 +18,9 @@ # Regexp patterns for matching -_TASK_NAME_MATCH = re.compile(r'tasks/([\w-]+)') +_TASK_NAME_MATCH = re.compile(r'tasks/([^/]+)') _TASK_PUSH_CONFIG_NAME_MATCH = re.compile( - r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)' + r'tasks/([^/]+)/pushNotificationConfigs/([^/]+)' ) @@ -46,6 +46,86 @@ def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct: return struct +def make_dict_serializable(value: Any) -> Any: + """Dict pre-processing utility: converts non-serializable values to serializable form. + + Use this when you want to normalize a dictionary before dict->Struct conversion. + + Args: + value: The value to convert. + + Returns: + A serializable value. + """ + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, dict): + return {k: make_dict_serializable(v) for k, v in value.items()} + if isinstance(value, list | tuple): + return [make_dict_serializable(item) for item in value] + return str(value) + + +def normalize_large_integers_to_strings( + value: Any, max_safe_digits: int = 15 +) -> Any: + """Integer preprocessing utility: converts large integers to strings. + + Use this when you want to convert large integers to strings considering + JavaScript's MAX_SAFE_INTEGER (2^53 - 1) limitation. + + Args: + value: The value to convert. + max_safe_digits: Maximum safe integer digits (default: 15). + + Returns: + A normalized value. + """ + max_safe_int = 10**max_safe_digits - 1 + + def _normalize(item: Any) -> Any: + if isinstance(item, int) and abs(item) > max_safe_int: + return str(item) + if isinstance(item, dict): + return {k: _normalize(v) for k, v in item.items()} + if isinstance(item, list | tuple): + return [_normalize(i) for i in item] + return item + + return _normalize(value) + + +def parse_string_integers_in_dict(value: Any, max_safe_digits: int = 15) -> Any: + """String post-processing utility: converts large integer strings back to integers. + + Use this when you want to restore large integer strings to integers + after Struct->dict conversion. + + Args: + value: The value to convert. + max_safe_digits: Maximum safe integer digits (default: 15). + + Returns: + A parsed value. + """ + if isinstance(value, dict): + return { + k: parse_string_integers_in_dict(v, max_safe_digits) + for k, v in value.items() + } + if isinstance(value, list | tuple): + return [ + parse_string_integers_in_dict(item, max_safe_digits) + for item in value + ] + if isinstance(value, str): + # Handle potential negative numbers. + stripped_value = value.lstrip('-') + if stripped_value.isdigit() and len(stripped_value) > max_safe_digits: + return int(value) + return value + + class ToProto: """Converts Python types to proto types.""" diff --git a/tests/client/test_grpc_client.py b/tests/client/test_grpc_client.py index c6481b377..19f5abc16 100644 --- a/tests/client/test_grpc_client.py +++ b/tests/client/test_grpc_client.py @@ -1,5 +1,6 @@ -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock +import grpc import pytest from a2a.client.transports.grpc import GrpcTransport @@ -7,31 +8,38 @@ from a2a.types import ( AgentCapabilities, AgentCard, + Artifact, + GetTaskPushNotificationConfigParams, Message, MessageSendParams, Part, + PushNotificationAuthenticationInfo, + PushNotificationConfig, Role, Task, + TaskArtifactUpdateEvent, TaskIdParams, + TaskPushNotificationConfig, TaskQueryParams, TaskState, TaskStatus, + TaskStatusUpdateEvent, TextPart, ) from a2a.utils import get_text_parts, proto_utils +from a2a.utils.errors import ServerError -# Fixtures @pytest.fixture def mock_grpc_stub() -> AsyncMock: """Provides a mock gRPC stub with methods mocked.""" stub = AsyncMock(spec=a2a_pb2_grpc.A2AServiceStub) stub.SendMessage = AsyncMock() - stub.SendStreamingMessage = AsyncMock() + stub.SendStreamingMessage = MagicMock() stub.GetTask = AsyncMock() stub.CancelTask = AsyncMock() - stub.CreateTaskPushNotification = AsyncMock() - stub.GetTaskPushNotification = AsyncMock() + stub.CreateTaskPushNotificationConfig = AsyncMock() + stub.GetTaskPushNotificationConfig = AsyncMock() return stub @@ -93,6 +101,78 @@ def sample_message() -> Message: ) +@pytest.fixture +def sample_artifact() -> Artifact: + """Provides a sample Artifact object.""" + return Artifact( + artifact_id='artifact-1', + name='example.txt', + description='An example artifact', + parts=[Part(root=TextPart(text='Hi there'))], + metadata={}, + extensions=[], + ) + + +@pytest.fixture +def sample_task_status_update_event() -> TaskStatusUpdateEvent: + """Provides a sample TaskStatusUpdateEvent.""" + return TaskStatusUpdateEvent( + task_id='task-1', + context_id='ctx-1', + status=TaskStatus(state=TaskState.working), + final=False, + metadata={}, + ) + + +@pytest.fixture +def sample_task_artifact_update_event( + sample_artifact, +) -> TaskArtifactUpdateEvent: + """Provides a sample TaskArtifactUpdateEvent.""" + return TaskArtifactUpdateEvent( + task_id='task-1', + context_id='ctx-1', + artifact=sample_artifact, + append=True, + last_chunk=True, + metadata={}, + ) + + +@pytest.fixture +def sample_authentication_info() -> PushNotificationAuthenticationInfo: + """Provides a sample AuthenticationInfo object.""" + return PushNotificationAuthenticationInfo( + schemes=['apikey', 'oauth2'], credentials='secret-token' + ) + + +@pytest.fixture +def sample_push_notification_config( + sample_authentication_info: PushNotificationAuthenticationInfo, +) -> PushNotificationConfig: + """Provides a sample PushNotificationConfig object.""" + return PushNotificationConfig( + id='config-1', + url='https://example.com/notify', + token='example-token', + authentication=sample_authentication_info, + ) + + +@pytest.fixture +def sample_task_push_notification_config( + sample_push_notification_config: PushNotificationConfig, +) -> TaskPushNotificationConfig: + """Provides a sample TaskPushNotificationConfig object.""" + return TaskPushNotificationConfig( + task_id='task-1', + push_notification_config=sample_push_notification_config, + ) + + @pytest.mark.asyncio async def test_send_message_task_response( grpc_transport: GrpcTransport, @@ -134,6 +214,57 @@ async def test_send_message_message_response( ) +@pytest.mark.asyncio +async def test_send_message_streaming( # noqa: PLR0913 + grpc_transport: GrpcTransport, + mock_grpc_stub: AsyncMock, + sample_message_send_params: MessageSendParams, + sample_message: Message, + sample_task: Task, + sample_task_status_update_event: TaskStatusUpdateEvent, + sample_task_artifact_update_event: TaskArtifactUpdateEvent, +): + """Test send_message_streaming that yields responses.""" + stream = MagicMock() + stream.read = AsyncMock( + side_effect=[ + a2a_pb2.StreamResponse( + msg=proto_utils.ToProto.message(sample_message) + ), + a2a_pb2.StreamResponse(task=proto_utils.ToProto.task(sample_task)), + a2a_pb2.StreamResponse( + status_update=proto_utils.ToProto.task_status_update_event( + sample_task_status_update_event + ) + ), + a2a_pb2.StreamResponse( + artifact_update=proto_utils.ToProto.task_artifact_update_event( + sample_task_artifact_update_event + ) + ), + grpc.aio.EOF, + ] + ) + mock_grpc_stub.SendStreamingMessage.return_value = stream + + responses = [ + response + async for response in grpc_transport.send_message_streaming( + sample_message_send_params + ) + ] + + mock_grpc_stub.SendStreamingMessage.assert_called_once() + assert isinstance(responses[0], Message) + assert responses[0].message_id == sample_message.message_id + assert isinstance(responses[1], Task) + assert responses[1].id == sample_task.id + assert isinstance(responses[2], TaskStatusUpdateEvent) + assert responses[2].task_id == sample_task_status_update_event.task_id + assert isinstance(responses[3], TaskArtifactUpdateEvent) + assert responses[3].task_id == sample_task_artifact_update_event.task_id + + @pytest.mark.asyncio async def test_get_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task @@ -188,3 +319,118 @@ async def test_cancel_task( a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}') ) assert response.status.state == TaskState.canceled + + +@pytest.mark.asyncio +async def test_set_task_callback_with_valid_task( + grpc_transport: GrpcTransport, + mock_grpc_stub: AsyncMock, + sample_task_push_notification_config: TaskPushNotificationConfig, +): + """Test setting a task push notification config with a valid task id.""" + mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = ( + proto_utils.ToProto.task_push_notification_config( + sample_task_push_notification_config + ) + ) + + response = await grpc_transport.set_task_callback( + sample_task_push_notification_config + ) + + mock_grpc_stub.CreateTaskPushNotificationConfig.assert_awaited_once_with( + a2a_pb2.CreateTaskPushNotificationConfigRequest( + parent=f'tasks/{sample_task_push_notification_config.task_id}', + config_id=sample_task_push_notification_config.push_notification_config.id, + config=proto_utils.ToProto.task_push_notification_config( + sample_task_push_notification_config + ), + ) + ) + assert response.task_id == sample_task_push_notification_config.task_id + + +@pytest.mark.asyncio +async def test_set_task_callback_with_invalid_task( + grpc_transport: GrpcTransport, + mock_grpc_stub: AsyncMock, + sample_task_push_notification_config: TaskPushNotificationConfig, +): + """Test setting a task push notification config with an invalid task id.""" + mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = a2a_pb2.TaskPushNotificationConfig( + name=( + f'invalid-path-to-tasks/{sample_task_push_notification_config.task_id}/' + f'pushNotificationConfigs/{sample_task_push_notification_config.push_notification_config.id}' + ), + push_notification_config=proto_utils.ToProto.push_notification_config( + sample_task_push_notification_config.push_notification_config + ), + ) + + with pytest.raises(ServerError) as exc_info: + await grpc_transport.set_task_callback( + sample_task_push_notification_config + ) + assert ( + 'Bad TaskPushNotificationConfig resource name' + in exc_info.value.error.message + ) + + +@pytest.mark.asyncio +async def test_get_task_callback_with_valid_task( + grpc_transport: GrpcTransport, + mock_grpc_stub: AsyncMock, + sample_task_push_notification_config: TaskPushNotificationConfig, +): + """Test retrieving a task push notification config with a valid task id.""" + mock_grpc_stub.GetTaskPushNotificationConfig.return_value = ( + proto_utils.ToProto.task_push_notification_config( + sample_task_push_notification_config + ) + ) + params = GetTaskPushNotificationConfigParams( + id=sample_task_push_notification_config.task_id, + push_notification_config_id=sample_task_push_notification_config.push_notification_config.id, + ) + + response = await grpc_transport.get_task_callback(params) + + mock_grpc_stub.GetTaskPushNotificationConfig.assert_awaited_once_with( + a2a_pb2.GetTaskPushNotificationConfigRequest( + name=( + f'tasks/{params.id}/' + f'pushNotificationConfigs/{params.push_notification_config_id}' + ), + ) + ) + assert response.task_id == sample_task_push_notification_config.task_id + + +@pytest.mark.asyncio +async def test_get_task_callback_with_invalid_task( + grpc_transport: GrpcTransport, + mock_grpc_stub: AsyncMock, + sample_task_push_notification_config: TaskPushNotificationConfig, +): + """Test retrieving a task push notification config with an invalid task id.""" + mock_grpc_stub.GetTaskPushNotificationConfig.return_value = a2a_pb2.TaskPushNotificationConfig( + name=( + f'invalid-path-to-tasks/{sample_task_push_notification_config.task_id}/' + f'pushNotificationConfigs/{sample_task_push_notification_config.push_notification_config.id}' + ), + push_notification_config=proto_utils.ToProto.push_notification_config( + sample_task_push_notification_config.push_notification_config + ), + ) + params = GetTaskPushNotificationConfigParams( + id=sample_task_push_notification_config.task_id, + push_notification_config_id=sample_task_push_notification_config.push_notification_config.id, + ) + + with pytest.raises(ServerError) as exc_info: + await grpc_transport.get_task_callback(params) + assert ( + 'Bad TaskPushNotificationConfig resource name' + in exc_info.value.error.message + ) diff --git a/tests/server/apps/jsonrpc/test_jsonrpc_app.py b/tests/server/apps/jsonrpc/test_jsonrpc_app.py index 72da73772..36309872e 100644 --- a/tests/server/apps/jsonrpc/test_jsonrpc_app.py +++ b/tests/server/apps/jsonrpc/test_jsonrpc_app.py @@ -289,6 +289,26 @@ def test_request_with_comma_separated_extensions_no_space( call_context = mock_handler.on_message_send.call_args[0][1] assert call_context.requested_extensions == {'foo', 'bar', 'baz'} + def test_method_added_to_call_context_state(self, client, mock_handler): + response = client.post( + '/', + json=SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + message_id='1', + role=Role.user, + parts=[Part(TextPart(text='hi'))], + ) + ), + ).model_dump(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert call_context.state['method'] == 'message/send' + def test_request_with_multiple_extension_headers( self, client, mock_handler ): diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index cce4bca23..da54f833f 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -251,3 +251,260 @@ def test_none_handling(self): assert proto_utils.ToProto.provider(None) is None assert proto_utils.ToProto.security(None) is None assert proto_utils.ToProto.security_schemes(None) is None + + def test_metadata_conversion(self): + """Test metadata conversion with various data types.""" + metadata = { + 'null_value': None, + 'bool_value': True, + 'int_value': 42, + 'float_value': 3.14, + 'string_value': 'hello', + 'dict_value': {'nested': 'dict', 'count': 10}, + 'list_value': [1, 'two', 3.0, True, None], + 'tuple_value': (1, 2, 3), + 'complex_list': [ + {'name': 'item1', 'values': [1, 2, 3]}, + {'name': 'item2', 'values': [4, 5, 6]}, + ], + } + + # Convert to proto + proto_metadata = proto_utils.ToProto.metadata(metadata) + assert proto_metadata is not None + + # Convert back to Python + roundtrip_metadata = proto_utils.FromProto.metadata(proto_metadata) + + # Verify all values are preserved correctly + assert roundtrip_metadata['null_value'] is None + assert roundtrip_metadata['bool_value'] is True + assert roundtrip_metadata['int_value'] == 42 + assert roundtrip_metadata['float_value'] == 3.14 + assert roundtrip_metadata['string_value'] == 'hello' + assert roundtrip_metadata['dict_value']['nested'] == 'dict' + assert roundtrip_metadata['dict_value']['count'] == 10 + assert roundtrip_metadata['list_value'] == [1, 'two', 3.0, True, None] + assert roundtrip_metadata['tuple_value'] == [ + 1, + 2, + 3, + ] # tuples become lists + assert len(roundtrip_metadata['complex_list']) == 2 + assert roundtrip_metadata['complex_list'][0]['name'] == 'item1' + + def test_metadata_with_custom_objects(self): + """Test metadata conversion with custom objects using preprocessing utility.""" + + class CustomObject: + def __str__(self): + return 'custom_object_str' + + def __repr__(self): + return 'CustomObject()' + + metadata = { + 'custom_obj': CustomObject(), + 'list_with_custom': [1, CustomObject(), 'text'], + 'nested_custom': {'obj': CustomObject(), 'normal': 'value'}, + } + + # Use preprocessing utility to make it serializable + serializable_metadata = proto_utils.make_dict_serializable(metadata) + + # Convert to proto + proto_metadata = proto_utils.ToProto.metadata(serializable_metadata) + assert proto_metadata is not None + + # Convert back to Python + roundtrip_metadata = proto_utils.FromProto.metadata(proto_metadata) + + # Custom objects should be converted to strings + assert roundtrip_metadata['custom_obj'] == 'custom_object_str' + assert roundtrip_metadata['list_with_custom'] == [ + 1, + 'custom_object_str', + 'text', + ] + assert roundtrip_metadata['nested_custom']['obj'] == 'custom_object_str' + assert roundtrip_metadata['nested_custom']['normal'] == 'value' + + def test_metadata_edge_cases(self): + """Test metadata conversion with edge cases.""" + metadata = { + 'empty_dict': {}, + 'empty_list': [], + 'zero': 0, + 'false': False, + 'empty_string': '', + 'unicode_string': 'string test', + 'safe_number': 9007199254740991, # JavaScript MAX_SAFE_INTEGER + 'negative_number': -42, + 'float_precision': 0.123456789, + 'numeric_string': '12345', + } + + # Convert to proto and back + proto_metadata = proto_utils.ToProto.metadata(metadata) + roundtrip_metadata = proto_utils.FromProto.metadata(proto_metadata) + + # Verify edge cases are handled correctly + assert roundtrip_metadata['empty_dict'] == {} + assert roundtrip_metadata['empty_list'] == [] + assert roundtrip_metadata['zero'] == 0 + assert roundtrip_metadata['false'] is False + assert roundtrip_metadata['empty_string'] == '' + assert roundtrip_metadata['unicode_string'] == 'string test' + assert roundtrip_metadata['safe_number'] == 9007199254740991 + assert roundtrip_metadata['negative_number'] == -42 + assert abs(roundtrip_metadata['float_precision'] - 0.123456789) < 1e-10 + assert roundtrip_metadata['numeric_string'] == '12345' + + def test_make_dict_serializable(self): + """Test the make_dict_serializable utility function.""" + + class CustomObject: + def __str__(self): + return 'custom_str' + + test_data = { + 'string': 'hello', + 'int': 42, + 'float': 3.14, + 'bool': True, + 'none': None, + 'custom': CustomObject(), + 'list': [1, 'two', CustomObject()], + 'tuple': (1, 2, CustomObject()), + 'nested': {'inner_custom': CustomObject(), 'inner_normal': 'value'}, + } + + result = proto_utils.make_dict_serializable(test_data) + + # Basic types should be unchanged + assert result['string'] == 'hello' + assert result['int'] == 42 + assert result['float'] == 3.14 + assert result['bool'] is True + assert result['none'] is None + + # Custom objects should be converted to strings + assert result['custom'] == 'custom_str' + assert result['list'] == [1, 'two', 'custom_str'] + assert result['tuple'] == [1, 2, 'custom_str'] # tuples become lists + assert result['nested']['inner_custom'] == 'custom_str' + assert result['nested']['inner_normal'] == 'value' + + def test_normalize_large_integers_to_strings(self): + """Test the normalize_large_integers_to_strings utility function.""" + + test_data = { + 'small_int': 42, + 'large_int': 9999999999999999999, # > 15 digits + 'negative_large': -9999999999999999999, + 'float': 3.14, + 'string': 'hello', + 'list': [123, 9999999999999999999, 'text'], + 'nested': {'inner_large': 9999999999999999999, 'inner_small': 100}, + } + + result = proto_utils.normalize_large_integers_to_strings(test_data) + + # Small integers should remain as integers + assert result['small_int'] == 42 + assert isinstance(result['small_int'], int) + + # Large integers should be converted to strings + assert result['large_int'] == '9999999999999999999' + assert isinstance(result['large_int'], str) + assert result['negative_large'] == '-9999999999999999999' + assert isinstance(result['negative_large'], str) + + # Other types should be unchanged + assert result['float'] == 3.14 + assert result['string'] == 'hello' + + # Lists should be processed recursively + assert result['list'] == [123, '9999999999999999999', 'text'] + + # Nested dicts should be processed recursively + assert result['nested']['inner_large'] == '9999999999999999999' + assert result['nested']['inner_small'] == 100 + + def test_parse_string_integers_in_dict(self): + """Test the parse_string_integers_in_dict utility function.""" + + test_data = { + 'regular_string': 'hello', + 'numeric_string_small': '123', # small, should stay as string + 'numeric_string_large': '9999999999999999999', # > 15 digits, should become int + 'negative_large_string': '-9999999999999999999', + 'float_string': '3.14', # not all digits, should stay as string + 'mixed_string': '123abc', # not all digits, should stay as string + 'int': 42, + 'list': ['hello', '9999999999999999999', '123'], + 'nested': { + 'inner_large_string': '9999999999999999999', + 'inner_regular': 'value', + }, + } + + result = proto_utils.parse_string_integers_in_dict(test_data) + + # Regular strings should remain unchanged + assert result['regular_string'] == 'hello' + assert ( + result['numeric_string_small'] == '123' + ) # too small, stays string + assert result['float_string'] == '3.14' # not all digits + assert result['mixed_string'] == '123abc' # not all digits + + # Large numeric strings should be converted to integers + assert result['numeric_string_large'] == 9999999999999999999 + assert isinstance(result['numeric_string_large'], int) + assert result['negative_large_string'] == -9999999999999999999 + assert isinstance(result['negative_large_string'], int) + + # Other types should be unchanged + assert result['int'] == 42 + + # Lists should be processed recursively + assert result['list'] == ['hello', 9999999999999999999, '123'] + + # Nested dicts should be processed recursively + assert result['nested']['inner_large_string'] == 9999999999999999999 + assert result['nested']['inner_regular'] == 'value' + + def test_large_integer_roundtrip_with_utilities(self): + """Test large integer handling with preprocessing and post-processing utilities.""" + + original_data = { + 'large_int': 9999999999999999999, + 'small_int': 42, + 'nested': {'another_large': 12345678901234567890, 'normal': 'text'}, + } + + # Step 1: Preprocess to convert large integers to strings + preprocessed = proto_utils.normalize_large_integers_to_strings( + original_data + ) + + # Step 2: Convert to proto + proto_metadata = proto_utils.ToProto.metadata(preprocessed) + assert proto_metadata is not None + + # Step 3: Convert back from proto + dict_from_proto = proto_utils.FromProto.metadata(proto_metadata) + + # Step 4: Post-process to convert large integer strings back to integers + final_result = proto_utils.parse_string_integers_in_dict( + dict_from_proto + ) + + # Verify roundtrip preserved the original data + assert final_result['large_int'] == 9999999999999999999 + assert isinstance(final_result['large_int'], int) + assert final_result['small_int'] == 42 + assert final_result['nested']['another_large'] == 12345678901234567890 + assert isinstance(final_result['nested']['another_large'], int) + assert final_result['nested']['normal'] == 'text'