Skip to content

Commit 51b6a24

Browse files
committed
Commit remaining Arrow table and Python API changes
1 parent 14cd99c commit 51b6a24

5 files changed

Lines changed: 1498 additions & 0 deletions

File tree

src_cpp/include/py_connection.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class PyConnection {
5252

5353
std::unique_ptr<PyQueryResult> createArrowTable(const std::string& tableName,
5454
py::object arrowTable);
55+
std::unique_ptr<PyQueryResult> createArrowRelTable(const std::string& tableName,
56+
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName);
5557
std::unique_ptr<PyQueryResult> dropArrowTable(const std::string& tableName);
5658

5759
static Value transformPythonValue(const py::handle& val);

src_cpp/py_connection.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ void PyConnection::initialize(py::handle& m) {
4646
.def("remove_function", &PyConnection::removeScalarFunction, py::arg("name"))
4747
.def("create_arrow_table", &PyConnection::createArrowTable, py::arg("table_name"),
4848
py::arg("arrow_table"))
49+
.def("create_arrow_rel_table", &PyConnection::createArrowRelTable, py::arg("table_name"),
50+
py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name"))
4951
.def("drop_arrow_table", &PyConnection::dropArrowTable, py::arg("table_name"));
5052
PyDateTime_IMPORT;
5153
}
@@ -817,6 +819,41 @@ std::unique_ptr<PyQueryResult> PyConnection::createArrowTable(const std::string&
817819
return checkAndWrapQueryResult(result.queryResult);
818820
}
819821

822+
std::unique_ptr<PyQueryResult> PyConnection::createArrowRelTable(const std::string& tableName,
823+
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName) {
824+
py::gil_scoped_acquire acquire;
825+
826+
if (PyConnection::isPandasDataframe(arrowTable)) {
827+
arrowTable = importCache->pyarrow.lib.Table.from_pandas()(arrowTable);
828+
} else if (PyConnection::isPolarsDataframe(arrowTable)) {
829+
arrowTable = arrowTable.attr("to_arrow")();
830+
}
831+
if (!PyConnection::isPyArrowTable(arrowTable)) {
832+
throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame");
833+
}
834+
835+
ArrowSchemaWrapper schema;
836+
arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema));
837+
std::vector<ArrowArrayWrapper> arrays;
838+
py::list batches = arrowTable.attr("to_batches")();
839+
for (auto& batch : batches) {
840+
arrays.emplace_back();
841+
batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&arrays.back()));
842+
}
843+
844+
py::list keepAlive;
845+
keepAlive.append(arrowTable);
846+
keepAlive.append(batches);
847+
848+
auto result = ArrowTableSupport::createRelTableFromArrowTable(*conn, tableName, srcTableName,
849+
dstTableName, std::move(schema), std::move(arrays));
850+
if (result.queryResult && result.queryResult->isSuccess()) {
851+
arrowTableRefs[tableName] = std::move(keepAlive);
852+
}
853+
854+
return checkAndWrapQueryResult(result.queryResult);
855+
}
856+
820857
std::unique_ptr<PyQueryResult> PyConnection::dropArrowTable(const std::string& tableName) {
821858
auto result = ArrowTableSupport::unregisterArrowTable(*conn, tableName);
822859
if (result && result->isSuccess()) {

src_py/connection.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,44 @@ def drop_arrow_table(self, table_name: str) -> QueryResult:
370370
if not query_result_internal.isSuccess():
371371
raise RuntimeError(query_result_internal.getErrorMessage())
372372
return QueryResult(self, query_result_internal)
373+
374+
def create_arrow_rel_table(
375+
self,
376+
table_name: str,
377+
dataframe: Any,
378+
src_table_name: str,
379+
dst_table_name: str,
380+
) -> QueryResult:
381+
"""
382+
Create an Arrow memory-backed relationship table from a DataFrame.
383+
384+
Parameters
385+
----------
386+
table_name : str
387+
Name of the relationship table to create.
388+
389+
dataframe : Any
390+
A pandas DataFrame, polars DataFrame, or PyArrow table.
391+
392+
src_table_name : str
393+
Source node table name in the FROM/TO pair.
394+
395+
dst_table_name : str
396+
Destination node table name in the FROM/TO pair.
397+
398+
Returns
399+
-------
400+
QueryResult
401+
Result of the table creation query.
402+
403+
"""
404+
self.init_connection()
405+
query_result_internal = self._connection.create_arrow_rel_table(
406+
table_name,
407+
dataframe,
408+
src_table_name,
409+
dst_table_name,
410+
)
411+
if not query_result_internal.isSuccess():
412+
raise RuntimeError(query_result_internal.getErrorMessage())
413+
return QueryResult(self, query_result_internal)
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import polars as pl
2+
import pytest
3+
from type_aliases import ConnDB
4+
5+
6+
def test_arrow_memory_backed_table_basic(conn_db_empty: ConnDB) -> None:
7+
"""Test basic Arrow memory-backed table creation and querying with polars."""
8+
conn, _ = conn_db_empty
9+
10+
# Create a polars DataFrame
11+
df = pl.DataFrame({
12+
"id": [1, 2, 3, 4, 5],
13+
"name": ["Alice", "Bob", "Charlie", "Diana", "Eve"],
14+
"age": [25, 30, 35, 40, 45],
15+
"salary": [50000.0, 60000.0, 75000.0, 90000.0, 100000.0],
16+
})
17+
18+
# Register the Arrow table
19+
conn.create_arrow_table("employees", df)
20+
21+
# Query all data
22+
result = conn.execute("MATCH (n:employees) RETURN n.id, n.name, n.age, n.salary ORDER BY n.id")
23+
rows = []
24+
while result.has_next():
25+
rows.append(result.get_next())
26+
27+
assert len(rows) == 5
28+
assert rows[0] == [1, "Alice", 25, 50000.0]
29+
assert rows[1] == [2, "Bob", 30, 60000.0]
30+
assert rows[2] == [3, "Charlie", 35, 75000.0]
31+
assert rows[3] == [4, "Diana", 40, 90000.0]
32+
assert rows[4] == [5, "Eve", 45, 100000.0]
33+
34+
# Clean up
35+
conn.drop_arrow_table("employees")
36+
37+
38+
def test_arrow_memory_backed_table_filtering(conn_db_empty: ConnDB) -> None:
39+
"""Test filtering rows from an Arrow memory-backed table using Cypher."""
40+
conn, _ = conn_db_empty
41+
42+
# Create a polars DataFrame with more data
43+
df = pl.DataFrame({
44+
"id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
45+
"name": ["Alice", "Bob", "Charlie", "Diana", "Eve", "Frank", "Grace", "Henry", "Ivy", "Jack"],
46+
"age": [25, 30, 35, 40, 45, 28, 33, 38, 42, 50],
47+
"department": ["Engineering", "Sales", "Engineering", "HR", "Sales",
48+
"Engineering", "HR", "Sales", "Engineering", "HR"],
49+
"salary": [50000.0, 60000.0, 75000.0, 55000.0, 70000.0,
50+
52000.0, 58000.0, 65000.0, 80000.0, 60000.0],
51+
})
52+
53+
# Register the Arrow table
54+
conn.create_arrow_table("staff", df)
55+
56+
# Test 1: Filter by age > 35
57+
result = conn.execute("MATCH (n:staff) WHERE n.age > 35 RETURN n.name, n.age ORDER BY n.age")
58+
rows = []
59+
while result.has_next():
60+
rows.append(result.get_next())
61+
62+
assert len(rows) == 5
63+
assert rows[0] == ["Henry", 38]
64+
assert rows[1] == ["Diana", 40]
65+
assert rows[2] == ["Ivy", 42]
66+
assert rows[3] == ["Eve", 45]
67+
assert rows[4] == ["Jack", 50]
68+
69+
# Test 2: Filter by department
70+
result = conn.execute(
71+
"MATCH (n:staff) WHERE n.department = 'Engineering' RETURN n.name, n.department ORDER BY n.id"
72+
)
73+
rows = []
74+
while result.has_next():
75+
rows.append(result.get_next())
76+
77+
assert len(rows) == 4
78+
assert rows[0] == ["Alice", "Engineering"]
79+
assert rows[1] == ["Charlie", "Engineering"]
80+
assert rows[2] == ["Frank", "Engineering"]
81+
assert rows[3] == ["Ivy", "Engineering"]
82+
83+
# Test 3: Filter by salary range
84+
result = conn.execute(
85+
"MATCH (n:staff) WHERE n.salary >= 60000.0 AND n.salary <= 75000.0 "
86+
"RETURN n.name, n.salary ORDER BY n.salary"
87+
)
88+
rows = []
89+
while result.has_next():
90+
rows.append(result.get_next())
91+
92+
assert len(rows) == 5
93+
assert rows[0] == ["Bob", 60000.0]
94+
assert rows[1] == ["Jack", 60000.0]
95+
assert rows[2] == ["Henry", 65000.0]
96+
assert rows[3] == ["Eve", 70000.0]
97+
assert rows[4] == ["Charlie", 75000.0]
98+
99+
# Test 4: Complex filter with AND/OR
100+
result = conn.execute(
101+
"MATCH (n:staff) WHERE (n.department = 'Engineering' AND n.salary > 60000.0) "
102+
"OR n.age > 45 RETURN n.name, n.department, n.salary, n.age ORDER BY n.id"
103+
)
104+
rows = []
105+
while result.has_next():
106+
rows.append(result.get_next())
107+
108+
assert len(rows) == 3
109+
assert rows[0] == ["Charlie", "Engineering", 75000.0, 35]
110+
assert rows[1] == ["Ivy", "Engineering", 80000.0, 42]
111+
assert rows[2] == ["Jack", "HR", 60000.0, 50]
112+
113+
# Clean up
114+
conn.drop_arrow_table("staff")
115+
116+
117+
def test_arrow_memory_backed_table_with_pandas(conn_db_empty: ConnDB) -> None:
118+
"""Test Arrow memory-backed table with pandas DataFrame."""
119+
conn, _ = conn_db_empty
120+
121+
pd = pytest.importorskip("pandas")
122+
123+
# Create a pandas DataFrame
124+
df = pd.DataFrame({
125+
"product_id": [101, 102, 103, 104, 105],
126+
"product_name": ["Widget A", "Widget B", "Gadget X", "Gadget Y", "Tool Z"],
127+
"price": [9.99, 14.99, 29.99, 34.99, 49.99],
128+
"in_stock": [True, True, False, True, False],
129+
})
130+
131+
# Register the Arrow table
132+
conn.create_arrow_table("products", df)
133+
134+
# Query with filter
135+
result = conn.execute(
136+
"MATCH (n:products) WHERE n.in_stock = true AND n.price < 20.0 "
137+
"RETURN n.product_name, n.price ORDER BY n.price"
138+
)
139+
rows = []
140+
while result.has_next():
141+
rows.append(result.get_next())
142+
143+
assert len(rows) == 2
144+
assert rows[0] == ["Widget A", 9.99]
145+
assert rows[1] == ["Widget B", 14.99]
146+
147+
# Clean up
148+
conn.drop_arrow_table("products")
149+
150+
151+
def test_arrow_memory_backed_table_with_pyarrow(conn_db_empty: ConnDB) -> None:
152+
"""Test Arrow memory-backed table with native PyArrow table."""
153+
conn, _ = conn_db_empty
154+
155+
import pyarrow as pa
156+
157+
# Create a PyArrow table directly
158+
table = pa.table({
159+
"city": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"],
160+
"population": [8419000, 3980000, 2716000, 2328000, 1690000],
161+
"area_sq_miles": [302.6, 468.7, 227.3, 637.5, 517.6],
162+
})
163+
164+
# Register the Arrow table
165+
conn.create_arrow_table("cities", table)
166+
167+
# Query with filter
168+
result = conn.execute(
169+
"MATCH (n:cities) WHERE n.population > 2000000 AND n.area_sq_miles < 400 "
170+
"RETURN n.city, n.population, n.area_sq_miles ORDER BY n.population DESC"
171+
)
172+
rows = []
173+
while result.has_next():
174+
rows.append(result.get_next())
175+
176+
assert len(rows) == 2
177+
assert rows[0] == ["New York", 8419000, 302.6]
178+
assert rows[1] == ["Chicago", 2716000, 227.3]
179+
180+
# Clean up
181+
conn.drop_arrow_table("cities")
182+
183+
184+
def test_arrow_memory_backed_table_empty_result(conn_db_empty: ConnDB) -> None:
185+
"""Test filtering that returns no results."""
186+
conn, _ = conn_db_empty
187+
188+
df = pl.DataFrame({
189+
"id": [1, 2, 3],
190+
"value": [10, 20, 30],
191+
})
192+
193+
conn.create_arrow_table("data", df)
194+
195+
# Filter that matches nothing
196+
result = conn.execute("MATCH (n:data) WHERE n.value > 100 RETURN n.id")
197+
assert not result.has_next()
198+
199+
# Clean up
200+
conn.drop_arrow_table("data")
201+
202+
203+
def test_arrow_memory_backed_table_count(conn_db_empty: ConnDB) -> None:
204+
"""Test aggregation on Arrow memory-backed table."""
205+
conn, _ = conn_db_empty
206+
207+
df = pl.DataFrame({
208+
"category": ["A", "B", "A", "C", "B", "A", "C", "B"],
209+
"amount": [100, 200, 150, 300, 250, 120, 280, 180],
210+
})
211+
212+
conn.create_arrow_table("transactions", df)
213+
214+
# Count by category
215+
result = conn.execute(
216+
"MATCH (n:transactions) RETURN n.category, COUNT(*) as cnt ORDER BY n.category"
217+
)
218+
rows = []
219+
while result.has_next():
220+
rows.append(result.get_next())
221+
222+
assert len(rows) == 3
223+
assert rows[0] == ["A", 3]
224+
assert rows[1] == ["B", 3]
225+
assert rows[2] == ["C", 2]
226+
227+
# Clean up
228+
conn.drop_arrow_table("transactions")

0 commit comments

Comments
 (0)