Skip to content

Commit 541b6b7

Browse files
committed
[python] Add tests for Arrow IPC query results
Signed-off-by: Mattias Matthiesen <mattias.matthiesen@eviny.no>
1 parent 8d285cf commit 541b6b7

File tree

2 files changed

+181
-3
lines changed

2 files changed

+181
-3
lines changed

python/tests/platform/test_shared_pipeline.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gzip
12
import io
23
import json
34
import os
@@ -6,9 +7,9 @@
67
import time
78
import unittest
89
import zipfile
9-
import gzip
1010

1111
import pandas as pd
12+
import pytest
1213

1314
from feldera import Pipeline
1415
from feldera.enums import CompletionTokenStatus, PipelineFieldSelector, PipelineStatus
@@ -167,6 +168,36 @@ def test_adhoc_query_json(self):
167168
got = list(resp)
168169
self.assertCountEqual(got, expected)
169170

171+
def test_adhoc_query_arrow(self):
172+
pa = pytest.importorskip("pyarrow")
173+
174+
data = "1\n2\n"
175+
self.pipeline.start()
176+
TEST_CLIENT.push_to_pipeline(self.pipeline.name, "tbl", "csv", data)
177+
178+
expected_rows = list(
179+
TEST_CLIENT.query_as_json(
180+
self.pipeline.name,
181+
"SELECT * FROM tbl ORDER BY id",
182+
)
183+
)
184+
expected_ids = [row["id"] for row in expected_rows]
185+
186+
batches_client = list(
187+
TEST_CLIENT.query_as_arrow(
188+
self.pipeline.name,
189+
"SELECT * FROM tbl ORDER BY id",
190+
)
191+
)
192+
table_client = pa.Table.from_batches(batches_client)
193+
assert table_client.column("id").to_pylist() == expected_ids
194+
195+
batches_pipeline = list(
196+
self.pipeline.query_arrow("SELECT * FROM tbl ORDER BY id")
197+
)
198+
table_pipeline = pa.Table.from_batches(batches_pipeline)
199+
assert table_pipeline.column("id").to_pylist() == expected_ids
200+
170201
def test_local(self):
171202
"""
172203
CREATE TABLE students (
@@ -347,8 +378,10 @@ def test_failed_pipeline_stop(self):
347378
self.pipeline.input_json("tbl", data, wait=False)
348379
wait_for_condition(
349380
"pipeline stops with deployment error after worker panic",
350-
lambda: self.pipeline.status() == PipelineStatus.STOPPED
351-
and len(self.pipeline.deployment_error()) > 0,
381+
lambda: (
382+
self.pipeline.status() == PipelineStatus.STOPPED
383+
and len(self.pipeline.deployment_error()) > 0
384+
),
352385
timeout_s=20.0,
353386
poll_interval_s=1.0,
354387
)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Unit tests for FelderaClient.query_as_arrow and Pipeline.query_arrow."""
2+
3+
import builtins
4+
import io
5+
import sys
6+
from unittest.mock import MagicMock
7+
8+
import pytest
9+
10+
from feldera.rest.feldera_client import FelderaClient
11+
12+
13+
def _import_arrow_modules():
14+
pa = pytest.importorskip("pyarrow")
15+
ipc = pytest.importorskip("pyarrow.ipc")
16+
return pa, ipc
17+
18+
19+
def _make_ipc_bytes(table) -> bytes:
20+
"""Serialise a ``pyarrow.Table`` to Arrow IPC stream bytes."""
21+
_, ipc = _import_arrow_modules()
22+
buf = io.BytesIO()
23+
with ipc.new_stream(buf, table.schema) as writer:
24+
if table.num_rows > 0:
25+
writer.write_table(table)
26+
return buf.getvalue()
27+
28+
29+
def _mock_response(ipc_bytes: bytes) -> MagicMock:
30+
"""Return a mock response whose ``raw`` is an Arrow IPC byte stream."""
31+
resp = MagicMock()
32+
resp.raw = io.BytesIO(ipc_bytes)
33+
return resp
34+
35+
36+
@pytest.fixture()
37+
def client() -> FelderaClient:
38+
"""A ``FelderaClient`` with a mocked HTTP layer (no real network calls)."""
39+
c = FelderaClient.__new__(FelderaClient)
40+
c.http = MagicMock()
41+
return c
42+
43+
44+
class TestQueryAsArrow:
45+
def test_non_empty_result_yields_correct_data(self, client: FelderaClient):
46+
pa, _ = _import_arrow_modules()
47+
schema = pa.schema([("id", pa.int64()), ("name", pa.utf8())])
48+
expected = pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]}, schema=schema)
49+
client.http.get.return_value = _mock_response(_make_ipc_bytes(expected))
50+
51+
batches = list(client.query_as_arrow("my_pipeline", "SELECT id, name FROM t"))
52+
result = pa.Table.from_batches(batches, schema=schema)
53+
54+
assert len(batches) > 0
55+
assert result.schema == schema
56+
assert result.num_rows == 3
57+
assert result.column("id").to_pylist() == [1, 2, 3]
58+
assert result.column("name").to_pylist() == ["a", "b", "c"]
59+
60+
def test_http_called_with_correct_params(self, client: FelderaClient):
61+
pa, _ = _import_arrow_modules()
62+
schema = pa.schema([("id", pa.int64())])
63+
table = pa.table({"id": [42]}, schema=schema)
64+
client.http.get.return_value = _mock_response(_make_ipc_bytes(table))
65+
66+
list(client.query_as_arrow("my_pipeline", "SELECT id FROM t"))
67+
68+
client.http.get.assert_called_once_with(
69+
path="/pipelines/my_pipeline/query",
70+
params={
71+
"pipeline_name": "my_pipeline",
72+
"sql": "SELECT id FROM t",
73+
"format": "arrow_ipc",
74+
},
75+
stream=True,
76+
)
77+
78+
def test_empty_result_yields_no_batches(self, client: FelderaClient):
79+
pa, _ = _import_arrow_modules()
80+
schema = pa.schema([("id", pa.int64()), ("value", pa.float64())])
81+
empty = pa.table(
82+
{
83+
"id": pa.array([], type=pa.int64()),
84+
"value": pa.array([], type=pa.float64()),
85+
},
86+
schema=schema,
87+
)
88+
client.http.get.return_value = _mock_response(_make_ipc_bytes(empty))
89+
90+
result_batches = list(
91+
client.query_as_arrow("my_pipeline", "SELECT id, value FROM t WHERE false")
92+
)
93+
94+
assert result_batches == []
95+
96+
def test_missing_pyarrow_raises_helpful_import_error(
97+
self, client: FelderaClient, monkeypatch
98+
):
99+
real_import = builtins.__import__
100+
101+
def _import(name, globals=None, locals=None, fromlist=(), level=0):
102+
if name == "pyarrow" or name.startswith("pyarrow."):
103+
raise ImportError("No module named 'pyarrow'")
104+
return real_import(name, globals, locals, fromlist, level)
105+
106+
monkeypatch.delitem(sys.modules, "pyarrow", raising=False)
107+
monkeypatch.delitem(sys.modules, "pyarrow.ipc", raising=False)
108+
monkeypatch.setattr(builtins, "__import__", _import)
109+
110+
with pytest.raises(ImportError, match="pip install feldera\\[arrow\\]"):
111+
next(client.query_as_arrow("my_pipeline", "SELECT 1"))
112+
113+
client.http.get.assert_not_called()
114+
115+
def test_response_closed_after_full_consumption(self, client: FelderaClient):
116+
pa, _ = _import_arrow_modules()
117+
schema = pa.schema([("id", pa.int64())])
118+
table = pa.table({"id": [1, 2]}, schema=schema)
119+
resp = _mock_response(_make_ipc_bytes(table))
120+
client.http.get.return_value = resp
121+
122+
list(client.query_as_arrow("my_pipeline", "SELECT id FROM t"))
123+
124+
resp.close.assert_called_once()
125+
126+
127+
class TestPipelineQueryArrow:
128+
def test_query_arrow_delegates_to_client(self):
129+
"""Pipeline.query_arrow must forward to client.query_as_arrow."""
130+
from feldera.pipeline import Pipeline
131+
132+
pipeline = Pipeline.__new__(Pipeline)
133+
pipeline._inner = MagicMock()
134+
pipeline._inner.name = "pipe1"
135+
pipeline.client = MagicMock()
136+
137+
expected = object()
138+
pipeline.client.query_as_arrow.return_value = expected
139+
140+
result = pipeline.query_arrow("SELECT x FROM v")
141+
142+
pipeline.client.query_as_arrow.assert_called_once_with(
143+
"pipe1", "SELECT x FROM v"
144+
)
145+
assert result is expected

0 commit comments

Comments
 (0)