Skip to content

Commit dded5b9

Browse files
committed
[DataFusion] Use common NoREC oracle
1 parent fd1ba0e commit dded5b9

File tree

2 files changed

+35
-76
lines changed

2 files changed

+35
-76
lines changed

src/sqlancer/datafusion/DataFusionErrors.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import static sqlancer.datafusion.DataFusionUtil.dfAssert;
44

5+
import java.util.ArrayList;
6+
import java.util.List;
7+
58
import sqlancer.common.query.ExpectedErrors;
69

710
public final class DataFusionErrors {
@@ -17,7 +20,8 @@ private DataFusionErrors() {
1720
* Note now it's implemented this way for simplicity This way might cause false negative, because Q1 and Q2 should
1821
* both succeed or both fail TODO(datafusion): ensure both succeed or both fail
1922
*/
20-
public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
23+
public static List<String> getExpectedExecutionErrors() {
24+
ArrayList<String> errors = new ArrayList<>();
2125
/*
2226
* Expected
2327
*/
@@ -40,5 +44,11 @@ public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
4044
errors.add("Physical plan does not support logical expression AggregateFunction"); // False positive: when aggr
4145
// is generated in where
4246
// clause
47+
48+
return errors;
49+
}
50+
51+
public static void registerExpectedExecutionErrors(ExpectedErrors errors) {
52+
errors.addAll(getExpectedExecutionErrors());
4353
}
4454
}
Lines changed: 24 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,44 @@
11
package sqlancer.datafusion.test;
22

3-
import static sqlancer.datafusion.DataFusionUtil.DataFusionLogger.DataFusionLogType.ERROR;
4-
import static sqlancer.datafusion.ast.DataFusionSelect.getRandomSelect;
5-
63
import java.sql.SQLException;
7-
import java.util.List;
84

9-
import sqlancer.ComparatorHelper;
10-
import sqlancer.common.oracle.NoRECBase;
5+
import sqlancer.Reproducer;
6+
import sqlancer.common.oracle.NoRECOracle;
117
import sqlancer.common.oracle.TestOracle;
8+
import sqlancer.common.query.ExpectedErrors;
129
import sqlancer.datafusion.DataFusionErrors;
1310
import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState;
14-
import sqlancer.datafusion.DataFusionToStringVisitor;
15-
import sqlancer.datafusion.DataFusionUtil;
11+
import sqlancer.datafusion.DataFusionSchema;
12+
import sqlancer.datafusion.DataFusionSchema.DataFusionColumn;
13+
import sqlancer.datafusion.DataFusionSchema.DataFusionTable;
14+
import sqlancer.datafusion.ast.DataFusionExpression;
15+
import sqlancer.datafusion.ast.DataFusionJoin;
1616
import sqlancer.datafusion.ast.DataFusionSelect;
17+
import sqlancer.datafusion.gen.DataFusionExpressionGenerator;
1718

18-
public class DataFusionNoRECOracle extends NoRECBase<DataFusionGlobalState>
19-
implements TestOracle<DataFusionGlobalState> {
19+
public class DataFusionNoRECOracle implements TestOracle<DataFusionGlobalState> {
2020

21-
private final DataFusionGlobalState state;
21+
NoRECOracle<DataFusionSelect, DataFusionJoin, DataFusionExpression, DataFusionSchema, DataFusionTable, DataFusionColumn, DataFusionGlobalState> oracle;
2222

2323
public DataFusionNoRECOracle(DataFusionGlobalState globalState) {
24-
super(globalState);
25-
this.state = globalState;
26-
DataFusionErrors.registerExpectedExecutionErrors(errors);
24+
DataFusionExpressionGenerator gen = new DataFusionExpressionGenerator(globalState);
25+
ExpectedErrors errors = ExpectedErrors.newErrors().with(DataFusionErrors.getExpectedExecutionErrors())
26+
.with("canceling statement due to statement timeout").build();
27+
this.oracle = new NoRECOracle<>(globalState, gen, errors);
2728
}
2829

29-
/*
30-
* Non-Optimizing Reference Engine Construction q1: SELECT [expr1] FROM [expr2] WHERE [expr3] q2: SELECT [expr3]
31-
* FROM [expr2]
32-
*
33-
* Oracle Check: q1's result size equals to `true` count in q2's result set
34-
*/
3530
@Override
3631
public void check() throws SQLException {
37-
/*
38-
* Setup Q1 and Q2
39-
*/
40-
// generate a random:
41-
// SELECT [expr1] FROM [expr2] WHERE [expr3]
42-
DataFusionSelect randomSelect = getRandomSelect(state);
43-
// Q1: SELECT count(*) FROM [expr2] WHERE [expr3]
44-
DataFusionSelect q1 = new DataFusionSelect();
45-
q1.setFetchColumnsString("COUNT(*)");
46-
q1.setFromList(randomSelect.getFromList());
47-
q1.setWhereClause(randomSelect.getWhereClause());
48-
// Q2: SELECT count(case when [expr3] then 1 else null end) FROM [expr2]
49-
DataFusionSelect q2 = new DataFusionSelect();
50-
String selectExpr = String.format("COUNT(CASE WHEN %S THEN 1 ELSE NULL END)",
51-
DataFusionToStringVisitor.asString(randomSelect.getWhereClause()));
52-
q2.setFetchColumnsString(selectExpr);
53-
q2.setFromList(randomSelect.getFromList());
54-
q2.setWhereClause(null);
55-
56-
/*
57-
* Execute Q1 and Q2
58-
*/
59-
String q1String = DataFusionToStringVisitor.asString(q1);
60-
String q2String = DataFusionToStringVisitor.asString(q2);
61-
List<String> q1ResultSet = null;
62-
List<String> q2ResultSet = null;
63-
try {
64-
q1ResultSet = ComparatorHelper.getResultSetFirstColumnAsString(q1String, errors, state);
65-
q2ResultSet = ComparatorHelper.getResultSetFirstColumnAsString(q2String, errors, state);
66-
} catch (AssertionError e) {
67-
// Append detailed error message
68-
String replay = DataFusionUtil.getReplay(state.getDatabaseName());
69-
String newMessage = e.getMessage() + "\n" + e.getCause() + "\n" + replay + "\n";
70-
state.dfLogger.appendToLog(ERROR, newMessage);
71-
72-
throw new AssertionError(newMessage);
73-
}
74-
75-
/*
76-
* NoREC check
77-
*/
78-
int count1 = q1ResultSet != null ? Integer.parseInt(q1ResultSet.get(0)) : -1;
79-
int count2 = q2ResultSet != null ? Integer.parseInt(q2ResultSet.get(0)) : -1;
80-
if (count1 != count2) {
81-
StringBuilder errorMessage = new StringBuilder().append("NoREC oracle violated:\n")
82-
.append(" Q1(result size ").append(count1).append("):").append(q1String).append(";\n")
83-
.append(" Q2(result size ").append(count2).append("):").append(q2String).append(";\n")
84-
.append("=======================================\n").append("Reproducer: \n");
85-
86-
String replay = DataFusionUtil.getReplay(state.getDatabaseName());
32+
oracle.check();
33+
}
8734

88-
String errorLog = errorMessage.toString() + replay + "\n";
89-
String indentedErrorLog = errorLog.replaceAll("(?m)^", " ");
90-
state.dfLogger.appendToLog(ERROR, errorLog);
35+
@Override
36+
public Reproducer<DataFusionGlobalState> getLastReproducer() {
37+
return oracle.getLastReproducer();
38+
}
9139

92-
throw new AssertionError("\n\n" + indentedErrorLog);
93-
}
40+
@Override
41+
public String getLastQueryString() {
42+
return oracle.getLastQueryString();
9443
}
9544
}

0 commit comments

Comments
 (0)