Skip to content

Commit 331bc2a

Browse files
committed
feat: Standardize parameter handling for cursor operations
1 parent 1a1e2bc commit 331bc2a

3 files changed

Lines changed: 149 additions & 26 deletions

File tree

spannerlib/wrappers/spannerlib-python/google-cloud-spanner-driver/google/cloud/spanner_driver/cursor.py

Lines changed: 84 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import base64
15+
import datetime
1516
from enum import Enum
1617
import logging
1718
from typing import TYPE_CHECKING, Any
19+
import uuid
1820

1921
from google.cloud.spanner_v1 import (
2022
ExecuteBatchDmlRequest,
@@ -132,6 +134,84 @@ def rowcount(self) -> int:
132134
"""
133135
return self._rowcount
134136

137+
def _prepare_params(
138+
self, parameters: dict[str, Any] | list[Any] | tuple[Any] | None = None
139+
) -> (dict[str, Any] | None, dict[str, Type] | None):
140+
"""
141+
Prepares parameters for Spanner execution
142+
143+
Args:
144+
parameters: A dictionary (for named parameters/GoogleSQL)
145+
or a list/tuple
146+
(for positional parameters/PostgreSQL).
147+
148+
Returns:
149+
A tuple containing:
150+
- converted_params: Dictionary of parameters with values
151+
converted for Spanner (e.g. ints to strings).
152+
- param_types: Dictionary mapping parameter names to
153+
their Spanner Type.
154+
"""
155+
if not parameters:
156+
return {}, {}
157+
158+
converted_params = {}
159+
param_types = {}
160+
161+
# Normalize input to an iterable of (key, value)
162+
if isinstance(parameters, (list, tuple)):
163+
# PostgreSQL Dialect: Positional parameters $1, $2... are
164+
# mapped to P1, P2...
165+
iterator = ((f"P{i}", val) for i, val in enumerate(parameters, 1))
166+
elif isinstance(parameters, dict):
167+
# GoogleSQL Dialect: Named parameters @name are mapped directly.
168+
iterator = parameters.items()
169+
else:
170+
# If strictly required, raise an error for unsupported types
171+
return {}, {}
172+
173+
for key, value in iterator:
174+
if value is None:
175+
converted_params[key] = None
176+
continue
177+
# Note: check bool before int, as bool is a subclass of int
178+
if isinstance(value, bool):
179+
converted_params[key] = value
180+
param_types[key] = Type(code=TypeCode.BOOL)
181+
elif isinstance(value, int):
182+
# Spanner expects INT64 as strings to preserve precision
183+
# in JSON
184+
converted_params[key] = str(value)
185+
param_types[key] = Type(code=TypeCode.INT64)
186+
elif isinstance(value, float):
187+
converted_params[key] = value
188+
param_types[key] = Type(code=TypeCode.FLOAT64)
189+
elif isinstance(value, bytes):
190+
converted_params[key] = value
191+
param_types[key] = Type(code=TypeCode.BYTES)
192+
elif isinstance(value, uuid.UUID):
193+
# Convert UUID to string as requested
194+
converted_params[key] = str(value)
195+
# Use STRING type for UUIDs (unless specific UUID type is
196+
# required/supported by your backend version)
197+
param_types[key] = Type(code=TypeCode.STRING)
198+
elif isinstance(value, datetime.datetime):
199+
# Convert Datetime to string (RFC 3339 format is standard
200+
# for str(datetime))
201+
converted_params[key] = str(value)
202+
param_types[key] = Type(code=TypeCode.TIMESTAMP)
203+
elif isinstance(value, datetime.date):
204+
converted_params[key] = str(value)
205+
param_types[key] = Type(code=TypeCode.DATE)
206+
else:
207+
# Fallback for strings and other types
208+
converted_params[key] = value
209+
# For strings, we can explicitly set the type or let it default.
210+
if isinstance(value, str):
211+
param_types[key] = Type(code=TypeCode.STRING)
212+
213+
return converted_params, param_types
214+
135215
@check_not_closed
136216
def execute(
137217
self,
@@ -152,18 +232,8 @@ def execute(
152232
logger.debug(f"Executing operation: {operation}")
153233

154234
request = ExecuteSqlRequest(sql=operation)
155-
if parameters:
156-
converted_params = {}
157-
param_types = {}
158-
for key, value in parameters.items():
159-
if isinstance(value, int) and not isinstance(value, bool):
160-
converted_params[key] = str(value)
161-
param_types[key] = Type(code=TypeCode.INT64)
162-
else:
163-
converted_params[key] = value
164-
165-
request.params = converted_params
166-
request.param_types = param_types
235+
params, _ = self._prepare_params(parameters)
236+
request.params = params
167237

168238
try:
169239
self._rows = self._connection._internal_conn.execute(request)
@@ -202,18 +272,8 @@ def executemany(
202272

203273
for parameters in seq_of_parameters:
204274
statement = ExecuteBatchDmlRequest.Statement(sql=operation)
205-
if parameters:
206-
converted_params = {}
207-
param_types = {}
208-
for key, value in parameters.items():
209-
if isinstance(value, int) and not isinstance(value, bool):
210-
converted_params[key] = str(value)
211-
param_types[key] = Type(code=TypeCode.INT64)
212-
else:
213-
converted_params[key] = value
214-
215-
statement.params = converted_params
216-
statement.param_types = param_types
275+
params, _ = self._prepare_params(parameters)
276+
statement.params = params
217277

218278
request.statements.append(statement)
219279

spannerlib/wrappers/spannerlib-python/google-cloud-spanner-driver/tests/unit/test_connection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def test_begin(self):
5858
self.mock_internal_conn.begin_transaction.assert_called_once()
5959

6060
def test_begin_error(self):
61-
self.mock_internal_conn.begin_transaction.side_effect = Exception("Internal Error")
61+
self.mock_internal_conn.begin_transaction.side_effect = Exception(
62+
"Internal Error"
63+
)
6264
with self.assertRaises(errors.DatabaseError):
6365
self.conn.begin()
6466

spannerlib/wrappers/spannerlib-python/google-cloud-spanner-driver/tests/unit/test_cursor.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import datetime
1516
import unittest
1617
from unittest import mock
18+
import uuid
1719

1820
from google.cloud.spanner_v1 import ExecuteSqlRequest, TypeCode
1921
from google.cloud.spanner_v1.types import StructField, Type
@@ -91,7 +93,6 @@ def test_execute_with_params(self):
9193
self.assertEqual(request.sql, operation)
9294
self.assertEqual(request.sql, operation)
9395
self.assertEqual(request.params, {"id": "1"})
94-
self.assertEqual(request.param_types, {"id": Type(code=TypeCode.INT64)})
9596

9697
def test_executemany(self):
9798
operation = "INSERT INTO table (id) VALUES (@id)"
@@ -288,3 +289,63 @@ def test_iterator(self):
288289
self.assertEqual(next(it), (1,))
289290
with self.assertRaises(StopIteration):
290291
next(it)
292+
293+
def test_prepare_params(self):
294+
# Test 1: None
295+
converted, types = self.cursor._prepare_params(None)
296+
self.assertEqual(converted, {})
297+
self.assertEqual(types, {})
298+
299+
# Test 2: Dict (GoogleSQL)
300+
uuid_val = uuid.uuid4()
301+
dt_val = datetime.datetime(2024, 1, 1, 12, 0, 0)
302+
date_val = datetime.date(2024, 1, 1)
303+
params = {
304+
"int_val": 123,
305+
"bool_val": True,
306+
"float_val": 1.23,
307+
"bytes_val": b"bytes",
308+
"str_val": "string",
309+
"uuid_val": uuid_val,
310+
"dt_val": dt_val,
311+
"date_val": date_val,
312+
"none_val": None,
313+
}
314+
converted, types = self.cursor._prepare_params(params)
315+
316+
self.assertEqual(converted["int_val"], "123")
317+
self.assertEqual(types["int_val"].code, TypeCode.INT64)
318+
319+
self.assertEqual(converted["bool_val"], True)
320+
self.assertEqual(types["bool_val"].code, TypeCode.BOOL)
321+
322+
self.assertEqual(converted["float_val"], 1.23)
323+
self.assertEqual(types["float_val"].code, TypeCode.FLOAT64)
324+
325+
self.assertEqual(converted["bytes_val"], b"bytes")
326+
self.assertEqual(types["bytes_val"].code, TypeCode.BYTES)
327+
328+
self.assertEqual(converted["str_val"], "string")
329+
self.assertEqual(types["str_val"].code, TypeCode.STRING)
330+
331+
self.assertEqual(converted["uuid_val"], str(uuid_val))
332+
self.assertEqual(types["uuid_val"].code, TypeCode.STRING)
333+
334+
self.assertEqual(converted["dt_val"], str(dt_val))
335+
self.assertEqual(types["dt_val"].code, TypeCode.TIMESTAMP)
336+
337+
self.assertEqual(converted["date_val"], str(date_val))
338+
self.assertEqual(types["date_val"].code, TypeCode.DATE)
339+
340+
self.assertIsNone(converted["none_val"])
341+
self.assertNotIn("none_val", types)
342+
343+
# Test 3: List (PostgreSQL)
344+
params_list = [1, "test"]
345+
converted, types = self.cursor._prepare_params(params_list)
346+
347+
self.assertEqual(converted["P1"], "1")
348+
self.assertEqual(types["P1"].code, TypeCode.INT64)
349+
350+
self.assertEqual(converted["P2"], "test")
351+
self.assertEqual(types["P2"].code, TypeCode.STRING)

0 commit comments

Comments
 (0)