Skip to content

Commit 819d224

Browse files
committed
[SQL] Eliminate unnecessary clones from array_agg and map_agg
Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
1 parent 1d8b652 commit 819d224

File tree

14 files changed

+69
-70
lines changed

14 files changed

+69
-70
lines changed

crates/sqllib/src/array.rs

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -482,13 +482,7 @@ where
482482
}
483483

484484
#[doc(hidden)]
485-
pub fn array_agg<T>(
486-
accumulator: &mut Vec<T>,
487-
value: T,
488-
weight: Weight,
489-
distinct: bool,
490-
keep: bool,
491-
) -> Vec<T>
485+
pub fn array_agg<T>(accumulator: &mut Vec<T>, value: T, weight: Weight, distinct: bool, keep: bool)
492486
where
493487
T: Clone,
494488
{
@@ -502,7 +496,6 @@ where
502496
accumulator.push(value.clone())
503497
}
504498
}
505-
accumulator.to_vec()
506499
}
507500

508501
#[doc(hidden)]
@@ -512,13 +505,12 @@ pub fn array_aggN<T>(
512505
weight: Weight,
513506
distinct: bool,
514507
keep: bool,
515-
) -> Option<Vec<T>>
516-
where
508+
) where
517509
T: Clone,
518510
{
519-
accumulator
520-
.as_mut()
521-
.map(|accumulator| array_agg(accumulator, value, weight, distinct, keep))
511+
if let Some(accumulator) = accumulator.as_mut() {
512+
array_agg(accumulator, value, weight, distinct, keep)
513+
}
522514
}
523515

524516
#[doc(hidden)]
@@ -529,14 +521,11 @@ pub fn array_agg_opt<T>(
529521
distinct: bool,
530522
keep: bool,
531523
ignore_nulls: bool,
532-
) -> Vec<Option<T>>
533-
where
524+
) where
534525
T: Clone,
535526
{
536-
if ignore_nulls && value.is_none() {
537-
accumulator.to_vec()
538-
} else {
539-
array_agg(accumulator, value, weight, distinct, keep)
527+
if !ignore_nulls || value.is_some() {
528+
array_agg(accumulator, value, weight, distinct, keep);
540529
}
541530
}
542531

@@ -548,13 +537,12 @@ pub fn array_agg_optN<T>(
548537
distinct: bool,
549538
keep: bool,
550539
ignore_nulls: bool,
551-
) -> Option<Vec<Option<T>>>
552-
where
540+
) where
553541
T: Clone,
554542
{
555-
accumulator
556-
.as_mut()
557-
.map(|accumulator| array_agg_opt(accumulator, value, weight, distinct, keep, ignore_nulls))
543+
if let Some(accumulator) = accumulator.as_mut() {
544+
array_agg_opt(accumulator, value, weight, distinct, keep, ignore_nulls);
545+
}
558546
}
559547

560548
/////////////// Variant index

crates/sqllib/src/map.rs

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,7 @@ where
8282
}
8383

8484
#[doc(hidden)]
85-
pub fn map_agg<K, V>(
86-
accumulator: &mut BTreeMap<K, V>,
87-
value: Tup2<K, V>,
88-
weight: Weight,
89-
) -> BTreeMap<K, V>
85+
pub fn map_agg<K, V>(accumulator: &mut BTreeMap<K, V>, value: Tup2<K, V>, weight: Weight)
9086
where
9187
K: Clone + Ord,
9288
V: Clone + Ord,
@@ -97,22 +93,17 @@ where
9793
let k = value.0;
9894
let v = value.1;
9995
insert_or_keep_largest(accumulator, &k, &v);
100-
accumulator.clone()
10196
}
10297

10398
#[doc(hidden)]
104-
pub fn map_aggN<K, V>(
105-
accumulator: &mut Option<BTreeMap<K, V>>,
106-
value: Tup2<K, V>,
107-
weight: Weight,
108-
) -> Option<BTreeMap<K, V>>
99+
pub fn map_aggN<K, V>(accumulator: &mut Option<BTreeMap<K, V>>, value: Tup2<K, V>, weight: Weight)
109100
where
110101
K: Clone + Ord,
111102
V: Clone + Ord,
112103
{
113-
accumulator
114-
.as_mut()
115-
.map(|accumulator| map_agg(accumulator, value, weight))
104+
if let Some(accumulator) = accumulator.as_mut() {
105+
map_agg(accumulator, value, weight)
106+
}
116107
}
117108

118109
/////////////////////////////////////////

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/circuit/operator/DBSPNestedOperator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ public int outputCount() {
175175
public DBSPType streamType(int outputNumber) {
176176
OutputPort port = this.internalOutputs.get(outputNumber);
177177
if (port == null)
178-
return new DBSPTypeVoid();
178+
return DBSPTypeVoid.INSTANCE;
179179
return port.node().streamType(port.outputNumber);
180180
}
181181
}

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/frontend/AggregateCompiler.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
import org.dbsp.sqlCompiler.ir.type.primitive.DBSPTypeDouble;
7070
import org.dbsp.sqlCompiler.ir.type.primitive.DBSPTypeInteger;
7171
import org.dbsp.sqlCompiler.ir.type.primitive.DBSPTypeNull;
72+
import org.dbsp.sqlCompiler.ir.type.primitive.DBSPTypeVoid;
7273
import org.dbsp.sqlCompiler.ir.type.user.DBSPTypeUser;
7374
import org.dbsp.sqlCompiler.ir.type.user.DBSPTypeVec;
7475
import org.dbsp.util.ICastable;
@@ -316,7 +317,7 @@ void processArrayAgg(SqlBasicAggFunction function) {
316317
if (arguments.length == 6) {
317318
arguments[5] = new DBSPBoolLiteral(ignoreNulls);
318319
}
319-
DBSPExpression increment = new DBSPApplyExpression(node, functionName, this.resultType, arguments);
320+
DBSPExpression increment = new DBSPApplyExpression(node, functionName, DBSPTypeVoid.INSTANCE, arguments);
320321
DBSPType semigroup = new DBSPTypeUser(node, SEMIGROUP, "ConcatSemigroup", false, accumulator.getType());
321322
this.setResult(new NonLinearAggregate(
322323
node, zero, this.makeRowClosure(increment, accumulator), zero, semigroup));

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/frontend/CalciteToDBSPCompiler.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,7 +1879,7 @@ private void visitCollect(Collect collect) {
18791879
arguments[2] = this.compiler.weightVar;
18801880
arguments[3] = new DBSPBoolLiteral(false);
18811881
arguments[4] = new DBSPBoolLiteral(true);
1882-
DBSPExpression increment = new DBSPApplyExpression(node, functionName, vecType, arguments);
1882+
DBSPExpression increment = new DBSPApplyExpression(node, functionName, DBSPTypeVoid.INSTANCE, arguments);
18831883
DBSPType semigroup = new DBSPTypeUser(node, SEMIGROUP, "ConcatSemigroup", false, accumulator.getType());
18841884
agg = new NonLinearAggregate(
18851885
node, zero, increment.closure(accumulator, row, this.compiler.weightVar),
@@ -1902,7 +1902,7 @@ private void visitCollect(Collect collect) {
19021902
DBSPVariablePath accumulator = accumulatorType.var();
19031903
String functionName;
19041904
functionName = "map_agg" + mapType.nullableSuffix();
1905-
DBSPExpression increment = new DBSPApplyExpression(node, functionName, accumulatorType,
1905+
DBSPExpression increment = new DBSPApplyExpression(node, functionName, DBSPTypeVoid.INSTANCE,
19061906
accumulator.borrow(true),
19071907
row.deref().applyCloneIfNeeded(),
19081908
this.compiler.weightVar);
@@ -2957,7 +2957,7 @@ void visitSort(LogicalSort sort) {
29572957
DBSPVariablePath row = inputRowType.ref().var();
29582958
// An element with weight 'w' is pushed 'w' times into the vector
29592959
DBSPExpression wPush = new DBSPApplyExpression(node,
2960-
"weighted_push", new DBSPTypeVoid(), accum, row, this.compiler.weightVar);
2960+
"weighted_push", DBSPTypeVoid.INSTANCE, accum, row, this.compiler.weightVar);
29612961
DBSPExpression push = wPush.closure(accum, row, this.compiler.weightVar);
29622962
DBSPExpression constructor =
29632963
new DBSPPath(new DBSPSimplePathSegment("Fold", DBSPTypeAny.getDefault(),

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/inner/EliminateDump.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public VisitDecision preorder(DBSPApplyExpression expression) {
4343
assert arguments.length == 2: "Expected 2 arguments for dump function";
4444
Function<DBSPExpression, DBSPExpressionStatement> makePrint = stringArgument ->
4545
new DBSPApplyExpression(
46-
expression.getNode(), "print", new DBSPTypeVoid(), stringArgument)
46+
expression.getNode(), "print", DBSPTypeVoid.INSTANCE, stringArgument)
4747
.toStatement();
4848
List<DBSPStatement> block = new ArrayList<>();
4949
block.add(makePrint.apply(arguments[0]));

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/inner/ExpandWriteLog.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ public VisitDecision preorder(DBSPApplyExpression expression) {
7171
} else {
7272
String printFunction = type.mayBeNull ? "print_opt" : "print";
7373
DBSPExpression print = new DBSPApplyExpression(
74-
expression.getNode(), printFunction, new DBSPTypeVoid(), castToStr.applyCloneIfNeeded());
74+
expression.getNode(), printFunction, DBSPTypeVoid.INSTANCE, castToStr.applyCloneIfNeeded());
7575
statements.add(print.toStatement());
7676
}
7777
if (!part.isEmpty()) {
7878
DBSPExpression print = new DBSPApplyExpression(
79-
expression.getNode(), "print", new DBSPTypeVoid(), new DBSPStringLiteral(part));
79+
expression.getNode(), "print", DBSPTypeVoid.INSTANCE, new DBSPStringLiteral(part));
8080
statements.add(print.toStatement());
8181
}
8282
}

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/ir/aggregate/NonLinearAggregate.java

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,24 @@
99
import org.dbsp.sqlCompiler.compiler.visitors.inner.InnerVisitor;
1010
import org.dbsp.sqlCompiler.ir.IDBSPInnerNode;
1111
import org.dbsp.sqlCompiler.ir.expression.DBSPAssignmentExpression;
12+
import org.dbsp.sqlCompiler.ir.expression.DBSPBlockExpression;
1213
import org.dbsp.sqlCompiler.ir.expression.DBSPClosureExpression;
1314
import org.dbsp.sqlCompiler.ir.expression.DBSPExpression;
1415
import org.dbsp.sqlCompiler.ir.expression.DBSPTupleExpression;
1516
import org.dbsp.sqlCompiler.ir.expression.DBSPVariablePath;
1617
import org.dbsp.sqlCompiler.ir.path.DBSPPath;
1718
import org.dbsp.sqlCompiler.ir.path.DBSPSimplePathSegment;
19+
import org.dbsp.sqlCompiler.ir.statement.DBSPExpressionStatement;
20+
import org.dbsp.sqlCompiler.ir.statement.DBSPStatement;
1821
import org.dbsp.sqlCompiler.ir.type.DBSPType;
1922
import org.dbsp.sqlCompiler.ir.type.primitive.DBSPTypeAny;
2023
import org.dbsp.sqlCompiler.ir.type.derived.DBSPTypeTuple;
24+
import org.dbsp.sqlCompiler.ir.type.primitive.DBSPTypeVoid;
2125
import org.dbsp.sqlCompiler.ir.type.user.DBSPTypeSemigroup;
2226
import org.dbsp.util.IIndentStream;
2327

2428
import javax.annotation.Nullable;
29+
import java.util.ArrayList;
2530
import java.util.List;
2631
import java.util.Objects;
2732

@@ -36,7 +41,8 @@
3641
public class NonLinearAggregate extends AggregateBase {
3742
/** Zero of the fold function. */
3843
public final DBSPExpression zero;
39-
/** A closure with signature |accumulator, value, weight| -> accumulator */
44+
/** A closure with signature |&mut accumulator, value, weight|.
45+
* The closure may return a result, or may just mutate the accumulator. */
4046
public final DBSPClosureExpression increment;
4147
/** Function that may post-process the accumulator to produce the final result. */
4248
@Nullable
@@ -57,6 +63,7 @@ public NonLinearAggregate(
5763
super(origin, emptySetResult.getType());
5864
this.zero = zero;
5965
this.increment = increment;
66+
assert increment.parameters.length == 3;
6067
this.postProcess = postProcess;
6168
this.emptySetResult = emptySetResult;
6269
this.semigroup = semigroup;
@@ -76,6 +83,10 @@ public DBSPExpression getEmptySetResult() {
7683
return this.emptySetResult;
7784
}
7885

86+
public DBSPType getIncrementType() {
87+
return this.increment.parameters[0].getType();
88+
}
89+
7990
@Override
8091
public boolean isLinear() {
8192
return false;
@@ -97,7 +108,7 @@ public void validate() {
97108
throw new InternalCompilerError("Post-process result type " + postProcessType +
98109
" different from empty set type " + emptyResultType, this);
99110
} else {
100-
DBSPType incrementResultType = this.increment.getResultType();
111+
DBSPType incrementResultType = this.getIncrementType();
101112
if (!emptyResultType.sameType(incrementResultType)) {
102113
throw new InternalCompilerError("Increment result type " + incrementResultType +
103114
" different from empty set type " + emptyResultType, this);
@@ -124,7 +135,7 @@ public DBSPClosureExpression getPostprocessing() {
124135
if (this.postProcess != null)
125136
return this.postProcess;
126137
// If it is not set return the identity function
127-
DBSPVariablePath var = this.increment.getResultType().var();
138+
DBSPVariablePath var = this.getIncrementType().var();
128139
return var.closure(var);
129140
}
130141

@@ -190,7 +201,7 @@ public static NonLinearAggregate combine(
190201
DBSPVariablePath rowVar, List<NonLinearAggregate> components) {
191202
int parts = components.size();
192203
DBSPExpression[] zeros = new DBSPExpression[parts];
193-
DBSPExpression[] increments = new DBSPExpression[parts];
204+
DBSPClosureExpression[] increments = new DBSPClosureExpression[parts];
194205
DBSPExpression[] posts = new DBSPExpression[parts];
195206
DBSPExpression[] emptySetResults = new DBSPExpression[parts];
196207

@@ -199,7 +210,7 @@ public static NonLinearAggregate combine(
199210
DBSPType weightType = null;
200211
for (int i = 0; i < parts; i++) {
201212
NonLinearAggregate implementation = components.get(i);
202-
DBSPType incType = implementation.increment.getResultType();
213+
DBSPType incType = implementation.getIncrementType();
203214
zeros[i] = implementation.zero;
204215
increments[i] = implementation.increment;
205216
if (implementation.increment.parameters.length != 3)
@@ -224,19 +235,25 @@ public static NonLinearAggregate combine(
224235
DBSPVariablePath accumulator = accumulatorType.ref(true).var();
225236
DBSPVariablePath postAccumulator = accumulatorType.var();
226237

238+
List<DBSPStatement> block = new ArrayList<>();
227239
DBSPVariablePath weightVar = new DBSPVariablePath(Objects.requireNonNull(weightType));
228240
for (int i = 0; i < parts; i++) {
229241
DBSPExpression accumulatorField = accumulator.deref().field(i);
230-
DBSPExpression expr = increments[i].call(
231-
accumulatorField, rowVar, weightVar);
242+
DBSPExpression expr = increments[i].call(accumulatorField, rowVar, weightVar);
232243
BetaReduction reducer = new BetaReduction(compiler);
233-
increments[i] = reducer.reduce(expr);
244+
expr = reducer.reduce(expr);
245+
// Generate either increment(&a.i...); or *a.i = increment(&a.i...)
246+
// depending on the type of the result returned by the increment function
247+
if (increments[i].getResultType().is(DBSPTypeVoid.class))
248+
block.add(new DBSPExpressionStatement(expr));
249+
else
250+
block.add(new DBSPExpressionStatement(
251+
new DBSPAssignmentExpression(accumulatorField, expr)));
234252
DBSPExpression postAccumulatorField = postAccumulator.field(i);
235253
expr = posts[i].call(postAccumulatorField);
236254
posts[i] = reducer.reduce(expr);
237255
}
238-
DBSPAssignmentExpression accumulatorBody = new DBSPAssignmentExpression(
239-
accumulator.deref(), new DBSPTupleExpression(increments));
256+
DBSPExpression accumulatorBody = new DBSPBlockExpression(block, null);
240257
DBSPClosureExpression accumFunction = accumulatorBody.closure(
241258
accumulator, rowVar,
242259
weightVar);

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/ir/expression/DBSPAssignmentExpression.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public final class DBSPAssignmentExpression extends DBSPExpression {
3535
public final DBSPExpression right;
3636

3737
public DBSPAssignmentExpression(DBSPExpression left, DBSPExpression right) {
38-
super(left.getNode(), new DBSPTypeVoid());
38+
super(left.getNode(), DBSPTypeVoid.INSTANCE);
3939
this.left = left;
4040
this.right = right;
4141
}

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/ir/expression/DBSPBlockExpression.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public final class DBSPBlockExpression extends DBSPExpression {
4343
public final DBSPExpression lastExpression;
4444

4545
public DBSPBlockExpression(List<DBSPStatement> contents, @Nullable DBSPExpression last) {
46-
super(CalciteObject.EMPTY, last != null ? last.getType() : new DBSPTypeVoid());
46+
super(CalciteObject.EMPTY, last != null ? last.getType() : DBSPTypeVoid.INSTANCE);
4747
this.contents = contents;
4848
this.lastExpression = last;
4949
}

0 commit comments

Comments
 (0)