Skip to content

Commit a6feed5

Browse files
committed
[SQL] Preserve source position information in casts during compilation
Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
1 parent 3cbc4ac commit a6feed5

File tree

22 files changed

+190
-184
lines changed

22 files changed

+190
-184
lines changed

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/backend/rust/ToRustInnerVisitor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1545,7 +1545,7 @@ public VisitDecision preorder(DBSPBinaryExpression expression) {
15451545
DBSPExpression sub1 = ExpressionCompiler.makeBinaryExpression(
15461546
expression.getNode(), indexType, DBSPOpcode.SUB,
15471547
expression.right, indexType.to(IsNumericType.class).getOne());
1548-
sub1 = sub1.cast(DBSPTypeISize.create(indexType.mayBeNull), false);
1548+
sub1 = sub1.cast(expression.getNode(), DBSPTypeISize.create(indexType.mayBeNull), false);
15491549
Simplify simplify = new Simplify(this.compiler);
15501550
sub1 = simplify.apply(sub1).to(DBSPExpression.class);
15511551
sub1.accept(this);

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/backend/rust/ToRustVisitor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ public VisitDecision preorder(DBSPIndexedTopKOperator operator) {
977977
if (operator.numbering != DBSPIndexedTopKOperator.TopKNumbering.ROW_NUMBER)
978978
this.builder.append(", _");
979979
this.builder.append(">(hash, ");
980-
DBSPExpression cast = operator.limit.cast(
980+
DBSPExpression cast = operator.limit.cast(CalciteObject.EMPTY,
981981
DBSPTypeUSize.create(operator.limit.getType().mayBeNull), false);
982982
cast.accept(this.innerVisitor);
983983
if (operator.numbering != DBSPIndexedTopKOperator.TopKNumbering.ROW_NUMBER) {

sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/frontend/AggregateCompiler.java

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -194,21 +194,21 @@ void processBitOp(SqlBitOpAggFunction function) {
194194
String helper = aggregatedValue.getType().is(DBSPTypeBinary.class) ?
195195
"right_xor_weigh_bytes" : "right_xor_weigh";
196196
helper += aggregatedValue.getType().nullableSuffix();
197-
aggregatedValue = new DBSPApplyExpression(node, helper,
197+
aggregatedValue = new DBSPApplyExpression(this.node, helper,
198198
aggregatedValue.getType(), aggregatedValue.applyCloneIfNeeded(), this.compiler.weightVar);
199199
}
200-
increment = this.aggregateOperation(node, opcode,
200+
increment = this.aggregateOperation(this.node, opcode,
201201
this.nullableResultType, accumulator, aggregatedValue, this.filterArgument());
202202
DBSPType semigroup = new DBSPTypeUser(CalciteObject.EMPTY, SEMIGROUP, "UnimplementedSemigroup",
203203
false, accumulator.getType());
204204
this.setResult(new NonLinearAggregate(
205-
node, zero, this.makeRowClosure(increment, accumulator), zero, semigroup));
205+
this.node, zero, this.makeRowClosure(increment, accumulator), zero, semigroup));
206206
}
207207

208208
void processGrouping(SqlAbstractGroupFunction function) {
209209
DBSPExpression zero = this.nullableResultType.to(IsNumericType.class).getZero();
210210
if (this.filterArgument() != null) {
211-
throw new UnimplementedException("GROUPING with FILTER not implemented", node);
211+
throw new UnimplementedException("GROUPING with FILTER not implemented", this.node);
212212
}
213213

214214
long result = 0;
@@ -223,13 +223,14 @@ void processGrouping(SqlAbstractGroupFunction function) {
223223
}
224224

225225
// TODO: should this be looking at the filter argument?
226-
DBSPExpression increment = new DBSPI64Literal(result).cast(this.nullableResultType, false);
226+
DBSPExpression increment = new DBSPI64Literal(result).cast(
227+
this.node, this.nullableResultType, false);
227228
DBSPVariablePath accumulator = this.nullableResultType.var();
228229
DBSPType semigroup = new DBSPTypeUser(CalciteObject.EMPTY, SEMIGROUP, "UnimplementedSemigroup",
229230
false, accumulator.getType());
230231
// Always non-linear (result can not be zero).
231232
this.setResult(new NonLinearAggregate(
232-
node, zero, this.makeRowClosure(increment, accumulator), zero, semigroup));
233+
this.node, zero, this.makeRowClosure(increment, accumulator), zero, semigroup));
233234
}
234235

235236
void processCount(SqlCountAggFunction function) {
@@ -249,15 +250,15 @@ void processCount(SqlCountAggFunction function) {
249250
DBSPExpression indicator;
250251
if (this.aggArgument != null) {
251252
DBSPExpression agg = this.getAggregatedValue();
252-
indicator = ExpressionCompiler.makeIndicator(node, this.resultType, agg);
253+
indicator = ExpressionCompiler.makeIndicator(this.node, this.resultType, agg);
253254
} else {
254255
indicator = one;
255256
}
256257

257258
DBSPExpression filter = this.filterArgument();
258259
DBSPExpression combined = indicator;
259260
if (filter != null)
260-
combined = new DBSPIfExpression(node, filter, indicator, zero);
261+
combined = new DBSPIfExpression(this.node, filter, indicator, zero);
261262
DBSPExpression mapBody = new DBSPTupleExpression(combined, one);
262263
DBSPVariablePath postVar = mapBody.getType().var();
263264
// post = |x| x.0
@@ -361,17 +362,17 @@ void processBasic(SqlBasicAggFunction function) {
361362
DBSPVariablePath accumulator = zeroType.var();
362363
DBSPExpression ge = new DBSPBinaryExpression(
363364
node, DBSPTypeBool.create(false), compareOpcode,
364-
tuple.fields[1].cast(currentType, false).applyCloneIfNeeded(),
365+
tuple.fields[1].cast(this.node, currentType, false).applyCloneIfNeeded(),
365366
accumulator.field(1).applyCloneIfNeeded());
366367
if (this.filterArgument >= 0) {
367368
ge = ExpressionCompiler.makeBinaryExpression(
368369
node, ge.getType(), DBSPOpcode.AND, ge, Objects.requireNonNull(this.filterArgument()));
369370
}
370371
DBSPTupleExpression aggArgCast = new DBSPTupleExpression(
371-
tuple.fields[0].cast(this.resultType, false).applyCloneIfNeeded(),
372-
tuple.fields[1].cast(currentType, false).applyCloneIfNeeded());
372+
tuple.fields[0].cast(this.node, this.resultType, false).applyCloneIfNeeded(),
373+
tuple.fields[1].cast(this.node, currentType, false).applyCloneIfNeeded());
373374
DBSPExpression increment = new DBSPIfExpression(node, ge, aggArgCast, accumulator.applyCloneIfNeeded());
374-
DBSPType semigroup = new DBSPTypeUser(node, SEMIGROUP, "UnimplementedSemigroup",
375+
DBSPType semigroup = new DBSPTypeUser(this.node, SEMIGROUP, "UnimplementedSemigroup",
375376
false, aggArgCast.getType());
376377
DBSPExpression postBody = accumulator.field(0).applyCloneIfNeeded();
377378
this.setResult(new NonLinearAggregate(
@@ -410,7 +411,7 @@ public DBSPExpression aggregateOperation(
410411
if (!leftType.withMayBeNull(rightType.mayBeNull).sameType(rightType)) {
411412
if (!rightType.is(DBSPTypeBaseType.class)) {
412413
// These can also be different DECIMAL types
413-
right = right.applyCloneIfNeeded().cast(leftType.withMayBeNull(rightType.mayBeNull), false);
414+
right = right.applyCloneIfNeeded().cast(node, leftType.withMayBeNull(rightType.mayBeNull), false);
414415
}
415416
}
416417

@@ -471,12 +472,12 @@ void processSum(SqlSumAggFunction function) {
471472
DBSPExpression condition;
472473
if (filter != null)
473474
condition = ExpressionCompiler.makeBinaryExpression(
474-
node, DBSPTypeBool.create(false), DBSPOpcode.AND, filter, agg);
475+
this.node, DBSPTypeBool.create(false), DBSPOpcode.AND, filter, agg);
475476
else
476477
condition = agg;
477478
DBSPExpression first = new DBSPIfExpression(
478-
node, condition,
479-
this.getAggregatedValue().cast(typedZero.getType(), false),
479+
this.node, condition,
480+
this.getAggregatedValue().cast(this.node, typedZero.getType(), false),
480481
typedZero);
481482
DBSPExpression second = new DBSPIfExpression(node, condition, one, realZero);
482483
DBSPExpression mapBody = new DBSPTupleExpression(first, second, one);
@@ -485,7 +486,7 @@ void processSum(SqlSumAggFunction function) {
485486
DBSPExpression postBody = new DBSPIfExpression(node,
486487
ExpressionCompiler.makeBinaryExpression(node,
487488
DBSPTypeBool.create(false), DBSPOpcode.NEQ, postVar.field(1), realZero),
488-
postVar.field(0).cast(this.nullableResultType, false), zero);
489+
postVar.field(0).cast(this.node, this.nullableResultType, false), zero);
489490
post = postBody.closure(postVar);
490491
map = mapBody.closure(this.v);
491492
this.setResult(new LinearAggregate(node, map, post, zero));
@@ -523,31 +524,31 @@ void processSumZero(SqlSumEmptyIsZeroAggFunction function) {
523524
DBSPExpression condition;
524525
if (filter != null)
525526
condition = ExpressionCompiler.makeBinaryExpression(
526-
node, DBSPTypeBool.create(false), DBSPOpcode.AND, filter, agg);
527+
this.node, DBSPTypeBool.create(false), DBSPOpcode.AND, filter, agg);
527528
else
528529
condition = agg;
529530
DBSPExpression first = new DBSPIfExpression(
530-
node, condition, this.getAggregatedValue().cast(zero.getType(), false), zero);
531+
this.node, condition, this.getAggregatedValue().cast(this.node, zero.getType(), false), zero);
531532
DBSPExpression mapBody = new DBSPTupleExpression(first, one);
532533
DBSPVariablePath postVar = mapBody.getType().var();
533534
// post = |x| x.0
534535
post = postVar.field(0).closure(postVar);
535536
map = mapBody.closure(this.v);
536-
this.setResult(new LinearAggregate(node, map, post, zero));
537+
this.setResult(new LinearAggregate(this.node, map, post, zero));
537538
} else {
538539
DBSPExpression weighted = new DBSPBinaryExpression(
539-
node, aggregatedValue.getType(),
540+
this.node, aggregatedValue.getType(),
540541
DBSPOpcode.MUL_WEIGHT, aggregatedValue, this.compiler.weightVar);
541542
increment = this.aggregateOperation(
542-
node, DBSPOpcode.AGG_ADD_NON_NULL, this.resultType,
543+
this.node, DBSPOpcode.AGG_ADD_NON_NULL, this.resultType,
543544
accumulator, weighted, this.filterArgument());
544545
String semigroupName = "DefaultSemigroup";
545546
if (accumulator.getType().mayBeNull)
546547
semigroupName = "DefaultOptSemigroup";
547-
DBSPType semigroup = new DBSPTypeUser(node, SEMIGROUP, semigroupName, false,
548+
DBSPType semigroup = new DBSPTypeUser(this.node, SEMIGROUP, semigroupName, false,
548549
accumulator.getType().withMayBeNull(false));
549550
this.setResult(new NonLinearAggregate(
550-
node, zero, this.makeRowClosure(increment, accumulator), zero, semigroup));
551+
this.node, zero, this.makeRowClosure(increment, accumulator), zero, semigroup));
551552
}
552553
}
553554

@@ -596,7 +597,7 @@ AggregateBase doAverage(SqlAvgAggFunction function) {
596597
condition = agg;
597598
DBSPExpression first = new DBSPIfExpression(
598599
node, condition,
599-
this.getAggregatedValue().cast(typedZero.getType(), false),
600+
this.getAggregatedValue().cast(this.node, typedZero.getType(), false),
600601
typedZero);
601602
DBSPExpression second = new DBSPIfExpression(node, condition, one, typedZero);
602603
DBSPExpression mapBody = new DBSPTupleExpression(first, second, one);
@@ -607,7 +608,7 @@ AggregateBase doAverage(SqlAvgAggFunction function) {
607608
DBSPExpression postBody = new DBSPIfExpression(node,
608609
ExpressionCompiler.makeBinaryExpression(node,
609610
DBSPTypeBool.create(false), DBSPOpcode.NEQ, postVar.field(1), typedZero),
610-
div.cast(this.nullableResultType, false), postZero);
611+
div.cast(this.node, this.nullableResultType, false), postZero);
611612
post = postBody.closure(postVar);
612613
map = mapBody.closure(this.v);
613614
return new LinearAggregate(node, map, post, postZero);
@@ -623,7 +624,7 @@ AggregateBase doAverage(SqlAvgAggFunction function) {
623624
final int countIndex = 1;
624625
DBSPExpression countAccumulator = accumulator.field(countIndex);
625626
DBSPExpression sumAccumulator = accumulator.field(sumIndex);
626-
DBSPExpression aggregatedValue = this.getAggregatedValue().cast(intermediateResultType, false);
627+
DBSPExpression aggregatedValue = this.getAggregatedValue().cast(this.node, intermediateResultType, false);
627628
DBSPType intermediateResultTypeNonNull = intermediateResultType.withMayBeNull(false);
628629
DBSPExpression plusOne = intermediateResultTypeNonNull.to(IsNumericType.class).getOne();
629630

@@ -649,7 +650,7 @@ AggregateBase doAverage(SqlAvgAggFunction function) {
649650
DBSPExpression divide = ExpressionCompiler.makeBinaryExpression(
650651
node, this.resultType, DBSPOpcode.DIV,
651652
a.field(sumIndex), a.field(countIndex));
652-
divide = divide.cast(this.nullableResultType, false);
653+
divide = divide.cast(this.node, this.nullableResultType, false);
653654
DBSPClosureExpression post = new DBSPClosureExpression(
654655
node, divide, a.asParameter());
655656
DBSPType semigroup = new DBSPTypeUser(node, SEMIGROUP, "PairSemigroup", false,
@@ -688,7 +689,7 @@ NonLinearAggregate doStddev(SqlAvgAggFunction function) {
688689
DBSPExpression sumAccumulator = accumulator.field(sumIndex);
689690
DBSPExpression sumSquaresAccumulator = accumulator.field(sumSquaresIndex);
690691

691-
DBSPExpression aggregatedValue = this.getAggregatedValue().cast(intermediateResultType, false);
692+
DBSPExpression aggregatedValue = this.getAggregatedValue().cast(this.node, intermediateResultType, false);
692693
DBSPType intermediateResultTypeNonNull = intermediateResultType.withMayBeNull(false);
693694
DBSPExpression plusOne = intermediateResultTypeNonNull.to(IsNumericType.class).getOne();
694695

@@ -737,14 +738,14 @@ NonLinearAggregate doStddev(SqlAvgAggFunction function) {
737738
DBSPType sqrtType = new DBSPTypeDouble(node, intermediateResultType.mayBeNull);
738739
DBSPExpression div = ExpressionCompiler.makeBinaryExpression(
739740
node, intermediateResultType, DBSPOpcode.DIV_NULL,
740-
sub, denom).cast(sqrtType, false);
741+
sub, denom).cast(this.node, sqrtType, false);
741742
// Prevent sqrt from negative values if computations are unstable
742743
DBSPExpression max = ExpressionCompiler.makeBinaryExpression(
743744
node, sqrtType, DBSPOpcode.MAX,
744745
div, sqrtType.to(IsNumericType.class).getZero());
745746
DBSPExpression sqrt = ExpressionCompiler.compilePolymorphicFunction(
746747
false, "sqrt", node, sqrtType, Linq.list(max), 1);
747-
sqrt = sqrt.cast(this.nullableResultType, false);
748+
sqrt = sqrt.cast(this.node, this.nullableResultType, false);
748749
DBSPClosureExpression post = new DBSPClosureExpression(node, sqrt, a.asParameter());
749750
DBSPExpression postZero = DBSPLiteral.none(this.nullableResultType);
750751
DBSPType semigroup = new DBSPTypeUser(node, SEMIGROUP, "TripleSemigroup", false,

0 commit comments

Comments
 (0)