diff --git a/transformers/struct.go b/transformers/struct.go index 0c69dfbe78..6560008233 100644 --- a/transformers/struct.go +++ b/transformers/struct.go @@ -272,9 +272,12 @@ func (t *structTransformer) addColumnFromField(field reflect.StructField, parent } for _, pk := range t.pkFields { - if pk == field.Name { + if pk == path { + // use path to allow the following + // 1. Don't duplicate the PK fields if the unwrapped struct contains a fields with the same name + // 2. Allow specifying the nested unwrapped field as part of the PK. column.CreationOptions.PrimaryKey = true - t.pkFieldsFound = append(t.pkFieldsFound, field.Name) + t.pkFieldsFound = append(t.pkFieldsFound, pk) } } diff --git a/transformers/struct_test.go b/transformers/struct_test.go index 3091bd18b8..751e4f87ec 100644 --- a/transformers/struct_test.go +++ b/transformers/struct_test.go @@ -8,11 +8,13 @@ import ( "github.com/cloudquery/plugin-sdk/schema" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "golang.org/x/exp/slices" ) type ( embeddedStruct struct { EmbeddedString string + IntCol int `json:"int_col,omitempty"` } testStruct struct { @@ -45,10 +47,12 @@ type ( *embeddedStruct } testStructWithEmbeddedStruct struct { + IntCol int `json:"int_col,omitempty"` *testStruct *embeddedStruct } testStructWithNonEmbeddedStruct struct { + IntCol int `json:"int_col,omitempty"` TestStruct *testStruct NonEmbedded *embeddedStruct } @@ -140,16 +144,61 @@ var ( Columns: expectedColumns, } expectedTestTableEmbeddedStruct = schema.Table{ + Name: "test_struct", + Columns: append(expectedColumns, schema.Column{Name: "embedded_string", Type: schema.TypeString}), + } + expectedTestTableEmbeddedStructWithTopLevelPK = schema.Table{ + Name: "test_struct", + Columns: func(base schema.ColumnList) schema.ColumnList { + cols := slices.Clone(base) + cols = append(cols, schema.Column{Name: "embedded_string", Type: schema.TypeString}) + cols[cols.Index("int_col")].CreationOptions.PrimaryKey = true + return cols + }(expectedColumns), + } + expectedTestTableEmbeddedStructWithUnwrappedPK = schema.Table{ Name: "test_struct", Columns: append( expectedColumns, schema.Column{ - Name: "embedded_string", - Type: schema.TypeString, + Name: "embedded_string", + Type: schema.TypeString, + CreationOptions: schema.ColumnCreationOptions{PrimaryKey: true}, }), } expectedTestTableNonEmbeddedStruct = schema.Table{ Name: "test_struct", Columns: schema.ColumnList{ + schema.Column{Name: "int_col", Type: schema.TypeInt}, + // Should not be unwrapped + schema.Column{Name: "test_struct", Type: schema.TypeJSON}, + // Should be unwrapped + schema.Column{Name: "non_embedded_embedded_string", Type: schema.TypeString}, + schema.Column{Name: "non_embedded_int_col", Type: schema.TypeInt}, + }, + } + expectedTestTableNonEmbeddedStructWithTopLevelPK = schema.Table{ + Name: "test_struct", + Columns: schema.ColumnList{ + schema.Column{ + Name: "int_col", + Type: schema.TypeInt, + CreationOptions: schema.ColumnCreationOptions{PrimaryKey: true}, + }, + // Should not be unwrapped + schema.Column{Name: "test_struct", Type: schema.TypeJSON}, + // Should be unwrapped + schema.Column{ + Name: "non_embedded_embedded_string", + Type: schema.TypeString, + }, + schema.Column{Name: "non_embedded_int_col", Type: schema.TypeInt}, + }, + } + expectedTestTableNonEmbeddedStructWithUnwrappedPK = schema.Table{ + Name: "test_struct", + Columns: schema.ColumnList{ + // shouldn't be PK + schema.Column{Name: "int_col", Type: schema.TypeInt}, // Should not be unwrapped schema.Column{Name: "test_struct", Type: schema.TypeJSON}, // Should be unwrapped @@ -157,6 +206,12 @@ var ( Name: "non_embedded_embedded_string", Type: schema.TypeString, }, + // should be PK + schema.Column{ + Name: "non_embedded_int_col", + Type: schema.TypeInt, + CreationOptions: schema.ColumnCreationOptions{PrimaryKey: true}, + }, }, } expectedTestSliceStruct = schema.Table{ @@ -219,6 +274,28 @@ func TestTableFromGoStruct(t *testing.T) { }, want: expectedTestTableEmbeddedStruct, }, + { + name: "should unwrap all embedded structs when option is set and use top-level field as PK", + args: args{ + testStruct: testStructWithEmbeddedStruct{}, + options: []StructTransformerOption{ + WithUnwrapAllEmbeddedStructs(), + WithPrimaryKeys("IntCol"), + }, + }, + want: expectedTestTableEmbeddedStructWithTopLevelPK, + }, + { + name: "should unwrap all embedded structs when option is set and use its field as PK", + args: args{ + testStruct: testStructWithEmbeddedStruct{}, + options: []StructTransformerOption{ + WithUnwrapAllEmbeddedStructs(), + WithPrimaryKeys("EmbeddedString"), + }, + }, + want: expectedTestTableEmbeddedStructWithUnwrappedPK, + }, { name: "should unwrap specific structs when option is set", args: args{ @@ -229,6 +306,28 @@ func TestTableFromGoStruct(t *testing.T) { }, want: expectedTestTableNonEmbeddedStruct, }, + { + name: "should unwrap specific structs when option is set and use top level field as PK", + args: args{ + testStruct: testStructWithNonEmbeddedStruct{}, + options: []StructTransformerOption{ + WithUnwrapStructFields("NonEmbedded"), + WithPrimaryKeys("IntCol"), + }, + }, + want: expectedTestTableNonEmbeddedStructWithTopLevelPK, + }, + { + name: "should unwrap specific structs when option is set and use its field as PK", + args: args{ + testStruct: testStructWithNonEmbeddedStruct{}, + options: []StructTransformerOption{ + WithUnwrapStructFields("NonEmbedded"), + WithPrimaryKeys("NonEmbedded.IntCol"), + }, + }, + want: expectedTestTableNonEmbeddedStructWithUnwrappedPK, + }, { name: "should generate table from slice struct", args: args{