Skip to content

Commit a17ee4b

Browse files
authored
feat: Add unwrap option to transformations (#573)
Follow-up on transformation. Adding unwrap option as needed in some cases where we create our own wrapper struct. This also moves the tests from codegen.
1 parent d89a911 commit a17ee4b

2 files changed

Lines changed: 321 additions & 14 deletions

File tree

transformers/struct.go

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ import (
1010
)
1111

1212
type structTransformer struct {
13-
table *schema.Table
14-
skipFields []string
15-
nameTransformer NameTransformer
16-
typeTransformer TypeTransformer
13+
table *schema.Table
14+
skipFields []string
15+
nameTransformer NameTransformer
16+
typeTransformer TypeTransformer
17+
unwrapAllEmbeddedStructFields bool
18+
structFieldsToUnwrap []string
1719
}
1820

1921
type NameTransformer func(reflect.StructField) (string, error)
@@ -22,6 +24,31 @@ type TypeTransformer func(reflect.StructField) (schema.ValueType, error)
2224

2325
type StructTransformerOption func(*structTransformer)
2426

27+
func isFieldStruct(reflectType reflect.Type) bool {
28+
switch reflectType.Kind() {
29+
case reflect.Struct:
30+
return true
31+
case reflect.Ptr:
32+
return reflectType.Elem().Kind() == reflect.Struct
33+
default:
34+
return false
35+
}
36+
}
37+
38+
// WithUnwrapAllEmbeddedStructs instructs codegen to unwrap all embedded fields (1 level deep only)
39+
func WithUnwrapAllEmbeddedStructs() StructTransformerOption {
40+
return func(t *structTransformer) {
41+
t.unwrapAllEmbeddedStructFields = true
42+
}
43+
}
44+
45+
// WithUnwrapStructFields allows to unwrap specific struct fields (1 level deep only)
46+
func WithUnwrapStructFields(fields ...string) StructTransformerOption {
47+
return func(t *structTransformer) {
48+
t.structFieldsToUnwrap = fields
49+
}
50+
}
51+
2552
// WithSkipFields allows to specify what struct fields should be skipped.
2653
func WithSkipFields(fields ...string) StructTransformerOption {
2754
return func(t *structTransformer) {
@@ -37,14 +64,6 @@ func WithNameTransformer(transformer NameTransformer) StructTransformerOption {
3764
}
3865
}
3966

40-
// WithTypeTransformer sets a function that can override the schema type for specific fields. Return `schema.TypeInvalid` to fall back to default behavior.
41-
// DefaultTypeTransformer is used as the default.
42-
func WithTypeTransformer(transformer TypeTransformer) StructTransformerOption {
43-
return func(t *structTransformer) {
44-
t.typeTransformer = transformer
45-
}
46-
}
47-
4867
func TransformWithStruct(st any, opts ...StructTransformerOption) schema.Transform {
4968
t := &structTransformer{
5069
nameTransformer: codegen.DefaultNameTransformer,
@@ -70,14 +89,69 @@ func TransformWithStruct(st any, opts ...StructTransformerOption) schema.Transfo
7089
for i := 0; i < e.NumField(); i++ {
7190
field := eType.Field(i)
7291

73-
if err := t.addColumnFromField(field, nil); err != nil {
74-
return fmt.Errorf("failed to add column for field %s: %w", field.Name, err)
92+
switch {
93+
case t.shouldUnwrapField(field):
94+
if err := t.unwrapField(field); err != nil {
95+
return err
96+
}
97+
default:
98+
if err := t.addColumnFromField(field, nil); err != nil {
99+
return fmt.Errorf("failed to add column for field %s: %w", field.Name, err)
100+
}
75101
}
76102
}
77103
return nil
78104
}
79105
}
80106

107+
func (t *structTransformer) getUnwrappedFields(field reflect.StructField) []reflect.StructField {
108+
reflectType := field.Type
109+
if reflectType.Kind() == reflect.Ptr {
110+
reflectType = reflectType.Elem()
111+
}
112+
113+
fields := make([]reflect.StructField, 0)
114+
for i := 0; i < reflectType.NumField(); i++ {
115+
sf := reflectType.Field(i)
116+
if t.ignoreField(sf) {
117+
continue
118+
}
119+
120+
fields = append(fields, sf)
121+
}
122+
return fields
123+
}
124+
125+
func (t *structTransformer) unwrapField(field reflect.StructField) error {
126+
unwrappedFields := t.getUnwrappedFields(field)
127+
var parent *reflect.StructField
128+
// For non embedded structs we need to add the parent field name to the path
129+
if !field.Anonymous {
130+
parent = &field
131+
}
132+
for _, f := range unwrappedFields {
133+
if err := t.addColumnFromField(f, parent); err != nil {
134+
return fmt.Errorf("failed to add column from field %s: %w", f.Name, err)
135+
}
136+
}
137+
return nil
138+
}
139+
140+
func (t *structTransformer) shouldUnwrapField(field reflect.StructField) bool {
141+
switch {
142+
case !isFieldStruct(field.Type):
143+
return false
144+
case slices.Contains(t.structFieldsToUnwrap, field.Name):
145+
return true
146+
case !field.Anonymous:
147+
return false
148+
case t.unwrapAllEmbeddedStructFields:
149+
return true
150+
default:
151+
return false
152+
}
153+
}
154+
81155
func (t *structTransformer) ignoreField(field reflect.StructField) bool {
82156
switch {
83157
case len(field.Name) == 0,

transformers/struct_test.go

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
package transformers
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/cloudquery/plugin-sdk/schema"
8+
"github.com/google/go-cmp/cmp"
9+
"github.com/google/go-cmp/cmp/cmpopts"
10+
)
11+
12+
type (
13+
embeddedStruct struct {
14+
EmbeddedString string
15+
}
16+
17+
testStruct struct {
18+
// IntCol this is an example documentation comment
19+
IntCol int `json:"int_col,omitempty"`
20+
Int64Col int64 `json:"int64_col,omitempty"`
21+
StringCol string `json:"string_col,omitempty"`
22+
FloatCol float64 `json:"float_col,omitempty"`
23+
BoolCol bool `json:"bool_col,omitempty"`
24+
JSONCol struct {
25+
IntCol int `json:"int_col,omitempty"`
26+
StringCol string `json:"string_col,omitempty"`
27+
}
28+
IntArrayCol []int `json:"int_array_col,omitempty"`
29+
IntPointerArrayCol []*int `json:"int_pointer_array_col,omitempty"`
30+
31+
StringArrayCol []string `json:"string_array_col,omitempty"`
32+
StringPointerArrayCol []*string `json:"string_pointer_array_col,omitempty"`
33+
34+
TimeCol time.Time `json:"time_col,omitempty"`
35+
TimePointerCol *time.Time `json:"time_pointer_col,omitempty"`
36+
JSONTag *string `json:"json_tag"`
37+
SkipJSONTag *string `json:"-"`
38+
NoJSONTag *string
39+
*embeddedStruct
40+
}
41+
testStructWithEmbeddedStruct struct {
42+
*testStruct
43+
*embeddedStruct
44+
}
45+
testStructWithNonEmbeddedStruct struct {
46+
TestStruct *testStruct
47+
NonEmbedded *embeddedStruct
48+
}
49+
50+
testSliceStruct []struct {
51+
IntCol int
52+
}
53+
)
54+
55+
var (
56+
expectedColumns = []schema.Column{
57+
{
58+
Name: "int_col",
59+
Type: schema.TypeInt,
60+
},
61+
{
62+
Name: "int64_col",
63+
Type: schema.TypeInt,
64+
},
65+
{
66+
Name: "string_col",
67+
Type: schema.TypeString,
68+
},
69+
{
70+
Name: "float_col",
71+
Type: schema.TypeFloat,
72+
},
73+
{
74+
Name: "bool_col",
75+
Type: schema.TypeBool,
76+
},
77+
{
78+
Name: "json_col",
79+
Type: schema.TypeJSON,
80+
},
81+
{
82+
Name: "int_array_col",
83+
Type: schema.TypeIntArray,
84+
},
85+
{
86+
Name: "int_pointer_array_col",
87+
Type: schema.TypeIntArray,
88+
},
89+
{
90+
Name: "string_array_col",
91+
Type: schema.TypeStringArray,
92+
},
93+
{
94+
Name: "string_pointer_array_col",
95+
Type: schema.TypeStringArray,
96+
},
97+
{
98+
Name: "time_col",
99+
Type: schema.TypeTimestamp,
100+
},
101+
{
102+
Name: "time_pointer_col",
103+
Type: schema.TypeTimestamp,
104+
},
105+
{
106+
Name: "json_tag",
107+
Type: schema.TypeString,
108+
},
109+
{
110+
Name: "no_json_tag",
111+
Type: schema.TypeString,
112+
},
113+
}
114+
expectedTestTable = schema.Table{
115+
Name: "test_struct",
116+
Columns: expectedColumns,
117+
}
118+
expectedTestTableEmbeddedStruct = schema.Table{
119+
Name: "test_struct",
120+
Columns: append(
121+
expectedColumns, schema.Column{
122+
Name: "embedded_string",
123+
Type: schema.TypeString,
124+
}),
125+
}
126+
expectedTestTableNonEmbeddedStruct = schema.Table{
127+
Name: "test_struct",
128+
Columns: schema.ColumnList{
129+
// Should not be unwrapped
130+
schema.Column{Name: "test_struct", Type: schema.TypeJSON},
131+
// Should be unwrapped
132+
schema.Column{
133+
Name: "non_embedded_embedded_string",
134+
Type: schema.TypeString,
135+
},
136+
},
137+
}
138+
expectedTestTableStructForCustomResolvers = schema.Table{
139+
Name: "test_struct",
140+
Columns: schema.ColumnList{
141+
{
142+
Name: "time_col",
143+
Type: schema.TypeTimestamp,
144+
},
145+
{
146+
Name: "custom",
147+
Type: schema.TypeTimestamp,
148+
},
149+
},
150+
}
151+
expectedTestSliceStruct = schema.Table{
152+
Name: "test_struct",
153+
Columns: schema.ColumnList{
154+
{
155+
Name: "int_col",
156+
Type: schema.TypeInt,
157+
},
158+
},
159+
}
160+
)
161+
162+
func TestTableFromGoStruct(t *testing.T) {
163+
type args struct {
164+
testStruct any
165+
options []StructTransformerOption
166+
}
167+
168+
tests := []struct {
169+
name string
170+
args args
171+
want schema.Table
172+
wantErr bool
173+
}{
174+
{
175+
name: "should generate table from struct with default options",
176+
args: args{
177+
testStruct: testStruct{},
178+
},
179+
want: expectedTestTable,
180+
},
181+
{
182+
name: "should unwrap all embedded structs when option is set",
183+
args: args{
184+
testStruct: testStructWithEmbeddedStruct{},
185+
options: []StructTransformerOption{
186+
WithUnwrapAllEmbeddedStructs(),
187+
},
188+
},
189+
want: expectedTestTableEmbeddedStruct,
190+
},
191+
{
192+
name: "should unwrap specific structs when option is set",
193+
args: args{
194+
testStruct: testStructWithNonEmbeddedStruct{},
195+
options: []StructTransformerOption{
196+
WithUnwrapStructFields("NonEmbedded"),
197+
},
198+
},
199+
want: expectedTestTableNonEmbeddedStruct,
200+
},
201+
{
202+
name: "should generate table from slice struct",
203+
args: args{
204+
testStruct: testSliceStruct{},
205+
},
206+
want: expectedTestSliceStruct,
207+
},
208+
}
209+
210+
for _, tt := range tests {
211+
t.Run(tt.name, func(t *testing.T) {
212+
table := schema.Table{
213+
Name: "test",
214+
Columns: schema.ColumnList{},
215+
}
216+
transformer := TransformWithStruct(tt.args.testStruct, tt.args.options...)
217+
if err := transformer(&table); err != nil {
218+
t.Fatal(err)
219+
}
220+
// table, err := NewTableFromStruct("test_struct", tt.args.testStruct, tt.args.options...)
221+
// if (err != nil) != tt.wantErr {
222+
// t.Fatalf("error = %v, wantErr %v", err, tt.wantErr)
223+
// }
224+
// if tt.wantErr {
225+
// return
226+
// }
227+
if diff := cmp.Diff(table.Columns, tt.want.Columns,
228+
cmpopts.IgnoreFields(schema.Column{}, "Resolver")); diff != "" {
229+
t.Fatalf("table does not match expected. diff (-got, +want): %v", diff)
230+
}
231+
})
232+
}
233+
}

0 commit comments

Comments
 (0)