11#include " pandas/pandas_scan.h"
22
3+ #include " pyarrow/pyarrow_scan.h"
34#include " function/table/bind_input.h"
45#include " cached_import/py_cached_import.h"
56#include " numpy/numpy_scan.h"
@@ -13,31 +14,7 @@ using namespace kuzu::catalog;
1314
1415namespace kuzu {
1516
16- static offset_t tableFunc (TableFuncInput&, TableFuncOutput&);
17- static std::unique_ptr<TableFuncBindData> bindFunc (main::ClientContext*,
18- TableFuncBindInput*);
19- static std::unique_ptr<TableFuncSharedState> initSharedState (
20- TableFunctionInitInput&);
21- static std::unique_ptr<TableFuncLocalState> initLocalState (
22- TableFunctionInitInput&, TableFuncSharedState*,
23- storage::MemoryManager*);
24- static bool sharedStateNext (const TableFuncBindData*,
25- PandasScanLocalState*, TableFuncSharedState*);
26- static void pandasBackendScanSwitch (PandasColumnBindData*, uint64_t ,
27- uint64_t , ValueVector*);
28-
29- static TableFunction getFunction () {
30- return TableFunction (READ_PANDAS_FUNC_NAME, tableFunc, bindFunc, initSharedState,
31- initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::POINTER});
32- }
33-
34- function_set PandasScanFunction::getFunctionSet () {
35- function_set functionSet;
36- functionSet.push_back (getFunction ().copy ());
37- return functionSet;
38- }
39-
40- std::unique_ptr<TableFuncBindData> bindFunc (
17+ std::unique_ptr<function::TableFuncBindData> bindFunc (
4118 main::ClientContext* /* context*/ , TableFuncBindInput* input) {
4219 py::gil_scoped_acquire acquire;
4320 py::handle df (reinterpret_cast <PyObject*>(input->inputs [0 ].getValue <uint8_t *>()));
@@ -71,16 +48,16 @@ bool sharedStateNext(const TableFuncBindData* /*bindData*/,
7148 return true ;
7249}
7350
74- std::unique_ptr<TableFuncLocalState> initLocalState (
75- TableFunctionInitInput& input, TableFuncSharedState* sharedState,
76- storage::MemoryManager*) {
51+ std::unique_ptr<function:: TableFuncLocalState> initLocalState (
52+ function:: TableFunctionInitInput& input, function:: TableFuncSharedState* sharedState,
53+ storage::MemoryManager* /* mm */ ) {
7754 auto localState = std::make_unique<PandasScanLocalState>(0 /* start */ , 0 /* end */ );
7855 sharedStateNext (input.bindData , localState.get (), sharedState);
7956 return localState;
8057}
8158
82- std::unique_ptr<TableFuncSharedState> initSharedState (
83- TableFunctionInitInput& input) {
59+ std::unique_ptr<function:: TableFuncSharedState> initSharedState (
60+ function:: TableFunctionInitInput& input) {
8461 // LCOV_EXCL_START
8562 if (PyGILState_Check ()) {
8663 throw RuntimeException (" PandasScan called but GIL was already held!" );
@@ -132,14 +109,44 @@ std::vector<std::unique_ptr<PandasColumnBindData>> PandasScanFunctionData::copyC
132109 return result;
133110}
134111
112+ static TableFunction getFunction () {
113+ return TableFunction (READ_PANDAS_FUNC_NAME, tableFunc, bindFunc, initSharedState,
114+ initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::POINTER});
115+ }
116+
117+ function_set PandasScanFunction::getFunctionSet () {
118+ function_set functionSet;
119+ functionSet.push_back (getFunction ().copy ());
120+ return functionSet;
121+ }
122+
123+ static bool isPyArrowBacked (const py::handle &df) {
124+ py::list dtypes = df.attr (" dtypes" );
125+ if (dtypes.empty ()) {
126+ return false ;
127+ }
128+
129+ auto arrow_dtype = importCache->pandas .ArrowDtype ();
130+ for (auto &dtype : dtypes) {
131+ if (py::isinstance (dtype, arrow_dtype)) {
132+ return true ;
133+ }
134+ }
135+ return false ;
136+ }
137+
135138static std::unique_ptr<ScanReplacementData> tryReplacePD (py::dict& dict, py::str& objectName) {
136139 if (!dict.contains (objectName)) {
137140 return nullptr ;
138141 }
139142 auto entry = dict[objectName];
140143 if (PyConnection::isPandasDataframe (entry)) {
141144 auto scanReplacementData = std::make_unique<ScanReplacementData>();
142- scanReplacementData->func = getFunction ();
145+ if (isPyArrowBacked (entry)) {
146+ scanReplacementData->func = PyArrowTableScanFunction::getFunction ();
147+ } else {
148+ scanReplacementData->func = getFunction ();
149+ }
143150 auto bindInput = TableFuncBindInput ();
144151 bindInput.inputs .push_back (Value::createValue (reinterpret_cast <uint8_t *>(entry.ptr ())));
145152 scanReplacementData->bindInput = std::move (bindInput);
0 commit comments