-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathpy_database.cpp
More file actions
87 lines (77 loc) · 4.22 KB
/
py_database.cpp
File metadata and controls
87 lines (77 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include "include/py_database.h"
#include <memory>
#include "extension/extension.h"
#include "include/cached_import/py_cached_import.h"
#include "main/version.h"
#include "pandas/pandas_scan.h"
using namespace lbug::common;
void PyDatabase::initialize(py::handle& m) {
py::class_<PyDatabase>(m, "Database")
.def(py::init<const std::string&, uint64_t, uint64_t, bool, bool, uint64_t, bool, int64_t,
bool, bool, bool>(),
py::arg("database_path"), py::arg("buffer_pool_size") = 0,
py::arg("max_num_threads") = 0, py::arg("compression") = true,
py::arg("read_only") = false, py::arg("max_db_size") = (uint64_t)1 << 43,
py::arg("auto_checkpoint") = true, py::arg("checkpoint_threshold") = -1,
py::arg("throw_on_wal_replay_failure") = true, py::arg("enable_checksums") = true,
py::arg("enable_multi_writes") = false)
.def("scan_node_table_as_int64", &PyDatabase::scanNodeTable<std::int64_t>,
py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"),
py::arg("num_threads"))
.def("scan_node_table_as_int32", &PyDatabase::scanNodeTable<std::int32_t>,
py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"),
py::arg("num_threads"))
.def("scan_node_table_as_int16", &PyDatabase::scanNodeTable<std::int16_t>,
py::arg("table_name"), py::arg("prop_name"), py::arg("indices"), py::arg("np_array"),
py::arg("num_threads"))
.def("scan_node_table_as_double", &PyDatabase::scanNodeTable<double>, py::arg("table_name"),
py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads"))
.def("scan_node_table_as_float", &PyDatabase::scanNodeTable<float>, py::arg("table_name"),
py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads"))
.def("scan_node_table_as_bool", &PyDatabase::scanNodeTable<bool>, py::arg("table_name"),
py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads"))
.def("close", &PyDatabase::close)
.def_static("get_version", &PyDatabase::getVersion)
.def_static("get_storage_version", &PyDatabase::getStorageVersion);
}
py::str PyDatabase::getVersion() {
return py::str(Version::getVersion());
}
uint64_t PyDatabase::getStorageVersion() {
return Version::getStorageVersion();
}
PyDatabase::PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize,
uint64_t maxNumThreads, bool compression, bool readOnly, uint64_t maxDBSize,
bool autoCheckpoint, int64_t checkpointThreshold, bool throwOnWalReplayFailure,
bool enableChecksums, bool enableMultiWrites) {
auto systemConfig = SystemConfig(bufferPoolSize, maxNumThreads, compression, readOnly,
maxDBSize, autoCheckpoint);
if (checkpointThreshold >= 0) {
systemConfig.checkpointThreshold = static_cast<uint64_t>(checkpointThreshold);
}
systemConfig.throwOnWalReplayFailure = throwOnWalReplayFailure;
systemConfig.enableChecksums = enableChecksums;
systemConfig.enableMultiWrites = enableMultiWrites;
database = std::make_unique<Database>(databasePath, systemConfig);
lbug::extension::ExtensionUtils::addTableFunc<lbug::PandasScanFunction>(*database);
storageDriver = std::make_unique<StorageDriver>(database.get());
py::gil_scoped_acquire acquire;
if (lbug::importCache.get() == nullptr) {
lbug::importCache = std::make_shared<lbug::PythonCachedImport>();
}
}
PyDatabase::~PyDatabase() {}
void PyDatabase::close() {
database.reset();
}
template<class T>
void PyDatabase::scanNodeTable(const std::string& tableName, const std::string& propName,
const py::array_t<uint64_t>& indices, py::array_t<T>& result, int numThreads) {
auto indices_buffer_info = indices.request(false);
auto indices_buffer = static_cast<uint64_t*>(indices_buffer_info.ptr);
auto nodeOffsets = (offset_t*)indices_buffer;
auto result_buffer_info = result.request();
auto result_buffer = (uint8_t*)result_buffer_info.ptr;
auto size = indices.size();
storageDriver->scan(tableName, propName, nodeOffsets, size, result_buffer, numThreads);
}