Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion plugins/destination/postgresql/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type Client struct {
batchSize int
writer *mixedbatchwriter.MixedBatchWriter

pgTablesToPKConstraints map[string]string

plugin.UnimplementedSource
}

Expand All @@ -41,7 +43,8 @@ const (

func New(ctx context.Context, logger zerolog.Logger, specBytes []byte, opts plugin.NewClientOptions) (plugin.Client, error) {
c := &Client{
logger: logger.With().Str("module", "pg-dest").Logger(),
logger: logger.With().Str("module", "pg-dest").Logger(),
pgTablesToPKConstraints: make(map[string]string),
}
if opts.NoConnection {
return c, nil
Expand Down
11 changes: 1 addition & 10 deletions plugins/destination/postgresql/client/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,7 @@ func (c *Client) InsertBatch(ctx context.Context, messages message.WriteInserts)
return err
}

include := make([]string, len(tables))
for i, table := range tables {
include[i] = table.Name
}
var exclude []string
pgTables, err := c.listTables(ctx, include, exclude)
if err != nil {
return err
}
tables = c.normalizeTables(tables, pgTables)
tables = c.normalizeTables(tables)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion plugins/destination/postgresql/client/list_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (c *Client) listTables(ctx context.Context, include, exclude []string) (sch
}
table := tables[len(tables)-1]
if pkName != "" {
table.PkConstraintName = pkName
c.pgTablesToPKConstraints[tableName], table.PkConstraintName = pkName, pkName
}
table.Columns = append(table.Columns, schema.Column{
Name: columnName,
Expand Down
29 changes: 15 additions & 14 deletions plugins/destination/postgresql/client/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (c *Client) MigrateTableBatch(ctx context.Context, messages message.WriteMi
if err != nil {
return fmt.Errorf("failed listing postgres tables: %w", err)
}
tables = c.normalizeTables(tables, pgTables)
tables = c.normalizeTables(tables)

safeTables := map[string]bool{}
for _, msg := range messages {
Expand Down Expand Up @@ -80,7 +80,7 @@ func (c *Client) MigrateTableBatch(ctx context.Context, messages message.WriteMi
return nil
}

func (c *Client) normalizeTable(table *schema.Table, pgTable *schema.Table) *schema.Table {
func (c *Client) normalizeTable(table *schema.Table) *schema.Table {
normalizedTable := schema.Table{
Name: table.Name,
}
Expand All @@ -90,10 +90,8 @@ func (c *Client) normalizeTable(table *schema.Table, pgTable *schema.Table) *sch
}
col.Type = c.PgToSchemaType(c.SchemaTypeToPg(col.Type))
normalizedTable.Columns = append(normalizedTable.Columns, col)
}

if pgTable != nil && pgTable.PkConstraintName != "" {
normalizedTable.PkConstraintName = pgTable.PkConstraintName
// pgTablesToPKConstraints is populated when handling migrate messages
normalizedTable.PkConstraintName = c.pgTablesToPKConstraints[table.Name]
}

return &normalizedTable
Expand Down Expand Up @@ -142,10 +140,10 @@ func (*Client) canAutoMigrate(changes []schema.TableColumnChange) bool {
}

// normalize the requested schema to be compatible with what Postgres supports
func (c *Client) normalizeTables(tables schema.Tables, pgTables schema.Tables) schema.Tables {
func (c *Client) normalizeTables(tables schema.Tables) schema.Tables {
result := make(schema.Tables, len(tables))
for i, table := range tables {
result[i] = c.normalizeTable(table, pgTables.Get(table.Name))
result[i] = c.normalizeTable(table)
}
return result
}
Expand Down Expand Up @@ -189,10 +187,10 @@ func (c *Client) addColumn(ctx context.Context, tableName string, column schema.

func (c *Client) createTableIfNotExist(ctx context.Context, table *schema.Table) error {
var sb strings.Builder
tName := table.Name
tableName := pgx.Identifier{tName}.Sanitize()
tableName := table.Name
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed these since I initially used tableName for the dictionary key which is wrong since it had the sanitized value

sanitizedTableName := pgx.Identifier{tableName}.Sanitize()
sb.WriteString("CREATE TABLE IF NOT EXISTS ")
sb.WriteString(tableName)
sb.WriteString(sanitizedTableName)
sb.WriteString(" (")
totalColumns := len(table.Columns)

Expand All @@ -216,19 +214,22 @@ func (c *Client) createTableIfNotExist(ctx context.Context, table *schema.Table)
}
}

pkConstraintName := tableName + "_cqpk"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't it actually use the c.pgTablesToPKConstraints value if that's available?

Copy link
Copy Markdown
Member Author

@erezrokah erezrokah Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inside createTableIfNotExist so there should not be any value available

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, but what about table.PkConstraintName?

c.pgTablesToPKConstraints[tableName] = pkConstraintName

if len(primaryKeys) > 0 {
// add composite PK constraint on primary key columns
sb.WriteString(", CONSTRAINT ")
sb.WriteString(pgx.Identifier{tName + "_cqpk"}.Sanitize())
sb.WriteString(pgx.Identifier{pkConstraintName}.Sanitize())
sb.WriteString(" PRIMARY KEY (")
sb.WriteString(strings.Join(primaryKeys, ","))
sb.WriteString(")")
}
sb.WriteString(")")
_, err := c.conn.Exec(ctx, sb.String())
if err != nil {
c.logger.Error().Err(err).Str("table", tName).Str("query", sb.String()).Msg("Failed to create table")
return fmt.Errorf("failed to create table %s: %w"+sb.String(), tName, err)
c.logger.Error().Err(err).Str("table", tableName).Str("query", sb.String()).Msg("Failed to create table")
return fmt.Errorf("failed to create table %s: %w"+sb.String(), tableName, err)
}
return nil
}