Skip to content

Commit 3f10a1c

Browse files
committed
[SQL] Optimized implementation of ARG_MAX/ARG_MIN for append-only relations
Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
1 parent 7236c74 commit 3f10a1c

17 files changed

Lines changed: 550 additions & 232 deletions

File tree

crates/sqllib/src/aggregates.rs

Lines changed: 155 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use dbsp::algebra::{FirstLargeValue, HasOne, HasZero, SignedPrimInt, UnsignedPri
88
use num::PrimInt;
99
use num_traits::CheckedAdd;
1010
use std::cmp::Ord;
11-
use std::fmt::Debug;
11+
use std::fmt::{Debug, Display};
1212
use std::marker::Copy;
1313

1414
/// Holds some methods for wrapping values into unsigned values
@@ -32,7 +32,7 @@ impl UnsignedWrapper {
3232
S: PrimInt,
3333
I: SignedPrimInt + From<S>,
3434
U: UnsignedPrimInt + TryFrom<I> + Debug,
35-
<U as TryFrom<I>>::Error: std::fmt::Debug,
35+
<U as TryFrom<I>>::Error: Debug,
3636
{
3737
let s = <O as ToInteger<S>>::to_integer(&value);
3838
let i = <I as From<S>>::from(s);
@@ -53,7 +53,7 @@ impl UnsignedWrapper {
5353
S: SignedPrimInt,
5454
I: SignedPrimInt + From<S>,
5555
U: UnsignedPrimInt + TryFrom<I> + HasZero + Debug,
56-
<U as TryFrom<I>>::Error: std::fmt::Debug,
56+
<U as TryFrom<I>>::Error: Debug,
5757
{
5858
match value {
5959
None => {
@@ -78,8 +78,8 @@ impl UnsignedWrapper {
7878
S: SignedPrimInt + TryFrom<I> + Debug,
7979
I: SignedPrimInt + From<S> + TryFrom<U>,
8080
U: UnsignedPrimInt + TryFrom<I>,
81-
<I as TryFrom<U>>::Error: std::fmt::Debug,
82-
<S as TryFrom<I>>::Error: std::fmt::Debug,
81+
<I as TryFrom<U>>::Error: Debug,
82+
<S as TryFrom<I>>::Error: Debug,
8383
{
8484
let i = <I as TryFrom<U>>::try_from(value).unwrap();
8585
let i = i + <I as From<S>>::from(S::min_value()) - <I as HasOne>::one();
@@ -98,8 +98,8 @@ impl UnsignedWrapper {
9898
S: SignedPrimInt + TryFrom<U> + TryFrom<I>,
9999
I: SignedPrimInt + From<S> + TryFrom<U>,
100100
U: UnsignedPrimInt + TryFrom<I> + Debug,
101-
<I as TryFrom<U>>::Error: std::fmt::Debug,
102-
<S as TryFrom<I>>::Error: std::fmt::Debug,
101+
<I as TryFrom<U>>::Error: Debug,
102+
<S as TryFrom<I>>::Error: Debug,
103103
{
104104
if nullsLast {
105105
if <U as FirstLargeValue>::large() == value {
@@ -264,16 +264,16 @@ macro_rules! for_all_int_aggregate_non_null {
264264
// Macro to create variants of an aggregation function
265265
// There must exist a function f__(left: T, right: T) -> T
266266
// This creates 3 more functions
267-
// f_N_(left: Option<T>, right: T) -> Option<T>
267+
// f_N_<T>(left: Option<T>, right: T) -> Option<T>
268268
// etc.
269269
// And 4 more functions:
270-
// f_N_N_conditional(left: T, right: T, predicate: bool) -> T
270+
// f_N_N_conditional<T>(left: T, right: T, predicate: bool) -> T
271271
macro_rules! universal_aggregate {
272-
($func:ident) => {
272+
($func:ident, $t: ty where $($bounds:tt)*) => {
273273
::paste::paste! {
274274
#[doc(hidden)]
275-
pub fn [<$func _N_ >]<T>( left: Option<T>, right: T ) -> Option<T>
276-
where T: Ord + Clone,
275+
pub fn [<$func _N_ >]<$t>( left: Option<$t>, right: $t ) -> Option<$t>
276+
where $($bounds)*
277277
{
278278
match left {
279279
None => Some(right.clone()),
@@ -282,8 +282,8 @@ macro_rules! universal_aggregate {
282282
}
283283

284284
#[doc(hidden)]
285-
pub fn [<$func __N>]<T>( left: T, right: Option<T> ) -> Option<T>
286-
where T: Ord + Clone,
285+
pub fn [<$func __N>]<$t>( left: $t, right: Option<$t> ) -> Option<$t>
286+
where $($bounds)*
287287
{
288288
match right {
289289
None => Some(left.clone()),
@@ -292,19 +292,19 @@ macro_rules! universal_aggregate {
292292
}
293293

294294
#[doc(hidden)]
295-
pub fn [<$func _N_N>]<T>( left: Option<T>, right: Option<T> ) -> Option<T>
296-
where T: Ord + Clone,
295+
pub fn [<$func _N_N>]<$t>( left: Option<$t>, right: Option<$t> ) -> Option<$t>
296+
where $($bounds)*
297297
{
298298
match (left.clone(), right.clone()) {
299-
(None, _) => right.clone(),
300-
(_, None) => left.clone(),
299+
(None, right) => right,
300+
(left, None) => left,
301301
(Some(left), Some(right)) => Some([<$func __>](left, right)),
302302
}
303303
}
304304

305305
#[doc(hidden)]
306-
pub fn [<$func ___conditional>]<T>( left: T, right: T, predicate: bool ) -> T
307-
where T: Ord + Clone,
306+
pub fn [<$func ___conditional>]<$t>( left: $t, right: $t, predicate: bool ) -> $t
307+
where $($bounds)*
308308
{
309309
if predicate {
310310
[<$func __>](left, right)
@@ -314,35 +314,35 @@ macro_rules! universal_aggregate {
314314
}
315315

316316
#[doc(hidden)]
317-
pub fn [<$func _N__conditional>]<T>( left: Option<T>, right: T, predicate: bool ) -> Option<T>
318-
where T: Ord + Clone,
317+
pub fn [<$func _N__conditional>]<$t>( left: Option<$t>, right: $t, predicate: bool ) -> Option<$t>
318+
where $($bounds)*
319319
{
320320
match (left.clone(), right.clone(), predicate) {
321-
(_, _, false) => left.clone(),
322-
(None, _, _) => Some(right.clone()),
321+
(left, _, false) => left,
322+
(None, right, _) => Some(right),
323323
(Some(x), _, _) => Some([<$func __>](x, right)),
324324
}
325325
}
326326

327327
#[doc(hidden)]
328-
pub fn [<$func __N_conditional>]<T>( left: T, right: Option<T>, predicate: bool ) -> Option<T>
329-
where T: Ord + Clone,
328+
pub fn [<$func __N_conditional>]<$t>( left: $t, right: Option<$t>, predicate: bool ) -> Option<$t>
329+
where $($bounds)*
330330
{
331331
match (left.clone(), right.clone(), predicate) {
332-
(_, _, false) => Some(left.clone()),
333-
(_, None, _) => Some(left.clone()),
332+
(left, _, false) => Some(left),
333+
(left, None, _) => Some(left),
334334
(_, Some(y), _) => Some([<$func __>](left, y)),
335335
}
336336
}
337337

338338
#[doc(hidden)]
339-
pub fn [<$func _N_N_conditional>]<T>( left: Option<T>, right: Option<T>, predicate: bool ) -> Option<T>
340-
where T: Ord + Clone,
339+
pub fn [<$func _N_N_conditional>]<$t>( left: Option<$t>, right: Option<$t>, predicate: bool ) -> Option<$t>
340+
where $($bounds)*
341341
{
342342
match (left.clone(), right.clone(), predicate) {
343-
(_, _, false) => left.clone(),
344-
(None, _, _) => right.clone(),
345-
(_, None, _) => left.clone(),
343+
(left, _, false) => left,
344+
(None, right, _) => right,
345+
(left, None, _) => left,
346346
(Some(x), Some(y), _) => Some([< $func __ >](x, y)),
347347
}
348348
}
@@ -351,25 +351,141 @@ macro_rules! universal_aggregate {
351351
}
352352
pub(crate) use universal_aggregate;
353353

354+
#[doc(hidden)]
355+
pub fn agg_min__<T>(left: T, right: T) -> T
356+
where
357+
T: Ord + Clone + Debug,
358+
{
359+
left.min(right)
360+
}
361+
362+
universal_aggregate!(agg_min, T where T: Ord + Clone + Debug);
363+
354364
#[doc(hidden)]
355365
pub fn agg_max__<T>(left: T, right: T) -> T
356366
where
357-
T: Ord + Clone,
367+
T: Ord + Clone + Debug,
368+
{
369+
left.max(right)
370+
}
371+
372+
universal_aggregate!(agg_max, T where T: Ord + Clone + Debug);
373+
374+
fn o0<L, R>(t: (L, R)) -> (Option<L>, R) {
375+
(Some(t.0), t.1)
376+
}
377+
378+
// Macro to create variants of an aggregation function
379+
// There must exist a function f__(left: (L, R), right: (L, R)) -> (L, R)
380+
// This creates 3 more functions
381+
// f_N_<L, R>(left: (<Option<L>, R), right: (L, R)) -> (Option<L>, R)
382+
// etc.
383+
// And 4 more functions:
384+
// f_N_N_conditional<L, R>(left: (L, R), right: (L, R), predicate: bool) -> (L, R)
385+
macro_rules! universal_aggregate2 {
386+
($func:ident, $l: ty, $r: ty where $($bounds:tt)*) => {
387+
::paste::paste! {
388+
#[doc(hidden)]
389+
pub fn [<$func _N_ >]<$l, $r>( left: (Option<$l>, $r), right: ($l, $r)) -> (Option<$l>, $r)
390+
where $($bounds)*
391+
{
392+
match left {
393+
(None, _) => o0(right),
394+
(Some(left), r) => o0([<$func __>]((left, r), right)),
395+
}
396+
}
397+
398+
#[doc(hidden)]
399+
pub fn [<$func __N>]<$l, $r>( left: ($l, $r), right: (Option<$l>, $r) ) -> (Option<$l>, $r)
400+
where $($bounds)*
401+
{
402+
match right {
403+
(None, _) => o0(left),
404+
(Some(right), r) => o0([<$func __>](left, (right, r))),
405+
}
406+
}
407+
408+
#[doc(hidden)]
409+
pub fn [<$func _N_N>]<$l, $r>( left: (Option<$l>, $r), right: (Option<$l>, $r) ) -> (Option<$l>, $r)
410+
where $($bounds)*
411+
{
412+
match (left.clone(), right.clone()) {
413+
((None, _), right) => right,
414+
(left, (None, _)) => left,
415+
((Some(left), l1), (Some(right), r1)) => o0([<$func __>]((left, l1), (right, r1))),
416+
}
417+
}
418+
419+
#[doc(hidden)]
420+
pub fn [<$func ___conditional>]<$l, $r>( left: ($l, $r), right: ($l, $r), predicate: bool ) -> ($l, $r)
421+
where $($bounds)*
422+
{
423+
if predicate {
424+
[<$func __>](left, right)
425+
} else {
426+
left.clone()
427+
}
428+
}
429+
430+
#[doc(hidden)]
431+
pub fn [<$func _N__conditional>]<$l, $r>( left: (Option<$l>, $r), right: ($l, $r), predicate: bool ) -> (Option<$l>, $r)
432+
where $($bounds)*
433+
{
434+
match (left.clone(), right.clone(), predicate) {
435+
(left, _, false) => left,
436+
((None, _), right, _) => o0(right),
437+
((Some(x), r), _, _) => o0([<$func __>]((x, r), right)),
438+
}
439+
}
440+
441+
#[doc(hidden)]
442+
pub fn [<$func __N_conditional>]<$l, $r>( left: ($l, $r), right: (Option<$l>, $r), predicate: bool ) -> (Option<$l>, $r)
443+
where $($bounds)*
444+
{
445+
match (left.clone(), right.clone(), predicate) {
446+
(left, _, false) => o0(left),
447+
(left, (None, _), _) => o0(left),
448+
(_, (Some(y), r), _) => o0([<$func __>](left, (y, r))),
449+
}
450+
}
451+
452+
#[doc(hidden)]
453+
pub fn [<$func _N_N_conditional>]<$l, $r>( left: (Option<$l>, $r), right: (Option<$l>, $r), predicate: bool ) -> (Option<$l>, $r)
454+
where $($bounds)*
455+
{
456+
match (left.clone(), right.clone(), predicate) {
457+
(left, _, false) => left,
458+
((None, _), right, _) => right,
459+
(left, (None, _), _) => left,
460+
((Some(x), l), (Some(y), r), _) => o0([< $func __ >]((x, l), (y, r))),
461+
}
462+
}
463+
}
464+
};
465+
}
466+
pub(crate) use universal_aggregate2;
467+
468+
#[doc(hidden)]
469+
pub fn agg_max1__<L, R>(left: (L, R), right: (L, R)) -> (L, R)
470+
where
471+
L: Ord + Clone + Debug,
472+
R: Ord + Clone + Debug,
358473
{
359474
left.max(right)
360475
}
361476

362-
universal_aggregate!(agg_max);
477+
universal_aggregate2!(agg_max1, L, R where L: Ord + Clone + Debug, R: Ord + Clone + Debug);
363478

364479
#[doc(hidden)]
365-
pub fn agg_min__<T>(left: T, right: T) -> T
480+
pub fn agg_min1__<L, R>(left: (L, R), right: (L, R)) -> (L, R)
366481
where
367-
T: Ord + Clone,
482+
L: Ord + Clone + Debug,
483+
R: Ord + Clone + Debug,
368484
{
369485
left.min(right)
370486
}
371487

372-
universal_aggregate!(agg_min);
488+
universal_aggregate2!(agg_min1, L, R where L: Ord + Clone + Debug, R: Ord + Clone + Debug);
373489

374490
#[doc(hidden)]
375491
pub fn agg_plus<T>(left: T, right: T) -> T
@@ -411,7 +527,7 @@ some_aggregate!(agg_plus_SqlDecimal<const P: usize, const S: usize>, agg_plus, S
411527
#[doc(hidden)]
412528
pub fn agg_plus_non_null<T>(left: T, right: T) -> T
413529
where
414-
T: CheckedAdd + Copy + std::fmt::Display,
530+
T: CheckedAdd + Copy + Display,
415531
{
416532
left.checked_add(&right)
417533
.unwrap_or_else(|| panic!("Overflow during aggregation {}+{}", left, right))

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/backend/ToJsonInnerVisitor.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import org.dbsp.sqlCompiler.ir.IDBSPInnerNode;
99
import org.dbsp.sqlCompiler.ir.aggregate.DBSPAggregateList;
1010
import org.dbsp.sqlCompiler.ir.aggregate.DBSPMinMax;
11+
import org.dbsp.sqlCompiler.ir.aggregate.MinMaxAggregate;
1112
import org.dbsp.sqlCompiler.ir.expression.*;
1213
import org.dbsp.sqlCompiler.ir.expression.literal.DBSPBinaryLiteral;
1314
import org.dbsp.sqlCompiler.ir.expression.literal.DBSPBoolLiteral;
@@ -571,6 +572,13 @@ public void postorder(DBSPConditionalAggregateExpression node) {
571572
super.postorder(node);
572573
}
573574

575+
@Override
576+
public void postorder(MinMaxAggregate node) {
577+
this.property("operation");
578+
this.stream.append(node.operation.toString());
579+
super.postorder(node);
580+
}
581+
574582
@Override
575583
public void postorder(DBSPUnaryExpression node) {
576584
this.property("opcode");

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/backend/rust/RustSqlRuntimeLibrary.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.dbsp.sqlCompiler.compiler.frontend.calciteObject.CalciteObject;
3030
import org.dbsp.sqlCompiler.ir.expression.*;
3131
import org.dbsp.sqlCompiler.ir.type.*;
32+
import org.dbsp.sqlCompiler.ir.type.derived.DBSPTypeRawTuple;
3233
import org.dbsp.sqlCompiler.ir.type.primitive.*;
3334
import org.dbsp.sqlCompiler.compiler.errors.UnimplementedException;
3435
import org.dbsp.util.Utilities;
@@ -66,6 +67,8 @@ protected RustSqlRuntimeLibrary() {
6667
this.universalFunctions.put(DBSPOpcode.AGG_GTE.toString(), DBSPOpcode.AGG_GTE);
6768
this.universalFunctions.put(DBSPOpcode.AGG_MIN.toString(), DBSPOpcode.AGG_MIN);
6869
this.universalFunctions.put(DBSPOpcode.AGG_MAX.toString(), DBSPOpcode.AGG_MAX);
70+
this.universalFunctions.put(DBSPOpcode.AGG_MIN1.toString(), DBSPOpcode.AGG_MIN1);
71+
this.universalFunctions.put(DBSPOpcode.AGG_MAX1.toString(), DBSPOpcode.AGG_MAX1);
6972

7073
this.arithmeticFunctions.put("plus", DBSPOpcode.ADD);
7174
this.arithmeticFunctions.put("minus", DBSPOpcode.SUB);
@@ -174,6 +177,7 @@ public String getFunctionName(CalciteObject node,
174177
opcode == DBSPOpcode.MAX_IGNORE_NULLS || opcode == DBSPOpcode.MIN_IGNORE_NULLS ||
175178
opcode == DBSPOpcode.AGG_GTE || opcode == DBSPOpcode.AGG_LTE ||
176179
opcode == DBSPOpcode.AGG_MIN || opcode == DBSPOpcode.AGG_MAX ||
180+
opcode == DBSPOpcode.AGG_MIN1 || opcode == DBSPOpcode.AGG_MAX1 ||
177181
opcode == DBSPOpcode.IS_DISTINCT) {
178182
map = this.universalFunctions;
179183
} else if (ltype.as(DBSPTypeBool.class) != null) {
@@ -209,10 +213,18 @@ public String getFunctionName(CalciteObject node,
209213
String tsuffixl;
210214
String tsuffixr;
211215
if (map == universalFunctions) {
216+
Utilities.enforce(rtype != null);
212217
tsuffixl = "";
213218
tsuffixr = "";
219+
if (opcode == DBSPOpcode.AGG_MIN1 || opcode == DBSPOpcode.AGG_MAX1) {
220+
// The function has 2 two-tuple arguments of types (L, R).
221+
// The name is generated based on the types of the L component only
222+
Utilities.enforce(ltype.is(DBSPTypeRawTuple.class));
223+
Utilities.enforce(rtype.is(DBSPTypeRawTuple.class));
224+
ltype = ltype.to(DBSPTypeRawTuple.class).getFieldType(0);
225+
rtype = rtype.to(DBSPTypeRawTuple.class).getFieldType(0);
226+
}
214227
suffixl = ltype.nullableSuffix();
215-
Utilities.enforce(rtype != null);
216228
suffixr = rtype.nullableSuffix();
217229
} else if (opcode == DBSPOpcode.CONTROLLED_FILTER_GTE) {
218230
tsuffixl = "";

0 commit comments

Comments
 (0)