Skip to content

Commit 20bf989

Browse files
committed
feat: support create/drop arrow tables from python
1 parent 33e62fa commit 20bf989

3 files changed

Lines changed: 103 additions & 1 deletion

File tree

src_cpp/include/py_connection.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ class PyConnection {
4848
const py::list& params, const std::string& retval, bool defaultNull, bool catchExceptions);
4949
void removeScalarFunction(const std::string& name);
5050

51+
std::unique_ptr<PyQueryResult> createArrowTable(const std::string& tableName,
52+
py::object arrowTable);
53+
std::unique_ptr<PyQueryResult> dropArrowTable(const std::string& tableName);
54+
5155
static Value transformPythonValue(const py::handle& val);
5256
static Value transformPythonValueAs(const py::handle& val, const LogicalType& type);
5357
static Value transformPythonValueFromParameter(const py::handle& val);

src_cpp/py_connection.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "pandas/pandas_scan.h"
1717
#include "processor/result/factorized_table.h"
1818
#include "pyarrow/pyarrow_scan.h"
19+
#include "storage/table/arrow_table_support.h"
1920
#include <format>
2021

2122
using namespace lbug::common;
@@ -42,7 +43,10 @@ void PyConnection::initialize(py::handle& m) {
4243
.def("create_function", &PyConnection::createScalarFunction, py::arg("name"),
4344
py::arg("udf"), py::arg("params_type"), py::arg("return_value"),
4445
py::arg("default_null"), py::arg("catch_exceptions"))
45-
.def("remove_function", &PyConnection::removeScalarFunction, py::arg("name"));
46+
.def("remove_function", &PyConnection::removeScalarFunction, py::arg("name"))
47+
.def("create_arrow_table", &PyConnection::createArrowTable, py::arg("table_name"),
48+
py::arg("arrow_table"))
49+
.def("drop_arrow_table", &PyConnection::dropArrowTable, py::arg("table_name"));
4650
PyDateTime_IMPORT;
4751
}
4852

@@ -768,3 +772,48 @@ void PyConnection::createScalarFunction(const std::string& name, const py::funct
768772
void PyConnection::removeScalarFunction(const std::string& name) {
769773
conn->removeUDFFunction(name);
770774
}
775+
776+
std::unique_ptr<PyQueryResult> PyConnection::createArrowTable(const std::string& tableName,
777+
py::object arrowTable) {
778+
py::gil_scoped_acquire acquire;
779+
780+
// Convert pandas/polars to pyarrow if needed
781+
if (PyConnection::isPandasDataframe(arrowTable)) {
782+
arrowTable = importCache->pyarrow.lib.Table.from_pandas()(arrowTable);
783+
} else if (PyConnection::isPolarsDataframe(arrowTable)) {
784+
arrowTable = arrowTable.attr("to_arrow")();
785+
}
786+
787+
// Ensure we have a pyarrow table
788+
if (!PyConnection::isPyArrowTable(arrowTable)) {
789+
throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame");
790+
}
791+
792+
// Export Arrow table to C Data Interface
793+
// First, get the schema
794+
ArrowSchemaWrapper schema;
795+
arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema));
796+
797+
// Get the batches (arrays)
798+
std::vector<ArrowArrayWrapper> arrays;
799+
py::list batches = arrowTable.attr("to_batches")();
800+
for (auto& batch : batches) {
801+
arrays.emplace_back();
802+
batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&arrays.back()));
803+
}
804+
805+
// Convert wrappers to raw Arrow structs for the API
806+
std::vector<ArrowArray> rawArrays;
807+
for (auto& arr : arrays) {
808+
rawArrays.push_back(static_cast<ArrowArray>(arr));
809+
}
810+
811+
auto result = ArrowTableSupport::createViewFromArrowTable(*conn, tableName, schema, rawArrays);
812+
813+
return checkAndWrapQueryResult(result.queryResult);
814+
}
815+
816+
std::unique_ptr<PyQueryResult> PyConnection::dropArrowTable(const std::string& tableName) {
817+
auto result = ArrowTableSupport::unregisterArrowTable(*conn, tableName);
818+
return checkAndWrapQueryResult(result);
819+
}

src_py/connection.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,52 @@ def remove_function(self, name: str) -> None:
321321
name of function to be removed.
322322
"""
323323
self._connection.remove_function(name)
324+
325+
def create_arrow_table(
326+
self,
327+
table_name: str,
328+
dataframe: Any,
329+
) -> QueryResult:
330+
"""
331+
Create an Arrow memory-backed table from a DataFrame.
332+
333+
Parameters
334+
----------
335+
table_name : str
336+
Name of the table to create.
337+
338+
dataframe : Any
339+
A pandas DataFrame, polars DataFrame, or PyArrow table.
340+
341+
Returns
342+
-------
343+
QueryResult
344+
Result of the table creation query.
345+
346+
"""
347+
self.init_connection()
348+
query_result_internal = self._connection.create_arrow_table(table_name, dataframe)
349+
if not query_result_internal.isSuccess():
350+
raise RuntimeError(query_result_internal.getErrorMessage())
351+
return QueryResult(self, query_result_internal)
352+
353+
def drop_arrow_table(self, table_name: str) -> QueryResult:
354+
"""
355+
Drop an Arrow memory-backed table.
356+
357+
Parameters
358+
----------
359+
table_name : str
360+
Name of the table to drop.
361+
362+
Returns
363+
-------
364+
QueryResult
365+
Result of the drop table query.
366+
367+
"""
368+
self.init_connection()
369+
query_result_internal = self._connection.drop_arrow_table(table_name)
370+
if not query_result_internal.isSuccess():
371+
raise RuntimeError(query_result_internal.getErrorMessage())
372+
return QueryResult(self, query_result_internal)

0 commit comments

Comments
 (0)