@@ -23,7 +23,7 @@ func paramsInLimitExpr(limit *sqlparser.Limit, s *Schema, tableAliasMap FromTabl
2323 return params , nil
2424 }
2525
26- parseLimitSubExp := func (node sqlparser.Expr ) {
26+ parseLimitSubExp := func (node sqlparser.Expr ) error {
2727 switch v := node .(type ) {
2828 case * sqlparser.SQLVal :
2929 if v .Type == sqlparser .ValArg {
@@ -33,11 +33,30 @@ func paramsInLimitExpr(limit *sqlparser.Limit, s *Schema, tableAliasMap FromTabl
3333 Typ : "uint32" ,
3434 })
3535 }
36+ case * sqlparser.FuncExpr :
37+ name , raw , err := matchFuncExpr (v )
38+ if err != nil {
39+ return err
40+ }
41+ if name != "" && raw != "" {
42+ params = append (params , & Param {
43+ OriginalName : raw ,
44+ Name : name ,
45+ Typ : "uint32" ,
46+ })
47+ }
3648 }
49+ return nil
3750 }
3851
39- parseLimitSubExp (limit .Offset )
40- parseLimitSubExp (limit .Rowcount )
52+ err := parseLimitSubExp (limit .Offset )
53+ if err != nil {
54+ return nil , err
55+ }
56+ err = parseLimitSubExp (limit .Rowcount )
57+ if err != nil {
58+ return nil , err
59+ }
4160
4261 return params , nil
4362}
@@ -115,13 +134,26 @@ func paramInComparison(cond *sqlparser.ComparisonExpr, s *Schema, tableAliasMap
115134 if v .Type == sqlparser .ValArg {
116135 p .OriginalName = string (v .Val )
117136 }
137+ case * sqlparser.FuncExpr :
138+ name , raw , err := matchFuncExpr (v )
139+ if err != nil {
140+ return false , err
141+ }
142+ if name != "" && raw != "" {
143+ p .OriginalName = raw
144+ p .Name = name
145+ }
146+ return false , nil
118147 }
119148 return true , nil
120149 }
121150 err := sqlparser .Walk (walker , cond )
122151 if err != nil {
123152 return nil , false , err
124153 }
154+ if p .Name != "" {
155+ return p , true , nil
156+ }
125157 if p .OriginalName != "" && p .Typ != "" {
126158 p .Name = paramName (colIdent , p .OriginalName )
127159 return p , true , nil
@@ -143,11 +175,39 @@ func paramName(col sqlparser.ColIdent, originalName string) string {
143175
144176func replaceParamStrs (query string , params []* Param ) (string , error ) {
145177 for _ , p := range params {
146- re , err := regexp .Compile (fmt .Sprintf ("(%v)" , p .OriginalName ))
178+ re , err := regexp .Compile (fmt .Sprintf ("(%v)" , regexp . QuoteMeta ( p .OriginalName ) ))
147179 if err != nil {
148180 return "" , err
149181 }
150182 query = re .ReplaceAllString (query , "?" )
151183 }
152184 return query , nil
153185}
186+
187+ func matchFuncExpr (v * sqlparser.FuncExpr ) (name string , raw string , err error ) {
188+ namespace := "sqlc"
189+ fakeFunc := "arg"
190+ if v .Qualifier .String () == namespace {
191+ if v .Name .String () == fakeFunc {
192+ if expr , ok := v .Exprs [0 ].(* sqlparser.AliasedExpr ); ok {
193+ if colName , ok := expr .Expr .(* sqlparser.ColName ); ok {
194+ customName := colName .Name .String ()
195+ return customName , fmt .Sprintf ("%s.%s(%s)" , namespace , fakeFunc , customName ), nil
196+ }
197+ return "" , "" , fmt .Errorf ("invalid custom argument value \" %s.%s(%s)\" " , namespace , fakeFunc , replaceVParamExprs (sqlparser .String (v .Exprs [0 ])))
198+ }
199+ return "" , "" , fmt .Errorf ("invalid custom argument value \" %s.%s(%s)\" " , namespace , fakeFunc , replaceVParamExprs (sqlparser .String (v .Exprs [0 ])))
200+ }
201+ return "" , "" , fmt .Errorf ("invalid function call \" %s.%s\" , did you mean \" %s.%s\" ?" , namespace , v .Name .String (), namespace , fakeFunc )
202+ }
203+ return "" , "" , nil
204+ }
205+
206+ func replaceVParamExprs (sql string ) string {
207+ /*
208+ the sqlparser replaces "?" with ":v1"
209+ to display a helpful error message, these should be replaced back to "?"
210+ */
211+ matcher := regexp .MustCompile (":v[0-9]*" )
212+ return matcher .ReplaceAllString (sql , "?" )
213+ }
0 commit comments