Skip to content

Commit 3357be1

Browse files
committed
py: support listening to case sensitive views
Signed-off-by: Abhinav Gyawali <22275402+abhizer@users.noreply.github.com>
1 parent 30fe46c commit 3357be1

File tree

4 files changed

+38
-11
lines changed

4 files changed

+38
-11
lines changed

python/feldera/_callback_runner.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,12 @@ def run(self):
3939
"""
4040

4141
pipeline = self.client.get_pipeline(self.pipeline_name)
42-
schema = pipeline.program_info["schema"]
4342

44-
if schema:
45-
schemas = [relation for relation in schema["inputs"] + schema["outputs"]]
46-
for schema in schemas:
47-
if schema["name"] == self.view_name:
48-
self.schema = schema
49-
break
43+
schemas = pipeline.tables + pipeline.views
44+
for schema in schemas:
45+
if schema.name == self.view_name:
46+
self.schema = schema
47+
break
5048

5149
if self.schema is None:
5250
raise ValueError(
@@ -66,7 +64,10 @@ def run(self):
6664
case _CallbackRunnerInstruction.PipelineStarted:
6765
# listen to the pipeline
6866
gen_obj = self.client.listen_to_pipeline(
69-
self.pipeline_name, self.view_name, format="json"
67+
self.pipeline_name,
68+
self.view_name,
69+
format="json",
70+
case_sensitive=self.schema.case_sensitive,
7071
)
7172

7273
# if there is a queue set up, inform the main thread that the listener has been started, and it can
@@ -83,7 +84,7 @@ def run(self):
8384
seq_no: Optional[int] = chunk.get("sequence_number")
8485
if data is not None and seq_no is not None:
8586
self.callback(
86-
dataframe_from_response([data], self.schema), seq_no
87+
dataframe_from_response([data], self.schema.fields), seq_no
8788
)
8889

8990
if self.queue:

python/feldera/_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def ensure_dataframe_has_columns(df: pd.DataFrame):
6060
)
6161

6262

63-
def dataframe_from_response(buffer: list[list[dict]], schema: dict):
63+
def dataframe_from_response(buffer: list[list[dict]], fields: list[dict]):
6464
"""
6565
Converts the response from Feldera to a pandas DataFrame.
6666
"""
@@ -70,7 +70,7 @@ def dataframe_from_response(buffer: list[list[dict]], schema: dict):
7070
decimal_col = []
7171
uuid_col = []
7272

73-
for column in schema["fields"]:
73+
for column in fields:
7474
column_name = column["name"]
7575
if not column["case_sensitive"]:
7676
column_name = column_name.lower()

python/feldera/rest/feldera_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,7 @@ def listen_to_pipeline(
830830
backpressure: bool = True,
831831
array: bool = False,
832832
timeout: Optional[float] = None,
833+
case_sensitive: bool = False,
833834
):
834835
"""
835836
Listen for updates to views for pipeline, yields the chunks of data
@@ -845,6 +846,7 @@ def listen_to_pipeline(
845846
"json" format, the default value is False
846847
847848
:param timeout: The amount of time in seconds to listen to the stream for
849+
:param case_sensitive: True if the table name is case sensitive, False by default
848850
"""
849851

850852
params = {
@@ -855,6 +857,8 @@ def listen_to_pipeline(
855857
if format == "json":
856858
params["array"] = _prepare_boolean_input(array)
857859

860+
table_name = f'"{table_name}"' if case_sensitive else table_name
861+
858862
resp = self.http.post(
859863
path=f"/pipelines/{pipeline_name}/egress/{table_name}",
860864
params=params,

python/tests/platform/test_shared_pipeline.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import io
77
import tempfile
88
import zipfile
9+
import sys
910

1011
from tests.shared_test_pipeline import SharedTestPipeline
1112
from tests import TEST_CLIENT, enterprise_only
@@ -19,6 +20,7 @@ def test_create_pipeline(self):
1920
"""
2021
CREATE TABLE tbl(id INT) WITH ('materialized' = 'true');
2122
CREATE MATERIALIZED VIEW v0 AS SELECT * FROM tbl;
23+
CREATE MATERIALIZED VIEW "V0" AS SELECT * FROM tbl WHERE id % 2 <> 0;
2224
"""
2325
pass
2426

@@ -65,6 +67,26 @@ def test_get_pipeline_stats(self):
6567
assert stats.get("inputs") is not None
6668
assert stats.get("outputs") is not None
6769

70+
def test_case_sensitive_views_listen(self):
71+
all_stream = self.pipeline.listen("v0")
72+
odd_stream = self.pipeline.listen("V0")
73+
74+
self.pipeline.start()
75+
self.pipeline.input_json("tbl", [{"id": i} for i in range(10)])
76+
self.pipeline.wait_for_completion()
77+
78+
all = all_stream.to_dict()
79+
odd = odd_stream.to_dict()
80+
81+
expected_all = list(self.pipeline.query("select * from v0"))
82+
expected_odd = list(self.pipeline.query('select * from "V0"'))
83+
84+
def extract_ids(x):
85+
return sorted(i["id"] for i in x)
86+
87+
assert extract_ids(all) == extract_ids(expected_all)
88+
assert extract_ids(odd) == extract_ids(expected_odd)
89+
6890
def test_adhoc_query_text(self):
6991
data = "1\n2\n"
7092
self.pipeline.start()

0 commit comments

Comments
 (0)