Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/destination/sqlite/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"fmt"

"github.com/cloudquery/plugin-pb-go/specs"
"github.com/cloudquery/plugin-sdk/v2/plugins/destination"
"github.com/cloudquery/plugin-sdk/v3/plugins/destination"
"github.com/rs/zerolog"

// Import sqlite3 driver
Expand Down
2 changes: 1 addition & 1 deletion plugins/destination/sqlite/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"testing"

"github.com/cloudquery/plugin-pb-go/specs"
"github.com/cloudquery/plugin-sdk/v2/plugins/destination"
"github.com/cloudquery/plugin-sdk/v3/plugins/destination"
)

var migrateStrategy = destination.MigrateStrategy{
Expand Down
6 changes: 3 additions & 3 deletions plugins/destination/sqlite/client/deletestale.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ import (
"strings"
"time"

"github.com/cloudquery/plugin-sdk/v2/schema"
"github.com/cloudquery/plugin-sdk/v3/schema"
)

func (c *Client) DeleteStale(ctx context.Context, tables schema.Schemas, source string, syncTime time.Time) error {
func (c *Client) DeleteStale(ctx context.Context, tables schema.Tables, source string, syncTime time.Time) error {
for _, table := range tables {
var sb strings.Builder
sb.WriteString("delete from ")
sb.WriteString(`"` + schema.TableName(table) + `"`)
sb.WriteString(`"` + table.Name + `"`)
sb.WriteString(" where ")
sb.WriteString(`"` + schema.CqSourceNameColumn.Name + `"`)
sb.WriteString(" = $1 and datetime(")
Expand Down
2 changes: 1 addition & 1 deletion plugins/destination/sqlite/client/metrics.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package client

import (
"github.com/cloudquery/plugin-sdk/v2/plugins/destination"
"github.com/cloudquery/plugin-sdk/v3/plugins/destination"
)

func (c *Client) Metrics() destination.Metrics {
Expand Down
187 changes: 82 additions & 105 deletions plugins/destination/sqlite/client/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

"github.com/apache/arrow/go/v13/arrow"
"github.com/cloudquery/plugin-pb-go/specs"
"github.com/cloudquery/plugin-sdk/v2/schema"
"github.com/cloudquery/plugin-sdk/v3/schema"
)

const (
Expand All @@ -27,158 +27,142 @@ type tableInfo struct {
columns []columnInfo
}

func (c *Client) sqliteTables(schemas schema.Schemas) (schema.Schemas, error) {
var schemaTables schema.Schemas
for _, sc := range schemas {
var fields []arrow.Field
tableName := schema.TableName(sc)
if tableName == "" {
return nil, fmt.Errorf("schema %s has no table name", sc.String())
}
info, err := c.getTableInfo(tableName)
func (c *Client) sqliteTables(tables schema.Tables) (schema.Tables, error) {
var schemaTables schema.Tables
for _, table := range tables {
var columns []schema.Column
info, err := c.getTableInfo(table.Name)
if info == nil {
continue
}
if err != nil {
return nil, err
}
for _, col := range info.columns {
var fieldMetadata schema.MetadataFieldOptions
if col.pk != 0 {
fieldMetadata.PrimaryKey = true
}
fields = append(fields, arrow.Field{
Name: col.name,
Type: c.sqliteTypeToArrowType(col.typ),
Nullable: !col.notNull,
Metadata: schema.NewFieldMetadataFromOptions(fieldMetadata),
columns = append(columns, schema.Column{
Name: col.name,
Type: c.sqliteTypeToArrowType(col.typ),
PrimaryKey: col.pk != 0,
NotNull: col.notNull,
})
}
var tableMetadata schema.MetadataSchemaOptions
tableMetadata.TableName = tableName
m := schema.NewSchemaMetadataFromOptions(tableMetadata)
schemaTables = append(schemaTables, arrow.NewSchema(fields, &m))
schemaTables = append(schemaTables, &schema.Table{Name: table.Name, Columns: columns})
}
return schemaTables, nil
}

func (c *Client) normalizeSchemas(scs schema.Schemas) schema.Schemas {
var normalized schema.Schemas
for _, sc := range scs {
fields := make([]arrow.Field, 0)
for _, f := range sc.Fields() {
keys := make([]string, 0)
values := make([]string, 0)
origKeys := f.Metadata.Keys()
origValues := f.Metadata.Values()
for k, v := range origKeys {
if v != schema.MetadataUnique {
keys = append(keys, v)
values = append(values, origValues[k])
}
}
fields = append(fields, arrow.Field{
Name: f.Name,
Type: c.arrowTypeToSqlite(f.Type),
Nullable: f.Nullable,
Metadata: arrow.NewMetadata(keys, values),
})
}
func (c *Client) normalizeTables(tables schema.Tables) schema.Tables {
flattened := tables.FlattenTables()
normalized := make(schema.Tables, len(flattened))
for i, table := range flattened {
normalized[i] = c.normalizeTable(table)
}
return normalized
}

md := sc.Metadata()
normalized = append(normalized, arrow.NewSchema(fields, &md))
func (c *Client) normalizeTable(table *schema.Table) *schema.Table {
columns := make([]schema.Column, len(table.Columns))
for i, col := range table.Columns {
normalized := c.normalizeField(col.ToArrowField())
columns[i] = schema.NewColumnFromArrowField(*normalized)
}
return &schema.Table{Name: table.Name, Columns: columns}
}

return normalized
func (c *Client) normalizeField(field arrow.Field) *arrow.Field {
return &arrow.Field{
Name: field.Name,
Type: c.arrowTypeToSqlite(field.Type),
Nullable: field.Nullable,
Metadata: field.Metadata,
}
}

func (c *Client) nonAutoMigrableTables(tables schema.Schemas, sqliteTables schema.Schemas) ([]string, [][]schema.FieldChange) {
func (c *Client) nonAutoMigratableTables(tables schema.Tables, sqliteTables schema.Tables) ([]string, [][]schema.TableColumnChange) {
var result []string
var tableChanges [][]schema.FieldChange
var tableChanges [][]schema.TableColumnChange
for _, t := range tables {
tableName := schema.TableName(t)
sqliteTable := sqliteTables.SchemaByName(tableName)
sqliteTable := sqliteTables.Get(t.Name)
if sqliteTable == nil {
continue
}
changes := schema.GetSchemaChanges(t, sqliteTable)
changes := sqliteTable.GetChanges(t)
if !c.canAutoMigrate(changes) {
result = append(result, tableName)
result = append(result, t.Name)
tableChanges = append(tableChanges, changes)
}
}
return result, tableChanges
}

func (c *Client) autoMigrateTable(table *arrow.Schema, changes []schema.FieldChange) error {
func (c *Client) autoMigrateTable(table *schema.Table, changes []schema.TableColumnChange) error {
for _, change := range changes {
if change.Type == schema.TableColumnChangeTypeAdd {
if err := c.addColumn(schema.TableName(table), change.Current.Name, c.arrowTypeToSqliteStr(change.Current.Type)); err != nil {
if err := c.addColumn(table.Name, change.Current.Name, c.arrowTypeToSqliteStr(change.Current.Type)); err != nil {
return err
}
}
}
return nil
}

func (*Client) canAutoMigrate(changes []schema.FieldChange) bool {
func (*Client) canAutoMigrate(changes []schema.TableColumnChange) bool {
for _, change := range changes {
if change.Type == schema.TableColumnChangeTypeAdd && (schema.IsPk(change.Current) || !change.Current.Nullable) {
return false
}

if change.Type == schema.TableColumnChangeTypeRemove && (schema.IsPk(change.Previous) || !change.Previous.Nullable) {
return false
}

if change.Type == schema.TableColumnChangeTypeUpdate {
switch change.Type {
case schema.TableColumnChangeTypeAdd:
if change.Current.PrimaryKey || change.Current.NotNull {
return false
}
case schema.TableColumnChangeTypeRemove:
if change.Previous.PrimaryKey || change.Previous.NotNull {
return false
}
case schema.TableColumnChangeTypeUpdate:
return false
default:
panic("unknown change type")
}
}
return true
}

// This is the responsibility of the CLI of the client to lock before running migration
func (c *Client) Migrate(ctx context.Context, schemas schema.Schemas) error {
schemas = c.normalizeSchemas(schemas)
sqliteTables, err := c.sqliteTables(schemas)
func (c *Client) Migrate(ctx context.Context, tables schema.Tables) error {
normalizedTables := c.normalizeTables(tables)
sqliteTables, err := c.sqliteTables(normalizedTables)
if err != nil {
return err
}

if c.spec.MigrateMode != specs.MigrateModeForced {
nonAutoMigrableTables, changes := c.nonAutoMigrableTables(schemas, sqliteTables)
if len(nonAutoMigrableTables) > 0 {
return fmt.Errorf("tables %s with changes %v require force migration. use 'migrate_mode: forced'", strings.Join(nonAutoMigrableTables, ","), changes)
nonAutoMigratableTables, changes := c.nonAutoMigratableTables(normalizedTables, sqliteTables)
if len(nonAutoMigratableTables) > 0 {
return fmt.Errorf("tables %s with changes %v require force migration. use 'migrate_mode: forced'", strings.Join(nonAutoMigratableTables, ","), changes)
}
}

for _, table := range schemas {
tableName := schema.TableName(table)
if tableName == "" {
return fmt.Errorf("schema %s has no table name", table.String())
}
c.logger.Info().Str("table", tableName).Msg("Migrating table")
if len(table.Fields()) == 0 {
c.logger.Info().Str("table", tableName).Msg("Table with no columns, skipping")
for _, table := range normalizedTables {
c.logger.Info().Str("table", table.Name).Msg("Migrating table")
if len(table.Columns) == 0 {
c.logger.Info().Str("table", table.Name).Msg("Table with no columns, skipping")
continue
}

sqlite := sqliteTables.SchemaByName(tableName)
sqlite := sqliteTables.Get(table.Name)
if sqlite == nil {
c.logger.Debug().Str("table", tableName).Msg("Table doesn't exist, creating")
c.logger.Debug().Str("table", table.Name).Msg("Table doesn't exist, creating")
if err := c.createTableIfNotExist(table); err != nil {
return err
}
} else {
changes := schema.GetSchemaChanges(table, sqlite)
changes := table.GetChanges(sqlite)
if c.canAutoMigrate(changes) {
c.logger.Info().Str("table", tableName).Msg("Table exists, auto-migrating")
c.logger.Info().Str("table", table.Name).Msg("Table exists, auto-migrating")
if err := c.autoMigrateTable(table, changes); err != nil {
return err
}
} else {
c.logger.Info().Str("table", tableName).Msg("Table exists, force migration required")
c.logger.Info().Str("table", table.Name).Msg("Table exists, force migration required")
if err := c.recreateTable(table); err != nil {
return err
}
Expand All @@ -189,14 +173,10 @@ func (c *Client) Migrate(ctx context.Context, schemas schema.Schemas) error {
return nil
}

func (c *Client) recreateTable(table *arrow.Schema) error {
tableName, ok := table.Metadata().GetValue(schema.MetadataTableName)
if !ok {
return fmt.Errorf("schema %s has no table name", table.String())
}
sql := "drop table if exists \"" + tableName + "\""
func (c *Client) recreateTable(table *schema.Table) error {
sql := "drop table if exists \"" + table.Name + "\""
if _, err := c.db.Exec(sql); err != nil {
return fmt.Errorf("failed to drop table %s: %w", tableName, err)
return fmt.Errorf("failed to drop table %s: %w", table.Name, err)
}
return c.createTableIfNotExist(table)
}
Expand All @@ -209,44 +189,41 @@ func (c *Client) addColumn(tableName string, columnName string, columnType strin
return nil
}

func (c *Client) createTableIfNotExist(sc *arrow.Schema) error {
func (c *Client) createTableIfNotExist(table *schema.Table) error {
var sb strings.Builder
tableName, ok := sc.Metadata().GetValue(schema.MetadataTableName)
if !ok {
return fmt.Errorf("schema %s has no table name", sc.String())
}
// TODO sanitize tablename

// TODO sanitize table.Name
sb.WriteString("CREATE TABLE IF NOT EXISTS ")
sb.WriteString(`"` + tableName + `"`)
sb.WriteString(`"` + table.Name + `"`)
sb.WriteString(" (")
totalColumns := len(sc.Fields())
totalColumns := len(table.Columns)

primaryKeys := []string{}
for i, col := range sc.Fields() {
for i, col := range table.Columns {
sqlType := c.arrowTypeToSqliteStr(col.Type)
if sqlType == "" {
c.logger.Warn().Str("table", tableName).Str("column", col.Name).Msg("Column type is not supported, skipping")
c.logger.Warn().Str("table", table.Name).Str("column", col.Name).Msg("Column type is not supported, skipping")
continue
}
// TODO: sanitize column name
fieldDef := `"` + col.Name + `" ` + sqlType
if !col.Nullable {
if col.NotNull {
fieldDef += " NOT NULL"
}
sb.WriteString(fieldDef)
if i != totalColumns-1 {
sb.WriteString(",")
}

if c.enabledPks() && schema.IsPk(col) {
if c.enabledPks() && col.PrimaryKey {
primaryKeys = append(primaryKeys, `"`+col.Name+`"`)
}
}

if len(primaryKeys) > 0 {
// add composite PK constraint on primary key columns
sb.WriteString(", CONSTRAINT ")
sb.WriteString(tableName)
sb.WriteString(table.Name)
sb.WriteString("_cqpk PRIMARY KEY (")
sb.WriteString(strings.Join(primaryKeys, ","))
sb.WriteString(")")
Expand Down
Loading