diff --git a/src/crate/client/cursor.py b/src/crate/client/cursor.py index f07d3197..66b61a4a 100644 --- a/src/crate/client/cursor.py +++ b/src/crate/client/cursor.py @@ -70,6 +70,37 @@ def _replace(match: "re.Match[str]") -> str: return converted_sql, new_params +def _convert_named_bulk_params( + sql: str, seq_of_dicts: t.Sequence[t.Dict[str, t.Any]] +) -> t.Tuple[str, t.List[t.List[t.Any]]]: + """Convert pyformat SQL and a sequence of dicts to positional bulk args. + + Uses the first row to determine the SQL template and position map, then + builds a positional argument list for every row. + + Raises ``ProgrammingError`` if a placeholder name is absent from any row. + Extra keys in each row are silently ignored (consistent with + ``_convert_named_to_positional``). + """ + first = seq_of_dicts[0] + converted_sql, _ = _convert_named_to_positional(sql, first) + positions = {k: i + 1 for i, k in enumerate(first)} + n = len(positions) + + bulk_args: t.List[t.List[t.Any]] = [] + for row in seq_of_dicts: + positional: t.List[t.Any] = [None] * n + for name, pos in positions.items(): + if name not in row: + raise ProgrammingError( + f"Named parameter '{name}' not found in the parameters dict" + ) + positional[pos - 1] = row[name] + bulk_args.append(positional) + + return converted_sql, bulk_args + + class Cursor: """ not thread-safe by intention @@ -118,7 +149,16 @@ def executemany(self, sql, seq_of_parameters): """ row_counts = [] durations = [] - self.execute(sql, bulk_parameters=seq_of_parameters) + bulk_parameters = seq_of_parameters + if ( + bulk_parameters + and isinstance(bulk_parameters[0], dict) + and _NAMED_PARAM_RE.search(sql) + ): + sql, bulk_parameters = _convert_named_bulk_params( + sql, bulk_parameters + ) + self.execute(sql, bulk_parameters=bulk_parameters) for result in self._result.get("results", []): if result.get("rowcount") > -1: diff --git a/tests/client/test_cursor.py b/tests/client/test_cursor.py index dcb71774..ec5f9695 100644 --- a/tests/client/test_cursor.py +++ b/tests/client/test_cursor.py @@ -125,6 +125,53 @@ def test_cursor_executemany(mocked_connection): assert response["results"] == result +def test_executemany_with_named_params(mocked_connection): + """ + Verify that executemany() translates pyformat %(name)s placeholders to + positional $N markers and converts each dict row to a positional list. + + """ + response = { + "col_types": [], + "cols": [], + "duration": 123, + "results": [{"rowcount": 1}, {"rowcount": 1}], + } + with mock.patch.object( + mocked_connection.client, "sql", return_value=response + ): + cursor = mocked_connection.cursor() + cursor.executemany( + "INSERT INTO characters (name, age) VALUES (%(name)s, %(age)s)", + [ + {"name": "Arthur", "age": 42}, + {"name": "Bill", "age": 35}, + ], + ) + sql, _params, bulk_args = mocked_connection.client.sql.call_args[0] + assert sql == "INSERT INTO characters (name, age) VALUES ($1, $2)" + assert bulk_args == [["Arthur", 42], ["Bill", 35]] + + +def test_executemany_with_named_params_missing_key(mocked_connection): + """ + Verify that executemany() raises ProgrammingError when a row is missing a + key that appears as a placeholder in the SQL. + """ + cursor = mocked_connection.cursor() + with pytest.raises( + ProgrammingError, match="Named parameter 'age' not found" + ): + cursor.executemany( + "INSERT INTO characters (name, age) VALUES (%(name)s, %(age)s)", + [ + {"name": "Arthur", "age": 42}, + {"name": "Bill"}, # missing 'age' + ], + ) + mocked_connection.client.sql.assert_not_called() + + def test_create_with_timezone_as_datetime_object(mocked_connection): """ The cursor can return timezone-aware `datetime` objects when requested.