Skip to content

Commit 68e6961

Browse files
committed
[MySQL] Test aggregate generation
1 parent 66e8beb commit 68e6961

6 files changed

Lines changed: 69 additions & 14 deletions

File tree

pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,12 @@
327327
<version>5.11.2</version>
328328
<scope>test</scope>
329329
</dependency>
330+
<dependency>
331+
<groupId>org.junit.jupiter</groupId>
332+
<artifactId>junit-jupiter-params</artifactId>
333+
<version>5.11.2</version>
334+
<scope>test</scope>
335+
</dependency>
330336
<dependency>
331337
<groupId>org.slf4j</groupId>
332338
<artifactId> slf4j-simple</artifactId>

src/sqlancer/mysql/MySQLExpectedValueVisitor.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,6 @@ public void visit(MySQLText text) {
169169

170170
@Override
171171
public void visit(MySQLAggregate aggr) {
172-
print(aggr);
173-
visit(aggr.getExpr());
174172
}
175173

176174
}

src/sqlancer/mysql/MySQLToStringVisitor.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,10 +328,14 @@ public void visit(MySQLText text) {
328328
@Override
329329
public void visit(MySQLAggregate aggr) {
330330
MySQLAggregateFunction func = aggr.getFunc();
331+
String option = aggr.getOption();
331332

332333
sb.append(func);
333334
sb.append("(");
334-
sb.append(func.getRandomOption());
335+
if (option != null) {
336+
sb.append(option);
337+
sb.append(" ");
338+
}
335339
visit(aggr.getExpr());
336340
sb.append(")");
337341
}
Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package sqlancer.mysql.ast;
22

3-
import sqlancer.Randomly;
3+
import java.util.List;
44

55
public class MySQLAggregate implements MySQLExpression {
66

@@ -14,27 +14,25 @@ public enum MySQLAggregateFunction {
1414
// See https://dev.mysql.com/doc/refman/8.4/en/aggregate-functions.html#function_max.
1515
MAX("DISTINCT");
1616

17-
private final String[] options;
17+
private final List<String> options;
1818

1919
private MySQLAggregateFunction(String... options) {
20-
this.options = options.clone();
20+
this.options = List.of(options);
2121
}
2222

23-
public String getRandomOption() {
24-
if (options.length == 0 || Randomly.getBoolean()) {
25-
return "";
26-
}
27-
28-
return Randomly.fromOptions(options);
23+
public List<String> getOptions() {
24+
return options;
2925
}
3026
}
3127

3228
private final MySQLExpression expr;
3329
private final MySQLAggregateFunction func;
30+
private final String option;
3431

35-
public MySQLAggregate(MySQLExpression expr, MySQLAggregateFunction func) {
32+
public MySQLAggregate(MySQLExpression expr, MySQLAggregateFunction func, String option) {
3633
this.expr = expr;
3734
this.func = func;
35+
this.option = option;
3836
}
3937

4038
public MySQLExpression getExpr() {
@@ -44,4 +42,8 @@ public MySQLExpression getExpr() {
4442
public MySQLAggregateFunction getFunc() {
4543
return func;
4644
}
45+
46+
public String getOption() {
47+
return option;
48+
}
4749
}

src/sqlancer/mysql/gen/MySQLExpressionGenerator.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,13 @@ public String generateExplainQuery(MySQLSelect select) {
253253
public MySQLAggregate generateAggregate() {
254254
MySQLAggregateFunction func = Randomly.fromOptions(MySQLAggregateFunction.values());
255255
MySQLExpression expr = generateExpression();
256-
return new MySQLAggregate(expr, func);
256+
257+
if (Randomly.getBoolean() && func.getOptions().size() > 0) {
258+
String option = Randomly.fromList(func.getOptions());
259+
return new MySQLAggregate(expr, func, option);
260+
} else {
261+
return new MySQLAggregate(expr, func, null);
262+
}
257263
}
258264

259265
@Override
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package sqlancer.mysql;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import org.junit.jupiter.params.ParameterizedTest;
5+
import org.junit.jupiter.params.provider.EnumSource;
6+
7+
import sqlancer.mysql.ast.MySQLAggregate;
8+
import sqlancer.mysql.ast.MySQLColumnReference;
9+
10+
public class MySQLToStringVisitorTest {
11+
12+
@ParameterizedTest
13+
@EnumSource(MySQLAggregate.MySQLAggregateFunction.class)
14+
void visitAggregateWithOptions(MySQLAggregate.MySQLAggregateFunction function) {
15+
MySQLSchema.MySQLColumn aCol = new MySQLSchema.MySQLColumn("a",
16+
MySQLSchema.MySQLDataType.INT, false, 0);
17+
MySQLColumnReference aRef = new MySQLColumnReference(aCol, null);
18+
19+
MySQLToStringVisitor visitor = new MySQLToStringVisitor();
20+
21+
for (String option : function.getOptions()) {
22+
visitor.visit(new MySQLAggregate(aRef, function, option));
23+
assertEquals(String.format("%s(%s a)", function, option), visitor.get());
24+
}
25+
}
26+
27+
@ParameterizedTest
28+
@EnumSource(MySQLAggregate.MySQLAggregateFunction.class)
29+
void visitAggregateWithoutOptions(MySQLAggregate.MySQLAggregateFunction function) {
30+
MySQLSchema.MySQLColumn aCol = new MySQLSchema.MySQLColumn("a",
31+
MySQLSchema.MySQLDataType.INT, false, 0);
32+
MySQLColumnReference aRef = new MySQLColumnReference(aCol, null);
33+
34+
MySQLToStringVisitor visitor = new MySQLToStringVisitor();
35+
36+
visitor.visit(new MySQLAggregate(aRef, function, null));
37+
assertEquals(String.format("%s(a)", function), visitor.get());
38+
}
39+
}

0 commit comments

Comments
 (0)