Skip to content

Commit eaae461

Browse files
committed
sqlc.embed: skip object if id is None
1 parent b821701 commit eaae461

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

internal/config.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ type Config struct {
99
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"`
1010
TablePrefix string `json:"table_prefix"`
1111
// When a query uses a table with RLS enforced fields, it will be required to
12-
// parametrized those fields. Associate tables are not covered!
12+
// parametrized those fields. Not covered:
13+
// - Associate tables
14+
// - sqlc.embed()
15+
// - json_agg(tbl.*)
1316
RLSEnforcedFields []string `json:"rls_enforced_fields"`
1417
// Merge queries defined in different files into one output queries.py file
1518
MergeQueryFiles bool `json:"merge_query_files"`

internal/gen.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -118,27 +118,33 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node {
118118
var embedFields []*pyast.Keyword
119119
for _, embed := range f.EmbedFields {
120120
embedFields = append(embedFields, &pyast.Keyword{
121-
Arg: embed.Name,
122-
Value: subscriptNode(
123-
rowVar,
124-
constantInt(idx),
125-
),
121+
Arg: embed.Name,
122+
Value: subscriptNode(rowVar, constantInt(idx)),
126123
})
127124
idx++
128125
}
129126
val = &pyast.Node{
130-
Node: &pyast.Node_Call{
131-
Call: &pyast.Call{
132-
Func: f.Type.Annotation(false),
133-
Keywords: embedFields,
127+
Node: &pyast.Node_Compare{
128+
Compare: &pyast.Compare{
129+
Left: &pyast.Node{
130+
Node: &pyast.Node_Call{
131+
Call: &pyast.Call{
132+
Func: f.Type.Annotation(false),
133+
Keywords: embedFields,
134+
},
135+
},
136+
},
137+
Ops: []*pyast.Node{
138+
poet.Name(fmt.Sprintf("if row[%d] else", idx-len(f.EmbedFields))),
139+
},
140+
Comparators: []*pyast.Node{
141+
poet.Constant(nil),
142+
},
134143
},
135144
},
136145
}
137146
} else {
138-
val = subscriptNode(
139-
rowVar,
140-
constantInt(idx),
141-
)
147+
val = subscriptNode(rowVar, constantInt(idx))
142148
idx++
143149
}
144150
call.Keywords = append(call.Keywords, &pyast.Keyword{

0 commit comments

Comments
 (0)