diff --git a/plugins/destination/postgresql/client/client.go b/plugins/destination/postgresql/client/client.go index 875b513fd36009..4228c61453d5d7 100644 --- a/plugins/destination/postgresql/client/client.go +++ b/plugins/destination/postgresql/client/client.go @@ -25,6 +25,8 @@ type Client struct { batchSize int writer *mixedbatchwriter.MixedBatchWriter + pgTablesToPKConstraints map[string]string + plugin.UnimplementedSource } @@ -41,7 +43,8 @@ const ( func New(ctx context.Context, logger zerolog.Logger, specBytes []byte, opts plugin.NewClientOptions) (plugin.Client, error) { c := &Client{ - logger: logger.With().Str("module", "pg-dest").Logger(), + logger: logger.With().Str("module", "pg-dest").Logger(), + pgTablesToPKConstraints: make(map[string]string), } if opts.NoConnection { return c, nil diff --git a/plugins/destination/postgresql/client/insert.go b/plugins/destination/postgresql/client/insert.go index 7b00128a1a5ec9..5cdd9231955849 100644 --- a/plugins/destination/postgresql/client/insert.go +++ b/plugins/destination/postgresql/client/insert.go @@ -20,16 +20,7 @@ func (c *Client) InsertBatch(ctx context.Context, messages message.WriteInserts) return err } - include := make([]string, len(tables)) - for i, table := range tables { - include[i] = table.Name - } - var exclude []string - pgTables, err := c.listTables(ctx, include, exclude) - if err != nil { - return err - } - tables = c.normalizeTables(tables, pgTables) + tables = c.normalizeTables(tables) if err != nil { return err } diff --git a/plugins/destination/postgresql/client/list_tables.go b/plugins/destination/postgresql/client/list_tables.go index b9139622a97526..ffb0399c2595bd 100644 --- a/plugins/destination/postgresql/client/list_tables.go +++ b/plugins/destination/postgresql/client/list_tables.go @@ -89,7 +89,7 @@ func (c *Client) listTables(ctx context.Context, include, exclude []string) (sch } table := tables[len(tables)-1] if pkName != "" { - table.PkConstraintName = pkName + c.pgTablesToPKConstraints[tableName], table.PkConstraintName = pkName, pkName } table.Columns = append(table.Columns, schema.Column{ Name: columnName, diff --git a/plugins/destination/postgresql/client/migrate.go b/plugins/destination/postgresql/client/migrate.go index 8dc95ae2dd9c10..c70a9eded2b0f6 100644 --- a/plugins/destination/postgresql/client/migrate.go +++ b/plugins/destination/postgresql/client/migrate.go @@ -25,7 +25,7 @@ func (c *Client) MigrateTableBatch(ctx context.Context, messages message.WriteMi if err != nil { return fmt.Errorf("failed listing postgres tables: %w", err) } - tables = c.normalizeTables(tables, pgTables) + tables = c.normalizeTables(tables) safeTables := map[string]bool{} for _, msg := range messages { @@ -80,7 +80,7 @@ func (c *Client) MigrateTableBatch(ctx context.Context, messages message.WriteMi return nil } -func (c *Client) normalizeTable(table *schema.Table, pgTable *schema.Table) *schema.Table { +func (c *Client) normalizeTable(table *schema.Table) *schema.Table { normalizedTable := schema.Table{ Name: table.Name, } @@ -90,10 +90,8 @@ func (c *Client) normalizeTable(table *schema.Table, pgTable *schema.Table) *sch } col.Type = c.PgToSchemaType(c.SchemaTypeToPg(col.Type)) normalizedTable.Columns = append(normalizedTable.Columns, col) - } - - if pgTable != nil && pgTable.PkConstraintName != "" { - normalizedTable.PkConstraintName = pgTable.PkConstraintName + // pgTablesToPKConstraints is populated when handling migrate messages + normalizedTable.PkConstraintName = c.pgTablesToPKConstraints[table.Name] } return &normalizedTable @@ -142,10 +140,10 @@ func (*Client) canAutoMigrate(changes []schema.TableColumnChange) bool { } // normalize the requested schema to be compatible with what Postgres supports -func (c *Client) normalizeTables(tables schema.Tables, pgTables schema.Tables) schema.Tables { +func (c *Client) normalizeTables(tables schema.Tables) schema.Tables { result := make(schema.Tables, len(tables)) for i, table := range tables { - result[i] = c.normalizeTable(table, pgTables.Get(table.Name)) + result[i] = c.normalizeTable(table) } return result } @@ -189,10 +187,10 @@ func (c *Client) addColumn(ctx context.Context, tableName string, column schema. func (c *Client) createTableIfNotExist(ctx context.Context, table *schema.Table) error { var sb strings.Builder - tName := table.Name - tableName := pgx.Identifier{tName}.Sanitize() + tableName := table.Name + sanitizedTableName := pgx.Identifier{tableName}.Sanitize() sb.WriteString("CREATE TABLE IF NOT EXISTS ") - sb.WriteString(tableName) + sb.WriteString(sanitizedTableName) sb.WriteString(" (") totalColumns := len(table.Columns) @@ -216,10 +214,13 @@ func (c *Client) createTableIfNotExist(ctx context.Context, table *schema.Table) } } + pkConstraintName := tableName + "_cqpk" + c.pgTablesToPKConstraints[tableName] = pkConstraintName + if len(primaryKeys) > 0 { // add composite PK constraint on primary key columns sb.WriteString(", CONSTRAINT ") - sb.WriteString(pgx.Identifier{tName + "_cqpk"}.Sanitize()) + sb.WriteString(pgx.Identifier{pkConstraintName}.Sanitize()) sb.WriteString(" PRIMARY KEY (") sb.WriteString(strings.Join(primaryKeys, ",")) sb.WriteString(")") @@ -227,8 +228,8 @@ func (c *Client) createTableIfNotExist(ctx context.Context, table *schema.Table) sb.WriteString(")") _, err := c.conn.Exec(ctx, sb.String()) if err != nil { - c.logger.Error().Err(err).Str("table", tName).Str("query", sb.String()).Msg("Failed to create table") - return fmt.Errorf("failed to create table %s: %w"+sb.String(), tName, err) + c.logger.Error().Err(err).Str("table", tableName).Str("query", sb.String()).Msg("Failed to create table") + return fmt.Errorf("failed to create table %s: %w"+sb.String(), tableName, err) } return nil }