Skip to content

Commit 3764929

Browse files
committed
Implement pattern matching
1 parent d3c90cc commit 3764929

4 files changed

Lines changed: 264 additions & 183 deletions

File tree

internal/cmd/shim.go

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,51 @@
11
package cmd
22

33
import (
4+
"strings"
5+
46
"github.com/kyleconroy/sqlc/internal/compiler"
57
"github.com/kyleconroy/sqlc/internal/config"
68
"github.com/kyleconroy/sqlc/internal/plugin"
79
"github.com/kyleconroy/sqlc/internal/sql/catalog"
810
)
911

12+
func pluginOverride(o config.Override) *plugin.Override {
13+
var column string
14+
var table plugin.Identifier
15+
16+
if o.Column != "" {
17+
colParts := strings.Split(o.Column, ".")
18+
switch len(colParts) {
19+
case 2:
20+
table.Schema = "public"
21+
table.Name = colParts[0]
22+
column = colParts[1]
23+
case 3:
24+
table.Schema = colParts[0]
25+
table.Name = colParts[1]
26+
column = colParts[2]
27+
case 4:
28+
table.Catalog = colParts[0]
29+
table.Schema = colParts[1]
30+
table.Name = colParts[2]
31+
column = colParts[3]
32+
}
33+
}
34+
return &plugin.Override{
35+
CodeType: "", // FIXME
36+
DbType: o.DBType,
37+
Nullable: o.Nullable,
38+
Column: o.Column,
39+
ColumnName: column,
40+
Table: &table,
41+
PythonType: pluginPythonType(o.PythonType),
42+
}
43+
}
44+
1045
func pluginSettings(cs config.CombinedSettings) *plugin.Settings {
1146
var over []*plugin.Override
1247
for _, o := range cs.Overrides {
13-
over = append(over, &plugin.Override{
14-
CodeType: "", // FIXME
15-
DbType: o.DBType,
16-
Nullable: o.Nullable,
17-
Column: o.Column,
18-
PythonType: pluginPythonType(o.PythonType),
19-
})
48+
over = append(over, pluginOverride(o))
2049
}
2150
return &plugin.Settings{
2251
Version: cs.Global.Version,

internal/codegen/python/gen.go

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"github.com/kyleconroy/sqlc/internal/codegen"
1212
"github.com/kyleconroy/sqlc/internal/config"
1313
"github.com/kyleconroy/sqlc/internal/core"
14-
"github.com/kyleconroy/sqlc/internal/debug"
1514
"github.com/kyleconroy/sqlc/internal/inflection"
1615
"github.com/kyleconroy/sqlc/internal/metadata"
1716
"github.com/kyleconroy/sqlc/internal/plugin"
@@ -194,11 +193,10 @@ func pyInnerType(req *plugin.CodeGenRequest, col *plugin.Column) string {
194193
if !pyTypeIsSet(oride.PythonType) {
195194
continue
196195
}
197-
// TODO: What do we do about regexs?
198-
// sameTable := oride.Matches(col.Table, req.Catalog.DefaultSchema)
199-
// if oride.Column != "" && oride.ColumnName.MatchString(col.Name) && sameTable {
200-
// return pyTypeString(oride.PythonType)
201-
// }
196+
sameTable := matches(oride, col.Table, req.Catalog.DefaultSchema)
197+
if oride.Column != "" && matchString(oride.ColumnName, col.Name) && sameTable {
198+
return pyTypeString(oride.PythonType)
199+
}
202200
if oride.DbType != "" && oride.DbType == col.DataType && oride.Nullable != (col.NotNull || col.IsArray) {
203201
return pyTypeString(oride.PythonType)
204202
}
@@ -213,6 +211,48 @@ func pyInnerType(req *plugin.CodeGenRequest, col *plugin.Column) string {
213211
}
214212
}
215213

214+
func matchString(pattern, target string) bool {
215+
// TODO: Create a separate package for the matchers
216+
matcher, err := config.MatchCompile(pattern)
217+
if err != nil {
218+
panic(err)
219+
}
220+
return matcher.MatchString(target)
221+
}
222+
223+
func matches(o *plugin.Override, n *plugin.Identifier, defaultSchema string) bool {
224+
if n == nil {
225+
return false
226+
}
227+
228+
schema := n.Schema
229+
if n.Schema == "" {
230+
schema = defaultSchema
231+
}
232+
233+
if o.Table.Catalog != "" && !matchString(o.Table.Catalog, n.Catalog) {
234+
return false
235+
}
236+
237+
if o.Table.Schema == "" && schema != "" {
238+
return false
239+
}
240+
241+
if o.Table.Schema != "" && !matchString(o.Table.Schema, schema) {
242+
return false
243+
}
244+
245+
if o.Table.Name == "" && n.Name != "" {
246+
return false
247+
}
248+
249+
if o.Table.Name != "" && !matchString(o.Table.Name, n.Name) {
250+
return false
251+
}
252+
253+
return true
254+
}
255+
216256
func modelName(name string, settings *plugin.Settings) string {
217257
if rename := settings.Rename[name]; rename != "" {
218258
return rename
@@ -413,12 +453,6 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
413453
SourceName: query.Filename,
414454
}
415455

416-
dump := methodName == "get_venue"
417-
if dump {
418-
debug.Dump(query)
419-
debug.Dump(gq)
420-
}
421-
422456
if len(query.Params) > 4 {
423457
var cols []pyColumn
424458
for _, p := range query.Params {
@@ -467,10 +501,6 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
467501
sameName := f.Name == columnName(c, i)
468502
sameType := f.Type == trimmedPyType
469503
sameTable := sameTableName(c.Table, s.Table, req.Catalog.DefaultSchema)
470-
if dump {
471-
debug.Dump(c.Table, s.Table, req.Catalog.DefaultSchema)
472-
debug.Dump(sameName, sameType, sameTable)
473-
}
474504
if !sameName || !sameType || !sameTable {
475505
same = false
476506
}

0 commit comments

Comments
 (0)