Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix(compiler): correctly validate alias in order/group by clauses for…
… joins

Resolves #1886
Resolves #2398
Resolves #2399
  • Loading branch information
andrewmbenton committed Jul 27, 2023
commit 5bd0257b6ec3cd5e9347555f29846d437df6c578
18 changes: 10 additions & 8 deletions internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er

if n.GroupClause != nil {
for _, item := range n.GroupClause.Items {
if err := findColumnForNode(item, tables, n); err != nil {
if err := findColumnForNode(item, tables, targets); err != nil {
return nil, err
}
}
Expand All @@ -85,7 +85,7 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
if !ok {
continue
}
if err := findColumnForNode(sb.Node, tables, n); err != nil {
if err := findColumnForNode(sb.Node, tables, targets); err != nil {
return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err)
}
}
Expand All @@ -101,7 +101,7 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
if !ok {
continue
}
if err := findColumnForNode(caseExpr.Xpr, tables, n); err != nil {
if err := findColumnForNode(caseExpr.Xpr, tables, targets); err != nil {
return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err)
}
}
Expand Down Expand Up @@ -650,15 +650,15 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef)
return cols, nil
}

func findColumnForNode(item ast.Node, tables []*Table, n *ast.SelectStmt) error {
func findColumnForNode(item ast.Node, tables []*Table, targetList *ast.List) error {
ref, ok := item.(*ast.ColumnRef)
if !ok {
return nil
}
return findColumnForRef(ref, tables, n)
return findColumnForRef(ref, tables, targetList)
}

func findColumnForRef(ref *ast.ColumnRef, tables []*Table, selectStatement *ast.SelectStmt) error {
func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List) error {
parts := stringSlice(ref.Fields)
var alias, name string
if len(parts) == 1 {
Expand Down Expand Up @@ -686,9 +686,11 @@ func findColumnForRef(ref *ast.ColumnRef, tables []*Table, selectStatement *ast.
if foundColumn {
continue
}
}

// Find matching alias
for _, c := range selectStatement.TargetList.Items {
// Find matching alias if necessary
if found == 0 {
for _, c := range targetList.Items {
resTarget, ok := c.(*ast.ResTarget)
if !ok {
continue
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CREATE TABLE foo (email text not null);

-- name: ColumnAsOrderBy :many
SELECT a.email AS id
FROM foo a JOIN foo b ON a.email = b.email
ORDER BY id;

-- name: ColumnAsGroupBy :many
SELECT a.email AS id
FROM foo a JOIN foo b ON a.email = b.email
GROUP BY id;
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"version": "1",
"packages": [
{
"engine": "postgresql",
"path": "go",
"name": "querytest",
"schema": "query.sql",
"queries": "query.sql"
}
]
}