Skip to content

Commit 0a520d4

Browse files
authored
Add checkpoint configuration parameters to the API bindings (#4739)
1 parent 96303b4 commit 0a520d4

4 files changed

Lines changed: 49 additions & 6 deletions

File tree

src_cpp/include/py_database.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class PyDatabase {
1818
static uint64_t getStorageVersion();
1919

2020
explicit PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize,
21-
uint64_t maxNumThreads, bool compression, bool readOnly, uint64_t maxDBSize);
21+
uint64_t maxNumThreads, bool compression, bool readOnly, uint64_t maxDBSize,
22+
bool autoCheckpoint, int64_t checkpointThreshold);
2223

2324
~PyDatabase();
2425

src_cpp/py_database.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ using namespace kuzu::common;
1111

1212
void PyDatabase::initialize(py::handle& m) {
1313
py::class_<PyDatabase>(m, "Database")
14-
.def(py::init<const std::string&, uint64_t, uint64_t, bool, bool, uint64_t>(),
14+
.def(
15+
py::init<const std::string&, uint64_t, uint64_t, bool, bool, uint64_t, bool, int64_t>(),
1516
py::arg("database_path"), py::arg("buffer_pool_size") = 0,
1617
py::arg("max_num_threads") = 0, py::arg("compression") = true,
17-
py::arg("read_only") = false, py::arg("max_db_size") = (uint64_t)1 << 43)
18+
py::arg("read_only") = false, py::arg("max_db_size") = (uint64_t)1 << 43,
19+
py::arg("auto_checkpoint") = true, py::arg("checkpoint_threshold") = -1)
1820
.def("scan_node_table_as_int64", &PyDatabase::scanNodeTable<std::int64_t>,
1921
py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"),
2022
py::arg("num_threads"))
@@ -44,9 +46,13 @@ uint64_t PyDatabase::getStorageVersion() {
4446
}
4547

4648
PyDatabase::PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize,
47-
uint64_t maxNumThreads, bool compression, bool readOnly, uint64_t maxDBSize) {
48-
auto systemConfig =
49-
SystemConfig(bufferPoolSize, maxNumThreads, compression, readOnly, maxDBSize);
49+
uint64_t maxNumThreads, bool compression, bool readOnly, uint64_t maxDBSize,
50+
bool autoCheckpoint, int64_t checkpointThreshold) {
51+
auto systemConfig = SystemConfig(bufferPoolSize, maxNumThreads, compression, readOnly,
52+
maxDBSize, autoCheckpoint);
53+
if (checkpointThreshold >= 0) {
54+
systemConfig.checkpointThreshold = static_cast<uint64_t>(checkpointThreshold);
55+
}
5056
database = std::make_unique<Database>(databasePath, systemConfig);
5157
database->addTableFunction(kuzu::PandasScanFunction::name,
5258
kuzu::PandasScanFunction::getFunctionSet());

src_py/database.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __init__(
3535
lazy_init: bool = False,
3636
read_only: bool = False,
3737
max_db_size: int = (1 << 43),
38+
auto_checkpoint: bool = True,
39+
checkpoint_threshold: int = -1,
3840
):
3941
"""
4042
Parameters
@@ -73,6 +75,14 @@ def __init__(
7375
a better solution later. The value is default to 1 << 43 (8TB) under 64-bit
7476
environment and 1GB under 32-bit one.
7577
78+
auto_checkpoint: bool
79+
If true, the database will automatically checkpoint when the size of
80+
the WAL file exceeds the checkpoint threshold.
81+
82+
checkpoint_threshold: int
83+
The threshold of the WAL file size in bytes. When the size of the
84+
WAL file exceeds this threshold, the database will checkpoint if autoCheckpoint is true.
85+
7686
"""
7787
if database_path is None:
7888
database_path = ":memory:"
@@ -85,6 +95,8 @@ def __init__(
8595
self.compression = compression
8696
self.read_only = read_only
8797
self.max_db_size = max_db_size
98+
self.auto_checkpoint = auto_checkpoint
99+
self.checkpoint_threshold = checkpoint_threshold
88100
self.is_closed = False
89101

90102
self._database: Any = None # (type: _kuzu.Database from pybind11)
@@ -147,6 +159,8 @@ def init_database(self) -> None:
147159
self.compression,
148160
self.read_only,
149161
self.max_db_size,
162+
self.auto_checkpoint,
163+
self.checkpoint_threshold,
150164
)
151165

152166
def get_torch_geometric_remote_backend(

test/test_database.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,25 @@ def test_in_mem_database_no_db_path() -> None:
117117
conn.execute("CREATE (:person {name: 'Bob', age: 40});")
118118
with conn.execute("MATCH (p:person) RETURN p.*") as result:
119119
assert result.get_num_tuples() == 2
120+
121+
122+
def test_database_auto_checkpoint_config(tmp_path: Path) -> None:
123+
with kuzu.Database(database_path=tmp_path, auto_checkpoint=False) as db:
124+
assert not db.is_closed
125+
assert db._database is not None
126+
127+
conn = kuzu.Connection(db)
128+
with conn.execute("CALL current_setting('auto_checkpoint') RETURN *") as result:
129+
assert result.get_num_tuples() == 1
130+
assert result.get_next()[0] == "False"
131+
132+
133+
def test_database_checkpoint_threshold_config(tmp_path: Path) -> None:
134+
with kuzu.Database(database_path=tmp_path, checkpoint_threshold=1234) as db:
135+
assert not db.is_closed
136+
assert db._database is not None
137+
138+
conn = kuzu.Connection(db)
139+
with conn.execute("CALL current_setting('checkpoint_threshold') RETURN *") as result:
140+
assert result.get_num_tuples() == 1
141+
assert result.get_next()[0] == "1234"

0 commit comments

Comments
 (0)