Skip to content

Commit 38bf640

Browse files
authored
Query Plan Guidance Framework (#661)
* QPG framework * QPG for SQLite * QPG for CockroachDB * QPG for TiDB
1 parent fb93140 commit 38bf640

File tree

8 files changed

+430
-5
lines changed

8 files changed

+430
-5
lines changed

src/sqlancer/DatabaseProvider.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ public interface DatabaseProvider<G extends GlobalState<O, ?, C>, O extends DBMS
3232
*/
3333
Reproducer<G> generateAndTestDatabase(G globalState) throws Exception;
3434

35+
/**
36+
* The experimental feature: Query Plan Guidance.
37+
*
38+
* @param globalState
39+
* the state created and is valid for this method call.
40+
*
41+
* @throws Exception
42+
* if testing fails.
43+
*
44+
*/
45+
void generateAndTestDatabaseWithQueryPlanGuidance(G globalState) throws Exception;
46+
3547
C createDatabase(G globalState) throws Exception;
3648

3749
/**

src/sqlancer/Main.java

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,13 @@ public static final class StateLogger {
5151

5252
private final File loggerFile;
5353
private File curFile;
54+
private File queryPlanFile;
5455
private FileWriter logFileWriter;
5556
public FileWriter currentFileWriter;
57+
private FileWriter queryPlanFileWriter;
5658
private static final List<String> INITIALIZED_PROVIDER_NAMES = new ArrayList<>();
5759
private final boolean logEachSelect;
60+
private final boolean logQueryPlan;
5861
private final DatabaseProvider<?, ?, ?> databaseProvider;
5962

6063
private static final class AlsoWriteToConsoleFileWriter extends FileWriter {
@@ -87,6 +90,10 @@ public StateLogger(String databaseName, DatabaseProvider<?, ?, ?> provider, Main
8790
if (logEachSelect) {
8891
curFile = new File(dir, databaseName + "-cur.log");
8992
}
93+
logQueryPlan = options.logQueryPlan();
94+
if (logQueryPlan) {
95+
queryPlanFile = new File(dir, databaseName + "-plan.log");
96+
}
9097
this.databaseProvider = provider;
9198
}
9299

@@ -138,6 +145,20 @@ public FileWriter getCurrentFileWriter() {
138145
return currentFileWriter;
139146
}
140147

148+
public FileWriter getQueryPlanFileWriter() {
149+
if (!logQueryPlan) {
150+
throw new UnsupportedOperationException();
151+
}
152+
if (queryPlanFileWriter == null) {
153+
try {
154+
queryPlanFileWriter = new FileWriter(queryPlanFile, true);
155+
} catch (IOException e) {
156+
throw new AssertionError(e);
157+
}
158+
}
159+
return queryPlanFileWriter;
160+
}
161+
141162
public void writeCurrent(StateToReproduce state) {
142163
if (!logEachSelect) {
143164
throw new UnsupportedOperationException();
@@ -172,6 +193,18 @@ private void write(Loggable loggable) {
172193
}
173194
}
174195

196+
public void writeQueryPlan(String queryPlan) {
197+
if (!logQueryPlan) {
198+
throw new UnsupportedOperationException();
199+
}
200+
try {
201+
getQueryPlanFileWriter().append(removeNamesFromQueryPlans(queryPlan));
202+
queryPlanFileWriter.flush();
203+
} catch (IOException e) {
204+
throw new AssertionError();
205+
}
206+
}
207+
175208
public void logException(Throwable reduce, StateToReproduce state) {
176209
Loggable stackTrace = getStackTrace(reduce);
177210
FileWriter logFileWriter2 = getLogFileWriter();
@@ -201,8 +234,7 @@ private void printState(FileWriter writer, StateToReproduce state) {
201234
.getInfo(state.getDatabaseName(), state.getDatabaseVersion(), state.getSeedValue()).getLogString());
202235

203236
for (Query<?> s : state.getStatements()) {
204-
sb.append(s.getLogString());
205-
sb.append('\n');
237+
sb.append(databaseProvider.getLoggableFactory().createLoggable(s.getLogString()).getLogString());
206238
}
207239
try {
208240
writer.write(sb.toString());
@@ -211,6 +243,13 @@ private void printState(FileWriter writer, StateToReproduce state) {
211243
}
212244
}
213245

246+
private String removeNamesFromQueryPlans(String queryPlan) {
247+
String result = queryPlan;
248+
result = result.replaceAll("t[0-9]+", "t0"); // Avoid duplicate tables
249+
result = result.replaceAll("v[0-9]+", "v0"); // Avoid duplicate views
250+
result = result.replaceAll("i[0-9]+", "i0"); // Avoid duplicate indexes
251+
return result + "\n";
252+
}
214253
}
215254

216255
public static class QueryManager<C extends SQLancerDBConnection> {
@@ -243,6 +282,10 @@ public void incrementSelectQueryCount() {
243282
Main.nrQueries.addAndGet(1);
244283
}
245284

285+
public Long getSelectQueryCount() {
286+
return Main.nrQueries.get();
287+
}
288+
246289
public void incrementCreateDatabase() {
247290
Main.nrDatabases.addAndGet(1);
248291
}
@@ -314,7 +357,12 @@ public void run() throws Exception {
314357
if (options.logEachSelect()) {
315358
logger.writeCurrent(state.getState());
316359
}
317-
Reproducer<G> reproducer = provider.generateAndTestDatabase(state);
360+
Reproducer<G> reproducer = null;
361+
if (options.enableQPG()) {
362+
provider.generateAndTestDatabaseWithQueryPlanGuidance(state);
363+
} else {
364+
reproducer = provider.generateAndTestDatabase(state);
365+
}
318366
try {
319367
logger.getCurrentFileWriter().close();
320368
logger.currentFileWriter = null;

src/sqlancer/MainOptions.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@ public class MainOptions {
5050
@Parameter(names = "--print-failed", description = "Logs failed insert, create and other statements without results", arity = 1)
5151
private boolean loggerPrintFailed = true; // NOPMD
5252

53+
@Parameter(names = "--qpg-enable", description = "Enable the experimental feature Query Plan Guidance (QPG)", arity = 1)
54+
private boolean enableQPG;
55+
56+
@Parameter(names = "--qpg-log-query-plan", description = "Logs the query plans of each query (requires --qpg-enable)", arity = 1)
57+
private boolean logQueryPlan;
58+
59+
@Parameter(names = "--qpg-max-interval", description = "The maximum number of iterations to mutate tables if no new query plans (requires --qpg-enable)")
60+
private static int qpgMaxInterval = 1000;
61+
62+
@Parameter(names = "--qpg-reward-weight", description = "The weight (0-1) of last reward when updating weighted average reward. A higher value denotes average reward is more affected by the last reward (requires --qpg-enable)")
63+
private static double qpgk = 0.25;
64+
65+
@Parameter(names = "--qpg-selection-probability", description = "The probability (0-1) of the random selection of mutators. A higher value (>0.5) favors exploration over exploitation. (requires --qpg-enable)")
66+
private static double qpgProbability = 0.7;
67+
5368
@Parameter(names = "--username", description = "The user name used to log into the DBMS")
5469
private String userName = "sqlancer"; // NOPMD
5570

@@ -151,6 +166,26 @@ public boolean loggerPrintFailed() {
151166
return loggerPrintFailed;
152167
}
153168

169+
public boolean logQueryPlan() {
170+
return logQueryPlan;
171+
}
172+
173+
public boolean enableQPG() {
174+
return enableQPG;
175+
}
176+
177+
public int getQPGMaxMutationInterval() {
178+
return qpgMaxInterval;
179+
}
180+
181+
public double getQPGk() {
182+
return qpgk;
183+
}
184+
185+
public double getQPGProbability() {
186+
return qpgProbability;
187+
}
188+
154189
public int getNrQueries() {
155190
return nrQueries;
156191
}

src/sqlancer/ProviderAdapter.java

Lines changed: 166 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
package sqlancer;
22

3+
import java.sql.SQLException;
4+
import java.util.HashMap;
5+
import java.util.Iterator;
36
import java.util.List;
7+
import java.util.Map;
48
import java.util.stream.Collectors;
59

610
import sqlancer.StateToReproduce.OracleRunReproductionState;
11+
import sqlancer.common.DBMSCommon;
712
import sqlancer.common.oracle.CompositeTestOracle;
813
import sqlancer.common.oracle.TestOracle;
914
import sqlancer.common.schema.AbstractSchema;
@@ -14,6 +19,13 @@ public abstract class ProviderAdapter<G extends GlobalState<O, ? extends Abstrac
1419
private final Class<G> globalClass;
1520
private final Class<O> optionClass;
1621

22+
// Variables for QPG
23+
Map<String, String> queryPlanPool = new HashMap<>();
24+
static double[] weightedAverageReward; // static variable for sharing across all threads
25+
int currentSelectRewards;
26+
int currentSelectCounts;
27+
int currentMutationOperator = -1;
28+
1729
public ProviderAdapter(Class<G> globalClass, Class<O> optionClass) {
1830
this.globalClass = globalClass;
1931
this.optionClass = optionClass;
@@ -67,7 +79,7 @@ public Reproducer<G> generateAndTestDatabase(G globalState) throws Exception {
6779
return null;
6880
}
6981

70-
protected abstract void checkViewsAreValid(G globalState);
82+
protected abstract void checkViewsAreValid(G globalState) throws SQLException;
7183

7284
protected TestOracle<G> getTestOracle(G globalState) throws Exception {
7385
List<? extends OracleFactory<G>> testOracleFactory = globalState.getDbmsSpecificOptions()
@@ -77,7 +89,11 @@ protected TestOracle<G> getTestOracle(G globalState) throws Exception {
7789
boolean userRequiresMoreThanZeroRows = globalState.getOptions().testOnlyWithMoreThanZeroRows();
7890
boolean checkZeroRows = testOracleRequiresMoreThanZeroRows || userRequiresMoreThanZeroRows;
7991
if (checkZeroRows && globalState.getSchema().containsTableWithZeroRows(globalState)) {
80-
throw new IgnoreMeException();
92+
if (globalState.getOptions().enableQPG()) {
93+
addRowsToAllTables(globalState);
94+
} else {
95+
throw new IgnoreMeException();
96+
}
8197
}
8298
if (testOracleFactory.size() == 1) {
8399
return testOracleFactory.get(0).create(globalState);
@@ -94,4 +110,152 @@ protected TestOracle<G> getTestOracle(G globalState) throws Exception {
94110

95111
public abstract void generateDatabase(G globalState) throws Exception;
96112

113+
// QPG: entry function
114+
@Override
115+
public void generateAndTestDatabaseWithQueryPlanGuidance(G globalState) throws Exception {
116+
if (weightedAverageReward == null) {
117+
weightedAverageReward = initializeWeightedAverageReward(); // Same length as the list of mutators
118+
}
119+
try {
120+
generateDatabase(globalState);
121+
checkViewsAreValid(globalState);
122+
globalState.getManager().incrementCreateDatabase();
123+
124+
Long executedQueryCount = 0L;
125+
while (executedQueryCount < globalState.getOptions().getNrQueries()) {
126+
int numOfNoNewQueryPlans = 0;
127+
TestOracle<G> oracle = getTestOracle(globalState);
128+
while (true) {
129+
try (OracleRunReproductionState localState = globalState.getState().createLocalState()) {
130+
assert localState != null;
131+
try {
132+
oracle.check();
133+
String query = oracle.getLastQueryString();
134+
executedQueryCount += 1;
135+
if (addQueryPlan(query, globalState)) {
136+
numOfNoNewQueryPlans = 0;
137+
} else {
138+
numOfNoNewQueryPlans++;
139+
}
140+
globalState.getManager().incrementSelectQueryCount();
141+
} catch (IgnoreMeException e) {
142+
143+
}
144+
assert localState != null;
145+
localState.executedWithoutError();
146+
}
147+
// exit loop to mutate tables if no new query plans have been found after a while
148+
if (numOfNoNewQueryPlans > globalState.getOptions().getQPGMaxMutationInterval()) {
149+
mutateTables(globalState);
150+
break;
151+
}
152+
}
153+
}
154+
} finally {
155+
globalState.getConnection().close();
156+
}
157+
}
158+
159+
// QPG: mutate tables for a new database state
160+
private synchronized boolean mutateTables(G globalState) throws Exception {
161+
// Update rewards based on a set of newly generated queries in last iteration
162+
if (currentMutationOperator != -1) {
163+
weightedAverageReward[currentMutationOperator] += ((double) currentSelectRewards
164+
/ (double) currentSelectCounts) * globalState.getOptions().getQPGk();
165+
}
166+
currentMutationOperator = -1;
167+
168+
// Choose mutator based on the rewards
169+
int selectedActionIndex = 0;
170+
if (Randomly.getPercentage() < globalState.getOptions().getQPGProbability()) {
171+
selectedActionIndex = globalState.getRandomly().getInteger(0, weightedAverageReward.length);
172+
} else {
173+
selectedActionIndex = DBMSCommon.getMaxIndexInDoubleArrary(weightedAverageReward);
174+
}
175+
int reward = 0;
176+
177+
try {
178+
executeMutator(selectedActionIndex, globalState);
179+
checkViewsAreValid(globalState); // Remove the invalid views
180+
reward = checkQueryPlan(globalState);
181+
} catch (IgnoreMeException | AssertionError e) {
182+
} finally {
183+
// Update rewards based on existing queries associated with the query plan pool
184+
updateReward(selectedActionIndex, (double) reward / (double) queryPlanPool.size(), globalState);
185+
currentMutationOperator = selectedActionIndex;
186+
}
187+
188+
// Clear the variables for storing the rewards of the action on a set of newly generated queries
189+
currentSelectRewards = 0;
190+
currentSelectCounts = 0;
191+
return true;
192+
}
193+
194+
// QPG: add a query plan to the query plan pool and return true if the query plan is new
195+
private boolean addQueryPlan(String selectStr, G globalState) throws Exception {
196+
String queryPlan = getQueryPlan(selectStr, globalState);
197+
198+
if (globalState.getOptions().logQueryPlan()) {
199+
globalState.getLogger().writeQueryPlan(queryPlan);
200+
}
201+
202+
currentSelectCounts += 1;
203+
if (queryPlanPool.containsKey(queryPlan)) {
204+
return false;
205+
} else {
206+
queryPlanPool.put(queryPlan, selectStr);
207+
currentSelectRewards += 1;
208+
return true;
209+
}
210+
}
211+
212+
// Obtain the reward of the current action based on the queries associated with the query plan pool
213+
private int checkQueryPlan(G globalState) throws Exception {
214+
int newQueryPlanFound = 0;
215+
HashMap<String, String> modifiedQueryPlan = new HashMap<>();
216+
for (Iterator<Map.Entry<String, String>> it = queryPlanPool.entrySet().iterator(); it.hasNext();) {
217+
Map.Entry<String, String> item = it.next();
218+
String queryPlan = item.getKey();
219+
String selectStr = item.getValue();
220+
String newQueryPlan = getQueryPlan(selectStr, globalState);
221+
if (newQueryPlan.isEmpty()) { // Invalid query
222+
it.remove();
223+
} else if (!queryPlan.equals(newQueryPlan)) { // A query plan has been changed
224+
it.remove();
225+
modifiedQueryPlan.put(newQueryPlan, selectStr);
226+
if (!queryPlanPool.containsKey(newQueryPlan)) { // A new query plan is found
227+
newQueryPlanFound++;
228+
}
229+
}
230+
}
231+
queryPlanPool.putAll(modifiedQueryPlan);
232+
return newQueryPlanFound;
233+
}
234+
235+
// QPG: update the reward of current action
236+
private void updateReward(int actionIndex, double reward, G globalState) {
237+
weightedAverageReward[actionIndex] += (reward - weightedAverageReward[actionIndex])
238+
* globalState.getOptions().getQPGk();
239+
}
240+
241+
// QPG: initialize the weighted average reward of all mutation operators (required implementation in specific DBMS)
242+
protected double[] initializeWeightedAverageReward() {
243+
throw new UnsupportedOperationException();
244+
}
245+
246+
// QPG: obtain the query plan of a query (required implementation in specific DBMS)
247+
protected String getQueryPlan(String selectStr, G globalState) throws Exception {
248+
throw new UnsupportedOperationException();
249+
}
250+
251+
// QPG: execute a mutation operator (required implementation in specific DBMS)
252+
protected void executeMutator(int index, G globalState) throws Exception {
253+
throw new UnsupportedOperationException();
254+
}
255+
256+
// QPG: add rows to all tables (required implementation in specific DBMS when enabling PQS oracle for QPG)
257+
protected boolean addRowsToAllTables(G globalState) throws Exception {
258+
throw new UnsupportedOperationException();
259+
}
260+
97261
}

0 commit comments

Comments
 (0)