Skip to content
Open
Prev Previous commit
Next Next commit
fix: some type bugs
  • Loading branch information
qykr committed Mar 10, 2026
commit f6af815c60afbcda4b141f6ff5fb95663def0a2f
159 changes: 146 additions & 13 deletions internal/compiler/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,137 @@ func compatibleParamTypes(a, b *Column) bool {
a.ArrayDims == b.ArrayDims
}

func sameTypeName(a, b *ast.TypeName) bool {
if a == nil || b == nil {
return a == nil && b == nil
}
return a.Catalog == b.Catalog && a.Schema == b.Schema && a.Name == b.Name
}

func matchingFuncCallOverloads(c *catalog.Catalog, call *ast.FuncCall) []catalog.Function {
funs, err := c.ListFuncsByName(call.Func)
if err != nil {
return nil
}

var positional []ast.Node
var named []*ast.NamedArgExpr
if call.Args != nil {
for _, arg := range call.Args.Items {
if narg, ok := arg.(*ast.NamedArgExpr); ok {
named = append(named, narg)
continue
}
if len(named) > 0 {
return nil
}
positional = append(positional, arg)
}
}

var matches []catalog.Function
for _, fun := range funs {
args := fun.InArgs()
var defaults int
var variadic bool
known := map[string]struct{}{}
for _, arg := range args {
if arg.HasDefault {
defaults += 1
}
if arg.Mode == ast.FuncParamVariadic {
variadic = true
defaults += 1
}
if arg.Name != "" {
known[arg.Name] = struct{}{}
}
}

argc := len(named) + len(positional)
if variadic {
if argc < (len(args) - defaults) {
continue
}
} else {
if argc > len(args) || argc < (len(args)-defaults) {
continue
}
}

var unknownArgName bool
for _, expr := range named {
if expr.Name != nil {
if _, found := known[*expr.Name]; !found {
unknownArgName = true
}
}
}
if unknownArgName {
continue
}

matches = append(matches, fun)
}

return matches
}

func stableFuncCallArgType(c *catalog.Catalog, call *ast.FuncCall, argIndex int, argName string) *ast.TypeName {
var stable *ast.TypeName
var seen bool

for _, fun := range matchingFuncCallOverloads(c, call) {
args := fun.InArgs()
var current *ast.TypeName
if argName == "" {
if argIndex >= len(args) {
return nil
}
current = args[argIndex].Type
} else {
for _, arg := range args {
if arg.Name == argName {
current = arg.Type
break
}
}
if current == nil {
return nil
}
}

if !seen {
stable = current
seen = true
continue
}
if !sameTypeName(stable, current) {
return nil
}
}

return stable
}

func resolvedFuncCallArgType(fun *catalog.Function, argIndex int, argName string) *ast.TypeName {
if fun == nil {
return nil
}
if argName == "" {
if argIndex < len(fun.Args) {
return fun.Args[argIndex].Type
}
return nil
}
for _, arg := range fun.Args {
if arg.Name == argName {
return arg.Type
}
}
return nil
}

func mergeResolvedParam(existing, incoming Parameter) Parameter {
if existing.Column == nil {
return incoming
Expand Down Expand Up @@ -93,8 +224,8 @@ func mergeResolvedParam(existing, incoming Parameter) Parameter {

func (comp *Compiler) incompatibleParamRefError(ref paramRef, existing, incoming Parameter) error {
return &sqlerr.Error{
Code: "42P08",
Message: fmt.Sprintf(
Code: "42P08",
Message: fmt.Sprintf(
"parameter $%d has incompatible types: %s, %s",
ref.ref.Number,
comp.paramTypeString(existing.Column),
Expand Down Expand Up @@ -182,6 +313,10 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,

var a []Parameter
seen := map[int]int{}
paramCounts := map[int]int{}
for _, ref := range args {
paramCounts[ref.ref.Number] += 1
}

addParam := func(ref paramRef, p Parameter) error {
if idx, ok := seen[p.Number]; ok {
Expand Down Expand Up @@ -424,8 +559,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
}

case *ast.FuncCall:
fun, err := c.ResolveFuncCall(n)
if err != nil {
fun, resolveErr := c.ResolveFuncCall(n)
if resolveErr != nil {
// Synthesize a function on the fly to avoid returning with an error
// for an unknown Postgres function (e.g. defined in an extension)
var args []*catalog.Argument
Expand Down Expand Up @@ -503,22 +638,20 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
if argName == "" {
if i < len(fun.Args) {
paramName = fun.Args[i].Name
paramType = fun.Args[i].Type
}
} else {
paramName = argName
for _, arg := range fun.Args {
if arg.Name == argName {
paramType = arg.Type
}
}
if paramType == nil {
panic(fmt.Sprintf("named argument %s has no type", paramName))
}
}
if paramName == "" {
paramName = funcName
}
if resolveErr == nil {
if paramCounts[ref.ref.Number] > 1 {
paramType = stableFuncCallArgType(c, n, i, argName)
} else {
paramType = resolvedFuncCallArgType(fun, i, argName)
}
}
if paramType == nil {
paramType = &ast.TypeName{Name: ""}
}
Expand Down
36 changes: 36 additions & 0 deletions internal/compiler/resolve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
"github.com/sqlc-dev/sqlc/internal/engine/sqlite"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
)

Expand Down Expand Up @@ -66,3 +67,38 @@ func TestIncompatibleParamRefErrorFormatsTypeNames(t *testing.T) {
t.Fatalf("unexpected message: %q", sqlErr.Message)
}
}

func TestMergeResolvedParamKeepsFirstNameForCompatibleTypes(t *testing.T) {
t.Parallel()

merged := mergeResolvedParam(
Parameter{Number: 1, Column: &Column{Name: "user", DataType: "text"}},
Parameter{Number: 1, Column: &Column{Name: "student_user", DataType: "text"}},
)

if merged.Column == nil {
t.Fatal("expected merged column")
}
if merged.Column.Name != "user" {
t.Fatalf("expected first inferred name to win, got %q", merged.Column.Name)
}
}

func TestResolvedFuncCallArgType(t *testing.T) {
t.Parallel()

fun := &catalog.Function{Args: []*catalog.Argument{
{Name: "lhs", Type: &ast.TypeName{Name: "int8"}},
{Name: "rhs", Type: &ast.TypeName{Name: "text"}},
}}

if got := resolvedFuncCallArgType(fun, 0, ""); got == nil || got.Name != "int8" {
t.Fatalf("expected positional arg type int8, got %#v", got)
}
if got := resolvedFuncCallArgType(fun, 0, "rhs"); got == nil || got.Name != "text" {
t.Fatalf("expected named arg type text, got %#v", got)
}
if got := resolvedFuncCallArgType(fun, 2, ""); got != nil {
t.Fatalf("expected nil for out-of-range positional arg, got %#v", got)
}
}
Loading