From 04c8662151dc137d9c7c4e6f568a88811e5b0ccd Mon Sep 17 00:00:00 2001 From: shimonp21 Date: Thu, 29 Sep 2022 12:10:27 +0300 Subject: [PATCH] fix: Upsert in postgresql The previous upsert query just replaced the values with the values that were already in the table... i.e. the query was: DO UPDATE SET "column"=table."column" which of course does nothing. The new query is: "column"=excluded."column", which does what we want. --- .../postgresql/client/postgresql.go | 10 +- .../postgresql/client/postgresql_test.go | 115 ++++++++++++++++-- 2 files changed, 112 insertions(+), 13 deletions(-) diff --git a/plugins/destination/postgresql/client/postgresql.go b/plugins/destination/postgresql/client/postgresql.go index 3fff2dbfabd310..371a18bc9bfc1a 100644 --- a/plugins/destination/postgresql/client/postgresql.go +++ b/plugins/destination/postgresql/client/postgresql.go @@ -360,12 +360,10 @@ func upsert(table string, data map[string]interface{}) (string, []interface{}) { sb.WriteString(" on conflict on constraint ") sb.WriteString(constraintName) sb.WriteString(" do update set ") - for i, c := range columns { - sb.WriteString(c) - sb.WriteString("=") - sb.WriteString(table) - sb.WriteString(".") - sb.WriteString(c) + for i, column := range columns { + sb.WriteString(column) + sb.WriteString("=excluded.") // excluded references the new values + sb.WriteString(column) if i < len(columns)-1 { sb.WriteString(",") } else { diff --git a/plugins/destination/postgresql/client/postgresql_test.go b/plugins/destination/postgresql/client/postgresql_test.go index 0afa68d1c92ac0..5f8e7da0ae0833 100644 --- a/plugins/destination/postgresql/client/postgresql_test.go +++ b/plugins/destination/postgresql/client/postgresql_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "fmt" "os" "testing" "time" @@ -94,22 +95,36 @@ var createTablesTests = []*schema.Table{ }, } -func TestPostgreSqlCreateTables(t *testing.T) { +// Initializes a postgres client at "postgres://postgres:pass@localhost:5432/postgres" +func newLocalhostPostgresClient(ctx context.Context, t *testing.T) (*Client, error) { zerolog.TimeFieldFormat = zerolog.TimeFormatUnixMs - l := zerolog.New(zerolog.NewTestWriter(t)).Output( + logger := zerolog.New(zerolog.NewTestWriter(t)).Output( zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.StampMicro}, ).Level(zerolog.DebugLevel).With().Timestamp().Logger() - ctx := context.Background() - c := New() - c.SetLogger(l) - if err := c.Initialize(ctx, + + client := New() + client.SetLogger(logger) + + err := client.Initialize(ctx, specs.Destination{ Spec: &PostgreSqlSpec{ ConnectionString: "postgres://postgres:pass@localhost:5432/postgres", PgxLogLevel: LogLevelInfo, }, }, - ); err != nil { + ) + + if err != nil { + return nil, err + } + + return client, nil +} + +func TestPostgreSqlCreateTables(t *testing.T) { + ctx := context.Background() + c, err := newLocalhostPostgresClient(ctx, t) + if err != nil { t.Fatalf("failed to initialize client: %v", err) } @@ -186,3 +201,89 @@ func TestPostgreSqlCreateTables(t *testing.T) { t.Fatal(diff) } } + +func TestUpdate(t *testing.T) { + ctx := context.Background() + client, err := newLocalhostPostgresClient(ctx, t) + if err != nil { + t.Fatalf("failed to initialize client: %v", err) + } + + if err := client.Drop(ctx, createTablesTests); err != nil { + t.Fatalf("failed to drop tables: %v", err) + } + + if err := client.Migrate(ctx, createTablesTests); err != nil { + t.Fatalf("failed to migrate tables: %v", err) + } + + data := map[string]interface{}{ + "id": "9a6011b7-c5ee-4b55-95a6-37ce5e02a5a0", + "bool_column": true, + "int_column": float64(3), + "float_column": float64(3.3), + "uuid_column": "9a6011b7-c5ee-4b55-95a6-37ce5e02a5a0", + "string_column": "test", + "string_array_column": []interface{}{"test", "test2"}, + "int_array_column": []interface{}{float64(1), float64(2), float64(3)}, + "timestamp_column": "2019-01-01T00:00:00", + "json_column": map[string]interface{}{"1": float64(1), "test": "test"}, + "uuid_array_column": []interface{}{"1a6011b7-c5ee-4b55-95a6-37ce5e02a5a0", "9a6011b7-c5ee-4b55-95a6-37ce5e02a5a0"}, + "inet_column": "1.1.1.1", + "inet_array_column": []interface{}{"8.8.8.8/0"}, + "cidr_column": "0.0.0.0/24", + "cidr_array_column": []interface{}{"0.0.0.0/24", "0.0.0.0/16"}, + "mac_addr_column": "00:00:00:00:00:ab", + } + + if err := client.Write(ctx, "simple_table", data); err != nil { + t.Fatalf("failed to write data: %v", err) + } + + intColumn, err := getIntColumn(ctx, client) + if err != nil { + t.Fatal(err) + } + if intColumn != 3 { + t.Fatal("expected int_column to be 3, got", intColumn) + } + + // Update `int_column` to be 5, and make sure it's changed in the database. + data["int_column"] = float64(5) + + if err := client.Write(ctx, "simple_table", data); err != nil { + t.Fatalf("failed to write data: %v", err) + } + + intColumn, err = getIntColumn(ctx, client) + if err != nil { + t.Fatal(err) + } + if intColumn != 5 { + t.Fatal("expected int_column to be 5, got", intColumn) + } +} + +// Returns the value of "simple_table.int_column". +// Makes sure there is only one result in the table. +func getIntColumn(ctx context.Context, client *Client) (int, error) { + rows, err := client.conn.Query(ctx, "SELECT int_column FROM simple_table") + if err != nil { + return 0, err + } + + var intColumn int + totalResults := 0 + for rows.Next() { + if err := rows.Scan(&intColumn); err != nil { + return 0, err + } + + totalResults++ + } + if totalResults != 1 { + return 0, fmt.Errorf("expected 1 result, got %d", totalResults) + } + + return intColumn, nil +}