diff --git a/go.mod b/go.mod index 73ec0596f43ad0..8287caa97a7c79 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.17 require ( github.com/VividCortex/ewma v1.2.0 // indirect github.com/aws/aws-lambda-go v1.23.0 - github.com/cloudquery/cq-provider-sdk v0.6.1 + github.com/cloudquery/cq-provider-sdk v0.7.0-alpha2 github.com/fatih/color v1.13.0 github.com/fsnotify/fsnotify v1.4.9 github.com/golang-migrate/migrate/v4 v4.15.0 @@ -42,7 +42,6 @@ require ( github.com/google/go-cmp v0.5.6 github.com/google/uuid v1.3.0 github.com/hashicorp/go-getter v1.5.10 - github.com/huandu/go-sqlbuilder v1.13.0 github.com/jackc/pgconn v1.10.0 github.com/jackc/pgerrcode v0.0.0-20201024163028-a0d42d470451 github.com/jackc/pgtype v1.8.1 @@ -82,7 +81,6 @@ require ( github.com/hashicorp/go-safetemp v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/yamux v0.0.0-20210826001029-26ff87cf9493 // indirect - github.com/huandu/xstrings v1.3.2 // indirect github.com/iancoleman/strcase v0.2.0 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect diff --git a/go.sum b/go.sum index c7955b295d0222..bacc88c831fda6 100644 --- a/go.sum +++ b/go.sum @@ -206,8 +206,8 @@ github.com/cilium/ebpf v0.4.0/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJ github.com/cilium/ebpf v0.6.2/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJXRs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80= -github.com/cloudquery/cq-provider-sdk v0.6.1 h1:pyHabGR81AdsnwtZF0oaJhF9VPK5tHiC8DukA5JEW/A= -github.com/cloudquery/cq-provider-sdk v0.6.1/go.mod h1:lLjzStk8uqMiunTDnAp26QXyQ3XAMexOqzuo8T2riMc= +github.com/cloudquery/cq-provider-sdk v0.7.0-alpha2 h1:GY0NJLEYf5JSHluVJsdAfFN00ygX5A+HZHw6/LDif5Q= +github.com/cloudquery/cq-provider-sdk v0.7.0-alpha2/go.mod h1:T+ngRXzcjJ6otKDGkWnPrHTsZuHUe3KZKtyhSLcvHCs= github.com/cloudquery/faker/v3 v3.7.4 h1:cCcU3r0yHpS0gqKj9rRKAGS0/hY33fBxbqCNFtDD4ec= github.com/cloudquery/faker/v3 v3.7.4/go.mod h1:1b8WVG9Gh0T2hVo1a8dWeXfu0AhqSB6J/mmJaesqOeo= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= @@ -257,8 +257,8 @@ github.com/containerd/containerd v1.5.0-beta.1/go.mod h1:5HfvG1V2FsKesEGQ17k5/T7 github.com/containerd/containerd v1.5.0-beta.3/go.mod h1:/wr9AVtEM7x9c+n0+stptlo/uBBoBORwEx6ardVcmKU= github.com/containerd/containerd v1.5.0-beta.4/go.mod h1:GmdgZd2zA2GYIBZ0w09ZvgqEq8EfBp/m3lcVZIvPHhI= github.com/containerd/containerd v1.5.0-rc.0/go.mod h1:V/IXoMqNGgBlabz3tHD2TWDoTJseu1FGOKuoA4nNb2s= -github.com/containerd/containerd v1.5.8 h1:NmkCC1/QxyZFBny8JogwLpOy2f+VEbO/f6bV2Mqtwuw= -github.com/containerd/containerd v1.5.8/go.mod h1:YdFSv5bTFLpG2HIYmfqDpSYYTDX+mc5qtSuYx1YUb/s= +github.com/containerd/containerd v1.5.9 h1:rs6Xg1gtIxaeyG+Smsb/0xaSDu1VgFhOCKBXxMxbsF4= +github.com/containerd/containerd v1.5.9/go.mod h1:fvQqCfadDGga5HZyn3j4+dx56qj2I9YwBrlSdalvJYQ= github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= github.com/containerd/continuity v0.0.0-20190815185530-f2a389ac0a02/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= github.com/containerd/continuity v0.0.0-20191127005431-f65d91d395eb/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= @@ -666,12 +666,6 @@ github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKe github.com/hashicorp/yamux v0.0.0-20210826001029-26ff87cf9493 h1:brI5vBRUlAlM34VFmnLPwjnCL/FxAJp9XvOdX6Zt+XE= github.com/hashicorp/yamux v0.0.0-20210826001029-26ff87cf9493/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/huandu/go-assert v1.1.5 h1:fjemmA7sSfYHJD7CUqs9qTwwfdNAx7/j2/ZlHXzNB3c= -github.com/huandu/go-assert v1.1.5/go.mod h1:yOLvuqZwmcHIC5rIzrBhT7D3Q9c3GFnd0JrPVhn/06U= -github.com/huandu/go-sqlbuilder v1.13.0 h1:IN1VRzcyQ+Kx74L0g5ZAY5qDaRJjwMWVmb6GrFAF8Jc= -github.com/huandu/go-sqlbuilder v1.13.0/go.mod h1:LILlbQo0MOYjlIiGgOSR3UcWQpd5Y/oZ7HLNGyAUz0E= -github.com/huandu/xstrings v1.3.2 h1:L18LIDzqlW6xN2rEkpdV8+oL/IXWJ1APd+vsdYy4Wdw= -github.com/huandu/xstrings v1.3.2/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/iancoleman/strcase v0.2.0 h1:05I4QRnGpI0m37iZQRuskXh+w77mr6Z41lwQzuHLwW0= github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= diff --git a/internal/test/provider/migrations/1_v0.0.1.down.sql b/internal/test/provider/migrations/1_v0.0.1.down.sql deleted file mode 100644 index 0ea1d04617f1b7..00000000000000 --- a/internal/test/provider/migrations/1_v0.0.1.down.sql +++ /dev/null @@ -1,2 +0,0 @@ -ALTER TABLE "slow_resource" - DROP COLUMN IF EXISTS upgrade_column \ No newline at end of file diff --git a/internal/test/provider/migrations/1_v0.0.1.up.sql b/internal/test/provider/migrations/1_v0.0.1.up.sql deleted file mode 100644 index 5b37beac44692a..00000000000000 --- a/internal/test/provider/migrations/1_v0.0.1.up.sql +++ /dev/null @@ -1,2 +0,0 @@ -ALTER TABLE "slow_resource" - ADD COLUMN IF NOT EXISTS upgrade_column integer \ No newline at end of file diff --git a/internal/test/provider/migrations/2_v0.0.2.down.sql b/internal/test/provider/migrations/2_v0.0.2.down.sql deleted file mode 100644 index 5db8e2799a8f66..00000000000000 --- a/internal/test/provider/migrations/2_v0.0.2.down.sql +++ /dev/null @@ -1,2 +0,0 @@ -ALTER TABLE "slow_resource" - DROP COLUMN IF EXISTS upgrade_column_2 \ No newline at end of file diff --git a/internal/test/provider/migrations/2_v0.0.2.up.sql b/internal/test/provider/migrations/2_v0.0.2.up.sql deleted file mode 100644 index 39954932fa4cc8..00000000000000 --- a/internal/test/provider/migrations/2_v0.0.2.up.sql +++ /dev/null @@ -1,2 +0,0 @@ -ALTER TABLE "slow_resource" - ADD COLUMN IF NOT EXISTS upgrade_column_2 integer \ No newline at end of file diff --git a/internal/test/provider/migrations/postgres/1_v0.0.1.down.sql b/internal/test/provider/migrations/postgres/1_v0.0.1.down.sql new file mode 100644 index 00000000000000..19db8ca6e8d45d --- /dev/null +++ b/internal/test/provider/migrations/postgres/1_v0.0.1.down.sql @@ -0,0 +1,8 @@ +-- Resource: error_resource +DROP TABLE IF EXISTS error_resource; + +-- Resource: slow_resource +DROP TABLE IF EXISTS slow_resource; + +-- Resource: very_slow_resource +DROP TABLE IF EXISTS very_slow_resource; diff --git a/internal/test/provider/migrations/postgres/1_v0.0.1.up.sql b/internal/test/provider/migrations/postgres/1_v0.0.1.up.sql new file mode 100644 index 00000000000000..c8b26fb0669960 --- /dev/null +++ b/internal/test/provider/migrations/postgres/1_v0.0.1.up.sql @@ -0,0 +1,24 @@ +-- Resource: error_resource +CREATE TABLE IF NOT EXISTS "error_resource" ( + "cq_id" uuid NOT NULL, + "cq_meta" jsonb, + CONSTRAINT error_resource_pk PRIMARY KEY(cq_id), + UNIQUE(cq_id) +); + +-- Resource: slow_resource +CREATE TABLE IF NOT EXISTS "slow_resource" ( + "cq_id" uuid NOT NULL, + "cq_meta" jsonb, + "some_bool" boolean, + CONSTRAINT slow_resource_pk PRIMARY KEY(cq_id), + UNIQUE(cq_id) +); + +-- Resource: very_slow_resource +CREATE TABLE IF NOT EXISTS "very_slow_resource" ( + "cq_id" uuid NOT NULL, + "cq_meta" jsonb, + CONSTRAINT very_slow_resource_pk PRIMARY KEY(cq_id), + UNIQUE(cq_id) +); diff --git a/internal/test/provider/migrations/postgres/2_v0.0.2.down.sql b/internal/test/provider/migrations/postgres/2_v0.0.2.down.sql new file mode 100644 index 00000000000000..683e0a60a9eaef --- /dev/null +++ b/internal/test/provider/migrations/postgres/2_v0.0.2.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE "slow_resource" + DROP COLUMN IF EXISTS upgrade_column; \ No newline at end of file diff --git a/internal/test/provider/migrations/postgres/2_v0.0.2.up.sql b/internal/test/provider/migrations/postgres/2_v0.0.2.up.sql new file mode 100644 index 00000000000000..a513aac13c44b5 --- /dev/null +++ b/internal/test/provider/migrations/postgres/2_v0.0.2.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE "slow_resource" + ADD COLUMN IF NOT EXISTS upgrade_column integer; \ No newline at end of file diff --git a/internal/test/provider/migrations/postgres/3_v0.0.3.down.sql b/internal/test/provider/migrations/postgres/3_v0.0.3.down.sql new file mode 100644 index 00000000000000..60217bb38fe84a --- /dev/null +++ b/internal/test/provider/migrations/postgres/3_v0.0.3.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE "slow_resource" + DROP COLUMN IF EXISTS upgrade_column_2; \ No newline at end of file diff --git a/internal/test/provider/migrations/postgres/3_v0.0.3.up.sql b/internal/test/provider/migrations/postgres/3_v0.0.3.up.sql new file mode 100644 index 00000000000000..157c03fd8aebef --- /dev/null +++ b/internal/test/provider/migrations/postgres/3_v0.0.3.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE "slow_resource" + ADD COLUMN IF NOT EXISTS upgrade_column_2 integer; \ No newline at end of file diff --git a/internal/test/provider/migrations/timescale/1_v0.0.1.down.sql b/internal/test/provider/migrations/timescale/1_v0.0.1.down.sql new file mode 100644 index 00000000000000..19db8ca6e8d45d --- /dev/null +++ b/internal/test/provider/migrations/timescale/1_v0.0.1.down.sql @@ -0,0 +1,8 @@ +-- Resource: error_resource +DROP TABLE IF EXISTS error_resource; + +-- Resource: slow_resource +DROP TABLE IF EXISTS slow_resource; + +-- Resource: very_slow_resource +DROP TABLE IF EXISTS very_slow_resource; diff --git a/internal/test/provider/migrations/timescale/1_v0.0.1.up.sql b/internal/test/provider/migrations/timescale/1_v0.0.1.up.sql new file mode 100644 index 00000000000000..951a2e57f2e657 --- /dev/null +++ b/internal/test/provider/migrations/timescale/1_v0.0.1.up.sql @@ -0,0 +1,27 @@ +-- Resource: error_resource +CREATE TABLE IF NOT EXISTS "error_resource" ( + "cq_id" uuid NOT NULL, + "cq_meta" jsonb, + "cq_fetch_date" timestamp without time zone NOT NULL, + CONSTRAINT error_resource_pk PRIMARY KEY(cq_fetch_date,cq_id), + UNIQUE(cq_fetch_date,cq_id) +); + +-- Resource: slow_resource +CREATE TABLE IF NOT EXISTS "slow_resource" ( + "cq_id" uuid NOT NULL, + "cq_meta" jsonb, + "cq_fetch_date" timestamp without time zone NOT NULL, + "some_bool" boolean, + CONSTRAINT slow_resource_pk PRIMARY KEY(cq_fetch_date,cq_id), + UNIQUE(cq_fetch_date,cq_id) +); + +-- Resource: very_slow_resource +CREATE TABLE IF NOT EXISTS "very_slow_resource" ( + "cq_id" uuid NOT NULL, + "cq_meta" jsonb, + "cq_fetch_date" timestamp without time zone NOT NULL, + CONSTRAINT very_slow_resource_pk PRIMARY KEY(cq_fetch_date,cq_id), + UNIQUE(cq_fetch_date,cq_id) +); diff --git a/internal/test/provider/migrations/timescale/2_v0.0.2.down.sql b/internal/test/provider/migrations/timescale/2_v0.0.2.down.sql new file mode 100644 index 00000000000000..683e0a60a9eaef --- /dev/null +++ b/internal/test/provider/migrations/timescale/2_v0.0.2.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE "slow_resource" + DROP COLUMN IF EXISTS upgrade_column; \ No newline at end of file diff --git a/internal/test/provider/migrations/timescale/2_v0.0.2.up.sql b/internal/test/provider/migrations/timescale/2_v0.0.2.up.sql new file mode 100644 index 00000000000000..a513aac13c44b5 --- /dev/null +++ b/internal/test/provider/migrations/timescale/2_v0.0.2.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE "slow_resource" + ADD COLUMN IF NOT EXISTS upgrade_column integer; \ No newline at end of file diff --git a/internal/test/provider/migrations/timescale/3_v0.0.3.down.sql b/internal/test/provider/migrations/timescale/3_v0.0.3.down.sql new file mode 100644 index 00000000000000..60217bb38fe84a --- /dev/null +++ b/internal/test/provider/migrations/timescale/3_v0.0.3.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE "slow_resource" + DROP COLUMN IF EXISTS upgrade_column_2; \ No newline at end of file diff --git a/internal/test/provider/migrations/timescale/3_v0.0.3.up.sql b/internal/test/provider/migrations/timescale/3_v0.0.3.up.sql new file mode 100644 index 00000000000000..157c03fd8aebef --- /dev/null +++ b/internal/test/provider/migrations/timescale/3_v0.0.3.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE "slow_resource" + ADD COLUMN IF NOT EXISTS upgrade_column_2 integer; \ No newline at end of file diff --git a/internal/test/test_history_config.hcl b/internal/test/test_history_config.hcl index 2df1292da7679e..ff1b67e91ca36d 100644 --- a/internal/test/test_history_config.hcl +++ b/internal/test/test_history_config.hcl @@ -1,7 +1,7 @@ cloudquery { connection { - dsn = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" + dsn = "tsdb://postgres:pass@localhost:5432/postgres?sslmode=disable" } provider "test" { source = "cloudquery" diff --git a/internal/test/tools/migrations.go b/internal/test/tools/migrations.go new file mode 100644 index 00000000000000..9d06f1744fc80a --- /dev/null +++ b/internal/test/tools/migrations.go @@ -0,0 +1,17 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/cloudquery/cloudquery/internal/test/provider" + "github.com/cloudquery/cq-provider-sdk/migration" +) + +func main() { + if err := migration.Run(context.Background(), provider.Provider(), "internal/test/provider/migrations"); err != nil { + fmt.Fprintf(os.Stderr, "Error: %s\n", err.Error()) + os.Exit(1) + } +} diff --git a/pkg/client/client.go b/pkg/client/client.go index 7113e63fef8766..dc6dcad06a71f4 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "io/fs" "path/filepath" "sort" "strconv" @@ -14,6 +15,8 @@ import ( "github.com/cloudquery/cloudquery/internal/logging" "github.com/cloudquery/cloudquery/internal/telemetry" + "github.com/cloudquery/cloudquery/pkg/client/database" + "github.com/cloudquery/cloudquery/pkg/client/database/timescale" "github.com/cloudquery/cloudquery/pkg/client/history" "github.com/cloudquery/cloudquery/pkg/config" "github.com/cloudquery/cloudquery/pkg/module" @@ -23,8 +26,10 @@ import ( "github.com/cloudquery/cloudquery/pkg/policy" "github.com/cloudquery/cloudquery/pkg/ui" "github.com/cloudquery/cq-provider-sdk/cqproto" - "github.com/cloudquery/cq-provider-sdk/helpers" - "github.com/cloudquery/cq-provider-sdk/provider" + sdkdb "github.com/cloudquery/cq-provider-sdk/database" + "github.com/cloudquery/cq-provider-sdk/database/dsn" + "github.com/cloudquery/cq-provider-sdk/migration" + "github.com/cloudquery/cq-provider-sdk/migration/migrator" "github.com/cloudquery/cq-provider-sdk/provider/schema" "github.com/cloudquery/cq-provider-sdk/provider/schema/diag" "github.com/getsentry/sentry-go" @@ -33,7 +38,6 @@ import ( "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-version" "github.com/hashicorp/hcl/v2" - "github.com/jackc/pgx/v4/pgxpool" zerolog "github.com/rs/zerolog/log" "go.opentelemetry.io/otel/attribute" otrace "go.opentelemetry.io/otel/trace" @@ -44,21 +48,17 @@ import ( var ( ErrMigrationsNotSupported = errors.New("provider doesn't support migrations") - //go:embed migrations/*.sql + //go:embed migrations/*/*.sql coreMigrations embed.FS ) -const ( - latestVersion = "latest" -) - // FetchRequest is provided to the Client to execute a fetch on one or more providers type FetchRequest struct { // UpdateCallback allows gets called when the client receives updates on fetch. UpdateCallback FetchUpdateCallback // Providers list of providers to call for fetching Providers []*config.Provider - // Optional: Adds extra fields to the provider, this is used for testing purposes. + // Optional: Adds extra fields to the provider, this is used for history mode and testing purposes. ExtraFields map[string]interface{} } @@ -201,11 +201,7 @@ type FetchDoneResult struct { // TableCreator creates tables based on schema received from providers type TableCreator interface { - CreateTable(ctx context.Context, conn *pgxpool.Conn, t *schema.Table, p *schema.Table) error -} - -type TableRemover interface { - DropTable(ctx context.Context, conn *pgxpool.Conn, t *schema.Table) error + CreateTable(context.Context, schema.QueryExecer, *schema.Table, *schema.Table) error } type FetchUpdateCallback func(update FetchUpdate) @@ -244,12 +240,13 @@ type Client struct { ModuleManager module.Manager // ModuleManager manages all modules lifecycle PolicyManager policy.Manager - // TableCreator defines how table are created in the database + // TableCreator defines how tables are created in the database, only for plugin protocol < 4 TableCreator TableCreator // HistoryConfig defines configuration for CloudQuery history mode HistoryCfg *history.Config - // pool is a list of connection that are used for policy/query execution - pool *pgxpool.Pool + + db *sdkdb.DB + dialectExecutor database.DialectExecutor } func New(ctx context.Context, options ...Option) (*Client, error) { @@ -276,29 +273,13 @@ func New(ctx context.Context, options ...Option) (*Client, error) { if c.DSN == "" { c.Logger.Warn("missing DSN, some commands won't work") - } else { - c.pool, err = CreateDatabase(ctx, c.DSN) - if err != nil { - return nil, err - } - if err := ValidatePostgresVersion(ctx, c.pool, MinPostgresVersion); err != nil { - c.Logger.Warn("postgres validation warning", "err", err) - } - } - // migrate cloudquery core tables to latest version - if c.DSN != "" { - if err := c.MigrateCore(ctx); err != nil { - return nil, fmt.Errorf("failed to migrate cloudquery_core tables: %w", err) - } - } - - if err := c.setupTableCreator(ctx); err != nil { + } else if err := c.initDatabase(ctx); err != nil { return nil, err } c.initModules() - c.PolicyManager = policy.NewManager(c.PolicyDirectory, c.pool, c.Logger) + c.PolicyManager = policy.NewManager(c.PolicyDirectory, c.db, c.Logger) return c, nil } @@ -403,13 +384,19 @@ func (c *Client) Fetch(ctx context.Context, request FetchRequest) (res *FetchRes c.Logger.Info("received fetch request", "extra_fields", request.ExtraFields, "history_enabled", c.HistoryCfg != nil) - searchPath := "" + var dsnURI string if c.HistoryCfg != nil { - searchPath = "history" - } - dsn, err := parseDSN(c.DSN, searchPath) - if err != nil { - return nil, err + var err error + dsnURI, err = history.TransformDSN(c.DSN) + if err != nil { + return nil, err + } + } else { + parsed, err := dsn.ParseConnectionString(c.DSN) + if err != nil { + return nil, err + } + dsnURI = parsed.String() } fetchSummaries := make(chan ProviderFetchSummary, len(request.Providers)) @@ -440,7 +427,7 @@ func (c *Client) Fetch(ctx context.Context, request FetchRequest) (res *FetchRes } defer func() { - if err := SaveFetchSummary(ctx, c.pool, &fs); err != nil { + if err := c.SaveFetchSummary(ctx, &fs); err != nil { c.Logger.Error("failed to save fetch summary", "err", err) } }() @@ -448,20 +435,20 @@ func (c *Client) Fetch(ctx context.Context, request FetchRequest) (res *FetchRes pLog := c.Logger.With("provider", providerConfig.Name, "alias", providerConfig.Alias, "version", providerPlugin.Version()) pLog.Info("requesting provider to configure") if c.HistoryCfg != nil { - pLog.Info("history enabled adding fetch date", "fetch_date", c.HistoryCfg.FetchDate().Format(time.RFC3339)) + fd := c.HistoryCfg.FetchDate() + pLog.Info("history enabled adding fetch date", "fetch_date", fd.Format(time.RFC3339)) if request.ExtraFields == nil { request.ExtraFields = make(map[string]interface{}) } - request.ExtraFields["cq_fetch_date"] = c.HistoryCfg.FetchDate() + request.ExtraFields["cq_fetch_date"] = fd } _, err = providerPlugin.Provider().ConfigureProvider(ctx, &cqproto.ConfigureProviderRequest{ CloudQueryVersion: Version, Connection: cqproto.ConnectionDetails{ - DSN: dsn, + DSN: dsnURI, }, - Config: providerConfig.Configuration, - DisableDelete: true, - ExtraFields: request.ExtraFields, + Config: providerConfig.Configuration, + ExtraFields: request.ExtraFields, }) if err != nil { pLog.Error("failed to configure provider", "error", err) @@ -581,7 +568,13 @@ func (c *Client) Fetch(ctx context.Context, request FetchRequest) (res *FetchRes return response, nil } -func (c *Client) GetProviderSchema(ctx context.Context, providerName string) (*cqproto.GetProviderSchemaResponse, error) { +type ProviderSchema struct { + *cqproto.GetProviderSchemaResponse + + ProtocolVersion int +} + +func (c *Client) GetProviderSchema(ctx context.Context, providerName string) (*ProviderSchema, error) { providerPlugin, err := c.Manager.CreatePlugin(providerName, "", nil) if err != nil { c.Logger.Error("failed to create provider plugin", "provider", providerName, "error", err) @@ -596,7 +589,16 @@ func (c *Client) GetProviderSchema(ctx context.Context, providerName string) (*c c.Logger.Warn("failed to kill provider", "provider", providerName) } }() - return providerPlugin.Provider().GetProviderSchema(ctx, &cqproto.GetProviderSchemaRequest{}) + + schema, err := providerPlugin.Provider().GetProviderSchema(ctx, &cqproto.GetProviderSchemaRequest{}) + if err != nil { + return nil, err + } + + return &ProviderSchema{ + GetProviderSchemaResponse: schema, + ProtocolVersion: providerPlugin.ProtocolVersion(), + }, nil } func (c *Client) GetProviderConfiguration(ctx context.Context, providerName string) (*cqproto.GetProviderConfigResponse, error) { @@ -627,24 +629,30 @@ func (c *Client) BuildProviderTables(ctx context.Context, providerName string) ( if err != nil { return err } - conn, err := c.pool.Acquire(ctx) - if err != nil { - return err - } - defer conn.Release() - for name, t := range s.ResourceTables { - c.Logger.Debug("creating tables for resource for provider", "resource_name", name, "provider", s.Name, "version", s.Version) - if err := c.TableCreator.CreateTable(ctx, conn, t, nil); err != nil { - return err - } - } if s.Migrations == nil { - c.Logger.Debug("provider doesn't support migrations", "provider", providerName) + // Keep the table creator if we don't have any migrations defined for this provider and hope that it works + for name, t := range s.ResourceTables { + c.Logger.Debug("creating tables for resource for provider", "resource_name", name, "provider", s.Name, "version", s.Version) + if err := c.TableCreator.CreateTable(ctx, c.db, t, nil); err != nil { + return fmt.Errorf("CreateTable(%s) failed: %w", t.Name, err) + } + } + return nil } + + defer func() { + if retErr == nil || !errors.Is(retErr, fs.ErrNotExist) { + return + } + + c.Logger.Error("BuildProviderTables failed", "error", retErr) + retErr = fmt.Errorf("Incompatible provider schema: Please drop provider tables and recreate, alternatively execute `cq provider drop %s`", providerName) + }() + // create migration table and set it to version based on latest create table - m, cfg, err := c.buildProviderMigrator(s.Migrations, providerName) + m, cfg, err := c.buildProviderMigrator(ctx, s.Migrations, providerName) if err != nil { return err } @@ -653,13 +661,10 @@ func (c *Client) BuildProviderTables(ctx context.Context, providerName string) ( c.Logger.Error("failed to close migrator connection", "error", err) } }() - if _, _, err := m.Version(); err == migrate.ErrNilVersion { - mv, err := m.FindLatestMigration(cfg.Version) - if err != nil { - return err - } - c.Logger.Debug("setting provider schema migration version", "version", cfg.Version, "migration_version", mv) - return m.SetVersion(cfg.Version) + + c.Logger.Debug("setting provider schema migration version", "version", cfg.Version) + if err := m.UpgradeProvider(cfg.Version); err != nil && err != migrate.ErrNoChange { + return err } return nil } @@ -684,7 +689,7 @@ func (c *Client) UpgradeProvider(ctx context.Context, providerName string) (retE if s.Migrations == nil { return ErrMigrationsNotSupported } - m, cfg, err := c.buildProviderMigrator(s.Migrations, providerName) + m, cfg, err := c.buildProviderMigrator(ctx, s.Migrations, providerName) if err != nil { return err } @@ -725,7 +730,7 @@ func (c *Client) DowngradeProvider(ctx context.Context, providerName string) (re if s.Migrations == nil { return fmt.Errorf("provider doesn't support migrations") } - m, cfg, err := c.buildProviderMigrator(s.Migrations, providerName) + m, cfg, err := c.buildProviderMigrator(ctx, s.Migrations, providerName) if err != nil { return err } @@ -748,7 +753,7 @@ func (c *Client) DropProvider(ctx context.Context, providerName string) (retErr if err != nil { return err } - m, cfg, err := c.buildProviderMigrator(s.Migrations, providerName) + m, cfg, err := c.buildProviderMigrator(ctx, s.Migrations, providerName) if err != nil { return err } @@ -874,8 +879,8 @@ func (c *Client) ExecuteModule(ctx context.Context, req ModuleRunRequest) (res * func (c *Client) Close() { c.Manager.Shutdown() - if c.pool != nil { - c.pool.Close() + if c.db != nil { + c.db.Close() } } @@ -887,7 +892,7 @@ func (c *Client) SetProviderVersion(ctx context.Context, providerName, version s if s.Migrations == nil { return fmt.Errorf("provider doesn't support migrations") } - m, cfg, err := c.buildProviderMigrator(s.Migrations, providerName) + m, cfg, err := c.buildProviderMigrator(ctx, s.Migrations, providerName) if err != nil { return err } @@ -896,11 +901,11 @@ func (c *Client) SetProviderVersion(ctx context.Context, providerName, version s } func (c *Client) initModules() { - c.ModuleManager = module.NewManager(c.pool, c.Logger) + c.ModuleManager = module.NewManager(c.db, c.Logger) c.ModuleManager.RegisterModule(drift.New(c.Logger)) } -func (c *Client) buildProviderMigrator(migrations map[string][]byte, providerName string) (*provider.Migrator, *config.RequiredProvider, error) { +func (c *Client) buildProviderMigrator(ctx context.Context, migrations map[string]map[string][]byte, providerName string) (*migrator.Migrator, *config.RequiredProvider, error) { providerConfig, err := c.getProviderConfig(providerName) if err != nil { return nil, nil, err @@ -910,32 +915,38 @@ func (c *Client) buildProviderMigrator(migrations map[string][]byte, providerNam return nil, nil, err } - dsn := c.DSN - if c.HistoryCfg != nil { - dsn, err = parseDSN(c.DSN, "history") - if err != nil { - return nil, nil, err - } + dsn, err := c.dialectExecutor.Setup(ctx) + if err != nil { + return nil, nil, fmt.Errorf("dialectExecutor.Setup: %w", err) } - m, err := provider.NewMigrator(c.Logger, migrations, dsn, fmt.Sprintf("%s_%s", org, name)) + m, err := migrator.New(c.Logger, c.db.DialectType(), migrations, dsn, fmt.Sprintf("%s_%s", org, name), c.dialectExecutor.Finalize) if err != nil { return nil, nil, err } return m, providerConfig, err } -func (c *Client) MigrateCore(ctx context.Context) error { - err := createCoreSchema(ctx, c.pool) +func (c *Client) MigrateCore(ctx context.Context, de database.DialectExecutor) error { + err := createCoreSchema(ctx, c.db) + if err != nil { + return err + } + + newDSN, err := de.Setup(ctx) + if err != nil { + return err + } + + migrations, err := migrator.ReadMigrationFiles(c.Logger, coreMigrations) if err != nil { return err } - migrations, err := provider.ReadMigrationFiles(c.Logger, coreMigrations) + newDSN, err = dsn.SetDSNElement(newDSN, map[string]string{"search_path": "cloudquery"}) if err != nil { return err } - dsn := c.DSN + "&search_path=cloudquery" - m, err := provider.NewMigrator(c.Logger, migrations, dsn, "cloudquery_core") + m, err := migrator.New(c.Logger, schema.Postgres, migrations, newDSN, "cloudquery_core", nil) if err != nil { return err } @@ -946,7 +957,7 @@ func (c *Client) MigrateCore(ctx context.Context) error { } }() - if err := m.UpgradeProvider(latestVersion); err != nil && err != migrate.ErrNoChange { + if err := m.UpgradeProvider(migrator.Latest); err != nil && err != migrate.ErrNoChange { return fmt.Errorf("failed to migrate cloudquery core schema: %w", err) } return nil @@ -966,44 +977,6 @@ func (c *Client) getProviderConfig(providerName string) (*config.RequiredProvide return providerConfig, nil } -func (c *Client) setupTableCreator(ctx context.Context) error { - if c.TableCreator != nil { - c.Logger.Debug("table creator already set") - return nil - } - if c.HistoryCfg == nil { - c.Logger.Debug("using default table creator without history mode enabled.") - c.TableCreator = provider.NewTableCreator(c.Logger) - return nil - } - creator, err := history.NewHistoryTableCreator(c.HistoryCfg, c.Logger) - if err != nil { - return err - } - // set history table creator - c.TableCreator = creator - conn, err := c.pool.Acquire(ctx) - if err != nil { - return fmt.Errorf("failed to acquire connection for history setup: %w", err) - } - defer conn.Release() - return history.SetupHistory(ctx, conn) -} - -func parseDSN(dsn, searchPath string) (string, error) { - url, err := helpers.ParseConnectionString(dsn) - if err != nil { - return "", err - } - if searchPath == "" { - return url.String(), nil - } - if url.RawQuery != "" { - return url.String() + "&search_path=history", nil - } - return url.String() + "search_path=history", nil -} - func parsePartialFetchKV(r *cqproto.FailedResourceFetch) []interface{} { kv := []interface{}{"table", r.TableName, "err", r.Error} if r.RootTableName != "" { @@ -1106,16 +1079,52 @@ func reportFetchSummaryErrors(span otrace.Span, fetchSummaries map[string]Provid ) } -func createCoreSchema(ctx context.Context, pool *pgxpool.Pool) error { - conn, err := pool.Acquire(ctx) +func createCoreSchema(ctx context.Context, db schema.QueryExecer) error { + return db.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS cloudquery") +} + +func (c *Client) initDatabase(ctx context.Context) error { + var err error + c.db, err = sdkdb.New(ctx, c.Logger, c.DSN) if err != nil { return err } - defer conn.Release() - _, err = conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS cloudquery") + var dt schema.DialectType + dt, c.dialectExecutor, err = database.GetExecutor(c.Logger, c.DSN, c.HistoryCfg) + if err != nil { + return fmt.Errorf("getExecutor: %w", err) + } + + if c.HistoryCfg != nil && dt != schema.TSDB { + // check if we're already on TSDB but the dsn is wrong + ts, err := timescale.New(c.Logger, c.DSN, c.HistoryCfg) + if err != nil { + return err + } + if ok, err := ts.Validate(ctx); ok && err == nil { + return fmt.Errorf("you must update the dsn to use tsdb:// prefix") + } + + return fmt.Errorf("history is only supported on timescaledb") + } + + if ok, err := c.dialectExecutor.Validate(ctx); err != nil { + return fmt.Errorf("validate: %w", err) + } else if !ok { + c.Logger.Warn("postgres validation warning") + } + + // migrate cloudquery core tables to latest version + if err := c.MigrateCore(ctx, c.dialectExecutor); err != nil { + return fmt.Errorf("failed to migrate cloudquery_core tables: %w", err) + } + + dialect, err := schema.GetDialect(c.db.DialectType()) if err != nil { return err } + c.TableCreator = migration.NewTableCreator(c.Logger, dialect) + return nil } diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 968d07351b76e5..20d4e4c0c98923 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -3,10 +3,12 @@ package client import ( "context" "errors" + "math/rand" "net" "os" "path/filepath" "reflect" + "strconv" "strings" "testing" "time" @@ -25,24 +27,66 @@ import ( "github.com/stretchr/testify/require" ) -var ( - providerSrc = "cloudquery" - requiredTestProviders = []*config.RequiredProvider{ +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func requiredTestProviders() []*config.RequiredProvider { + providerSrc := "cloudquery" + return []*config.RequiredProvider{ { Name: "test", Source: &providerSrc, Version: "latest", }, } -) +} + +func setupDB(t *testing.T) (dsn string) { + baseDSN := os.Getenv("CQ_CLIENT_TEST_DSN") + if baseDSN == "" { + baseDSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" + } + + conn, err := pgx.Connect(context.Background(), baseDSN) + if err != nil { + assert.FailNow(t, "failed to create connection") + return + } + + newDB := "test_" + strconv.Itoa(rand.Int()) + + _, err = conn.Exec(context.Background(), "CREATE DATABASE "+newDB) + assert.NoError(t, err) + + t.Cleanup(func() { + defer conn.Close(context.Background()) + + if os.Getenv("CQ_TEST_DEBUG") != "" && t.Failed() { + t.Log("Not dropping database", newDB) + return + } + + if _, err := conn.Exec(context.Background(), "DROP DATABASE "+newDB+" WITH(FORCE)"); err != nil { + t.Logf("teardown: drop database failed: %v", err) + } + }) + + return strings.Replace(baseDSN, "/postgres?", "/"+newDB+"?", 1) +} func TestClient_FailOnFetchWithPartialFetch(t *testing.T) { ctx := context.Background() + + dbDSN := setupDB(t) + c, err := New(ctx, func(options *Client) { - options.DSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = dbDSN + options.Providers = requiredTestProviders() }) assert.Nil(t, err) + t.Cleanup(c.Close) + // download test provider if it doesn't already exist err = c.DownloadProviders(ctx) assert.Nil(t, err) @@ -69,11 +113,15 @@ func TestClient_FailOnFetchWithPartialFetch(t *testing.T) { func TestClient_FailOnFetch(t *testing.T) { ctx := context.Background() + dbDSN := setupDB(t) + c, err := New(ctx, func(options *Client) { - options.DSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = dbDSN + options.Providers = requiredTestProviders() }) assert.Nil(t, err) + t.Cleanup(c.Close) + // download test provider if it doesn't already exist err = c.DownloadProviders(ctx) assert.Nil(t, err) @@ -99,11 +147,15 @@ func TestClient_FailOnFetch(t *testing.T) { func TestClient_PartialFetch(t *testing.T) { ctx := context.Background() + dbDSN := setupDB(t) + c, err := New(ctx, func(options *Client) { - options.DSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = dbDSN + options.Providers = requiredTestProviders() }) assert.Nil(t, err) + t.Cleanup(c.Close) + // download test provider if it doesn't already exist err = c.DownloadProviders(ctx) assert.Nil(t, err) @@ -128,12 +180,15 @@ func TestClient_PartialFetch(t *testing.T) { func TestClient_TestNoDownload(t *testing.T) { _ = os.RemoveAll(".cq/downloadTest") + c, err := New(context.Background(), func(options *Client) { - options.DSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = setupDB(t) + options.Providers = requiredTestProviders() options.PluginDirectory = ".cq/downloadTest" }) assert.Nil(t, err) + t.Cleanup(c.Close) + _, err = c.Manager.GetPluginDetails("test") assert.Error(t, err) @@ -147,12 +202,14 @@ func TestClient_TestNoDownload(t *testing.T) { pd, err := c.Manager.GetPluginDetails("test") assert.Nil(t, err) assert.Equal(t, "test", pd.Name) + c, err = New(context.Background(), func(options *Client) { - options.DSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = setupDB(t) + options.Providers = requiredTestProviders() options.PluginDirectory = ".cq/downloadTest" }) assert.Nil(t, err) + t.Cleanup(c.Close) pd2, err := c.Manager.GetPluginDetails("test") assert.Nil(t, err) assert.Equal(t, pd2.FilePath, pd.FilePath) @@ -164,16 +221,20 @@ func TestClient_TestNoDownload(t *testing.T) { func TestClient_FetchTimeout(t *testing.T) { cancelServe := setupTestPlugin(t) defer cancelServe() + + dbDSN := setupDB(t) + c, err := New(context.Background(), func(options *Client) { - options.DSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = dbDSN + options.Providers = requiredTestProviders() }) assert.Nil(t, err) if c == nil { assert.FailNow(t, "failed to create client") } + t.Cleanup(c.Close) assert.Nil(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) defer cancel() _, err = c.Fetch(ctx, FetchRequest{ Providers: []*config.Provider{ @@ -192,18 +253,24 @@ func TestClient_FetchTimeout(t *testing.T) { func TestClient_FetchNilConfig(t *testing.T) { cancelServe := setupTestPlugin(t) defer cancelServe() - cfg, diags := config.NewParser().LoadConfigFromSource("config.hcl", []byte(testConfig)) + + dbDSN := setupDB(t) + + testCfg := []byte(strings.Replace(testConfig, "DSN_PLACEHOLDER", `"`+dbDSN+`"`, 1)) + + cfg, diags := config.NewParser().LoadConfigFromSource("config.hcl", testCfg) assert.Nil(t, diags) // Set configuration to nil cfg.Providers[0].Configuration = nil c, err := New(context.Background(), func(options *Client) { - options.DSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = dbDSN + options.Providers = requiredTestProviders() }) assert.Nil(t, err) if c == nil { assert.FailNow(t, "failed to create client") } + t.Cleanup(c.Close) ctx := context.Background() _, err = c.Fetch(ctx, FetchRequest{ Providers: []*config.Provider{ @@ -220,15 +287,18 @@ func TestClient_FetchNilConfig(t *testing.T) { func TestClient_Fetch(t *testing.T) { cancelServe := setupTestPlugin(t) defer cancelServe() + + dbDSN := setupDB(t) + c, err := New(context.Background(), func(options *Client) { - options.DSN = "host=localhost user=postgres password=pass database=postgres port=5432 sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = dbDSN + options.Providers = requiredTestProviders() }) assert.Nil(t, err) if c == nil { assert.FailNow(t, "failed to create client") } - assert.Nil(t, err) + t.Cleanup(c.Close) ctx := context.Background() _, err = c.Fetch(ctx, FetchRequest{ @@ -245,15 +315,19 @@ func TestClient_Fetch(t *testing.T) { func TestClient_GetProviderSchema(t *testing.T) { cancelServe := setupTestPlugin(t) defer cancelServe() + + dbDSN := setupDB(t) + c, err := New(context.Background(), func(options *Client) { - options.DSN = "host=localhost user=postgres password=pass database=postgres port=5432 sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = dbDSN + options.Providers = requiredTestProviders() }) assert.Nil(t, err) if c == nil { assert.FailNow(t, "failed to create client") return } + t.Cleanup(c.Close) ctx := context.Background() s, err := c.GetProviderSchema(ctx, "test") if s == nil { @@ -269,15 +343,18 @@ func TestClient_GetProviderConfig(t *testing.T) { cancelServe := setupTestPlugin(t) defer cancelServe() + dbDSN := setupDB(t) + c, err := New(context.Background(), func(options *Client) { - options.DSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = dbDSN + options.Providers = requiredTestProviders() }) assert.Nil(t, err) if c == nil { assert.FailNow(t, "failed to create client") return } + t.Cleanup(c.Close) ctx := context.Background() pConfig, err := c.GetProviderConfiguration(ctx, "test") @@ -295,14 +372,17 @@ func TestClient_ProviderUpgradeNoBuild(t *testing.T) { cancelServe := setupTestPlugin(t) defer cancelServe() + dbDSN := setupDB(t) + c, err := New(context.Background(), func(options *Client) { - options.DSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = dbDSN + options.Providers = requiredTestProviders() }) assert.NoError(t, err) if c == nil { assert.FailNow(t, "failed to create client") } + t.Cleanup(c.Close) ctx := context.Background() err = c.DropProvider(ctx, "test") assert.NoError(t, err) @@ -314,14 +394,17 @@ func TestClient_ProviderMigrations(t *testing.T) { cancelServe := setupTestPlugin(t) defer cancelServe() + dbDSN := setupDB(t) + c, err := New(context.Background(), func(options *Client) { - options.DSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = dbDSN + options.Providers = requiredTestProviders() }) assert.NoError(t, err) if c == nil { assert.FailNow(t, "failed to create client") } + t.Cleanup(c.Close) ctx := context.Background() err = c.DropProvider(ctx, "test") assert.NoError(t, err) @@ -330,7 +413,7 @@ func TestClient_ProviderMigrations(t *testing.T) { err = c.UpgradeProvider(ctx, "test") assert.ErrorIs(t, err, migrate.ErrNoChange) - conn, err := pgx.Connect(ctx, "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable") + conn, err := pgx.Connect(ctx, dbDSN) if err != nil { assert.FailNow(t, "failed to create connection") return @@ -341,7 +424,7 @@ func TestClient_ProviderMigrations(t *testing.T) { c.Providers[0].Version = "v0.0.1" err = c.DowngradeProvider(ctx, "test") assert.NoError(t, err) - _, err = conn.Exec(ctx, "select some_bool, upgrade_column from slow_resource") + _, err = conn.Exec(ctx, "select some_bool from slow_resource") assert.NoError(t, err) _, err = conn.Exec(ctx, "select some_bool, upgrade_column, upgrade_column_2 from slow_resource") assert.Error(t, err) @@ -349,7 +432,7 @@ func TestClient_ProviderMigrations(t *testing.T) { c.Providers[0].Version = "v0.0.2" err = c.UpgradeProvider(ctx, "test") assert.NoError(t, err) - _, err = conn.Exec(ctx, "select some_bool, upgrade_column, upgrade_column_2 from slow_resource") + _, err = conn.Exec(ctx, "select some_bool, upgrade_column from slow_resource") assert.NoError(t, err) } @@ -358,14 +441,17 @@ func TestClient_ProviderSkipVersionMigrations(t *testing.T) { cancelServe := setupTestPlugin(t) defer cancelServe() + dbDSN := setupDB(t) + c, err := New(context.Background(), func(options *Client) { - options.DSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" - options.Providers = requiredTestProviders + options.DSN = dbDSN + options.Providers = requiredTestProviders() }) assert.Nil(t, err) if c == nil { assert.FailNow(t, "failed to create client") } + t.Cleanup(c.Close) ctx := context.Background() err = c.DropProvider(ctx, "test") assert.Nil(t, err) @@ -374,7 +460,7 @@ func TestClient_ProviderSkipVersionMigrations(t *testing.T) { err = c.UpgradeProvider(ctx, "test") assert.ErrorIs(t, err, migrate.ErrNoChange) - conn, err := pgx.Connect(ctx, "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable") + conn, err := pgx.Connect(ctx, dbDSN) if err != nil { assert.FailNow(t, "failed to create connection") return @@ -385,37 +471,40 @@ func TestClient_ProviderSkipVersionMigrations(t *testing.T) { c.Providers[0].Version = "v0.0.1" err = c.DowngradeProvider(ctx, "test") assert.Nil(t, err) - _, err = conn.Exec(ctx, "select some_bool, upgrade_column from slow_resource") + _, err = conn.Exec(ctx, "select some_bool from slow_resource") assert.Nil(t, err) _, err = conn.Exec(ctx, "select some_bool, upgrade_column, upgrade_column_2 from slow_resource") assert.Error(t, err) c.Providers[0].Version = "v0.0.5" - // latest migration should be to v0.0.2 + // latest migration should be to v0.0.3 err = c.UpgradeProvider(ctx, "test") assert.Nil(t, err) _, err = conn.Exec(ctx, "select some_bool, upgrade_column, upgrade_column_2 from slow_resource") assert.Nil(t, err) // insert dummy migration files like test provider just for version number return - m, _, err := c.buildProviderMigrator(map[string][]byte{ - "1_v0.0.1.up.sql": []byte(""), - "1_v0.0.1.down.sql": []byte(""), - "2_v0.0.2.up.sql": []byte(""), - "2_v0.0.2.down.sql": []byte(""), + m, _, err := c.buildProviderMigrator(ctx, map[string]map[string][]byte{ + "postgres": { + "1_v0.0.1.up.sql": []byte(""), + "1_v0.0.1.down.sql": []byte(""), + "2_v0.0.2.up.sql": []byte(""), + "2_v0.0.2.down.sql": []byte(""), + "3_v0.0.3.up.sql": []byte(""), + "3_v0.0.3.down.sql": []byte(""), + }, }, "test") - if err != nil { - t.Fatal(err) - } - // migrations should be in 2 i.e v0.0.2 + assert.NoError(t, err) + + // migrations should be in 3 i.e v0.0.3 v, dirty, err := m.Version() - assert.Equal(t, []interface{}{"v0.0.2", false, nil}, []interface{}{v, dirty, err}) + assert.Equal(t, []interface{}{"v0.0.3", false, nil}, []interface{}{v, dirty, err}) } const testConfig = `cloudquery { connection { - dsn = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" + dsn = DSN_PLACEHOLDER } provider "test" { source = "cloudquery" @@ -655,7 +744,7 @@ func Test_CheckForProviderUpdates(t *testing.T) { { Name: "test", Source: &source, - Version: "v0.0.9", + Version: "v0.0.11", }, }, 0, @@ -675,11 +764,13 @@ func Test_CheckForProviderUpdates(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() + c, err := New(ctx, func(options *Client) { - options.DSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" + options.DSN = setupDB(t) options.Providers = tt.providers }) assert.Nil(t, err) + t.Cleanup(c.Close) providers, err := c.CheckForProviderUpdates(ctx) assert.Nil(t, err) assert.Len(t, providers, tt.updates) diff --git a/pkg/client/database/database.go b/pkg/client/database/database.go new file mode 100644 index 00000000000000..29bf37e3993617 --- /dev/null +++ b/pkg/client/database/database.go @@ -0,0 +1,53 @@ +package database + +import ( + "context" + "fmt" + + "github.com/cloudquery/cloudquery/pkg/client/database/postgres" + "github.com/cloudquery/cloudquery/pkg/client/database/timescale" + "github.com/cloudquery/cloudquery/pkg/client/history" + sdkdb "github.com/cloudquery/cq-provider-sdk/database" + "github.com/cloudquery/cq-provider-sdk/provider/schema" + "github.com/hashicorp/go-hclog" +) + +type DialectExecutor interface { + // Setup is called on the dialect on initialization, returns the DSN (modified if necessary) to use for migrations + Setup(context.Context) (string, error) + + // Validate is called before startup to check that the dialect can execute properly + Validate(context.Context) (bool, error) + + // Finalize is called after migrations and upgrades are run + Finalize(context.Context) error +} + +var ( + _ DialectExecutor = (*postgres.Executor)(nil) + _ DialectExecutor = (*timescale.Executor)(nil) +) + +func GetExecutor(logger hclog.Logger, dsn string, c *history.Config) (schema.DialectType, DialectExecutor, error) { + if dsn == "" { + return schema.Postgres, nil, fmt.Errorf("missing DSN") + } + + dType, dsn, err := sdkdb.ParseDialectDSN(dsn) + if err != nil { + return dType, nil, err + } + + switch dType { + case schema.Postgres: + return dType, postgres.New(logger, dsn), nil + case schema.TSDB: + ts, err := timescale.New(logger, dsn, c) + if err != nil { + return dType, nil, err + } + return dType, ts, nil + default: + return dType, nil, fmt.Errorf("unhandled dialect type") + } +} diff --git a/pkg/client/database.go b/pkg/client/database/postgres/executor.go similarity index 70% rename from pkg/client/database.go rename to pkg/client/database/postgres/executor.go index 4f23a235cb81c3..d8a9827c3c8ee0 100644 --- a/pkg/client/database.go +++ b/pkg/client/database/postgres/executor.go @@ -1,44 +1,50 @@ -package client +package postgres import ( "context" "fmt" "strings" + sdkpg "github.com/cloudquery/cq-provider-sdk/database/postgres" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-version" - "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" ) +type Executor struct { + logger hclog.Logger + dsn string +} + var MinPostgresVersion = version.Must(version.NewVersion("11.0")) -func CreateDatabase(ctx context.Context, dsn string) (*pgxpool.Pool, error) { - if dsn == "" { - return nil, fmt.Errorf("missing DSN") - } - poolCfg, err := pgxpool.ParseConfig(dsn) - if err != nil { - return nil, err +func New(logger hclog.Logger, dsn string) Executor { + return Executor{ + logger: logger, + dsn: dsn, } +} - poolCfg.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { - UUIDType := pgtype.DataType{ - Value: &UUID{}, - Name: "uuid", - OID: pgtype.UUIDOID, - } +func (e Executor) Setup(ctx context.Context) (string, error) { + return e.dsn, nil +} - conn.ConnInfo().RegisterDataType(UUIDType) - return nil +func (e Executor) Validate(ctx context.Context) (bool, error) { + pool, err := sdkpg.Connect(ctx, e.dsn) + if err != nil { + return false, err } - poolCfg.LazyConnect = true - pool, err := pgxpool.ConnectConfig(ctx, poolCfg) - if err != nil { - return nil, err + if err := ValidatePostgresVersion(ctx, pool, MinPostgresVersion); err != nil { + return false, err } - return pool, err + + return true, nil +} + +func (e Executor) Finalize(ctx context.Context) error { + return nil } // queryRower helps with unit tests diff --git a/pkg/client/database_test.go b/pkg/client/database/postgres/executor_test.go similarity index 99% rename from pkg/client/database_test.go rename to pkg/client/database/postgres/executor_test.go index e9cb46df87100f..abf39249ac5121 100644 --- a/pkg/client/database_test.go +++ b/pkg/client/database/postgres/executor_test.go @@ -1,4 +1,4 @@ -package client +package postgres import ( "context" diff --git a/pkg/client/database/timescale/ddlmanager.go b/pkg/client/database/timescale/ddlmanager.go new file mode 100644 index 00000000000000..d85d10275f32fb --- /dev/null +++ b/pkg/client/database/timescale/ddlmanager.go @@ -0,0 +1,221 @@ +package timescale + +import ( + "context" + "fmt" + + "github.com/cloudquery/cloudquery/pkg/client/history" + "github.com/cloudquery/cq-provider-sdk/provider/schema" + "github.com/georgysavva/scany/pgxscan" + "github.com/hashicorp/go-hclog" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" +) + +const ( + listHyperTables = `SELECT hypertable_name FROM timescaledb_information.hypertables WHERE hypertable_schema=$1 ORDER BY 1` + + setChunkTimeInterval = `SELECT * FROM set_chunk_time_interval($1, INTERVAL '%d hour');` + dataRetentionPolicy = `SELECT history.update_retention($1, INTERVAL '%d day');` + + dropTableView = `DROP VIEW IF EXISTS "%[1]s"` + createTableView = `CREATE VIEW "%[1]s" AS SELECT * FROM history."%[1]s" WHERE cq_fetch_date = find_latest('history', '%[1]s')` +) + +type DDLManager struct { + log hclog.Logger + conn *pgxpool.Conn + cfg *history.Config + dialect schema.Dialect +} + +func NewDDLManager(l hclog.Logger, conn *pgxpool.Conn, cfg *history.Config, dt schema.DialectType) (*DDLManager, error) { + if dt != schema.TSDB { + return nil, fmt.Errorf("history is only supported on timescaledb") + } + + dialect, err := schema.GetDialect(dt) + if err != nil { + return nil, err + } + + return &DDLManager{ + log: l, + conn: conn, + cfg: cfg, + dialect: dialect, + }, nil +} + +func (h DDLManager) SetupHistory(ctx context.Context, conn *pgxpool.Conn) error { + var tables []string + if err := pgxscan.Select(ctx, conn, &tables, listHyperTables, history.SchemaName); err != nil { + return fmt.Errorf("failed to list hypertables: %w", err) + } + + for _, table := range tables { + if err := h.configureHyperTable(ctx, conn, table); err != nil { + return fmt.Errorf("failed to configure hypertable for table: %s: %w", table, err) + } + if err := h.recreateView(ctx, conn, table); err != nil { + return fmt.Errorf("recreateView: %w", err) + } + } + + return nil +} + +func (h DDLManager) configureHyperTable(ctx context.Context, conn *pgxpool.Conn, tableName string) error { + tName := fmt.Sprintf(`"%s"."%s"`, history.SchemaName, tableName) + + if _, err := conn.Exec(ctx, fmt.Sprintf(setChunkTimeInterval, h.cfg.TimeInterval), tName); err != nil { + return err + } + h.log.Debug("updated chunk_time_interval for table", "table", tableName, "interval", h.cfg.TimeInterval) + + // Below call is only needed for "parent" tables. dataRetentionPolicy function takes care of that by updating retention ONLY IF a previous retention policy is set. + if _, err := conn.Exec(ctx, fmt.Sprintf(dataRetentionPolicy, h.cfg.Retention), tName); err != nil { + return err + } + + h.log.Debug("created data retention policy", "table", tableName, "days", h.cfg.Retention) + return nil +} + +func (h DDLManager) recreateView(ctx context.Context, conn *pgxpool.Conn, table string) error { + if err := conn.BeginTxFunc(ctx, pgx.TxOptions{}, func(tx pgx.Tx) error { + // Must drop the view first -- CREATE OR REPLACE view won't cut it if columns are changed. PostgreSQL doc states: + // > The new query must generate the same columns that were generated by the existing view query (that is, the same column names in the same order and with + // > the same data types), but it may add additional columns to the end of the list. + // ref: https://www.postgresql.org/docs/14/sql-createview.html + + if _, err := tx.Exec(ctx, fmt.Sprintf(dropTableView, table)); err != nil { + return fmt.Errorf("failed to drop view for table: %w", err) + } + + if _, err := tx.Exec(ctx, fmt.Sprintf(createTableView, table)); err != nil { + return fmt.Errorf("failed to create view for table: %w", err) + } + + return nil + }); err != nil { + return fmt.Errorf("tx failed for %s: %w", table, err) + } + return nil +} + +func AddHistoryFunctions(ctx context.Context, conn *pgxpool.Conn) error { + return conn.BeginFunc(ctx, func(tx pgx.Tx) error { + if _, err := tx.Exec(ctx, createHistorySchema); err != nil { + return err + } + if _, err := tx.Exec(ctx, setupTriggerFunction); err != nil { + return err + } + if _, err := tx.Exec(ctx, setupParentFunction); err != nil { + return err + } + if _, err := tx.Exec(ctx, defineRetentionFunction); err != nil { + return err + } + if _, err := tx.Exec(ctx, cascadeDeleteFunction); err != nil { + return err + } + if _, err := tx.Exec(ctx, findLatestFetchDate); err != nil { + return err + } + return nil + }) +} + +const ( + createHistorySchema = `CREATE SCHEMA IF NOT EXISTS history;` + cascadeDeleteFunction = ` + CREATE OR REPLACE FUNCTION history.cascade_delete() + RETURNS trigger + LANGUAGE 'plpgsql' + COST 100 + VOLATILE NOT LEAKPROOF + AS $BODY$ + BEGIN + BEGIN + IF (TG_OP = 'DELETE') THEN + EXECUTE format('DELETE FROM history.%I where %I = %L AND cq_fetch_date = %L', TG_ARGV[0], TG_ARGV[1], OLD.cq_id, OLD.cq_fetch_date); + RETURN OLD; + END IF; + RETURN NULL; -- result is ignored since this is an AFTER trigger + END; + END; + $BODY$;` + + // Creates trigger on a referenced table, so each time a row from the parent table is deleted, referencing (child) rows are also cleared from database. + setupTriggerFunction = ` + CREATE OR REPLACE FUNCTION history.setup_tsdb_child(_table_name text, _column_name text, _parent_table_name text, _parent_column_name text) + RETURNS integer + LANGUAGE 'plpgsql' + COST 100 + VOLATILE PARALLEL UNSAFE + AS $BODY$ + BEGIN + PERFORM public.create_hypertable(_table_name, 'cq_fetch_date', chunk_time_interval => INTERVAL '1 day', if_not_exists => true); + + IF NOT EXISTS ( SELECT 1 FROM pg_trigger WHERE tgname = _table_name ) then + EXECUTE format( + 'CREATE TRIGGER %I BEFORE DELETE ON history.%I FOR EACH ROW EXECUTE PROCEDURE history.cascade_delete(%s, %s)'::text, + _table_name, _parent_table_name, _table_name, _column_name); + return 0; + ELSE + return 1; + END IF; + END; + $BODY$;` + + // Creates hypertable on the given table with a default chunk_time_interval, and adds a default retention policy + setupParentFunction = ` + CREATE OR REPLACE FUNCTION history.setup_tsdb_parent(_table_name text) + RETURNS integer + LANGUAGE 'plpgsql' + COST 100 + VOLATILE PARALLEL UNSAFE + AS $BODY$ + DECLARE + result integer; + BEGIN + PERFORM public.create_hypertable(_table_name, 'cq_fetch_date', chunk_time_interval => INTERVAL '1 day', if_not_exists => true); + SELECT public.add_retention_policy(_table_name, INTERVAL '14 day', if_not_exists => true) into result; + RETURN result; + END; + $BODY$;` + + // Updates the retention policy on the given table, only if a policy already exists. + defineRetentionFunction = ` + CREATE OR REPLACE FUNCTION history.update_retention(_table_name text, _retention interval) + RETURNS integer + LANGUAGE 'plpgsql' + COST 100 + VOLATILE PARALLEL UNSAFE + AS $BODY$ + DECLARE + result integer; + BEGIN + IF EXISTS ( SELECT 1 FROM timescaledb_information.jobs WHERE proc_name = 'policy_retention' AND hypertable_name = _table_name) THEN + PERFORM remove_retention_policy(_table_name, if_exists => true); + SELECT add_retention_policy(_table_name, _retention, if_not_exists => true) INTO result; + RETURN result; + ELSE + RETURN -2; + END IF; + END; + $BODY$;` + + findLatestFetchDate = ` + CREATE OR REPLACE FUNCTION find_latest(schema TEXT, _table_name TEXT) + RETURNS timestamp without time zone AS $body$ + DECLARE + fetchDate timestamp without time zone; + BEGIN + EXECUTE format('SELECT cq_fetch_date FROM %I.%I order by cq_fetch_date desc limit 1', schema, _table_name) into fetchDate; + return fetchDate; + END; + $body$ LANGUAGE plpgsql IMMUTABLE` +) diff --git a/pkg/client/database/timescale/timescale.go b/pkg/client/database/timescale/timescale.go new file mode 100644 index 00000000000000..8c37c61ccdaa71 --- /dev/null +++ b/pkg/client/database/timescale/timescale.go @@ -0,0 +1,94 @@ +package timescale + +import ( + "context" + "fmt" + + "github.com/cloudquery/cloudquery/pkg/client/database/postgres" + "github.com/cloudquery/cloudquery/pkg/client/history" + pgsdk "github.com/cloudquery/cq-provider-sdk/database/postgres" + "github.com/cloudquery/cq-provider-sdk/provider/schema" + "github.com/georgysavva/scany/pgxscan" + "github.com/hashicorp/go-hclog" +) + +const ( + validateTimescaleInstalled = `SELECT EXISTS(SELECT 1 FROM pg_extension where extname = 'timescaledb')` +) + +type Executor struct { + logger hclog.Logger + dsn string + cfg *history.Config +} + +func New(logger hclog.Logger, dsn string, cfg *history.Config) (*Executor, error) { + if cfg == nil { + return nil, fmt.Errorf("missing history config") + } + return &Executor{ + logger: logger, + dsn: dsn, + cfg: cfg, + }, nil +} + +// Setup sets all required history functions and validation checks that it can run cleanly. +func (e Executor) Setup(ctx context.Context) (string, error) { + pool, err := pgsdk.Connect(ctx, e.dsn) + if err != nil { + return e.dsn, err + } + defer pool.Close() + conn, err := pool.Acquire(ctx) + if err != nil { + return e.dsn, err + } + defer conn.Release() + + if err := AddHistoryFunctions(ctx, conn); err != nil { + return e.dsn, fmt.Errorf("failed to create history functions: %w", err) + } + + return history.TransformDSN(e.dsn) +} +func (e Executor) Validate(ctx context.Context) (bool, error) { + pool, err := pgsdk.Connect(ctx, e.dsn) + if err != nil { + return false, err + } + defer pool.Close() + + if err := postgres.ValidatePostgresVersion(ctx, pool, postgres.MinPostgresVersion); err != nil { + return false, err + } + + var installed bool + if err := pgxscan.Get(ctx, pool, &installed, validateTimescaleInstalled); err != nil { + return false, err + } + if !installed { + return false, fmt.Errorf("timescaledb extension not installed, `CREATE EXTENSION IF NOT EXISTS timescaledb;`") + } + + return true, nil +} + +func (e Executor) Finalize(ctx context.Context) error { + pool, err := pgsdk.Connect(ctx, e.dsn) + if err != nil { + return err + } + defer pool.Close() + conn, err := pool.Acquire(ctx) + if err != nil { + return err + } + defer conn.Release() + + ddl, err := NewDDLManager(e.logger, conn, e.cfg, schema.TSDB) + if err != nil { + return err + } + return ddl.SetupHistory(ctx, conn) +} diff --git a/pkg/client/database/timescale/timescale_test.go b/pkg/client/database/timescale/timescale_test.go new file mode 100644 index 00000000000000..fba4f5dee873fb --- /dev/null +++ b/pkg/client/database/timescale/timescale_test.go @@ -0,0 +1,163 @@ +//go:build history +// +build history + +package timescale + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + "time" + + "github.com/cloudquery/cloudquery/pkg/client/history" + pgsdk "github.com/cloudquery/cq-provider-sdk/database/postgres" + "github.com/cloudquery/cq-provider-sdk/migration" + "github.com/cloudquery/cq-provider-sdk/provider/schema" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/assert" +) + +var testTable = &schema.Table{ + Name: "test_table", + Columns: []schema.Column{ + { + Name: "id", + Type: schema.TypeString, + }, + }, + Relations: []*schema.Table{ + { + Name: "test_rel_table", + Columns: []schema.Column{ + { + Name: "parent_cq_id", + Type: schema.TypeUUID, + Resolver: schema.ParentIdResolver, + }, + { + Name: "test", + Type: schema.TypeString, + }, + }, + }, + }, + Options: schema.TableCreationOptions{PrimaryKeys: []string{"id"}}, +} + +func getDSN() string { + dbDSN := os.Getenv("CQ_TIMESCALE_TEST_DSN") + if dbDSN == "" { + dbDSN = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" // timescale + } + return dbDSN +} + +func TestSetupHistory(t *testing.T) { + ctx := context.TODO() + ts, err := New(hclog.L(), getDSN(), &history.Config{ + Retention: 1, + TimeInterval: 1, + TimeTruncation: 24, + }) + assert.NoError(t, err) + + ok, err := ts.Validate(ctx) + assert.NoError(t, err) + assert.True(t, ok) + + migrationDSN, err := ts.Setup(ctx) + assert.NoError(t, err) + + { + pool, err := pgsdk.Connect(ctx, migrationDSN) + assert.NoError(t, err) + defer pool.Close() + + conn, err := pool.Acquire(ctx) + assert.NoError(t, err) + defer conn.Release() + + tc := migration.NewTableCreator(hclog.L(), schema.TSDBDialect{}) + ups, downs, err := tc.CreateTableDefinitions(ctx, testTable, nil) + assert.NoError(t, err) + + newDowns := make([]string, len(downs)) + for i, sql := range downs { + if strings.HasPrefix(sql, "DROP TABLE ") { + sql = strings.TrimSuffix(sql, ";") + " CASCADE" + } + newDowns[i] = sql + } + defer func() { + for _, sql := range newDowns { + _, err = conn.Exec(ctx, sql) + assert.NoError(t, err) + } + }() + + for _, sql := range append(newDowns, ups...) { // DROP old tables first, if they exist + _, err = conn.Exec(ctx, sql) + assert.NoError(t, err) + } + } + + err = ts.Finalize(ctx) + assert.NoError(t, err) + + t.Run("FinalizeSecondTime", func(t *testing.T) { + // Finalize() again shouldn't create any errors + err := ts.Finalize(ctx) + assert.NoError(t, err) + }) + + pool, err := pgsdk.Connect(ctx, getDSN()) + assert.NoError(t, err) + defer pool.Close() + + conn, err := pool.Acquire(ctx) + assert.NoError(t, err) + defer conn.Release() + + t.Run("QueryView", func(t *testing.T) { + _, err = conn.Exec(ctx, "select cq_fetch_date from test_table") + assert.Nil(t, err) + }) + + t.Run("QueryHistoryTable", func(t *testing.T) { + _, err = conn.Exec(ctx, "select cq_fetch_date from history.test_table") + assert.Nil(t, err) + }) + + partitionDate := time.Now().Format("2006/01/02") + + t.Run("Insert", func(t *testing.T) { + const ( + sqlInsertMainTable = `INSERT INTO public.test_table(cq_id, cq_meta, cq_fetch_date, id) + VALUES ('0d0bf7c6-c87d-4b3c-a270-60246dcb6ab1', NULL, TO_DATE('%s', 'YYYY/MM/DD'), 'test_id')` + sqlInsertRelTable = `INSERT INTO public.test_rel_table(cq_id, cq_meta, cq_fetch_date, parent_cq_id, test) + VALUES (gen_random_uuid(), null, TO_DATE('%s', 'YYYY/MM/DD'), '0d0bf7c6-c87d-4b3c-a270-60246dcb6ab1', 'test2')` + ) + + _, err = conn.Exec(ctx, fmt.Sprintf(sqlInsertMainTable, partitionDate)) + assert.NoError(t, err) + _, err = conn.Exec(ctx, fmt.Sprintf(sqlInsertRelTable, partitionDate)) + assert.NoError(t, err) + }) + + t.Run("Select", func(t *testing.T) { + res, err := conn.Exec(ctx, "select * from test_rel_table") + assert.NoError(t, err) + assert.Equal(t, res.RowsAffected(), int64(1)) + }) + + t.Run("DeleteCascadeTrigger", func(t *testing.T) { + res, err := conn.Exec(ctx, fmt.Sprintf(`DELETE FROM test_table WHERE cq_fetch_date = TO_DATE('%s', 'YYYY/MM/DD')`, partitionDate)) + assert.NoError(t, err) + assert.Equal(t, res.RowsAffected(), int64(1)) + res, err = conn.Exec(ctx, "select * from test_rel_table") + assert.NoError(t, err) + assert.Equal(t, res.RowsAffected(), int64(0)) + }) +} diff --git a/pkg/client/fetch.go b/pkg/client/fetch.go index 6fe89c3b9ed806..5261684571b4b4 100644 --- a/pkg/client/fetch.go +++ b/pkg/client/fetch.go @@ -10,7 +10,6 @@ import ( "github.com/cloudquery/cq-provider-sdk/provider/schema/diag" "github.com/doug-martin/goqu/v9" "github.com/google/uuid" - "github.com/jackc/pgx/v4/pgxpool" ) // FetchSummary includes a summarized report of fetch, such as fetch id, fetch start and finish, @@ -59,13 +58,7 @@ type ResourceFetchSummary struct { } // SaveFetchSummary saves fetch summary into fetches database -func SaveFetchSummary(ctx context.Context, pool *pgxpool.Pool, fs *FetchSummary) error { - conn, err := pool.Acquire(ctx) - if err != nil { - return err - } - defer conn.Release() - +func (c *Client) SaveFetchSummary(ctx context.Context, fs *FetchSummary) error { id, err := uuid.NewUUID() if err != nil { return err @@ -77,6 +70,5 @@ func SaveFetchSummary(ctx context.Context, pool *pgxpool.Pool, fs *FetchSummary) return err } - _, err = conn.Exec(ctx, sql, args...) - return err + return c.db.Exec(ctx, sql, args...) } diff --git a/pkg/client/fetch_test.go b/pkg/client/fetch_test.go index 244d346deb52f3..254289bba0fbf6 100644 --- a/pkg/client/fetch_test.go +++ b/pkg/client/fetch_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/google/uuid" - "github.com/jackc/pgx/v4/pgxpool" "github.com/stretchr/testify/assert" ) @@ -75,38 +74,21 @@ var fetchSummaryTests = []fetchSummaryTest{ }, } -func setupDatabase(dsn string) (*pgxpool.Pool, error) { - poolCfg, err := pgxpool.ParseConfig(dsn) - if err != nil { - return nil, err - } - poolCfg.LazyConnect = true - pool, err := pgxpool.ConnectConfig(context.Background(), poolCfg) - if err != nil { - return nil, err - } - return pool, nil -} - func TestFetchSummary(t *testing.T) { - option := func(c *Client) { + c, err := New(context.Background(), func(c *Client) { c.DSN = testDBConnection - } - _, err := New(context.Background(), option) - assert.NoError(t, err) - pool, err := setupDatabase(testDBConnection) - assert.NoError(t, err) - defer pool.Close() + }) assert.NoError(t, err) + fetchId := uuid.New() for _, f := range fetchSummaryTests { if !f.skipFetchId { f.summary.FetchId = fetchId } f.summary.Start = time.Now() - err := SaveFetchSummary(context.Background(), pool, &f.summary) + err := c.SaveFetchSummary(context.Background(), &f.summary) if f.err != nil { - assert.Equal(t, f.err.Error(), err.Error()) + assert.EqualError(t, err, f.err.Error()) } else { assert.NoError(t, err) } diff --git a/pkg/client/history/history.go b/pkg/client/history/history.go index 5eb80d9110da8e..82b126c36dbd65 100644 --- a/pkg/client/history/history.go +++ b/pkg/client/history/history.go @@ -1,71 +1,12 @@ package history import ( - "context" - "errors" - "fmt" "time" - "github.com/georgysavva/scany/pgxscan" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgxpool" + "github.com/cloudquery/cq-provider-sdk/database/dsn" ) -const ( - createHistorySchema = `CREATE SCHEMA IF NOT EXISTS history;` - validateTimescaleInstalled = `SELECT EXISTS(SELECT 1 FROM pg_extension where extname = 'timescaledb')` - cascadeDeleteFunction = ` - CREATE OR REPLACE FUNCTION history.cascade_delete() - RETURNS trigger - LANGUAGE 'plpgsql' - COST 100 - VOLATILE NOT LEAKPROOF - AS $BODY$ - BEGIN - BEGIN - IF (TG_OP = 'DELETE') THEN - EXECUTE format('DELETE FROM history.%I where %I = %L AND cq_fetch_date = %L', TG_ARGV[0], TG_ARGV[1], OLD.cq_id, OLD.cq_fetch_date); - RETURN OLD; - END IF; - RETURN NULL; -- result is ignored since this is an AFTER trigger - END; - END; - $BODY$;` - buildTriggerFunction = ` - CREATE OR REPLACE FUNCTION history.build_trigger(_table_name text, _child_table_name text, _parent_id text) - RETURNS integer - LANGUAGE 'plpgsql' - COST 100 - VOLATILE PARALLEL UNSAFE - AS $BODY$ - BEGIN - IF NOT EXISTS ( SELECT 1 FROM pg_trigger WHERE tgname = _child_table_name ) then - EXECUTE format( - 'CREATE TRIGGER %I BEFORE DELETE ON history.%I FOR EACH ROW EXECUTE PROCEDURE history.cascade_delete(%s, %s)'::text, - _child_table_name, _table_name, _child_table_name, _parent_id); - return 0; - ELSE - return 1; - END IF; - END; - $BODY$;` - - createHyperTable = `SELECT * FROM create_hypertable($1, 'cq_fetch_date', chunk_time_interval => INTERVAL '%d day', if_not_exists => true);` - dataRetentionPolicy = `SELECT add_retention_policy($1, INTERVAL '%d day', if_not_exists => true);` - findLatestFetchDate = ` - CREATE OR REPLACE FUNCTION find_latest(schema TEXT, _table_name TEXT) - RETURNS timestamp without time zone AS $body$ - DECLARE - fetchDate timestamp without time zone; - BEGIN - EXECUTE format('SELECT cq_fetch_date FROM %I.%I order by cq_fetch_date desc limit 1', schema, _table_name) into fetchDate; - return fetchDate; - END; - $body$ LANGUAGE plpgsql IMMUTABLE` - - dropTableView = `DROP VIEW IF EXISTS "%[1]s"` - createTableView = `CREATE VIEW "%[1]s" AS SELECT * FROM history."%[1]s" WHERE cq_fetch_date = find_latest('history', '%[1]s')` -) +const SchemaName = "history" type Config struct { // Retention of data in days, defaults to 7 @@ -81,43 +22,7 @@ func (c Config) FetchDate() time.Time { return time.Now().UTC().Truncate(time.Duration(c.TimeTruncation) * time.Hour) } -// SetupHistory sets all required history functions and validation checks that it can run cleanly. -func SetupHistory(ctx context.Context, conn *pgxpool.Conn) error { - installed, err := validateInstalled(ctx, conn) - if err != nil { - return fmt.Errorf("failed to validate timescale installed: %w", err) - } - if !installed { - return errors.New("timescaledb extension not installed, `CREATE EXTENSION IF NOT EXISTS timescaledb;`") - } - if err := addHistoryFunctions(ctx, conn); err != nil { - return fmt.Errorf("failed to create history functions: %w", err) - } - return nil -} - -func addHistoryFunctions(ctx context.Context, conn *pgxpool.Conn) error { - return conn.BeginFunc(ctx, func(tx pgx.Tx) error { - if _, err := tx.Exec(ctx, createHistorySchema); err != nil { - return err - } - if _, err := tx.Exec(ctx, buildTriggerFunction); err != nil { - return err - } - if _, err := tx.Exec(ctx, cascadeDeleteFunction); err != nil { - return err - } - if _, err := tx.Exec(ctx, findLatestFetchDate); err != nil { - return err - } - return nil - }) -} - -func validateInstalled(ctx context.Context, conn *pgxpool.Conn) (bool, error) { - var installed bool - if err := pgxscan.Get(ctx, conn, &installed, validateTimescaleInstalled); err != nil { - return false, err - } - return installed, nil +// TransformDSN sets the search_path of the given DSN to the history schema +func TransformDSN(inputDSN string) (string, error) { + return dsn.SetDSNElement(inputDSN, map[string]string{"search_path": SchemaName}) } diff --git a/pkg/client/history/table.go b/pkg/client/history/table.go deleted file mode 100644 index 06cc629a7ecbe6..00000000000000 --- a/pkg/client/history/table.go +++ /dev/null @@ -1,165 +0,0 @@ -package history - -import ( - "context" - "fmt" - "strconv" - "strings" - - "github.com/cloudquery/cq-provider-sdk/provider/schema" - "github.com/georgysavva/scany/pgxscan" - "github.com/hashicorp/go-hclog" - "github.com/huandu/go-sqlbuilder" - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgxpool" -) - -type CreateHyperTableResult struct { - HypertableId int `db:"hypertable_id"` - SchemaName string `db:"schema_name"` - TableName string `db:"table_name"` - Created bool `db:"created"` -} - -type TableCreator struct { - log hclog.Logger - cfg *Config -} - -func NewHistoryTableCreator(cfg *Config, l hclog.Logger) (*TableCreator, error) { - return &TableCreator{ - l, - cfg, - }, nil -} - -func (h TableCreator) CreateTable(ctx context.Context, conn *pgxpool.Conn, t, p *schema.Table) error { - sql, err := h.buildTableSQL(t, p) - if err != nil { - return err - } - - h.log.Debug("creating table if not exists", "table", t.Name) - if _, err := conn.Exec(ctx, sql); err != nil { - return fmt.Errorf("failed to create table: %w", err) - } - - if err := h.createHyperTable(ctx, t, p, conn); err != nil { - return fmt.Errorf("failed to create hypertable for table: %s: %w", t.Name, err) - } - - if err := conn.BeginTxFunc(ctx, pgx.TxOptions{}, func(tx pgx.Tx) error { - // Must drop the view first -- CREATE OR REPLACE view won't cut it if columns are changed. PostgreSQL doc states: - // > The new query must generate the same columns that were generated by the existing view query (that is, the same column names in the same order and with - // > the same data types), but it may add additional columns to the end of the list. - // ref: https://www.postgresql.org/docs/14/sql-createview.html - - if _, err := tx.Exec(ctx, fmt.Sprintf(dropTableView, t.Name)); err != nil { - return fmt.Errorf("failed to drop view for table: %w", err) - } - - if _, err := tx.Exec(ctx, fmt.Sprintf(createTableView, t.Name)); err != nil { - return fmt.Errorf("failed to create view for table: %w", err) - } - - return nil - }); err != nil { - return fmt.Errorf("tx failed for %s: %w", t.Name, err) - } - - if p != nil { - if err := h.buildCascadeTrigger(ctx, conn, t, p); err != nil { - return fmt.Errorf("table build %s failed: %w", t.Name, err) - } - } - - // Create relation tables - for _, r := range t.Relations { - h.log.Debug("creating table relation", "table", r.Name) - if err := h.CreateTable(ctx, conn, r, t); err != nil { - return err - } - } - - return nil -} - -func (h TableCreator) createHyperTable(ctx context.Context, t, p *schema.Table, conn *pgxpool.Conn) error { - var hyperTable CreateHyperTableResult - tName := fmt.Sprintf(`"history"."%s"`, t.Name) - if err := pgxscan.Get(ctx, conn, &hyperTable, fmt.Sprintf(createHyperTable, h.cfg.TimeInterval), tName); err != nil { - return fmt.Errorf("failed to create hypertable: %w", err) - } - h.log.Debug("created hyper table for table", "table", hyperTable.TableName, "id", hyperTable.HypertableId, "created", hyperTable.Created) - if p != nil { - return nil - } - if _, err := conn.Exec(ctx, fmt.Sprintf(dataRetentionPolicy, h.cfg.Retention), tName); err != nil { - return err - } - h.log.Debug("created data retention policy", "table", hyperTable.TableName, "days", h.cfg.Retention) - return nil -} - -func (h TableCreator) buildCascadeTrigger(ctx context.Context, conn *pgxpool.Conn, t, p *schema.Table) error { - c := h.findParentIdColumn(t) - if c == nil { - return fmt.Errorf("failed to find parent cq id column for %s", t.Name) - } - if _, err := conn.Exec(ctx, "SELECT history.build_trigger($1, $2, $3);", p.Name, t.Name, c.Name); err != nil { - return fmt.Errorf("failed to create trigger: %w", err) - } - if _, err := conn.Exec(ctx, fmt.Sprintf("CREATE INDEX ON \"history\".\"%s\" (cq_fetch_date, %s)", t.Name, c.Name)); err != nil { - return fmt.Errorf("failed to create index on %s (cq_fetch_date, %s): %w", t.Name, c.Name, err) - } - return nil -} - -func (h TableCreator) findParentIdColumn(t *schema.Table) *schema.Column { - for _, c := range t.Columns { - if c.Meta().Resolver != nil && c.Meta().Resolver.Name == "ParentIdResolver" { - return &c - } - } - // Support old school columns instead of meta, this is backwards compatibility for providers using SDK prior v0.5.0 - for _, c := range t.Columns { - if strings.HasSuffix(c.Name, "cq_id") && c.Name != "cq_id" { - return &c - } - } - - return nil -} - -func (h TableCreator) buildTableSQL(table, _ *schema.Table) (string, error) { - // Build SQL to create a table. - ctb := sqlbuilder.CreateTable(fmt.Sprintf("history.%s", table.Name)).IfNotExists() - var uniques []string - for _, c := range schema.GetDefaultSDKColumns() { - ctb.Define(c.Name, schema.GetPgTypeFromType(c.Type)) - if c.CreationOptions.Unique { - uniques = append(uniques, c.Name) - } - } - ctb.Define("cq_fetch_date", schema.GetPgTypeFromType(schema.TypeTimestamp)) - h.buildColumns(ctb, table.Columns) - for _, s := range uniques { - ctb.Define(fmt.Sprintf("UNIQUE (cq_fetch_date, %s)", s)) - } - allKeys := append([]string{"cq_fetch_date"}, table.PrimaryKeys()...) - ctb.Define(fmt.Sprintf("constraint %s_pk primary key(%s)", schema.TruncateTableConstraint(table.Name), strings.Join(allKeys, ","))) - - sql, _ := ctb.BuildWithFlavor(sqlbuilder.PostgreSQL) - h.log.Trace("creating table if not exists", "table", table.Name) - return sql, nil -} - -func (h TableCreator) buildColumns(ctb *sqlbuilder.CreateTableBuilder, cc []schema.Column) { - for _, c := range cc { - defs := []string{strconv.Quote(c.Name), schema.GetPgTypeFromType(c.Type)} - if c.CreationOptions.Unique { - defs = []string{strconv.Quote(c.Name), schema.GetPgTypeFromType(c.Type), "unique"} - } - ctb.Define(defs...) - } -} diff --git a/pkg/client/history/table_test.go b/pkg/client/history/table_test.go deleted file mode 100644 index 87adc2ac204c77..00000000000000 --- a/pkg/client/history/table_test.go +++ /dev/null @@ -1,108 +0,0 @@ -//go:build history -// +build history - -package history_test - -import ( - "context" - "fmt" - "testing" - "time" - - "github.com/cloudquery/cloudquery/pkg/client" - "github.com/cloudquery/cloudquery/pkg/client/history" - "github.com/cloudquery/cq-provider-sdk/provider/schema" - "github.com/hashicorp/go-hclog" - "github.com/stretchr/testify/assert" -) - -const ( - testDBConnection = "postgres://postgres:pass@localhost:5432/postgres?sslmode=disable" - sqlInsertMainTable = `INSERT INTO public.test_table( - cq_id, meta, cq_fetch_date, test) - VALUES ('0d0bf7c6-c87d-4b3c-a270-60246dcb6ab1', NULL, TO_DATE('%s', 'YYYY/MM/DD'), 'test'); - ` - sqlInsertRelTable = `INSERT INTO public.test_rel_table( - cq_id, meta, cq_fetch_date, parent_cq_id, test) - VALUES (gen_random_uuid(), null, TO_DATE('%s', 'YYYY/MM/DD'), '0d0bf7c6-c87d-4b3c-a270-60246dcb6ab1', 'test2'); - ` -) - -var testTable = &schema.Table{ - Name: "test_table", - Columns: []schema.Column{ - { - Name: "test", - Type: schema.TypeString, - }, - }, - Relations: []*schema.Table{ - { - Name: "test_rel_table", - Columns: []schema.Column{ - { - Name: "parent_cq_id", - Type: schema.TypeUUID, - }, - { - Name: "test", - Type: schema.TypeString, - }, - }, - }, - }, - Options: schema.TableCreationOptions{PrimaryKeys: []string{"test"}}, -} - -func TestHistory_SetupHistory(t *testing.T) { - pool, err := client.CreateDatabase(context.Background(), testDBConnection) - assert.NoError(t, err) - defer pool.Close() - conn, err := pool.Acquire(context.Background()) - assert.NoError(t, err) - defer conn.Release() - assert.NoError(t, history.SetupHistory(context.Background(), conn)) -} - -func TestHistoryTableCreator_CreateTables(t *testing.T) { - m, err := history.NewHistoryTableCreator(&history.Config{Retention: 1, - TimeInterval: 1, - TimeTruncation: 24, - }, hclog.L()) - assert.NoError(t, err) - assert.NotNil(t, m) - - pool, err := client.CreateDatabase(context.Background(), testDBConnection) - assert.NoError(t, err) - defer pool.Close() - conn, err := pool.Acquire(context.Background()) - assert.NoError(t, err) - defer conn.Release() - // Call setup history as previous test can execute before - assert.NoError(t, history.SetupHistory(context.Background(), conn)) - - assert.NoError(t, m.CreateTable(context.Background(), conn, testTable, nil)) - // creating tables again shouldn't create any errors - assert.NoError(t, m.CreateTable(context.Background(), conn, testTable, nil)) - // query the view - _, err = conn.Exec(context.Background(), "select cq_fetch_date from test_table") - assert.Nil(t, err) - // query the history table itself - _, err = conn.Exec(context.Background(), "select cq_fetch_date from history.test_table") - assert.Nil(t, err) - partitionDate := time.Now().Format("2006/01/02") - _, err = conn.Exec(context.Background(), fmt.Sprintf(sqlInsertMainTable, partitionDate)) - assert.Nil(t, err) - _, err = conn.Exec(context.Background(), fmt.Sprintf(sqlInsertRelTable, partitionDate)) - // Check data was inserted - res, err := conn.Exec(context.Background(), "select * from test_rel_table") - assert.Nil(t, err) - assert.Equal(t, res.RowsAffected(), int64(1)) - // Test that delete cascade trigger works - res, err = conn.Exec(context.Background(), fmt.Sprintf(`DELETE FROM test_table WHERE cq_fetch_date = TO_DATE('%s', 'YYYY/MM/DD')`, partitionDate)) - assert.Nil(t, err) - assert.Equal(t, res.RowsAffected(), int64(1)) - res, err = conn.Exec(context.Background(), "select * from test_rel_table") - assert.Nil(t, err) - assert.Equal(t, res.RowsAffected(), int64(0)) -} diff --git a/pkg/client/migrations/1_v0.19.2.down.sql b/pkg/client/migrations/postgres/1_v0.19.2.down.sql similarity index 100% rename from pkg/client/migrations/1_v0.19.2.down.sql rename to pkg/client/migrations/postgres/1_v0.19.2.down.sql diff --git a/pkg/client/migrations/1_v0.19.2.up.sql b/pkg/client/migrations/postgres/1_v0.19.2.up.sql similarity index 100% rename from pkg/client/migrations/1_v0.19.2.up.sql rename to pkg/client/migrations/postgres/1_v0.19.2.up.sql diff --git a/pkg/module/drift/config.go b/pkg/module/drift/config.go index 83322eaa742d73..465e4e6677a6a1 100644 --- a/pkg/module/drift/config.go +++ b/pkg/module/drift/config.go @@ -422,6 +422,13 @@ func (d *Drift) applyProvider(cfg *ProviderConfig, p *cqproto.GetProviderSchemaR if len(cfg.versionConstraints) > 0 { pver, err := version.NewSemver(p.Version) + if err == nil { + if pr := pver.Prerelease(); pr != "" && strings.HasPrefix(pr, "SNAPSHOT") { + // re-parse without prerelease info + v := strings.SplitN(p.Version, "-", 2) + pver, err = version.NewVersion(v[0]) + } + } if err != nil { return false, []*hcl.Diagnostic{ { diff --git a/pkg/module/drift/drift.go b/pkg/module/drift/drift.go index 11753bef8af3f4..370688a8af4049 100644 --- a/pkg/module/drift/drift.go +++ b/pkg/module/drift/drift.go @@ -7,6 +7,7 @@ import ( "regexp" "strings" + "github.com/cloudquery/cq-provider-sdk/provider/schema" "github.com/doug-martin/goqu/v9" "github.com/doug-martin/goqu/v9/exp" "github.com/georgysavva/scany/pgxscan" @@ -15,7 +16,6 @@ import ( "github.com/hashicorp/hcl/v2/hclparse" "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v4/pgxpool" "github.com/spf13/afero" "github.com/cloudquery/cloudquery/pkg/module" @@ -260,7 +260,7 @@ func (d *Drift) run(ctx context.Context, req *module.ExecuteRequest) (*Results, return resList, nil } -func queryIntoResourceList(ctx context.Context, logger hclog.Logger, conn *pgxpool.Conn, sel *goqu.SelectDataset) (ResourceList, error) { +func queryIntoResourceList(ctx context.Context, logger hclog.Logger, conn schema.QueryExecer, sel *goqu.SelectDataset) (ResourceList, error) { query, args, err := sel.ToSQL() if err != nil { return nil, fmt.Errorf("goqu build failed: %w", err) diff --git a/pkg/module/drift/terraform.go b/pkg/module/drift/terraform.go index 23600ff0d69d8a..9b12b5c110f58f 100644 --- a/pkg/module/drift/terraform.go +++ b/pkg/module/drift/terraform.go @@ -15,7 +15,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/hashicorp/go-hclog" - "github.com/jackc/pgx/v4/pgxpool" "github.com/olekukonko/tablewriter" "github.com/tidwall/gjson" ) @@ -161,7 +160,7 @@ func parseTerraformAttribute(val interface{}, t schema.ValueType) interface{} { } } -func driftTerraform(ctx context.Context, logger hclog.Logger, conn *pgxpool.Conn, cloudName string, cloudTable *traversedTable, resName string, resources map[string]*ResourceConfig, iacData *IACConfig, states TFStates, runParams RunParams, accountIDs []string) (*Result, error) { +func driftTerraform(ctx context.Context, logger hclog.Logger, conn schema.QueryExecer, cloudName string, cloudTable *traversedTable, resName string, resources map[string]*ResourceConfig, iacData *IACConfig, states TFStates, runParams RunParams, accountIDs []string) (*Result, error) { res := &Result{ Different: nil, Equal: nil, diff --git a/pkg/module/manager.go b/pkg/module/manager.go index 124ff9377a9208..878673e044aa83 100644 --- a/pkg/module/manager.go +++ b/pkg/module/manager.go @@ -4,9 +4,9 @@ import ( "context" "fmt" + "github.com/cloudquery/cq-provider-sdk/provider/schema" "github.com/hashicorp/go-hclog" "github.com/hashicorp/hcl/v2" - "github.com/jackc/pgx/v4/pgxpool" ) // ManagerImpl is the manager implementation struct. @@ -14,8 +14,8 @@ type ManagerImpl struct { modules map[string]Module modOrder []string - // Instance of a database connection pool - pool *pgxpool.Pool + // Instance of database + pool schema.QueryExecer // Logger instance logger hclog.Logger @@ -35,7 +35,7 @@ type Manager interface { } // NewManager returns a new manager instance. -func NewManager(pool *pgxpool.Pool, logger hclog.Logger) *ManagerImpl { +func NewManager(pool schema.QueryExecer, logger hclog.Logger) *ManagerImpl { return &ManagerImpl{ modules: make(map[string]Module), pool: pool, @@ -64,14 +64,7 @@ func (m *ManagerImpl) ExecuteModule(ctx context.Context, modName string, cfg hcl return nil, fmt.Errorf("module configuration failed: %w", err) } - var err error - - // Acquire connection from the connection pool - execReq.Conn, err = m.pool.Acquire(ctx) - if err != nil { - return nil, fmt.Errorf("failed to acquire connection from the connection pool: %w", err) - } - defer execReq.Conn.Release() + execReq.Conn = m.pool return mod.Execute(ctx, execReq), nil } diff --git a/pkg/module/types.go b/pkg/module/types.go index 11bf96db46d737..b8f1fb0a84d2a0 100644 --- a/pkg/module/types.go +++ b/pkg/module/types.go @@ -3,10 +3,9 @@ package module import ( "context" - "github.com/hashicorp/hcl/v2" - "github.com/jackc/pgx/v4/pgxpool" - "github.com/cloudquery/cq-provider-sdk/cqproto" + "github.com/cloudquery/cq-provider-sdk/provider/schema" + "github.com/hashicorp/hcl/v2" ) type Module interface { @@ -29,7 +28,7 @@ type ExecuteRequest struct { // Providers is the list of providers to process Providers []*cqproto.GetProviderSchemaResponse // Conn is the db connection to use - Conn *pgxpool.Conn + Conn schema.QueryExecer } type ExecutionResult struct { diff --git a/pkg/plugin/plugin.go b/pkg/plugin/plugin.go index db3666ed0c8195..b4fc8fc51c369d 100644 --- a/pkg/plugin/plugin.go +++ b/pkg/plugin/plugin.go @@ -27,6 +27,7 @@ var pluginMap = map[string]plugin.Plugin{ type Plugin interface { Name() string Version() string + ProtocolVersion() int Provider() cqproto.CQProvider Close() } @@ -46,7 +47,7 @@ func newRemotePlugin(details *registry.ProviderDetails, alias string, env []stri client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: serve.Handshake, VersionedPlugins: map[int]plugin.PluginSet{ - 3: pluginMap, + cqproto.V4: pluginMap, }, Managed: true, Cmd: cmd, @@ -89,6 +90,8 @@ func (m managedPlugin) Name() string { return m.name } func (m managedPlugin) Version() string { return m.version } +func (m managedPlugin) ProtocolVersion() int { return m.client.NegotiatedVersion() } + func (m managedPlugin) Provider() cqproto.CQProvider { return m.provider } func (m managedPlugin) Close() { @@ -141,6 +144,8 @@ func (m unmanagedPlugin) Name() string { return m.name } func (m unmanagedPlugin) Version() string { return Unmanaged } +func (m unmanagedPlugin) ProtocolVersion() int { return cqproto.Vunmanaged } + func (m unmanagedPlugin) Provider() cqproto.CQProvider { return m.provider } func (m unmanagedPlugin) Close() {} diff --git a/pkg/policy/execute.go b/pkg/policy/execute.go index 121d1add63f897..b55353af320abe 100644 --- a/pkg/policy/execute.go +++ b/pkg/policy/execute.go @@ -6,13 +6,12 @@ import ( "errors" "fmt" "path" - "strings" - "path/filepath" + "strings" + "github.com/cloudquery/cq-provider-sdk/provider/schema" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-version" - "github.com/jackc/pgx/v4/pgxpool" "github.com/spf13/afero" ) @@ -46,7 +45,7 @@ func (f Update) DoneCount() int { // Executor implements the execution framework. type Executor struct { // Connection to the database - conn *pgxpool.Conn + conn schema.QueryExecer log hclog.Logger PolicyPath []string @@ -99,7 +98,7 @@ type ExecuteRequest struct { } // NewExecutor creates a new executor. -func NewExecutor(conn *pgxpool.Conn, log hclog.Logger, progressUpdate UpdateCallback) *Executor { +func NewExecutor(conn schema.QueryExecer, log hclog.Logger, progressUpdate UpdateCallback) *Executor { return &Executor{ conn: conn, log: log, @@ -239,7 +238,7 @@ func (e *Executor) executeQuery(ctx context.Context, q *Check) (*QueryResult, er func (e *Executor) createViews(ctx context.Context, policy *Policy) error { for _, v := range policy.Views { e.log.Info("creating policy view", "view", v.Name) - if _, err := e.conn.Exec(ctx, fmt.Sprintf("CREATE OR REPLACE TEMPORARY VIEW %s AS %s", v.Name, v.Query)); err != nil { + if err := e.conn.Exec(ctx, fmt.Sprintf("CREATE OR REPLACE TEMPORARY VIEW %s AS %s", v.Name, v.Query)); err != nil { return fmt.Errorf("failed to create view %s/%s: %w", policy.Name, v.Name, err) } } diff --git a/pkg/policy/execute_test.go b/pkg/policy/execute_test.go index f6d24c0a2c00b9..adfdf2de130df4 100644 --- a/pkg/policy/execute_test.go +++ b/pkg/policy/execute_test.go @@ -5,31 +5,27 @@ import ( "fmt" "testing" + sdkdb "github.com/cloudquery/cq-provider-sdk/database" + "github.com/cloudquery/cq-provider-sdk/provider/schema" "github.com/hashicorp/go-hclog" - "github.com/jackc/pgx/v4/pgxpool" "github.com/stretchr/testify/assert" ) -func setupPolicyDatabase(t *testing.T, tableName string) (*pgxpool.Pool, func(t *testing.T)) { - poolCfg, err := pgxpool.ParseConfig("postgres://postgres:pass@localhost:5432/postgres") - assert.NoError(t, err) - poolCfg.LazyConnect = true - pool, err := pgxpool.ConnectConfig(context.Background(), poolCfg) - assert.NoError(t, err) - conn, err := pool.Acquire(context.Background()) +func setupPolicyDatabase(t *testing.T, tableName string) (schema.QueryExecer, func(t *testing.T)) { + conn, err := sdkdb.New(context.Background(), hclog.NewNullLogger(), "postgres://postgres:pass@localhost:5432/postgres") assert.NoError(t, err) // Setup test data - _, err = conn.Exec(context.Background(), fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)) + err = conn.Exec(context.Background(), fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)) assert.NoError(t, err) - _, err = conn.Exec(context.Background(), fmt.Sprintf("CREATE TABLE %s (id serial PRIMARY KEY, name VARCHAR(50) NOT NULL)", tableName)) + err = conn.Exec(context.Background(), fmt.Sprintf("CREATE TABLE %s (id serial PRIMARY KEY, name VARCHAR(50) NOT NULL)", tableName)) assert.NoError(t, err) - _, err = conn.Exec(context.Background(), fmt.Sprintf("INSERT INTO %s VALUES (1, 'john')", tableName)) + err = conn.Exec(context.Background(), fmt.Sprintf("INSERT INTO %s VALUES (1, 'john')", tableName)) assert.NoError(t, err) // Return conn and tear down func - return pool, func(t *testing.T) { - _, err = conn.Exec(context.Background(), fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", tableName)) + return conn, func(t *testing.T) { + err = conn.Exec(context.Background(), fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", tableName)) assert.NoError(t, err) } } @@ -55,10 +51,8 @@ func TestExecutor_executeQuery(t *testing.T) { }, } - pool, tearDownFunc := setupPolicyDatabase(t, t.Name()) + conn, tearDownFunc := setupPolicyDatabase(t, t.Name()) defer tearDownFunc(t) - conn, err := pool.Acquire(context.Background()) - assert.NoError(t, err) executor := NewExecutor(conn, hclog.Default(), nil) for _, tc := range cases { @@ -147,10 +141,8 @@ func TestExecutor_executePolicy(t *testing.T) { }, } - pool, tearDownFunc := setupPolicyDatabase(t, t.Name()) + conn, tearDownFunc := setupPolicyDatabase(t, t.Name()) defer tearDownFunc(t) - conn, err := pool.Acquire(context.Background()) - assert.NoError(t, err) executor := NewExecutor(conn, hclog.Default(), nil) for _, tc := range cases { @@ -310,10 +302,8 @@ func TestExecutor_Execute(t *testing.T) { }, } - pool, tearDownFunc := setupPolicyDatabase(t, t.Name()) + conn, tearDownFunc := setupPolicyDatabase(t, t.Name()) defer tearDownFunc(t) - conn, err := pool.Acquire(context.Background()) - assert.NoError(t, err) executor := NewExecutor(conn, hclog.Default(), nil) for _, tc := range cases { diff --git a/pkg/policy/manager.go b/pkg/policy/manager.go index e5c14b1ccda3f8..cd7b2ad2aec441 100644 --- a/pkg/policy/manager.go +++ b/pkg/policy/manager.go @@ -2,14 +2,13 @@ package policy import ( "context" - "fmt" "strings" + "github.com/cloudquery/cq-provider-sdk/provider/schema" "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/hclsyntax" "github.com/hashicorp/go-hclog" - "github.com/jackc/pgx/v4/pgxpool" ) const ( @@ -22,7 +21,7 @@ type ManagerImpl struct { policyDirectory string // Instance of a database connection pool - pool *pgxpool.Pool + pool schema.QueryExecer // Logger instance logger hclog.Logger @@ -39,7 +38,7 @@ type Manager interface { } // NewManager returns a new manager instance. -func NewManager(policyDir string, pool *pgxpool.Pool, logger hclog.Logger) *ManagerImpl { +func NewManager(policyDir string, pool schema.QueryExecer, logger hclog.Logger) *ManagerImpl { return &ManagerImpl{ policyDirectory: policyDir, pool: pool, @@ -70,14 +69,6 @@ func (m *ManagerImpl) Load(ctx context.Context, policy *Policy) (*Policy, error) } func (m *ManagerImpl) Run(ctx context.Context, request *ExecuteRequest) (*ExecutionResult, error) { - // Acquire connection from the connection pool - conn, err := m.pool.Acquire(ctx) - m.logger.Trace("acquired connection from the connection pool", "err", err) - if err != nil { - return nil, fmt.Errorf("failed to acquire connection from the connection pool: %w", err) - } - defer conn.Release() - var ( totalQueriesToRun = request.Policy.TotalQueries() finishedQueries = 0 @@ -117,7 +108,7 @@ func (m *ManagerImpl) Run(ctx context.Context, request *ExecuteRequest) (*Execut } // execute the queries - return NewExecutor(conn, m.logger, progressUpdate).Execute(ctx, request, request.Policy, selector) + return NewExecutor(m.pool, m.logger, progressUpdate).Execute(ctx, request, request.Policy, selector) } func (m *ManagerImpl) loadPolicyFromSource(ctx context.Context, name, subPolicy, sourceURL string) (*Policy, error) { diff --git a/pkg/ui/console/client.go b/pkg/ui/console/client.go index da12f271f884e3..3325a3eabb8cb5 100644 --- a/pkg/ui/console/client.go +++ b/pkg/ui/console/client.go @@ -480,7 +480,7 @@ func (c Client) getModuleProviders(ctx context.Context) ([]*cqproto.GetProviderS s.Version = deets.Version } } - list[i] = s + list[i] = s.GetProviderSchemaResponse } return list, nil