Skip to content

Commit c926041

Browse files
committed
pyarrow backend scanning for pandas
formatting checks remove disabled tests more clang format fixes... py lint check clang tidy more clang tidy and py lint checks more and more clang tidy explicit pyarrow scan ctor possibly fixed tests not running? CI fixes fix pytest non portable type resolution solution apple clang test fix? add some requested changes apply backend switching remove fixed list update httpfs version & remove python-debug clang tidy fix array arithmetic on void revert extension version change apply clang-tidy fix compiler errors ... more clang-tidy fixes added requested changes clang-tidy
1 parent 52c1d6b commit c926041

9 files changed

Lines changed: 487 additions & 34 deletions

File tree

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ pybind11_add_module(_kuzu
1717
src_cpp/py_query_result.cpp
1818
src_cpp/py_query_result_converter.cpp
1919
src_cpp/py_conversion.cpp
20+
src_cpp/pyarrow/pyarrow_bind.cpp
21+
src_cpp/pyarrow/pyarrow_scan.cpp
2022
src_cpp/pandas/pandas_bind.cpp
2123
src_cpp/pandas/pandas_scan.cpp
2224
src_cpp/pandas/pandas_analyzer.cpp

src_cpp/include/cached_import/py_cached_modules.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ class PandasCachedItem : public PythonCachedItem {
6666
};
6767

6868
public:
69-
PandasCachedItem() : PythonCachedItem("pandas"), core(this), DataFrame(this), NA("NA", this),
70-
NaT("NaT", this) {}
69+
PandasCachedItem() : PythonCachedItem("pandas"), ArrowDtype("ArrowDtype", this), core(this), DataFrame(this),
70+
NA("NA", this), NaT("NaT", this) {}
7171

72+
PythonCachedItem ArrowDtype;
7273
CoreCachedItem core;
7374
DataFrameCachedItem DataFrame;
7475
PythonCachedItem NA;
@@ -96,9 +97,10 @@ class PyarrowCachedItem : public PythonCachedItem {
9697
class TableCachedItem : public PythonCachedItem {
9798
public:
9899
explicit TableCachedItem(PythonCachedItem* parent): PythonCachedItem("Table", parent),
99-
from_batches("from_batches", this) {}
100+
from_batches("from_batches", this), from_pandas("from_pandas", this) {}
100101

101102
PythonCachedItem from_batches;
103+
PythonCachedItem from_pandas;
102104
};
103105

104106
class LibCachedItem : public PythonCachedItem {
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#pragma once
2+
3+
#include "common/arrow/arrow_converter.h"
4+
#include "py_object_container.h"
5+
#include "pybind_include.h"
6+
7+
namespace kuzu {
8+
9+
namespace main {
10+
class ClientContext;
11+
}
12+
13+
struct Pyarrow {
14+
static std::shared_ptr<ArrowSchemaWrapper> bind(py::handle tableToBind,
15+
std::vector<common::LogicalType>& returnTypes, std::vector<std::string>& names);
16+
};
17+
18+
} // namespace kuzu
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#pragma once
2+
3+
#include "common/arrow/arrow.h"
4+
#include "function/scalar_function.h"
5+
#include "function/table/bind_data.h"
6+
#include "function/table/scan_functions.h"
7+
#include "function/table_functions.h"
8+
#include "pyarrow_bind.h"
9+
#include "pybind_include.h"
10+
11+
namespace kuzu {
12+
13+
struct PyArrowTableScanLocalState final : public function::TableFuncLocalState {
14+
ArrowArrayWrapper* arrowArray;
15+
16+
explicit PyArrowTableScanLocalState(ArrowArrayWrapper* arrowArray) : arrowArray{arrowArray} {}
17+
};
18+
19+
struct PyArrowTableScanSharedState final : public function::BaseScanSharedState {
20+
std::vector<std::shared_ptr<ArrowArrayWrapper>> chunks;
21+
uint64_t currentChunk;
22+
std::mutex lock;
23+
24+
PyArrowTableScanSharedState(
25+
uint64_t numRows, std::vector<std::shared_ptr<ArrowArrayWrapper>>&& chunks)
26+
: BaseScanSharedState{numRows}, chunks{std::move(chunks)}, currentChunk{0} {}
27+
28+
ArrowArrayWrapper* getNextChunk();
29+
};
30+
31+
struct PyArrowTableScanFunctionData final : public function::TableFuncBindData {
32+
std::shared_ptr<ArrowSchemaWrapper> schema;
33+
std::unique_ptr<py::object> table;
34+
uint64_t numRows;
35+
36+
PyArrowTableScanFunctionData(std::vector<common::LogicalType> columnTypes,
37+
std::shared_ptr<ArrowSchemaWrapper> schema, std::vector<std::string> columnNames,
38+
py::object table, uint64_t numRows)
39+
: TableFuncBindData{std::move(columnTypes), std::move(columnNames)},
40+
schema{std::move(schema)}, table{std::make_unique<py::object>(table)}, numRows{numRows} {}
41+
42+
~PyArrowTableScanFunctionData() override {
43+
py::gil_scoped_acquire acquire;
44+
table.reset();
45+
}
46+
47+
std::unique_ptr<function::TableFuncBindData> copy() const override {
48+
py::gil_scoped_acquire acquire;
49+
// the schema is considered immutable so copying it by copying the shared_ptr is fine.
50+
return std::make_unique<PyArrowTableScanFunctionData>(
51+
columnTypes, schema, columnNames, *table, numRows);
52+
}
53+
};
54+
55+
struct PyArrowTableScanFunction {
56+
static function::function_set getFunctionSet();
57+
58+
static function::TableFunction getFunction();
59+
};
60+
61+
} // namespace kuzu

src_cpp/pandas/pandas_scan.cpp

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "pandas/pandas_scan.h"
22

3+
#include "pyarrow/pyarrow_scan.h"
34
#include "function/table/bind_input.h"
45
#include "cached_import/py_cached_import.h"
56
#include "numpy/numpy_scan.h"
@@ -13,31 +14,7 @@ using namespace kuzu::catalog;
1314

1415
namespace kuzu {
1516

16-
static offset_t tableFunc(TableFuncInput&, TableFuncOutput&);
17-
static std::unique_ptr<TableFuncBindData> bindFunc(main::ClientContext*,
18-
TableFuncBindInput*);
19-
static std::unique_ptr<TableFuncSharedState> initSharedState(
20-
TableFunctionInitInput&);
21-
static std::unique_ptr<TableFuncLocalState> initLocalState(
22-
TableFunctionInitInput&, TableFuncSharedState*,
23-
storage::MemoryManager*);
24-
static bool sharedStateNext(const TableFuncBindData*,
25-
PandasScanLocalState*, TableFuncSharedState*);
26-
static void pandasBackendScanSwitch(PandasColumnBindData*, uint64_t,
27-
uint64_t, ValueVector*);
28-
29-
static TableFunction getFunction() {
30-
return TableFunction(READ_PANDAS_FUNC_NAME, tableFunc, bindFunc, initSharedState,
31-
initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::POINTER});
32-
}
33-
34-
function_set PandasScanFunction::getFunctionSet() {
35-
function_set functionSet;
36-
functionSet.push_back(getFunction().copy());
37-
return functionSet;
38-
}
39-
40-
std::unique_ptr<TableFuncBindData> bindFunc(
17+
std::unique_ptr<function::TableFuncBindData> bindFunc(
4118
main::ClientContext* /*context*/, TableFuncBindInput* input) {
4219
py::gil_scoped_acquire acquire;
4320
py::handle df(reinterpret_cast<PyObject*>(input->inputs[0].getValue<uint8_t*>()));
@@ -71,16 +48,16 @@ bool sharedStateNext(const TableFuncBindData* /*bindData*/,
7148
return true;
7249
}
7350

74-
std::unique_ptr<TableFuncLocalState> initLocalState(
75-
TableFunctionInitInput& input, TableFuncSharedState* sharedState,
76-
storage::MemoryManager*) {
51+
std::unique_ptr<function::TableFuncLocalState> initLocalState(
52+
function::TableFunctionInitInput& input, function::TableFuncSharedState* sharedState,
53+
storage::MemoryManager* /*mm*/) {
7754
auto localState = std::make_unique<PandasScanLocalState>(0 /* start */, 0 /* end */);
7855
sharedStateNext(input.bindData, localState.get(), sharedState);
7956
return localState;
8057
}
8158

82-
std::unique_ptr<TableFuncSharedState> initSharedState(
83-
TableFunctionInitInput& input) {
59+
std::unique_ptr<function::TableFuncSharedState> initSharedState(
60+
function::TableFunctionInitInput& input) {
8461
// LCOV_EXCL_START
8562
if (PyGILState_Check()) {
8663
throw RuntimeException("PandasScan called but GIL was already held!");
@@ -132,14 +109,44 @@ std::vector<std::unique_ptr<PandasColumnBindData>> PandasScanFunctionData::copyC
132109
return result;
133110
}
134111

112+
static TableFunction getFunction() {
113+
return TableFunction(READ_PANDAS_FUNC_NAME, tableFunc, bindFunc, initSharedState,
114+
initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::POINTER});
115+
}
116+
117+
function_set PandasScanFunction::getFunctionSet() {
118+
function_set functionSet;
119+
functionSet.push_back(getFunction().copy());
120+
return functionSet;
121+
}
122+
123+
static bool isPyArrowBacked(const py::handle &df) {
124+
py::list dtypes = df.attr("dtypes");
125+
if (dtypes.empty()) {
126+
return false;
127+
}
128+
129+
auto arrow_dtype = importCache->pandas.ArrowDtype();
130+
for (auto &dtype : dtypes) {
131+
if (py::isinstance(dtype, arrow_dtype)) {
132+
return true;
133+
}
134+
}
135+
return false;
136+
}
137+
135138
static std::unique_ptr<ScanReplacementData> tryReplacePD(py::dict& dict, py::str& objectName) {
136139
if (!dict.contains(objectName)) {
137140
return nullptr;
138141
}
139142
auto entry = dict[objectName];
140143
if (PyConnection::isPandasDataframe(entry)) {
141144
auto scanReplacementData = std::make_unique<ScanReplacementData>();
142-
scanReplacementData->func = getFunction();
145+
if (isPyArrowBacked(entry)) {
146+
scanReplacementData->func = PyArrowTableScanFunction::getFunction();
147+
} else {
148+
scanReplacementData->func = getFunction();
149+
}
143150
auto bindInput = TableFuncBindInput();
144151
bindInput.inputs.push_back(Value::createValue(reinterpret_cast<uint8_t*>(entry.ptr())));
145152
scanReplacementData->bindInput = std::move(bindInput);

src_cpp/py_database.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "include/cached_import/py_cached_import.h"
44
#include "pandas/pandas_scan.h"
5+
#include "pyarrow/pyarrow_scan.h"
56

67
#include <memory>
78

src_cpp/pyarrow/pyarrow_bind.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include "pyarrow/pyarrow_bind.h"
2+
3+
#include "cached_import/py_cached_import.h"
4+
#include "common/arrow/arrow.h"
5+
#include "common/arrow/arrow_converter.h"
6+
7+
namespace kuzu {
8+
9+
std::shared_ptr<ArrowSchemaWrapper> Pyarrow::bind(py::handle tableToBind,
10+
std::vector<common::LogicalType>& returnTypes, std::vector<std::string>& names) {
11+
12+
std::shared_ptr<ArrowSchemaWrapper> schema = std::make_shared<ArrowSchemaWrapper>();
13+
auto pyschema = tableToBind.attr("schema");
14+
auto exportSchemaToC = pyschema.attr("_export_to_c");
15+
exportSchemaToC(reinterpret_cast<uint64_t>(schema.get()));
16+
17+
for (int64_t i = 0; i < schema->n_children; i++) {
18+
ArrowSchema* child = schema->children[i];
19+
names.emplace_back(child->name);
20+
returnTypes.push_back(common::ArrowConverter::fromArrowSchema(child));
21+
}
22+
23+
return schema;
24+
}
25+
26+
} // namespace kuzu

src_cpp/pyarrow/pyarrow_scan.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#include "pyarrow/pyarrow_scan.h"
2+
3+
#include "cached_import/py_cached_import.h"
4+
#include "common/arrow/arrow_converter.h"
5+
#include "function/table/bind_input.h"
6+
#include "py_connection.h"
7+
#include "pybind11/pytypes.h"
8+
9+
using namespace kuzu::function;
10+
using namespace kuzu::common;
11+
using namespace kuzu::catalog;
12+
13+
namespace kuzu {
14+
15+
static std::unique_ptr<function::TableFuncBindData> bindFunc(
16+
main::ClientContext* /*context*/, TableFuncBindInput* input) {
17+
18+
py::gil_scoped_acquire acquire;
19+
py::object table(py::reinterpret_steal<py::object>(
20+
reinterpret_cast<PyObject*>(input->inputs[0].getValue<uint8_t*>())));
21+
if (py::isinstance(table, importCache->pandas.DataFrame())) {
22+
table = importCache->pyarrow.lib.Table.from_pandas()(table);
23+
}
24+
std::vector<LogicalType> returnTypes;
25+
std::vector<std::string> names;
26+
if (py::isinstance<py::dict>(table)) {
27+
KU_UNREACHABLE;
28+
}
29+
auto numRows = py::len(table);
30+
auto schema = Pyarrow::bind(table, returnTypes, names);
31+
return std::make_unique<PyArrowTableScanFunctionData>(
32+
std::move(returnTypes), std::move(schema), std::move(names), table, numRows);
33+
}
34+
35+
ArrowArrayWrapper* PyArrowTableScanSharedState::getNextChunk() {
36+
std::lock_guard<std::mutex> lck{lock};
37+
if (currentChunk == chunks.size()) {
38+
return nullptr;
39+
}
40+
return chunks[currentChunk++].get();
41+
}
42+
43+
static std::unique_ptr<function::TableFuncSharedState> initSharedState(
44+
function::TableFunctionInitInput& input) {
45+
46+
py::gil_scoped_acquire acquire;
47+
PyArrowTableScanFunctionData* bindData =
48+
dynamic_cast<PyArrowTableScanFunctionData*>(input.bindData);
49+
py::list batches = bindData->table->attr("to_batches")(DEFAULT_VECTOR_CAPACITY);
50+
std::vector<std::shared_ptr<ArrowArrayWrapper>> arrowArrayBatches;
51+
52+
for (auto& i : batches) {
53+
arrowArrayBatches.push_back(std::make_shared<ArrowArrayWrapper>());
54+
i.attr("_export_to_c")(reinterpret_cast<uint64_t>(arrowArrayBatches.back().get()));
55+
}
56+
57+
return std::make_unique<PyArrowTableScanSharedState>(
58+
bindData->numRows, std::move(arrowArrayBatches));
59+
}
60+
61+
static std::unique_ptr<function::TableFuncLocalState> initLocalState(
62+
function::TableFunctionInitInput& /*input*/, function::TableFuncSharedState* sharedState,
63+
storage::MemoryManager* /*mm*/) {
64+
65+
PyArrowTableScanSharedState* pyArrowShared =
66+
dynamic_cast<PyArrowTableScanSharedState*>(sharedState);
67+
return std::make_unique<PyArrowTableScanLocalState>(pyArrowShared->getNextChunk());
68+
}
69+
70+
static common::offset_t tableFunc(
71+
function::TableFuncInput& input, function::TableFuncOutput& output) {
72+
73+
auto arrowScanData = dynamic_cast<PyArrowTableScanFunctionData*>(input.bindData);
74+
auto arrowLocalState = dynamic_cast<PyArrowTableScanLocalState*>(input.localState);
75+
auto arrowSharedState = dynamic_cast<PyArrowTableScanSharedState*>(input.sharedState);
76+
if (arrowLocalState->arrowArray == nullptr) {
77+
return 0;
78+
}
79+
for (auto i = 0u; i < arrowScanData->columnTypes.size(); i++) {
80+
common::ArrowConverter::fromArrowArray(arrowScanData->schema->children[i],
81+
arrowLocalState->arrowArray->children[i], *output.dataChunk.getValueVector(i));
82+
}
83+
auto len = arrowLocalState->arrowArray->length;
84+
arrowLocalState->arrowArray = arrowSharedState->getNextChunk();
85+
return len;
86+
}
87+
88+
function::function_set PyArrowTableScanFunction::getFunctionSet() {
89+
90+
function_set functionSet;
91+
functionSet.push_back(
92+
std::make_unique<TableFunction>(READ_PYARROW_FUNC_NAME, tableFunc, bindFunc,
93+
initSharedState, initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::POINTER}));
94+
return functionSet;
95+
}
96+
97+
TableFunction PyArrowTableScanFunction::getFunction() {
98+
return TableFunction(READ_PYARROW_FUNC_NAME, tableFunc, bindFunc, initSharedState,
99+
initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::POINTER});
100+
}
101+
102+
} // namespace kuzu

0 commit comments

Comments
 (0)