Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix some invalid SQL
  • Loading branch information
kyleconroy committed Apr 4, 2024
commit f634f227ae6b8188a70ac69b6ac9a071b365f837
2 changes: 1 addition & 1 deletion internal/endtoend/endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func TestReplay(t *testing.T) {
}
switch c.SQL[i].Engine {
case config.EnginePostgreSQL:
uri := local.PostgreSQL(t, files)
uri := local.ReadOnlyPostgreSQL(t, files)
c.SQL[i].Database = &config.Database{
URI: uri,
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CREATE TABLE foo(
CREATE TABLE foo (
bar_id text,
site_url text
);
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
CREATE TABLE foo(
CREATE TABLE foo (
bar text
);
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
CREATE SCHEMA foo;
CREATE TABLE foo.bar (id serial not null);

38 changes: 24 additions & 14 deletions internal/sqltest/local/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ import (
var flight singleflight.Group

func PostgreSQL(t *testing.T, migrations []string) string {
return postgreSQL(t, migrations, true)
}

func ReadOnlyPostgreSQL(t *testing.T, migrations []string) string {
return postgreSQL(t, migrations, false)
}

func postgreSQL(t *testing.T, migrations []string, rw bool) string {
ctx := context.Background()
t.Helper()

Expand Down Expand Up @@ -49,13 +57,19 @@ func PostgreSQL(t *testing.T, migrations []string) string {
seed = append(seed, migrate.RemoveRollbackStatements(string(blob)))
}

name := fmt.Sprintf("sqlc_test_%x", h.Sum(nil))
var name string
if rw {
name = fmt.Sprintf("sqlc_test_%s", id())
} else {
name = fmt.Sprintf("sqlc_test_%x", h.Sum(nil))
}

uri, err := url.Parse(dburi)
if err != nil {
t.Fatal(err)
}
uri.Path = name
dropQuery := fmt.Sprintf(`DROP DATABASE IF EXISTS "%s" WITH (FORCE)`, name)

key := uri.String()

Expand All @@ -65,27 +79,17 @@ func PostgreSQL(t *testing.T, migrations []string) string {

var datname string
if err := row.Scan(&datname); err == nil {
fmt.Println("database already exists", name)
// Database already exists
t.Logf("database exists: %s", name)
return nil, nil
}

fmt.Println("creating database name", name)
t.Logf("creating database: %s", name)
if _, err := postgresPool.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, name)); err != nil {
return nil, err
}

dropQuery := fmt.Sprintf(`DROP DATABASE IF EXISTS "%s" WITH (FORCE)`, name)

cleanup := func() {
if _, err := postgresPool.Exec(ctx, dropQuery); err != nil {
t.Logf("failed cleaning up: %s", err)
}
}

conn, err := pgx.Connect(ctx, uri.String())
if err != nil {
cleanup()
return nil, fmt.Errorf("connect %s: %s", name, err)
}
defer conn.Close(ctx)
Expand All @@ -95,12 +99,18 @@ func PostgreSQL(t *testing.T, migrations []string) string {
continue
}
if _, err := conn.Exec(ctx, q); err != nil {
cleanup()
return nil, fmt.Errorf("%s: %s", q, err)
}
}
return nil, nil
})
if rw || err != nil {
t.Cleanup(func() {
if _, err := postgresPool.Exec(ctx, dropQuery); err != nil {
t.Fatalf("failed cleaning up: %s", err)
}
})
}
if err != nil {
t.Fatalf("create db: %s", err)
}
Expand Down
55 changes: 55 additions & 0 deletions scripts/cleanup-test-dbs/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package main

import (
"context"
"fmt"
"log"
"os"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)

func main() {
if err := run(); err != nil {
log.Fatal(err)
}
}

const query = `
SELECT datname
FROM pg_database
WHERE datname LIKE 'sqlc_test_%'
`

func run() error {
ctx := context.Background()
dburi := os.Getenv("POSTGRESQL_SERVER_URI")
if dburi == "" {
return fmt.Errorf("POSTGRESQL_SERVER_URI is empty")
}
pool, err := pgxpool.New(ctx, dburi)
if err != nil {
return err
}

rows, err := pool.Query(ctx, query)
if err != nil {
return err
}

names, err := pgx.CollectRows(rows, pgx.RowTo[string])
if err != nil {
return err
}

for _, name := range names {
drop := fmt.Sprintf(`DROP DATABASE IF EXISTS "%s" WITH (FORCE)`, name)
if _, err := pool.Exec(ctx, drop); err != nil {
return err
}
log.Println("dropping database", name)
}

return nil
}