Skip to content

Commit a52fc7f

Browse files
royi-luoray6080
andauthored
Add ability to pass python dataframes as query parameters (#5376)
* Allow pyarrow tables to be bound as query parameters + used for COPY FROM Add support for polars/pandas Add tests for scaning from invalid source Add check to see if parameterized dataframes are used properly Remove unused include Fix test failures Validate params in prepare Throw exception if input parameter is not found in statement Fix test Fix gcc compile Update autogenerated files * add more tests * Fix missing param test * Deprecate python prepare() API + update docstrings * Fix CI errors * Fix python tests * Limit number of threads for async tests to make sure buffer pool limit isn't hit * Move param value copy to client context * Run clang-format --------- Co-authored-by: Guodong Jin <guod.jin@gmail.com> Co-authored-by: CI Bot <royi-luo@users.noreply.github.com>
1 parent 31b0a75 commit a52fc7f

13 files changed

+309
-127
lines changed

src_cpp/include/pandas/pandas_scan.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,6 @@ struct PandasScanFunctionData : public function::TableFuncBindData {
5555
}
5656
};
5757

58-
std::unique_ptr<function::ScanReplacementData> tryReplacePD(py::dict& dict, py::str& objectName);
58+
std::unique_ptr<function::ScanReplacementData> tryReplacePD(py::handle& entry);
5959

6060
} // namespace kuzu

src_cpp/include/py_connection.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class PyConnection {
3030

3131
void setMaxNumThreadForExec(uint64_t numThreads);
3232

33-
PyPreparedStatement prepare(const std::string& query);
33+
PyPreparedStatement prepare(const std::string& query, const py::dict& parameters);
3434

3535
uint64_t getNumNodes(const std::string& nodeName);
3636

@@ -40,9 +40,9 @@ class PyConnection {
4040
const std::string& srcTableName, const std::string& relName,
4141
const std::string& dstTableName, size_t queryBatchSize);
4242

43-
static bool isPandasDataframe(const py::object& object);
44-
static bool isPolarsDataframe(const py::object& object);
45-
static bool isPyArrowTable(const py::object& object);
43+
static bool isPandasDataframe(const py::handle& object);
44+
static bool isPolarsDataframe(const py::handle& object);
45+
static bool isPyArrowTable(const py::handle& object);
4646

4747
void createScalarFunction(const std::string& name, const py::function& udf,
4848
const py::list& params, const std::string& retval, bool defaultNull, bool catchExceptions);

src_cpp/pandas/pandas_scan.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,7 @@ static bool isPyArrowBacked(const py::handle& df) {
141141
return false;
142142
}
143143

144-
std::unique_ptr<ScanReplacementData> tryReplacePD(py::dict& dict, py::str& objectName) {
145-
if (!dict.contains(objectName)) {
146-
return nullptr;
147-
}
148-
auto entry = dict[objectName];
144+
std::unique_ptr<ScanReplacementData> tryReplacePD(py::handle& entry) {
149145
if (PyConnection::isPandasDataframe(entry)) {
150146
auto scanReplacementData = std::make_unique<ScanReplacementData>();
151147
if (isPyArrowBacked(entry)) {

src_cpp/py_connection.cpp

Lines changed: 60 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ void PyConnection::initialize(py::handle& m) {
2929
.def("query", &PyConnection::query, py::arg("statement"))
3030
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
3131
py::arg("num_threads"))
32-
.def("prepare", &PyConnection::prepare, py::arg("query"))
32+
.def("prepare", &PyConnection::prepare, py::arg("query"),
33+
py::arg("parameters") = py::dict())
3334
.def("set_query_timeout", &PyConnection::setQueryTimeout, py::arg("timeout_in_ms"))
3435
.def("interrupt", &PyConnection::interrupt)
3536
.def("get_num_nodes", &PyConnection::getNumNodes, py::arg("node_name"))
@@ -44,12 +45,36 @@ void PyConnection::initialize(py::handle& m) {
4445
PyDateTime_IMPORT;
4546
}
4647

47-
static std::unique_ptr<function::ScanReplacementData> tryReplacePolars(py::dict& dict,
48-
py::str& objectName) {
49-
if (!dict.contains(objectName)) {
50-
return nullptr;
48+
static std::vector<function::scan_replace_handle_t> lookupPythonObject(
49+
const std::string& objectName) {
50+
std::vector<function::scan_replace_handle_t> ret;
51+
52+
py::gil_scoped_acquire acquire;
53+
auto pyTableName = py::str(objectName);
54+
// Here we do an exhaustive search on the frame lineage.
55+
auto currentFrame = importCache->inspect.currentframe()();
56+
while (hasattr(currentFrame, "f_locals")) {
57+
auto localDict = py::cast<py::dict>(currentFrame.attr("f_locals"));
58+
auto hasLocalDict = !py::none().is(localDict);
59+
if (hasLocalDict) {
60+
if (localDict.contains(pyTableName)) {
61+
ret.push_back(reinterpret_cast<function::scan_replace_handle_t>(
62+
localDict[pyTableName].ptr()));
63+
}
64+
}
65+
auto globalDict = py::reinterpret_borrow<py::dict>(currentFrame.attr("f_globals"));
66+
if (globalDict) {
67+
if (globalDict.contains(pyTableName)) {
68+
ret.push_back(reinterpret_cast<function::scan_replace_handle_t>(
69+
globalDict[pyTableName].ptr()));
70+
}
71+
}
72+
currentFrame = currentFrame.attr("f_back");
5173
}
52-
auto entry = dict[objectName];
74+
return ret;
75+
}
76+
77+
static std::unique_ptr<function::ScanReplacementData> tryReplacePolars(py::handle& entry) {
5378
if (PyConnection::isPolarsDataframe(entry)) {
5479
auto scanReplacementData = std::make_unique<function::ScanReplacementData>();
5580
scanReplacementData->func = PyArrowTableScanFunction::getFunction();
@@ -62,12 +87,7 @@ static std::unique_ptr<function::ScanReplacementData> tryReplacePolars(py::dict&
6287
}
6388
}
6489

65-
static std::unique_ptr<function::ScanReplacementData> tryReplacePyArrow(py::dict& dict,
66-
py::str& objectName) {
67-
if (!dict.contains(objectName)) {
68-
return nullptr;
69-
}
70-
auto entry = dict[objectName];
90+
static std::unique_ptr<function::ScanReplacementData> tryReplacePyArrow(py::handle& entry) {
7191
if (PyConnection::isPyArrowTable(entry)) {
7292
auto scanReplacementData = std::make_unique<function::ScanReplacementData>();
7393
scanReplacementData->func = PyArrowTableScanFunction::getFunction();
@@ -81,59 +101,33 @@ static std::unique_ptr<function::ScanReplacementData> tryReplacePyArrow(py::dict
81101
}
82102

83103
static std::unique_ptr<function::ScanReplacementData> replacePythonObject(
84-
const std::string& objectName) {
104+
std::span<function::scan_replace_handle_t> candidateHandles) {
85105
py::gil_scoped_acquire acquire;
86-
auto pyTableName = py::str(objectName);
87-
// Here we do an exhaustive search on the frame lineage.
88-
auto currentFrame = importCache->inspect.currentframe()();
89-
bool nameMatchFound = false;
90-
while (hasattr(currentFrame, "f_locals")) {
91-
auto localDict = py::cast<py::dict>(currentFrame.attr("f_locals"));
92-
auto hasLocalDict = !py::none().is(localDict);
93-
if (hasLocalDict) {
94-
if (localDict.contains(pyTableName)) {
95-
nameMatchFound = true;
96-
}
97-
auto result = tryReplacePD(localDict, pyTableName);
98-
if (!result) {
99-
result = tryReplacePolars(localDict, pyTableName);
100-
}
101-
if (!result) {
102-
result = tryReplacePyArrow(localDict, pyTableName);
103-
}
104-
if (result) {
105-
return result;
106-
}
106+
for (auto* handle : candidateHandles) {
107+
auto entry = py::handle(reinterpret_cast<PyObject*>(handle));
108+
auto result = tryReplacePD(entry);
109+
if (!result) {
110+
result = tryReplacePolars(entry);
107111
}
108-
auto globalDict = py::reinterpret_borrow<py::dict>(currentFrame.attr("f_globals"));
109-
if (globalDict) {
110-
if (globalDict.contains(pyTableName)) {
111-
nameMatchFound = true;
112-
}
113-
auto result = tryReplacePD(globalDict, pyTableName);
114-
if (!result) {
115-
result = tryReplacePolars(globalDict, pyTableName);
116-
}
117-
if (!result) {
118-
result = tryReplacePyArrow(globalDict, pyTableName);
119-
}
120-
if (result) {
121-
return result;
122-
}
112+
if (!result) {
113+
result = tryReplacePyArrow(entry);
114+
}
115+
if (result) {
116+
return result;
123117
}
124-
currentFrame = currentFrame.attr("f_back");
125118
}
126-
if (nameMatchFound) {
127-
throw BinderException(
128-
stringFormat("Variable {} found but no matches were scannable", objectName));
119+
if (!candidateHandles.empty()) {
120+
throw BinderException("Attempted to scan from unsupported python object. Can only scan "
121+
"from pandas/polars dataframes and pyarrow tables.");
129122
}
130123
return nullptr;
131124
}
132125

133126
PyConnection::PyConnection(PyDatabase* pyDatabase, uint64_t numThreads) {
134127
storageDriver = std::make_unique<kuzu::main::StorageDriver>(pyDatabase->database.get());
135128
conn = std::make_unique<Connection>(pyDatabase->database.get());
136-
conn->getClientContext()->addScanReplace(function::ScanReplacement(replacePythonObject));
129+
conn->getClientContext()->addScanReplace(
130+
function::ScanReplacement(lookupPythonObject, replacePythonObject));
137131
if (numThreads > 0) {
138132
conn->setMaxNumThreadForExec(numThreads);
139133
}
@@ -175,8 +169,9 @@ void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) {
175169
conn->setMaxNumThreadForExec(numThreads);
176170
}
177171

178-
PyPreparedStatement PyConnection::prepare(const std::string& query) {
179-
auto preparedStatement = conn->prepare(query);
172+
PyPreparedStatement PyConnection::prepare(const std::string& query, const py::dict& parameters) {
173+
auto params = transformPythonParameters(parameters, conn.get());
174+
auto preparedStatement = conn->prepareWithParams(query, std::move(params));
180175
PyPreparedStatement pyPreparedStatement;
181176
pyPreparedStatement.preparedStatement = std::move(preparedStatement);
182177
return pyPreparedStatement;
@@ -261,21 +256,21 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t<int64_t>& npArray,
261256
conn->setMaxNumThreadForExec(numThreadsForExec);
262257
}
263258

264-
bool PyConnection::isPandasDataframe(const py::object& object) {
259+
bool PyConnection::isPandasDataframe(const py::handle& object) {
265260
if (!doesPyModuleExist("pandas")) {
266261
return false;
267262
}
268263
return py::isinstance(object, importCache->pandas.DataFrame());
269264
}
270265

271-
bool PyConnection::isPolarsDataframe(const py::object& object) {
266+
bool PyConnection::isPolarsDataframe(const py::handle& object) {
272267
if (!doesPyModuleExist("polars")) {
273268
return false;
274269
}
275270
return py::isinstance(object, importCache->polars.DataFrame());
276271
}
277272

278-
bool PyConnection::isPyArrowTable(const py::object& object) {
273+
bool PyConnection::isPyArrowTable(const py::handle& object) {
279274
if (!doesPyModuleExist("pyarrow")) {
280275
return false;
281276
}
@@ -389,6 +384,9 @@ static LogicalType pyLogicalType(const py::handle& val) {
389384
childType = std::move(result);
390385
}
391386
return LogicalType::LIST(std::move(childType));
387+
} else if (PyConnection::isPyArrowTable(val) || PyConnection::isPandasDataframe(val) ||
388+
PyConnection::isPolarsDataframe(val)) {
389+
return LogicalType::POINTER();
392390
} else {
393391
// LCOV_EXCL_START
394392
throw common::RuntimeException(
@@ -678,6 +676,9 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val,
678676
}
679677
return Value(type.copy(), std::move(children));
680678
}
679+
case LogicalTypeID::POINTER: {
680+
return Value::createValue(reinterpret_cast<uint8_t*>(val.ptr()));
681+
}
681682
default:
682683
return transformPythonValueAs(val, type);
683684
}

src_py/async_connection.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import threading
5+
import warnings
56
from concurrent.futures import ThreadPoolExecutor
67
from typing import TYPE_CHECKING, Any
78

@@ -173,29 +174,44 @@ async def execute(
173174
finally:
174175
self.__decrement_connection_counter(conn_index)
175176

176-
async def prepare(self, query: str) -> PreparedStatement:
177+
async def _prepare(self, query: str, parameters: dict[str, Any] | None = None) -> PreparedStatement:
178+
"""
179+
The only parameters supported during prepare are dataframes.
180+
Any remaining parameters will be ignored and should be passed to execute().
181+
"""
182+
loop = asyncio.get_running_loop()
183+
conn, conn_index = self.__get_connection_with_least_queries()
184+
185+
try:
186+
preparedStatement = await loop.run_in_executor(self.executor, conn.prepare, query, parameters)
187+
return preparedStatement
188+
finally:
189+
self.__decrement_connection_counter(conn_index)
190+
191+
async def prepare(self, query: str, parameters: dict[str, Any] | None = None) -> PreparedStatement:
177192
"""
178193
Create a prepared statement for a query asynchronously.
179194
180195
Parameters
181196
----------
182197
query : str
183198
Query to prepare.
199+
parameters : dict[str, Any]
200+
Parameters for the query.
184201
185202
Returns
186203
-------
187204
PreparedStatement
188205
Prepared statement.
189206
190207
"""
191-
loop = asyncio.get_running_loop()
192-
conn, conn_index = self.__get_connection_with_least_queries()
193-
194-
try:
195-
preparedStatement = await loop.run_in_executor(self.executor, conn.prepare, query)
196-
return preparedStatement
197-
finally:
198-
self.__decrement_connection_counter(conn_index)
208+
warnings.warn(
209+
"The use of separate prepare + execute of queries is deprecated. "
210+
"Please using a single call to the execute() API instead.",
211+
DeprecationWarning,
212+
stacklevel=2,
213+
)
214+
return await self._prepare(query, parameters)
199215

200216
def close(self) -> None:
201217
"""

src_py/connection.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from typing import TYPE_CHECKING, Any, Callable
45

56
from . import _kuzu
@@ -129,7 +130,7 @@ def execute(
129130
if len(parameters) == 0 and isinstance(query, str):
130131
_query_result = self._connection.query(query)
131132
else:
132-
prepared_statement = self.prepare(query) if isinstance(query, str) else query
133+
prepared_statement = self._prepare(query, parameters) if isinstance(query, str) else query
133134
_query_result = self._connection.execute(prepared_statement._prepared_statement, parameters)
134135
if not _query_result.isSuccess():
135136
raise RuntimeError(_query_result.getErrorMessage())
@@ -144,7 +145,22 @@ def execute(
144145
all_query_results.append(QueryResult(self, _query_result))
145146
return all_query_results
146147

147-
def prepare(self, query: str) -> PreparedStatement:
148+
def _prepare(
149+
self,
150+
query: str,
151+
parameters: dict[str, Any] | None = None,
152+
) -> PreparedStatement:
153+
"""
154+
The only parameters supported during prepare are dataframes.
155+
Any remaining parameters will be ignored and should be passed to execute().
156+
"""
157+
return PreparedStatement(self, query, parameters)
158+
159+
def prepare(
160+
self,
161+
query: str,
162+
parameters: dict[str, Any] | None = None,
163+
) -> PreparedStatement:
148164
"""
149165
Create a prepared statement for a query.
150166
@@ -153,13 +169,22 @@ def prepare(self, query: str) -> PreparedStatement:
153169
query : str
154170
Query to prepare.
155171
172+
parameters : dict[str, Any]
173+
Parameters for the query.
174+
156175
Returns
157176
-------
158177
PreparedStatement
159178
Prepared statement.
160179
161180
"""
162-
return PreparedStatement(self, query)
181+
warnings.warn(
182+
"The use of separate prepare + execute of queries is deprecated. "
183+
"Please using a single call to the execute() API instead.",
184+
DeprecationWarning,
185+
stacklevel=2,
186+
)
187+
return self._prepare(query, parameters)
163188

164189
def _get_node_property_names(self, table_name: str) -> dict[str, Any]:
165190
LIST_START_SYMBOL = "["

src_py/prepared_statement.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Any
44

55
if TYPE_CHECKING:
66
from .connection import Connection
@@ -12,16 +12,20 @@ class PreparedStatement:
1212
same query for repeated execution.
1313
"""
1414

15-
def __init__(self, connection: Connection, query: str):
15+
def __init__(self, connection: Connection, query: str, parameters: dict[str, Any] | None = None):
1616
"""
1717
Parameters
1818
----------
1919
connection : Connection
2020
Connection to a database.
2121
query : str
2222
Query to prepare.
23+
parameters : dict[str, Any]
24+
Parameters for the query.
2325
"""
24-
self._prepared_statement = connection._connection.prepare(query)
26+
if parameters is None:
27+
parameters = {}
28+
self._prepared_statement = connection._connection.prepare(query, parameters)
2529
self._connection = connection
2630

2731
def is_success(self) -> bool:

0 commit comments

Comments
 (0)