diff --git a/plugins/destination/postgresql/client/client.go b/plugins/destination/postgresql/client/client.go index 7599433673d822..1ac954b4a0c8e7 100644 --- a/plugins/destination/postgresql/client/client.go +++ b/plugins/destination/postgresql/client/client.go @@ -6,7 +6,7 @@ import ( "strings" "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v2/plugins/destination" + "github.com/cloudquery/plugin-sdk/v3/plugins/destination" pgx_zero_log "github.com/jackc/pgx-zerolog" "github.com/jackc/pgx/v5" diff --git a/plugins/destination/postgresql/client/client_test.go b/plugins/destination/postgresql/client/client_test.go index 19a872dad5dff8..5f0a4cdc1f9f9d 100644 --- a/plugins/destination/postgresql/client/client_test.go +++ b/plugins/destination/postgresql/client/client_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v2/plugins/destination" + "github.com/cloudquery/plugin-sdk/v3/plugins/destination" ) func getTestConnection() string { diff --git a/plugins/destination/postgresql/client/deletestale.go b/plugins/destination/postgresql/client/deletestale.go index 320d6ec52bb974..edad26fead5a9d 100644 --- a/plugins/destination/postgresql/client/deletestale.go +++ b/plugins/destination/postgresql/client/deletestale.go @@ -6,16 +6,16 @@ import ( "strings" "time" - "github.com/cloudquery/plugin-sdk/v2/schema" + "github.com/cloudquery/plugin-sdk/v3/schema" "github.com/jackc/pgx/v5" ) -func (c *Client) DeleteStale(ctx context.Context, tables schema.Schemas, source string, syncTime time.Time) error { +func (c *Client) DeleteStale(ctx context.Context, tables schema.Tables, source string, syncTime time.Time) error { batch := &pgx.Batch{} for _, table := range tables { var sb strings.Builder sb.WriteString("delete from ") - sb.WriteString(pgx.Identifier{schema.TableName(table)}.Sanitize()) + sb.WriteString(pgx.Identifier{table.Name}.Sanitize()) sb.WriteString(" where ") sb.WriteString(schema.CqSourceNameColumn.Name) sb.WriteString(" = $1 and ") diff --git a/plugins/destination/postgresql/client/metrics.go b/plugins/destination/postgresql/client/metrics.go index 6e0bc82117c87a..0abe9f700b726d 100644 --- a/plugins/destination/postgresql/client/metrics.go +++ b/plugins/destination/postgresql/client/metrics.go @@ -1,7 +1,7 @@ package client import ( - "github.com/cloudquery/plugin-sdk/v2/plugins/destination" + "github.com/cloudquery/plugin-sdk/v3/plugins/destination" ) func (c *Client) Metrics() destination.Metrics { diff --git a/plugins/destination/postgresql/client/migrate.go b/plugins/destination/postgresql/client/migrate.go index 2fe0654047cdab..a13713d8a0437f 100644 --- a/plugins/destination/postgresql/client/migrate.go +++ b/plugins/destination/postgresql/client/migrate.go @@ -3,12 +3,10 @@ package client import ( "context" "fmt" - "strconv" "strings" - "github.com/apache/arrow/go/v13/arrow" "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v2/schema" + "github.com/cloudquery/plugin-sdk/v3/schema" "github.com/jackc/pgx/v5" ) @@ -94,10 +92,8 @@ ORDER BY ` ) -func (c *Client) listPgTables(ctx context.Context, pluginTables schema.Schemas) (schema.Schemas, error) { - var tables schema.Schemas - var fields []arrow.Field - tableMetaData := make(map[string]string) +func (c *Client) listPgTables(ctx context.Context, pluginTables schema.Tables) (schema.Tables, error) { + var tables schema.Tables sql := selectAllTables if c.pgType == pgTypeCockroachDB { sql = selectAllTablesCockroach @@ -115,67 +111,52 @@ func (c *Client) listPgTables(ctx context.Context, pluginTables schema.Schemas) return nil, err } // We don't want to migrate tables that are not a part of the spec, or non CloudQuery tables - if pluginTables.SchemaByName(tableName) == nil { + if pluginTables.Get(tableName) == nil { continue } if ordinalPosition == 1 { - if fields != nil { - md := arrow.MetadataFrom(tableMetaData) - tables = append(tables, arrow.NewSchema(fields, &md)) - fields = nil - tableMetaData = make(map[string]string, 0) - } - tableMetaData[schema.MetadataTableName] = tableName + tables = append(tables, &schema.Table{ + Name: tableName, + }) } + table := tables[len(tables)-1] if pkName != "" { - tableMetaData[schema.MetadataConstraintName] = pkName + table.PkConstraintName = pkName } schemaType := c.PgToSchemaType(columnType) - fields = append(fields, arrow.Field{ - Name: columnName, - Type: schemaType, - Nullable: !notNull, - Metadata: arrow.MetadataFrom(map[string]string{ - schema.MetadataPrimaryKey: strconv.FormatBool(isPrimaryKey), - }), + table.Columns = append(table.Columns, schema.Column{ + Name: columnName, + Type: schemaType, + PrimaryKey: isPrimaryKey, + NotNull: notNull, }) } - if fields != nil { - md := arrow.MetadataFrom(tableMetaData) - tables = append(tables, arrow.NewSchema(fields, &md)) - } return tables, nil } -func (c *Client) normalizeTable(table *arrow.Schema, pgTable *arrow.Schema) *arrow.Schema { - fields := make([]arrow.Field, len(table.Fields())) - for i, f := range table.Fields() { - metadata := make(map[string]string, 0) - if c.enabledPks() && schema.IsPk(f) { - metadata[schema.MetadataPrimaryKey] = schema.MetadataTrue - f.Nullable = false +func (c *Client) normalizeTable(table *schema.Table, pgTable *schema.Table) *schema.Table { + normalizedTable := schema.Table{ + Name: table.Name, + } + for _, col := range table.Columns { + if c.enabledPks() && col.PrimaryKey { + col.NotNull = true } else { - metadata[schema.MetadataPrimaryKey] = schema.MetadataFalse + col.PrimaryKey = false } - - f.Metadata = arrow.MetadataFrom(metadata) - f.Type = c.PgToSchemaType(c.SchemaTypeToPg(f.Type)) - fields[i] = f + col.Type = c.PgToSchemaType(c.SchemaTypeToPg(col.Type)) + normalizedTable.Columns = append(normalizedTable.Columns, col) } - mdMap := make(map[string]string) - if pgTable != nil { - mdMap[schema.MetadataTableName] = schema.TableName(pgTable) - if constraintName, ok := pgTable.Metadata().GetValue(schema.MetadataConstraintName); ok { - mdMap[schema.MetadataConstraintName] = constraintName - } + + if pgTable != nil && pgTable.PkConstraintName != "" { + normalizedTable.PkConstraintName = pgTable.PkConstraintName } - mdMap[schema.MetadataTableName] = schema.TableName(table) - md := arrow.MetadataFrom(mdMap) - return arrow.NewSchema(fields, &md) + + return &normalizedTable } -func (c *Client) autoMigrateTable(ctx context.Context, table *arrow.Schema, changes []schema.FieldChange) error { - tableName := schema.TableName(table) +func (c *Client) autoMigrateTable(ctx context.Context, table *schema.Table, changes []schema.TableColumnChange) error { + tableName := table.Name for _, change := range changes { switch change.Type { case schema.TableColumnChangeTypeAdd: @@ -191,15 +172,15 @@ func (c *Client) autoMigrateTable(ctx context.Context, table *arrow.Schema, chan return nil } -func (*Client) canAutoMigrate(changes []schema.FieldChange) bool { +func (*Client) canAutoMigrate(changes []schema.TableColumnChange) bool { for _, change := range changes { switch change.Type { case schema.TableColumnChangeTypeAdd: - if schema.IsPk(change.Current) || !change.Current.Nullable { + if change.Current.PrimaryKey || change.Current.NotNull { return false } case schema.TableColumnChangeTypeRemove: - if schema.IsPk(change.Previous) || !change.Previous.Nullable { + if change.Previous.PrimaryKey || change.Previous.NotNull { return false } case schema.TableColumnChangeTypeUpdate: @@ -212,26 +193,30 @@ func (*Client) canAutoMigrate(changes []schema.FieldChange) bool { } // normalize the requested schema to be compatible with what Postgres supports -func (c *Client) normalizeTables(tables schema.Schemas, pgTables schema.Schemas) schema.Schemas { - var result schema.Schemas +func (c *Client) normalizeTables(tables schema.Tables, pgTables schema.Tables) schema.Tables { + var result schema.Tables for _, table := range tables { - pgTabe := pgTables.SchemaByName(schema.TableName(table)) - result = append(result, c.normalizeTable(table, pgTabe)) + pgTable := pgTables.Get(table.Name) + if pgTable == nil { + result = append(result, table) + } else { + result = append(result, c.normalizeTable(table, pgTable)) + } } return result } -func (c *Client) nonAutoMigrableTables(tables schema.Schemas, pgTables schema.Schemas) ([]string, [][]schema.FieldChange) { +func (c *Client) nonAutoMigrableTables(tables schema.Tables, pgTables schema.Tables) ([]string, [][]schema.TableColumnChange) { var result []string - var tableChanges [][]schema.FieldChange + var tableChanges [][]schema.TableColumnChange for _, t := range tables { - pgTable := pgTables.SchemaByName(schema.TableName(t)) + pgTable := pgTables.Get(t.Name) if pgTable == nil { continue } - changes := schema.GetSchemaChanges(t, pgTable) + changes := t.GetChanges(pgTable) if !c.canAutoMigrate(changes) { - result = append(result, schema.TableName(t)) + result = append(result, t.Name) tableChanges = append(tableChanges, changes) } } @@ -239,7 +224,7 @@ func (c *Client) nonAutoMigrableTables(tables schema.Schemas, pgTables schema.Sc } // This is the responsibility of the CLI of the client to lock before running migration -func (c *Client) Migrate(ctx context.Context, tables schema.Schemas) error { +func (c *Client) Migrate(ctx context.Context, tables schema.Tables) error { pgTables, err := c.listPgTables(ctx, tables) if err != nil { return fmt.Errorf("failed listing postgres tables: %w", err) @@ -253,20 +238,20 @@ func (c *Client) Migrate(ctx context.Context, tables schema.Schemas) error { } for _, table := range tables { - tableName := schema.TableName(table) + tableName := table.Name c.logger.Info().Str("table", tableName).Msg("Migrating table") - if len(table.Fields()) == 0 { + if len(table.Columns) == 0 { c.logger.Info().Str("table", tableName).Msg("Table with no columns, skipping") continue } - pgTable := pgTables.SchemaByName(tableName) + pgTable := pgTables.Get(tableName) if pgTable == nil { c.logger.Debug().Str("table", tableName).Msg("Table doesn't exist, creating") if err := c.createTableIfNotExist(ctx, table); err != nil { return err } } else { - changes := schema.GetSchemaChanges(table, pgTable) + changes := table.GetChanges(pgTable) if c.canAutoMigrate(changes) { c.logger.Info().Str("table", tableName).Msg("Table exists, auto-migrating") if err := c.autoMigrateTable(ctx, table, changes); err != nil { @@ -303,7 +288,7 @@ func (c *Client) dropTable(ctx context.Context, tableName string) error { return nil } -func (c *Client) addColumn(ctx context.Context, tableName string, column arrow.Field) error { +func (c *Client) addColumn(ctx context.Context, tableName string, column schema.Column) error { c.logger.Info().Str("table", tableName).Str("column", column.Name).Msg("Column doesn't exist, creating") columnName := pgx.Identifier{column.Name}.Sanitize() columnType := c.SchemaTypeToPg(column.Type) @@ -314,31 +299,31 @@ func (c *Client) addColumn(ctx context.Context, tableName string, column arrow.F return nil } -func (c *Client) createTableIfNotExist(ctx context.Context, table *arrow.Schema) error { +func (c *Client) createTableIfNotExist(ctx context.Context, table *schema.Table) error { var sb strings.Builder - tName := schema.TableName(table) + tName := table.Name tableName := pgx.Identifier{tName}.Sanitize() sb.WriteString("CREATE TABLE IF NOT EXISTS ") sb.WriteString(tableName) sb.WriteString(" (") - totalColumns := len(table.Fields()) + totalColumns := len(table.Columns) primaryKeys := []string{} - for i, col := range table.Fields() { + for i, col := range table.Columns { pgType := c.SchemaTypeToPg(col.Type) columnName := pgx.Identifier{col.Name}.Sanitize() fieldDef := columnName + " " + pgType - if schema.IsUnique(col) { + if col.Unique { fieldDef += " UNIQUE" } - if !col.Nullable { + if col.NotNull { fieldDef += " NOT NULL" } sb.WriteString(fieldDef) if i != totalColumns-1 { sb.WriteString(",") } - if c.enabledPks() && schema.IsPk(col) { + if c.enabledPks() && col.PrimaryKey { primaryKeys = append(primaryKeys, pgx.Identifier{col.Name}.Sanitize()) } } diff --git a/plugins/destination/postgresql/client/read.go b/plugins/destination/postgresql/client/read.go index f35e8cd93826e9..0f9e195494aa3f 100644 --- a/plugins/destination/postgresql/client/read.go +++ b/plugins/destination/postgresql/client/read.go @@ -1,6 +1,7 @@ package client import ( + "bytes" "context" "fmt" "net" @@ -8,29 +9,35 @@ import ( "strings" "time" + "github.com/goccy/go-json" + "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/memory" - "github.com/cloudquery/plugin-sdk/v2/schema" - "github.com/cloudquery/plugin-sdk/v2/types" + "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/cloudquery/plugin-sdk/v3/types" "github.com/google/uuid" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" ) const ( readSQL = "SELECT %s FROM %s WHERE _cq_source_name = $1 order by _cq_sync_time asc" ) +// nolint: dupl func (c *Client) reverseTransform(f arrow.Field, bldr array.Builder, val any) error { if val == nil { bldr.AppendNull() return nil } + switch b := bldr.(type) { case *array.BooleanBuilder: b.Append(val.(bool)) case *array.Int8Builder: - b.Append(val.(int8)) + // pgx always return int16 for int8 + b.Append(int8(val.(int16))) case *array.Int16Builder: b.Append(val.(int16)) case *array.Int32Builder: @@ -38,13 +45,14 @@ func (c *Client) reverseTransform(f arrow.Field, bldr array.Builder, val any) er case *array.Int64Builder: b.Append(val.(int64)) case *array.Uint8Builder: - b.Append(val.(uint8)) + b.Append(uint8(val.(int16))) case *array.Uint16Builder: - b.Append(val.(uint16)) + b.Append(uint16(val.(int32))) case *array.Uint32Builder: - b.Append(val.(uint32)) + b.Append(uint32(val.(int64))) case *array.Uint64Builder: - b.Append(val.(uint64)) + v := val.(pgtype.Numeric) + b.Append(v.Int.Uint64()) case *array.Float32Builder: b.Append(val.(float32)) case *array.Float64Builder: @@ -60,7 +68,18 @@ func (c *Client) reverseTransform(f arrow.Field, bldr array.Builder, val any) er case *array.BinaryBuilder: b.Append(val.([]byte)) case *array.TimestampBuilder: - b.Append(arrow.Timestamp(val.(time.Time).UnixMicro())) + switch b.Type().(*arrow.TimestampType).Unit { + case arrow.Second: + b.Append(arrow.Timestamp(val.(time.Time).Unix())) + case arrow.Millisecond: + b.Append(arrow.Timestamp(val.(time.Time).UnixMilli())) + case arrow.Microsecond: + b.Append(arrow.Timestamp(val.(time.Time).UnixMicro())) + case arrow.Nanosecond: + b.Append(arrow.Timestamp(val.(time.Time).UnixNano())) + default: + return fmt.Errorf("unsupported timestamp unit %s", f.Type.(*arrow.TimestampType).Unit) + } case *types.UUIDBuilder: va, ok := val.([16]byte) if !ok { @@ -73,6 +92,15 @@ func (c *Client) reverseTransform(f arrow.Field, bldr array.Builder, val any) er b.Append(u) case *types.JSONBuilder: b.Append(val) + case *array.StructBuilder: + structBytes, err := json.Marshal(val) + if err != nil { + return fmt.Errorf("failed to marshal struct: %w", err) + } + dec := json.NewDecoder(bytes.NewReader(structBytes)) + if err := b.UnmarshalOne(dec); err != nil { + return fmt.Errorf("failed to unmarshal struct: %w", err) + } case *types.InetBuilder: if v, ok := val.(netip.Prefix); ok { _, ipnet, err := net.ParseCIDR(v.String()) @@ -83,7 +111,7 @@ func (c *Client) reverseTransform(f arrow.Field, bldr array.Builder, val any) er return nil } b.Append(val.(*net.IPNet)) - case *types.MacBuilder: + case *types.MACBuilder: if c.pgType == pgTypePostgreSQL { b.Append(val.(net.HardwareAddr)) } else { @@ -113,24 +141,31 @@ func (c *Client) reverseTransform(f arrow.Field, bldr array.Builder, val any) er return nil } -func (c *Client) reverseTransformer(sc *arrow.Schema, values []any) (arrow.Record, error) { +func (c *Client) reverseTransformer(table *schema.Table, values []any) (arrow.Record, error) { + sc := table.ToArrowSchema() bldr := array.NewRecordBuilder(memory.DefaultAllocator, sc) for i, f := range sc.Fields() { - if err := c.reverseTransform(f, bldr.Field(i), values[i]); err != nil { - return nil, err + if c.pgType == pgTypePostgreSQL { + if err := c.reverseTransform(f, bldr.Field(i), values[i]); err != nil { + return nil, err + } + } else { + if err := c.reverseTransformCockroach(f, bldr.Field(i), values[i]); err != nil { + return nil, err + } } } rec := bldr.NewRecord() return rec, nil } -func (c *Client) Read(ctx context.Context, table *arrow.Schema, sourceName string, res chan<- arrow.Record) error { - colNames := make([]string, 0, len(table.Fields())) - for _, col := range table.Fields() { +func (c *Client) Read(ctx context.Context, table *schema.Table, sourceName string, res chan<- arrow.Record) error { + colNames := make([]string, 0, len(table.Columns)) + for _, col := range table.Columns { colNames = append(colNames, pgx.Identifier{col.Name}.Sanitize()) } cols := strings.Join(colNames, ",") - tableName := schema.TableName(table) + tableName := table.Name sql := fmt.Sprintf(readSQL, cols, pgx.Identifier{tableName}.Sanitize()) rows, err := c.conn.Query(ctx, sql, sourceName) if err != nil { diff --git a/plugins/destination/postgresql/client/read_cockroach.go b/plugins/destination/postgresql/client/read_cockroach.go new file mode 100644 index 00000000000000..2af4feea58235d --- /dev/null +++ b/plugins/destination/postgresql/client/read_cockroach.go @@ -0,0 +1,133 @@ +package client + +import ( + "bytes" + "fmt" + "net" + "net/netip" + "time" + + "github.com/goccy/go-json" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/google/uuid" +) + +// nolint:dupl +func (c *Client) reverseTransformCockroach(f arrow.Field, bldr array.Builder, val any) error { + if val == nil { + bldr.AppendNull() + return nil + } + + switch b := bldr.(type) { + case *array.BooleanBuilder: + b.Append(val.(bool)) + case *array.Int8Builder: + // pgx always return int16 for int8 + b.Append(int8(val.(int16))) + case *array.Int16Builder: + b.Append(val.(int16)) + case *array.Int32Builder: + b.Append(int32(val.(int64))) + case *array.Int64Builder: + b.Append(val.(int64)) + case *array.Uint8Builder: + b.Append(uint8(val.(int16))) + case *array.Uint16Builder: + b.Append(uint16(val.(int64))) + case *array.Uint32Builder: + b.Append(uint32(val.(int64))) + case *array.Uint64Builder: + v := val.(pgtype.Numeric) + b.Append(v.Int.Uint64()) + case *array.Float32Builder: + b.Append(val.(float32)) + case *array.Float64Builder: + b.Append(val.(float64)) + case *array.StringBuilder: + va, ok := val.(string) + if !ok { + return fmt.Errorf("unsupported type %T with builder %T and column %s", val, bldr, f.Name) + } + b.Append(va) + case *array.LargeStringBuilder: + b.Append(val.(string)) + case *array.BinaryBuilder: + b.Append(val.([]byte)) + case *array.TimestampBuilder: + switch b.Type().(*arrow.TimestampType).Unit { + case arrow.Second: + b.Append(arrow.Timestamp(val.(time.Time).Unix())) + case arrow.Millisecond: + b.Append(arrow.Timestamp(val.(time.Time).UnixMilli())) + case arrow.Microsecond: + b.Append(arrow.Timestamp(val.(time.Time).UnixMicro())) + case arrow.Nanosecond: + b.Append(arrow.Timestamp(val.(time.Time).UnixNano())) + default: + return fmt.Errorf("unsupported timestamp unit %s", f.Type.(*arrow.TimestampType).Unit) + } + case *types.UUIDBuilder: + va, ok := val.([16]byte) + if !ok { + return fmt.Errorf("unsupported type %T with builder %T", val, bldr) + } + u, err := uuid.FromBytes(va[:]) + if err != nil { + return err + } + b.Append(u) + case *types.JSONBuilder: + b.Append(val) + case *array.StructBuilder: + structBytes, err := json.Marshal(val) + if err != nil { + return fmt.Errorf("failed to marshal struct: %w", err) + } + dec := json.NewDecoder(bytes.NewReader(structBytes)) + if err := b.UnmarshalOne(dec); err != nil { + return fmt.Errorf("failed to unmarshal struct: %w", err) + } + case *types.InetBuilder: + if v, ok := val.(netip.Prefix); ok { + _, ipnet, err := net.ParseCIDR(v.String()) + if err != nil { + return err + } + b.Append(ipnet) + return nil + } + b.Append(val.(*net.IPNet)) + case *types.MACBuilder: + if c.pgType == pgTypePostgreSQL { + b.Append(val.(net.HardwareAddr)) + } else { + hardwareAddr, err := net.ParseMAC(val.(string)) + if err != nil { + return err + } + b.Append(hardwareAddr) + } + case array.ListLikeBuilder: + b.Append(true) + valBuilder := b.ValueBuilder() + for _, v := range val.([]any) { + if err := c.reverseTransformCockroach(f, valBuilder, v); err != nil { + return err + } + } + default: + v, ok := val.(string) + if !ok { + return fmt.Errorf("unsupported type %T with builder %T", val, bldr) + } + if err := bldr.AppendValueFromString(v); err != nil { + return err + } + } + return nil +} diff --git a/plugins/destination/postgresql/client/transformer.go b/plugins/destination/postgresql/client/transformer.go index 3744d8ad953e7d..d19bdced827f72 100644 --- a/plugins/destination/postgresql/client/transformer.go +++ b/plugins/destination/postgresql/client/transformer.go @@ -6,7 +6,7 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" - "github.com/cloudquery/plugin-sdk/v2/types" + "github.com/cloudquery/plugin-sdk/v3/types" "github.com/jackc/pgx/v5/pgtype" ) @@ -65,6 +65,11 @@ func transformArr(arr arrow.Array) []any { Bool: a.Value(i), Valid: a.IsValid(i), } + case *array.Int8: + pgArr[i] = pgtype.Int2{ + Int16: int16(a.Value(i)), + Valid: a.IsValid(i), + } case *array.Int16: pgArr[i] = pgtype.Int2{ Int16: a.Value(i), @@ -80,6 +85,23 @@ func transformArr(arr arrow.Array) []any { Int64: a.Value(i), Valid: a.IsValid(i), } + case *array.Uint8: + pgArr[i] = pgtype.Int2{ + Int16: int16(a.Value(i)), + Valid: a.IsValid(i), + } + case *array.Uint16: + pgArr[i] = pgtype.Int4{ + Int32: int32(a.Value(i)), + Valid: a.IsValid(i), + } + case *array.Uint32: + pgArr[i] = pgtype.Int8{ + Int64: int64(a.Value(i)), + Valid: a.IsValid(i), + } + case *array.Uint64: + pgArr[i] = a.Value(i) case *array.Float32: pgArr[i] = pgtype.Float4{ Float32: a.Value(i), @@ -106,7 +128,7 @@ func transformArr(arr arrow.Array) []any { } case *array.Timestamp: pgArr[i] = pgtype.Timestamptz{ - Time: a.Value(i).ToTime(arrow.Microsecond), + Time: a.Value(i).ToTime(a.DataType().(*arrow.TimestampType).Unit).UTC(), Valid: a.IsValid(i), } case *types.UUIDArray: @@ -123,6 +145,8 @@ func transformArr(arr arrow.Array) []any { start, end := a.ValueOffsets(i) nested := array.NewSlice(a.ListValues(), start, end) pgArr[i] = transformArr(nested) + case *types.JSONArray: + pgArr[i] = a.Storage().(*array.Binary).Value(i) default: pgArr[i] = stripNulls(arr.ValueStr(i)) } diff --git a/plugins/destination/postgresql/client/types_cockroach.go b/plugins/destination/postgresql/client/types_cockroach.go index a2adf2dfe5e596..ccbeecb44b81ab 100644 --- a/plugins/destination/postgresql/client/types_cockroach.go +++ b/plugins/destination/postgresql/client/types_cockroach.go @@ -5,7 +5,7 @@ import ( "strings" "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-sdk/v2/types" + "github.com/cloudquery/plugin-sdk/v3/types" ) func (c *Client) SchemaTypeToCockroach(t arrow.DataType) string { @@ -16,14 +16,12 @@ func (c *Client) SchemaTypeToCockroach(t arrow.DataType) string { return c.SchemaTypeToCockroach(v.Elem()) + fmt.Sprintf("[%d]", v.Len()) case *arrow.BooleanType: return "boolean" - case *arrow.Int8Type, *arrow.Uint8Type: + case *arrow.Int8Type, *arrow.Uint8Type, *arrow.Int16Type: return "smallint" - case *arrow.Int16Type, *arrow.Uint16Type: - return "smallint" - case *arrow.Int32Type, *arrow.Uint32Type: - return "integer" - case *arrow.Int64Type, *arrow.Uint64Type: + case *arrow.Uint16Type, *arrow.Int32Type, *arrow.Uint32Type, *arrow.Int64Type: return "bigint" + case *arrow.Uint64Type: + return "numeric" case *arrow.Float32Type: return "real" case *arrow.Float64Type: @@ -58,10 +56,14 @@ func (c *Client) CockroachToSchemaType(t string) arrow.DataType { switch t { case "boolean": return arrow.FixedWidthTypes.Boolean - case "bigint", "int", "oid", "serial": + case "smallint": + return arrow.PrimitiveTypes.Int16 + case "int4", "bigint", "int", "oid", "serial", "integer", "int8", "int64": return arrow.PrimitiveTypes.Int64 case "decimal", "float", "real", "double precision": return arrow.PrimitiveTypes.Float64 + case "numeric": + return arrow.PrimitiveTypes.Uint64 case "uuid": return types.ExtensionTypes.UUID case "bytea": diff --git a/plugins/destination/postgresql/client/types_pg.go b/plugins/destination/postgresql/client/types_pg.go index 3edd52dfd14382..2e8069134c1851 100644 --- a/plugins/destination/postgresql/client/types_pg.go +++ b/plugins/destination/postgresql/client/types_pg.go @@ -5,7 +5,7 @@ import ( "strings" "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-sdk/v2/types" + "github.com/cloudquery/plugin-sdk/v3/types" ) func (c *Client) SchemaTypeToPg(t arrow.DataType) string { @@ -36,12 +36,14 @@ func (c *Client) SchemaTypeToPg10(t arrow.DataType) string { return "boolean" case *arrow.Int8Type, *arrow.Uint8Type: return "smallint" - case *arrow.Int16Type, *arrow.Uint16Type: + case *arrow.Int16Type: return "smallint" - case *arrow.Int32Type, *arrow.Uint32Type: + case *arrow.Uint16Type, *arrow.Int32Type: return "integer" - case *arrow.Int64Type, *arrow.Uint64Type: + case *arrow.Uint32Type, *arrow.Int64Type: return "bigint" + case *arrow.Uint64Type: + return "numeric" case *arrow.Float32Type: return "real" case *arrow.Float64Type: @@ -60,7 +62,7 @@ func (c *Client) SchemaTypeToPg10(t arrow.DataType) string { return "jsonb" case *types.InetType: return "inet" - case *types.MacType: + case *types.MACType: return "macaddr" default: return "text" @@ -86,6 +88,8 @@ func (c *Client) Pg10ToSchemaType(t string) arrow.DataType { return arrow.PrimitiveTypes.Int32 case "bigint": return arrow.PrimitiveTypes.Int64 + case "numeric": + return arrow.PrimitiveTypes.Uint64 case "real": return arrow.PrimitiveTypes.Float32 case "double precision": @@ -99,7 +103,7 @@ func (c *Client) Pg10ToSchemaType(t string) arrow.DataType { case "cidr": return types.ExtensionTypes.Inet case "macaddr", "macaddr8": - return types.ExtensionTypes.Mac + return types.ExtensionTypes.MAC case "inet": return types.ExtensionTypes.Inet default: diff --git a/plugins/destination/postgresql/client/write.go b/plugins/destination/postgresql/client/write.go index 9ad5a9e0916fbe..471362c90dbba3 100644 --- a/plugins/destination/postgresql/client/write.go +++ b/plugins/destination/postgresql/client/write.go @@ -10,7 +10,7 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v2/schema" + "github.com/cloudquery/plugin-sdk/v3/schema" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" ) @@ -54,7 +54,7 @@ func pgErrToStr(err *pgconn.PgError) string { return sb.String() } -func (c *Client) Write(ctx context.Context, tables schema.Schemas, res <-chan arrow.Record) error { +func (c *Client) Write(ctx context.Context, tables schema.Tables, res <-chan arrow.Record) error { var sql string batch := &pgx.Batch{} pgTables, err := c.listPgTables(ctx, tables) @@ -66,15 +66,19 @@ func (c *Client) Write(ctx context.Context, tables schema.Schemas, res <-chan ar return err } for r := range res { - tableName := schema.TableName(r.Schema()) - table := tables.SchemaByName(tableName) + md := r.Schema().Metadata() + tableName, ok := md.GetValue(schema.MetadataTableName) + if !ok { + return fmt.Errorf("table name not found in metadata") + } + table := tables.Get(tableName) if table == nil { - panic(fmt.Errorf("table %s not found", tableName)) + return fmt.Errorf("table %s not found", tableName) } if c.spec.WriteMode == specs.WriteModeAppend { sql = c.insert(table) } else { - if len(schema.PrimaryKeyIndices(table)) > 0 { + if len(table.PrimaryKeysIndexes()) > 0 { sql = c.upsert(table) } else { sql = c.insert(table) @@ -117,13 +121,13 @@ func (c *Client) Write(ctx context.Context, tables schema.Schemas, res <-chan ar return nil } -func (*Client) insert(table *arrow.Schema) string { +func (*Client) insert(table *schema.Table) string { var sb strings.Builder - tableName := schema.TableName(table) + tableName := table.Name sb.WriteString("insert into ") sb.WriteString(pgx.Identifier{tableName}.Sanitize()) sb.WriteString(" (") - columns := table.Fields() + columns := table.Columns columnsLen := len(columns) for i, c := range columns { sb.WriteString(pgx.Identifier{c.Name}.Sanitize()) @@ -144,17 +148,14 @@ func (*Client) insert(table *arrow.Schema) string { return sb.String() } -func (c *Client) upsert(table *arrow.Schema) string { +func (c *Client) upsert(table *schema.Table) string { var sb strings.Builder sb.WriteString(c.insert(table)) - columns := table.Fields() + columns := table.Columns columnsLen := len(columns) - constraintName, ok := table.Metadata().GetValue(schema.MetadataConstraintName) - if !ok { - panic(fmt.Errorf("constraint_name not found in table metadata")) - } + constraintName := table.PkConstraintName sb.WriteString(" on conflict on constraint ") sb.WriteString(pgx.Identifier{constraintName}.Sanitize()) sb.WriteString(" do update set ") diff --git a/plugins/destination/postgresql/go.mod b/plugins/destination/postgresql/go.mod index 3c253d2cb5441e..f2bb9747205116 100644 --- a/plugins/destination/postgresql/go.mod +++ b/plugins/destination/postgresql/go.mod @@ -5,7 +5,8 @@ go 1.19 require ( github.com/apache/arrow/go/v13 v13.0.0-20230509040948-de6c3cd2b604 github.com/cloudquery/plugin-pb-go v1.0.8 - github.com/cloudquery/plugin-sdk/v2 v2.7.0 + github.com/cloudquery/plugin-sdk/v3 v3.5.1 + github.com/goccy/go-json v0.9.11 github.com/google/uuid v1.3.0 github.com/jackc/pgx-zerolog v0.0.0-20230315001418-f978528409eb github.com/jackc/pgx/v5 v5.3.1 @@ -19,10 +20,10 @@ replace github.com/apache/arrow/go/v13 => github.com/cloudquery/arrow/go/v13 v13 require ( github.com/andybalholm/brotli v1.0.5 // indirect github.com/apache/thrift v0.16.0 // indirect + github.com/cloudquery/plugin-sdk/v2 v2.7.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/getsentry/sentry-go v0.20.0 // indirect github.com/ghodss/yaml v1.0.0 // indirect - github.com/goccy/go-json v0.9.11 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v2.0.8+incompatible // indirect @@ -39,6 +40,7 @@ require ( github.com/mattn/go-isatty v0.0.18 // indirect github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 // indirect github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect + github.com/pierrec/lz4/v4 v4.1.15 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/spf13/cobra v1.6.1 // indirect diff --git a/plugins/destination/postgresql/go.sum b/plugins/destination/postgresql/go.sum index 844c893478c784..2daea4ff00526f 100644 --- a/plugins/destination/postgresql/go.sum +++ b/plugins/destination/postgresql/go.sum @@ -50,6 +50,8 @@ github.com/cloudquery/plugin-pb-go v1.0.8 h1:wn3GXhcNItcP+6wUUZuzUFbvdL59liKBO37 github.com/cloudquery/plugin-pb-go v1.0.8/go.mod h1:vAGA27psem7ZZNAY4a3S9TKuA/JDQWstjKcHPJX91Mc= github.com/cloudquery/plugin-sdk/v2 v2.7.0 h1:hRXsdEiaOxJtsn/wZMFQC9/jPfU1MeMK3KF+gPGqm7U= github.com/cloudquery/plugin-sdk/v2 v2.7.0/go.mod h1:pAX6ojIW99b/Vg4CkhnsGkRIzNaVEceYMR+Bdit73ug= +github.com/cloudquery/plugin-sdk/v3 v3.5.1 h1:797hWUEsojwvp7xtr6LSaf5tk5iG9UDixoRACxu3xrU= +github.com/cloudquery/plugin-sdk/v3 v3.5.1/go.mod h1:3JrZXEULmGXpkOukVaRIzaA63d7TJr9Ukp6hemTjbtc= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -184,6 +186,7 @@ github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8D github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pierrec/lz4/v4 v4.1.15 h1:MO0/ucJhngq7299dKLwIMtgTfbkoSPF6AoMYDd8Q4q0= +github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/plugins/destination/postgresql/main.go b/plugins/destination/postgresql/main.go index 1425dfbd2bb00f..a62404dea843d7 100644 --- a/plugins/destination/postgresql/main.go +++ b/plugins/destination/postgresql/main.go @@ -3,8 +3,8 @@ package main import ( "github.com/cloudquery/cloudquery/plugins/destination/postgresql/client" "github.com/cloudquery/cloudquery/plugins/destination/postgresql/resources/plugin" - "github.com/cloudquery/plugin-sdk/v2/plugins/destination" - "github.com/cloudquery/plugin-sdk/v2/serve" + "github.com/cloudquery/plugin-sdk/v3/plugins/destination" + "github.com/cloudquery/plugin-sdk/v3/serve" ) const (