diff --git a/.github/workflows/buf.yml b/.github/workflows/buf.yml index 7b7ef08214..fabfc40023 100644 --- a/.github/workflows/buf.yml +++ b/.github/workflows/buf.yml @@ -4,6 +4,6 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: bufbuild/buf-setup-action@v1 - uses: bufbuild/buf-lint-action@v1 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ba18f8c832..b15a5ea75b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -9,10 +9,10 @@ jobs: name: build ${{ matrix.os }} runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: - go-version: '1.25.0' + go-version: '1.26.0' - name: install ./... run: go build ./... env: diff --git a/.github/workflows/ci-kotlin.yml b/.github/workflows/ci-kotlin.yml index 729bd3cc86..a324917ed7 100644 --- a/.github/workflows/ci-kotlin.yml +++ b/.github/workflows/ci-kotlin.yml @@ -10,13 +10,13 @@ jobs: name: test runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: - go-version: '1.24.1' + go-version: '1.26.0' - name: install ./... run: go install ./... - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: repository: sqlc-dev/sqlc-gen-kotlin path: kotlin diff --git a/.github/workflows/ci-python.yml b/.github/workflows/ci-python.yml index 0eb11aeaae..a59bd402c3 100644 --- a/.github/workflows/ci-python.yml +++ b/.github/workflows/ci-python.yml @@ -10,13 +10,13 @@ jobs: name: test runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: - go-version: '1.24.1' + go-version: '1.26.0' - name: install ./... run: go install ./... - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: repository: sqlc-dev/sqlc-gen-python path: python diff --git a/.github/workflows/ci-typescript.yml b/.github/workflows/ci-typescript.yml index 191e5949bd..7ec747a91f 100644 --- a/.github/workflows/ci-typescript.yml +++ b/.github/workflows/ci-typescript.yml @@ -10,13 +10,13 @@ jobs: name: test runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: - go-version: '1.24.1' + go-version: '1.26.0' - name: install ./... run: go install ./... - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: repository: sqlc-dev/sqlc-gen-typescript path: typescript diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ca36cef036..1ee0a8f696 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,10 +13,10 @@ jobs: name: build ${{ matrix.goos }}/${{ matrix.goarch }} runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: - go-version: '1.25.0' + go-version: '1.26.0' - run: go build ./... env: CGO_ENABLED: "0" @@ -25,10 +25,10 @@ jobs: test: runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: - go-version: '1.25.0' + go-version: '1.26.0' - name: install gotestsum run: go install gotest.tools/gotestsum@latest @@ -50,13 +50,20 @@ jobs: env: CGO_ENABLED: "0" + - name: install databases + run: go run ./cmd/sqlc-test-setup install + + - name: start databases + run: go run ./cmd/sqlc-test-setup start + - name: test ./... - run: gotestsum --junitfile junit.xml -- --tags=examples -timeout 20m ./... - if: ${{ matrix.os }} != "windows-2022" + run: gotestsum --junitfile junit.xml -- --tags=examples -timeout 20m -failfast ./... env: CI_SQLC_PROJECT_ID: ${{ secrets.CI_SQLC_PROJECT_ID }} CI_SQLC_AUTH_TOKEN: ${{ secrets.CI_SQLC_AUTH_TOKEN }} SQLC_AUTH_TOKEN: ${{ secrets.CI_SQLC_AUTH_TOKEN }} + POSTGRESQL_SERVER_URI: "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable" + MYSQL_SERVER_URI: "root:mysecretpassword@tcp(127.0.0.1:3306)/mysql?multiStatements=true&parseTime=true" CGO_ENABLED: "0" vuln_check: diff --git a/.github/workflows/gen.yml b/.github/workflows/gen.yml index 8d2c69a7e8..eb83825c39 100644 --- a/.github/workflows/gen.yml +++ b/.github/workflows/gen.yml @@ -17,7 +17,7 @@ jobs: # needed because the postgres container does not provide a healthcheck options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version-file: go.mod @@ -32,7 +32,7 @@ jobs: PG_PASSWORD: postgres PG_PORT: ${{ job.services.postgres.ports['5432'] }} - name: Save results - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: sqlc-pg-gen-results path: gen diff --git a/CLAUDE.md b/CLAUDE.md index 9d637256a1..46c623bebf 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -6,41 +6,78 @@ This document provides essential information for working with the sqlc codebase, ### Prerequisites -- **Go 1.25.0+** - Required for building and testing -- **Docker & Docker Compose** - Required for integration tests with databases +- **Go 1.26.0+** - Required for building and testing +- **Docker & Docker Compose** - Required for integration tests with databases (local development) - **Git** - For version control -### Running Tests +## Database Setup with sqlc-test-setup -#### Basic Unit Tests (No Database Required) +The `sqlc-test-setup` tool (`cmd/sqlc-test-setup/`) automates installing and starting PostgreSQL and MySQL for tests. Both commands are idempotent and safe to re-run. + +### Install databases ```bash -# Simplest approach - runs all unit tests -go test ./... +go run ./cmd/sqlc-test-setup install +``` + +This will: +- Configure the apt proxy (if `http_proxy` is set, e.g. in Claude Code remote environments) +- Install PostgreSQL via apt +- Download and install MySQL 9 from Oracle's deb bundle +- Resolve all dependencies automatically +- Skip anything already installed -# Using make -make test +### Start databases + +```bash +go run ./cmd/sqlc-test-setup start ``` -#### Full Test Suite with Integration Tests +This will: +- Start PostgreSQL and configure password auth (`postgres`/`postgres`) +- Start MySQL via `mysqld_safe` and set root password (`mysecretpassword`) +- Verify both connections +- Skip steps that are already done (running services, existing config) + +Connection URIs after start: +- PostgreSQL: `postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable` +- MySQL: `root:mysecretpassword@tcp(127.0.0.1:3306)/mysql` + +### Run tests ```bash -# Step 1: Start database containers -docker compose up -d +# Full test suite (requires databases running) +go test --tags=examples -timeout 20m ./... +``` + +## Running Tests + +### Basic Unit Tests (No Database Required) + +```bash +go test ./... +``` -# Step 2: Run all tests including examples +### Full Test Suite with Docker (Local Development) + +```bash +docker compose up -d go test --tags=examples -timeout 20m ./... +``` -# Or use make for the full CI suite -make test-ci +### Full Test Suite without Docker (Remote / CI) + +```bash +go run ./cmd/sqlc-test-setup install +go run ./cmd/sqlc-test-setup start +go test --tags=examples -timeout 20m ./... ``` -#### Running Specific Tests +### Running Specific Tests ```bash # Test a specific package go test ./internal/config -go test ./internal/compiler # Run with verbose output go test -v ./internal/config @@ -94,21 +131,6 @@ The `docker-compose.yml` provides test databases: - Password: `mysecretpassword` - Database: `dinotest` -### Managing Databases - -```bash -# Start databases -make start -# or -docker compose up -d - -# Stop databases -docker compose down - -# View logs -docker compose logs -f -``` - ## Makefile Targets ```bash @@ -125,23 +147,11 @@ make start # Start database containers ### GitHub Actions Workflow - **File:** `.github/workflows/ci.yml` -- **Go Version:** 1.25.0 +- **Go Version:** 1.26.0 +- **Database Setup:** Uses `sqlc-test-setup` (not Docker) to install and start PostgreSQL and MySQL directly on the runner - **Test Command:** `gotestsum --junitfile junit.xml -- --tags=examples -timeout 20m ./...` - **Additional Checks:** `govulncheck` for vulnerability scanning -### Running Tests Like CI Locally - -```bash -# Install CI tools (optional) -go install gotest.tools/gotestsum@latest - -# Run tests with same timeout as CI -go test --tags=examples -timeout 20m ./... - -# Or use the CI make target -make test-ci -``` - ## Development Workflow ### Building Development Versions @@ -156,37 +166,18 @@ go build -o ~/go/bin/sqlc-gen-json ./cmd/sqlc-gen-json ### Environment Variables for Tests -You can customize database connections: +You can override database connections via environment variables: -**PostgreSQL:** ```bash -PG_HOST=127.0.0.1 -PG_PORT=5432 -PG_USER=postgres -PG_PASSWORD=mysecretpassword -PG_DATABASE=dinotest -``` - -**MySQL:** -```bash -MYSQL_HOST=127.0.0.1 -MYSQL_PORT=3306 -MYSQL_USER=root -MYSQL_ROOT_PASSWORD=mysecretpassword -MYSQL_DATABASE=dinotest -``` - -**Example:** -```bash -POSTGRESQL_SERVER_URI="postgres://postgres:mysecretpassword@localhost:5432/postgres" \ - go test -v ./... +POSTGRESQL_SERVER_URI="postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable" +MYSQL_SERVER_URI="root:mysecretpassword@tcp(127.0.0.1:3306)/mysql?multiStatements=true&parseTime=true" ``` ## Code Structure ### Key Directories -- `/cmd/` - Main binaries (sqlc, sqlc-gen-json) +- `/cmd/` - Main binaries (sqlc, sqlc-gen-json, sqlc-test-setup) - `/internal/cmd/` - Command implementations (vet, generate, etc.) - `/internal/engine/` - Database engine implementations - `/postgresql/` - PostgreSQL parser and converter @@ -196,6 +187,7 @@ POSTGRESQL_SERVER_URI="postgres://postgres:mysecretpassword@localhost:5432/postg - `/internal/codegen/` - Code generation for different languages - `/internal/config/` - Configuration file parsing - `/internal/endtoend/` - End-to-end tests +- `/internal/sqltest/` - Test database setup (Docker, native, local detection) - `/examples/` - Example projects for testing ### Important Files @@ -203,13 +195,12 @@ POSTGRESQL_SERVER_URI="postgres://postgres:mysecretpassword@localhost:5432/postg - `/Makefile` - Build and test targets - `/docker-compose.yml` - Database services for testing - `/.github/workflows/ci.yml` - CI configuration -- `/docs/guides/development.md` - Developer documentation ## Common Issues & Solutions ### Network Connectivity Issues -If you see errors about `storage.googleapis.com`, the Go proxy may be unreachable. Tests may still pass for packages that don't require network dependencies. +If you see errors about `storage.googleapis.com`, the Go proxy may be unreachable. Use `GOPROXY=direct go mod download` to fetch modules directly from source. ### Test Timeouts @@ -227,19 +218,23 @@ go test -race ./... ### Database Connection Failures -Ensure Docker containers are running: +If using Docker: ```bash docker compose ps docker compose up -d ``` +If using sqlc-test-setup: +```bash +go run ./cmd/sqlc-test-setup start +``` + ## Tips for Contributors -1. **Run tests before committing:** `make test-ci` +1. **Run tests before committing:** `go test --tags=examples -timeout 20m ./...` 2. **Check for race conditions:** Use `-race` flag when testing concurrent code 3. **Use specific package tests:** Faster iteration during development -4. **Start databases early:** `docker compose up -d` before running integration tests -5. **Read existing tests:** Good examples in `/internal/engine/postgresql/*_test.go` +4. **Read existing tests:** Good examples in `/internal/engine/postgresql/*_test.go` ## Git Workflow @@ -251,34 +246,18 @@ docker compose up -d ### Committing Changes ```bash -# Stage changes git add - -# Commit with descriptive message -git commit -m "Brief description - -Detailed explanation of changes. - -🤖 Generated with [Claude Code](https://claude.com/claude-code) - -Co-Authored-By: Claude " - -# Push to remote +git commit -m "Brief description of changes" git push -u origin ``` ### Rebasing ```bash -# Update main git checkout main git pull origin main - -# Rebase feature branch git checkout git rebase main - -# Force push rebased branch git push --force-with-lease origin ``` @@ -288,21 +267,3 @@ git push --force-with-lease origin - **Development Guide:** `/docs/guides/development.md` - **CI Configuration:** `/.github/workflows/ci.yml` - **Docker Compose:** `/docker-compose.yml` - -## Recent Fixes & Improvements - -### Fixed Issues - -1. **Typo in create_function_stmt.go** - Fixed "Undertand" → "Understand" -2. **Race condition in vet.go** - Fixed Client initialization using `sync.Once` -3. **Nil pointer dereference in parse.go** - Fixed unsafe type assertion in primary key parsing - -These fixes demonstrate common patterns: -- Using `sync.Once` for thread-safe lazy initialization -- Using comma-ok idiom for safe type assertions: `if val, ok := x.(Type); ok { ... }` -- Adding proper nil checks and defensive programming - ---- - -**Last Updated:** 2025-10-21 -**Maintainer:** Claude Code diff --git a/Dockerfile b/Dockerfile index 06d3008d07..0c2b2595e3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # STEP 1: Build sqlc -FROM golang:1.25.3 AS builder +FROM golang:1.26.1 AS builder COPY . /workspace WORKDIR /workspace diff --git a/cmd/sqlc-test-setup/main.go b/cmd/sqlc-test-setup/main.go new file mode 100644 index 0000000000..2a0d04dc5b --- /dev/null +++ b/cmd/sqlc-test-setup/main.go @@ -0,0 +1,705 @@ +package main + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" +) + +const ( + // pgVersion is the PostgreSQL version to install. + pgVersion = "18.2.0" +) + +// pgBinary contains the download information for a PostgreSQL binary release. +type pgBinary struct { + URL string + SHA256 string +} + +// pgBinaries maps "/" to the corresponding binary download info. +var pgBinaries = map[string]pgBinary{ + "linux/amd64": { + URL: "https://github.com/theseus-rs/postgresql-binaries/releases/download/" + pgVersion + "/postgresql-" + pgVersion + "-x86_64-unknown-linux-gnu.tar.gz", + SHA256: "cc2674e1641aa2a62b478971a22c131a768eb783f313e6a3385888f58a604074", + }, + "linux/arm64": { + URL: "https://github.com/theseus-rs/postgresql-binaries/releases/download/" + pgVersion + "/postgresql-" + pgVersion + "-aarch64-unknown-linux-gnu.tar.gz", + SHA256: "8b415a11c7a5484e5fbf7a57fca71554d2d1d7acd34faf066606d2fee1261854", + }, +} + +func main() { + log.SetFlags(log.Ltime) + log.SetPrefix("[sqlc-test-setup] ") + + if len(os.Args) < 2 { + fmt.Fprintln(os.Stderr, "usage: sqlc-test-setup ") + os.Exit(1) + } + + switch os.Args[1] { + case "install": + if err := runInstall(); err != nil { + log.Fatalf("install failed: %s", err) + } + case "start": + if err := runStart(); err != nil { + log.Fatalf("start failed: %s", err) + } + default: + fmt.Fprintf(os.Stderr, "unknown command: %s\nusage: sqlc-test-setup \n", os.Args[1]) + os.Exit(1) + } +} + +// run executes a command with verbose logging, streaming output to stderr. +func run(name string, args ...string) error { + log.Printf("exec: %s %s", name, strings.Join(args, " ")) + cmd := exec.Command(name, args...) + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + return cmd.Run() +} + +// runOutput executes a command and returns its combined output. +func runOutput(name string, args ...string) (string, error) { + log.Printf("exec: %s %s", name, strings.Join(args, " ")) + cmd := exec.Command(name, args...) + out, err := cmd.CombinedOutput() + return string(out), err +} + +// commandExists checks if a binary is available in PATH. +func commandExists(name string) bool { + _, err := exec.LookPath(name) + return err == nil +} + +// isMySQLVersionOK checks if the mysqld --version output indicates MySQL 9+. +// Example version string: "/usr/sbin/mysqld Ver 8.0.44-0ubuntu0.24.04.2 ..." +func isMySQLVersionOK(versionOutput string) bool { + // Look for "Ver X.Y.Z" pattern + fields := strings.Fields(versionOutput) + for i, f := range fields { + if strings.EqualFold(f, "Ver") && i+1 < len(fields) { + ver := strings.Split(fields[i+1], ".") + if len(ver) > 0 { + major := strings.TrimLeft(ver[0], "0") + if major == "" { + return false + } + return major[0] >= '9' + } + } + } + return false +} + +// pgBaseDir returns the sqlc-specific directory where PostgreSQL is installed, +// using the user's cache directory (~/.cache/sqlc/postgresql on Linux). +func pgBaseDir() string { + cacheDir, err := os.UserCacheDir() + if err != nil { + cacheDir = filepath.Join(os.Getenv("HOME"), ".cache") + } + return filepath.Join(cacheDir, "sqlc", "postgresql") +} + +// pgBinDir returns the path to the PostgreSQL bin directory. +func pgBinDir() string { + return filepath.Join(pgBaseDir(), "bin") +} + +// pgDataDir returns the path to the PostgreSQL data directory. +func pgDataDir() string { + return filepath.Join(pgBaseDir(), "data") +} + +// pgBin returns the full path to a PostgreSQL binary. +func pgBin(name string) string { + return filepath.Join(pgBinDir(), name) +} + +// ---- install ---- + +func runInstall() error { + log.Println("=== Installing PostgreSQL and MySQL for test setup ===") + + if err := installAptProxy(); err != nil { + return fmt.Errorf("configuring apt proxy: %w", err) + } + + if err := installPostgreSQL(); err != nil { + return fmt.Errorf("installing postgresql: %w", err) + } + + if err := installMySQL(); err != nil { + return fmt.Errorf("installing mysql: %w", err) + } + + log.Println("=== Install complete ===") + return nil +} + +func installAptProxy() error { + proxy := os.Getenv("http_proxy") + if proxy == "" { + log.Println("http_proxy is not set, skipping apt proxy configuration") + return nil + } + + const confPath = "/etc/apt/apt.conf.d/99proxy" + if _, err := os.Stat(confPath); err == nil { + log.Printf("apt proxy config already exists at %s, skipping", confPath) + return nil + } + + log.Printf("configuring apt proxy to use %s", proxy) + proxyConf := fmt.Sprintf("Acquire::http::Proxy \"%s\";", proxy) + cmd := fmt.Sprintf("echo '%s' | sudo tee /etc/apt/apt.conf.d/99proxy", proxyConf) + return run("bash", "-c", cmd) +} + +func installPostgreSQL() error { + log.Println("--- Installing PostgreSQL ---") + + // Install runtime dependencies needed by PostgreSQL extensions (e.g. + // uuid-ossp requires libossp-uuid16). + if err := installPgDeps(); err != nil { + return fmt.Errorf("installing postgresql dependencies: %w", err) + } + + // Check if already installed in our directory + if _, err := os.Stat(pgBin("postgres")); err == nil { + out, err := runOutput(pgBin("postgres"), "--version") + if err == nil { + log.Printf("postgresql is already installed: %s", strings.TrimSpace(out)) + log.Println("skipping postgresql installation") + return nil + } + } + + platform := runtime.GOOS + "/" + runtime.GOARCH + bin, ok := pgBinaries[platform] + if !ok { + return fmt.Errorf("unsupported platform: %s (supported: %s)", platform, supportedPlatforms()) + } + + // Download to a temp file + tarball := filepath.Join(os.TempDir(), fmt.Sprintf("postgresql-%s.tar.gz", pgVersion)) + + if _, err := os.Stat(tarball); err != nil { + log.Printf("downloading PostgreSQL %s from %s", pgVersion, bin.URL) + if err := downloadFile(tarball, bin.URL); err != nil { + os.Remove(tarball) + return fmt.Errorf("downloading postgresql: %w", err) + } + } else { + log.Printf("postgresql tarball already downloaded at %s", tarball) + } + + // Verify SHA256 checksum + log.Printf("verifying SHA256 checksum") + actualHash, err := sha256File(tarball) + if err != nil { + return fmt.Errorf("computing sha256: %w", err) + } + if actualHash != bin.SHA256 { + os.Remove(tarball) + return fmt.Errorf("SHA256 mismatch: expected %s, got %s", bin.SHA256, actualHash) + } + log.Printf("SHA256 checksum verified: %s", actualHash) + + baseDir := pgBaseDir() + + // Create the base directory in the user cache + if err := os.MkdirAll(baseDir, 0o755); err != nil { + return fmt.Errorf("creating %s: %w", baseDir, err) + } + + // Extract the tarball - it contains a top-level directory like + // postgresql-18.2.0-x86_64-unknown-linux-gnu/ with bin/, lib/, share/ inside. + // We strip that top-level directory and extract directly into the base dir. + log.Printf("extracting postgresql to %s", baseDir) + if err := run("tar", "-xzf", tarball, "-C", baseDir, "--strip-components=1"); err != nil { + return fmt.Errorf("extracting postgresql: %w", err) + } + + // Verify the binary works + out, err := runOutput(pgBin("postgres"), "--version") + if err != nil { + return fmt.Errorf("postgres --version failed after install: %w", err) + } + log.Printf("postgresql installed successfully: %s", strings.TrimSpace(out)) + return nil +} + +// installPgDeps installs shared libraries required by PostgreSQL extensions at +// runtime (e.g. libossp-uuid16 for uuid-ossp). +func installPgDeps() error { + log.Println("installing postgresql runtime dependencies") + if err := run("sudo", "apt-get", "install", "-y", "--no-install-recommends", "libossp-uuid16"); err != nil { + return fmt.Errorf("apt-get install libossp-uuid16: %w", err) + } + return nil +} + +// supportedPlatforms returns a comma-separated list of supported platforms. +func supportedPlatforms() string { + platforms := make([]string, 0, len(pgBinaries)) + for p := range pgBinaries { + platforms = append(platforms, p) + } + return strings.Join(platforms, ", ") +} + +// downloadFile downloads a URL to a local file path. +func downloadFile(filepath string, url string) error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status) + } + + out, err := os.Create(filepath) + if err != nil { + return err + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + return err +} + +// sha256File computes the SHA256 hash of a file and returns the hex string. +func sha256File(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil +} + +func installMySQL() error { + log.Println("--- Installing MySQL 9 ---") + + if commandExists("mysqld") { + out, err := runOutput("mysqld", "--version") + if err == nil { + version := strings.TrimSpace(out) + log.Printf("mysql is already installed: %s", version) + if isMySQLVersionOK(version) { + log.Println("mysql version is 9+, skipping installation") + return nil + } + log.Println("mysql version is too old, upgrading to MySQL 9") + // Stop existing MySQL before upgrading + _ = exec.Command("sudo", "service", "mysql", "stop").Run() + _ = exec.Command("sudo", "pkill", "-f", "mysqld").Run() + time.Sleep(2 * time.Second) + // Remove old MySQL packages to avoid conflicts + log.Println("removing old mysql packages") + _ = run("sudo", "apt-get", "remove", "-y", "mysql-server", "mysql-client", "mysql-common", + "mysql-server-core-*", "mysql-client-core-*") + // Clear old data directory so MySQL 9 can initialize fresh + log.Println("clearing old mysql data directory") + _ = run("sudo", "rm", "-rf", "/var/lib/mysql") + _ = run("sudo", "mkdir", "-p", "/var/lib/mysql") + _ = run("sudo", "chown", "mysql:mysql", "/var/lib/mysql") + } + } + + bundleURL := "https://dev.mysql.com/get/Downloads/MySQL-9.1/mysql-server_9.1.0-1ubuntu24.04_amd64.deb-bundle.tar" + bundleTar := "/tmp/mysql-server-bundle.tar" + extractDir := "/tmp/mysql9" + + if _, err := os.Stat(bundleTar); err != nil { + log.Printf("downloading MySQL 9 bundle from %s", bundleURL) + if err := run("curl", "-L", "-o", bundleTar, bundleURL); err != nil { + return fmt.Errorf("downloading mysql bundle: %w", err) + } + } else { + log.Printf("mysql bundle already downloaded at %s, skipping download", bundleTar) + } + + log.Printf("extracting bundle to %s", extractDir) + if err := os.MkdirAll(extractDir, 0o755); err != nil { + return fmt.Errorf("creating extract dir: %w", err) + } + if err := run("tar", "-xf", bundleTar, "-C", extractDir); err != nil { + return fmt.Errorf("extracting mysql bundle: %w", err) + } + + // Install packages in dependency order using dpkg. + // Some packages may fail due to missing dependencies, which is expected. + // We fix them all at the end with apt-get install -f. + packages := []string{ + "mysql-common_*.deb", + "mysql-community-client-plugins_*.deb", + "mysql-community-client-core_*.deb", + "mysql-community-client_*.deb", + "mysql-client_*.deb", + "mysql-community-server-core_*.deb", + "mysql-community-server_*.deb", + "mysql-server_*.deb", + } + + for _, pkg := range packages { + log.Printf("installing %s (dependency errors will be fixed afterwards)", pkg) + cmd := fmt.Sprintf("sudo dpkg -i %s/%s", extractDir, pkg) + if err := run("bash", "-c", cmd); err != nil { + log.Printf("dpkg reported errors for %s (will fix with apt-get install -f)", pkg) + } + } + + log.Println("fixing missing dependencies with apt-get install -f") + if err := run("sudo", "apt-get", "install", "-f", "-y"); err != nil { + return fmt.Errorf("apt-get install -f: %w", err) + } + + log.Println("mysql 9 installed successfully") + return nil +} + +// ---- start ---- + +func runStart() error { + log.Println("=== Starting PostgreSQL and MySQL ===") + + if err := startPostgreSQL(); err != nil { + return fmt.Errorf("starting postgresql: %w", err) + } + + if err := startMySQL(); err != nil { + return fmt.Errorf("starting mysql: %w", err) + } + + log.Println("=== Both databases are running and configured ===") + log.Println("PostgreSQL: postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable") + log.Println("MySQL: root:mysecretpassword@tcp(127.0.0.1:3306)/mysql") + return nil +} + +func startPostgreSQL() error { + log.Println("--- Starting PostgreSQL ---") + + dataDir := pgDataDir() + logFile := filepath.Join(pgBaseDir(), "postgresql.log") + + // Check if already running + if pgIsReady() { + log.Println("postgresql is already running and accepting connections") + return nil + } + + // Initialize data directory if needed + if _, err := os.Stat(filepath.Join(dataDir, "PG_VERSION")); os.IsNotExist(err) { + log.Println("initializing postgresql data directory") + if err := os.MkdirAll(dataDir, 0o700); err != nil { + return fmt.Errorf("creating data directory: %w", err) + } + if err := run(pgBin("initdb"), + "-D", dataDir, + "--username=postgres", + "--auth=trust", + ); err != nil { + return fmt.Errorf("initdb: %w", err) + } + + // Configure pg_hba.conf for md5 password authentication on TCP + hbaPath := filepath.Join(dataDir, "pg_hba.conf") + if err := configurePgHBA(hbaPath); err != nil { + return fmt.Errorf("configuring pg_hba.conf: %w", err) + } + + // Configure postgresql.conf to listen on localhost + confPath := filepath.Join(dataDir, "postgresql.conf") + if err := appendToFile(confPath, + "\n# sqlc-test-setup configuration\n"+ + "listen_addresses = '127.0.0.1'\n"+ + "port = 5432\n", + ); err != nil { + return fmt.Errorf("configuring postgresql.conf: %w", err) + } + } else { + log.Println("postgresql data directory already initialized") + } + + // Start PostgreSQL using pg_ctl + log.Println("starting postgresql") + if err := run(pgBin("pg_ctl"), + "-D", dataDir, + "-l", logFile, + "-o", fmt.Sprintf("-k %s", dataDir), + "start", + ); err != nil { + return fmt.Errorf("pg_ctl start: %w", err) + } + + // Wait for PostgreSQL to be ready + log.Println("waiting for postgresql to accept connections") + if err := waitForPostgreSQL(30 * time.Second); err != nil { + return fmt.Errorf("postgresql did not start in time: %w", err) + } + + // Set the postgres user password + log.Println("setting password for postgres user") + if err := run(pgBin("psql"), + "-h", "127.0.0.1", + "-U", "postgres", + "-c", "ALTER USER postgres PASSWORD 'postgres';", + ); err != nil { + return fmt.Errorf("setting postgres password: %w", err) + } + + // Update pg_hba.conf to require md5 auth now that password is set + hbaPath := filepath.Join(dataDir, "pg_hba.conf") + if err := configurePgHBAWithMD5(hbaPath); err != nil { + return fmt.Errorf("updating pg_hba.conf for md5: %w", err) + } + + // Reload configuration + log.Println("reloading postgresql configuration") + if err := run(pgBin("pg_ctl"), "-D", dataDir, "reload"); err != nil { + return fmt.Errorf("pg_ctl reload: %w", err) + } + + // Verify connection with password + log.Println("verifying postgresql connection") + cmd := exec.Command(pgBin("psql"), + "-h", "127.0.0.1", + "-U", "postgres", + "-c", "SELECT 1;", + ) + cmd.Env = append(os.Environ(), "PGPASSWORD=postgres") + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("postgresql connection test failed: %w", err) + } + + log.Println("postgresql is running and configured") + return nil +} + +// configurePgHBA writes a pg_hba.conf that allows trust auth initially (for +// setting the password), then we switch to md5. +func configurePgHBA(hbaPath string) error { + content := `# pg_hba.conf - generated by sqlc-test-setup +# TYPE DATABASE USER ADDRESS METHOD +local all all trust +host all all 127.0.0.1/32 trust +host all all ::1/128 trust +` + return os.WriteFile(hbaPath, []byte(content), 0o600) +} + +// configurePgHBAWithMD5 rewrites pg_hba.conf to use md5 for TCP connections. +func configurePgHBAWithMD5(hbaPath string) error { + content := `# pg_hba.conf - generated by sqlc-test-setup +# TYPE DATABASE USER ADDRESS METHOD +local all all trust +host all all 127.0.0.1/32 md5 +host all all ::1/128 md5 +` + return os.WriteFile(hbaPath, []byte(content), 0o600) +} + +// appendToFile appends text to a file. +func appendToFile(path, text string) error { + f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + return err + } + defer f.Close() + _, err = f.WriteString(text) + return err +} + +// pgIsReady checks if PostgreSQL is running and accepting connections. +func pgIsReady() bool { + cmd := exec.Command(pgBin("pg_isready"), "-h", "127.0.0.1", "-p", "5432") + return cmd.Run() == nil +} + +// waitForPostgreSQL polls until PostgreSQL accepts connections or times out. +func waitForPostgreSQL(timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if pgIsReady() { + return nil + } + time.Sleep(500 * time.Millisecond) + } + return fmt.Errorf("timed out after %s waiting for postgresql", timeout) +} + +func startMySQL() error { + log.Println("--- Starting MySQL ---") + + // Check if MySQL is already running and accessible with the expected password + if mysqlReady() { + log.Println("mysql is already running and accepting connections") + return verifyMySQL() + } + + // Stop any existing MySQL service that might be running (e.g. pre-installed + // on GitHub Actions runners) to avoid port conflicts. + log.Println("stopping any existing mysql service") + _ = exec.Command("sudo", "service", "mysql", "stop").Run() + _ = exec.Command("sudo", "mysqladmin", "shutdown").Run() + // Give MySQL time to fully shut down + time.Sleep(2 * time.Second) + + if err := ensureMySQLDirs(); err != nil { + return err + } + + // Check if data directory already exists and has been initialized + needsPasswordReset := false + if mysqlInitialized() { + log.Println("mysql data directory already initialized, skipping initialization") + // Existing data dir may have an unknown root password (e.g. pre-installed + // MySQL on GitHub Actions). We'll need to use --skip-grant-tables to reset it. + needsPasswordReset = true + } else { + log.Println("initializing mysql data directory") + if err := run("sudo", "mysqld", "--initialize-insecure", "--user=mysql"); err != nil { + return fmt.Errorf("mysqld --initialize-insecure: %w", err) + } + } + + if needsPasswordReset { + // Start with --skip-grant-tables to reset the unknown root password. + if err := startMySQLDaemon("--skip-grant-tables"); err != nil { + return err + } + + log.Println("resetting root password via --skip-grant-tables") + resetSQL := "FLUSH PRIVILEGES; ALTER USER 'root'@'localhost' IDENTIFIED WITH caching_sha2_password BY 'mysecretpassword';" + if err := run("mysql", "-u", "root", "-e", resetSQL); err != nil { + return fmt.Errorf("resetting mysql root password: %w", err) + } + + // Restart without --skip-grant-tables + log.Println("restarting mysql normally") + if err := run("sudo", "mysqladmin", "-u", "root", "-pmysecretpassword", "shutdown"); err != nil { + // If mysqladmin fails, try killing the process directly + _ = run("sudo", "pkill", "-f", "mysqld") + } + time.Sleep(2 * time.Second) + + if err := startMySQLDaemon(); err != nil { + return err + } + } else { + // Fresh initialization — start normally and set password + if err := startMySQLDaemon(); err != nil { + return err + } + + log.Println("setting mysql root password") + alterSQL := "ALTER USER 'root'@'localhost' IDENTIFIED WITH caching_sha2_password BY 'mysecretpassword'; FLUSH PRIVILEGES;" + if err := run("mysql", "-u", "root", "-e", alterSQL); err != nil { + return fmt.Errorf("setting mysql root password: %w", err) + } + } + + return verifyMySQL() +} + +// ensureMySQLDirs creates the directories MySQL needs at runtime. +func ensureMySQLDirs() error { + if err := run("sudo", "mkdir", "-p", "/var/run/mysqld"); err != nil { + return fmt.Errorf("creating /var/run/mysqld: %w", err) + } + if err := run("sudo", "chown", "mysql:mysql", "/var/run/mysqld"); err != nil { + return fmt.Errorf("chowning /var/run/mysqld: %w", err) + } + return nil +} + +// startMySQLDaemon starts mysqld_safe in the background and waits for it to +// accept connections. Extra args (e.g. "--skip-grant-tables") are appended. +func startMySQLDaemon(extraArgs ...string) error { + args := append([]string{"mysqld_safe", "--user=mysql"}, extraArgs...) + log.Printf("starting mysql via mysqld_safe %v", extraArgs) + cmd := exec.Command("sudo", args...) + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + if err := cmd.Start(); err != nil { + return fmt.Errorf("starting mysqld_safe: %w", err) + } + + log.Println("waiting for mysql to accept connections") + if err := waitForMySQL(30 * time.Second); err != nil { + return fmt.Errorf("mysql did not start in time: %w", err) + } + log.Println("mysql is accepting connections") + return nil +} + +// mysqlReady checks if MySQL is running and accepting connections with the expected password. +func mysqlReady() bool { + err := exec.Command("mysqladmin", "-h", "127.0.0.1", "-u", "root", "-pmysecretpassword", "ping").Run() + return err == nil +} + +// waitForMySQL polls until MySQL accepts connections or the timeout expires. +func waitForMySQL(timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + // Try connecting without password (fresh) or with password (already configured) + if exec.Command("mysqladmin", "-u", "root", "ping").Run() == nil { + return nil + } + if exec.Command("mysqladmin", "-h", "127.0.0.1", "-u", "root", "-pmysecretpassword", "ping").Run() == nil { + return nil + } + time.Sleep(500 * time.Millisecond) + } + return fmt.Errorf("timed out after %s waiting for mysql", timeout) +} + +func verifyMySQL() error { + log.Println("verifying mysql connection") + if err := run("mysql", "-h", "127.0.0.1", "-u", "root", "-pmysecretpassword", "-e", "SELECT VERSION();"); err != nil { + return fmt.Errorf("mysql connection test failed: %w", err) + } + log.Println("mysql is running and configured") + return nil +} + +// mysqlInitialized checks if the MySQL data directory has been initialized. +// We use sudo ls because /var/lib/mysql is typically only readable by the +// mysql user, so filepath.Glob from a non-root process would silently fail. +func mysqlInitialized() bool { + out, err := exec.Command("sudo", "ls", "/var/lib/mysql").CombinedOutput() + if err != nil { + return false + } + // If the directory has any contents, consider it initialized. + // mysqld --initialize-insecure requires an empty directory. + return strings.TrimSpace(string(out)) != "" +} diff --git a/docs/reference/environment-variables.md b/docs/reference/environment-variables.md index 185807168c..837dd13980 100644 --- a/docs/reference/environment-variables.md +++ b/docs/reference/environment-variables.md @@ -1,5 +1,22 @@ # Environment variables +## SQLCEXPERIMENT + +The `SQLCEXPERIMENT` variable controls experimental features within sqlc. It is +a comma-separated list of experiment names. This is modeled after Go's +[GOEXPERIMENT](https://pkg.go.dev/internal/goexperiment) environment variable. + +Experiment names can be prefixed with `no` to explicitly disable them. + +``` +SQLCEXPERIMENT=foo,bar # enable foo and bar experiments +SQLCEXPERIMENT=nofoo # explicitly disable foo experiment +SQLCEXPERIMENT=foo,nobar # enable foo, disable bar +``` + +Currently, no experiments are defined. Experiments will be documented here as +they are introduced. + ## SQLCCACHE The `SQLCCACHE` environment variable dictates where `sqlc` will store cached diff --git a/docs/reference/language-support.rst b/docs/reference/language-support.rst index d6532ba543..20de2817d6 100644 --- a/docs/reference/language-support.rst +++ b/docs/reference/language-support.rst @@ -16,17 +16,20 @@ Community language support New languages can be added via :doc:`plugins <../guides/plugins>`. -======== ================================= =============== =============== =============== -Language Plugin MySQL PostgreSQL SQLite -======== ================================= =============== =============== =============== -C# `DaredevilOSS/sqlc-gen-csharp`_ Stable Stable Stable -F# `kaashyapan/sqlc-gen-fsharp`_ N/A Beta Beta -Java `tandemdude/sqlc-gen-java`_ Beta Beta N/A -PHP `lcarilla/sqlc-plugin-php-dbal`_ Beta N/A N/A -Ruby `DaredevilOSS/sqlc-gen-ruby`_ Beta Beta Beta -Zig `tinyzimmer/sqlc-gen-zig`_ N/A Beta Beta -[Any] `fdietze/sqlc-gen-from-template`_ Stable Stable Stable -======== ================================= =============== =============== =============== +======== ================================== =============== =============== =============== +Language Plugin MySQL PostgreSQL SQLite +======== ================================== =============== =============== =============== +C# `DaredevilOSS/sqlc-gen-csharp`_ Stable Stable Stable +F# `kaashyapan/sqlc-gen-fsharp`_ N/A Beta Beta +Java `tandemdude/sqlc-gen-java`_ Beta Beta N/A +PHP `lcarilla/sqlc-plugin-php-dbal`_ Beta N/A N/A +Ruby `DaredevilOSS/sqlc-gen-ruby`_ Beta Beta Beta +Zig `tinyzimmer/sqlc-gen-zig`_ N/A Beta Beta +Python `rayakame/sqlc-gen-better-python`_ N/A Beta Beta +[Any] `fdietze/sqlc-gen-from-template`_ Stable Stable Stable +======== ================================== =============== =============== =============== + +Plugins developed by our Community can also be found using our `github topic`_. Community projects ****************** @@ -49,3 +52,5 @@ Gleam `daniellionel01/parrot`_ Stable Stable S .. _tandemdude/sqlc-gen-java: https://github.com/tandemdude/sqlc-gen-java .. _tinyzimmer/sqlc-gen-zig: https://github.com/tinyzimmer/sqlc-gen-zig .. _daniellionel01/parrot: https://github.com/daniellionel01/parrot +.. _rayakame/sqlc-gen-better-python: https://github.com/rayakame/sqlc-gen-better-python +.. _github topic: https://github.com/topics/sqlc-plugin diff --git a/docs/requirements.txt b/docs/requirements.txt index d2720c5c21..b67f41549a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,7 +3,7 @@ Jinja2==3.1.6 MarkupSafe==3.0.3 Pygments==2.19.2 Sphinx==7.4.7 -certifi==2025.10.5 +certifi==2026.1.4 chardet==5.2.0 commonmark==0.9.1 docutils==0.20.1 @@ -11,7 +11,7 @@ idna==3.11 imagesize==1.4.1 myst-parser==4.0.1 packaging==25.0 -pyparsing==3.2.5 +pyparsing==3.3.1 pytz==2025.2 requests==2.32.5 snowballstemmer==3.0.1 @@ -24,4 +24,4 @@ sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==2.0.0 sphinxcontrib-serializinghtml==2.0.0 sphinxext-rediraffe==0.3.0 -urllib3==2.5.0 +urllib3==2.6.3 diff --git a/go.mod b/go.mod index e0f585b9fd..57624c33b1 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/sqlc-dev/sqlc -go 1.24.0 - -toolchain go1.24.1 +go 1.26.0 require ( github.com/antlr4-go/antlr/v4 v4.13.1 @@ -10,32 +8,31 @@ require ( github.com/davecgh/go-spew v1.1.1 github.com/fatih/structtag v1.2.0 github.com/go-sql-driver/mysql v1.9.3 - github.com/google/cel-go v0.26.1 + github.com/google/cel-go v0.27.0 github.com/google/go-cmp v0.7.0 github.com/jackc/pgx/v4 v4.18.3 - github.com/jackc/pgx/v5 v5.7.6 + github.com/jackc/pgx/v5 v5.8.0 github.com/jinzhu/inflection v1.0.0 - github.com/lib/pq v1.10.9 - github.com/pganalyze/pg_query_go/v6 v6.1.0 + github.com/lib/pq v1.12.0 + github.com/ncruces/go-sqlite3 v0.32.0 + github.com/pganalyze/pg_query_go/v6 v6.2.2 github.com/pingcap/tidb/pkg/parser v0.0.0-20250324122243-d51e00e5bbf0 github.com/riza-io/grpc-go v0.2.0 - github.com/spf13/cobra v1.10.1 + github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.10 - github.com/tetratelabs/wazero v1.9.0 + github.com/sqlc-dev/doubleclick v1.0.0 + github.com/tetratelabs/wazero v1.11.0 github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 github.com/xeipuuv/gojsonschema v1.2.0 - golang.org/x/sync v0.17.0 - google.golang.org/grpc v1.76.0 - google.golang.org/protobuf v1.36.10 + golang.org/x/sync v0.20.0 + google.golang.org/grpc v1.79.3 + google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 - modernc.org/sqlite v1.39.1 ) require ( - cel.dev/expr v0.24.0 // indirect - filippo.io/edwards25519 v1.1.0 // indirect - github.com/dustin/go-humanize v1.0.1 // indirect - github.com/google/uuid v1.6.0 // indirect + cel.dev/expr v0.25.1 // indirect + filippo.io/edwards25519 v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 // indirect @@ -45,29 +42,25 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgtype v1.14.0 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/ncruces/go-strftime v0.1.9 // indirect + github.com/ncruces/julianday v1.0.0 // indirect github.com/pingcap/errors v0.11.5-0.20240311024730-e056997136bb // indirect github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/log v1.1.0 // indirect - github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect - github.com/stoewer/go-strcase v1.2.0 // indirect github.com/wasilibs/wazero-helpers v0.0.0-20240620070341-3dff1577cd52 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect - golang.org/x/crypto v0.40.0 // indirect + golang.org/x/crypto v0.48.0 // indirect golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect - golang.org/x/net v0.42.0 // indirect - golang.org/x/sys v0.36.0 // indirect - golang.org/x/text v0.27.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250804133106-a7a43d27e69b // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250804133106-a7a43d27e69b // indirect + golang.org/x/net v0.49.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect - modernc.org/libc v1.66.10 // indirect - modernc.org/mathutil v1.7.1 // indirect - modernc.org/memory v1.11.0 // indirect ) + +replace github.com/go-sql-driver/mysql => github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2 diff --git a/go.sum b/go.sum index 2d91a24ae4..ca0a5c97f4 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,14 @@ -cel.dev/expr v0.24.0 h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY= -cel.dev/expr v0.24.0/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4= +cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= +filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw= +filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -18,8 +20,6 @@ github.com/cubicdaiya/gonp v1.0.4/go.mod h1:iWGuP/7+JVTn02OWhRemVbMmG1DOUnmrGTYY github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= -github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -28,21 +28,15 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= -github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/cel-go v0.26.1 h1:iPbVVEdkhTX++hpe3lzSk7D3G3QSYqLGoHOcEio+UXQ= -github.com/google/cel-go v0.26.1/go.mod h1:A9O8OU9rdvrK5MQyrqfIxo1a0u4g3sF8KB6PUIaryMM= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/cel-go v0.27.0 h1:e7ih85+4qVrBuqQWTW4FKSqZYokVuc3HnhH5keboFTo= +github.com/google/cel-go v0.27.0/go.mod h1:tTJ11FWqnhw5KKpnWpvW9CJC3Y9GK4EIS0WXnBbebzw= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= -github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -92,8 +86,8 @@ github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQ github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/pgx/v4 v4.18.3 h1:dE2/TrEsGX3RBprb3qryqSV9Y60iZN1C6i8IrmW9/BA= github.com/jackc/pgx/v4 v4.18.3/go.mod h1:Ey4Oru5tH5sB6tV7hDmfWFahwF15Eb7DNXlRKx2CkVw= -github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= -github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= +github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= @@ -116,19 +110,19 @@ github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.12.0 h1:mC1zeiNamwKBecjHarAr26c/+d8V5w/u4J0I/yASbJo= +github.com/lib/pq v1.12.0/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= -github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/pganalyze/pg_query_go/v6 v6.1.0 h1:jG5ZLhcVgL1FAw4C/0VNQaVmX1SUJx71wBGdtTtBvls= -github.com/pganalyze/pg_query_go/v6 v6.1.0/go.mod h1:nvTHIuoud6e1SfrUaFwHqT0i4b5Nr+1rPWVds3B5+50= +github.com/ncruces/go-sqlite3 v0.32.0 h1:hNBUXp88LrfQCsuyXLqWTbTUG35sUuktDsqhhgHvU20= +github.com/ncruces/go-sqlite3 v0.32.0/go.mod h1:MIWTK60ONDl0oVY073zYvJP21C3Dly6P9bxVpgkLwdQ= +github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= +github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= +github.com/pganalyze/pg_query_go/v6 v6.2.2 h1:O0L6zMC226R82RF3X5n0Ki6HjytDsoAzuzp4ATVAHNo= +github.com/pganalyze/pg_query_go/v6 v6.2.2/go.mod h1:Cn6+j4870kJz3iYNsb0VsNG04vpSWgEvBwc590J4qD0= github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.5-0.20240311024730-e056997136bb h1:3pSi4EDG6hg0orE1ndHkXvX6Qdq2cZn8gAPir8ymKZk= github.com/pingcap/errors v0.11.5-0.20240311024730-e056997136bb/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg= @@ -143,8 +137,6 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/riza-io/grpc-go v0.2.0 h1:2HxQKFVE7VuYstcJ8zqpN84VnAoJ4dCL6YFhJewNcHQ= github.com/riza-io/grpc-go v0.2.0/go.mod h1:2bDvR9KkKC3KhtlSHfR3dAXjUMT86kg4UfWFyVGWqi8= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -160,13 +152,15 @@ github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXY github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= -github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= -github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= +github.com/sqlc-dev/doubleclick v1.0.0 h1:2/OApfQ2eLgcfa/Fqs8WSMA6atH0G8j9hHbQIgMfAXI= +github.com/sqlc-dev/doubleclick v1.0.0/go.mod h1:ODHRroSrk/rr5neRHlWMSRijqOak8YmNaO3VAZCNl5Y= +github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2 h1:kmCAKKtOgK6EXXQX9oPdEASIhgor7TCpWxD8NtcqVcU= +github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2/go.mod h1:TrDMWzjNTKvJeK2GC8uspG+PWyPLiY9QKvwdWpAdlZE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= @@ -175,10 +169,10 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= -github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA= +github.com/tetratelabs/wazero v1.11.0/go.mod h1:eV28rsN8Q+xwjogd7f4/Pp4xFxO7uOGbLcD/LzB1wiU= github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 h1:mJdDDPblDfPe7z7go8Dvv1AJQDI3eQ/5xith3q2mFlo= github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07/go.mod h1:Ak17IJ037caFp4jpCw/iQQ7/W74Sqpb1YuKJU6HTKfM= github.com/wasilibs/wazero-helpers v0.0.0-20240620070341-3dff1577cd52 h1:OvLBa8SqJnZ6P+mjlzc2K7PM22rRUPE1x32G9DTPrC4= @@ -190,18 +184,18 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1: github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= -go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= -go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= -go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= -go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= -go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= -go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= -go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= -go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= -go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= @@ -227,6 +221,8 @@ go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -236,25 +232,23 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= -golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= -golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= -golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -267,9 +261,8 @@ golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= -golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -277,8 +270,8 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= @@ -288,26 +281,21 @@ golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= -golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/genproto/googleapis/api v0.0.0-20250804133106-a7a43d27e69b h1:ULiyYQ0FdsJhwwZUwbaXpZF5yUE3h+RA+gxvBu37ucc= -google.golang.org/genproto/googleapis/api v0.0.0-20250804133106-a7a43d27e69b/go.mod h1:oDOGiMSXHL4sDTJvFvIB9nRQCGdLP1o/iVaqQK8zB+M= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250804133106-a7a43d27e69b h1:zPKJod4w6F1+nRGDI9ubnXYhU9NSWoFAijkHkUXeTK8= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250804133106-a7a43d27e69b/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= -google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A= -google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= -google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= +google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= @@ -324,29 +312,3 @@ gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4= -modernc.org/cc/v4 v4.26.5/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= -modernc.org/ccgo/v4 v4.28.1 h1:wPKYn5EC/mYTqBO373jKjvX2n+3+aK7+sICCv4Fjy1A= -modernc.org/ccgo/v4 v4.28.1/go.mod h1:uD+4RnfrVgE6ec9NGguUNdhqzNIeeomeXf6CL0GTE5Q= -modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= -modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= -modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= -modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= -modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= -modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= -modernc.org/libc v1.66.10 h1:yZkb3YeLx4oynyR+iUsXsybsX4Ubx7MQlSYEw4yj59A= -modernc.org/libc v1.66.10/go.mod h1:8vGSEwvoUoltr4dlywvHqjtAqHBaw0j1jI7iFBTAr2I= -modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= -modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= -modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= -modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= -modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= -modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= -modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= -modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= -modernc.org/sqlite v1.39.1 h1:H+/wGFzuSCIEVCvXYVHX5RQglwhMOvtHSv+VtidL2r4= -modernc.org/sqlite v1.39.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE= -modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= -modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= -modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= -modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/analyzer/analyzer.go b/internal/analyzer/analyzer.go index 3d7e3a0287..674f283db9 100644 --- a/internal/analyzer/analyzer.go +++ b/internal/analyzer/analyzer.go @@ -110,7 +110,21 @@ func (c *CachedAnalyzer) Close(ctx context.Context) error { return c.a.Close(ctx) } +func (c *CachedAnalyzer) EnsureConn(ctx context.Context, migrations []string) error { + return c.a.EnsureConn(ctx, migrations) +} + +func (c *CachedAnalyzer) GetColumnNames(ctx context.Context, query string) ([]string, error) { + return c.a.GetColumnNames(ctx, query) +} + type Analyzer interface { Analyze(context.Context, ast.Node, string, []string, *named.ParamSet) (*analysis.Analysis, error) Close(context.Context) error + // EnsureConn initializes the database connection with the given migrations. + // This is required for database-only mode where we need to connect before analyzing queries. + EnsureConn(ctx context.Context, migrations []string) error + // GetColumnNames returns the column names for a query by preparing it against the database. + // This is used for star expansion in database-only mode. + GetColumnNames(ctx context.Context, query string) ([]string, error) } diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 93fd6bbeaa..80a167353e 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -30,6 +30,7 @@ func init() { initCmd.Flags().BoolP("v1", "", false, "generate v1 config yaml file") initCmd.Flags().BoolP("v2", "", true, "generate v2 config yaml file") initCmd.MarkFlagsMutuallyExclusive("v1", "v2") + parseCmd.Flags().StringP("dialect", "d", "", "SQL dialect to use (postgresql, mysql, or sqlite)") } // Do runs the command logic. @@ -44,6 +45,7 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int rootCmd.AddCommand(diffCmd) rootCmd.AddCommand(genCmd) rootCmd.AddCommand(initCmd) + rootCmd.AddCommand(parseCmd) rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(verifyCmd) rootCmd.AddCommand(pushCmd) @@ -136,10 +138,11 @@ var initCmd = &cobra.Command{ } type Env struct { - DryRun bool - Debug opts.Debug - Remote bool - NoRemote bool + DryRun bool + Debug opts.Debug + Experiment opts.Experiment + Remote bool + NoRemote bool } func ParseEnv(c *cobra.Command) Env { @@ -147,10 +150,11 @@ func ParseEnv(c *cobra.Command) Env { r := c.Flag("remote") nr := c.Flag("no-remote") return Env{ - DryRun: dr != nil && dr.Changed, - Debug: opts.DebugFromEnv(), - Remote: r != nil && r.Value.String() == "true", - NoRemote: nr != nil && nr.Value.String() == "true", + DryRun: dr != nil && dr.Changed, + Debug: opts.DebugFromEnv(), + Experiment: opts.ExperimentFromEnv(), + Remote: r != nil && r.Value.String() == "true", + NoRemote: nr != nil && nr.Value.String() == "true", } } diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 00e8871c7e..05b5445ebb 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -295,7 +295,7 @@ func remoteGenerate(ctx context.Context, configPath string, conf *config.Config, func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { defer trace.StartRegion(ctx, "parse").End() - c, err := compiler.NewCompiler(sql, combo) + c, err := compiler.NewCompiler(sql, combo, parserOpts) defer func() { if c != nil { c.Close(ctx) diff --git a/internal/cmd/parse.go b/internal/cmd/parse.go new file mode 100644 index 0000000000..aca01511f1 --- /dev/null +++ b/internal/cmd/parse.go @@ -0,0 +1,101 @@ +package cmd + +import ( + "encoding/json" + "fmt" + "io" + "os" + + "github.com/spf13/cobra" + + "github.com/sqlc-dev/sqlc/internal/engine/clickhouse" + "github.com/sqlc-dev/sqlc/internal/engine/dolphin" + "github.com/sqlc-dev/sqlc/internal/engine/postgresql" + "github.com/sqlc-dev/sqlc/internal/engine/sqlite" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +var parseCmd = &cobra.Command{ + Use: "parse [file]", + Short: "Parse SQL and output the AST as JSON", + Long: `Parse SQL from a file or stdin and output the abstract syntax tree as JSON. + +Examples: + # Parse a SQL file with PostgreSQL dialect + sqlc parse --dialect postgresql schema.sql + + # Parse from stdin with MySQL dialect + echo "SELECT * FROM users" | sqlc parse --dialect mysql + + # Parse SQLite SQL + sqlc parse --dialect sqlite queries.sql + + # Parse ClickHouse SQL + sqlc parse --dialect clickhouse queries.sql`, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + dialect, err := cmd.Flags().GetString("dialect") + if err != nil { + return err + } + if dialect == "" { + return fmt.Errorf("--dialect flag is required (postgresql, mysql, sqlite, or clickhouse)") + } + + // Determine input source + var input io.Reader + if len(args) == 1 { + file, err := os.Open(args[0]) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + input = file + } else { + // Check if stdin has data + stat, err := os.Stdin.Stat() + if err != nil { + return fmt.Errorf("failed to stat stdin: %w", err) + } + if (stat.Mode() & os.ModeCharDevice) != 0 { + return fmt.Errorf("no input provided. Specify a file path or pipe SQL via stdin") + } + input = cmd.InOrStdin() + } + + // Parse SQL based on dialect + var stmts []ast.Statement + switch dialect { + case "postgresql", "postgres", "pg": + parser := postgresql.NewParser() + stmts, err = parser.Parse(input) + case "mysql": + parser := dolphin.NewParser() + stmts, err = parser.Parse(input) + case "sqlite": + parser := sqlite.NewParser() + stmts, err = parser.Parse(input) + case "clickhouse": + parser := clickhouse.NewParser() + stmts, err = parser.Parse(input) + default: + return fmt.Errorf("unsupported dialect: %s (use postgresql, mysql, sqlite, or clickhouse)", dialect) + } + if err != nil { + return fmt.Errorf("parse error: %w", err) + } + + // Output AST as JSON + stdout := cmd.OutOrStdout() + encoder := json.NewEncoder(stdout) + encoder.SetIndent("", " ") + + for _, stmt := range stmts { + if err := encoder.Encode(stmt.Raw); err != nil { + return fmt.Errorf("failed to encode AST: %w", err) + } + } + + return nil + }, +} diff --git a/internal/cmd/vet.go b/internal/cmd/vet.go index fe3ece38f3..4dbd3c3b7b 100644 --- a/internal/cmd/vet.go +++ b/internal/cmd/vet.go @@ -391,6 +391,19 @@ type checker struct { Replacer *shfmt.Replacer } +// isInMemorySQLite checks if a SQLite URI refers to an in-memory database +func isInMemorySQLite(uri string) bool { + if uri == ":memory:" || uri == "" { + return true + } + // Check for file URI with mode=memory parameter + // e.g., "file:test?mode=memory&cache=shared" + if strings.Contains(uri, "mode=memory") { + return true + } + return false +} + func (c *checker) fetchDatabaseUri(ctx context.Context, s config.SQL) (string, func() error, error) { cleanup := func() error { return nil @@ -517,7 +530,7 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error { prep = &dbPreparer{db} expl = &mysqlExplainer{db} case config.EngineSQLite: - db, err := sql.Open("sqlite", dburl) + db, err := sql.Open("sqlite3", dburl) if err != nil { return fmt.Errorf("database: connection error: %s", err) } @@ -525,6 +538,23 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error { return fmt.Errorf("database: connection error: %s", err) } defer db.Close() + // For in-memory SQLite databases, apply migrations + if isInMemorySQLite(dburl) { + files, err := sqlpath.Glob(s.Schema) + if err != nil { + return fmt.Errorf("schema: %w", err) + } + for _, schema := range files { + contents, err := os.ReadFile(schema) + if err != nil { + return fmt.Errorf("read schema file: %w", err) + } + ddl := migrations.RemoveRollbackStatements(string(contents)) + if _, err := db.ExecContext(ctx, ddl); err != nil { + return fmt.Errorf("apply schema %s: %w", schema, err) + } + } + } prep = &dbPreparer{db} // SQLite really doesn't want us to depend on the output of EXPLAIN // QUERY PLAN: https://www.sqlite.org/eqp.html diff --git a/internal/cmd/vet_modernc.go b/internal/cmd/vet_modernc.go deleted file mode 100644 index 74313007af..0000000000 --- a/internal/cmd/vet_modernc.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !wasm - -package cmd - -import ( - _ "modernc.org/sqlite" -) diff --git a/internal/cmd/vet_sqlite.go b/internal/cmd/vet_sqlite.go new file mode 100644 index 0000000000..e1f8c7f9a8 --- /dev/null +++ b/internal/cmd/vet_sqlite.go @@ -0,0 +1,6 @@ +package cmd + +import ( + _ "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" +) diff --git a/internal/codegen/golang/mysql_type.go b/internal/codegen/golang/mysql_type.go index b8e8aa43c7..252e291f58 100644 --- a/internal/codegen/golang/mysql_type.go +++ b/internal/codegen/golang/mysql_type.go @@ -64,7 +64,11 @@ func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.C } return "sql.NullInt32" - case "bigint": + case "bigint", "bigint unsigned", "bigint signed": + // "bigint unsigned" and "bigint signed" are MySQL CAST types + // Note: We use int64 for CAST AS UNSIGNED to match original behavior, + // even though uint64 would be more semantically correct. + // The Unsigned flag on columns (from table schema) still uses uint64. if notNull { if unsigned { return "uint64" diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 515d0a654f..0820488f9d 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -141,9 +141,7 @@ func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string } fields := make([]Field, len(s.Fields)) - for i, f := range s.Fields { - fields[i] = f - } + copy(fields, s.Fields) return &goEmbed{ modelType: s.Name, diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 38d66fce19..0d7d507575 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -79,9 +79,13 @@ func combineAnalysis(prev *analysis, a *analyzer.Analysis) *analysis { } if len(prev.Columns) == len(cols) { for i := range prev.Columns { - prev.Columns[i].DataType = cols[i].DataType - prev.Columns[i].IsArray = cols[i].IsArray - prev.Columns[i].ArrayDims = cols[i].ArrayDims + // Only override column types if the analyzer provides a specific type + // (not "any"), since the catalog-based inference may have better info + if cols[i].DataType != "any" { + prev.Columns[i].DataType = cols[i].DataType + prev.Columns[i].IsArray = cols[i].IsArray + prev.Columns[i].ArrayDims = cols[i].ArrayDims + } } } else { embedding := false @@ -96,9 +100,13 @@ func combineAnalysis(prev *analysis, a *analyzer.Analysis) *analysis { } if len(prev.Parameters) == len(params) { for i := range prev.Parameters { - prev.Parameters[i].Column.DataType = params[i].Column.DataType - prev.Parameters[i].Column.IsArray = params[i].Column.IsArray - prev.Parameters[i].Column.ArrayDims = params[i].Column.ArrayDims + // Only override parameter types if the analyzer provides a specific type + // (not "any"), since the catalog-based inference may have better info + if params[i].Column.DataType != "any" { + prev.Parameters[i].Column.DataType = params[i].Column.DataType + prev.Parameters[i].Column.IsArray = params[i].Column.IsArray + prev.Parameters[i].Column.ArrayDims = params[i].Column.ArrayDims + } } } else { prev.Parameters = params diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index 84fbb20a3c..1a95b586f4 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -1,6 +1,7 @@ package compiler import ( + "context" "errors" "fmt" "io" @@ -39,11 +40,20 @@ func (c *Compiler) parseCatalog(schemas []string) error { } contents := migrations.RemoveRollbackStatements(string(blob)) c.schema = append(c.schema, contents) + + // In database-only mode, we parse the schema to validate syntax + // but don't update the catalog - the database will be the source of truth stmts, err := c.parser.Parse(strings.NewReader(contents)) if err != nil { merr.Add(filename, contents, 0, err) continue } + + // Skip catalog updates in database-only mode + if c.databaseOnlyMode { + continue + } + for i := range stmts { if err := c.catalog.Update(stmts[i], c); err != nil { merr.Add(filename, contents, stmts[i].Pos(), err) @@ -58,6 +68,15 @@ func (c *Compiler) parseCatalog(schemas []string) error { } func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) { + ctx := context.Background() + + // In database-only mode, initialize the database connection before parsing queries + if c.databaseOnlyMode && c.analyzer != nil { + if err := c.analyzer.EnsureConn(ctx, c.schema); err != nil { + return nil, fmt.Errorf("failed to initialize database connection: %w", err) + } + } + var q []*Query merr := multierr.New() set := map[string]struct{}{} @@ -113,6 +132,7 @@ func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) { if len(q) == 0 { return nil, fmt.Errorf("no queries contained in paths %s", strings.Join(c.conf.Queries, ",")) } + return &Result{ Catalog: c.catalog, Queries: q, diff --git a/internal/compiler/engine.go b/internal/compiler/engine.go index f742bfd999..64fdf3d5c7 100644 --- a/internal/compiler/engine.go +++ b/internal/compiler/engine.go @@ -11,8 +11,10 @@ import ( "github.com/sqlc-dev/sqlc/internal/engine/postgresql" pganalyze "github.com/sqlc-dev/sqlc/internal/engine/postgresql/analyzer" "github.com/sqlc-dev/sqlc/internal/engine/sqlite" + sqliteanalyze "github.com/sqlc-dev/sqlc/internal/engine/sqlite/analyzer" "github.com/sqlc-dev/sqlc/internal/opts" "github.com/sqlc-dev/sqlc/internal/sql/catalog" + "github.com/sqlc-dev/sqlc/internal/x/expander" ) type Compiler struct { @@ -26,9 +28,15 @@ type Compiler struct { selector selector schema []string + + // databaseOnlyMode indicates that the compiler should use database-only analysis + // and skip building the internal catalog from schema files (analyzer.database: only) + databaseOnlyMode bool + // expander is used to expand SELECT * and RETURNING * in database-only mode + expander *expander.Expander } -func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, error) { +func NewCompiler(conf config.SQL, combo config.CombinedSettings, parserOpts opts.Parser) (*Compiler, error) { c := &Compiler{conf: conf, combo: combo} if conf.Database != nil && conf.Database.Managed { @@ -36,21 +44,66 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err c.client = client } + // Check for database-only mode (analyzer.database: only) + // This feature requires the analyzerv2 experiment to be enabled + databaseOnlyMode := conf.Analyzer.Database.IsOnly() && parserOpts.Experiment.AnalyzerV2 + switch conf.Engine { case config.EngineSQLite: - c.parser = sqlite.NewParser() + parser := sqlite.NewParser() + c.parser = parser c.catalog = sqlite.NewCatalog() c.selector = newSQLiteSelector() + + if databaseOnlyMode { + // Database-only mode requires a database connection + if conf.Database == nil { + return nil, fmt.Errorf("analyzer.database: only requires database configuration") + } + if conf.Database.URI == "" && !conf.Database.Managed { + return nil, fmt.Errorf("analyzer.database: only requires database.uri or database.managed") + } + c.databaseOnlyMode = true + // Create the SQLite analyzer (implements Analyzer interface) + sqliteAnalyzer := sqliteanalyze.New(*conf.Database) + c.analyzer = analyzer.Cached(sqliteAnalyzer, combo.Global, *conf.Database) + // Create the expander using the analyzer as the column getter + c.expander = expander.New(c.analyzer, parser, parser) + } else if conf.Database != nil { + if conf.Analyzer.Database.IsEnabled() { + c.analyzer = analyzer.Cached( + sqliteanalyze.New(*conf.Database), + combo.Global, + *conf.Database, + ) + } + } case config.EngineMySQL: c.parser = dolphin.NewParser() c.catalog = dolphin.NewCatalog() c.selector = newDefaultSelector() case config.EnginePostgreSQL: - c.parser = postgresql.NewParser() + parser := postgresql.NewParser() + c.parser = parser c.catalog = postgresql.NewCatalog() c.selector = newDefaultSelector() - if conf.Database != nil { - if conf.Analyzer.Database == nil || *conf.Analyzer.Database { + + if databaseOnlyMode { + // Database-only mode requires a database connection + if conf.Database == nil { + return nil, fmt.Errorf("analyzer.database: only requires database configuration") + } + if conf.Database.URI == "" && !conf.Database.Managed { + return nil, fmt.Errorf("analyzer.database: only requires database.uri or database.managed") + } + c.databaseOnlyMode = true + // Create the PostgreSQL analyzer (implements Analyzer interface) + pgAnalyzer := pganalyze.New(c.client, *conf.Database) + c.analyzer = analyzer.Cached(pgAnalyzer, combo.Global, *conf.Database) + // Create the expander using the analyzer as the column getter + c.expander = expander.New(c.analyzer, parser, parser) + } else if conf.Database != nil { + if conf.Analyzer.Database.IsEnabled() { c.analyzer = analyzer.Cached( pganalyze.New(c.client, *conf.Database), combo.Global, diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index b0a15e6ac4..dbd486359a 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -482,7 +482,14 @@ func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, erro list := &ast.List{} switch n := node.(type) { case *ast.DeleteStmt: - list = n.Relations + if n.Relations != nil { + list = n.Relations + } else if n.FromClause != nil { + // Multi-table DELETE: walk FromClause to find tables + var tv tableVisitor + astutils.Walk(&tv, n.FromClause) + list = &tv.list + } case *ast.InsertStmt: list = &ast.List{ Items: []ast.Node{n.Relation}, diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 681d291122..751cb3271a 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -71,7 +71,56 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, } var anlys *analysis - if c.analyzer != nil { + if c.databaseOnlyMode && c.expander != nil { + // In database-only mode, use the expander for star expansion + // and rely entirely on the database analyzer for type resolution + expandedQuery, err := c.expander.Expand(ctx, rawSQL) + if err != nil { + return nil, fmt.Errorf("star expansion failed: %w", err) + } + + // Parse named parameters from the expanded query + expandedStmts, err := c.parser.Parse(strings.NewReader(expandedQuery)) + if err != nil { + return nil, fmt.Errorf("parsing expanded query failed: %w", err) + } + if len(expandedStmts) == 0 { + return nil, errors.New("no statements in expanded query") + } + expandedRaw := expandedStmts[0].Raw + + // Use the analyzer to get type information from the database + result, err := c.analyzer.Analyze(ctx, expandedRaw, expandedQuery, c.schema, nil) + if err != nil { + return nil, err + } + + // Convert the analyzer result to the internal analysis format + var cols []*Column + for _, col := range result.Columns { + cols = append(cols, convertColumn(col)) + } + var params []Parameter + for _, p := range result.Params { + params = append(params, Parameter{ + Number: int(p.Number), + Column: convertColumn(p.Column), + }) + } + + // Determine the insert table if applicable + var table *ast.TableName + if insert, ok := expandedRaw.Stmt.(*ast.InsertStmt); ok { + table, _ = ParseTableName(insert.Relation) + } + + anlys = &analysis{ + Table: table, + Columns: cols, + Parameters: params, + Query: expandedQuery, + } + } else if c.analyzer != nil { inference, _ := c.inferQuery(raw, rawSQL) if inference == nil { inference = &analysis{} diff --git a/internal/config/config.go b/internal/config/config.go index 0ff805fccd..d3e610ef05 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -122,8 +122,75 @@ type SQL struct { Analyzer Analyzer `json:"analyzer" yaml:"analyzer"` } +// AnalyzerDatabase represents the database analyzer setting. +// It can be a boolean (true/false) or the string "only" for database-only mode. +type AnalyzerDatabase struct { + value *bool // nil means not set, true/false for boolean values + isOnly bool // true when set to "only" +} + +// IsEnabled returns true if the database analyzer should be used. +// Returns true for both `true` and `"only"` settings. +func (a AnalyzerDatabase) IsEnabled() bool { + if a.isOnly { + return true + } + return a.value == nil || *a.value +} + +// IsOnly returns true if the analyzer is set to "only" mode. +func (a AnalyzerDatabase) IsOnly() bool { + return a.isOnly +} + +func (a *AnalyzerDatabase) UnmarshalJSON(data []byte) error { + // Try to unmarshal as boolean first + var b bool + if err := json.Unmarshal(data, &b); err == nil { + a.value = &b + a.isOnly = false + return nil + } + + // Try to unmarshal as string + var s string + if err := json.Unmarshal(data, &s); err == nil { + if s == "only" { + a.isOnly = true + a.value = nil + return nil + } + return errors.New("analyzer.database must be true, false, or \"only\"") + } + + return errors.New("analyzer.database must be true, false, or \"only\"") +} + +func (a *AnalyzerDatabase) UnmarshalYAML(unmarshal func(interface{}) error) error { + // Try to unmarshal as boolean first + var b bool + if err := unmarshal(&b); err == nil { + a.value = &b + a.isOnly = false + return nil + } + + // Try to unmarshal as string + var s string + if err := unmarshal(&s); err == nil { + if s == "only" { + a.isOnly = true + a.value = nil + return nil + } + return errors.New("analyzer.database must be true, false, or \"only\"") + } + + return errors.New("analyzer.database must be true, false, or \"only\"") +} + type Analyzer struct { - Database *bool `json:"database" yaml:"database"` + Database AnalyzerDatabase `json:"database" yaml:"database"` } // TODO: Figure out a better name for this diff --git a/internal/config/v_one.json b/internal/config/v_one.json index a0667a7c9c..e5ce9ec549 100644 --- a/internal/config/v_one.json +++ b/internal/config/v_one.json @@ -79,7 +79,10 @@ "type": "object", "properties": { "database": { - "type": "boolean" + "oneOf": [ + {"type": "boolean"}, + {"const": "only"} + ] } } }, diff --git a/internal/config/v_two.json b/internal/config/v_two.json index acf914997d..22591d7335 100644 --- a/internal/config/v_two.json +++ b/internal/config/v_two.json @@ -82,7 +82,10 @@ "type": "object", "properties": { "database": { - "type": "boolean" + "oneOf": [ + {"type": "boolean"}, + {"const": "only"} + ] } } }, diff --git a/internal/endtoend/CLAUDE.md b/internal/endtoend/CLAUDE.md new file mode 100644 index 0000000000..b9c995c9df --- /dev/null +++ b/internal/endtoend/CLAUDE.md @@ -0,0 +1,117 @@ +# End-to-End Tests - Native Database Setup + +This document describes how to set up MySQL and PostgreSQL for running end-to-end tests in environments without Docker, particularly when using an HTTP proxy. + +## Overview + +The end-to-end tests support three methods for connecting to databases: + +1. **Environment Variables**: Set `POSTGRESQL_SERVER_URI` and `MYSQL_SERVER_URI` directly +2. **Docker**: Automatically starts containers via the docker package +3. **Native Installation**: Starts existing database services on Linux + +## Installing Databases with HTTP Proxy + +In environments where DNS doesn't work directly but an HTTP proxy is available (e.g., some CI environments), you need to configure apt to use the proxy before installing packages. + +### Configure apt Proxy + +```bash +# Check if HTTP_PROXY is set +echo $HTTP_PROXY + +# Configure apt to use the proxy +sudo tee /etc/apt/apt.conf.d/99proxy << EOF +Acquire::http::Proxy "$HTTP_PROXY"; +Acquire::https::Proxy "$HTTPS_PROXY"; +EOF + +# Update package lists +sudo apt-get update -qq +``` + +### Install PostgreSQL + +```bash +# Install PostgreSQL +sudo DEBIAN_FRONTEND=noninteractive apt-get install -y postgresql postgresql-contrib + +# Start the service +sudo service postgresql start + +# Set password for postgres user +sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'postgres';" + +# Configure pg_hba.conf for password authentication +# Find the hba_file location: +sudo -u postgres psql -t -c "SHOW hba_file;" + +# Add md5 authentication for localhost (add to the beginning of pg_hba.conf): +# host all all 127.0.0.1/32 md5 + +# Reload PostgreSQL +sudo service postgresql reload +``` + +### Install MySQL + +```bash +# Pre-configure MySQL root password +echo "mysql-server mysql-server/root_password password mysecretpassword" | sudo debconf-set-selections +echo "mysql-server mysql-server/root_password_again password mysecretpassword" | sudo debconf-set-selections + +# Install MySQL +sudo DEBIAN_FRONTEND=noninteractive apt-get install -y mysql-server + +# Start the service +sudo service mysql start + +# Verify connection +mysql -uroot -pmysecretpassword -e "SELECT 1;" +``` + +## Expected Database Credentials + +The native database support expects the following credentials: + +### PostgreSQL +- **URI**: `postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable` +- **User**: `postgres` +- **Password**: `postgres` +- **Port**: `5432` + +### MySQL +- **URI**: `root:mysecretpassword@tcp(localhost:3306)/mysql?multiStatements=true&parseTime=true` +- **User**: `root` +- **Password**: `mysecretpassword` +- **Port**: `3306` + +## Running Tests + +```bash +# Run end-to-end tests +go test -v -run TestReplay -timeout 20m ./internal/endtoend/... + +# With verbose logging +go test -v -run TestReplay -timeout 20m ./internal/endtoend/... 2>&1 | tee test.log +``` + +## Troubleshooting + +### apt-get times out or fails +- Ensure HTTP proxy is configured in `/etc/apt/apt.conf.d/99proxy` +- Check that the proxy URL is correct: `echo $HTTP_PROXY` +- Try running `sudo apt-get update` first to verify connectivity + +### MySQL connection refused +- Check if MySQL is running: `sudo service mysql status` +- Verify the password: `mysql -uroot -pmysecretpassword -e "SELECT 1;"` +- Check if MySQL is listening on TCP: `netstat -tlnp | grep 3306` + +### PostgreSQL authentication failed +- Verify pg_hba.conf has md5 authentication for localhost +- Check password: `PGPASSWORD=postgres psql -h localhost -U postgres -c "SELECT 1;"` +- Reload PostgreSQL after config changes: `sudo service postgresql reload` + +### DNS resolution fails +This is expected in some environments. Configure apt proxy as shown above. diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index 311eba9825..7634918446 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -18,6 +18,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/opts" "github.com/sqlc-dev/sqlc/internal/sqltest/docker" + "github.com/sqlc-dev/sqlc/internal/sqltest/native" ) func lineEndings() cmp.Option { @@ -113,23 +114,63 @@ func TestReplay(t *testing.T) { ctx := context.Background() var mysqlURI, postgresURI string - if err := docker.Installed(); err == nil { - { - host, err := docker.StartPostgreSQLServer(ctx) - if err != nil { - t.Fatalf("starting postgresql failed: %s", err) + + // First, check environment variables + if uri := os.Getenv("POSTGRESQL_SERVER_URI"); uri != "" { + postgresURI = uri + } + if uri := os.Getenv("MYSQL_SERVER_URI"); uri != "" { + mysqlURI = uri + } + + // Try Docker for any missing databases + if postgresURI == "" || mysqlURI == "" { + if err := docker.Installed(); err == nil { + if postgresURI == "" { + host, err := docker.StartPostgreSQLServer(ctx) + if err != nil { + t.Logf("docker postgresql startup failed: %s", err) + } else { + postgresURI = host + } + } + if mysqlURI == "" { + host, err := docker.StartMySQLServer(ctx) + if err != nil { + t.Logf("docker mysql startup failed: %s", err) + } else { + mysqlURI = host + } } - postgresURI = host } - { - host, err := docker.StartMySQLServer(ctx) - if err != nil { - t.Fatalf("starting mysql failed: %s", err) + } + + // Try native installation for any missing databases (Linux only) + if postgresURI == "" || mysqlURI == "" { + if err := native.Supported(); err == nil { + if postgresURI == "" { + host, err := native.StartPostgreSQLServer(ctx) + if err != nil { + t.Logf("native postgresql startup failed: %s", err) + } else { + postgresURI = host + } + } + if mysqlURI == "" { + host, err := native.StartMySQLServer(ctx) + if err != nil { + t.Logf("native mysql startup failed: %s", err) + } else { + mysqlURI = host + } } - mysqlURI = host } } + // Log which databases are available + t.Logf("PostgreSQL available: %v (URI: %s)", postgresURI != "", postgresURI) + t.Logf("MySQL available: %v (URI: %s)", mysqlURI != "", mysqlURI) + contexts := map[string]textContext{ "base": { Mutate: func(t *testing.T, path string) func(*config.Config) { return func(c *config.Config) {} }, @@ -138,19 +179,20 @@ func TestReplay(t *testing.T) { "managed-db": { Mutate: func(t *testing.T, path string) func(*config.Config) { return func(c *config.Config) { + // Add all servers - tests will fail if database isn't available c.Servers = []config.Server{ { Name: "postgres", Engine: config.EnginePostgreSQL, URI: postgresURI, }, - { Name: "mysql", Engine: config.EngineMySQL, URI: mysqlURI, }, } + for i := range c.SQL { switch c.SQL[i].Engine { case config.EnginePostgreSQL: @@ -161,6 +203,10 @@ func TestReplay(t *testing.T) { c.SQL[i].Database = &config.Database{ Managed: true, } + case config.EngineSQLite: + c.SQL[i].Database = &config.Database{ + Managed: true, + } default: // pass } @@ -168,8 +214,8 @@ func TestReplay(t *testing.T) { } }, Enabled: func() bool { - err := docker.Installed() - return err == nil + // Enabled if at least one database URI is available + return postgresURI != "" || mysqlURI != "" }, }, } @@ -217,8 +263,9 @@ func TestReplay(t *testing.T) { opts := cmd.Options{ Env: cmd.Env{ - Debug: opts.DebugFromString(args.Env["SQLCDEBUG"]), - NoRemote: true, + Debug: opts.DebugFromString(args.Env["SQLCDEBUG"]), + Experiment: opts.ExperimentFromString(args.Env["SQLCEXPERIMENT"]), + NoRemote: true, }, Stderr: &stderr, MutateConfig: testctx.Mutate(t, path), diff --git a/internal/endtoend/fmt_test.go b/internal/endtoend/fmt_test.go index 04e753e5b7..eac3fa0390 100644 --- a/internal/endtoend/fmt_test.go +++ b/internal/endtoend/fmt_test.go @@ -3,71 +3,183 @@ package main import ( "bytes" "fmt" + "io" "os" "path/filepath" "strings" "testing" + "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/engine/dolphin" "github.com/sqlc-dev/sqlc/internal/engine/postgresql" + "github.com/sqlc-dev/sqlc/internal/engine/sqlite" "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/format" ) +// sqlParser is an interface for SQL parsers +type sqlParser interface { + Parse(r io.Reader) ([]ast.Statement, error) +} + +// sqlFormatter is an interface for formatters +type sqlFormatter interface { + format.Dialect +} + func TestFormat(t *testing.T) { t.Parallel() - parse := postgresql.NewParser() for _, tc := range FindTests(t, "testdata", "base") { tc := tc - - if !strings.Contains(tc.Path, filepath.Join("pgx/v5")) { - continue - } - - q := filepath.Join(tc.Path, "query.sql") - if _, err := os.Stat(q); os.IsNotExist(err) { - continue - } - t.Run(tc.Name, func(t *testing.T) { - contents, err := os.ReadFile(q) + // Parse the config file to determine the engine + configPath := filepath.Join(tc.Path, tc.ConfigName) + configFile, err := os.Open(configPath) if err != nil { t.Fatal(err) } - for i, query := range bytes.Split(bytes.TrimSpace(contents), []byte(";")) { - if len(query) <= 1 { - continue - } - query := query - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - expected, err := postgresql.Fingerprint(string(query)) + conf, err := config.ParseConfig(configFile) + configFile.Close() + if err != nil { + t.Fatal(err) + } + + // Skip if there are no SQL packages configured + if len(conf.SQL) == 0 { + return + } + + engine := conf.SQL[0].Engine + + // Select the appropriate parser and fingerprint function based on engine + var parse sqlParser + var formatter sqlFormatter + var fingerprint func(string) (string, error) + + switch engine { + case config.EnginePostgreSQL: + pgParser := postgresql.NewParser() + parse = pgParser + formatter = pgParser + fingerprint = postgresql.Fingerprint + case config.EngineMySQL: + mysqlParser := dolphin.NewParser() + parse = mysqlParser + formatter = mysqlParser + // For MySQL, we use a "round-trip" fingerprint: parse the SQL, format it, + // and return the formatted string. This tests that our formatting produces + // valid SQL that parses to the same AST structure. + fingerprint = func(sql string) (string, error) { + stmts, err := mysqlParser.Parse(strings.NewReader(sql)) if err != nil { - t.Fatal(err) + return "", err } - stmts, err := parse.Parse(bytes.NewReader(query)) - if err != nil { - t.Fatal(err) + if len(stmts) == 0 { + return "", nil } - if len(stmts) != 1 { - t.Fatal("expected one statement") + return ast.Format(stmts[0].Raw, mysqlParser), nil + } + case config.EngineSQLite: + sqliteParser := sqlite.NewParser() + parse = sqliteParser + formatter = sqliteParser + // For SQLite, we use the same "round-trip" fingerprint strategy as MySQL: + // parse the SQL, format it, and return the formatted string. + fingerprint = func(sql string) (string, error) { + stmts, err := sqliteParser.Parse(strings.NewReader(sql)) + if err != nil { + return "", err } - if false { - r, err := postgresql.Parse(string(query)) - debug.Dump(r, err) + if len(stmts) == 0 { + return "", nil } + return strings.ToLower(ast.Format(stmts[0].Raw, sqliteParser)), nil + } + default: + // Skip unsupported engines + return + } - out := ast.Format(stmts[0].Raw) - actual, err := postgresql.Fingerprint(out) + // Find query files from config + var queryFiles []string + for _, sql := range conf.SQL { + for _, q := range sql.Queries { + queryPath := filepath.Join(tc.Path, q) + info, err := os.Stat(queryPath) if err != nil { - t.Error(err) + continue } - if expected != actual { - debug.Dump(stmts[0].Raw) - t.Errorf("- %s", expected) - t.Errorf("- %s", string(query)) - t.Errorf("+ %s", actual) - t.Errorf("+ %s", out) + if info.IsDir() { + // If it's a directory, glob for .sql files + matches, err := filepath.Glob(filepath.Join(queryPath, "*.sql")) + if err != nil { + continue + } + queryFiles = append(queryFiles, matches...) + } else { + queryFiles = append(queryFiles, queryPath) } - }) + } + } + + if len(queryFiles) == 0 { + return + } + + for _, queryFile := range queryFiles { + if _, err := os.Stat(queryFile); os.IsNotExist(err) { + continue + } + + contents, err := os.ReadFile(queryFile) + if err != nil { + t.Fatal(err) + } + + // Parse the entire file to get proper statement boundaries + stmts, err := parse.Parse(bytes.NewReader(contents)) + if err != nil { + // Skip files with parse errors (e.g., syntax_errors test cases) + return + } + + for i, stmt := range stmts { + stmt := stmt + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + // Extract the original query text using statement location and length + start := stmt.Raw.StmtLocation + length := stmt.Raw.StmtLen + if length == 0 { + // If StmtLen is 0, it means the statement goes to the end of the input + length = len(contents) - start + } + query := strings.TrimSpace(string(contents[start : start+length])) + + expected, err := fingerprint(query) + if err != nil { + t.Fatal(err) + } + + if false { + r, err := postgresql.Parse(query) + debug.Dump(r, err) + } + + out := ast.Format(stmt.Raw, formatter) + actual, err := fingerprint(out) + if err != nil { + t.Error(err) + } + if expected != actual { + debug.Dump(stmt.Raw) + t.Errorf("- %s", expected) + t.Errorf("- %s", query) + t.Errorf("+ %s", actual) + t.Errorf("+ %s", out) + } + }) + } } }) } diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/exec.json b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/exec.json new file mode 100644 index 0000000000..aaf587c793 --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/exec.json @@ -0,0 +1,6 @@ +{ + "contexts": ["managed-db"], + "env": { + "SQLCEXPERIMENT": "analyzerv2" + } +} diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..3b320aa168 --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..90b88c3389 --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +type Product struct { + ID int32 + Name string + Price string +} diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..8d31d41cdf --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,65 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const getProductStats = `-- name: GetProductStats :one +WITH product_stats AS ( + SELECT COUNT(*) as total, AVG(price) as avg_price FROM products +) +SELECT total, avg_price FROM product_stats +` + +type GetProductStatsRow struct { + Total int64 + AvgPrice string +} + +func (q *Queries) GetProductStats(ctx context.Context) (GetProductStatsRow, error) { + row := q.db.QueryRowContext(ctx, getProductStats) + var i GetProductStatsRow + err := row.Scan(&i.Total, &i.AvgPrice) + return i, err +} + +const listExpensiveProducts = `-- name: ListExpensiveProducts :many +WITH expensive AS ( + SELECT id, name, price FROM products WHERE price > 100 +) +SELECT id, name, price FROM expensive +` + +type ListExpensiveProductsRow struct { + ID int32 + Name string + Price string +} + +func (q *Queries) ListExpensiveProducts(ctx context.Context) ([]ListExpensiveProductsRow, error) { + rows, err := q.db.QueryContext(ctx, listExpensiveProducts) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListExpensiveProductsRow + for rows.Next() { + var i ListExpensiveProductsRow + if err := rows.Scan(&i.ID, &i.Name, &i.Price); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/query.sql b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..4626fe0f04 --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/query.sql @@ -0,0 +1,11 @@ +-- name: ListExpensiveProducts :many +WITH expensive AS ( + SELECT * FROM products WHERE price > 100 +) +SELECT * FROM expensive; + +-- name: GetProductStats :one +WITH product_stats AS ( + SELECT COUNT(*) as total, AVG(price) as avg_price FROM products +) +SELECT * FROM product_stats; diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..17aaa6e650 --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE products ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + price NUMERIC(10,2) NOT NULL +); diff --git a/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/sqlc.yaml b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/sqlc.yaml new file mode 100644 index 0000000000..629b01dea6 --- /dev/null +++ b/internal/endtoend/testdata/accurate_cte/postgresql/stdlib/sqlc.yaml @@ -0,0 +1,13 @@ +version: "2" +sql: + - engine: postgresql + schema: "schema.sql" + queries: "query.sql" + database: + managed: true + analyzer: + database: "only" + gen: + go: + package: "querytest" + out: "go" diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/exec.json b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/exec.json new file mode 100644 index 0000000000..aaf587c793 --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/exec.json @@ -0,0 +1,6 @@ +{ + "contexts": ["managed-db"], + "env": { + "SQLCEXPERIMENT": "analyzerv2" + } +} diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..3b320aa168 --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..2b42787339 --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/models.go @@ -0,0 +1,59 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "database/sql/driver" + "fmt" +) + +type Status string + +const ( + StatusPending Status = "pending" + StatusActive Status = "active" + StatusCompleted Status = "completed" +) + +func (e *Status) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = Status(s) + case string: + *e = Status(s) + default: + return fmt.Errorf("unsupported scan type for Status: %T", src) + } + return nil +} + +type NullStatus struct { + Status Status + Valid bool // Valid is true if Status is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullStatus) Scan(value interface{}) error { + if value == nil { + ns.Status, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.Status.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.Status), nil +} + +type Task struct { + ID int32 + Title string + Status Status +} diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..263a6b6736 --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,80 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const createTask = `-- name: CreateTask :one +INSERT INTO tasks (title, status) VALUES ($1, $2) RETURNING id, title, status +` + +type CreateTaskParams struct { + Title string + Status Status +} + +func (q *Queries) CreateTask(ctx context.Context, arg CreateTaskParams) (Task, error) { + row := q.db.QueryRowContext(ctx, createTask, arg.Title, arg.Status) + var i Task + err := row.Scan(&i.ID, &i.Title, &i.Status) + return i, err +} + +const getTasksByStatus = `-- name: GetTasksByStatus :many +SELECT id, title, status FROM tasks WHERE status = $1 +` + +func (q *Queries) GetTasksByStatus(ctx context.Context, status Status) ([]Task, error) { + rows, err := q.db.QueryContext(ctx, getTasksByStatus, status) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Task + for rows.Next() { + var i Task + if err := rows.Scan(&i.ID, &i.Title, &i.Status); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listTasks = `-- name: ListTasks :many +SELECT id, title, status FROM tasks +` + +func (q *Queries) ListTasks(ctx context.Context) ([]Task, error) { + rows, err := q.db.QueryContext(ctx, listTasks) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Task + for rows.Next() { + var i Task + if err := rows.Scan(&i.ID, &i.Title, &i.Status); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/query.sql b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..11dcd9bf48 --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/query.sql @@ -0,0 +1,8 @@ +-- name: ListTasks :many +SELECT * FROM tasks; + +-- name: GetTasksByStatus :many +SELECT * FROM tasks WHERE status = $1; + +-- name: CreateTask :one +INSERT INTO tasks (title, status) VALUES ($1, $2) RETURNING *; diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..443ae9845f --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/schema.sql @@ -0,0 +1,7 @@ +CREATE TYPE status AS ENUM ('pending', 'active', 'completed'); + +CREATE TABLE tasks ( + id SERIAL PRIMARY KEY, + title TEXT NOT NULL, + status status NOT NULL DEFAULT 'pending' +); diff --git a/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/sqlc.yaml b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/sqlc.yaml new file mode 100644 index 0000000000..629b01dea6 --- /dev/null +++ b/internal/endtoend/testdata/accurate_enum/postgresql/stdlib/sqlc.yaml @@ -0,0 +1,13 @@ +version: "2" +sql: + - engine: postgresql + schema: "schema.sql" + queries: "query.sql" + database: + managed: true + analyzer: + database: "only" + gen: + go: + package: "querytest" + out: "go" diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/exec.json b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/exec.json new file mode 100644 index 0000000000..aaf587c793 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/exec.json @@ -0,0 +1,6 @@ +{ + "contexts": ["managed-db"], + "env": { + "SQLCEXPERIMENT": "analyzerv2" + } +} diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/db.go b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/db.go new file mode 100644 index 0000000000..3b320aa168 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/models.go b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/models.go new file mode 100644 index 0000000000..eaf05e5c00 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "database/sql" +) + +type Author struct { + ID int64 + Name string + Bio sql.NullString +} diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/query.sql.go b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/query.sql.go new file mode 100644 index 0000000000..203224ead2 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/go/query.sql.go @@ -0,0 +1,65 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const createAuthor = `-- name: CreateAuthor :one +INSERT INTO authors (name, bio) VALUES (?, ?) RETURNING id, name, bio +` + +type CreateAuthorParams struct { + Name string + Bio sql.NullString +} + +func (q *Queries) CreateAuthor(ctx context.Context, arg CreateAuthorParams) (Author, error) { + row := q.db.QueryRowContext(ctx, createAuthor, arg.Name, arg.Bio) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const getAuthor = `-- name: GetAuthor :one +SELECT id, name, bio FROM authors WHERE id = ? +` + +func (q *Queries) GetAuthor(ctx context.Context, id int64) (Author, error) { + row := q.db.QueryRowContext(ctx, getAuthor, id) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const listAuthors = `-- name: ListAuthors :many +SELECT id, name, bio FROM authors +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/query.sql b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/query.sql new file mode 100644 index 0000000000..8fe23a8600 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/query.sql @@ -0,0 +1,8 @@ +-- name: GetAuthor :one +SELECT * FROM authors WHERE id = ?; + +-- name: ListAuthors :many +SELECT * FROM authors; + +-- name: CreateAuthor :one +INSERT INTO authors (name, bio) VALUES (?, ?) RETURNING *; diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/schema.sql b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/schema.sql new file mode 100644 index 0000000000..22fc0993c1 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + bio TEXT +); diff --git a/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/sqlc.yaml b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/sqlc.yaml new file mode 100644 index 0000000000..d2da6c31b2 --- /dev/null +++ b/internal/endtoend/testdata/accurate_sqlite/sqlite/stdlib/sqlc.yaml @@ -0,0 +1,13 @@ +version: "2" +sql: + - engine: sqlite + schema: "schema.sql" + queries: "query.sql" + database: + managed: true + analyzer: + database: "only" + gen: + go: + package: "querytest" + out: "go" diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/exec.json b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/exec.json new file mode 100644 index 0000000000..aaf587c793 --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/exec.json @@ -0,0 +1,6 @@ +{ + "contexts": ["managed-db"], + "env": { + "SQLCEXPERIMENT": "analyzerv2" + } +} diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..3b320aa168 --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..ec1cb8d670 --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "database/sql" +) + +type Author struct { + ID int32 + Name string + Bio sql.NullString +} diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..9e2820cdbd --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,93 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const createAuthor = `-- name: CreateAuthor :one +INSERT INTO authors (name, bio) VALUES ($1, $2) RETURNING id, name, bio +` + +type CreateAuthorParams struct { + Name string + Bio sql.NullString +} + +func (q *Queries) CreateAuthor(ctx context.Context, arg CreateAuthorParams) (Author, error) { + row := q.db.QueryRowContext(ctx, createAuthor, arg.Name, arg.Bio) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const deleteAuthor = `-- name: DeleteAuthor :one +DELETE FROM authors WHERE id = $1 RETURNING id, name, bio +` + +func (q *Queries) DeleteAuthor(ctx context.Context, id int32) (Author, error) { + row := q.db.QueryRowContext(ctx, deleteAuthor, id) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const getAuthor = `-- name: GetAuthor :one +SELECT id, name, bio FROM authors WHERE id = $1 +` + +func (q *Queries) GetAuthor(ctx context.Context, id int32) (Author, error) { + row := q.db.QueryRowContext(ctx, getAuthor, id) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const listAuthors = `-- name: ListAuthors :many +SELECT id, name, bio FROM authors +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateAuthor = `-- name: UpdateAuthor :one +UPDATE authors SET name = $1, bio = $2 WHERE id = $3 RETURNING id, name, bio +` + +type UpdateAuthorParams struct { + Name string + Bio sql.NullString + ID int32 +} + +func (q *Queries) UpdateAuthor(ctx context.Context, arg UpdateAuthorParams) (Author, error) { + row := q.db.QueryRowContext(ctx, updateAuthor, arg.Name, arg.Bio, arg.ID) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/query.sql b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..e091a5eaef --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/query.sql @@ -0,0 +1,14 @@ +-- name: ListAuthors :many +SELECT * FROM authors; + +-- name: GetAuthor :one +SELECT * FROM authors WHERE id = $1; + +-- name: CreateAuthor :one +INSERT INTO authors (name, bio) VALUES ($1, $2) RETURNING *; + +-- name: UpdateAuthor :one +UPDATE authors SET name = $1, bio = $2 WHERE id = $3 RETURNING *; + +-- name: DeleteAuthor :one +DELETE FROM authors WHERE id = $1 RETURNING *; diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..ca6ad1e2cf --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + bio TEXT +); diff --git a/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/sqlc.yaml b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/sqlc.yaml new file mode 100644 index 0000000000..629b01dea6 --- /dev/null +++ b/internal/endtoend/testdata/accurate_star_expansion/postgresql/stdlib/sqlc.yaml @@ -0,0 +1,13 @@ +version: "2" +sql: + - engine: postgresql + schema: "schema.sql" + queries: "query.sql" + database: + managed: true + analyzer: + database: "only" + gen: + go: + package: "querytest" + out: "go" diff --git a/internal/endtoend/testdata/builtins/sqlite/go/scalarfunc.sql.go b/internal/endtoend/testdata/builtins/sqlite/go/scalarfunc.sql.go index a54cb8134c..0e7d271c32 100644 --- a/internal/endtoend/testdata/builtins/sqlite/go/scalarfunc.sql.go +++ b/internal/endtoend/testdata/builtins/sqlite/go/scalarfunc.sql.go @@ -384,17 +384,6 @@ func (q *Queries) GetSQLiteCompileOptionUsed(ctx context.Context) (int64, error) return sqlite_compileoption_used, err } -const getSQLiteOffset = `-- name: GetSQLiteOffset :one -SELECT sqlite_offset(1) -` - -func (q *Queries) GetSQLiteOffset(ctx context.Context) (sql.NullInt64, error) { - row := q.db.QueryRowContext(ctx, getSQLiteOffset) - var sqlite_offset sql.NullInt64 - err := row.Scan(&sqlite_offset) - return sqlite_offset, err -} - const getSQLiteSourceID = `-- name: GetSQLiteSourceID :one SELECT sqlite_source_id() ` diff --git a/internal/endtoend/testdata/builtins/sqlite/queries/scalarfunc.sql b/internal/endtoend/testdata/builtins/sqlite/queries/scalarfunc.sql index 728a6be1a5..571cdb958a 100644 --- a/internal/endtoend/testdata/builtins/sqlite/queries/scalarfunc.sql +++ b/internal/endtoend/testdata/builtins/sqlite/queries/scalarfunc.sql @@ -106,9 +106,6 @@ SELECT sqlite_compileoption_get(1); -- name: GetSQLiteCompileOptionUsed :one SELECT sqlite_compileoption_used(1); --- name: GetSQLiteOffset :one -SELECT sqlite_offset(1); - -- name: GetSQLiteSourceID :one SELECT sqlite_source_id(); diff --git a/internal/endtoend/testdata/ddl_create_trigger/sqlite/go/models.go b/internal/endtoend/testdata/ddl_create_trigger/sqlite/go/models.go index b4ca845334..e651fe2f3d 100644 --- a/internal/endtoend/testdata/ddl_create_trigger/sqlite/go/models.go +++ b/internal/endtoend/testdata/ddl_create_trigger/sqlite/go/models.go @@ -18,3 +18,14 @@ type CustomerAddress struct { CustID int64 CustAddr sql.NullString } + +type TriggerCustomer struct { + Name string + Address sql.NullString +} + +type TriggerOrder struct { + ID int64 + CustomerName sql.NullString + Address sql.NullString +} diff --git a/internal/endtoend/testdata/ddl_create_trigger/sqlite/schema.sql b/internal/endtoend/testdata/ddl_create_trigger/sqlite/schema.sql index 9143d0c069..59df748064 100644 --- a/internal/endtoend/testdata/ddl_create_trigger/sqlite/schema.sql +++ b/internal/endtoend/testdata/ddl_create_trigger/sqlite/schema.sql @@ -1,9 +1,20 @@ /* examples copied from https://www.sqlite.org/lang_createtrigger.html only expectation in sqlc is that they parse, codegen is unaffected */ -CREATE TRIGGER update_customer_address UPDATE OF address ON customers +CREATE TABLE trigger_customers ( + name TEXT PRIMARY KEY, + address TEXT +); + +CREATE TABLE trigger_orders ( + id INTEGER PRIMARY KEY, + customer_name TEXT, + address TEXT +); + +CREATE TRIGGER update_customer_address UPDATE OF address ON trigger_customers BEGIN - UPDATE orders SET address = new.address WHERE customer_name = old.name; + UPDATE trigger_orders SET address = new.address WHERE customer_name = old.name; END; CREATE TABLE customer( diff --git a/internal/endtoend/testdata/insert_select_invalid/sqlite/exec.json b/internal/endtoend/testdata/insert_select_invalid/sqlite/exec.json new file mode 100644 index 0000000000..e5dfda7818 --- /dev/null +++ b/internal/endtoend/testdata/insert_select_invalid/sqlite/exec.json @@ -0,0 +1,3 @@ +{ + "contexts": ["managed-db"] +} diff --git a/internal/endtoend/testdata/insert_select_invalid/sqlite/query.sql b/internal/endtoend/testdata/insert_select_invalid/sqlite/query.sql index cfd90fe55d..3311b32009 100644 --- a/internal/endtoend/testdata/insert_select_invalid/sqlite/query.sql +++ b/internal/endtoend/testdata/insert_select_invalid/sqlite/query.sql @@ -1,5 +1,3 @@ -CREATE TABLE foo (bar text); - -- name: InsertFoo :exec INSERT INTO foo (bar) SELECT 1, ?, ?; diff --git a/internal/endtoend/testdata/insert_select_invalid/sqlite/schema.sql b/internal/endtoend/testdata/insert_select_invalid/sqlite/schema.sql new file mode 100644 index 0000000000..d849628fb1 --- /dev/null +++ b/internal/endtoend/testdata/insert_select_invalid/sqlite/schema.sql @@ -0,0 +1 @@ +CREATE TABLE foo (bar text); diff --git a/internal/endtoend/testdata/insert_select_invalid/sqlite/sqlc.json b/internal/endtoend/testdata/insert_select_invalid/sqlite/sqlc.json index 13e65f3ffd..f8e8051087 100644 --- a/internal/endtoend/testdata/insert_select_invalid/sqlite/sqlc.json +++ b/internal/endtoend/testdata/insert_select_invalid/sqlite/sqlc.json @@ -5,7 +5,7 @@ "engine": "sqlite", "path": "go", "name": "querytest", - "schema": "query.sql", + "schema": "schema.sql", "queries": "query.sql" } ] diff --git a/internal/endtoend/testdata/insert_select_invalid/sqlite/stderr.txt b/internal/endtoend/testdata/insert_select_invalid/sqlite/stderr.txt index 063b2a149a..20a7ac053a 100644 --- a/internal/endtoend/testdata/insert_select_invalid/sqlite/stderr.txt +++ b/internal/endtoend/testdata/insert_select_invalid/sqlite/stderr.txt @@ -1,2 +1,2 @@ # package querytest -query.sql:4:1: INSERT has more expressions than target columns +query.sql:1:1: sqlite3: SQL logic error: 3 values for 1 columns diff --git a/internal/endtoend/testdata/invalid_group_by_reference/sqlite/exec.json b/internal/endtoend/testdata/invalid_group_by_reference/sqlite/exec.json new file mode 100644 index 0000000000..e5dfda7818 --- /dev/null +++ b/internal/endtoend/testdata/invalid_group_by_reference/sqlite/exec.json @@ -0,0 +1,3 @@ +{ + "contexts": ["managed-db"] +} diff --git a/internal/endtoend/testdata/invalid_group_by_reference/sqlite/query.sql b/internal/endtoend/testdata/invalid_group_by_reference/sqlite/query.sql index 41ed0cf32c..b036fba240 100644 --- a/internal/endtoend/testdata/invalid_group_by_reference/sqlite/query.sql +++ b/internal/endtoend/testdata/invalid_group_by_reference/sqlite/query.sql @@ -1,10 +1,3 @@ -CREATE TABLE authors ( - id integer NOT NULL PRIMARY KEY AUTOINCREMENT, - name text NOT NULL, - bio text, - UNIQUE(name) -); - -- name: ListAuthors :many SELECT * FROM authors diff --git a/internal/endtoend/testdata/invalid_group_by_reference/sqlite/schema.sql b/internal/endtoend/testdata/invalid_group_by_reference/sqlite/schema.sql new file mode 100644 index 0000000000..e3ed6b0dba --- /dev/null +++ b/internal/endtoend/testdata/invalid_group_by_reference/sqlite/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE authors ( + id integer NOT NULL PRIMARY KEY AUTOINCREMENT, + name text NOT NULL, + bio text, + UNIQUE(name) +); diff --git a/internal/endtoend/testdata/invalid_group_by_reference/sqlite/sqlc.json b/internal/endtoend/testdata/invalid_group_by_reference/sqlite/sqlc.json index fcb288cb35..d4963e751f 100644 --- a/internal/endtoend/testdata/invalid_group_by_reference/sqlite/sqlc.json +++ b/internal/endtoend/testdata/invalid_group_by_reference/sqlite/sqlc.json @@ -5,7 +5,7 @@ "path": "go", "engine": "sqlite", "name": "querytest", - "schema": "query.sql", + "schema": "schema.sql", "queries": "query.sql" } ] diff --git a/internal/endtoend/testdata/invalid_group_by_reference/sqlite/stderr.txt b/internal/endtoend/testdata/invalid_group_by_reference/sqlite/stderr.txt index 1fc9998d4c..d255c11c94 100644 --- a/internal/endtoend/testdata/invalid_group_by_reference/sqlite/stderr.txt +++ b/internal/endtoend/testdata/invalid_group_by_reference/sqlite/stderr.txt @@ -1,2 +1,2 @@ # package querytest -query.sql:11:10: column reference "invalid_reference" not found +query.sql:1:1: sqlite3: SQL logic error: no such column: invalid_reference diff --git a/internal/endtoend/testdata/invalid_table_alias/sqlite/exec.json b/internal/endtoend/testdata/invalid_table_alias/sqlite/exec.json new file mode 100644 index 0000000000..e5dfda7818 --- /dev/null +++ b/internal/endtoend/testdata/invalid_table_alias/sqlite/exec.json @@ -0,0 +1,3 @@ +{ + "contexts": ["managed-db"] +} diff --git a/internal/endtoend/testdata/invalid_table_alias/sqlite/query.sql b/internal/endtoend/testdata/invalid_table_alias/sqlite/query.sql index 22482fb724..52f5aae051 100644 --- a/internal/endtoend/testdata/invalid_table_alias/sqlite/query.sql +++ b/internal/endtoend/testdata/invalid_table_alias/sqlite/query.sql @@ -1,10 +1,3 @@ --- https://github.com/sqlc-dev/sqlc/issues/437 -CREATE TABLE authors ( - id INT PRIMARY KEY, - name VARCHAR(255) NOT NULL, - bio text -); - -- name: GetAuthor :one SELECT * FROM authors a diff --git a/internal/endtoend/testdata/invalid_table_alias/sqlite/schema.sql b/internal/endtoend/testdata/invalid_table_alias/sqlite/schema.sql new file mode 100644 index 0000000000..fe5a44f601 --- /dev/null +++ b/internal/endtoend/testdata/invalid_table_alias/sqlite/schema.sql @@ -0,0 +1,6 @@ +-- https://github.com/sqlc-dev/sqlc/issues/437 +CREATE TABLE authors ( + id INTEGER PRIMARY KEY, + name VARCHAR(255) NOT NULL, + bio text +); diff --git a/internal/endtoend/testdata/invalid_table_alias/sqlite/sqlc.json b/internal/endtoend/testdata/invalid_table_alias/sqlite/sqlc.json index fcb288cb35..d4963e751f 100644 --- a/internal/endtoend/testdata/invalid_table_alias/sqlite/sqlc.json +++ b/internal/endtoend/testdata/invalid_table_alias/sqlite/sqlc.json @@ -5,7 +5,7 @@ "path": "go", "engine": "sqlite", "name": "querytest", - "schema": "query.sql", + "schema": "schema.sql", "queries": "query.sql" } ] diff --git a/internal/endtoend/testdata/invalid_table_alias/sqlite/stderr.txt b/internal/endtoend/testdata/invalid_table_alias/sqlite/stderr.txt index 810c893a70..97e43851e0 100644 --- a/internal/endtoend/testdata/invalid_table_alias/sqlite/stderr.txt +++ b/internal/endtoend/testdata/invalid_table_alias/sqlite/stderr.txt @@ -1,2 +1,2 @@ # package querytest -query.sql:11:9: table alias "p" does not exist +query.sql:1:1: sqlite3: SQL logic error: no such column: p.id diff --git a/internal/endtoend/testdata/join_left_same_table/sqlite/go/query.sql.go b/internal/endtoend/testdata/join_left_same_table/sqlite/go/query.sql.go index 82a6d25562..c25e22e249 100644 --- a/internal/endtoend/testdata/join_left_same_table/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/join_left_same_table/sqlite/go/query.sql.go @@ -17,7 +17,7 @@ SELECT a.id, p.name as alias_name FROM authors AS a LEFT JOIN authors AS p - ON (authors.parent_id = p.id) + ON (a.parent_id = p.id) ` type AllAuthorsRow struct { diff --git a/internal/endtoend/testdata/join_left_same_table/sqlite/query.sql b/internal/endtoend/testdata/join_left_same_table/sqlite/query.sql index 11f6c6903b..79daa2dfd5 100644 --- a/internal/endtoend/testdata/join_left_same_table/sqlite/query.sql +++ b/internal/endtoend/testdata/join_left_same_table/sqlite/query.sql @@ -5,4 +5,4 @@ SELECT a.id, p.name as alias_name FROM authors AS a LEFT JOIN authors AS p - ON (authors.parent_id = p.id); + ON (a.parent_id = p.id); diff --git a/internal/endtoend/testdata/limit/sqlite/go/query.sql.go b/internal/endtoend/testdata/limit/sqlite/go/query.sql.go index 31a0ab2993..3612dc12ef 100644 --- a/internal/endtoend/testdata/limit/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/limit/sqlite/go/query.sql.go @@ -9,15 +9,6 @@ import ( "context" ) -const deleteLimit = `-- name: DeleteLimit :exec -DELETE FROM foo LIMIT ? -` - -func (q *Queries) DeleteLimit(ctx context.Context, limit int64) error { - _, err := q.db.ExecContext(ctx, deleteLimit, limit) - return err -} - const limitMe = `-- name: LimitMe :many SELECT bar FROM foo LIMIT ? ` @@ -44,12 +35,3 @@ func (q *Queries) LimitMe(ctx context.Context, limit int64) ([]bool, error) { } return items, nil } - -const updateLimit = `-- name: UpdateLimit :exec -UPDATE foo SET bar='baz' LIMIT ? -` - -func (q *Queries) UpdateLimit(ctx context.Context, limit int64) error { - _, err := q.db.ExecContext(ctx, updateLimit, limit) - return err -} diff --git a/internal/endtoend/testdata/limit/sqlite/query.sql b/internal/endtoend/testdata/limit/sqlite/query.sql index 025e2a812b..8514c9b476 100644 --- a/internal/endtoend/testdata/limit/sqlite/query.sql +++ b/internal/endtoend/testdata/limit/sqlite/query.sql @@ -1,8 +1,2 @@ -- name: LimitMe :many SELECT bar FROM foo LIMIT ?; - --- name: UpdateLimit :exec -UPDATE foo SET bar='baz' LIMIT ?; - --- name: DeleteLimit :exec -DELETE FROM foo LIMIT ?; diff --git a/internal/endtoend/testdata/quoted_names_complex/sqlite/schema.sql b/internal/endtoend/testdata/quoted_names_complex/sqlite/schema.sql index fc6a73756e..5486831199 100644 --- a/internal/endtoend/testdata/quoted_names_complex/sqlite/schema.sql +++ b/internal/endtoend/testdata/quoted_names_complex/sqlite/schema.sql @@ -12,7 +12,7 @@ ALTER TABLE products ADD COLUMN "Price Info" text; -- Test mixed case operations across different statement types INSERT INTO "user profiles" ("profile data") VALUES ('test data'); -UPDATE "ORDERS" SET data = 'updated' WHERE id = 1; +UPDATE "customer_orders" SET data = 'updated' WHERE id = 1; DELETE FROM products WHERE id = 1; -- Test DROP with various identifier formats diff --git a/internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go b/internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go index e22e5b6f33..b30fa7d95a 100644 --- a/internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go @@ -21,9 +21,9 @@ SELECT ) ` -func (q *Queries) BarExists(ctx context.Context, id int64) (int64, error) { +func (q *Queries) BarExists(ctx context.Context, id int64) (bool, error) { row := q.db.QueryRowContext(ctx, barExists, id) - var column_1 int64 - err := row.Scan(&column_1) - return column_1, err + var exists bool + err := row.Scan(&exists) + return exists, err } diff --git a/internal/endtoend/testdata/select_exists/sqlite/schema.sql b/internal/endtoend/testdata/select_exists/sqlite/schema.sql index 52799a37db..cf6a8b9507 100644 --- a/internal/endtoend/testdata/select_exists/sqlite/schema.sql +++ b/internal/endtoend/testdata/select_exists/sqlite/schema.sql @@ -1,2 +1 @@ -CREATE TABLE bar (id int not null primary key autoincrement); - +CREATE TABLE bar (id integer not null primary key autoincrement); diff --git a/internal/endtoend/testdata/select_not_exists/sqlite/exec.json b/internal/endtoend/testdata/select_not_exists/sqlite/exec.json new file mode 100644 index 0000000000..e5dfda7818 --- /dev/null +++ b/internal/endtoend/testdata/select_not_exists/sqlite/exec.json @@ -0,0 +1,3 @@ +{ + "contexts": ["managed-db"] +} diff --git a/internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go b/internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go index ee1b8e548b..91dea13570 100644 --- a/internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go @@ -10,7 +10,7 @@ import ( ) const barNotExists = `-- name: BarNotExists :one -SELECT +SELECT NOT EXISTS ( SELECT 1 @@ -21,9 +21,9 @@ SELECT ) ` -func (q *Queries) BarNotExists(ctx context.Context) (interface{}, error) { - row := q.db.QueryRowContext(ctx, barNotExists) - var column_1 interface{} - err := row.Scan(&column_1) - return column_1, err +func (q *Queries) BarNotExists(ctx context.Context, id int64) (bool, error) { + row := q.db.QueryRowContext(ctx, barNotExists, id) + var not_exists bool + err := row.Scan(¬_exists) + return not_exists, err } diff --git a/internal/endtoend/testdata/select_not_exists/sqlite/query.sql b/internal/endtoend/testdata/select_not_exists/sqlite/query.sql index d868c64a0b..f7e76ae92c 100644 --- a/internal/endtoend/testdata/select_not_exists/sqlite/query.sql +++ b/internal/endtoend/testdata/select_not_exists/sqlite/query.sql @@ -1,5 +1,5 @@ -- name: BarNotExists :one -SELECT +SELECT NOT EXISTS ( SELECT 1 diff --git a/internal/endtoend/testdata/sqlc_embed/sqlite/go/query.sql.go b/internal/endtoend/testdata/sqlc_embed/sqlite/go/query.sql.go index aafc0897a8..6b7b33ae28 100644 --- a/internal/endtoend/testdata/sqlc_embed/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/sqlc_embed/sqlite/go/query.sql.go @@ -35,7 +35,7 @@ func (q *Queries) Duplicate(ctx context.Context) (DuplicateRow, error) { const join = `-- name: Join :one SELECT u.id, u.name, u.age, p.id, p.user_id FROM posts AS p -INNER JOIN users AS u ON p.user_id = u.users.id +INNER JOIN users AS u ON p.user_id = u.id ` type JoinRow struct { diff --git a/internal/endtoend/testdata/sqlc_embed/sqlite/query.sql b/internal/endtoend/testdata/sqlc_embed/sqlite/query.sql index 4b999b5629..1d0a02f109 100644 --- a/internal/endtoend/testdata/sqlc_embed/sqlite/query.sql +++ b/internal/endtoend/testdata/sqlc_embed/sqlite/query.sql @@ -15,7 +15,7 @@ SELECT sqlc.embed(users), sqlc.embed(users) FROM users; -- name: Join :one SELECT sqlc.embed(u), sqlc.embed(p) FROM posts AS p -INNER JOIN users AS u ON p.user_id = u.users.id; +INNER JOIN users AS u ON p.user_id = u.id; -- name: WithSchema :one SELECT sqlc.embed(bu) FROM baz.users AS bu; diff --git a/internal/endtoend/testdata/sqlc_embed/sqlite/schema.sql b/internal/endtoend/testdata/sqlc_embed/sqlite/schema.sql index a67026ba33..5a1d371b7e 100644 --- a/internal/endtoend/testdata/sqlc_embed/sqlite/schema.sql +++ b/internal/endtoend/testdata/sqlc_embed/sqlite/schema.sql @@ -1,4 +1,4 @@ -ATTACH 'baz.db' AS baz; +ATTACH ':memory:' AS baz; CREATE TABLE users ( id integer PRIMARY KEY, diff --git a/internal/engine/clickhouse/catalog.go b/internal/engine/clickhouse/catalog.go new file mode 100644 index 0000000000..fb0511f72e --- /dev/null +++ b/internal/engine/clickhouse/catalog.go @@ -0,0 +1,16 @@ +package clickhouse + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func NewCatalog() *catalog.Catalog { + def := "default" // ClickHouse default database + return &catalog.Catalog{ + DefaultSchema: def, + Schemas: []*catalog.Schema{ + defaultSchema(def), + }, + Extensions: map[string]struct{}{}, + } +} diff --git a/internal/engine/clickhouse/convert.go b/internal/engine/clickhouse/convert.go new file mode 100644 index 0000000000..ba2817e2bb --- /dev/null +++ b/internal/engine/clickhouse/convert.go @@ -0,0 +1,1020 @@ +package clickhouse + +import ( + "strconv" + "strings" + + chast "github.com/sqlc-dev/doubleclick/ast" + + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +type cc struct { + paramCount int +} + +func (c *cc) convert(node chast.Node) ast.Node { + switch n := node.(type) { + case *chast.SelectWithUnionQuery: + return c.convertSelectWithUnionQuery(n) + case *chast.SelectQuery: + return c.convertSelectQuery(n) + case *chast.InsertQuery: + return c.convertInsertQuery(n) + case *chast.CreateQuery: + return c.convertCreateQuery(n) + case *chast.UpdateQuery: + return c.convertUpdateQuery(n) + case *chast.DeleteQuery: + return c.convertDeleteQuery(n) + case *chast.DropQuery: + return c.convertDropQuery(n) + case *chast.AlterQuery: + return c.convertAlterQuery(n) + case *chast.TruncateQuery: + return c.convertTruncateQuery(n) + default: + return todo(n) + } +} + +func (c *cc) convertSelectWithUnionQuery(n *chast.SelectWithUnionQuery) ast.Node { + if len(n.Selects) == 0 { + return &ast.TODO{} + } + + // Single select without union + if len(n.Selects) == 1 { + return c.convert(n.Selects[0]) + } + + // Build a chain of SelectStmt with UNION operations + var result *ast.SelectStmt + for i, sel := range n.Selects { + stmt, ok := c.convert(sel).(*ast.SelectStmt) + if !ok { + continue + } + if i == 0 { + result = stmt + } else { + unionMode := ast.Union + if i-1 < len(n.UnionModes) { + switch strings.ToUpper(n.UnionModes[i-1]) { + case "ALL": + unionMode = ast.Union + case "DISTINCT": + unionMode = ast.Union + } + } + result = &ast.SelectStmt{ + Op: unionMode, + All: n.UnionAll || (i-1 < len(n.UnionModes) && strings.ToUpper(n.UnionModes[i-1]) == "ALL"), + Larg: result, + Rarg: stmt, + } + } + } + return result +} + +func (c *cc) convertSelectQuery(n *chast.SelectQuery) *ast.SelectStmt { + stmt := &ast.SelectStmt{} + + // Convert target list (SELECT columns) + if len(n.Columns) > 0 { + stmt.TargetList = &ast.List{} + for _, col := range n.Columns { + target := c.convertToResTarget(col) + if target != nil { + stmt.TargetList.Items = append(stmt.TargetList.Items, target) + } + } + } + + // Convert FROM clause + if n.From != nil { + stmt.FromClause = c.convertTablesInSelectQuery(n.From) + } + + // Convert WHERE clause + if n.Where != nil { + stmt.WhereClause = c.convertExpr(n.Where) + } + + // Convert GROUP BY clause + if len(n.GroupBy) > 0 { + stmt.GroupClause = &ast.List{} + for _, expr := range n.GroupBy { + stmt.GroupClause.Items = append(stmt.GroupClause.Items, c.convertExpr(expr)) + } + } + + // Convert HAVING clause + if n.Having != nil { + stmt.HavingClause = c.convertExpr(n.Having) + } + + // Convert ORDER BY clause + if len(n.OrderBy) > 0 { + stmt.SortClause = &ast.List{} + for _, orderBy := range n.OrderBy { + stmt.SortClause.Items = append(stmt.SortClause.Items, c.convertOrderByElement(orderBy)) + } + } + + // Convert LIMIT clause + if n.Limit != nil { + stmt.LimitCount = c.convertExpr(n.Limit) + } + + // Convert OFFSET clause + if n.Offset != nil { + stmt.LimitOffset = c.convertExpr(n.Offset) + } + + // Convert DISTINCT clause + if n.Distinct { + stmt.DistinctClause = &ast.List{} + } + + // Convert DISTINCT ON clause + if len(n.DistinctOn) > 0 { + stmt.DistinctClause = &ast.List{} + for _, expr := range n.DistinctOn { + stmt.DistinctClause.Items = append(stmt.DistinctClause.Items, c.convertExpr(expr)) + } + } + + // Convert WITH clause (CTEs) + if len(n.With) > 0 { + stmt.WithClause = &ast.WithClause{ + Ctes: &ast.List{}, + } + for _, cte := range n.With { + if aliased, ok := cte.(*chast.AliasedExpr); ok { + cteNode := &ast.CommonTableExpr{ + Ctename: &aliased.Alias, + } + // CTE expression may be a Subquery containing the actual SELECT + if subq, ok := aliased.Expr.(*chast.Subquery); ok { + cteNode.Ctequery = c.convert(subq.Query) + } else { + // Fallback: treat the expression itself as the query + cteNode.Ctequery = c.convertExpr(aliased.Expr) + } + stmt.WithClause.Ctes.Items = append(stmt.WithClause.Ctes.Items, cteNode) + } + } + } + + return stmt +} + +func (c *cc) convertToResTarget(expr chast.Expression) *ast.ResTarget { + res := &ast.ResTarget{ + Location: expr.Pos().Offset, + } + + switch e := expr.(type) { + case *chast.Asterisk: + if e.Table != "" { + // table.* + res.Val = &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + NewIdentifier(e.Table), + &ast.A_Star{}, + }, + }, + } + } else { + // Just * + res.Val = &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{&ast.A_Star{}}, + }, + } + } + case *chast.AliasedExpr: + res.Name = &e.Alias + res.Val = c.convertExpr(e.Expr) + case *chast.Identifier: + if e.Alias != "" { + res.Name = &e.Alias + } + res.Val = c.convertIdentifier(e) + case *chast.FunctionCall: + if e.Alias != "" { + res.Name = &e.Alias + } + res.Val = c.convertFunctionCall(e) + default: + res.Val = c.convertExpr(expr) + } + + return res +} + +func (c *cc) convertTablesInSelectQuery(n *chast.TablesInSelectQuery) *ast.List { + if n == nil || len(n.Tables) == 0 { + return nil + } + + result := &ast.List{} + + for i, elem := range n.Tables { + if elem.Table != nil { + tableExpr := c.convertTableExpression(elem.Table) + if i == 0 { + result.Items = append(result.Items, tableExpr) + } else if elem.Join != nil { + // This element has a join + joinExpr := c.convertTableJoin(elem.Join, result.Items[len(result.Items)-1], tableExpr) + result.Items[len(result.Items)-1] = joinExpr + } else { + result.Items = append(result.Items, tableExpr) + } + } else if elem.Join != nil && len(result.Items) > 0 { + // Join without table (should not happen normally) + continue + } + } + + return result +} + +func (c *cc) convertTableExpression(n *chast.TableExpression) ast.Node { + var result ast.Node + + switch t := n.Table.(type) { + case *chast.TableIdentifier: + rv := parseTableIdentifierToRangeVar(t) + if n.Alias != "" { + alias := n.Alias + rv.Alias = &ast.Alias{Aliasname: &alias} + } + result = rv + case *chast.Subquery: + subselect := &ast.RangeSubselect{ + Subquery: c.convert(t.Query), + } + alias := n.Alias + if alias == "" && t.Alias != "" { + alias = t.Alias + } + if alias != "" { + subselect.Alias = &ast.Alias{Aliasname: &alias} + } + result = subselect + case *chast.FunctionCall: + // Table function like file(), url(), etc. + rf := &ast.RangeFunction{ + Functions: &ast.List{ + Items: []ast.Node{c.convertFunctionCall(t)}, + }, + } + if n.Alias != "" { + alias := n.Alias + rf.Alias = &ast.Alias{Aliasname: &alias} + } + result = rf + default: + result = &ast.TODO{} + } + + return result +} + +func (c *cc) convertTableJoin(n *chast.TableJoin, left, right ast.Node) *ast.JoinExpr { + join := &ast.JoinExpr{ + Larg: left, + Rarg: right, + } + + // Convert join type + switch n.Type { + case chast.JoinInner: + join.Jointype = ast.JoinTypeInner + case chast.JoinLeft: + join.Jointype = ast.JoinTypeLeft + case chast.JoinRight: + join.Jointype = ast.JoinTypeRight + case chast.JoinFull: + join.Jointype = ast.JoinTypeFull + case chast.JoinCross: + join.Jointype = ast.JoinTypeInner + join.IsNatural = false + default: + join.Jointype = ast.JoinTypeInner + } + + // Convert ON clause + if n.On != nil { + join.Quals = c.convertExpr(n.On) + } + + // Convert USING clause + if len(n.Using) > 0 { + join.UsingClause = &ast.List{} + for _, u := range n.Using { + if id, ok := u.(*chast.Identifier); ok { + join.UsingClause.Items = append(join.UsingClause.Items, NewIdentifier(id.Name())) + } + } + } + + return join +} + +func (c *cc) convertExpr(expr chast.Expression) ast.Node { + if expr == nil { + return nil + } + + switch e := expr.(type) { + case *chast.Identifier: + return c.convertIdentifier(e) + case *chast.Literal: + return c.convertLiteral(e) + case *chast.BinaryExpr: + return c.convertBinaryExpr(e) + case *chast.FunctionCall: + return c.convertFunctionCall(e) + case *chast.AliasedExpr: + return c.convertExpr(e.Expr) + case *chast.Parameter: + return c.convertParameter(e) + case *chast.Asterisk: + return c.convertAsterisk(e) + case *chast.CaseExpr: + return c.convertCaseExpr(e) + case *chast.CastExpr: + return c.convertCastExpr(e) + case *chast.BetweenExpr: + return c.convertBetweenExpr(e) + case *chast.InExpr: + return c.convertInExpr(e) + case *chast.IsNullExpr: + return c.convertIsNullExpr(e) + case *chast.LikeExpr: + return c.convertLikeExpr(e) + case *chast.Subquery: + return c.convertSubquery(e) + case *chast.ArrayAccess: + return c.convertArrayAccess(e) + case *chast.UnaryExpr: + return c.convertUnaryExpr(e) + case *chast.Lambda: + // Lambda expressions are ClickHouse-specific, return as-is for now + return &ast.TODO{} + default: + return &ast.TODO{} + } +} + +func (c *cc) convertIdentifier(n *chast.Identifier) *ast.ColumnRef { + fields := &ast.List{} + for _, part := range n.Parts { + fields.Items = append(fields.Items, NewIdentifier(part)) + } + return &ast.ColumnRef{ + Fields: fields, + Location: n.Pos().Offset, + } +} + +func (c *cc) convertLiteral(n *chast.Literal) *ast.A_Const { + switch n.Type { + case chast.LiteralString: + str := n.Value.(string) + return &ast.A_Const{ + Val: &ast.String{Str: str}, + Location: n.Pos().Offset, + } + case chast.LiteralInteger: + var ival int64 + switch v := n.Value.(type) { + case int64: + ival = v + case int: + ival = int64(v) + case float64: + ival = int64(v) + case string: + ival, _ = strconv.ParseInt(v, 10, 64) + } + return &ast.A_Const{ + Val: &ast.Integer{Ival: ival}, + Location: n.Pos().Offset, + } + case chast.LiteralFloat: + var fval float64 + switch v := n.Value.(type) { + case float64: + fval = v + case string: + fval, _ = strconv.ParseFloat(v, 64) + } + str := strconv.FormatFloat(fval, 'f', -1, 64) + return &ast.A_Const{ + Val: &ast.Float{Str: str}, + Location: n.Pos().Offset, + } + case chast.LiteralBoolean: + // ClickHouse booleans are typically 0/1 + bval := n.Value.(bool) + if bval { + return &ast.A_Const{ + Val: &ast.Integer{Ival: 1}, + Location: n.Pos().Offset, + } + } + return &ast.A_Const{ + Val: &ast.Integer{Ival: 0}, + Location: n.Pos().Offset, + } + case chast.LiteralNull: + return &ast.A_Const{ + Val: &ast.Null{}, + Location: n.Pos().Offset, + } + default: + return &ast.A_Const{ + Location: n.Pos().Offset, + } + } +} + +func (c *cc) convertBinaryExpr(n *chast.BinaryExpr) ast.Node { + op := strings.ToUpper(n.Op) + + // Handle logical operators + if op == "AND" || op == "OR" { + var boolop ast.BoolExprType + if op == "AND" { + boolop = ast.BoolExprTypeAnd + } else { + boolop = ast.BoolExprTypeOr + } + return &ast.BoolExpr{ + Boolop: boolop, + Args: &ast.List{ + Items: []ast.Node{ + c.convertExpr(n.Left), + c.convertExpr(n.Right), + }, + }, + Location: n.Pos().Offset, + } + } + + // Handle other operators + return &ast.A_Expr{ + Name: &ast.List{ + Items: []ast.Node{&ast.String{Str: n.Op}}, + }, + Lexpr: c.convertExpr(n.Left), + Rexpr: c.convertExpr(n.Right), + Location: n.Pos().Offset, + } +} + +func (c *cc) convertFunctionCall(n *chast.FunctionCall) *ast.FuncCall { + fc := &ast.FuncCall{ + Funcname: &ast.List{ + Items: []ast.Node{&ast.String{Str: n.Name}}, + }, + Location: n.Pos().Offset, + AggDistinct: n.Distinct, + } + + // Convert arguments + if len(n.Arguments) > 0 { + fc.Args = &ast.List{} + for _, arg := range n.Arguments { + fc.Args.Items = append(fc.Args.Items, c.convertExpr(arg)) + } + } + + // Convert window function + if n.Over != nil { + fc.Over = &ast.WindowDef{} + if len(n.Over.PartitionBy) > 0 { + fc.Over.PartitionClause = &ast.List{} + for _, p := range n.Over.PartitionBy { + fc.Over.PartitionClause.Items = append(fc.Over.PartitionClause.Items, c.convertExpr(p)) + } + } + if len(n.Over.OrderBy) > 0 { + fc.Over.OrderClause = &ast.List{} + for _, o := range n.Over.OrderBy { + fc.Over.OrderClause.Items = append(fc.Over.OrderClause.Items, c.convertOrderByElement(o)) + } + } + } + + return fc +} + +func (c *cc) convertParameter(n *chast.Parameter) ast.Node { + c.paramCount++ + // Use the parameter name if available + name := n.Name + if name == "" { + name = strconv.Itoa(c.paramCount) + } + return &ast.ParamRef{ + Number: c.paramCount, + Location: n.Pos().Offset, + } +} + +func (c *cc) convertAsterisk(n *chast.Asterisk) *ast.ColumnRef { + fields := &ast.List{} + if n.Table != "" { + fields.Items = append(fields.Items, NewIdentifier(n.Table)) + } + fields.Items = append(fields.Items, &ast.A_Star{}) + return &ast.ColumnRef{ + Fields: fields, + Location: n.Pos().Offset, + } +} + +func (c *cc) convertCaseExpr(n *chast.CaseExpr) *ast.CaseExpr { + ce := &ast.CaseExpr{ + Location: n.Pos().Offset, + } + + // Convert test expression (CASE expr WHEN ...) + if n.Operand != nil { + ce.Arg = c.convertExpr(n.Operand) + } + + // Convert WHEN clauses + if len(n.Whens) > 0 { + ce.Args = &ast.List{} + for _, when := range n.Whens { + caseWhen := &ast.CaseWhen{ + Expr: c.convertExpr(when.Condition), + Result: c.convertExpr(when.Result), + } + ce.Args.Items = append(ce.Args.Items, caseWhen) + } + } + + // Convert ELSE clause + if n.Else != nil { + ce.Defresult = c.convertExpr(n.Else) + } + + return ce +} + +func (c *cc) convertCastExpr(n *chast.CastExpr) *ast.TypeCast { + tc := &ast.TypeCast{ + Arg: c.convertExpr(n.Expr), + Location: n.Pos().Offset, + } + + if n.Type != nil { + tc.TypeName = &ast.TypeName{ + Name: n.Type.Name, + } + } + + return tc +} + +func (c *cc) convertBetweenExpr(n *chast.BetweenExpr) *ast.BetweenExpr { + return &ast.BetweenExpr{ + Expr: c.convertExpr(n.Expr), + Left: c.convertExpr(n.Low), + Right: c.convertExpr(n.High), + Not: n.Not, + Location: n.Pos().Offset, + } +} + +func (c *cc) convertInExpr(n *chast.InExpr) *ast.In { + in := &ast.In{ + Expr: c.convertExpr(n.Expr), + Not: n.Not, + Location: n.Pos().Offset, + } + + // Convert the list + if len(n.List) > 0 { + in.List = make([]ast.Node, 0, len(n.List)) + for _, item := range n.List { + in.List = append(in.List, c.convertExpr(item)) + } + } + + // Handle subquery + if n.Query != nil { + in.Sel = c.convert(n.Query) + } + + return in +} + +func (c *cc) convertIsNullExpr(n *chast.IsNullExpr) *ast.NullTest { + nullTest := &ast.NullTest{ + Arg: c.convertExpr(n.Expr), + Location: n.Pos().Offset, + } + if n.Not { + nullTest.Nulltesttype = ast.NullTestTypeIsNotNull + } else { + nullTest.Nulltesttype = ast.NullTestTypeIsNull + } + return nullTest +} + +func (c *cc) convertLikeExpr(n *chast.LikeExpr) *ast.A_Expr { + kind := ast.A_Expr_Kind(0) + opName := "~~" + if n.CaseInsensitive { + opName = "~~*" + } + if n.Not { + opName = "!~~" + if n.CaseInsensitive { + opName = "!~~*" + } + } + + return &ast.A_Expr{ + Kind: kind, + Name: &ast.List{ + Items: []ast.Node{&ast.String{Str: opName}}, + }, + Lexpr: c.convertExpr(n.Expr), + Rexpr: c.convertExpr(n.Pattern), + Location: n.Pos().Offset, + } +} + +func (c *cc) convertSubquery(n *chast.Subquery) *ast.SubLink { + return &ast.SubLink{ + SubLinkType: ast.EXISTS_SUBLINK, + Subselect: c.convert(n.Query), + } +} + +func (c *cc) convertArrayAccess(n *chast.ArrayAccess) *ast.A_Indirection { + return &ast.A_Indirection{ + Arg: c.convertExpr(n.Array), + Indirection: &ast.List{ + Items: []ast.Node{ + &ast.A_Indices{ + Uidx: c.convertExpr(n.Index), + }, + }, + }, + } +} + +func (c *cc) convertUnaryExpr(n *chast.UnaryExpr) ast.Node { + op := strings.ToUpper(n.Op) + + if op == "NOT" { + return &ast.BoolExpr{ + Boolop: ast.BoolExprTypeNot, + Args: &ast.List{ + Items: []ast.Node{c.convertExpr(n.Operand)}, + }, + Location: n.Pos().Offset, + } + } + + return &ast.A_Expr{ + Name: &ast.List{ + Items: []ast.Node{&ast.String{Str: n.Op}}, + }, + Rexpr: c.convertExpr(n.Operand), + Location: n.Pos().Offset, + } +} + +func (c *cc) convertOrderByElement(n *chast.OrderByElement) *ast.SortBy { + sortBy := &ast.SortBy{ + Node: c.convertExpr(n.Expression), + Location: n.Expression.Pos().Offset, + } + + if n.Descending { + sortBy.SortbyDir = ast.SortByDirDesc + } else { + sortBy.SortbyDir = ast.SortByDirAsc + } + + if n.NullsFirst != nil { + if *n.NullsFirst { + sortBy.SortbyNulls = ast.SortByNullsFirst + } else { + sortBy.SortbyNulls = ast.SortByNullsLast + } + } + + return sortBy +} + +func (c *cc) convertInsertQuery(n *chast.InsertQuery) *ast.InsertStmt { + stmt := &ast.InsertStmt{ + Relation: &ast.RangeVar{ + Relname: &n.Table, + }, + } + + if n.Database != "" { + stmt.Relation.Schemaname = &n.Database + } + + // Convert column list + if len(n.Columns) > 0 { + stmt.Cols = &ast.List{} + for _, col := range n.Columns { + name := col.Name() + stmt.Cols.Items = append(stmt.Cols.Items, &ast.ResTarget{ + Name: &name, + }) + } + } + + // Convert SELECT subquery if present + if n.Select != nil { + stmt.SelectStmt = c.convert(n.Select) + } + + // Convert VALUES clause + if len(n.Values) > 0 { + selectStmt := &ast.SelectStmt{ + ValuesLists: &ast.List{}, + } + for _, row := range n.Values { + rowList := &ast.List{} + for _, val := range row { + rowList.Items = append(rowList.Items, c.convertExpr(val)) + } + selectStmt.ValuesLists.Items = append(selectStmt.ValuesLists.Items, rowList) + } + stmt.SelectStmt = selectStmt + } + + return stmt +} + +func (c *cc) convertCreateQuery(n *chast.CreateQuery) ast.Node { + // Handle CREATE DATABASE + if n.CreateDatabase { + return &ast.CreateSchemaStmt{ + Name: &n.Database, + IfNotExists: n.IfNotExists, + } + } + + // Handle CREATE TABLE + if n.Table != "" { + stmt := &ast.CreateTableStmt{ + Name: &ast.TableName{ + Name: identifier(n.Table), + }, + IfNotExists: n.IfNotExists, + } + + if n.Database != "" { + stmt.Name.Schema = identifier(n.Database) + } + + // Convert columns + for _, col := range n.Columns { + colDef := c.convertColumnDeclaration(col) + stmt.Cols = append(stmt.Cols, colDef) + } + + // Convert AS SELECT + if n.AsSelect != nil { + // This is a CREATE TABLE ... AS SELECT + // The AsSelect field contains the SELECT statement + } + + return stmt + } + + // Handle CREATE VIEW + if n.View != "" { + return &ast.ViewStmt{ + View: &ast.RangeVar{ + Relname: &n.View, + }, + Query: c.convert(n.AsSelect), + Replace: n.OrReplace, + } + } + + return &ast.TODO{} +} + +func (c *cc) convertColumnDeclaration(n *chast.ColumnDeclaration) *ast.ColumnDef { + colDef := &ast.ColumnDef{ + Colname: identifier(n.Name), + IsNotNull: isNotNull(n), + } + + if n.Type != nil { + colDef.TypeName = &ast.TypeName{ + Name: n.Type.Name, + } + // Handle type parameters (e.g., Decimal(10, 2)) + if len(n.Type.Parameters) > 0 { + colDef.TypeName.Typmods = &ast.List{} + for _, param := range n.Type.Parameters { + colDef.TypeName.Typmods.Items = append(colDef.TypeName.Typmods.Items, c.convertExpr(param)) + } + } + } + + // Handle PRIMARY KEY constraint + if n.PrimaryKey { + colDef.PrimaryKey = true + } + + // Handle DEFAULT + if n.Default != nil { + // colDef.RawDefault = c.convertExpr(n.Default) + } + + // Handle comment + if n.Comment != "" { + colDef.Comment = n.Comment + } + + return colDef +} + +func (c *cc) convertUpdateQuery(n *chast.UpdateQuery) *ast.UpdateStmt { + rv := &ast.RangeVar{ + Relname: &n.Table, + } + if n.Database != "" { + rv.Schemaname = &n.Database + } + stmt := &ast.UpdateStmt{ + Relations: &ast.List{ + Items: []ast.Node{rv}, + }, + } + + // Convert assignments + if len(n.Assignments) > 0 { + stmt.TargetList = &ast.List{} + for _, assign := range n.Assignments { + name := identifier(assign.Column) + stmt.TargetList.Items = append(stmt.TargetList.Items, &ast.ResTarget{ + Name: &name, + Val: c.convertExpr(assign.Value), + }) + } + } + + // Convert WHERE clause + if n.Where != nil { + stmt.WhereClause = c.convertExpr(n.Where) + } + + return stmt +} + +func (c *cc) convertDeleteQuery(n *chast.DeleteQuery) *ast.DeleteStmt { + rv := &ast.RangeVar{ + Relname: &n.Table, + } + if n.Database != "" { + rv.Schemaname = &n.Database + } + stmt := &ast.DeleteStmt{ + Relations: &ast.List{ + Items: []ast.Node{rv}, + }, + } + + // Convert WHERE clause + if n.Where != nil { + stmt.WhereClause = c.convertExpr(n.Where) + } + + return stmt +} + +func (c *cc) convertDropQuery(n *chast.DropQuery) ast.Node { + // Handle DROP TABLE + if n.Table != "" { + tableName := &ast.TableName{ + Name: identifier(n.Table), + } + if n.Database != "" { + tableName.Schema = identifier(n.Database) + } + return &ast.DropTableStmt{ + IfExists: n.IfExists, + Tables: []*ast.TableName{tableName}, + } + } + + // Handle DROP TABLE with multiple tables + if len(n.Tables) > 0 { + tables := make([]*ast.TableName, 0, len(n.Tables)) + for _, t := range n.Tables { + tables = append(tables, parseTableName(t)) + } + return &ast.DropTableStmt{ + IfExists: n.IfExists, + Tables: tables, + } + } + + // Handle DROP DATABASE - return TODO for now + // Handle DROP VIEW - return TODO for now + return &ast.TODO{} +} + +func (c *cc) convertAlterQuery(n *chast.AlterQuery) ast.Node { + alt := &ast.AlterTableStmt{ + Table: &ast.TableName{ + Name: identifier(n.Table), + }, + Cmds: &ast.List{}, + } + + if n.Database != "" { + alt.Table.Schema = identifier(n.Database) + } + + for _, cmd := range n.Commands { + switch cmd.Type { + case chast.AlterAddColumn: + if cmd.Column != nil { + name := cmd.Column.Name + alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ + Name: &name, + Subtype: ast.AT_AddColumn, + Def: c.convertColumnDeclaration(cmd.Column), + }) + } + case chast.AlterDropColumn: + name := cmd.ColumnName + alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ + Name: &name, + Subtype: ast.AT_DropColumn, + MissingOk: cmd.IfExists, + }) + case chast.AlterModifyColumn: + if cmd.Column != nil { + name := cmd.Column.Name + // Drop and re-add to simulate modify + alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ + Name: &name, + Subtype: ast.AT_DropColumn, + }) + alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ + Name: &name, + Subtype: ast.AT_AddColumn, + Def: c.convertColumnDeclaration(cmd.Column), + }) + } + case chast.AlterRenameColumn: + oldName := cmd.ColumnName + newName := cmd.NewName + return &ast.RenameColumnStmt{ + Table: alt.Table, + Col: &ast.ColumnRef{Name: oldName}, + NewName: &newName, + } + } + } + + return alt +} + +func (c *cc) convertTruncateQuery(n *chast.TruncateQuery) *ast.TruncateStmt { + stmt := &ast.TruncateStmt{ + Relations: &ast.List{}, + } + + tableName := n.Table + schemaName := n.Database + + rv := &ast.RangeVar{ + Relname: &tableName, + } + if schemaName != "" { + rv.Schemaname = &schemaName + } + + stmt.Relations.Items = append(stmt.Relations.Items, rv) + + return stmt +} diff --git a/internal/engine/clickhouse/format.go b/internal/engine/clickhouse/format.go new file mode 100644 index 0000000000..c103c7803f --- /dev/null +++ b/internal/engine/clickhouse/format.go @@ -0,0 +1,35 @@ +package clickhouse + +// QuoteIdent returns a quoted identifier if it needs quoting. +// ClickHouse uses backticks or double quotes for quoting identifiers. +func (p *Parser) QuoteIdent(s string) string { + // For now, don't quote - can be extended to quote when necessary + return s +} + +// TypeName returns the SQL type name for the given namespace and name. +func (p *Parser) TypeName(ns, name string) string { + if ns != "" { + return ns + "." + name + } + return name +} + +// Param returns the parameter placeholder for the given number. +// ClickHouse uses {name:Type} for named parameters, but for positional +// parameters we use ? which is supported by the clickhouse-go driver. +func (p *Parser) Param(n int) string { + return "?" +} + +// NamedParam returns the named parameter placeholder for the given name. +// ClickHouse uses {name:Type} syntax for named parameters. +func (p *Parser) NamedParam(name string) string { + return "{" + name + ":String}" +} + +// Cast returns a type cast expression. +// ClickHouse uses CAST(expr AS type) syntax, same as MySQL. +func (p *Parser) Cast(arg, typeName string) string { + return "CAST(" + arg + " AS " + typeName + ")" +} diff --git a/internal/engine/clickhouse/parse.go b/internal/engine/clickhouse/parse.go new file mode 100644 index 0000000000..282089f31d --- /dev/null +++ b/internal/engine/clickhouse/parse.go @@ -0,0 +1,64 @@ +package clickhouse + +import ( + "bytes" + "context" + "io" + + "github.com/sqlc-dev/doubleclick/parser" + + "github.com/sqlc-dev/sqlc/internal/source" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func NewParser() *Parser { + return &Parser{} +} + +type Parser struct{} + +func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { + blob, err := io.ReadAll(r) + if err != nil { + return nil, err + } + + ctx := context.Background() + stmtNodes, err := parser.Parse(ctx, bytes.NewReader(blob)) + if err != nil { + return nil, err + } + + var stmts []ast.Statement + for _, stmt := range stmtNodes { + converter := &cc{} + out := converter.convert(stmt) + if _, ok := out.(*ast.TODO); ok { + continue + } + + // Get position information from the statement + pos := stmt.Pos() + end := stmt.End() + stmtLen := end.Offset - pos.Offset + + stmts = append(stmts, ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: out, + StmtLocation: pos.Offset, + StmtLen: stmtLen, + }, + }) + } + + return stmts, nil +} + +// https://clickhouse.com/docs/en/sql-reference/syntax#comments +func (p *Parser) CommentSyntax() source.CommentSyntax { + return source.CommentSyntax{ + Dash: true, // -- comment + SlashStar: true, // /* comment */ + Hash: true, // # comment (ClickHouse supports this) + } +} diff --git a/internal/engine/clickhouse/reserved.go b/internal/engine/clickhouse/reserved.go new file mode 100644 index 0000000000..1a9ac45f3a --- /dev/null +++ b/internal/engine/clickhouse/reserved.go @@ -0,0 +1,150 @@ +package clickhouse + +import "strings" + +// https://clickhouse.com/docs/en/sql-reference/syntax#keywords +func (p *Parser) IsReservedKeyword(s string) bool { + switch strings.ToLower(s) { + case "add": + case "after": + case "alias": + case "all": + case "alter": + case "and": + case "anti": + case "any": + case "array": + case "as": + case "asc": + case "asof": + case "between": + case "both": + case "by": + case "case": + case "cast": + case "check": + case "cluster": + case "collate": + case "column": + case "comment": + case "constraint": + case "create": + case "cross": + case "cube": + case "database": + case "databases": + case "default": + case "delete": + case "desc": + case "describe": + case "detach": + case "distinct": + case "distributed": + case "drop": + case "else": + case "end": + case "engine": + case "exists": + case "explain": + case "expression": + case "extract": + case "false": + case "fetch": + case "final": + case "first": + case "for": + case "format": + case "from": + case "full": + case "function": + case "global": + case "grant": + case "group": + case "having": + case "if": + case "ilike": + case "in": + case "index": + case "inner": + case "insert": + case "interpolate": + case "interval": + case "into": + case "is": + case "join": + case "key": + case "kill": + case "last": + case "leading": + case "left": + case "like": + case "limit": + case "live": + case "local": + case "logs": + case "materialized": + case "modify": + case "natural": + case "not": + case "null": + case "nulls": + case "offset": + case "on": + case "optimize": + case "or": + case "order": + case "outer": + case "outfile": + case "over": + case "partition": + case "paste": + case "populate": + case "prewhere": + case "primary": + case "projection": + case "rename": + case "replace": + case "right": + case "rollup": + case "sample": + case "select": + case "semi": + case "set": + case "settings": + case "show": + case "storage": + case "substring": + case "sync": + case "system": + case "table": + case "tables": + case "temporary": + case "test": + case "then": + case "ties": + case "to": + case "top": + case "totals": + case "trailing": + case "trim": + case "true": + case "truncate": + case "ttl": + case "type": + case "union": + case "update": + case "use": + case "using": + case "uuid": + case "values": + case "view": + case "watch": + case "when": + case "where": + case "window": + case "with": + default: + return false + } + return true +} diff --git a/internal/engine/clickhouse/stdlib.go b/internal/engine/clickhouse/stdlib.go new file mode 100644 index 0000000000..da7b53ab21 --- /dev/null +++ b/internal/engine/clickhouse/stdlib.go @@ -0,0 +1,9 @@ +package clickhouse + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func defaultSchema(name string) *catalog.Schema { + return &catalog.Schema{Name: name} +} diff --git a/internal/engine/clickhouse/utils.go b/internal/engine/clickhouse/utils.go new file mode 100644 index 0000000000..9e52f4d5a7 --- /dev/null +++ b/internal/engine/clickhouse/utils.go @@ -0,0 +1,59 @@ +package clickhouse + +import ( + "log" + "strings" + + chast "github.com/sqlc-dev/doubleclick/ast" + + "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func todo(n chast.Node) *ast.TODO { + if debug.Active { + log.Printf("clickhouse.convert: Unknown node type %T\n", n) + } + return &ast.TODO{} +} + +func identifier(id string) string { + return strings.ToLower(id) +} + +func NewIdentifier(t string) *ast.String { + return &ast.String{Str: identifier(t)} +} + +func parseTableName(n *chast.TableIdentifier) *ast.TableName { + return &ast.TableName{ + Schema: identifier(n.Database), + Name: identifier(n.Table), + } +} + +func parseTableIdentifierToRangeVar(n *chast.TableIdentifier) *ast.RangeVar { + schemaname := identifier(n.Database) + relname := identifier(n.Table) + return &ast.RangeVar{ + Schemaname: &schemaname, + Relname: &relname, + } +} + +func isNotNull(n *chast.ColumnDeclaration) bool { + if n.Type == nil { + return false + } + // Check if type is wrapped in Nullable() + // If it's Nullable, it can be null, so return false + // If it's not Nullable, it's NOT NULL by default in ClickHouse + if n.Type.Name != "" && strings.ToLower(n.Type.Name) == "nullable" { + return false + } + // Also check if Nullable field is explicitly set + if n.Nullable != nil && *n.Nullable { + return false + } + return true +} diff --git a/internal/engine/dolphin/CLAUDE.md b/internal/engine/dolphin/CLAUDE.md new file mode 100644 index 0000000000..20142fafaa --- /dev/null +++ b/internal/engine/dolphin/CLAUDE.md @@ -0,0 +1,224 @@ +# Dolphin Engine (MySQL) - Claude Code Guide + +The dolphin engine handles MySQL parsing and AST conversion using the TiDB parser. + +## Architecture + +### Parser Flow +``` +SQL String → TiDB Parser → TiDB AST → sqlc AST → Analysis/Codegen +``` + +### Key Files +- `convert.go` - Converts TiDB AST nodes to sqlc AST nodes +- `format.go` - MySQL-specific formatting (identifiers, types, parameters) +- `parse.go` - Entry point for parsing MySQL SQL + +## TiDB Parser + +The TiDB parser (`github.com/pingcap/tidb/pkg/parser`) is used for MySQL parsing: + +```go +import ( + pcast "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/types" +) +``` + +### Common TiDB Types +- `pcast.SelectStmt`, `pcast.InsertStmt`, etc. - Statement types +- `pcast.ColumnNameExpr` - Column reference +- `pcast.FuncCallExpr` - Function call +- `pcast.BinaryOperationExpr` - Binary expression +- `pcast.VariableExpr` - MySQL user variable (@var) +- `pcast.Join` - JOIN clause with Left, Right, On, Using + +## Conversion Pattern + +Each TiDB node type has a corresponding converter method: + +```go +func (c *cc) convertSelectStmt(n *pcast.SelectStmt) *ast.SelectStmt { + return &ast.SelectStmt{ + FromClause: c.convertTableRefsClause(n.From), + WhereClause: c.convert(n.Where), + // ... + } +} +``` + +The main `convert()` method dispatches to specific converters: +```go +func (c *cc) convert(node pcast.Node) ast.Node { + switch n := node.(type) { + case *pcast.SelectStmt: + return c.convertSelectStmt(n) + case *pcast.InsertStmt: + return c.convertInsertStmt(n) + // ... + } +} +``` + +## Key Conversions + +### Column References +```go +func (c *cc) convertColumnNameExpr(n *pcast.ColumnNameExpr) *ast.ColumnRef { + var items []ast.Node + if schema := n.Name.Schema.String(); schema != "" { + items = append(items, NewIdentifier(schema)) + } + if table := n.Name.Table.String(); table != "" { + items = append(items, NewIdentifier(table)) + } + items = append(items, NewIdentifier(n.Name.Name.String())) + return &ast.ColumnRef{Fields: &ast.List{Items: items}} +} +``` + +### JOINs +```go +func (c *cc) convertJoin(n *pcast.Join) *ast.List { + if n.Right != nil && n.Left != nil { + return &ast.List{ + Items: []ast.Node{&ast.JoinExpr{ + Jointype: ast.JoinType(n.Tp), + Larg: c.convert(n.Left), + Rarg: c.convert(n.Right), + Quals: c.convert(n.On), + UsingClause: convertUsing(n.Using), + }}, + } + } + // No join - just return tables + // ... +} +``` + +### MySQL User Variables +MySQL user variables (`@var`) are different from sqlc's `@param` syntax: +```go +func (c *cc) convertVariableExpr(n *pcast.VariableExpr) ast.Node { + // Use VariableExpr to preserve as-is (NOT A_Expr which would be treated as sqlc param) + return &ast.VariableExpr{ + Name: n.Name, + Location: n.OriginTextPosition(), + } +} +``` + +### Type Casts (CAST AS) +```go +func (c *cc) convertFuncCastExpr(n *pcast.FuncCastExpr) ast.Node { + typeName := types.TypeStr(n.Tp.GetType()) + // Handle UNSIGNED/SIGNED specially + if typeName == "bigint" { + if mysql.HasUnsignedFlag(n.Tp.GetFlag()) { + typeName = "bigint unsigned" + } else { + typeName = "bigint signed" + } + } + return &ast.TypeCast{ + Arg: c.convert(n.Expr), + TypeName: &ast.TypeName{Name: typeName}, + } +} +``` + +### Column Definitions +```go +func convertColumnDef(def *pcast.ColumnDef) *ast.ColumnDef { + typeName := &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())} + + // Only add Typmods for types where length is meaningful + tp := def.Tp.GetType() + flen := def.Tp.GetFlen() + switch tp { + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString: + if flen >= 0 { + typeName.Typmods = &ast.List{ + Items: []ast.Node{&ast.Integer{Ival: int64(flen)}}, + } + } + // Don't add for DATETIME, TIMESTAMP - internal flen is not user-specified + } + // ... +} +``` + +### Multi-Table DELETE +MySQL supports `DELETE t1, t2 FROM t1 JOIN t2 ...`: +```go +func (c *cc) convertDeleteStmt(n *pcast.DeleteStmt) *ast.DeleteStmt { + if n.IsMultiTable && n.Tables != nil { + // Convert targets (t1.*, t2.*) + targets := &ast.List{} + for _, table := range n.Tables.Tables { + // Build ColumnRef for each target + } + stmt.Targets = targets + + // Preserve JOINs in FromClause + stmt.FromClause = c.convertTableRefsClause(n.TableRefs).Items[0] + } else { + // Single-table DELETE + stmt.Relations = c.convertTableRefsClause(n.TableRefs) + } +} +``` + +## MySQL-Specific Formatting + +### format.go +```go +func (p *Parser) TypeName(ns, name string) string { + switch name { + case "bigint unsigned": + return "UNSIGNED" + case "bigint signed": + return "SIGNED" + } + return name +} + +func (p *Parser) Param(n int) string { + return "?" // MySQL uses ? for all parameters +} +``` + +## Common Issues and Solutions + +### Issue: Panic in Walk/Apply +**Cause**: New AST node type not handled in `astutils/walk.go` or `astutils/rewrite.go` +**Solution**: Add case for the node type in both files + +### Issue: sqlc.arg() not converted in ON DUPLICATE KEY UPDATE +**Cause**: `InsertStmt` case in `rewrite.go` didn't traverse `OnDuplicateKeyUpdate` +**Solution**: Add `a.apply(n, "OnDuplicateKeyUpdate", nil, n.OnDuplicateKeyUpdate)` + +### Issue: MySQL @variable being treated as parameter +**Cause**: Converting `VariableExpr` to `A_Expr` with `@` operator +**Solution**: Use `ast.VariableExpr` instead, which is not detected by `named.IsParamSign()` + +### Issue: Type length appearing incorrectly (e.g., datetime(39)) +**Cause**: Using internal `flen` for all types +**Solution**: Only populate `Typmods` for types where length is user-specified (varchar, char, etc.) + +## Testing + +### TestFormat +Tests that SQL can be: +1. Parsed +2. Formatted back to SQL +3. Re-parsed +4. Re-formatted to match + +### TestReplay +Tests the full sqlc pipeline: +1. Parse schema and queries +2. Analyze +3. Generate code +4. Compare with expected output diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 33b89ae8f4..1f68358ce4 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -2,6 +2,7 @@ package dolphin import ( "log" + "strconv" "strings" pcast "github.com/pingcap/tidb/pkg/parser/ast" @@ -187,8 +188,14 @@ func opToName(o opcode.Op) string { func (c *cc) convertBinaryOperationExpr(n *pcast.BinaryOperationExpr) ast.Node { if n.Op == opcode.LogicAnd || n.Op == opcode.LogicOr { + var boolop ast.BoolExprType + if n.Op == opcode.LogicAnd { + boolop = ast.BoolExprTypeAnd + } else { + boolop = ast.BoolExprTypeOr + } return &ast.BoolExpr{ - // TODO: Set op + Boolop: boolop, Args: &ast.List{ Items: []ast.Node{ c.convert(n.L), @@ -249,9 +256,36 @@ func convertColumnDef(def *pcast.ColumnDef) *ast.ColumnDef { } } } + + // Build TypeName with modifiers for proper formatting + typeName := &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())} + + // Add type modifiers (e.g., length for varchar(255), char(32)) + // Only for types where length is meaningful and user-specified + tp := def.Tp.GetType() + flen := def.Tp.GetFlen() + needsLength := false + switch tp { + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString: + // VARCHAR(n), CHAR(n) - always need length + needsLength = flen >= 0 + case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + // BLOB types - only if user specified length (VARBINARY(n), BINARY(n)) + // Default blob types don't need length + needsLength = false + } + + if needsLength { + typeName.Typmods = &ast.List{ + Items: []ast.Node{ + &ast.Integer{Ival: int64(flen)}, + }, + } + } + columnDef := ast.ColumnDef{ Colname: def.Name.String(), - TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}, + TypeName: typeName, IsNotNull: isNotNull(def), IsUnsigned: isUnsigned(def), Comment: comment, @@ -294,22 +328,54 @@ func (c *cc) convertColumnNames(cols []*pcast.ColumnName) *ast.List { } func (c *cc) convertDeleteStmt(n *pcast.DeleteStmt) *ast.DeleteStmt { - rels := c.convertTableRefsClause(n.TableRefs) - if len(rels.Items) != 1 { - panic("expected one range var") - } - relations := &ast.List{} - convertToRangeVarList(rels, relations) - stmt := &ast.DeleteStmt{ - Relations: relations, WhereClause: c.convert(n.Where), ReturningList: &ast.List{}, WithClause: c.convertWithClause(n.With), } + if n.Limit != nil { stmt.LimitCount = c.convert(n.Limit.Count) } + + // Handle multi-table DELETE (DELETE t1, t2 FROM t1 JOIN t2 ...) + if n.IsMultiTable && n.Tables != nil && len(n.Tables.Tables) > 0 { + // Convert delete targets (e.g., jt.*, pt.*) + targets := &ast.List{} + for _, table := range n.Tables.Tables { + // Each table in the delete list is a ColumnRef like "jt.*" or "pt.*" + items := []ast.Node{} + if table.Schema.String() != "" { + items = append(items, NewIdentifier(table.Schema.String())) + } + items = append(items, NewIdentifier(table.Name.String())) + items = append(items, &ast.A_Star{}) + targets.Items = append(targets.Items, &ast.ColumnRef{ + Fields: &ast.List{Items: items}, + }) + } + stmt.Targets = targets + + // Convert FROM clause preserving JOINs + if n.TableRefs != nil { + fromList := c.convertTableRefsClause(n.TableRefs) + if len(fromList.Items) == 1 { + stmt.FromClause = fromList.Items[0] + } else { + stmt.FromClause = fromList + } + } + } else { + // Single-table DELETE + rels := c.convertTableRefsClause(n.TableRefs) + if len(rels.Items) != 1 { + panic("expected one range var") + } + relations := &ast.List{} + convertToRangeVarList(rels, relations) + stmt.Relations = relations + } + return stmt } @@ -333,9 +399,11 @@ func (c *cc) convertRenameTableStmt(n *pcast.RenameTableStmt) ast.Node { } func (c *cc) convertExistsSubqueryExpr(n *pcast.ExistsSubqueryExpr) *ast.SubLink { - sublink := &ast.SubLink{} - if ss, ok := c.convert(n.Sel).(*ast.SelectStmt); ok { - sublink.Subselect = ss + sublink := &ast.SubLink{ + SubLinkType: ast.EXISTS_SUBLINK, + } + if n.Sel != nil { + sublink.Subselect = c.convert(n.Sel) } return sublink } @@ -359,6 +427,33 @@ func (c *cc) convertFuncCallExpr(n *pcast.FuncCallExpr) ast.Node { } items = append(items, NewIdentifier(name)) + // Handle DATE_ADD/DATE_SUB specially to construct INTERVAL expressions + // These functions have args: [date, interval_value, TimeUnitExpr] + if (name == "date_add" || name == "date_sub") && len(n.Args) == 3 { + if timeUnit, ok := n.Args[2].(*pcast.TimeUnitExpr); ok { + args := &ast.List{ + Items: []ast.Node{ + c.convert(n.Args[0]), + &ast.IntervalExpr{ + Value: c.convert(n.Args[1]), + Unit: timeUnit.Unit.String(), + }, + }, + } + return &ast.FuncCall{ + Args: args, + Func: &ast.FuncName{ + Schema: schema, + Name: name, + }, + Funcname: &ast.List{ + Items: items, + }, + Location: n.OriginTextPosition(), + } + } + } + args := &ast.List{} for _, arg := range n.Args { args.Items = append(args.Items, c.convert(arg)) @@ -415,7 +510,7 @@ func (c *cc) convertInsertStmt(n *pcast.InsertStmt) *ast.InsertStmt { for _, a := range n.OnDuplicate { targetList.Items = append(targetList.Items, c.convertAssignment(a)) } - insert.OnConflictClause = &ast.OnConflictClause{ + insert.OnDuplicateKeyUpdate = &ast.OnDuplicateKeyUpdate{ TargetList: targetList, Location: n.OriginTextPosition(), } @@ -492,7 +587,11 @@ func (c *cc) convertSelectStmt(n *pcast.SelectStmt) *ast.SelectStmt { } func (c *cc) convertSubqueryExpr(n *pcast.SubqueryExpr) ast.Node { - return c.convert(n.Query) + // Wrap subquery in SubLink to ensure parentheses are added + return &ast.SubLink{ + SubLinkType: ast.EXPR_SUBLINK, + Subselect: c.convert(n.Query), + } } func (c *cc) convertTableRefsClause(n *pcast.TableRefsClause) *ast.List { @@ -514,9 +613,17 @@ func (c *cc) convertCommonTableExpression(n *pcast.CommonTableExpression) *ast.C columns.Items = append(columns.Items, NewIdentifier(col.String())) } + // CTE Query is wrapped in SubqueryExpr by TiDB parser. + // We need to unwrap it to get the SelectStmt directly, + // otherwise it would be double-wrapped with parentheses. + var cteQuery ast.Node + if n.Query != nil { + cteQuery = c.convert(n.Query.Query) + } + return &ast.CommonTableExpr{ Ctename: &name, - Ctequery: c.convert(n.Query), + Ctequery: cteQuery, Ctecolnames: columns, } } @@ -596,7 +703,7 @@ func (c *cc) convertValueExpr(n *driver.ValueExpr) *ast.A_Const { mysql.TypeNewDecimal: return &ast.A_Const{ Val: &ast.Float{ - // TODO: Extract the value from n.TexprNode + Str: strconv.FormatFloat(n.Datum.GetFloat64(), 'f', -1, 64), }, Location: n.OriginTextPosition(), } @@ -643,7 +750,21 @@ func (c *cc) convertAggregateFuncExpr(n *pcast.AggregateFuncExpr) *ast.FuncCall Args: &ast.List{}, AggOrder: &ast.List{}, } - for _, a := range n.Args { + + // GROUP_CONCAT has special handling: + // TiDB always adds the separator as the last argument + // We need to extract it and use SEPARATOR syntax + args := n.Args + var separator string + if name == "group_concat" && len(args) >= 2 { + // The last arg is always the separator + if value, ok := args[len(args)-1].(*driver.ValueExpr); ok { + separator = value.GetString() + args = args[:len(args)-1] + } + } + + for _, a := range args { if value, ok := a.(*driver.ValueExpr); ok { if value.GetInt64() == int64(1) { fn.AggStar = true @@ -655,6 +776,12 @@ func (c *cc) convertAggregateFuncExpr(n *pcast.AggregateFuncExpr) *ast.FuncCall if n.Distinct { fn.AggDistinct = true } + + // Store separator for GROUP_CONCAT (only if non-default) + if name == "group_concat" && separator != "" && separator != "," { + fn.Separator = &separator + } + return fn } @@ -871,9 +998,21 @@ func (c *cc) convertFrameClause(n *pcast.FrameClause) ast.Node { } func (c *cc) convertFuncCastExpr(n *pcast.FuncCastExpr) ast.Node { + typeName := types.TypeStr(n.Tp.GetType()) + + // MySQL CAST AS UNSIGNED/SIGNED uses bigint internally. + // We need to preserve the signed/unsigned info for formatting. + if typeName == "bigint" { + if mysql.HasUnsignedFlag(n.Tp.GetFlag()) { + typeName = "bigint unsigned" + } else { + typeName = "bigint signed" + } + } + return &ast.TypeCast{ Arg: c.convert(n.Expr), - TypeName: &ast.TypeName{Name: types.TypeStr(n.Tp.GetType())}, + TypeName: &ast.TypeName{Name: typeName}, } } @@ -949,12 +1088,24 @@ func (c *cc) convertJoin(n *pcast.Join) *ast.List { joinType++ } + // Convert USING clause + var usingClause *ast.List + if len(n.Using) > 0 { + items := make([]ast.Node, len(n.Using)) + for i, col := range n.Using { + items[i] = &ast.String{Str: col.Name.O} + } + usingClause = &ast.List{Items: items} + } + return &ast.List{ Items: []ast.Node{&ast.JoinExpr{ - Jointype: joinType, - Larg: c.convert(n.Left), - Rarg: c.convert(n.Right), - Quals: c.convert(n.On), + Jointype: joinType, + IsNatural: n.NaturalJoin, + Larg: c.convert(n.Left), + Rarg: c.convert(n.Right), + UsingClause: usingClause, + Quals: c.convert(n.On), }}, } } @@ -1049,7 +1200,16 @@ func (c *cc) convertParenthesesExpr(n *pcast.ParenthesesExpr) ast.Node { if n == nil { return nil } - return c.convert(n.Expr) + inner := c.convert(n.Expr) + // Only wrap in ParenExpr for SELECT statements (needed for UNION with parenthesized subqueries) + // For other expressions, the BoolExpr already adds parentheses + if _, ok := inner.(*ast.SelectStmt); ok { + return &ast.ParenExpr{ + Expr: inner, + Location: n.OriginTextPosition(), + } + } + return inner } func (c *cc) convertPartitionByClause(n *pcast.PartitionByClause) ast.Node { @@ -1100,7 +1260,7 @@ func (c *cc) convertPatternRegexpExpr(n *pcast.PatternRegexpExpr) ast.Node { } func (c *cc) convertPositionExpr(n *pcast.PositionExpr) ast.Node { - return todo(n) + return &ast.Integer{Ival: int64(n.N)} } func (c *cc) convertPrepareStmt(n *pcast.PrepareStmt) ast.Node { @@ -1205,7 +1365,28 @@ func (c *cc) convertSetOprSelectList(n *pcast.SetOprSelectList) ast.Node { case *pcast.SelectStmt: selectStmts[i] = c.convertSelectStmt(node) case *pcast.SetOprSelectList: - selectStmts[i] = c.convertSetOprSelectList(node).(*ast.SelectStmt) + // If this is a single-select SetOprSelectList (e.g., from parenthesized SELECT), + // extract the inner select instead of building a UNION tree + if len(node.Selects) == 1 { + if innerSelect, ok := node.Selects[0].(*pcast.SelectStmt); ok { + selectStmts[i] = c.convertSelectStmt(innerSelect) + } else { + selectStmts[i] = c.convertSetOprSelectList(node).(*ast.SelectStmt) + } + } else { + selectStmts[i] = c.convertSetOprSelectList(node).(*ast.SelectStmt) + } + default: + // Handle other node types like ParenthesesExpr wrapping a SELECT + converted := c.convert(node) + if ss, ok := converted.(*ast.SelectStmt); ok { + selectStmts[i] = ss + } else if pe, ok := converted.(*ast.ParenExpr); ok { + // Unwrap ParenExpr to get the inner SelectStmt + if inner, ok := pe.Expr.(*ast.SelectStmt); ok { + selectStmts[i] = inner + } + } } } @@ -1396,7 +1577,12 @@ func (c *cc) convertVariableAssignment(n *pcast.VariableAssignment) ast.Node { } func (c *cc) convertVariableExpr(n *pcast.VariableExpr) ast.Node { - return todo(n) + // MySQL @variable references are user-defined variables, NOT sqlc named parameters. + // Use VariableExpr to preserve them as-is in the output. + return &ast.VariableExpr{ + Name: n.Name, + Location: n.OriginTextPosition(), + } } func (c *cc) convertWhenClause(n *pcast.WhenClause) ast.Node { diff --git a/internal/engine/dolphin/format.go b/internal/engine/dolphin/format.go new file mode 100644 index 0000000000..9c6346756c --- /dev/null +++ b/internal/engine/dolphin/format.go @@ -0,0 +1,43 @@ +package dolphin + +// QuoteIdent returns a quoted identifier if it needs quoting. +// MySQL uses backticks for quoting identifiers. +func (p *Parser) QuoteIdent(s string) string { + // For now, don't quote - MySQL is less strict about quoting + return s +} + +// TypeName returns the SQL type name for the given namespace and name. +// Handles MySQL-specific type name mappings for formatting. +func (p *Parser) TypeName(ns, name string) string { + if ns != "" { + return ns + "." + name + } + // Map internal type names to MySQL CAST-compatible names for formatting + switch name { + case "bigint unsigned": + return "UNSIGNED" + case "bigint signed": + return "SIGNED" + } + return name +} + +// Param returns the parameter placeholder for the given number. +// MySQL uses ? for all parameters (positional). +func (p *Parser) Param(n int) string { + return "?" +} + +// NamedParam returns the named parameter placeholder for the given name. +// MySQL doesn't have native named parameters, so we use ? (positional). +// The actual parameter names are handled by sqlc's rewrite phase. +func (p *Parser) NamedParam(name string) string { + return "?" +} + +// Cast returns a type cast expression. +// MySQL uses CAST(expr AS type) syntax. +func (p *Parser) Cast(arg, typeName string) string { + return "CAST(" + arg + " AS " + typeName + ")" +} diff --git a/internal/engine/dolphin/stdlib.go b/internal/engine/dolphin/stdlib.go index 41469ca49d..46ce500eb5 100644 --- a/internal/engine/dolphin/stdlib.go +++ b/internal/engine/dolphin/stdlib.go @@ -636,6 +636,19 @@ func defaultSchema(name string) *catalog.Schema { }, ReturnType: &ast.TypeName{Name: "date"}, }, + { + // DATE_ADD with INTERVAL expression (2 args) + Name: "DATE_ADD", + Args: []*catalog.Argument{ + { + Type: &ast.TypeName{Name: "date"}, + }, + { + Type: &ast.TypeName{Name: "interval"}, + }, + }, + ReturnType: &ast.TypeName{Name: "date"}, + }, { Name: "DATE_ADD_INTERVAL", Args: []*catalog.Argument{ @@ -675,6 +688,19 @@ func defaultSchema(name string) *catalog.Schema { }, ReturnType: &ast.TypeName{Name: "date"}, }, + { + // DATE_SUB with INTERVAL expression (2 args) + Name: "DATE_SUB", + Args: []*catalog.Argument{ + { + Type: &ast.TypeName{Name: "date"}, + }, + { + Type: &ast.TypeName{Name: "interval"}, + }, + }, + ReturnType: &ast.TypeName{Name: "date"}, + }, { Name: "DATE_SUB_INTERVAL", Args: []*catalog.Argument{ diff --git a/internal/engine/postgresql/analyzer/analyze.go b/internal/engine/postgresql/analyzer/analyze.go index 5a08fa98ec..ee03e4d3c5 100644 --- a/internal/engine/postgresql/analyzer/analyze.go +++ b/internal/engine/postgresql/analyzer/analyze.go @@ -17,6 +17,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/opts" "github.com/sqlc-dev/sqlc/internal/shfmt" "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" "github.com/sqlc-dev/sqlc/internal/sql/named" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) @@ -320,3 +321,227 @@ func (a *Analyzer) Close(_ context.Context) error { } return nil } + +// SQL queries for schema introspection +const introspectTablesQuery = ` +SELECT + n.nspname AS schema_name, + c.relname AS table_name, + a.attname AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, + a.attnotnull AS not_null, + a.attndims AS array_dims, + COALESCE( + (SELECT true FROM pg_index i + WHERE i.indrelid = c.oid + AND i.indisprimary + AND a.attnum = ANY(i.indkey)), + false + ) AS is_primary_key +FROM + pg_catalog.pg_class c +JOIN + pg_catalog.pg_namespace n ON n.oid = c.relnamespace +JOIN + pg_catalog.pg_attribute a ON a.attrelid = c.oid +WHERE + c.relkind IN ('r', 'v', 'p') -- tables, views, partitioned tables + AND a.attnum > 0 -- skip system columns + AND NOT a.attisdropped + AND n.nspname = ANY($1) +ORDER BY + n.nspname, c.relname, a.attnum +` + +const introspectEnumsQuery = ` +SELECT + n.nspname AS schema_name, + t.typname AS type_name, + e.enumlabel AS enum_value +FROM + pg_catalog.pg_type t +JOIN + pg_catalog.pg_namespace n ON n.oid = t.typnamespace +JOIN + pg_catalog.pg_enum e ON e.enumtypid = t.oid +WHERE + t.typtype = 'e' + AND n.nspname = ANY($1) +ORDER BY + n.nspname, t.typname, e.enumsortorder +` + +type introspectedColumn struct { + SchemaName string `db:"schema_name"` + TableName string `db:"table_name"` + ColumnName string `db:"column_name"` + DataType string `db:"data_type"` + NotNull bool `db:"not_null"` + ArrayDims int `db:"array_dims"` + IsPrimaryKey bool `db:"is_primary_key"` +} + +type introspectedEnum struct { + SchemaName string `db:"schema_name"` + TypeName string `db:"type_name"` + EnumValue string `db:"enum_value"` +} + +// IntrospectSchema queries the database to build a catalog containing +// tables, columns, and enum types for the specified schemas. +func (a *Analyzer) IntrospectSchema(ctx context.Context, schemas []string) (*catalog.Catalog, error) { + if a.pool == nil { + return nil, fmt.Errorf("database connection not initialized") + } + + c, err := a.pool.Acquire(ctx) + if err != nil { + return nil, err + } + defer c.Release() + + // Query tables and columns + rows, err := c.Query(ctx, introspectTablesQuery, schemas) + if err != nil { + return nil, fmt.Errorf("introspect tables: %w", err) + } + columns, err := pgx.CollectRows(rows, pgx.RowToStructByName[introspectedColumn]) + if err != nil { + return nil, fmt.Errorf("collect table rows: %w", err) + } + + // Query enums + enumRows, err := c.Query(ctx, introspectEnumsQuery, schemas) + if err != nil { + return nil, fmt.Errorf("introspect enums: %w", err) + } + enums, err := pgx.CollectRows(enumRows, pgx.RowToStructByName[introspectedEnum]) + if err != nil { + return nil, fmt.Errorf("collect enum rows: %w", err) + } + + // Build catalog + cat := &catalog.Catalog{ + DefaultSchema: "public", + SearchPath: schemas, + } + + // Create schema map for quick lookup + schemaMap := make(map[string]*catalog.Schema) + for _, schemaName := range schemas { + schema := &catalog.Schema{Name: schemaName} + cat.Schemas = append(cat.Schemas, schema) + schemaMap[schemaName] = schema + } + + // Group columns by table + tableMap := make(map[string]*catalog.Table) + for _, col := range columns { + key := col.SchemaName + "." + col.TableName + tbl, exists := tableMap[key] + if !exists { + tbl = &catalog.Table{ + Rel: &ast.TableName{ + Schema: col.SchemaName, + Name: col.TableName, + }, + } + tableMap[key] = tbl + if schema, ok := schemaMap[col.SchemaName]; ok { + schema.Tables = append(schema.Tables, tbl) + } + } + + dt, isArray, dims := parseType(col.DataType) + tbl.Columns = append(tbl.Columns, &catalog.Column{ + Name: col.ColumnName, + Type: ast.TypeName{Name: dt}, + IsNotNull: col.NotNull, + IsArray: isArray || col.ArrayDims > 0, + ArrayDims: max(dims, col.ArrayDims), + }) + } + + // Group enum values by type + enumMap := make(map[string]*catalog.Enum) + for _, e := range enums { + key := e.SchemaName + "." + e.TypeName + enum, exists := enumMap[key] + if !exists { + enum = &catalog.Enum{ + Name: e.TypeName, + } + enumMap[key] = enum + if schema, ok := schemaMap[e.SchemaName]; ok { + schema.Types = append(schema.Types, enum) + } + } + enum.Vals = append(enum.Vals, e.EnumValue) + } + + return cat, nil +} + +// EnsureConn initializes the database connection pool if not already done. +// This is useful for database-only mode where we need to connect before analyzing queries. +func (a *Analyzer) EnsureConn(ctx context.Context, migrations []string) error { + if a.pool != nil { + return nil + } + + var uri string + if a.db.Managed { + if a.client == nil { + return fmt.Errorf("client is nil") + } + edb, err := a.client.CreateDatabase(ctx, &dbmanager.CreateDatabaseRequest{ + Engine: "postgresql", + Migrations: migrations, + }) + if err != nil { + return err + } + uri = edb.Uri + } else if a.dbg.OnlyManagedDatabases { + return fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed") + } else { + uri = a.replacer.Replace(a.db.URI) + } + + conf, err := pgxpool.ParseConfig(uri) + if err != nil { + return err + } + pool, err := pgxpool.NewWithConfig(ctx, conf) + if err != nil { + return err + } + a.pool = pool + return nil +} + +// GetColumnNames implements the expander.ColumnGetter interface. +// It prepares a query and returns the column names from the result set description. +func (a *Analyzer) GetColumnNames(ctx context.Context, query string) ([]string, error) { + if a.pool == nil { + return nil, fmt.Errorf("database connection not initialized") + } + + conn, err := a.pool.Acquire(ctx) + if err != nil { + return nil, err + } + defer conn.Release() + + desc, err := conn.Conn().Prepare(ctx, "", query) + if err != nil { + return nil, err + } + + columns := make([]string, len(desc.Fields)) + for i, field := range desc.Fields { + columns[i] = field.Name + } + + return columns, nil +} diff --git a/internal/engine/postgresql/convert.go b/internal/engine/postgresql/convert.go index f56a572c16..321294c59e 100644 --- a/internal/engine/postgresql/convert.go +++ b/internal/engine/postgresql/convert.go @@ -1965,6 +1965,22 @@ func convertNullTest(n *pg.NullTest) *ast.NullTest { } } +func convertNullIfExpr(n *pg.NullIfExpr) *ast.NullIfExpr { + if n == nil { + return nil + } + return &ast.NullIfExpr{ + Xpr: convertNode(n.Xpr), + Opno: ast.Oid(n.Opno), + Opresulttype: ast.Oid(n.Opresulttype), + Opretset: n.Opretset, + Opcollid: ast.Oid(n.Opcollid), + Inputcollid: ast.Oid(n.Inputcollid), + Args: convertSlice(n.Args), + Location: int(n.Location), + } +} + func convertObjectWithArgs(n *pg.ObjectWithArgs) *ast.ObjectWithArgs { if n == nil { return nil @@ -3420,6 +3436,9 @@ func convertNode(node *pg.Node) ast.Node { case *pg.Node_NullTest: return convertNullTest(n.NullTest) + case *pg.Node_NullIfExpr: + return convertNullIfExpr(n.NullIfExpr) + case *pg.Node_ObjectWithArgs: return convertObjectWithArgs(n.ObjectWithArgs) diff --git a/internal/engine/postgresql/parse.go b/internal/engine/postgresql/parse.go index 40af125962..0c6b3a0fc2 100644 --- a/internal/engine/postgresql/parse.go +++ b/internal/engine/postgresql/parse.go @@ -494,6 +494,7 @@ func translate(node *nodes.Node) (ast.Node, error) { ReturnType: rt, Replace: n.Replace, Params: &ast.List{}, + Options: convertSlice(n.Options), } for _, item := range n.Parameters { arg := item.Node.(*nodes.Node_FunctionParameter).FunctionParameter diff --git a/internal/engine/postgresql/reserved.go b/internal/engine/postgresql/reserved.go index 8f796ffa19..b03a6a7e9f 100644 --- a/internal/engine/postgresql/reserved.go +++ b/internal/engine/postgresql/reserved.go @@ -1,6 +1,80 @@ package postgresql -import "strings" +import ( + "fmt" + "strings" +) + +// hasMixedCase returns true if the string has any uppercase letters +// (identifiers with mixed case need quoting in PostgreSQL) +func hasMixedCase(s string) bool { + for _, r := range s { + if r >= 'A' && r <= 'Z' { + return true + } + } + return false +} + +// QuoteIdent returns a quoted identifier if it needs quoting. +// This implements the format.Dialect interface. +func (p *Parser) QuoteIdent(s string) string { + if p.IsReservedKeyword(s) || hasMixedCase(s) { + return `"` + s + `"` + } + return s +} + +// TypeName returns the SQL type name for the given namespace and name. +// This implements the format.Dialect interface. +func (p *Parser) TypeName(ns, name string) string { + if ns == "pg_catalog" { + switch name { + case "int4": + return "integer" + case "int8": + return "bigint" + case "int2": + return "smallint" + case "float4": + return "real" + case "float8": + return "double precision" + case "bool": + return "boolean" + case "bpchar": + return "character" + case "timestamptz": + return "timestamp with time zone" + case "timetz": + return "time with time zone" + default: + return name + } + } + if ns != "" { + return ns + "." + name + } + return name +} + +// Param returns the parameter placeholder for the given number. +// PostgreSQL uses $1, $2, etc. +func (p *Parser) Param(n int) string { + return fmt.Sprintf("$%d", n) +} + +// NamedParam returns the named parameter placeholder for the given name. +// PostgreSQL/sqlc uses @name syntax. +func (p *Parser) NamedParam(name string) string { + return "@" + name +} + +// Cast returns a type cast expression. +// PostgreSQL uses expr::type syntax. +func (p *Parser) Cast(arg, typeName string) string { + return arg + "::" + typeName +} // https://www.postgresql.org/docs/current/sql-keywords-appendix.html func (p *Parser) IsReservedKeyword(s string) bool { diff --git a/internal/engine/sqlite/analyzer/analyze.go b/internal/engine/sqlite/analyzer/analyze.go new file mode 100644 index 0000000000..3af9f99a30 --- /dev/null +++ b/internal/engine/sqlite/analyzer/analyze.go @@ -0,0 +1,369 @@ +package analyzer + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" + + core "github.com/sqlc-dev/sqlc/internal/analysis" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/opts" + "github.com/sqlc-dev/sqlc/internal/shfmt" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" + "github.com/sqlc-dev/sqlc/internal/sql/named" + "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" +) + +type Analyzer struct { + db config.Database + conn *sqlite3.Conn + dbg opts.Debug + replacer *shfmt.Replacer + mu sync.Mutex +} + +func New(db config.Database) *Analyzer { + return &Analyzer{ + db: db, + dbg: opts.DebugFromEnv(), + replacer: shfmt.NewReplacer(nil), + } +} + +func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrations []string, ps *named.ParamSet) (*core.Analysis, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.conn == nil { + var uri string + applyMigrations := a.db.Managed + if a.db.Managed { + // For managed databases, create an in-memory database + uri = ":memory:" + } else if a.dbg.OnlyManagedDatabases { + return nil, fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed") + } else { + uri = a.replacer.Replace(a.db.URI) + // For in-memory databases, we need to apply migrations since the database starts empty + if isInMemoryDatabase(uri) { + applyMigrations = true + } + } + + conn, err := sqlite3.Open(uri) + if err != nil { + return nil, fmt.Errorf("failed to open sqlite database: %w", err) + } + a.conn = conn + + // Apply migrations for managed or in-memory databases + if applyMigrations { + for _, m := range migrations { + if len(strings.TrimSpace(m)) == 0 { + continue + } + if err := a.conn.Exec(m); err != nil { + a.conn.Close() + a.conn = nil + return nil, fmt.Errorf("migration failed: %s: %w", m, err) + } + } + } + } + + // Prepare the statement to get column and parameter information + stmt, _, err := a.conn.Prepare(query) + if err != nil { + return nil, a.extractSqlErr(n, err) + } + defer stmt.Close() + + var result core.Analysis + + // Get column information + colCount := stmt.ColumnCount() + for i := 0; i < colCount; i++ { + name := stmt.ColumnName(i) + declType := stmt.ColumnDeclType(i) + tableName := stmt.ColumnTableName(i) + originName := stmt.ColumnOriginName(i) + dbName := stmt.ColumnDatabaseName(i) + + // Normalize the data type + dataType := normalizeType(declType) + + // Determine if column is NOT NULL + // SQLite doesn't provide this info directly from prepared statements, + // so we default to nullable (false) + notNull := false + + col := &core.Column{ + Name: name, + OriginalName: originName, + DataType: dataType, + NotNull: notNull, + } + + if tableName != "" { + col.Table = &core.Identifier{ + Schema: dbName, + Name: tableName, + } + } + + result.Columns = append(result.Columns, col) + } + + // Get parameter information + bindCount := stmt.BindCount() + for i := 1; i <= bindCount; i++ { + paramName := stmt.BindName(i) + + // SQLite doesn't provide parameter types from prepared statements + // We use "any" as the default type + name := "" + if paramName != "" { + // Remove the prefix (?, :, @, $) from parameter names + name = strings.TrimLeft(paramName, "?:@$") + } + if ps != nil { + if n, ok := ps.NameFor(i); ok { + name = n + } + } + + result.Params = append(result.Params, &core.Parameter{ + Number: int32(i), + Column: &core.Column{ + Name: name, + DataType: "any", + NotNull: false, + }, + }) + } + + return &result, nil +} + +func (a *Analyzer) extractSqlErr(n ast.Node, err error) error { + if err == nil { + return nil + } + // Try to extract SQLite error details + var sqliteErr *sqlite3.Error + if e, ok := err.(*sqlite3.Error); ok { + sqliteErr = e + } + if sqliteErr != nil { + return &sqlerr.Error{ + Code: fmt.Sprintf("%d", sqliteErr.Code()), + Message: sqliteErr.Error(), + Location: n.Pos(), + } + } + return &sqlerr.Error{ + Message: err.Error(), + Location: n.Pos(), + } +} + +func (a *Analyzer) Close(_ context.Context) error { + a.mu.Lock() + defer a.mu.Unlock() + if a.conn != nil { + err := a.conn.Close() + a.conn = nil + return err + } + return nil +} + +// EnsureConn initializes the database connection if not already done. +// This is useful for database-only mode where we need to connect before analyzing queries. +func (a *Analyzer) EnsureConn(ctx context.Context, migrations []string) error { + a.mu.Lock() + defer a.mu.Unlock() + + if a.conn != nil { + return nil + } + + var uri string + applyMigrations := a.db.Managed + if a.db.Managed { + // For managed databases, create an in-memory database + uri = ":memory:" + } else if a.dbg.OnlyManagedDatabases { + return fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed") + } else { + uri = a.replacer.Replace(a.db.URI) + // For in-memory databases, we need to apply migrations since the database starts empty + if isInMemoryDatabase(uri) { + applyMigrations = true + } + } + + conn, err := sqlite3.Open(uri) + if err != nil { + return fmt.Errorf("failed to open sqlite database: %w", err) + } + a.conn = conn + + // Apply migrations for managed or in-memory databases + if applyMigrations { + for _, m := range migrations { + if len(strings.TrimSpace(m)) == 0 { + continue + } + if err := a.conn.Exec(m); err != nil { + a.conn.Close() + a.conn = nil + return fmt.Errorf("migration failed: %s: %w", m, err) + } + } + } + + return nil +} + +// GetColumnNames implements the expander.ColumnGetter interface. +// It prepares a query and returns the column names from the result set description. +func (a *Analyzer) GetColumnNames(ctx context.Context, query string) ([]string, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.conn == nil { + return nil, fmt.Errorf("database connection not initialized") + } + + stmt, _, err := a.conn.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + + colCount := stmt.ColumnCount() + columns := make([]string, colCount) + for i := 0; i < colCount; i++ { + columns[i] = stmt.ColumnName(i) + } + + return columns, nil +} + +// IntrospectSchema queries the database to build a catalog containing +// tables and columns for the database. +func (a *Analyzer) IntrospectSchema(ctx context.Context, schemas []string) (*catalog.Catalog, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.conn == nil { + return nil, fmt.Errorf("database connection not initialized") + } + + // Build catalog + cat := &catalog.Catalog{ + DefaultSchema: "main", + } + + // Create default schema + mainSchema := &catalog.Schema{Name: "main"} + cat.Schemas = append(cat.Schemas, mainSchema) + + // Query tables from sqlite_master + stmt, _, err := a.conn.Prepare("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") + if err != nil { + return nil, fmt.Errorf("introspect tables: %w", err) + } + + tableNames := []string{} + for stmt.Step() { + tableName := stmt.ColumnText(0) + tableNames = append(tableNames, tableName) + } + stmt.Close() + + // For each table, get column information using PRAGMA table_info + for _, tableName := range tableNames { + tbl := &catalog.Table{ + Rel: &ast.TableName{ + Name: tableName, + }, + } + + pragmaStmt, _, err := a.conn.Prepare(fmt.Sprintf("PRAGMA table_info('%s')", tableName)) + if err != nil { + return nil, fmt.Errorf("pragma table_info for %s: %w", tableName, err) + } + + for pragmaStmt.Step() { + // PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk + colName := pragmaStmt.ColumnText(1) + colType := pragmaStmt.ColumnText(2) + notNull := pragmaStmt.ColumnInt(3) != 0 + + tbl.Columns = append(tbl.Columns, &catalog.Column{ + Name: colName, + Type: ast.TypeName{Name: normalizeType(colType)}, + IsNotNull: notNull, + }) + } + pragmaStmt.Close() + + mainSchema.Tables = append(mainSchema.Tables, tbl) + } + + return cat, nil +} + +// isInMemoryDatabase checks if a SQLite URI refers to an in-memory database +func isInMemoryDatabase(uri string) bool { + if uri == ":memory:" || uri == "" { + return true + } + // Check for file URI with mode=memory parameter + // e.g., "file:test?mode=memory&cache=shared" + if strings.Contains(uri, "mode=memory") { + return true + } + return false +} + +// normalizeType converts SQLite type declarations to standard type names +func normalizeType(declType string) string { + if declType == "" { + return "any" + } + + // Convert to lowercase for comparison + lower := strings.ToLower(declType) + + // SQLite type affinity rules (https://www.sqlite.org/datatype3.html) + switch { + case strings.Contains(lower, "int"): + return "integer" + case strings.Contains(lower, "char"), + strings.Contains(lower, "clob"), + strings.Contains(lower, "text"): + return "text" + case strings.Contains(lower, "blob"): + return "blob" + case strings.Contains(lower, "real"), + strings.Contains(lower, "floa"), + strings.Contains(lower, "doub"): + return "real" + case strings.Contains(lower, "bool"): + return "boolean" + case strings.Contains(lower, "date"), + strings.Contains(lower, "time"): + return "datetime" + default: + // Return as-is for numeric or other types + return lower + } +} diff --git a/internal/engine/sqlite/analyzer/analyze_test.go b/internal/engine/sqlite/analyzer/analyze_test.go new file mode 100644 index 0000000000..320b692597 --- /dev/null +++ b/internal/engine/sqlite/analyzer/analyze_test.go @@ -0,0 +1,119 @@ +package analyzer + +import ( + "context" + "testing" + + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestAnalyzer_Analyze(t *testing.T) { + db := config.Database{ + Managed: true, + } + a := New(db) + defer a.Close(context.Background()) + + ctx := context.Background() + + migrations := []string{ + `CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT + )`, + } + + query := `SELECT id, name, email FROM users WHERE id = ?` + node := &ast.TODO{} + + result, err := a.Analyze(ctx, node, query, migrations, nil) + if err != nil { + t.Fatalf("Analyze failed: %v", err) + } + + if len(result.Columns) != 3 { + t.Errorf("Expected 3 columns, got %d", len(result.Columns)) + } + + expectedCols := []struct { + name string + dataType string + }{ + {"id", "integer"}, + {"name", "text"}, + {"email", "text"}, + } + + for i, expected := range expectedCols { + if i >= len(result.Columns) { + break + } + col := result.Columns[i] + if col.Name != expected.name { + t.Errorf("Column %d: expected name %q, got %q", i, expected.name, col.Name) + } + if col.DataType != expected.dataType { + t.Errorf("Column %d: expected dataType %q, got %q", i, expected.dataType, col.DataType) + } + if col.Table == nil || col.Table.Name != "users" { + t.Errorf("Column %d: expected table 'users', got %v", i, col.Table) + } + } + + if len(result.Params) != 1 { + t.Errorf("Expected 1 parameter, got %d", len(result.Params)) + } +} + +func TestAnalyzer_InvalidQuery(t *testing.T) { + db := config.Database{ + Managed: true, + } + a := New(db) + defer a.Close(context.Background()) + + ctx := context.Background() + + migrations := []string{ + `CREATE TABLE users (id INTEGER PRIMARY KEY)`, + } + + query := `SELECT * FROM nonexistent` + node := &ast.TODO{} + + _, err := a.Analyze(ctx, node, query, migrations, nil) + if err == nil { + t.Error("Expected error for invalid query, got nil") + } +} + +func TestNormalizeType(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"INTEGER", "integer"}, + {"INT", "integer"}, + {"BIGINT", "integer"}, + {"TEXT", "text"}, + {"VARCHAR(255)", "text"}, + {"BLOB", "blob"}, + {"REAL", "real"}, + {"FLOAT", "real"}, + {"DOUBLE", "real"}, + {"BOOLEAN", "boolean"}, + {"DATETIME", "datetime"}, + {"", "any"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := normalizeType(tt.input) + if result != tt.expected { + t.Errorf("normalizeType(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index 658a9d7f33..e9868f5be6 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -514,7 +514,10 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt()) selectStmt.LimitCount = limitCount selectStmt.LimitOffset = limitOffset - selectStmt.WithClause = &ast.WithClause{Ctes: &ctes} + // Only set WithClause if there are CTEs + if len(ctes.Items) > 0 { + selectStmt.WithClause = &ast.WithClause{Ctes: &ctes} + } return selectStmt } @@ -759,6 +762,13 @@ func (c *cc) convertLiteral(n *parser.Expr_literalContext) ast.Node { Location: n.GetStart().GetStart(), } } + + if literal.NULL_() != nil { + return &ast.A_Const{ + Val: &ast.Null{}, + Location: n.GetStart().GetStart(), + } + } } return todo("convertLiteral", n) } @@ -776,8 +786,14 @@ func (c *cc) convertBinaryNode(n *parser.Expr_binaryContext) ast.Node { } func (c *cc) convertBoolNode(n *parser.Expr_boolContext) ast.Node { + var op ast.BoolExprType + if n.AND_() != nil { + op = ast.BoolExprTypeAnd + } else if n.OR_() != nil { + op = ast.BoolExprTypeOr + } return &ast.BoolExpr{ - // TODO: Set op + Boolop: op, Args: &ast.List{ Items: []ast.Node{ c.convert(n.Expr(0)), @@ -787,6 +803,49 @@ func (c *cc) convertBoolNode(n *parser.Expr_boolContext) ast.Node { } } +func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node { + op := n.Unary_operator() + if op == nil { + return c.convert(n.Expr()) + } + + // Get the inner expression + expr := c.convert(n.Expr()) + + // Check the operator type + if opCtx, ok := op.(*parser.Unary_operatorContext); ok { + if opCtx.NOT_() != nil { + // NOT expression + return &ast.BoolExpr{ + Boolop: ast.BoolExprTypeNot, + Args: &ast.List{ + Items: []ast.Node{expr}, + }, + } + } + if opCtx.MINUS() != nil { + // Negative number: -expr + return &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}}, + Rexpr: expr, + } + } + if opCtx.PLUS() != nil { + // Positive number: +expr (just return expr) + return expr + } + if opCtx.TILDE() != nil { + // Bitwise NOT: ~expr + return &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}}, + Rexpr: expr, + } + } + } + + return expr +} + func (c *cc) convertParam(n *parser.Expr_bindContext) ast.Node { if n.NUMBERED_BIND_PARAMETER() != nil { // Parameter numbers start at one @@ -816,7 +875,52 @@ func (c *cc) convertParam(n *parser.Expr_bindContext) ast.Node { } func (c *cc) convertInSelectNode(n *parser.Expr_in_selectContext) ast.Node { - return c.convert(n.Select_stmt()) + // Check if this is EXISTS or NOT EXISTS + if n.EXISTS_() != nil { + linkType := ast.EXISTS_SUBLINK + sublink := &ast.SubLink{ + SubLinkType: linkType, + Subselect: c.convert(n.Select_stmt()), + } + if n.NOT_() != nil { + // NOT EXISTS is represented as a BoolExpr NOT wrapping the EXISTS + return &ast.BoolExpr{ + Boolop: ast.BoolExprTypeNot, + Args: &ast.List{ + Items: []ast.Node{sublink}, + }, + } + } + return sublink + } + + // Check if this is an IN/NOT IN expression: expr IN (SELECT ...) + if n.IN_() != nil && len(n.AllExpr()) > 0 { + linkType := ast.ANY_SUBLINK + sublink := &ast.SubLink{ + SubLinkType: linkType, + Testexpr: c.convert(n.Expr(0)), + Subselect: c.convert(n.Select_stmt()), + } + if n.NOT_() != nil { + return &ast.A_Expr{ + Kind: ast.A_Expr_Kind_OP, + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "NOT IN"}}}, + Lexpr: c.convert(n.Expr(0)), + Rexpr: &ast.SubLink{ + SubLinkType: ast.EXPR_SUBLINK, + Subselect: c.convert(n.Select_stmt()), + }, + } + } + return sublink + } + + // Plain subquery in parentheses (SELECT ...) + return &ast.SubLink{ + SubLinkType: ast.EXPR_SUBLINK, + Subselect: c.convert(n.Select_stmt()), + } } func (c *cc) convertReturning_caluseContext(n parser.IReturning_clauseContext) *ast.List { @@ -887,12 +991,8 @@ func (c *cc) convertInsert_stmtContext(n *parser.Insert_stmtContext) ast.Node { } if hasDefaultValues { - // For DEFAULT VALUES, create an empty select statement - insert.SelectStmt = &ast.SelectStmt{ - FromClause: &ast.List{}, - TargetList: &ast.List{}, - ValuesLists: &ast.List{Items: []ast.Node{&ast.List{}}}, // Single empty values list - } + // For DEFAULT VALUES, set the flag instead of creating an empty values list + insert.DefaultValues = true } else if n.Select_stmt() != nil { if ss, ok := c.convert(n.Select_stmt()).(*ast.SelectStmt); ok { ss.ValuesLists = &ast.List{} @@ -976,6 +1076,11 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast tables = append(tables, rv) } else if from.Table_function_name() != nil { rel := from.Table_function_name().GetText() + // Convert function arguments + var args []ast.Node + for _, expr := range from.AllExpr() { + args = append(args, c.convert(expr)) + } rf := &ast.RangeFunction{ Functions: &ast.List{ Items: []ast.Node{ @@ -989,7 +1094,7 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast }, }, Args: &ast.List{ - Items: []ast.Node{&ast.TODO{}}, + Items: args, }, Location: from.GetStart().GetStart(), }, @@ -1189,6 +1294,9 @@ func (c *cc) convert(node node) ast.Node { case *parser.Expr_binaryContext: return c.convertBinaryNode(n) + case *parser.Expr_unaryContext: + return c.convertUnaryExpr(n) + case *parser.Expr_in_selectContext: return c.convertInSelectNode(n) diff --git a/internal/engine/sqlite/format.go b/internal/engine/sqlite/format.go new file mode 100644 index 0000000000..39ac859ca5 --- /dev/null +++ b/internal/engine/sqlite/format.go @@ -0,0 +1,35 @@ +package sqlite + +// QuoteIdent returns a quoted identifier if it needs quoting. +// SQLite uses double quotes for quoting identifiers (SQL standard), +// though backticks are also supported for MySQL compatibility. +func (p *Parser) QuoteIdent(s string) string { + // For now, don't quote - return as-is + return s +} + +// TypeName returns the SQL type name for the given namespace and name. +func (p *Parser) TypeName(ns, name string) string { + if ns != "" { + return ns + "." + name + } + return name +} + +// Param returns the parameter placeholder for the given number. +// SQLite uses ? for positional parameters. +func (p *Parser) Param(n int) string { + return "?" +} + +// NamedParam returns the named parameter placeholder for the given name. +// SQLite uses :name syntax for named parameters. +func (p *Parser) NamedParam(name string) string { + return ":" + name +} + +// Cast returns a type cast expression. +// SQLite uses CAST(expr AS type) syntax. +func (p *Parser) Cast(arg, typeName string) string { + return "CAST(" + arg + " AS " + typeName + ")" +} diff --git a/internal/opts/experiment.go b/internal/opts/experiment.go new file mode 100644 index 0000000000..45a1c11e05 --- /dev/null +++ b/internal/opts/experiment.go @@ -0,0 +1,105 @@ +package opts + +import ( + "os" + "strings" +) + +// The SQLCEXPERIMENT variable controls experimental features within sqlc. It +// is a comma-separated list of experiment names. Experiment names can be +// prefixed with "no" to explicitly disable them. +// +// This is modeled after Go's GOEXPERIMENT environment variable. For more +// information, see https://pkg.go.dev/internal/goexperiment +// +// Available experiments: +// +// analyzerv2 - enables database-only analyzer mode +// +// Example usage: +// +// SQLCEXPERIMENT=foo,bar # enable foo and bar experiments +// SQLCEXPERIMENT=nofoo # explicitly disable foo experiment +// SQLCEXPERIMENT=foo,nobar # enable foo, disable bar + +// Experiment holds the state of all experimental features. +// Add new experiments as boolean fields to this struct. +type Experiment struct { + // AnalyzerV2 enables the database-only analyzer mode (analyzer.database: only) + // which uses the database for all type resolution instead of parsing schema files. + AnalyzerV2 bool +} + +// ExperimentFromEnv returns an Experiment initialized from the SQLCEXPERIMENT +// environment variable. +func ExperimentFromEnv() Experiment { + return ExperimentFromString(os.Getenv("SQLCEXPERIMENT")) +} + +// ExperimentFromString parses a comma-separated list of experiment names +// and returns an Experiment with the appropriate flags set. +// +// Experiment names can be prefixed with "no" to explicitly disable them. +// Unknown experiment names are silently ignored. +func ExperimentFromString(val string) Experiment { + e := Experiment{} + if val == "" { + return e + } + + for _, name := range strings.Split(val, ",") { + name = strings.TrimSpace(name) + if name == "" { + continue + } + + // Check if this is a negation (noFoo) + enabled := true + if strings.HasPrefix(strings.ToLower(name), "no") && len(name) > 2 { + // Could be a negation, check if the rest is a valid experiment + possibleExp := name[2:] + if isKnownExperiment(possibleExp) { + name = possibleExp + enabled = false + } + // If not a known experiment, treat "no..." as a potential experiment name itself + } + + setExperiment(&e, name, enabled) + } + + return e +} + +// isKnownExperiment returns true if the given name (case-insensitive) is a +// known experiment. +func isKnownExperiment(name string) bool { + switch strings.ToLower(name) { + case "analyzerv2": + return true + default: + return false + } +} + +// setExperiment sets the experiment flag with the given name to the given value. +func setExperiment(e *Experiment, name string, enabled bool) { + switch strings.ToLower(name) { + case "analyzerv2": + e.AnalyzerV2 = enabled + } +} + +// Enabled returns a slice of all enabled experiment names. +func (e Experiment) Enabled() []string { + var enabled []string + if e.AnalyzerV2 { + enabled = append(enabled, "analyzerv2") + } + return enabled +} + +// String returns a comma-separated list of enabled experiments. +func (e Experiment) String() string { + return strings.Join(e.Enabled(), ",") +} diff --git a/internal/opts/experiment_test.go b/internal/opts/experiment_test.go new file mode 100644 index 0000000000..e9a8618e89 --- /dev/null +++ b/internal/opts/experiment_test.go @@ -0,0 +1,176 @@ +package opts + +import "testing" + +func TestExperimentFromString(t *testing.T) { + tests := []struct { + name string + input string + want Experiment + }{ + { + name: "empty string", + input: "", + want: Experiment{}, + }, + { + name: "whitespace only", + input: " ", + want: Experiment{}, + }, + { + name: "unknown experiment", + input: "unknownexperiment", + want: Experiment{}, + }, + { + name: "multiple unknown experiments", + input: "foo,bar,baz", + want: Experiment{}, + }, + { + name: "unknown with no prefix", + input: "nounknown", + want: Experiment{}, + }, + { + name: "whitespace around experiments", + input: " foo , bar , baz ", + want: Experiment{}, + }, + { + name: "empty items in list", + input: "foo,,bar", + want: Experiment{}, + }, + { + name: "enable analyzerv2", + input: "analyzerv2", + want: Experiment{AnalyzerV2: true}, + }, + { + name: "disable analyzerv2", + input: "noanalyzerv2", + want: Experiment{AnalyzerV2: false}, + }, + { + name: "enable then disable analyzerv2", + input: "analyzerv2,noanalyzerv2", + want: Experiment{AnalyzerV2: false}, + }, + { + name: "analyzerv2 case insensitive", + input: "AnalyzerV2", + want: Experiment{AnalyzerV2: true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExperimentFromString(tt.input) + if got != tt.want { + t.Errorf("ExperimentFromString(%q) = %+v, want %+v", tt.input, got, tt.want) + } + }) + } +} + +func TestExperimentEnabled(t *testing.T) { + tests := []struct { + name string + exp Experiment + want []string + }{ + { + name: "no experiments enabled", + exp: Experiment{}, + want: nil, + }, + { + name: "analyzerv2 enabled", + exp: Experiment{AnalyzerV2: true}, + want: []string{"analyzerv2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.exp.Enabled() + if len(got) != len(tt.want) { + t.Errorf("Experiment.Enabled() = %v, want %v", got, tt.want) + return + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("Experiment.Enabled()[%d] = %q, want %q", i, got[i], tt.want[i]) + } + } + }) + } +} + +func TestExperimentString(t *testing.T) { + tests := []struct { + name string + exp Experiment + want string + }{ + { + name: "no experiments", + exp: Experiment{}, + want: "", + }, + { + name: "analyzerv2 enabled", + exp: Experiment{AnalyzerV2: true}, + want: "analyzerv2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.exp.String() + if got != tt.want { + t.Errorf("Experiment.String() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestIsKnownExperiment(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + { + name: "unknown experiment", + input: "unknown", + want: false, + }, + { + name: "empty string", + input: "", + want: false, + }, + { + name: "analyzerv2 lowercase", + input: "analyzerv2", + want: true, + }, + { + name: "analyzerv2 mixed case", + input: "AnalyzerV2", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isKnownExperiment(tt.input) + if got != tt.want { + t.Errorf("isKnownExperiment(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/opts/parser.go b/internal/opts/parser.go index d6fb399552..2059d4f6a1 100644 --- a/internal/opts/parser.go +++ b/internal/opts/parser.go @@ -1,5 +1,6 @@ package opts type Parser struct { - Debug Debug + Debug Debug + Experiment Experiment } diff --git a/internal/sql/ast/CLAUDE.md b/internal/sql/ast/CLAUDE.md new file mode 100644 index 0000000000..e769fbfca6 --- /dev/null +++ b/internal/sql/ast/CLAUDE.md @@ -0,0 +1,116 @@ +# AST Package - Claude Code Guide + +This package defines the Abstract Syntax Tree (AST) nodes used by sqlc to represent SQL statements across all supported databases (PostgreSQL, MySQL, SQLite). + +## Key Concepts + +### Node Interface +All AST nodes implement the `Node` interface with: +- `Pos() int` - returns the source position +- `Format(buf *TrackedBuffer)` - formats the node back to SQL + +### TrackedBuffer +The `TrackedBuffer` type (`pg_query.go`) handles SQL formatting with dialect-specific behavior: +- `astFormat(node Node)` - formats any AST node +- `join(list *List, sep string)` - joins list items with separator +- `WriteString(s string)` - writes raw SQL +- `QuoteIdent(name string)` - quotes identifiers (dialect-specific) +- `TypeName(ns, name string)` - formats type names (dialect-specific) + +### Dialect Interface +Dialect-specific formatting is handled via the `Dialect` interface: +```go +type Dialect interface { + QuoteIdent(string) string + TypeName(ns, name string) string + Param(int) string // $1 for PostgreSQL, ? for MySQL + NamedParam(string) string // @name for PostgreSQL, :name for SQLite + Cast(string) string +} +``` + +## Adding New AST Nodes + +When adding a new AST node type: + +1. **Create the node file** (e.g., `variable_expr.go`): +```go +package ast + +type VariableExpr struct { + Name string + Location int +} + +func (n *VariableExpr) Pos() int { + return n.Location +} + +func (n *VariableExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.WriteString("@") + buf.WriteString(n.Name) +} +``` + +2. **Add to `astutils/walk.go`** - Add a case in the Walk function: +```go +case *ast.VariableExpr: + // Leaf node - no children to traverse +``` + +3. **Add to `astutils/rewrite.go`** - Add a case in the Apply function: +```go +case *ast.VariableExpr: + // Leaf node - no children to traverse +``` + +4. **Update the parser/converter** - In the relevant engine (e.g., `dolphin/convert.go` for MySQL) + +## Helper Functions for Format Methods + +- `set(node Node) bool` - returns true if node is non-nil and not an empty List +- `items(list *List) bool` - returns true if list has items +- `todo(node) Node` - placeholder for unimplemented conversions (returns nil) + +## Common Node Types + +### Statements +- `SelectStmt` - SELECT queries with FromClause, WhereClause, etc. +- `InsertStmt` - INSERT with Relation, Cols, SelectStmt, OnConflictClause +- `UpdateStmt` - UPDATE with Relations, TargetList, WhereClause +- `DeleteStmt` - DELETE with Relations, FromClause (for JOINs), Targets + +### Expressions +- `A_Expr` - General expression with operator (e.g., `a + b`, `@param`) +- `ColumnRef` - Column reference with Fields list +- `FuncCall` - Function call with Func, Args, aggregation options +- `TypeCast` - Type cast with Arg and TypeName +- `ParenExpr` - Parenthesized expression +- `VariableExpr` - MySQL user variable (e.g., `@user_id`) + +### Table References +- `RangeVar` - Table reference with schema, name, alias +- `JoinExpr` - JOIN with Larg, Rarg, Jointype, Quals/UsingClause + +## MySQL-Specific Nodes + +- `VariableExpr` - User variables (`@var`), distinct from sqlc's `@param` syntax +- `IntervalExpr` - INTERVAL expressions +- `OnDuplicateKeyUpdate` - MySQL's ON DUPLICATE KEY UPDATE clause +- `ParenExpr` - Explicit parentheses (TiDB parser wraps expressions) + +## Important Distinctions + +### MySQL @variable vs sqlc @param +- MySQL user variables (`@user_id`) use `VariableExpr` - preserved as-is in output +- sqlc named parameters (`@param`) use `A_Expr` with `@` operator - replaced with `?` +- The `named.IsParamSign()` function checks for `A_Expr` with `@` operator + +### Type Modifiers +- `TypeName.Typmods` holds type modifiers like `varchar(255)` +- For MySQL, only populate Typmods for types where length is user-specified: + - VARCHAR, CHAR, VARBINARY, BINARY - need length + - DATETIME, TIMESTAMP, DATE - internal flen should NOT be output diff --git a/internal/sql/ast/a_array_expr.go b/internal/sql/ast/a_array_expr.go index dafa0e8e85..0437dac84f 100644 --- a/internal/sql/ast/a_array_expr.go +++ b/internal/sql/ast/a_array_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type A_ArrayExpr struct { Elements *List Location int @@ -8,3 +10,12 @@ type A_ArrayExpr struct { func (n *A_ArrayExpr) Pos() int { return n.Location } + +func (n *A_ArrayExpr) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.WriteString("ARRAY[") + buf.join(n.Elements, d, ", ") + buf.WriteString("]") +} diff --git a/internal/sql/ast/a_const.go b/internal/sql/ast/a_const.go index ec1d780945..a6b610e349 100644 --- a/internal/sql/ast/a_const.go +++ b/internal/sql/ast/a_const.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type A_Const struct { Val Node Location int @@ -9,15 +11,15 @@ func (n *A_Const) Pos() int { return n.Location } -func (n *A_Const) Format(buf *TrackedBuffer) { +func (n *A_Const) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if _, ok := n.Val.(*String); ok { buf.WriteString("'") - buf.astFormat(n.Val) + buf.astFormat(n.Val, d) buf.WriteString("'") } else { - buf.astFormat(n.Val) + buf.astFormat(n.Val, d) } } diff --git a/internal/sql/ast/a_expr.go b/internal/sql/ast/a_expr.go index b0b7f75367..4e67967baa 100644 --- a/internal/sql/ast/a_expr.go +++ b/internal/sql/ast/a_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type A_Expr struct { Kind A_Expr_Kind Name *List @@ -12,23 +14,94 @@ func (n *A_Expr) Pos() int { return n.Location } -func (n *A_Expr) Format(buf *TrackedBuffer) { +// isNamedParam returns true if this A_Expr represents a named parameter (@name) +// and extracts the parameter name if so. +func (n *A_Expr) isNamedParam() (string, bool) { + if n.Name == nil || len(n.Name.Items) != 1 { + return "", false + } + s, ok := n.Name.Items[0].(*String) + if !ok || s.Str != "@" { + return "", false + } + if set(n.Lexpr) || !set(n.Rexpr) { + return "", false + } + if nameStr, ok := n.Rexpr.(*String); ok { + return nameStr.Str, true + } + return "", false +} + +func (n *A_Expr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Lexpr) - buf.WriteString(" ") + + // Check for named parameter first (works regardless of Kind) + if name, ok := n.isNamedParam(); ok { + buf.WriteString(d.NamedParam(name)) + return + } + switch n.Kind { case A_Expr_Kind_IN: + buf.astFormat(n.Lexpr, d) buf.WriteString(" IN (") - buf.astFormat(n.Rexpr) + buf.astFormat(n.Rexpr, d) buf.WriteString(")") case A_Expr_Kind_LIKE: + buf.astFormat(n.Lexpr, d) buf.WriteString(" LIKE ") - buf.astFormat(n.Rexpr) + buf.astFormat(n.Rexpr, d) + case A_Expr_Kind_ILIKE: + buf.astFormat(n.Lexpr, d) + buf.WriteString(" ILIKE ") + buf.astFormat(n.Rexpr, d) + case A_Expr_Kind_SIMILAR: + buf.astFormat(n.Lexpr, d) + buf.WriteString(" SIMILAR TO ") + buf.astFormat(n.Rexpr, d) + case A_Expr_Kind_BETWEEN: + buf.astFormat(n.Lexpr, d) + buf.WriteString(" BETWEEN ") + if l, ok := n.Rexpr.(*List); ok && len(l.Items) == 2 { + buf.astFormat(l.Items[0], d) + buf.WriteString(" AND ") + buf.astFormat(l.Items[1], d) + } + case A_Expr_Kind_NOT_BETWEEN: + buf.astFormat(n.Lexpr, d) + buf.WriteString(" NOT BETWEEN ") + if l, ok := n.Rexpr.(*List); ok && len(l.Items) == 2 { + buf.astFormat(l.Items[0], d) + buf.WriteString(" AND ") + buf.astFormat(l.Items[1], d) + } + case A_Expr_Kind_DISTINCT: + buf.astFormat(n.Lexpr, d) + buf.WriteString(" IS DISTINCT FROM ") + buf.astFormat(n.Rexpr, d) + case A_Expr_Kind_NOT_DISTINCT: + buf.astFormat(n.Lexpr, d) + buf.WriteString(" IS NOT DISTINCT FROM ") + buf.astFormat(n.Rexpr, d) + case A_Expr_Kind_NULLIF: + buf.WriteString("NULLIF(") + buf.astFormat(n.Lexpr, d) + buf.WriteString(", ") + buf.astFormat(n.Rexpr, d) + buf.WriteString(")") default: - buf.astFormat(n.Name) - buf.WriteString(" ") - buf.astFormat(n.Rexpr) + // Standard operator (including A_Expr_Kind_OP) + if set(n.Lexpr) { + buf.astFormat(n.Lexpr, d) + buf.WriteString(" ") + } + buf.astFormat(n.Name, d) + if set(n.Rexpr) { + buf.WriteString(" ") + buf.astFormat(n.Rexpr, d) + } } } diff --git a/internal/sql/ast/a_expr_kind.go b/internal/sql/ast/a_expr_kind.go index 53a237896b..3adc9232cf 100644 --- a/internal/sql/ast/a_expr_kind.go +++ b/internal/sql/ast/a_expr_kind.go @@ -3,8 +3,20 @@ package ast type A_Expr_Kind uint const ( - A_Expr_Kind_IN A_Expr_Kind = 7 - A_Expr_Kind_LIKE A_Expr_Kind = 8 + A_Expr_Kind_OP A_Expr_Kind = 1 + A_Expr_Kind_OP_ANY A_Expr_Kind = 2 + A_Expr_Kind_OP_ALL A_Expr_Kind = 3 + A_Expr_Kind_DISTINCT A_Expr_Kind = 4 + A_Expr_Kind_NOT_DISTINCT A_Expr_Kind = 5 + A_Expr_Kind_NULLIF A_Expr_Kind = 6 + A_Expr_Kind_IN A_Expr_Kind = 7 + A_Expr_Kind_LIKE A_Expr_Kind = 8 + A_Expr_Kind_ILIKE A_Expr_Kind = 9 + A_Expr_Kind_SIMILAR A_Expr_Kind = 10 + A_Expr_Kind_BETWEEN A_Expr_Kind = 11 + A_Expr_Kind_NOT_BETWEEN A_Expr_Kind = 12 + A_Expr_Kind_BETWEEN_SYM A_Expr_Kind = 13 + A_Expr_Kind_NOT_BETWEEN_SYM A_Expr_Kind = 14 ) func (n *A_Expr_Kind) Pos() int { diff --git a/internal/sql/ast/a_indices.go b/internal/sql/ast/a_indices.go index 8972f3a556..7180f220e7 100644 --- a/internal/sql/ast/a_indices.go +++ b/internal/sql/ast/a_indices.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type A_Indices struct { IsSlice bool Lidx Node @@ -9,3 +11,22 @@ type A_Indices struct { func (n *A_Indices) Pos() int { return 0 } + +func (n *A_Indices) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.WriteString("[") + if n.IsSlice { + if set(n.Lidx) { + buf.astFormat(n.Lidx, d) + } + buf.WriteString(":") + if set(n.Uidx) { + buf.astFormat(n.Uidx, d) + } + } else { + buf.astFormat(n.Uidx, d) + } + buf.WriteString("]") +} diff --git a/internal/sql/ast/a_star.go b/internal/sql/ast/a_star.go index a43b2ab5b7..7e5f07b96a 100644 --- a/internal/sql/ast/a_star.go +++ b/internal/sql/ast/a_star.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type A_Star struct { } @@ -7,7 +9,7 @@ func (n *A_Star) Pos() int { return 0 } -func (n *A_Star) Format(buf *TrackedBuffer) { +func (n *A_Star) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/alias.go b/internal/sql/ast/alias.go index 55965b55c9..7123982305 100644 --- a/internal/sql/ast/alias.go +++ b/internal/sql/ast/alias.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type Alias struct { Aliasname *string Colnames *List @@ -9,7 +11,7 @@ func (n *Alias) Pos() int { return 0 } -func (n *Alias) Format(buf *TrackedBuffer) { +func (n *Alias) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -18,7 +20,7 @@ func (n *Alias) Format(buf *TrackedBuffer) { } if items(n.Colnames) { buf.WriteString("(") - buf.astFormat((n.Colnames)) + buf.astFormat(n.Colnames, d) buf.WriteString(")") } } diff --git a/internal/sql/ast/alter_table_cmd.go b/internal/sql/ast/alter_table_cmd.go index 80fad95eaf..90ffd891eb 100644 --- a/internal/sql/ast/alter_table_cmd.go +++ b/internal/sql/ast/alter_table_cmd.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + const ( AT_AddColumn AlterTableType = iota AT_AlterColumnType @@ -40,7 +42,7 @@ func (n *AlterTableCmd) Pos() int { return 0 } -func (n *AlterTableCmd) Format(buf *TrackedBuffer) { +func (n *AlterTableCmd) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -51,5 +53,5 @@ func (n *AlterTableCmd) Format(buf *TrackedBuffer) { buf.WriteString(" DROP COLUMN ") } - buf.astFormat(n.Def) + buf.astFormat(n.Def, d) } diff --git a/internal/sql/ast/alter_table_stmt.go b/internal/sql/ast/alter_table_stmt.go index 5d4a22f50e..4dc88707ff 100644 --- a/internal/sql/ast/alter_table_stmt.go +++ b/internal/sql/ast/alter_table_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type AlterTableStmt struct { // TODO: Only TableName or Relation should be defined Relation *RangeVar @@ -13,12 +15,12 @@ func (n *AlterTableStmt) Pos() int { return 0 } -func (n *AlterTableStmt) Format(buf *TrackedBuffer) { +func (n *AlterTableStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("ALTER TABLE ") - buf.astFormat(n.Relation) - buf.astFormat(n.Table) - buf.astFormat(n.Cmds) + buf.astFormat(n.Relation, d) + buf.astFormat(n.Table, d) + buf.astFormat(n.Cmds, d) } diff --git a/internal/sql/ast/between_expr.go b/internal/sql/ast/between_expr.go index 0811caee31..a160f1892c 100644 --- a/internal/sql/ast/between_expr.go +++ b/internal/sql/ast/between_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type BetweenExpr struct { // Expr is the value expression to be compared. Expr Node @@ -15,3 +17,18 @@ type BetweenExpr struct { func (n *BetweenExpr) Pos() int { return n.Location } + +func (n *BetweenExpr) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.astFormat(n.Expr, d) + if n.Not { + buf.WriteString(" NOT BETWEEN ") + } else { + buf.WriteString(" BETWEEN ") + } + buf.astFormat(n.Left, d) + buf.WriteString(" AND ") + buf.astFormat(n.Right, d) +} diff --git a/internal/sql/ast/bool_expr.go b/internal/sql/ast/bool_expr.go index 6d15276a05..f2c0243a9c 100644 --- a/internal/sql/ast/bool_expr.go +++ b/internal/sql/ast/bool_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type BoolExpr struct { Xpr Node Boolop BoolExprType @@ -11,21 +13,37 @@ func (n *BoolExpr) Pos() int { return n.Location } -func (n *BoolExpr) Format(buf *TrackedBuffer) { +func (n *BoolExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.WriteString("(") - if items(n.Args) { - switch n.Boolop { - case BoolExprTypeAnd: - buf.join(n.Args, " AND ") - case BoolExprTypeOr: - buf.join(n.Args, " OR ") - case BoolExprTypeNot: - buf.WriteString(" NOT ") - buf.astFormat(n.Args) + switch n.Boolop { + case BoolExprTypeIsNull: + if items(n.Args) && len(n.Args.Items) > 0 { + buf.astFormat(n.Args.Items[0], d) + } + buf.WriteString(" IS NULL") + case BoolExprTypeIsNotNull: + if items(n.Args) && len(n.Args.Items) > 0 { + buf.astFormat(n.Args.Items[0], d) + } + buf.WriteString(" IS NOT NULL") + case BoolExprTypeNot: + // NOT expression: format as NOT + buf.WriteString("NOT ") + if items(n.Args) && len(n.Args.Items) > 0 { + buf.astFormat(n.Args.Items[0], d) + } + default: + buf.WriteString("(") + if items(n.Args) { + switch n.Boolop { + case BoolExprTypeAnd: + buf.join(n.Args, d, " AND ") + case BoolExprTypeOr: + buf.join(n.Args, d, " OR ") + } } + buf.WriteString(")") } - buf.WriteString(")") } diff --git a/internal/sql/ast/boolean.go b/internal/sql/ast/boolean.go index 522af84868..16a6db54da 100644 --- a/internal/sql/ast/boolean.go +++ b/internal/sql/ast/boolean.go @@ -1,6 +1,10 @@ package ast -import "fmt" +import ( + "fmt" + + "github.com/sqlc-dev/sqlc/internal/sql/format" +) type Boolean struct { Boolval bool @@ -10,7 +14,7 @@ func (n *Boolean) Pos() int { return 0 } -func (n *Boolean) Format(buf *TrackedBuffer) { +func (n *Boolean) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/call_stmt.go b/internal/sql/ast/call_stmt.go index 5267a1ff3f..6cba39986e 100644 --- a/internal/sql/ast/call_stmt.go +++ b/internal/sql/ast/call_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CallStmt struct { FuncCall *FuncCall } @@ -11,7 +13,7 @@ func (n *CallStmt) Pos() int { return n.FuncCall.Pos() } -func (n *CallStmt) Format(buf *TrackedBuffer) { +func (n *CallStmt) Format(buf *TrackedBuffer, d format.Dialect) { buf.WriteString("CALL ") - buf.astFormat(n.FuncCall) + buf.astFormat(n.FuncCall, d) } diff --git a/internal/sql/ast/case_expr.go b/internal/sql/ast/case_expr.go index 1da54f0d78..52692b297b 100644 --- a/internal/sql/ast/case_expr.go +++ b/internal/sql/ast/case_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CaseExpr struct { Xpr Node Casetype Oid @@ -14,13 +16,19 @@ func (n *CaseExpr) Pos() int { return n.Location } -func (n *CaseExpr) Format(buf *TrackedBuffer) { +func (n *CaseExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("CASE ") - buf.astFormat(n.Args) - buf.WriteString(" ELSE ") - buf.astFormat(n.Defresult) - buf.WriteString(" END ") + if set(n.Arg) { + buf.astFormat(n.Arg, d) + buf.WriteString(" ") + } + buf.join(n.Args, d, " ") + if set(n.Defresult) { + buf.WriteString(" ELSE ") + buf.astFormat(n.Defresult, d) + } + buf.WriteString(" END") } diff --git a/internal/sql/ast/case_when.go b/internal/sql/ast/case_when.go index b036411d54..9636d24a97 100644 --- a/internal/sql/ast/case_when.go +++ b/internal/sql/ast/case_when.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CaseWhen struct { Xpr Node Expr Node @@ -11,12 +13,12 @@ func (n *CaseWhen) Pos() int { return n.Location } -func (n *CaseWhen) Format(buf *TrackedBuffer) { +func (n *CaseWhen) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("WHEN ") - buf.astFormat(n.Expr) + buf.astFormat(n.Expr, d) buf.WriteString(" THEN ") - buf.astFormat(n.Result) + buf.astFormat(n.Result, d) } diff --git a/internal/sql/ast/coalesce_expr.go b/internal/sql/ast/coalesce_expr.go index cbf7025748..0faee5bf4c 100644 --- a/internal/sql/ast/coalesce_expr.go +++ b/internal/sql/ast/coalesce_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CoalesceExpr struct { Xpr Node Coalescetype Oid @@ -12,11 +14,11 @@ func (n *CoalesceExpr) Pos() int { return n.Location } -func (n *CoalesceExpr) Format(buf *TrackedBuffer) { +func (n *CoalesceExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("COALESCE(") - buf.astFormat(n.Args) + buf.astFormat(n.Args, d) buf.WriteString(")") } diff --git a/internal/sql/ast/collate_expr.go b/internal/sql/ast/collate_expr.go index 6c32eece77..80483f75ce 100644 --- a/internal/sql/ast/collate_expr.go +++ b/internal/sql/ast/collate_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CollateExpr struct { Xpr Node Arg Node @@ -10,3 +12,12 @@ type CollateExpr struct { func (n *CollateExpr) Pos() int { return n.Location } + +func (n *CollateExpr) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.astFormat(n.Xpr, d) + buf.WriteString(" COLLATE ") + buf.astFormat(n.Arg, d) +} diff --git a/internal/sql/ast/column_def.go b/internal/sql/ast/column_def.go index f9504eefc7..225cdd4779 100644 --- a/internal/sql/ast/column_def.go +++ b/internal/sql/ast/column_def.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type ColumnDef struct { Colname string TypeName *TypeName @@ -32,17 +34,22 @@ func (n *ColumnDef) Pos() int { return n.Location } -func (n *ColumnDef) Format(buf *TrackedBuffer) { +func (n *ColumnDef) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString(n.Colname) buf.WriteString(" ") - buf.astFormat(n.TypeName) + buf.astFormat(n.TypeName, d) + // Use IsArray from ColumnDef since TypeName.ArrayBounds may not be set + // (for type resolution compatibility) + if n.IsArray && !items(n.TypeName.ArrayBounds) { + buf.WriteString("[]") + } if n.PrimaryKey { buf.WriteString(" PRIMARY KEY") } else if n.IsNotNull { buf.WriteString(" NOT NULL") } - buf.astFormat(n.Constraints) + buf.astFormat(n.Constraints, d) } diff --git a/internal/sql/ast/column_ref.go b/internal/sql/ast/column_ref.go index e95b844896..943311799d 100644 --- a/internal/sql/ast/column_ref.go +++ b/internal/sql/ast/column_ref.go @@ -1,6 +1,10 @@ package ast -import "strings" +import ( + "strings" + + "github.com/sqlc-dev/sqlc/internal/sql/format" +) type ColumnRef struct { Name string @@ -14,7 +18,7 @@ func (n *ColumnRef) Pos() int { return n.Location } -func (n *ColumnRef) Format(buf *TrackedBuffer) { +func (n *ColumnRef) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -24,11 +28,7 @@ func (n *ColumnRef) Format(buf *TrackedBuffer) { for _, item := range n.Fields.Items { switch nn := item.(type) { case *String: - if nn.Str == "user" { - items = append(items, `"user"`) - } else { - items = append(items, nn.Str) - } + items = append(items, d.QuoteIdent(nn.Str)) case *A_Star: items = append(items, "*") } diff --git a/internal/sql/ast/common_table_expr.go b/internal/sql/ast/common_table_expr.go index f2edddff79..aa334167ce 100644 --- a/internal/sql/ast/common_table_expr.go +++ b/internal/sql/ast/common_table_expr.go @@ -1,8 +1,6 @@ package ast -import ( - "fmt" -) +import "github.com/sqlc-dev/sqlc/internal/sql/format" type CommonTableExpr struct { Ctename *string @@ -21,13 +19,19 @@ func (n *CommonTableExpr) Pos() int { return n.Location } -func (n *CommonTableExpr) Format(buf *TrackedBuffer) { +func (n *CommonTableExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if n.Ctename != nil { - fmt.Fprintf(buf, " %s AS (", *n.Ctename) + buf.WriteString(*n.Ctename) } - buf.astFormat(n.Ctequery) + if items(n.Aliascolnames) { + buf.WriteString("(") + buf.join(n.Aliascolnames, d, ", ") + buf.WriteString(")") + } + buf.WriteString(" AS (") + buf.astFormat(n.Ctequery, d) buf.WriteString(")") } diff --git a/internal/sql/ast/create_extension_stmt.go b/internal/sql/ast/create_extension_stmt.go index 2fe8755b6a..140a10da4c 100644 --- a/internal/sql/ast/create_extension_stmt.go +++ b/internal/sql/ast/create_extension_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CreateExtensionStmt struct { Extname *string IfNotExists bool @@ -9,3 +11,16 @@ type CreateExtensionStmt struct { func (n *CreateExtensionStmt) Pos() int { return 0 } + +func (n *CreateExtensionStmt) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.WriteString("CREATE EXTENSION ") + if n.IfNotExists { + buf.WriteString("IF NOT EXISTS ") + } + if n.Extname != nil { + buf.WriteString(*n.Extname) + } +} diff --git a/internal/sql/ast/create_function_stmt.go b/internal/sql/ast/create_function_stmt.go index 86605344f7..f5200085ee 100644 --- a/internal/sql/ast/create_function_stmt.go +++ b/internal/sql/ast/create_function_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CreateFunctionStmt struct { Replace bool Params *List @@ -13,3 +15,31 @@ type CreateFunctionStmt struct { func (n *CreateFunctionStmt) Pos() int { return 0 } + +func (n *CreateFunctionStmt) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.WriteString("CREATE ") + if n.Replace { + buf.WriteString("OR REPLACE ") + } + buf.WriteString("FUNCTION ") + buf.astFormat(n.Func, d) + buf.WriteString("(") + if items(n.Params) { + buf.join(n.Params, d, ", ") + } + buf.WriteString(")") + if n.ReturnType != nil { + buf.WriteString(" RETURNS ") + buf.astFormat(n.ReturnType, d) + } + // Format options (AS, LANGUAGE, etc.) + if items(n.Options) { + for _, opt := range n.Options.Items { + buf.WriteString(" ") + buf.astFormat(opt, d) + } + } +} diff --git a/internal/sql/ast/create_table_stmt.go b/internal/sql/ast/create_table_stmt.go index ce88a1b244..f7ab2f9f60 100644 --- a/internal/sql/ast/create_table_stmt.go +++ b/internal/sql/ast/create_table_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type CreateTableStmt struct { IfNotExists bool Name *TableName @@ -13,19 +15,19 @@ func (n *CreateTableStmt) Pos() int { return 0 } -func (n *CreateTableStmt) Format(buf *TrackedBuffer) { +func (n *CreateTableStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("CREATE TABLE ") - buf.astFormat(n.Name) + buf.astFormat(n.Name, d) buf.WriteString("(") for i, col := range n.Cols { if i > 0 { buf.WriteString(", ") } - buf.astFormat(col) + buf.astFormat(col, d) } buf.WriteString(")") } diff --git a/internal/sql/ast/def_elem.go b/internal/sql/ast/def_elem.go index 03ecf88e77..33aacaaa03 100644 --- a/internal/sql/ast/def_elem.go +++ b/internal/sql/ast/def_elem.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type DefElem struct { Defnamespace *string Defname *string @@ -11,3 +13,56 @@ type DefElem struct { func (n *DefElem) Pos() int { return n.Location } + +func (n *DefElem) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + if n.Defname != nil { + switch *n.Defname { + case "as": + buf.WriteString("AS ") + // AS clause contains function body which needs quoting + if l, ok := n.Arg.(*List); ok { + for i, item := range l.Items { + if i > 0 { + buf.WriteString(", ") + } + if s, ok := item.(*String); ok { + buf.WriteString("'") + buf.WriteString(s.Str) + buf.WriteString("'") + } else { + buf.astFormat(item, d) + } + } + } else { + buf.astFormat(n.Arg, d) + } + case "language": + buf.WriteString("LANGUAGE ") + buf.astFormat(n.Arg, d) + case "volatility": + // VOLATILE, STABLE, IMMUTABLE + buf.astFormat(n.Arg, d) + case "strict": + if s, ok := n.Arg.(*Boolean); ok && s.Boolval { + buf.WriteString("STRICT") + } else { + buf.WriteString("CALLED ON NULL INPUT") + } + case "security": + if s, ok := n.Arg.(*Boolean); ok && s.Boolval { + buf.WriteString("SECURITY DEFINER") + } else { + buf.WriteString("SECURITY INVOKER") + } + default: + buf.WriteString(*n.Defname) + if n.Arg != nil { + buf.WriteString(" ") + buf.astFormat(n.Arg, d) + } + } + } +} diff --git a/internal/sql/ast/delete_stmt.go b/internal/sql/ast/delete_stmt.go index d77f043a12..d23617881a 100644 --- a/internal/sql/ast/delete_stmt.go +++ b/internal/sql/ast/delete_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type DeleteStmt struct { Relations *List UsingClause *List @@ -7,39 +9,60 @@ type DeleteStmt struct { LimitCount Node ReturningList *List WithClause *WithClause + // MySQL multi-table DELETE support + Targets *List // Tables to delete from (e.g., jt.*, pt.*) + FromClause Node // FROM clause with JOINs (Node to support JoinExpr) } func (n *DeleteStmt) Pos() int { return 0 } -func (n *DeleteStmt) Format(buf *TrackedBuffer) { +func (n *DeleteStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if n.WithClause != nil { - buf.astFormat(n.WithClause) + buf.astFormat(n.WithClause, d) buf.WriteString(" ") } - buf.WriteString("DELETE FROM ") - if items(n.Relations) { - buf.astFormat(n.Relations) + buf.WriteString("DELETE ") + + // MySQL multi-table DELETE: DELETE t1.*, t2.* FROM t1 JOIN t2 ... + if items(n.Targets) { + buf.join(n.Targets, d, ", ") + buf.WriteString(" FROM ") + if set(n.FromClause) { + buf.astFormat(n.FromClause, d) + } else if items(n.Relations) { + buf.astFormat(n.Relations, d) + } + } else { + buf.WriteString("FROM ") + if items(n.Relations) { + buf.astFormat(n.Relations, d) + } + } + + if items(n.UsingClause) { + buf.WriteString(" USING ") + buf.join(n.UsingClause, d, ", ") } if set(n.WhereClause) { buf.WriteString(" WHERE ") - buf.astFormat(n.WhereClause) + buf.astFormat(n.WhereClause, d) } if set(n.LimitCount) { buf.WriteString(" LIMIT ") - buf.astFormat(n.LimitCount) + buf.astFormat(n.LimitCount, d) } if items(n.ReturningList) { buf.WriteString(" RETURNING ") - buf.astFormat(n.ReturningList) + buf.astFormat(n.ReturningList, d) } } diff --git a/internal/sql/ast/do_stmt.go b/internal/sql/ast/do_stmt.go index edc831f15c..9becfb8e64 100644 --- a/internal/sql/ast/do_stmt.go +++ b/internal/sql/ast/do_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type DoStmt struct { Args *List } @@ -7,3 +9,22 @@ type DoStmt struct { func (n *DoStmt) Pos() int { return 0 } + +func (n *DoStmt) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.WriteString("DO ") + // Find the "as" argument which contains the body + if items(n.Args) { + for _, arg := range n.Args.Items { + if de, ok := arg.(*DefElem); ok && de.Defname != nil && *de.Defname == "as" { + if s, ok := de.Arg.(*String); ok { + buf.WriteString("$$") + buf.WriteString(s.Str) + buf.WriteString("$$") + } + } + } + } +} diff --git a/internal/sql/ast/float.go b/internal/sql/ast/float.go index fee8655bbe..94e8c2652f 100644 --- a/internal/sql/ast/float.go +++ b/internal/sql/ast/float.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type Float struct { Str string } @@ -8,7 +10,7 @@ func (n *Float) Pos() int { return 0 } -func (n *Float) Format(buf *TrackedBuffer) { +func (n *Float) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/func_call.go b/internal/sql/ast/func_call.go index 2bfe961b50..cb4f210fe4 100644 --- a/internal/sql/ast/func_call.go +++ b/internal/sql/ast/func_call.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type FuncCall struct { Func *FuncName Funcname *List @@ -11,6 +13,7 @@ type FuncCall struct { AggDistinct bool FuncVariadic bool Over *WindowDef + Separator *string // MySQL GROUP_CONCAT SEPARATOR Location int } @@ -18,16 +21,46 @@ func (n *FuncCall) Pos() int { return n.Location } -func (n *FuncCall) Format(buf *TrackedBuffer) { +func (n *FuncCall) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Func) + buf.astFormat(n.Func, d) buf.WriteString("(") + if n.AggDistinct { + buf.WriteString("DISTINCT ") + } if n.AggStar { buf.WriteString("*") } else { - buf.astFormat(n.Args) + buf.astFormat(n.Args, d) + } + // ORDER BY inside function call (not WITHIN GROUP) + if items(n.AggOrder) && !n.AggWithinGroup { + buf.WriteString(" ORDER BY ") + buf.join(n.AggOrder, d, ", ") + } + // SEPARATOR for GROUP_CONCAT (MySQL) + if n.Separator != nil { + buf.WriteString(" SEPARATOR ") + buf.WriteString("'") + buf.WriteString(*n.Separator) + buf.WriteString("'") } buf.WriteString(")") + // WITHIN GROUP clause for ordered-set aggregates + if items(n.AggOrder) && n.AggWithinGroup { + buf.WriteString(" WITHIN GROUP (ORDER BY ") + buf.join(n.AggOrder, d, ", ") + buf.WriteString(")") + } + if set(n.AggFilter) { + buf.WriteString(" FILTER (WHERE ") + buf.astFormat(n.AggFilter, d) + buf.WriteString(")") + } + if n.Over != nil { + buf.WriteString(" OVER ") + buf.astFormat(n.Over, d) + } } diff --git a/internal/sql/ast/func_name.go b/internal/sql/ast/func_name.go index 29b8e0fa61..cdf3e23d33 100644 --- a/internal/sql/ast/func_name.go +++ b/internal/sql/ast/func_name.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type FuncName struct { Catalog string Schema string @@ -10,7 +12,7 @@ func (n *FuncName) Pos() int { return 0 } -func (n *FuncName) Format(buf *TrackedBuffer) { +func (n *FuncName) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/func_param.go b/internal/sql/ast/func_param.go index b5cf8cfcf0..5881a1441f 100644 --- a/internal/sql/ast/func_param.go +++ b/internal/sql/ast/func_param.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type FuncParamMode int const ( @@ -21,3 +23,25 @@ type FuncParam struct { func (n *FuncParam) Pos() int { return 0 } + +func (n *FuncParam) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + // Parameter mode prefix (OUT, INOUT, VARIADIC) + switch n.Mode { + case FuncParamOut: + buf.WriteString("OUT ") + case FuncParamInOut: + buf.WriteString("INOUT ") + case FuncParamVariadic: + buf.WriteString("VARIADIC ") + } + // Parameter name (if present) + if n.Name != nil { + buf.WriteString(*n.Name) + buf.WriteString(" ") + } + // Parameter type + buf.astFormat(n.Type, d) +} diff --git a/internal/sql/ast/in.go b/internal/sql/ast/in.go index e11b2086a1..9bdad67eeb 100644 --- a/internal/sql/ast/in.go +++ b/internal/sql/ast/in.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + // In describes a 'select foo in (bar, baz)' type statement, though there are multiple important variants handled. type In struct { // Expr is the value expression to be compared. @@ -17,3 +19,30 @@ type In struct { func (n *In) Pos() int { return n.Location } + +// Format formats the In expression. +func (n *In) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.astFormat(n.Expr, d) + if n.Not { + buf.WriteString(" NOT IN ") + } else { + buf.WriteString(" IN ") + } + if n.Sel != nil { + buf.WriteString("(") + buf.astFormat(n.Sel, d) + buf.WriteString(")") + } else if len(n.List) > 0 { + buf.WriteString("(") + for i, item := range n.List { + if i > 0 { + buf.WriteString(", ") + } + buf.astFormat(item, d) + } + buf.WriteString(")") + } +} diff --git a/internal/sql/ast/index_elem.go b/internal/sql/ast/index_elem.go index 52ac09688b..acc2a7fc23 100644 --- a/internal/sql/ast/index_elem.go +++ b/internal/sql/ast/index_elem.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type IndexElem struct { Name *string Expr Node @@ -13,3 +15,14 @@ type IndexElem struct { func (n *IndexElem) Pos() int { return 0 } + +func (n *IndexElem) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + if n.Name != nil && *n.Name != "" { + buf.WriteString(*n.Name) + } else if set(n.Expr) { + buf.astFormat(n.Expr, d) + } +} diff --git a/internal/sql/ast/infer_clause.go b/internal/sql/ast/infer_clause.go index 1e1d93c3d8..6df0db4a86 100644 --- a/internal/sql/ast/infer_clause.go +++ b/internal/sql/ast/infer_clause.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type InferClause struct { IndexElems *List WhereClause Node @@ -10,3 +12,21 @@ type InferClause struct { func (n *InferClause) Pos() int { return n.Location } + +func (n *InferClause) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + if n.Conname != nil && *n.Conname != "" { + buf.WriteString("ON CONSTRAINT ") + buf.WriteString(*n.Conname) + } else if items(n.IndexElems) { + buf.WriteString("(") + buf.join(n.IndexElems, d, ", ") + buf.WriteString(")") + if set(n.WhereClause) { + buf.WriteString(" WHERE ") + buf.astFormat(n.WhereClause, d) + } + } +} diff --git a/internal/sql/ast/insert_stmt.go b/internal/sql/ast/insert_stmt.go index 3cdf854091..4d5c8d1df2 100644 --- a/internal/sql/ast/insert_stmt.go +++ b/internal/sql/ast/insert_stmt.go @@ -1,49 +1,62 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type InsertStmt struct { - Relation *RangeVar - Cols *List - SelectStmt Node - OnConflictClause *OnConflictClause - ReturningList *List - WithClause *WithClause - Override OverridingKind + Relation *RangeVar + Cols *List + SelectStmt Node + OnConflictClause *OnConflictClause + OnDuplicateKeyUpdate *OnDuplicateKeyUpdate // MySQL-specific + ReturningList *List + WithClause *WithClause + Override OverridingKind + DefaultValues bool // SQLite-specific: INSERT INTO ... DEFAULT VALUES } func (n *InsertStmt) Pos() int { return 0 } -func (n *InsertStmt) Format(buf *TrackedBuffer) { +func (n *InsertStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if n.WithClause != nil { - buf.astFormat(n.WithClause) + buf.astFormat(n.WithClause, d) buf.WriteString(" ") } buf.WriteString("INSERT INTO ") if n.Relation != nil { - buf.astFormat(n.Relation) + buf.astFormat(n.Relation, d) } if items(n.Cols) { buf.WriteString(" (") - buf.astFormat(n.Cols) - buf.WriteString(") ") + buf.astFormat(n.Cols, d) + buf.WriteString(")") } - if set(n.SelectStmt) { - buf.astFormat(n.SelectStmt) + if n.DefaultValues { + buf.WriteString(" DEFAULT VALUES") + } else if set(n.SelectStmt) { + buf.WriteString(" ") + buf.astFormat(n.SelectStmt, d) } if n.OnConflictClause != nil { - buf.WriteString(" ON CONFLICT DO NOTHING ") + buf.WriteString(" ") + buf.astFormat(n.OnConflictClause, d) + } + + if n.OnDuplicateKeyUpdate != nil { + buf.WriteString(" ") + buf.astFormat(n.OnDuplicateKeyUpdate, d) } if items(n.ReturningList) { buf.WriteString(" RETURNING ") - buf.astFormat(n.ReturningList) + buf.astFormat(n.ReturningList, d) } } diff --git a/internal/sql/ast/integer.go b/internal/sql/ast/integer.go index e9f911add2..c0c360f2f2 100644 --- a/internal/sql/ast/integer.go +++ b/internal/sql/ast/integer.go @@ -1,6 +1,10 @@ package ast -import "strconv" +import ( + "strconv" + + "github.com/sqlc-dev/sqlc/internal/sql/format" +) type Integer struct { Ival int64 @@ -10,7 +14,7 @@ func (n *Integer) Pos() int { return 0 } -func (n *Integer) Format(buf *TrackedBuffer) { +func (n *Integer) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/interval_expr.go b/internal/sql/ast/interval_expr.go new file mode 100644 index 0000000000..dac73a0557 --- /dev/null +++ b/internal/sql/ast/interval_expr.go @@ -0,0 +1,24 @@ +package ast + +import "github.com/sqlc-dev/sqlc/internal/sql/format" + +// IntervalExpr represents a MySQL INTERVAL expression like "INTERVAL 1 DAY" +type IntervalExpr struct { + Value Node + Unit string + Location int +} + +func (n *IntervalExpr) Pos() int { + return n.Location +} + +func (n *IntervalExpr) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.WriteString("INTERVAL ") + buf.astFormat(n.Value, d) + buf.WriteString(" ") + buf.WriteString(n.Unit) +} diff --git a/internal/sql/ast/join_expr.go b/internal/sql/ast/join_expr.go index e316869560..8ac059d006 100644 --- a/internal/sql/ast/join_expr.go +++ b/internal/sql/ast/join_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type JoinExpr struct { Jointype JoinType IsNatural bool @@ -15,28 +17,38 @@ func (n *JoinExpr) Pos() int { return 0 } -func (n *JoinExpr) Format(buf *TrackedBuffer) { +func (n *JoinExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Larg) + buf.astFormat(n.Larg, d) + if n.IsNatural { + buf.WriteString(" NATURAL") + } switch n.Jointype { case JoinTypeLeft: buf.WriteString(" LEFT JOIN ") + case JoinTypeRight: + buf.WriteString(" RIGHT JOIN ") + case JoinTypeFull: + buf.WriteString(" FULL JOIN ") case JoinTypeInner: - buf.WriteString(" INNER JOIN ") + // CROSS JOIN has no ON or USING clause + if !items(n.UsingClause) && !set(n.Quals) { + buf.WriteString(" CROSS JOIN ") + } else { + buf.WriteString(" JOIN ") + } default: buf.WriteString(" JOIN ") } - buf.astFormat(n.Rarg) - buf.WriteString(" ON ") - if n.Jointype == JoinTypeInner { - if set(n.Quals) { - buf.astFormat(n.Quals) - } else { - buf.WriteString("TRUE") - } - } else { - buf.astFormat(n.Quals) + buf.astFormat(n.Rarg, d) + if items(n.UsingClause) { + buf.WriteString(" USING (") + buf.join(n.UsingClause, d, ", ") + buf.WriteString(")") + } else if set(n.Quals) { + buf.WriteString(" ON ") + buf.astFormat(n.Quals, d) } } diff --git a/internal/sql/ast/list.go b/internal/sql/ast/list.go index 1c89d55339..3bb9d90dcd 100644 --- a/internal/sql/ast/list.go +++ b/internal/sql/ast/list.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type List struct { Items []Node } @@ -8,9 +10,9 @@ func (n *List) Pos() int { return 0 } -func (n *List) Format(buf *TrackedBuffer) { +func (n *List) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.join(n, ",") + buf.join(n, d, ", ") } diff --git a/internal/sql/ast/listen_stmt.go b/internal/sql/ast/listen_stmt.go index 79c1b132c1..48c38419a8 100644 --- a/internal/sql/ast/listen_stmt.go +++ b/internal/sql/ast/listen_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type ListenStmt struct { Conditionname *string } @@ -8,7 +10,7 @@ func (n *ListenStmt) Pos() int { return 0 } -func (n *ListenStmt) Format(buf *TrackedBuffer) { +func (n *ListenStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/locking_clause.go b/internal/sql/ast/locking_clause.go index 11a9159de2..6202b4ae02 100644 --- a/internal/sql/ast/locking_clause.go +++ b/internal/sql/ast/locking_clause.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type LockingClause struct { LockedRels *List Strength LockClauseStrength @@ -10,15 +12,46 @@ func (n *LockingClause) Pos() int { return 0 } -func (n *LockingClause) Format(buf *TrackedBuffer) { +// LockClauseStrength values (matching pg_query_go) +const ( + LockClauseStrengthUndefined LockClauseStrength = 0 + LockClauseStrengthNone LockClauseStrength = 1 + LockClauseStrengthForKeyShare LockClauseStrength = 2 + LockClauseStrengthForShare LockClauseStrength = 3 + LockClauseStrengthForNoKeyUpdate LockClauseStrength = 4 + LockClauseStrengthForUpdate LockClauseStrength = 5 +) + +// LockWaitPolicy values +const ( + LockWaitPolicyBlock LockWaitPolicy = 1 + LockWaitPolicySkip LockWaitPolicy = 2 + LockWaitPolicyError LockWaitPolicy = 3 +) + +func (n *LockingClause) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("FOR ") switch n.Strength { - case 3: + case LockClauseStrengthForKeyShare: + buf.WriteString("KEY SHARE") + case LockClauseStrengthForShare: buf.WriteString("SHARE") - case 5: + case LockClauseStrengthForNoKeyUpdate: + buf.WriteString("NO KEY UPDATE") + case LockClauseStrengthForUpdate: buf.WriteString("UPDATE") } + if items(n.LockedRels) { + buf.WriteString(" OF ") + buf.join(n.LockedRels, d, ", ") + } + switch n.WaitPolicy { + case LockWaitPolicySkip: + buf.WriteString(" SKIP LOCKED") + case LockWaitPolicyError: + buf.WriteString(" NOWAIT") + } } diff --git a/internal/sql/ast/multi_assign_ref.go b/internal/sql/ast/multi_assign_ref.go index 16302b4e4c..94b783bcc1 100644 --- a/internal/sql/ast/multi_assign_ref.go +++ b/internal/sql/ast/multi_assign_ref.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type MultiAssignRef struct { Source Node Colno int @@ -10,9 +12,9 @@ func (n *MultiAssignRef) Pos() int { return 0 } -func (n *MultiAssignRef) Format(buf *TrackedBuffer) { +func (n *MultiAssignRef) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Source) + buf.astFormat(n.Source, d) } diff --git a/internal/sql/ast/named_arg_expr.go b/internal/sql/ast/named_arg_expr.go index e37427826e..a711fd2712 100644 --- a/internal/sql/ast/named_arg_expr.go +++ b/internal/sql/ast/named_arg_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type NamedArgExpr struct { Xpr Node Arg Node @@ -12,7 +14,7 @@ func (n *NamedArgExpr) Pos() int { return n.Location } -func (n *NamedArgExpr) Format(buf *TrackedBuffer) { +func (n *NamedArgExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } @@ -20,5 +22,5 @@ func (n *NamedArgExpr) Format(buf *TrackedBuffer) { buf.WriteString(*n.Name) } buf.WriteString(" => ") - buf.astFormat(n.Arg) + buf.astFormat(n.Arg, d) } diff --git a/internal/sql/ast/notify_stmt.go b/internal/sql/ast/notify_stmt.go index 0c50a11123..abecb94360 100644 --- a/internal/sql/ast/notify_stmt.go +++ b/internal/sql/ast/notify_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type NotifyStmt struct { Conditionname *string Payload *string @@ -9,7 +11,7 @@ func (n *NotifyStmt) Pos() int { return 0 } -func (n *NotifyStmt) Format(buf *TrackedBuffer) { +func (n *NotifyStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/null.go b/internal/sql/ast/null.go index 380c8e7372..e3606e2d7f 100644 --- a/internal/sql/ast/null.go +++ b/internal/sql/ast/null.go @@ -1,11 +1,13 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type Null struct { } func (n *Null) Pos() int { return 0 } -func (n *Null) Format(buf *TrackedBuffer) { +func (n *Null) Format(buf *TrackedBuffer, d format.Dialect) { buf.WriteString("NULL") } diff --git a/internal/sql/ast/null_test_expr.go b/internal/sql/ast/null_test_expr.go index 51fd37f6bb..3436bff0a5 100644 --- a/internal/sql/ast/null_test_expr.go +++ b/internal/sql/ast/null_test_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type NullTest struct { Xpr Node Arg Node @@ -11,3 +13,22 @@ type NullTest struct { func (n *NullTest) Pos() int { return n.Location } + +// NullTestType values +const ( + NullTestTypeIsNull NullTestType = 1 + NullTestTypeIsNotNull NullTestType = 2 +) + +func (n *NullTest) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.astFormat(n.Arg, d) + switch n.Nulltesttype { + case NullTestTypeIsNull: + buf.WriteString(" IS NULL") + case NullTestTypeIsNotNull: + buf.WriteString(" IS NOT NULL") + } +} diff --git a/internal/sql/ast/on_conflict_clause.go b/internal/sql/ast/on_conflict_clause.go index 25333d6d59..a71bae0a23 100644 --- a/internal/sql/ast/on_conflict_clause.go +++ b/internal/sql/ast/on_conflict_clause.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type OnConflictClause struct { Action OnConflictAction Infer *InferClause @@ -11,3 +13,49 @@ type OnConflictClause struct { func (n *OnConflictClause) Pos() int { return n.Location } + +// OnConflictAction values matching pg_query_go +const ( + OnConflictActionUndefined OnConflictAction = 0 + OnConflictActionNone OnConflictAction = 1 + OnConflictActionNothing OnConflictAction = 2 + OnConflictActionUpdate OnConflictAction = 3 +) + +func (n *OnConflictClause) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.WriteString("ON CONFLICT ") + if n.Infer != nil { + buf.astFormat(n.Infer, d) + buf.WriteString(" ") + } + switch n.Action { + case OnConflictActionNothing: + buf.WriteString("DO NOTHING") + case OnConflictActionUpdate: + buf.WriteString("DO UPDATE SET ") + // Format as assignment list: name = val + if n.TargetList != nil { + for i, item := range n.TargetList.Items { + if i > 0 { + buf.WriteString(", ") + } + if rt, ok := item.(*ResTarget); ok { + if rt.Name != nil { + buf.WriteString(*rt.Name) + } + buf.WriteString(" = ") + buf.astFormat(rt.Val, d) + } else { + buf.astFormat(item, d) + } + } + } + if set(n.WhereClause) { + buf.WriteString(" WHERE ") + buf.astFormat(n.WhereClause, d) + } + } +} diff --git a/internal/sql/ast/on_duplicate_key_update.go b/internal/sql/ast/on_duplicate_key_update.go new file mode 100644 index 0000000000..a11ce1ab18 --- /dev/null +++ b/internal/sql/ast/on_duplicate_key_update.go @@ -0,0 +1,37 @@ +package ast + +import "github.com/sqlc-dev/sqlc/internal/sql/format" + +// OnDuplicateKeyUpdate represents MySQL's ON DUPLICATE KEY UPDATE clause +type OnDuplicateKeyUpdate struct { + // TargetList contains the assignments (column = value pairs) + TargetList *List + Location int +} + +func (n *OnDuplicateKeyUpdate) Pos() int { + return n.Location +} + +func (n *OnDuplicateKeyUpdate) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.WriteString("ON DUPLICATE KEY UPDATE ") + if n.TargetList != nil { + for i, item := range n.TargetList.Items { + if i > 0 { + buf.WriteString(", ") + } + if rt, ok := item.(*ResTarget); ok { + if rt.Name != nil { + buf.WriteString(*rt.Name) + } + buf.WriteString(" = ") + buf.astFormat(rt.Val, d) + } else { + buf.astFormat(item, d) + } + } + } +} diff --git a/internal/sql/ast/param_ref.go b/internal/sql/ast/param_ref.go index 8bd724993d..7ebc897a95 100644 --- a/internal/sql/ast/param_ref.go +++ b/internal/sql/ast/param_ref.go @@ -1,6 +1,6 @@ package ast -import "fmt" +import "github.com/sqlc-dev/sqlc/internal/sql/format" type ParamRef struct { Number int @@ -12,9 +12,9 @@ func (n *ParamRef) Pos() int { return n.Location } -func (n *ParamRef) Format(buf *TrackedBuffer) { +func (n *ParamRef) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - fmt.Fprintf(buf, "$%d", n.Number) + buf.WriteString(d.Param(n.Number)) } diff --git a/internal/sql/ast/paren_expr.go b/internal/sql/ast/paren_expr.go new file mode 100644 index 0000000000..831d461f3e --- /dev/null +++ b/internal/sql/ast/paren_expr.go @@ -0,0 +1,22 @@ +package ast + +import "github.com/sqlc-dev/sqlc/internal/sql/format" + +// ParenExpr represents a parenthesized expression +type ParenExpr struct { + Expr Node + Location int +} + +func (n *ParenExpr) Pos() int { + return n.Location +} + +func (n *ParenExpr) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.WriteString("(") + buf.astFormat(n.Expr, d) + buf.WriteString(")") +} diff --git a/internal/sql/ast/print.go b/internal/sql/ast/print.go index 867a53a177..87f6107622 100644 --- a/internal/sql/ast/print.go +++ b/internal/sql/ast/print.go @@ -4,10 +4,11 @@ import ( "strings" "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/sql/format" ) -type formatter interface { - Format(*TrackedBuffer) +type nodeFormatter interface { + Format(*TrackedBuffer, format.Dialect) } type TrackedBuffer struct { @@ -16,21 +17,20 @@ type TrackedBuffer struct { // NewTrackedBuffer creates a new TrackedBuffer. func NewTrackedBuffer() *TrackedBuffer { - buf := &TrackedBuffer{ + return &TrackedBuffer{ Builder: new(strings.Builder), } - return buf } -func (t *TrackedBuffer) astFormat(n Node) { - if ft, ok := n.(formatter); ok { - ft.Format(t) +func (t *TrackedBuffer) astFormat(n Node, d format.Dialect) { + if ft, ok := n.(nodeFormatter); ok { + ft.Format(t, d) } else { debug.Dump(n) } } -func (t *TrackedBuffer) join(n *List, sep string) { +func (t *TrackedBuffer) join(n *List, d format.Dialect, sep string) { if n == nil { return } @@ -41,14 +41,14 @@ func (t *TrackedBuffer) join(n *List, sep string) { if i > 0 { t.WriteString(sep) } - t.astFormat(item) + t.astFormat(item, d) } } -func Format(n Node) string { +func Format(n Node, d format.Dialect) string { tb := NewTrackedBuffer() - if ft, ok := n.(formatter); ok { - ft.Format(tb) + if ft, ok := n.(nodeFormatter); ok { + ft.Format(tb, d) } return tb.String() } diff --git a/internal/sql/ast/range_function.go b/internal/sql/ast/range_function.go index 299078d481..dca63595d8 100644 --- a/internal/sql/ast/range_function.go +++ b/internal/sql/ast/range_function.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type RangeFunction struct { Lateral bool Ordinality bool @@ -13,13 +15,19 @@ func (n *RangeFunction) Pos() int { return 0 } -func (n *RangeFunction) Format(buf *TrackedBuffer) { +func (n *RangeFunction) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Functions) + if n.Lateral { + buf.WriteString("LATERAL ") + } + buf.astFormat(n.Functions, d) if n.Ordinality { - buf.WriteString(" WITH ORDINALITY ") + buf.WriteString(" WITH ORDINALITY") + } + if n.Alias != nil { + buf.WriteString(" AS ") + buf.astFormat(n.Alias, d) } - buf.astFormat(n.Alias) } diff --git a/internal/sql/ast/range_subselect.go b/internal/sql/ast/range_subselect.go index 1506ee7994..51a8825e2b 100644 --- a/internal/sql/ast/range_subselect.go +++ b/internal/sql/ast/range_subselect.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type RangeSubselect struct { Lateral bool Subquery Node @@ -10,15 +12,18 @@ func (n *RangeSubselect) Pos() int { return 0 } -func (n *RangeSubselect) Format(buf *TrackedBuffer) { +func (n *RangeSubselect) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } + if n.Lateral { + buf.WriteString("LATERAL ") + } buf.WriteString("(") - buf.astFormat(n.Subquery) + buf.astFormat(n.Subquery, d) buf.WriteString(")") if n.Alias != nil { - buf.WriteString(" ") - buf.astFormat(n.Alias) + buf.WriteString(" AS ") + buf.astFormat(n.Alias, d) } } diff --git a/internal/sql/ast/range_var.go b/internal/sql/ast/range_var.go index 1d1656f6c0..250b2b3bbf 100644 --- a/internal/sql/ast/range_var.go +++ b/internal/sql/ast/range_var.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type RangeVar struct { Catalogname *string Schemaname *string @@ -14,26 +16,19 @@ func (n *RangeVar) Pos() int { return n.Location } -func (n *RangeVar) Format(buf *TrackedBuffer) { +func (n *RangeVar) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - if n.Schemaname != nil { - buf.WriteString(*n.Schemaname) + if n.Schemaname != nil && *n.Schemaname != "" { + buf.WriteString(d.QuoteIdent(*n.Schemaname)) buf.WriteString(".") } if n.Relname != nil { - // TODO: What names need to be quoted - if *n.Relname == "user" { - buf.WriteString(`"`) - buf.WriteString(*n.Relname) - buf.WriteString(`"`) - } else { - buf.WriteString(*n.Relname) - } + buf.WriteString(d.QuoteIdent(*n.Relname)) } if n.Alias != nil { - buf.WriteString(" ") - buf.astFormat(n.Alias) + buf.WriteString(" AS ") + buf.astFormat(n.Alias, d) } } diff --git a/internal/sql/ast/raw_stmt.go b/internal/sql/ast/raw_stmt.go index 55192d2eec..fe02bed803 100644 --- a/internal/sql/ast/raw_stmt.go +++ b/internal/sql/ast/raw_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type RawStmt struct { Stmt Node StmtLocation int @@ -10,9 +12,9 @@ func (n *RawStmt) Pos() int { return n.StmtLocation } -func (n *RawStmt) Format(buf *TrackedBuffer) { +func (n *RawStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n.Stmt != nil { - buf.astFormat(n.Stmt) + buf.astFormat(n.Stmt, d) } buf.WriteString(";") } diff --git a/internal/sql/ast/refresh_mat_view_stmt.go b/internal/sql/ast/refresh_mat_view_stmt.go index e9b3e26bfa..f627e7bf21 100644 --- a/internal/sql/ast/refresh_mat_view_stmt.go +++ b/internal/sql/ast/refresh_mat_view_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type RefreshMatViewStmt struct { Concurrent bool SkipData bool @@ -10,10 +12,10 @@ func (n *RefreshMatViewStmt) Pos() int { return 0 } -func (n *RefreshMatViewStmt) Format(buf *TrackedBuffer) { +func (n *RefreshMatViewStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("REFRESH MATERIALIZED VIEW ") - buf.astFormat(n.Relation) + buf.astFormat(n.Relation, d) } diff --git a/internal/sql/ast/res_target.go b/internal/sql/ast/res_target.go index 4ee2e72112..dc34879942 100644 --- a/internal/sql/ast/res_target.go +++ b/internal/sql/ast/res_target.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type ResTarget struct { Name *string Indirection *List @@ -11,19 +13,19 @@ func (n *ResTarget) Pos() int { return n.Location } -func (n *ResTarget) Format(buf *TrackedBuffer) { +func (n *ResTarget) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if set(n.Val) { - buf.astFormat(n.Val) + buf.astFormat(n.Val, d) if n.Name != nil { buf.WriteString(" AS ") - buf.WriteString(*n.Name) + buf.WriteString(d.QuoteIdent(*n.Name)) } } else { if n.Name != nil { - buf.WriteString(*n.Name) + buf.WriteString(d.QuoteIdent(*n.Name)) } } } diff --git a/internal/sql/ast/row_expr.go b/internal/sql/ast/row_expr.go index 14804f5821..0f8578355a 100644 --- a/internal/sql/ast/row_expr.go +++ b/internal/sql/ast/row_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type RowExpr struct { Xpr Node Args *List @@ -13,17 +15,17 @@ func (n *RowExpr) Pos() int { return n.Location } -func (n *RowExpr) Format(buf *TrackedBuffer) { +func (n *RowExpr) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if items(n.Args) { buf.WriteString("args") - buf.astFormat(n.Args) + buf.astFormat(n.Args, d) } - buf.astFormat(n.Xpr) + buf.astFormat(n.Xpr, d) if items(n.Colnames) { buf.WriteString("cols") - buf.astFormat(n.Colnames) + buf.astFormat(n.Colnames, d) } } diff --git a/internal/sql/ast/scalar_array_op_expr.go b/internal/sql/ast/scalar_array_op_expr.go index fc438c10b3..b4f36548b3 100644 --- a/internal/sql/ast/scalar_array_op_expr.go +++ b/internal/sql/ast/scalar_array_op_expr.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type ScalarArrayOpExpr struct { Xpr Node Opno Oid @@ -12,3 +14,22 @@ type ScalarArrayOpExpr struct { func (n *ScalarArrayOpExpr) Pos() int { return n.Location } + +func (n *ScalarArrayOpExpr) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + // ScalarArrayOpExpr represents "scalar op ANY/ALL (array)" + // Args[0] is the left operand, Args[1] is the array + if n.Args != nil && len(n.Args.Items) >= 2 { + buf.astFormat(n.Args.Items[0], d) + buf.WriteString(" = ") // TODO: Use actual operator based on Opno + if n.UseOr { + buf.WriteString("ANY(") + } else { + buf.WriteString("ALL(") + } + buf.astFormat(n.Args.Items[1], d) + buf.WriteString(")") + } +} diff --git a/internal/sql/ast/select_stmt.go b/internal/sql/ast/select_stmt.go index 051dd5c8c5..62e6f1c9cf 100644 --- a/internal/sql/ast/select_stmt.go +++ b/internal/sql/ast/select_stmt.go @@ -2,6 +2,8 @@ package ast import ( "fmt" + + "github.com/sqlc-dev/sqlc/internal/sql/format" ) type SelectStmt struct { @@ -29,25 +31,32 @@ func (n *SelectStmt) Pos() int { return 0 } -func (n *SelectStmt) Format(buf *TrackedBuffer) { +func (n *SelectStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if items(n.ValuesLists) { - buf.WriteString("VALUES (") - buf.astFormat(n.ValuesLists) - buf.WriteString(")") + buf.WriteString("VALUES ") + // ValuesLists is a list of rows, where each row is a List of values + for i, row := range n.ValuesLists.Items { + if i > 0 { + buf.WriteString(", ") + } + buf.WriteString("(") + buf.astFormat(row, d) + buf.WriteString(")") + } return } if n.WithClause != nil { - buf.astFormat(n.WithClause) + buf.astFormat(n.WithClause, d) buf.WriteString(" ") } if n.Larg != nil && n.Rarg != nil { - buf.astFormat(n.Larg) + buf.astFormat(n.Larg, d) switch n.Op { case Union: buf.WriteString(" UNION ") @@ -59,7 +68,7 @@ func (n *SelectStmt) Format(buf *TrackedBuffer) { if n.All { buf.WriteString("ALL ") } - buf.astFormat(n.Rarg) + buf.astFormat(n.Rarg, d) } else { buf.WriteString("SELECT ") } @@ -68,45 +77,50 @@ func (n *SelectStmt) Format(buf *TrackedBuffer) { buf.WriteString("DISTINCT ") if !todo(n.DistinctClause) { fmt.Fprintf(buf, "ON (") - buf.astFormat(n.DistinctClause) + buf.astFormat(n.DistinctClause, d) fmt.Fprintf(buf, ")") } } - buf.astFormat(n.TargetList) + buf.astFormat(n.TargetList, d) if items(n.FromClause) { buf.WriteString(" FROM ") - buf.astFormat(n.FromClause) + buf.astFormat(n.FromClause, d) } if set(n.WhereClause) { buf.WriteString(" WHERE ") - buf.astFormat(n.WhereClause) + buf.astFormat(n.WhereClause, d) } if items(n.GroupClause) { buf.WriteString(" GROUP BY ") - buf.astFormat(n.GroupClause) + buf.astFormat(n.GroupClause, d) + } + + if set(n.HavingClause) { + buf.WriteString(" HAVING ") + buf.astFormat(n.HavingClause, d) } if items(n.SortClause) { buf.WriteString(" ORDER BY ") - buf.astFormat(n.SortClause) + buf.astFormat(n.SortClause, d) } if set(n.LimitCount) { buf.WriteString(" LIMIT ") - buf.astFormat(n.LimitCount) + buf.astFormat(n.LimitCount, d) } if set(n.LimitOffset) { buf.WriteString(" OFFSET ") - buf.astFormat(n.LimitOffset) + buf.astFormat(n.LimitOffset, d) } if items(n.LockingClause) { buf.WriteString(" ") - buf.astFormat(n.LockingClause) + buf.astFormat(n.LockingClause, d) } } diff --git a/internal/sql/ast/sort_by.go b/internal/sql/ast/sort_by.go index 21a7a079aa..b8634b7d6d 100644 --- a/internal/sql/ast/sort_by.go +++ b/internal/sql/ast/sort_by.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type SortBy struct { Node Node SortbyDir SortByDir @@ -12,15 +14,21 @@ func (n *SortBy) Pos() int { return n.Location } -func (n *SortBy) Format(buf *TrackedBuffer) { +func (n *SortBy) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Node) + buf.astFormat(n.Node, d) switch n.SortbyDir { case SortByDirAsc: buf.WriteString(" ASC") case SortByDirDesc: buf.WriteString(" DESC") } + switch n.SortbyNulls { + case SortByNullsFirst: + buf.WriteString(" NULLS FIRST") + case SortByNullsLast: + buf.WriteString(" NULLS LAST") + } } diff --git a/internal/sql/ast/sql_value_function.go b/internal/sql/ast/sql_value_function.go index 0bd0777374..31bd008245 100644 --- a/internal/sql/ast/sql_value_function.go +++ b/internal/sql/ast/sql_value_function.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type SQLValueFunction struct { Xpr Node Op SQLValueFunctionOp @@ -12,7 +14,7 @@ func (n *SQLValueFunction) Pos() int { return n.Location } -func (n *SQLValueFunction) Format(buf *TrackedBuffer) { +func (n *SQLValueFunction) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/string.go b/internal/sql/ast/string.go index 977fc19a2f..d167ef4575 100644 --- a/internal/sql/ast/string.go +++ b/internal/sql/ast/string.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type String struct { Str string } @@ -8,7 +10,7 @@ func (n *String) Pos() int { return 0 } -func (n *String) Format(buf *TrackedBuffer) { +func (n *String) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/sub_link.go b/internal/sql/ast/sub_link.go index 9463f98c54..99b8458afe 100644 --- a/internal/sql/ast/sub_link.go +++ b/internal/sql/ast/sub_link.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type SubLinkType uint const ( @@ -27,19 +29,31 @@ func (n *SubLink) Pos() int { return n.Location } -func (n *SubLink) Format(buf *TrackedBuffer) { +func (n *SubLink) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Testexpr) + // Format the test expression if present (for IN subqueries etc.) + hasTestExpr := n.Testexpr != nil + if hasTestExpr { + buf.astFormat(n.Testexpr, d) + } switch n.SubLinkType { case EXISTS_SUBLINK: - buf.WriteString(" EXISTS (") + buf.WriteString("EXISTS (") case ANY_SUBLINK: - buf.WriteString(" IN (") + if hasTestExpr { + buf.WriteString(" IN (") + } else { + buf.WriteString("IN (") + } default: - buf.WriteString(" (") + if hasTestExpr { + buf.WriteString(" (") + } else { + buf.WriteString("(") + } } - buf.astFormat(n.Subselect) + buf.astFormat(n.Subselect, d) buf.WriteString(")") } diff --git a/internal/sql/ast/table_name.go b/internal/sql/ast/table_name.go index a95a510c83..4f494a67e0 100644 --- a/internal/sql/ast/table_name.go +++ b/internal/sql/ast/table_name.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type TableName struct { Catalog string Schema string @@ -10,7 +12,7 @@ func (n *TableName) Pos() int { return 0 } -func (n *TableName) Format(buf *TrackedBuffer) { +func (n *TableName) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } diff --git a/internal/sql/ast/truncate_stmt.go b/internal/sql/ast/truncate_stmt.go index f23a5bbcb3..6636e9f9e8 100644 --- a/internal/sql/ast/truncate_stmt.go +++ b/internal/sql/ast/truncate_stmt.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type TruncateStmt struct { Relations *List RestartSeqs bool @@ -10,10 +12,10 @@ func (n *TruncateStmt) Pos() int { return 0 } -func (n *TruncateStmt) Format(buf *TrackedBuffer) { +func (n *TruncateStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } buf.WriteString("TRUNCATE ") - buf.astFormat(n.Relations) + buf.astFormat(n.Relations, d) } diff --git a/internal/sql/ast/type_cast.go b/internal/sql/ast/type_cast.go index 0b549eb4b1..fe5b321abf 100644 --- a/internal/sql/ast/type_cast.go +++ b/internal/sql/ast/type_cast.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type TypeCast struct { Arg Node TypeName *TypeName @@ -10,11 +12,16 @@ func (n *TypeCast) Pos() int { return n.Location } -func (n *TypeCast) Format(buf *TrackedBuffer) { +func (n *TypeCast) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.astFormat(n.Arg) - buf.WriteString("::") - buf.astFormat(n.TypeName) + // Format the arg and type to strings first + argBuf := NewTrackedBuffer() + argBuf.astFormat(n.Arg, d) + + typeBuf := NewTrackedBuffer() + typeBuf.astFormat(n.TypeName, d) + + buf.WriteString(d.Cast(argBuf.String(), typeBuf.String())) } diff --git a/internal/sql/ast/type_name.go b/internal/sql/ast/type_name.go index e26404b3ba..d8d91f4f87 100644 --- a/internal/sql/ast/type_name.go +++ b/internal/sql/ast/type_name.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type TypeName struct { Catalog string Schema string @@ -20,18 +22,37 @@ func (n *TypeName) Pos() int { return n.Location } -func (n *TypeName) Format(buf *TrackedBuffer) { +func (n *TypeName) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if items(n.Names) { - buf.join(n.Names, ".") - } else { - if n.Name == "int4" { - buf.WriteString("INTEGER") - } else { - buf.WriteString(n.Name) + // Check if this is a qualified type (e.g., pg_catalog.int4) + if len(n.Names.Items) == 2 { + first, _ := n.Names.Items[0].(*String) + second, _ := n.Names.Items[1].(*String) + if first != nil && second != nil { + buf.WriteString(d.TypeName(first.Str, second.Str)) + goto addMods + } + } + // For single name types, just output as-is + if len(n.Names.Items) == 1 { + if s, ok := n.Names.Items[0].(*String); ok { + buf.WriteString(d.TypeName("", s.Str)) + goto addMods + } } + buf.join(n.Names, d, ".") + } else { + buf.WriteString(d.TypeName(n.Schema, n.Name)) + } +addMods: + // Add type modifiers (e.g., varchar(255)) + if items(n.Typmods) { + buf.WriteString("(") + buf.join(n.Typmods, d, ", ") + buf.WriteString(")") } if items(n.ArrayBounds) { buf.WriteString("[]") diff --git a/internal/sql/ast/typedefs.go b/internal/sql/ast/typedefs.go index 351008e841..924fad767b 100644 --- a/internal/sql/ast/typedefs.go +++ b/internal/sql/ast/typedefs.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type AclMode uint32 func (n *AclMode) Pos() int { @@ -18,6 +20,15 @@ func (n *NullIfExpr) Pos() int { return 0 } +func (n *NullIfExpr) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.WriteString("NULLIF(") + buf.join(n.Args, d, ", ") + buf.WriteString(")") +} + type Selectivity float64 func (n *Selectivity) Pos() int { diff --git a/internal/sql/ast/update_stmt.go b/internal/sql/ast/update_stmt.go index efd496ad75..5376a8c6ce 100644 --- a/internal/sql/ast/update_stmt.go +++ b/internal/sql/ast/update_stmt.go @@ -1,6 +1,10 @@ package ast -import "strings" +import ( + "strings" + + "github.com/sqlc-dev/sqlc/internal/sql/format" +) type UpdateStmt struct { Relations *List @@ -16,18 +20,18 @@ func (n *UpdateStmt) Pos() int { return 0 } -func (n *UpdateStmt) Format(buf *TrackedBuffer) { +func (n *UpdateStmt) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } if n.WithClause != nil { - buf.astFormat(n.WithClause) + buf.astFormat(n.WithClause, d) buf.WriteString(" ") } buf.WriteString("UPDATE ") if items(n.Relations) { - buf.astFormat(n.Relations) + buf.astFormat(n.Relations, d) } if items(n.TargetList) { @@ -69,7 +73,7 @@ func (n *UpdateStmt) Format(buf *TrackedBuffer) { buf.WriteString("(") buf.WriteString(strings.Join(names, ",")) buf.WriteString(") = (") - buf.join(vals, ",") + buf.join(vals, d, ",") buf.WriteString(")") } else { for i, item := range n.TargetList.Items { @@ -79,12 +83,18 @@ func (n *UpdateStmt) Format(buf *TrackedBuffer) { switch nn := item.(type) { case *ResTarget: if nn.Name != nil { - buf.WriteString(*nn.Name) + buf.WriteString(d.QuoteIdent(*nn.Name)) + } + // Handle array subscript indirection (e.g., names[$1]) + if items(nn.Indirection) { + for _, ind := range nn.Indirection.Items { + buf.astFormat(ind, d) + } } buf.WriteString(" = ") - buf.astFormat(nn.Val) + buf.astFormat(nn.Val, d) default: - buf.astFormat(item) + buf.astFormat(item, d) } } } @@ -92,21 +102,21 @@ func (n *UpdateStmt) Format(buf *TrackedBuffer) { if items(n.FromClause) { buf.WriteString(" FROM ") - buf.astFormat(n.FromClause) + buf.astFormat(n.FromClause, d) } if set(n.WhereClause) { buf.WriteString(" WHERE ") - buf.astFormat(n.WhereClause) + buf.astFormat(n.WhereClause, d) } if set(n.LimitCount) { buf.WriteString(" LIMIT ") - buf.astFormat(n.LimitCount) + buf.astFormat(n.LimitCount, d) } if items(n.ReturningList) { buf.WriteString(" RETURNING ") - buf.astFormat(n.ReturningList) + buf.astFormat(n.ReturningList, d) } } diff --git a/internal/sql/ast/variable_expr.go b/internal/sql/ast/variable_expr.go new file mode 100644 index 0000000000..83223b482b --- /dev/null +++ b/internal/sql/ast/variable_expr.go @@ -0,0 +1,22 @@ +package ast + +import "github.com/sqlc-dev/sqlc/internal/sql/format" + +// VariableExpr represents a MySQL user variable (e.g., @user_id) +// This is distinct from sqlc's @param named parameter syntax. +type VariableExpr struct { + Name string + Location int +} + +func (n *VariableExpr) Pos() int { + return n.Location +} + +func (n *VariableExpr) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + buf.WriteString("@") + buf.WriteString(n.Name) +} diff --git a/internal/sql/ast/window_def.go b/internal/sql/ast/window_def.go index 29840767c9..caba3e643c 100644 --- a/internal/sql/ast/window_def.go +++ b/internal/sql/ast/window_def.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type WindowDef struct { Name *string Refname *string @@ -14,3 +16,99 @@ type WindowDef struct { func (n *WindowDef) Pos() int { return n.Location } + +// Frame option constants (from PostgreSQL's parsenodes.h) +const ( + FrameOptionNonDefault = 0x00001 + FrameOptionRange = 0x00002 + FrameOptionRows = 0x00004 + FrameOptionGroups = 0x00008 + FrameOptionBetween = 0x00010 + FrameOptionStartUnboundedPreceding = 0x00020 + FrameOptionEndUnboundedPreceding = 0x00040 + FrameOptionStartUnboundedFollowing = 0x00080 + FrameOptionEndUnboundedFollowing = 0x00100 + FrameOptionStartCurrentRow = 0x00200 + FrameOptionEndCurrentRow = 0x00400 + FrameOptionStartOffset = 0x00800 + FrameOptionEndOffset = 0x01000 + FrameOptionExcludeCurrentRow = 0x02000 + FrameOptionExcludeGroup = 0x04000 + FrameOptionExcludeTies = 0x08000 +) + +func (n *WindowDef) Format(buf *TrackedBuffer, d format.Dialect) { + if n == nil { + return + } + + // Named window reference + if n.Refname != nil && *n.Refname != "" { + buf.WriteString(*n.Refname) + return + } + + buf.WriteString("(") + needSpace := false + + if items(n.PartitionClause) { + buf.WriteString("PARTITION BY ") + buf.join(n.PartitionClause, d, ", ") + needSpace = true + } + + if items(n.OrderClause) { + if needSpace { + buf.WriteString(" ") + } + buf.WriteString("ORDER BY ") + buf.join(n.OrderClause, d, ", ") + needSpace = true + } + + // Frame clause + if n.FrameOptions&FrameOptionNonDefault != 0 { + if needSpace { + buf.WriteString(" ") + } + + // Frame type + if n.FrameOptions&FrameOptionRows != 0 { + buf.WriteString("ROWS ") + } else if n.FrameOptions&FrameOptionRange != 0 { + buf.WriteString("RANGE ") + } else if n.FrameOptions&FrameOptionGroups != 0 { + buf.WriteString("GROUPS ") + } + + if n.FrameOptions&FrameOptionBetween != 0 { + buf.WriteString("BETWEEN ") + } + + // Start bound + if n.FrameOptions&FrameOptionStartUnboundedPreceding != 0 { + buf.WriteString("UNBOUNDED PRECEDING") + } else if n.FrameOptions&FrameOptionStartCurrentRow != 0 { + buf.WriteString("CURRENT ROW") + } else if n.FrameOptions&FrameOptionStartOffset != 0 { + buf.astFormat(n.StartOffset, d) + buf.WriteString(" PRECEDING") + } + + if n.FrameOptions&FrameOptionBetween != 0 { + buf.WriteString(" AND ") + + // End bound + if n.FrameOptions&FrameOptionEndUnboundedFollowing != 0 { + buf.WriteString("UNBOUNDED FOLLOWING") + } else if n.FrameOptions&FrameOptionEndCurrentRow != 0 { + buf.WriteString("CURRENT ROW") + } else if n.FrameOptions&FrameOptionEndOffset != 0 { + buf.astFormat(n.EndOffset, d) + buf.WriteString(" FOLLOWING") + } + } + } + + buf.WriteString(")") +} diff --git a/internal/sql/ast/with_clause.go b/internal/sql/ast/with_clause.go index 634326fa7e..0def53d382 100644 --- a/internal/sql/ast/with_clause.go +++ b/internal/sql/ast/with_clause.go @@ -1,5 +1,7 @@ package ast +import "github.com/sqlc-dev/sqlc/internal/sql/format" + type WithClause struct { Ctes *List Recursive bool @@ -10,13 +12,13 @@ func (n *WithClause) Pos() int { return n.Location } -func (n *WithClause) Format(buf *TrackedBuffer) { +func (n *WithClause) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.WriteString("WITH") + buf.WriteString("WITH ") if n.Recursive { - buf.WriteString(" RECURSIVE") + buf.WriteString("RECURSIVE ") } - buf.astFormat(n.Ctes) + buf.join(n.Ctes, d, ", ") } diff --git a/internal/sql/astutils/CLAUDE.md b/internal/sql/astutils/CLAUDE.md new file mode 100644 index 0000000000..b7903542c5 --- /dev/null +++ b/internal/sql/astutils/CLAUDE.md @@ -0,0 +1,117 @@ +# AST Utilities Package - Claude Code Guide + +This package provides utilities for traversing and transforming AST nodes. + +## Key Functions + +### Walk +`Walk(f Visitor, node ast.Node)` traverses the AST depth-first, calling `f.Visit()` on each node. + +```go +type Visitor interface { + Visit(node ast.Node) Visitor +} +``` + +**Important**: When adding new AST node types, you MUST add a case to the switch statement in `walk.go`, otherwise you'll get a panic: +``` +panic: walk: unexpected node type *ast.YourNewType +``` + +### Apply (Rewrite) +`Apply(root ast.Node, pre, post ApplyFunc) ast.Node` traverses and optionally transforms the AST. + +```go +type ApplyFunc func(*Cursor) bool +``` + +The `Cursor` provides: +- `Node()` - current node +- `Parent()` - parent node +- `Name()` - field name in parent +- `Index()` - index if in a list +- `Replace(node)` - replace current node + +**Important**: When adding new AST node types, you MUST add a case to the switch statement in `rewrite.go`, otherwise you'll get a panic: +``` +panic: Apply: unexpected node type *ast.YourNewType +``` + +### Search +`Search(root ast.Node, fn func(ast.Node) bool) *ast.List` finds all nodes matching a predicate. + +### Join +`Join(list *ast.List, sep string) string` joins string nodes with a separator. + +## Adding Support for New AST Nodes + +When you create a new AST node type, you must update BOTH `walk.go` and `rewrite.go`: + +### In walk.go +Add a case that walks all child nodes: +```go +case *ast.YourNewType: + if n.ChildField != nil { + Walk(f, n.ChildField) + } + if n.ChildList != nil { + Walk(f, n.ChildList) + } +``` + +For leaf nodes with no children: +```go +case *ast.YourNewType: + // Leaf node - no children to traverse +``` + +### In rewrite.go +Add a case that applies to all child nodes: +```go +case *ast.YourNewType: + a.apply(n, "ChildField", nil, n.ChildField) + a.apply(n, "ChildList", nil, n.ChildList) +``` + +For leaf nodes: +```go +case *ast.YourNewType: + // Leaf node - no children to traverse +``` + +## Common Patterns + +### Finding All Tables in a Statement +```go +var tv tableVisitor +astutils.Walk(&tv, stmt.FromClause) +// tv.list now contains all RangeVar nodes +``` + +### Replacing Named Parameters +The `rewrite/parameters.go` uses Apply to replace `sqlc.arg()` calls with `ParamRef`: +```go +astutils.Apply(root, func(cr *astutils.Cursor) bool { + if named.IsParamFunc(cr.Node()) { + cr.Replace(&ast.ParamRef{Number: nextParam()}) + } + return true +}, nil) +``` + +## Node Types That Must Be Handled + +All node types in `internal/sql/ast/` must have cases in both walk.go and rewrite.go. Key MySQL-specific nodes: +- `IntervalExpr` - INTERVAL expressions +- `OnDuplicateKeyUpdate` - MySQL ON DUPLICATE KEY UPDATE +- `ParenExpr` - Parenthesized expressions +- `VariableExpr` - MySQL user variables (@var) + +## Debugging Tips + +If you see a panic like: +``` +panic: walk: unexpected node type *ast.SomeType +``` + +Check that `SomeType` has a case in both `walk.go` and `rewrite.go`. diff --git a/internal/sql/astutils/rewrite.go b/internal/sql/astutils/rewrite.go index 93c5be3cfb..fc7996b5f5 100644 --- a/internal/sql/astutils/rewrite.go +++ b/internal/sql/astutils/rewrite.go @@ -687,6 +687,8 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "WhereClause", nil, n.WhereClause) a.apply(n, "ReturningList", nil, n.ReturningList) a.apply(n, "WithClause", nil, n.WithClause) + a.apply(n, "Targets", nil, n.Targets) + a.apply(n, "FromClause", nil, n.FromClause) case *ast.DiscardStmt: // pass @@ -812,12 +814,16 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "Cols", nil, n.Cols) a.apply(n, "SelectStmt", nil, n.SelectStmt) a.apply(n, "OnConflictClause", nil, n.OnConflictClause) + a.apply(n, "OnDuplicateKeyUpdate", nil, n.OnDuplicateKeyUpdate) a.apply(n, "ReturningList", nil, n.ReturningList) a.apply(n, "WithClause", nil, n.WithClause) case *ast.Integer: // pass + case *ast.IntervalExpr: + a.apply(n, "Value", nil, n.Value) + case *ast.IntoClause: a.apply(n, "Rel", nil, n.Rel) a.apply(n, "ColNames", nil, n.ColNames) @@ -883,6 +889,9 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "OnConflictWhere", nil, n.OnConflictWhere) a.apply(n, "ExclRelTlist", nil, n.ExclRelTlist) + case *ast.OnDuplicateKeyUpdate: + a.apply(n, "TargetList", nil, n.TargetList) + case *ast.OpExpr: a.apply(n, "Xpr", nil, n.Xpr) a.apply(n, "Args", nil, n.Args) @@ -902,6 +911,12 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. case *ast.ParamRef: // pass + case *ast.ParenExpr: + a.apply(n, "Expr", nil, n.Expr) + + case *ast.VariableExpr: + // Leaf node - no children to traverse + case *ast.PartitionBoundSpec: a.apply(n, "Listdatums", nil, n.Listdatums) a.apply(n, "Lowerdatums", nil, n.Lowerdatums) diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index 0943379f03..6d5e80bdc3 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -1077,6 +1077,12 @@ func Walk(f Visitor, node ast.Node) { if n.WithClause != nil { Walk(f, n.WithClause) } + if n.Targets != nil { + Walk(f, n.Targets) + } + if n.FromClause != nil { + Walk(f, n.FromClause) + } case *ast.DiscardStmt: // pass @@ -1312,6 +1318,9 @@ func Walk(f Visitor, node ast.Node) { if n.OnConflictClause != nil { Walk(f, n.OnConflictClause) } + if n.OnDuplicateKeyUpdate != nil { + Walk(f, n.OnDuplicateKeyUpdate) + } if n.ReturningList != nil { Walk(f, n.ReturningList) } @@ -1336,6 +1345,11 @@ func Walk(f Visitor, node ast.Node) { Walk(f, n.ViewQuery) } + case *ast.IntervalExpr: + if n.Value != nil { + Walk(f, n.Value) + } + case *ast.JoinExpr: if n.Larg != nil { Walk(f, n.Larg) @@ -1445,6 +1459,11 @@ func Walk(f Visitor, node ast.Node) { Walk(f, n.ExclRelTlist) } + case *ast.OnDuplicateKeyUpdate: + if n.TargetList != nil { + Walk(f, n.TargetList) + } + case *ast.OpExpr: if n.Xpr != nil { Walk(f, n.Xpr) @@ -1470,6 +1489,14 @@ func Walk(f Visitor, node ast.Node) { case *ast.ParamRef: // pass + case *ast.ParenExpr: + if n.Expr != nil { + Walk(f, n.Expr) + } + + case *ast.VariableExpr: + // Leaf node - no children to traverse + case *ast.PartitionBoundSpec: if n.Listdatums != nil { Walk(f, n.Listdatums) diff --git a/internal/sql/format/format.go b/internal/sql/format/format.go new file mode 100644 index 0000000000..b900c227ed --- /dev/null +++ b/internal/sql/format/format.go @@ -0,0 +1,24 @@ +package format + +// Dialect provides SQL dialect-specific formatting behavior +type Dialect interface { + // QuoteIdent returns a quoted identifier if it needs quoting + // (e.g., reserved words, mixed case identifiers) + QuoteIdent(s string) string + + // TypeName returns the SQL type name for the given namespace and name. + // This handles dialect-specific type name mappings (e.g., pg_catalog.int4 -> integer) + TypeName(ns, name string) string + + // Param returns the parameter placeholder for the given parameter number. + // PostgreSQL uses $1, $2, etc. MySQL uses ? + Param(n int) string + + // NamedParam returns the named parameter placeholder for the given name. + // PostgreSQL uses @name, SQLite uses :name + NamedParam(name string) string + + // Cast formats a type cast expression. + // PostgreSQL uses expr::type, MySQL uses CAST(expr AS type) + Cast(arg, typeName string) string +} diff --git a/internal/sql/named/CLAUDE.md b/internal/sql/named/CLAUDE.md new file mode 100644 index 0000000000..05ba358ee9 --- /dev/null +++ b/internal/sql/named/CLAUDE.md @@ -0,0 +1,94 @@ +# Named Parameters Package - Claude Code Guide + +This package provides utilities for identifying sqlc's named parameter syntax. + +## Named Parameter Styles + +sqlc supports two styles of named parameters: + +### 1. Function-style: `sqlc.arg(name)`, `sqlc.narg(name)`, `sqlc.slice(name)` +Identified by `IsParamFunc()`: +```go +func IsParamFunc(node ast.Node) bool { + call, ok := node.(*ast.FuncCall) + if !ok { + return false + } + return call.Func.Schema == "sqlc" && + (call.Func.Name == "arg" || call.Func.Name == "narg" || call.Func.Name == "slice") +} +``` + +### 2. At-sign style: `@param_name` (PostgreSQL only) +Identified by `IsParamSign()`: +```go +func IsParamSign(node ast.Node) bool { + expr, ok := node.(*ast.A_Expr) + return ok && astutils.Join(expr.Name, ".") == "@" +} +``` + +## Important Distinction: sqlc @param vs MySQL @variable + +**sqlc named parameters** (`@param` in PostgreSQL queries): +- Represented as `A_Expr` with `Kind=A_Expr_Kind_OP` and `Name=["@"]` +- Detected by `IsParamSign()` +- Replaced with positional parameters (`$1`, `$2` for PostgreSQL, `?` for MySQL) + +**MySQL user variables** (`@user_id` in MySQL queries): +- Represented as `VariableExpr` +- NOT detected by `IsParamSign()` (it checks for `A_Expr`, not `VariableExpr`) +- Preserved as-is in the output SQL + +This distinction is critical: +```sql +-- PostgreSQL with sqlc @param syntax: +SELECT * FROM users WHERE id = @user_id +-- Becomes: SELECT * FROM users WHERE id = $1 + +-- MySQL with user variable: +SELECT * FROM users WHERE id != @user_id +-- Stays: SELECT * FROM users WHERE id != @user_id +``` + +## Usage in Parameter Rewriting + +The `rewrite/parameters.go` package uses these functions to find and replace named parameters: + +```go +// Find all named parameters +params := astutils.Search(root, func(node ast.Node) bool { + return named.IsParamFunc(node) || named.IsParamSign(node) +}) + +// Replace with positional parameters +astutils.Apply(root, func(cr *astutils.Cursor) bool { + if named.IsParamFunc(cr.Node()) || named.IsParamSign(cr.Node()) { + cr.Replace(&ast.ParamRef{Number: nextParam()}) + } + return true +}, nil) +``` + +## Converting MySQL @variable Correctly + +When converting TiDB's `VariableExpr` in `dolphin/convert.go`: + +```go +// CORRECT - preserves MySQL user variable as-is +func (c *cc) convertVariableExpr(n *pcast.VariableExpr) ast.Node { + return &ast.VariableExpr{ + Name: n.Name, + Location: n.OriginTextPosition(), + } +} + +// WRONG - would be treated as sqlc named parameter +func (c *cc) convertVariableExpr(n *pcast.VariableExpr) ast.Node { + return &ast.A_Expr{ + Kind: ast.A_Expr_Kind_OP, + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "@"}}}, + Rexpr: &ast.String{Str: n.Name}, + } +} +``` diff --git a/internal/sql/rewrite/CLAUDE.md b/internal/sql/rewrite/CLAUDE.md new file mode 100644 index 0000000000..6ea885016e --- /dev/null +++ b/internal/sql/rewrite/CLAUDE.md @@ -0,0 +1,104 @@ +# SQL Rewrite Package - Claude Code Guide + +This package handles AST transformations, primarily for parameter handling. + +## Key Functions + +### NamedParameters +`NamedParameters(engine config.Engine, raw *ast.RawStmt, ...) (*ast.RawStmt, map[int]Parameter, error)` + +Finds and replaces named parameters (`sqlc.arg()`, `@param`) with positional parameters. + +The function: +1. Searches for named parameters using `named.IsParamFunc()` and `named.IsParamSign()` +2. Extracts parameter names and types +3. Replaces them with `ast.ParamRef` nodes +4. Returns a map of parameter positions to their metadata + +### Expand +`Expand(raw *ast.RawStmt, expected int) error` + +Expands `sqlc.slice()` parameters into the correct number of positional parameters. + +## How Parameter Rewriting Works + +### Step 1: Find Named Parameters +```go +refs := astutils.Search(raw.Stmt, func(node ast.Node) bool { + return named.IsParamFunc(node) || named.IsParamSign(node) +}) +``` + +### Step 2: Replace with ParamRef +```go +astutils.Apply(raw.Stmt, func(cr *astutils.Cursor) bool { + if named.IsParamFunc(cr.Node()) { + // Extract name from sqlc.arg(name) + call := cr.Node().(*ast.FuncCall) + name := extractName(call.Args) + + cr.Replace(&ast.ParamRef{ + Number: nextParam(), + Location: call.Location, + }) + } + return true +}, nil) +``` + +## Important: AST Node Requirements + +For parameter rewriting to work correctly, the AST must be walkable. This means: + +1. All node types must have cases in `astutils/walk.go` +2. All node types must have cases in `astutils/rewrite.go` +3. New container types (like `OnDuplicateKeyUpdate`) must be traversed + +### Example: OnDuplicateKeyUpdate + +MySQL's `ON DUPLICATE KEY UPDATE` clause can contain `sqlc.arg()`: +```sql +INSERT INTO t (a) VALUES (sqlc.arg(val)) +ON DUPLICATE KEY UPDATE a = sqlc.arg(new_val) +``` + +For the parameter in `ON DUPLICATE KEY UPDATE` to be found and replaced: + +1. `InsertStmt` in `rewrite.go` must traverse `OnDuplicateKeyUpdate`: +```go +case *ast.InsertStmt: + a.apply(n, "Relation", nil, n.Relation) + a.apply(n, "Cols", nil, n.Cols) + a.apply(n, "SelectStmt", nil, n.SelectStmt) + a.apply(n, "OnConflictClause", nil, n.OnConflictClause) + a.apply(n, "OnDuplicateKeyUpdate", nil, n.OnDuplicateKeyUpdate) // Critical! + a.apply(n, "ReturningList", nil, n.ReturningList) + a.apply(n, "WithClause", nil, n.WithClause) +``` + +2. `OnDuplicateKeyUpdate` must have its own case: +```go +case *ast.OnDuplicateKeyUpdate: + a.apply(n, "List", nil, n.List) +``` + +## Debugging Parameter Issues + +If a `sqlc.arg()` isn't being converted to `?`: + +1. Check that the containing node type has a case in `rewrite.go` +2. Check that the case traverses all child fields +3. Add debug logging to see if the node is being visited: +```go +case *ast.YourType: + fmt.Printf("Visiting YourType with fields: %+v\n", n) + a.apply(n, "ChildField", nil, n.ChildField) +``` + +## Parameter Output Format by Engine + +- PostgreSQL: `$1`, `$2`, `$3`, ... +- MySQL: `?`, `?`, `?`, ... +- SQLite: `?`, `?`, `?`, ... + +The format is determined by the `Dialect.Param()` method in each engine. diff --git a/internal/sqltest/docker/enabled.go b/internal/sqltest/docker/enabled.go index e17c0201b2..251ae1f332 100644 --- a/internal/sqltest/docker/enabled.go +++ b/internal/sqltest/docker/enabled.go @@ -13,5 +13,11 @@ func Installed() error { if _, err := exec.LookPath("docker"); err != nil { return fmt.Errorf("docker not found: %w", err) } + // Verify the Docker daemon is actually running and accessible. + // Without this check, tests will try Docker, fail on docker pull, + // and t.Fatal instead of falling back to native database support. + if out, err := exec.Command("docker", "info").CombinedOutput(); err != nil { + return fmt.Errorf("docker daemon not available: %w\n%s", err, out) + } return nil } diff --git a/internal/sqltest/local/mysql.go b/internal/sqltest/local/mysql.go index dedd3dfd78..05733f6e8b 100644 --- a/internal/sqltest/local/mysql.go +++ b/internal/sqltest/local/mysql.go @@ -14,6 +14,7 @@ import ( migrate "github.com/sqlc-dev/sqlc/internal/migrations" "github.com/sqlc-dev/sqlc/internal/sql/sqlpath" "github.com/sqlc-dev/sqlc/internal/sqltest/docker" + "github.com/sqlc-dev/sqlc/internal/sqltest/native" ) var mysqlSync sync.Once @@ -31,8 +32,15 @@ func MySQL(t *testing.T, migrations []string) string { t.Fatal(err) } dburi = u + } else if ierr := native.Supported(); ierr == nil { + // Fall back to native installation when Docker is not available + u, err := native.StartMySQLServer(ctx) + if err != nil { + t.Fatal(err) + } + dburi = u } else { - t.Skip("MYSQL_SERVER_URI is empty") + t.Skip("MYSQL_SERVER_URI is empty and neither Docker nor native installation is available") } } diff --git a/internal/sqltest/local/postgres.go b/internal/sqltest/local/postgres.go index feda4cf7ac..243a7133ab 100644 --- a/internal/sqltest/local/postgres.go +++ b/internal/sqltest/local/postgres.go @@ -16,6 +16,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/pgx/poolcache" "github.com/sqlc-dev/sqlc/internal/sql/sqlpath" "github.com/sqlc-dev/sqlc/internal/sqltest/docker" + "github.com/sqlc-dev/sqlc/internal/sqltest/native" ) var flight singleflight.Group @@ -41,8 +42,15 @@ func postgreSQL(t *testing.T, migrations []string, rw bool) string { t.Fatal(err) } dburi = u + } else if ierr := native.Supported(); ierr == nil { + // Fall back to native installation when Docker is not available + u, err := native.StartPostgreSQLServer(ctx) + if err != nil { + t.Fatal(err) + } + dburi = u } else { - t.Skip("POSTGRESQL_SERVER_URI is empty") + t.Skip("POSTGRESQL_SERVER_URI is empty and neither Docker nor native installation is available") } } diff --git a/internal/sqltest/native/enabled.go b/internal/sqltest/native/enabled.go new file mode 100644 index 0000000000..e5e12ccd80 --- /dev/null +++ b/internal/sqltest/native/enabled.go @@ -0,0 +1,20 @@ +package native + +import ( + "fmt" + "os/exec" + "runtime" +) + +// Supported returns nil if native database installation is supported on this platform. +// Currently only Linux (Ubuntu/Debian) is supported. +func Supported() error { + if runtime.GOOS != "linux" { + return fmt.Errorf("native database installation only supported on linux, got %s", runtime.GOOS) + } + // Check if apt-get is available (Debian/Ubuntu) + if _, err := exec.LookPath("apt-get"); err != nil { + return fmt.Errorf("apt-get not found: %w", err) + } + return nil +} diff --git a/internal/sqltest/native/mysql.go b/internal/sqltest/native/mysql.go new file mode 100644 index 0000000000..69482bace6 --- /dev/null +++ b/internal/sqltest/native/mysql.go @@ -0,0 +1,203 @@ +package native + +import ( + "context" + "database/sql" + "fmt" + "log/slog" + "os/exec" + "time" + + _ "github.com/go-sql-driver/mysql" + "golang.org/x/sync/singleflight" +) + +var mysqlFlight singleflight.Group +var mysqlURI string + +// StartMySQLServer starts an existing MySQL installation natively (without Docker). +func StartMySQLServer(ctx context.Context) (string, error) { + if err := Supported(); err != nil { + return "", err + } + if mysqlURI != "" { + return mysqlURI, nil + } + value, err, _ := mysqlFlight.Do("mysql", func() (interface{}, error) { + uri, err := startMySQLServer(ctx) + if err != nil { + return "", err + } + mysqlURI = uri + return uri, nil + }) + if err != nil { + return "", err + } + data, ok := value.(string) + if !ok { + return "", fmt.Errorf("returned value was not a string") + } + return data, nil +} + +func startMySQLServer(ctx context.Context) (string, error) { + // Standard URI for test MySQL + uri := "root:mysecretpassword@tcp(localhost:3306)/mysql?multiStatements=true&parseTime=true" + + // Try to connect first - it might already be running + if err := waitForMySQL(ctx, uri, 500*time.Millisecond); err == nil { + slog.Info("native/mysql", "status", "already running") + return uri, nil + } + + // Also try without password (default MySQL installation) + uriNoPassword := "root@tcp(localhost:3306)/mysql?multiStatements=true&parseTime=true" + if err := waitForMySQL(ctx, uriNoPassword, 500*time.Millisecond); err == nil { + slog.Info("native/mysql", "status", "already running (no password)") + // MySQL is running without password, try to set one + if err := setMySQLPassword(ctx); err != nil { + slog.Debug("native/mysql", "set-password-error", err) + // Return without password if we can't set one + return uriNoPassword, nil + } + // Try again with password + if err := waitForMySQL(ctx, uri, 1*time.Second); err == nil { + return uri, nil + } + // If password didn't work, use no password + return uriNoPassword, nil + } + + // Try to start existing MySQL service (might be installed but not running) + if _, err := exec.LookPath("mysqld"); err == nil { + slog.Info("native/mysql", "status", "starting existing service") + if err := startMySQLService(); err != nil { + slog.Debug("native/mysql", "start-error", err) + } else { + // Wait for MySQL to be ready + waitCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + // Try with password first + if err := waitForMySQL(waitCtx, uri, 15*time.Second); err == nil { + return uri, nil + } + + // Try without password + if err := waitForMySQL(waitCtx, uriNoPassword, 15*time.Second); err == nil { + if err := setMySQLPassword(ctx); err != nil { + slog.Debug("native/mysql", "set-password-error", err) + return uriNoPassword, nil + } + if err := waitForMySQL(ctx, uri, 1*time.Second); err == nil { + return uri, nil + } + return uriNoPassword, nil + } + } + } + + return "", fmt.Errorf("MySQL is not installed or could not be started") +} + +func startMySQLService() error { + // Try systemctl first + cmd := exec.Command("sudo", "systemctl", "start", "mysql") + if err := cmd.Run(); err == nil { + // Give MySQL time to fully initialize + time.Sleep(2 * time.Second) + return nil + } + + // Try mysqld + cmd = exec.Command("sudo", "systemctl", "start", "mysqld") + if err := cmd.Run(); err == nil { + time.Sleep(2 * time.Second) + return nil + } + + // Try service command + cmd = exec.Command("sudo", "service", "mysql", "start") + if err := cmd.Run(); err == nil { + time.Sleep(2 * time.Second) + return nil + } + + cmd = exec.Command("sudo", "service", "mysqld", "start") + if err := cmd.Run(); err == nil { + time.Sleep(2 * time.Second) + return nil + } + + return fmt.Errorf("could not start MySQL service") +} + +func setMySQLPassword(ctx context.Context) error { + // Connect without password + db, err := sql.Open("mysql", "root@tcp(localhost:3306)/mysql") + if err != nil { + return err + } + defer db.Close() + + // Set root password using mysql_native_password for broader compatibility + _, err = db.ExecContext(ctx, "ALTER USER 'root'@'localhost' IDENTIFIED WITH mysql_native_password BY 'mysecretpassword';") + if err != nil { + // Try without specifying auth plugin + _, err = db.ExecContext(ctx, "ALTER USER 'root'@'localhost' IDENTIFIED BY 'mysecretpassword';") + if err != nil { + // Try older MySQL syntax + _, err = db.ExecContext(ctx, "SET PASSWORD FOR 'root'@'localhost' = PASSWORD('mysecretpassword');") + if err != nil { + return fmt.Errorf("could not set MySQL password: %w", err) + } + } + } + + // Flush privileges + _, _ = db.ExecContext(ctx, "FLUSH PRIVILEGES;") + + return nil +} + +func waitForMySQL(ctx context.Context, uri string, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + // Make an immediate first attempt before waiting for the ticker + if err := tryMySQLConnection(ctx, uri); err == nil { + return nil + } + + var lastErr error + for { + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled: %w (last error: %v)", ctx.Err(), lastErr) + case <-ticker.C: + if time.Now().After(deadline) { + return fmt.Errorf("timeout waiting for MySQL (last error: %v)", lastErr) + } + if err := tryMySQLConnection(ctx, uri); err != nil { + lastErr = err + continue + } + return nil + } + } +} + +func tryMySQLConnection(ctx context.Context, uri string) error { + db, err := sql.Open("mysql", uri) + if err != nil { + slog.Debug("native/mysql", "open-attempt", err) + return err + } + defer db.Close() + // Use a short timeout for ping to avoid hanging + pingCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + return db.PingContext(pingCtx) +} diff --git a/internal/sqltest/native/postgres.go b/internal/sqltest/native/postgres.go new file mode 100644 index 0000000000..f805a40a1c --- /dev/null +++ b/internal/sqltest/native/postgres.go @@ -0,0 +1,221 @@ +package native + +import ( + "context" + "fmt" + "log/slog" + "os/exec" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "golang.org/x/sync/singleflight" +) + +var postgresFlight singleflight.Group +var postgresURI string + +// StartPostgreSQLServer starts an existing PostgreSQL installation natively (without Docker). +func StartPostgreSQLServer(ctx context.Context) (string, error) { + if err := Supported(); err != nil { + return "", err + } + if postgresURI != "" { + return postgresURI, nil + } + value, err, _ := postgresFlight.Do("postgresql", func() (interface{}, error) { + uri, err := startPostgreSQLServer(ctx) + if err != nil { + return "", err + } + postgresURI = uri + return uri, nil + }) + if err != nil { + return "", err + } + data, ok := value.(string) + if !ok { + return "", fmt.Errorf("returned value was not a string") + } + return data, nil +} + +func startPostgreSQLServer(ctx context.Context) (string, error) { + // Standard URI for test PostgreSQL + uri := "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable" + + // Try to connect first - it might already be running + if err := waitForPostgres(ctx, uri, 500*time.Millisecond); err == nil { + slog.Info("native/postgres", "status", "already running") + return uri, nil + } + + // Check if PostgreSQL is installed + if _, err := exec.LookPath("psql"); err != nil { + return "", fmt.Errorf("PostgreSQL is not installed (psql not found)") + } + + // Start PostgreSQL service + slog.Info("native/postgres", "status", "starting service") + + // Try systemctl first, fall back to pg_ctlcluster + if err := startPostgresService(); err != nil { + return "", fmt.Errorf("failed to start PostgreSQL: %w", err) + } + + // Configure PostgreSQL for password authentication + if err := configurePostgres(); err != nil { + return "", fmt.Errorf("failed to configure PostgreSQL: %w", err) + } + + // Wait for PostgreSQL to be ready + waitCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + if err := waitForPostgres(waitCtx, uri, 30*time.Second); err != nil { + return "", fmt.Errorf("timeout waiting for PostgreSQL: %w", err) + } + + return uri, nil +} + +func startPostgresService() error { + // Try systemctl first + cmd := exec.Command("sudo", "systemctl", "start", "postgresql") + if err := cmd.Run(); err == nil { + return nil + } + + // Try service command + cmd = exec.Command("sudo", "service", "postgresql", "start") + if err := cmd.Run(); err == nil { + return nil + } + + // Try pg_ctlcluster (Debian/Ubuntu specific) + // Find the installed PostgreSQL version + output, err := exec.Command("ls", "/etc/postgresql/").CombinedOutput() + if err != nil { + return fmt.Errorf("could not find PostgreSQL version: %w", err) + } + + versions := strings.Fields(string(output)) + if len(versions) == 0 { + return fmt.Errorf("no PostgreSQL version found in /etc/postgresql/") + } + + version := versions[0] + cmd = exec.Command("sudo", "pg_ctlcluster", version, "main", "start") + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("pg_ctlcluster start failed: %w\n%s", err, output) + } + + return nil +} + +func configurePostgres() error { + // Set password for postgres user using sudo -u postgres + cmd := exec.Command("sudo", "-u", "postgres", "psql", "-c", "ALTER USER postgres PASSWORD 'postgres';") + if output, err := cmd.CombinedOutput(); err != nil { + // This might fail if password is already set, which is fine + slog.Debug("native/postgres", "set-password", string(output)) + } + + // Update pg_hba.conf to allow password authentication + // First, find the pg_hba.conf file + output, err := exec.Command("sudo", "-u", "postgres", "psql", "-t", "-c", "SHOW hba_file;").CombinedOutput() + if err != nil { + return fmt.Errorf("could not find hba_file: %w", err) + } + + hbaFile := strings.TrimSpace(string(output)) + if hbaFile == "" { + return fmt.Errorf("empty hba_file path") + } + + // Check if we need to update pg_hba.conf + catOutput, err := exec.Command("sudo", "cat", hbaFile).CombinedOutput() + if err != nil { + return fmt.Errorf("could not read %s: %w", hbaFile, err) + } + + // If md5 or scram-sha-256 auth is not configured for local connections, add it + content := string(catOutput) + if !strings.Contains(content, "host all all 127.0.0.1/32 md5") && + !strings.Contains(content, "host all all 127.0.0.1/32 scram-sha-256") { + + // Prepend a rule for localhost password authentication + newRule := "host all all 127.0.0.1/32 md5\n" + + // Use sed to add the rule at the beginning (after comments) + cmd := exec.Command("sudo", "bash", "-c", + fmt.Sprintf(`echo '%s' | cat - %s > /tmp/pg_hba.conf.new && sudo mv /tmp/pg_hba.conf.new %s`, + newRule, hbaFile, hbaFile)) + if output, err := cmd.CombinedOutput(); err != nil { + slog.Debug("native/postgres", "update-hba-error", string(output)) + } + + // Reload PostgreSQL to apply changes + if err := reloadPostgres(); err != nil { + slog.Debug("native/postgres", "reload-error", err) + } + } + + return nil +} + +func reloadPostgres() error { + // Try systemctl reload + cmd := exec.Command("sudo", "systemctl", "reload", "postgresql") + if err := cmd.Run(); err == nil { + return nil + } + + // Try service reload + cmd = exec.Command("sudo", "service", "postgresql", "reload") + if err := cmd.Run(); err == nil { + return nil + } + + // Try pg_ctlcluster reload + output, _ := exec.Command("ls", "/etc/postgresql/").CombinedOutput() + versions := strings.Fields(string(output)) + if len(versions) > 0 { + cmd = exec.Command("sudo", "pg_ctlcluster", versions[0], "main", "reload") + return cmd.Run() + } + + return fmt.Errorf("could not reload PostgreSQL") +} + +func waitForPostgres(ctx context.Context, uri string, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + var lastErr error + for { + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled: %w (last error: %v)", ctx.Err(), lastErr) + case <-ticker.C: + if time.Now().After(deadline) { + return fmt.Errorf("timeout waiting for PostgreSQL (last error: %v)", lastErr) + } + conn, err := pgx.Connect(ctx, uri) + if err != nil { + lastErr = err + slog.Debug("native/postgres", "connect-attempt", err) + continue + } + if err := conn.Ping(ctx); err != nil { + lastErr = err + conn.Close(ctx) + continue + } + conn.Close(ctx) + return nil + } + } +} diff --git a/internal/sqltest/sqlite.go b/internal/sqltest/sqlite.go index 0e5161967d..3ad04bb78d 100644 --- a/internal/sqltest/sqlite.go +++ b/internal/sqltest/sqlite.go @@ -6,6 +6,9 @@ import ( "path/filepath" "testing" + _ "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" + "github.com/sqlc-dev/sqlc/internal/sql/sqlpath" ) @@ -26,7 +29,7 @@ func CreateSQLiteDatabase(t *testing.T, path string, migrations []string) (*sql. t.Helper() t.Logf("open %s\n", path) - sdb, err := sql.Open("sqlite", path) + sdb, err := sql.Open("sqlite3", path) if err != nil { t.Fatal(err) } diff --git a/internal/sqltest/sqlite_modernc.go b/internal/sqltest/sqlite_modernc.go deleted file mode 100644 index 708ea40e49..0000000000 --- a/internal/sqltest/sqlite_modernc.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !wasm - -package sqltest - -import ( - _ "modernc.org/sqlite" -) diff --git a/internal/x/expander/expander.go b/internal/x/expander/expander.go new file mode 100644 index 0000000000..af0cab26e8 --- /dev/null +++ b/internal/x/expander/expander.go @@ -0,0 +1,507 @@ +package expander + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/astutils" + "github.com/sqlc-dev/sqlc/internal/sql/format" +) + +// Parser is an interface for SQL parsers that can parse SQL into AST statements. +type Parser interface { + Parse(r io.Reader) ([]ast.Statement, error) +} + +// ColumnGetter retrieves column names for a query by preparing it against a database. +type ColumnGetter interface { + GetColumnNames(ctx context.Context, query string) ([]string, error) +} + +// Expander expands SELECT * and RETURNING * queries by replacing * with explicit column names +// obtained from preparing the query against a database. +type Expander struct { + colGetter ColumnGetter + parser Parser + dialect format.Dialect +} + +// New creates a new Expander with the given column getter, parser, and dialect. +func New(colGetter ColumnGetter, parser Parser, dialect format.Dialect) *Expander { + return &Expander{ + colGetter: colGetter, + parser: parser, + dialect: dialect, + } +} + +// Expand takes a SQL query, and if it contains * in SELECT or RETURNING clause, +// expands it to use explicit column names. Returns the expanded query string. +func (e *Expander) Expand(ctx context.Context, query string) (string, error) { + // Parse the query + stmts, err := e.parser.Parse(strings.NewReader(query)) + if err != nil { + return "", fmt.Errorf("failed to parse query: %w", err) + } + + if len(stmts) == 0 { + return query, nil + } + + stmt := stmts[0].Raw.Stmt + + // Check if there's any star in the statement (including CTEs, subqueries, etc.) + if !hasStarAnywhere(stmt) { + return query, nil + } + + // Expand all stars in the statement recursively + if err := e.expandNode(ctx, stmt); err != nil { + return "", err + } + + // Format the modified AST back to SQL + expanded := ast.Format(stmts[0].Raw, e.dialect) + + return expanded, nil +} + +// expandNode recursively expands * in all parts of the statement +func (e *Expander) expandNode(ctx context.Context, node ast.Node) error { + if node == nil { + return nil + } + + switch n := node.(type) { + case *ast.SelectStmt: + return e.expandSelectStmt(ctx, n) + case *ast.InsertStmt: + return e.expandInsertStmt(ctx, n) + case *ast.UpdateStmt: + return e.expandUpdateStmt(ctx, n) + case *ast.DeleteStmt: + return e.expandDeleteStmt(ctx, n) + case *ast.CommonTableExpr: + return e.expandNode(ctx, n.Ctequery) + } + return nil +} + +// expandSelectStmt expands * in a SELECT statement including CTEs and subqueries +func (e *Expander) expandSelectStmt(ctx context.Context, stmt *ast.SelectStmt) error { + // First expand any CTEs - must be done in order since later CTEs may depend on earlier ones + if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { + for _, cteNode := range stmt.WithClause.Ctes.Items { + cte, ok := cteNode.(*ast.CommonTableExpr) + if !ok { + continue + } + cteSelect, ok := cte.Ctequery.(*ast.SelectStmt) + if !ok { + continue + } + if hasStarInList(cteSelect.TargetList) { + // Get column names for this CTE + columns, err := e.getCTEColumnNames(ctx, stmt, cte) + if err != nil { + return err + } + cteSelect.TargetList = rewriteTargetList(cteSelect.TargetList, columns) + } + // Recursively handle nested CTEs/subqueries in this CTE + if err := e.expandSelectStmtInner(ctx, cteSelect); err != nil { + return err + } + } + } + + // Expand subqueries in FROM clause + if stmt.FromClause != nil { + for _, fromItem := range stmt.FromClause.Items { + if err := e.expandFromClause(ctx, fromItem); err != nil { + return err + } + } + } + + // Expand the target list if it has stars + if hasStarInList(stmt.TargetList) { + // Format the current state to get columns + tempRaw := &ast.RawStmt{Stmt: stmt} + tempQuery := ast.Format(tempRaw, e.dialect) + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.TargetList = rewriteTargetList(stmt.TargetList, columns) + } + + return nil +} + +// expandSelectStmtInner expands nested structures without re-processing the target list +func (e *Expander) expandSelectStmtInner(ctx context.Context, stmt *ast.SelectStmt) error { + // Expand subqueries in FROM clause + if stmt.FromClause != nil { + for _, fromItem := range stmt.FromClause.Items { + if err := e.expandFromClause(ctx, fromItem); err != nil { + return err + } + } + } + return nil +} + +// getCTEColumnNames gets the column names for a CTE by constructing a query with proper context +func (e *Expander) getCTEColumnNames(ctx context.Context, stmt *ast.SelectStmt, targetCTE *ast.CommonTableExpr) ([]string, error) { + // Build a temporary query: WITH SELECT * FROM + var ctesToInclude []ast.Node + for _, cteNode := range stmt.WithClause.Ctes.Items { + ctesToInclude = append(ctesToInclude, cteNode) + cte, ok := cteNode.(*ast.CommonTableExpr) + if ok && cte.Ctename != nil && targetCTE.Ctename != nil && *cte.Ctename == *targetCTE.Ctename { + break + } + } + + // Create a SELECT * FROM with the relevant CTEs + cteName := "" + if targetCTE.Ctename != nil { + cteName = *targetCTE.Ctename + } + + tempStmt := &ast.SelectStmt{ + WithClause: &ast.WithClause{ + Ctes: &ast.List{Items: ctesToInclude}, + Recursive: stmt.WithClause.Recursive, + }, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{&ast.A_Star{}}, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: &cteName, + }, + }, + }, + } + + tempRaw := &ast.RawStmt{Stmt: tempStmt} + tempQuery := ast.Format(tempRaw, e.dialect) + + return e.getColumnNames(ctx, tempQuery) +} + +// expandInsertStmt expands * in an INSERT statement's RETURNING clause +func (e *Expander) expandInsertStmt(ctx context.Context, stmt *ast.InsertStmt) error { + // Expand CTEs first + if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { + for _, cte := range stmt.WithClause.Ctes.Items { + if err := e.expandNode(ctx, cte); err != nil { + return err + } + } + } + + // Expand the SELECT part if present + if stmt.SelectStmt != nil { + if err := e.expandNode(ctx, stmt.SelectStmt); err != nil { + return err + } + } + + // Expand RETURNING clause + if hasStarInList(stmt.ReturningList) { + tempRaw := &ast.RawStmt{Stmt: stmt} + tempQuery := ast.Format(tempRaw, e.dialect) + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) + } + + return nil +} + +// expandUpdateStmt expands * in an UPDATE statement's RETURNING clause +func (e *Expander) expandUpdateStmt(ctx context.Context, stmt *ast.UpdateStmt) error { + // Expand CTEs first + if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { + for _, cte := range stmt.WithClause.Ctes.Items { + if err := e.expandNode(ctx, cte); err != nil { + return err + } + } + } + + // Expand RETURNING clause + if hasStarInList(stmt.ReturningList) { + tempRaw := &ast.RawStmt{Stmt: stmt} + tempQuery := ast.Format(tempRaw, e.dialect) + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) + } + + return nil +} + +// expandDeleteStmt expands * in a DELETE statement's RETURNING clause +func (e *Expander) expandDeleteStmt(ctx context.Context, stmt *ast.DeleteStmt) error { + // Expand CTEs first + if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { + for _, cte := range stmt.WithClause.Ctes.Items { + if err := e.expandNode(ctx, cte); err != nil { + return err + } + } + } + + // Expand RETURNING clause + if hasStarInList(stmt.ReturningList) { + tempRaw := &ast.RawStmt{Stmt: stmt} + tempQuery := ast.Format(tempRaw, e.dialect) + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) + } + + return nil +} + +// expandFromClause expands * in subqueries within FROM clause +func (e *Expander) expandFromClause(ctx context.Context, node ast.Node) error { + if node == nil { + return nil + } + + switch n := node.(type) { + case *ast.RangeSubselect: + if n.Subquery != nil { + return e.expandNode(ctx, n.Subquery) + } + case *ast.JoinExpr: + if err := e.expandFromClause(ctx, n.Larg); err != nil { + return err + } + if err := e.expandFromClause(ctx, n.Rarg); err != nil { + return err + } + } + return nil +} + +// hasStarAnywhere checks if there's a * anywhere in the statement using astutils.Search +func hasStarAnywhere(node ast.Node) bool { + if node == nil { + return false + } + // Use astutils.Search to find any A_Star node in the AST + stars := astutils.Search(node, func(n ast.Node) bool { + _, ok := n.(*ast.A_Star) + return ok + }) + return len(stars.Items) > 0 +} + +// hasStarInList checks if a target list contains a * expression using astutils.Search +func hasStarInList(targets *ast.List) bool { + if targets == nil { + return false + } + // Use astutils.Search to find any A_Star node in the target list + stars := astutils.Search(targets, func(n ast.Node) bool { + _, ok := n.(*ast.A_Star) + return ok + }) + return len(stars.Items) > 0 +} + +// getColumnNames prepares the query and returns the column names from the result +func (e *Expander) getColumnNames(ctx context.Context, query string) ([]string, error) { + return e.colGetter.GetColumnNames(ctx, query) +} + +// countStarsInList counts the number of * expressions in a target list +func countStarsInList(targets *ast.List) int { + if targets == nil { + return 0 + } + count := 0 + for _, target := range targets.Items { + resTarget, ok := target.(*ast.ResTarget) + if !ok { + continue + } + if resTarget.Val == nil { + continue + } + colRef, ok := resTarget.Val.(*ast.ColumnRef) + if !ok { + continue + } + if colRef.Fields == nil { + continue + } + for _, field := range colRef.Fields.Items { + if _, ok := field.(*ast.A_Star); ok { + count++ + break + } + } + } + return count +} + +// countNonStarsInList counts the number of non-* expressions in a target list +func countNonStarsInList(targets *ast.List) int { + if targets == nil { + return 0 + } + count := 0 + for _, target := range targets.Items { + resTarget, ok := target.(*ast.ResTarget) + if !ok { + count++ + continue + } + if resTarget.Val == nil { + count++ + continue + } + colRef, ok := resTarget.Val.(*ast.ColumnRef) + if !ok { + count++ + continue + } + if colRef.Fields == nil { + count++ + continue + } + isStar := false + for _, field := range colRef.Fields.Items { + if _, ok := field.(*ast.A_Star); ok { + isStar = true + break + } + } + if !isStar { + count++ + } + } + return count +} + +// rewriteTargetList replaces * in a target list with explicit column references +func rewriteTargetList(targets *ast.List, columns []string) *ast.List { + if targets == nil { + return nil + } + + starCount := countStarsInList(targets) + nonStarCount := countNonStarsInList(targets) + + // Calculate how many columns each * expands to + // Total columns = (columns per star * number of stars) + non-star columns + // So: columns per star = (total - non-star) / stars + columnsPerStar := 0 + if starCount > 0 { + columnsPerStar = (len(columns) - nonStarCount) / starCount + } + + newItems := make([]ast.Node, 0, len(columns)) + colIndex := 0 + + for _, target := range targets.Items { + resTarget, ok := target.(*ast.ResTarget) + if !ok { + newItems = append(newItems, target) + colIndex++ + continue + } + + if resTarget.Val == nil { + newItems = append(newItems, target) + colIndex++ + continue + } + + colRef, ok := resTarget.Val.(*ast.ColumnRef) + if !ok { + newItems = append(newItems, target) + colIndex++ + continue + } + + if colRef.Fields == nil { + newItems = append(newItems, target) + colIndex++ + continue + } + + // Check if this is a * (with or without table qualifier) + // and extract any table prefix + isStar := false + var tablePrefix []string + for _, field := range colRef.Fields.Items { + if _, ok := field.(*ast.A_Star); ok { + isStar = true + break + } + // Collect prefix parts (schema, table name) + if str, ok := field.(*ast.String); ok { + tablePrefix = append(tablePrefix, str.Str) + } + } + + if !isStar { + newItems = append(newItems, target) + colIndex++ + continue + } + + // Replace * with explicit column references + for i := 0; i < columnsPerStar && colIndex < len(columns); i++ { + newItems = append(newItems, makeColumnTargetWithPrefix(columns[colIndex], tablePrefix)) + colIndex++ + } + } + + return &ast.List{Items: newItems} +} + +// makeColumnTargetWithPrefix creates a ResTarget node for a column reference with optional table prefix +func makeColumnTargetWithPrefix(colName string, prefix []string) ast.Node { + fields := make([]ast.Node, 0, len(prefix)+1) + + // Add prefix parts (schema, table name) + for _, p := range prefix { + fields = append(fields, &ast.String{Str: p}) + } + + // Add column name + fields = append(fields, &ast.String{Str: colName}) + + return &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: fields}, + }, + } +} diff --git a/internal/x/expander/expander_test.go b/internal/x/expander/expander_test.go new file mode 100644 index 0000000000..52d62c6b5e --- /dev/null +++ b/internal/x/expander/expander_test.go @@ -0,0 +1,456 @@ +package expander + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "os" + "testing" + + "github.com/go-sql-driver/mysql" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" + + "github.com/sqlc-dev/sqlc/internal/engine/dolphin" + "github.com/sqlc-dev/sqlc/internal/engine/postgresql" + "github.com/sqlc-dev/sqlc/internal/engine/sqlite" + "github.com/sqlc-dev/sqlc/internal/sqltest/docker" + "github.com/sqlc-dev/sqlc/internal/sqltest/native" +) + +// PostgreSQLColumnGetter implements ColumnGetter for PostgreSQL using pgxpool. +type PostgreSQLColumnGetter struct { + pool *pgxpool.Pool +} + +func (g *PostgreSQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { + conn, err := g.pool.Acquire(ctx) + if err != nil { + return nil, err + } + defer conn.Release() + + desc, err := conn.Conn().Prepare(ctx, "", query) + if err != nil { + return nil, err + } + + columns := make([]string, len(desc.Fields)) + for i, field := range desc.Fields { + columns[i] = field.Name + } + + return columns, nil +} + +// MySQLColumnGetter implements ColumnGetter for MySQL using the forked driver's StmtMetadata. +type MySQLColumnGetter struct { + db *sql.DB +} + +func (g *MySQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { + conn, err := g.db.Conn(ctx) + if err != nil { + return nil, err + } + defer conn.Close() + + var columns []string + err = conn.Raw(func(driverConn any) error { + preparer, ok := driverConn.(driver.ConnPrepareContext) + if !ok { + return fmt.Errorf("driver connection does not support PrepareContext") + } + + stmt, err := preparer.PrepareContext(ctx, query) + if err != nil { + return err + } + defer stmt.Close() + + meta, ok := stmt.(mysql.StmtMetadata) + if !ok { + return fmt.Errorf("prepared statement does not implement StmtMetadata") + } + + for _, col := range meta.ColumnMetadata() { + columns = append(columns, col.Name) + } + return nil + }) + if err != nil { + return nil, err + } + + return columns, nil +} + +// SQLiteColumnGetter implements ColumnGetter for SQLite using the native ncruces/go-sqlite3 API. +type SQLiteColumnGetter struct { + conn *sqlite3.Conn +} + +func (g *SQLiteColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { + // Prepare the statement - this gives us column metadata without executing + stmt, _, err := g.conn.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + + // Get column names from the prepared statement + count := stmt.ColumnCount() + columns := make([]string, count) + for i := 0; i < count; i++ { + columns[i] = stmt.ColumnName(i) + } + + return columns, nil +} + +func TestExpandPostgreSQL(t *testing.T) { + ctx := context.Background() + + uri := os.Getenv("POSTGRESQL_SERVER_URI") + if uri == "" { + if err := docker.Installed(); err == nil { + u, err := docker.StartPostgreSQLServer(ctx) + if err != nil { + t.Fatal(err) + } + uri = u + } else if err := native.Supported(); err == nil { + u, err := native.StartPostgreSQLServer(ctx) + if err != nil { + t.Fatal(err) + } + uri = u + } else { + t.Skip("POSTGRESQL_SERVER_URI is empty and neither Docker nor native installation is available") + } + } + + pool, err := pgxpool.New(ctx, uri) + if err != nil { + t.Skipf("could not connect to database: %v", err) + } + defer pool.Close() + + // Create a test table + _, err = pool.Exec(ctx, ` + DROP TABLE IF EXISTS authors; + CREATE TABLE authors ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + bio TEXT + ); + `) + if err != nil { + t.Fatalf("failed to create test table: %v", err) + } + defer pool.Exec(ctx, "DROP TABLE IF EXISTS authors") + + // Create the parser which also implements format.Dialect + parser := postgresql.NewParser() + + // Create the expander + colGetter := &PostgreSQLColumnGetter{pool: pool} + exp := New(colGetter, parser, parser) + + tests := []struct { + name string + query string + expected string + }{ + { + name: "simple select star", + query: "SELECT * FROM authors", + expected: "SELECT id, name, bio FROM authors;", + }, + { + name: "select with no star", + query: "SELECT id, name FROM authors", + expected: "SELECT id, name FROM authors", // No change, returns original + }, + { + name: "select star with where clause", + query: "SELECT * FROM authors WHERE id = 1", + expected: "SELECT id, name, bio FROM authors WHERE id = 1;", + }, + { + name: "double star", + query: "SELECT *, * FROM authors", + expected: "SELECT id, name, bio, id, name, bio FROM authors;", + }, + { + name: "table qualified star", + query: "SELECT authors.* FROM authors", + expected: "SELECT authors.id, authors.name, authors.bio FROM authors;", + }, + { + name: "star in middle of columns", + query: "SELECT id, *, name FROM authors", + expected: "SELECT id, id, name, bio, name FROM authors;", + }, + { + name: "insert returning star", + query: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING *", + expected: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, name, bio;", + }, + { + name: "insert returning mixed", + query: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, *", + expected: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, id, name, bio;", + }, + { + name: "update returning star", + query: "UPDATE authors SET name = 'Jane' WHERE id = 1 RETURNING *", + expected: "UPDATE authors SET name = 'Jane' WHERE id = 1 RETURNING id, name, bio;", + }, + { + name: "delete returning star", + query: "DELETE FROM authors WHERE id = 1 RETURNING *", + expected: "DELETE FROM authors WHERE id = 1 RETURNING id, name, bio;", + }, + { + name: "cte with select star", + query: "WITH a AS (SELECT * FROM authors) SELECT * FROM a", + expected: "WITH a AS (SELECT id, name, bio FROM authors) SELECT id, name, bio FROM a;", + }, + { + name: "multiple ctes with dependency", + query: "WITH a AS (SELECT * FROM authors), b AS (SELECT * FROM a) SELECT * FROM b", + expected: "WITH a AS (SELECT id, name, bio FROM authors), b AS (SELECT id, name, bio FROM a) SELECT id, name, bio FROM b;", + }, + { + name: "count star not expanded", + query: "SELECT COUNT(*) FROM authors", + expected: "SELECT COUNT(*) FROM authors", // No change - COUNT(*) should not be expanded + }, + { + name: "count star with other columns", + query: "SELECT COUNT(*), name FROM authors GROUP BY name", + expected: "SELECT COUNT(*), name FROM authors GROUP BY name", // No change + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := exp.Expand(ctx, tc.query) + if err != nil { + t.Fatalf("Expand failed: %v", err) + } + if result != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, result) + } + }) + } +} + +func TestExpandMySQL(t *testing.T) { + ctx := context.Background() + + source := os.Getenv("MYSQL_SERVER_URI") + if source == "" { + if err := docker.Installed(); err == nil { + u, err := docker.StartMySQLServer(ctx) + if err != nil { + t.Fatal(err) + } + source = u + } else if err := native.Supported(); err == nil { + u, err := native.StartMySQLServer(ctx) + if err != nil { + t.Fatal(err) + } + source = u + } else { + t.Skip("MYSQL_SERVER_URI is empty and neither Docker nor native installation is available") + } + } + + db, err := sql.Open("mysql", source) + if err != nil { + t.Skipf("could not connect to MySQL: %v", err) + } + defer db.Close() + + // Verify connection + if err := db.Ping(); err != nil { + t.Skipf("could not ping MySQL: %v", err) + } + + // Create a test table + _, err = db.ExecContext(ctx, `DROP TABLE IF EXISTS authors`) + if err != nil { + t.Fatalf("failed to drop test table: %v", err) + } + _, err = db.ExecContext(ctx, ` + CREATE TABLE authors ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) NOT NULL, + bio TEXT + ) + `) + if err != nil { + t.Fatalf("failed to create test table: %v", err) + } + defer db.ExecContext(ctx, "DROP TABLE IF EXISTS authors") + + // Create the parser which also implements format.Dialect + parser := dolphin.NewParser() + + // Create the expander + colGetter := &MySQLColumnGetter{db: db} + exp := New(colGetter, parser, parser) + + tests := []struct { + name string + query string + expected string + }{ + { + name: "simple select star", + query: "SELECT * FROM authors", + expected: "SELECT id, name, bio FROM authors;", + }, + { + name: "select with no star", + query: "SELECT id, name FROM authors", + expected: "SELECT id, name FROM authors", // No change, returns original + }, + { + name: "select star with where clause", + query: "SELECT * FROM authors WHERE id = 1", + expected: "SELECT id, name, bio FROM authors WHERE id = 1;", + }, + { + name: "table qualified star", + query: "SELECT authors.* FROM authors", + expected: "SELECT authors.id, authors.name, authors.bio FROM authors;", + }, + { + name: "double table qualified star", + query: "SELECT authors.*, authors.* FROM authors", + expected: "SELECT authors.id, authors.name, authors.bio, authors.id, authors.name, authors.bio FROM authors;", + }, + { + name: "star in middle of columns table qualified", + query: "SELECT id, authors.*, name FROM authors", + expected: "SELECT id, authors.id, authors.name, authors.bio, name FROM authors;", + }, + { + name: "count star not expanded", + query: "SELECT COUNT(*) FROM authors", + expected: "SELECT COUNT(*) FROM authors", // No change - COUNT(*) should not be expanded + }, + { + name: "count star with other columns", + query: "SELECT COUNT(*), name FROM authors GROUP BY name", + expected: "SELECT COUNT(*), name FROM authors GROUP BY name", // No change + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := exp.Expand(ctx, tc.query) + if err != nil { + t.Fatalf("Expand failed: %v", err) + } + if result != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, result) + } + }) + } +} + +func TestExpandSQLite(t *testing.T) { + ctx := context.Background() + + // Create an in-memory SQLite database using native API + conn, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatalf("could not open SQLite: %v", err) + } + defer conn.Close() + + // Create a test table + err = conn.Exec(` + CREATE TABLE authors ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + bio TEXT + ) + `) + if err != nil { + t.Fatalf("failed to create test table: %v", err) + } + + // Create the parser which also implements format.Dialect + parser := sqlite.NewParser() + + // Create the expander using native SQLite column getter + colGetter := &SQLiteColumnGetter{conn: conn} + exp := New(colGetter, parser, parser) + + tests := []struct { + name string + query string + expected string + }{ + { + name: "simple select star", + query: "SELECT * FROM authors", + expected: "SELECT id, name, bio FROM authors;", + }, + { + name: "select with no star", + query: "SELECT id, name FROM authors", + expected: "SELECT id, name FROM authors", // No change, returns original + }, + { + name: "select star with where clause", + query: "SELECT * FROM authors WHERE id = 1", + expected: "SELECT id, name, bio FROM authors WHERE id = 1;", + }, + { + name: "double star", + query: "SELECT *, * FROM authors", + expected: "SELECT id, name, bio, id, name, bio FROM authors;", + }, + { + name: "table qualified star", + query: "SELECT authors.* FROM authors", + expected: "SELECT authors.id, authors.name, authors.bio FROM authors;", + }, + { + name: "star in middle of columns", + query: "SELECT id, *, name FROM authors", + expected: "SELECT id, id, name, bio, name FROM authors;", + }, + { + name: "count star not expanded", + query: "SELECT COUNT(*) FROM authors", + expected: "SELECT COUNT(*) FROM authors", // No change - COUNT(*) should not be expanded + }, + { + name: "count star with other columns", + query: "SELECT COUNT(*), name FROM authors GROUP BY name", + expected: "SELECT COUNT(*), name FROM authors GROUP BY name", // No change + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := exp.Expand(ctx, tc.query) + if err != nil { + t.Fatalf("Expand failed: %v", err) + } + if result != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, result) + } + }) + } +}