Skip to content

Commit 0c50685

Browse files
committed
[SQL] Reduce the user effort necessary for implementing linear aggregates
Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
1 parent 34df16b commit 0c50685

File tree

8 files changed

+164
-112
lines changed

8 files changed

+164
-112
lines changed

docs.feldera.com/docs/changelog.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ import TabItem from '@theme/TabItem';
1515

1616
## Unreleased
1717

18+
Simplified the way user-defined aggregates are defined -- the
19+
compiler now automates the handling of NULL values.
20+
21+
The following change doesn't affect the external Feldera API, only the
22+
pipeline's API available from a sidecar container. The `/status`
23+
endpoint no longer returns HTTP status 503 (SERVICE_UNAVAILABLE) while
24+
the pipeline is initializing. Instead it returns status OK with message
25+
body containing the "Initializing" string.
26+
27+
## 0.138.0
28+
1829
[Transaction (also known as huge-step) support](/pipelines/transactions).
1930

2031
TIMESTAMP is now the same as TIMESTAMP(3); TIME is now the same as
@@ -23,11 +34,7 @@ import TabItem from '@theme/TabItem';
2334
that differ from the default ones are ignored (and the compiler
2435
gives a warning).
2536

26-
The following change doesn't affect the external Feldera API, only the
27-
pipeline's API available from a sidecar container. The `/status`
28-
endpoint no longer returns HTTP status 503 (SERVICE_UNAVAILABLE) while
29-
the pipeline is initializing. Instead it returns status OK with message
30-
body containing the "Initializing" string.
37+
## 0.136.0
3138

3239
### Changes to Python SDK `feldera`:
3340
- `Pipeline.sync_checkpoint` will now raise a runtime error if `wait`

docs.feldera.com/docs/sql/udf.md

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -499,14 +499,8 @@ traits. Most of the code is devoted for this task, and is relatively
499499
straightforward.
500500

501501
For our example the accumulator type that the user has to define is
502-
named `i128_sum_accumulator_type`. In our implementation the
503-
accumulator is a tuple with 3 fields:
504-
505-
- the partial sum computed, stored in an I256 value
506-
507-
- the count of non-null elements in the collection encountered
508-
509-
- the total count of elements in the collection
502+
named `i128_sum_accumulator_type`, holding the partial sum computed,
503+
stored in an I256 value.
510504

511505
The user would add the following implementation to the `udf.rs` file:
512506

@@ -554,7 +548,6 @@ impl MulByRef<Weight> for I256Wrapper {
554548
type Output = Self;
555549

556550
fn mul_by_ref(&self, other: &Weight) -> Self::Output {
557-
println!("Mul {:?} by {}", self, other);
558551
Self {
559552
data: self.data.checked_mul_i64(*other)
560553
.expect("Overflow during multiplication"),
@@ -616,29 +609,18 @@ impl<D: Fallible + ?Sized> rkyv::Deserialize<I256Wrapper, D> for ArchivedI256Wra
616609
}
617610
}
618611

619-
pub type i128_sum_accumulator_type = Tup3<I256Wrapper, i64, i64>;
612+
pub type i128_sum_accumulator_type = I256Wrapper;
620613

621-
pub fn i128_sum_map(val: Option<ByteArray>) -> i128_sum_accumulator_type {
622-
match val {
623-
None => Tup3::new(I256Wrapper::zero(), 0, 1),
624-
Some(val) => Tup3::new(
625-
I256Wrapper::from(val.as_slice()),
626-
1,
627-
1,
628-
),
629-
}
614+
pub fn i128_sum_map(val: ByteArray) -> i128_sum_accumulator_type {
615+
I256Wrapper::from(val.as_slice())
630616
}
631617

632-
pub fn i128_sum_post(val: i128_sum_accumulator_type) -> Option<ByteArray> {
633-
if val.1 == 0 {
634-
None
635-
} else {
636-
// Check for overflow
637-
if val.0.data < I256::from(i128::MIN) || val.0.data > I256::from(i128::MAX) {
638-
panic!("Result of aggregation {} does not fit in 128 bits", val.0.data);
639-
}
640-
Some(ByteArray::new(&val.0.data.to_be_bytes()[16..]))
618+
pub fn i128_sum_post(val: i128_sum_accumulator_type) -> ByteArray {
619+
// Check for overflow
620+
if val.data < I256::from(i128::MIN) || val.data > I256::from(i128::MAX) {
621+
panic!("Result of aggregation {} does not fit in 128 bits", val.data);
641622
}
623+
ByteArray::new(&val.data.to_be_bytes()[16..])
642624
}
643625
```
644626

@@ -647,15 +629,15 @@ The two functions needed to implement the aggregation are
647629

648630
`i128_sum_map` converts a `BINARY(16)` value into an accumulator
649631
value. Notice that in the SQL runtime library `BINARY(16)` is
650-
implemented as a `ByteArray`.
632+
implemented as a `ByteArray`. The argument of this function must be
633+
non-nullable.
651634

652635
`i128_sum_post` converts the accumulator value into the expected
653-
result type `BINARY(16)`.
636+
result type `BINARY(16)`. The result must be non-nullable.
654637

655-
We use the `Tup3` type from our SQL runtime library. This type
656-
implements `Add` and other required operations if all fields do.
657-
The addition of `Tup3` values is done field-wise, and the `Zero` trait
658-
for `Tup3` is a tuple with all fields zero.
638+
The handling of `NULL` is dictated by the SQL semantics, and cannot be
639+
changed: aggregating a collection containing only `NULL` values (or
640+
empty) produces `NULL`.
659641

660642
### Creating user-defined non-linear aggregate functions
661643

Lines changed: 81 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ class TestUDA(unittest.TestCase):
88
def test_local(self):
99
sql = """
1010
CREATE LINEAR AGGREGATE I128_SUM(s BINARY(16)) RETURNS BINARY(16);
11-
CREATE TABLE T(x BINARY(16));
12-
CREATE MATERIALIZED VIEW V AS SELECT I128_SUM(x) AS S, COUNT(*) AS C FROM T;
11+
CREATE TABLE T(x BINARY(16), y BINARY(16) NOT NULL);
12+
CREATE MATERIALIZED VIEW V AS SELECT I128_SUM(x) AS S, I128_SUM(y) AS N, COUNT(*) AS C FROM T;
1313
"""
1414

1515
toml = """
@@ -123,29 +123,18 @@ def test_local(self):
123123
}
124124
}
125125
126-
pub type i128_sum_accumulator_type = Tup3<I256Wrapper, i64, i64>;
126+
pub type i128_sum_accumulator_type = I256Wrapper;
127127
128-
pub fn i128_sum_map(val: Option<ByteArray>) -> i128_sum_accumulator_type {
129-
match val {
130-
None => Tup3::new(I256Wrapper::zero(), 0, 1),
131-
Some(val) => Tup3::new(
132-
I256Wrapper::from(val.as_slice()),
133-
1,
134-
1,
135-
),
136-
}
128+
pub fn i128_sum_map(val: ByteArray) -> i128_sum_accumulator_type {
129+
I256Wrapper::from(val.as_slice())
137130
}
138131
139-
pub fn i128_sum_post(val: i128_sum_accumulator_type) -> Option<ByteArray> {
140-
if val.1 == 0 {
141-
None
142-
} else {
143-
// Check for overflow
144-
if val.0.data < I256::from(i128::MIN) || val.0.data > I256::from(i128::MAX) {
145-
panic!("Result of aggregation {} does not fit in 128 bits", val.0.data);
146-
}
147-
Some(ByteArray::new(&val.0.data.to_be_bytes()[16..]))
132+
pub fn i128_sum_post(val: i128_sum_accumulator_type) -> ByteArray {
133+
// Check for overflow
134+
if val.data < I256::from(i128::MIN) || val.data > I256::from(i128::MAX) {
135+
panic!("Result of aggregation {} does not fit in 128 bits", val.data);
148136
}
137+
ByteArray::new(&val.data.to_be_bytes()[16..])
149138
}
150139
"""
151140

@@ -160,14 +149,21 @@ def test_local(self):
160149
{
161150
"insert": {
162151
"x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
152+
"y": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
163153
}
164154
}
165155
],
166156
update_format="insert_delete",
167157
)
168158
pipeline.wait_for_idle()
169159
output = list(pipeline.query("SELECT * FROM V;"))
170-
assert output == [{"s": "00000000000000000000000000000001", "c": 1}]
160+
assert output == [
161+
{
162+
"s": "00000000000000000000000000000001",
163+
"n": "00000000000000000000000000000001",
164+
"c": 1,
165+
}
166+
]
171167

172168
# Insert -1
173169
pipeline.input_json(
@@ -193,59 +189,100 @@ def test_local(self):
193189
255,
194190
255,
195191
],
192+
"y": [
193+
255,
194+
255,
195+
255,
196+
255,
197+
255,
198+
255,
199+
255,
200+
255,
201+
255,
202+
255,
203+
255,
204+
255,
205+
255,
206+
255,
207+
255,
208+
255,
209+
],
196210
}
197211
}
198212
],
199213
update_format="insert_delete",
200214
)
201215
pipeline.wait_for_idle()
202216
output = list(pipeline.query("SELECT * FROM V;"))
203-
assert output == [{"s": "00000000000000000000000000000000", "c": 2}]
217+
assert output == [
218+
{
219+
"s": "00000000000000000000000000000000",
220+
"n": "00000000000000000000000000000000",
221+
"c": 2,
222+
}
223+
]
204224

205225
pipeline.input_json(
206226
"t",
207227
[
208228
{
209229
"insert": {
210230
"x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
231+
"y": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
211232
}
212233
}
213234
],
214235
update_format="insert_delete",
215236
)
216237
output = list(pipeline.query("SELECT * FROM V;"))
217-
assert output == [{"s": "00000000000000000000000000000002", "c": 3}]
238+
assert output == [
239+
{
240+
"s": "00000000000000000000000000000002",
241+
"n": "00000000000000000000000000000002",
242+
"c": 3,
243+
}
244+
]
218245

219246
pipeline.input_json(
220247
"t",
221248
[
222249
{
223250
"insert": {
224251
"x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3],
252+
"y": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3],
225253
}
226254
}
227255
],
228256
update_format="insert_delete",
229257
)
230258
output = list(pipeline.query("SELECT * FROM V;"))
231-
assert output == [{"s": "00000000000000000000000000000005", "c": 4}]
259+
assert output == [
260+
{
261+
"s": "00000000000000000000000000000005",
262+
"n": "00000000000000000000000000000005",
263+
"c": 4,
264+
}
265+
]
232266

233267
pipeline.input_json(
234268
"t",
235269
[
236270
{
237271
"delete": {
238272
"x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
273+
"y": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
239274
}
240275
},
241276
{
242277
"delete": {
243278
"x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
279+
"y": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
244280
}
245281
},
246282
{
247283
"delete": {
248284
"x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3],
285+
"y": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3],
249286
}
250287
},
251288
{
@@ -268,13 +305,31 @@ def test_local(self):
268305
255,
269306
1,
270307
],
308+
"y": [
309+
255,
310+
255,
311+
255,
312+
255,
313+
255,
314+
255,
315+
255,
316+
255,
317+
255,
318+
255,
319+
255,
320+
255,
321+
255,
322+
255,
323+
255,
324+
1,
325+
],
271326
}
272327
},
273328
],
274329
update_format="insert_delete",
275330
)
276331
output = list(pipeline.query("SELECT * FROM V;"))
277-
assert output == [{"s": None, "c": 0}]
332+
assert output == [{"s": None, "n": None, "c": 0}]
278333

279334
pipeline.stop(force=True)
280335

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/frontend/CalciteToDBSPCompiler.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3108,17 +3108,18 @@ public DBSPNode compileCreateAggregate(CreateAggregateStatement aggregate) {
31083108
String name = uda.description.name.getSimple();
31093109
if (uda.isLinear()) {
31103110
// Add two functions that the user needs to define to the circuit declarations.
3111-
DBSPTypeUser accumulatorType = LinearAggregate.accumulatorType(node, name);
3111+
DBSPType accumulatorType = LinearAggregate.userAccumulatorType(node, name);
31123112
List<DBSPParameter> parameters = Linq.map(uda.description.parameterList,
3113-
p -> new DBSPParameter(p.getName(), this.convertType(node.getPositionRange(), p.getType(), false)));
3113+
p -> new DBSPParameter(p.getName(),
3114+
this.convertType(node.getPositionRange(), p.getType(), false).withMayBeNull(false)));
31143115
DBSPFunction mapFunction = new DBSPFunction(
31153116
node, LinearAggregate.userDefinedMapFunctionName(name), parameters, accumulatorType, null, Linq.list());
31163117
this.getCircuit().addDeclaration(new DBSPDeclaration(new DBSPFunctionItem(mapFunction)));
31173118

31183119
DBSPType resultType = this.convertType(node.getPositionRange(), uda.description.returnType, false);
31193120
DBSPFunction postFunction = new DBSPFunction(
31203121
node, LinearAggregate.userDefinedPostFunctionName(name),
3121-
Linq.list(new DBSPParameter("accumulator", accumulatorType)), resultType, null, Linq.list());
3122+
Linq.list(new DBSPParameter("accumulator", accumulatorType)), resultType.withMayBeNull(false), null, Linq.list());
31223123
this.getCircuit().addDeclaration(new DBSPDeclaration(new DBSPFunctionItem(postFunction)));
31233124
} else {
31243125
throw new UnimplementedException("Non-linear user-defined aggregation functions");

0 commit comments

Comments
 (0)