Skip to content

Commit ab066ad

Browse files
committed
Change setFromTables to setFromList for SQLite and MariaDB
1 parent 91a3a9b commit ab066ad

File tree

13 files changed

+69
-25
lines changed

13 files changed

+69
-25
lines changed

src/sqlancer/common/ast/newast/Select.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
public interface Select<J extends Join<E, T, C>, E extends Expression<C>, T extends AbstractTable<C, ?, ?>, C extends AbstractTableColumn<?, ?>>
99
extends Expression<C> {
1010

11-
void setFromTables(List<E> fromTables);
12-
1311
List<E> getFromList();
1412

1513
void setFromList(List<E> fromList);

src/sqlancer/mariadb/ast/MariaDBSelectStatement.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import java.util.List;
55

66
import sqlancer.common.ast.SelectBase;
7-
import sqlancer.mariadb.MariaDBSchema.MariaDBTable;
87

98
public class MariaDBSelectStatement extends SelectBase<MariaDBExpression> implements MariaDBExpression {
109

@@ -14,7 +13,6 @@ public enum MariaDBSelectType {
1413

1514
private List<MariaDBExpression> groupBys = new ArrayList<>();
1615
private List<MariaDBExpression> columns = new ArrayList<>();
17-
private List<MariaDBTable> tables = new ArrayList<>();
1816
private MariaDBSelectType selectType = MariaDBSelectType.ALL;
1917
private MariaDBExpression whereCondition;
2018

@@ -28,10 +26,6 @@ public void setFetchColumns(List<MariaDBExpression> columns) {
2826

2927
}
3028

31-
public void setFromTables(List<MariaDBTable> tables) {
32-
this.tables = tables;
33-
}
34-
3529
public void setSelectType(MariaDBSelectType selectType) {
3630
this.selectType = selectType;
3731
}
@@ -53,10 +47,6 @@ public MariaDBSelectType getSelectType() {
5347
return selectType;
5448
}
5549

56-
public List<MariaDBTable> getTables() {
57-
return tables;
58-
}
59-
6050
public MariaDBExpression getWhereCondition() {
6151
return whereCondition;
6252
}

src/sqlancer/mariadb/ast/MariaDBStringVisitor.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package sqlancer.mariadb.ast;
22

33
import java.util.List;
4-
import java.util.stream.Collectors;
54

65
public class MariaDBStringVisitor extends MariaDBVisitor {
76

@@ -41,8 +40,13 @@ public void visit(MariaDBSelectStatement s) {
4140
visit(column);
4241
}
4342
sb.append(" FROM ");
44-
sb.append(s.getTables().stream().map(t -> t.getName()).collect(Collectors.joining(", ")));
4543

44+
for (int j = 0; j < s.getFromList().size(); j++) {
45+
if (j != 0) {
46+
sb.append(", ");
47+
}
48+
visit(s.getFromList().get(j));
49+
}
4650
for (MariaDBExpression j : s.getJoinList()) {
4751
visit(j);
4852
}
@@ -167,4 +171,9 @@ public void visit(MariaDBJoin join) {
167171
visit(join.getOnClause());
168172
}
169173
}
174+
175+
@Override
176+
public void visit(MariaDBTableReference ref) {
177+
sb.append(ref.getTable().getName());
178+
}
170179
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package sqlancer.mariadb.ast;
2+
3+
import sqlancer.mariadb.MariaDBSchema.MariaDBTable;
4+
5+
public class MariaDBTableReference implements MariaDBExpression {
6+
7+
private final MariaDBTable table;
8+
9+
public MariaDBTableReference(MariaDBTable table) {
10+
this.table = table;
11+
}
12+
13+
public MariaDBTable getTable() {
14+
return table;
15+
}
16+
}

src/sqlancer/mariadb/ast/MariaDBVisitor.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ public abstract class MariaDBVisitor {
2424

2525
public abstract void visit(MariaDBJoin join);
2626

27+
public abstract void visit(MariaDBTableReference join);
28+
2729
public void visit(MariaDBExpression expr) {
2830
if (expr instanceof MariaDBConstant) {
2931
visit((MariaDBConstant) expr);
@@ -47,6 +49,8 @@ public void visit(MariaDBExpression expr) {
4749
visit((MariaDBInOperation) expr);
4850
} else if (expr instanceof MariaDBJoin) {
4951
visit((MariaDBJoin) expr);
52+
} else if (expr instanceof MariaDBTableReference) {
53+
visit((MariaDBTableReference) expr);
5054
} else {
5155
throw new AssertionError(expr.getClass());
5256
}

src/sqlancer/mariadb/oracle/MariaDBDQPOracle.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import sqlancer.mariadb.ast.MariaDBExpression;
1818
import sqlancer.mariadb.ast.MariaDBJoin;
1919
import sqlancer.mariadb.ast.MariaDBSelectStatement;
20+
import sqlancer.mariadb.ast.MariaDBTableReference;
2021
import sqlancer.mariadb.ast.MariaDBVisitor;
2122
import sqlancer.mariadb.gen.MariaDBExpressionGenerator;
2223
import sqlancer.mariadb.gen.MariaDBSetGenerator;
@@ -60,7 +61,8 @@ public void check() throws Exception {
6061
select.setJoinList(joinExpressions.stream().map(j -> (MariaDBExpression) j).collect(Collectors.toList()));
6162

6263
// Set the from clause from the tables that are not used in the join.
63-
select.setFromTables(tables.getTables());
64+
select.setFromList(
65+
tables.getTables().stream().map(t -> new MariaDBTableReference(t)).collect(Collectors.toList()));
6466

6567
// Get the result of the first query
6668
String originalQueryString = MariaDBVisitor.asString(select);

src/sqlancer/mariadb/oracle/MariaDBNoRECOracle.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import sqlancer.mariadb.ast.MariaDBPostfixUnaryOperation.MariaDBPostfixUnaryOperator;
2424
import sqlancer.mariadb.ast.MariaDBSelectStatement;
2525
import sqlancer.mariadb.ast.MariaDBSelectStatement.MariaDBSelectType;
26+
import sqlancer.mariadb.ast.MariaDBTableReference;
2627
import sqlancer.mariadb.ast.MariaDBText;
2728
import sqlancer.mariadb.ast.MariaDBVisitor;
2829
import sqlancer.mariadb.gen.MariaDBExpressionGenerator;
@@ -76,7 +77,7 @@ private int getUnoptimizedQuery(MariaDBTable randomTable, MariaDBExpression rand
7677
randomWhereCondition);
7778
MariaDBText asText = new MariaDBText(isTrue, " as count", false);
7879
select.setFetchColumns(Arrays.asList(asText));
79-
select.setFromTables(Arrays.asList(randomTable));
80+
select.setFromList(Arrays.asList(new MariaDBTableReference(randomTable)));
8081
select.setSelectType(MariaDBSelectType.ALL);
8182
int secondCount = 0;
8283

@@ -103,7 +104,7 @@ private int getOptimizedQuery(MariaDBTable randomTable, MariaDBExpression random
103104
new MariaDBColumnName(new MariaDBColumn("*", MariaDBDataType.INT, false, 0)),
104105
MariaDBAggregateFunction.COUNT);
105106
select.setFetchColumns(Arrays.asList(aggr));
106-
select.setFromTables(Arrays.asList(randomTable));
107+
select.setFromList(Arrays.asList(new MariaDBTableReference(randomTable)));
107108
select.setWhereClause(randomWhereCondition);
108109
select.setSelectType(MariaDBSelectType.ALL);
109110
int firstCount;

src/sqlancer/sqlite3/ast/SQLite3Select.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ public void setSelectType(SelectType fromOptions) {
4646
this.setFromOptions(fromOptions);
4747
}
4848

49-
public void setFromTables(List<SQLite3Expression> fromTables) {
50-
this.setFromList(fromTables);
51-
}
52-
5349
public SelectType getFromOptions() {
5450
return fromOptions;
5551
}

src/sqlancer/sqlite3/oracle/SQLite3NoRECOracle.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public void check() throws SQLException {
8080
List<Join> joinStatements = gen.getRandomJoinClauses(tables);
8181
List<SQLite3Expression> tableRefs = SQLite3Common.getTableRefs(tables, s);
8282
SQLite3Select select = new SQLite3Select();
83-
select.setFromTables(tableRefs);
83+
select.setFromList(tableRefs);
8484
select.setJoinClauses(joinStatements);
8585

8686
Function<SQLite3GlobalState, Integer> optimizedQuery = getOptimizedQuery(select, randomWhereCondition);

src/sqlancer/sqlite3/oracle/SQLite3PivotedQuerySynthesisOracle.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public SQLite3Select getQuery() throws SQLException {
7575
.filter(c -> !SQLite3Schema.ROWID_STRINGS.contains(c.getName())).collect(Collectors.toList());
7676
List<Join> joinStatements = getJoinStatements(globalState, tables, columnsWithoutRowid);
7777
selectStatement.setJoinClauses(joinStatements);
78-
selectStatement.setFromTables(SQLite3Common.getTableRefs(tables, globalState.getSchema()));
78+
selectStatement.setFromList(SQLite3Common.getTableRefs(tables, globalState.getSchema()));
7979

8080
fetchColumns = Randomly.nonEmptySubset(columnsWithoutRowid);
8181
List<SQLite3Table> allTables = new ArrayList<>();

0 commit comments

Comments
 (0)