package sqlancer.postgres; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLIntegrityConstraintViolationException; import java.sql.Statement; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import org.postgresql.util.PSQLException; import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.SQLConnection; import sqlancer.common.DBMSCommon; import sqlancer.common.schema.AbstractRelationalTable; import sqlancer.common.schema.AbstractRowValue; import sqlancer.common.schema.AbstractSchema; import sqlancer.common.schema.AbstractTableColumn; import sqlancer.common.schema.AbstractTables; import sqlancer.common.schema.TableIndex; import sqlancer.postgres.PostgresSchema.PostgresTable; import sqlancer.postgres.PostgresSchema.PostgresTable.TableType; import sqlancer.postgres.ast.PostgresConstant; public class PostgresSchema extends AbstractSchema { private final String databaseName; public enum PostgresDataType { INT, BOOLEAN, TEXT, DECIMAL, FLOAT, REAL, RANGE, MONEY, BIT, INET; public static PostgresDataType getRandomType() { List dataTypes = new ArrayList<>(Arrays.asList(values())); if (PostgresProvider.generateOnlyKnown) { dataTypes.remove(PostgresDataType.DECIMAL); dataTypes.remove(PostgresDataType.FLOAT); dataTypes.remove(PostgresDataType.REAL); dataTypes.remove(PostgresDataType.INET); dataTypes.remove(PostgresDataType.RANGE); dataTypes.remove(PostgresDataType.MONEY); dataTypes.remove(PostgresDataType.BIT); } return Randomly.fromList(dataTypes); } } public static class PostgresColumn extends AbstractTableColumn { public PostgresColumn(String name, PostgresDataType columnType) { super(name, null, columnType); } public static PostgresColumn createDummy(String name) { return new PostgresColumn(name, PostgresDataType.INT); } } public static class PostgresTables extends AbstractTables { public PostgresTables(List tables) { super(tables); } public PostgresRowValue getRandomRowValue(SQLConnection con) throws SQLException { String randomRow = String.format("SELECT %s FROM %s ORDER BY RANDOM() LIMIT 1", columnNamesAsString( c -> c.getTable().getName() + "." + c.getName() + " AS " + c.getTable().getName() + c.getName()), // columnNamesAsString(c -> "typeof(" + c.getTable().getName() + "." + // c.getName() + ")") tableNamesAsString()); Map values = new HashMap<>(); try (Statement s = con.createStatement()) { ResultSet randomRowValues = s.executeQuery(randomRow); if (!randomRowValues.next()) { throw new AssertionError("could not find random row! " + randomRow + "\n"); } for (int i = 0; i < getColumns().size(); i++) { PostgresColumn column = getColumns().get(i); int columnIndex = randomRowValues.findColumn(column.getTable().getName() + column.getName()); assert columnIndex == i + 1; PostgresConstant constant; if (randomRowValues.getString(columnIndex) == null) { constant = PostgresConstant.createNullConstant(); } else { switch (column.getType()) { case INT: constant = PostgresConstant.createIntConstant(randomRowValues.getLong(columnIndex)); break; case BOOLEAN: constant = PostgresConstant.createBooleanConstant(randomRowValues.getBoolean(columnIndex)); break; case TEXT: constant = PostgresConstant.createTextConstant(randomRowValues.getString(columnIndex)); break; default: throw new IgnoreMeException(); } } values.put(column, constant); } assert !randomRowValues.next(); return new PostgresRowValue(this, values); } catch (PSQLException e) { throw new IgnoreMeException(); } } } public static PostgresDataType getColumnType(String typeString) { switch (typeString) { case "smallint": case "integer": case "bigint": return PostgresDataType.INT; case "boolean": return PostgresDataType.BOOLEAN; case "text": case "character": case "character varying": case "name": case "regclass": return PostgresDataType.TEXT; case "numeric": return PostgresDataType.DECIMAL; case "double precision": return PostgresDataType.FLOAT; case "real": return PostgresDataType.REAL; case "int4range": return PostgresDataType.RANGE; case "money": return PostgresDataType.MONEY; case "bit": case "bit varying": return PostgresDataType.BIT; case "inet": return PostgresDataType.INET; default: throw new AssertionError(typeString); } } public static class PostgresRowValue extends AbstractRowValue { protected PostgresRowValue(PostgresTables tables, Map values) { super(tables, values); } } public static class PostgresTable extends AbstractRelationalTable { public enum TableType { STANDARD, TEMPORARY } private final TableType tableType; private final List statistics; private final boolean isInsertable; private final boolean isPartitioned; public PostgresTable(String tableName, List columns, List indexes, TableType tableType, List statistics, boolean isView, boolean isInsertable) { super(tableName, columns, indexes, isView); this.statistics = statistics; this.isInsertable = isInsertable; this.tableType = tableType; // TODO: simple adapter for other implementations this.isPartitioned = false; } public PostgresTable(String tableName, List columns, List indexes, TableType tableType, List statistics, boolean isView, boolean isInsertable, boolean isPartitioned) { super(tableName, columns, indexes, isView); this.statistics = statistics; this.isInsertable = isInsertable; this.tableType = tableType; this.isPartitioned = isPartitioned; } public List getStatistics() { return statistics; } public TableType getTableType() { return tableType; } public boolean isInsertable() { return isInsertable; } public boolean isPartitioned() { return isPartitioned; } } public static final class PostgresStatisticsObject { private final String name; public PostgresStatisticsObject(String name) { this.name = name; } public String getName() { return name; } } public static final class PostgresIndex extends TableIndex { private PostgresIndex(String indexName) { super(indexName); } public static PostgresIndex create(String indexName) { return new PostgresIndex(indexName); } @Override public String getIndexName() { if (super.getIndexName().contentEquals("PRIMARY")) { return "`PRIMARY`"; } else { return super.getIndexName(); } } } public static PostgresSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { try { List databaseTables = new ArrayList<>(); try (Statement s = con.createStatement()) { try (ResultSet rs = s.executeQuery( "SELECT t.table_name, t.table_schema, t.table_type, t.is_insertable_into, c.relkind FROM information_schema.tables t JOIN pg_class c ON c.relname = t.table_name JOIN pg_namespace n ON n.oid = c.relnamespace AND n.nspname = t.table_schema WHERE t.table_schema='public' OR t.table_schema LIKE 'pg_temp_%' ORDER BY t.table_name;")) { while (rs.next()) { String tableName = rs.getString("table_name"); String tableTypeSchema = rs.getString("table_schema"); boolean isInsertable = rs.getBoolean("is_insertable_into"); boolean isPartitioned = "p".equals(rs.getString("relkind")); // TODO: also check insertable // TODO: insert into view? boolean isView = matchesViewName(tableName); // tableTypeStr.contains("VIEW") || // tableTypeStr.contains("LOCAL TEMPORARY") && // !isInsertable; PostgresTable.TableType tableType = getTableType(tableTypeSchema); List databaseColumns = getTableColumns(con, tableName); List indexes = getIndexes(con, tableName); List statistics = getStatistics(con); PostgresTable t = new PostgresTable(tableName, databaseColumns, indexes, tableType, statistics, isView, isInsertable, isPartitioned); for (PostgresColumn c : databaseColumns) { c.setTable(t); } databaseTables.add(t); } } } return new PostgresSchema(databaseTables, databaseName); } catch (SQLIntegrityConstraintViolationException e) { throw new AssertionError(e); } } protected static List getStatistics(SQLConnection con) throws SQLException { List statistics = new ArrayList<>(); try (Statement s = con.createStatement()) { try (ResultSet rs = s.executeQuery("SELECT stxname FROM pg_statistic_ext ORDER BY stxname;")) { while (rs.next()) { statistics.add(new PostgresStatisticsObject(rs.getString("stxname"))); } } } return statistics; } protected static PostgresTable.TableType getTableType(String tableTypeStr) throws AssertionError { PostgresTable.TableType tableType; if (tableTypeStr.contentEquals("public")) { tableType = TableType.STANDARD; } else if (tableTypeStr.startsWith("pg_temp")) { tableType = TableType.TEMPORARY; } else { throw new AssertionError(tableTypeStr); } return tableType; } protected static List getIndexes(SQLConnection con, String tableName) throws SQLException { List indexes = new ArrayList<>(); try (Statement s = con.createStatement()) { try (ResultSet rs = s.executeQuery(String .format("SELECT indexname FROM pg_indexes WHERE tablename='%s' ORDER BY indexname;", tableName))) { while (rs.next()) { String indexName = rs.getString("indexname"); if (DBMSCommon.matchesIndexName(indexName)) { indexes.add(PostgresIndex.create(indexName)); } } } } return indexes; } protected static List getTableColumns(SQLConnection con, String tableName) throws SQLException { List columns = new ArrayList<>(); try (Statement s = con.createStatement()) { try (ResultSet rs = s .executeQuery("select column_name, data_type from INFORMATION_SCHEMA.COLUMNS where table_name = '" + tableName + "' ORDER BY column_name")) { while (rs.next()) { String columnName = rs.getString("column_name"); String dataType = rs.getString("data_type"); PostgresColumn c = new PostgresColumn(columnName, getColumnType(dataType)); columns.add(c); } } } return columns; } public PostgresSchema(List databaseTables, String databaseName) { super(databaseTables); this.databaseName = databaseName; } public PostgresTables getRandomTableNonEmptyTables() { return new PostgresTables(Randomly.nonEmptySubset(getDatabaseTables())); } public String getDatabaseName() { return databaseName; } }