Skip to content

Commit 1958f1e

Browse files
committed
Revert "Revert "Implement Python Import Caching""
This reverts commit 1653b4026cec55a49747a17c8cb3c57ec6b0ed3c.
1 parent 085d3a9 commit 1958f1e

14 files changed

Lines changed: 275 additions & 37 deletions

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ file(GLOB SOURCE_PY
99
pybind11_add_module(_kuzu
1010
SHARED
1111
src_cpp/kuzu_binding.cpp
12+
src_cpp/cached_import/py_cached_item.cpp
13+
src_cpp/cached_import/py_cached_import.cpp
1214
src_cpp/py_connection.cpp
1315
src_cpp/py_database.cpp
1416
src_cpp/py_prepared_statement.cpp
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include "cached_import/py_cached_import.h"
2+
3+
namespace kuzu {
4+
5+
PythonCachedImport::~PythonCachedImport() {
6+
py::gil_scoped_acquire acquire;
7+
allObjects.clear();
8+
}
9+
10+
py::handle PythonCachedImport::addToCache(py::object obj) {
11+
auto ptr = obj.ptr();
12+
allObjects.push_back(obj);
13+
return ptr;
14+
}
15+
16+
std::shared_ptr<PythonCachedImport> importCache;
17+
18+
} // namespace kuzu
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include "cached_import/py_cached_item.h"
2+
3+
4+
#include "cached_import/py_cached_import.h"
5+
#include "common/exception/runtime.h"
6+
7+
namespace kuzu {
8+
9+
py::handle PythonCachedItem::operator()() {
10+
assert((bool)PyGILState_Check());
11+
// load if unloaded, return cached object if already loaded
12+
if (loaded) {
13+
return object;
14+
}
15+
if (parent == nullptr) {
16+
object = importCache->addToCache(std::move(py::module::import(name.c_str())));
17+
} else {
18+
object = importCache->addToCache(std::move((*parent)().attr(name.c_str())));
19+
}
20+
loaded = true;
21+
return object;
22+
}
23+
24+
} // namespace kuzu
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
#include "py_cached_modules.h"
6+
7+
namespace kuzu {
8+
9+
class PythonCachedImport {
10+
public:
11+
// Note: Callers generally acquire the GIL prior to entering functions
12+
// that require the import cache.
13+
14+
PythonCachedImport() = default;
15+
~PythonCachedImport();
16+
17+
py::handle addToCache(py::object obj);
18+
19+
DateTimeCachedItem datetime;
20+
DecimalCachedItem decimal;
21+
InspectCachedItem inspect;
22+
NumpyMaCachedItem numpyma;
23+
PandasCachedItem pandas;
24+
PyarrowCachedItem pyarrow;
25+
UUIDCachedItem uuid;
26+
27+
private:
28+
std::vector<py::object> allObjects;
29+
};
30+
31+
extern std::shared_ptr<PythonCachedImport> importCache;
32+
33+
} // namespace kuzu
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <string>
5+
6+
#include "pybind_include.h"
7+
8+
namespace kuzu {
9+
10+
class PythonCachedItem {
11+
public:
12+
explicit PythonCachedItem(const std::string& name, PythonCachedItem* parent = nullptr)
13+
: name(name), parent(parent), loaded(false) {}
14+
virtual ~PythonCachedItem() = default;
15+
16+
bool isLoaded() const {return loaded;}
17+
py::handle operator()();
18+
19+
private:
20+
std::string name;
21+
PythonCachedItem* parent;
22+
bool loaded;
23+
py::handle object;
24+
};
25+
26+
} // namespace kuzu
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#pragma once
2+
3+
#include "py_cached_item.h"
4+
5+
namespace kuzu {
6+
7+
class DateTimeCachedItem : public PythonCachedItem {
8+
9+
public:
10+
DateTimeCachedItem() : PythonCachedItem("datetime"), date("date", this),
11+
datetime("datetime", this), timedelta("timedelta", this) {}
12+
13+
PythonCachedItem date;
14+
PythonCachedItem datetime;
15+
PythonCachedItem timedelta;
16+
};
17+
18+
class DecimalCachedItem : public PythonCachedItem {
19+
20+
public:
21+
DecimalCachedItem() : PythonCachedItem("decimal"), Decimal("Decimal", this) {}
22+
23+
PythonCachedItem Decimal;
24+
};
25+
26+
class InspectCachedItem : public PythonCachedItem {
27+
28+
public:
29+
InspectCachedItem() : PythonCachedItem("inspect"), currentframe("currentframe", this) {}
30+
31+
PythonCachedItem currentframe;
32+
};
33+
34+
class NumpyMaCachedItem : public PythonCachedItem {
35+
36+
public:
37+
NumpyMaCachedItem() : PythonCachedItem("numpy.ma"), masked_array("masked_array", this) {}
38+
39+
PythonCachedItem masked_array;
40+
};
41+
42+
class PandasCachedItem : public PythonCachedItem {
43+
44+
class SeriesCachedItem : public PythonCachedItem {
45+
public:
46+
explicit SeriesCachedItem(PythonCachedItem* parent): PythonCachedItem("series", parent),
47+
Series("Series", this) {}
48+
49+
PythonCachedItem Series;
50+
};
51+
52+
class CoreCachedItem : public PythonCachedItem {
53+
public:
54+
explicit CoreCachedItem(PythonCachedItem* parent): PythonCachedItem("core", parent),
55+
series(this) {}
56+
57+
SeriesCachedItem series;
58+
};
59+
60+
class DataFrameCachedItem : public PythonCachedItem {
61+
public:
62+
explicit DataFrameCachedItem(PythonCachedItem* parent): PythonCachedItem("DataFrame", parent),
63+
from_dict("from_dict", this) {}
64+
65+
PythonCachedItem from_dict;
66+
};
67+
68+
public:
69+
PandasCachedItem() : PythonCachedItem("pandas"), core(this), DataFrame(this), NA("NA", this),
70+
NaT("NaT", this) {}
71+
72+
CoreCachedItem core;
73+
DataFrameCachedItem DataFrame;
74+
PythonCachedItem NA;
75+
PythonCachedItem NaT;
76+
};
77+
78+
class PyarrowCachedItem : public PythonCachedItem {
79+
80+
class RecordBatchCachedItem : public PythonCachedItem {
81+
public:
82+
explicit RecordBatchCachedItem(PythonCachedItem* parent): PythonCachedItem("RecordBatch", parent),
83+
_import_from_c("_import_from_c", this) {}
84+
85+
PythonCachedItem _import_from_c;
86+
};
87+
88+
class SchemaCachedItem : public PythonCachedItem {
89+
public:
90+
explicit SchemaCachedItem(PythonCachedItem* parent): PythonCachedItem("Schema", parent),
91+
_import_from_c("_import_from_c", this) {}
92+
93+
PythonCachedItem _import_from_c;
94+
};
95+
96+
class TableCachedItem : public PythonCachedItem {
97+
public:
98+
explicit TableCachedItem(PythonCachedItem* parent): PythonCachedItem("Table", parent),
99+
from_batches("from_batches", this) {}
100+
101+
PythonCachedItem from_batches;
102+
};
103+
104+
class LibCachedItem : public PythonCachedItem {
105+
public:
106+
explicit LibCachedItem(PythonCachedItem* parent): PythonCachedItem("lib", parent),
107+
RecordBatch(this), Schema(this), Table(this) {}
108+
109+
RecordBatchCachedItem RecordBatch;
110+
SchemaCachedItem Schema;
111+
TableCachedItem Table;
112+
};
113+
114+
public:
115+
PyarrowCachedItem(): PythonCachedItem("pyarrow"), lib(this) {}
116+
117+
LibCachedItem lib;
118+
};
119+
120+
class UUIDCachedItem : public PythonCachedItem {
121+
122+
public:
123+
UUIDCachedItem() : PythonCachedItem("uuid"), UUID("UUID", this) {}
124+
125+
PythonCachedItem UUID;
126+
};
127+
128+
} // namespace kuzu

src_cpp/include/py_database.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class PyDatabase {
2424
explicit PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize,
2525
uint64_t maxNumThreads, bool compression, bool readOnly, uint64_t maxDBSize);
2626

27-
~PyDatabase() = default;
27+
~PyDatabase();
2828

2929
template<class T>
3030
void scanNodeTable(const std::string& tableName, const std::string& propName,

src_cpp/pandas/pandas_analyzer.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "pandas/pandas_analyzer.h"
22

33
#include "function/built_in_function_utils.h"
4+
#include "cached_import/py_cached_import.h"
45
#include "py_conversion.h"
56

67
namespace kuzu {
@@ -37,7 +38,7 @@ common::LogicalType PandasAnalyzer::getListType(py::object& ele, bool& canConver
3738
for (auto pyVal : ele) {
3839
auto object = py::reinterpret_borrow<py::object>(pyVal);
3940
auto itemType = getItemType(object, canConvert);
40-
if (i != 0) {
41+
if (i == 0) {
4142
listType = itemType;
4243
} else {
4344
if (!upgradeType(listType, itemType)) {
@@ -88,8 +89,8 @@ static py::object findFirstNonNull(const py::handle& row, uint64_t numRows) {
8889

8990
common::LogicalType PandasAnalyzer::innerAnalyze(py::object column, bool& canConvert) {
9091
auto numRows = py::len(column);
91-
auto pandasModule = py::module::import("pandas");
92-
auto pandasSeries = pandasModule.attr("core").attr("series").attr("Series");
92+
auto pandasModule = importCache->pandas;
93+
auto pandasSeries = pandasModule.core.series.Series();
9394

9495
if (py::isinstance(column, pandasSeries)) {
9596
column = column.attr("__array__")();

src_cpp/pandas/pandas_scan.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "pandas/pandas_scan.h"
22

33
#include "function/table/bind_input.h"
4+
#include "cached_import/py_cached_import.h"
45
#include "numpy/numpy_scan.h"
56
#include "py_connection.h"
67
#include "pybind11/pytypes.h"
@@ -127,10 +128,9 @@ std::unique_ptr<Value> tryReplacePD(py::dict& dict, py::str& tableName) {
127128
}
128129

129130
std::unique_ptr<common::Value> replacePD(common::Value* value) {
130-
py::gil_scoped_acquire acquire;
131131
auto pyTableName = py::str(value->getValue<std::string>());
132132
// Here we do an exhaustive search on the frame lineage.
133-
auto currentFrame = py::module::import("inspect").attr("currentframe")();
133+
auto currentFrame = importCache->inspect.currentframe()();
134134
while (hasattr(currentFrame, "f_locals")) {
135135
auto localDict = py::reinterpret_borrow<py::dict>(currentFrame.attr("f_locals"));
136136
if (localDict) {

src_cpp/py_connection.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
#include "common/string_format.h"
66
#include "datetime.h" // from Python
7+
#include "cached_import/py_cached_import.h"
78
#include "main/connection.h"
89
#include "pandas/pandas_scan.h"
910
#include "processor/result/factorized_table.h"
1011
#include "common/types/uuid.h"
1112

1213
using namespace kuzu::common;
14+
using namespace kuzu;
1315

1416
void PyConnection::initialize(py::handle& m) {
1517
py::class_<PyConnection>(m, "Connection")
@@ -151,9 +153,7 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t<int64_t>& npArray,
151153
}
152154

153155
bool PyConnection::isPandasDataframe(const py::object& object) {
154-
// TODO(Ziyi): introduce PythonCachedImport to avoid unnecessary import.
155-
py::module pandas = py::module::import("pandas");
156-
return py::isinstance(object, pandas.attr("DataFrame"));
156+
return py::isinstance(object, importCache->pandas.DataFrame());
157157
}
158158

159159
static Value transformPythonValue(py::handle val);
@@ -176,11 +176,10 @@ std::unordered_map<std::string, std::unique_ptr<Value>> transformPythonParameter
176176
}
177177

178178
Value transformPythonValue(py::handle val) {
179-
auto datetime_mod = py::module::import("datetime");
180-
auto datetime_datetime = datetime_mod.attr("datetime");
181-
auto time_delta = datetime_mod.attr("timedelta");
182-
auto datetime_date = datetime_mod.attr("date");
183-
auto uuid = py::module::import("uuid").attr("UUID");
179+
auto datetime_datetime = importCache->datetime.datetime();
180+
auto time_delta = importCache->datetime.timedelta();
181+
auto datetime_date = importCache->datetime.date();
182+
auto uuid = importCache->uuid.UUID();
184183
if (py::isinstance<py::bool_>(val)) {
185184
return Value::createValue<bool>(val.cast<bool>());
186185
} else if (py::isinstance<py::int_>(val)) {

0 commit comments

Comments
 (0)