diff --git a/packages/google-cloud-spanner/google/cloud/spanner_dbapi/partition_helper.py b/packages/google-cloud-spanner/google/cloud/spanner_dbapi/partition_helper.py index 9c7b0d5cb5a2..ba085f33529c 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_dbapi/partition_helper.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_dbapi/partition_helper.py @@ -13,23 +13,104 @@ # limitations under the License. import base64 +import datetime import gzip -import pickle +import json from dataclasses import dataclass from typing import Any +from google.protobuf.json_format import MessageToDict, ParseDict +from google.protobuf.message import Message + from google.cloud.spanner_v1 import BatchTransactionId +from google.cloud.spanner_v1.types import ExecuteSqlRequest, DirectedReadOptions + +_PROTO_CLASS_MAP = { + "QueryOptions": ExecuteSqlRequest.QueryOptions, + "DirectedReadOptions": DirectedReadOptions, +} + + +def _serialize_value(val: Any) -> Any: + if isinstance(val, bytes): + return {"__type__": "bytes", "value": base64.b64encode(val).decode("utf-8")} + elif isinstance(val, datetime.datetime): + return {"__type__": "datetime", "value": val.isoformat()} + elif hasattr(val, "_pb"): + return { + "__type__": "protobuf", + "class": val.__class__.__name__, + "value": MessageToDict(val._pb, preserving_proto_field_name=True), + } + elif isinstance(val, Message): + return { + "__type__": "protobuf", + "class": val.__class__.__name__, + "value": MessageToDict(val, preserving_proto_field_name=True), + } + elif isinstance(val, dict): + return {k: _serialize_value(v) for k, v in val.items()} + elif isinstance(val, list): + return [_serialize_value(v) for v in val] + elif isinstance(val, tuple): + return {"__type__": "tuple", "value": [_serialize_value(v) for v in val]} + return val + + +def _deserialize_value(val: Any) -> Any: + if isinstance(val, dict): + if "__type__" in val: + t = val["__type__"] + if t == "bytes": + return base64.b64decode(val["value"]) + elif t == "datetime": + dt_str = val["value"] + if dt_str.endswith("Z"): + dt_str = dt_str[:-1] + "+00:00" + return datetime.datetime.fromisoformat(dt_str) + elif t == "tuple": + return tuple(_deserialize_value(x) for x in val["value"]) + elif t == "protobuf": + cls_name = val.get("class") + dict_val = val["value"] + if cls_name in _PROTO_CLASS_MAP: + cls = _PROTO_CLASS_MAP[cls_name] + msg = cls()._pb + ParseDict(dict_val, msg) + return cls(msg) + return _deserialize_value(dict_val) + return {k: _deserialize_value(v) for k, v in val.items()} + elif isinstance(val, list): + return [_deserialize_value(v) for v in val] + return val def decode_from_string(encoded_partition_id): gzip_bytes = base64.b64decode(bytes(encoded_partition_id, "utf-8")) partition_id_bytes = gzip.decompress(gzip_bytes) - return pickle.loads(partition_id_bytes) + + data = json.loads(partition_id_bytes.decode("utf-8")) + btid_data = data["batch_transaction_id"] + btid = BatchTransactionId( + transaction_id=_deserialize_value(btid_data["transaction_id"]), + session_id=btid_data["session_id"], + read_timestamp=_deserialize_value(btid_data["read_timestamp"]), + ) + partition_result = _deserialize_value(data["partition_result"]) + return PartitionId(btid, partition_result) def encode_to_string(batch_transaction_id, partition_result): - partition_id = PartitionId(batch_transaction_id, partition_result) - partition_id_bytes = pickle.dumps(partition_id) + data = { + "batch_transaction_id": { + "transaction_id": _serialize_value(batch_transaction_id.transaction_id), + "session_id": batch_transaction_id.session_id, + "read_timestamp": _serialize_value(batch_transaction_id.read_timestamp), + }, + "partition_result": _serialize_value(partition_result), + } + + partition_id_bytes = json.dumps(data).encode("utf-8") gzip_bytes = gzip.compress(partition_id_bytes) return str(base64.b64encode(gzip_bytes), "utf-8") diff --git a/packages/google-cloud-spanner/tests/unit/spanner_dbapi/test_partition_helper.py b/packages/google-cloud-spanner/tests/unit/spanner_dbapi/test_partition_helper.py new file mode 100644 index 000000000000..e2f81f8d1eda --- /dev/null +++ b/packages/google-cloud-spanner/tests/unit/spanner_dbapi/test_partition_helper.py @@ -0,0 +1,124 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import datetime +import gzip +import json +import unittest + +from google.cloud.spanner_dbapi import partition_helper +from google.cloud.spanner_v1 import BatchTransactionId +from google.cloud.spanner_v1.types import ExecuteSqlRequest + + +class TestPartitionHelper(unittest.TestCase): + def test_encode_and_decode_success_query(self): + btid = BatchTransactionId( + transaction_id=b"test-txn-123", + session_id="session-xyz", + read_timestamp=datetime.datetime( + 2024, 5, 10, 12, 34, 56, tzinfo=datetime.timezone.utc + ), + ) + + query_options = ExecuteSqlRequest.QueryOptions( + optimizer_version="2", + optimizer_statistics_package="package-abc", + ) + + partition_result = { + "partition": b"partition-token-456", + "query": { + "sql": "SELECT * FROM users WHERE age > %s", + "params": {"age": 21}, + "query_options": query_options, + }, + } + + encoded = partition_helper.encode_to_string(btid, partition_result) + self.assertIsInstance(encoded, str) + + decoded = partition_helper.decode_from_string(encoded) + self.assertIsInstance(decoded, partition_helper.PartitionId) + + # Verify BatchTransactionId + self.assertEqual( + decoded.batch_transaction_id.transaction_id, btid.transaction_id + ) + self.assertEqual(decoded.batch_transaction_id.session_id, btid.session_id) + self.assertEqual( + decoded.batch_transaction_id.read_timestamp, btid.read_timestamp + ) + + # Verify partition result + self.assertEqual(decoded.partition_result["partition"], b"partition-token-456") + self.assertEqual( + decoded.partition_result["query"]["sql"], + "SELECT * FROM users WHERE age > %s", + ) + self.assertEqual(decoded.partition_result["query"]["params"], {"age": 21}) + + # Verify query options (restored to object) + opts_obj = decoded.partition_result["query"]["query_options"] + self.assertEqual(opts_obj.optimizer_version, "2") + self.assertEqual(opts_obj.optimizer_statistics_package, "package-abc") + + def test_encode_and_decode_success_read(self): + btid = BatchTransactionId( + transaction_id=b"test-txn-456", + session_id="session-abc", + read_timestamp=None, + ) + + partition_result = { + "partition": b"partition-token-789", + "read": { + "table": "users", + "columns": ["name", "age"], + "keyset": {"keys": [[1], [2]]}, + }, + } + + encoded = partition_helper.encode_to_string(btid, partition_result) + decoded = partition_helper.decode_from_string(encoded) + + self.assertEqual( + decoded.batch_transaction_id.transaction_id, btid.transaction_id + ) + self.assertEqual(decoded.batch_transaction_id.session_id, btid.session_id) + self.assertIsNone(decoded.batch_transaction_id.read_timestamp) + + self.assertEqual(decoded.partition_result["partition"], b"partition-token-789") + self.assertEqual(decoded.partition_result["read"]["table"], "users") + self.assertEqual(decoded.partition_result["read"]["columns"], ["name", "age"]) + self.assertEqual( + decoded.partition_result["read"]["keyset"], {"keys": [[1], [2]]} + ) + + def test_insecure_deserialization_failure(self): + # Malicious payload that attempts to execute pickle.loads under old code + # (Here, we'll just pass invalid JSON wrapped in gzip + base64, or a pickle payload, + # and make sure it does NOT get deserialized or execute anything, but raises an error gracefully) + + # A valid pickle payload for some simple object, base64 encoded and compressed + import pickle + + pickle_bytes = pickle.dumps({"test": "payload"}) + gzip_bytes = gzip.compress(pickle_bytes) + encoded_pickle = base64.b64encode(gzip_bytes).decode("utf-8") + + # Since we now use json.loads, a pickle payload will fail to decode as UTF-8 / JSON + with self.assertRaises((json.JSONDecodeError, UnicodeDecodeError)): + partition_helper.decode_from_string(encoded_pickle)