Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/destination/postgresql/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"strings"

"github.com/cloudquery/plugin-pb-go/specs"
"github.com/cloudquery/plugin-sdk/v2/plugins/destination"
"github.com/cloudquery/plugin-sdk/v3/plugins/destination"

pgx_zero_log "github.com/jackc/pgx-zerolog"
"github.com/jackc/pgx/v5"
Expand Down
2 changes: 1 addition & 1 deletion plugins/destination/postgresql/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"testing"

"github.com/cloudquery/plugin-pb-go/specs"
"github.com/cloudquery/plugin-sdk/v2/plugins/destination"
"github.com/cloudquery/plugin-sdk/v3/plugins/destination"
)

func getTestConnection() string {
Expand Down
6 changes: 3 additions & 3 deletions plugins/destination/postgresql/client/deletestale.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ import (
"strings"
"time"

"github.com/cloudquery/plugin-sdk/v2/schema"
"github.com/cloudquery/plugin-sdk/v3/schema"
"github.com/jackc/pgx/v5"
)

func (c *Client) DeleteStale(ctx context.Context, tables schema.Schemas, source string, syncTime time.Time) error {
func (c *Client) DeleteStale(ctx context.Context, tables schema.Tables, source string, syncTime time.Time) error {
batch := &pgx.Batch{}
for _, table := range tables {
var sb strings.Builder
sb.WriteString("delete from ")
sb.WriteString(pgx.Identifier{schema.TableName(table)}.Sanitize())
sb.WriteString(pgx.Identifier{table.Name}.Sanitize())
sb.WriteString(" where ")
sb.WriteString(schema.CqSourceNameColumn.Name)
sb.WriteString(" = $1 and ")
Expand Down
2 changes: 1 addition & 1 deletion plugins/destination/postgresql/client/metrics.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package client

import (
"github.com/cloudquery/plugin-sdk/v2/plugins/destination"
"github.com/cloudquery/plugin-sdk/v3/plugins/destination"
)

func (c *Client) Metrics() destination.Metrics {
Expand Down
135 changes: 60 additions & 75 deletions plugins/destination/postgresql/client/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ package client
import (
"context"
"fmt"
"strconv"
"strings"

"github.com/apache/arrow/go/v13/arrow"
"github.com/cloudquery/plugin-pb-go/specs"
"github.com/cloudquery/plugin-sdk/v2/schema"
"github.com/cloudquery/plugin-sdk/v3/schema"
"github.com/jackc/pgx/v5"
)

Expand Down Expand Up @@ -94,10 +92,8 @@ ORDER BY
`
)

func (c *Client) listPgTables(ctx context.Context, pluginTables schema.Schemas) (schema.Schemas, error) {
var tables schema.Schemas
var fields []arrow.Field
tableMetaData := make(map[string]string)
func (c *Client) listPgTables(ctx context.Context, pluginTables schema.Tables) (schema.Tables, error) {
var tables schema.Tables
sql := selectAllTables
if c.pgType == pgTypeCockroachDB {
sql = selectAllTablesCockroach
Expand All @@ -115,67 +111,52 @@ func (c *Client) listPgTables(ctx context.Context, pluginTables schema.Schemas)
return nil, err
}
// We don't want to migrate tables that are not a part of the spec, or non CloudQuery tables
if pluginTables.SchemaByName(tableName) == nil {
if pluginTables.Get(tableName) == nil {
continue
}
if ordinalPosition == 1 {
if fields != nil {
md := arrow.MetadataFrom(tableMetaData)
tables = append(tables, arrow.NewSchema(fields, &md))
fields = nil
tableMetaData = make(map[string]string, 0)
}
tableMetaData[schema.MetadataTableName] = tableName
tables = append(tables, &schema.Table{
Name: tableName,
})
}
table := tables[len(tables)-1]
if pkName != "" {
tableMetaData[schema.MetadataConstraintName] = pkName
table.PkConstraintName = pkName
}
schemaType := c.PgToSchemaType(columnType)
fields = append(fields, arrow.Field{
Name: columnName,
Type: schemaType,
Nullable: !notNull,
Metadata: arrow.MetadataFrom(map[string]string{
schema.MetadataPrimaryKey: strconv.FormatBool(isPrimaryKey),
}),
table.Columns = append(table.Columns, schema.Column{
Name: columnName,
Type: schemaType,
PrimaryKey: isPrimaryKey,
NotNull: notNull,
})
}
if fields != nil {
md := arrow.MetadataFrom(tableMetaData)
tables = append(tables, arrow.NewSchema(fields, &md))
}
return tables, nil
}

func (c *Client) normalizeTable(table *arrow.Schema, pgTable *arrow.Schema) *arrow.Schema {
fields := make([]arrow.Field, len(table.Fields()))
for i, f := range table.Fields() {
metadata := make(map[string]string, 0)
if c.enabledPks() && schema.IsPk(f) {
metadata[schema.MetadataPrimaryKey] = schema.MetadataTrue
f.Nullable = false
func (c *Client) normalizeTable(table *schema.Table, pgTable *schema.Table) *schema.Table {
normalizedTable := schema.Table{
Name: table.Name,
}
for _, col := range table.Columns {
if c.enabledPks() && col.PrimaryKey {
col.NotNull = true
} else {
metadata[schema.MetadataPrimaryKey] = schema.MetadataFalse
col.PrimaryKey = false
}

f.Metadata = arrow.MetadataFrom(metadata)
f.Type = c.PgToSchemaType(c.SchemaTypeToPg(f.Type))
fields[i] = f
col.Type = c.PgToSchemaType(c.SchemaTypeToPg(col.Type))
normalizedTable.Columns = append(normalizedTable.Columns, col)
}
mdMap := make(map[string]string)
if pgTable != nil {
mdMap[schema.MetadataTableName] = schema.TableName(pgTable)
if constraintName, ok := pgTable.Metadata().GetValue(schema.MetadataConstraintName); ok {
mdMap[schema.MetadataConstraintName] = constraintName
}

if pgTable != nil && pgTable.PkConstraintName != "" {
normalizedTable.PkConstraintName = pgTable.PkConstraintName
}
mdMap[schema.MetadataTableName] = schema.TableName(table)
md := arrow.MetadataFrom(mdMap)
return arrow.NewSchema(fields, &md)

return &normalizedTable
}

func (c *Client) autoMigrateTable(ctx context.Context, table *arrow.Schema, changes []schema.FieldChange) error {
tableName := schema.TableName(table)
func (c *Client) autoMigrateTable(ctx context.Context, table *schema.Table, changes []schema.TableColumnChange) error {
tableName := table.Name
for _, change := range changes {
switch change.Type {
case schema.TableColumnChangeTypeAdd:
Expand All @@ -191,15 +172,15 @@ func (c *Client) autoMigrateTable(ctx context.Context, table *arrow.Schema, chan
return nil
}

func (*Client) canAutoMigrate(changes []schema.FieldChange) bool {
func (*Client) canAutoMigrate(changes []schema.TableColumnChange) bool {
for _, change := range changes {
switch change.Type {
case schema.TableColumnChangeTypeAdd:
if schema.IsPk(change.Current) || !change.Current.Nullable {
if change.Current.PrimaryKey || change.Current.NotNull {
return false
}
case schema.TableColumnChangeTypeRemove:
if schema.IsPk(change.Previous) || !change.Previous.Nullable {
if change.Previous.PrimaryKey || change.Previous.NotNull {
return false
}
case schema.TableColumnChangeTypeUpdate:
Expand All @@ -212,34 +193,38 @@ func (*Client) canAutoMigrate(changes []schema.FieldChange) bool {
}

// normalize the requested schema to be compatible with what Postgres supports
func (c *Client) normalizeTables(tables schema.Schemas, pgTables schema.Schemas) schema.Schemas {
var result schema.Schemas
func (c *Client) normalizeTables(tables schema.Tables, pgTables schema.Tables) schema.Tables {
var result schema.Tables
for _, table := range tables {
pgTabe := pgTables.SchemaByName(schema.TableName(table))
result = append(result, c.normalizeTable(table, pgTabe))
pgTable := pgTables.Get(table.Name)
if pgTable == nil {
result = append(result, table)
} else {
result = append(result, c.normalizeTable(table, pgTable))
}
}
return result
}

func (c *Client) nonAutoMigrableTables(tables schema.Schemas, pgTables schema.Schemas) ([]string, [][]schema.FieldChange) {
func (c *Client) nonAutoMigrableTables(tables schema.Tables, pgTables schema.Tables) ([]string, [][]schema.TableColumnChange) {
var result []string
var tableChanges [][]schema.FieldChange
var tableChanges [][]schema.TableColumnChange
for _, t := range tables {
pgTable := pgTables.SchemaByName(schema.TableName(t))
pgTable := pgTables.Get(t.Name)
if pgTable == nil {
continue
}
changes := schema.GetSchemaChanges(t, pgTable)
changes := t.GetChanges(pgTable)
if !c.canAutoMigrate(changes) {
result = append(result, schema.TableName(t))
result = append(result, t.Name)
tableChanges = append(tableChanges, changes)
}
}
return result, tableChanges
}

// This is the responsibility of the CLI of the client to lock before running migration
func (c *Client) Migrate(ctx context.Context, tables schema.Schemas) error {
func (c *Client) Migrate(ctx context.Context, tables schema.Tables) error {
pgTables, err := c.listPgTables(ctx, tables)
if err != nil {
return fmt.Errorf("failed listing postgres tables: %w", err)
Expand All @@ -253,20 +238,20 @@ func (c *Client) Migrate(ctx context.Context, tables schema.Schemas) error {
}

for _, table := range tables {
tableName := schema.TableName(table)
tableName := table.Name
c.logger.Info().Str("table", tableName).Msg("Migrating table")
if len(table.Fields()) == 0 {
if len(table.Columns) == 0 {
c.logger.Info().Str("table", tableName).Msg("Table with no columns, skipping")
continue
}
pgTable := pgTables.SchemaByName(tableName)
pgTable := pgTables.Get(tableName)
if pgTable == nil {
c.logger.Debug().Str("table", tableName).Msg("Table doesn't exist, creating")
if err := c.createTableIfNotExist(ctx, table); err != nil {
return err
}
} else {
changes := schema.GetSchemaChanges(table, pgTable)
changes := table.GetChanges(pgTable)
if c.canAutoMigrate(changes) {
c.logger.Info().Str("table", tableName).Msg("Table exists, auto-migrating")
if err := c.autoMigrateTable(ctx, table, changes); err != nil {
Expand Down Expand Up @@ -303,7 +288,7 @@ func (c *Client) dropTable(ctx context.Context, tableName string) error {
return nil
}

func (c *Client) addColumn(ctx context.Context, tableName string, column arrow.Field) error {
func (c *Client) addColumn(ctx context.Context, tableName string, column schema.Column) error {
c.logger.Info().Str("table", tableName).Str("column", column.Name).Msg("Column doesn't exist, creating")
columnName := pgx.Identifier{column.Name}.Sanitize()
columnType := c.SchemaTypeToPg(column.Type)
Expand All @@ -314,31 +299,31 @@ func (c *Client) addColumn(ctx context.Context, tableName string, column arrow.F
return nil
}

func (c *Client) createTableIfNotExist(ctx context.Context, table *arrow.Schema) error {
func (c *Client) createTableIfNotExist(ctx context.Context, table *schema.Table) error {
var sb strings.Builder
tName := schema.TableName(table)
tName := table.Name
tableName := pgx.Identifier{tName}.Sanitize()
sb.WriteString("CREATE TABLE IF NOT EXISTS ")
sb.WriteString(tableName)
sb.WriteString(" (")
totalColumns := len(table.Fields())
totalColumns := len(table.Columns)

primaryKeys := []string{}
for i, col := range table.Fields() {
for i, col := range table.Columns {
pgType := c.SchemaTypeToPg(col.Type)
columnName := pgx.Identifier{col.Name}.Sanitize()
fieldDef := columnName + " " + pgType
if schema.IsUnique(col) {
if col.Unique {
fieldDef += " UNIQUE"
}
if !col.Nullable {
if col.NotNull {
fieldDef += " NOT NULL"
}
sb.WriteString(fieldDef)
if i != totalColumns-1 {
sb.WriteString(",")
}
if c.enabledPks() && schema.IsPk(col) {
if c.enabledPks() && col.PrimaryKey {
primaryKeys = append(primaryKeys, pgx.Identifier{col.Name}.Sanitize())
}
}
Expand Down
Loading