Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/hashicorp/go-plugin v1.4.3
github.com/hashicorp/go-version v1.4.0
github.com/hashicorp/hcl/v2 v2.10.1
github.com/jackc/pgx/v4 v4.15.0
github.com/jackc/pgx/v4 v4.16.0
github.com/mattn/go-isatty v0.0.14
github.com/mitchellh/mapstructure v1.4.3 // indirect
github.com/rs/zerolog v1.26.1
Expand All @@ -24,7 +24,7 @@ require (
github.com/thoas/go-funk v0.9.1
github.com/vbauerster/mpb/v6 v6.0.3
github.com/zclconf/go-cty v1.9.1
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 // indirect
golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
Expand All @@ -35,14 +35,15 @@ require (
github.com/aws/aws-sdk-go v1.43.41
github.com/creasty/defaults v1.5.2
github.com/doug-martin/goqu/v9 v9.17.0
github.com/driftprogramming/pgxpoolmock v1.1.0
github.com/georgysavva/scany v0.2.9
github.com/getsentry/sentry-go v0.13.0
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.5.7
github.com/google/uuid v1.3.0
github.com/hairyhenderson/go-fsimpl v0.0.0-20220419174024-16654461dc34
github.com/hashicorp/go-getter v1.5.10
github.com/jackc/pgconn v1.11.0
github.com/jackc/pgconn v1.12.0
github.com/jackc/pgerrcode v0.0.0-20201024163028-a0d42d470451
github.com/jeremywohl/flatten v1.0.1
github.com/johannesboyne/gofakes3 v0.0.0-20220314170512-33c13122505e
Expand Down Expand Up @@ -116,9 +117,9 @@ require (
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgio v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgproto3/v2 v2.2.0 // indirect
github.com/jackc/pgproto3/v2 v2.3.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
github.com/jackc/pgtype v1.10.0 // indirect
github.com/jackc/pgtype v1.11.0 // indirect
github.com/jackc/puddle v1.2.1 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/klauspost/compress v1.15.1 // indirect
Expand Down
106 changes: 98 additions & 8 deletions go.sum

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pkg/core/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ type DialectExecutor interface {
// Setup is called on the dialect on initialization, returns the DSN (modified if necessary) to use for migrations
Setup(context.Context) (string, error)

// Identifier returns a unique identifier for the database if possible, or "", false
Identifier(context.Context) (string, bool)

// Validate is called before startup to check that the dialect can execute properly. If returns true and error is set, the error is merely logged.
Validate(context.Context) (bool, error)

Expand Down
41 changes: 41 additions & 0 deletions pkg/core/database/postgres/dbid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package postgres

import (
"context"
"os"
"testing"

"github.com/cloudquery/cq-provider-sdk/database"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
)

func TestDatabaseId(t *testing.T) {
dbUrl := os.Getenv("DATABASE_URL")
if dbUrl == "" {
dbUrl = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable"
}

pool, err := database.New(context.Background(), hclog.NewNullLogger(), dbUrl)
assert.NoError(t, err)

dbId1, err := GetDatabaseId(context.Background(), pool)
assert.NoError(t, err)
assert.NotEmpty(t, dbId1)

dbId2, err := GetDatabaseId(context.Background(), pool)
assert.NoError(t, err)
assert.Equal(t, dbId1, dbId2)

pool.Close()

// new conn
pool, err = database.New(context.Background(), hclog.NewNullLogger(), dbUrl)
assert.NoError(t, err)

dbId3, err := GetDatabaseId(context.Background(), pool)
assert.NoError(t, err)
assert.Equal(t, dbId1, dbId3)

pool.Close()
}
60 changes: 35 additions & 25 deletions pkg/core/database/postgres/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@ import (
"time"

sdkpg "github.com/cloudquery/cq-provider-sdk/database/postgres"
"github.com/georgysavva/scany/pgxscan"
"github.com/hashicorp/go-version"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)

type Executor struct {
dsn string
dsn string
dbId string
}

var MinPostgresVersion = version.Must(version.NewVersion("11.0"))

func New(dsn string) Executor {
return Executor{
func New(dsn string) *Executor {
return &Executor{
dsn: dsn,
}
}
Expand All @@ -28,7 +29,7 @@ func (e Executor) Setup(ctx context.Context) (string, error) {
return e.dsn, nil
}

func (e Executor) Validate(ctx context.Context) (bool, error) {
func (e *Executor) Validate(ctx context.Context) (bool, error) {
pool, err := sdkpg.Connect(ctx, e.dsn)
if err != nil {
return false, err
Expand All @@ -42,7 +43,15 @@ func (e Executor) Validate(ctx context.Context) (bool, error) {
return true, err
}

return true, nil
e.dbId, err = GetDatabaseId(ctx, pool)
return true, err
}

func (e Executor) Identifier(_ context.Context) (string, bool) {
if e.dbId == "" {
return "", false
}
return e.dbId, true
}

func (e Executor) Prepare(_ context.Context) error {
Expand All @@ -61,24 +70,6 @@ func ValidatePostgresConnection(ctx context.Context, pool *pgxpool.Pool) error {
return pool.Ping(ctx)
}

// queryRower helps with unit tests
type queryRower interface {
QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
}

func runningPostgresVersion(ctx context.Context, q queryRower) (*version.Version, error) {
row := q.QueryRow(ctx, "SELECT version()")
var result string
if err := row.Scan(&result); err != nil {
return nil, err
}
fields := strings.Fields(result)
if len(fields) < 2 {
return nil, fmt.Errorf("failed to parse version: %s", result)
}
return version.NewVersion(fields[1])
}

// ValidatePostgresVersion checks that PostgreSQL instance version available through pool is not lower than wanted version.
// In this case it returns nil. Otherwise returns error describing current and desired version or any other error encountered
// during the check.
Expand All @@ -91,7 +82,13 @@ func ValidatePostgresVersion(ctx context.Context, pool *pgxpool.Pool) error {
return doValidatePostgresVersion(ctx, conn, MinPostgresVersion)
}

func doValidatePostgresVersion(ctx context.Context, q queryRower, want *version.Version) error {
func GetDatabaseId(ctx context.Context, q pgxscan.Querier) (string, error) {
var result string
err := pgxscan.Get(ctx, q, &result, `SELECT system_identifier::varchar AS id FROM pg_control_system()`)
return result, err
}

func doValidatePostgresVersion(ctx context.Context, q pgxscan.Querier, want *version.Version) error {
got, err := runningPostgresVersion(ctx, q)
if err != nil {
return fmt.Errorf("error getting PostgreSQL version: %w", err)
Expand All @@ -101,3 +98,16 @@ func doValidatePostgresVersion(ctx context.Context, q queryRower, want *version.
}
return nil
}

func runningPostgresVersion(ctx context.Context, q pgxscan.Querier) (*version.Version, error) {
var result string
if err := pgxscan.Get(ctx, q, &result, `SELECT version()`); err != nil {
return nil, err
}

fields := strings.Fields(result)
if len(fields) < 2 {
return nil, fmt.Errorf("failed to parse version: %s", result)
}
return version.NewVersion(fields[1])
}
66 changes: 30 additions & 36 deletions pkg/core/database/postgres/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,85 +5,79 @@ import (
"errors"
"testing"

"github.com/driftprogramming/pgxpoolmock"
"github.com/golang/mock/gomock"
"github.com/hashicorp/go-version"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/assert"
)

type mockConn struct {
row pgx.Row
}

func (m mockConn) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
return m.row
}

type mockScanner struct {
t *testing.T
val string
err error
}

func (m mockScanner) Scan(dst ...interface{}) error {
if len(dst) != 1 {
m.t.Fatalf("called with %d args, want exactly one", len(dst))
}
ptr, ok := dst[0].(*string)
if !ok {
m.t.Fatalf("received %T, expected *string", dst[0])
}
*ptr = m.val
return m.err
}

func Test_doValidatePostgresVersion(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

tests := []struct {
name string
q mockConn
value string
minVersion string
mockErr error
wantErr error
}{
{
"scan error",
mockConn{row: mockScanner{err: errors.New("scan")}},
"",
"10.0",
errors.New("error getting PostgreSQL version: scan"),
errors.New("scan"),
errors.New("error getting PostgreSQL version: scany: query one result row: scan"),
},
{
"strange version output",
mockConn{row: mockScanner{t, "MSSQL", nil}},
"MSSQL",
"10.0",
nil,
errors.New("error getting PostgreSQL version: failed to parse version: MSSQL"),
},
{
"unparsable version",
mockConn{row: mockScanner{t, "PostgreSQL 10.a.1", nil}},
"PostgreSQL 10.a.1",
"10.0",
nil,
errors.New("error getting PostgreSQL version: Malformed version: 10.a.1"),
},
{
"lower than needed",
mockConn{row: mockScanner{t, "PostgreSQL 9.5 blah blah", nil}},
"PostgreSQL 9.5 blah blah",
"10.0",
nil,
errors.New("unsupported PostgreSQL version: 9.5.0. (should be >= 10.0.0)"),
},
{
"equal",
mockConn{row: mockScanner{t, "PostgreSQL 10.0 blah blah", nil}},
"PostgreSQL 10.0 blah blah",
"10.0",
nil,
nil,
},
{
"greater than needed",
mockConn{row: mockScanner{t, "PostgreSQL 12.5 blah blah", nil}},
"PostgreSQL 12.5 blah blah",
"10.0",
nil,
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockPool := pgxpoolmock.NewMockPgxPool(ctrl)
pgxRows := pgxpoolmock.NewRows([]string{"value"}).AddRow(tt.value).ToPgxRows()

if tt.mockErr == nil {
mockPool.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any()).Return(pgxRows, nil)
} else {
mockPool.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, tt.mockErr)
}

want := version.Must(version.NewVersion(tt.minVersion))
err := doValidatePostgresVersion(context.Background(), tt.q, want)
err := doValidatePostgresVersion(context.Background(), mockPool, want)
if (tt.wantErr == nil) != (err == nil) {
t.Errorf("wantErr is %v, returned error is %v", tt.wantErr, err)
}
Expand Down
32 changes: 17 additions & 15 deletions pkg/core/database/timescale/timescale.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/georgysavva/scany/pgxscan"
"github.com/golang-migrate/migrate/v4"
"github.com/hashicorp/go-version"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)

Expand All @@ -23,15 +22,11 @@ const (

var MinTimescaleVersion = version.Must(version.NewVersion("2.0"))

// queryRower helps with unit tests
type queryRower interface {
QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
}

type Executor struct {
dsn string
cfg *history.Config
ddl *DDLManager
dsn string
dbId string
cfg *history.Config
ddl *DDLManager
}

func New(dsn string, cfg *history.Config) (*Executor, error) {
Expand Down Expand Up @@ -63,7 +58,7 @@ func (e *Executor) Setup(ctx context.Context) (string, error) {
return history.TransformDSN(e.dsn)
}

func (e Executor) Validate(ctx context.Context) (bool, error) {
func (e *Executor) Validate(ctx context.Context) (bool, error) {
pool, err := pgsdk.Connect(ctx, e.dsn)
if err != nil {
return false, err
Expand All @@ -90,7 +85,15 @@ func (e Executor) Validate(ctx context.Context) (bool, error) {
return false, err
}

return true, nil
e.dbId, err = postgres.GetDatabaseId(ctx, pool)
return true, err
}

func (e Executor) Identifier(_ context.Context) (string, bool) {
if e.dbId == "" {
return "", false
}
return e.dbId, true
}

func (e Executor) Prepare(ctx context.Context) error {
Expand All @@ -111,10 +114,9 @@ func (e Executor) Finalize(ctx context.Context, retErr error) error {
return retErr // keep migrate.ErrNoChange
}

func runningTimescaleVersion(ctx context.Context, q queryRower) (*version.Version, error) {
row := q.QueryRow(ctx, timescaleVersionQuery)
func runningTimescaleVersion(ctx context.Context, q pgxscan.Querier) (*version.Version, error) {
var result string
if err := row.Scan(&result); err != nil {
if err := pgxscan.Get(ctx, q, &result, timescaleVersionQuery); err != nil {
return nil, err
}
fields := strings.Fields(result)
Expand All @@ -136,7 +138,7 @@ func ValidateTimescaleVersion(ctx context.Context, pool *pgxpool.Pool) error {
return doValidateTimescaleVersion(ctx, conn, MinTimescaleVersion)
}

func doValidateTimescaleVersion(ctx context.Context, q queryRower, want *version.Version) error {
func doValidateTimescaleVersion(ctx context.Context, q pgxscan.Querier, want *version.Version) error {
got, err := runningTimescaleVersion(ctx, q)
if err != nil {
return fmt.Errorf("error getting Timescale version: %w", err)
Expand Down
Loading