diff --git a/.github/actions/setup-jfrog/action.yml b/.github/actions/setup-jfrog/action.yml new file mode 100644 index 000000000..97ae146ba --- /dev/null +++ b/.github/actions/setup-jfrog/action.yml @@ -0,0 +1,32 @@ +name: Setup JFrog OIDC +description: Obtain a JFrog access token via GitHub OIDC and configure pip to use JFrog PyPI proxy + +runs: + using: composite + steps: + - name: Get JFrog OIDC token + shell: bash + run: | + set -euo pipefail + ID_TOKEN=$(curl -sLS \ + -H "User-Agent: actions/oidc-client" \ + -H "Authorization: Bearer $ACTIONS_ID_TOKEN_REQUEST_TOKEN" \ + "${ACTIONS_ID_TOKEN_REQUEST_URL}&audience=jfrog-github" | jq .value | tr -d '"') + echo "::add-mask::${ID_TOKEN}" + ACCESS_TOKEN=$(curl -sLS -XPOST -H "Content-Type: application/json" \ + "https://databricks.jfrog.io/access/api/v1/oidc/token" \ + -d "{\"grant_type\": \"urn:ietf:params:oauth:grant-type:token-exchange\", \"subject_token_type\":\"urn:ietf:params:oauth:token-type:id_token\", \"subject_token\": \"${ID_TOKEN}\", \"provider_name\": \"github-actions\"}" | jq .access_token | tr -d '"') + echo "::add-mask::${ACCESS_TOKEN}" + if [ -z "$ACCESS_TOKEN" ] || [ "$ACCESS_TOKEN" = "null" ]; then + echo "FAIL: Could not extract JFrog access token" + exit 1 + fi + echo "JFROG_ACCESS_TOKEN=${ACCESS_TOKEN}" >> "$GITHUB_ENV" + echo "JFrog OIDC token obtained successfully" + + - name: Configure pip + shell: bash + run: | + set -euo pipefail + echo "PIP_INDEX_URL=https://gha-service-account:${JFROG_ACCESS_TOKEN}@databricks.jfrog.io/artifactory/api/pypi/db-pypi/simple" >> "$GITHUB_ENV" + echo "pip configured to use JFrog registry" diff --git a/.github/actions/setup-poetry/action.yml b/.github/actions/setup-poetry/action.yml new file mode 100644 index 000000000..f7e15b1c0 --- /dev/null +++ b/.github/actions/setup-poetry/action.yml @@ -0,0 +1,63 @@ +name: Setup Poetry with JFrog +description: Install Poetry, configure JFrog as primary PyPI source, and install project dependencies + +inputs: + python-version: + description: Python version to set up + required: true + install-args: + description: Extra arguments for poetry install (e.g. --all-extras) + required: false + default: "" + cache-path: + description: Path to the virtualenv for caching (e.g. .venv or .venv-pyarrow) + required: false + default: ".venv" + cache-suffix: + description: Extra suffix for the cache key to avoid collisions across job variants + required: false + default: "" + +runs: + using: composite + steps: + - name: Setup JFrog + uses: ./.github/actions/setup-jfrog + + - name: Set up python ${{ inputs.python-version }} + id: setup-python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 + with: + python-version: ${{ inputs.python-version }} + + - name: Install Poetry + shell: bash + run: | + pip install poetry==2.2.1 + poetry config virtualenvs.create true + poetry config virtualenvs.in-project true + poetry config installer.parallel true + + - name: Configure Poetry JFrog source + shell: bash + run: | + poetry config repositories.jfrog https://databricks.jfrog.io/artifactory/api/pypi/db-pypi/simple + poetry config http-basic.jfrog gha-service-account "${JFROG_ACCESS_TOKEN}" + poetry source add --priority=primary jfrog https://databricks.jfrog.io/artifactory/api/pypi/db-pypi/simple + poetry lock + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 + with: + path: ${{ inputs.cache-path }} + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ inputs.cache-suffix }}${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} + + - name: Install dependencies + if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' + shell: bash + run: poetry install --no-interaction --no-root + + - name: Install library + shell: bash + run: poetry install --no-interaction ${{ inputs.install-args }} diff --git a/.github/workflows/code-coverage.yml b/.github/workflows/code-coverage.yml index 5c961757e..9f578ec9f 100644 --- a/.github/workflows/code-coverage.yml +++ b/.github/workflows/code-coverage.yml @@ -1,13 +1,21 @@ -name: Code Coverage +name: E2E Tests and Code Coverage permissions: contents: read + id-token: write -on: [pull_request, workflow_dispatch] +on: + push: + branches: + - main + pull_request: + workflow_dispatch: jobs: test-with-coverage: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest environment: azure-prod env: DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} @@ -16,97 +24,35 @@ jobs: DATABRICKS_CATALOG: peco DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 with: fetch-depth: 0 - ref: ${{ github.event.pull_request.head.ref || github.ref_name }} - repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }} - - name: Set up python - id: setup-python - uses: actions/setup-python@v5 - with: - python-version: "3.10" - #---------------------------------------------- - # ----- install system dependencies ----- - #---------------------------------------------- - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y libkrb5-dev - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v4 + - name: Setup Poetry + uses: ./.github/actions/setup-poetry with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Install library - run: poetry install --no-interaction --all-extras - #---------------------------------------------- - # run parallel tests with coverage - #---------------------------------------------- - - name: Run parallel tests with coverage + python-version: "3.10" + install-args: "--all-extras" + - name: Run all tests with coverage continue-on-error: false run: | poetry run pytest tests/unit tests/e2e \ - -m "not serial" \ - -n auto \ + -n 4 \ + --dist=loadgroup \ --cov=src \ --cov-report=xml \ --cov-report=term \ -v - - #---------------------------------------------- - # run telemetry tests with coverage (isolated) - #---------------------------------------------- - - name: Run telemetry tests with coverage (isolated) - continue-on-error: false - run: | - # Run test_concurrent_telemetry.py separately for isolation - poetry run pytest tests/e2e/test_concurrent_telemetry.py \ - --cov=src \ - --cov-append \ - --cov-report=xml \ - --cov-report=term \ - -v - - #---------------------------------------------- - # check for coverage override - #---------------------------------------------- - name: Check for coverage override id: override + env: + PR_BODY: ${{ github.event.pull_request.body }} run: | - OVERRIDE_COMMENT=$(echo "${{ github.event.pull_request.body }}" | grep -E "SKIP_COVERAGE_CHECK\s*=" || echo "") + OVERRIDE_COMMENT=$(echo "$PR_BODY" | grep -E "SKIP_COVERAGE_CHECK\s*=" || echo "") if [ -n "$OVERRIDE_COMMENT" ]; then echo "override=true" >> $GITHUB_OUTPUT REASON=$(echo "$OVERRIDE_COMMENT" | sed -E 's/.*SKIP_COVERAGE_CHECK\s*=\s*(.+)/\1/') @@ -116,9 +62,6 @@ jobs: echo "override=false" >> $GITHUB_OUTPUT echo "No coverage override found" fi - #---------------------------------------------- - # check coverage percentage - #---------------------------------------------- - name: Check coverage percentage if: steps.override.outputs.override == 'false' run: | @@ -127,20 +70,14 @@ jobs: echo "ERROR: Coverage file not found at $COVERAGE_FILE" exit 1 fi - - # Install xmllint if not available if ! command -v xmllint &> /dev/null; then sudo apt-get update && sudo apt-get install -y libxml2-utils fi - COVERED=$(xmllint --xpath "string(//coverage/@lines-covered)" "$COVERAGE_FILE") TOTAL=$(xmllint --xpath "string(//coverage/@lines-valid)" "$COVERAGE_FILE") PERCENTAGE=$(python3 -c "covered=${COVERED}; total=${TOTAL}; print(round((covered/total)*100, 2))") - echo "Branch Coverage: $PERCENTAGE%" echo "Required Coverage: 85%" - - # Use Python to compare the coverage with 85 python3 -c "import sys; sys.exit(0 if float('$PERCENTAGE') >= 85 else 1)" if [ $? -eq 1 ]; then echo "ERROR: Coverage is $PERCENTAGE%, which is less than the required 85%" @@ -148,16 +85,14 @@ jobs: else echo "SUCCESS: Coverage is $PERCENTAGE%, which meets the required 85%" fi - - #---------------------------------------------- - # coverage enforcement summary - #---------------------------------------------- - name: Coverage enforcement summary + env: + OVERRIDE: ${{ steps.override.outputs.override }} + REASON: ${{ steps.override.outputs.reason }} run: | - if [ "${{ steps.override.outputs.override }}" == "true" ]; then - echo "⚠️ Coverage checks bypassed: ${{ steps.override.outputs.reason }}" + if [ "$OVERRIDE" == "true" ]; then + echo "Coverage checks bypassed: $REASON" echo "Please ensure this override is justified and temporary" else - echo "✅ Coverage checks enforced - minimum 85% required" + echo "Coverage checks enforced - minimum 85% required" fi - diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index cc3952920..4071a6e51 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -2,94 +2,58 @@ name: Code Quality Checks on: [pull_request] +permissions: + contents: read + id-token: write + jobs: run-unit-tests: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] dependency-version: ["default", "min"] - # Optimize matrix - test min/max on subset of Python versions exclude: - python-version: "3.12" dependency-version: "min" - python-version: "3.13" dependency-version: "min" - + name: "Unit Tests (Python ${{ matrix.python-version }}, ${{ matrix.dependency-version }} deps)" - + steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v4 - - name: Set up python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@v5 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - name: Setup Poetry + uses: ./.github/actions/setup-poetry with: python-version: ${{ matrix.python-version }} - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ matrix.dependency-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install library - run: poetry install --no-interaction - #---------------------------------------------- - # override with custom dependency versions - #---------------------------------------------- + cache-suffix: "${{ matrix.dependency-version }}-" - name: Install Python tools for custom versions if: matrix.dependency-version != 'default' run: poetry run pip install toml packaging - - name: Generate requirements file if: matrix.dependency-version != 'default' run: | poetry run python scripts/dependency_manager.py ${{ matrix.dependency-version }} --output requirements-${{ matrix.dependency-version }}.txt echo "Generated requirements for ${{ matrix.dependency-version }} versions:" cat requirements-${{ matrix.dependency-version }}.txt - - name: Override with custom dependency versions if: matrix.dependency-version != 'default' run: poetry run pip install -r requirements-${{ matrix.dependency-version }}.txt - - #---------------------------------------------- - # run test suite - #---------------------------------------------- - name: Show installed versions run: | echo "=== Dependency Version: ${{ matrix.dependency-version }} ===" poetry run pip list - - name: Run tests run: poetry run python -m pytest tests/unit + run-unit-tests-with-arrow: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] @@ -99,186 +63,73 @@ jobs: dependency-version: "min" - python-version: "3.13" dependency-version: "min" - - name: "Unit Tests + PyArrow (Python ${{ matrix.python-version }}, ${{ matrix.dependency-version }} deps)" - - steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@v2 - - name: Set up python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v4 - with: - path: .venv-pyarrow - key: venv-pyarrow-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ matrix.dependency-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Install library - run: poetry install --no-interaction --all-extras - #---------------------------------------------- - # override with custom dependency versions - #---------------------------------------------- - - name: Install Python tools for custom versions - if: matrix.dependency-version != 'default' - run: poetry run pip install toml packaging - - - name: Generate requirements file with pyarrow - if: matrix.dependency-version != 'default' - run: | - poetry run python scripts/dependency_manager.py ${{ matrix.dependency-version }} --output requirements-${{ matrix.dependency-version }}-arrow.txt - echo "Generated requirements for ${{ matrix.dependency-version }} versions with PyArrow:" - cat requirements-${{ matrix.dependency-version }}-arrow.txt + name: "Unit Tests + PyArrow (Python ${{ matrix.python-version }}, ${{ matrix.dependency-version }} deps)" - - name: Override with custom dependency versions - if: matrix.dependency-version != 'default' - run: poetry run pip install -r requirements-${{ matrix.dependency-version }}-arrow.txt - #---------------------------------------------- - # run test suite - #---------------------------------------------- - - name: Show installed versions - run: | - echo "=== Dependency Version: ${{ matrix.dependency-version }} with PyArrow ===" - poetry run pip list + steps: + - name: Check out repository + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - name: Install Kerberos system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libkrb5-dev + - name: Setup Poetry + uses: ./.github/actions/setup-poetry + with: + python-version: ${{ matrix.python-version }} + install-args: "--all-extras" + cache-suffix: "pyarrow-${{ matrix.dependency-version }}-" + - name: Install Python tools for custom versions + if: matrix.dependency-version != 'default' + run: poetry run pip install toml packaging + - name: Generate requirements file with pyarrow + if: matrix.dependency-version != 'default' + run: | + poetry run python scripts/dependency_manager.py ${{ matrix.dependency-version }} --output requirements-${{ matrix.dependency-version }}-arrow.txt + echo "Generated requirements for ${{ matrix.dependency-version }} versions with PyArrow:" + cat requirements-${{ matrix.dependency-version }}-arrow.txt + - name: Override with custom dependency versions + if: matrix.dependency-version != 'default' + run: poetry run pip install -r requirements-${{ matrix.dependency-version }}-arrow.txt + - name: Show installed versions + run: | + echo "=== Dependency Version: ${{ matrix.dependency-version }} with PyArrow ===" + poetry run pip list + - name: Run tests + run: poetry run python -m pytest tests/unit - - name: Run tests - run: poetry run python -m pytest tests/unit check-linting: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v4 - - name: Set up python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@v5 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - name: Setup Poetry + uses: ./.github/actions/setup-poetry with: python-version: ${{ matrix.python-version }} - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install library - run: poetry install --no-interaction - #---------------------------------------------- - # black the code - #---------------------------------------------- - name: Black run: poetry run black --check src check-types: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v4 - - name: Set up python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@v5 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - name: Setup Poetry + uses: ./.github/actions/setup-poetry with: python-version: ${{ matrix.python-version }} - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install library - run: poetry install --no-interaction - #---------------------------------------------- - # mypy the code - #---------------------------------------------- - name: Mypy run: | - mkdir .mypy_cache # Workaround for bad error message "error: --install-types failed (no mypy cache directory)"; see https://github.com/python/mypy/issues/10768#issuecomment-2178450153 + mkdir .mypy_cache poetry run mypy --install-types --non-interactive src diff --git a/.github/workflows/daily-telemetry-e2e.yml b/.github/workflows/daily-telemetry-e2e.yml deleted file mode 100644 index d60b7f5a9..000000000 --- a/.github/workflows/daily-telemetry-e2e.yml +++ /dev/null @@ -1,92 +0,0 @@ -name: Daily Telemetry E2E Tests - -on: - schedule: - - cron: '0 0 * * 0' # Run every Sunday at midnight UTC - - workflow_dispatch: # Allow manual triggering - inputs: - test_pattern: - description: 'Test pattern to run (default: tests/e2e/test_telemetry_e2e.py)' - required: false - default: 'tests/e2e/test_telemetry_e2e.py' - type: string - -jobs: - telemetry-e2e-tests: - runs-on: ubuntu-latest - environment: azure-prod - - env: - DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} - DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} - DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} - DATABRICKS_CATALOG: peco - DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} - - steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@v4 - - - name: Set up python - id: setup-python - uses: actions/setup-python@v5 - with: - python-version: "3.10" - - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Install dependencies - run: poetry install --no-interaction --all-extras - - #---------------------------------------------- - # run telemetry E2E tests - #---------------------------------------------- - - name: Run telemetry E2E tests - run: | - TEST_PATTERN="${{ github.event.inputs.test_pattern || 'tests/e2e/test_telemetry_e2e.py' }}" - echo "Running tests: $TEST_PATTERN" - poetry run python -m pytest $TEST_PATTERN -v -s - - #---------------------------------------------- - # upload test results on failure - #---------------------------------------------- - - name: Upload test results on failure - if: failure() - uses: actions/upload-artifact@v4 - with: - name: telemetry-test-results - path: | - .pytest_cache/ - tests-unsafe.log - retention-days: 7 - diff --git a/.github/workflows/dco-check.yml b/.github/workflows/dco-check.yml index 050665ec9..fdcf1b3bb 100644 --- a/.github/workflows/dco-check.yml +++ b/.github/workflows/dco-check.yml @@ -1,27 +1,74 @@ name: DCO Check -on: [pull_request] +on: + pull_request: + types: [opened, synchronize, reopened] + branches: [main] + +permissions: + contents: read jobs: - check: + dco-check: runs-on: - group: databricks-protected-runner-group - labels: linux-ubuntu-latest + group: databricks-protected-runner-group + labels: linux-ubuntu-latest + name: Check DCO Sign-off steps: - - name: Check for DCO - id: dco-check - uses: tisonkun/actions-dco@v1.1 - - name: Comment about DCO status - uses: actions/github-script@v7 - if: ${{ failure() }} + - name: Checkout + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 with: - script: | - github.rest.issues.createComment({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - body: `Thanks for your contribution! To satisfy the DCO policy in our \ - [contributing guide](https://github.com/databricks/databricks-sql-python/blob/main/CONTRIBUTING.md) \ - every commit message must include a sign-off message. One or more of your commits is missing this message. \ - You can reword previous commit messages with an interactive rebase (\`git rebase -i main\`).` - }) + fetch-depth: 0 + + - name: Check DCO Sign-off + env: + BASE_SHA: ${{ github.event.pull_request.base.sha }} + HEAD_SHA: ${{ github.event.pull_request.head.sha }} + run: | + #!/bin/bash + set -e + + echo "Checking commits from $BASE_SHA to $HEAD_SHA" + + COMMITS=$(git rev-list --no-merges "$BASE_SHA..$HEAD_SHA") + + if [ -z "$COMMITS" ]; then + echo "No commits found in this PR" + exit 0 + fi + + FAILED_COMMITS=() + + for commit in $COMMITS; do + echo "Checking commit: $commit" + COMMIT_MSG=$(git log --format=%B -n 1 "$commit") + if echo "$COMMIT_MSG" | grep -q "^Signed-off-by: "; then + echo " Commit $commit has DCO sign-off" + else + echo " Commit $commit is missing DCO sign-off" + FAILED_COMMITS+=("$commit") + fi + done + + if [ ${#FAILED_COMMITS[@]} -ne 0 ]; then + echo "" + echo "DCO Check Failed!" + echo "The following commits are missing the required 'Signed-off-by' line:" + for commit in "${FAILED_COMMITS[@]}"; do + echo " - $commit: $(git log --format=%s -n 1 "$commit")" + done + echo "" + echo "To fix this, you need to sign off your commits. You can:" + echo "1. Add sign-off to new commits: git commit -s -m 'Your commit message'" + echo "2. Amend existing commits: git commit --amend --signoff" + echo "3. For multiple commits, use: git rebase --signoff HEAD~N (where N is the number of commits)" + echo "" + echo "The sign-off should be in the format:" + echo "Signed-off-by: Your Name " + echo "" + echo "For more details, see CONTRIBUTING.md" + exit 1 + else + echo "" + echo "All commits have proper DCO sign-off!" + fi diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml deleted file mode 100644 index c915ee6c1..000000000 --- a/.github/workflows/integration.yml +++ /dev/null @@ -1,111 +0,0 @@ -name: Integration Tests - -on: - push: - branches: - - main - pull_request: - -jobs: - run-non-telemetry-tests: - runs-on: ubuntu-latest - environment: azure-prod - env: - DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} - DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} - DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} - DATABRICKS_CATALOG: peco - DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} - steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@v4 - - name: Set up python - id: setup-python - uses: actions/setup-python@v5 - with: - python-version: "3.10" - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Install dependencies - run: poetry install --no-interaction --all-extras - #---------------------------------------------- - # run test suite - #---------------------------------------------- - - name: Run non-telemetry e2e tests - run: | - # Exclude all telemetry tests - they run in separate job for isolation - poetry run python -m pytest tests/e2e \ - --ignore=tests/e2e/test_telemetry_e2e.py \ - --ignore=tests/e2e/test_concurrent_telemetry.py \ - -n auto - - run-telemetry-tests: - runs-on: ubuntu-latest - needs: run-non-telemetry-tests # Run after non-telemetry tests complete - environment: azure-prod - env: - DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} - DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} - DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} - DATABRICKS_CATALOG: peco - DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} - steps: - - name: Check out repository - uses: actions/checkout@v4 - - name: Set up python - id: setup-python - uses: actions/setup-python@v5 - with: - python-version: "3.10" - - name: Install system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - - name: Install dependencies - run: poetry install --no-interaction --all-extras - - name: Run telemetry tests in isolation - run: | - # Run test_concurrent_telemetry.py in isolation with complete process separation - # Use --dist=loadgroup to respect @pytest.mark.xdist_group markers - poetry run python -m pytest tests/e2e/test_concurrent_telemetry.py \ - -n auto --dist=loadgroup -v \ No newline at end of file diff --git a/.github/workflows/publish-manual.yml b/.github/workflows/publish-manual.yml deleted file mode 100644 index 2f2a7a4dd..000000000 --- a/.github/workflows/publish-manual.yml +++ /dev/null @@ -1,87 +0,0 @@ -name: Publish to PyPI Manual [Production] - -# Allow manual triggering of the workflow -on: - workflow_dispatch: {} - -jobs: - publish: - name: Publish - runs-on: ubuntu-latest - - steps: - #---------------------------------------------- - # Step 1: Check out the repository code - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@v2 # Check out the repository to access the code - - #---------------------------------------------- - # Step 2: Set up Python environment - #---------------------------------------------- - - name: Set up python - id: setup-python - uses: actions/setup-python@v2 - with: - python-version: 3.9 # Specify the Python version to be used - - #---------------------------------------------- - # Step 3: Install and configure Poetry - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 # Install Poetry, the Python package manager - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # Step 3.5: Install Kerberos system dependencies - #---------------------------------------------- - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - -# #---------------------------------------------- -# # Step 4: Load cached virtual environment (if available) -# #---------------------------------------------- -# - name: Load cached venv -# id: cached-poetry-dependencies -# uses: actions/cache@v2 -# with: -# path: .venv # Path to the virtual environment -# key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} -# # Cache key is generated based on OS, Python version, repo name, and the `poetry.lock` file hash - -# #---------------------------------------------- -# # Step 5: Install dependencies if the cache is not found -# #---------------------------------------------- -# - name: Install dependencies -# if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' # Only run if the cache was not hit -# run: poetry install --no-interaction --no-root # Install dependencies without interaction - -# #---------------------------------------------- -# # Step 6: Update the version to the manually provided version -# #---------------------------------------------- -# - name: Update pyproject.toml with the specified version -# run: poetry version ${{ github.event.inputs.version }} # Use the version provided by the user input - - #---------------------------------------------- - # Step 7: Build and publish the first package to PyPI - #---------------------------------------------- - - name: Build and publish databricks sql connector to PyPI - working-directory: ./databricks_sql_connector - run: | - poetry build - poetry publish -u __token__ -p ${{ secrets.PROD_PYPI_TOKEN }} # Publish with PyPI token - #---------------------------------------------- - # Step 7: Build and publish the second package to PyPI - #---------------------------------------------- - - - name: Build and publish databricks sql connector core to PyPI - working-directory: ./databricks_sql_connector_core - run: | - poetry build - poetry publish -u __token__ -p ${{ secrets.PROD_PYPI_TOKEN }} # Publish with PyPI token \ No newline at end of file diff --git a/.github/workflows/publish-test.yml b/.github/workflows/publish-test.yml deleted file mode 100644 index 97a444e68..000000000 --- a/.github/workflows/publish-test.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: Publish to PyPI [Test] -on: [push] -jobs: - test-pypi: - name: Create patch version number and push to test-pypi - runs-on: ubuntu-latest - steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@v4 - - name: Set up python - id: setup-python - uses: actions/setup-python@v5 - with: - python-version: "3.10" - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # Get the current version and increment it (test-pypi requires a unique version number) - #---------------------------------------------- - - name: Get next version - uses: reecetech/version-increment@2022.2.4 - id: version - with: - scheme: semver - increment: patch - #---------------------------------------------- - # Tell poetry to update the version number - #---------------------------------------------- - - name: Update pyproject.toml - run: poetry version ${{ steps.version.outputs.major-version }}.${{ steps.version.outputs.minor-version }}.dev$(date +%s) - #---------------------------------------------- - # Build the package (before publish action) - #---------------------------------------------- - - name: Build package - run: poetry build - #---------------------------------------------- - # Configure test-pypi repository - #---------------------------------------------- - - name: Configure test-pypi repository - run: poetry config repositories.testpypi https://test.pypi.org/legacy/ - #---------------------------------------------- - # Attempt push to test-pypi - #---------------------------------------------- - - name: Publish to test-pypi - run: poetry publish --username __token__ --password ${{ secrets.TEST_PYPI_TOKEN }} --repository testpypi diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index b101f421c..000000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,72 +0,0 @@ -name: Publish to PyPI [Production] -on: - release: - types: [published] -jobs: - publish: - name: Publish - runs-on: ubuntu-latest - steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@v4 - - name: Set up python - id: setup-python - uses: actions/setup-python@v5 - with: - python-version: 3.9 - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #------------------------------------------------------------------------------------------------ - # Here we use version-increment to fetch the latest tagged version (we won't increment it though) - #------------------------------------------------------------------------------------------------ - - name: Get next version - uses: reecetech/version-increment@2022.2.4 - id: version - with: - scheme: semver - increment: patch - #----------------------------------------------------------------------------- - # Tell poetry to use the `current-version` that was found by the previous step - #----------------------------------------------------------------------------- - - name: Update pyproject.toml - run: poetry version ${{ steps.version.outputs.current-version }} - #---------------------------------------------- - # Build the package (before publish) - #---------------------------------------------- - - name: Build package - run: poetry build - #---------------------------------------------- - # Publish to pypi - #---------------------------------------------- - - name: Publish to pypi - run: poetry publish --username __token__ --password ${{ secrets.PROD_PYPI_TOKEN }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ba3bb1a8..fc89750d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Release History +# 4.2.6 (2026-04-22) +- Add SPOG routing support for account-level vanity URLs (databricks/databricks-sql-python#767 by @msrathore-db) +- Fix dependency_manager: handle PEP 440 ~= compatible release syntax (databricks/databricks-sql-python#776 by @vikrantpuppala) +- Bump thrift to fix deprecation warning (databricks/databricks-sql-python#733 by @Korijn) +- Add AI coding agent detection to User-Agent header (databricks/databricks-sql-python#740 by @vikrantpuppala) +- Add statement-level query_tags support for SEA backend (databricks/databricks-sql-python#754 by @sreekanth-db) +- Update PyArrow concatenation of tables to use promote_options as default (databricks/databricks-sql-python#751 by @jprakash-db) +- Fix float inference to use DoubleParameter (64-bit) instead of FloatParameter (databricks/databricks-sql-python#742 by @Shubhambhusate) +- Allow specifying query_tags as a dict upon connection creation (databricks/databricks-sql-python#749 by @jiabin-hu) +- Add query_tags parameter support for execute methods (databricks/databricks-sql-python#736 by @jiabin-hu) + # 4.2.5 (2026-02-09) - Fix feature-flag endpoint retries in gov region (databricks/databricks-sql-python#735 by @samikshya-db) - Improve telemetry lifecycle management (databricks/databricks-sql-python#734 by @msrathore-db) diff --git a/examples/query_tags_example.py b/examples/query_tags_example.py index f615d082c..977dc6ad5 100644 --- a/examples/query_tags_example.py +++ b/examples/query_tags_example.py @@ -7,24 +7,112 @@ Query Tags are key-value pairs that can be attached to SQL executions and will appear in the system.query.history table for analytical purposes. -Format: "key1:value1,key2:value2,key3:value3" +There are two ways to set query tags: +1. Connection-level: Pass query_tags parameter to sql.connect() (applies to all queries in the session) +2. Per-query level: Pass query_tags parameter to execute() or execute_async() (applies to specific query) + +Format: Dictionary with string keys and optional string values +Example: {"team": "engineering", "application": "etl", "priority": "high"} + +Special cases: +- If a value is None, only the key is included (no colon or value) +- Special characters (comma, colon and backslash) in values are automatically escaped +- Backslashes in keys are automatically escaped; other special characters in keys are not allowed """ print("=== Query Tags Example ===\n") +# Example 1: Connection-level query tags +print("Example 1: Connection-level query tags") with sql.connect( server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), http_path=os.getenv("DATABRICKS_HTTP_PATH"), access_token=os.getenv("DATABRICKS_TOKEN"), - session_configuration={ - 'QUERY_TAGS': 'team:engineering,test:query-tags', - 'ansi_mode': False - } + query_tags={"team": "engineering", "application": "etl"}, ) as connection: - + with connection.cursor() as cursor: cursor.execute("SELECT 1") result = cursor.fetchone() print(f" Result: {result[0]}") -print("\n=== Query Tags Example Complete ===") \ No newline at end of file +print() + +# Example 2: Per-query query tags +print("Example 2: Per-query query tags") +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + with connection.cursor() as cursor: + # Query 1: Tags for a critical ETL job + cursor.execute( + "SELECT 1", + query_tags={"team": "data-eng", "application": "etl", "priority": "high"} + ) + result = cursor.fetchone() + print(f" ETL Query Result: {result[0]}") + + # Query 2: Tags with None value (key-only tag) + cursor.execute( + "SELECT 2", + query_tags={"team": "analytics", "experimental": None} + ) + result = cursor.fetchone() + print(f" Experimental Query Result: {result[0]}") + + # Query 3: Tags with special characters (automatically escaped) + cursor.execute( + "SELECT 3", + query_tags={"description": "test:with:colons,and,commas"} + ) + result = cursor.fetchone() + print(f" Special Chars Query Result: {result[0]}") + + # Query 4: No tags (demonstrates tags don't persist from previous queries) + cursor.execute("SELECT 4") + result = cursor.fetchone() + print(f" No Tags Query Result: {result[0]}") + +print() + +# Example 3: Async execution with query tags +print("Example 3: Async execution with query tags") +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + with connection.cursor() as cursor: + cursor.execute_async( + "SELECT 5", + query_tags={"team": "data-eng", "mode": "async"} + ) + cursor.get_async_execution_result() + result = cursor.fetchone() + print(f" Async Query Result: {result[0]}") + +print() + +# Example 4: executemany with query tags +print("Example 4: executemany with query tags") +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + with connection.cursor() as cursor: + # Execute multiple queries with the same tags + cursor.executemany( + "SELECT ?", + [[6], [7], [8]], + query_tags={"team": "data-eng", "batch": "executemany"} + ) + result = cursor.fetchone() + print(f" Executemany Query Result (last): {result[0]}") + +print("\n=== Query Tags Example Complete ===") diff --git a/poetry.lock b/poetry.lock index 7d0845a58..5644190f4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.1 and should not be changed by hand. [[package]] name = "astroid" @@ -1158,15 +1158,15 @@ pytz = ">=2020.1" tzdata = ">=2022.1" [package.extras] -all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"] -aws = ["s3fs (>=2021.08.0)"] +all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.7.0)", "gcsfs (>=2021.7.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.8.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"] +aws = ["s3fs (>=2021.8.0)"] clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"] compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"] computation = ["scipy (>=1.7.1)", "xarray (>=0.21.0)"] excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pyxlsb (>=1.0.8)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)"] feather = ["pyarrow (>=7.0.0)"] -fss = ["fsspec (>=2021.07.0)"] -gcp = ["gcsfs (>=2021.07.0)", "pandas-gbq (>=0.15.0)"] +fss = ["fsspec (>=2021.7.0)"] +gcp = ["gcsfs (>=2021.7.0)", "pandas-gbq (>=0.15.0)"] hdf5 = ["tables (>=3.6.1)"] html = ["beautifulsoup4 (>=4.9.3)", "html5lib (>=1.1)", "lxml (>=4.6.3)"] mysql = ["SQLAlchemy (>=1.4.16)", "pymysql (>=1.0.2)"] @@ -1523,7 +1523,7 @@ files = [ ] [package.dependencies] -astroid = ">=3.2.4,<=3.3.0-dev0" +astroid = ">=3.2.4,<=3.3.0.dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, @@ -1849,18 +1849,15 @@ files = [ [[package]] name = "thrift" -version = "0.20.0" +version = "0.22.0" description = "Python bindings for the Apache Thrift RPC system" optional = false python-versions = "*" groups = ["main"] files = [ - {file = "thrift-0.20.0.tar.gz", hash = "sha256:4dd662eadf6b8aebe8a41729527bd69adf6ceaa2a8681cbef64d1273b3e8feba"}, + {file = "thrift-0.22.0.tar.gz", hash = "sha256:42e8276afbd5f54fe1d364858b6877bc5e5a4a5ed69f6a005b94ca4918fe1466"}, ] -[package.dependencies] -six = ">=1.7.2" - [package.extras] all = ["tornado (>=4.0)", "twisted"] tornado = ["tornado (>=4.0)"] @@ -1969,4 +1966,4 @@ pyarrow = ["pyarrow", "pyarrow", "pyarrow"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "ec311bf26ec866de2f427bcdf4ec69ceed721bfd70edfae3aba1ac12882a09d6" +content-hash = "d1739e84dcbd6e7ac311eb6fbb9cf87ad110491f7d954f07fdfc32b704b4413f" diff --git a/pyproject.toml b/pyproject.toml index 911f1b79c..5e9f7f0ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.2.5" +version = "4.2.6" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" @@ -10,7 +10,7 @@ include = ["CHANGELOG.md"] [tool.poetry.dependencies] python = "^3.8.0" -thrift = ">=0.16.0,<0.21.0" +thrift = "~=0.22.0" pandas = [ { version = ">=1.2.5,<2.4.0", python = ">=3.8,<3.13" }, { version = ">=2.2.3,<2.4.0", python = ">=3.13" } diff --git a/scripts/dependency_manager.py b/scripts/dependency_manager.py index 15e119841..29c5fe828 100644 --- a/scripts/dependency_manager.py +++ b/scripts/dependency_manager.py @@ -69,16 +69,21 @@ def _parse_constraint(self, name, constraint): def _extract_versions_from_specifier(self, spec_set_str): """Extract minimum version from a specifier set""" try: - # Handle caret (^) and tilde (~) constraints that packaging doesn't support + # Handle caret (^) and tilde (~, ~=) constraints that packaging doesn't + # support (Poetry ^, Poetry ~, and PEP 440 ~=). if spec_set_str.startswith('^'): # ^1.2.3 means >=1.2.3, <2.0.0 min_version = spec_set_str[1:] # Remove ^ return min_version, None + elif spec_set_str.startswith('~='): + # PEP 440 compatible release: ~=1.2.3 means >=1.2.3, <1.3.0 + min_version = spec_set_str[2:] # Remove ~= + return min_version, None elif spec_set_str.startswith('~'): - # ~1.2.3 means >=1.2.3, <1.3.0 + # Poetry tilde: ~1.2.3 means >=1.2.3, <1.3.0 min_version = spec_set_str[1:] # Remove ~ return min_version, None - + spec_set = SpecifierSet(spec_set_str) min_version = None diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index c9195b89f..493ffe3a2 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -71,7 +71,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.2.5" +__version__ = "4.2.6" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 2213635fe..b772e7ddd 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -83,6 +83,7 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, row_limit: Optional[int] = None, + query_tags: Optional[Dict[str, Optional[str]]] = None, ) -> Union[ResultSet, None]: """ Executes a SQL command or query within the specified session. @@ -102,6 +103,7 @@ def execute_command( async_op: Whether to execute the command asynchronously enforce_embedded_schema_correctness: Whether to enforce schema correctness row_limit: Maximum number of rows in the response. + query_tags: Optional dictionary of query tags to apply for this query only. Returns: If async_op is False, returns a ResultSet object containing the diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 1427226d2..04c79a18b 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -188,8 +188,9 @@ def _extract_warehouse_id(self, http_path: str) -> str: ValueError: If the warehouse ID cannot be extracted from the path """ - warehouse_pattern = re.compile(r".*/warehouses/(.+)") - endpoint_pattern = re.compile(r".*/endpoints/(.+)") + # [^?&]+ stops at query params (e.g. ?o= for SPOG routing) + warehouse_pattern = re.compile(r".*/warehouses/([^?&]+)") + endpoint_pattern = re.compile(r".*/endpoints/([^?&]+)") for pattern in [warehouse_pattern, endpoint_pattern]: match = pattern.match(http_path) @@ -463,6 +464,7 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, row_limit: Optional[int] = None, + query_tags: Optional[Dict[str, Optional[str]]] = None, ) -> Union[SeaResultSet, None]: """ Execute a SQL command using the SEA backend. @@ -529,6 +531,7 @@ def execute_command( row_limit=row_limit, parameters=sea_parameters if sea_parameters else None, result_compression=result_compression, + query_tags=query_tags, ) response_data = self._http_client._make_request( diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index ad046ff54..eb156fb1a 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -31,6 +31,7 @@ class ExecuteStatementRequest: wait_timeout: str = "10s" on_wait_timeout: str = "CONTINUE" row_limit: Optional[int] = None + query_tags: Optional[Dict[str, Optional[str]]] = None def to_dict(self) -> Dict[str, Any]: """Convert the request to a dictionary for JSON serialization.""" @@ -60,6 +61,13 @@ def to_dict(self) -> Dict[str, Any]: for param in self.parameters ] + # SEA API expects query_tags as an array of {key, value} objects. + # None/empty values are left to the server to handle as key-only tags. + if self.query_tags: + result["query_tags"] = [ + {"key": k, "value": v} for k, v in self.query_tags.items() + ] + return result diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index edee02bfa..e23f3389b 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,7 +5,7 @@ import math import time import threading -from typing import List, Optional, Union, Any, TYPE_CHECKING +from typing import Dict, List, Optional, Union, Any, TYPE_CHECKING from uuid import UUID from databricks.sql.common.unified_http_client import UnifiedHttpClient @@ -53,6 +53,7 @@ convert_arrow_based_set_to_arrow_table, convert_decimals_in_arrow_table, convert_column_based_set_to_arrow_table, + serialize_query_tags, ) from databricks.sql.types import SSLOptions from databricks.sql.backend.databricks_client import DatabricksClient @@ -1003,6 +1004,7 @@ def execute_command( async_op=False, enforce_embedded_schema_correctness=False, row_limit: Optional[int] = None, + query_tags: Optional[Dict[str, Optional[str]]] = None, ) -> Union["ResultSet", None]: thrift_handle = session_id.to_thrift_handle() if not thrift_handle: @@ -1022,6 +1024,19 @@ def execute_command( # DBR should be changed to use month_day_nano_interval intervalTypesAsArrow=False, ) + + # Build confOverlay with default configs and query_tags + merged_conf_overlay = { + # We want to receive proper Timestamp arrow types. + "spark.thriftserver.arrowBasedRowSet.timestampAsString": "false" + } + + # Serialize and add query_tags to confOverlay if provided + if query_tags: + serialized_tags = serialize_query_tags(query_tags) + if serialized_tags: + merged_conf_overlay["query_tags"] = serialized_tags + req = ttypes.TExecuteStatementReq( sessionHandle=thrift_handle, statement=operation, @@ -1036,10 +1051,7 @@ def execute_command( canReadArrowResult=True if pyarrow else False, canDecompressLZ4Result=lz4_compression, canDownloadResult=use_cloud_fetch, - confOverlay={ - # We want to receive proper Timestamp arrow types. - "spark.thriftserver.arrowBasedRowSet.timestampAsString": "false" - }, + confOverlay=merged_conf_overlay, useArrowNativeTypes=spark_arrow_types, parameters=parameters, enforceEmbeddedSchemaCorrectness=enforce_embedded_schema_correctness, diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 1a246b7c1..fe52f0c79 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -36,6 +36,7 @@ ColumnQueue, build_client_context, get_session_config_value, + serialize_query_tags, ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -106,6 +107,7 @@ def __init__( schema: Optional[str] = None, _use_arrow_native_complex_types: Optional[bool] = True, ignore_transactions: bool = True, + query_tags: Optional[Dict[str, Optional[str]]] = None, **kwargs, ) -> None: """ @@ -281,6 +283,15 @@ def read(self) -> Optional[OAuthToken]: "spark.sql.thriftserver.metadata.metricview.enabled" ] = "true" + if query_tags is not None: + if session_configuration is None: + session_configuration = {} + serialized = serialize_query_tags(query_tags) + if serialized: + session_configuration["QUERY_TAGS"] = serialized + else: + session_configuration.pop("QUERY_TAGS", None) + self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) @@ -342,6 +353,7 @@ def read(self) -> Optional[OAuthToken]: host_url=self.session.host, batch_size=self.telemetry_batch_size, client_context=client_context, + extra_headers=self.session.get_spog_headers(), ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( @@ -1263,6 +1275,7 @@ def execute( parameters: Optional[TParameterCollection] = None, enforce_embedded_schema_correctness=False, input_stream: Optional[BinaryIO] = None, + query_tags: Optional[Dict[str, Optional[str]]] = None, ) -> "Cursor": """ Execute a query and wait for execution to complete. @@ -1293,6 +1306,10 @@ def execute( Both will result in the query equivalent to "SELECT * FROM table WHERE field = 'foo' being sent to the server + :param query_tags: Optional dictionary of query tags to apply for this query only. + Tags are key-value pairs that can be used to identify and categorize queries. + Example: {"team": "data-eng", "application": "etl"} + :returns self """ @@ -1333,6 +1350,7 @@ def execute( async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, row_limit=self.row_limit, + query_tags=query_tags, ) if self.active_result_set and self.active_result_set.is_staging_operation: @@ -1349,6 +1367,7 @@ def execute_async( operation: str, parameters: Optional[TParameterCollection] = None, enforce_embedded_schema_correctness=False, + query_tags: Optional[Dict[str, Optional[str]]] = None, ) -> "Cursor": """ @@ -1356,6 +1375,9 @@ def execute_async( :param operation: :param parameters: + :param query_tags: Optional dictionary of query tags to apply for this query only. + Tags are key-value pairs that can be used to identify and categorize queries. + Example: {"team": "data-eng", "application": "etl"} :return: """ @@ -1392,6 +1414,7 @@ def execute_async( async_op=True, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, row_limit=self.row_limit, + query_tags=query_tags, ) return self @@ -1448,7 +1471,12 @@ def get_async_execution_result(self): session_id_hex=self.connection.get_session_id_hex(), ) - def executemany(self, operation, seq_of_parameters): + def executemany( + self, + operation, + seq_of_parameters, + query_tags: Optional[Dict[str, Optional[str]]] = None, + ): """ Execute the operation once for every set of passed in parameters. @@ -1457,10 +1485,14 @@ def executemany(self, operation, seq_of_parameters): Only the final result set is retained. + :param query_tags: Optional dictionary of query tags to apply for all queries in this batch. + Tags are key-value pairs that can be used to identify and categorize queries. + Example: {"team": "data-eng", "application": "etl"} + :returns self """ for parameters in seq_of_parameters: - self.execute(operation, parameters) + self.execute(operation, parameters, query_tags=query_tags) return self @log_latency(StatementType.METADATA) diff --git a/src/databricks/sql/common/agent.py b/src/databricks/sql/common/agent.py new file mode 100644 index 000000000..79d1b2b7a --- /dev/null +++ b/src/databricks/sql/common/agent.py @@ -0,0 +1,52 @@ +""" +Detects whether the Python SQL connector is being invoked by an AI coding agent +by checking for well-known environment variables that agents set in their spawned +shell processes. + +Detection only succeeds when exactly one agent environment variable is present, +to avoid ambiguous attribution when multiple agent environments overlap. + +Adding a new agent requires only a new entry in KNOWN_AGENTS. + +References for each environment variable: + - ANTIGRAVITY_AGENT: Closed source. Google Antigravity sets this variable. + - CLAUDECODE: https://github.com/anthropics/claude-code (sets CLAUDECODE=1) + - CLINE_ACTIVE: https://github.com/cline/cline (shipped in v3.24.0) + - CODEX_CI: https://github.com/openai/codex (part of UNIFIED_EXEC_ENV array in codex-rs) + - CURSOR_AGENT: Closed source. Referenced in a gist by johnlindquist. + - GEMINI_CLI: https://google-gemini.github.io/gemini-cli/docs/tools/shell.html (sets GEMINI_CLI=1) + - OPENCODE: https://github.com/opencode-ai/opencode (sets OPENCODE=1) +""" + +import os + +KNOWN_AGENTS = [ + ("ANTIGRAVITY_AGENT", "antigravity"), + ("CLAUDECODE", "claude-code"), + ("CLINE_ACTIVE", "cline"), + ("CODEX_CI", "codex"), + ("CURSOR_AGENT", "cursor"), + ("GEMINI_CLI", "gemini-cli"), + ("OPENCODE", "opencode"), +] + + +def detect(env=None): + """Detect which AI coding agent (if any) is driving the current process. + + Args: + env: Optional dict-like object for environment variable lookup. + Defaults to os.environ. Exists for testability. + + Returns: + The agent product string if exactly one agent is detected, + or an empty string otherwise. + """ + if env is None: + env = os.environ + + detected = [product for var, product in KNOWN_AGENTS if env.get(var)] + + if len(detected) == 1: + return detected[0] + return "" diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 36e4b8a02..0b2c7490b 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -113,6 +113,7 @@ def _refresh_flags(self): # Authenticate the request self._connection.session.auth_provider.add_headers(headers) headers["User-Agent"] = self._connection.session.useragent_header + headers.update(self._connection.session.get_spog_headers()) response = self._http_client.request( HttpMethod.GET, self._feature_flag_endpoint, headers=headers, timeout=30 diff --git a/src/databricks/sql/parameters/native.py b/src/databricks/sql/parameters/native.py index b7c448254..d0fb8d82c 100644 --- a/src/databricks/sql/parameters/native.py +++ b/src/databricks/sql/parameters/native.py @@ -659,7 +659,7 @@ def dbsql_parameter_from_primitive( elif isinstance(value, str): return StringParameter(value=value, name=name) elif isinstance(value, float): - return FloatParameter(value=value, name=name) + return DoubleParameter(value=value, name=name) elif isinstance(value, datetime.datetime): return TimestampParameter(value=value, name=name) elif isinstance(value, datetime.date): diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 0f723d144..65c0d6aca 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -13,6 +13,7 @@ from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.agent import detect as detect_agent logger = logging.getLogger(__name__) @@ -64,9 +65,21 @@ def __init__( else: self.useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) + agent_product = detect_agent() + if agent_product: + self.useragent_header += " agent/{}".format(agent_product) + base_headers = [("User-Agent", self.useragent_header)] all_headers = (http_headers or []) + base_headers + # Extract ?o= from http_path for SPOG routing. + # On SPOG hosts, the httpPath contains ?o= which routes Thrift + # requests via the URL. For SEA, telemetry, and feature flags (which use + # separate endpoints), we inject x-databricks-org-id as an HTTP header. + self._spog_headers = self._extract_spog_headers(http_path, all_headers) + if self._spog_headers: + all_headers = all_headers + list(self._spog_headers.items()) + self.ssl_options = SSLOptions( # Double negation is generally a bad thing, but we have to keep backward compatibility tls_verify=not kwargs.get( @@ -131,6 +144,44 @@ def _create_backend( } return databricks_client_class(**common_args) + @staticmethod + def _extract_spog_headers(http_path, existing_headers): + """Extract ?o= from http_path and return as a header dict for SPOG routing.""" + if not http_path or "?" not in http_path: + return {} + + from urllib.parse import parse_qs + + query_string = http_path.split("?", 1)[1] + params = parse_qs(query_string) + org_id = params.get("o", [None])[0] + if not org_id: + logger.debug( + "SPOG header extraction: http_path has query string but no ?o= param, " + "skipping x-databricks-org-id injection" + ) + return {} + + # Don't override if explicitly set + if any(k == "x-databricks-org-id" for k, _ in existing_headers): + logger.debug( + "SPOG header extraction: x-databricks-org-id already set by caller, " + "not overriding with ?o=%s from http_path", + org_id, + ) + return {} + + logger.debug( + "SPOG header extraction: injecting x-databricks-org-id=%s " + "(extracted from ?o= in http_path)", + org_id, + ) + return {"x-databricks-org-id": org_id} + + def get_spog_headers(self): + """Returns SPOG routing headers (x-databricks-org-id) if ?o= was in http_path.""" + return dict(self._spog_headers) + def open(self): self._session_id = self.backend.open_session( session_configuration=self.session_configuration, diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 408162400..55d845e46 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -188,6 +188,7 @@ def __init__( executor, batch_size: int, client_context, + extra_headers: Optional[Dict[str, str]] = None, ) -> None: logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled @@ -195,6 +196,7 @@ def __init__( self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None + self._extra_headers = extra_headers or {} # OPTIMIZATION: Use lock-free Queue instead of list + lock # Queue is thread-safe internally and has better performance under concurrency @@ -287,6 +289,8 @@ def _send_telemetry(self, events): if self._auth_provider: self._auth_provider.add_headers(headers) + headers.update(self._extra_headers) + try: logger.debug("Submitting telemetry request to thread pool") @@ -587,6 +591,7 @@ def initialize_telemetry_client( host_url, batch_size, client_context, + extra_headers=None, ): """ Initialize a telemetry client for a specific connection if telemetry is enabled. @@ -627,6 +632,7 @@ def initialize_telemetry_client( executor=TelemetryClientFactory._executor, batch_size=batch_size, client_context=client_context, + extra_headers=extra_headers, ) TelemetryClientFactory._clients[ host_url diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 043183ac2..ce2670969 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -895,7 +895,50 @@ def concat_table_chunks( result_table[j].extend(table_chunks[i].column_table[j]) return ColumnTable(result_table, table_chunks[0].column_names) else: - return pyarrow.concat_tables(table_chunks) + return pyarrow.concat_tables(table_chunks, promote_options="default") + + +def serialize_query_tags( + query_tags: Optional[Dict[str, Optional[str]]] +) -> Optional[str]: + """ + Serialize query_tags dictionary to a string format. + + Format: "key1:value1,key2:value2" + Special cases: + - If value is None, omit the colon and value (e.g., "key1:value1,key2,key3:value3") + - Escape special characters (:, ,, \\) in values with a leading backslash + - Backslashes in keys are escaped; other special characters in keys are not escaped + + Args: + query_tags: Dictionary of query tags where keys are strings and values are optional strings + + Returns: + Serialized string or None if query_tags is None or empty + """ + if not query_tags: + return None + + def escape_value(value: str) -> str: + """Escape special characters in tag values.""" + # Escape backslash first to avoid double-escaping + value = value.replace("\\", r"\\") + # Escape colon and comma + value = value.replace(":", r"\:") + value = value.replace(",", r"\,") + return value + + serialized_parts = [] + for key, value in query_tags.items(): + escaped_key = key.replace("\\", r"\\") + if value is None: + # No colon or value when value is None + serialized_parts.append(escaped_key) + else: + escaped_value = escape_value(value) + serialized_parts.append(f"{escaped_key}:{escaped_value}") + + return ",".join(serialized_parts) def build_client_context(server_hostname: str, version: str, **kwargs): @@ -914,12 +957,18 @@ def build_client_context(server_hostname: str, version: str, **kwargs): ) # Build user agent + from databricks.sql.common.agent import detect as detect_agent + user_agent_entry = kwargs.get("user_agent_entry", "") if user_agent_entry: user_agent = f"PyDatabricksSqlConnector/{version} ({user_agent_entry})" else: user_agent = f"PyDatabricksSqlConnector/{version}" + agent_product = detect_agent() + if agent_product: + user_agent += f" agent/{agent_product}" + # Explicitly construct ClientContext with proper types return ClientContext( hostname=server_hostname, diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index dd7c56996..7255ee095 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -2,139 +2,43 @@ import math import time -import pytest - log = logging.getLogger(__name__) -class LargeQueriesMixin: +def fetch_rows(test_case, cursor, row_count, fetchmany_size): """ - This mixin expects to be mixed with a CursorTest-like class + A generator for rows. Fetches until the end or up to 5 minutes. """ - - def fetch_rows(self, cursor, row_count, fetchmany_size): - """ - A generator for rows. Fetches until the end or up to 5 minutes. - """ - # TODO: Remove fetchmany_size when we have fixed the performance issues with fetchone - # in the Python client - max_fetch_time = 5 * 60 # Fetch for at most 5 minutes - - rows = self.get_some_rows(cursor, fetchmany_size) - start_time = time.time() - n = 0 - while rows: - for row in rows: - n += 1 - yield row - if time.time() - start_time >= max_fetch_time: - log.warning("Fetching rows timed out") - break - rows = self.get_some_rows(cursor, fetchmany_size) - if not rows: - # Read all the rows, row_count should match - self.assertEqual(n, row_count) - - num_fetches = max(math.ceil(n / 10000), 1) - latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1 - print( - "Fetched {} rows with an avg latency of {} per fetch, ".format( - n, latency_ms - ) - + "assuming 10K fetch size." + max_fetch_time = 5 * 60 # Fetch for at most 5 minutes + + rows = _get_some_rows(cursor, fetchmany_size) + start_time = time.time() + n = 0 + while rows: + for row in rows: + n += 1 + yield row + if time.time() - start_time >= max_fetch_time: + log.warning("Fetching rows timed out") + break + rows = _get_some_rows(cursor, fetchmany_size) + if not rows: + # Read all the rows, row_count should match + test_case.assertEqual(n, row_count) + + num_fetches = max(math.ceil(n / 10000), 1) + latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1 + print( + "Fetched {} rows with an avg latency of {} per fetch, ".format( + n, latency_ms ) - - @pytest.mark.parametrize( - "extra_params", - [ - {}, - {"use_sea": True}, - ], + + "assuming 10K fetch size." ) - def test_query_with_large_wide_result_set(self, extra_params): - resultSize = 300 * 1000 * 1000 # 300 MB - width = 8192 # B - rows = resultSize // width - cols = width // 36 - - # Set the fetchmany_size to get 10MB of data a go - fetchmany_size = 10 * 1024 * 1024 // width - # This is used by PyHive tests to determine the buffer size - self.arraysize = 1000 - with self.cursor(extra_params) as cursor: - for lz4_compression in [False, True]: - cursor.connection.lz4_compression = lz4_compression - uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) - cursor.execute( - "SELECT id, {uuids} FROM RANGE({rows})".format( - uuids=uuids, rows=rows - ) - ) - assert lz4_compression == cursor.active_result_set.lz4_compressed - for row_id, row in enumerate( - self.fetch_rows(cursor, rows, fetchmany_size) - ): - assert row[0] == row_id # Verify no rows are dropped in the middle. - assert len(row[1]) == 36 - - @pytest.mark.parametrize( - "extra_params", - [ - {}, - {"use_sea": True}, - ], - ) - def test_query_with_large_narrow_result_set(self, extra_params): - resultSize = 300 * 1000 * 1000 # 300 MB - width = 8 # sizeof(long) - rows = resultSize / width - - # Set the fetchmany_size to get 10MB of data a go - fetchmany_size = 10 * 1024 * 1024 // width - # This is used by PyHive tests to determine the buffer size - self.arraysize = 10000000 - with self.cursor(extra_params) as cursor: - cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows)) - for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): - assert row[0] == row_id - - @pytest.mark.parametrize( - "extra_params", - [ - {}, - {"use_sea": True}, - ], - ) - def test_long_running_query(self, extra_params): - """Incrementally increase query size until it takes at least 3 minutes, - and asserts that the query completes successfully. - """ - minutes = 60 - min_duration = 3 * minutes - - duration = -1 - scale0 = 10000 - scale_factor = 1 - with self.cursor(extra_params) as cursor: - while duration < min_duration: - assert scale_factor < 4096, "Detected infinite loop" - start = time.time() - - cursor.execute( - """SELECT count(*) - FROM RANGE({scale}) x - JOIN RANGE({scale0}) y - ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%" - """.format( - scale=scale_factor * scale0, scale0=scale0 - ) - ) - (n,) = cursor.fetchone() - assert n == 0 - duration = time.time() - start - current_fraction = duration / min_duration - print("Took {} s with scale factor={}".format(duration, scale_factor)) - # Extrapolate linearly to reach 3 min and add 50% padding to push over the limit - scale_factor = math.ceil(1.5 * scale_factor / current_fraction) +def _get_some_rows(cursor, fetchmany_size): + row = cursor.fetchone() + if row: + return [row] + else: + return None diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index e04e348c9..45b56ae08 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -39,7 +39,7 @@ ) from databricks.sql.thrift_api.TCLIService import ttypes from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin -from tests.e2e.common.large_queries_mixin import LargeQueriesMixin +from tests.e2e.common.large_queries_mixin import fetch_rows from tests.e2e.common.timestamp_tests import TimestampTestsMixin from tests.e2e.common.decimal_tests import DecimalTestsMixin from tests.e2e.common.retry_test_mixins import ( @@ -138,24 +138,89 @@ def assertEqualRowValues(self, actual, expected): assert act[i] == exp[i] -class TestPySQLLargeQueriesSuite(PySQLPytestTestCase, LargeQueriesMixin): - def get_some_rows(self, cursor, fetchmany_size): - row = cursor.fetchone() - if row: - return [row] - else: - return None +class TestPySQLLargeWideResultSet(PySQLPytestTestCase): + @pytest.mark.parametrize("extra_params", [{}, {"use_sea": True}]) + @pytest.mark.parametrize("lz4_compression", [False, True]) + def test_query_with_large_wide_result_set(self, extra_params, lz4_compression): + resultSize = 100 * 1000 * 1000 # 100 MB + width = 8192 # B + rows = resultSize // width + cols = width // 36 + fetchmany_size = 10 * 1024 * 1024 // width + self.arraysize = 1000 + with self.cursor(extra_params) as cursor: + cursor.connection.lz4_compression = lz4_compression + uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) + cursor.execute( + "SELECT id, {uuids} FROM RANGE({rows})".format( + uuids=uuids, rows=rows + ) + ) + assert lz4_compression == cursor.active_result_set.lz4_compressed + for row_id, row in enumerate( + fetch_rows(self, cursor, rows, fetchmany_size) + ): + assert row[0] == row_id + assert len(row[1]) == 36 + + +class TestPySQLLargeNarrowResultSet(PySQLPytestTestCase): + @pytest.mark.parametrize("extra_params", [{}, {"use_sea": True}]) + def test_query_with_large_narrow_result_set(self, extra_params): + resultSize = 100 * 1000 * 1000 # 100 MB + width = 8 # sizeof(long) + rows = resultSize / width + fetchmany_size = 10 * 1024 * 1024 // width + self.arraysize = 10000000 + with self.cursor(extra_params) as cursor: + cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows)) + for row_id, row in enumerate( + fetch_rows(self, cursor, rows, fetchmany_size) + ): + assert row[0] == row_id + + +class TestPySQLLongRunningQuery(PySQLPytestTestCase): + @pytest.mark.parametrize("extra_params", [{}, {"use_sea": True}]) + def test_long_running_query(self, extra_params): + """Incrementally increase query size until it takes at least 1 minute, + and asserts that the query completes successfully. + """ + import math + + minutes = 60 + min_duration = 1 * minutes + duration = -1 + scale0 = 10000 + scale_factor = 50 + with self.cursor(extra_params) as cursor: + while duration < min_duration: + assert scale_factor < 4096, "Detected infinite loop" + start = time.time() + cursor.execute( + """SELECT count(*) + FROM RANGE({scale}) x + JOIN RANGE({scale0}) y + ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%" + """.format( + scale=scale_factor * scale0, scale0=scale0 + ) + ) + (n,) = cursor.fetchone() + assert n == 0 + duration = time.time() - start + current_fraction = duration / min_duration + print("Took {} s with scale factor={}".format(duration, scale_factor)) + scale_factor = math.ceil(1.5 * scale_factor / current_fraction) + +class TestPySQLCloudFetch(PySQLPytestTestCase): @skipUnless(pysql_supports_arrow(), "needs arrow support") @pytest.mark.skip("This test requires a previously uploaded data set") def test_cloud_fetch(self): - # This test can take several minutes to run limits = [100000, 300000] threads = [10, 25] self.arraysize = 100000 - # This test requires a large table with many rows to properly initiate cloud fetch. - # e2-dogfood host > hive_metastore catalog > main schema has such a table called store_sales. - # If this table is deleted or this test is run on a different host, a different table may need to be used. base_query = "SELECT * FROM store_sales WHERE ss_sold_date_sk = 2452234 " for num_limit, num_threads, lz4_compression in itertools.product( limits, threads, [True, False] diff --git a/tests/e2e/test_transactions.py b/tests/e2e/test_transactions.py index d4f6a790a..4fb7918b9 100644 --- a/tests/e2e/test_transactions.py +++ b/tests/e2e/test_transactions.py @@ -1,598 +1,772 @@ """ End-to-end integration tests for Multi-Statement Transaction (MST) APIs. -These tests verify: -- autocommit property (getter/setter) -- commit() and rollback() methods -- get_transaction_isolation() and set_transaction_isolation() methods -- Transaction error handling +Tests driver behavior for MST across: +- Basic correctness (commit/rollback/isolation/multi-table) +- API-specific (autocommit, isolation level, error handling) +- Metadata RPCs inside transactions (non-transactional freshness) +- SQL statements blocked by MSTCheckRule (SHOW, DESCRIBE, information_schema) +- Execute variants (executemany) + +Parallelisation: +- Each test uses its own unique table (derived from test name) to allow + parallel execution with pytest-xdist. +- Tests requiring multiple concurrent connections to the same table are + tagged with xdist_group so the concurrent connections within a single + test don't conflict with other tests on different workers. Requirements: - DBSQL warehouse that supports Multi-Statement Transactions (MST) -- Test environment configured via test.env file or environment variables - -Setup: -Set the following environment variables: -- DATABRICKS_SERVER_HOSTNAME -- DATABRICKS_HTTP_PATH -- DATABRICKS_ACCESS_TOKEN (or use OAuth) - -Usage: - pytest tests/e2e/test_transactions.py -v +- Env vars: DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, + DATABRICKS_TOKEN, DATABRICKS_CATALOG, DATABRICKS_SCHEMA """ import logging import os +import re +import uuid + import pytest -from typing import Any, Dict import databricks.sql as sql -from databricks.sql import TransactionError, NotSupportedError, InterfaceError logger = logging.getLogger(__name__) -@pytest.mark.skip( - reason="Test environment does not yet support multi-statement transactions" -) -class TestTransactions: - """E2E tests for transaction control methods (MST support).""" +def _unique_table_name(request): + """Derive a unique Delta table name from the test node id.""" + node_id = request.node.name + sanitized = re.sub(r"[^a-z0-9_]", "_", node_id.lower()) + return f"mst_pysql_{sanitized}"[:80] - # Test table name - TEST_TABLE_NAME = "transaction_test_table" - @pytest.fixture(autouse=True) - def setup_and_teardown(self, connection_details): - """Setup test environment before each test and cleanup after.""" - self.connection_params = { - "server_hostname": connection_details["host"], - "http_path": connection_details["http_path"], - "access_token": connection_details.get("access_token"), - "ignore_transactions": False, # Enable actual transaction functionality for these tests - } +def _unique_table_name_raw(suffix): + """Non-fixture unique table name helper for extra tables within a test.""" + return f"mst_pysql_{suffix}_{uuid.uuid4().hex[:8]}" - # Get catalog and schema from environment or use defaults - self.catalog = os.getenv("DATABRICKS_CATALOG", "main") - self.schema = os.getenv("DATABRICKS_SCHEMA", "default") - # Create connection for setup - self.connection = sql.connect(**self.connection_params) +@pytest.fixture +def mst_conn_params(connection_details): + """Connection parameters with MST enabled.""" + return { + "server_hostname": connection_details["host"], + "http_path": connection_details["http_path"], + "access_token": connection_details.get("access_token"), + "ignore_transactions": False, + } - # Setup: Create test table - self._create_test_table() - yield +@pytest.fixture +def mst_catalog(connection_details): + return connection_details.get("catalog") or os.getenv("DATABRICKS_CATALOG") or "main" - # Teardown: Cleanup - self._cleanup() - def _get_fully_qualified_table_name(self) -> str: - """Get the fully qualified table name.""" - return f"{self.catalog}.{self.schema}.{self.TEST_TABLE_NAME}" +@pytest.fixture +def mst_schema(connection_details): + return connection_details.get("schema") or os.getenv("DATABRICKS_SCHEMA") or "default" - def _create_test_table(self): - """Create the test table with Delta format and MST support.""" - fq_table_name = self._get_fully_qualified_table_name() - cursor = self.connection.cursor() - try: - # Drop if exists - cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") +@pytest.fixture +def mst_table(request, mst_conn_params, mst_catalog, mst_schema): + """Create a fresh Delta table for the test and drop it afterwards. + + Yields (fq_table_name, table_name). The table is unique per test so tests + can run in parallel without stepping on each other. + """ + table_name = _unique_table_name(request) + fq_table = f"{mst_catalog}.{mst_schema}.{table_name}" - # Create table with Delta and catalog-owned feature for MST compatibility + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table}") cursor.execute( - f""" - CREATE TABLE IF NOT EXISTS {fq_table_name} - (id INT, value STRING) - USING DELTA - TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') - """ + f"CREATE TABLE {fq_table} (id INT, value STRING) USING DELTA " + f"TBLPROPERTIES ('delta.feature.catalogManaged' = 'supported')" ) - logger.info(f"Created test table: {fq_table_name}") - finally: - cursor.close() - - def _cleanup(self): - """Cleanup after test: rollback pending transactions, drop table, close connection.""" - try: - # Try to rollback any pending transaction - if ( - self.connection - and self.connection.open - and not self.connection.autocommit - ): - try: - self.connection.rollback() - except Exception as e: - logger.debug( - f"Rollback during cleanup failed (may be expected): {e}" + yield fq_table, table_name + + try: + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table}") + except Exception as e: + logger.warning(f"Failed to drop {fq_table}: {e}") + + +def _get_row_count(mst_conn_params, fq_table): + """Count rows from a fresh connection (avoids in-txn caching).""" + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT COUNT(*) FROM {fq_table}") + return cursor.fetchone()[0] + + +def _get_ids(mst_conn_params, fq_table): + """Return the set of ids from a fresh connection.""" + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT id FROM {fq_table}") + return {row[0] for row in cursor.fetchall()} + + +# ==================== A. BASIC CORRECTNESS ==================== + + +class TestMstCorrectness: + """Core MST correctness: commit, rollback, isolation, multi-table.""" + + def test_commit_single_insert(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'committed')") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 1 + + def test_commit_multiple_inserts(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'a')") + cursor.execute(f"INSERT INTO {fq_table} VALUES (2, 'b')") + cursor.execute(f"INSERT INTO {fq_table} VALUES (3, 'c')") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 3 + + def test_rollback_single_insert(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'rolled_back')") + conn.rollback() + + assert _get_row_count(mst_conn_params, fq_table) == 0 + + def test_sequential_transactions(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'txn1')") + conn.commit() + + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (2, 'txn2')") + conn.rollback() + + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (3, 'txn3')") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 2 + + def test_auto_start_after_commit(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'txn1')") + conn.commit() + + # Second INSERT auto-starts a new transaction + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (2, 'txn2')") + conn.rollback() + + assert _get_ids(mst_conn_params, fq_table) == {1} + + def test_auto_start_after_rollback(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'txn1')") + conn.rollback() + + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (2, 'txn2')") + conn.commit() + + assert _get_ids(mst_conn_params, fq_table) == {2} + + def test_update_in_transaction(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'original')") + + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"UPDATE {fq_table} SET value = 'updated' WHERE id = 1") + conn.commit() + + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT value FROM {fq_table} WHERE id = 1") + assert cursor.fetchone()[0] == "updated" + + def test_delete_in_transaction(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'a')") + cursor.execute(f"INSERT INTO {fq_table} VALUES (2, 'b')") + + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"DELETE FROM {fq_table} WHERE id = 1") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 1 + + def test_multi_table_commit(self, mst_conn_params, mst_table, mst_catalog, mst_schema): + fq_table1, _ = mst_table + fq_table2 = f"{mst_catalog}.{mst_schema}.{_unique_table_name_raw('multi_commit_t2')}" + + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2}") + cursor.execute( + f"CREATE TABLE {fq_table2} (id INT, value STRING) USING DELTA " + f"TBLPROPERTIES ('delta.feature.catalogManaged' = 'supported')" + ) + try: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table1} VALUES (1, 't1')") + cursor.execute(f"INSERT INTO {fq_table2} VALUES (1, 't2')") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table1) == 1 + assert _get_row_count(mst_conn_params, fq_table2) == 1 + finally: + conn.autocommit = True + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2}") + + def test_multi_table_rollback(self, mst_conn_params, mst_table, mst_catalog, mst_schema): + fq_table1, _ = mst_table + fq_table2 = f"{mst_catalog}.{mst_schema}.{_unique_table_name_raw('multi_rb_t2')}" + + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2}") + cursor.execute( + f"CREATE TABLE {fq_table2} (id INT, value STRING) USING DELTA " + f"TBLPROPERTIES ('delta.feature.catalogManaged' = 'supported')" + ) + try: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table1} VALUES (1, 't1')") + cursor.execute(f"INSERT INTO {fq_table2} VALUES (1, 't2')") + conn.rollback() + + assert _get_row_count(mst_conn_params, fq_table1) == 0 + assert _get_row_count(mst_conn_params, fq_table2) == 0 + finally: + conn.autocommit = True + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2}") + + def test_multi_table_atomicity(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'should_rollback')") + with pytest.raises(Exception): + cursor.execute( + "INSERT INTO nonexistent_table_xyz_xyz VALUES (1, 'fail')" + ) + conn.rollback() + + assert _get_row_count(mst_conn_params, fq_table) == 0 + + @pytest.mark.xdist_group(name="mst_repeatable_reads") + def test_repeatable_reads(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'initial')") + + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"SELECT value FROM {fq_table} WHERE id = 1") + first_read = cursor.fetchone()[0] + + # External connection modifies data + with sql.connect(**mst_conn_params) as ext_conn: + with ext_conn.cursor() as ext_cursor: + ext_cursor.execute( + f"UPDATE {fq_table} SET value = 'modified' WHERE id = 1" ) - # Reset to autocommit mode - try: - self.connection.autocommit = True - except Exception as e: - logger.debug(f"Reset autocommit during cleanup failed: {e}") - - # Drop test table - if self.connection and self.connection.open: - fq_table_name = self._get_fully_qualified_table_name() - cursor = self.connection.cursor() - try: - cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") - logger.info(f"Dropped test table: {fq_table_name}") - except Exception as e: - logger.warning(f"Failed to drop test table: {e}") - finally: - cursor.close() + # Re-read in same txn — should see original value + with conn.cursor() as cursor: + cursor.execute(f"SELECT value FROM {fq_table} WHERE id = 1") + second_read = cursor.fetchone()[0] - finally: - # Close connection - if self.connection: - self.connection.close() - - # ==================== BASIC AUTOCOMMIT TESTS ==================== - - def test_default_autocommit_is_true(self): - """Test that new connection defaults to autocommit=true.""" - assert ( - self.connection.autocommit is True - ), "New connection should have autocommit=true by default" - - def test_set_autocommit_to_false(self): - """Test successfully setting autocommit to false.""" - self.connection.autocommit = False - assert ( - self.connection.autocommit is False - ), "autocommit should be false after setting to false" - - def test_set_autocommit_to_true(self): - """Test successfully setting autocommit back to true.""" - # First disable - self.connection.autocommit = False - assert self.connection.autocommit is False - - # Then enable - self.connection.autocommit = True - assert ( - self.connection.autocommit is True - ), "autocommit should be true after setting to true" - - # ==================== COMMIT TESTS ==================== - - def test_commit_single_insert(self): - """Test successfully committing a transaction with single INSERT.""" - fq_table_name = self._get_fully_qualified_table_name() - - # Start transaction - self.connection.autocommit = False - - # Insert data - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'test_value')" - ) - cursor.close() + assert first_read == second_read, "Repeatable read: value should not change" + conn.rollback() - # Commit - self.connection.commit() + @pytest.mark.xdist_group(name="mst_write_conflict") + def test_write_conflict_single_table(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as setup_conn: + with setup_conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'initial')") - # Verify data is persisted using a new connection - verify_conn = sql.connect(**self.connection_params) + conn1 = sql.connect(**mst_conn_params) + conn2 = sql.connect(**mst_conn_params) try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") - result = verify_cursor.fetchone() - verify_cursor.close() + conn1.autocommit = False + conn2.autocommit = False - assert result is not None, "Should find inserted row after commit" - assert result[0] == "test_value", "Value should match inserted value" - finally: - verify_conn.close() + with conn1.cursor() as c1: + c1.execute(f"UPDATE {fq_table} SET value = 'conn1' WHERE id = 1") + with conn2.cursor() as c2: + c2.execute(f"UPDATE {fq_table} SET value = 'conn2' WHERE id = 1") - def test_commit_multiple_inserts(self): - """Test successfully committing a transaction with multiple INSERTs.""" - fq_table_name = self._get_fully_qualified_table_name() - - self.connection.autocommit = False - - # Insert multiple rows - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'value1')") - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'value2')") - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'value3')") - cursor.close() - - self.connection.commit() - - # Verify all rows persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name}") - result = verify_cursor.fetchone() - verify_cursor.close() - - assert result[0] == 3, "Should have 3 rows after commit" + conn1.commit() + with pytest.raises(Exception): + conn2.commit() finally: - verify_conn.close() - - # ==================== ROLLBACK TESTS ==================== - - def test_rollback_single_insert(self): - """Test successfully rolling back a transaction.""" - fq_table_name = self._get_fully_qualified_table_name() - - self.connection.autocommit = False + try: + conn1.close() + except Exception: + pass + try: + conn2.close() + except Exception: + pass + + def test_read_only_transaction(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'existing')") + + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"SELECT COUNT(*) FROM {fq_table}") + assert cursor.fetchone()[0] == 1 + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 1 + + def test_rollback_after_query_failure(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'before_error')") + with pytest.raises(Exception): + cursor.execute("SELECT * FROM nonexistent_xyz_xyz") + conn.rollback() + + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (2, 'after_recovery')") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 1 + + def test_multiple_cursors_in_txn(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as c1: + c1.execute(f"INSERT INTO {fq_table} VALUES (1, 'c1')") + with conn.cursor() as c2: + c2.execute(f"INSERT INTO {fq_table} VALUES (2, 'c2')") + with conn.cursor() as c3: + c3.execute(f"INSERT INTO {fq_table} VALUES (3, 'c3')") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 3 + + def test_parameterized_insert(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute( + f"INSERT INTO {fq_table} VALUES (%(id)s, %(value)s)", + {"id": 1, "value": "parameterized"}, + ) + conn.commit() + + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT value FROM {fq_table} WHERE id = 1") + assert cursor.fetchone()[0] == "parameterized" + + def test_empty_transaction_rollback(self, mst_conn_params, mst_table): + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + # Rollback with no DML should not raise + conn.rollback() + + def test_close_connection_implicit_rollback(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + conn = sql.connect(**mst_conn_params) + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'pending')") + conn.close() + + assert _get_row_count(mst_conn_params, fq_table) == 0 + + +# ==================== B. API-SPECIFIC TESTS ==================== + + +class TestMstApi: + """DB-API-specific tests: autocommit, isolation, error handling.""" + + def test_default_autocommit_is_true(self, mst_conn_params): + with sql.connect(**mst_conn_params) as conn: + assert conn.autocommit is True + + def test_set_autocommit_false(self, mst_conn_params): + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + assert conn.autocommit is False + + def test_commit_without_active_txn_throws(self, mst_conn_params): + with sql.connect(**mst_conn_params) as conn: + with pytest.raises(Exception, match=r"NO_ACTIVE_TRANSACTION"): + conn.commit() + + def test_set_autocommit_during_active_txn_throws( + self, mst_conn_params, mst_table + ): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'active_txn')") + with pytest.raises(Exception): + conn.autocommit = True + conn.rollback() + + def test_supported_isolation_level(self, mst_conn_params): + with sql.connect(**mst_conn_params) as conn: + conn.set_transaction_isolation("REPEATABLE_READ") + assert conn.get_transaction_isolation() == "REPEATABLE_READ" + + def test_unsupported_isolation_level_rejected(self, mst_conn_params): + with sql.connect(**mst_conn_params) as conn: + for level in ["READ_UNCOMMITTED", "READ_COMMITTED", "SERIALIZABLE"]: + with pytest.raises(Exception): + conn.set_transaction_isolation(level) + + +# ==================== C. METADATA RPCs ==================== + + +class TestMstMetadata: + """Metadata RPCs inside active transactions. + + Python uses Thrift RPCs for cursor.columns, cursor.tables, etc. These + RPCs bypass MST context and return non-transactional data — they see + concurrent DDL changes that the transaction shouldn't see. + """ + + def test_cursor_columns_in_mst( + self, mst_conn_params, mst_table, mst_catalog, mst_schema + ): + fq_table, table_name = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") + cursor.columns( + catalog_name=mst_catalog, schema_name=mst_schema, table_name=table_name + ) + columns = cursor.fetchall() + assert len(columns) > 0 + conn.rollback() + + def test_cursor_tables_in_mst( + self, mst_conn_params, mst_table, mst_catalog, mst_schema + ): + fq_table, table_name = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") + cursor.tables( + catalog_name=mst_catalog, schema_name=mst_schema, table_name=table_name + ) + tables = cursor.fetchall() + assert len(tables) > 0 + conn.rollback() + + def test_cursor_schemas_in_mst(self, mst_conn_params, mst_table, mst_catalog): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") + cursor.schemas(catalog_name=mst_catalog) + schemas = cursor.fetchall() + assert len(schemas) > 0 + conn.rollback() + + def test_cursor_catalogs_in_mst(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") + cursor.catalogs() + catalogs = cursor.fetchall() + assert len(catalogs) > 0 + conn.rollback() + + @pytest.mark.xdist_group(name="mst_freshness_columns") + def test_cursor_columns_non_transactional_after_concurrent_ddl( + self, mst_conn_params, mst_table, mst_catalog, mst_schema + ): + """Thrift cursor.columns() bypasses MST — sees concurrent ALTER TABLE.""" + fq_table, table_name = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") + cursor.columns( + catalog_name=mst_catalog, schema_name=mst_schema, table_name=table_name + ) + before_cols = {row[3].lower() for row in cursor.fetchall()} - # Insert data - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table_name} (id, value) VALUES (100, 'rollback_test')" - ) - cursor.close() + # External connection alters schema + with sql.connect(**mst_conn_params) as ext_conn: + with ext_conn.cursor() as ext_cursor: + ext_cursor.execute( + f"ALTER TABLE {fq_table} ADD COLUMN new_col STRING" + ) - # Rollback - self.connection.rollback() + # Re-read columns in same txn — Thrift RPC bypasses txn isolation, + # so new_col IS visible (proves non-transactional behavior) + with conn.cursor() as cursor: + cursor.columns( + catalog_name=mst_catalog, schema_name=mst_schema, table_name=table_name + ) + after_cols = {row[3].lower() for row in cursor.fetchall()} - # Verify data is NOT persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 100" + assert "new_col" in after_cols, ( + "Thrift cursor.columns() should see concurrent DDL " + "(non-transactional behavior)" ) - result = verify_cursor.fetchone() - verify_cursor.close() - - assert result[0] == 0, "Rolled back data should not be persisted" - finally: - verify_conn.close() - - # ==================== SEQUENTIAL TRANSACTION TESTS ==================== - - def test_multiple_sequential_transactions(self): - """Test executing multiple sequential transactions (commit, commit, rollback).""" - fq_table_name = self._get_fully_qualified_table_name() - - self.connection.autocommit = False - - # First transaction - commit - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'txn1')") - cursor.close() - self.connection.commit() - - # Second transaction - commit - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'txn2')") - cursor.close() - self.connection.commit() - - # Third transaction - rollback - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'txn3')") - cursor.close() - self.connection.rollback() + assert before_cols != after_cols + conn.rollback() + + @pytest.mark.xdist_group(name="mst_freshness_tables") + def test_cursor_tables_non_transactional_after_concurrent_create( + self, mst_conn_params, mst_table, mst_catalog, mst_schema + ): + """Thrift cursor.tables() bypasses MST — sees concurrent CREATE TABLE.""" + fq_table, _ = mst_table + new_table_name = _unique_table_name_raw("freshness_new_tbl") + fq_new_table = f"{mst_catalog}.{mst_schema}.{new_table_name}" - # Verify only first two transactions persisted - verify_conn = sql.connect(**self.connection_params) try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table_name} WHERE id IN (1, 2)" - ) - result = verify_cursor.fetchone() - assert result[0] == 2, "Should have 2 committed rows" - - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 3") - result = verify_cursor.fetchone() - assert result[0] == 0, "Rolled back row should not exist" - verify_cursor.close() + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") + cursor.tables( + catalog_name=mst_catalog, + schema_name=mst_schema, + table_name=new_table_name, + ) + assert len(cursor.fetchall()) == 0 + + # External connection creates the table + with sql.connect(**mst_conn_params) as ext_conn: + with ext_conn.cursor() as ext_cursor: + ext_cursor.execute( + f"CREATE TABLE {fq_new_table} (id INT) USING DELTA " + f"TBLPROPERTIES ('delta.feature.catalogManaged' = 'supported')" + ) + + # Re-read in same txn — should see the new table + with conn.cursor() as cursor: + cursor.tables( + catalog_name=mst_catalog, + schema_name=mst_schema, + table_name=new_table_name, + ) + assert len(cursor.fetchall()) > 0, ( + "Thrift cursor.tables() should see concurrent CREATE TABLE " + "(non-transactional behavior)" + ) + conn.rollback() finally: - verify_conn.close() - - def test_auto_start_transaction_after_commit(self): - """Test that new transaction automatically starts after commit.""" - fq_table_name = self._get_fully_qualified_table_name() + try: + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_new_table}") + except Exception as e: + logger.warning(f"Failed to drop {fq_new_table}: {e}") - self.connection.autocommit = False - # First transaction - commit - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") - cursor.close() - self.connection.commit() +# ==================== D. BLOCKED SQL (MSTCheckRule) ==================== - # New transaction should start automatically - insert and rollback - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") - cursor.close() - self.connection.rollback() - # Verify: first committed, second rolled back - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") - result = verify_cursor.fetchone() - assert result[0] == 1, "First insert should be committed" - - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") - result = verify_cursor.fetchone() - assert result[0] == 0, "Second insert should be rolled back" - verify_cursor.close() - finally: - verify_conn.close() +class TestMstBlockedSql: + """SQL introspection statements inside active transactions. - def test_auto_start_transaction_after_rollback(self): - """Test that new transaction automatically starts after rollback.""" - fq_table_name = self._get_fully_qualified_table_name() + The server restricts MST to an allowlist enforced by MSTCheckRule. The + TRANSACTION_NOT_SUPPORTED.COMMAND error originally advertised only: + "Only SELECT / INSERT / MERGE / UPDATE / DELETE / DESCRIBE TABLE are supported." - self.connection.autocommit = False + The server has since broadened the allowlist to include SHOW COLUMNS + (ShowDeltaTableColumnsCommand), observed on current DBSQL warehouses. - # First transaction - rollback - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") - cursor.close() - self.connection.rollback() + Blocked (throw + abort txn): + - SHOW TABLES, SHOW SCHEMAS, SHOW CATALOGS, SHOW FUNCTIONS + - DESCRIBE QUERY, DESCRIBE TABLE EXTENDED + - SELECT FROM information_schema - # New transaction should start automatically - insert and commit - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") - cursor.close() - self.connection.commit() - - # Verify: first rolled back, second committed - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") - result = verify_cursor.fetchone() - assert result[0] == 0, "First insert should be rolled back" - - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") - result = verify_cursor.fetchone() - assert result[0] == 1, "Second insert should be committed" - verify_cursor.close() - finally: - verify_conn.close() + Allowed: + - DESCRIBE TABLE (basic form) + - SHOW COLUMNS + """ - # ==================== UPDATE/DELETE OPERATION TESTS ==================== + def _assert_blocked_and_txn_aborted(self, mst_conn_params, fq_table, blocked_sql): + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'before_blocked')") - def test_update_in_transaction(self): - """Test UPDATE operation in transaction.""" - fq_table_name = self._get_fully_qualified_table_name() + with pytest.raises(Exception): + cursor.execute(blocked_sql) - # First insert a row with autocommit - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'original')" + with pytest.raises(Exception): + cursor.execute( + f"INSERT INTO {fq_table} VALUES (2, 'after_blocked')" + ) + try: + conn.rollback() + except Exception: + pass + + def _assert_not_blocked(self, mst_conn_params, fq_table, allowed_sql): + """Assert the SQL succeeds and returns rows inside an active txn.""" + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'before')") + cursor.execute(allowed_sql) + rows = cursor.fetchall() + assert len(rows) > 0 + conn.rollback() + + def test_show_tables_blocked(self, mst_conn_params, mst_table, mst_catalog, mst_schema): + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, fq_table, f"SHOW TABLES IN {mst_catalog}.{mst_schema}" ) - cursor.close() - # Start transaction and update - self.connection.autocommit = False - cursor = self.connection.cursor() - cursor.execute(f"UPDATE {fq_table_name} SET value = 'updated' WHERE id = 1") - cursor.close() - self.connection.commit() - - # Verify update persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") - result = verify_cursor.fetchone() - assert result[0] == "updated", "Value should be updated after commit" - verify_cursor.close() - finally: - verify_conn.close() - - # ==================== MULTI-TABLE TRANSACTION TESTS ==================== - - def test_multi_table_transaction_commit(self): - """Test atomic commit across multiple tables.""" - fq_table1_name = self._get_fully_qualified_table_name() - table2_name = self.TEST_TABLE_NAME + "_2" - fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" - - # Create second table - cursor = self.connection.cursor() - cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") - cursor.execute( - f""" - CREATE TABLE IF NOT EXISTS {fq_table2_name} - (id INT, category STRING) - USING DELTA - TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') - """ + def test_show_schemas_blocked(self, mst_conn_params, mst_table, mst_catalog): + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, fq_table, f"SHOW SCHEMAS IN {mst_catalog}" ) - cursor.close() - try: - # Start transaction and insert into both tables - self.connection.autocommit = False - - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table1_name} (id, value) VALUES (10, 'table1_data')" - ) - cursor.execute( - f"INSERT INTO {fq_table2_name} (id, category) VALUES (10, 'table2_data')" - ) - cursor.close() + def test_show_catalogs_blocked(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, fq_table, "SHOW CATALOGS" + ) - # Commit both atomically - self.connection.commit() + def test_show_functions_blocked(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, fq_table, "SHOW FUNCTIONS" + ) - # Verify both inserts persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() + def test_describe_table_extended_blocked(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, fq_table, f"DESCRIBE TABLE EXTENDED {fq_table}" + ) - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 10" - ) - result = verify_cursor.fetchone() - assert result[0] == 1, "Table1 insert should be committed" + def test_information_schema_blocked(self, mst_conn_params, mst_table, mst_catalog): + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, + fq_table, + f"SELECT * FROM {mst_catalog}.information_schema.columns LIMIT 1", + ) - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 10" - ) - result = verify_cursor.fetchone() - assert result[0] == 1, "Table2 insert should be committed" + def test_show_columns_not_blocked(self, mst_conn_params, mst_table): + """SHOW COLUMNS succeeds in MST — allowed by the server's MSTCheckRule allowlist.""" + fq_table, _ = mst_table + self._assert_not_blocked( + mst_conn_params, fq_table, f"SHOW COLUMNS IN {fq_table}" + ) - verify_cursor.close() - finally: - verify_conn.close() + def test_describe_query_blocked(self, mst_conn_params, mst_table): + """DESCRIBE QUERY is blocked in MST (DescribeQueryCommand).""" + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, + fq_table, + f"DESCRIBE QUERY SELECT * FROM {fq_table}", + ) - finally: - # Cleanup second table - self.connection.autocommit = True - cursor = self.connection.cursor() - cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") - cursor.close() - - def test_multi_table_transaction_rollback(self): - """Test atomic rollback across multiple tables.""" - fq_table1_name = self._get_fully_qualified_table_name() - table2_name = self.TEST_TABLE_NAME + "_2" - fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" - - # Create second table - cursor = self.connection.cursor() - cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") - cursor.execute( - f""" - CREATE TABLE IF NOT EXISTS {fq_table2_name} - (id INT, category STRING) - USING DELTA - TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') - """ + # DESCRIBE TABLE is explicitly listed as an allowed command in the server's + # TRANSACTION_NOT_SUPPORTED.COMMAND error message: + # "Only SELECT / INSERT / MERGE / UPDATE / DELETE / DESCRIBE TABLE are supported." + def test_describe_table_not_blocked(self, mst_conn_params, mst_table): + """DESCRIBE TABLE succeeds in MST — explicitly allowed by the server.""" + fq_table, _ = mst_table + self._assert_not_blocked( + mst_conn_params, fq_table, f"DESCRIBE TABLE {fq_table}" ) - cursor.close() - try: - # Start transaction and insert into both tables - self.connection.autocommit = False - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table1_name} (id, value) VALUES (20, 'rollback1')" - ) - cursor.execute( - f"INSERT INTO {fq_table2_name} (id, category) VALUES (20, 'rollback2')" - ) - cursor.close() +# ==================== E. EXECUTE VARIANTS ==================== - # Rollback both atomically - self.connection.rollback() - # Verify both inserts were rolled back - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() +class TestMstExecuteVariants: + """Execute method variants (executemany) inside MST.""" - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 20" + def test_executemany_in_txn(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.executemany( + f"INSERT INTO {fq_table} VALUES (%(id)s, %(value)s)", + [ + {"id": 1, "value": "a"}, + {"id": 2, "value": "b"}, + {"id": 3, "value": "c"}, + ], ) - result = verify_cursor.fetchone() - assert result[0] == 0, "Table1 insert should be rolled back" - - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 20" + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 3 + + def test_executemany_rollback_in_txn(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.executemany( + f"INSERT INTO {fq_table} VALUES (%(id)s, %(value)s)", + [{"id": 1, "value": "a"}, {"id": 2, "value": "b"}], ) - result = verify_cursor.fetchone() - assert result[0] == 0, "Table2 insert should be rolled back" + conn.rollback() - verify_cursor.close() - finally: - verify_conn.close() - - finally: - # Cleanup second table - self.connection.autocommit = True - cursor = self.connection.cursor() - cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") - cursor.close() - - # ==================== ERROR HANDLING TESTS ==================== - - def test_set_autocommit_during_active_transaction(self): - """Test that setting autocommit during an active transaction throws error.""" - fq_table_name = self._get_fully_qualified_table_name() - - # Start transaction - self.connection.autocommit = False - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (99, 'test')") - cursor.close() - - # Try to set autocommit=True during active transaction - with pytest.raises(TransactionError) as exc_info: - self.connection.autocommit = True - - # Verify error message mentions autocommit or active transaction - error_msg = str(exc_info.value).lower() - assert ( - "autocommit" in error_msg or "active transaction" in error_msg - ), "Error should mention autocommit or active transaction" - - # Cleanup - rollback the transaction - self.connection.rollback() - - def test_commit_without_active_transaction_throws_error(self): - """Test that commit() throws error when autocommit=true (no active transaction).""" - # Ensure autocommit is true (default) - assert self.connection.autocommit is True - - # Attempt commit without active transaction should throw - with pytest.raises(TransactionError) as exc_info: - self.connection.commit() - - # Verify error message indicates no active transaction - error_message = str(exc_info.value) - assert ( - "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION" in error_message - or "no active transaction" in error_message.lower() - ), "Error should indicate no active transaction" - - def test_rollback_without_active_transaction_is_safe(self): - """Test that rollback() without active transaction is a safe no-op.""" - # With autocommit=true (no active transaction) - assert self.connection.autocommit is True - - # ROLLBACK should be safe (no exception) - self.connection.rollback() - - # Verify connection is still usable - assert self.connection.autocommit is True - assert self.connection.open is True - - # ==================== TRANSACTION ISOLATION TESTS ==================== - - def test_get_transaction_isolation_returns_repeatable_read(self): - """Test that get_transaction_isolation() returns REPEATABLE_READ.""" - isolation_level = self.connection.get_transaction_isolation() - assert ( - isolation_level == "REPEATABLE_READ" - ), "Databricks MST should use REPEATABLE_READ (Snapshot Isolation)" - - def test_set_transaction_isolation_accepts_repeatable_read(self): - """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" - # Should not raise - these are all valid formats - self.connection.set_transaction_isolation("REPEATABLE_READ") - self.connection.set_transaction_isolation("REPEATABLE READ") - self.connection.set_transaction_isolation("repeatable_read") - self.connection.set_transaction_isolation("repeatable read") - - def test_set_transaction_isolation_rejects_unsupported_level(self): - """Test that set_transaction_isolation() rejects unsupported levels.""" - with pytest.raises(NotSupportedError) as exc_info: - self.connection.set_transaction_isolation("READ_COMMITTED") - - error_message = str(exc_info.value) - assert "not supported" in error_message.lower() - assert "READ_COMMITTED" in error_message + assert _get_row_count(mst_conn_params, fq_table) == 0 diff --git a/tests/unit/test_agent_detection.py b/tests/unit/test_agent_detection.py new file mode 100644 index 000000000..0be404a1d --- /dev/null +++ b/tests/unit/test_agent_detection.py @@ -0,0 +1,51 @@ +import pytest +from databricks.sql.common.agent import detect, KNOWN_AGENTS + + +class TestAgentDetection: + def test_detects_single_agent_claude_code(self): + assert detect({"CLAUDECODE": "1"}) == "claude-code" + + def test_detects_single_agent_cursor(self): + assert detect({"CURSOR_AGENT": "1"}) == "cursor" + + def test_detects_single_agent_gemini_cli(self): + assert detect({"GEMINI_CLI": "1"}) == "gemini-cli" + + def test_detects_single_agent_cline(self): + assert detect({"CLINE_ACTIVE": "1"}) == "cline" + + def test_detects_single_agent_codex(self): + assert detect({"CODEX_CI": "1"}) == "codex" + + def test_detects_single_agent_opencode(self): + assert detect({"OPENCODE": "1"}) == "opencode" + + def test_detects_single_agent_antigravity(self): + assert detect({"ANTIGRAVITY_AGENT": "1"}) == "antigravity" + + def test_returns_empty_when_no_agent_detected(self): + assert detect({}) == "" + + def test_returns_empty_when_multiple_agents_detected(self): + assert detect({"CLAUDECODE": "1", "CURSOR_AGENT": "1"}) == "" + + def test_ignores_empty_env_var_values(self): + assert detect({"CLAUDECODE": ""}) == "" + + def test_all_known_agents_are_covered(self): + for env_var, product in KNOWN_AGENTS: + assert detect({env_var: "1"}) == product, ( + f"Agent with env var {env_var} should be detected as {product}" + ) + + def test_defaults_to_os_environ(self, monkeypatch): + monkeypatch.delenv("CLAUDECODE", raising=False) + monkeypatch.delenv("CURSOR_AGENT", raising=False) + monkeypatch.delenv("GEMINI_CLI", raising=False) + monkeypatch.delenv("CLINE_ACTIVE", raising=False) + monkeypatch.delenv("CODEX_CI", raising=False) + monkeypatch.delenv("OPENCODE", raising=False) + monkeypatch.delenv("ANTIGRAVITY_AGENT", raising=False) + # With all agent vars cleared, detect() should return empty + assert detect() == "" diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5b6991931..4a8cb0b68 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -87,6 +87,7 @@ class ClientTestSuite(unittest.TestCase): "server_hostname": "foo", "http_path": "dummy_path", "access_token": "tok", + "enable_telemetry": False, } @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) @@ -644,6 +645,7 @@ class TransactionTestSuite(unittest.TestCase): "server_hostname": "foo", "http_path": "dummy_path", "access_token": "tok", + "enable_telemetry": False, } def _setup_mock_session_with_http_client(self, mock_session): diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 0c3fc7103..97bb99ad9 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -174,7 +174,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): assert ( result == pyarrow.concat_tables( - [self.make_arrow_table(), self.make_arrow_table()] + [self.make_arrow_table(), self.make_arrow_table()],promote_options="default" )[:7] ) @@ -266,7 +266,7 @@ def test_remaining_rows_multiple_tables_fully_returned( assert ( result == pyarrow.concat_tables( - [self.make_arrow_table(), self.make_arrow_table()] + [self.make_arrow_table(), self.make_arrow_table()], promote_options="default" )[3:] ) diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index cf2e24951..0588eb499 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -295,7 +295,8 @@ def test_tspark_param_ordinal(self): (BigIntegerParameter, Primitive.BIGINT), (BooleanParameter, Primitive.BOOL), (DateParameter, Primitive.DATE), - (FloatParameter, Primitive.FLOAT), + (DoubleParameter, Primitive.DOUBLE), + (DoubleParameter, Primitive.FLOAT), (VoidParameter, Primitive.NONE), (TimestampParameter, Primitive.TIMESTAMP), (MapParameter, Primitive.MAP), @@ -305,7 +306,7 @@ def test_tspark_param_ordinal(self): def test_inference(self, _type: TDbsqlParameter, prim: Primitive): """This method only tests inferrable types. - Not tested are TinyIntParameter, SmallIntParameter DoubleParameter and TimestampNTZParameter + Not tested are TinyIntParameter, SmallIntParameter, FloatParameter and TimestampNTZParameter """ inferred_type = dbsql_parameter_from_primitive(prim.value) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 26a898cb8..24a5e8242 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -143,6 +143,39 @@ def test_initialization(self, mock_http_client): ) assert client2.warehouse_id == "def456" + # Test with SPOG query param ?o= in http_path + client_spog = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/1.0/warehouses/abc123?o=6051921418418893", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client_spog.warehouse_id == "abc123" + + # Test with SPOG query param on endpoints path + client_spog_ep = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/1.0/endpoints/def456?o=6051921418418893", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client_spog_ep.warehouse_id == "def456" + + # Test with multiple query params + client_spog_multi = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/1.0/warehouses/abc123?o=123&extra=val", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client_spog_multi.warehouse_id == "abc123" + # Test with custom max_download_threads client3 = SeaDatabricksClient( server_hostname="test-server.databricks.com", @@ -185,7 +218,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i session_config = { "ANSI_MODE": "FALSE", # Supported parameter "STATEMENT_TIMEOUT": "3600", # Supported parameter - "QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter + "QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter "unsupported_param": "value", # Unsupported parameter } catalog = "test_catalog" @@ -197,7 +230,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i "session_confs": { "ansi_mode": "FALSE", "statement_timeout": "3600", - "query_tags": "team:marketing,dashboard:abc123", + "query_tags": "team:marketing,dashboard:abc123", }, "catalog": catalog, "schema": schema, @@ -416,6 +449,112 @@ def test_command_execution_advanced( ) assert "Command failed" in str(excinfo.value) + def _execute_response(self): + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + def _run_execute_command(self, sea_client, sea_session_id, mock_cursor, **kwargs): + """Helper to invoke execute_command with default args.""" + return sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + **kwargs, + ) + + def test_execute_command_query_tags_string_values( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """query_tags with string values are included in the request payload.""" + mock_http_client._make_request.return_value = self._execute_response() + with patch.object(sea_client, "_response_to_result_set"): + self._run_execute_command( + sea_client, + sea_session_id, + mock_cursor, + query_tags={"env": "prod", "team": "data"}, + ) + _, kwargs = mock_http_client._make_request.call_args + assert kwargs["data"]["query_tags"] == [ + {"key": "env", "value": "prod"}, + {"key": "team", "value": "data"}, + ] + + def test_execute_command_query_tags_none_value( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """query_tags with a None value omit the value field (key-only tag).""" + mock_http_client._make_request.return_value = self._execute_response() + with patch.object(sea_client, "_response_to_result_set"): + self._run_execute_command( + sea_client, + sea_session_id, + mock_cursor, + query_tags={"env": "prod", "team": None}, + ) + _, kwargs = mock_http_client._make_request.call_args + assert kwargs["data"]["query_tags"] == [ + {"key": "env", "value": "prod"}, + {"key": "team", "value": None}, + ] + + def test_execute_command_no_query_tags_omitted( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """query_tags field is absent from the request when not provided.""" + mock_http_client._make_request.return_value = self._execute_response() + with patch.object(sea_client, "_response_to_result_set"): + self._run_execute_command(sea_client, sea_session_id, mock_cursor) + _, kwargs = mock_http_client._make_request.call_args + assert "query_tags" not in kwargs["data"] + + def test_execute_command_empty_query_tags_omitted( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Empty query_tags dict is treated as absent — field omitted from request.""" + mock_http_client._make_request.return_value = self._execute_response() + with patch.object(sea_client, "_response_to_result_set"): + self._run_execute_command( + sea_client, sea_session_id, mock_cursor, query_tags={} + ) + _, kwargs = mock_http_client._make_request.call_args + assert "query_tags" not in kwargs["data"] + + def test_execute_command_async_query_tags( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """query_tags are included in async execute requests (execute_async path).""" + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-async", + "status": {"state": "PENDING"}, + } + sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, + enforce_embedded_schema_correctness=False, + query_tags={"job": "nightly-etl"}, + ) + _, kwargs = mock_http_client._make_request.call_args + assert kwargs["data"]["query_tags"] == [{"key": "job", "value": "nightly-etl"}] + def test_command_management( self, sea_client, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1d70ec4c4..136c99e53 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -8,6 +8,7 @@ THandleIdentifier, ) from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.session import Session import databricks.sql @@ -22,6 +23,7 @@ class TestSession: "server_hostname": "foo", "http_path": "dummy_path", "access_token": "tok", + "enable_telemetry": False, } @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) @@ -50,6 +52,7 @@ def test_auth_args(self, mock_client_class): "server_hostname": "foo", "http_path": None, "access_token": "tok", + "enable_telemetry": False, }, { "server_hostname": "foo", @@ -57,6 +60,7 @@ def test_auth_args(self, mock_client_class): "_tls_client_cert_file": "something", "_use_cert_as_auth": True, "access_token": None, + "enable_telemetry": False, }, ] @@ -202,3 +206,67 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): close_session_call_args = instance.close_session.call_args[0][0] assert close_session_call_args.guid == b"\x22" assert close_session_call_args.secret == b"\x33" + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_query_tags_dict_sets_session_config(self, mock_client_class): + databricks.sql.connect( + query_tags={"team": "data-eng", "project": "etl"}, + **self.DUMMY_CONNECTION_ARGS, + ) + + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + assert call_kwargs["session_configuration"]["QUERY_TAGS"] == "team:data-eng,project:etl" + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_query_tags_dict_takes_precedence_over_session_config(self, mock_client_class): + databricks.sql.connect( + query_tags={"team": "new-team"}, + session_configuration={"QUERY_TAGS": "team:old-team,other:value"}, + **self.DUMMY_CONNECTION_ARGS, + ) + + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + assert call_kwargs["session_configuration"]["QUERY_TAGS"] == "team:new-team" + + +class TestSpogHeaders: + """Unit tests for SPOG header extraction from http_path.""" + + def test_extracts_org_id_from_query_param(self): + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123?o=6051921418418893", [] + ) + assert result == {"x-databricks-org-id": "6051921418418893"} + + def test_no_query_param_returns_empty(self): + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123", [] + ) + assert result == {} + + def test_no_o_param_returns_empty(self): + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123?other=value", [] + ) + assert result == {} + + def test_empty_http_path_returns_empty(self): + result = Session._extract_spog_headers("", []) + assert result == {} + + def test_none_http_path_returns_empty(self): + result = Session._extract_spog_headers(None, []) + assert result == {} + + def test_explicit_header_takes_precedence(self): + existing = [("x-databricks-org-id", "explicit-value")] + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123?o=6051921418418893", existing + ) + assert result == {} + + def test_multiple_query_params(self): + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123?o=12345&extra=val", [] + ) + assert result == {"x-databricks-org-id": "12345"} diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 713342b2e..687bdd391 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -6,6 +6,7 @@ convert_to_assigned_datatypes_in_column_table, ColumnTable, concat_table_chunks, + serialize_query_tags, ) try: @@ -161,3 +162,65 @@ def test_concat_table_chunks__incorrect_column_names_error(self): with pytest.raises(ValueError): concat_table_chunks([column_table1, column_table2]) + + def test_serialize_query_tags_basic(self): + """Test basic query tags serialization""" + query_tags = {"team": "data-eng", "application": "etl"} + result = serialize_query_tags(query_tags) + assert result == "team:data-eng,application:etl" + + def test_serialize_query_tags_with_none_value(self): + """Test query tags with None value (should omit colon and value)""" + query_tags = {"key1": "value1", "key2": None, "key3": "value3"} + result = serialize_query_tags(query_tags) + assert result == "key1:value1,key2,key3:value3" + + def test_serialize_query_tags_with_special_chars(self): + """Test query tags with special characters (colon, comma, backslash)""" + query_tags = { + "key1": "value:with:colons", + "key2": "value,with,commas", + "key3": r"value\with\backslashes", + } + result = serialize_query_tags(query_tags) + assert ( + result + == r"key1:value\:with\:colons,key2:value\,with\,commas,key3:value\\with\\backslashes" + ) + + def test_serialize_query_tags_with_mixed_special_chars(self): + """Test query tags with mixed special characters""" + query_tags = {"key1": r"a:b,c\d"} + result = serialize_query_tags(query_tags) + assert result == r"key1:a\:b\,c\\d" + + def test_serialize_query_tags_empty_dict(self): + """Test serialization with empty dictionary""" + query_tags = {} + result = serialize_query_tags(query_tags) + assert result is None + + def test_serialize_query_tags_none(self): + """Test serialization with None input""" + result = serialize_query_tags(None) + assert result is None + + def test_serialize_query_tags_with_special_chars_in_key(self): + """Test query tags with special characters in keys (only backslashes are escaped in keys)""" + query_tags = { + "key:with:colons": "value1", + "key,with,commas": "value2", + r"key\with\backslashes": "value3", + } + result = serialize_query_tags(query_tags) + # Only backslashes are escaped in keys; colons and commas in keys are not escaped + assert ( + result + == r"key:with:colons:value1,key,with,commas:value2,key\\with\\backslashes:value3" + ) + + def test_serialize_query_tags_all_none_values(self): + """Test query tags where all values are None""" + query_tags = {"key1": None, "key2": None, "key3": None} + result = serialize_query_tags(query_tags) + assert result == "key1,key2,key3"