@@ -29,7 +29,8 @@ void PyConnection::initialize(py::handle& m) {
2929 .def (" query" , &PyConnection::query, py::arg (" statement" ))
3030 .def (" set_max_threads_for_exec" , &PyConnection::setMaxNumThreadForExec,
3131 py::arg (" num_threads" ))
32- .def (" prepare" , &PyConnection::prepare, py::arg (" query" ))
32+ .def (" prepare" , &PyConnection::prepare, py::arg (" query" ),
33+ py::arg (" parameters" ) = py::dict ())
3334 .def (" set_query_timeout" , &PyConnection::setQueryTimeout, py::arg (" timeout_in_ms" ))
3435 .def (" interrupt" , &PyConnection::interrupt)
3536 .def (" get_num_nodes" , &PyConnection::getNumNodes, py::arg (" node_name" ))
@@ -44,12 +45,36 @@ void PyConnection::initialize(py::handle& m) {
4445 PyDateTime_IMPORT;
4546}
4647
47- static std::unique_ptr<function::ScanReplacementData> tryReplacePolars (py::dict& dict,
48- py::str& objectName) {
49- if (!dict.contains (objectName)) {
50- return nullptr ;
48+ static std::vector<function::scan_replace_handle_t > lookupPythonObject (
49+ const std::string& objectName) {
50+ std::vector<function::scan_replace_handle_t > ret;
51+
52+ py::gil_scoped_acquire acquire;
53+ auto pyTableName = py::str (objectName);
54+ // Here we do an exhaustive search on the frame lineage.
55+ auto currentFrame = importCache->inspect .currentframe ()();
56+ while (hasattr (currentFrame, " f_locals" )) {
57+ auto localDict = py::cast<py::dict>(currentFrame.attr (" f_locals" ));
58+ auto hasLocalDict = !py::none ().is (localDict);
59+ if (hasLocalDict) {
60+ if (localDict.contains (pyTableName)) {
61+ ret.push_back (reinterpret_cast <function::scan_replace_handle_t >(
62+ localDict[pyTableName].ptr ()));
63+ }
64+ }
65+ auto globalDict = py::reinterpret_borrow<py::dict>(currentFrame.attr (" f_globals" ));
66+ if (globalDict) {
67+ if (globalDict.contains (pyTableName)) {
68+ ret.push_back (reinterpret_cast <function::scan_replace_handle_t >(
69+ globalDict[pyTableName].ptr ()));
70+ }
71+ }
72+ currentFrame = currentFrame.attr (" f_back" );
5173 }
52- auto entry = dict[objectName];
74+ return ret;
75+ }
76+
77+ static std::unique_ptr<function::ScanReplacementData> tryReplacePolars (py::handle& entry) {
5378 if (PyConnection::isPolarsDataframe (entry)) {
5479 auto scanReplacementData = std::make_unique<function::ScanReplacementData>();
5580 scanReplacementData->func = PyArrowTableScanFunction::getFunction ();
@@ -62,12 +87,7 @@ static std::unique_ptr<function::ScanReplacementData> tryReplacePolars(py::dict&
6287 }
6388}
6489
65- static std::unique_ptr<function::ScanReplacementData> tryReplacePyArrow (py::dict& dict,
66- py::str& objectName) {
67- if (!dict.contains (objectName)) {
68- return nullptr ;
69- }
70- auto entry = dict[objectName];
90+ static std::unique_ptr<function::ScanReplacementData> tryReplacePyArrow (py::handle& entry) {
7191 if (PyConnection::isPyArrowTable (entry)) {
7292 auto scanReplacementData = std::make_unique<function::ScanReplacementData>();
7393 scanReplacementData->func = PyArrowTableScanFunction::getFunction ();
@@ -81,59 +101,33 @@ static std::unique_ptr<function::ScanReplacementData> tryReplacePyArrow(py::dict
81101}
82102
83103static std::unique_ptr<function::ScanReplacementData> replacePythonObject (
84- const std::string& objectName ) {
104+ std::span<function:: scan_replace_handle_t > candidateHandles ) {
85105 py::gil_scoped_acquire acquire;
86- auto pyTableName = py::str (objectName);
87- // Here we do an exhaustive search on the frame lineage.
88- auto currentFrame = importCache->inspect .currentframe ()();
89- bool nameMatchFound = false ;
90- while (hasattr (currentFrame, " f_locals" )) {
91- auto localDict = py::cast<py::dict>(currentFrame.attr (" f_locals" ));
92- auto hasLocalDict = !py::none ().is (localDict);
93- if (hasLocalDict) {
94- if (localDict.contains (pyTableName)) {
95- nameMatchFound = true ;
96- }
97- auto result = tryReplacePD (localDict, pyTableName);
98- if (!result) {
99- result = tryReplacePolars (localDict, pyTableName);
100- }
101- if (!result) {
102- result = tryReplacePyArrow (localDict, pyTableName);
103- }
104- if (result) {
105- return result;
106- }
106+ for (auto * handle : candidateHandles) {
107+ auto entry = py::handle (reinterpret_cast <PyObject*>(handle));
108+ auto result = tryReplacePD (entry);
109+ if (!result) {
110+ result = tryReplacePolars (entry);
107111 }
108- auto globalDict = py::reinterpret_borrow<py::dict>(currentFrame.attr (" f_globals" ));
109- if (globalDict) {
110- if (globalDict.contains (pyTableName)) {
111- nameMatchFound = true ;
112- }
113- auto result = tryReplacePD (globalDict, pyTableName);
114- if (!result) {
115- result = tryReplacePolars (globalDict, pyTableName);
116- }
117- if (!result) {
118- result = tryReplacePyArrow (globalDict, pyTableName);
119- }
120- if (result) {
121- return result;
122- }
112+ if (!result) {
113+ result = tryReplacePyArrow (entry);
114+ }
115+ if (result) {
116+ return result;
123117 }
124- currentFrame = currentFrame.attr (" f_back" );
125118 }
126- if (nameMatchFound ) {
127- throw BinderException (
128- stringFormat ( " Variable {} found but no matches were scannable " , objectName) );
119+ if (!candidateHandles. empty () ) {
120+ throw BinderException (" Attempted to scan from unsupported python object. Can only scan "
121+ " from pandas/polars dataframes and pyarrow tables. " );
129122 }
130123 return nullptr ;
131124}
132125
133126PyConnection::PyConnection (PyDatabase* pyDatabase, uint64_t numThreads) {
134127 storageDriver = std::make_unique<kuzu::main::StorageDriver>(pyDatabase->database .get ());
135128 conn = std::make_unique<Connection>(pyDatabase->database .get ());
136- conn->getClientContext ()->addScanReplace (function::ScanReplacement (replacePythonObject));
129+ conn->getClientContext ()->addScanReplace (
130+ function::ScanReplacement (lookupPythonObject, replacePythonObject));
137131 if (numThreads > 0 ) {
138132 conn->setMaxNumThreadForExec (numThreads);
139133 }
@@ -175,8 +169,9 @@ void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) {
175169 conn->setMaxNumThreadForExec (numThreads);
176170}
177171
178- PyPreparedStatement PyConnection::prepare (const std::string& query) {
179- auto preparedStatement = conn->prepare (query);
172+ PyPreparedStatement PyConnection::prepare (const std::string& query, const py::dict& parameters) {
173+ auto params = transformPythonParameters (parameters, conn.get ());
174+ auto preparedStatement = conn->prepareWithParams (query, std::move (params));
180175 PyPreparedStatement pyPreparedStatement;
181176 pyPreparedStatement.preparedStatement = std::move (preparedStatement);
182177 return pyPreparedStatement;
@@ -261,21 +256,21 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t<int64_t>& npArray,
261256 conn->setMaxNumThreadForExec (numThreadsForExec);
262257}
263258
264- bool PyConnection::isPandasDataframe (const py::object & object) {
259+ bool PyConnection::isPandasDataframe (const py::handle & object) {
265260 if (!doesPyModuleExist (" pandas" )) {
266261 return false ;
267262 }
268263 return py::isinstance (object, importCache->pandas .DataFrame ());
269264}
270265
271- bool PyConnection::isPolarsDataframe (const py::object & object) {
266+ bool PyConnection::isPolarsDataframe (const py::handle & object) {
272267 if (!doesPyModuleExist (" polars" )) {
273268 return false ;
274269 }
275270 return py::isinstance (object, importCache->polars .DataFrame ());
276271}
277272
278- bool PyConnection::isPyArrowTable (const py::object & object) {
273+ bool PyConnection::isPyArrowTable (const py::handle & object) {
279274 if (!doesPyModuleExist (" pyarrow" )) {
280275 return false ;
281276 }
@@ -389,6 +384,9 @@ static LogicalType pyLogicalType(const py::handle& val) {
389384 childType = std::move (result);
390385 }
391386 return LogicalType::LIST (std::move (childType));
387+ } else if (PyConnection::isPyArrowTable (val) || PyConnection::isPandasDataframe (val) ||
388+ PyConnection::isPolarsDataframe (val)) {
389+ return LogicalType::POINTER ();
392390 } else {
393391 // LCOV_EXCL_START
394392 throw common::RuntimeException (
@@ -678,6 +676,9 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val,
678676 }
679677 return Value (type.copy (), std::move (children));
680678 }
679+ case LogicalTypeID::POINTER: {
680+ return Value::createValue (reinterpret_cast <uint8_t *>(val.ptr ()));
681+ }
681682 default :
682683 return transformPythonValueAs (val, type);
683684 }
0 commit comments