Skip to content

Commit d652e90

Browse files
committed
py: modify kafka integration test to support multi variant test methods
- Introduces multiple test methods per pipeline variant. - Generates connector configurations from predefined variant settings. - Inherits from SharedTestPipeline instead of unittest.TestCase to allow single SQL compilation. - Separates SQL definitions into different functions instead of a single string. - Adds helper functions for polling and loopback validation. Signed-off-by: rivudhk <rivudhkr@gmail.com>
1 parent 0a33ed1 commit d652e90

File tree

1 file changed

+144
-71
lines changed

1 file changed

+144
-71
lines changed

python/tests/workloads/test_kafka_avro.py

Lines changed: 144 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import unittest
21
from tests import TEST_CLIENT
3-
from feldera import PipelineBuilder
42
import time
53
import os
64
from confluent_kafka.admin import AdminClient
75
import requests
86
import re
7+
import json
8+
9+
from tests.shared_test_pipeline import SharedTestPipeline, sql
910

1011

1112
def env(name: str, default: str) -> str:
@@ -68,14 +69,27 @@ def cleanup_kafka(sql: str, bootstrap_servers: str, registry_url: str):
6869
delete_schema_subjects(registry_url, subjects)
6970

7071

71-
# Set the limit for number of records to generate
72-
LIMIT = 1000000
72+
class Variant:
73+
"""Represents a pipeline variant whose tables and views share the same SQL but differ in connector configuration.
74+
Each variant generates unique topic, table, and view names based on the provided configuration."""
75+
76+
def __init__(self, cfg):
77+
self.id = cfg["id"]
78+
self.limit = cfg["limit"]
79+
self.partitions = cfg.get("partitions")
80+
self.sync = cfg.get("sync")
81+
self.start_from = cfg.get("start_from")
7382

83+
self.topic1 = f"my_topic_avro_{self.id}"
84+
self.topic2 = f"my_topic_avro2_{self.id}"
85+
self.source = f"t_{self.id}"
86+
self.view = f"v_{self.id}"
87+
self.loopback = f"loopback_{self.id}"
7488

75-
class TestKafkaAvro(unittest.TestCase):
76-
def test_check_avro(self):
77-
sql = f"""
78-
create table t (
89+
90+
def sql_source_table(v: Variant) -> str:
91+
return f"""
92+
create table {v.source} (
7993
id int,
8094
str varchar,
8195
dec decimal,
@@ -90,19 +104,23 @@ def test_check_avro(self):
90104
'connectors' = '[{{
91105
"transport": {{
92106
"name": "datagen",
93-
"config": {{ "plan": [{{"limit": {LIMIT}}}], "seed": 1 }}
107+
"config": {{ "plan": [{{"limit": {v.limit}}}], "seed": 1 }}
94108
}}
95109
}}]'
96110
);
111+
"""
97112

98-
create view v
113+
114+
def sql_view(v: Variant) -> str:
115+
return f"""
116+
create view {v.view}
99117
with (
100118
'connectors' = '[{{
101119
"transport": {{
102120
"name": "kafka_output",
103121
"config": {{
104122
"bootstrap.servers": "{KAFKA_BOOTSTRAP}",
105-
"topic": "my_topic_avro"
123+
"topic": "{v.topic1}"
106124
}}
107125
}},
108126
"format": {{
@@ -114,12 +132,12 @@ def test_check_avro(self):
114132
}}
115133
}},
116134
{{
117-
"index": "t_index",
135+
"index": "idx_{v.id}",
118136
"transport": {{
119137
"name": "kafka_output",
120138
"config": {{
121139
"bootstrap.servers": "{KAFKA_BOOTSTRAP}",
122-
"topic": "my_topic_avro2"
140+
"topic": "{v.topic2}"
123141
}}
124142
}},
125143
"format": {{
@@ -131,11 +149,31 @@ def test_check_avro(self):
131149
}}
132150
}}]'
133151
)
134-
as select * from t;
152+
as select * from {v.source};
153+
154+
create index idx_{v.id} on {v.view}(id);
155+
"""
135156

136-
create index t_index on v(id);
137157

138-
create table loopback (
158+
def sql_loopback_table(v: Variant) -> str:
159+
# Optional configurations that will use connector defaults if not specified
160+
config = {
161+
"bootstrap.servers": KAFKA_BOOTSTRAP,
162+
"topic": v.topic2,
163+
}
164+
165+
if v.start_from:
166+
config["start_from"] = v.start_from
167+
if v.partitions:
168+
config["partitions"] = v.partitions
169+
if v.sync:
170+
config["synchronize_partitions"] = v.sync
171+
172+
# Convert to SQL config string
173+
config_json = json.dumps(config)
174+
175+
return f"""
176+
create table {v.loopback} (
139177
id int,
140178
str varchar,
141179
dec decimal,
@@ -150,11 +188,7 @@ def test_check_avro(self):
150188
'connectors' = '[{{
151189
"transport": {{
152190
"name": "kafka_input",
153-
"config": {{
154-
"topic": "my_topic_avro2",
155-
"start_from": "earliest",
156-
"bootstrap.servers": "{KAFKA_BOOTSTRAP}"
157-
}}
191+
"config": {config_json}
158192
}},
159193
"format": {{
160194
"name": "avro",
@@ -166,61 +200,100 @@ def test_check_avro(self):
166200
}}]'
167201
);
168202
"""
169-
pipeline = PipelineBuilder(
170-
TEST_CLIENT,
171-
"test_kafka_avro",
172-
sql=sql,
173-
).create_or_replace()
174203

175-
try:
176-
pipeline.start()
177-
178-
# NOTE => total_completed_records counts all rows that are processed through each output as follows:
179-
# 1. Written by the view<v> -> Kafka
180-
# 2. Ingested into loopback table from Kafka
181-
# Thus, expected_records = generated_rows * number_of_outputs (in this case 2)
182-
expected_records = LIMIT * 2
183-
timeout_s = 1800
184-
poll_interval_s = 5
185-
186-
start_time = time.perf_counter()
187-
# Poll `total_completed_records` every `poll_interval_s` seconds until it reaches `expected_records`
188-
while True:
189-
stats = TEST_CLIENT.get_pipeline_stats(pipeline.name)
190-
completed = stats["global_metrics"]["total_completed_records"]
191-
192-
print(f"Processed {completed}/{expected_records} rows so far...")
193-
194-
if completed >= expected_records:
195-
break
196-
197-
# Prevent infinite polling
198-
if time.perf_counter() - start_time > timeout_s:
199-
raise AssertionError(
200-
f"Timeout: only {completed}/{expected_records} rows processed"
201-
)
202-
203-
time.sleep(poll_interval_s)
204-
205-
elapsed = time.perf_counter() - start_time
206-
print(
207-
f"All {completed}/{expected_records} rows processed in {elapsed:.3f}s"
204+
205+
def build_sql(configs) -> str:
206+
"""Generate SQL for the pipeline by combining all tables and view for each variant"""
207+
variants = [Variant(c) for c in configs]
208+
parts = []
209+
210+
for v in variants:
211+
parts.append(sql_source_table(v))
212+
parts.append(sql_view(v))
213+
parts.append(sql_loopback_table(v))
214+
215+
return "\n".join(parts)
216+
217+
218+
def wait_for_rows(pipeline, expected_rows, timeout_s=1800, poll_interval_s=5):
219+
"""Since records aren't processed instantaneously, wait until all rows are processed to validate completion by
220+
polling `total_completed_records` every `poll_interval_s` seconds until it reaches `expected_records`"""
221+
start = time.perf_counter()
222+
while True:
223+
stats = TEST_CLIENT.get_pipeline_stats(pipeline.name)
224+
completed = stats["global_metrics"]["total_completed_records"]
225+
print(f"Processed {completed}/{expected_rows} rows so far...")
226+
if completed >= expected_rows:
227+
return completed
228+
# Prevent infinite polling
229+
if time.perf_counter() - start > timeout_s:
230+
raise AssertionError(
231+
f"Timeout: only {completed}/{expected_rows} rows processed"
208232
)
233+
time.sleep(poll_interval_s)
209234

210-
# Validation: once finished, the loopback table should contain all generated values
211-
# Validate by comparing the hash of the source table 't' and loopback table
212235

213-
expected_hash = pipeline.query_hash("SELECT * FROM t ORDER BY id, str")
214-
result_hash = pipeline.query_hash("SELECT * FROM loopback ORDER BY id, str")
236+
def validate_loopback(self, variant: Variant):
237+
"""Validation: once finished, the loopback table should contain all generated values
238+
Validate by comparing the hash of the source table 't' and loopback table"""
239+
src_tbl_hash = self.pipeline.query_hash(
240+
f"SELECT * FROM {variant.source} ORDER BY id, str"
241+
)
242+
243+
loopback_tbl_hash = self.pipeline.query_hash(
244+
f"SELECT * FROM {variant.loopback} ORDER BY id, str"
245+
)
246+
247+
assert src_tbl_hash == loopback_tbl_hash, (
248+
f"Loopback table hash mismatch for variant {variant.id}!\n"
249+
f"Source table: {variant.source}\n"
250+
f"Loopback table: {variant.loopback}\n"
251+
f"Expected hash: {src_tbl_hash}\n"
252+
f"Got hash: {loopback_tbl_hash}"
253+
)
254+
255+
print(f"Loopback table validated successfully for variant {variant.id}")
215256

216-
assert result_hash == expected_hash, (
217-
f"Validation failed: loopback table hash mismatch!\n"
218-
f"Expected: {expected_hash}\nGot: {result_hash}"
219-
)
220-
print("Loopback table validated successfully!")
221257

258+
class TestKafkaAvro(SharedTestPipeline):
259+
"""Each test method uses its own SQL snippet and processes only its own variant."""
260+
261+
TEST_CONFIGS = [
262+
{"id": 0, "limit": 10},
263+
{"id": 1, "limit": 20},
264+
# {
265+
# "id": 2,
266+
# "limit": 1000000,
267+
# "partitions": [0],
268+
# "sync": True,
269+
# "start_from": "earliest",
270+
# },
271+
]
272+
273+
@sql(build_sql([TEST_CONFIGS[0]]))
274+
def test_kafka_avro_config_0(self):
275+
cfg = self.TEST_CONFIGS[0]
276+
variant = Variant(cfg)
277+
278+
self.pipeline.start()
279+
try:
280+
expected_rows = variant.limit * 2 # view->Kafka + Kafka->loopback
281+
wait_for_rows(self.pipeline, expected_rows)
282+
validate_loopback(self, variant)
222283
finally:
223-
pipeline.stop(force=True)
284+
self.pipeline.stop(force=True)
285+
cleanup_kafka(build_sql([cfg]), KAFKA_BOOTSTRAP, SCHEMA_REGISTRY)
224286

225-
# Cleanup Kafka and Schema Registry
226-
cleanup_kafka(sql, KAFKA_BOOTSTRAP, SCHEMA_REGISTRY)
287+
@sql(build_sql([TEST_CONFIGS[1]]))
288+
def test_kafka_avro_config_1(self):
289+
cfg = self.TEST_CONFIGS[1]
290+
variant = Variant(cfg)
291+
292+
self.pipeline.start()
293+
try:
294+
expected_rows = variant.limit * 2
295+
wait_for_rows(self.pipeline, expected_rows)
296+
validate_loopback(self, variant)
297+
finally:
298+
self.pipeline.stop(force=True)
299+
cleanup_kafka(build_sql([cfg]), KAFKA_BOOTSTRAP, SCHEMA_REGISTRY)

0 commit comments

Comments
 (0)