From fd65fd27b2265be8a1976a0e0240ff0a475abeea Mon Sep 17 00:00:00 2001 From: msrathore-db Date: Fri, 24 Oct 2025 16:35:51 +0530 Subject: [PATCH 01/39] Added a workflow to parallelise the E2E tests(#697) * Added a workflow to parallelise the E2E tests. Updated E2E tests to create new table names for each run to avoid issue in parallelisation * Modified parallel code coverage workflow to fail when e2e tests fail * Fixed parallel coverage check pytest command * Modified e2e tests to support parallel execution * Added fallbacks for exit code 5 where we find no test for SEA * Fixed coverage artifact for parallel test workflow * Debugging coverage check merge * Improved coverage report merge and removed the test_driver test for faster testing * Debug commit for coverage merge * Debugging coverage merge 2 * Debugging coverage merge 3 * Removed unnecessary debug statements from the parallel code coverage workflow * Added unit test and common e2e tests * Added null checks for coverage workflow * Improved the null check for test list * Improved the visibility for test list * Added check for exit code 5 * Updated the workflowfor coverage check to use pytst -xdist to run the tests parallely * Enforced the e2e tests should pass * Changed name for workflow job * Updated poetry * Removed integration and previous code coverage workflow * Added the integration workflow again --- .../{coverage-check.yml => code-coverage.yml} | 25 +++--- .github/workflows/integration.yml | 2 +- poetry.lock | 82 ++++++++++++++++--- pyproject.toml | 1 + tests/e2e/test_complex_types.py | 20 +++-- tests/e2e/test_parameterized_queries.py | 64 +++++++++------ tests/e2e/test_variant_types.py | 13 +-- 7 files changed, 145 insertions(+), 62 deletions(-) rename .github/workflows/{coverage-check.yml => code-coverage.yml} (92%) diff --git a/.github/workflows/coverage-check.yml b/.github/workflows/code-coverage.yml similarity index 92% rename from .github/workflows/coverage-check.yml rename to .github/workflows/code-coverage.yml index 51e42f9e7..d9954d051 100644 --- a/.github/workflows/coverage-check.yml +++ b/.github/workflows/code-coverage.yml @@ -6,7 +6,7 @@ permissions: on: [pull_request, workflow_dispatch] jobs: - coverage: + test-with-coverage: runs-on: ubuntu-latest environment: azure-prod env: @@ -22,9 +22,9 @@ jobs: - name: Check out repository uses: actions/checkout@v4 with: - fetch-depth: 0 # Needed for coverage comparison + fetch-depth: 0 ref: ${{ github.event.pull_request.head.ref || github.ref_name }} - repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }} + repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }} - name: Set up python id: setup-python uses: actions/setup-python@v5 @@ -61,14 +61,18 @@ jobs: - name: Install library run: poetry install --no-interaction --all-extras #---------------------------------------------- - # run all tests + # run all tests with coverage #---------------------------------------------- - - name: Run tests with coverage - continue-on-error: true + - name: Run all tests with coverage + continue-on-error: false run: | - poetry run python -m pytest \ - tests/unit tests/e2e \ - --cov=src --cov-report=xml --cov-report=term -v + poetry run pytest tests/unit tests/e2e \ + -n auto \ + --cov=src \ + --cov-report=xml \ + --cov-report=term \ + -v + #---------------------------------------------- # check for coverage override #---------------------------------------------- @@ -128,4 +132,5 @@ jobs: echo "Please ensure this override is justified and temporary" else echo "✅ Coverage checks enforced - minimum 85% required" - fi \ No newline at end of file + fi + diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 127c8ff4f..9c9e30a24 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -55,4 +55,4 @@ jobs: # run test suite #---------------------------------------------- - name: Run e2e tests - run: poetry run python -m pytest tests/e2e + run: poetry run python -m pytest tests/e2e -n auto \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 5fd216330..1a8074c2a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -70,7 +70,7 @@ description = "Foreign Function Interface for Python calling C code." optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\" and platform_python_implementation != \"PyPy\"" +markers = "platform_python_implementation != \"PyPy\"" files = [ {file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"}, {file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"}, @@ -475,7 +475,7 @@ description = "cryptography is a package which provides cryptographic recipes an optional = true python-versions = ">=3.7" groups = ["main"] -markers = "python_version < \"3.10\" and extra == \"true\"" +markers = "python_version < \"3.10\"" files = [ {file = "cryptography-43.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bf7a1932ac4176486eab36a19ed4c0492da5d97123f1406cf15e41b05e787d2e"}, {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63efa177ff54aec6e1c0aefaa1a241232dcd37413835a9b674b6e3f0ae2bfd3e"}, @@ -526,7 +526,7 @@ description = "cryptography is a package which provides cryptographic recipes an optional = true python-versions = "!=3.9.0,!=3.9.1,>=3.7" groups = ["main"] -markers = "python_version >= \"3.10\" and extra == \"true\"" +markers = "python_version >= \"3.10\"" files = [ {file = "cryptography-45.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:048e7ad9e08cf4c0ab07ff7f36cc3115924e22e2266e034450a890d9e312dd74"}, {file = "cryptography-45.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:44647c5d796f5fc042bbc6d61307d04bf29bccb74d188f18051b635f20a9c75f"}, @@ -587,7 +587,7 @@ description = "Decorators for Humans" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\" and sys_platform != \"win32\"" +markers = "sys_platform != \"win32\"" files = [ {file = "decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a"}, {file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"}, @@ -637,6 +637,21 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "execnet" +version = "2.1.1" +description = "execnet: rapid multi-Python deployment" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc"}, + {file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"}, +] + +[package.extras] +testing = ["hatch", "pre-commit", "pytest", "tox"] + [[package]] name = "gssapi" version = "1.9.0" @@ -644,7 +659,7 @@ description = "Python GSSAPI Wrapper" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\" and sys_platform != \"win32\"" +markers = "sys_platform != \"win32\"" files = [ {file = "gssapi-1.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:261e00ac426d840055ddb2199f4989db7e3ce70fa18b1538f53e392b4823e8f1"}, {file = "gssapi-1.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:14a1ae12fdf1e4c8889206195ba1843de09fe82587fa113112887cd5894587c6"}, @@ -725,7 +740,7 @@ description = "Kerberos API bindings for Python" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\" and sys_platform != \"win32\"" +markers = "sys_platform != \"win32\"" files = [ {file = "krb5-0.7.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cbdcd2c4514af5ca32d189bc31f30fee2ab297dcbff74a53bd82f92ad1f6e0ef"}, {file = "krb5-0.7.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:40ad837d563865946cffd65a588f24876da2809aa5ce4412de49442d7cf11d50"}, @@ -1340,7 +1355,7 @@ description = "C parser in Python" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\" and platform_python_implementation != \"PyPy\"" +markers = "platform_python_implementation != \"PyPy\"" files = [ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, @@ -1422,7 +1437,6 @@ description = "Windows Negotiate Authentication Client and Server" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\"" files = [ {file = "pyspnego-0.11.2-py3-none-any.whl", hash = "sha256:74abc1fb51e59360eb5c5c9086e5962174f1072c7a50cf6da0bda9a4bcfdfbd4"}, {file = "pyspnego-0.11.2.tar.gz", hash = "sha256:994388d308fb06e4498365ce78d222bf4f3570b6df4ec95738431f61510c971b"}, @@ -1496,6 +1510,50 @@ files = [ pytest = ">=5.0.0" python-dotenv = ">=0.9.1" +[[package]] +name = "pytest-xdist" +version = "3.6.1" +description = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version < \"3.10\"" +files = [ + {file = "pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7"}, + {file = "pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d"}, +] + +[package.dependencies] +execnet = ">=2.1" +pytest = ">=7.0.0" + +[package.extras] +psutil = ["psutil (>=3.0)"] +setproctitle = ["setproctitle"] +testing = ["filelock"] + +[[package]] +name = "pytest-xdist" +version = "3.8.0" +description = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88"}, + {file = "pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1"}, +] + +[package.dependencies] +execnet = ">=2.1" +pytest = ">=7.0.0" + +[package.extras] +psutil = ["psutil (>=3.0)"] +setproctitle = ["setproctitle"] +testing = ["filelock"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1567,7 +1625,6 @@ description = "A Kerberos authentication handler for python-requests" optional = true python-versions = ">=3.6" groups = ["main"] -markers = "extra == \"true\"" files = [ {file = "requests_kerberos-0.15.0-py2.py3-none-any.whl", hash = "sha256:ba9b0980b8489c93bfb13854fd118834e576d6700bfea3745cb2e62278cd16a6"}, {file = "requests_kerberos-0.15.0.tar.gz", hash = "sha256:437512e424413d8113181d696e56694ffa4259eb9a5fc4e803926963864eaf4e"}, @@ -1597,7 +1654,7 @@ description = "SSPI API bindings for Python" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\" and sys_platform == \"win32\" and python_version < \"3.10\"" +markers = "python_version < \"3.10\" and sys_platform == \"win32\"" files = [ {file = "sspilib-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:34f566ba8b332c91594e21a71200de2d4ce55ca5a205541d4128ed23e3c98777"}, {file = "sspilib-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b11e4f030de5c5de0f29bcf41a6e87c9fd90cb3b0f64e446a6e1d1aef4d08f5"}, @@ -1644,7 +1701,7 @@ description = "SSPI API bindings for Python" optional = true python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"true\" and sys_platform == \"win32\" and python_version >= \"3.10\"" +markers = "python_version >= \"3.10\" and sys_platform == \"win32\"" files = [ {file = "sspilib-0.3.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:c45860bdc4793af572d365434020ff5a1ef78c42a2fc2c7a7d8e44eacaf475b6"}, {file = "sspilib-0.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:62cc4de547503dec13b81a6af82b398e9ef53ea82c3535418d7d069c7a05d5cd"}, @@ -1797,9 +1854,8 @@ zstd = ["zstandard (>=0.18.0)"] [extras] pyarrow = ["pyarrow", "pyarrow"] -true = ["requests-kerberos"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "ddc7354d47a940fa40b4d34c43a1c42488b01258d09d771d58d64a0dfaf0b955" +content-hash = "0a3f611ef8747376f018c1df0a1ea7873368851873cc4bd3a4d51bba0bba847c" diff --git a/pyproject.toml b/pyproject.toml index 3240e2c5b..c0eb8244d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ pylint = ">=2.12.0" black = "^22.3.0" pytest-dotenv = "^0.5.2" pytest-cov = "^4.0.0" +pytest-xdist = "^3.0.0" numpy = [ { version = ">=1.16.6", python = ">=3.8,<3.11" }, { version = ">=1.23.4", python = ">=3.11" }, diff --git a/tests/e2e/test_complex_types.py b/tests/e2e/test_complex_types.py index 212ddf916..d075a5670 100644 --- a/tests/e2e/test_complex_types.py +++ b/tests/e2e/test_complex_types.py @@ -1,6 +1,7 @@ import pytest from numpy import ndarray from typing import Sequence +from uuid import uuid4 from tests.e2e.test_driver import PySQLPytestTestCase @@ -10,12 +11,15 @@ class TestComplexTypes(PySQLPytestTestCase): def table_fixture(self, connection_details): self.arguments = connection_details.copy() """A pytest fixture that creates a table with a complex type, inserts a record, yields, and then drops the table""" + + table_name = f"pysql_test_complex_types_table_{str(uuid4()).replace('-', '_')}" + self.table_name = table_name with self.cursor() as cursor: # Create the table cursor.execute( - """ - CREATE TABLE IF NOT EXISTS pysql_test_complex_types_table ( + f""" + CREATE TABLE IF NOT EXISTS {table_name} ( array_col ARRAY, map_col MAP, struct_col STRUCT, @@ -27,8 +31,8 @@ def table_fixture(self, connection_details): ) # Insert a record cursor.execute( - """ - INSERT INTO pysql_test_complex_types_table + f""" + INSERT INTO {table_name} VALUES ( ARRAY('a', 'b', 'c'), MAP('a', 1, 'b', 2, 'c', 3), @@ -40,10 +44,10 @@ def table_fixture(self, connection_details): """ ) try: - yield + yield table_name finally: # Clean up the table after the test - cursor.execute("DELETE FROM pysql_test_complex_types_table") + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") @pytest.mark.parametrize( "field,expected_type", @@ -61,7 +65,7 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): with self.cursor() as cursor: result = cursor.execute( - "SELECT * FROM pysql_test_complex_types_table LIMIT 1" + f"SELECT * FROM {table_fixture} LIMIT 1" ).fetchone() assert isinstance(result[field], expected_type) @@ -83,7 +87,7 @@ def test_read_complex_types_as_string(self, field, table_fixture): extra_params={"_use_arrow_native_complex_types": False} ) as cursor: result = cursor.execute( - "SELECT * FROM pysql_test_complex_types_table LIMIT 1" + f"SELECT * FROM {table_fixture} LIMIT 1" ).fetchone() assert isinstance(result[field], str) diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index 79def9b72..7370eea93 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -4,6 +4,7 @@ from enum import Enum from typing import Dict, List, Type, Union from unittest.mock import patch +from uuid import uuid4 import time import numpy as np @@ -118,21 +119,10 @@ class TestParameterizedQueries(PySQLPytestTestCase): def _get_inline_table_column(self, value): return self.inline_type_map[Primitive(value)] - @pytest.fixture(scope="class") - def inline_table(self, connection_details): - self.arguments = connection_details.copy() - """This table is necessary to verify that a parameter sent with INLINE - approach can actually write to its analogous data type. - - For example, a Python Decimal(), when rendered inline, should be able - to read/write into a DECIMAL column in Databricks - - Note that this fixture doesn't clean itself up. So the table will remain - in the schema for use by subsequent test runs. - """ - - query = """ - CREATE TABLE IF NOT EXISTS pysql_e2e_inline_param_test_table ( + def _create_inline_table(self, table_name): + """Create the inline test table with all necessary columns""" + query = f""" + CREATE TABLE IF NOT EXISTS {table_name} ( null_col INT, int_col INT, bigint_col BIGINT, @@ -155,6 +145,24 @@ def inline_table(self, connection_details): with conn.cursor() as cursor: cursor.execute(query) + @pytest.fixture(scope="class") + def inline_table(self, connection_details): + self.arguments = connection_details.copy() + """This table is necessary to verify that a parameter sent with INLINE + approach can actually write to its analogous data type. + + For example, a Python Decimal(), when rendered inline, should be able + to read/write into a DECIMAL column in Databricks + + Note that this fixture doesn't clean itself up. So the table will remain + in the schema for use by subsequent test runs. + """ + + # Generate unique table name to avoid conflicts in parallel execution + table_name = f"pysql_e2e_inline_param_test_table_{str(uuid4()).replace('-', '_')}" + self.inline_table_name = table_name + self._create_inline_table(table_name) + @contextmanager def patch_server_supports_native_params(self, supports_native_params: bool = True): """Applies a patch so we can test the connector's behaviour under different SPARK_CLI_SERVICE_PROTOCOL_VERSION conditions.""" @@ -179,17 +187,25 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column) :paramstyle: This is a no-op but is included to make the test-code easier to read. """ - INSERT_QUERY = f"INSERT INTO pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)" - SELECT_QUERY = f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1" - DELETE_QUERY = "DELETE FROM pysql_e2e_inline_param_test_table" + if not hasattr(self, 'inline_table_name'): + table_name = f"pysql_e2e_inline_param_test_table_{str(uuid4()).replace('-', '_')}" + self.inline_table_name = table_name + self._create_inline_table(table_name) + + table_name = self.inline_table_name + INSERT_QUERY = f"INSERT INTO {table_name} (`{target_column}`) VALUES (%(p)s)" + SELECT_QUERY = f"SELECT {target_column} `col` FROM {table_name} LIMIT 1" + DELETE_QUERY = f"DELETE FROM {table_name}" with self.connection(extra_params={"use_inline_params": True}) as conn: - with conn.cursor() as cursor: - cursor.execute(INSERT_QUERY, parameters=params) - with conn.cursor() as cursor: - to_return = cursor.execute(SELECT_QUERY).fetchone() - with conn.cursor() as cursor: - cursor.execute(DELETE_QUERY) + try: + with conn.cursor() as cursor: + cursor.execute(INSERT_QUERY, parameters=params) + with conn.cursor() as cursor: + to_return = cursor.execute(SELECT_QUERY).fetchone() + finally: + with conn.cursor() as cursor: + cursor.execute(DELETE_QUERY) return to_return diff --git a/tests/e2e/test_variant_types.py b/tests/e2e/test_variant_types.py index b5dc1f421..14be3aa3d 100644 --- a/tests/e2e/test_variant_types.py +++ b/tests/e2e/test_variant_types.py @@ -1,6 +1,7 @@ import pytest from datetime import datetime import json +from uuid import uuid4 try: import pyarrow @@ -19,14 +20,14 @@ class TestVariantTypes(PySQLPytestTestCase): def variant_table(self, connection_details): """A pytest fixture that creates a test table and cleans up after tests""" self.arguments = connection_details.copy() - table_name = "pysql_test_variant_types_table" + table_name = f"pysql_test_variant_types_table_{str(uuid4()).replace('-', '_')}" with self.cursor() as cursor: try: # Create the table with variant columns cursor.execute( - """ - CREATE TABLE IF NOT EXISTS pysql_test_variant_types_table ( + f""" + CREATE TABLE IF NOT EXISTS {table_name} ( id INTEGER, variant_col VARIANT, regular_string_col STRING @@ -36,10 +37,10 @@ def variant_table(self, connection_details): # Insert test records with different variant values cursor.execute( - """ - INSERT INTO pysql_test_variant_types_table + f""" + INSERT INTO {table_name} VALUES - (1, PARSE_JSON('{"name": "John", "age": 30}'), 'regular string'), + (1, PARSE_JSON('{{\"name\": \"John\", \"age\": 30}}'), 'regular string'), (2, PARSE_JSON('[1, 2, 3, 4]'), 'another string') """ ) From 3b37dd203ad40f6b0be15fa0cfb2f3a707b37159 Mon Sep 17 00:00:00 2001 From: nikhilsuri-db Date: Tue, 11 Nov 2025 02:05:06 +0530 Subject: [PATCH 02/39] Bring Python telemetry event model consistent with JDBC (#701) * Added driver connection params Signed-off-by: Nikhil Suri * Added model fields for chunk/result latency Signed-off-by: Nikhil Suri * fixed linting issues Signed-off-by: Nikhil Suri * lint issue fixing Signed-off-by: Nikhil Suri --------- Signed-off-by: Nikhil Suri --- src/databricks/sql/client.py | 35 +- .../sql/common/unified_http_client.py | 5 + src/databricks/sql/telemetry/models/event.py | 109 +++++- .../sql/telemetry/telemetry_client.py | 2 +- tests/unit/test_telemetry.py | 365 +++++++++++++++++- 5 files changed, 512 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 5bb191ca2..5e5b9cedc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -9,6 +9,7 @@ import json import os import decimal +from urllib.parse import urlparse from uuid import UUID from databricks.sql import __version__ @@ -322,6 +323,20 @@ def read(self) -> Optional[OAuthToken]: session_id_hex=self.get_session_id_hex() ) + # Determine proxy usage + use_proxy = self.http_client.using_proxy() + proxy_host_info = None + if ( + use_proxy + and self.http_client.proxy_uri + and isinstance(self.http_client.proxy_uri, str) + ): + parsed = urlparse(self.http_client.proxy_uri) + proxy_host_info = HostDetails( + host_url=parsed.hostname or self.http_client.proxy_uri, + port=parsed.port or 8080, + ) + driver_connection_params = DriverConnectionParameters( http_path=http_path, mode=DatabricksClientType.SEA @@ -331,13 +346,31 @@ def read(self) -> Optional[OAuthToken]: auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider), auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider), socket_timeout=kwargs.get("_socket_timeout", None), + azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id", None), + azure_tenant_id=kwargs.get("azure_tenant_id", None), + use_proxy=use_proxy, + use_system_proxy=use_proxy, + proxy_host_info=proxy_host_info, + use_cf_proxy=False, # CloudFlare proxy not yet supported in Python + cf_proxy_host_info=None, # CloudFlare proxy not yet supported in Python + non_proxy_hosts=None, + allow_self_signed_support=kwargs.get("_tls_no_verify", False), + use_system_trust_store=True, # Python uses system SSL by default + enable_arrow=pyarrow is not None, + enable_direct_results=True, # Always enabled in Python + enable_sea_hybrid_results=kwargs.get("use_hybrid_disposition", False), + http_connection_pool_size=kwargs.get("pool_maxsize", None), + rows_fetched_per_block=DEFAULT_ARRAY_SIZE, + async_poll_interval_millis=2000, # Default polling interval + support_many_parameters=True, # Native parameters supported + enable_complex_datatype_support=_use_arrow_native_complex_types, + allowed_volume_ingestion_paths=self.staging_allowed_local_path, ) self._telemetry_client.export_initial_telemetry_log( driver_connection_params=driver_connection_params, user_agent=self.session.useragent_header, ) - self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) def _set_use_inline_params_with_warning(self, value: Union[bool, str]): """Valid values are True, False, and "silent" diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 7ccd69c54..96fb9cbb9 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -301,6 +301,11 @@ def using_proxy(self) -> bool: """Check if proxy support is available (not whether it's being used for a specific request).""" return self._proxy_pool_manager is not None + @property + def proxy_uri(self) -> Optional[str]: + """Get the configured proxy URI, if any.""" + return self._proxy_uri + def close(self): """Close the underlying connection pools.""" if self._direct_pool_manager: diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index c7f9d9d17..2e6f63a6f 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -38,6 +38,25 @@ class DriverConnectionParameters(JsonSerializableMixin): auth_mech (AuthMech): The authentication mechanism used auth_flow (AuthFlow): The authentication flow type socket_timeout (int): Connection timeout in milliseconds + azure_workspace_resource_id (str): Azure workspace resource ID + azure_tenant_id (str): Azure tenant ID + use_proxy (bool): Whether proxy is being used + use_system_proxy (bool): Whether system proxy is being used + proxy_host_info (HostDetails): Proxy host details if configured + use_cf_proxy (bool): Whether CloudFlare proxy is being used + cf_proxy_host_info (HostDetails): CloudFlare proxy host details if configured + non_proxy_hosts (list): List of hosts that bypass proxy + allow_self_signed_support (bool): Whether self-signed certificates are allowed + use_system_trust_store (bool): Whether system trust store is used + enable_arrow (bool): Whether Arrow format is enabled + enable_direct_results (bool): Whether direct results are enabled + enable_sea_hybrid_results (bool): Whether SEA hybrid results are enabled + http_connection_pool_size (int): HTTP connection pool size + rows_fetched_per_block (int): Number of rows fetched per block + async_poll_interval_millis (int): Async polling interval in milliseconds + support_many_parameters (bool): Whether many parameters are supported + enable_complex_datatype_support (bool): Whether complex datatypes are supported + allowed_volume_ingestion_paths (str): Allowed paths for volume ingestion """ http_path: str @@ -46,6 +65,25 @@ class DriverConnectionParameters(JsonSerializableMixin): auth_mech: Optional[AuthMech] = None auth_flow: Optional[AuthFlow] = None socket_timeout: Optional[int] = None + azure_workspace_resource_id: Optional[str] = None + azure_tenant_id: Optional[str] = None + use_proxy: Optional[bool] = None + use_system_proxy: Optional[bool] = None + proxy_host_info: Optional[HostDetails] = None + use_cf_proxy: Optional[bool] = None + cf_proxy_host_info: Optional[HostDetails] = None + non_proxy_hosts: Optional[list] = None + allow_self_signed_support: Optional[bool] = None + use_system_trust_store: Optional[bool] = None + enable_arrow: Optional[bool] = None + enable_direct_results: Optional[bool] = None + enable_sea_hybrid_results: Optional[bool] = None + http_connection_pool_size: Optional[int] = None + rows_fetched_per_block: Optional[int] = None + async_poll_interval_millis: Optional[int] = None + support_many_parameters: Optional[bool] = None + enable_complex_datatype_support: Optional[bool] = None + allowed_volume_ingestion_paths: Optional[str] = None @dataclass @@ -111,6 +149,69 @@ class DriverErrorInfo(JsonSerializableMixin): stack_trace: str +@dataclass +class ChunkDetails(JsonSerializableMixin): + """ + Contains detailed metrics about chunk downloads during result fetching. + + These metrics are accumulated across all chunk downloads for a single statement. + + Attributes: + initial_chunk_latency_millis (int): Latency of the first chunk download + slowest_chunk_latency_millis (int): Latency of the slowest chunk download + total_chunks_present (int): Total number of chunks available + total_chunks_iterated (int): Number of chunks actually downloaded + sum_chunks_download_time_millis (int): Total time spent downloading all chunks + """ + + initial_chunk_latency_millis: Optional[int] = None + slowest_chunk_latency_millis: Optional[int] = None + total_chunks_present: Optional[int] = None + total_chunks_iterated: Optional[int] = None + sum_chunks_download_time_millis: Optional[int] = None + + +@dataclass +class ResultLatency(JsonSerializableMixin): + """ + Contains latency metrics for different phases of query execution. + + This tracks two distinct phases: + 1. result_set_ready_latency_millis: Time from query submission until results are available (execute phase) + - Set when execute() completes + 2. result_set_consumption_latency_millis: Time spent iterating/fetching results (fetch phase) + - Measured from first fetch call until no more rows available + - In Java: tracked via markResultSetConsumption(hasNext) method + - Records start time on first fetch, calculates total on last fetch + + Attributes: + result_set_ready_latency_millis (int): Time until query results are ready (execution phase) + result_set_consumption_latency_millis (int): Time spent fetching/consuming results (fetch phase) + + """ + + result_set_ready_latency_millis: Optional[int] = None + result_set_consumption_latency_millis: Optional[int] = None + + +@dataclass +class OperationDetail(JsonSerializableMixin): + """ + Contains detailed information about the operation being performed. + + Attributes: + n_operation_status_calls (int): Number of status polling calls made + operation_status_latency_millis (int): Total latency of all status calls + operation_type (str): Specific operation type (e.g., EXECUTE_STATEMENT, LIST_TABLES, CANCEL_STATEMENT) + is_internal_call (bool): Whether this is an internal driver operation + """ + + n_operation_status_calls: Optional[int] = None + operation_status_latency_millis: Optional[int] = None + operation_type: Optional[str] = None + is_internal_call: Optional[bool] = None + + @dataclass class SqlExecutionEvent(JsonSerializableMixin): """ @@ -122,7 +223,10 @@ class SqlExecutionEvent(JsonSerializableMixin): is_compressed (bool): Whether the result is compressed execution_result (ExecutionResultFormat): Format of the execution result retry_count (int): Number of retry attempts made - chunk_id (int): ID of the chunk if applicable + chunk_id (int): ID of the chunk if applicable (used for error tracking) + chunk_details (ChunkDetails): Aggregated chunk download metrics + result_latency (ResultLatency): Latency breakdown by execution phase + operation_detail (OperationDetail): Detailed operation information """ statement_type: StatementType @@ -130,6 +234,9 @@ class SqlExecutionEvent(JsonSerializableMixin): execution_result: ExecutionResultFormat retry_count: Optional[int] chunk_id: Optional[int] + chunk_details: Optional[ChunkDetails] = None + result_latency: Optional[ResultLatency] = None + operation_detail: Optional[OperationDetail] = None @dataclass diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 71fcc40c6..134757fe5 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -380,7 +380,7 @@ class TelemetryClientFactory: # Shared flush thread for all clients _flush_thread = None _flush_event = threading.Event() - _flush_interval_seconds = 90 + _flush_interval_seconds = 300 # 5 minutes DEFAULT_BATCH_SIZE = 100 diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 2ff82cee5..36141ee2b 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -2,6 +2,7 @@ import pytest from unittest.mock import patch, MagicMock import json +from dataclasses import asdict from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, @@ -9,7 +10,16 @@ TelemetryClientFactory, TelemetryHelper, ) -from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow +from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow, DatabricksClientType +from databricks.sql.telemetry.models.event import ( + TelemetryEvent, + DriverConnectionParameters, + DriverSystemConfiguration, + SqlExecutionEvent, + DriverErrorInfo, + DriverVolumeOperation, + HostDetails, +) from databricks.sql.auth.authenticators import ( AccessTokenAuthProvider, DatabricksOAuthProvider, @@ -446,3 +456,356 @@ def test_telemetry_disabled_when_flag_request_fails( mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail") assert isinstance(client, NoopTelemetryClient) + + +class TestTelemetryEventModels: + """Tests for telemetry event model data structures and JSON serialization.""" + + def test_host_details_serialization(self): + """Test HostDetails model serialization.""" + host = HostDetails(host_url="test-host.com", port=443) + + # Test JSON string generation + json_str = host.to_json() + assert isinstance(json_str, str) + parsed = json.loads(json_str) + assert parsed["host_url"] == "test-host.com" + assert parsed["port"] == 443 + + def test_driver_connection_parameters_all_fields(self): + """Test DriverConnectionParameters with all fields populated.""" + host_info = HostDetails(host_url="workspace.databricks.com", port=443) + proxy_info = HostDetails(host_url="proxy.company.com", port=8080) + cf_proxy_info = HostDetails(host_url="cf-proxy.company.com", port=8080) + + params = DriverConnectionParameters( + http_path="/sql/1.0/warehouses/abc123", + mode=DatabricksClientType.SEA, + host_info=host_info, + auth_mech=AuthMech.OAUTH, + auth_flow=AuthFlow.BROWSER_BASED_AUTHENTICATION, + socket_timeout=30000, + azure_workspace_resource_id="/subscriptions/test/resourceGroups/test", + azure_tenant_id="tenant-123", + use_proxy=True, + use_system_proxy=True, + proxy_host_info=proxy_info, + use_cf_proxy=False, + cf_proxy_host_info=cf_proxy_info, + non_proxy_hosts=["localhost", "127.0.0.1"], + allow_self_signed_support=False, + use_system_trust_store=True, + enable_arrow=True, + enable_direct_results=True, + enable_sea_hybrid_results=True, + http_connection_pool_size=100, + rows_fetched_per_block=100000, + async_poll_interval_millis=2000, + support_many_parameters=True, + enable_complex_datatype_support=True, + allowed_volume_ingestion_paths="/Volumes/catalog/schema/volume", + ) + + # Serialize to JSON and parse back + json_str = params.to_json() + json_dict = json.loads(json_str) + + # Verify all new fields are in JSON + assert json_dict["http_path"] == "/sql/1.0/warehouses/abc123" + assert json_dict["mode"] == "SEA" + assert json_dict["host_info"]["host_url"] == "workspace.databricks.com" + assert json_dict["auth_mech"] == "OAUTH" + assert json_dict["auth_flow"] == "BROWSER_BASED_AUTHENTICATION" + assert json_dict["socket_timeout"] == 30000 + assert json_dict["azure_workspace_resource_id"] == "/subscriptions/test/resourceGroups/test" + assert json_dict["azure_tenant_id"] == "tenant-123" + assert json_dict["use_proxy"] is True + assert json_dict["use_system_proxy"] is True + assert json_dict["proxy_host_info"]["host_url"] == "proxy.company.com" + assert json_dict["use_cf_proxy"] is False + assert json_dict["cf_proxy_host_info"]["host_url"] == "cf-proxy.company.com" + assert json_dict["non_proxy_hosts"] == ["localhost", "127.0.0.1"] + assert json_dict["allow_self_signed_support"] is False + assert json_dict["use_system_trust_store"] is True + assert json_dict["enable_arrow"] is True + assert json_dict["enable_direct_results"] is True + assert json_dict["enable_sea_hybrid_results"] is True + assert json_dict["http_connection_pool_size"] == 100 + assert json_dict["rows_fetched_per_block"] == 100000 + assert json_dict["async_poll_interval_millis"] == 2000 + assert json_dict["support_many_parameters"] is True + assert json_dict["enable_complex_datatype_support"] is True + assert json_dict["allowed_volume_ingestion_paths"] == "/Volumes/catalog/schema/volume" + + def test_driver_connection_parameters_minimal_fields(self): + """Test DriverConnectionParameters with only required fields.""" + host_info = HostDetails(host_url="workspace.databricks.com", port=443) + + params = DriverConnectionParameters( + http_path="/sql/1.0/warehouses/abc123", + mode=DatabricksClientType.THRIFT, + host_info=host_info, + ) + + # Note: to_json() filters out None values, so we need to check asdict for complete structure + json_str = params.to_json() + json_dict = json.loads(json_str) + + # Required fields should be present + assert json_dict["http_path"] == "/sql/1.0/warehouses/abc123" + assert json_dict["mode"] == "THRIFT" + assert json_dict["host_info"]["host_url"] == "workspace.databricks.com" + + # Optional fields with None are filtered out by to_json() + # This is expected behavior - None values are excluded from JSON output + + def test_driver_system_configuration_serialization(self): + """Test DriverSystemConfiguration model serialization.""" + sys_config = DriverSystemConfiguration( + driver_name="Databricks SQL Connector for Python", + driver_version="3.0.0", + runtime_name="CPython", + runtime_version="3.11.0", + runtime_vendor="Python Software Foundation", + os_name="Darwin", + os_version="23.0.0", + os_arch="arm64", + char_set_encoding="utf-8", + locale_name="en_US", + client_app_name="MyApp", + ) + + json_str = sys_config.to_json() + json_dict = json.loads(json_str) + + assert json_dict["driver_name"] == "Databricks SQL Connector for Python" + assert json_dict["driver_version"] == "3.0.0" + assert json_dict["runtime_name"] == "CPython" + assert json_dict["runtime_version"] == "3.11.0" + assert json_dict["runtime_vendor"] == "Python Software Foundation" + assert json_dict["os_name"] == "Darwin" + assert json_dict["os_version"] == "23.0.0" + assert json_dict["os_arch"] == "arm64" + assert json_dict["locale_name"] == "en_US" + assert json_dict["char_set_encoding"] == "utf-8" + assert json_dict["client_app_name"] == "MyApp" + + def test_telemetry_event_complete_serialization(self): + """Test complete TelemetryEvent serialization with all nested objects.""" + host_info = HostDetails(host_url="workspace.databricks.com", port=443) + proxy_info = HostDetails(host_url="proxy.company.com", port=8080) + + connection_params = DriverConnectionParameters( + http_path="/sql/1.0/warehouses/abc123", + mode=DatabricksClientType.SEA, + host_info=host_info, + auth_mech=AuthMech.OAUTH, + use_proxy=True, + proxy_host_info=proxy_info, + enable_arrow=True, + rows_fetched_per_block=100000, + ) + + sys_config = DriverSystemConfiguration( + driver_name="Databricks SQL Connector for Python", + driver_version="3.0.0", + runtime_name="CPython", + runtime_version="3.11.0", + runtime_vendor="Python Software Foundation", + os_name="Darwin", + os_version="23.0.0", + os_arch="arm64", + char_set_encoding="utf-8", + ) + + error_info = DriverErrorInfo( + error_name="ConnectionError", + stack_trace="Traceback...", + ) + + event = TelemetryEvent( + session_id="test-session-123", + sql_statement_id="test-stmt-456", + operation_latency_ms=1500, + auth_type="OAUTH", + system_configuration=sys_config, + driver_connection_params=connection_params, + error_info=error_info, + ) + + # Test JSON serialization + json_str = event.to_json() + assert isinstance(json_str, str) + + # Parse and verify structure + parsed = json.loads(json_str) + assert parsed["session_id"] == "test-session-123" + assert parsed["sql_statement_id"] == "test-stmt-456" + assert parsed["operation_latency_ms"] == 1500 + assert parsed["auth_type"] == "OAUTH" + + # Verify nested objects + assert parsed["system_configuration"]["driver_name"] == "Databricks SQL Connector for Python" + assert parsed["driver_connection_params"]["http_path"] == "/sql/1.0/warehouses/abc123" + assert parsed["driver_connection_params"]["use_proxy"] is True + assert parsed["driver_connection_params"]["proxy_host_info"]["host_url"] == "proxy.company.com" + assert parsed["error_info"]["error_name"] == "ConnectionError" + + def test_json_serialization_excludes_none_values(self): + """Test that JSON serialization properly excludes None values.""" + host_info = HostDetails(host_url="workspace.databricks.com", port=443) + + params = DriverConnectionParameters( + http_path="/sql/1.0/warehouses/abc123", + mode=DatabricksClientType.SEA, + host_info=host_info, + # All optional fields left as None + ) + + json_str = params.to_json() + parsed = json.loads(json_str) + + # Required fields present + assert parsed["http_path"] == "/sql/1.0/warehouses/abc123" + + # None values should be EXCLUDED from JSON (not included as null) + # This is the behavior of JsonSerializableMixin + assert "auth_mech" not in parsed + assert "azure_tenant_id" not in parsed + assert "proxy_host_info" not in parsed + + +@patch("databricks.sql.client.Session") +@patch("databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers") +class TestConnectionParameterTelemetry: + """Tests for connection parameter population in telemetry.""" + + def test_connection_with_proxy_populates_telemetry(self, mock_setup_pools, mock_session): + """Test that proxy configuration is captured in telemetry.""" + mock_session_instance = MagicMock() + mock_session_instance.guid_hex = "test-session-proxy" + mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False + mock_session_instance.use_sea = True + mock_session_instance.port = 443 + mock_session_instance.host = "workspace.databricks.com" + mock_session.return_value = mock_session_instance + + with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: + conn = sql.connect( + server_hostname="workspace.databricks.com", + http_path="/sql/1.0/warehouses/test", + access_token="test-token", + enable_telemetry=True, + force_enable_telemetry=True, + ) + + # Verify export was called + mock_export.assert_called_once() + call_args = mock_export.call_args + + # Extract driver_connection_params + driver_params = call_args.kwargs.get("driver_connection_params") + assert driver_params is not None + assert isinstance(driver_params, DriverConnectionParameters) + + # Verify fields are populated + assert driver_params.http_path == "/sql/1.0/warehouses/test" + assert driver_params.mode == DatabricksClientType.SEA + assert driver_params.host_info.host_url == "workspace.databricks.com" + assert driver_params.host_info.port == 443 + + def test_connection_with_azure_params_populates_telemetry(self, mock_setup_pools, mock_session): + """Test that Azure-specific parameters are captured in telemetry.""" + mock_session_instance = MagicMock() + mock_session_instance.guid_hex = "test-session-azure" + mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False + mock_session_instance.use_sea = False + mock_session_instance.port = 443 + mock_session_instance.host = "workspace.azuredatabricks.net" + mock_session.return_value = mock_session_instance + + with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: + conn = sql.connect( + server_hostname="workspace.azuredatabricks.net", + http_path="/sql/1.0/warehouses/test", + access_token="test-token", + azure_workspace_resource_id="/subscriptions/test/resourceGroups/test", + azure_tenant_id="tenant-123", + enable_telemetry=True, + force_enable_telemetry=True, + ) + + mock_export.assert_called_once() + driver_params = mock_export.call_args.kwargs.get("driver_connection_params") + + # Verify Azure fields + assert driver_params.azure_workspace_resource_id == "/subscriptions/test/resourceGroups/test" + assert driver_params.azure_tenant_id == "tenant-123" + + def test_connection_populates_arrow_and_performance_params(self, mock_setup_pools, mock_session): + """Test that Arrow and performance parameters are captured in telemetry.""" + mock_session_instance = MagicMock() + mock_session_instance.guid_hex = "test-session-perf" + mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False + mock_session_instance.use_sea = True + mock_session_instance.port = 443 + mock_session_instance.host = "workspace.databricks.com" + mock_session.return_value = mock_session_instance + + with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: + # Import pyarrow availability check + try: + import pyarrow + arrow_available = True + except ImportError: + arrow_available = False + + conn = sql.connect( + server_hostname="workspace.databricks.com", + http_path="/sql/1.0/warehouses/test", + access_token="test-token", + pool_maxsize=200, + enable_telemetry=True, + force_enable_telemetry=True, + ) + + mock_export.assert_called_once() + driver_params = mock_export.call_args.kwargs.get("driver_connection_params") + + # Verify performance fields + assert driver_params.enable_arrow == arrow_available + assert driver_params.enable_direct_results is True + assert driver_params.http_connection_pool_size == 200 + assert driver_params.rows_fetched_per_block == 100000 # DEFAULT_ARRAY_SIZE + assert driver_params.async_poll_interval_millis == 2000 + assert driver_params.support_many_parameters is True + + def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_session): + """Test that CloudFlare proxy fields default to False/None (not yet supported).""" + mock_session_instance = MagicMock() + mock_session_instance.guid_hex = "test-session-cfproxy" + mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False + mock_session_instance.use_sea = True + mock_session_instance.port = 443 + mock_session_instance.host = "workspace.databricks.com" + mock_session.return_value = mock_session_instance + + with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: + conn = sql.connect( + server_hostname="workspace.databricks.com", + http_path="/sql/1.0/warehouses/test", + access_token="test-token", + enable_telemetry=True, + force_enable_telemetry=True, + ) + + mock_export.assert_called_once() + driver_params = mock_export.call_args.kwargs.get("driver_connection_params") + + # CF proxy not yet supported - should be False/None + assert driver_params.use_cf_proxy is False + assert driver_params.cf_proxy_host_info is None From f9f5fc6f986789fdd1b4faf6c9d8a43dfb1e1634 Mon Sep 17 00:00:00 2001 From: jayant <167047871+jayantsing-db@users.noreply.github.com> Date: Fri, 14 Nov 2025 02:45:30 +0530 Subject: [PATCH 03/39] feat: Add multi-statement transaction support (#704) Implement PEP 249-compliant transaction control with extensions for manual commit/rollback operations. This enables atomic multi-table operations with REPEATABLE_READ isolation semantics. Core API additions: - connection.autocommit property for enabling/disabling auto-commit mode - connection.commit() to commit active transactions - connection.rollback() to rollback active transactions - connection.get_transaction_isolation() returns current isolation level - connection.set_transaction_isolation() validates isolation level - TransactionError exception for transaction-specific failures Implementation details: - Added autocommit state caching in Session with optional server query - Added TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ constant - All transaction operations include proper error handling and telemetry - Supports fetch_autocommit_from_server connection parameter Testing: - Unit tests covering all transaction methods and error scenarios - e2e integration tests validating transaction behavior including multi-table atomicity, sequential transactions, and isolation semantics Documentation: - Comprehensive TRANSACTIONS.md guide with examples and best practices - Updated README.md with basic usage and reference to detailed docs Requires MST-enabled Databricks SQL warehouse and Delta tables with 'delta.feature.catalogOwned-preview' table property. --------- Signed-off-by: Jayant Singh --- README.md | 6 + TRANSACTIONS.md | 387 +++++++++++++++++++++ examples/README.md | 1 + examples/transactions.py | 47 +++ src/databricks/sql/__init__.py | 3 + src/databricks/sql/client.py | 275 ++++++++++++++- src/databricks/sql/exc.py | 17 + src/databricks/sql/session.py | 21 ++ tests/e2e/test_transactions.py | 597 +++++++++++++++++++++++++++++++++ tests/unit/test_client.py | 385 ++++++++++++++++++++- 10 files changed, 1719 insertions(+), 20 deletions(-) create mode 100644 TRANSACTIONS.md create mode 100644 examples/transactions.py create mode 100644 tests/e2e/test_transactions.py diff --git a/README.md b/README.md index d57efda1f..ec82a3637 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,12 @@ or to a Databricks Runtime interactive cluster (e.g. /sql/protocolv1/o/123456789 > to authenticate the target Databricks user account and needs to open the browser for authentication. So it > can only run on the user's machine. +## Transaction Support + +The connector supports multi-statement transactions with manual commit/rollback control. Set `connection.autocommit = False` to disable autocommit mode, then use `connection.commit()` and `connection.rollback()` to control transactions. + +For detailed documentation, examples, and best practices, see **[TRANSACTIONS.md](TRANSACTIONS.md)**. + ## SQLAlchemy Starting from `databricks-sql-connector` version 4.0.0 SQLAlchemy support has been extracted to a new library `databricks-sqlalchemy`. diff --git a/TRANSACTIONS.md b/TRANSACTIONS.md new file mode 100644 index 000000000..590c298c0 --- /dev/null +++ b/TRANSACTIONS.md @@ -0,0 +1,387 @@ +# Transaction Support + +The Databricks SQL Connector for Python supports multi-statement transactions (MST). This allows you to group multiple SQL statements into atomic units that either succeed completely or fail completely. + +## Autocommit Behavior + +By default, every SQL statement executes in its own transaction and commits immediately (autocommit mode). This is the standard behavior for most database connectors. + +```python +from databricks import sql + +connection = sql.connect( + server_hostname="your-workspace.cloud.databricks.com", + http_path="/sql/1.0/warehouses/abc123" +) + +# Default: autocommit is True +print(connection.autocommit) # True + +# Each statement commits immediately +cursor = connection.cursor() +cursor.execute("INSERT INTO my_table VALUES (1, 'data')") +# Already committed - data is visible to other connections +``` + +To use explicit transactions, disable autocommit: + +```python +connection.autocommit = False + +# Now statements are grouped into a transaction +cursor = connection.cursor() +cursor.execute("INSERT INTO my_table VALUES (1, 'data')") +# Not committed yet - must call connection.commit() + +connection.commit() # Now it's visible +``` + +## Basic Transaction Operations + +### Committing Changes + +When autocommit is disabled, you must explicitly commit your changes: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO orders VALUES (1, 100.00)") + cursor.execute("INSERT INTO order_items VALUES (1, 'Widget', 2)") + connection.commit() # Both inserts succeed together +except Exception as e: + connection.rollback() # Neither insert is saved + raise +finally: + connection.autocommit = True # Restore default state +``` + +### Rolling Back Changes + +Use `rollback()` to discard all changes made in the current transaction: + +```python +connection.autocommit = False +cursor = connection.cursor() + +cursor.execute("INSERT INTO accounts VALUES (1, 1000)") +cursor.execute("UPDATE accounts SET balance = balance - 500 WHERE id = 1") + +# Changed your mind? +connection.rollback() # All changes discarded +``` + +Note: Calling `rollback()` when autocommit is enabled is safe (it's a no-op), but calling `commit()` will raise a `TransactionError`. + +### Sequential Transactions + +After a commit or rollback, a new transaction starts automatically: + +```python +connection.autocommit = False + +# First transaction +cursor.execute("INSERT INTO logs VALUES (1, 'event1')") +connection.commit() + +# Second transaction starts automatically +cursor.execute("INSERT INTO logs VALUES (2, 'event2')") +connection.rollback() # Only the second insert is discarded +``` + +## Multi-Table Transactions + +Transactions span multiple tables atomically. Either all changes are committed, or all are rolled back: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + # Insert into multiple tables + cursor.execute("INSERT INTO customers VALUES (1, 'Alice')") + cursor.execute("INSERT INTO orders VALUES (1, 1, 100.00)") + cursor.execute("INSERT INTO shipments VALUES (1, 1, 'pending')") + + connection.commit() # All three inserts succeed atomically +except Exception as e: + connection.rollback() # All three inserts are discarded + raise +finally: + connection.autocommit = True # Restore default state +``` + +This is particularly useful for maintaining data consistency across related tables. + +## Transaction Isolation + +Databricks uses **Snapshot Isolation** (mapped to `REPEATABLE_READ` in standard SQL terminology). This means: + +- **Repeatable reads**: Once you read data in a transaction, subsequent reads will see the same data (even if other transactions modify it) +- **Atomic commits**: Changes are visible to other connections only after commit +- **Write serializability within a single table**: Concurrent writes to the same table will cause conflicts +- **Snapshot isolation across tables**: Concurrent writes to different tables can succeed + +### Getting the Isolation Level + +```python +level = connection.get_transaction_isolation() +print(level) # Output: REPEATABLE_READ +``` + +### Setting the Isolation Level + +Currently, only `REPEATABLE_READ` is supported: + +```python +from databricks import sql + +# Using the constant +connection.set_transaction_isolation(sql.TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ) + +# Or using a string +connection.set_transaction_isolation("REPEATABLE_READ") + +# Other levels will raise NotSupportedError +connection.set_transaction_isolation("READ_COMMITTED") # Raises NotSupportedError +``` + +### What Repeatable Read Means in Practice + +Within a transaction, you'll always see a consistent snapshot of the data: + +```python +connection.autocommit = False +cursor = connection.cursor() + +# First read +cursor.execute("SELECT balance FROM accounts WHERE id = 1") +balance1 = cursor.fetchone()[0] # Returns 1000 + +# Another connection updates the balance +# (In a separate connection: UPDATE accounts SET balance = 500 WHERE id = 1) + +# Second read in the same transaction +cursor.execute("SELECT balance FROM accounts WHERE id = 1") +balance2 = cursor.fetchone()[0] # Still returns 1000 (repeatable read!) + +connection.commit() + +# After commit, new transactions will see the updated value (500) +``` + +## Error Handling + +### Setting Autocommit During a Transaction + +You cannot change autocommit mode while a transaction is active: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO logs VALUES (1, 'data')") + + # This will raise TransactionError + connection.autocommit = True # Error: transaction is active + +except sql.TransactionError as e: + print(f"Cannot change autocommit: {e}") + connection.rollback() # Clean up the transaction +finally: + connection.autocommit = True # Now it's safe to restore +``` + +### Committing Without an Active Transaction + +If autocommit is enabled, there's no active transaction, so calling `commit()` will fail: + +```python +connection.autocommit = True # Default + +try: + connection.commit() # Raises TransactionError +except sql.TransactionError as e: + print(f"No active transaction: {e}") +``` + +However, `rollback()` is safe in this case (it's a no-op). + +### Recovering from Query Failures + +If a statement fails during a transaction, roll back and start a new transaction: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO valid_table VALUES (1, 'data')") + cursor.execute("INSERT INTO nonexistent_table VALUES (2, 'data')") # Fails + connection.commit() +except Exception as e: + connection.rollback() # Discard the partial transaction + + # Log the error (with autocommit still disabled) + try: + cursor.execute("INSERT INTO error_log VALUES (1, 'Query failed')") + connection.commit() + except Exception: + connection.rollback() +finally: + connection.autocommit = True # Restore default state +``` + +## Querying Server State + +By default, the `autocommit` property returns a cached value for performance. If you need to query the server each time (for instance, when strong consistency is required): + +```python +connection = sql.connect( + server_hostname="your-workspace.cloud.databricks.com", + http_path="/sql/1.0/warehouses/abc123", + fetch_autocommit_from_server=True +) + +# Each access queries the server +state = connection.autocommit # Executes "SET AUTOCOMMIT" query +``` + +This is generally not needed for normal usage. + +## Write Conflicts + +### Within a Single Table + +Databricks enforces **write serializability** within a single table. If two transactions try to modify the same table concurrently, one will fail: + +```python +# Connection 1 +conn1.autocommit = False +cursor1 = conn1.cursor() +cursor1.execute("INSERT INTO accounts VALUES (1, 100)") + +# Connection 2 (concurrent) +conn2.autocommit = False +cursor2 = conn2.cursor() +cursor2.execute("INSERT INTO accounts VALUES (2, 200)") + +# First commit succeeds +conn1.commit() # OK + +# Second commit fails with concurrent write conflict +try: + conn2.commit() # Raises error about concurrent writes +except Exception as e: + conn2.rollback() + print(f"Concurrent write detected: {e}") +``` + +This happens even when the rows being modified are different. The conflict detection is at the table level. + +### Across Multiple Tables + +Concurrent writes to *different* tables can succeed. Each table tracks its own write conflicts independently: + +```python +# Connection 1: writes to table_a +conn1.autocommit = False +cursor1 = conn1.cursor() +cursor1.execute("INSERT INTO table_a VALUES (1, 'data')") + +# Connection 2: writes to table_b (different table) +conn2.autocommit = False +cursor2 = conn2.cursor() +cursor2.execute("INSERT INTO table_b VALUES (1, 'data')") + +# Both commits succeed (different tables) +conn1.commit() # OK +conn2.commit() # Also OK +``` + +## Best Practices + +1. **Keep transactions short**: Long-running transactions can cause conflicts with other connections. Commit as soon as your atomic unit of work is complete. + +2. **Always handle exceptions**: Wrap transaction code in try/except/finally and call `rollback()` on errors. + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO table1 VALUES (1, 'data')") + cursor.execute("UPDATE table2 SET status = 'updated'") + connection.commit() +except Exception as e: + connection.rollback() + logger.error(f"Transaction failed: {e}") + raise +finally: + connection.autocommit = True # Restore default state +``` + +3. **Use context managers**: If you're writing helper functions, consider using a context manager pattern: + +```python +from contextlib import contextmanager + +@contextmanager +def transaction(connection): + connection.autocommit = False + try: + yield connection + connection.commit() + except Exception: + connection.rollback() + raise + finally: + connection.autocommit = True + +# Usage +with transaction(connection): + cursor = connection.cursor() + cursor.execute("INSERT INTO logs VALUES (1, 'message')") + # Auto-commits on success, auto-rolls back on exception +``` + +4. **Reset autocommit when done**: Use a `finally` block to restore autocommit to `True`. This is especially important if the connection is reused or part of a connection pool: + +```python +connection.autocommit = False +try: + # ... transaction code ... + connection.commit() +except Exception: + connection.rollback() + raise +finally: + connection.autocommit = True # Restore to default state +``` + +5. **Be aware of isolation semantics**: Remember that repeatable read means you see a snapshot from the start of your transaction. If you need to see recent changes from other transactions, commit your current transaction and start a new one. + +## Requirements + +To use transactions, you need: +- A Databricks SQL warehouse that supports Multi-Statement Transactions (MST) +- Tables created with the `delta.feature.catalogOwned-preview` table property: + +```sql +CREATE TABLE my_table (id INT, value STRING) +USING DELTA +TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') +``` + +## Related APIs + +- `connection.autocommit` - Get or set autocommit mode (boolean) +- `connection.commit()` - Commit the current transaction +- `connection.rollback()` - Roll back the current transaction +- `connection.get_transaction_isolation()` - Get the isolation level (returns `"REPEATABLE_READ"`) +- `connection.set_transaction_isolation(level)` - Validate/set isolation level (only `"REPEATABLE_READ"` supported) +- `sql.TransactionError` - Exception raised for transaction-specific errors + +All of these are extensions to [PEP 249](https://www.python.org/dev/peps/pep-0249/) (Python Database API Specification v2.0). diff --git a/examples/README.md b/examples/README.md index d73c58a6b..f52dede1d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -31,6 +31,7 @@ To run all of these examples you can clone the entire repository to your disk. O - **`query_execute.py`** connects to the `samples` database of your default catalog, runs a small query, and prints the result to screen. - **`insert_data.py`** adds a tables called `squares` to your default catalog and inserts one hundred rows of example data. Then it fetches this data and prints it to the screen. +- **`transactions.py`** demonstrates multi-statement transaction support with explicit commit/rollback control. Shows how to group multiple SQL statements into an atomic unit that either succeeds completely or fails completely. - **`query_cancel.py`** shows how to cancel a query assuming that you can access the `Cursor` executing that query from a different thread. This is necessary because `databricks-sql-connector` does not yet implement an asynchronous API; calling `.execute()` blocks the current thread until execution completes. Therefore, the connector can't cancel queries from the same thread where they began. - **`interactive_oauth.py`** shows the simplest example of authenticating by OAuth (no need for a PAT generated in the DBSQL UI) while Bring Your Own IDP is in public preview. When you run the script it will open a browser window so you can authenticate. Afterward, the script fetches some sample data from Databricks and prints it to the screen. For this script, the OAuth token is not persisted which means you need to authenticate every time you run the script. - **`m2m_oauth.py`** shows the simplest example of authenticating by using OAuth M2M (machine-to-machine) for service principal. diff --git a/examples/transactions.py b/examples/transactions.py new file mode 100644 index 000000000..6f58dbd2d --- /dev/null +++ b/examples/transactions.py @@ -0,0 +1,47 @@ +from databricks import sql +import os + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + # Disable autocommit to use explicit transactions + connection.autocommit = False + + with connection.cursor() as cursor: + try: + # Create tables for demonstration + cursor.execute("CREATE TABLE IF NOT EXISTS accounts (id int, balance int)") + cursor.execute( + "CREATE TABLE IF NOT EXISTS transfers (from_id int, to_id int, amount int)" + ) + connection.commit() + + # Start a new transaction - transfer money between accounts + cursor.execute("INSERT INTO accounts VALUES (1, 1000), (2, 500)") + cursor.execute("UPDATE accounts SET balance = balance - 100 WHERE id = 1") + cursor.execute("UPDATE accounts SET balance = balance + 100 WHERE id = 2") + cursor.execute("INSERT INTO transfers VALUES (1, 2, 100)") + + # Commit the transaction - all changes succeed together + connection.commit() + print("Transaction committed successfully") + + # Verify the results + cursor.execute("SELECT * FROM accounts ORDER BY id") + print("Accounts:", cursor.fetchall()) + + cursor.execute("SELECT * FROM transfers") + print("Transfers:", cursor.fetchall()) + + except Exception as e: + # Roll back on error - all changes are discarded + connection.rollback() + print(f"Transaction rolled back due to error: {e}") + raise + + finally: + # Restore autocommit to default state + connection.autocommit = True diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 403a4d130..df44dd534 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -8,6 +8,9 @@ paramstyle = "named" +# Transaction isolation level constants (extension to PEP 249) +TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ" + import re from typing import TYPE_CHECKING diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 5e5b9cedc..fedfafdf3 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -21,6 +21,8 @@ InterfaceError, NotSupportedError, ProgrammingError, + TransactionError, + DatabaseError, ) from databricks.sql.thrift_api.TCLIService import ttypes @@ -87,6 +89,9 @@ NO_NATIVE_PARAMS: List = [] +# Transaction isolation level constants (extension to PEP 249) +TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ" + class Connection: def __init__( @@ -207,6 +212,11 @@ def read(self) -> Optional[OAuthToken]: This allows 1. cursor.tables() to return METRIC_VIEW table type 2. cursor.columns() to return "measure" column type + :param fetch_autocommit_from_server: `bool`, optional (default is False) + When True, the connection.autocommit property queries the server for current state + using SET AUTOCOMMIT instead of returning cached value. + Set to True if autocommit might be changed by external means (e.g., external SQL commands). + When False (default), uses cached state for better performance. """ # Internal arguments in **kwargs: @@ -305,6 +315,9 @@ def read(self) -> Optional[OAuthToken]: kwargs.get("use_inline_params", False) ) self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) + self._fetch_autocommit_from_server = kwargs.get( + "fetch_autocommit_from_server", False + ) self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False) self.enable_telemetry = kwargs.get("enable_telemetry", False) @@ -506,15 +519,261 @@ def _close(self, close_cursors=True) -> None: if self.http_client: self.http_client.close() - def commit(self): - """No-op because Databricks does not support transactions""" - pass + @property + def autocommit(self) -> bool: + """ + Get auto-commit mode for this connection. - def rollback(self): - raise NotSupportedError( - "Transactions are not supported on Databricks", - session_id_hex=self.get_session_id_hex(), - ) + Extension to PEP 249. Returns cached value by default. + If fetch_autocommit_from_server=True was set during connection, + queries server for current state. + + Returns: + bool: True if auto-commit is enabled, False otherwise + + Raises: + InterfaceError: If connection is closed + TransactionError: If fetch_autocommit_from_server=True and query fails + """ + if not self.open: + raise InterfaceError( + "Cannot get autocommit on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + if self._fetch_autocommit_from_server: + return self._fetch_autocommit_state_from_server() + + return self.session.get_autocommit() + + @autocommit.setter + def autocommit(self, value: bool) -> None: + """ + Set auto-commit mode for this connection. + + Extension to PEP 249. Executes SET AUTOCOMMIT command on server. + + Args: + value: True to enable auto-commit, False to disable + + Raises: + InterfaceError: If connection is closed + TransactionError: If server rejects the change + """ + if not self.open: + raise InterfaceError( + "Cannot set autocommit on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + # Create internal cursor for transaction control + cursor = None + try: + cursor = self.cursor() + sql = f"SET AUTOCOMMIT = {'TRUE' if value else 'FALSE'}" + cursor.execute(sql) + + # Update cached state on success + self.session.set_autocommit(value) + + except DatabaseError as e: + # Wrap in TransactionError with context + raise TransactionError( + f"Failed to set autocommit to {value}: {e.message}", + context={ + **e.context, + "operation": "set_autocommit", + "autocommit_value": value, + }, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def _fetch_autocommit_state_from_server(self) -> bool: + """ + Query server for current autocommit state using SET AUTOCOMMIT. + + Returns: + bool: Server's autocommit state + + Raises: + TransactionError: If query fails + """ + cursor = None + try: + cursor = self.cursor() + cursor.execute("SET AUTOCOMMIT") + + # Fetch result: should return row with value column + result = cursor.fetchone() + if result is None: + raise TransactionError( + "No result returned from SET AUTOCOMMIT query", + context={"operation": "fetch_autocommit"}, + session_id_hex=self.get_session_id_hex(), + ) + + # Parse value (first column should be "true" or "false") + value_str = str(result[0]).lower() + autocommit_state = value_str == "true" + + # Update cache + self.session.set_autocommit(autocommit_state) + + return autocommit_state + + except TransactionError: + # Re-raise TransactionError as-is + raise + except DatabaseError as e: + # Wrap other DatabaseErrors + raise TransactionError( + f"Failed to fetch autocommit state from server: {e.message}", + context={**e.context, "operation": "fetch_autocommit"}, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def commit(self) -> None: + """ + Commit the current transaction. + + Per PEP 249. Should be called only when autocommit is disabled. + + When autocommit is False: + - Commits the current transaction + - Server automatically starts new transaction + + When autocommit is True: + - Server may throw error if no active transaction + + Raises: + InterfaceError: If connection is closed + TransactionError: If commit fails (e.g., no active transaction) + """ + if not self.open: + raise InterfaceError( + "Cannot commit on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + cursor = None + try: + cursor = self.cursor() + cursor.execute("COMMIT") + + except DatabaseError as e: + raise TransactionError( + f"Failed to commit transaction: {e.message}", + context={**e.context, "operation": "commit"}, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def rollback(self) -> None: + """ + Rollback the current transaction. + + Per PEP 249. Should be called only when autocommit is disabled. + + When autocommit is False: + - Rolls back the current transaction + - Server automatically starts new transaction + + When autocommit is True: + - ROLLBACK is forgiving (no-op, doesn't throw exception) + + Note: ROLLBACK is safe to call even without active transaction. + + Raises: + InterfaceError: If connection is closed + TransactionError: If rollback fails + """ + if not self.open: + raise InterfaceError( + "Cannot rollback on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + cursor = None + try: + cursor = self.cursor() + cursor.execute("ROLLBACK") + + except DatabaseError as e: + raise TransactionError( + f"Failed to rollback transaction: {e.message}", + context={**e.context, "operation": "rollback"}, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def get_transaction_isolation(self) -> str: + """ + Get the transaction isolation level. + + Extension to PEP 249. + + Databricks supports REPEATABLE_READ isolation level (Snapshot Isolation), + which is the default and only supported level. + + Returns: + str: "REPEATABLE_READ" - the transaction isolation level constant + + Raises: + InterfaceError: If connection is closed + """ + if not self.open: + raise InterfaceError( + "Cannot get transaction isolation on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + return TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ + + def set_transaction_isolation(self, level: str) -> None: + """ + Set transaction isolation level. + + Extension to PEP 249. + + Databricks supports only REPEATABLE_READ isolation level (Snapshot Isolation). + This method validates that the requested level is supported but does not + execute any SQL, as REPEATABLE_READ is the default server behavior. + + Args: + level: Isolation level. Must be "REPEATABLE_READ" or "REPEATABLE READ" + (case-insensitive, underscores and spaces are interchangeable) + + Raises: + InterfaceError: If connection is closed + NotSupportedError: If isolation level not supported + """ + if not self.open: + raise InterfaceError( + "Cannot set transaction isolation on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + # Normalize and validate isolation level + normalized_level = level.upper().replace("_", " ") + + if normalized_level != TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ.replace( + "_", " " + ): + raise NotSupportedError( + f"Setting transaction isolation level '{level}' is not supported. " + f"Only {TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ} is supported.", + session_id_hex=self.get_session_id_hex(), + ) class Cursor: diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 4a772c49b..3a3a6b3c5 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -70,6 +70,23 @@ class NotSupportedError(DatabaseError): pass +class TransactionError(DatabaseError): + """ + Exception raised for transaction-specific errors. + + This exception is used when transaction control operations fail, such as: + - Setting autocommit mode (AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION) + - Committing a transaction (MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION) + - Rolling back a transaction + - Setting transaction isolation level + + The exception includes context about which transaction operation failed + and preserves the underlying cause via exception chaining. + """ + + pass + + ### Custom error classes ### class InvalidServerResponseError(OperationalError): """Thrown if the server does not set the initial namespace correctly""" diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index d8ba5d125..0f723d144 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -45,6 +45,9 @@ def __init__( self.schema = schema self.http_path = http_path + # Initialize autocommit state (JDBC default is True) + self._autocommit = True + user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: user_agent_entry = kwargs.get("_user_agent_entry") @@ -168,6 +171,24 @@ def guid_hex(self) -> str: """Get the session ID in hex format""" return self._session_id.hex_guid + def get_autocommit(self) -> bool: + """ + Get the cached autocommit state for this session. + + Returns: + bool: True if autocommit is enabled, False otherwise + """ + return self._autocommit + + def set_autocommit(self, value: bool) -> None: + """ + Update the cached autocommit state for this session. + + Args: + value: True to cache autocommit as enabled, False as disabled + """ + self._autocommit = value + def close(self) -> None: """Close the underlying session.""" logger.info("Closing session %s", self.guid_hex) diff --git a/tests/e2e/test_transactions.py b/tests/e2e/test_transactions.py new file mode 100644 index 000000000..09cbdae24 --- /dev/null +++ b/tests/e2e/test_transactions.py @@ -0,0 +1,597 @@ +""" +End-to-end integration tests for Multi-Statement Transaction (MST) APIs. + +These tests verify: +- autocommit property (getter/setter) +- commit() and rollback() methods +- get_transaction_isolation() and set_transaction_isolation() methods +- Transaction error handling + +Requirements: +- DBSQL warehouse that supports Multi-Statement Transactions (MST) +- Test environment configured via test.env file or environment variables + +Setup: +Set the following environment variables: +- DATABRICKS_SERVER_HOSTNAME +- DATABRICKS_HTTP_PATH +- DATABRICKS_ACCESS_TOKEN (or use OAuth) + +Usage: + pytest tests/e2e/test_transactions.py -v +""" + +import logging +import os +import pytest +from typing import Any, Dict + +import databricks.sql as sql +from databricks.sql import TransactionError, NotSupportedError, InterfaceError + +logger = logging.getLogger(__name__) + + +@pytest.mark.skip( + reason="Test environment does not yet support multi-statement transactions" +) +class TestTransactions: + """E2E tests for transaction control methods (MST support).""" + + # Test table name + TEST_TABLE_NAME = "transaction_test_table" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self, connection_details): + """Setup test environment before each test and cleanup after.""" + self.connection_params = { + "server_hostname": connection_details["host"], + "http_path": connection_details["http_path"], + "access_token": connection_details.get("access_token"), + } + + # Get catalog and schema from environment or use defaults + self.catalog = os.getenv("DATABRICKS_CATALOG", "main") + self.schema = os.getenv("DATABRICKS_SCHEMA", "default") + + # Create connection for setup + self.connection = sql.connect(**self.connection_params) + + # Setup: Create test table + self._create_test_table() + + yield + + # Teardown: Cleanup + self._cleanup() + + def _get_fully_qualified_table_name(self) -> str: + """Get the fully qualified table name.""" + return f"{self.catalog}.{self.schema}.{self.TEST_TABLE_NAME}" + + def _create_test_table(self): + """Create the test table with Delta format and MST support.""" + fq_table_name = self._get_fully_qualified_table_name() + cursor = self.connection.cursor() + + try: + # Drop if exists + cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") + + # Create table with Delta and catalog-owned feature for MST compatibility + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {fq_table_name} + (id INT, value STRING) + USING DELTA + TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') + """ + ) + + logger.info(f"Created test table: {fq_table_name}") + finally: + cursor.close() + + def _cleanup(self): + """Cleanup after test: rollback pending transactions, drop table, close connection.""" + try: + # Try to rollback any pending transaction + if ( + self.connection + and self.connection.open + and not self.connection.autocommit + ): + try: + self.connection.rollback() + except Exception as e: + logger.debug( + f"Rollback during cleanup failed (may be expected): {e}" + ) + + # Reset to autocommit mode + try: + self.connection.autocommit = True + except Exception as e: + logger.debug(f"Reset autocommit during cleanup failed: {e}") + + # Drop test table + if self.connection and self.connection.open: + fq_table_name = self._get_fully_qualified_table_name() + cursor = self.connection.cursor() + try: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") + logger.info(f"Dropped test table: {fq_table_name}") + except Exception as e: + logger.warning(f"Failed to drop test table: {e}") + finally: + cursor.close() + + finally: + # Close connection + if self.connection: + self.connection.close() + + # ==================== BASIC AUTOCOMMIT TESTS ==================== + + def test_default_autocommit_is_true(self): + """Test that new connection defaults to autocommit=true.""" + assert ( + self.connection.autocommit is True + ), "New connection should have autocommit=true by default" + + def test_set_autocommit_to_false(self): + """Test successfully setting autocommit to false.""" + self.connection.autocommit = False + assert ( + self.connection.autocommit is False + ), "autocommit should be false after setting to false" + + def test_set_autocommit_to_true(self): + """Test successfully setting autocommit back to true.""" + # First disable + self.connection.autocommit = False + assert self.connection.autocommit is False + + # Then enable + self.connection.autocommit = True + assert ( + self.connection.autocommit is True + ), "autocommit should be true after setting to true" + + # ==================== COMMIT TESTS ==================== + + def test_commit_single_insert(self): + """Test successfully committing a transaction with single INSERT.""" + fq_table_name = self._get_fully_qualified_table_name() + + # Start transaction + self.connection.autocommit = False + + # Insert data + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'test_value')" + ) + cursor.close() + + # Commit + self.connection.commit() + + # Verify data is persisted using a new connection + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + verify_cursor.close() + + assert result is not None, "Should find inserted row after commit" + assert result[0] == "test_value", "Value should match inserted value" + finally: + verify_conn.close() + + def test_commit_multiple_inserts(self): + """Test successfully committing a transaction with multiple INSERTs.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # Insert multiple rows + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'value1')") + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'value2')") + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'value3')") + cursor.close() + + self.connection.commit() + + # Verify all rows persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name}") + result = verify_cursor.fetchone() + verify_cursor.close() + + assert result[0] == 3, "Should have 3 rows after commit" + finally: + verify_conn.close() + + # ==================== ROLLBACK TESTS ==================== + + def test_rollback_single_insert(self): + """Test successfully rolling back a transaction.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # Insert data + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table_name} (id, value) VALUES (100, 'rollback_test')" + ) + cursor.close() + + # Rollback + self.connection.rollback() + + # Verify data is NOT persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 100" + ) + result = verify_cursor.fetchone() + verify_cursor.close() + + assert result[0] == 0, "Rolled back data should not be persisted" + finally: + verify_conn.close() + + # ==================== SEQUENTIAL TRANSACTION TESTS ==================== + + def test_multiple_sequential_transactions(self): + """Test executing multiple sequential transactions (commit, commit, rollback).""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # First transaction - commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'txn1')") + cursor.close() + self.connection.commit() + + # Second transaction - commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'txn2')") + cursor.close() + self.connection.commit() + + # Third transaction - rollback + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'txn3')") + cursor.close() + self.connection.rollback() + + # Verify only first two transactions persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table_name} WHERE id IN (1, 2)" + ) + result = verify_cursor.fetchone() + assert result[0] == 2, "Should have 2 committed rows" + + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 3") + result = verify_cursor.fetchone() + assert result[0] == 0, "Rolled back row should not exist" + verify_cursor.close() + finally: + verify_conn.close() + + def test_auto_start_transaction_after_commit(self): + """Test that new transaction automatically starts after commit.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # First transaction - commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") + cursor.close() + self.connection.commit() + + # New transaction should start automatically - insert and rollback + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") + cursor.close() + self.connection.rollback() + + # Verify: first committed, second rolled back + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + assert result[0] == 1, "First insert should be committed" + + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") + result = verify_cursor.fetchone() + assert result[0] == 0, "Second insert should be rolled back" + verify_cursor.close() + finally: + verify_conn.close() + + def test_auto_start_transaction_after_rollback(self): + """Test that new transaction automatically starts after rollback.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # First transaction - rollback + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") + cursor.close() + self.connection.rollback() + + # New transaction should start automatically - insert and commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") + cursor.close() + self.connection.commit() + + # Verify: first rolled back, second committed + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + assert result[0] == 0, "First insert should be rolled back" + + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") + result = verify_cursor.fetchone() + assert result[0] == 1, "Second insert should be committed" + verify_cursor.close() + finally: + verify_conn.close() + + # ==================== UPDATE/DELETE OPERATION TESTS ==================== + + def test_update_in_transaction(self): + """Test UPDATE operation in transaction.""" + fq_table_name = self._get_fully_qualified_table_name() + + # First insert a row with autocommit + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'original')" + ) + cursor.close() + + # Start transaction and update + self.connection.autocommit = False + cursor = self.connection.cursor() + cursor.execute(f"UPDATE {fq_table_name} SET value = 'updated' WHERE id = 1") + cursor.close() + self.connection.commit() + + # Verify update persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + assert result[0] == "updated", "Value should be updated after commit" + verify_cursor.close() + finally: + verify_conn.close() + + # ==================== MULTI-TABLE TRANSACTION TESTS ==================== + + def test_multi_table_transaction_commit(self): + """Test atomic commit across multiple tables.""" + fq_table1_name = self._get_fully_qualified_table_name() + table2_name = self.TEST_TABLE_NAME + "_2" + fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" + + # Create second table + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {fq_table2_name} + (id INT, category STRING) + USING DELTA + TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') + """ + ) + cursor.close() + + try: + # Start transaction and insert into both tables + self.connection.autocommit = False + + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table1_name} (id, value) VALUES (10, 'table1_data')" + ) + cursor.execute( + f"INSERT INTO {fq_table2_name} (id, category) VALUES (10, 'table2_data')" + ) + cursor.close() + + # Commit both atomically + self.connection.commit() + + # Verify both inserts persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 10" + ) + result = verify_cursor.fetchone() + assert result[0] == 1, "Table1 insert should be committed" + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 10" + ) + result = verify_cursor.fetchone() + assert result[0] == 1, "Table2 insert should be committed" + + verify_cursor.close() + finally: + verify_conn.close() + + finally: + # Cleanup second table + self.connection.autocommit = True + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.close() + + def test_multi_table_transaction_rollback(self): + """Test atomic rollback across multiple tables.""" + fq_table1_name = self._get_fully_qualified_table_name() + table2_name = self.TEST_TABLE_NAME + "_2" + fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" + + # Create second table + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {fq_table2_name} + (id INT, category STRING) + USING DELTA + TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') + """ + ) + cursor.close() + + try: + # Start transaction and insert into both tables + self.connection.autocommit = False + + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table1_name} (id, value) VALUES (20, 'rollback1')" + ) + cursor.execute( + f"INSERT INTO {fq_table2_name} (id, category) VALUES (20, 'rollback2')" + ) + cursor.close() + + # Rollback both atomically + self.connection.rollback() + + # Verify both inserts were rolled back + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 20" + ) + result = verify_cursor.fetchone() + assert result[0] == 0, "Table1 insert should be rolled back" + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 20" + ) + result = verify_cursor.fetchone() + assert result[0] == 0, "Table2 insert should be rolled back" + + verify_cursor.close() + finally: + verify_conn.close() + + finally: + # Cleanup second table + self.connection.autocommit = True + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.close() + + # ==================== ERROR HANDLING TESTS ==================== + + def test_set_autocommit_during_active_transaction(self): + """Test that setting autocommit during an active transaction throws error.""" + fq_table_name = self._get_fully_qualified_table_name() + + # Start transaction + self.connection.autocommit = False + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (99, 'test')") + cursor.close() + + # Try to set autocommit=True during active transaction + with pytest.raises(TransactionError) as exc_info: + self.connection.autocommit = True + + # Verify error message mentions autocommit or active transaction + error_msg = str(exc_info.value).lower() + assert ( + "autocommit" in error_msg or "active transaction" in error_msg + ), "Error should mention autocommit or active transaction" + + # Cleanup - rollback the transaction + self.connection.rollback() + + def test_commit_without_active_transaction_throws_error(self): + """Test that commit() throws error when autocommit=true (no active transaction).""" + # Ensure autocommit is true (default) + assert self.connection.autocommit is True + + # Attempt commit without active transaction should throw + with pytest.raises(TransactionError) as exc_info: + self.connection.commit() + + # Verify error message indicates no active transaction + error_message = str(exc_info.value) + assert ( + "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION" in error_message + or "no active transaction" in error_message.lower() + ), "Error should indicate no active transaction" + + def test_rollback_without_active_transaction_is_safe(self): + """Test that rollback() without active transaction is a safe no-op.""" + # With autocommit=true (no active transaction) + assert self.connection.autocommit is True + + # ROLLBACK should be safe (no exception) + self.connection.rollback() + + # Verify connection is still usable + assert self.connection.autocommit is True + assert self.connection.open is True + + # ==================== TRANSACTION ISOLATION TESTS ==================== + + def test_get_transaction_isolation_returns_repeatable_read(self): + """Test that get_transaction_isolation() returns REPEATABLE_READ.""" + isolation_level = self.connection.get_transaction_isolation() + assert ( + isolation_level == "REPEATABLE_READ" + ), "Databricks MST should use REPEATABLE_READ (Snapshot Isolation)" + + def test_set_transaction_isolation_accepts_repeatable_read(self): + """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" + # Should not raise - these are all valid formats + self.connection.set_transaction_isolation("REPEATABLE_READ") + self.connection.set_transaction_isolation("REPEATABLE READ") + self.connection.set_transaction_isolation("repeatable_read") + self.connection.set_transaction_isolation("repeatable read") + + def test_set_transaction_isolation_rejects_unsupported_level(self): + """Test that set_transaction_isolation() rejects unsupported levels.""" + with pytest.raises(NotSupportedError) as exc_info: + self.connection.set_transaction_isolation("READ_COMMITTED") + + error_message = str(exc_info.value) + assert "not supported" in error_message.lower() + assert "READ_COMMITTED" in error_message diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 19375cde3..cb810afbb 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -22,7 +22,13 @@ import databricks.sql import databricks.sql.client as client -from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError +from databricks.sql import ( + InterfaceError, + DatabaseError, + Error, + NotSupportedError, + TransactionError, +) from databricks.sql.types import Row from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.backend.types import CommandId, CommandState @@ -439,11 +445,6 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set(self): "last operation", ) - @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_commit_a_noop(self, mock_thrift_backend_class): - c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - c.commit() - def test_setinputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setinputsizes(1) @@ -452,12 +453,6 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_rollback_not_supported(self, mock_thrift_backend_class): - c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - with self.assertRaises(NotSupportedError): - c.rollback() - @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_row_number_respected(self, mock_thrift_backend_class): @@ -639,11 +634,377 @@ def mock_close_normal(): ) +class TransactionTestSuite(unittest.TestCase): + """ + Unit tests for transaction control methods (MST support). + """ + + PACKAGE_NAME = "databricks.sql" + DUMMY_CONNECTION_ARGS = { + "server_hostname": "foo", + "http_path": "dummy_path", + "access_token": "tok", + } + + def _create_mock_connection(self, mock_session_class): + """Helper to create a mocked connection for transaction tests.""" + # Mock session + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session.get_autocommit.return_value = True + mock_session_class.return_value = mock_session + + # Create connection + conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) + return conn + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_getter_returns_cached_value(self, mock_session_class): + """Test that autocommit property returns cached session value by default.""" + conn = self._create_mock_connection(mock_session_class) + + # Get autocommit (should use cached value) + result = conn.autocommit + + conn.session.get_autocommit.assert_called_once() + self.assertTrue(result) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_executes_sql(self, mock_session_class): + """Test that setting autocommit executes SET AUTOCOMMIT command.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.autocommit = False + + # Verify SQL was executed + mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT = FALSE") + mock_cursor.close.assert_called_once() + + conn.session.set_autocommit.assert_called_once_with(False) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_with_true_value(self, mock_session_class): + """Test setting autocommit to True.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.autocommit = True + + mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT = TRUE") + conn.session.set_autocommit.assert_called_once_with(True) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_wraps_database_error(self, mock_session_class): + """Test that autocommit setter wraps DatabaseError in TransactionError.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + server_error = DatabaseError( + "AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION", + context={"sql_state": "25000"}, + session_id_hex="test-session-id", + ) + mock_cursor.execute.side_effect = server_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.autocommit = False + + self.assertIn("Failed to set autocommit", str(ctx.exception)) + self.assertEqual(ctx.exception.context["operation"], "set_autocommit") + self.assertEqual(ctx.exception.context["autocommit_value"], False) + + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_preserves_exception_chain(self, mock_session_class): + """Test that exception chaining is preserved.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + original_error = DatabaseError( + "Original error", session_id_hex="test-session-id" + ) + mock_cursor.execute.side_effect = original_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.autocommit = False + + self.assertEqual(ctx.exception.__cause__, original_error) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_executes_sql(self, mock_session_class): + """Test that commit() executes COMMIT command.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.commit() + + mock_cursor.execute.assert_called_once_with("COMMIT") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_wraps_database_error(self, mock_session_class): + """Test that commit() wraps DatabaseError in TransactionError.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + server_error = DatabaseError( + "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION", + context={"sql_state": "25000"}, + session_id_hex="test-session-id", + ) + mock_cursor.execute.side_effect = server_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.commit() + + self.assertIn("Failed to commit", str(ctx.exception)) + self.assertEqual(ctx.exception.context["operation"], "commit") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that commit() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.commit() + + self.assertIn("Cannot commit on closed connection", str(ctx.exception)) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_executes_sql(self, mock_session_class): + """Test that rollback() executes ROLLBACK command.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.rollback() + + mock_cursor.execute.assert_called_once_with("ROLLBACK") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_wraps_database_error(self, mock_session_class): + """Test that rollback() wraps DatabaseError in TransactionError.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + server_error = DatabaseError( + "Unexpected rollback error", + context={"sql_state": "HY000"}, + session_id_hex="test-session-id", + ) + mock_cursor.execute.side_effect = server_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.rollback() + + self.assertIn("Failed to rollback", str(ctx.exception)) + self.assertEqual(ctx.exception.context["operation"], "rollback") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that rollback() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.rollback() + + self.assertIn("Cannot rollback on closed connection", str(ctx.exception)) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_get_transaction_isolation_returns_repeatable_read( + self, mock_session_class + ): + """Test that get_transaction_isolation() returns REPEATABLE_READ.""" + conn = self._create_mock_connection(mock_session_class) + + result = conn.get_transaction_isolation() + + self.assertEqual(result, "REPEATABLE_READ") + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_get_transaction_isolation_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that get_transaction_isolation() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.get_transaction_isolation() + + self.assertIn( + "Cannot get transaction isolation on closed connection", str(ctx.exception) + ) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_set_transaction_isolation_accepts_repeatable_read( + self, mock_session_class + ): + """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" + conn = self._create_mock_connection(mock_session_class) + + # Should not raise + conn.set_transaction_isolation("REPEATABLE_READ") + conn.set_transaction_isolation("REPEATABLE READ") # With space + conn.set_transaction_isolation("repeatable_read") # Lowercase with underscore + conn.set_transaction_isolation("repeatable read") # Lowercase with space + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_set_transaction_isolation_rejects_other_levels(self, mock_session_class): + """Test that set_transaction_isolation() rejects non-REPEATABLE_READ levels.""" + conn = self._create_mock_connection(mock_session_class) + + with self.assertRaises(NotSupportedError) as ctx: + conn.set_transaction_isolation("READ_COMMITTED") + + self.assertIn("not supported", str(ctx.exception)) + self.assertIn("READ_COMMITTED", str(ctx.exception)) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_set_transaction_isolation_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that set_transaction_isolation() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.set_transaction_isolation("REPEATABLE_READ") + + self.assertIn( + "Cannot set transaction isolation on closed connection", str(ctx.exception) + ) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_fetch_autocommit_from_server_queries_server(self, mock_session_class): + """Test that fetch_autocommit_from_server=True queries server.""" + # Create connection with fetch_autocommit_from_server=True + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + conn = client.Connection( + fetch_autocommit_from_server=True, **self.DUMMY_CONNECTION_ARGS + ) + + mock_cursor = Mock() + mock_row = Mock() + mock_row.__getitem__ = Mock(return_value="true") + mock_cursor.fetchone.return_value = mock_row + + with patch.object(conn, "cursor", return_value=mock_cursor): + result = conn.autocommit + + mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT") + mock_cursor.fetchone.assert_called_once() + mock_cursor.close.assert_called_once() + + conn.session.set_autocommit.assert_called_once_with(True) + + self.assertTrue(result) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_class): + """Test that fetch_autocommit_from_server correctly parses false value.""" + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + conn = client.Connection( + fetch_autocommit_from_server=True, **self.DUMMY_CONNECTION_ARGS + ) + + mock_cursor = Mock() + mock_row = Mock() + mock_row.__getitem__ = Mock(return_value="false") + mock_cursor.fetchone.return_value = mock_row + + with patch.object(conn, "cursor", return_value=mock_cursor): + result = conn.autocommit + + conn.session.set_autocommit.assert_called_once_with(False) + self.assertFalse(result) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_class): + """Test that fetch_autocommit_from_server raises error when no result.""" + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + conn = client.Connection( + fetch_autocommit_from_server=True, **self.DUMMY_CONNECTION_ARGS + ) + + mock_cursor = Mock() + mock_cursor.fetchone.return_value = None + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + _ = conn.autocommit + + self.assertIn("No result returned", str(ctx.exception)) + mock_cursor.close.assert_called_once() + + conn.close() + + if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) loader = unittest.TestLoader() test_classes = [ ClientTestSuite, + TransactionTestSuite, FetchTests, ThriftBackendTestSuite, ArrowQueueSuite, From cca421b6dcab315df7dfd66332da9b4b106d56d9 Mon Sep 17 00:00:00 2001 From: jayant <167047871+jayantsing-db@users.noreply.github.com> Date: Fri, 14 Nov 2025 12:48:54 +0530 Subject: [PATCH 04/39] Bump to version 4.2.0 (#707) Bump to version 4.2.0 Signed-off-by: Jayant Singh --- CHANGELOG.md | 5 +++++ pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fa6bfb66..0f5402ccb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Release History +# 4.2.0 (2025-11-14) +- Add multi-statement transaction support (databricks/databricks-sql-python#704 by @jayantsing-db) +- Add a workflow to parallelise the E2E tests (databricks/databricks-sql-python#697 by @msrathore-db) +- Bring Python telemetry event model consistent with JDBC (databricks/databricks-sql-python#701 by @nikhilsuri-db) + # 4.1.4 (2025-10-15) - Add support for Token Federation (databricks/databricks-sql-python#691 by @madhav-db) - Add metric view support (databricks/databricks-sql-python#688 by @shivam2680) diff --git a/pyproject.toml b/pyproject.toml index c0eb8244d..7bfc3851f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.1.4" +version = "4.2.0" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index df44dd534..741845d11 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -71,7 +71,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.1.4" +__version__ = "4.2.0" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From a4899cb16811de86daee1e4293670aa625c38303 Mon Sep 17 00:00:00 2001 From: jayant <167047871+jayantsing-db@users.noreply.github.com> Date: Thu, 20 Nov 2025 17:47:48 +0530 Subject: [PATCH 05/39] Add ignore_transactions config to disable transaction operations (#711) Introduces a new `ignore_transactions` configuration parameter (default: True) to control transaction-related behavior in the Connection class. When ignore_transactions=True (default): - commit(): no-op, returns immediately - rollback(): raises NotSupportedError with message "Transactions are not supported on Databricks" - autocommit setter: no-op, returns immediately When ignore_transactions=False: - All transaction methods execute normally Changes: - Added ignore_transactions parameter to Connection.__init__() with default value True - Modified commit(), rollback(), and autocommit setter to check ignore_transactions flag - Updated unit tests to pass ignore_transactions=False when testing transaction functionality - Updated e2e transaction tests to pass ignore_transactions=False - Added three new unit tests to verify ignore_transactions --- src/databricks/sql/client.py | 33 +++++++++++ tests/e2e/test_transactions.py | 1 + tests/unit/test_client.py | 102 +++++++++++++++++++++++++++++++-- 3 files changed, 131 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index fedfafdf3..a7f802dcd 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -104,6 +104,7 @@ def __init__( catalog: Optional[str] = None, schema: Optional[str] = None, _use_arrow_native_complex_types: Optional[bool] = True, + ignore_transactions: bool = True, **kwargs, ) -> None: """ @@ -217,6 +218,12 @@ def read(self) -> Optional[OAuthToken]: using SET AUTOCOMMIT instead of returning cached value. Set to True if autocommit might be changed by external means (e.g., external SQL commands). When False (default), uses cached state for better performance. + :param ignore_transactions: `bool`, optional (default is True) + When True, transaction-related operations behave as follows: + - commit(): no-op (does nothing) + - rollback(): raises NotSupportedError + - autocommit setter: no-op (does nothing) + When False, transaction operations execute normally. """ # Internal arguments in **kwargs: @@ -318,6 +325,7 @@ def read(self) -> Optional[OAuthToken]: self._fetch_autocommit_from_server = kwargs.get( "fetch_autocommit_from_server", False ) + self.ignore_transactions = ignore_transactions self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False) self.enable_telemetry = kwargs.get("enable_telemetry", False) @@ -556,10 +564,17 @@ def autocommit(self, value: bool) -> None: Args: value: True to enable auto-commit, False to disable + When ignore_transactions is True: + - This method is a no-op (does nothing) + Raises: InterfaceError: If connection is closed TransactionError: If server rejects the change """ + # No-op when ignore_transactions is True + if self.ignore_transactions: + return + if not self.open: raise InterfaceError( "Cannot set autocommit on closed connection", @@ -651,10 +666,17 @@ def commit(self) -> None: When autocommit is True: - Server may throw error if no active transaction + When ignore_transactions is True: + - This method is a no-op (does nothing) + Raises: InterfaceError: If connection is closed TransactionError: If commit fails (e.g., no active transaction) """ + # No-op when ignore_transactions is True + if self.ignore_transactions: + return + if not self.open: raise InterfaceError( "Cannot commit on closed connection", @@ -689,12 +711,23 @@ def rollback(self) -> None: When autocommit is True: - ROLLBACK is forgiving (no-op, doesn't throw exception) + When ignore_transactions is True: + - Raises NotSupportedError + Note: ROLLBACK is safe to call even without active transaction. Raises: InterfaceError: If connection is closed + NotSupportedError: If ignore_transactions is True TransactionError: If rollback fails """ + # Raise NotSupportedError when ignore_transactions is True + if self.ignore_transactions: + raise NotSupportedError( + "Transactions are not supported on Databricks", + session_id_hex=self.get_session_id_hex(), + ) + if not self.open: raise InterfaceError( "Cannot rollback on closed connection", diff --git a/tests/e2e/test_transactions.py b/tests/e2e/test_transactions.py index 09cbdae24..d4f6a790a 100644 --- a/tests/e2e/test_transactions.py +++ b/tests/e2e/test_transactions.py @@ -48,6 +48,7 @@ def setup_and_teardown(self, connection_details): "server_hostname": connection_details["host"], "http_path": connection_details["http_path"], "access_token": connection_details.get("access_token"), + "ignore_transactions": False, # Enable actual transaction functionality for these tests } # Get catalog and schema from environment or use defaults diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index cb810afbb..b515756e8 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -655,8 +655,10 @@ def _create_mock_connection(self, mock_session_class): mock_session.get_autocommit.return_value = True mock_session_class.return_value = mock_session - # Create connection - conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) + # Create connection with ignore_transactions=False to test actual transaction functionality + conn = client.Connection( + ignore_transactions=False, **self.DUMMY_CONNECTION_ARGS + ) return conn @patch("%s.client.Session" % PACKAGE_NAME) @@ -928,7 +930,9 @@ def test_fetch_autocommit_from_server_queries_server(self, mock_session_class): mock_session_class.return_value = mock_session conn = client.Connection( - fetch_autocommit_from_server=True, **self.DUMMY_CONNECTION_ARGS + fetch_autocommit_from_server=True, + ignore_transactions=False, + **self.DUMMY_CONNECTION_ARGS, ) mock_cursor = Mock() @@ -958,7 +962,9 @@ def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_cla mock_session_class.return_value = mock_session conn = client.Connection( - fetch_autocommit_from_server=True, **self.DUMMY_CONNECTION_ARGS + fetch_autocommit_from_server=True, + ignore_transactions=False, + **self.DUMMY_CONNECTION_ARGS, ) mock_cursor = Mock() @@ -983,7 +989,9 @@ def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_cla mock_session_class.return_value = mock_session conn = client.Connection( - fetch_autocommit_from_server=True, **self.DUMMY_CONNECTION_ARGS + fetch_autocommit_from_server=True, + ignore_transactions=False, + **self.DUMMY_CONNECTION_ARGS, ) mock_cursor = Mock() @@ -998,6 +1006,90 @@ def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_cla conn.close() + # ==================== IGNORE_TRANSACTIONS TESTS ==================== + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_is_noop_when_ignore_transactions_true(self, mock_session_class): + """Test that commit() is a no-op when ignore_transactions=True.""" + + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + # Create connection with ignore_transactions=True (default) + conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) + + # Verify ignore_transactions is True by default + self.assertTrue(conn.ignore_transactions) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + # Call commit - should be no-op + conn.commit() + + # Verify that execute was NOT called (no-op) + mock_cursor.execute.assert_not_called() + mock_cursor.close.assert_not_called() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_raises_not_supported_when_ignore_transactions_true( + self, mock_session_class + ): + """Test that rollback() raises NotSupportedError when ignore_transactions=True.""" + + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + # Create connection with ignore_transactions=True (default) + conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) + + # Verify ignore_transactions is True by default + self.assertTrue(conn.ignore_transactions) + + # Call rollback - should raise NotSupportedError + with self.assertRaises(NotSupportedError) as ctx: + conn.rollback() + + self.assertIn("Transactions are not supported", str(ctx.exception)) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_is_noop_when_ignore_transactions_true( + self, mock_session_class + ): + """Test that autocommit setter is a no-op when ignore_transactions=True.""" + + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + # Create connection with ignore_transactions=True (default) + conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) + + # Verify ignore_transactions is True by default + self.assertTrue(conn.ignore_transactions) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + # Set autocommit - should be no-op + conn.autocommit = False + + # Verify that execute was NOT called (no-op) + mock_cursor.execute.assert_not_called() + mock_cursor.close.assert_not_called() + + # Session set_autocommit should also not be called + conn.session.set_autocommit.assert_not_called() + + conn.close() + if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) From ad227ca8ffb41319b851d207e5926dda50e8f937 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Thu, 20 Nov 2025 18:14:39 +0530 Subject: [PATCH 06/39] Ready for 4.2.1 release (#713) Signed-off-by: Vikrant Puppala --- CHANGELOG.md | 3 +++ pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f5402ccb..5b902e976 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Release History +# 4.2.1 (2025-11-20) +- Ignore transactions by default (databricks/databricks-sql-python#711 by @jayantsing-db) + # 4.2.0 (2025-11-14) - Add multi-statement transaction support (databricks/databricks-sql-python#704 by @jayantsing-db) - Add a workflow to parallelise the E2E tests (databricks/databricks-sql-python#697 by @msrathore-db) diff --git a/pyproject.toml b/pyproject.toml index 7bfc3851f..d26a71667 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.2.0" +version = "4.2.1" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 741845d11..cd37e6ce1 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -71,7 +71,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.2.0" +__version__ = "4.2.1" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From b8494ff84cff5fc8b0229460835f97f5409f6934 Mon Sep 17 00:00:00 2001 From: Samikshya Chand <148681192+samikshya-db@users.noreply.github.com> Date: Fri, 21 Nov 2025 14:05:04 +0530 Subject: [PATCH 07/39] Change default use_hybrid_disposition to False (#714) This changes the default value of use_hybrid_disposition from True to False in the SEA backend, disabling hybrid disposition by default. --- src/databricks/sql/backend/sea/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 75d2c665c..1427226d2 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -157,7 +157,7 @@ def __init__( "_use_arrow_native_complex_types", True ) - self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True) + self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", False) self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) # Extract warehouse ID from http_path From 73580fe298168f553d99f99b8628b0771a91abf5 Mon Sep 17 00:00:00 2001 From: nikhilsuri-db Date: Wed, 26 Nov 2025 18:33:12 +0530 Subject: [PATCH 08/39] Circuit breaker changes using pybreaker (#705) * Added driver connection params Signed-off-by: Nikhil Suri * Added model fields for chunk/result latency Signed-off-by: Nikhil Suri * fixed linting issues Signed-off-by: Nikhil Suri * lint issue fixing Signed-off-by: Nikhil Suri * circuit breaker changes using pybreaker Signed-off-by: Nikhil Suri * Added interface layer top of http client to use circuit rbeaker Signed-off-by: Nikhil Suri * Added test cases to validate ciruit breaker Signed-off-by: Nikhil Suri * fixing broken tests Signed-off-by: Nikhil Suri * fixed linting issues Signed-off-by: Nikhil Suri * fixed failing test cases Signed-off-by: Nikhil Suri * fixed urllib3 issue Signed-off-by: Nikhil Suri * added more test cases for telemetry Signed-off-by: Nikhil Suri * simplified CB config Signed-off-by: Nikhil Suri * poetry lock Signed-off-by: Nikhil Suri * fix minor issues & improvement Signed-off-by: Nikhil Suri * improved circuit breaker for handling only 429/503 Signed-off-by: Nikhil Suri * linting issue fixed Signed-off-by: Nikhil Suri * raise CB only for 429/503 Signed-off-by: Nikhil Suri * fix broken test cases Signed-off-by: Nikhil Suri * fixed untyped references Signed-off-by: Nikhil Suri * added more test to verify the changes Signed-off-by: Nikhil Suri * description changed Signed-off-by: Nikhil Suri * remove cb congig class to constants Signed-off-by: Nikhil Suri * removed mocked reponse and use a new exlucded exception in CB Signed-off-by: Nikhil Suri * fixed broken test Signed-off-by: Nikhil Suri * added e2e test to verify circuit breaker Signed-off-by: Nikhil Suri * lower log level for telemetry Signed-off-by: Nikhil Suri * fixed broken test, removed tests on log assertions Signed-off-by: Nikhil Suri * modified unit to reduce the noise and follow dry principle Signed-off-by: Nikhil Suri --------- Signed-off-by: Nikhil Suri --- poetry.lock | 36 ++- pyproject.toml | 1 + src/databricks/sql/auth/common.py | 2 + .../sql/common/unified_http_client.py | 47 +++- src/databricks/sql/exc.py | 21 ++ .../sql/telemetry/circuit_breaker_manager.py | 112 +++++++++ .../sql/telemetry/telemetry_client.py | 32 ++- .../sql/telemetry/telemetry_push_client.py | 201 +++++++++++++++ src/databricks/sql/utils.py | 3 + tests/e2e/test_circuit_breaker.py | 232 ++++++++++++++++++ .../unit/test_circuit_breaker_http_client.py | 208 ++++++++++++++++ tests/unit/test_circuit_breaker_manager.py | 160 ++++++++++++ tests/unit/test_telemetry.py | 32 ++- tests/unit/test_telemetry_push_client.py | 213 ++++++++++++++++ .../test_telemetry_request_error_handling.py | 96 ++++++++ tests/unit/test_unified_http_client.py | 136 ++++++++++ 16 files changed, 1512 insertions(+), 20 deletions(-) create mode 100644 src/databricks/sql/telemetry/circuit_breaker_manager.py create mode 100644 src/databricks/sql/telemetry/telemetry_push_client.py create mode 100644 tests/e2e/test_circuit_breaker.py create mode 100644 tests/unit/test_circuit_breaker_http_client.py create mode 100644 tests/unit/test_circuit_breaker_manager.py create mode 100644 tests/unit/test_telemetry_push_client.py create mode 100644 tests/unit/test_telemetry_request_error_handling.py create mode 100644 tests/unit/test_unified_http_client.py diff --git a/poetry.lock b/poetry.lock index 1a8074c2a..193efa109 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "astroid" @@ -1348,6 +1348,38 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pybreaker" +version = "1.2.0" +description = "Python implementation of the Circuit Breaker pattern" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" +files = [ + {file = "pybreaker-1.2.0-py3-none-any.whl", hash = "sha256:c3e7683e29ecb3d4421265aaea55504f1186a2fdc1f17b6b091d80d1e1eb5ede"}, + {file = "pybreaker-1.2.0.tar.gz", hash = "sha256:18707776316f93a30c1be0e4fec1f8aa5ed19d7e395a218eb2f050c8524fb2dc"}, +] + +[package.extras] +test = ["fakeredis", "mock", "pytest", "redis", "tornado", "types-mock", "types-redis"] + +[[package]] +name = "pybreaker" +version = "1.4.1" +description = "Python implementation of the Circuit Breaker pattern" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "pybreaker-1.4.1-py3-none-any.whl", hash = "sha256:b4dab4a05195b7f2a64a6c1a6c4ba7a96534ef56ea7210e6bcb59f28897160e0"}, + {file = "pybreaker-1.4.1.tar.gz", hash = "sha256:8df2d245c73ba40c8242c56ffb4f12138fbadc23e296224740c2028ea9dc1178"}, +] + +[package.extras] +test = ["fakeredis", "mock", "pytest", "redis", "tornado", "types-mock", "types-redis"] + [[package]] name = "pycparser" version = "2.22" @@ -1858,4 +1890,4 @@ pyarrow = ["pyarrow", "pyarrow"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "0a3f611ef8747376f018c1df0a1ea7873368851873cc4bd3a4d51bba0bba847c" +content-hash = "56b62e3543644c91cc316b11d89025423a66daba5f36609c45bcb3eeb3ce3f54" diff --git a/pyproject.toml b/pyproject.toml index d26a71667..61c248e98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ pyarrow = [ { version = ">=18.0.0", python = ">=3.13", optional=true } ] pyjwt = "^2.0.0" +pybreaker = "^1.0.0" requests-kerberos = {version = "^0.15.0", optional = true} diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 3e0be0d2b..a764b036d 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -51,6 +51,7 @@ def __init__( pool_connections: Optional[int] = None, pool_maxsize: Optional[int] = None, user_agent: Optional[str] = None, + telemetry_circuit_breaker_enabled: Optional[bool] = None, ): self.hostname = hostname self.access_token = access_token @@ -83,6 +84,7 @@ def __init__( self.pool_connections = pool_connections or 10 self.pool_maxsize = pool_maxsize or 20 self.user_agent = user_agent + self.telemetry_circuit_breaker_enabled = bool(telemetry_circuit_breaker_enabled) def get_effective_azure_login_app_id(hostname) -> str: diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 96fb9cbb9..d5f7d3c8d 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -28,6 +28,42 @@ logger = logging.getLogger(__name__) +def _extract_http_status_from_max_retry_error(e: MaxRetryError) -> Optional[int]: + """ + Extract HTTP status code from MaxRetryError if available. + + urllib3 structures MaxRetryError in different ways depending on the failure scenario: + - e.reason.response.status: Most common case when retries are exhausted + - e.response.status: Alternate structure in some scenarios + + Args: + e: MaxRetryError exception from urllib3 + + Returns: + HTTP status code as int if found, None otherwise + """ + # Try primary structure: e.reason.response.status + if ( + hasattr(e, "reason") + and e.reason is not None + and hasattr(e.reason, "response") + and e.reason.response is not None + ): + http_code = getattr(e.reason.response, "status", None) + if http_code is not None: + return http_code + + # Try alternate structure: e.response.status + if ( + hasattr(e, "response") + and e.response is not None + and hasattr(e.response, "status") + ): + return e.response.status + + return None + + class UnifiedHttpClient: """ Unified HTTP client for all Databricks SQL connector HTTP operations. @@ -264,7 +300,16 @@ def request_context( yield response except MaxRetryError as e: logger.error("HTTP request failed after retries: %s", e) - raise RequestError(f"HTTP request failed: {e}") + + # Extract HTTP status code from MaxRetryError if available + http_code = _extract_http_status_from_max_retry_error(e) + + context = {} + if http_code is not None: + context["http-code"] = http_code + logger.error("HTTP request failed with status code: %d", http_code) + + raise RequestError(f"HTTP request failed: {e}", context=context) except Exception as e: logger.error("HTTP request error: %s", e) raise RequestError(f"HTTP request error: {e}") diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 3a3a6b3c5..24844d573 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -143,3 +143,24 @@ class SessionAlreadyClosedError(RequestError): class CursorAlreadyClosedError(RequestError): """Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected.""" + + +class TelemetryRateLimitError(Exception): + """Raised when telemetry endpoint returns 429 or 503, indicating rate limiting or service unavailable. + This exception is used exclusively by the circuit breaker to track telemetry rate limiting events.""" + + +class TelemetryNonRateLimitError(Exception): + """Wrapper for telemetry errors that should NOT trigger circuit breaker. + + This exception wraps non-rate-limiting errors (network errors, timeouts, server errors, etc.) + and is excluded from circuit breaker failure counting. Only TelemetryRateLimitError should + open the circuit breaker. + + Attributes: + original_exception: The actual exception that occurred + """ + + def __init__(self, original_exception: Exception): + self.original_exception = original_exception + super().__init__(f"Non-rate-limit telemetry error: {original_exception}") diff --git a/src/databricks/sql/telemetry/circuit_breaker_manager.py b/src/databricks/sql/telemetry/circuit_breaker_manager.py new file mode 100644 index 000000000..852f0d916 --- /dev/null +++ b/src/databricks/sql/telemetry/circuit_breaker_manager.py @@ -0,0 +1,112 @@ +""" +Circuit breaker implementation for telemetry requests. + +This module provides circuit breaker functionality to prevent telemetry failures +from impacting the main SQL operations. It uses pybreaker library to implement +the circuit breaker pattern. +""" + +import logging +import threading +from typing import Dict + +import pybreaker +from pybreaker import CircuitBreaker, CircuitBreakerError, CircuitBreakerListener + +from databricks.sql.exc import TelemetryNonRateLimitError + +logger = logging.getLogger(__name__) + +# Circuit Breaker Constants +MINIMUM_CALLS = 20 # Number of failures before circuit opens +RESET_TIMEOUT = 30 # Seconds to wait before trying to close circuit +NAME_PREFIX = "telemetry-circuit-breaker" + +# Circuit Breaker State Constants (used in logging) +CIRCUIT_BREAKER_STATE_OPEN = "open" +CIRCUIT_BREAKER_STATE_CLOSED = "closed" +CIRCUIT_BREAKER_STATE_HALF_OPEN = "half-open" + +# Logging Message Constants +LOG_CIRCUIT_BREAKER_STATE_CHANGED = "Circuit breaker state changed from %s to %s for %s" +LOG_CIRCUIT_BREAKER_OPENED = ( + "Circuit breaker opened for %s - telemetry requests will be blocked" +) +LOG_CIRCUIT_BREAKER_CLOSED = ( + "Circuit breaker closed for %s - telemetry requests will be allowed" +) +LOG_CIRCUIT_BREAKER_HALF_OPEN = ( + "Circuit breaker half-open for %s - testing telemetry requests" +) + + +class CircuitBreakerStateListener(CircuitBreakerListener): + """Listener for circuit breaker state changes.""" + + def before_call(self, cb: CircuitBreaker, func, *args, **kwargs) -> None: + """Called before the circuit breaker calls a function.""" + pass + + def failure(self, cb: CircuitBreaker, exc: BaseException) -> None: + """Called when a function called by the circuit breaker fails.""" + pass + + def success(self, cb: CircuitBreaker) -> None: + """Called when a function called by the circuit breaker succeeds.""" + pass + + def state_change(self, cb: CircuitBreaker, old_state, new_state) -> None: + """Called when the circuit breaker state changes.""" + old_state_name = old_state.name if old_state else "None" + new_state_name = new_state.name if new_state else "None" + + logger.info( + LOG_CIRCUIT_BREAKER_STATE_CHANGED, old_state_name, new_state_name, cb.name + ) + + if new_state_name == CIRCUIT_BREAKER_STATE_OPEN: + logger.warning(LOG_CIRCUIT_BREAKER_OPENED, cb.name) + elif new_state_name == CIRCUIT_BREAKER_STATE_CLOSED: + logger.info(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) + elif new_state_name == CIRCUIT_BREAKER_STATE_HALF_OPEN: + logger.info(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) + + +class CircuitBreakerManager: + """ + Manages circuit breaker instances for telemetry requests. + + Creates and caches circuit breaker instances per host to ensure telemetry + failures don't impact main SQL operations. + """ + + _instances: Dict[str, CircuitBreaker] = {} + _lock = threading.RLock() + + @classmethod + def get_circuit_breaker(cls, host: str) -> CircuitBreaker: + """ + Get or create a circuit breaker instance for the specified host. + + Args: + host: The hostname for which to get the circuit breaker + + Returns: + CircuitBreaker instance for the host + """ + with cls._lock: + if host not in cls._instances: + breaker = CircuitBreaker( + fail_max=MINIMUM_CALLS, + reset_timeout=RESET_TIMEOUT, + name=f"{NAME_PREFIX}-{host}", + exclude=[ + TelemetryNonRateLimitError + ], # Don't count these as failures + ) + # Add state change listener for logging + breaker.add_listener(CircuitBreakerStateListener()) + cls._instances[host] = breaker + logger.debug("Created circuit breaker for host: %s", host) + + return cls._instances[host] diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 134757fe5..177d5445c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -41,6 +41,11 @@ from databricks.sql.common.feature_flag import FeatureFlagsContextFactory from databricks.sql.common.unified_http_client import UnifiedHttpClient from databricks.sql.common.http import HttpMethod +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) if TYPE_CHECKING: from databricks.sql.client import Connection @@ -166,21 +171,21 @@ class TelemetryClient(BaseTelemetryClient): def __init__( self, - telemetry_enabled, - session_id_hex, + telemetry_enabled: bool, + session_id_hex: str, auth_provider, - host_url, + host_url: str, executor, - batch_size, + batch_size: int, client_context, - ): + ) -> None: logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled self._batch_size = batch_size self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None - self._events_batch = [] + self._events_batch: list = [] self._lock = threading.RLock() self._driver_connection_params = None self._host_url = host_url @@ -189,6 +194,19 @@ def __init__( # Create own HTTP client from client context self._http_client = UnifiedHttpClient(client_context) + # Create telemetry push client based on circuit breaker enabled flag + if client_context.telemetry_circuit_breaker_enabled: + # Create circuit breaker telemetry push client (circuit breakers created on-demand) + self._telemetry_push_client: ITelemetryPushClient = ( + CircuitBreakerTelemetryPushClient( + TelemetryPushClient(self._http_client), + host_url, + ) + ) + else: + # Circuit breaker disabled - use direct telemetry push client + self._telemetry_push_client = TelemetryPushClient(self._http_client) + def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._session_id_hex) @@ -254,7 +272,7 @@ def _send_telemetry(self, events): def _send_with_unified_client(self, url, data, headers, timeout=900): """Helper method to send telemetry using the unified HTTP client.""" try: - response = self._http_client.request( + response = self._telemetry_push_client.request( HttpMethod.POST, url, body=data, headers=headers, timeout=timeout ) return response diff --git a/src/databricks/sql/telemetry/telemetry_push_client.py b/src/databricks/sql/telemetry/telemetry_push_client.py new file mode 100644 index 000000000..461a57738 --- /dev/null +++ b/src/databricks/sql/telemetry/telemetry_push_client.py @@ -0,0 +1,201 @@ +""" +Telemetry push client interface and implementations. + +This module provides an interface for telemetry push clients with two implementations: +1. TelemetryPushClient - Direct HTTP client implementation +2. CircuitBreakerTelemetryPushClient - Circuit breaker wrapper implementation +""" + +import logging +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional + +try: + from urllib3 import BaseHTTPResponse +except ImportError: + from urllib3 import HTTPResponse as BaseHTTPResponse +from pybreaker import CircuitBreakerError + +from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import ( + TelemetryRateLimitError, + TelemetryNonRateLimitError, + RequestError, +) +from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + +logger = logging.getLogger(__name__) + + +class ITelemetryPushClient(ABC): + """Interface for telemetry push clients.""" + + @abstractmethod + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> BaseHTTPResponse: + """Make an HTTP request.""" + pass + + +class TelemetryPushClient(ITelemetryPushClient): + """Direct HTTP client implementation for telemetry requests.""" + + def __init__(self, http_client: UnifiedHttpClient): + """ + Initialize the telemetry push client. + + Args: + http_client: The underlying HTTP client + """ + self._http_client = http_client + logger.debug("TelemetryPushClient initialized") + + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> BaseHTTPResponse: + """Make an HTTP request using the underlying HTTP client.""" + return self._http_client.request(method, url, headers, **kwargs) + + +class CircuitBreakerTelemetryPushClient(ITelemetryPushClient): + """Circuit breaker wrapper implementation for telemetry requests.""" + + def __init__(self, delegate: ITelemetryPushClient, host: str): + """ + Initialize the circuit breaker telemetry push client. + + Args: + delegate: The underlying telemetry push client to wrap + host: The hostname for circuit breaker identification + """ + self._delegate = delegate + self._host = host + + # Get circuit breaker for this host (creates if doesn't exist) + self._circuit_breaker = CircuitBreakerManager.get_circuit_breaker(host) + + logger.debug( + "CircuitBreakerTelemetryPushClient initialized for host %s", + host, + ) + + def _make_request_and_check_status( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]], + **kwargs, + ) -> BaseHTTPResponse: + """ + Make the request and check response status. + + Raises TelemetryRateLimitError for 429/503 (circuit breaker counts these). + Wraps other errors in TelemetryNonRateLimitError (circuit breaker excludes these). + + Args: + method: HTTP method + url: Request URL + headers: Request headers + **kwargs: Additional request parameters + + Returns: + HTTP response + + Raises: + TelemetryRateLimitError: For 429/503 status codes (circuit breaker counts) + TelemetryNonRateLimitError: For other errors (circuit breaker excludes) + """ + try: + response = self._delegate.request(method, url, headers, **kwargs) + + # Check for rate limiting or service unavailable + if response.status in [429, 503]: + logger.warning( + "Telemetry endpoint returned %d for host %s, triggering circuit breaker", + response.status, + self._host, + ) + raise TelemetryRateLimitError( + f"Telemetry endpoint rate limited or unavailable: {response.status}" + ) + + return response + + except Exception as e: + # Don't catch TelemetryRateLimitError - let it propagate to circuit breaker + if isinstance(e, TelemetryRateLimitError): + raise + + # Check if it's a RequestError with rate limiting status code (exhausted retries) + if isinstance(e, RequestError): + http_code = ( + e.context.get("http-code") + if hasattr(e, "context") and e.context + else None + ) + + if http_code in [429, 503]: + logger.debug( + "Telemetry retries exhausted with status %d for host %s, triggering circuit breaker", + http_code, + self._host, + ) + raise TelemetryRateLimitError( + f"Telemetry rate limited after retries: {http_code}" + ) + + # NOT rate limiting (500 errors, network errors, timeouts, etc.) + # Wrap in TelemetryNonRateLimitError so circuit breaker excludes it + logger.debug( + "Non-rate-limit telemetry error for host %s: %s, wrapping to exclude from circuit breaker", + self._host, + e, + ) + raise TelemetryNonRateLimitError(e) from e + + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> BaseHTTPResponse: + """ + Make an HTTP request with circuit breaker protection. + + Circuit breaker only opens for TelemetryRateLimitError (429/503 responses). + Other errors are wrapped in TelemetryNonRateLimitError and excluded from circuit breaker. + All exceptions propagate to caller (TelemetryClient callback handles them). + """ + try: + # Use circuit breaker to protect the request + # TelemetryRateLimitError will trigger circuit breaker + # TelemetryNonRateLimitError is excluded from circuit breaker + return self._circuit_breaker.call( + self._make_request_and_check_status, + method, + url, + headers, + **kwargs, + ) + + except TelemetryNonRateLimitError as e: + # Unwrap and re-raise original exception + # Circuit breaker didn't count this, but caller should handle it + logger.debug( + "Non-rate-limit telemetry error for host %s, re-raising original: %s", + self._host, + e.original_exception, + ) + raise e.original_exception from e + # All other exceptions (TelemetryRateLimitError, CircuitBreakerError) propagate as-is diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 9f96e8743..b46784b10 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -922,4 +922,7 @@ def build_client_context(server_hostname: str, version: str, **kwargs): proxy_auth_method=kwargs.get("_proxy_auth_method"), pool_connections=kwargs.get("_pool_connections"), pool_maxsize=kwargs.get("_pool_maxsize"), + telemetry_circuit_breaker_enabled=kwargs.get( + "_telemetry_circuit_breaker_enabled" + ), ) diff --git a/tests/e2e/test_circuit_breaker.py b/tests/e2e/test_circuit_breaker.py new file mode 100644 index 000000000..45c494d19 --- /dev/null +++ b/tests/e2e/test_circuit_breaker.py @@ -0,0 +1,232 @@ +""" +E2E tests for circuit breaker functionality in telemetry. + +This test suite verifies: +1. Circuit breaker opens after rate limit failures (429/503) +2. Circuit breaker blocks subsequent calls while open +3. Circuit breaker does not trigger for non-rate-limit errors +4. Circuit breaker can be disabled via configuration flag +5. Circuit breaker closes after reset timeout + +Run with: + pytest tests/e2e/test_circuit_breaker.py -v -s +""" + +import time +from unittest.mock import patch, MagicMock + +import pytest +from pybreaker import STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN +from urllib3 import HTTPResponse + +import databricks.sql as sql +from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + + +@pytest.fixture(autouse=True) +def aggressive_circuit_breaker_config(): + """ + Configure circuit breaker to be aggressive for faster testing. + Opens after 2 failures instead of 20, with 5 second timeout. + """ + from databricks.sql.telemetry import circuit_breaker_manager + + original_minimum_calls = circuit_breaker_manager.MINIMUM_CALLS + original_reset_timeout = circuit_breaker_manager.RESET_TIMEOUT + + circuit_breaker_manager.MINIMUM_CALLS = 2 + circuit_breaker_manager.RESET_TIMEOUT = 5 + + CircuitBreakerManager._instances.clear() + + yield + + circuit_breaker_manager.MINIMUM_CALLS = original_minimum_calls + circuit_breaker_manager.RESET_TIMEOUT = original_reset_timeout + CircuitBreakerManager._instances.clear() + + +class TestCircuitBreakerTelemetry: + """Tests for circuit breaker functionality with telemetry""" + + @pytest.fixture(autouse=True) + def get_details(self, connection_details): + """Get connection details from pytest fixture""" + self.arguments = connection_details.copy() + + def create_mock_response(self, status_code): + """Helper to create mock HTTP response.""" + response = MagicMock(spec=HTTPResponse) + response.status = status_code + response.data = { + 429: b"Too Many Requests", + 503: b"Service Unavailable", + 500: b"Internal Server Error", + }.get(status_code, b"Response") + return response + + @pytest.mark.parametrize("status_code,should_trigger", [ + (429, True), + (503, True), + (500, False), + ]) + def test_circuit_breaker_triggers_for_rate_limit_codes(self, status_code, should_trigger): + """ + Verify circuit breaker opens for rate-limit codes (429/503) but not others (500). + """ + request_count = {"count": 0} + + def mock_request(*args, **kwargs): + request_count["count"] += 1 + return self.create_mock_response(status_code) + + with patch( + "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", + side_effect=mock_request, + ): + with sql.connect( + server_hostname=self.arguments["host"], + http_path=self.arguments["http_path"], + access_token=self.arguments.get("access_token"), + force_enable_telemetry=True, + telemetry_batch_size=1, + _telemetry_circuit_breaker_enabled=True, + ) as conn: + circuit_breaker = CircuitBreakerManager.get_circuit_breaker( + self.arguments["host"] + ) + + assert circuit_breaker.current_state == STATE_CLOSED + + cursor = conn.cursor() + + # Execute queries to trigger telemetry + for i in range(1, 6): + cursor.execute(f"SELECT {i}") + cursor.fetchone() + time.sleep(0.5) + + if should_trigger: + # Circuit should be OPEN after 2 rate-limit failures + assert circuit_breaker.current_state == STATE_OPEN + assert circuit_breaker.fail_counter == 2 + + # Track requests before another query + requests_before = request_count["count"] + cursor.execute("SELECT 99") + cursor.fetchone() + time.sleep(1) + + # No new telemetry requests (circuit is open) + assert request_count["count"] == requests_before + else: + # Circuit should remain CLOSED for non-rate-limit errors + assert circuit_breaker.current_state == STATE_CLOSED + assert circuit_breaker.fail_counter == 0 + assert request_count["count"] >= 5 + + def test_circuit_breaker_disabled_allows_all_calls(self): + """ + Verify that when circuit breaker is disabled, all calls go through + even with rate limit errors. + """ + request_count = {"count": 0} + + def mock_rate_limited_request(*args, **kwargs): + request_count["count"] += 1 + return self.create_mock_response(429) + + with patch( + "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", + side_effect=mock_rate_limited_request, + ): + with sql.connect( + server_hostname=self.arguments["host"], + http_path=self.arguments["http_path"], + access_token=self.arguments.get("access_token"), + force_enable_telemetry=True, + telemetry_batch_size=1, + _telemetry_circuit_breaker_enabled=False, # Disabled + ) as conn: + cursor = conn.cursor() + + for i in range(5): + cursor.execute(f"SELECT {i}") + cursor.fetchone() + time.sleep(0.3) + + assert request_count["count"] >= 5 + + def test_circuit_breaker_recovers_after_reset_timeout(self): + """ + Verify circuit breaker transitions to HALF_OPEN after reset timeout + and eventually CLOSES if requests succeed. + """ + request_count = {"count": 0} + fail_requests = {"enabled": True} + + def mock_conditional_request(*args, **kwargs): + request_count["count"] += 1 + status = 429 if fail_requests["enabled"] else 200 + return self.create_mock_response(status) + + with patch( + "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", + side_effect=mock_conditional_request, + ): + with sql.connect( + server_hostname=self.arguments["host"], + http_path=self.arguments["http_path"], + access_token=self.arguments.get("access_token"), + force_enable_telemetry=True, + telemetry_batch_size=1, + _telemetry_circuit_breaker_enabled=True, + ) as conn: + circuit_breaker = CircuitBreakerManager.get_circuit_breaker( + self.arguments["host"] + ) + + cursor = conn.cursor() + + # Trigger failures to open circuit + cursor.execute("SELECT 1") + cursor.fetchone() + time.sleep(1) + + cursor.execute("SELECT 2") + cursor.fetchone() + time.sleep(2) + + assert circuit_breaker.current_state == STATE_OPEN + + # Wait for reset timeout (5 seconds in test) + time.sleep(6) + + # Now make requests succeed + fail_requests["enabled"] = False + + # Execute query to trigger HALF_OPEN state + cursor.execute("SELECT 3") + cursor.fetchone() + time.sleep(1) + + # Circuit should be recovering + assert circuit_breaker.current_state in [ + STATE_HALF_OPEN, + STATE_CLOSED, + ], f"Circuit should be recovering, but is {circuit_breaker.current_state}" + + # Execute more queries to fully recover + cursor.execute("SELECT 4") + cursor.fetchone() + time.sleep(1) + + current_state = circuit_breaker.current_state + assert current_state in [ + STATE_CLOSED, + STATE_HALF_OPEN, + ], f"Circuit should recover to CLOSED or HALF_OPEN, got {current_state}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/unit/test_circuit_breaker_http_client.py b/tests/unit/test_circuit_breaker_http_client.py new file mode 100644 index 000000000..432ca1be3 --- /dev/null +++ b/tests/unit/test_circuit_breaker_http_client.py @@ -0,0 +1,208 @@ +""" +Unit tests for telemetry push client functionality. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock + +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from pybreaker import CircuitBreakerError + + +class TestTelemetryPushClient: + """Test cases for TelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_http_client = Mock() + self.client = TelemetryPushClient(self.mock_http_client) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._http_client == self.mock_http_client + + def test_request_delegates_to_http_client(self): + """Test that request delegates to underlying HTTP client.""" + mock_response = Mock() + self.mock_http_client.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_http_client.request.assert_called_once() + + def test_direct_client_has_no_circuit_breaker(self): + """Test that direct client does not have circuit breaker functionality.""" + # Direct client should work without circuit breaker + assert isinstance(self.client, TelemetryPushClient) + + +class TestCircuitBreakerTelemetryPushClient: + """Test cases for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock(spec=ITelemetryPushClient) + self.host = "test-host.example.com" + self.client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._delegate == self.mock_delegate + assert self.client._host == self.host + assert self.client._circuit_breaker is not None + + def test_request_enabled_success(self): + """Test successful request when circuit breaker is enabled.""" + mock_response = Mock() + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_delegate.request.assert_called_once() + + def test_request_enabled_circuit_breaker_error(self): + """Test request when circuit breaker is open - should raise CircuitBreakerError.""" + # Mock circuit breaker to raise CircuitBreakerError + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + # Circuit breaker open should raise (caller handles it) + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_enabled_other_error(self): + """Test request when other error occurs - should raise original exception.""" + # Mock delegate to raise a different error (not rate limiting) + self.mock_delegate.request.side_effect = ValueError("Network error") + + # Non-rate-limit errors are unwrapped and raised + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_is_circuit_breaker_enabled(self): + """Test checking if circuit breaker is enabled.""" + assert self.client._circuit_breaker is not None + + def test_circuit_breaker_state_logging(self): + """Test that circuit breaker errors are raised (no longer silent).""" + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + # Should raise CircuitBreakerError (caller handles it) + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_other_error_logging(self): + """Test that other errors are wrapped, logged, then unwrapped and raised.""" + with patch( + "databricks.sql.telemetry.telemetry_push_client.logger" + ) as mock_logger: + self.mock_delegate.request.side_effect = ValueError("Network error") + + # Should raise the original ValueError + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + # Check that debug was logged (for wrapping and/or unwrapping) + assert mock_logger.debug.call_count >= 1 + + +class TestCircuitBreakerTelemetryPushClientIntegration: + """Integration tests for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock() + self.host = "test-host.example.com" + + def test_circuit_breaker_opens_after_failures(self): + """Test that circuit breaker opens after repeated failures (429/503 errors).""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + ) + from databricks.sql.exc import TelemetryRateLimitError + + # Clear any existing state + CircuitBreakerManager._instances.clear() + + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Simulate rate limit failures (429) + mock_response = Mock() + mock_response.status = 429 + self.mock_delegate.request.return_value = mock_response + + # All calls should raise TelemetryRateLimitError + # After MINIMUM_CALLS failures, circuit breaker opens + rate_limit_error_count = 0 + circuit_breaker_error_count = 0 + + for i in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except TelemetryRateLimitError: + rate_limit_error_count += 1 + except CircuitBreakerError: + circuit_breaker_error_count += 1 + + # Should have some rate limit errors before circuit opens, then circuit breaker errors + assert rate_limit_error_count >= MINIMUM_CALLS - 1 + assert circuit_breaker_error_count > 0 + + def test_circuit_breaker_recovers_after_success(self): + """Test that circuit breaker recovers after successful calls.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + ) + import time + + # Clear any existing state + CircuitBreakerManager._instances.clear() + + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Simulate rate limit failures first (429) + from databricks.sql.exc import TelemetryRateLimitError + from pybreaker import CircuitBreakerError + + mock_rate_limit_response = Mock() + mock_rate_limit_response.status = 429 + self.mock_delegate.request.return_value = mock_rate_limit_response + + # Trigger enough rate limit failures to open circuit + for i in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except (TelemetryRateLimitError, CircuitBreakerError): + pass # Expected - circuit breaker opens after MINIMUM_CALLS failures + + # Circuit should be open now - raises CircuitBreakerError + with pytest.raises(CircuitBreakerError): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 1.0) + + # Simulate successful calls (200 response) + mock_success_response = Mock() + mock_success_response.status = 200 + self.mock_delegate.request.return_value = mock_success_response + + # Should work again with actual success response + response = client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + assert response.status == 200 diff --git a/tests/unit/test_circuit_breaker_manager.py b/tests/unit/test_circuit_breaker_manager.py new file mode 100644 index 000000000..e8ed4e809 --- /dev/null +++ b/tests/unit/test_circuit_breaker_manager.py @@ -0,0 +1,160 @@ +""" +Unit tests for circuit breaker manager functionality. +""" + +import pytest +import threading +import time +from unittest.mock import Mock, patch + +from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + NAME_PREFIX as CIRCUIT_BREAKER_NAME, +) +from pybreaker import CircuitBreakerError + + +class TestCircuitBreakerManager: + """Test cases for CircuitBreakerManager.""" + + def setup_method(self): + """Set up test fixtures.""" + CircuitBreakerManager._instances.clear() + + def teardown_method(self): + """Clean up after tests.""" + CircuitBreakerManager._instances.clear() + + def test_get_circuit_breaker_creates_instance(self): + """Test getting circuit breaker creates instance with correct config.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker.name == "telemetry-circuit-breaker-test-host" + assert breaker.fail_max == MINIMUM_CALLS + + def test_get_circuit_breaker_same_host_returns_same_instance(self): + """Test that same host returns same circuit breaker instance.""" + breaker1 = CircuitBreakerManager.get_circuit_breaker("test-host") + breaker2 = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker1 is breaker2 + + def test_get_circuit_breaker_different_hosts_return_different_instances(self): + """Test that different hosts return different circuit breaker instances.""" + breaker1 = CircuitBreakerManager.get_circuit_breaker("host1") + breaker2 = CircuitBreakerManager.get_circuit_breaker("host2") + + assert breaker1 is not breaker2 + assert breaker1.name != breaker2.name + + def test_thread_safety(self): + """Test thread safety of circuit breaker manager.""" + results = [] + + def get_breaker(host): + breaker = CircuitBreakerManager.get_circuit_breaker(host) + results.append(breaker) + + threads = [] + for i in range(10): + thread = threading.Thread(target=get_breaker, args=(f"host{i % 3}",)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert len(results) == 10 + + # All breakers for same host should be same instance + host0_breakers = [b for b in results if b.name.endswith("host0")] + assert all(b is host0_breakers[0] for b in host0_breakers) + + +class TestCircuitBreakerIntegration: + """Integration tests for circuit breaker functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + CircuitBreakerManager._instances.clear() + + def teardown_method(self): + """Clean up after tests.""" + CircuitBreakerManager._instances.clear() + + def test_circuit_breaker_state_transitions(self): + """Test circuit breaker state transitions from closed to open.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker.current_state == "closed" + + def failing_func(): + raise Exception("Simulated failure") + + # Trigger failures up to the threshold (MINIMUM_CALLS = 20) + for _ in range(MINIMUM_CALLS): + with pytest.raises(Exception): + breaker.call(failing_func) + + # Next call should fail with CircuitBreakerError (circuit is now open) + with pytest.raises(CircuitBreakerError): + breaker.call(failing_func) + + assert breaker.current_state == "open" + + def test_circuit_breaker_recovery(self): + """Test circuit breaker recovery after failures.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + def failing_func(): + raise Exception("Simulated failure") + + # Trigger failures up to the threshold + for _ in range(MINIMUM_CALLS): + with pytest.raises(Exception): + breaker.call(failing_func) + + assert breaker.current_state == "open" + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 1.0) + + # Try successful call to close circuit breaker + def successful_func(): + return "success" + + try: + result = breaker.call(successful_func) + assert result == "success" + except CircuitBreakerError: + pass # Circuit might still be open, acceptable + + assert breaker.current_state in ["closed", "half-open", "open"] + + @pytest.mark.parametrize("old_state,new_state", [ + ("closed", "open"), + ("open", "half-open"), + ("half-open", "closed"), + ("closed", "half-open"), + ]) + def test_circuit_breaker_state_listener_transitions(self, old_state, new_state): + """Test circuit breaker state listener logs all state transitions.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerStateListener, + ) + + listener = CircuitBreakerStateListener() + mock_cb = Mock() + mock_cb.name = "test-breaker" + + mock_old_state = Mock() + mock_old_state.name = old_state + + mock_new_state = Mock() + mock_new_state.name = new_state + + with patch("databricks.sql.telemetry.circuit_breaker_manager.logger") as mock_logger: + listener.state_change(mock_cb, mock_old_state, mock_new_state) + mock_logger.info.assert_called() diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 36141ee2b..6f5a01c7b 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -37,7 +37,9 @@ def mock_telemetry_client(): client_context = MagicMock() # Patch the _setup_pool_manager method to avoid SSL file loading - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -95,7 +97,7 @@ def test_network_request_flow(self, mock_http_request, mock_telemetry_client): mock_response.status = 200 mock_response.status_code = 200 mock_http_request.return_value = mock_response - + client = mock_telemetry_client # Create mock events @@ -231,7 +233,9 @@ def test_client_lifecycle_flow(self): client_context = MagicMock() # Initialize enabled client - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, @@ -299,7 +303,9 @@ def test_factory_shutdown_flow(self): client_context = MagicMock() # Initialize multiple clients - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, @@ -382,8 +388,10 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -410,8 +418,10 @@ def test_telemetry_disabled_when_flag_is_false( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -438,8 +448,10 @@ def test_telemetry_disabled_when_flag_request_fails( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request diff --git a/tests/unit/test_telemetry_push_client.py b/tests/unit/test_telemetry_push_client.py new file mode 100644 index 000000000..0e9455e1f --- /dev/null +++ b/tests/unit/test_telemetry_push_client.py @@ -0,0 +1,213 @@ +""" +Unit tests for telemetry push client functionality. +""" + +import pytest +from unittest.mock import Mock, patch + +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import TelemetryRateLimitError +from pybreaker import CircuitBreakerError + + +class TestTelemetryPushClient: + """Test cases for TelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_http_client = Mock() + self.client = TelemetryPushClient(self.mock_http_client) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._http_client == self.mock_http_client + + def test_request_delegates_to_http_client(self): + """Test that request delegates to underlying HTTP client.""" + mock_response = Mock() + self.mock_http_client.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_http_client.request.assert_called_once() + + +class TestCircuitBreakerTelemetryPushClient: + """Test cases for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock(spec=ITelemetryPushClient) + self.host = "test-host.example.com" + self.client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._delegate == self.mock_delegate + assert self.client._host == self.host + assert self.client._circuit_breaker is not None + + def test_request_success(self): + """Test successful request when circuit breaker is enabled.""" + mock_response = Mock() + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_delegate.request.assert_called_once() + + def test_request_circuit_breaker_open(self): + """Test request when circuit breaker is open raises CircuitBreakerError.""" + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_other_error(self): + """Test request when other error occurs raises original exception.""" + self.mock_delegate.request.side_effect = ValueError("Network error") + + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + @pytest.mark.parametrize("status_code,expected_error", [ + (429, TelemetryRateLimitError), + (503, TelemetryRateLimitError), + ]) + def test_request_rate_limit_codes(self, status_code, expected_error): + """Test that rate-limit status codes raise TelemetryRateLimitError.""" + mock_response = Mock() + mock_response.status = status_code + self.mock_delegate.request.return_value = mock_response + + with pytest.raises(expected_error): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_non_rate_limit_code(self): + """Test that non-rate-limit status codes return response.""" + mock_response = Mock() + mock_response.status = 500 + mock_response.data = b'Server error' + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + assert response.status == 500 + + def test_rate_limit_error_logging(self): + """Test that rate limit errors are logged with circuit breaker context.""" + with patch("databricks.sql.telemetry.telemetry_push_client.logger") as mock_logger: + mock_response = Mock() + mock_response.status = 429 + self.mock_delegate.request.return_value = mock_response + + with pytest.raises(TelemetryRateLimitError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + mock_logger.warning.assert_called() + warning_args = mock_logger.warning.call_args[0] + assert "429" in str(warning_args) + assert "circuit breaker" in warning_args[0] + + def test_other_error_logging(self): + """Test that other errors are logged during wrapping/unwrapping.""" + with patch("databricks.sql.telemetry.telemetry_push_client.logger") as mock_logger: + self.mock_delegate.request.side_effect = ValueError("Network error") + + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert mock_logger.debug.call_count >= 1 + + +class TestCircuitBreakerTelemetryPushClientIntegration: + """Integration tests for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock() + self.host = "test-host.example.com" + from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + CircuitBreakerManager._instances.clear() + + def test_circuit_breaker_opens_after_failures(self): + """Test that circuit breaker opens after repeated failures (429/503 errors).""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + ) + + CircuitBreakerManager._instances.clear() + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + mock_response = Mock() + mock_response.status = 429 + self.mock_delegate.request.return_value = mock_response + + rate_limit_error_count = 0 + circuit_breaker_error_count = 0 + + for _ in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except TelemetryRateLimitError: + rate_limit_error_count += 1 + except CircuitBreakerError: + circuit_breaker_error_count += 1 + + assert rate_limit_error_count >= MINIMUM_CALLS - 1 + assert circuit_breaker_error_count > 0 + + def test_circuit_breaker_recovers_after_success(self): + """Test that circuit breaker recovers after successful calls.""" + import time + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + ) + + CircuitBreakerManager._instances.clear() + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Trigger failures + mock_rate_limit_response = Mock() + mock_rate_limit_response.status = 429 + self.mock_delegate.request.return_value = mock_rate_limit_response + + for _ in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except (TelemetryRateLimitError, CircuitBreakerError): + pass + + # Circuit should be open + with pytest.raises(CircuitBreakerError): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 1.0) + + # Simulate success + mock_success_response = Mock() + mock_success_response.status = 200 + self.mock_delegate.request.return_value = mock_success_response + + response = client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + assert response.status == 200 + + def test_urllib3_import_fallback(self): + """Test that the urllib3 import fallback works correctly.""" + from databricks.sql.telemetry.telemetry_push_client import BaseHTTPResponse + assert BaseHTTPResponse is not None diff --git a/tests/unit/test_telemetry_request_error_handling.py b/tests/unit/test_telemetry_request_error_handling.py new file mode 100644 index 000000000..aa31f6628 --- /dev/null +++ b/tests/unit/test_telemetry_request_error_handling.py @@ -0,0 +1,96 @@ +""" +Unit tests specifically for telemetry_push_client RequestError handling +with http-code context extraction for rate limiting detection. +""" + +import pytest +from unittest.mock import Mock + +from databricks.sql.telemetry.telemetry_push_client import ( + CircuitBreakerTelemetryPushClient, + TelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import RequestError, TelemetryRateLimitError +from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + + +class TestTelemetryPushClientRequestErrorHandling: + """Test RequestError handling and http-code context extraction.""" + + @pytest.fixture + def setup_circuit_breaker(self): + """Setup circuit breaker for testing.""" + CircuitBreakerManager._instances.clear() + yield + CircuitBreakerManager._instances.clear() + + @pytest.fixture + def mock_delegate(self): + """Create mock delegate client.""" + return Mock(spec=TelemetryPushClient) + + @pytest.fixture + def client(self, mock_delegate, setup_circuit_breaker): + """Create CircuitBreakerTelemetryPushClient instance.""" + return CircuitBreakerTelemetryPushClient(mock_delegate, "test-host.example.com") + + @pytest.mark.parametrize("status_code", [429, 503]) + def test_request_error_with_rate_limit_codes(self, client, mock_delegate, status_code): + """Test that RequestError with rate-limit codes raises TelemetryRateLimitError.""" + request_error = RequestError("HTTP request failed", context={"http-code": status_code}) + mock_delegate.request.side_effect = request_error + + with pytest.raises(TelemetryRateLimitError): + client.request(HttpMethod.POST, "https://test.com", {}) + + @pytest.mark.parametrize("status_code", [500, 400, 404]) + def test_request_error_with_non_rate_limit_codes(self, client, mock_delegate, status_code): + """Test that RequestError with non-rate-limit codes raises original RequestError.""" + request_error = RequestError("HTTP request failed", context={"http-code": status_code}) + mock_delegate.request.side_effect = request_error + + with pytest.raises(RequestError, match="HTTP request failed"): + client.request(HttpMethod.POST, "https://test.com", {}) + + @pytest.mark.parametrize("context", [{}, None, "429"]) + def test_request_error_with_invalid_context(self, client, mock_delegate, context): + """Test RequestError with invalid/missing context raises original error.""" + request_error = RequestError("HTTP request failed") + if context == "429": + # Edge case: http-code as string instead of int + request_error.context = {"http-code": context} + else: + request_error.context = context + mock_delegate.request.side_effect = request_error + + with pytest.raises(RequestError, match="HTTP request failed"): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_error_missing_context_attribute(self, client, mock_delegate): + """Test RequestError without context attribute raises original error.""" + request_error = RequestError("HTTP request failed") + if hasattr(request_error, "context"): + delattr(request_error, "context") + mock_delegate.request.side_effect = request_error + + with pytest.raises(RequestError, match="HTTP request failed"): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_http_code_extraction_prioritization(self, client, mock_delegate): + """Test that http-code from RequestError context is correctly extracted.""" + request_error = RequestError( + "HTTP request failed after retries", context={"http-code": 503} + ) + mock_delegate.request.side_effect = request_error + + with pytest.raises(TelemetryRateLimitError): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_non_request_error_exceptions_raised(self, client, mock_delegate): + """Test that non-RequestError exceptions are wrapped then unwrapped.""" + generic_error = ValueError("Network timeout") + mock_delegate.request.side_effect = generic_error + + with pytest.raises(ValueError, match="Network timeout"): + client.request(HttpMethod.POST, "https://test.com", {}) diff --git a/tests/unit/test_unified_http_client.py b/tests/unit/test_unified_http_client.py new file mode 100644 index 000000000..4e9ce1bbf --- /dev/null +++ b/tests/unit/test_unified_http_client.py @@ -0,0 +1,136 @@ +""" +Unit tests for UnifiedHttpClient, specifically testing MaxRetryError handling +and HTTP status code extraction. +""" + +import pytest +from unittest.mock import Mock, patch +from urllib3.exceptions import MaxRetryError + +from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import RequestError +from databricks.sql.auth.common import ClientContext +from databricks.sql.types import SSLOptions + + +class TestUnifiedHttpClientMaxRetryError: + """Test MaxRetryError handling and HTTP status code extraction.""" + + @pytest.fixture + def client_context(self): + """Create a minimal ClientContext for testing.""" + context = Mock(spec=ClientContext) + context.hostname = "https://test.databricks.com" + context.ssl_options = SSLOptions( + tls_verify=True, + tls_verify_hostname=True, + tls_trusted_ca_file=None, + tls_client_cert_file=None, + tls_client_cert_key_file=None, + tls_client_cert_key_password=None, + ) + context.socket_timeout = 30 + context.retry_stop_after_attempts_count = 3 + context.retry_delay_min = 1.0 + context.retry_delay_max = 10.0 + context.retry_stop_after_attempts_duration = 300.0 + context.retry_delay_default = 5.0 + context.retry_dangerous_codes = [] + context.proxy_auth_method = None + context.pool_connections = 10 + context.pool_maxsize = 20 + context.user_agent = "test-agent" + return context + + @pytest.fixture + def http_client(self, client_context): + """Create UnifiedHttpClient instance.""" + return UnifiedHttpClient(client_context) + + @pytest.mark.parametrize("status_code,path", [ + (429, "reason.response"), + (503, "reason.response"), + (500, "direct_response"), + ]) + def test_max_retry_error_with_status_codes(self, http_client, status_code, path): + """Test MaxRetryError with various status codes and response paths.""" + mock_pool = Mock() + max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") + + if path == "reason.response": + max_retry_error.reason = Mock() + max_retry_error.reason.response = Mock() + max_retry_error.reason.response.status = status_code + else: # direct_response + max_retry_error.response = Mock() + max_retry_error.response.status = status_code + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=max_retry_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request( + HttpMethod.POST, "http://test.com", headers={"test": "header"} + ) + + error = exc_info.value + assert hasattr(error, "context") + assert "http-code" in error.context + assert error.context["http-code"] == status_code + + @pytest.mark.parametrize("setup_func", [ + lambda e: None, # No setup - error with no attributes + lambda e: setattr(e, "reason", None), # reason=None + lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", None)), # reason.response=None + lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", Mock(spec=[]))), # No status attr + ]) + def test_max_retry_error_missing_status(self, http_client, setup_func): + """Test MaxRetryError without status code (no crash, empty context).""" + mock_pool = Mock() + max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") + setup_func(max_retry_error) + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=max_retry_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request(HttpMethod.GET, "http://test.com") + + error = exc_info.value + assert error.context == {} + + def test_max_retry_error_prefers_reason_response(self, http_client): + """Test that e.reason.response.status is preferred over e.response.status.""" + mock_pool = Mock() + max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") + + # Set both structures with different status codes + max_retry_error.reason = Mock() + max_retry_error.reason.response = Mock() + max_retry_error.reason.response.status = 429 # Should use this + + max_retry_error.response = Mock() + max_retry_error.response.status = 500 # Should be ignored + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=max_retry_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request(HttpMethod.GET, "http://test.com") + + error = exc_info.value + assert error.context["http-code"] == 429 + + def test_generic_exception_no_crash(self, http_client): + """Test that generic exceptions don't crash when checking for status code.""" + generic_error = Exception("Network error") + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=generic_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request(HttpMethod.POST, "http://test.com") + + error = exc_info.value + assert "HTTP request error" in str(error) From 8d5e1552531092f7892f424b1efcfb313e7db8ae Mon Sep 17 00:00:00 2001 From: Samikshya Chand <148681192+samikshya-db@users.noreply.github.com> Date: Thu, 27 Nov 2025 12:41:07 +0530 Subject: [PATCH 09/39] perf: Optimize telemetry latency logging to reduce overhead (#715) perf: Optimize telemetry latency logging to reduce overhead Optimizations implemented: 1. Eliminated extractor pattern - replaced wrapper classes with direct attribute access functions, removing object creation overhead 2. Added feature flag early exit - checks cached telemetry_enabled flag to skip heavy work when telemetry is disabled 3. Simplified code structure with early returns for better readability Signed-off-by: Samikshya Chand --- src/databricks/sql/common/feature_flag.py | 8 +- .../sql/telemetry/latency_logger.py | 289 +++++++++--------- .../sql/telemetry/telemetry_client.py | 62 +++- tests/unit/test_telemetry.py | 70 ++++- 4 files changed, 264 insertions(+), 165 deletions(-) diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 8a1cf5bd5..032701f63 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -165,8 +165,9 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: cls._initialize() assert cls._executor is not None - # Use the unique session ID as the key - key = connection.get_session_id_hex() + # Cache at HOST level - share feature flags across connections to same host + # Feature flags are per-host, not per-session + key = connection.session.host if key not in cls._context_map: cls._context_map[key] = FeatureFlagsContext( connection, cls._executor, connection.session.http_client @@ -177,7 +178,8 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: def remove_instance(cls, connection: "Connection"): """Removes the context for a given connection and shuts down the executor if no clients remain.""" with cls._lock: - key = connection.get_session_id_hex() + # Use host as key to match get_instance + key = connection.session.host if key in cls._context_map: cls._context_map.pop(key, None) diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 12cacd851..36ebee2b8 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -1,6 +1,6 @@ import time import functools -from typing import Optional +from typing import Optional, Dict, Any import logging from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.telemetry.models.event import ( @@ -11,127 +11,141 @@ logger = logging.getLogger(__name__) -class TelemetryExtractor: +def _extract_cursor_data(cursor) -> Dict[str, Any]: """ - Base class for extracting telemetry information from various object types. + Extract telemetry data directly from a Cursor object. - This class serves as a proxy that delegates attribute access to the wrapped object - while providing a common interface for extracting telemetry-related data. - """ - - def __init__(self, obj): - self._obj = obj - - def __getattr__(self, name): - return getattr(self._obj, name) - - def get_session_id_hex(self): - pass - - def get_statement_id(self): - pass - - def get_is_compressed(self): - pass - - def get_execution_result_format(self): - pass - - def get_retry_count(self): - pass - - def get_chunk_id(self): - pass + OPTIMIZATION: Uses direct attribute access instead of wrapper objects. + This eliminates object creation overhead and method call indirection. + Args: + cursor: The Cursor object to extract data from -class CursorExtractor(TelemetryExtractor): + Returns: + Dict with telemetry data (values may be None if extraction fails) """ - Telemetry extractor specialized for Cursor objects. - - Extracts telemetry information from database cursor objects, including - statement IDs, session information, compression settings, and result formats. + data = {} + + # Extract statement_id (query_id) - direct attribute access + try: + data["statement_id"] = cursor.query_id + except (AttributeError, Exception): + data["statement_id"] = None + + # Extract session_id_hex - direct method call + try: + data["session_id_hex"] = cursor.connection.get_session_id_hex() + except (AttributeError, Exception): + data["session_id_hex"] = None + + # Extract is_compressed - direct attribute access + try: + data["is_compressed"] = cursor.connection.lz4_compression + except (AttributeError, Exception): + data["is_compressed"] = False + + # Extract execution_result_format - inline logic + try: + if cursor.active_result_set is None: + data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED + else: + from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue + + results = cursor.active_result_set.results + if isinstance(results, ColumnQueue): + data["execution_result"] = ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(results, CloudFetchQueue): + data["execution_result"] = ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(results, ArrowQueue): + data["execution_result"] = ExecutionResultFormat.INLINE_ARROW + else: + data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED + except (AttributeError, Exception): + data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED + + # Extract retry_count - direct attribute access + try: + if hasattr(cursor.backend, "retry_policy") and cursor.backend.retry_policy: + data["retry_count"] = len(cursor.backend.retry_policy.history) + else: + data["retry_count"] = 0 + except (AttributeError, Exception): + data["retry_count"] = 0 + + # chunk_id is always None for Cursor + data["chunk_id"] = None + + return data + + +def _extract_result_set_handler_data(handler) -> Dict[str, Any]: """ + Extract telemetry data directly from a ResultSetDownloadHandler object. - def get_statement_id(self) -> Optional[str]: - return self.query_id - - def get_session_id_hex(self) -> Optional[str]: - return self.connection.get_session_id_hex() - - def get_is_compressed(self) -> bool: - return self.connection.lz4_compression - - def get_execution_result_format(self) -> ExecutionResultFormat: - if self.active_result_set is None: - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue - - if isinstance(self.active_result_set.results, ColumnQueue): - return ExecutionResultFormat.COLUMNAR_INLINE - elif isinstance(self.active_result_set.results, CloudFetchQueue): - return ExecutionResultFormat.EXTERNAL_LINKS - elif isinstance(self.active_result_set.results, ArrowQueue): - return ExecutionResultFormat.INLINE_ARROW - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - def get_retry_count(self) -> int: - if hasattr(self.backend, "retry_policy") and self.backend.retry_policy: - return len(self.backend.retry_policy.history) - return 0 - - def get_chunk_id(self): - return None + OPTIMIZATION: Uses direct attribute access instead of wrapper objects. + Args: + handler: The ResultSetDownloadHandler object to extract data from -class ResultSetDownloadHandlerExtractor(TelemetryExtractor): - """ - Telemetry extractor specialized for ResultSetDownloadHandler objects. + Returns: + Dict with telemetry data (values may be None if extraction fails) """ + data = {} - def get_session_id_hex(self) -> Optional[str]: - return self._obj.session_id_hex + # Extract session_id_hex - direct attribute access + try: + data["session_id_hex"] = handler.session_id_hex + except (AttributeError, Exception): + data["session_id_hex"] = None - def get_statement_id(self) -> Optional[str]: - return self._obj.statement_id + # Extract statement_id - direct attribute access + try: + data["statement_id"] = handler.statement_id + except (AttributeError, Exception): + data["statement_id"] = None - def get_is_compressed(self) -> bool: - return self._obj.settings.is_lz4_compressed + # Extract is_compressed - direct attribute access + try: + data["is_compressed"] = handler.settings.is_lz4_compressed + except (AttributeError, Exception): + data["is_compressed"] = False - def get_execution_result_format(self) -> ExecutionResultFormat: - return ExecutionResultFormat.EXTERNAL_LINKS + # execution_result is always EXTERNAL_LINKS for result set handlers + data["execution_result"] = ExecutionResultFormat.EXTERNAL_LINKS - def get_retry_count(self) -> Optional[int]: - # standard requests and urllib3 libraries don't expose retry count - return None + # retry_count is not available for result set handlers + data["retry_count"] = None + + # Extract chunk_id - direct attribute access + try: + data["chunk_id"] = handler.chunk_id + except (AttributeError, Exception): + data["chunk_id"] = None - def get_chunk_id(self) -> Optional[int]: - return self._obj.chunk_id + return data -def get_extractor(obj): +def _extract_telemetry_data(obj) -> Optional[Dict[str, Any]]: """ - Factory function to create the appropriate telemetry extractor for an object. + Extract telemetry data from an object based on its type. - Determines the object type and returns the corresponding specialized extractor - that can extract telemetry information from that object type. + OPTIMIZATION: Returns a simple dict instead of creating wrapper objects. + This dict will be used to create the SqlExecutionEvent in the background thread. Args: - obj: The object to create an extractor for. Can be a Cursor, - ResultSetDownloadHandler, or any other object. + obj: The object to extract data from (Cursor, ResultSetDownloadHandler, etc.) Returns: - TelemetryExtractor: A specialized extractor instance: - - CursorExtractor for Cursor objects - - ResultSetDownloadHandlerExtractor for ResultSetDownloadHandler objects - - None for all other objects + Dict with telemetry data, or None if object type is not supported """ - if obj.__class__.__name__ == "Cursor": - return CursorExtractor(obj) - elif obj.__class__.__name__ == "ResultSetDownloadHandler": - return ResultSetDownloadHandlerExtractor(obj) + obj_type = obj.__class__.__name__ + + if obj_type == "Cursor": + return _extract_cursor_data(obj) + elif obj_type == "ResultSetDownloadHandler": + return _extract_result_set_handler_data(obj) else: - logger.debug("No extractor found for %s", obj.__class__.__name__) + logger.debug("No telemetry extraction available for %s", obj_type) return None @@ -143,12 +157,6 @@ def log_latency(statement_type: StatementType = StatementType.NONE): data about the operation, including latency, statement information, and execution context. - The decorator automatically: - - Measures execution time using high-precision performance counters - - Extracts telemetry information from the method's object (self) - - Creates a SqlExecutionEvent with execution details - - Sends the telemetry data asynchronously via TelemetryClient - Args: statement_type (StatementType): The type of SQL statement being executed. @@ -162,54 +170,49 @@ def execute(self, query): function: A decorator that wraps methods to add latency logging. Note: - The wrapped method's object (self) must be compatible with the - telemetry extractor system (e.g., Cursor or ResultSet objects). + The wrapped method's object (self) must be a Cursor or + ResultSetDownloadHandler for telemetry data extraction. """ def decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): - start_time = time.perf_counter() - result = None + start_time = time.monotonic() try: - result = func(self, *args, **kwargs) - return result + return func(self, *args, **kwargs) finally: - - def _safe_call(func_to_call): - """Calls a function and returns a default value on any exception.""" - try: - return func_to_call() - except Exception: - return None - - end_time = time.perf_counter() - duration_ms = int((end_time - start_time) * 1000) - - extractor = get_extractor(self) - - if extractor is not None: - session_id_hex = _safe_call(extractor.get_session_id_hex) - statement_id = _safe_call(extractor.get_statement_id) - - sql_exec_event = SqlExecutionEvent( - statement_type=statement_type, - is_compressed=_safe_call(extractor.get_is_compressed), - execution_result=_safe_call( - extractor.get_execution_result_format - ), - retry_count=_safe_call(extractor.get_retry_count), - chunk_id=_safe_call(extractor.get_chunk_id), - ) - - telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex - ) - telemetry_client.export_latency_log( - latency_ms=duration_ms, - sql_execution_event=sql_exec_event, - sql_statement_id=statement_id, - ) + duration_ms = int((time.monotonic() - start_time) * 1000) + + # Always log for debugging + logger.debug("%s completed in %dms", func.__name__, duration_ms) + + # Fast check: use cached telemetry_enabled flag from connection + # Avoids dictionary lookup + instance check on every operation + connection = getattr(self, "connection", None) + if connection and getattr(connection, "telemetry_enabled", False): + session_id_hex = connection.get_session_id_hex() + if session_id_hex: + # Telemetry enabled - extract and send + telemetry_data = _extract_telemetry_data(self) + if telemetry_data: + sql_exec_event = SqlExecutionEvent( + statement_type=statement_type, + is_compressed=telemetry_data.get("is_compressed"), + execution_result=telemetry_data.get("execution_result"), + retry_count=telemetry_data.get("retry_count"), + chunk_id=telemetry_data.get("chunk_id"), + ) + + telemetry_client = ( + TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) + ) + telemetry_client.export_latency_log( + latency_ms=duration_ms, + sql_execution_event=sql_exec_event, + sql_statement_id=telemetry_data.get("statement_id"), + ) return wrapper diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 177d5445c..d5f5b575c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -2,6 +2,7 @@ import time import logging import json +from queue import Queue, Full from concurrent.futures import ThreadPoolExecutor from concurrent.futures import Future from datetime import datetime, timezone @@ -114,18 +115,21 @@ def get_auth_flow(auth_provider): @staticmethod def is_telemetry_enabled(connection: "Connection") -> bool: + # Fast path: force enabled - skip feature flag fetch entirely if connection.force_enable_telemetry: return True - if connection.enable_telemetry: - context = FeatureFlagsContextFactory.get_instance(connection) - flag_value = context.get_flag_value( - TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False - ) - return str(flag_value).lower() == "true" - else: + # Fast path: disabled - no need to check feature flag + if not connection.enable_telemetry: return False + # Only fetch feature flags when enable_telemetry=True and not forced + context = FeatureFlagsContextFactory.get_instance(connection) + flag_value = context.get_flag_value( + TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False + ) + return str(flag_value).lower() == "true" + class NoopTelemetryClient(BaseTelemetryClient): """ @@ -185,8 +189,11 @@ def __init__( self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None - self._events_batch: list = [] - self._lock = threading.RLock() + + # OPTIMIZATION: Use lock-free Queue instead of list + lock + # Queue is thread-safe internally and has better performance under concurrency + self._events_queue: Queue[TelemetryFrontendLog] = Queue(maxsize=batch_size * 2) + self._driver_connection_params = None self._host_url = host_url self._executor = executor @@ -196,7 +203,8 @@ def __init__( # Create telemetry push client based on circuit breaker enabled flag if client_context.telemetry_circuit_breaker_enabled: - # Create circuit breaker telemetry push client (circuit breakers created on-demand) + # Create circuit breaker telemetry push client + # (circuit breakers created on-demand) self._telemetry_push_client: ITelemetryPushClient = ( CircuitBreakerTelemetryPushClient( TelemetryPushClient(self._http_client), @@ -210,9 +218,24 @@ def __init__( def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._session_id_hex) - with self._lock: - self._events_batch.append(event) - if len(self._events_batch) >= self._batch_size: + + # OPTIMIZATION: Use non-blocking put with queue + # No explicit lock needed - Queue is thread-safe internally + try: + self._events_queue.put_nowait(event) + except Full: + # Queue is full, trigger immediate flush + logger.debug("Event queue full, triggering flush") + self._flush() + # Try again after flush + try: + self._events_queue.put_nowait(event) + except Full: + # Still full, drop event (acceptable for telemetry) + logger.debug("Dropped telemetry event - queue still full") + + # Check if we should flush based on queue size + if self._events_queue.qsize() >= self._batch_size: logger.debug( "Batch size limit reached (%s), flushing events", self._batch_size ) @@ -220,9 +243,16 @@ def _export_event(self, event): def _flush(self): """Flush the current batch of events to the server""" - with self._lock: - events_to_flush = self._events_batch.copy() - self._events_batch = [] + # OPTIMIZATION: Drain queue without locks + # Collect all events currently in the queue + events_to_flush = [] + while not self._events_queue.empty(): + try: + event = self._events_queue.get_nowait() + events_to_flush.append(event) + except: + # Queue is empty + break if events_to_flush: logger.debug("Flushing %s telemetry events to server", len(events_to_flush)) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 6f5a01c7b..96a2f87d8 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -10,6 +10,10 @@ TelemetryClientFactory, TelemetryHelper, ) +from databricks.sql.common.feature_flag import ( + FeatureFlagsContextFactory, + FeatureFlagsContext, +) from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow, DatabricksClientType from databricks.sql.telemetry.models.event import ( TelemetryEvent, @@ -82,12 +86,12 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): client._export_event("event1") client._export_event("event2") mock_send.assert_not_called() - assert len(client._events_batch) == 2 + assert client._events_queue.qsize() == 2 # Third event should trigger flush client._export_event("event3") mock_send.assert_called_once() - assert len(client._events_batch) == 0 # Batch cleared after flush + assert client._events_queue.qsize() == 0 # Queue cleared after flush @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_network_request_flow(self, mock_http_request, mock_telemetry_client): @@ -817,7 +821,67 @@ def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_sess mock_export.assert_called_once() driver_params = mock_export.call_args.kwargs.get("driver_connection_params") - + # CF proxy not yet supported - should be False/None assert driver_params.use_cf_proxy is False assert driver_params.cf_proxy_host_info is None + + +class TestFeatureFlagsContextFactory: + """Tests for FeatureFlagsContextFactory host-level caching.""" + + @pytest.fixture(autouse=True) + def reset_factory(self): + """Reset factory state before/after each test.""" + FeatureFlagsContextFactory._context_map.clear() + if FeatureFlagsContextFactory._executor: + FeatureFlagsContextFactory._executor.shutdown(wait=False) + FeatureFlagsContextFactory._executor = None + yield + FeatureFlagsContextFactory._context_map.clear() + if FeatureFlagsContextFactory._executor: + FeatureFlagsContextFactory._executor.shutdown(wait=False) + FeatureFlagsContextFactory._executor = None + + @pytest.mark.parametrize( + "hosts,expected_contexts", + [ + (["host1.com", "host1.com"], 1), # Same host shares context + (["host1.com", "host2.com"], 2), # Different hosts get separate contexts + (["host1.com", "host1.com", "host2.com"], 2), # Mixed scenario + ], + ) + def test_host_level_caching(self, hosts, expected_contexts): + """Test that contexts are cached by host correctly.""" + contexts = [] + for host in hosts: + conn = MagicMock() + conn.session.host = host + conn.session.http_client = MagicMock() + contexts.append(FeatureFlagsContextFactory.get_instance(conn)) + + assert len(FeatureFlagsContextFactory._context_map) == expected_contexts + if expected_contexts == 1: + assert all(ctx is contexts[0] for ctx in contexts) + + def test_remove_instance_and_executor_cleanup(self): + """Test removal uses host key and cleans up executor when empty.""" + conn1 = MagicMock() + conn1.session.host = "host1.com" + conn1.session.http_client = MagicMock() + + conn2 = MagicMock() + conn2.session.host = "host2.com" + conn2.session.http_client = MagicMock() + + FeatureFlagsContextFactory.get_instance(conn1) + FeatureFlagsContextFactory.get_instance(conn2) + assert FeatureFlagsContextFactory._executor is not None + + FeatureFlagsContextFactory.remove_instance(conn1) + assert len(FeatureFlagsContextFactory._context_map) == 1 + assert FeatureFlagsContextFactory._executor is not None + + FeatureFlagsContextFactory.remove_instance(conn2) + assert len(FeatureFlagsContextFactory._context_map) == 0 + assert FeatureFlagsContextFactory._executor is None From d524f0e4bd730ec7397f3c0088c6e80761ebf8c9 Mon Sep 17 00:00:00 2001 From: nikhilsuri-db Date: Fri, 28 Nov 2025 13:33:02 +0530 Subject: [PATCH 10/39] basic e2e test for force telemetry verification (#708) * basic e2e test for force telemetry verification Signed-off-by: Nikhil Suri * Added more integration test scenarios Signed-off-by: Nikhil Suri * default on telemetry + logs to investigate failing test Signed-off-by: Nikhil Suri * fixed linting issue Signed-off-by: Nikhil Suri * added more logs to identify server side flag evaluation Signed-off-by: Nikhil Suri * remove unused logs Signed-off-by: Nikhil Suri * fix broken test case for default enable telemetry Signed-off-by: Nikhil Suri * redcude test length and made more reusable code Signed-off-by: Nikhil Suri * removed telemetry e2e to daily single run Signed-off-by: Nikhil Suri --------- Signed-off-by: Nikhil Suri --- .github/workflows/daily-telemetry-e2e.yml | 87 ++++++ .github/workflows/integration.yml | 8 +- tests/e2e/test_telemetry_e2e.py | 343 ++++++++++++++++++++++ 3 files changed, 436 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/daily-telemetry-e2e.yml create mode 100644 tests/e2e/test_telemetry_e2e.py diff --git a/.github/workflows/daily-telemetry-e2e.yml b/.github/workflows/daily-telemetry-e2e.yml new file mode 100644 index 000000000..3d61cf177 --- /dev/null +++ b/.github/workflows/daily-telemetry-e2e.yml @@ -0,0 +1,87 @@ +name: Daily Telemetry E2E Tests + +on: + schedule: + - cron: '0 0 * * 0' # Run every Sunday at midnight UTC + + workflow_dispatch: # Allow manual triggering + inputs: + test_pattern: + description: 'Test pattern to run (default: tests/e2e/test_telemetry_e2e.py)' + required: false + default: 'tests/e2e/test_telemetry_e2e.py' + type: string + +jobs: + telemetry-e2e-tests: + runs-on: ubuntu-latest + environment: azure-prod + + env: + DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} + DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} + DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} + DATABRICKS_CATALOG: peco + DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} + + steps: + #---------------------------------------------- + # check-out repo and set-up python + #---------------------------------------------- + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up python + id: setup-python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + #---------------------------------------------- + # ----- install & configure poetry ----- + #---------------------------------------------- + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + #---------------------------------------------- + # load cached venv if cache exists + #---------------------------------------------- + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v4 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} + + #---------------------------------------------- + # install dependencies if cache does not exist + #---------------------------------------------- + - name: Install dependencies + run: poetry install --no-interaction --all-extras + + #---------------------------------------------- + # run telemetry E2E tests + #---------------------------------------------- + - name: Run telemetry E2E tests + run: | + TEST_PATTERN="${{ github.event.inputs.test_pattern || 'tests/e2e/test_telemetry_e2e.py' }}" + echo "Running tests: $TEST_PATTERN" + poetry run python -m pytest $TEST_PATTERN -v -s + + #---------------------------------------------- + # upload test results on failure + #---------------------------------------------- + - name: Upload test results on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: telemetry-test-results + path: | + .pytest_cache/ + tests-unsafe.log + retention-days: 7 + diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 9c9e30a24..ad5369997 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -54,5 +54,9 @@ jobs: #---------------------------------------------- # run test suite #---------------------------------------------- - - name: Run e2e tests - run: poetry run python -m pytest tests/e2e -n auto \ No newline at end of file + - name: Run e2e tests (excluding daily-only tests) + run: | + # Exclude telemetry E2E tests from PR runs (run daily instead) + poetry run python -m pytest tests/e2e \ + --ignore=tests/e2e/test_telemetry_e2e.py \ + -n auto \ No newline at end of file diff --git a/tests/e2e/test_telemetry_e2e.py b/tests/e2e/test_telemetry_e2e.py new file mode 100644 index 000000000..917c8e5eb --- /dev/null +++ b/tests/e2e/test_telemetry_e2e.py @@ -0,0 +1,343 @@ +""" +E2E test for telemetry - verifies telemetry behavior with different scenarios +""" +import time +import threading +import logging +from contextlib import contextmanager +from unittest.mock import patch +import pytest +from concurrent.futures import wait + +import databricks.sql as sql +from databricks.sql.telemetry.telemetry_client import ( + TelemetryClient, + TelemetryClientFactory, +) + +log = logging.getLogger(__name__) + + +class TelemetryTestBase: + """Simplified test base class for telemetry e2e tests""" + + @pytest.fixture(autouse=True) + def get_details(self, connection_details): + self.arguments = connection_details.copy() + + def connection_params(self): + return { + "server_hostname": self.arguments["host"], + "http_path": self.arguments["http_path"], + "access_token": self.arguments.get("access_token"), + } + + @contextmanager + def connection(self, extra_params=()): + connection_params = dict(self.connection_params(), **dict(extra_params)) + log.info("Connecting with args: {}".format(connection_params)) + conn = sql.connect(**connection_params) + try: + yield conn + finally: + conn.close() + + +class TestTelemetryE2E(TelemetryTestBase): + """E2E tests for telemetry scenarios""" + + @pytest.fixture(autouse=True) + def telemetry_setup_teardown(self): + """Clean up telemetry client state before and after each test""" + try: + yield + finally: + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._stop_flush_thread() + TelemetryClientFactory._initialized = False + + @pytest.fixture + def telemetry_interceptors(self): + """Setup reusable telemetry interceptors as a fixture""" + capture_lock = threading.Lock() + captured_events = [] + captured_futures = [] + + original_export = TelemetryClient._export_event + original_callback = TelemetryClient._telemetry_request_callback + + def export_wrapper(self_client, event): + with capture_lock: + captured_events.append(event) + return original_export(self_client, event) + + def callback_wrapper(self_client, future, sent_count): + with capture_lock: + captured_futures.append(future) + original_callback(self_client, future, sent_count) + + return captured_events, captured_futures, export_wrapper, callback_wrapper + + # ==================== ASSERTION HELPERS ==================== + + def assert_system_config(self, event): + """Assert system configuration fields""" + sys_config = event.entry.sql_driver_log.system_configuration + assert sys_config is not None + + # Check all required fields are non-empty + for field in ['driver_name', 'driver_version', 'os_name', 'os_version', + 'os_arch', 'runtime_name', 'runtime_version', 'runtime_vendor', + 'locale_name', 'char_set_encoding']: + value = getattr(sys_config, field) + assert value and len(value) > 0, f"{field} should not be None or empty" + + assert sys_config.driver_name == "Databricks SQL Python Connector" + + def assert_connection_params(self, event, expected_http_path=None): + """Assert connection parameters""" + conn_params = event.entry.sql_driver_log.driver_connection_params + assert conn_params is not None + assert conn_params.http_path + assert conn_params.host_info is not None + assert conn_params.auth_mech is not None + + if expected_http_path: + assert conn_params.http_path == expected_http_path + + if conn_params.socket_timeout is not None: + assert conn_params.socket_timeout > 0 + + def assert_statement_execution(self, event): + """Assert statement execution details""" + sql_op = event.entry.sql_driver_log.sql_operation + assert sql_op is not None + assert sql_op.statement_type is not None + assert sql_op.execution_result is not None + assert hasattr(sql_op, "retry_count") + + if sql_op.retry_count is not None: + assert sql_op.retry_count >= 0 + + latency = event.entry.sql_driver_log.operation_latency_ms + assert latency is not None and latency >= 0 + + def assert_error_info(self, event, expected_error_name=None): + """Assert error information""" + error_info = event.entry.sql_driver_log.error_info + assert error_info is not None + assert error_info.error_name and len(error_info.error_name) > 0 + assert error_info.stack_trace and len(error_info.stack_trace) > 0 + + if expected_error_name: + assert error_info.error_name == expected_error_name + + def verify_events(self, captured_events, captured_futures, expected_count): + """Common verification for event count and HTTP responses""" + if expected_count == 0: + assert len(captured_events) == 0, f"Expected 0 events, got {len(captured_events)}" + assert len(captured_futures) == 0, f"Expected 0 responses, got {len(captured_futures)}" + else: + assert len(captured_events) == expected_count, \ + f"Expected {expected_count} events, got {len(captured_events)}" + + time.sleep(2) + done, _ = wait(captured_futures, timeout=10) + assert len(done) == expected_count, \ + f"Expected {expected_count} responses, got {len(done)}" + + for future in done: + response = future.result() + assert 200 <= response.status < 300 + + # Assert common fields for all events + for event in captured_events: + self.assert_system_config(event) + self.assert_connection_params(event, self.arguments["http_path"]) + + # ==================== PARAMETERIZED TESTS ==================== + + @pytest.mark.parametrize("enable_telemetry,force_enable,expected_count,test_id", [ + (True, False, 2, "enable_on_force_off"), + (False, True, 2, "enable_off_force_on"), + (False, False, 0, "both_off"), + (None, None, 0, "default_behavior"), + ]) + def test_telemetry_flags(self, telemetry_interceptors, enable_telemetry, + force_enable, expected_count, test_id): + """Test telemetry behavior with different flag combinations""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + extra_params = {"telemetry_batch_size": 1} + if enable_telemetry is not None: + extra_params["enable_telemetry"] = enable_telemetry + if force_enable is not None: + extra_params["force_enable_telemetry"] = force_enable + + with self.connection(extra_params=extra_params) as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT 1") + cursor.fetchone() + + self.verify_events(captured_events, captured_futures, expected_count) + + # Assert statement execution on latency event (if events exist) + if expected_count > 0: + self.assert_statement_execution(captured_events[-1]) + + @pytest.mark.parametrize("query,expected_error", [ + ("SELECT * FROM WHERE INVALID SYNTAX 12345", "ServerOperationError"), + ("SELECT * FROM non_existent_table_xyz_12345", None), + ]) + def test_sql_errors(self, telemetry_interceptors, query, expected_error): + """Test telemetry captures error information for different SQL errors""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + with self.connection(extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + }) as conn: + with conn.cursor() as cursor: + with pytest.raises(Exception): + cursor.execute(query) + cursor.fetchone() + + time.sleep(2) + wait(captured_futures, timeout=10) + + assert len(captured_events) >= 1 + + # Find event with error_info + error_event = next((e for e in captured_events + if e.entry.sql_driver_log.error_info), None) + assert error_event is not None + + self.assert_system_config(error_event) + self.assert_connection_params(error_event, self.arguments["http_path"]) + self.assert_error_info(error_event, expected_error) + + def test_metadata_operation(self, telemetry_interceptors): + """Test telemetry for metadata operations (getCatalogs)""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + with self.connection(extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + }) as conn: + with conn.cursor() as cursor: + catalogs = cursor.catalogs() + catalogs.fetchall() + + time.sleep(2) + wait(captured_futures, timeout=10) + + assert len(captured_events) >= 1 + for event in captured_events: + self.assert_system_config(event) + self.assert_connection_params(event, self.arguments["http_path"]) + + def test_direct_results(self, telemetry_interceptors): + """Test telemetry with direct results (use_cloud_fetch=False)""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + with self.connection(extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + "use_cloud_fetch": False, + }) as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT 100") + result = cursor.fetchall() + assert len(result) == 1 and result[0][0] == 100 + + time.sleep(2) + wait(captured_futures, timeout=10) + + assert len(captured_events) >= 2 + for event in captured_events: + self.assert_system_config(event) + self.assert_connection_params(event, self.arguments["http_path"]) + + self.assert_statement_execution(captured_events[-1]) + + @pytest.mark.parametrize("close_type", [ + "context_manager", + "explicit_cursor", + "explicit_connection", + "implicit_fetchall", + ]) + def test_cloudfetch_with_different_close_patterns(self, telemetry_interceptors, + close_type): + """Test telemetry with cloud fetch using different resource closing patterns""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + if close_type == "explicit_connection": + # Test explicit connection close + conn = sql.connect( + **self.connection_params(), + force_enable_telemetry=True, + telemetry_batch_size=1, + use_cloud_fetch=True, + ) + cursor = conn.cursor() + cursor.execute("SELECT * FROM range(1000)") + result = cursor.fetchall() + assert len(result) == 1000 + conn.close() + else: + # Other patterns use connection context manager + with self.connection(extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + "use_cloud_fetch": True, + }) as conn: + if close_type == "context_manager": + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM range(1000)") + result = cursor.fetchall() + assert len(result) == 1000 + + elif close_type == "explicit_cursor": + cursor = conn.cursor() + cursor.execute("SELECT * FROM range(1000)") + result = cursor.fetchall() + assert len(result) == 1000 + cursor.close() + + elif close_type == "implicit_fetchall": + cursor = conn.cursor() + cursor.execute("SELECT * FROM range(1000)") + result = cursor.fetchall() + assert len(result) == 1000 + + time.sleep(2) + wait(captured_futures, timeout=10) + + assert len(captured_events) >= 2 + for event in captured_events: + self.assert_system_config(event) + self.assert_connection_params(event, self.arguments["http_path"]) + + self.assert_statement_execution(captured_events[-1]) From ebe4b07a83499bb725338a10cf18444b31e3ec44 Mon Sep 17 00:00:00 2001 From: Samikshya Chand <148681192+samikshya-db@users.noreply.github.com> Date: Wed, 3 Dec 2025 20:22:39 +0530 Subject: [PATCH 11/39] feat: Implement host-level telemetry batching to reduce rate limiting (#718) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Implement host-level telemetry batching to reduce rate limiting Changes telemetry client architecture from per-session to per-host batching, matching the JDBC driver implementation. This reduces the number of HTTP requests to the telemetry endpoint and prevents rate limiting in test environments. Key changes: - Add _TelemetryClientHolder with reference counting for shared clients - Change TelemetryClientFactory to key clients by host_url instead of session_id - Add getHostUrlSafely() helper for defensive null handling - Update all callers (client.py, exc.py, latency_logger.py) to pass host_url Before: 100 connections to same host = 100 separate TelemetryClients After: 100 connections to same host = 1 shared TelemetryClient (refcount=100) This fixes rate limiting issues seen in e2e tests where 300+ parallel connections were overwhelming the telemetry endpoint with 429 errors. * chore: Change all telemetry logging to DEBUG level Reduces log noise by changing all telemetry-related log statements (info, warning, error) to debug level. Telemetry operations are background tasks and should not clutter logs with operational messages. Changes: - Circuit breaker state changes: info/warning -> debug - Telemetry send failures: error -> debug - All telemetry operations now consistently use debug level * chore: Fix remaining telemetry warning log to debug Changes remaining logger.warning in telemetry_push_client.py to debug level for consistency with other telemetry logging. * fix: Update tests to use host_url instead of session_id_hex - Update circuit breaker test to check logger.debug instead of logger.info - Replace all session_id_hex test parameters with host_url - Apply Black formatting to exc.py and telemetry_client.py This fixes test failures caused by the signature change from session_id_hex to host_url in the Error class and TelemetryClientFactory. * fix: Revert session_id_hex in tests for functions that still use it Only Error classes changed from session_id_hex to host_url. Other classes (TelemetryClient, ResultSetDownloadHandler, etc.) still use session_id_hex. Reverted: - test_telemetry.py: TelemetryClient and initialize_telemetry_client - test_downloader.py: ResultSetDownloadHandler - test_download_manager.py: ResultFileDownloadManager Kept as host_url: - test_client.py: Error class instantiation * fix: Update all Error raises and test calls to use host_url Changes: 1. client.py: Changed all error raises from session_id_hex to host_url - Connection class: session_id_hex=self.get_session_id_hex() -> host_url=self.session.host - Cursor class: session_id_hex=self.connection.get_session_id_hex() -> host_url=self.connection.session.host 2. test_telemetry.py: Updated get_telemetry_client() and close() calls - get_telemetry_client(session_id) -> get_telemetry_client(host_url) - close(session_id) -> close(host_url=host_url) 3. test_telemetry_push_client.py: Changed logger.warning to logger.debug - Updated test assertion to match debug logging level These changes complete the migration from session-level to host-level telemetry client management. * fix: Update thrift_backend.py to use host_url instead of session_id_hex Changes: 1. Added self._host attribute to store server_hostname 2. Updated all error raises to use host_url=self._host 3. Changed method signatures from session_id_hex to host_url: - _check_response_for_error - _hive_schema_to_arrow_schema - _col_to_description - _hive_schema_to_description - _check_direct_results_for_error 4. Updated all method calls to pass self._host instead of self._session_id_hex This completes the migration from session-level to host-level error reporting. * Fix Black formatting by adjusting fmt directive placement Moved the `# fmt: on` directive to the except block level instead of inside the if statement to resolve Black parsing confusion. * Fix telemetry feature flag tests to set mock session host The tests were failing because they called get_telemetry_client("test") but the mock session didn't have .host set, so the telemetry client was registered under a different key (likely None or MagicMock). This caused the factory to return NoopTelemetryClient instead of the expected client. Fixed by setting mock_session_instance.host = "test" in all three tests. * Add teardown_method to clear telemetry factory state between tests Without this cleanup, tests were sharing telemetry clients because they all used the same host key ("test"), causing test pollution. The first test would create an enabled client, and subsequent tests would reuse it even when they expected a disabled client. * Clear feature flag context cache in teardown to fix test pollution The FeatureFlagsContextFactory caches feature flag contexts per session, causing tests to share the same feature flag state. This resulted in the first test creating a context with telemetry enabled, and subsequent tests incorrectly reusing that enabled state even when they expected disabled. * fix: Access actual client from holder in flush worker The flush worker was calling _flush() on _TelemetryClientHolder objects instead of the actual TelemetryClient. Fixed by accessing holder.client before calling _flush(). Fixes AttributeError in e2e tests: '_TelemetryClientHolder' object has no attribute '_flush' * Clear telemetry client cache in e2e test teardown Added _clients.clear() to the teardown fixture to prevent telemetry clients from persisting across e2e tests, which was causing session ID pollution in test_concurrent_queries_sends_telemetry. * Pass session_id parameter to telemetry export methods With host-level telemetry batching, multiple connections share one TelemetryClient. Each client stores session_id_hex from the first connection that created it. This caused all subsequent connections' telemetry events to use the wrong session ID. Changes: - Modified telemetry export method signatures to accept optional session_id - Updated Connection.export_initial_telemetry_log() to pass session_id - Updated latency_logger.py export_latency_log() to pass session_id - Updated Error.__init__() to accept optional session_id_hex and pass it - Updated all error raises in Connection and Cursor to pass session_id_hex 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * Fix Black formatting in telemetry_client.py * Use 'test-host' instead of 'test' for mock host in telemetry tests * Replace test-session-id with test-host in test_client.py * Fix telemetry client lookup to use test-host in tests * Make session_id_hex keyword-only parameter in Error.__init__ --------- Co-authored-by: Claude --- src/databricks/sql/backend/thrift_backend.py | 72 ++++--- src/databricks/sql/client.py | 36 +++- src/databricks/sql/exc.py | 16 +- .../sql/telemetry/circuit_breaker_manager.py | 8 +- .../sql/telemetry/latency_logger.py | 3 +- .../sql/telemetry/telemetry_client.py | 186 ++++++++++++++---- .../sql/telemetry/telemetry_push_client.py | 2 +- tests/e2e/test_concurrent_telemetry.py | 1 + tests/unit/test_circuit_breaker_manager.py | 2 +- tests/unit/test_client.py | 8 +- tests/unit/test_telemetry.py | 28 ++- tests/unit/test_telemetry_push_client.py | 8 +- 12 files changed, 262 insertions(+), 108 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index d2b10e718..edee02bfa 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -163,6 +163,7 @@ def __init__( else: raise ValueError("No valid connection settings.") + self._host = server_hostname self._initialize_retry_args(kwargs) self._use_arrow_native_complex_types = kwargs.get( "_use_arrow_native_complex_types", True @@ -279,14 +280,14 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response, session_id_hex=None): + def _check_response_for_error(response, host_url=None): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: raise DatabaseError( response.status.errorMessage, - session_id_hex=session_id_hex, + host_url=host_url, ) @staticmethod @@ -340,7 +341,7 @@ def _handle_request_error(self, error_info, attempt, elapsed): network_request_error = RequestError( user_friendly_error_message, full_error_info_context, - self._session_id_hex, + self._host, error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -461,13 +462,12 @@ def attempt_request(attempt): errno.ECONNRESET, # | 104 | 54 | errno.ETIMEDOUT, # | 110 | 60 | ] + # fmt: on gos_name = TCLIServiceClient.GetOperationStatus.__name__ # retry on timeout. Happens a lot in Azure and it is safe as data has not been sent to server yet if method.__name__ == gos_name or err.errno == errno.ETIMEDOUT: retry_delay = bound_retry_delay(attempt, self._retry_delay_default) - - # fmt: on log_string = f"{gos_name} failed with code {err.errno} and will attempt to retry" if err.errno in info_errs: logger.info(log_string) @@ -516,9 +516,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftDatabricksClient._check_response_for_error( - response, self._session_id_hex - ) + ThriftDatabricksClient._check_response_for_error(response, self._host) return response error_info = response_or_error_info @@ -533,7 +531,7 @@ def _check_protocol_version(self, t_open_session_resp): "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " "instead got: {}".format(protocol_version), - session_id_hex=self._session_id_hex, + host_url=self._host, ) def _check_initial_namespace(self, catalog, schema, response): @@ -547,7 +545,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", - session_id_hex=self._session_id_hex, + host_url=self._host, ) if catalog: @@ -555,7 +553,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " + "but server does not support multiple catalogs.".format(catalog), # type: ignore - session_id_hex=self._session_id_hex, + host_url=self._host, ) def _check_session_configuration(self, session_configuration): @@ -570,7 +568,7 @@ def _check_session_configuration(self, session_configuration): TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], ), - session_id_hex=self._session_id_hex, + host_url=self._host, ) def open_session(self, session_configuration, catalog, schema) -> SessionId: @@ -639,7 +637,7 @@ def _check_command_not_in_error_or_closed_state( and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, - session_id_hex=self._session_id_hex, + host_url=self._host, ) else: raise ServerOperationError( @@ -649,7 +647,7 @@ def _check_command_not_in_error_or_closed_state( and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, - session_id_hex=self._session_id_hex, + host_url=self._host, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -660,7 +658,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and guid_to_hex_id(op_handle.operationId.guid) }, - session_id_hex=self._session_id_hex, + host_url=self._host, ) def _poll_for_status(self, op_handle): @@ -683,7 +681,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti else: raise OperationalError( "Unsupported TRowSet instance {}".format(t_row_set), - session_id_hex=self._session_id_hex, + host_url=self._host, ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows @@ -692,7 +690,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema, session_id_hex=None): + def _hive_schema_to_arrow_schema(t_table_schema, host_url=None): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -724,7 +722,7 @@ def map_type(t_type_entry): # even for complex types raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - session_id_hex=session_id_hex, + host_url=host_url, ) def convert_col(t_column_desc): @@ -735,7 +733,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col, field=None, session_id_hex=None): + def _col_to_description(col, field=None, host_url=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -745,7 +743,7 @@ def _col_to_description(col, field=None, session_id_hex=None): else: raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - session_id_hex=session_id_hex, + host_url=host_url, ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -759,7 +757,7 @@ def _col_to_description(col, field=None, session_id_hex=None): raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " "primitiveEntry {}".format(type_entry.primitiveEntry), - session_id_hex=session_id_hex, + host_url=host_url, ) else: precision, scale = None, None @@ -778,9 +776,7 @@ def _col_to_description(col, field=None, session_id_hex=None): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description( - t_table_schema, schema_bytes=None, session_id_hex=None - ): + def _hive_schema_to_description(t_table_schema, schema_bytes=None, host_url=None): field_dict = {} if pyarrow and schema_bytes: try: @@ -795,7 +791,7 @@ def _hive_schema_to_description( ThriftDatabricksClient._col_to_description( col, field_dict.get(col.columnName) if field_dict else None, - session_id_hex, + host_url, ) for col in t_table_schema.columns ] @@ -818,7 +814,7 @@ def _results_message_to_execute_response(self, resp, operation_state): t_result_set_metadata_resp.resultFormat ] ), - session_id_hex=self._session_id_hex, + host_url=self._host, ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -833,7 +829,7 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = ( t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._session_id_hex + t_result_set_metadata_resp.schema, self._host ) .serialize() .to_pybytes() @@ -844,7 +840,7 @@ def _results_message_to_execute_response(self, resp, operation_state): description = self._hive_schema_to_description( t_result_set_metadata_resp.schema, schema_bytes, - self._session_id_hex, + self._host, ) lz4_compressed = t_result_set_metadata_resp.lz4Compressed @@ -895,7 +891,7 @@ def get_execution_result( schema_bytes = ( t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._session_id_hex + t_result_set_metadata_resp.schema, self._host ) .serialize() .to_pybytes() @@ -906,7 +902,7 @@ def get_execution_result( description = self._hive_schema_to_description( t_result_set_metadata_resp.schema, schema_bytes, - self._session_id_hex, + self._host, ) lz4_compressed = t_result_set_metadata_resp.lz4Compressed @@ -971,27 +967,27 @@ def get_query_state(self, command_id: CommandId) -> CommandState: return state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results, session_id_hex=None): + def _check_direct_results_for_error(t_spark_direct_results, host_url=None): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.operationStatus, - session_id_hex, + host_url, ) if t_spark_direct_results.resultSetMetadata: ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSetMetadata, - session_id_hex, + host_url, ) if t_spark_direct_results.resultSet: ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSet, - session_id_hex, + host_url, ) if t_spark_direct_results.closeOperation: ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.closeOperation, - session_id_hex, + host_url, ) def execute_command( @@ -1260,7 +1256,7 @@ def _handle_execute_response(self, resp, cursor): raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") cursor.active_command_id = command_id - self._check_direct_results_for_error(resp.directResults, self._session_id_hex) + self._check_direct_results_for_error(resp.directResults, self._host) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1275,7 +1271,7 @@ def _handle_execute_response_async(self, resp, cursor): raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") cursor.active_command_id = command_id - self._check_direct_results_for_error(resp.directResults, self._session_id_hex) + self._check_direct_results_for_error(resp.directResults, self._host) def fetch_results( self, @@ -1313,7 +1309,7 @@ def fetch_results( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset ), - session_id_hex=self._session_id_hex, + host_url=self._host, ) queue = ThriftResultSetQueueFactory.build_queue( diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index a7f802dcd..c873700bc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -341,7 +341,7 @@ def read(self) -> Optional[OAuthToken]: ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex=self.get_session_id_hex() + host_url=self.session.host ) # Determine proxy usage @@ -391,6 +391,7 @@ def read(self) -> Optional[OAuthToken]: self._telemetry_client.export_initial_telemetry_log( driver_connection_params=driver_connection_params, user_agent=self.session.useragent_header, + session_id=self.get_session_id_hex(), ) def _set_use_inline_params_with_warning(self, value: Union[bool, str]): @@ -494,6 +495,7 @@ def cursor( if not self.open: raise InterfaceError( "Cannot create cursor from closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -521,7 +523,7 @@ def _close(self, close_cursors=True) -> None: except Exception as e: logger.error(f"Attempt to close session raised a local exception: {e}") - TelemetryClientFactory.close(self.get_session_id_hex()) + TelemetryClientFactory.close(host_url=self.session.host) # Close HTTP client that was created by this connection if self.http_client: @@ -546,6 +548,7 @@ def autocommit(self) -> bool: if not self.open: raise InterfaceError( "Cannot get autocommit on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -578,6 +581,7 @@ def autocommit(self, value: bool) -> None: if not self.open: raise InterfaceError( "Cannot set autocommit on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -600,6 +604,7 @@ def autocommit(self, value: bool) -> None: "operation": "set_autocommit", "autocommit_value": value, }, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) from e finally: @@ -627,6 +632,7 @@ def _fetch_autocommit_state_from_server(self) -> bool: raise TransactionError( "No result returned from SET AUTOCOMMIT query", context={"operation": "fetch_autocommit"}, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -647,6 +653,7 @@ def _fetch_autocommit_state_from_server(self) -> bool: raise TransactionError( f"Failed to fetch autocommit state from server: {e.message}", context={**e.context, "operation": "fetch_autocommit"}, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) from e finally: @@ -680,6 +687,7 @@ def commit(self) -> None: if not self.open: raise InterfaceError( "Cannot commit on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -692,6 +700,7 @@ def commit(self) -> None: raise TransactionError( f"Failed to commit transaction: {e.message}", context={**e.context, "operation": "commit"}, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) from e finally: @@ -725,12 +734,14 @@ def rollback(self) -> None: if self.ignore_transactions: raise NotSupportedError( "Transactions are not supported on Databricks", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) if not self.open: raise InterfaceError( "Cannot rollback on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -743,6 +754,7 @@ def rollback(self) -> None: raise TransactionError( f"Failed to rollback transaction: {e.message}", context={**e.context, "operation": "rollback"}, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) from e finally: @@ -767,6 +779,7 @@ def get_transaction_isolation(self) -> str: if not self.open: raise InterfaceError( "Cannot get transaction isolation on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -793,6 +806,7 @@ def set_transaction_isolation(self, level: str) -> None: if not self.open: raise InterfaceError( "Cannot set transaction isolation on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -805,6 +819,7 @@ def set_transaction_isolation(self, level: str) -> None: raise NotSupportedError( f"Setting transaction isolation level '{level}' is not supported. " f"Only {TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ} is supported.", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -857,6 +872,7 @@ def __iter__(self): else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -997,6 +1013,7 @@ def _check_not_closed(self): if not self.open: raise InterfaceError( "Attempting operation on closed cursor", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1041,6 +1058,7 @@ def _handle_staging_operation( else: raise ProgrammingError( "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1067,6 +1085,7 @@ def _handle_staging_operation( if not allow_operation: raise ProgrammingError( "Local file operations are restricted to paths within the configured staging_allowed_local_path", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1095,6 +1114,7 @@ def _handle_staging_operation( raise ProgrammingError( f"Operation {row.operation} is not supported. " + "Supported operations are GET, PUT, and REMOVE", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1110,6 +1130,7 @@ def _handle_staging_put( if local_file is None: raise ProgrammingError( "Cannot perform PUT without specifying a local_file", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1135,6 +1156,7 @@ def _handle_staging_http_response(self, r): error_text = r.data.decode() if r.data else "" raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1166,6 +1188,7 @@ def _handle_staging_put_stream( if not stream: raise ProgrammingError( "No input stream provided for streaming operation", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1187,6 +1210,7 @@ def _handle_staging_get( if local_file is None: raise ProgrammingError( "Cannot perform GET without specifying a local_file", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1201,6 +1225,7 @@ def _handle_staging_get( error_text = r.data.decode() if r.data else "" raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1222,6 +1247,7 @@ def _handle_staging_remove( error_text = r.data.decode() if r.data else "" raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1413,6 +1439,7 @@ def get_async_execution_result(self): else: raise OperationalError( f"get_execution_result failed with Operation status {operation_state}", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1541,6 +1568,7 @@ def fetchall(self) -> List[Row]: else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1558,6 +1586,7 @@ def fetchone(self) -> Optional[Row]: else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1583,6 +1612,7 @@ def fetchmany(self, size: int) -> List[Row]: else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1593,6 +1623,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1603,6 +1634,7 @@ def fetchmany_arrow(self, size) -> "pyarrow.Table": else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 24844d573..f4770f3c4 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -12,20 +12,28 @@ class Error(Exception): """ def __init__( - self, message=None, context=None, session_id_hex=None, *args, **kwargs + self, + message=None, + context=None, + host_url=None, + *args, + session_id_hex=None, + **kwargs, ): super().__init__(message, *args, **kwargs) self.message = message self.context = context or {} error_name = self.__class__.__name__ - if session_id_hex: + if host_url: from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex + host_url=host_url + ) + telemetry_client.export_failure_log( + error_name, self.message, session_id=session_id_hex ) - telemetry_client.export_failure_log(error_name, self.message) def __str__(self): return self.message diff --git a/src/databricks/sql/telemetry/circuit_breaker_manager.py b/src/databricks/sql/telemetry/circuit_breaker_manager.py index 852f0d916..a5df7371e 100644 --- a/src/databricks/sql/telemetry/circuit_breaker_manager.py +++ b/src/databricks/sql/telemetry/circuit_breaker_manager.py @@ -60,16 +60,16 @@ def state_change(self, cb: CircuitBreaker, old_state, new_state) -> None: old_state_name = old_state.name if old_state else "None" new_state_name = new_state.name if new_state else "None" - logger.info( + logger.debug( LOG_CIRCUIT_BREAKER_STATE_CHANGED, old_state_name, new_state_name, cb.name ) if new_state_name == CIRCUIT_BREAKER_STATE_OPEN: - logger.warning(LOG_CIRCUIT_BREAKER_OPENED, cb.name) + logger.debug(LOG_CIRCUIT_BREAKER_OPENED, cb.name) elif new_state_name == CIRCUIT_BREAKER_STATE_CLOSED: - logger.info(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) + logger.debug(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) elif new_state_name == CIRCUIT_BREAKER_STATE_HALF_OPEN: - logger.info(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) + logger.debug(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) class CircuitBreakerManager: diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 36ebee2b8..2445c25c2 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -205,13 +205,14 @@ def wrapper(self, *args, **kwargs): telemetry_client = ( TelemetryClientFactory.get_telemetry_client( - session_id_hex + host_url=connection.session.host ) ) telemetry_client.export_latency_log( latency_ms=duration_ms, sql_execution_event=sql_exec_event, sql_statement_id=telemetry_data.get("statement_id"), + session_id=session_id_hex, ) return wrapper diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index d5f5b575c..77d1a2f9c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -147,13 +147,17 @@ def __new__(cls): cls._instance = super(NoopTelemetryClient, cls).__new__(cls) return cls._instance - def export_initial_telemetry_log(self, driver_connection_params, user_agent): + def export_initial_telemetry_log( + self, driver_connection_params, user_agent, session_id=None + ): pass - def export_failure_log(self, error_name, error_message): + def export_failure_log(self, error_name, error_message, session_id=None): pass - def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): + def export_latency_log( + self, latency_ms, sql_execution_event, sql_statement_id, session_id=None + ): pass def close(self): @@ -307,7 +311,7 @@ def _send_with_unified_client(self, url, data, headers, timeout=900): ) return response except Exception as e: - logger.error("Failed to send telemetry with unified client: %s", e) + logger.debug("Failed to send telemetry with unified client: %s", e) raise def _telemetry_request_callback(self, future, sent_count: int): @@ -352,19 +356,22 @@ def _telemetry_request_callback(self, future, sent_count: int): except Exception as e: logger.debug("Telemetry request failed with exception: %s", e) - def _export_telemetry_log(self, **telemetry_event_kwargs): + def _export_telemetry_log(self, session_id=None, **telemetry_event_kwargs): """ Common helper method for exporting telemetry logs. Args: + session_id: Optional session ID for this event. If not provided, uses the client's session ID. **telemetry_event_kwargs: Keyword arguments to pass to TelemetryEvent constructor """ - logger.debug("Exporting telemetry log for connection %s", self._session_id_hex) + # Use provided session_id or fall back to client's session_id + actual_session_id = session_id or self._session_id_hex + logger.debug("Exporting telemetry log for connection %s", actual_session_id) try: # Set common fields for all telemetry events event_kwargs = { - "session_id": self._session_id_hex, + "session_id": actual_session_id, "system_configuration": TelemetryHelper.get_driver_system_configuration(), "driver_connection_params": self._driver_connection_params, } @@ -387,17 +394,22 @@ def _export_telemetry_log(self, **telemetry_event_kwargs): except Exception as e: logger.debug("Failed to export telemetry log: %s", e) - def export_initial_telemetry_log(self, driver_connection_params, user_agent): + def export_initial_telemetry_log( + self, driver_connection_params, user_agent, session_id=None + ): self._driver_connection_params = driver_connection_params self._user_agent = user_agent - self._export_telemetry_log() + self._export_telemetry_log(session_id=session_id) - def export_failure_log(self, error_name, error_message): + def export_failure_log(self, error_name, error_message, session_id=None): error_info = DriverErrorInfo(error_name=error_name, stack_trace=error_message) - self._export_telemetry_log(error_info=error_info) + self._export_telemetry_log(session_id=session_id, error_info=error_info) - def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): + def export_latency_log( + self, latency_ms, sql_execution_event, sql_statement_id, session_id=None + ): self._export_telemetry_log( + session_id=session_id, sql_statement_id=sql_statement_id, sql_operation=sql_execution_event, operation_latency_ms=latency_ms, @@ -409,15 +421,39 @@ def close(self): self._flush() +class _TelemetryClientHolder: + """ + Holds a telemetry client with reference counting. + Multiple connections to the same host share one client. + """ + + def __init__(self, client: BaseTelemetryClient): + self.client = client + self.refcount = 1 + + def increment(self): + """Increment reference count when a new connection uses this client""" + self.refcount += 1 + + def decrement(self): + """Decrement reference count when a connection closes""" + self.refcount -= 1 + return self.refcount + + class TelemetryClientFactory: """ Static factory class for creating and managing telemetry clients. It uses a thread pool to handle asynchronous operations and a single flush thread for all clients. + + Clients are shared at the HOST level - multiple connections to the same host + share a single TelemetryClient to enable efficient batching and reduce load + on the telemetry endpoint. """ _clients: Dict[ - str, BaseTelemetryClient - ] = {} # Map of session_id_hex -> BaseTelemetryClient + str, _TelemetryClientHolder + ] = {} # Map of host_url -> TelemetryClientHolder _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False _lock = threading.RLock() # Thread safety for factory operations @@ -431,6 +467,22 @@ class TelemetryClientFactory: _flush_interval_seconds = 300 # 5 minutes DEFAULT_BATCH_SIZE = 100 + UNKNOWN_HOST = "unknown-host" + + @staticmethod + def getHostUrlSafely(host_url): + """ + Safely get host URL with fallback to UNKNOWN_HOST. + + Args: + host_url: The host URL to validate + + Returns: + The host_url if valid, otherwise UNKNOWN_HOST + """ + if not host_url or not isinstance(host_url, str) or not host_url.strip(): + return TelemetryClientFactory.UNKNOWN_HOST + return host_url @classmethod def _initialize(cls): @@ -464,8 +516,8 @@ def _flush_worker(cls): with cls._lock: clients_to_flush = list(cls._clients.values()) - for client in clients_to_flush: - client._flush() + for holder in clients_to_flush: + holder.client._flush() @classmethod def _stop_flush_thread(cls): @@ -506,21 +558,38 @@ def initialize_telemetry_client( batch_size, client_context, ): - """Initialize a telemetry client for a specific connection if telemetry is enabled""" + """ + Initialize a telemetry client for a specific connection if telemetry is enabled. + + Clients are shared at the HOST level - multiple connections to the same host + will share a single TelemetryClient with reference counting. + """ try: + # Safely get host_url with fallback to UNKNOWN_HOST + host_url = TelemetryClientFactory.getHostUrlSafely(host_url) with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() - if session_id_hex not in TelemetryClientFactory._clients: + if host_url in TelemetryClientFactory._clients: + # Reuse existing client for this host + holder = TelemetryClientFactory._clients[host_url] + holder.increment() logger.debug( - "Creating new TelemetryClient for connection %s", + "Reusing TelemetryClient for host %s (session %s, refcount=%d)", + host_url, + session_id_hex, + holder.refcount, + ) + else: + # Create new client for this host + logger.debug( + "Creating new TelemetryClient for host %s (session %s)", + host_url, session_id_hex, ) if telemetry_enabled: - TelemetryClientFactory._clients[ - session_id_hex - ] = TelemetryClient( + client = TelemetryClient( telemetry_enabled=telemetry_enabled, session_id_hex=session_id_hex, auth_provider=auth_provider, @@ -529,36 +598,73 @@ def initialize_telemetry_client( batch_size=batch_size, client_context=client_context, ) + TelemetryClientFactory._clients[ + host_url + ] = _TelemetryClientHolder(client) else: TelemetryClientFactory._clients[ - session_id_hex - ] = NoopTelemetryClient() + host_url + ] = _TelemetryClientHolder(NoopTelemetryClient()) except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail - TelemetryClientFactory._clients[session_id_hex] = NoopTelemetryClient() + TelemetryClientFactory._clients[host_url] = _TelemetryClientHolder( + NoopTelemetryClient() + ) @staticmethod - def get_telemetry_client(session_id_hex): - """Get the telemetry client for a specific connection""" - return TelemetryClientFactory._clients.get( - session_id_hex, NoopTelemetryClient() - ) + def get_telemetry_client(host_url): + """ + Get the shared telemetry client for a specific host. + + Args: + host_url: The host URL to look up the client. If None/empty, uses UNKNOWN_HOST. + + Returns: + The shared TelemetryClient for this host, or NoopTelemetryClient if not found + """ + host_url = TelemetryClientFactory.getHostUrlSafely(host_url) + + if host_url in TelemetryClientFactory._clients: + return TelemetryClientFactory._clients[host_url].client + return NoopTelemetryClient() @staticmethod - def close(session_id_hex): - """Close and remove the telemetry client for a specific connection""" + def close(host_url): + """ + Close the telemetry client for a specific host. + + Decrements the reference count for the host's client. Only actually closes + the client when the reference count reaches zero (all connections to this host closed). + + Args: + host_url: The host URL whose client to close. If None/empty, uses UNKNOWN_HOST. + """ + host_url = TelemetryClientFactory.getHostUrlSafely(host_url) with TelemetryClientFactory._lock: - if ( - telemetry_client := TelemetryClientFactory._clients.pop( - session_id_hex, None - ) - ) is not None: + # Get the holder for this host + holder = TelemetryClientFactory._clients.get(host_url) + if holder is None: + logger.debug("No telemetry client found for host %s", host_url) + return + + # Decrement refcount + remaining_refs = holder.decrement() + logger.debug( + "Decremented refcount for host %s (refcount=%d)", + host_url, + remaining_refs, + ) + + # Only close if no more references + if remaining_refs <= 0: logger.debug( - "Removing telemetry client for connection %s", session_id_hex + "Closing telemetry client for host %s (no more references)", + host_url, ) - telemetry_client.close() + TelemetryClientFactory._clients.pop(host_url, None) + holder.client.close() # Shutdown executor if no more clients if not TelemetryClientFactory._clients and TelemetryClientFactory._executor: @@ -597,7 +703,7 @@ def connection_failure_log( ) telemetry_client = TelemetryClientFactory.get_telemetry_client( - UNAUTH_DUMMY_SESSION_ID + host_url=host_url ) telemetry_client._driver_connection_params = DriverConnectionParameters( http_path=http_path, diff --git a/src/databricks/sql/telemetry/telemetry_push_client.py b/src/databricks/sql/telemetry/telemetry_push_client.py index 461a57738..e77910007 100644 --- a/src/databricks/sql/telemetry/telemetry_push_client.py +++ b/src/databricks/sql/telemetry/telemetry_push_client.py @@ -120,7 +120,7 @@ def _make_request_and_check_status( # Check for rate limiting or service unavailable if response.status in [429, 503]: - logger.warning( + logger.debug( "Telemetry endpoint returned %d for host %s, triggering circuit breaker", response.status, self._host, diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index d2ac4227d..546a2b8b2 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -41,6 +41,7 @@ def telemetry_setup_teardown(self): TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._stop_flush_thread() + TelemetryClientFactory._clients.clear() TelemetryClientFactory._initialized = False def test_concurrent_queries_sends_telemetry(self): diff --git a/tests/unit/test_circuit_breaker_manager.py b/tests/unit/test_circuit_breaker_manager.py index e8ed4e809..1e02556d9 100644 --- a/tests/unit/test_circuit_breaker_manager.py +++ b/tests/unit/test_circuit_breaker_manager.py @@ -157,4 +157,4 @@ def test_circuit_breaker_state_listener_transitions(self, old_state, new_state): with patch("databricks.sql.telemetry.circuit_breaker_manager.logger") as mock_logger: listener.state_change(mock_cb, mock_old_state, mock_new_state) - mock_logger.info.assert_called() + mock_logger.debug.assert_called() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index b515756e8..8f8a97eae 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -714,7 +714,7 @@ def test_autocommit_setter_wraps_database_error(self, mock_session_class): server_error = DatabaseError( "AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION", context={"sql_state": "25000"}, - session_id_hex="test-session-id", + host_url="test-host", ) mock_cursor.execute.side_effect = server_error @@ -737,7 +737,7 @@ def test_autocommit_setter_preserves_exception_chain(self, mock_session_class): mock_cursor = Mock() original_error = DatabaseError( - "Original error", session_id_hex="test-session-id" + "Original error", host_url="test-host" ) mock_cursor.execute.side_effect = original_error @@ -772,7 +772,7 @@ def test_commit_wraps_database_error(self, mock_session_class): server_error = DatabaseError( "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION", context={"sql_state": "25000"}, - session_id_hex="test-session-id", + host_url="test-host", ) mock_cursor.execute.side_effect = server_error @@ -822,7 +822,7 @@ def test_rollback_wraps_database_error(self, mock_session_class): server_error = DatabaseError( "Unexpected rollback error", context={"sql_state": "HY000"}, - session_id_hex="test-session-id", + host_url="test-host", ) mock_cursor.execute.side_effect = server_error diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 96a2f87d8..e9fa16649 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -249,13 +249,13 @@ def test_client_lifecycle_flow(self): client_context=client_context, ) - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client("test-host.com") assert isinstance(client, TelemetryClient) assert client._session_id_hex == session_id_hex # Close client with patch.object(client, "close") as mock_close: - TelemetryClientFactory.close(session_id_hex) + TelemetryClientFactory.close(host_url="test-host.com") mock_close.assert_called_once() # Should get NoopTelemetryClient after close @@ -274,7 +274,7 @@ def test_disabled_telemetry_creates_noop_client(self): client_context=client_context, ) - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client("test-host.com") assert isinstance(client, NoopTelemetryClient) def test_factory_error_handling(self): @@ -297,7 +297,7 @@ def test_factory_error_handling(self): ) # Should fall back to NoopTelemetryClient - client = TelemetryClientFactory.get_telemetry_client(session_id) + client = TelemetryClientFactory.get_telemetry_client("test-host.com") assert isinstance(client, NoopTelemetryClient) def test_factory_shutdown_flow(self): @@ -325,11 +325,11 @@ def test_factory_shutdown_flow(self): assert TelemetryClientFactory._executor is not None # Close first client - factory should stay initialized - TelemetryClientFactory.close(session1) + TelemetryClientFactory.close(host_url="test-host.com") assert TelemetryClientFactory._initialized is True # Close second client - factory should shut down - TelemetryClientFactory.close(session2) + TelemetryClientFactory.close(host_url="test-host.com") assert TelemetryClientFactory._initialized is False assert TelemetryClientFactory._executor is None @@ -367,6 +367,13 @@ def test_connection_failure_sends_correct_telemetry_payload( class TestTelemetryFeatureFlag: """Tests the interaction between the telemetry feature flag and connection parameters.""" + def teardown_method(self): + """Clean up telemetry factory state after each test to prevent test pollution.""" + from databricks.sql.common.feature_flag import FeatureFlagsContextFactory + + TelemetryClientFactory._clients.clear() + FeatureFlagsContextFactory._context_map.clear() + def _mock_ff_response(self, mock_http_request, enabled: bool): """Helper method to mock feature flag response for unified HTTP client.""" mock_response = MagicMock() @@ -391,6 +398,7 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio self._mock_ff_response(mock_http_request, enabled=True) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" + mock_session_instance.host = "test-host" # Set host for telemetry client lookup mock_session_instance.auth_provider = AccessTokenAuthProvider("token") mock_session_instance.is_open = ( False # Connection starts closed for test cleanup @@ -410,7 +418,7 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio assert conn.telemetry_enabled is True mock_http_request.assert_called_once() - client = TelemetryClientFactory.get_telemetry_client("test-session-ff-true") + client = TelemetryClientFactory.get_telemetry_client("test-host") assert isinstance(client, TelemetryClient) @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") @@ -421,6 +429,7 @@ def test_telemetry_disabled_when_flag_is_false( self._mock_ff_response(mock_http_request, enabled=False) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" + mock_session_instance.host = "test-host" # Set host for telemetry client lookup mock_session_instance.auth_provider = AccessTokenAuthProvider("token") mock_session_instance.is_open = ( False # Connection starts closed for test cleanup @@ -440,7 +449,7 @@ def test_telemetry_disabled_when_flag_is_false( assert conn.telemetry_enabled is False mock_http_request.assert_called_once() - client = TelemetryClientFactory.get_telemetry_client("test-session-ff-false") + client = TelemetryClientFactory.get_telemetry_client("test-host") assert isinstance(client, NoopTelemetryClient) @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") @@ -451,6 +460,7 @@ def test_telemetry_disabled_when_flag_request_fails( mock_http_request.side_effect = Exception("Network is down") mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" + mock_session_instance.host = "test-host" # Set host for telemetry client lookup mock_session_instance.auth_provider = AccessTokenAuthProvider("token") mock_session_instance.is_open = ( False # Connection starts closed for test cleanup @@ -470,7 +480,7 @@ def test_telemetry_disabled_when_flag_request_fails( assert conn.telemetry_enabled is False mock_http_request.assert_called_once() - client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail") + client = TelemetryClientFactory.get_telemetry_client("test-host") assert isinstance(client, NoopTelemetryClient) diff --git a/tests/unit/test_telemetry_push_client.py b/tests/unit/test_telemetry_push_client.py index 0e9455e1f..6555f1d02 100644 --- a/tests/unit/test_telemetry_push_client.py +++ b/tests/unit/test_telemetry_push_client.py @@ -114,10 +114,10 @@ def test_rate_limit_error_logging(self): with pytest.raises(TelemetryRateLimitError): self.client.request(HttpMethod.POST, "https://test.com", {}) - mock_logger.warning.assert_called() - warning_args = mock_logger.warning.call_args[0] - assert "429" in str(warning_args) - assert "circuit breaker" in warning_args[0] + mock_logger.debug.assert_called() + debug_args = mock_logger.debug.call_args[0] + assert "429" in str(debug_args) + assert "circuit breaker" in debug_args[0] def test_other_error_logging(self): """Test that other errors are logged during wrapping/unwrapping.""" From d2ae1e8fe6e1f4ad5dba13751cd0f91763aea9bf Mon Sep 17 00:00:00 2001 From: Samikshya Chand <148681192+samikshya-db@users.noreply.github.com> Date: Thu, 4 Dec 2025 16:21:20 +0530 Subject: [PATCH 12/39] Prepare for a release with telemetry on by default (#717) * Prepare for a release with telemetry on by default Signed-off-by: samikshya-chand_data * Make edits Signed-off-by: samikshya-chand_data * Update version Signed-off-by: samikshya-chand_data * Fix CHANGELOG formatting to match previous style Signed-off-by: samikshya-chand_data * Fix telemetry e2e tests for default-enabled behavior - Update test expectations to reflect telemetry being enabled by default - Add feature flags cache cleanup in teardown to prevent state leakage between tests - This ensures each test runs with fresh feature flag state * Add wait after connection close for async telemetry submission * Remove debug logging from telemetry tests * Mark telemetry e2e tests as serial - must not run in parallel Root cause: Telemetry tests share host-level client across pytest-xdist workers, causing test isolation issues with patches. Tests pass serially but fail with -n auto. Solution: Add @pytest.mark.serial marker. CI needs to run these separately without -n auto. * Split test execution to run serial tests separately Telemetry e2e tests must run serially due to shared host-level telemetry client across pytest-xdist workers. Running with -n auto causes test isolation issues where futures aren't properly captured. Changes: - Run parallel tests with -m 'not serial' -n auto - Run serial tests with -m 'serial' without parallelization - Use --cov-append for serial tests to combine coverage - Mark telemetry e2e tests with @pytest.mark.serial - Update test expectations for default telemetry behavior - Add feature flags cache cleanup in test teardown * Mark telemetry e2e tests as serial - must not run in parallel The concurrent telemetry e2e test globally patches telemetry methods to capture events. When run in parallel with other tests via pytest-xdist, it captures telemetry events from other concurrent tests, causing assertion failures (expected 60 events, got 88). All telemetry e2e tests must run serially to avoid cross-test interference with the shared host-level telemetry client. --------- Signed-off-by: samikshya-chand_data --- .github/workflows/code-coverage.yml | 19 +++++++++++++++++-- CHANGELOG.md | 7 +++++++ README.md | 2 +- pyproject.toml | 7 +++++-- src/databricks/sql/__init__.py | 2 +- src/databricks/sql/auth/common.py | 2 +- src/databricks/sql/client.py | 2 +- tests/e2e/test_concurrent_telemetry.py | 1 + tests/e2e/test_telemetry_e2e.py | 17 ++++++++++++++--- 9 files changed, 48 insertions(+), 11 deletions(-) diff --git a/.github/workflows/code-coverage.yml b/.github/workflows/code-coverage.yml index d9954d051..3c76be728 100644 --- a/.github/workflows/code-coverage.yml +++ b/.github/workflows/code-coverage.yml @@ -61,17 +61,32 @@ jobs: - name: Install library run: poetry install --no-interaction --all-extras #---------------------------------------------- - # run all tests with coverage + # run parallel tests with coverage #---------------------------------------------- - - name: Run all tests with coverage + - name: Run parallel tests with coverage continue-on-error: false run: | poetry run pytest tests/unit tests/e2e \ + -m "not serial" \ -n auto \ --cov=src \ --cov-report=xml \ --cov-report=term \ -v + + #---------------------------------------------- + # run serial tests with coverage + #---------------------------------------------- + - name: Run serial tests with coverage + continue-on-error: false + run: | + poetry run pytest tests/e2e \ + -m "serial" \ + --cov=src \ + --cov-append \ + --cov-report=xml \ + --cov-report=term \ + -v #---------------------------------------------- # check for coverage override diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b902e976..6be2dacaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Release History +# 4.2.2 (2025-12-01) +- Change default use_hybrid_disposition to False (databricks/databricks-sql-python#714 by @samikshya-db) +- Circuit breaker changes using pybreaker (databricks/databricks-sql-python#705 by @nikhilsuri-db) +- perf: Optimize telemetry latency logging to reduce overhead (databricks/databricks-sql-python#715 by @samikshya-db) +- basic e2e test for force telemetry verification (databricks/databricks-sql-python#708 by @nikhilsuri-db) +- Telemetry is ON by default to track connection stats. (Note : This strictly excludes PII, query text, and results) (databricks/databricks-sql-python#717 by @samikshya-db) + # 4.2.1 (2025-11-20) - Ignore transactions by default (databricks/databricks-sql-python#711 by @jayantsing-db) diff --git a/README.md b/README.md index ec82a3637..047515ba4 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ You are welcome to file an issue here for general use cases. You can also contac ## Requirements -Python 3.8 or above is required. +Python 3.9 or above is required. ## Documentation diff --git a/pyproject.toml b/pyproject.toml index 61c248e98..d2739c7d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.2.1" +version = "4.2.2" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" @@ -62,7 +62,10 @@ exclude = ['ttypes\.py$', 'TCLIService\.py$'] exclude = '/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist|thrift_api)/' [tool.pytest.ini_options] -markers = {"reviewed" = "Test case has been reviewed by Databricks"} +markers = [ + "reviewed: Test case has been reviewed by Databricks", + "serial: Tests that must run serially (not parallelized)" +] minversion = "6.0" log_cli = "false" log_cli_level = "INFO" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index cd37e6ce1..7cf631e83 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -71,7 +71,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.2.1" +__version__ = "4.2.2" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index a764b036d..0e3a01918 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -51,7 +51,7 @@ def __init__( pool_connections: Optional[int] = None, pool_maxsize: Optional[int] = None, user_agent: Optional[str] = None, - telemetry_circuit_breaker_enabled: Optional[bool] = None, + telemetry_circuit_breaker_enabled: Optional[bool] = True, ): self.hostname = hostname self.access_token = access_token diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index c873700bc..1f17d54f2 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -328,7 +328,7 @@ def read(self) -> Optional[OAuthToken]: self.ignore_transactions = ignore_transactions self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False) - self.enable_telemetry = kwargs.get("enable_telemetry", False) + self.enable_telemetry = kwargs.get("enable_telemetry", True) self.telemetry_enabled = TelemetryHelper.is_telemetry_enabled(self) TelemetryClientFactory.initialize_telemetry_client( diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index 546a2b8b2..bed348c2c 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -26,6 +26,7 @@ def run_in_threads(target, num_threads, pass_index=False): t.join() +@pytest.mark.serial class TestE2ETelemetry(PySQLPytestTestCase): @pytest.fixture(autouse=True) def telemetry_setup_teardown(self): diff --git a/tests/e2e/test_telemetry_e2e.py b/tests/e2e/test_telemetry_e2e.py index 917c8e5eb..0a57edd3c 100644 --- a/tests/e2e/test_telemetry_e2e.py +++ b/tests/e2e/test_telemetry_e2e.py @@ -43,8 +43,9 @@ def connection(self, extra_params=()): conn.close() +@pytest.mark.serial class TestTelemetryE2E(TelemetryTestBase): - """E2E tests for telemetry scenarios""" + """E2E tests for telemetry scenarios - must run serially due to shared host-level telemetry client""" @pytest.fixture(autouse=True) def telemetry_setup_teardown(self): @@ -58,6 +59,14 @@ def telemetry_setup_teardown(self): TelemetryClientFactory._stop_flush_thread() TelemetryClientFactory._initialized = False + # Clear feature flags cache to prevent state leakage between tests + from databricks.sql.common.feature_flag import FeatureFlagsContextFactory + with FeatureFlagsContextFactory._lock: + FeatureFlagsContextFactory._context_map.clear() + if FeatureFlagsContextFactory._executor: + FeatureFlagsContextFactory._executor.shutdown(wait=False) + FeatureFlagsContextFactory._executor = None + @pytest.fixture def telemetry_interceptors(self): """Setup reusable telemetry interceptors as a fixture""" @@ -142,7 +151,7 @@ def verify_events(self, captured_events, captured_futures, expected_count): else: assert len(captured_events) == expected_count, \ f"Expected {expected_count} events, got {len(captured_events)}" - + time.sleep(2) done, _ = wait(captured_futures, timeout=10) assert len(done) == expected_count, \ @@ -163,7 +172,7 @@ def verify_events(self, captured_events, captured_futures, expected_count): (True, False, 2, "enable_on_force_off"), (False, True, 2, "enable_off_force_on"), (False, False, 0, "both_off"), - (None, None, 0, "default_behavior"), + (None, None, 2, "default_behavior"), ]) def test_telemetry_flags(self, telemetry_interceptors, enable_telemetry, force_enable, expected_count, test_id): @@ -185,6 +194,8 @@ def test_telemetry_flags(self, telemetry_interceptors, enable_telemetry, cursor.execute("SELECT 1") cursor.fetchone() + # Give time for async telemetry submission after connection closes + time.sleep(0.5) self.verify_events(captured_events, captured_futures, expected_count) # Assert statement execution on latency event (if events exist) From 7c6adeebb97749654a3300863de0f8b692c23707 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Thu, 11 Dec 2025 15:04:19 +0530 Subject: [PATCH 13/39] added pandas < 2.4.0 support and tests for py 3.14 (#720) * added pandas 2.3.3 support and tests for py 3.14 Signed-off-by: Sreekanth Vadigi * generated poetry.lock Signed-off-by: Sreekanth Vadigi * lz4 version update for py 3.14 Signed-off-by: Sreekanth Vadigi * dependency selection based on py version Signed-off-by: Sreekanth Vadigi * pyarrow version update for py 3.14 Signed-off-by: Sreekanth Vadigi * poetry.lock with latest poetry version Signed-off-by: Sreekanth Vadigi --------- Signed-off-by: Sreekanth Vadigi --- .github/workflows/code-quality-checks.yml | 8 +- poetry.lock | 141 +++++++++++++++++++++- pyproject.toml | 10 +- scripts/dependency_manager.py | 19 ++- 4 files changed, 167 insertions(+), 11 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 22db995c5..3c368abef 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] dependency-version: ["default", "min"] # Optimize matrix - test min/max on subset of Python versions exclude: @@ -91,7 +91,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] dependency-version: ["default", "min"] exclude: - python-version: "3.12" @@ -173,7 +173,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] steps: #---------------------------------------------- # check-out repo and set-up python @@ -225,7 +225,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] steps: #---------------------------------------------- # check-out repo and set-up python diff --git a/poetry.lock b/poetry.lock index 193efa109..64a95add5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -764,6 +764,7 @@ description = "LZ4 Bindings for Python" optional = false python-versions = ">=3.8" groups = ["main"] +markers = "python_version < \"3.10\"" files = [ {file = "lz4-4.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b891880c187e96339474af2a3b2bfb11a8e4732ff5034be919aa9029484cd201"}, {file = "lz4-4.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:222a7e35137d7539c9c33bb53fcbb26510c5748779364014235afc62b0ec797f"}, @@ -808,6 +809,79 @@ docs = ["sphinx (>=1.6.0)", "sphinx-bootstrap-theme"] flake8 = ["flake8"] tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"] +[[package]] +name = "lz4" +version = "4.4.5" +description = "LZ4 Bindings for Python" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "lz4-4.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d221fa421b389ab2345640a508db57da36947a437dfe31aeddb8d5c7b646c22d"}, + {file = "lz4-4.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7dc1e1e2dbd872f8fae529acd5e4839efd0b141eaa8ae7ce835a9fe80fbad89f"}, + {file = "lz4-4.4.5-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:e928ec2d84dc8d13285b4a9288fd6246c5cde4f5f935b479f50d986911f085e3"}, + {file = "lz4-4.4.5-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:daffa4807ef54b927451208f5f85750c545a4abbff03d740835fc444cd97f758"}, + {file = "lz4-4.4.5-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2a2b7504d2dffed3fd19d4085fe1cc30cf221263fd01030819bdd8d2bb101cf1"}, + {file = "lz4-4.4.5-cp310-cp310-win32.whl", hash = "sha256:0846e6e78f374156ccf21c631de80967e03cc3c01c373c665789dc0c5431e7fc"}, + {file = "lz4-4.4.5-cp310-cp310-win_amd64.whl", hash = "sha256:7c4e7c44b6a31de77d4dc9772b7d2561937c9588a734681f70ec547cfbc51ecd"}, + {file = "lz4-4.4.5-cp310-cp310-win_arm64.whl", hash = "sha256:15551280f5656d2206b9b43262799c89b25a25460416ec554075a8dc568e4397"}, + {file = "lz4-4.4.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d6da84a26b3aa5da13a62e4b89ab36a396e9327de8cd48b436a3467077f8ccd4"}, + {file = "lz4-4.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:61d0ee03e6c616f4a8b69987d03d514e8896c8b1b7cc7598ad029e5c6aedfd43"}, + {file = "lz4-4.4.5-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:33dd86cea8375d8e5dd001e41f321d0a4b1eb7985f39be1b6a4f466cd480b8a7"}, + {file = "lz4-4.4.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:609a69c68e7cfcfa9d894dc06be13f2e00761485b62df4e2472f1b66f7b405fb"}, + {file = "lz4-4.4.5-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:75419bb1a559af00250b8f1360d508444e80ed4b26d9d40ec5b09fe7875cb989"}, + {file = "lz4-4.4.5-cp311-cp311-win32.whl", hash = "sha256:12233624f1bc2cebc414f9efb3113a03e89acce3ab6f72035577bc61b270d24d"}, + {file = "lz4-4.4.5-cp311-cp311-win_amd64.whl", hash = "sha256:8a842ead8ca7c0ee2f396ca5d878c4c40439a527ebad2b996b0444f0074ed004"}, + {file = "lz4-4.4.5-cp311-cp311-win_arm64.whl", hash = "sha256:83bc23ef65b6ae44f3287c38cbf82c269e2e96a26e560aa551735883388dcc4b"}, + {file = "lz4-4.4.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:df5aa4cead2044bab83e0ebae56e0944cc7fcc1505c7787e9e1057d6d549897e"}, + {file = "lz4-4.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6d0bf51e7745484d2092b3a51ae6eb58c3bd3ce0300cf2b2c14f76c536d5697a"}, + {file = "lz4-4.4.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:7b62f94b523c251cf32aa4ab555f14d39bd1a9df385b72443fd76d7c7fb051f5"}, + {file = "lz4-4.4.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2c3ea562c3af274264444819ae9b14dbbf1ab070aff214a05e97db6896c7597e"}, + {file = "lz4-4.4.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:24092635f47538b392c4eaeff14c7270d2c8e806bf4be2a6446a378591c5e69e"}, + {file = "lz4-4.4.5-cp312-cp312-win32.whl", hash = "sha256:214e37cfe270948ea7eb777229e211c601a3e0875541c1035ab408fbceaddf50"}, + {file = "lz4-4.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:713a777de88a73425cf08eb11f742cd2c98628e79a8673d6a52e3c5f0c116f33"}, + {file = "lz4-4.4.5-cp312-cp312-win_arm64.whl", hash = "sha256:a88cbb729cc333334ccfb52f070463c21560fca63afcf636a9f160a55fac3301"}, + {file = "lz4-4.4.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6bb05416444fafea170b07181bc70640975ecc2a8c92b3b658c554119519716c"}, + {file = "lz4-4.4.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b424df1076e40d4e884cfcc4c77d815368b7fb9ebcd7e634f937725cd9a8a72a"}, + {file = "lz4-4.4.5-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:216ca0c6c90719731c64f41cfbd6f27a736d7e50a10b70fad2a9c9b262ec923d"}, + {file = "lz4-4.4.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:533298d208b58b651662dd972f52d807d48915176e5b032fb4f8c3b6f5fe535c"}, + {file = "lz4-4.4.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:451039b609b9a88a934800b5fc6ee401c89ad9c175abf2f4d9f8b2e4ef1afc64"}, + {file = "lz4-4.4.5-cp313-cp313-win32.whl", hash = "sha256:a5f197ffa6fc0e93207b0af71b302e0a2f6f29982e5de0fbda61606dd3a55832"}, + {file = "lz4-4.4.5-cp313-cp313-win_amd64.whl", hash = "sha256:da68497f78953017deb20edff0dba95641cc86e7423dfadf7c0264e1ac60dc22"}, + {file = "lz4-4.4.5-cp313-cp313-win_arm64.whl", hash = "sha256:c1cfa663468a189dab510ab231aad030970593f997746d7a324d40104db0d0a9"}, + {file = "lz4-4.4.5-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:67531da3b62f49c939e09d56492baf397175ff39926d0bd5bd2d191ac2bff95f"}, + {file = "lz4-4.4.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a1acbbba9edbcbb982bc2cac5e7108f0f553aebac1040fbec67a011a45afa1ba"}, + {file = "lz4-4.4.5-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:a482eecc0b7829c89b498fda883dbd50e98153a116de612ee7c111c8bcf82d1d"}, + {file = "lz4-4.4.5-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e099ddfaa88f59dd8d36c8a3c66bd982b4984edf127eb18e30bb49bdba68ce67"}, + {file = "lz4-4.4.5-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a2af2897333b421360fdcce895c6f6281dc3fab018d19d341cf64d043fc8d90d"}, + {file = "lz4-4.4.5-cp313-cp313t-win32.whl", hash = "sha256:66c5de72bf4988e1b284ebdd6524c4bead2c507a2d7f172201572bac6f593901"}, + {file = "lz4-4.4.5-cp313-cp313t-win_amd64.whl", hash = "sha256:cdd4bdcbaf35056086d910d219106f6a04e1ab0daa40ec0eeef1626c27d0fddb"}, + {file = "lz4-4.4.5-cp313-cp313t-win_arm64.whl", hash = "sha256:28ccaeb7c5222454cd5f60fcd152564205bcb801bd80e125949d2dfbadc76bbd"}, + {file = "lz4-4.4.5-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:c216b6d5275fc060c6280936bb3bb0e0be6126afb08abccde27eed23dead135f"}, + {file = "lz4-4.4.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c8e71b14938082ebaf78144f3b3917ac715f72d14c076f384a4c062df96f9df6"}, + {file = "lz4-4.4.5-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9b5e6abca8df9f9bdc5c3085f33ff32cdc86ed04c65e0355506d46a5ac19b6e9"}, + {file = "lz4-4.4.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3b84a42da86e8ad8537aabef062e7f661f4a877d1c74d65606c49d835d36d668"}, + {file = "lz4-4.4.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bba042ec5a61fa77c7e380351a61cb768277801240249841defd2ff0a10742f"}, + {file = "lz4-4.4.5-cp314-cp314-win32.whl", hash = "sha256:bd85d118316b53ed73956435bee1997bd06cc66dd2fa74073e3b1322bd520a67"}, + {file = "lz4-4.4.5-cp314-cp314-win_amd64.whl", hash = "sha256:92159782a4502858a21e0079d77cdcaade23e8a5d252ddf46b0652604300d7be"}, + {file = "lz4-4.4.5-cp314-cp314-win_arm64.whl", hash = "sha256:d994b87abaa7a88ceb7a37c90f547b8284ff9da694e6afcfaa8568d739faf3f7"}, + {file = "lz4-4.4.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f6538aaaedd091d6e5abdaa19b99e6e82697d67518f114721b5248709b639fad"}, + {file = "lz4-4.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:13254bd78fef50105872989a2dc3418ff09aefc7d0765528adc21646a7288294"}, + {file = "lz4-4.4.5-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:e64e61f29cf95afb43549063d8433b46352baf0c8a70aa45e2585618fcf59d86"}, + {file = "lz4-4.4.5-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ff1b50aeeec64df5603f17984e4b5be6166058dcf8f1e26a3da40d7a0f6ab547"}, + {file = "lz4-4.4.5-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1dd4d91d25937c2441b9fc0f4af01704a2d09f30a38c5798bc1d1b5a15ec9581"}, + {file = "lz4-4.4.5-cp39-cp39-win32.whl", hash = "sha256:d64141085864918392c3159cdad15b102a620a67975c786777874e1e90ef15ce"}, + {file = "lz4-4.4.5-cp39-cp39-win_amd64.whl", hash = "sha256:f32b9e65d70f3684532358255dc053f143835c5f5991e28a5ac4c93ce94b9ea7"}, + {file = "lz4-4.4.5-cp39-cp39-win_arm64.whl", hash = "sha256:f9b8bde9909a010c75b3aea58ec3910393b758f3c219beed67063693df854db0"}, + {file = "lz4-4.4.5.tar.gz", hash = "sha256:5f0b9e53c1e82e88c10d7c180069363980136b9d7a8306c4dca4f760d60c39f0"}, +] + +[package.extras] +docs = ["sphinx (>=1.6.0)", "sphinx_bootstrap_theme"] +flake8 = ["flake8"] +tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"] + [[package]] name = "mccabe" version = "0.7.0" @@ -1299,7 +1373,7 @@ description = "Python library for Apache Arrow" optional = true python-versions = ">=3.9" groups = ["main"] -markers = "python_version >= \"3.10\" and extra == \"pyarrow\"" +markers = "python_version == \"3.13\" and extra == \"pyarrow\"" files = [ {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:fc28912a2dc924dddc2087679cc8b7263accc71b9ff025a1362b004711661a69"}, {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:fca15aabbe9b8355800d923cc2e82c8ef514af321e18b437c3d782aa884eaeec"}, @@ -1348,6 +1422,67 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pyarrow" +version = "22.0.0" +description = "Python library for Apache Arrow" +optional = true +python-versions = ">=3.10" +groups = ["main"] +markers = "python_version >= \"3.10\" and extra == \"pyarrow\" and (python_version < \"3.13\" or python_version >= \"3.14\")" +files = [ + {file = "pyarrow-22.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:77718810bd3066158db1e95a63c160ad7ce08c6b0710bc656055033e39cdad88"}, + {file = "pyarrow-22.0.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:44d2d26cda26d18f7af7db71453b7b783788322d756e81730acb98f24eb90ace"}, + {file = "pyarrow-22.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:b9d71701ce97c95480fecb0039ec5bb889e75f110da72005743451339262f4ce"}, + {file = "pyarrow-22.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:710624ab925dc2b05a6229d47f6f0dac1c1155e6ed559be7109f684eba048a48"}, + {file = "pyarrow-22.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f963ba8c3b0199f9d6b794c90ec77545e05eadc83973897a4523c9e8d84e9340"}, + {file = "pyarrow-22.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bd0d42297ace400d8febe55f13fdf46e86754842b860c978dfec16f081e5c653"}, + {file = "pyarrow-22.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:00626d9dc0f5ef3a75fe63fd68b9c7c8302d2b5bbc7f74ecaedba83447a24f84"}, + {file = "pyarrow-22.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:3e294c5eadfb93d78b0763e859a0c16d4051fc1c5231ae8956d61cb0b5666f5a"}, + {file = "pyarrow-22.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:69763ab2445f632d90b504a815a2a033f74332997052b721002298ed6de40f2e"}, + {file = "pyarrow-22.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:b41f37cabfe2463232684de44bad753d6be08a7a072f6a83447eeaf0e4d2a215"}, + {file = "pyarrow-22.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:35ad0f0378c9359b3f297299c3309778bb03b8612f987399a0333a560b43862d"}, + {file = "pyarrow-22.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8382ad21458075c2e66a82a29d650f963ce51c7708c7c0ff313a8c206c4fd5e8"}, + {file = "pyarrow-22.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1a812a5b727bc09c3d7ea072c4eebf657c2f7066155506ba31ebf4792f88f016"}, + {file = "pyarrow-22.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:ec5d40dd494882704fb876c16fa7261a69791e784ae34e6b5992e977bd2e238c"}, + {file = "pyarrow-22.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:bea79263d55c24a32b0d79c00a1c58bb2ee5f0757ed95656b01c0fb310c5af3d"}, + {file = "pyarrow-22.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:12fe549c9b10ac98c91cf791d2945e878875d95508e1a5d14091a7aaa66d9cf8"}, + {file = "pyarrow-22.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:334f900ff08ce0423407af97e6c26ad5d4e3b0763645559ece6fbf3747d6a8f5"}, + {file = "pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:c6c791b09c57ed76a18b03f2631753a4960eefbbca80f846da8baefc6491fcfe"}, + {file = "pyarrow-22.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c3200cb41cdbc65156e5f8c908d739b0dfed57e890329413da2748d1a2cd1a4e"}, + {file = "pyarrow-22.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ac93252226cf288753d8b46280f4edf3433bf9508b6977f8dd8526b521a1bbb9"}, + {file = "pyarrow-22.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:44729980b6c50a5f2bfcc2668d36c569ce17f8b17bccaf470c4313dcbbf13c9d"}, + {file = "pyarrow-22.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e6e95176209257803a8b3d0394f21604e796dadb643d2f7ca21b66c9c0b30c9a"}, + {file = "pyarrow-22.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:001ea83a58024818826a9e3f89bf9310a114f7e26dfe404a4c32686f97bd7901"}, + {file = "pyarrow-22.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:ce20fe000754f477c8a9125543f1936ea5b8867c5406757c224d745ed033e691"}, + {file = "pyarrow-22.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e0a15757fccb38c410947df156f9749ae4a3c89b2393741a50521f39a8cf202a"}, + {file = "pyarrow-22.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cedb9dd9358e4ea1d9bce3665ce0797f6adf97ff142c8e25b46ba9cdd508e9b6"}, + {file = "pyarrow-22.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:252be4a05f9d9185bb8c18e83764ebcfea7185076c07a7a662253af3a8c07941"}, + {file = "pyarrow-22.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:a4893d31e5ef780b6edcaf63122df0f8d321088bb0dee4c8c06eccb1ca28d145"}, + {file = "pyarrow-22.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:f7fe3dbe871294ba70d789be16b6e7e52b418311e166e0e3cba9522f0f437fb1"}, + {file = "pyarrow-22.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:ba95112d15fd4f1105fb2402c4eab9068f0554435e9b7085924bcfaac2cc306f"}, + {file = "pyarrow-22.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:c064e28361c05d72eed8e744c9605cbd6d2bb7481a511c74071fd9b24bc65d7d"}, + {file = "pyarrow-22.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6f9762274496c244d951c819348afbcf212714902742225f649cf02823a6a10f"}, + {file = "pyarrow-22.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a9d9ffdc2ab696f6b15b4d1f7cec6658e1d788124418cb30030afbae31c64746"}, + {file = "pyarrow-22.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ec1a15968a9d80da01e1d30349b2b0d7cc91e96588ee324ce1b5228175043e95"}, + {file = "pyarrow-22.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:bba208d9c7decf9961998edf5c65e3ea4355d5818dd6cd0f6809bec1afb951cc"}, + {file = "pyarrow-22.0.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:9bddc2cade6561f6820d4cd73f99a0243532ad506bc510a75a5a65a522b2d74d"}, + {file = "pyarrow-22.0.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:e70ff90c64419709d38c8932ea9fe1cc98415c4f87ea8da81719e43f02534bc9"}, + {file = "pyarrow-22.0.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:92843c305330aa94a36e706c16209cd4df274693e777ca47112617db7d0ef3d7"}, + {file = "pyarrow-22.0.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:6dda1ddac033d27421c20d7a7943eec60be44e0db4e079f33cc5af3b8280ccde"}, + {file = "pyarrow-22.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:84378110dd9a6c06323b41b56e129c504d157d1a983ce8f5443761eb5256bafc"}, + {file = "pyarrow-22.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:854794239111d2b88b40b6ef92aa478024d1e5074f364033e73e21e3f76b25e0"}, + {file = "pyarrow-22.0.0-cp314-cp314-win_amd64.whl", hash = "sha256:b883fe6fd85adad7932b3271c38ac289c65b7337c2c132e9569f9d3940620730"}, + {file = "pyarrow-22.0.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:7a820d8ae11facf32585507c11f04e3f38343c1e784c9b5a8b1da5c930547fe2"}, + {file = "pyarrow-22.0.0-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:c6ec3675d98915bf1ec8b3c7986422682f7232ea76cad276f4c8abd5b7319b70"}, + {file = "pyarrow-22.0.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:3e739edd001b04f654b166204fc7a9de896cf6007eaff33409ee9e50ceaff754"}, + {file = "pyarrow-22.0.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7388ac685cab5b279a41dfe0a6ccd99e4dbf322edfb63e02fc0443bf24134e91"}, + {file = "pyarrow-22.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f633074f36dbc33d5c05b5dc75371e5660f1dbf9c8b1d95669def05e5425989c"}, + {file = "pyarrow-22.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:4c19236ae2402a8663a2c8f21f1870a03cc57f0bef7e4b6eb3238cc82944de80"}, + {file = "pyarrow-22.0.0-cp314-cp314t-win_amd64.whl", hash = "sha256:0c34fe18094686194f204a3b1787a27456897d8a2d62caf84b61e8dfbc0252ae"}, + {file = "pyarrow-22.0.0.tar.gz", hash = "sha256:3d600dc583260d845c7d8a6db540339dd883081925da2bd1c5cb808f720b3cd9"}, +] + [[package]] name = "pybreaker" version = "1.2.0" @@ -1885,9 +2020,9 @@ socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] [extras] -pyarrow = ["pyarrow", "pyarrow"] +pyarrow = ["pyarrow", "pyarrow", "pyarrow"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "56b62e3543644c91cc316b11d89025423a66daba5f36609c45bcb3eeb3ce3f54" +content-hash = "01373096b340b4e384eb7b7dcc15c41f93c1fb3197145937f3eb07ea2719533e" diff --git a/pyproject.toml b/pyproject.toml index d2739c7d4..596260674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,9 +13,12 @@ python = "^3.8.0" thrift = ">=0.16.0,<0.21.0" pandas = [ { version = ">=1.2.5,<2.3.0", python = ">=3.8,<3.13" }, - { version = ">=2.2.3,<2.3.0", python = ">=3.13" } + { version = ">=2.2.3,<2.4.0", python = ">=3.13" } +] +lz4 = [ + { version = "^4.0.2", python = ">=3.8,<3.14" }, + { version = "^4.4.5", python = ">=3.14" } ] -lz4 = "^4.0.2" requests = "^2.18.1" oauthlib = "^3.1.0" openpyxl = "^3.0.10" @@ -23,7 +26,8 @@ urllib3 = ">=1.26" python-dateutil = "^2.8.0" pyarrow = [ { version = ">=14.0.1", python = ">=3.8,<3.13", optional=true }, - { version = ">=18.0.0", python = ">=3.13", optional=true } + { version = ">=18.0.0,<22.0.0", python = ">=3.13,<3.14", optional=true }, + { version = ">=22.0.0", python = ">=3.14", optional=true } ] pyjwt = "^2.0.0" pybreaker = "^1.0.0" diff --git a/scripts/dependency_manager.py b/scripts/dependency_manager.py index d73d095f2..15e119841 100644 --- a/scripts/dependency_manager.py +++ b/scripts/dependency_manager.py @@ -39,7 +39,24 @@ def _parse_constraint(self, name, constraint): if isinstance(constraint, str): return constraint, False # version_constraint, is_optional elif isinstance(constraint, list): - # Handle complex constraints like pandas/pyarrow + # Handle complex constraints like pandas/pyarrow with Python version markers + current_python = sys.version_info + current_version = f"{current_python.major}.{current_python.minor}" + + # Find the constraint that matches the current Python version + for item in constraint: + if 'python' in item: + python_spec = item['python'] + # Parse the Python version specifier + spec_set = SpecifierSet(python_spec) + + # Check if current Python version matches this constraint + if current_version in spec_set: + version = item['version'] + is_optional = item.get('optional', False) + return version, is_optional + + # Fallback to first constraint if no Python version match first_constraint = constraint[0] version = first_constraint['version'] is_optional = first_constraint.get('optional', False) From f7822fdb9814089504d871e97321b02267a2e5c9 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Thu, 18 Dec 2025 15:13:01 +0530 Subject: [PATCH 14/39] pandas 2.3.3 support for py < 3.14 (#721) * pandas 2.3.3 support for py < 3.14 Signed-off-by: Sreekanth Vadigi * poetry lock Signed-off-by: Sreekanth Vadigi --------- Signed-off-by: Sreekanth Vadigi --- poetry.lock | 60 ++------------------------------------------------ pyproject.toml | 4 ++-- 2 files changed, 4 insertions(+), 60 deletions(-) diff --git a/poetry.lock b/poetry.lock index 64a95add5..7d0845a58 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1366,62 +1366,6 @@ numpy = ">=1.16.6" [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] -[[package]] -name = "pyarrow" -version = "19.0.1" -description = "Python library for Apache Arrow" -optional = true -python-versions = ">=3.9" -groups = ["main"] -markers = "python_version == \"3.13\" and extra == \"pyarrow\"" -files = [ - {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:fc28912a2dc924dddc2087679cc8b7263accc71b9ff025a1362b004711661a69"}, - {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:fca15aabbe9b8355800d923cc2e82c8ef514af321e18b437c3d782aa884eaeec"}, - {file = "pyarrow-19.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad76aef7f5f7e4a757fddcdcf010a8290958f09e3470ea458c80d26f4316ae89"}, - {file = "pyarrow-19.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d03c9d6f2a3dffbd62671ca070f13fc527bb1867b4ec2b98c7eeed381d4f389a"}, - {file = "pyarrow-19.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:65cf9feebab489b19cdfcfe4aa82f62147218558d8d3f0fc1e9dea0ab8e7905a"}, - {file = "pyarrow-19.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:41f9706fbe505e0abc10e84bf3a906a1338905cbbcf1177b71486b03e6ea6608"}, - {file = "pyarrow-19.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:c6cb2335a411b713fdf1e82a752162f72d4a7b5dbc588e32aa18383318b05866"}, - {file = "pyarrow-19.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc55d71898ea30dc95900297d191377caba257612f384207fe9f8293b5850f90"}, - {file = "pyarrow-19.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:7a544ec12de66769612b2d6988c36adc96fb9767ecc8ee0a4d270b10b1c51e00"}, - {file = "pyarrow-19.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0148bb4fc158bfbc3d6dfe5001d93ebeed253793fff4435167f6ce1dc4bddeae"}, - {file = "pyarrow-19.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f24faab6ed18f216a37870d8c5623f9c044566d75ec586ef884e13a02a9d62c5"}, - {file = "pyarrow-19.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:4982f8e2b7afd6dae8608d70ba5bd91699077323f812a0448d8b7abdff6cb5d3"}, - {file = "pyarrow-19.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:49a3aecb62c1be1d822f8bf629226d4a96418228a42f5b40835c1f10d42e4db6"}, - {file = "pyarrow-19.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:008a4009efdb4ea3d2e18f05cd31f9d43c388aad29c636112c2966605ba33466"}, - {file = "pyarrow-19.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:80b2ad2b193e7d19e81008a96e313fbd53157945c7be9ac65f44f8937a55427b"}, - {file = "pyarrow-19.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:ee8dec072569f43835932a3b10c55973593abc00936c202707a4ad06af7cb294"}, - {file = "pyarrow-19.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d5d1ec7ec5324b98887bdc006f4d2ce534e10e60f7ad995e7875ffa0ff9cb14"}, - {file = "pyarrow-19.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3ad4c0eb4e2a9aeb990af6c09e6fa0b195c8c0e7b272ecc8d4d2b6574809d34"}, - {file = "pyarrow-19.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:d383591f3dcbe545f6cc62daaef9c7cdfe0dff0fb9e1c8121101cabe9098cfa6"}, - {file = "pyarrow-19.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b4c4156a625f1e35d6c0b2132635a237708944eb41df5fbe7d50f20d20c17832"}, - {file = "pyarrow-19.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:5bd1618ae5e5476b7654c7b55a6364ae87686d4724538c24185bbb2952679960"}, - {file = "pyarrow-19.0.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e45274b20e524ae5c39d7fc1ca2aa923aab494776d2d4b316b49ec7572ca324c"}, - {file = "pyarrow-19.0.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:d9dedeaf19097a143ed6da37f04f4051aba353c95ef507764d344229b2b740ae"}, - {file = "pyarrow-19.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ebfb5171bb5f4a52319344ebbbecc731af3f021e49318c74f33d520d31ae0c4"}, - {file = "pyarrow-19.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a21d39fbdb948857f67eacb5bbaaf36802de044ec36fbef7a1c8f0dd3a4ab2"}, - {file = "pyarrow-19.0.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:99bc1bec6d234359743b01e70d4310d0ab240c3d6b0da7e2a93663b0158616f6"}, - {file = "pyarrow-19.0.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:1b93ef2c93e77c442c979b0d596af45e4665d8b96da598db145b0fec014b9136"}, - {file = "pyarrow-19.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:d9d46e06846a41ba906ab25302cf0fd522f81aa2a85a71021826f34639ad31ef"}, - {file = "pyarrow-19.0.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:c0fe3dbbf054a00d1f162fda94ce236a899ca01123a798c561ba307ca38af5f0"}, - {file = "pyarrow-19.0.1-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:96606c3ba57944d128e8a8399da4812f56c7f61de8c647e3470b417f795d0ef9"}, - {file = "pyarrow-19.0.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f04d49a6b64cf24719c080b3c2029a3a5b16417fd5fd7c4041f94233af732f3"}, - {file = "pyarrow-19.0.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a9137cf7e1640dce4c190551ee69d478f7121b5c6f323553b319cac936395f6"}, - {file = "pyarrow-19.0.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:7c1bca1897c28013db5e4c83944a2ab53231f541b9e0c3f4791206d0c0de389a"}, - {file = "pyarrow-19.0.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:58d9397b2e273ef76264b45531e9d552d8ec8a6688b7390b5be44c02a37aade8"}, - {file = "pyarrow-19.0.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b9766a47a9cb56fefe95cb27f535038b5a195707a08bf61b180e642324963b46"}, - {file = "pyarrow-19.0.1-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:6c5941c1aac89a6c2f2b16cd64fe76bcdb94b2b1e99ca6459de4e6f07638d755"}, - {file = "pyarrow-19.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd44d66093a239358d07c42a91eebf5015aa54fccba959db899f932218ac9cc8"}, - {file = "pyarrow-19.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:335d170e050bcc7da867a1ed8ffb8b44c57aaa6e0843b156a501298657b1e972"}, - {file = "pyarrow-19.0.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:1c7556165bd38cf0cd992df2636f8bcdd2d4b26916c6b7e646101aff3c16f76f"}, - {file = "pyarrow-19.0.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:699799f9c80bebcf1da0983ba86d7f289c5a2a5c04b945e2f2bcf7e874a91911"}, - {file = "pyarrow-19.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:8464c9fbe6d94a7fe1599e7e8965f350fd233532868232ab2596a71586c5a429"}, - {file = "pyarrow-19.0.1.tar.gz", hash = "sha256:3bf266b485df66a400f282ac0b6d1b500b9d2ae73314a153dbe97d6d5cc8a99e"}, -] - -[package.extras] -test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] - [[package]] name = "pyarrow" version = "22.0.0" @@ -1429,7 +1373,7 @@ description = "Python library for Apache Arrow" optional = true python-versions = ">=3.10" groups = ["main"] -markers = "python_version >= \"3.10\" and extra == \"pyarrow\" and (python_version < \"3.13\" or python_version >= \"3.14\")" +markers = "python_version >= \"3.10\" and extra == \"pyarrow\"" files = [ {file = "pyarrow-22.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:77718810bd3066158db1e95a63c160ad7ce08c6b0710bc656055033e39cdad88"}, {file = "pyarrow-22.0.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:44d2d26cda26d18f7af7db71453b7b783788322d756e81730acb98f24eb90ace"}, @@ -2025,4 +1969,4 @@ pyarrow = ["pyarrow", "pyarrow", "pyarrow"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "01373096b340b4e384eb7b7dcc15c41f93c1fb3197145937f3eb07ea2719533e" +content-hash = "ec311bf26ec866de2f427bcdf4ec69ceed721bfd70edfae3aba1ac12882a09d6" diff --git a/pyproject.toml b/pyproject.toml index 596260674..39533d711 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ include = ["CHANGELOG.md"] python = "^3.8.0" thrift = ">=0.16.0,<0.21.0" pandas = [ - { version = ">=1.2.5,<2.3.0", python = ">=3.8,<3.13" }, + { version = ">=1.2.5,<2.4.0", python = ">=3.8,<3.13" }, { version = ">=2.2.3,<2.4.0", python = ">=3.13" } ] lz4 = [ @@ -26,7 +26,7 @@ urllib3 = ">=1.26" python-dateutil = "^2.8.0" pyarrow = [ { version = ">=14.0.1", python = ">=3.8,<3.13", optional=true }, - { version = ">=18.0.0,<22.0.0", python = ">=3.13,<3.14", optional=true }, + { version = ">=18.0.0", python = ">=3.13,<3.14", optional=true }, { version = ">=22.0.0", python = ">=3.14", optional=true } ] pyjwt = "^2.0.0" From ce55e7ba1d19190d38a699232345a92bd7c7ee34 Mon Sep 17 00:00:00 2001 From: Samikshya Chand <148681192+samikshya-db@users.noreply.github.com> Date: Fri, 19 Dec 2025 01:07:41 +0530 Subject: [PATCH 15/39] New minor release (#722) --- CHANGELOG.md | 4 ++++ pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6be2dacaa..99e6ce839 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +# 4.2.3 (2025-12-18) +- added pandas < 2.4.0 support and tests for py 3.14 (databricks/databricks-sql-python#720 by @sreekanth-db) +- pandas 2.3.3 support for py < 3.14 (databricks/databricks-sql-python#721 by @sreekanth-db) + # 4.2.2 (2025-12-01) - Change default use_hybrid_disposition to False (databricks/databricks-sql-python#714 by @samikshya-db) - Circuit breaker changes using pybreaker (databricks/databricks-sql-python#705 by @nikhilsuri-db) diff --git a/pyproject.toml b/pyproject.toml index 39533d711..87312530b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.2.2" +version = "4.2.3" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 7cf631e83..49e1f9ee0 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -71,7 +71,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.2.2" +__version__ = "4.2.3" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From 9b4e5770ab21ad54f86e8c7dd5be62f82a5cf8a3 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Fri, 2 Jan 2026 01:18:14 +0530 Subject: [PATCH 16/39] Fixed the exception handler close() on _TelemetryClientHolder (#723) Fixed the exception handler calls close() on _TelemetryClientHolder objects instead of accessing the client inside them. --- src/databricks/sql/telemetry/telemetry_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 77d1a2f9c..9a38776b1 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -542,8 +542,8 @@ def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): logger.debug("Handling unhandled exception: %s", exc_type.__name__) clients_to_close = list(cls._clients.values()) - for client in clients_to_close: - client.close() + for holder in clients_to_close: + holder.client.close() # Call the original exception handler to maintain normal behavior if cls._original_excepthook: From 946a265a3d9a1a86690a795b9d4e388e36be3bff Mon Sep 17 00:00:00 2001 From: nikhilsuri-db Date: Mon, 5 Jan 2026 13:17:03 +0530 Subject: [PATCH 17/39] created util method to normalise http protocol in http path (#724) * created util method to normalise http protocol in http path Signed-off-by: Nikhil Suri * Added impacted files using util method Signed-off-by: Nikhil Suri * Fixed linting issues Signed-off-by: Nikhil Suri * fixed broken test with mock host string Signed-off-by: Nikhil Suri * mocked http client Signed-off-by: Nikhil Suri * made case sensitive check in url utils Signed-off-by: Nikhil Suri * linting issue resolved Signed-off-by: Nikhil Suri * removed unnecessary md files Signed-off-by: Nikhil Suri * made test readbale Signed-off-by: Nikhil Suri * changes done in auth util as well as sea http Signed-off-by: Nikhil Suri --------- Signed-off-by: Nikhil Suri --- src/databricks/sql/auth/auth_utils.py | 17 ---- src/databricks/sql/auth/token_federation.py | 6 +- .../sql/backend/sea/utils/http_client.py | 6 +- src/databricks/sql/common/feature_flag.py | 4 +- src/databricks/sql/common/url_utils.py | 45 +++++++++++ .../sql/telemetry/telemetry_client.py | 3 +- tests/e2e/test_circuit_breaker.py | 79 ++++++++++++++----- tests/unit/test_client.py | 36 ++++++++- tests/unit/test_sea_http_client.py | 34 ++++++++ tests/unit/test_token_federation.py | 24 +++--- tests/unit/test_url_utils.py | 41 ++++++++++ 11 files changed, 234 insertions(+), 61 deletions(-) create mode 100644 src/databricks/sql/common/url_utils.py create mode 100644 tests/unit/test_url_utils.py diff --git a/src/databricks/sql/auth/auth_utils.py b/src/databricks/sql/auth/auth_utils.py index 439aabc51..a21ce843b 100644 --- a/src/databricks/sql/auth/auth_utils.py +++ b/src/databricks/sql/auth/auth_utils.py @@ -7,23 +7,6 @@ logger = logging.getLogger(__name__) -def parse_hostname(hostname: str) -> str: - """ - Normalize the hostname to include scheme and trailing slash. - - Args: - hostname: The hostname to normalize - - Returns: - Normalized hostname with scheme and trailing slash - """ - if not hostname.startswith("http://") and not hostname.startswith("https://"): - hostname = f"https://{hostname}" - if not hostname.endswith("/"): - hostname = f"{hostname}/" - return hostname - - def decode_token(access_token: str) -> Optional[Dict]: """ Decode a JWT token without verification to extract claims. diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 7b62f6762..f75b904fb 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -6,10 +6,10 @@ from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.auth.auth_utils import ( - parse_hostname, decode_token, is_same_host, ) +from databricks.sql.common.url_utils import normalize_host_with_protocol from databricks.sql.common.http import HttpMethod logger = logging.getLogger(__name__) @@ -99,7 +99,7 @@ def __init__( if not http_client: raise ValueError("http_client is required for TokenFederationProvider") - self.hostname = parse_hostname(hostname) + self.hostname = normalize_host_with_protocol(hostname) self.external_provider = external_provider self.http_client = http_client self.identity_federation_client_id = identity_federation_client_id @@ -164,7 +164,7 @@ def _should_exchange_token(self, access_token: str) -> bool: def _exchange_token(self, access_token: str) -> Token: """Exchange the external token for a Databricks token.""" - token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}" + token_url = f"{self.hostname}{self.TOKEN_EXCHANGE_ENDPOINT}" data = { "grant_type": self.TOKEN_EXCHANGE_GRANT_TYPE, diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index b47f2add2..caefe9929 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -18,6 +18,7 @@ from databricks.sql.common.http_utils import ( detect_and_parse_proxy, ) +from databricks.sql.common.url_utils import normalize_host_with_protocol logger = logging.getLogger(__name__) @@ -66,8 +67,9 @@ def __init__( self.auth_provider = auth_provider self.ssl_options = ssl_options - # Build base URL - self.base_url = f"https://{server_hostname}:{self.port}" + # Build base URL using url_utils for consistent normalization + normalized_host = normalize_host_with_protocol(server_hostname) + self.base_url = f"{normalized_host}:{self.port}" # Parse URL for proxy handling parsed_url = urllib.parse.urlparse(self.base_url) diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 032701f63..36e4b8a02 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -6,6 +6,7 @@ from typing import Dict, Optional, List, Any, TYPE_CHECKING from databricks.sql.common.http import HttpMethod +from databricks.sql.common.url_utils import normalize_host_with_protocol if TYPE_CHECKING: from databricks.sql.client import Connection @@ -67,7 +68,8 @@ def __init__( endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__) self._feature_flag_endpoint = ( - f"https://{self._connection.session.host}{endpoint_suffix}" + normalize_host_with_protocol(self._connection.session.host) + + endpoint_suffix ) # Use the provided HTTP client diff --git a/src/databricks/sql/common/url_utils.py b/src/databricks/sql/common/url_utils.py new file mode 100644 index 000000000..1c4d10369 --- /dev/null +++ b/src/databricks/sql/common/url_utils.py @@ -0,0 +1,45 @@ +""" +URL utility functions for the Databricks SQL connector. +""" + + +def normalize_host_with_protocol(host: str) -> str: + """ + Normalize a connection hostname by ensuring it has a protocol. + + This is useful for handling cases where users may provide hostnames with or without protocols + (common with dbt-databricks users copying URLs from their browser). + + Args: + host: Connection hostname which may or may not include a protocol prefix (https:// or http://) + and may or may not have a trailing slash + + Returns: + Normalized hostname with protocol prefix and no trailing slashes + + Examples: + normalize_host_with_protocol("myserver.com") -> "https://myserver.com" + normalize_host_with_protocol("https://myserver.com") -> "https://myserver.com" + normalize_host_with_protocol("HTTPS://myserver.com/") -> "https://myserver.com" + normalize_host_with_protocol("http://localhost:8080/") -> "http://localhost:8080" + + Raises: + ValueError: If host is None or empty string + """ + # Handle None or empty host + if not host or not host.strip(): + raise ValueError("Host cannot be None or empty") + + # Remove trailing slashes + host = host.rstrip("/") + + # Add protocol if not present (case-insensitive check) + host_lower = host.lower() + if not host_lower.startswith("https://") and not host_lower.startswith("http://"): + host = f"https://{host}" + elif host_lower.startswith("https://") or host_lower.startswith("http://"): + # Normalize protocol to lowercase + protocol_end = host.index("://") + 3 + host = host[:protocol_end].lower() + host[protocol_end:] + + return host diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 9a38776b1..523fcc1dc 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -47,6 +47,7 @@ TelemetryPushClient, CircuitBreakerTelemetryPushClient, ) +from databricks.sql.common.url_utils import normalize_host_with_protocol if TYPE_CHECKING: from databricks.sql.client import Connection @@ -278,7 +279,7 @@ def _send_telemetry(self, events): if self._auth_provider else self.TELEMETRY_UNAUTHENTICATED_PATH ) - url = f"https://{self._host_url}{path}" + url = normalize_host_with_protocol(self._host_url) + path headers = {"Accept": "application/json", "Content-Type": "application/json"} diff --git a/tests/e2e/test_circuit_breaker.py b/tests/e2e/test_circuit_breaker.py index 45c494d19..f9df0f377 100644 --- a/tests/e2e/test_circuit_breaker.py +++ b/tests/e2e/test_circuit_breaker.py @@ -23,6 +23,34 @@ from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager +def wait_for_circuit_state(circuit_breaker, expected_states, timeout=5): + """ + Wait for circuit breaker to reach one of the expected states with polling. + + Args: + circuit_breaker: The circuit breaker instance to monitor + expected_states: List of acceptable states + (STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN) + timeout: Maximum time to wait in seconds + + Returns: + True if state reached, False if timeout + + Examples: + # Single state - pass list with one element + wait_for_circuit_state(cb, [STATE_OPEN]) + + # Multiple states + wait_for_circuit_state(cb, [STATE_CLOSED, STATE_HALF_OPEN]) + """ + start = time.time() + while time.time() - start < timeout: + if circuit_breaker.current_state in expected_states: + return True + time.sleep(0.1) # Poll every 100ms + return False + + @pytest.fixture(autouse=True) def aggressive_circuit_breaker_config(): """ @@ -65,12 +93,17 @@ def create_mock_response(self, status_code): }.get(status_code, b"Response") return response - @pytest.mark.parametrize("status_code,should_trigger", [ - (429, True), - (503, True), - (500, False), - ]) - def test_circuit_breaker_triggers_for_rate_limit_codes(self, status_code, should_trigger): + @pytest.mark.parametrize( + "status_code,should_trigger", + [ + (429, True), + (503, True), + (500, False), + ], + ) + def test_circuit_breaker_triggers_for_rate_limit_codes( + self, status_code, should_trigger + ): """ Verify circuit breaker opens for rate-limit codes (429/503) but not others (500). """ @@ -107,9 +140,14 @@ def mock_request(*args, **kwargs): time.sleep(0.5) if should_trigger: - # Circuit should be OPEN after 2 rate-limit failures + # Wait for circuit to open (async telemetry may take time) + assert wait_for_circuit_state( + circuit_breaker, [STATE_OPEN], timeout=5 + ), f"Circuit didn't open within 5s, state: {circuit_breaker.current_state}" + + # Circuit should be OPEN after rate-limit failures assert circuit_breaker.current_state == STATE_OPEN - assert circuit_breaker.fail_counter == 2 + assert circuit_breaker.fail_counter >= 2 # At least 2 failures # Track requests before another query requests_before = request_count["count"] @@ -197,7 +235,10 @@ def mock_conditional_request(*args, **kwargs): cursor.fetchone() time.sleep(2) - assert circuit_breaker.current_state == STATE_OPEN + # Wait for circuit to open + assert wait_for_circuit_state( + circuit_breaker, [STATE_OPEN], timeout=5 + ), f"Circuit didn't open, state: {circuit_breaker.current_state}" # Wait for reset timeout (5 seconds in test) time.sleep(6) @@ -208,24 +249,20 @@ def mock_conditional_request(*args, **kwargs): # Execute query to trigger HALF_OPEN state cursor.execute("SELECT 3") cursor.fetchone() - time.sleep(1) - # Circuit should be recovering - assert circuit_breaker.current_state in [ - STATE_HALF_OPEN, - STATE_CLOSED, - ], f"Circuit should be recovering, but is {circuit_breaker.current_state}" + # Wait for circuit to start recovering + assert wait_for_circuit_state( + circuit_breaker, [STATE_HALF_OPEN, STATE_CLOSED], timeout=5 + ), f"Circuit didn't recover, state: {circuit_breaker.current_state}" # Execute more queries to fully recover cursor.execute("SELECT 4") cursor.fetchone() - time.sleep(1) - current_state = circuit_breaker.current_state - assert current_state in [ - STATE_CLOSED, - STATE_HALF_OPEN, - ], f"Circuit should recover to CLOSED or HALF_OPEN, got {current_state}" + # Wait for full recovery + assert wait_for_circuit_state( + circuit_breaker, [STATE_CLOSED, STATE_HALF_OPEN], timeout=5 + ), f"Circuit didn't fully recover, state: {circuit_breaker.current_state}" if __name__ == "__main__": diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 8f8a97eae..5b6991931 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -646,13 +646,31 @@ class TransactionTestSuite(unittest.TestCase): "access_token": "tok", } + def _setup_mock_session_with_http_client(self, mock_session): + """ + Helper to configure a mock session with HTTP client mocks. + This prevents feature flag network requests during Connection initialization. + """ + mock_session.host = "foo" + + # Mock HTTP client to prevent feature flag network requests + mock_http_client = Mock() + mock_session.http_client = mock_http_client + + # Mock feature flag response to prevent blocking HTTP calls + mock_ff_response = Mock() + mock_ff_response.status = 200 + mock_ff_response.data = b'{"flags": [], "ttl_seconds": 900}' + mock_http_client.request.return_value = mock_ff_response + def _create_mock_connection(self, mock_session_class): """Helper to create a mocked connection for transaction tests.""" - # Mock session mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" mock_session.get_autocommit.return_value = True + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session # Create connection with ignore_transactions=False to test actual transaction functionality @@ -736,9 +754,7 @@ def test_autocommit_setter_preserves_exception_chain(self, mock_session_class): conn = self._create_mock_connection(mock_session_class) mock_cursor = Mock() - original_error = DatabaseError( - "Original error", host_url="test-host" - ) + original_error = DatabaseError("Original error", host_url="test-host") mock_cursor.execute.side_effect = original_error with patch.object(conn, "cursor", return_value=mock_cursor): @@ -927,6 +943,8 @@ def test_fetch_autocommit_from_server_queries_server(self, mock_session_class): mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session conn = client.Connection( @@ -959,6 +977,8 @@ def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_cla mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session conn = client.Connection( @@ -986,6 +1006,8 @@ def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_cla mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session conn = client.Connection( @@ -1015,6 +1037,8 @@ def test_commit_is_noop_when_ignore_transactions_true(self, mock_session_class): mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session # Create connection with ignore_transactions=True (default) @@ -1043,6 +1067,8 @@ def test_rollback_raises_not_supported_when_ignore_transactions_true( mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session # Create connection with ignore_transactions=True (default) @@ -1068,6 +1094,8 @@ def test_autocommit_setter_is_noop_when_ignore_transactions_true( mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session # Create connection with ignore_transactions=True (default) diff --git a/tests/unit/test_sea_http_client.py b/tests/unit/test_sea_http_client.py index 39ecb58a7..5100a5cb0 100644 --- a/tests/unit/test_sea_http_client.py +++ b/tests/unit/test_sea_http_client.py @@ -44,6 +44,40 @@ def sea_http_client(self, mock_auth_provider, ssl_options): client._pool = Mock() return client + @pytest.mark.parametrize( + "server_hostname,port,expected_base_url", + [ + # Basic hostname without protocol + ("myserver.com", 443, "https://myserver.com:443"), + # Hostname with trailing slash + ("myserver.com/", 443, "https://myserver.com:443"), + # Hostname with https:// protocol + ("https://myserver.com", 443, "https://myserver.com:443"), + # Hostname with http:// protocol (preserved as-is) + ("http://myserver.com", 443, "http://myserver.com:443"), + # Hostname with protocol and trailing slash + ("https://myserver.com/", 443, "https://myserver.com:443"), + # Custom port + ("myserver.com", 8080, "https://myserver.com:8080"), + # Protocol with custom port + ("https://myserver.com", 8080, "https://myserver.com:8080"), + ], + ) + def test_base_url_construction( + self, server_hostname, port, expected_base_url, mock_auth_provider, ssl_options + ): + """Test that base_url is constructed correctly from various hostname inputs.""" + with patch("databricks.sql.backend.sea.utils.http_client.HTTPSConnectionPool"): + client = SeaHttpClient( + server_hostname=server_hostname, + port=port, + http_path="/sql/1.0/warehouses/test", + http_headers=[], + auth_provider=mock_auth_provider, + ssl_options=ssl_options, + ) + assert client.base_url == expected_base_url + def test_get_command_type_from_path(self, sea_http_client): """Test the _get_command_type_from_path method with various paths and methods.""" # Test statement execution diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 2e671c33e..9c209e894 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -6,10 +6,10 @@ from databricks.sql.auth.token_federation import TokenFederationProvider, Token from databricks.sql.auth.auth_utils import ( - parse_hostname, decode_token, is_same_host, ) +from databricks.sql.common.url_utils import normalize_host_with_protocol from databricks.sql.common.http import HttpMethod @@ -78,10 +78,10 @@ def test_init_requires_http_client(self, mock_external_provider): @pytest.mark.parametrize( "input_hostname,expected", [ - ("test.databricks.com", "https://test.databricks.com/"), - ("https://test.databricks.com", "https://test.databricks.com/"), - ("https://test.databricks.com/", "https://test.databricks.com/"), - ("test.databricks.com/", "https://test.databricks.com/"), + ("test.databricks.com", "https://test.databricks.com"), + ("https://test.databricks.com", "https://test.databricks.com"), + ("https://test.databricks.com/", "https://test.databricks.com"), + ("test.databricks.com/", "https://test.databricks.com"), ], ) def test_hostname_normalization( @@ -305,15 +305,15 @@ class TestUtilityFunctions: @pytest.mark.parametrize( "input_hostname,expected", [ - ("test.databricks.com", "https://test.databricks.com/"), - ("https://test.databricks.com", "https://test.databricks.com/"), - ("https://test.databricks.com/", "https://test.databricks.com/"), - ("test.databricks.com/", "https://test.databricks.com/"), + ("test.databricks.com", "https://test.databricks.com"), + ("https://test.databricks.com", "https://test.databricks.com"), + ("https://test.databricks.com/", "https://test.databricks.com"), + ("test.databricks.com/", "https://test.databricks.com"), ], ) - def test_parse_hostname(self, input_hostname, expected): - """Test hostname parsing.""" - assert parse_hostname(input_hostname) == expected + def test_normalize_hostname(self, input_hostname, expected): + """Test hostname normalization.""" + assert normalize_host_with_protocol(input_hostname) == expected @pytest.mark.parametrize( "url1,url2,expected", diff --git a/tests/unit/test_url_utils.py b/tests/unit/test_url_utils.py new file mode 100644 index 000000000..95f42408d --- /dev/null +++ b/tests/unit/test_url_utils.py @@ -0,0 +1,41 @@ +"""Tests for URL utility functions.""" +import pytest +from databricks.sql.common.url_utils import normalize_host_with_protocol + + +class TestNormalizeHostWithProtocol: + """Tests for normalize_host_with_protocol function.""" + + @pytest.mark.parametrize( + "input_url,expected_output", + [ + ("myserver.com", "https://myserver.com"), # Add https:// + ("https://myserver.com", "https://myserver.com"), # No duplicate + ("http://localhost:8080", "http://localhost:8080"), # Preserve http:// + ("myserver.com:443", "https://myserver.com:443"), # With port + ("myserver.com/", "https://myserver.com"), # Remove trailing slash + ("https://myserver.com///", "https://myserver.com"), # Multiple slashes + ("HTTPS://MyServer.COM", "https://MyServer.COM"), # Case handling + ], + ) + def test_normalize_host_with_protocol(self, input_url, expected_output): + """Test host normalization with various input formats.""" + result = normalize_host_with_protocol(input_url) + assert result == expected_output + + # Additional assertions + assert result.startswith("https://") or result.startswith("http://") + assert not result.endswith("/") + + @pytest.mark.parametrize( + "invalid_host", + [ + None, + "", + " ", # Whitespace only + ], + ) + def test_normalize_host_with_protocol_raises_on_invalid_input(self, invalid_host): + """Test that function raises ValueError for None or empty host.""" + with pytest.raises(ValueError, match="Host cannot be None or empty"): + normalize_host_with_protocol(invalid_host) From 03eb369882cfda15f98c653ec6c9654cee3220c9 Mon Sep 17 00:00:00 2001 From: Samikshya Chand <148681192+samikshya-db@users.noreply.github.com> Date: Thu, 8 Jan 2026 15:20:03 +0530 Subject: [PATCH 18/39] New minor version release 4.2.4 (#725) New minor version release --- CHANGELOG.md | 4 ++++ pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 99e6ce839..41105b072 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +# 4.2.4 (2026-01-07) +- Fixed the exception handler close() on _TelemetryClientHolder (databricks/databricks-sql-python#723 by @msrathore-db) +- Created util method to normalise http protocol in http path (databricks/databricks-sql-python#724 by @nikhilsuri-db) + # 4.2.3 (2025-12-18) - added pandas < 2.4.0 support and tests for py 3.14 (databricks/databricks-sql-python#720 by @sreekanth-db) - pandas 2.3.3 support for py < 3.14 (databricks/databricks-sql-python#721 by @sreekanth-db) diff --git a/pyproject.toml b/pyproject.toml index 87312530b..8a635588c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.2.3" +version = "4.2.4" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 49e1f9ee0..41784bc45 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -71,7 +71,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.2.3" +__version__ = "4.2.4" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From 4b7df5b0fd4da7e9caecbd8042c12e363c6d3d5f Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Sat, 10 Jan 2026 01:39:01 +0530 Subject: [PATCH 19/39] [PECOBLR-1168] query tags telemetry (#716) * query tags telemetry Signed-off-by: Sreekanth Vadigi * code linting fix Signed-off-by: Sreekanth Vadigi --------- Signed-off-by: Sreekanth Vadigi --- src/databricks/sql/client.py | 2 ++ src/databricks/sql/telemetry/models/event.py | 2 ++ src/databricks/sql/utils.py | 15 +++++++++++++++ tests/unit/test_telemetry.py | 2 ++ 4 files changed, 21 insertions(+) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 1f17d54f2..a0215aae5 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -35,6 +35,7 @@ ColumnTable, ColumnQueue, build_client_context, + get_session_config_value, ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -386,6 +387,7 @@ def read(self) -> Optional[OAuthToken]: support_many_parameters=True, # Native parameters supported enable_complex_datatype_support=_use_arrow_native_complex_types, allowed_volume_ingestion_paths=self.staging_allowed_local_path, + query_tags=get_session_config_value(session_configuration, "query_tags"), ) self._telemetry_client.export_initial_telemetry_log( diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index 2e6f63a6f..4d5a45038 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -57,6 +57,7 @@ class DriverConnectionParameters(JsonSerializableMixin): support_many_parameters (bool): Whether many parameters are supported enable_complex_datatype_support (bool): Whether complex datatypes are supported allowed_volume_ingestion_paths (str): Allowed paths for volume ingestion + query_tags (str): Query tags for tracking and attribution """ http_path: str @@ -84,6 +85,7 @@ class DriverConnectionParameters(JsonSerializableMixin): support_many_parameters: Optional[bool] = None enable_complex_datatype_support: Optional[bool] = None allowed_volume_ingestion_paths: Optional[str] = None + query_tags: Optional[str] = None @dataclass diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index b46784b10..043183ac2 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -38,6 +38,21 @@ logger = logging.getLogger(__name__) +def get_session_config_value( + session_configuration: Optional[Dict[str, Any]], key: str +) -> Optional[str]: + """Get a session configuration value with case-insensitive key matching""" + if not session_configuration: + return None + + key_upper = key.upper() + for k, v in session_configuration.items(): + if k.upper() == key_upper: + return str(v) if v is not None else None + + return None + + class ResultSetQueue(ABC): @abstractmethod def next_n_rows(self, num_rows: int): diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index e9fa16649..86f06aa8a 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -530,6 +530,7 @@ def test_driver_connection_parameters_all_fields(self): support_many_parameters=True, enable_complex_datatype_support=True, allowed_volume_ingestion_paths="/Volumes/catalog/schema/volume", + query_tags="team:engineering,project:telemetry", ) # Serialize to JSON and parse back @@ -562,6 +563,7 @@ def test_driver_connection_parameters_all_fields(self): assert json_dict["support_many_parameters"] is True assert json_dict["enable_complex_datatype_support"] is True assert json_dict["allowed_volume_ingestion_paths"] == "/Volumes/catalog/schema/volume" + assert json_dict["query_tags"] == "team:engineering,project:telemetry" def test_driver_connection_parameters_minimal_fields(self): """Test DriverConnectionParameters with only required fields.""" From cafed60ffc8cc7eff94ceee63e7a8b47eee522ea Mon Sep 17 00:00:00 2001 From: Samikshya Chand <148681192+samikshya-db@users.noreply.github.com> Date: Thu, 5 Feb 2026 10:47:11 +0530 Subject: [PATCH 20/39] [ES-1717039] Fix 60 seconds delay in gov cloud connections + Fix PR check failures in the repo (#735) * Fix 60 seconds delay in gov cloud connections * keep it simple :) * Add fix for krb error * pin poetry * Pin for publish flow too * Fix failing tests * Edit order for pypi * One last fix : pls work --- .github/workflows/code-coverage.yml | 5 +++ .github/workflows/code-quality-checks.yml | 8 +++++ .github/workflows/daily-telemetry-e2e.yml | 5 +++ .github/workflows/integration.yml | 5 +++ .github/workflows/publish-manual.yml | 9 ++++++ .github/workflows/publish-test.yml | 23 ++++++++++---- .github/workflows/publish.yml | 18 ++++++++--- src/databricks/sql/auth/retry.py | 34 +++++---------------- tests/e2e/common/retry_test_mixins.py | 26 +++++++--------- tests/e2e/common/staging_ingestion_tests.py | 2 +- tests/e2e/common/uc_volume_tests.py | 2 +- tests/unit/test_retry.py | 12 ++++++++ 12 files changed, 94 insertions(+), 55 deletions(-) diff --git a/.github/workflows/code-coverage.yml b/.github/workflows/code-coverage.yml index 3c76be728..9cb68dbc9 100644 --- a/.github/workflows/code-coverage.yml +++ b/.github/workflows/code-coverage.yml @@ -36,6 +36,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true @@ -58,6 +59,10 @@ jobs: #---------------------------------------------- # install your root project, if required #---------------------------------------------- + - name: Install Kerberos system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libkrb5-dev - name: Install library run: poetry install --no-interaction --all-extras #---------------------------------------------- diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 3c368abef..cc3952920 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -35,6 +35,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true @@ -118,6 +119,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true @@ -140,6 +142,10 @@ jobs: #---------------------------------------------- # install your root project, if required #---------------------------------------------- + - name: Install Kerberos system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libkrb5-dev - name: Install library run: poetry install --no-interaction --all-extras #---------------------------------------------- @@ -191,6 +197,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true @@ -243,6 +250,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true diff --git a/.github/workflows/daily-telemetry-e2e.yml b/.github/workflows/daily-telemetry-e2e.yml index 3d61cf177..d60b7f5a9 100644 --- a/.github/workflows/daily-telemetry-e2e.yml +++ b/.github/workflows/daily-telemetry-e2e.yml @@ -43,6 +43,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true @@ -60,6 +61,10 @@ jobs: #---------------------------------------------- # install dependencies if cache does not exist #---------------------------------------------- + - name: Install Kerberos system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libkrb5-dev - name: Install dependencies run: poetry install --no-interaction --all-extras diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index ad5369997..7fd6d98f1 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -33,6 +33,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true @@ -49,6 +50,10 @@ jobs: #---------------------------------------------- # install dependencies if cache does not exist #---------------------------------------------- + - name: Install Kerberos system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libkrb5-dev - name: Install dependencies run: poetry install --no-interaction --all-extras #---------------------------------------------- diff --git a/.github/workflows/publish-manual.yml b/.github/workflows/publish-manual.yml index ecad71a29..2f2a7a4dd 100644 --- a/.github/workflows/publish-manual.yml +++ b/.github/workflows/publish-manual.yml @@ -31,10 +31,19 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 # Install Poetry, the Python package manager with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true + #---------------------------------------------- + # Step 3.5: Install Kerberos system dependencies + #---------------------------------------------- + - name: Install Kerberos system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libkrb5-dev + # #---------------------------------------------- # # Step 4: Load cached virtual environment (if available) # #---------------------------------------------- diff --git a/.github/workflows/publish-test.yml b/.github/workflows/publish-test.yml index 2e6359a78..ea4158934 100644 --- a/.github/workflows/publish-test.yml +++ b/.github/workflows/publish-test.yml @@ -21,6 +21,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true @@ -36,6 +37,10 @@ jobs: #---------------------------------------------- # install dependencies if cache does not exist #---------------------------------------------- + - name: Install Kerberos system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libkrb5-dev - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' run: poetry install --no-interaction --no-root @@ -54,11 +59,17 @@ jobs: - name: Update pyproject.toml run: poetry version ${{ steps.version.outputs.major-version }}.${{ steps.version.outputs.minor-version }}.dev$(date +%s) #---------------------------------------------- + # Build the package (before publish action) + #---------------------------------------------- + - name: Build package + run: poetry build + #---------------------------------------------- + # Configure test-pypi repository + #---------------------------------------------- + - name: Configure test-pypi repository + run: poetry config repositories.testpypi https://test.pypi.org/legacy/ + #---------------------------------------------- # Attempt push to test-pypi #---------------------------------------------- - - name: Build and publish to pypi - uses: JRubics/poetry-publish@v1.10 - with: - pypi_token: ${{ secrets.TEST_PYPI_TOKEN }} - repository_name: "testpypi" - repository_url: "https://test.pypi.org/legacy/" + - name: Publish to test-pypi + run: poetry publish --username __token__ --password ${{ secrets.TEST_PYPI_TOKEN }} --repository testpypi diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index dde6cc2dc..b101f421c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -23,6 +23,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true @@ -38,6 +39,10 @@ jobs: #---------------------------------------------- # install dependencies if cache does not exist #---------------------------------------------- + - name: Install Kerberos system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libkrb5-dev - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' run: poetry install --no-interaction --no-root @@ -56,9 +61,12 @@ jobs: - name: Update pyproject.toml run: poetry version ${{ steps.version.outputs.current-version }} #---------------------------------------------- - # Attempt push to test-pypi + # Build the package (before publish) #---------------------------------------------- - - name: Build and publish to pypi - uses: JRubics/poetry-publish@v1.10 - with: - pypi_token: ${{ secrets.PROD_PYPI_TOKEN }} + - name: Build package + run: poetry build + #---------------------------------------------- + # Publish to pypi + #---------------------------------------------- + - name: Publish to pypi + run: poetry publish --username __token__ --password ${{ secrets.PROD_PYPI_TOKEN }} diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 4281883da..b0c2f497d 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -373,6 +373,13 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: if status_code == 403: return False, "403 codes are not retried" + # Request failed with 404. Don't retry for any command type. + if status_code == 404: + return ( + False, + "Received 404 - NOT_FOUND. The requested resource does not exist.", + ) + # Request failed and server said NotImplemented. This isn't recoverable. Don't retry. if status_code == 501: return False, "Received code 501 from server." @@ -381,33 +388,6 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: if not self._is_method_retryable(method): return False, "Only POST requests are retried" - # Request failed with 404 and was a GetOperationStatus. This is not recoverable. Don't retry. - if status_code == 404 and self.command_type == CommandType.GET_OPERATION_STATUS: - return ( - False, - "GetOperationStatus received 404 code from Databricks. Operation was canceled.", - ) - - # Request failed with 404 because CloseSession returns 404 if you repeat the request. - if ( - status_code == 404 - and self.command_type == CommandType.CLOSE_SESSION - and len(self.history) > 0 - ): - raise SessionAlreadyClosedError( - "CloseSession received 404 code from Databricks. Session is already closed." - ) - - # Request failed with 404 because CloseOperation returns 404 if you repeat the request. - if ( - status_code == 404 - and self.command_type == CommandType.CLOSE_OPERATION - and len(self.history) > 0 - ): - raise CursorAlreadyClosedError( - "CloseOperation received 404 code from Databricks. Cursor is already closed." - ) - # Request failed, was an ExecuteStatement and the command may have reached the server if ( self.command_type == CommandType.EXECUTE_STATEMENT diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index b2350bd98..80822ba47 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -278,7 +278,7 @@ def test_retry_max_count_not_exceeded(self, mock_send_telemetry, extra_params): THEN the connector issues six request (original plus five retries) before raising an exception """ - with mocked_server_response(status=404) as mock_obj: + with mocked_server_response(status=429, headers={"Retry-After": "0"}) as mock_obj: with pytest.raises(MaxRetryError) as cm: extra_params = {**extra_params, **self._retry_policy} with self.connection(extra_params=extra_params) as conn: @@ -467,22 +467,21 @@ def test_retry_safe_execute_statement_retry_condition(self, extra_params): ) def test_retry_abort_close_session_on_404(self, extra_params, caplog): """GIVEN the connector sends a CloseSession command - WHEN server sends a 404 (which is normally retried) - THEN nothing is retried because 404 means the session already closed + WHEN server sends a 404 (which is not retried since commit 41b28159) + THEN nothing is retried because 404 is globally non-retryable """ - # First response is a Bad Gateway -> Result is the command actually goes through - # Second response is a 404 because the session is no longer found + # With the idempotency-based retry refactor, 404 is now globally non-retryable + # regardless of command type. The close() method catches RequestError and proceeds. responses = [ - {"status": 502, "headers": {"Retry-After": "1"}, "redirect_location": None}, {"status": 404, "headers": {}, "redirect_location": None}, ] extra_params = {**extra_params, **self._retry_policy} with self.connection(extra_params=extra_params) as conn: with mock_sequential_server_responses(responses): + # Should not raise an exception, the error is caught internally conn.close() - assert "Session was closed by a prior request" in caplog.text @pytest.mark.parametrize( "extra_params", @@ -493,14 +492,13 @@ def test_retry_abort_close_session_on_404(self, extra_params, caplog): ) def test_retry_abort_close_operation_on_404(self, extra_params, caplog): """GIVEN the connector sends a CancelOperation command - WHEN server sends a 404 (which is normally retried) - THEN nothing is retried because 404 means the operation was already canceled + WHEN server sends a 404 (which is not retried since commit 41b28159) + THEN nothing is retried because 404 is globally non-retryable """ - # First response is a Bad Gateway -> Result is the command actually goes through - # Second response is a 404 because the session is no longer found + # With the idempotency-based retry refactor, 404 is now globally non-retryable + # regardless of command type. The close() method catches RequestError and proceeds. responses = [ - {"status": 502, "headers": {"Retry-After": "1"}, "redirect_location": None}, {"status": 404, "headers": {}, "redirect_location": None}, ] @@ -515,10 +513,8 @@ def test_retry_abort_close_operation_on_404(self, extra_params, caplog): # This call guarantees we have an open cursor at the server curs.execute("SELECT 1") with mock_sequential_server_responses(responses): + # Should not raise an exception, the error is caught internally curs.close() - assert ( - "Operation was canceled by a prior request" in caplog.text - ) @pytest.mark.parametrize( "extra_params", diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 73aa0a113..a88f55238 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -81,7 +81,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): # GET after REMOVE should fail with pytest.raises( - Error, match="too many 404 error responses" + Error, match="Staging operation over HTTP was unsuccessful: 404" ): cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 93e63bd28..5b4086f91 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -81,7 +81,7 @@ def test_uc_volume_life_cycle(self, catalog, schema): # GET after REMOVE should fail with pytest.raises( - Error, match="too many 404 error responses" + Error, match="Staging operation over HTTP was unsuccessful: 404" ): cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 897a1d111..0d01d8675 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -83,3 +83,15 @@ def test_excessive_retry_attempts_error(self, t_mock, retry_policy): retry_policy.sleep(HTTPResponse(status=503)) # Internally urllib3 calls the increment function generating a new instance for every retry retry_policy = retry_policy.increment() + + def test_404_does_not_retry_for_any_command_type(self, retry_policy): + """Test that 404 never retries for any CommandType""" + retry_policy._retry_start_time = time.time() + + # Test for each CommandType + for command_type in CommandType: + retry_policy.command_type = command_type + should_retry, msg = retry_policy.should_retry("POST", 404) + + assert should_retry is False, f"404 should not retry for {command_type}" + assert "404" in msg or "NOT_FOUND" in msg From 61f80298e301e62c51d87e7b0f2427827c60f540 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Fri, 6 Feb 2026 21:30:31 +0530 Subject: [PATCH 21/39] [PECOBLR-1735] Fix #729 and #731: Telemetry lifecycle management (#734) * Fix #729 and #731: Telemetry lifecycle management Signed-off-by: Madhavendra Rathore * Address review comments: revert timeout and telemetry_enabled changes Per reviewer feedback on PR #734: 1. Revert timeout from 30s back to 900s (line 299) - Reviewer noted that with wait=False, timeout is not critical - The async nature and wait=False handle the exit speed 2. Revert telemetry_enabled parameter back to True (line 734) - Reviewer noted this is redundant given the early return - If enable_telemetry=False, we return early (line 729) - Line 734 only executes when enable_telemetry=True - Therefore using the parameter here is unnecessary These changes address the reviewer's valid technical concerns while keeping the core fixes intact: - wait=False for non-blocking shutdown (critical for Issue #729) - Early return when enable_telemetry=False (critical for Issue #729) - All Issue #731 fixes (null-safety, __del__, documentation) Signed-off-by: Madhavendra Rathore * Fix Black formatting violations Apply Black formatting to files modified in previous commits: - src/databricks/sql/common/unified_http_client.py - src/databricks/sql/telemetry/telemetry_client.py Changes are purely cosmetic (quote style consistency). Signed-off-by: Madhavendra Rathore * Fix CI test failure: Prevent parallel execution of telemetry tests Add @pytest.mark.xdist_group to telemetry test classes to ensure they run sequentially on the same worker when using pytest-xdist (-n auto). Root cause: Tests marked @pytest.mark.serial were still being parallelized in CI because pytest-xdist doesn't respect custom markers by default. With host-level telemetry batching (PR #718), tests running in parallel would share the same TelemetryClient and interfere with each other's event counting, causing test_concurrent_queries_sends_telemetry to see 88 events instead of the expected 60. The xdist_group marker ensures all tests in the "serial_telemetry" group run on the same worker sequentially, preventing state interference. Signed-off-by: Claude Sonnet 4.5 * Fix telemetry test fixtures: Clean up state before AND after tests Modified telemetry_setup_teardown fixtures to clean up TelemetryClientFactory state both BEFORE and AFTER each test, not just after. This prevents leftover state from previous tests (pending events, active executors) from interfering with the current test. Root cause: In CI with sequential execution on the same worker, if a previous test left pending telemetry events in the executor, those events could be captured by the next test's mock, causing inflated event counts (88 instead of 60). Now ensures complete isolation between tests by resetting all shared state before each test starts. Signed-off-by: Claude Sonnet 4.5 * Fix CI test failure: Clear _flush_event between tests The _flush_event threading.Event was never cleared after stopping the flush thread, remaining in "set" state. This caused timing issues in subsequent tests where the Event was already signaled, triggering unexpected flush behavior and causing extra telemetry events to be captured (88 instead of 60). Now explicitly clear the _flush_event flag in both setup (before test) and teardown (after test) to ensure clean state isolation between tests. This explains why CI consistently got 88 events - the flush_event from previous tests triggered additional flushes during test execution. Signed-off-by: Claude Sonnet 4.5 * Add debug workflow and output to diagnose CI test failure 1. Created new workflow 'test-telemetry-only.yml' that runs only the failing telemetry test with -n auto, mimicking real CI but much faster 2. Added debug output to test showing: - Client-side captured events - Number of futures/batches - Number of server responses - Server-reported successful events This will help identify why CI gets 88 events vs local 60 events. Signed-off-by: Claude Sonnet 4.5 * Fix workflow: Add krb5 system dependency The workflow was failing during poetry install due to missing krb5 system libraries needed for kerberos dependencies. Signed-off-by: Claude Sonnet 4.5 * Fix xdist_group: Add --dist=loadgroup to pytest commands The @pytest.mark.xdist_group markers were being ignored because pytest-xdist uses --dist=load by default, which doesn't respect groups. With --dist=loadgroup, tests in the same xdist_group run sequentially on the same worker, preventing telemetry state interference between tests. This is the ROOT CAUSE of the 88 vs 60 events issue - tests were running in parallel across workers instead of sequentially on one worker as intended. Signed-off-by: Claude Sonnet 4.5 * Add aggressive flush before test to prevent event interference CI shows 72 events instead of 60. Debug output reveals: - Client captured: 60 events (correct) - Server received: 72 events across 2 batches The 12 extra events accumulate in the timing window between fixture cleanup and mock setup. Other tests (like circuit breaker tests not in our xdist_group) may be sending telemetry concurrently. Solution: Add an explicit flush+shutdown RIGHT BEFORE setting up the mock to ensure a completely clean slate with zero buffered events. Signed-off-by: Claude Sonnet 4.5 * Split workflow: Isolate telemetry tests in separate job To prevent interference from other e2e tests, split into two jobs: Job 1 (run-non-telemetry-tests): - Runs all e2e tests EXCEPT telemetry tests - Uses -n auto for parallel execution Job 2 (run-telemetry-tests): - Runs ONLY telemetry tests - Depends on Job 1 completing (needs: run-non-telemetry-tests) - Fresh Python process = complete isolation - No ambient telemetry from other tests This eliminates the 68 vs 60 event discrepancy by ensuring telemetry tests run in a clean environment with zero interference. Signed-off-by: Claude Sonnet 4.5 * Fix workflows: Add krb5 deps and cleanup debug code Changes across multiple workflows: 1. integration.yml: - Add krb5 system dependency to telemetry job - Fixes: krb5-config command not found error during poetry install 2. code-coverage.yml: - Add krb5 system dependency - Split telemetry tests into separate step for isolation - Maintains coverage accumulation with --cov-append 3. publish-test.yml: - Add krb5 system dependency for consistent builds 4. test_concurrent_telemetry.py: - Remove debug print statements 5. Delete test-telemetry-only.yml: - Remove temporary debug workflow All workflows now have proper telemetry test isolation and required system dependencies for kerberos packages. Signed-off-by: Claude Sonnet 4.5 * Fix publish-test.yml: Update Python 3.9 -> 3.10 Poetry 2.3.2 installation fails with Python 3.9: Installing Poetry (2.3.2): An error occurred. Other workflows use Python 3.10 and work fine. Updating to match ensures consistency and avoids Poetry installation issues. Signed-off-by: Claude Sonnet 4.5 * Fix integration workflow: Remove --dist=loadgroup from non-telemetry tests - Remove --dist=loadgroup from non-telemetry job (only needed for telemetry) - Remove test_telemetry_e2e.py from telemetry job (was skipped before) - This should fix test_uc_volume_life_cycle failure caused by changed test distribution * Fix code-coverage workflow: Remove test_telemetry_e2e.py from coverage tests - Only run test_concurrent_telemetry.py in isolated telemetry step - test_telemetry_e2e.py was excluded in original workflow, keep it excluded * Fix publish-test workflow: Remove cache conditional - Always run poetry install (not just on cache miss) - Ensures fresh install with system dependencies (krb5) - Matches pattern used in integration.yml * Fix publish-test.yml: Remove duplicate krb5 install, restore cache conditional - Remove duplicate system dependencies step - Restore cache conditional to match main branch - Keep Python 3.10 (our change from 3.9) * Fix code-coverage: Remove serial tests step - All serial tests are telemetry tests (test_concurrent_telemetry.py and test_telemetry_e2e.py) - They're already run in the isolated telemetry step - Running -m serial with --ignore on both files results in 0 tests (exit code 5) --------- Signed-off-by: Madhavendra Rathore Signed-off-by: Claude Sonnet 4.5 --- .github/workflows/code-coverage.yml | 15 ++++-- .github/workflows/integration.yml | 52 +++++++++++++++++-- .github/workflows/publish-test.yml | 2 +- src/databricks/sql/client.py | 3 ++ .../sql/common/unified_http_client.py | 12 ++++- .../sql/telemetry/telemetry_client.py | 40 +++++++++++++- tests/e2e/test_concurrent_telemetry.py | 26 +++++++++- tests/e2e/test_telemetry_e2e.py | 28 ++++++++-- 8 files changed, 161 insertions(+), 17 deletions(-) diff --git a/.github/workflows/code-coverage.yml b/.github/workflows/code-coverage.yml index 9cb68dbc9..5c961757e 100644 --- a/.github/workflows/code-coverage.yml +++ b/.github/workflows/code-coverage.yml @@ -31,6 +31,13 @@ jobs: with: python-version: "3.10" #---------------------------------------------- + # ----- install system dependencies ----- + #---------------------------------------------- + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libkrb5-dev + #---------------------------------------------- # ----- install & configure poetry ----- #---------------------------------------------- - name: Install Poetry @@ -80,13 +87,13 @@ jobs: -v #---------------------------------------------- - # run serial tests with coverage + # run telemetry tests with coverage (isolated) #---------------------------------------------- - - name: Run serial tests with coverage + - name: Run telemetry tests with coverage (isolated) continue-on-error: false run: | - poetry run pytest tests/e2e \ - -m "serial" \ + # Run test_concurrent_telemetry.py separately for isolation + poetry run pytest tests/e2e/test_concurrent_telemetry.py \ --cov=src \ --cov-append \ --cov-report=xml \ diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 7fd6d98f1..c915ee6c1 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -7,7 +7,7 @@ on: pull_request: jobs: - run-e2e-tests: + run-non-telemetry-tests: runs-on: ubuntu-latest environment: azure-prod env: @@ -59,9 +59,53 @@ jobs: #---------------------------------------------- # run test suite #---------------------------------------------- - - name: Run e2e tests (excluding daily-only tests) + - name: Run non-telemetry e2e tests run: | - # Exclude telemetry E2E tests from PR runs (run daily instead) + # Exclude all telemetry tests - they run in separate job for isolation poetry run python -m pytest tests/e2e \ --ignore=tests/e2e/test_telemetry_e2e.py \ - -n auto \ No newline at end of file + --ignore=tests/e2e/test_concurrent_telemetry.py \ + -n auto + + run-telemetry-tests: + runs-on: ubuntu-latest + needs: run-non-telemetry-tests # Run after non-telemetry tests complete + environment: azure-prod + env: + DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} + DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} + DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} + DATABRICKS_CATALOG: peco + DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} + steps: + - name: Check out repository + uses: actions/checkout@v4 + - name: Set up python + id: setup-python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libkrb5-dev + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v4 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} + - name: Install dependencies + run: poetry install --no-interaction --all-extras + - name: Run telemetry tests in isolation + run: | + # Run test_concurrent_telemetry.py in isolation with complete process separation + # Use --dist=loadgroup to respect @pytest.mark.xdist_group markers + poetry run python -m pytest tests/e2e/test_concurrent_telemetry.py \ + -n auto --dist=loadgroup -v \ No newline at end of file diff --git a/.github/workflows/publish-test.yml b/.github/workflows/publish-test.yml index ea4158934..97a444e68 100644 --- a/.github/workflows/publish-test.yml +++ b/.github/workflows/publish-test.yml @@ -14,7 +14,7 @@ jobs: id: setup-python uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: "3.10" #---------------------------------------------- # ----- install & configure poetry ----- #---------------------------------------------- diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index a0215aae5..1a246b7c1 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -306,6 +306,8 @@ def read(self) -> Optional[OAuthToken]: ) self.session.open() except Exception as e: + # Respect user's telemetry preference even during connection failure + enable_telemetry = kwargs.get("enable_telemetry", True) TelemetryClientFactory.connection_failure_log( error_name="Exception", error_message=str(e), @@ -316,6 +318,7 @@ def read(self) -> Optional[OAuthToken]: user_agent=self.session.useragent_header if hasattr(self, "session") else None, + enable_telemetry=enable_telemetry, ) raise e diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index d5f7d3c8d..ef55564c8 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -217,7 +217,7 @@ def _should_use_proxy(self, target_host: str) -> bool: logger.debug("Error checking proxy bypass for host %s: %s", target_host, e) return True - def _get_pool_manager_for_url(self, url: str) -> urllib3.PoolManager: + def _get_pool_manager_for_url(self, url: str) -> Optional[urllib3.PoolManager]: """ Get the appropriate pool manager for the given URL. @@ -225,7 +225,7 @@ def _get_pool_manager_for_url(self, url: str) -> urllib3.PoolManager: url: The target URL Returns: - PoolManager instance (either direct or proxy) + PoolManager instance (either direct or proxy), or None if client is closed """ parsed_url = urllib.parse.urlparse(url) target_host = parsed_url.hostname @@ -291,6 +291,14 @@ def request_context( # Select appropriate pool manager based on target URL pool_manager = self._get_pool_manager_for_url(url) + # DEFENSIVE: Check if pool_manager is None (client closing/closed) + # This prevents AttributeError race condition when telemetry cleanup happens + if pool_manager is None: + logger.debug( + "HTTP client closing or closed, cannot make request to %s", url + ) + raise RequestError("HTTP client is closing or has been closed") + response = None try: diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 523fcc1dc..408162400 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -42,6 +42,7 @@ from databricks.sql.common.feature_flag import FeatureFlagsContextFactory from databricks.sql.common.unified_http_client import UnifiedHttpClient from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import RequestError from databricks.sql.telemetry.telemetry_push_client import ( ITelemetryPushClient, TelemetryPushClient, @@ -417,10 +418,38 @@ def export_latency_log( ) def close(self): - """Flush remaining events before closing""" + """Flush remaining events before closing + + IMPORTANT: This method does NOT close self._http_client. + + Rationale: + - _flush() submits async work to the executor that uses _http_client + - If we closed _http_client here, async callbacks would fail with AttributeError + - Instead, we let _http_client live as long as needed: + * Pending futures hold references to self (via bound methods) + * This keeps self alive, which keeps self._http_client alive + * When all futures complete, Python GC will clean up naturally + - The __del__ method ensures eventual cleanup during garbage collection + + This design prevents race conditions while keeping telemetry truly async. + """ logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() + def __del__(self): + """Cleanup when TelemetryClient is garbage collected + + This ensures _http_client is eventually closed when the TelemetryClient + object is destroyed. By this point, all async work should be complete + (since the futures held references keeping us alive), so it's safe to + close the http client. + """ + try: + if hasattr(self, "_http_client") and self._http_client: + self._http_client.close() + except Exception: + pass + class _TelemetryClientHolder: """ @@ -674,7 +703,8 @@ def close(host_url): ) try: TelemetryClientFactory._stop_flush_thread() - TelemetryClientFactory._executor.shutdown(wait=True) + # Use wait=False to allow process to exit immediately + TelemetryClientFactory._executor.shutdown(wait=False) except Exception as e: logger.debug("Failed to shutdown thread pool executor: %s", e) TelemetryClientFactory._executor = None @@ -689,9 +719,15 @@ def connection_failure_log( port: int, client_context, user_agent: Optional[str] = None, + enable_telemetry: bool = True, ): """Send error telemetry when connection creation fails, using provided client context""" + # Respect user's telemetry preference - don't force-enable + if not enable_telemetry: + logger.debug("Telemetry disabled, skipping connection failure log") + return + UNAUTH_DUMMY_SESSION_ID = "unauth_session_id" TelemetryClientFactory.initialize_telemetry_client( diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index bed348c2c..6a317cbfa 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -27,6 +27,7 @@ def run_in_threads(target, num_threads, pass_index=False): @pytest.mark.serial +@pytest.mark.xdist_group(name="serial_telemetry") class TestE2ETelemetry(PySQLPytestTestCase): @pytest.fixture(autouse=True) def telemetry_setup_teardown(self): @@ -35,13 +36,27 @@ def telemetry_setup_teardown(self): before each test and shuts it down afterward. Using a fixture makes this robust and automatic. """ + # Clean up BEFORE test starts to ensure no leftover state from previous tests + # Use wait=True to ensure all pending telemetry from previous tests completes + # This prevents those events from being captured by this test's mock + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) # WAIT for pending telemetry + TelemetryClientFactory._executor = None + TelemetryClientFactory._stop_flush_thread() + TelemetryClientFactory._flush_event.clear() # Clear the event flag + TelemetryClientFactory._clients.clear() + TelemetryClientFactory._initialized = False + try: yield finally: + # Clean up AFTER test ends + # Use wait=True to ensure this test's telemetry completes before next test starts if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor.shutdown(wait=True) # WAIT for this test's telemetry TelemetryClientFactory._executor = None TelemetryClientFactory._stop_flush_thread() + TelemetryClientFactory._flush_event.clear() # Clear the event flag TelemetryClientFactory._clients.clear() TelemetryClientFactory._initialized = False @@ -50,6 +65,14 @@ def test_concurrent_queries_sends_telemetry(self): An E2E test where concurrent threads execute real queries against the staging endpoint, while we capture and verify the generated telemetry. """ + # Extra flush right before test starts to clear any events that accumulated + # between fixture cleanup and now (e.g., from other tests on same worker) + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._clients.clear() + TelemetryClientFactory._initialized = False + num_threads = 30 capture_lock = threading.Lock() captured_telemetry = [] @@ -139,6 +162,7 @@ def execute_query_worker(thread_id): assert "errors" not in response or not response["errors"] if "numProtoSuccess" in response: total_successful_events += response["numProtoSuccess"] + assert total_successful_events == num_threads * 2 assert ( diff --git a/tests/e2e/test_telemetry_e2e.py b/tests/e2e/test_telemetry_e2e.py index 0a57edd3c..83c2dbf81 100644 --- a/tests/e2e/test_telemetry_e2e.py +++ b/tests/e2e/test_telemetry_e2e.py @@ -44,23 +44,45 @@ def connection(self, extra_params=()): @pytest.mark.serial +@pytest.mark.xdist_group(name="serial_telemetry") class TestTelemetryE2E(TelemetryTestBase): """E2E tests for telemetry scenarios - must run serially due to shared host-level telemetry client""" @pytest.fixture(autouse=True) def telemetry_setup_teardown(self): """Clean up telemetry client state before and after each test""" + # Clean up BEFORE test starts + # Use wait=True to ensure all pending telemetry from previous tests completes + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) # WAIT for pending telemetry + TelemetryClientFactory._executor = None + TelemetryClientFactory._stop_flush_thread() + TelemetryClientFactory._flush_event.clear() # Clear the event flag + TelemetryClientFactory._clients.clear() + TelemetryClientFactory._initialized = False + + # Clear feature flags cache before test starts + from databricks.sql.common.feature_flag import FeatureFlagsContextFactory + with FeatureFlagsContextFactory._lock: + FeatureFlagsContextFactory._context_map.clear() + if FeatureFlagsContextFactory._executor: + FeatureFlagsContextFactory._executor.shutdown(wait=False) + FeatureFlagsContextFactory._executor = None + try: yield finally: + # Clean up AFTER test ends + # Use wait=True to ensure this test's telemetry completes if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor.shutdown(wait=True) # WAIT for this test's telemetry TelemetryClientFactory._executor = None TelemetryClientFactory._stop_flush_thread() + TelemetryClientFactory._flush_event.clear() # Clear the event flag + TelemetryClientFactory._clients.clear() TelemetryClientFactory._initialized = False - # Clear feature flags cache to prevent state leakage between tests - from databricks.sql.common.feature_flag import FeatureFlagsContextFactory + # Clear feature flags cache after test ends with FeatureFlagsContextFactory._lock: FeatureFlagsContextFactory._context_map.clear() if FeatureFlagsContextFactory._executor: From 9fe7356a18d611ae18943c2f972160657b08eea2 Mon Sep 17 00:00:00 2001 From: jayant <167047871+jayantsing-db@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:52:45 +0530 Subject: [PATCH 22/39] Bump to version 4.2.5 (#737) Signed-off-by: Jayant Singh --- CHANGELOG.md | 4 ++++ pyproject.toml | 4 ++-- src/databricks/sql/__init__.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 41105b072..0ba3bb1a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +# 4.2.5 (2026-02-09) +- Fix feature-flag endpoint retries in gov region (databricks/databricks-sql-python#735 by @samikshya-db) +- Improve telemetry lifecycle management (databricks/databricks-sql-python#734 by @msrathore-db) + # 4.2.4 (2026-01-07) - Fixed the exception handler close() on _TelemetryClientHolder (databricks/databricks-sql-python#723 by @msrathore-db) - Created util method to normalise http protocol in http path (databricks/databricks-sql-python#724 by @nikhilsuri-db) diff --git a/pyproject.toml b/pyproject.toml index 8a635588c..911f1b79c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.2.4" +version = "4.2.5" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" @@ -92,4 +92,4 @@ show_missing = true skip_covered = false [tool.coverage.xml] -output = "coverage.xml" \ No newline at end of file +output = "coverage.xml" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 41784bc45..c9195b89f 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -71,7 +71,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.2.4" +__version__ = "4.2.5" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From 38097f26ea7c21c6a7e9672397ee0deb3843f0e3 Mon Sep 17 00:00:00 2001 From: Jiabin Hu Date: Mon, 2 Mar 2026 02:18:54 -0800 Subject: [PATCH 23/39] Add query_tags parameter support for execute methods (#736) * Add statement level query tag support by introducing it as a parameter on execute* methods Signed-off-by: Jiabin Hu * Add query_tags support to executemany method - Added query_tags parameter to executemany() method - Query tags are applied to all queries in the batch - Updated example to demonstrate executemany usage with query_tags - All tests pass (122/122 client tests) Signed-off-by: Jiabin Hu * add example that doesn't have tag Signed-off-by: Jiabin Hu * fix presubmit errors Signed-off-by: Jiabin Hu * another lint Signed-off-by: Jiabin Hu * address review comments Signed-off-by: Jiabin Hu --------- Signed-off-by: Jiabin Hu --- examples/query_tags_example.py | 97 ++++++++++++++++++- .../sql/backend/databricks_client.py | 2 + src/databricks/sql/backend/sea/backend.py | 3 + src/databricks/sql/backend/thrift_backend.py | 22 ++++- src/databricks/sql/client.py | 24 ++++- src/databricks/sql/utils.py | 43 ++++++++ tests/unit/test_util.py | 63 ++++++++++++ 7 files changed, 244 insertions(+), 10 deletions(-) diff --git a/examples/query_tags_example.py b/examples/query_tags_example.py index f615d082c..687ce4140 100644 --- a/examples/query_tags_example.py +++ b/examples/query_tags_example.py @@ -7,11 +7,23 @@ Query Tags are key-value pairs that can be attached to SQL executions and will appear in the system.query.history table for analytical purposes. -Format: "key1:value1,key2:value2,key3:value3" +There are two ways to set query tags: +1. Session-level: Set in session_configuration (applies to all queries in the session) +2. Per-query level: Pass query_tags parameter to execute() or execute_async() (applies to specific query) + +Format: Dictionary with string keys and optional string values +Example: {"team": "engineering", "application": "etl", "priority": "high"} + +Special cases: +- If a value is None, only the key is included (no colon or value) +- Special characters (comma, colon and backslash) in values are automatically escaped +- Keys are not escaped (should be controlled identifiers) """ print("=== Query Tags Example ===\n") +# Example 1: Session-level query tags (old approach) +print("Example 1: Session-level query tags") with sql.connect( server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), http_path=os.getenv("DATABRICKS_HTTP_PATH"), @@ -21,10 +33,89 @@ 'ansi_mode': False } ) as connection: - + with connection.cursor() as cursor: cursor.execute("SELECT 1") result = cursor.fetchone() print(f" Result: {result[0]}") -print("\n=== Query Tags Example Complete ===") \ No newline at end of file +print() + +# Example 2: Per-query query tags (new approach) +print("Example 2: Per-query query tags") +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + with connection.cursor() as cursor: + # Query 1: Tags for a critical ETL job + cursor.execute( + "SELECT 1", + query_tags={"team": "data-eng", "application": "etl", "priority": "high"} + ) + result = cursor.fetchone() + print(f" ETL Query Result: {result[0]}") + + # Query 2: Tags with None value (key-only tag) + cursor.execute( + "SELECT 2", + query_tags={"team": "analytics", "experimental": None} + ) + result = cursor.fetchone() + print(f" Experimental Query Result: {result[0]}") + + # Query 3: Tags with special characters (automatically escaped) + cursor.execute( + "SELECT 3", + query_tags={"description": "test:with:colons,and,commas"} + ) + result = cursor.fetchone() + print(f" Special Chars Query Result: {result[0]}") + + # Query 4: No tags (demonstrates tags don't persist from previous queries) + cursor.execute("SELECT 4") + result = cursor.fetchone() + print(f" No Tags Query Result: {result[0]}") + +print() + +# Example 3: Async execution with query tags +print("Example 3: Async execution with query tags") +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + with connection.cursor() as cursor: + cursor.execute_async( + "SELECT 5", + query_tags={"team": "data-eng", "mode": "async"} + ) + cursor.get_async_execution_result() + result = cursor.fetchone() + print(f" Async Query Result: {result[0]}") + +print() + +# Example 4: executemany with query tags +print("Example 4: executemany with query tags") +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + with connection.cursor() as cursor: + # Execute multiple queries with the same tags + cursor.executemany( + "SELECT ?", + [[6], [7], [8]], + query_tags={"team": "data-eng", "batch": "executemany"} + ) + result = cursor.fetchone() + print(f" Executemany Query Result (last): {result[0]}") + +print("\n=== Query Tags Example Complete ===") diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 2213635fe..b772e7ddd 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -83,6 +83,7 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, row_limit: Optional[int] = None, + query_tags: Optional[Dict[str, Optional[str]]] = None, ) -> Union[ResultSet, None]: """ Executes a SQL command or query within the specified session. @@ -102,6 +103,7 @@ def execute_command( async_op: Whether to execute the command asynchronously enforce_embedded_schema_correctness: Whether to enforce schema correctness row_limit: Maximum number of rows in the response. + query_tags: Optional dictionary of query tags to apply for this query only. Returns: If async_op is False, returns a ResultSet object containing the diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 1427226d2..a6cff4913 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -463,6 +463,9 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, row_limit: Optional[int] = None, + query_tags: Optional[ + Dict[str, Optional[str]] + ] = None, # TODO: implement query_tags for SEA backend ) -> Union[SeaResultSet, None]: """ Execute a SQL command using the SEA backend. diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index edee02bfa..e23f3389b 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,7 +5,7 @@ import math import time import threading -from typing import List, Optional, Union, Any, TYPE_CHECKING +from typing import Dict, List, Optional, Union, Any, TYPE_CHECKING from uuid import UUID from databricks.sql.common.unified_http_client import UnifiedHttpClient @@ -53,6 +53,7 @@ convert_arrow_based_set_to_arrow_table, convert_decimals_in_arrow_table, convert_column_based_set_to_arrow_table, + serialize_query_tags, ) from databricks.sql.types import SSLOptions from databricks.sql.backend.databricks_client import DatabricksClient @@ -1003,6 +1004,7 @@ def execute_command( async_op=False, enforce_embedded_schema_correctness=False, row_limit: Optional[int] = None, + query_tags: Optional[Dict[str, Optional[str]]] = None, ) -> Union["ResultSet", None]: thrift_handle = session_id.to_thrift_handle() if not thrift_handle: @@ -1022,6 +1024,19 @@ def execute_command( # DBR should be changed to use month_day_nano_interval intervalTypesAsArrow=False, ) + + # Build confOverlay with default configs and query_tags + merged_conf_overlay = { + # We want to receive proper Timestamp arrow types. + "spark.thriftserver.arrowBasedRowSet.timestampAsString": "false" + } + + # Serialize and add query_tags to confOverlay if provided + if query_tags: + serialized_tags = serialize_query_tags(query_tags) + if serialized_tags: + merged_conf_overlay["query_tags"] = serialized_tags + req = ttypes.TExecuteStatementReq( sessionHandle=thrift_handle, statement=operation, @@ -1036,10 +1051,7 @@ def execute_command( canReadArrowResult=True if pyarrow else False, canDecompressLZ4Result=lz4_compression, canDownloadResult=use_cloud_fetch, - confOverlay={ - # We want to receive proper Timestamp arrow types. - "spark.thriftserver.arrowBasedRowSet.timestampAsString": "false" - }, + confOverlay=merged_conf_overlay, useArrowNativeTypes=spark_arrow_types, parameters=parameters, enforceEmbeddedSchemaCorrectness=enforce_embedded_schema_correctness, diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 1a246b7c1..efaf6ae4d 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1263,6 +1263,7 @@ def execute( parameters: Optional[TParameterCollection] = None, enforce_embedded_schema_correctness=False, input_stream: Optional[BinaryIO] = None, + query_tags: Optional[Dict[str, Optional[str]]] = None, ) -> "Cursor": """ Execute a query and wait for execution to complete. @@ -1293,6 +1294,10 @@ def execute( Both will result in the query equivalent to "SELECT * FROM table WHERE field = 'foo' being sent to the server + :param query_tags: Optional dictionary of query tags to apply for this query only. + Tags are key-value pairs that can be used to identify and categorize queries. + Example: {"team": "data-eng", "application": "etl"} + :returns self """ @@ -1333,6 +1338,7 @@ def execute( async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, row_limit=self.row_limit, + query_tags=query_tags, ) if self.active_result_set and self.active_result_set.is_staging_operation: @@ -1349,6 +1355,7 @@ def execute_async( operation: str, parameters: Optional[TParameterCollection] = None, enforce_embedded_schema_correctness=False, + query_tags: Optional[Dict[str, Optional[str]]] = None, ) -> "Cursor": """ @@ -1356,6 +1363,9 @@ def execute_async( :param operation: :param parameters: + :param query_tags: Optional dictionary of query tags to apply for this query only. + Tags are key-value pairs that can be used to identify and categorize queries. + Example: {"team": "data-eng", "application": "etl"} :return: """ @@ -1392,6 +1402,7 @@ def execute_async( async_op=True, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, row_limit=self.row_limit, + query_tags=query_tags, ) return self @@ -1448,7 +1459,12 @@ def get_async_execution_result(self): session_id_hex=self.connection.get_session_id_hex(), ) - def executemany(self, operation, seq_of_parameters): + def executemany( + self, + operation, + seq_of_parameters, + query_tags: Optional[Dict[str, Optional[str]]] = None, + ): """ Execute the operation once for every set of passed in parameters. @@ -1457,10 +1473,14 @@ def executemany(self, operation, seq_of_parameters): Only the final result set is retained. + :param query_tags: Optional dictionary of query tags to apply for all queries in this batch. + Tags are key-value pairs that can be used to identify and categorize queries. + Example: {"team": "data-eng", "application": "etl"} + :returns self """ for parameters in seq_of_parameters: - self.execute(operation, parameters) + self.execute(operation, parameters, query_tags=query_tags) return self @log_latency(StatementType.METADATA) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 043183ac2..125edbbaa 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -898,6 +898,49 @@ def concat_table_chunks( return pyarrow.concat_tables(table_chunks) +def serialize_query_tags( + query_tags: Optional[Dict[str, Optional[str]]] +) -> Optional[str]: + """ + Serialize query_tags dictionary to a string format. + + Format: "key1:value1,key2:value2" + Special cases: + - If value is None, omit the colon and value (e.g., "key1:value1,key2,key3:value3") + - Escape special characters (:, ,, \\) in values with a leading backslash + - Backslashes in keys are escaped; other special characters in keys are not escaped + + Args: + query_tags: Dictionary of query tags where keys are strings and values are optional strings + + Returns: + Serialized string or None if query_tags is None or empty + """ + if not query_tags: + return None + + def escape_value(value: str) -> str: + """Escape special characters in tag values.""" + # Escape backslash first to avoid double-escaping + value = value.replace("\\", r"\\") + # Escape colon and comma + value = value.replace(":", r"\:") + value = value.replace(",", r"\,") + return value + + serialized_parts = [] + for key, value in query_tags.items(): + escaped_key = key.replace("\\", r"\\") + if value is None: + # No colon or value when value is None + serialized_parts.append(escaped_key) + else: + escaped_value = escape_value(value) + serialized_parts.append(f"{escaped_key}:{escaped_value}") + + return ",".join(serialized_parts) + + def build_client_context(server_hostname: str, version: str, **kwargs): """Build ClientContext for HTTP client configuration.""" from databricks.sql.auth.common import ClientContext diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 713342b2e..687bdd391 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -6,6 +6,7 @@ convert_to_assigned_datatypes_in_column_table, ColumnTable, concat_table_chunks, + serialize_query_tags, ) try: @@ -161,3 +162,65 @@ def test_concat_table_chunks__incorrect_column_names_error(self): with pytest.raises(ValueError): concat_table_chunks([column_table1, column_table2]) + + def test_serialize_query_tags_basic(self): + """Test basic query tags serialization""" + query_tags = {"team": "data-eng", "application": "etl"} + result = serialize_query_tags(query_tags) + assert result == "team:data-eng,application:etl" + + def test_serialize_query_tags_with_none_value(self): + """Test query tags with None value (should omit colon and value)""" + query_tags = {"key1": "value1", "key2": None, "key3": "value3"} + result = serialize_query_tags(query_tags) + assert result == "key1:value1,key2,key3:value3" + + def test_serialize_query_tags_with_special_chars(self): + """Test query tags with special characters (colon, comma, backslash)""" + query_tags = { + "key1": "value:with:colons", + "key2": "value,with,commas", + "key3": r"value\with\backslashes", + } + result = serialize_query_tags(query_tags) + assert ( + result + == r"key1:value\:with\:colons,key2:value\,with\,commas,key3:value\\with\\backslashes" + ) + + def test_serialize_query_tags_with_mixed_special_chars(self): + """Test query tags with mixed special characters""" + query_tags = {"key1": r"a:b,c\d"} + result = serialize_query_tags(query_tags) + assert result == r"key1:a\:b\,c\\d" + + def test_serialize_query_tags_empty_dict(self): + """Test serialization with empty dictionary""" + query_tags = {} + result = serialize_query_tags(query_tags) + assert result is None + + def test_serialize_query_tags_none(self): + """Test serialization with None input""" + result = serialize_query_tags(None) + assert result is None + + def test_serialize_query_tags_with_special_chars_in_key(self): + """Test query tags with special characters in keys (only backslashes are escaped in keys)""" + query_tags = { + "key:with:colons": "value1", + "key,with,commas": "value2", + r"key\with\backslashes": "value3", + } + result = serialize_query_tags(query_tags) + # Only backslashes are escaped in keys; colons and commas in keys are not escaped + assert ( + result + == r"key:with:colons:value1,key,with,commas:value2,key\\with\\backslashes:value3" + ) + + def test_serialize_query_tags_all_none_values(self): + """Test query tags where all values are None""" + query_tags = {"key1": None, "key2": None, "key3": None} + result = serialize_query_tags(query_tags) + assert result == "key1,key2,key3" From e916f716521ea16d4087d2d909b3e4f1a896f05f Mon Sep 17 00:00:00 2001 From: Jiabin Hu Date: Sun, 8 Mar 2026 23:20:30 -0700 Subject: [PATCH 24/39] [QI-3367] Allow specifiying query tags as a dict upon connection creation (#749) * Allow specifiying query tags as a dict upon connection creation Signed-off-by: Jiabin Hu * fix comment Signed-off-by: Jiabin Hu --------- Signed-off-by: Jiabin Hu --- examples/query_tags_example.py | 15 ++++++--------- src/databricks/sql/client.py | 11 +++++++++++ tests/unit/test_session.py | 21 +++++++++++++++++++++ 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/examples/query_tags_example.py b/examples/query_tags_example.py index 687ce4140..977dc6ad5 100644 --- a/examples/query_tags_example.py +++ b/examples/query_tags_example.py @@ -8,7 +8,7 @@ in the system.query.history table for analytical purposes. There are two ways to set query tags: -1. Session-level: Set in session_configuration (applies to all queries in the session) +1. Connection-level: Pass query_tags parameter to sql.connect() (applies to all queries in the session) 2. Per-query level: Pass query_tags parameter to execute() or execute_async() (applies to specific query) Format: Dictionary with string keys and optional string values @@ -17,21 +17,18 @@ Special cases: - If a value is None, only the key is included (no colon or value) - Special characters (comma, colon and backslash) in values are automatically escaped -- Keys are not escaped (should be controlled identifiers) +- Backslashes in keys are automatically escaped; other special characters in keys are not allowed """ print("=== Query Tags Example ===\n") -# Example 1: Session-level query tags (old approach) -print("Example 1: Session-level query tags") +# Example 1: Connection-level query tags +print("Example 1: Connection-level query tags") with sql.connect( server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), http_path=os.getenv("DATABRICKS_HTTP_PATH"), access_token=os.getenv("DATABRICKS_TOKEN"), - session_configuration={ - 'QUERY_TAGS': 'team:engineering,test:query-tags', - 'ansi_mode': False - } + query_tags={"team": "engineering", "application": "etl"}, ) as connection: with connection.cursor() as cursor: @@ -41,7 +38,7 @@ print() -# Example 2: Per-query query tags (new approach) +# Example 2: Per-query query tags print("Example 2: Per-query query tags") with sql.connect( server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index efaf6ae4d..2aeea175e 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -36,6 +36,7 @@ ColumnQueue, build_client_context, get_session_config_value, + serialize_query_tags, ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -106,6 +107,7 @@ def __init__( schema: Optional[str] = None, _use_arrow_native_complex_types: Optional[bool] = True, ignore_transactions: bool = True, + query_tags: Optional[Dict[str, Optional[str]]] = None, **kwargs, ) -> None: """ @@ -281,6 +283,15 @@ def read(self) -> Optional[OAuthToken]: "spark.sql.thriftserver.metadata.metricview.enabled" ] = "true" + if query_tags is not None: + if session_configuration is None: + session_configuration = {} + serialized = serialize_query_tags(query_tags) + if serialized: + session_configuration["QUERY_TAGS"] = serialized + else: + session_configuration.pop("QUERY_TAGS", None) + self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1d70ec4c4..3a43c1a75 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -202,3 +202,24 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): close_session_call_args = instance.close_session.call_args[0][0] assert close_session_call_args.guid == b"\x22" assert close_session_call_args.secret == b"\x33" + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_query_tags_dict_sets_session_config(self, mock_client_class): + databricks.sql.connect( + query_tags={"team": "data-eng", "project": "etl"}, + **self.DUMMY_CONNECTION_ARGS, + ) + + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + assert call_kwargs["session_configuration"]["QUERY_TAGS"] == "team:data-eng,project:etl" + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_query_tags_dict_takes_precedence_over_session_config(self, mock_client_class): + databricks.sql.connect( + query_tags={"team": "new-team"}, + session_configuration={"QUERY_TAGS": "team:old-team,other:value"}, + **self.DUMMY_CONNECTION_ARGS, + ) + + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + assert call_kwargs["session_configuration"]["QUERY_TAGS"] == "team:new-team" From 12bfd5b7a86408da7ebbb61513dca72f90c50c4b Mon Sep 17 00:00:00 2001 From: Shubhambhusate <99666676+Shubhambhusate@users.noreply.github.com> Date: Tue, 10 Mar 2026 16:29:20 +0530 Subject: [PATCH 25/39] =?UTF-8?q?Fix=20float=20inference=20to=20use=20Doub?= =?UTF-8?q?leParameter=20(64-bit)=20instead=20of=20FloatP=E2=80=A6=20(#742?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix float inference to use DoubleParameter (64-bit) instead of FloatParameter (32-bit) Signed-off-by: Shubhambhusate * Add DoubleParameter with Primitive.DOUBLE to test_inference coverage --------- Signed-off-by: Shubhambhusate --- src/databricks/sql/parameters/native.py | 2 +- tests/unit/test_parameters.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/parameters/native.py b/src/databricks/sql/parameters/native.py index b7c448254..d0fb8d82c 100644 --- a/src/databricks/sql/parameters/native.py +++ b/src/databricks/sql/parameters/native.py @@ -659,7 +659,7 @@ def dbsql_parameter_from_primitive( elif isinstance(value, str): return StringParameter(value=value, name=name) elif isinstance(value, float): - return FloatParameter(value=value, name=name) + return DoubleParameter(value=value, name=name) elif isinstance(value, datetime.datetime): return TimestampParameter(value=value, name=name) elif isinstance(value, datetime.date): diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index cf2e24951..0588eb499 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -295,7 +295,8 @@ def test_tspark_param_ordinal(self): (BigIntegerParameter, Primitive.BIGINT), (BooleanParameter, Primitive.BOOL), (DateParameter, Primitive.DATE), - (FloatParameter, Primitive.FLOAT), + (DoubleParameter, Primitive.DOUBLE), + (DoubleParameter, Primitive.FLOAT), (VoidParameter, Primitive.NONE), (TimestampParameter, Primitive.TIMESTAMP), (MapParameter, Primitive.MAP), @@ -305,7 +306,7 @@ def test_tspark_param_ordinal(self): def test_inference(self, _type: TDbsqlParameter, prim: Primitive): """This method only tests inferrable types. - Not tested are TinyIntParameter, SmallIntParameter DoubleParameter and TimestampNTZParameter + Not tested are TinyIntParameter, SmallIntParameter, FloatParameter and TimestampNTZParameter """ inferred_type = dbsql_parameter_from_primitive(prim.value) From 36fb3760b9e1eec552376004ec9494fe6b425ff1 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 10 Mar 2026 16:29:57 +0530 Subject: [PATCH 26/39] Updated the PyArrow concatenation of tables to use promote_options as default (#751) Updated pyarrow-concat --- src/databricks/sql/utils.py | 2 +- tests/unit/test_cloud_fetch_queue.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 125edbbaa..b1fff7202 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -895,7 +895,7 @@ def concat_table_chunks( result_table[j].extend(table_chunks[i].column_table[j]) return ColumnTable(result_table, table_chunks[0].column_names) else: - return pyarrow.concat_tables(table_chunks) + return pyarrow.concat_tables(table_chunks, promote_options="default") def serialize_query_tags( diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 0c3fc7103..97bb99ad9 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -174,7 +174,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): assert ( result == pyarrow.concat_tables( - [self.make_arrow_table(), self.make_arrow_table()] + [self.make_arrow_table(), self.make_arrow_table()],promote_options="default" )[:7] ) @@ -266,7 +266,7 @@ def test_remaining_rows_multiple_tables_fully_returned( assert ( result == pyarrow.concat_tables( - [self.make_arrow_table(), self.make_arrow_table()] + [self.make_arrow_table(), self.make_arrow_table()], promote_options="default" )[3:] ) From ca4d7bcbcd6846bd9be1b079dcd84ce252e525b2 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 16 Mar 2026 14:50:33 +0530 Subject: [PATCH 27/39] Add statement-level query_tags support for SEA backend (#754) * Add statement-level query_tags support for SEA backend Signed-off-by: Sreekanth Vadigi * Simplify None handling in query_tags serialization Signed-off-by: Sreekanth Vadigi --------- Signed-off-by: Sreekanth Vadigi --- src/databricks/sql/backend/sea/backend.py | 5 +- .../sql/backend/sea/models/requests.py | 8 ++ tests/unit/test_sea_backend.py | 110 +++++++++++++++++- 3 files changed, 118 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index a6cff4913..ff130cd39 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -463,9 +463,7 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, row_limit: Optional[int] = None, - query_tags: Optional[ - Dict[str, Optional[str]] - ] = None, # TODO: implement query_tags for SEA backend + query_tags: Optional[Dict[str, Optional[str]]] = None, ) -> Union[SeaResultSet, None]: """ Execute a SQL command using the SEA backend. @@ -532,6 +530,7 @@ def execute_command( row_limit=row_limit, parameters=sea_parameters if sea_parameters else None, result_compression=result_compression, + query_tags=query_tags, ) response_data = self._http_client._make_request( diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index ad046ff54..eb156fb1a 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -31,6 +31,7 @@ class ExecuteStatementRequest: wait_timeout: str = "10s" on_wait_timeout: str = "CONTINUE" row_limit: Optional[int] = None + query_tags: Optional[Dict[str, Optional[str]]] = None def to_dict(self) -> Dict[str, Any]: """Convert the request to a dictionary for JSON serialization.""" @@ -60,6 +61,13 @@ def to_dict(self) -> Dict[str, Any]: for param in self.parameters ] + # SEA API expects query_tags as an array of {key, value} objects. + # None/empty values are left to the server to handle as key-only tags. + if self.query_tags: + result["query_tags"] = [ + {"key": k, "value": v} for k, v in self.query_tags.items() + ] + return result diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 26a898cb8..f71e60943 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -185,7 +185,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i session_config = { "ANSI_MODE": "FALSE", # Supported parameter "STATEMENT_TIMEOUT": "3600", # Supported parameter - "QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter + "QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter "unsupported_param": "value", # Unsupported parameter } catalog = "test_catalog" @@ -197,7 +197,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i "session_confs": { "ansi_mode": "FALSE", "statement_timeout": "3600", - "query_tags": "team:marketing,dashboard:abc123", + "query_tags": "team:marketing,dashboard:abc123", }, "catalog": catalog, "schema": schema, @@ -416,6 +416,112 @@ def test_command_execution_advanced( ) assert "Command failed" in str(excinfo.value) + def _execute_response(self): + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + def _run_execute_command(self, sea_client, sea_session_id, mock_cursor, **kwargs): + """Helper to invoke execute_command with default args.""" + return sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + **kwargs, + ) + + def test_execute_command_query_tags_string_values( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """query_tags with string values are included in the request payload.""" + mock_http_client._make_request.return_value = self._execute_response() + with patch.object(sea_client, "_response_to_result_set"): + self._run_execute_command( + sea_client, + sea_session_id, + mock_cursor, + query_tags={"env": "prod", "team": "data"}, + ) + _, kwargs = mock_http_client._make_request.call_args + assert kwargs["data"]["query_tags"] == [ + {"key": "env", "value": "prod"}, + {"key": "team", "value": "data"}, + ] + + def test_execute_command_query_tags_none_value( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """query_tags with a None value omit the value field (key-only tag).""" + mock_http_client._make_request.return_value = self._execute_response() + with patch.object(sea_client, "_response_to_result_set"): + self._run_execute_command( + sea_client, + sea_session_id, + mock_cursor, + query_tags={"env": "prod", "team": None}, + ) + _, kwargs = mock_http_client._make_request.call_args + assert kwargs["data"]["query_tags"] == [ + {"key": "env", "value": "prod"}, + {"key": "team", "value": None}, + ] + + def test_execute_command_no_query_tags_omitted( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """query_tags field is absent from the request when not provided.""" + mock_http_client._make_request.return_value = self._execute_response() + with patch.object(sea_client, "_response_to_result_set"): + self._run_execute_command(sea_client, sea_session_id, mock_cursor) + _, kwargs = mock_http_client._make_request.call_args + assert "query_tags" not in kwargs["data"] + + def test_execute_command_empty_query_tags_omitted( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Empty query_tags dict is treated as absent — field omitted from request.""" + mock_http_client._make_request.return_value = self._execute_response() + with patch.object(sea_client, "_response_to_result_set"): + self._run_execute_command( + sea_client, sea_session_id, mock_cursor, query_tags={} + ) + _, kwargs = mock_http_client._make_request.call_args + assert "query_tags" not in kwargs["data"] + + def test_execute_command_async_query_tags( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """query_tags are included in async execute requests (execute_async path).""" + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-async", + "status": {"state": "PENDING"}, + } + sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, + enforce_embedded_schema_correctness=False, + query_tags={"job": "nightly-etl"}, + ) + _, kwargs = mock_http_client._make_request.call_args + assert kwargs["data"]["query_tags"] == [{"key": "job", "value": "nightly-etl"}] + def test_command_management( self, sea_client, From 330c4454d2425a259f600b2126e01b9ce408274d Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Fri, 27 Mar 2026 15:14:42 +0530 Subject: [PATCH 28/39] Harden CI/CD workflows: fix secret exposure, script injection, and pin all actions to SHA (#762) Updaed the security --- .github/workflows/code-coverage.yml | 21 ++++++++------ .github/workflows/code-quality-checks.yml | 35 ++++++++++++----------- .github/workflows/daily-telemetry-e2e.yml | 13 +++++---- .github/workflows/dco-check.yml | 12 ++++---- .github/workflows/integration.yml | 21 ++++++++------ .github/workflows/publish-manual.yml | 9 ++++-- .github/workflows/publish-test.yml | 19 ++++++++---- .github/workflows/publish.yml | 14 +++++---- 8 files changed, 86 insertions(+), 58 deletions(-) diff --git a/.github/workflows/code-coverage.yml b/.github/workflows/code-coverage.yml index 5c961757e..9250d0ca5 100644 --- a/.github/workflows/code-coverage.yml +++ b/.github/workflows/code-coverage.yml @@ -20,14 +20,12 @@ jobs: # check-out repo and set-up python #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 with: fetch-depth: 0 - ref: ${{ github.event.pull_request.head.ref || github.ref_name }} - repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }} - name: Set up python id: setup-python - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: "3.10" #---------------------------------------------- @@ -41,7 +39,7 @@ jobs: # ----- install & configure poetry ----- #---------------------------------------------- - name: Install Poetry - uses: snok/install-poetry@v1 + uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 with: version: "2.2.1" virtualenvs-create: true @@ -53,7 +51,7 @@ jobs: #---------------------------------------------- - name: Load cached venv id: cached-poetry-dependencies - uses: actions/cache@v4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} @@ -105,8 +103,10 @@ jobs: #---------------------------------------------- - name: Check for coverage override id: override + env: + PR_BODY: ${{ github.event.pull_request.body }} run: | - OVERRIDE_COMMENT=$(echo "${{ github.event.pull_request.body }}" | grep -E "SKIP_COVERAGE_CHECK\s*=" || echo "") + OVERRIDE_COMMENT=$(echo "$PR_BODY" | grep -E "SKIP_COVERAGE_CHECK\s*=" || echo "") if [ -n "$OVERRIDE_COMMENT" ]; then echo "override=true" >> $GITHUB_OUTPUT REASON=$(echo "$OVERRIDE_COMMENT" | sed -E 's/.*SKIP_COVERAGE_CHECK\s*=\s*(.+)/\1/') @@ -153,9 +153,12 @@ jobs: # coverage enforcement summary #---------------------------------------------- - name: Coverage enforcement summary + env: + OVERRIDE: ${{ steps.override.outputs.override }} + REASON: ${{ steps.override.outputs.reason }} run: | - if [ "${{ steps.override.outputs.override }}" == "true" ]; then - echo "⚠️ Coverage checks bypassed: ${{ steps.override.outputs.reason }}" + if [ "$OVERRIDE" == "true" ]; then + echo "⚠️ Coverage checks bypassed: $REASON" echo "Please ensure this override is justified and temporary" else echo "✅ Coverage checks enforced - minimum 85% required" diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index cc3952920..13a889c56 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -2,6 +2,9 @@ name: Code Quality Checks on: [pull_request] +permissions: + contents: read + jobs: run-unit-tests: runs-on: ubuntu-latest @@ -23,17 +26,17 @@ jobs: # check-out repo and set-up python #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Set up python ${{ matrix.python-version }} id: setup-python - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: ${{ matrix.python-version }} #---------------------------------------------- # ----- install & configure poetry ----- #---------------------------------------------- - name: Install Poetry - uses: snok/install-poetry@v1 + uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 with: version: "2.2.1" virtualenvs-create: true @@ -45,7 +48,7 @@ jobs: #---------------------------------------------- - name: Load cached venv id: cached-poetry-dependencies - uses: actions/cache@v4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ matrix.dependency-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} @@ -107,17 +110,17 @@ jobs: # check-out repo and set-up python #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v2 + uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2 - name: Set up python ${{ matrix.python-version }} id: setup-python - uses: actions/setup-python@v2 + uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2 with: python-version: ${{ matrix.python-version }} #---------------------------------------------- # ----- install & configure poetry ----- #---------------------------------------------- - name: Install Poetry - uses: snok/install-poetry@v1 + uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 with: version: "2.2.1" virtualenvs-create: true @@ -129,7 +132,7 @@ jobs: #---------------------------------------------- - name: Load cached venv id: cached-poetry-dependencies - uses: actions/cache@v4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 with: path: .venv-pyarrow key: venv-pyarrow-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ matrix.dependency-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} @@ -185,17 +188,17 @@ jobs: # check-out repo and set-up python #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Set up python ${{ matrix.python-version }} id: setup-python - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: ${{ matrix.python-version }} #---------------------------------------------- # ----- install & configure poetry ----- #---------------------------------------------- - name: Install Poetry - uses: snok/install-poetry@v1 + uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 with: version: "2.2.1" virtualenvs-create: true @@ -207,7 +210,7 @@ jobs: #---------------------------------------------- - name: Load cached venv id: cached-poetry-dependencies - uses: actions/cache@v4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} @@ -238,17 +241,17 @@ jobs: # check-out repo and set-up python #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Set up python ${{ matrix.python-version }} id: setup-python - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: ${{ matrix.python-version }} #---------------------------------------------- # ----- install & configure poetry ----- #---------------------------------------------- - name: Install Poetry - uses: snok/install-poetry@v1 + uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 with: version: "2.2.1" virtualenvs-create: true @@ -260,7 +263,7 @@ jobs: #---------------------------------------------- - name: Load cached venv id: cached-poetry-dependencies - uses: actions/cache@v4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} diff --git a/.github/workflows/daily-telemetry-e2e.yml b/.github/workflows/daily-telemetry-e2e.yml index d60b7f5a9..c3d0da5df 100644 --- a/.github/workflows/daily-telemetry-e2e.yml +++ b/.github/workflows/daily-telemetry-e2e.yml @@ -12,6 +12,9 @@ on: default: 'tests/e2e/test_telemetry_e2e.py' type: string +permissions: + contents: read + jobs: telemetry-e2e-tests: runs-on: ubuntu-latest @@ -29,11 +32,11 @@ jobs: # check-out repo and set-up python #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Set up python id: setup-python - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: "3.10" @@ -41,7 +44,7 @@ jobs: # ----- install & configure poetry ----- #---------------------------------------------- - name: Install Poetry - uses: snok/install-poetry@v1 + uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 with: version: "2.2.1" virtualenvs-create: true @@ -53,7 +56,7 @@ jobs: #---------------------------------------------- - name: Load cached venv id: cached-poetry-dependencies - uses: actions/cache@v4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} @@ -82,7 +85,7 @@ jobs: #---------------------------------------------- - name: Upload test results on failure if: failure() - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4 with: name: telemetry-test-results path: | diff --git a/.github/workflows/dco-check.yml b/.github/workflows/dco-check.yml index 050665ec9..770ffb9ec 100644 --- a/.github/workflows/dco-check.yml +++ b/.github/workflows/dco-check.yml @@ -2,17 +2,19 @@ name: DCO Check on: [pull_request] +permissions: + contents: read + pull-requests: write + jobs: check: - runs-on: - group: databricks-protected-runner-group - labels: linux-ubuntu-latest + runs-on: ubuntu-latest steps: - name: Check for DCO id: dco-check - uses: tisonkun/actions-dco@v1.1 + uses: tisonkun/actions-dco@6d1f8a197db1b04df1769707b46b9366b1eca902 # v1.1 - name: Comment about DCO status - uses: actions/github-script@v7 + uses: actions/github-script@f28e40c7f34bde8b3046d885e986cb6290c5673b # v7 if: ${{ failure() }} with: script: | diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index c915ee6c1..49dedfd91 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -4,7 +4,10 @@ on: push: branches: - main - pull_request: + pull_request: + +permissions: + contents: read jobs: run-non-telemetry-tests: @@ -21,17 +24,17 @@ jobs: # check-out repo and set-up python #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Set up python id: setup-python - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: "3.10" #---------------------------------------------- # ----- install & configure poetry ----- #---------------------------------------------- - name: Install Poetry - uses: snok/install-poetry@v1 + uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 with: version: "2.2.1" virtualenvs-create: true @@ -43,7 +46,7 @@ jobs: #---------------------------------------------- - name: Load cached venv id: cached-poetry-dependencies - uses: actions/cache@v4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} @@ -79,10 +82,10 @@ jobs: DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} steps: - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Set up python id: setup-python - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: "3.10" - name: Install system dependencies @@ -90,14 +93,14 @@ jobs: sudo apt-get update sudo apt-get install -y libkrb5-dev - name: Install Poetry - uses: snok/install-poetry@v1 + uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 with: virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true - name: Load cached venv id: cached-poetry-dependencies - uses: actions/cache@v4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} diff --git a/.github/workflows/publish-manual.yml b/.github/workflows/publish-manual.yml index 2f2a7a4dd..1802664a0 100644 --- a/.github/workflows/publish-manual.yml +++ b/.github/workflows/publish-manual.yml @@ -4,6 +4,9 @@ name: Publish to PyPI Manual [Production] on: workflow_dispatch: {} +permissions: + contents: read + jobs: publish: name: Publish @@ -14,14 +17,14 @@ jobs: # Step 1: Check out the repository code #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v2 # Check out the repository to access the code + uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2 #---------------------------------------------- # Step 2: Set up Python environment #---------------------------------------------- - name: Set up python id: setup-python - uses: actions/setup-python@v2 + uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2 with: python-version: 3.9 # Specify the Python version to be used @@ -29,7 +32,7 @@ jobs: # Step 3: Install and configure Poetry #---------------------------------------------- - name: Install Poetry - uses: snok/install-poetry@v1 # Install Poetry, the Python package manager + uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 with: version: "2.2.1" virtualenvs-create: true diff --git a/.github/workflows/publish-test.yml b/.github/workflows/publish-test.yml index 97a444e68..e967b8085 100644 --- a/.github/workflows/publish-test.yml +++ b/.github/workflows/publish-test.yml @@ -1,5 +1,12 @@ name: Publish to PyPI [Test] -on: [push] +on: + push: + branches: + - main + +permissions: + contents: read + jobs: test-pypi: name: Create patch version number and push to test-pypi @@ -9,17 +16,17 @@ jobs: # check-out repo and set-up python #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Set up python id: setup-python - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: "3.10" #---------------------------------------------- # ----- install & configure poetry ----- #---------------------------------------------- - name: Install Poetry - uses: snok/install-poetry@v1 + uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 with: version: "2.2.1" virtualenvs-create: true @@ -30,7 +37,7 @@ jobs: #---------------------------------------------- - name: Load cached venv id: cached-poetry-dependencies - uses: actions/cache@v4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} @@ -48,7 +55,7 @@ jobs: # Get the current version and increment it (test-pypi requires a unique version number) #---------------------------------------------- - name: Get next version - uses: reecetech/version-increment@2022.2.4 + uses: reecetech/version-increment@ddbbe72b7f76a996076fabfdce21a16384e8644a # 2022.2.4 id: version with: scheme: semver diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index b101f421c..0fd04d992 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -2,6 +2,10 @@ name: Publish to PyPI [Production] on: release: types: [published] + +permissions: + contents: read + jobs: publish: name: Publish @@ -11,17 +15,17 @@ jobs: # check-out repo and set-up python #---------------------------------------------- - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Set up python id: setup-python - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: 3.9 #---------------------------------------------- # ----- install & configure poetry ----- #---------------------------------------------- - name: Install Poetry - uses: snok/install-poetry@v1 + uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 with: version: "2.2.1" virtualenvs-create: true @@ -32,7 +36,7 @@ jobs: #---------------------------------------------- - name: Load cached venv id: cached-poetry-dependencies - uses: actions/cache@v4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} @@ -50,7 +54,7 @@ jobs: # Here we use version-increment to fetch the latest tagged version (we won't increment it though) #------------------------------------------------------------------------------------------------ - name: Get next version - uses: reecetech/version-increment@2022.2.4 + uses: reecetech/version-increment@ddbbe72b7f76a996076fabfdce21a16384e8644a # 2022.2.4 id: version with: scheme: semver From 47933533071c11141b3e94248863644aa1735989 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Mon, 30 Mar 2026 09:59:34 +0530 Subject: [PATCH 29/39] Removed Publish workflow (#764) --- .github/workflows/publish-manual.yml | 90 ---------------------------- .github/workflows/publish-test.yml | 82 ------------------------- .github/workflows/publish.yml | 76 ----------------------- 3 files changed, 248 deletions(-) delete mode 100644 .github/workflows/publish-manual.yml delete mode 100644 .github/workflows/publish-test.yml delete mode 100644 .github/workflows/publish.yml diff --git a/.github/workflows/publish-manual.yml b/.github/workflows/publish-manual.yml deleted file mode 100644 index 1802664a0..000000000 --- a/.github/workflows/publish-manual.yml +++ /dev/null @@ -1,90 +0,0 @@ -name: Publish to PyPI Manual [Production] - -# Allow manual triggering of the workflow -on: - workflow_dispatch: {} - -permissions: - contents: read - -jobs: - publish: - name: Publish - runs-on: ubuntu-latest - - steps: - #---------------------------------------------- - # Step 1: Check out the repository code - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2 - - #---------------------------------------------- - # Step 2: Set up Python environment - #---------------------------------------------- - - name: Set up python - id: setup-python - uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2 - with: - python-version: 3.9 # Specify the Python version to be used - - #---------------------------------------------- - # Step 3: Install and configure Poetry - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # Step 3.5: Install Kerberos system dependencies - #---------------------------------------------- - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - -# #---------------------------------------------- -# # Step 4: Load cached virtual environment (if available) -# #---------------------------------------------- -# - name: Load cached venv -# id: cached-poetry-dependencies -# uses: actions/cache@v2 -# with: -# path: .venv # Path to the virtual environment -# key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} -# # Cache key is generated based on OS, Python version, repo name, and the `poetry.lock` file hash - -# #---------------------------------------------- -# # Step 5: Install dependencies if the cache is not found -# #---------------------------------------------- -# - name: Install dependencies -# if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' # Only run if the cache was not hit -# run: poetry install --no-interaction --no-root # Install dependencies without interaction - -# #---------------------------------------------- -# # Step 6: Update the version to the manually provided version -# #---------------------------------------------- -# - name: Update pyproject.toml with the specified version -# run: poetry version ${{ github.event.inputs.version }} # Use the version provided by the user input - - #---------------------------------------------- - # Step 7: Build and publish the first package to PyPI - #---------------------------------------------- - - name: Build and publish databricks sql connector to PyPI - working-directory: ./databricks_sql_connector - run: | - poetry build - poetry publish -u __token__ -p ${{ secrets.PROD_PYPI_TOKEN }} # Publish with PyPI token - #---------------------------------------------- - # Step 7: Build and publish the second package to PyPI - #---------------------------------------------- - - - name: Build and publish databricks sql connector core to PyPI - working-directory: ./databricks_sql_connector_core - run: | - poetry build - poetry publish -u __token__ -p ${{ secrets.PROD_PYPI_TOKEN }} # Publish with PyPI token \ No newline at end of file diff --git a/.github/workflows/publish-test.yml b/.github/workflows/publish-test.yml deleted file mode 100644 index e967b8085..000000000 --- a/.github/workflows/publish-test.yml +++ /dev/null @@ -1,82 +0,0 @@ -name: Publish to PyPI [Test] -on: - push: - branches: - - main - -permissions: - contents: read - -jobs: - test-pypi: - name: Create patch version number and push to test-pypi - runs-on: ubuntu-latest - steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - name: Set up python - id: setup-python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 - with: - python-version: "3.10" - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # Get the current version and increment it (test-pypi requires a unique version number) - #---------------------------------------------- - - name: Get next version - uses: reecetech/version-increment@ddbbe72b7f76a996076fabfdce21a16384e8644a # 2022.2.4 - id: version - with: - scheme: semver - increment: patch - #---------------------------------------------- - # Tell poetry to update the version number - #---------------------------------------------- - - name: Update pyproject.toml - run: poetry version ${{ steps.version.outputs.major-version }}.${{ steps.version.outputs.minor-version }}.dev$(date +%s) - #---------------------------------------------- - # Build the package (before publish action) - #---------------------------------------------- - - name: Build package - run: poetry build - #---------------------------------------------- - # Configure test-pypi repository - #---------------------------------------------- - - name: Configure test-pypi repository - run: poetry config repositories.testpypi https://test.pypi.org/legacy/ - #---------------------------------------------- - # Attempt push to test-pypi - #---------------------------------------------- - - name: Publish to test-pypi - run: poetry publish --username __token__ --password ${{ secrets.TEST_PYPI_TOKEN }} --repository testpypi diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index 0fd04d992..000000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,76 +0,0 @@ -name: Publish to PyPI [Production] -on: - release: - types: [published] - -permissions: - contents: read - -jobs: - publish: - name: Publish - runs-on: ubuntu-latest - steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - name: Set up python - id: setup-python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 - with: - python-version: 3.9 - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #------------------------------------------------------------------------------------------------ - # Here we use version-increment to fetch the latest tagged version (we won't increment it though) - #------------------------------------------------------------------------------------------------ - - name: Get next version - uses: reecetech/version-increment@ddbbe72b7f76a996076fabfdce21a16384e8644a # 2022.2.4 - id: version - with: - scheme: semver - increment: patch - #----------------------------------------------------------------------------- - # Tell poetry to use the `current-version` that was found by the previous step - #----------------------------------------------------------------------------- - - name: Update pyproject.toml - run: poetry version ${{ steps.version.outputs.current-version }} - #---------------------------------------------- - # Build the package (before publish) - #---------------------------------------------- - - name: Build package - run: poetry build - #---------------------------------------------- - # Publish to pypi - #---------------------------------------------- - - name: Publish to pypi - run: poetry publish --username __token__ --password ${{ secrets.PROD_PYPI_TOKEN }} From e056275e4e88879749517a4607a54e05972dadc3 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 13 Apr 2026 17:33:31 +0530 Subject: [PATCH 30/39] Replace third-party DCO action with custom script (#769) The tisonkun/actions-dco action has been unreliable. Replace it with an inline bash script (matching databricks-sql-go) that checks each commit for a Signed-off-by line, provides clear per-commit feedback, and scopes the trigger to opened/synchronize/reopened events on main. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- .github/workflows/dco-check.yml | 81 +++++++++++++++++++++++++-------- 1 file changed, 62 insertions(+), 19 deletions(-) diff --git a/.github/workflows/dco-check.yml b/.github/workflows/dco-check.yml index 770ffb9ec..b504ad84b 100644 --- a/.github/workflows/dco-check.yml +++ b/.github/workflows/dco-check.yml @@ -1,29 +1,72 @@ name: DCO Check -on: [pull_request] +on: + pull_request: + types: [opened, synchronize, reopened] + branches: [main] permissions: contents: read - pull-requests: write jobs: - check: + dco-check: runs-on: ubuntu-latest + name: Check DCO Sign-off steps: - - name: Check for DCO - id: dco-check - uses: tisonkun/actions-dco@6d1f8a197db1b04df1769707b46b9366b1eca902 # v1.1 - - name: Comment about DCO status - uses: actions/github-script@f28e40c7f34bde8b3046d885e986cb6290c5673b # v7 - if: ${{ failure() }} + - name: Checkout + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 with: - script: | - github.rest.issues.createComment({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - body: `Thanks for your contribution! To satisfy the DCO policy in our \ - [contributing guide](https://github.com/databricks/databricks-sql-python/blob/main/CONTRIBUTING.md) \ - every commit message must include a sign-off message. One or more of your commits is missing this message. \ - You can reword previous commit messages with an interactive rebase (\`git rebase -i main\`).` - }) + fetch-depth: 0 + + - name: Check DCO Sign-off + env: + BASE_SHA: ${{ github.event.pull_request.base.sha }} + HEAD_SHA: ${{ github.event.pull_request.head.sha }} + run: | + #!/bin/bash + set -e + + echo "Checking commits from $BASE_SHA to $HEAD_SHA" + + COMMITS=$(git rev-list --no-merges "$BASE_SHA..$HEAD_SHA") + + if [ -z "$COMMITS" ]; then + echo "No commits found in this PR" + exit 0 + fi + + FAILED_COMMITS=() + + for commit in $COMMITS; do + echo "Checking commit: $commit" + COMMIT_MSG=$(git log --format=%B -n 1 "$commit") + if echo "$COMMIT_MSG" | grep -q "^Signed-off-by: "; then + echo " Commit $commit has DCO sign-off" + else + echo " Commit $commit is missing DCO sign-off" + FAILED_COMMITS+=("$commit") + fi + done + + if [ ${#FAILED_COMMITS[@]} -ne 0 ]; then + echo "" + echo "DCO Check Failed!" + echo "The following commits are missing the required 'Signed-off-by' line:" + for commit in "${FAILED_COMMITS[@]}"; do + echo " - $commit: $(git log --format=%s -n 1 "$commit")" + done + echo "" + echo "To fix this, you need to sign off your commits. You can:" + echo "1. Add sign-off to new commits: git commit -s -m 'Your commit message'" + echo "2. Amend existing commits: git commit --amend --signoff" + echo "3. For multiple commits, use: git rebase --signoff HEAD~N (where N is the number of commits)" + echo "" + echo "The sign-off should be in the format:" + echo "Signed-off-by: Your Name " + echo "" + echo "For more details, see CONTRIBUTING.md" + exit 1 + else + echo "" + echo "All commits have proper DCO sign-off!" + fi From fbdcd32307ae65aa8a404954b8e132eed9f418a1 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 13 Apr 2026 18:28:47 +0530 Subject: [PATCH 31/39] Migrate CI to protected runners and JFrog PyPI proxy (#770) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Migrate CI to databricks-protected runners and route PyPI through JFrog Protected runners are required for Databricks OSS repos. Add a setup-jfrog composite action (OIDC-based, matching databricks-odbc) that sets PIP_INDEX_URL so all pip/poetry installs go through the JFrog PyPI proxy. Every workflow now runs on the databricks-protected-runner-group with id-token: write for the OIDC exchange. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala * Add Poetry JFrog source configuration to all workflows The previous commit only set PIP_INDEX_URL, but Poetry uses its own resolver and needs explicit source configuration. Add a "Configure Poetry for JFrog" step after poetry install in every job that sets up the JFrog repository and credentials, then adds it as the primary source for the project. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala * Fix step ordering: move JFrog setup after poetry install The snok/install-poetry action uses pip internally to install poetry. When PIP_INDEX_URL was set before this step, the installer tried to route through JFrog and failed with an SSL error. Move the JFrog OIDC token + PIP_INDEX_URL + poetry source configuration to run after Install Poetry but before poetry install. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala * Replace snok/install-poetry with pip install through JFrog The hardened runners block direct access to install.python-poetry.org, causing snok/install-poetry to fail with SSL errors. Replace it with `pip install poetry==2.2.1` which routes through the JFrog PyPI proxy. New step ordering: checkout → setup-python → Setup JFrog (OIDC + PIP_INDEX_URL) → pip install poetry → Configure Poetry for JFrog → poetry install. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala * Add poetry lock --no-update after source add to fix lock mismatch poetry source add modifies pyproject.toml, which makes poetry refuse to install from the existing lock file. Running poetry lock --no-update regenerates the lock file metadata without changing dependency versions. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala * Fix poetry lock flag and YAML indentation Poetry 2.x doesn't have --no-update flag, use poetry lock instead. Also fix indentation of poetry lock in the arrow test job. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala * Move JFrog setup before setup-python, matching sqlalchemy pattern Follow the proven pattern from databricks/databricks-sqlalchemy#59: checkout → Setup JFrog → setup-python → pip install poetry → poetry source add + poetry lock → poetry install. The hardened runners block pypi.org at the network level, so JFrog must be configured before actions/setup-python (which upgrades pip). Also simplified workflows by removing verbose section comments. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala * Extract setup-poetry composite action to remove duplication Create .github/actions/setup-poetry that bundles JFrog setup, setup-python, poetry install via pip, JFrog source config, cache, and dependency install into a single reusable action with inputs for python-version, install-args, cache-path, and cache-suffix. All workflows now call setup-poetry instead of repeating these steps, matching the pattern from databricks/databricks-sqlalchemy#59. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --------- Signed-off-by: Vikrant Puppala --- .github/actions/setup-jfrog/action.yml | 32 +++ .github/actions/setup-poetry/action.yml | 63 +++++ .github/workflows/code-coverage.yml | 84 +------ .github/workflows/code-quality-checks.yml | 265 +++++----------------- .github/workflows/daily-telemetry-e2e.yml | 61 +---- .github/workflows/dco-check.yml | 4 +- .github/workflows/integration.yml | 81 ++----- 7 files changed, 196 insertions(+), 394 deletions(-) create mode 100644 .github/actions/setup-jfrog/action.yml create mode 100644 .github/actions/setup-poetry/action.yml diff --git a/.github/actions/setup-jfrog/action.yml b/.github/actions/setup-jfrog/action.yml new file mode 100644 index 000000000..97ae146ba --- /dev/null +++ b/.github/actions/setup-jfrog/action.yml @@ -0,0 +1,32 @@ +name: Setup JFrog OIDC +description: Obtain a JFrog access token via GitHub OIDC and configure pip to use JFrog PyPI proxy + +runs: + using: composite + steps: + - name: Get JFrog OIDC token + shell: bash + run: | + set -euo pipefail + ID_TOKEN=$(curl -sLS \ + -H "User-Agent: actions/oidc-client" \ + -H "Authorization: Bearer $ACTIONS_ID_TOKEN_REQUEST_TOKEN" \ + "${ACTIONS_ID_TOKEN_REQUEST_URL}&audience=jfrog-github" | jq .value | tr -d '"') + echo "::add-mask::${ID_TOKEN}" + ACCESS_TOKEN=$(curl -sLS -XPOST -H "Content-Type: application/json" \ + "https://databricks.jfrog.io/access/api/v1/oidc/token" \ + -d "{\"grant_type\": \"urn:ietf:params:oauth:grant-type:token-exchange\", \"subject_token_type\":\"urn:ietf:params:oauth:token-type:id_token\", \"subject_token\": \"${ID_TOKEN}\", \"provider_name\": \"github-actions\"}" | jq .access_token | tr -d '"') + echo "::add-mask::${ACCESS_TOKEN}" + if [ -z "$ACCESS_TOKEN" ] || [ "$ACCESS_TOKEN" = "null" ]; then + echo "FAIL: Could not extract JFrog access token" + exit 1 + fi + echo "JFROG_ACCESS_TOKEN=${ACCESS_TOKEN}" >> "$GITHUB_ENV" + echo "JFrog OIDC token obtained successfully" + + - name: Configure pip + shell: bash + run: | + set -euo pipefail + echo "PIP_INDEX_URL=https://gha-service-account:${JFROG_ACCESS_TOKEN}@databricks.jfrog.io/artifactory/api/pypi/db-pypi/simple" >> "$GITHUB_ENV" + echo "pip configured to use JFrog registry" diff --git a/.github/actions/setup-poetry/action.yml b/.github/actions/setup-poetry/action.yml new file mode 100644 index 000000000..f7e15b1c0 --- /dev/null +++ b/.github/actions/setup-poetry/action.yml @@ -0,0 +1,63 @@ +name: Setup Poetry with JFrog +description: Install Poetry, configure JFrog as primary PyPI source, and install project dependencies + +inputs: + python-version: + description: Python version to set up + required: true + install-args: + description: Extra arguments for poetry install (e.g. --all-extras) + required: false + default: "" + cache-path: + description: Path to the virtualenv for caching (e.g. .venv or .venv-pyarrow) + required: false + default: ".venv" + cache-suffix: + description: Extra suffix for the cache key to avoid collisions across job variants + required: false + default: "" + +runs: + using: composite + steps: + - name: Setup JFrog + uses: ./.github/actions/setup-jfrog + + - name: Set up python ${{ inputs.python-version }} + id: setup-python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 + with: + python-version: ${{ inputs.python-version }} + + - name: Install Poetry + shell: bash + run: | + pip install poetry==2.2.1 + poetry config virtualenvs.create true + poetry config virtualenvs.in-project true + poetry config installer.parallel true + + - name: Configure Poetry JFrog source + shell: bash + run: | + poetry config repositories.jfrog https://databricks.jfrog.io/artifactory/api/pypi/db-pypi/simple + poetry config http-basic.jfrog gha-service-account "${JFROG_ACCESS_TOKEN}" + poetry source add --priority=primary jfrog https://databricks.jfrog.io/artifactory/api/pypi/db-pypi/simple + poetry lock + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 + with: + path: ${{ inputs.cache-path }} + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ inputs.cache-suffix }}${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} + + - name: Install dependencies + if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' + shell: bash + run: poetry install --no-interaction --no-root + + - name: Install library + shell: bash + run: poetry install --no-interaction ${{ inputs.install-args }} diff --git a/.github/workflows/code-coverage.yml b/.github/workflows/code-coverage.yml index 9250d0ca5..c188c5b3a 100644 --- a/.github/workflows/code-coverage.yml +++ b/.github/workflows/code-coverage.yml @@ -2,12 +2,15 @@ name: Code Coverage permissions: contents: read + id-token: write on: [pull_request, workflow_dispatch] jobs: test-with-coverage: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest environment: azure-prod env: DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} @@ -16,63 +19,19 @@ jobs: DATABRICKS_CATALOG: peco DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - name: Check out repository uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 with: fetch-depth: 0 - - name: Set up python - id: setup-python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 - with: - python-version: "3.10" - #---------------------------------------------- - # ----- install system dependencies ----- - #---------------------------------------------- - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y libkrb5-dev - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 + - name: Setup Poetry + uses: ./.github/actions/setup-poetry with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Install library - run: poetry install --no-interaction --all-extras - #---------------------------------------------- - # run parallel tests with coverage - #---------------------------------------------- + python-version: "3.10" + install-args: "--all-extras" - name: Run parallel tests with coverage continue-on-error: false run: | @@ -83,24 +42,15 @@ jobs: --cov-report=xml \ --cov-report=term \ -v - - #---------------------------------------------- - # run telemetry tests with coverage (isolated) - #---------------------------------------------- - name: Run telemetry tests with coverage (isolated) continue-on-error: false run: | - # Run test_concurrent_telemetry.py separately for isolation poetry run pytest tests/e2e/test_concurrent_telemetry.py \ --cov=src \ --cov-append \ --cov-report=xml \ --cov-report=term \ -v - - #---------------------------------------------- - # check for coverage override - #---------------------------------------------- - name: Check for coverage override id: override env: @@ -116,9 +66,6 @@ jobs: echo "override=false" >> $GITHUB_OUTPUT echo "No coverage override found" fi - #---------------------------------------------- - # check coverage percentage - #---------------------------------------------- - name: Check coverage percentage if: steps.override.outputs.override == 'false' run: | @@ -127,20 +74,14 @@ jobs: echo "ERROR: Coverage file not found at $COVERAGE_FILE" exit 1 fi - - # Install xmllint if not available if ! command -v xmllint &> /dev/null; then sudo apt-get update && sudo apt-get install -y libxml2-utils fi - COVERED=$(xmllint --xpath "string(//coverage/@lines-covered)" "$COVERAGE_FILE") TOTAL=$(xmllint --xpath "string(//coverage/@lines-valid)" "$COVERAGE_FILE") PERCENTAGE=$(python3 -c "covered=${COVERED}; total=${TOTAL}; print(round((covered/total)*100, 2))") - echo "Branch Coverage: $PERCENTAGE%" echo "Required Coverage: 85%" - - # Use Python to compare the coverage with 85 python3 -c "import sys; sys.exit(0 if float('$PERCENTAGE') >= 85 else 1)" if [ $? -eq 1 ]; then echo "ERROR: Coverage is $PERCENTAGE%, which is less than the required 85%" @@ -148,19 +89,14 @@ jobs: else echo "SUCCESS: Coverage is $PERCENTAGE%, which meets the required 85%" fi - - #---------------------------------------------- - # coverage enforcement summary - #---------------------------------------------- - name: Coverage enforcement summary env: OVERRIDE: ${{ steps.override.outputs.override }} REASON: ${{ steps.override.outputs.reason }} run: | if [ "$OVERRIDE" == "true" ]; then - echo "⚠️ Coverage checks bypassed: $REASON" + echo "Coverage checks bypassed: $REASON" echo "Please ensure this override is justified and temporary" else - echo "✅ Coverage checks enforced - minimum 85% required" + echo "Coverage checks enforced - minimum 85% required" fi - diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 13a889c56..ecc238263 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -4,95 +4,56 @@ on: [pull_request] permissions: contents: read + id-token: write jobs: run-unit-tests: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] dependency-version: ["default", "min"] - # Optimize matrix - test min/max on subset of Python versions exclude: - python-version: "3.12" dependency-version: "min" - python-version: "3.13" dependency-version: "min" - + name: "Unit Tests (Python ${{ matrix.python-version }}, ${{ matrix.dependency-version }} deps)" - + steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - name: Check out repository uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - name: Set up python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 + - name: Setup Poetry + uses: ./.github/actions/setup-poetry with: python-version: ${{ matrix.python-version }} - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ matrix.dependency-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install library - run: poetry install --no-interaction - #---------------------------------------------- - # override with custom dependency versions - #---------------------------------------------- + cache-suffix: "${{ matrix.dependency-version }}-" - name: Install Python tools for custom versions if: matrix.dependency-version != 'default' run: poetry run pip install toml packaging - - name: Generate requirements file if: matrix.dependency-version != 'default' run: | poetry run python scripts/dependency_manager.py ${{ matrix.dependency-version }} --output requirements-${{ matrix.dependency-version }}.txt echo "Generated requirements for ${{ matrix.dependency-version }} versions:" cat requirements-${{ matrix.dependency-version }}.txt - - name: Override with custom dependency versions if: matrix.dependency-version != 'default' run: poetry run pip install -r requirements-${{ matrix.dependency-version }}.txt - - #---------------------------------------------- - # run test suite - #---------------------------------------------- - name: Show installed versions run: | echo "=== Dependency Version: ${{ matrix.dependency-version }} ===" poetry run pip list - - name: Run tests run: poetry run python -m pytest tests/unit + run-unit-tests-with-arrow: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] @@ -102,186 +63,74 @@ jobs: dependency-version: "min" - python-version: "3.13" dependency-version: "min" - - name: "Unit Tests + PyArrow (Python ${{ matrix.python-version }}, ${{ matrix.dependency-version }} deps)" - - steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2 - - name: Set up python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2 - with: - python-version: ${{ matrix.python-version }} - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 - with: - path: .venv-pyarrow - key: venv-pyarrow-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ matrix.dependency-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Install library - run: poetry install --no-interaction --all-extras - #---------------------------------------------- - # override with custom dependency versions - #---------------------------------------------- - - name: Install Python tools for custom versions - if: matrix.dependency-version != 'default' - run: poetry run pip install toml packaging - - name: Generate requirements file with pyarrow - if: matrix.dependency-version != 'default' - run: | - poetry run python scripts/dependency_manager.py ${{ matrix.dependency-version }} --output requirements-${{ matrix.dependency-version }}-arrow.txt - echo "Generated requirements for ${{ matrix.dependency-version }} versions with PyArrow:" - cat requirements-${{ matrix.dependency-version }}-arrow.txt + name: "Unit Tests + PyArrow (Python ${{ matrix.python-version }}, ${{ matrix.dependency-version }} deps)" - - name: Override with custom dependency versions - if: matrix.dependency-version != 'default' - run: poetry run pip install -r requirements-${{ matrix.dependency-version }}-arrow.txt - #---------------------------------------------- - # run test suite - #---------------------------------------------- - - name: Show installed versions - run: | - echo "=== Dependency Version: ${{ matrix.dependency-version }} with PyArrow ===" - poetry run pip list + steps: + - name: Check out repository + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - name: Install Kerberos system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libkrb5-dev + - name: Setup Poetry + uses: ./.github/actions/setup-poetry + with: + python-version: ${{ matrix.python-version }} + install-args: "--all-extras" + cache-path: ".venv-pyarrow" + cache-suffix: "pyarrow-${{ matrix.dependency-version }}-" + - name: Install Python tools for custom versions + if: matrix.dependency-version != 'default' + run: poetry run pip install toml packaging + - name: Generate requirements file with pyarrow + if: matrix.dependency-version != 'default' + run: | + poetry run python scripts/dependency_manager.py ${{ matrix.dependency-version }} --output requirements-${{ matrix.dependency-version }}-arrow.txt + echo "Generated requirements for ${{ matrix.dependency-version }} versions with PyArrow:" + cat requirements-${{ matrix.dependency-version }}-arrow.txt + - name: Override with custom dependency versions + if: matrix.dependency-version != 'default' + run: poetry run pip install -r requirements-${{ matrix.dependency-version }}-arrow.txt + - name: Show installed versions + run: | + echo "=== Dependency Version: ${{ matrix.dependency-version }} with PyArrow ===" + poetry run pip list + - name: Run tests + run: poetry run python -m pytest tests/unit - - name: Run tests - run: poetry run python -m pytest tests/unit check-linting: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - name: Check out repository uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - name: Set up python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 + - name: Setup Poetry + uses: ./.github/actions/setup-poetry with: python-version: ${{ matrix.python-version }} - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install library - run: poetry install --no-interaction - #---------------------------------------------- - # black the code - #---------------------------------------------- - name: Black run: poetry run black --check src check-types: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - name: Check out repository uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - name: Set up python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 + - name: Setup Poetry + uses: ./.github/actions/setup-poetry with: python-version: ${{ matrix.python-version }} - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - #---------------------------------------------- - # install your root project, if required - #---------------------------------------------- - - name: Install library - run: poetry install --no-interaction - #---------------------------------------------- - # mypy the code - #---------------------------------------------- - name: Mypy run: | - mkdir .mypy_cache # Workaround for bad error message "error: --install-types failed (no mypy cache directory)"; see https://github.com/python/mypy/issues/10768#issuecomment-2178450153 + mkdir .mypy_cache poetry run mypy --install-types --non-interactive src diff --git a/.github/workflows/daily-telemetry-e2e.yml b/.github/workflows/daily-telemetry-e2e.yml index c3d0da5df..b6f78726c 100644 --- a/.github/workflows/daily-telemetry-e2e.yml +++ b/.github/workflows/daily-telemetry-e2e.yml @@ -3,7 +3,7 @@ name: Daily Telemetry E2E Tests on: schedule: - cron: '0 0 * * 0' # Run every Sunday at midnight UTC - + workflow_dispatch: # Allow manual triggering inputs: test_pattern: @@ -14,75 +14,39 @@ on: permissions: contents: read + id-token: write jobs: telemetry-e2e-tests: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest environment: azure-prod - + env: DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} DATABRICKS_CATALOG: peco DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} - + steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - name: Check out repository uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - - name: Set up python - id: setup-python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 - with: - python-version: "3.10" - - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - name: Install Kerberos system dependencies run: | sudo apt-get update sudo apt-get install -y libkrb5-dev - - name: Install dependencies - run: poetry install --no-interaction --all-extras - - #---------------------------------------------- - # run telemetry E2E tests - #---------------------------------------------- + - name: Setup Poetry + uses: ./.github/actions/setup-poetry + with: + python-version: "3.10" + install-args: "--all-extras" - name: Run telemetry E2E tests run: | TEST_PATTERN="${{ github.event.inputs.test_pattern || 'tests/e2e/test_telemetry_e2e.py' }}" echo "Running tests: $TEST_PATTERN" poetry run python -m pytest $TEST_PATTERN -v -s - - #---------------------------------------------- - # upload test results on failure - #---------------------------------------------- - name: Upload test results on failure if: failure() uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4 @@ -92,4 +56,3 @@ jobs: .pytest_cache/ tests-unsafe.log retention-days: 7 - diff --git a/.github/workflows/dco-check.yml b/.github/workflows/dco-check.yml index b504ad84b..fdcf1b3bb 100644 --- a/.github/workflows/dco-check.yml +++ b/.github/workflows/dco-check.yml @@ -10,7 +10,9 @@ permissions: jobs: dco-check: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest name: Check DCO Sign-off steps: - name: Checkout diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 49dedfd91..6c0cc7059 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -1,17 +1,20 @@ name: Integration Tests on: - push: + push: branches: - main pull_request: permissions: contents: read + id-token: write jobs: run-non-telemetry-tests: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest environment: azure-prod env: DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} @@ -20,59 +23,29 @@ jobs: DATABRICKS_CATALOG: peco DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - name: Check out repository uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - name: Set up python - id: setup-python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 - with: - python-version: "3.10" - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 - with: - version: "2.2.1" - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - name: Install Kerberos system dependencies run: | sudo apt-get update sudo apt-get install -y libkrb5-dev - - name: Install dependencies - run: poetry install --no-interaction --all-extras - #---------------------------------------------- - # run test suite - #---------------------------------------------- + - name: Setup Poetry + uses: ./.github/actions/setup-poetry + with: + python-version: "3.10" + install-args: "--all-extras" - name: Run non-telemetry e2e tests run: | - # Exclude all telemetry tests - they run in separate job for isolation poetry run python -m pytest tests/e2e \ --ignore=tests/e2e/test_telemetry_e2e.py \ --ignore=tests/e2e/test_concurrent_telemetry.py \ -n auto run-telemetry-tests: - runs-on: ubuntu-latest - needs: run-non-telemetry-tests # Run after non-telemetry tests complete + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest + needs: run-non-telemetry-tests environment: azure-prod env: DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} @@ -83,32 +56,16 @@ jobs: steps: - name: Check out repository uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - name: Set up python - id: setup-python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 - with: - python-version: "3.10" - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y libkrb5-dev - - name: Install Poetry - uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1 + - name: Setup Poetry + uses: ./.github/actions/setup-poetry with: - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - - name: Install dependencies - run: poetry install --no-interaction --all-extras + python-version: "3.10" + install-args: "--all-extras" - name: Run telemetry tests in isolation run: | - # Run test_concurrent_telemetry.py in isolation with complete process separation - # Use --dist=loadgroup to respect @pytest.mark.xdist_group markers poetry run python -m pytest tests/e2e/test_concurrent_telemetry.py \ - -n auto --dist=loadgroup -v \ No newline at end of file + -n auto --dist=loadgroup -v From 32c446b4e18487828bde5f97c9e26724fbbc4229 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 14 Apr 2026 01:05:16 +0530 Subject: [PATCH 32/39] [PECOBLR-1928] Add AI coding agent detection to User-Agent header (#740) Add AI coding agent detection to User-Agent header Detect when the Python SQL connector is invoked by an AI coding agent (e.g. Claude Code, Cursor, Gemini CLI) by checking well-known environment variables, and append `agent/` to the User-Agent string. This enables Databricks to understand how much driver usage originates from AI coding agents. Detection only succeeds when exactly one agent is detected to avoid ambiguous attribution. Mirrors the approach in databricks/cli#4287. Signed-off-by: Vikrant Puppala Co-authored-by: Claude Opus 4.6 --- src/databricks/sql/common/agent.py | 52 ++++++++++++++++++++++++++++++ src/databricks/sql/session.py | 5 +++ src/databricks/sql/utils.py | 6 ++++ tests/unit/test_agent_detection.py | 51 +++++++++++++++++++++++++++++ 4 files changed, 114 insertions(+) create mode 100644 src/databricks/sql/common/agent.py create mode 100644 tests/unit/test_agent_detection.py diff --git a/src/databricks/sql/common/agent.py b/src/databricks/sql/common/agent.py new file mode 100644 index 000000000..79d1b2b7a --- /dev/null +++ b/src/databricks/sql/common/agent.py @@ -0,0 +1,52 @@ +""" +Detects whether the Python SQL connector is being invoked by an AI coding agent +by checking for well-known environment variables that agents set in their spawned +shell processes. + +Detection only succeeds when exactly one agent environment variable is present, +to avoid ambiguous attribution when multiple agent environments overlap. + +Adding a new agent requires only a new entry in KNOWN_AGENTS. + +References for each environment variable: + - ANTIGRAVITY_AGENT: Closed source. Google Antigravity sets this variable. + - CLAUDECODE: https://github.com/anthropics/claude-code (sets CLAUDECODE=1) + - CLINE_ACTIVE: https://github.com/cline/cline (shipped in v3.24.0) + - CODEX_CI: https://github.com/openai/codex (part of UNIFIED_EXEC_ENV array in codex-rs) + - CURSOR_AGENT: Closed source. Referenced in a gist by johnlindquist. + - GEMINI_CLI: https://google-gemini.github.io/gemini-cli/docs/tools/shell.html (sets GEMINI_CLI=1) + - OPENCODE: https://github.com/opencode-ai/opencode (sets OPENCODE=1) +""" + +import os + +KNOWN_AGENTS = [ + ("ANTIGRAVITY_AGENT", "antigravity"), + ("CLAUDECODE", "claude-code"), + ("CLINE_ACTIVE", "cline"), + ("CODEX_CI", "codex"), + ("CURSOR_AGENT", "cursor"), + ("GEMINI_CLI", "gemini-cli"), + ("OPENCODE", "opencode"), +] + + +def detect(env=None): + """Detect which AI coding agent (if any) is driving the current process. + + Args: + env: Optional dict-like object for environment variable lookup. + Defaults to os.environ. Exists for testability. + + Returns: + The agent product string if exactly one agent is detected, + or an empty string otherwise. + """ + if env is None: + env = os.environ + + detected = [product for var, product in KNOWN_AGENTS if env.get(var)] + + if len(detected) == 1: + return detected[0] + return "" diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 0f723d144..1588d9f79 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -13,6 +13,7 @@ from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.agent import detect as detect_agent logger = logging.getLogger(__name__) @@ -64,6 +65,10 @@ def __init__( else: self.useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) + agent_product = detect_agent() + if agent_product: + self.useragent_header += " agent/{}".format(agent_product) + base_headers = [("User-Agent", self.useragent_header)] all_headers = (http_headers or []) + base_headers diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index b1fff7202..ce2670969 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -957,12 +957,18 @@ def build_client_context(server_hostname: str, version: str, **kwargs): ) # Build user agent + from databricks.sql.common.agent import detect as detect_agent + user_agent_entry = kwargs.get("user_agent_entry", "") if user_agent_entry: user_agent = f"PyDatabricksSqlConnector/{version} ({user_agent_entry})" else: user_agent = f"PyDatabricksSqlConnector/{version}" + agent_product = detect_agent() + if agent_product: + user_agent += f" agent/{agent_product}" + # Explicitly construct ClientContext with proper types return ClientContext( hostname=server_hostname, diff --git a/tests/unit/test_agent_detection.py b/tests/unit/test_agent_detection.py new file mode 100644 index 000000000..0be404a1d --- /dev/null +++ b/tests/unit/test_agent_detection.py @@ -0,0 +1,51 @@ +import pytest +from databricks.sql.common.agent import detect, KNOWN_AGENTS + + +class TestAgentDetection: + def test_detects_single_agent_claude_code(self): + assert detect({"CLAUDECODE": "1"}) == "claude-code" + + def test_detects_single_agent_cursor(self): + assert detect({"CURSOR_AGENT": "1"}) == "cursor" + + def test_detects_single_agent_gemini_cli(self): + assert detect({"GEMINI_CLI": "1"}) == "gemini-cli" + + def test_detects_single_agent_cline(self): + assert detect({"CLINE_ACTIVE": "1"}) == "cline" + + def test_detects_single_agent_codex(self): + assert detect({"CODEX_CI": "1"}) == "codex" + + def test_detects_single_agent_opencode(self): + assert detect({"OPENCODE": "1"}) == "opencode" + + def test_detects_single_agent_antigravity(self): + assert detect({"ANTIGRAVITY_AGENT": "1"}) == "antigravity" + + def test_returns_empty_when_no_agent_detected(self): + assert detect({}) == "" + + def test_returns_empty_when_multiple_agents_detected(self): + assert detect({"CLAUDECODE": "1", "CURSOR_AGENT": "1"}) == "" + + def test_ignores_empty_env_var_values(self): + assert detect({"CLAUDECODE": ""}) == "" + + def test_all_known_agents_are_covered(self): + for env_var, product in KNOWN_AGENTS: + assert detect({env_var: "1"}) == product, ( + f"Agent with env var {env_var} should be detected as {product}" + ) + + def test_defaults_to_os_environ(self, monkeypatch): + monkeypatch.delenv("CLAUDECODE", raising=False) + monkeypatch.delenv("CURSOR_AGENT", raising=False) + monkeypatch.delenv("GEMINI_CLI", raising=False) + monkeypatch.delenv("CLINE_ACTIVE", raising=False) + monkeypatch.delenv("CODEX_CI", raising=False) + monkeypatch.delenv("OPENCODE", raising=False) + monkeypatch.delenv("ANTIGRAVITY_AGENT", raising=False) + # With all agent vars cleared, detect() should return empty + assert detect() == "" From c46b3a0f9fb8a41abb59281be6fcdecb83ed8aa1 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 14 Apr 2026 11:50:51 +0530 Subject: [PATCH 33/39] =?UTF-8?q?Optimize=20CI:=20consolidate=20workflows,?= =?UTF-8?q?=20fix=20caching,=20speed=20up=20e2e=20tests=20(47min=20?= =?UTF-8?q?=E2=86=92=2015min)=20(#772)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Optimize CI: consolidate workflows, fix caching, speed up e2e tests Workflow consolidation: - Delete integration.yml and daily-telemetry-e2e.yml (redundant with coverage workflow which already runs all e2e tests) - Add push-to-main trigger to coverage workflow - Run all tests (including telemetry) in single pytest invocation with --dist=loadgroup to respect xdist_group markers for isolation Fix pyarrow cache: - Remove cache-path: .venv-pyarrow from pyarrow jobs. Poetry always creates .venv regardless of the cache-path input, so the cache was never saved ("Path does not exist" error). The cache-suffix already differentiates keys between variants. Fix 3.14 post-test DNS hang: - Add enable_telemetry=False to unit test DUMMY_CONNECTION_ARGS that use server_hostname="foo". This prevents FeatureFlagsContext from making real HTTP calls to fake hosts, eliminating ~8min hang from ThreadPoolExecutor threads timing out on DNS on protected runners. Improve e2e test parallelization: - Split TestPySQLLargeQueriesSuite into 3 separate classes (TestPySQLLargeWideResultSet, TestPySQLLargeNarrowResultSet, TestPySQLLongRunningQuery) so xdist distributes them across workers instead of all landing on one. Speed up slow tests: - Reduce large result set sizes from 300MB to 100MB (still validates large fetches, lz4, chunking, row integrity) - Start test_long_running_query at scale_factor=50 instead of 1 to skip ramp-up iterations that finish instantly Co-authored-by: Isaac Signed-off-by: Vikrant Puppala * Further optimize e2e: 4 workers, lower long-query threshold, split lz4 - Use -n 4 instead of -n auto in coverage workflow. The e2e tests are network-bound (waiting on warehouse), not CPU-bound, so 4 workers on a 2-CPU runner is fine and doubles parallelism. - Lower test_long_running_query min_duration from 3 min to 1 min. The test validates long-running query completion — 1 minute is sufficient and saves ~4 min per variant. - Split lz4 on/off loop in test_query_with_large_wide_result_set into separate parametrized test cases so xdist can run them on different workers instead of sequentially in one test. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala * Address review: inline test methods, drop mixin pattern Per review feedback from jprakash-db: - Remove mixin classes (LargeWideResultSetMixin, etc) — inline the test methods directly into the test classes in test_driver.py - Remove backward-compat LargeQueriesMixin alias (nothing uses it) - Rename _LargeQueryRowHelper — replaced entirely by inlining - Convert large_queries_mixin.py to just a fetch_rows() helper function Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --------- Signed-off-by: Vikrant Puppala --- .github/workflows/code-coverage.yml | 24 ++-- .github/workflows/code-quality-checks.yml | 1 - .github/workflows/daily-telemetry-e2e.yml | 58 -------- .github/workflows/integration.yml | 71 ---------- tests/e2e/common/large_queries_mixin.py | 158 +++++----------------- tests/e2e/test_driver.py | 89 ++++++++++-- tests/unit/test_client.py | 2 + tests/unit/test_session.py | 3 + 8 files changed, 123 insertions(+), 283 deletions(-) delete mode 100644 .github/workflows/daily-telemetry-e2e.yml delete mode 100644 .github/workflows/integration.yml diff --git a/.github/workflows/code-coverage.yml b/.github/workflows/code-coverage.yml index c188c5b3a..9f578ec9f 100644 --- a/.github/workflows/code-coverage.yml +++ b/.github/workflows/code-coverage.yml @@ -1,10 +1,15 @@ -name: Code Coverage +name: E2E Tests and Code Coverage permissions: contents: read id-token: write -on: [pull_request, workflow_dispatch] +on: + push: + branches: + - main + pull_request: + workflow_dispatch: jobs: test-with-coverage: @@ -32,25 +37,16 @@ jobs: with: python-version: "3.10" install-args: "--all-extras" - - name: Run parallel tests with coverage + - name: Run all tests with coverage continue-on-error: false run: | poetry run pytest tests/unit tests/e2e \ - -m "not serial" \ - -n auto \ + -n 4 \ + --dist=loadgroup \ --cov=src \ --cov-report=xml \ --cov-report=term \ -v - - name: Run telemetry tests with coverage (isolated) - continue-on-error: false - run: | - poetry run pytest tests/e2e/test_concurrent_telemetry.py \ - --cov=src \ - --cov-append \ - --cov-report=xml \ - --cov-report=term \ - -v - name: Check for coverage override id: override env: diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index ecc238263..4071a6e51 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -78,7 +78,6 @@ jobs: with: python-version: ${{ matrix.python-version }} install-args: "--all-extras" - cache-path: ".venv-pyarrow" cache-suffix: "pyarrow-${{ matrix.dependency-version }}-" - name: Install Python tools for custom versions if: matrix.dependency-version != 'default' diff --git a/.github/workflows/daily-telemetry-e2e.yml b/.github/workflows/daily-telemetry-e2e.yml deleted file mode 100644 index b6f78726c..000000000 --- a/.github/workflows/daily-telemetry-e2e.yml +++ /dev/null @@ -1,58 +0,0 @@ -name: Daily Telemetry E2E Tests - -on: - schedule: - - cron: '0 0 * * 0' # Run every Sunday at midnight UTC - - workflow_dispatch: # Allow manual triggering - inputs: - test_pattern: - description: 'Test pattern to run (default: tests/e2e/test_telemetry_e2e.py)' - required: false - default: 'tests/e2e/test_telemetry_e2e.py' - type: string - -permissions: - contents: read - id-token: write - -jobs: - telemetry-e2e-tests: - runs-on: - group: databricks-protected-runner-group - labels: linux-ubuntu-latest - environment: azure-prod - - env: - DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} - DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} - DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} - DATABRICKS_CATALOG: peco - DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} - - steps: - - name: Check out repository - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Setup Poetry - uses: ./.github/actions/setup-poetry - with: - python-version: "3.10" - install-args: "--all-extras" - - name: Run telemetry E2E tests - run: | - TEST_PATTERN="${{ github.event.inputs.test_pattern || 'tests/e2e/test_telemetry_e2e.py' }}" - echo "Running tests: $TEST_PATTERN" - poetry run python -m pytest $TEST_PATTERN -v -s - - name: Upload test results on failure - if: failure() - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4 - with: - name: telemetry-test-results - path: | - .pytest_cache/ - tests-unsafe.log - retention-days: 7 diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml deleted file mode 100644 index 6c0cc7059..000000000 --- a/.github/workflows/integration.yml +++ /dev/null @@ -1,71 +0,0 @@ -name: Integration Tests - -on: - push: - branches: - - main - pull_request: - -permissions: - contents: read - id-token: write - -jobs: - run-non-telemetry-tests: - runs-on: - group: databricks-protected-runner-group - labels: linux-ubuntu-latest - environment: azure-prod - env: - DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} - DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} - DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} - DATABRICKS_CATALOG: peco - DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} - steps: - - name: Check out repository - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - name: Install Kerberos system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Setup Poetry - uses: ./.github/actions/setup-poetry - with: - python-version: "3.10" - install-args: "--all-extras" - - name: Run non-telemetry e2e tests - run: | - poetry run python -m pytest tests/e2e \ - --ignore=tests/e2e/test_telemetry_e2e.py \ - --ignore=tests/e2e/test_concurrent_telemetry.py \ - -n auto - - run-telemetry-tests: - runs-on: - group: databricks-protected-runner-group - labels: linux-ubuntu-latest - needs: run-non-telemetry-tests - environment: azure-prod - env: - DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} - DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} - DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} - DATABRICKS_CATALOG: peco - DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} - steps: - - name: Check out repository - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - name: Install system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libkrb5-dev - - name: Setup Poetry - uses: ./.github/actions/setup-poetry - with: - python-version: "3.10" - install-args: "--all-extras" - - name: Run telemetry tests in isolation - run: | - poetry run python -m pytest tests/e2e/test_concurrent_telemetry.py \ - -n auto --dist=loadgroup -v diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index dd7c56996..7255ee095 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -2,139 +2,43 @@ import math import time -import pytest - log = logging.getLogger(__name__) -class LargeQueriesMixin: +def fetch_rows(test_case, cursor, row_count, fetchmany_size): """ - This mixin expects to be mixed with a CursorTest-like class + A generator for rows. Fetches until the end or up to 5 minutes. """ - - def fetch_rows(self, cursor, row_count, fetchmany_size): - """ - A generator for rows. Fetches until the end or up to 5 minutes. - """ - # TODO: Remove fetchmany_size when we have fixed the performance issues with fetchone - # in the Python client - max_fetch_time = 5 * 60 # Fetch for at most 5 minutes - - rows = self.get_some_rows(cursor, fetchmany_size) - start_time = time.time() - n = 0 - while rows: - for row in rows: - n += 1 - yield row - if time.time() - start_time >= max_fetch_time: - log.warning("Fetching rows timed out") - break - rows = self.get_some_rows(cursor, fetchmany_size) - if not rows: - # Read all the rows, row_count should match - self.assertEqual(n, row_count) - - num_fetches = max(math.ceil(n / 10000), 1) - latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1 - print( - "Fetched {} rows with an avg latency of {} per fetch, ".format( - n, latency_ms - ) - + "assuming 10K fetch size." + max_fetch_time = 5 * 60 # Fetch for at most 5 minutes + + rows = _get_some_rows(cursor, fetchmany_size) + start_time = time.time() + n = 0 + while rows: + for row in rows: + n += 1 + yield row + if time.time() - start_time >= max_fetch_time: + log.warning("Fetching rows timed out") + break + rows = _get_some_rows(cursor, fetchmany_size) + if not rows: + # Read all the rows, row_count should match + test_case.assertEqual(n, row_count) + + num_fetches = max(math.ceil(n / 10000), 1) + latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1 + print( + "Fetched {} rows with an avg latency of {} per fetch, ".format( + n, latency_ms ) - - @pytest.mark.parametrize( - "extra_params", - [ - {}, - {"use_sea": True}, - ], + + "assuming 10K fetch size." ) - def test_query_with_large_wide_result_set(self, extra_params): - resultSize = 300 * 1000 * 1000 # 300 MB - width = 8192 # B - rows = resultSize // width - cols = width // 36 - - # Set the fetchmany_size to get 10MB of data a go - fetchmany_size = 10 * 1024 * 1024 // width - # This is used by PyHive tests to determine the buffer size - self.arraysize = 1000 - with self.cursor(extra_params) as cursor: - for lz4_compression in [False, True]: - cursor.connection.lz4_compression = lz4_compression - uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) - cursor.execute( - "SELECT id, {uuids} FROM RANGE({rows})".format( - uuids=uuids, rows=rows - ) - ) - assert lz4_compression == cursor.active_result_set.lz4_compressed - for row_id, row in enumerate( - self.fetch_rows(cursor, rows, fetchmany_size) - ): - assert row[0] == row_id # Verify no rows are dropped in the middle. - assert len(row[1]) == 36 - - @pytest.mark.parametrize( - "extra_params", - [ - {}, - {"use_sea": True}, - ], - ) - def test_query_with_large_narrow_result_set(self, extra_params): - resultSize = 300 * 1000 * 1000 # 300 MB - width = 8 # sizeof(long) - rows = resultSize / width - - # Set the fetchmany_size to get 10MB of data a go - fetchmany_size = 10 * 1024 * 1024 // width - # This is used by PyHive tests to determine the buffer size - self.arraysize = 10000000 - with self.cursor(extra_params) as cursor: - cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows)) - for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): - assert row[0] == row_id - - @pytest.mark.parametrize( - "extra_params", - [ - {}, - {"use_sea": True}, - ], - ) - def test_long_running_query(self, extra_params): - """Incrementally increase query size until it takes at least 3 minutes, - and asserts that the query completes successfully. - """ - minutes = 60 - min_duration = 3 * minutes - - duration = -1 - scale0 = 10000 - scale_factor = 1 - with self.cursor(extra_params) as cursor: - while duration < min_duration: - assert scale_factor < 4096, "Detected infinite loop" - start = time.time() - - cursor.execute( - """SELECT count(*) - FROM RANGE({scale}) x - JOIN RANGE({scale0}) y - ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%" - """.format( - scale=scale_factor * scale0, scale0=scale0 - ) - ) - (n,) = cursor.fetchone() - assert n == 0 - duration = time.time() - start - current_fraction = duration / min_duration - print("Took {} s with scale factor={}".format(duration, scale_factor)) - # Extrapolate linearly to reach 3 min and add 50% padding to push over the limit - scale_factor = math.ceil(1.5 * scale_factor / current_fraction) +def _get_some_rows(cursor, fetchmany_size): + row = cursor.fetchone() + if row: + return [row] + else: + return None diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index e04e348c9..45b56ae08 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -39,7 +39,7 @@ ) from databricks.sql.thrift_api.TCLIService import ttypes from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin -from tests.e2e.common.large_queries_mixin import LargeQueriesMixin +from tests.e2e.common.large_queries_mixin import fetch_rows from tests.e2e.common.timestamp_tests import TimestampTestsMixin from tests.e2e.common.decimal_tests import DecimalTestsMixin from tests.e2e.common.retry_test_mixins import ( @@ -138,24 +138,89 @@ def assertEqualRowValues(self, actual, expected): assert act[i] == exp[i] -class TestPySQLLargeQueriesSuite(PySQLPytestTestCase, LargeQueriesMixin): - def get_some_rows(self, cursor, fetchmany_size): - row = cursor.fetchone() - if row: - return [row] - else: - return None +class TestPySQLLargeWideResultSet(PySQLPytestTestCase): + @pytest.mark.parametrize("extra_params", [{}, {"use_sea": True}]) + @pytest.mark.parametrize("lz4_compression", [False, True]) + def test_query_with_large_wide_result_set(self, extra_params, lz4_compression): + resultSize = 100 * 1000 * 1000 # 100 MB + width = 8192 # B + rows = resultSize // width + cols = width // 36 + fetchmany_size = 10 * 1024 * 1024 // width + self.arraysize = 1000 + with self.cursor(extra_params) as cursor: + cursor.connection.lz4_compression = lz4_compression + uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) + cursor.execute( + "SELECT id, {uuids} FROM RANGE({rows})".format( + uuids=uuids, rows=rows + ) + ) + assert lz4_compression == cursor.active_result_set.lz4_compressed + for row_id, row in enumerate( + fetch_rows(self, cursor, rows, fetchmany_size) + ): + assert row[0] == row_id + assert len(row[1]) == 36 + + +class TestPySQLLargeNarrowResultSet(PySQLPytestTestCase): + @pytest.mark.parametrize("extra_params", [{}, {"use_sea": True}]) + def test_query_with_large_narrow_result_set(self, extra_params): + resultSize = 100 * 1000 * 1000 # 100 MB + width = 8 # sizeof(long) + rows = resultSize / width + fetchmany_size = 10 * 1024 * 1024 // width + self.arraysize = 10000000 + with self.cursor(extra_params) as cursor: + cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows)) + for row_id, row in enumerate( + fetch_rows(self, cursor, rows, fetchmany_size) + ): + assert row[0] == row_id + + +class TestPySQLLongRunningQuery(PySQLPytestTestCase): + @pytest.mark.parametrize("extra_params", [{}, {"use_sea": True}]) + def test_long_running_query(self, extra_params): + """Incrementally increase query size until it takes at least 1 minute, + and asserts that the query completes successfully. + """ + import math + + minutes = 60 + min_duration = 1 * minutes + duration = -1 + scale0 = 10000 + scale_factor = 50 + with self.cursor(extra_params) as cursor: + while duration < min_duration: + assert scale_factor < 4096, "Detected infinite loop" + start = time.time() + cursor.execute( + """SELECT count(*) + FROM RANGE({scale}) x + JOIN RANGE({scale0}) y + ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%" + """.format( + scale=scale_factor * scale0, scale0=scale0 + ) + ) + (n,) = cursor.fetchone() + assert n == 0 + duration = time.time() - start + current_fraction = duration / min_duration + print("Took {} s with scale factor={}".format(duration, scale_factor)) + scale_factor = math.ceil(1.5 * scale_factor / current_fraction) + +class TestPySQLCloudFetch(PySQLPytestTestCase): @skipUnless(pysql_supports_arrow(), "needs arrow support") @pytest.mark.skip("This test requires a previously uploaded data set") def test_cloud_fetch(self): - # This test can take several minutes to run limits = [100000, 300000] threads = [10, 25] self.arraysize = 100000 - # This test requires a large table with many rows to properly initiate cloud fetch. - # e2-dogfood host > hive_metastore catalog > main schema has such a table called store_sales. - # If this table is deleted or this test is run on a different host, a different table may need to be used. base_query = "SELECT * FROM store_sales WHERE ss_sold_date_sk = 2452234 " for num_limit, num_threads, lz4_compression in itertools.product( limits, threads, [True, False] diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5b6991931..4a8cb0b68 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -87,6 +87,7 @@ class ClientTestSuite(unittest.TestCase): "server_hostname": "foo", "http_path": "dummy_path", "access_token": "tok", + "enable_telemetry": False, } @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) @@ -644,6 +645,7 @@ class TransactionTestSuite(unittest.TestCase): "server_hostname": "foo", "http_path": "dummy_path", "access_token": "tok", + "enable_telemetry": False, } def _setup_mock_session_with_http_client(self, mock_session): diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 3a43c1a75..aa7e7f02b 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -22,6 +22,7 @@ class TestSession: "server_hostname": "foo", "http_path": "dummy_path", "access_token": "tok", + "enable_telemetry": False, } @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) @@ -50,6 +51,7 @@ def test_auth_args(self, mock_client_class): "server_hostname": "foo", "http_path": None, "access_token": "tok", + "enable_telemetry": False, }, { "server_hostname": "foo", @@ -57,6 +59,7 @@ def test_auth_args(self, mock_client_class): "_tls_client_cert_file": "something", "_use_cert_as_auth": True, "access_token": None, + "enable_telemetry": False, }, ] From 9031863fb6380730e7b58b223739ee9028c2a720 Mon Sep 17 00:00:00 2001 From: Korijn van Golen Date: Mon, 20 Apr 2026 05:37:56 +0200 Subject: [PATCH 34/39] Bump thrift to fix deprecation warning (#733) Signed-off-by: Korijn van Golen --- poetry.lock | 21 +++++++++------------ pyproject.toml | 2 +- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/poetry.lock b/poetry.lock index 7d0845a58..5644190f4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.1 and should not be changed by hand. [[package]] name = "astroid" @@ -1158,15 +1158,15 @@ pytz = ">=2020.1" tzdata = ">=2022.1" [package.extras] -all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"] -aws = ["s3fs (>=2021.08.0)"] +all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.7.0)", "gcsfs (>=2021.7.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.8.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"] +aws = ["s3fs (>=2021.8.0)"] clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"] compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"] computation = ["scipy (>=1.7.1)", "xarray (>=0.21.0)"] excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pyxlsb (>=1.0.8)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)"] feather = ["pyarrow (>=7.0.0)"] -fss = ["fsspec (>=2021.07.0)"] -gcp = ["gcsfs (>=2021.07.0)", "pandas-gbq (>=0.15.0)"] +fss = ["fsspec (>=2021.7.0)"] +gcp = ["gcsfs (>=2021.7.0)", "pandas-gbq (>=0.15.0)"] hdf5 = ["tables (>=3.6.1)"] html = ["beautifulsoup4 (>=4.9.3)", "html5lib (>=1.1)", "lxml (>=4.6.3)"] mysql = ["SQLAlchemy (>=1.4.16)", "pymysql (>=1.0.2)"] @@ -1523,7 +1523,7 @@ files = [ ] [package.dependencies] -astroid = ">=3.2.4,<=3.3.0-dev0" +astroid = ">=3.2.4,<=3.3.0.dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, @@ -1849,18 +1849,15 @@ files = [ [[package]] name = "thrift" -version = "0.20.0" +version = "0.22.0" description = "Python bindings for the Apache Thrift RPC system" optional = false python-versions = "*" groups = ["main"] files = [ - {file = "thrift-0.20.0.tar.gz", hash = "sha256:4dd662eadf6b8aebe8a41729527bd69adf6ceaa2a8681cbef64d1273b3e8feba"}, + {file = "thrift-0.22.0.tar.gz", hash = "sha256:42e8276afbd5f54fe1d364858b6877bc5e5a4a5ed69f6a005b94ca4918fe1466"}, ] -[package.dependencies] -six = ">=1.7.2" - [package.extras] all = ["tornado (>=4.0)", "twisted"] tornado = ["tornado (>=4.0)"] @@ -1969,4 +1966,4 @@ pyarrow = ["pyarrow", "pyarrow", "pyarrow"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "ec311bf26ec866de2f427bcdf4ec69ceed721bfd70edfae3aba1ac12882a09d6" +content-hash = "d1739e84dcbd6e7ac311eb6fbb9cf87ad110491f7d954f07fdfc32b704b4413f" diff --git a/pyproject.toml b/pyproject.toml index 911f1b79c..e1ce3b73f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ include = ["CHANGELOG.md"] [tool.poetry.dependencies] python = "^3.8.0" -thrift = ">=0.16.0,<0.21.0" +thrift = "~=0.22.0" pandas = [ { version = ">=1.2.5,<2.4.0", python = ">=3.8,<3.13" }, { version = ">=2.2.3,<2.4.0", python = ">=3.13" } From b088a356ade14ab474e88e2978bfc1052adbd248 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 21 Apr 2026 12:34:58 +0530 Subject: [PATCH 35/39] Fix dependency_manager: handle PEP 440 ~= compatible release syntax (#776) The _extract_versions_from_specifier function stripped a single `~` character from constraint strings, which corrupted PEP 440 compatible release syntax (`~=`) by leaving a stray `=`. For example, `thrift = "~=0.22.0"` produced the invalid constraint `thrift>==0.22.0,<=0.23.0`, breaking every PR's "Unit Tests (min deps)" job since #733 was merged. Add an explicit branch for `~=` that strips both characters before extracting the minimum version. The Poetry-style single `~` branch is preserved for backward compatibility. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- scripts/dependency_manager.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/scripts/dependency_manager.py b/scripts/dependency_manager.py index 15e119841..29c5fe828 100644 --- a/scripts/dependency_manager.py +++ b/scripts/dependency_manager.py @@ -69,16 +69,21 @@ def _parse_constraint(self, name, constraint): def _extract_versions_from_specifier(self, spec_set_str): """Extract minimum version from a specifier set""" try: - # Handle caret (^) and tilde (~) constraints that packaging doesn't support + # Handle caret (^) and tilde (~, ~=) constraints that packaging doesn't + # support (Poetry ^, Poetry ~, and PEP 440 ~=). if spec_set_str.startswith('^'): # ^1.2.3 means >=1.2.3, <2.0.0 min_version = spec_set_str[1:] # Remove ^ return min_version, None + elif spec_set_str.startswith('~='): + # PEP 440 compatible release: ~=1.2.3 means >=1.2.3, <1.3.0 + min_version = spec_set_str[2:] # Remove ~= + return min_version, None elif spec_set_str.startswith('~'): - # ~1.2.3 means >=1.2.3, <1.3.0 + # Poetry tilde: ~1.2.3 means >=1.2.3, <1.3.0 min_version = spec_set_str[1:] # Remove ~ return min_version, None - + spec_set = SpecifierSet(spec_set_str) min_version = None From d872075e2127858379267a11f37d44b5273640e3 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 21 Apr 2026 15:21:17 +0530 Subject: [PATCH 36/39] [PECOBLR-2461] Add comprehensive MST transaction E2E tests (#775) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add comprehensive MST transaction E2E tests Replaces the prior speculative test skeleton with 42 tests across 5 categories: - TestMstCorrectness (18): commit/rollback/isolation/multi-table atomicity/repeatable reads/write conflict/parameterized DML/etc. - TestMstApi (6): DB-API-specific — autocommit, isolation level, error handling. - TestMstMetadata (6): cursor.columns/tables/schemas/catalogs inside a transaction, plus two freshness tests asserting Thrift metadata RPCs are non-transactional (they see concurrent DDL that the txn should not see). - TestMstBlockedSql (9): MSTCheckRule enforcement. Some SHOW/DESCRIBE commands throw + abort txn, others succeed silently on Python/Thrift (diverges from JDBC). Both behaviors are explicitly tested so regressions in either direction are caught. - TestMstExecuteVariants (2): executemany commit/rollback. Parallelisation: - Each test uses a unique Delta table derived from its test name so pytest-xdist workers don't collide on shared state. - Tests that spawn concurrent connections to the same table (repeatable reads, write conflict, freshness) use xdist_group so the concurrent connections within a single test don't conflict with other tests on different workers. Runtime: ~2 minutes on 4 workers (pytest -n 4 --dist=loadgroup), well within the existing e2e budget. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala * Fix TestMstBlockedSql: SHOW COLUMNS and DESCRIBE QUERY are blocked CI caught that the initial "not blocked" assertions were wrong — the server returns TRANSACTION_NOT_SUPPORTED.COMMAND for SHOW COLUMNS (ShowDeltaTableColumnsCommand) and DESCRIBE QUERY (DescribeQueryCommand) inside an active transaction. The server's error message explicitly lists the allowed commands: "Only SELECT / INSERT / MERGE / UPDATE / DELETE / DESCRIBE TABLE are supported." DESCRIBE TABLE (basic) remains the only DESCRIBE variant that is allowed. Earlier dogfood runs showed SHOW COLUMNS / DESCRIBE QUERY succeeding — likely because the dogfood warehouse DBR is older than CI. Aligning tests with the current/CI server behavior. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala * Address PR review comments - test_auto_start_after_commit: assert the rolled-back id=2 is NOT present (use _get_ids set equality instead of just row count). - test_auto_start_after_rollback: same pattern — assert the rolled-back id=1 is NOT present. - test_commit_without_active_txn_throws: match specific NO_ACTIVE_TRANSACTION server error code to ensure we're catching the right exception, not an unrelated one. Add _get_ids() helper for checking the exact set of persisted ids. Verified 42/42 pass against pecotesting in ~1:36 (4 workers). Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --------- Signed-off-by: Vikrant Puppala --- tests/e2e/test_transactions.py | 1214 ++++++++++++++++++-------------- 1 file changed, 692 insertions(+), 522 deletions(-) diff --git a/tests/e2e/test_transactions.py b/tests/e2e/test_transactions.py index d4f6a790a..9bf74baad 100644 --- a/tests/e2e/test_transactions.py +++ b/tests/e2e/test_transactions.py @@ -1,598 +1,768 @@ """ End-to-end integration tests for Multi-Statement Transaction (MST) APIs. -These tests verify: -- autocommit property (getter/setter) -- commit() and rollback() methods -- get_transaction_isolation() and set_transaction_isolation() methods -- Transaction error handling +Tests driver behavior for MST across: +- Basic correctness (commit/rollback/isolation/multi-table) +- API-specific (autocommit, isolation level, error handling) +- Metadata RPCs inside transactions (non-transactional freshness) +- SQL statements blocked by MSTCheckRule (SHOW, DESCRIBE, information_schema) +- Execute variants (executemany) + +Parallelisation: +- Each test uses its own unique table (derived from test name) to allow + parallel execution with pytest-xdist. +- Tests requiring multiple concurrent connections to the same table are + tagged with xdist_group so the concurrent connections within a single + test don't conflict with other tests on different workers. Requirements: - DBSQL warehouse that supports Multi-Statement Transactions (MST) -- Test environment configured via test.env file or environment variables - -Setup: -Set the following environment variables: -- DATABRICKS_SERVER_HOSTNAME -- DATABRICKS_HTTP_PATH -- DATABRICKS_ACCESS_TOKEN (or use OAuth) - -Usage: - pytest tests/e2e/test_transactions.py -v +- Env vars: DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, + DATABRICKS_TOKEN, DATABRICKS_CATALOG, DATABRICKS_SCHEMA """ import logging import os +import re +import uuid + import pytest -from typing import Any, Dict import databricks.sql as sql -from databricks.sql import TransactionError, NotSupportedError, InterfaceError logger = logging.getLogger(__name__) -@pytest.mark.skip( - reason="Test environment does not yet support multi-statement transactions" -) -class TestTransactions: - """E2E tests for transaction control methods (MST support).""" +def _unique_table_name(request): + """Derive a unique Delta table name from the test node id.""" + node_id = request.node.name + sanitized = re.sub(r"[^a-z0-9_]", "_", node_id.lower()) + return f"mst_pysql_{sanitized}"[:80] - # Test table name - TEST_TABLE_NAME = "transaction_test_table" - @pytest.fixture(autouse=True) - def setup_and_teardown(self, connection_details): - """Setup test environment before each test and cleanup after.""" - self.connection_params = { - "server_hostname": connection_details["host"], - "http_path": connection_details["http_path"], - "access_token": connection_details.get("access_token"), - "ignore_transactions": False, # Enable actual transaction functionality for these tests - } +def _unique_table_name_raw(suffix): + """Non-fixture unique table name helper for extra tables within a test.""" + return f"mst_pysql_{suffix}_{uuid.uuid4().hex[:8]}" - # Get catalog and schema from environment or use defaults - self.catalog = os.getenv("DATABRICKS_CATALOG", "main") - self.schema = os.getenv("DATABRICKS_SCHEMA", "default") - # Create connection for setup - self.connection = sql.connect(**self.connection_params) +@pytest.fixture +def mst_conn_params(connection_details): + """Connection parameters with MST enabled.""" + return { + "server_hostname": connection_details["host"], + "http_path": connection_details["http_path"], + "access_token": connection_details.get("access_token"), + "ignore_transactions": False, + } - # Setup: Create test table - self._create_test_table() - yield +@pytest.fixture +def mst_catalog(connection_details): + return connection_details.get("catalog") or os.getenv("DATABRICKS_CATALOG") or "main" - # Teardown: Cleanup - self._cleanup() - def _get_fully_qualified_table_name(self) -> str: - """Get the fully qualified table name.""" - return f"{self.catalog}.{self.schema}.{self.TEST_TABLE_NAME}" +@pytest.fixture +def mst_schema(connection_details): + return connection_details.get("schema") or os.getenv("DATABRICKS_SCHEMA") or "default" - def _create_test_table(self): - """Create the test table with Delta format and MST support.""" - fq_table_name = self._get_fully_qualified_table_name() - cursor = self.connection.cursor() - try: - # Drop if exists - cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") +@pytest.fixture +def mst_table(request, mst_conn_params, mst_catalog, mst_schema): + """Create a fresh Delta table for the test and drop it afterwards. + + Yields (fq_table_name, table_name). The table is unique per test so tests + can run in parallel without stepping on each other. + """ + table_name = _unique_table_name(request) + fq_table = f"{mst_catalog}.{mst_schema}.{table_name}" - # Create table with Delta and catalog-owned feature for MST compatibility + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table}") cursor.execute( - f""" - CREATE TABLE IF NOT EXISTS {fq_table_name} - (id INT, value STRING) - USING DELTA - TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') - """ + f"CREATE TABLE {fq_table} (id INT, value STRING) USING DELTA " + f"TBLPROPERTIES ('delta.feature.catalogManaged' = 'supported')" ) - logger.info(f"Created test table: {fq_table_name}") - finally: - cursor.close() - - def _cleanup(self): - """Cleanup after test: rollback pending transactions, drop table, close connection.""" - try: - # Try to rollback any pending transaction - if ( - self.connection - and self.connection.open - and not self.connection.autocommit - ): - try: - self.connection.rollback() - except Exception as e: - logger.debug( - f"Rollback during cleanup failed (may be expected): {e}" + yield fq_table, table_name + + try: + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table}") + except Exception as e: + logger.warning(f"Failed to drop {fq_table}: {e}") + + +def _get_row_count(mst_conn_params, fq_table): + """Count rows from a fresh connection (avoids in-txn caching).""" + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT COUNT(*) FROM {fq_table}") + return cursor.fetchone()[0] + + +def _get_ids(mst_conn_params, fq_table): + """Return the set of ids from a fresh connection.""" + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT id FROM {fq_table}") + return {row[0] for row in cursor.fetchall()} + + +# ==================== A. BASIC CORRECTNESS ==================== + + +class TestMstCorrectness: + """Core MST correctness: commit, rollback, isolation, multi-table.""" + + def test_commit_single_insert(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'committed')") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 1 + + def test_commit_multiple_inserts(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'a')") + cursor.execute(f"INSERT INTO {fq_table} VALUES (2, 'b')") + cursor.execute(f"INSERT INTO {fq_table} VALUES (3, 'c')") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 3 + + def test_rollback_single_insert(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'rolled_back')") + conn.rollback() + + assert _get_row_count(mst_conn_params, fq_table) == 0 + + def test_sequential_transactions(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'txn1')") + conn.commit() + + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (2, 'txn2')") + conn.rollback() + + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (3, 'txn3')") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 2 + + def test_auto_start_after_commit(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'txn1')") + conn.commit() + + # Second INSERT auto-starts a new transaction + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (2, 'txn2')") + conn.rollback() + + assert _get_ids(mst_conn_params, fq_table) == {1} + + def test_auto_start_after_rollback(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'txn1')") + conn.rollback() + + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (2, 'txn2')") + conn.commit() + + assert _get_ids(mst_conn_params, fq_table) == {2} + + def test_update_in_transaction(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'original')") + + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"UPDATE {fq_table} SET value = 'updated' WHERE id = 1") + conn.commit() + + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT value FROM {fq_table} WHERE id = 1") + assert cursor.fetchone()[0] == "updated" + + def test_delete_in_transaction(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'a')") + cursor.execute(f"INSERT INTO {fq_table} VALUES (2, 'b')") + + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"DELETE FROM {fq_table} WHERE id = 1") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 1 + + def test_multi_table_commit(self, mst_conn_params, mst_table, mst_catalog, mst_schema): + fq_table1, _ = mst_table + fq_table2 = f"{mst_catalog}.{mst_schema}.{_unique_table_name_raw('multi_commit_t2')}" + + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2}") + cursor.execute( + f"CREATE TABLE {fq_table2} (id INT, value STRING) USING DELTA " + f"TBLPROPERTIES ('delta.feature.catalogManaged' = 'supported')" + ) + try: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table1} VALUES (1, 't1')") + cursor.execute(f"INSERT INTO {fq_table2} VALUES (1, 't2')") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table1) == 1 + assert _get_row_count(mst_conn_params, fq_table2) == 1 + finally: + conn.autocommit = True + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2}") + + def test_multi_table_rollback(self, mst_conn_params, mst_table, mst_catalog, mst_schema): + fq_table1, _ = mst_table + fq_table2 = f"{mst_catalog}.{mst_schema}.{_unique_table_name_raw('multi_rb_t2')}" + + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2}") + cursor.execute( + f"CREATE TABLE {fq_table2} (id INT, value STRING) USING DELTA " + f"TBLPROPERTIES ('delta.feature.catalogManaged' = 'supported')" + ) + try: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table1} VALUES (1, 't1')") + cursor.execute(f"INSERT INTO {fq_table2} VALUES (1, 't2')") + conn.rollback() + + assert _get_row_count(mst_conn_params, fq_table1) == 0 + assert _get_row_count(mst_conn_params, fq_table2) == 0 + finally: + conn.autocommit = True + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2}") + + def test_multi_table_atomicity(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'should_rollback')") + with pytest.raises(Exception): + cursor.execute( + "INSERT INTO nonexistent_table_xyz_xyz VALUES (1, 'fail')" + ) + conn.rollback() + + assert _get_row_count(mst_conn_params, fq_table) == 0 + + @pytest.mark.xdist_group(name="mst_repeatable_reads") + def test_repeatable_reads(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'initial')") + + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"SELECT value FROM {fq_table} WHERE id = 1") + first_read = cursor.fetchone()[0] + + # External connection modifies data + with sql.connect(**mst_conn_params) as ext_conn: + with ext_conn.cursor() as ext_cursor: + ext_cursor.execute( + f"UPDATE {fq_table} SET value = 'modified' WHERE id = 1" ) - # Reset to autocommit mode - try: - self.connection.autocommit = True - except Exception as e: - logger.debug(f"Reset autocommit during cleanup failed: {e}") - - # Drop test table - if self.connection and self.connection.open: - fq_table_name = self._get_fully_qualified_table_name() - cursor = self.connection.cursor() - try: - cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") - logger.info(f"Dropped test table: {fq_table_name}") - except Exception as e: - logger.warning(f"Failed to drop test table: {e}") - finally: - cursor.close() + # Re-read in same txn — should see original value + with conn.cursor() as cursor: + cursor.execute(f"SELECT value FROM {fq_table} WHERE id = 1") + second_read = cursor.fetchone()[0] - finally: - # Close connection - if self.connection: - self.connection.close() - - # ==================== BASIC AUTOCOMMIT TESTS ==================== - - def test_default_autocommit_is_true(self): - """Test that new connection defaults to autocommit=true.""" - assert ( - self.connection.autocommit is True - ), "New connection should have autocommit=true by default" - - def test_set_autocommit_to_false(self): - """Test successfully setting autocommit to false.""" - self.connection.autocommit = False - assert ( - self.connection.autocommit is False - ), "autocommit should be false after setting to false" - - def test_set_autocommit_to_true(self): - """Test successfully setting autocommit back to true.""" - # First disable - self.connection.autocommit = False - assert self.connection.autocommit is False - - # Then enable - self.connection.autocommit = True - assert ( - self.connection.autocommit is True - ), "autocommit should be true after setting to true" - - # ==================== COMMIT TESTS ==================== - - def test_commit_single_insert(self): - """Test successfully committing a transaction with single INSERT.""" - fq_table_name = self._get_fully_qualified_table_name() - - # Start transaction - self.connection.autocommit = False - - # Insert data - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'test_value')" - ) - cursor.close() + assert first_read == second_read, "Repeatable read: value should not change" + conn.rollback() - # Commit - self.connection.commit() + @pytest.mark.xdist_group(name="mst_write_conflict") + def test_write_conflict_single_table(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as setup_conn: + with setup_conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'initial')") - # Verify data is persisted using a new connection - verify_conn = sql.connect(**self.connection_params) + conn1 = sql.connect(**mst_conn_params) + conn2 = sql.connect(**mst_conn_params) try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") - result = verify_cursor.fetchone() - verify_cursor.close() + conn1.autocommit = False + conn2.autocommit = False - assert result is not None, "Should find inserted row after commit" - assert result[0] == "test_value", "Value should match inserted value" - finally: - verify_conn.close() + with conn1.cursor() as c1: + c1.execute(f"UPDATE {fq_table} SET value = 'conn1' WHERE id = 1") + with conn2.cursor() as c2: + c2.execute(f"UPDATE {fq_table} SET value = 'conn2' WHERE id = 1") - def test_commit_multiple_inserts(self): - """Test successfully committing a transaction with multiple INSERTs.""" - fq_table_name = self._get_fully_qualified_table_name() - - self.connection.autocommit = False - - # Insert multiple rows - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'value1')") - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'value2')") - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'value3')") - cursor.close() - - self.connection.commit() - - # Verify all rows persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name}") - result = verify_cursor.fetchone() - verify_cursor.close() - - assert result[0] == 3, "Should have 3 rows after commit" + conn1.commit() + with pytest.raises(Exception): + conn2.commit() finally: - verify_conn.close() - - # ==================== ROLLBACK TESTS ==================== - - def test_rollback_single_insert(self): - """Test successfully rolling back a transaction.""" - fq_table_name = self._get_fully_qualified_table_name() - - self.connection.autocommit = False + try: + conn1.close() + except Exception: + pass + try: + conn2.close() + except Exception: + pass + + def test_read_only_transaction(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'existing')") + + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"SELECT COUNT(*) FROM {fq_table}") + assert cursor.fetchone()[0] == 1 + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 1 + + def test_rollback_after_query_failure(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'before_error')") + with pytest.raises(Exception): + cursor.execute("SELECT * FROM nonexistent_xyz_xyz") + conn.rollback() + + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (2, 'after_recovery')") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 1 + + def test_multiple_cursors_in_txn(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as c1: + c1.execute(f"INSERT INTO {fq_table} VALUES (1, 'c1')") + with conn.cursor() as c2: + c2.execute(f"INSERT INTO {fq_table} VALUES (2, 'c2')") + with conn.cursor() as c3: + c3.execute(f"INSERT INTO {fq_table} VALUES (3, 'c3')") + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 3 + + def test_parameterized_insert(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute( + f"INSERT INTO {fq_table} VALUES (%(id)s, %(value)s)", + {"id": 1, "value": "parameterized"}, + ) + conn.commit() + + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT value FROM {fq_table} WHERE id = 1") + assert cursor.fetchone()[0] == "parameterized" + + def test_empty_transaction_rollback(self, mst_conn_params, mst_table): + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + # Rollback with no DML should not raise + conn.rollback() + + def test_close_connection_implicit_rollback(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + conn = sql.connect(**mst_conn_params) + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'pending')") + conn.close() + + assert _get_row_count(mst_conn_params, fq_table) == 0 + + +# ==================== B. API-SPECIFIC TESTS ==================== + + +class TestMstApi: + """DB-API-specific tests: autocommit, isolation, error handling.""" + + def test_default_autocommit_is_true(self, mst_conn_params): + with sql.connect(**mst_conn_params) as conn: + assert conn.autocommit is True + + def test_set_autocommit_false(self, mst_conn_params): + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + assert conn.autocommit is False + + def test_commit_without_active_txn_throws(self, mst_conn_params): + with sql.connect(**mst_conn_params) as conn: + with pytest.raises(Exception, match=r"NO_ACTIVE_TRANSACTION"): + conn.commit() + + def test_set_autocommit_during_active_txn_throws( + self, mst_conn_params, mst_table + ): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'active_txn')") + with pytest.raises(Exception): + conn.autocommit = True + conn.rollback() + + def test_supported_isolation_level(self, mst_conn_params): + with sql.connect(**mst_conn_params) as conn: + conn.set_transaction_isolation("REPEATABLE_READ") + assert conn.get_transaction_isolation() == "REPEATABLE_READ" + + def test_unsupported_isolation_level_rejected(self, mst_conn_params): + with sql.connect(**mst_conn_params) as conn: + for level in ["READ_UNCOMMITTED", "READ_COMMITTED", "SERIALIZABLE"]: + with pytest.raises(Exception): + conn.set_transaction_isolation(level) + + +# ==================== C. METADATA RPCs ==================== + + +class TestMstMetadata: + """Metadata RPCs inside active transactions. + + Python uses Thrift RPCs for cursor.columns, cursor.tables, etc. These + RPCs bypass MST context and return non-transactional data — they see + concurrent DDL changes that the transaction shouldn't see. + """ + + def test_cursor_columns_in_mst( + self, mst_conn_params, mst_table, mst_catalog, mst_schema + ): + fq_table, table_name = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") + cursor.columns( + catalog_name=mst_catalog, schema_name=mst_schema, table_name=table_name + ) + columns = cursor.fetchall() + assert len(columns) > 0 + conn.rollback() + + def test_cursor_tables_in_mst( + self, mst_conn_params, mst_table, mst_catalog, mst_schema + ): + fq_table, table_name = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") + cursor.tables( + catalog_name=mst_catalog, schema_name=mst_schema, table_name=table_name + ) + tables = cursor.fetchall() + assert len(tables) > 0 + conn.rollback() + + def test_cursor_schemas_in_mst(self, mst_conn_params, mst_table, mst_catalog): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") + cursor.schemas(catalog_name=mst_catalog) + schemas = cursor.fetchall() + assert len(schemas) > 0 + conn.rollback() + + def test_cursor_catalogs_in_mst(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") + cursor.catalogs() + catalogs = cursor.fetchall() + assert len(catalogs) > 0 + conn.rollback() + + @pytest.mark.xdist_group(name="mst_freshness_columns") + def test_cursor_columns_non_transactional_after_concurrent_ddl( + self, mst_conn_params, mst_table, mst_catalog, mst_schema + ): + """Thrift cursor.columns() bypasses MST — sees concurrent ALTER TABLE.""" + fq_table, table_name = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") + cursor.columns( + catalog_name=mst_catalog, schema_name=mst_schema, table_name=table_name + ) + before_cols = {row[3].lower() for row in cursor.fetchall()} - # Insert data - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table_name} (id, value) VALUES (100, 'rollback_test')" - ) - cursor.close() + # External connection alters schema + with sql.connect(**mst_conn_params) as ext_conn: + with ext_conn.cursor() as ext_cursor: + ext_cursor.execute( + f"ALTER TABLE {fq_table} ADD COLUMN new_col STRING" + ) - # Rollback - self.connection.rollback() + # Re-read columns in same txn — Thrift RPC bypasses txn isolation, + # so new_col IS visible (proves non-transactional behavior) + with conn.cursor() as cursor: + cursor.columns( + catalog_name=mst_catalog, schema_name=mst_schema, table_name=table_name + ) + after_cols = {row[3].lower() for row in cursor.fetchall()} - # Verify data is NOT persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 100" + assert "new_col" in after_cols, ( + "Thrift cursor.columns() should see concurrent DDL " + "(non-transactional behavior)" ) - result = verify_cursor.fetchone() - verify_cursor.close() - - assert result[0] == 0, "Rolled back data should not be persisted" - finally: - verify_conn.close() - - # ==================== SEQUENTIAL TRANSACTION TESTS ==================== - - def test_multiple_sequential_transactions(self): - """Test executing multiple sequential transactions (commit, commit, rollback).""" - fq_table_name = self._get_fully_qualified_table_name() - - self.connection.autocommit = False - - # First transaction - commit - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'txn1')") - cursor.close() - self.connection.commit() - - # Second transaction - commit - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'txn2')") - cursor.close() - self.connection.commit() - - # Third transaction - rollback - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'txn3')") - cursor.close() - self.connection.rollback() + assert before_cols != after_cols + conn.rollback() + + @pytest.mark.xdist_group(name="mst_freshness_tables") + def test_cursor_tables_non_transactional_after_concurrent_create( + self, mst_conn_params, mst_table, mst_catalog, mst_schema + ): + """Thrift cursor.tables() bypasses MST — sees concurrent CREATE TABLE.""" + fq_table, _ = mst_table + new_table_name = _unique_table_name_raw("freshness_new_tbl") + fq_new_table = f"{mst_catalog}.{mst_schema}.{new_table_name}" - # Verify only first two transactions persisted - verify_conn = sql.connect(**self.connection_params) try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table_name} WHERE id IN (1, 2)" - ) - result = verify_cursor.fetchone() - assert result[0] == 2, "Should have 2 committed rows" - - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 3") - result = verify_cursor.fetchone() - assert result[0] == 0, "Rolled back row should not exist" - verify_cursor.close() + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") + cursor.tables( + catalog_name=mst_catalog, + schema_name=mst_schema, + table_name=new_table_name, + ) + assert len(cursor.fetchall()) == 0 + + # External connection creates the table + with sql.connect(**mst_conn_params) as ext_conn: + with ext_conn.cursor() as ext_cursor: + ext_cursor.execute( + f"CREATE TABLE {fq_new_table} (id INT) USING DELTA " + f"TBLPROPERTIES ('delta.feature.catalogManaged' = 'supported')" + ) + + # Re-read in same txn — should see the new table + with conn.cursor() as cursor: + cursor.tables( + catalog_name=mst_catalog, + schema_name=mst_schema, + table_name=new_table_name, + ) + assert len(cursor.fetchall()) > 0, ( + "Thrift cursor.tables() should see concurrent CREATE TABLE " + "(non-transactional behavior)" + ) + conn.rollback() finally: - verify_conn.close() - - def test_auto_start_transaction_after_commit(self): - """Test that new transaction automatically starts after commit.""" - fq_table_name = self._get_fully_qualified_table_name() - - self.connection.autocommit = False + try: + with sql.connect(**mst_conn_params) as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {fq_new_table}") + except Exception as e: + logger.warning(f"Failed to drop {fq_new_table}: {e}") - # First transaction - commit - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") - cursor.close() - self.connection.commit() - # New transaction should start automatically - insert and rollback - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") - cursor.close() - self.connection.rollback() +# ==================== D. BLOCKED SQL (MSTCheckRule) ==================== - # Verify: first committed, second rolled back - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") - result = verify_cursor.fetchone() - assert result[0] == 1, "First insert should be committed" - - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") - result = verify_cursor.fetchone() - assert result[0] == 0, "Second insert should be rolled back" - verify_cursor.close() - finally: - verify_conn.close() - def test_auto_start_transaction_after_rollback(self): - """Test that new transaction automatically starts after rollback.""" - fq_table_name = self._get_fully_qualified_table_name() +class TestMstBlockedSql: + """SQL introspection statements inside active transactions. - self.connection.autocommit = False + The server restricts MST to a specific allowlist of commands. The error + message from TRANSACTION_NOT_SUPPORTED.COMMAND is explicit: + "Only SELECT / INSERT / MERGE / UPDATE / DELETE / DESCRIBE TABLE are supported." - # First transaction - rollback - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") - cursor.close() - self.connection.rollback() + Blocked (throw + abort txn): + - SHOW COLUMNS, SHOW TABLES, SHOW SCHEMAS, SHOW CATALOGS, SHOW FUNCTIONS + - DESCRIBE QUERY, DESCRIBE TABLE EXTENDED + - SELECT FROM information_schema - # New transaction should start automatically - insert and commit - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") - cursor.close() - self.connection.commit() + Allowed: + - DESCRIBE TABLE (basic form — explicitly listed in server's allowlist) + """ - # Verify: first rolled back, second committed - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") - result = verify_cursor.fetchone() - assert result[0] == 0, "First insert should be rolled back" - - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") - result = verify_cursor.fetchone() - assert result[0] == 1, "Second insert should be committed" - verify_cursor.close() - finally: - verify_conn.close() + def _assert_blocked_and_txn_aborted(self, mst_conn_params, fq_table, blocked_sql): + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'before_blocked')") - # ==================== UPDATE/DELETE OPERATION TESTS ==================== + with pytest.raises(Exception): + cursor.execute(blocked_sql) - def test_update_in_transaction(self): - """Test UPDATE operation in transaction.""" - fq_table_name = self._get_fully_qualified_table_name() - - # First insert a row with autocommit - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'original')" + with pytest.raises(Exception): + cursor.execute( + f"INSERT INTO {fq_table} VALUES (2, 'after_blocked')" + ) + try: + conn.rollback() + except Exception: + pass + + def _assert_not_blocked(self, mst_conn_params, fq_table, allowed_sql): + """Assert the SQL succeeds and returns rows inside an active txn.""" + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'before')") + cursor.execute(allowed_sql) + rows = cursor.fetchall() + assert len(rows) > 0 + conn.rollback() + + def test_show_tables_blocked(self, mst_conn_params, mst_table, mst_catalog, mst_schema): + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, fq_table, f"SHOW TABLES IN {mst_catalog}.{mst_schema}" ) - cursor.close() - # Start transaction and update - self.connection.autocommit = False - cursor = self.connection.cursor() - cursor.execute(f"UPDATE {fq_table_name} SET value = 'updated' WHERE id = 1") - cursor.close() - self.connection.commit() - - # Verify update persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") - result = verify_cursor.fetchone() - assert result[0] == "updated", "Value should be updated after commit" - verify_cursor.close() - finally: - verify_conn.close() - - # ==================== MULTI-TABLE TRANSACTION TESTS ==================== - - def test_multi_table_transaction_commit(self): - """Test atomic commit across multiple tables.""" - fq_table1_name = self._get_fully_qualified_table_name() - table2_name = self.TEST_TABLE_NAME + "_2" - fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" - - # Create second table - cursor = self.connection.cursor() - cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") - cursor.execute( - f""" - CREATE TABLE IF NOT EXISTS {fq_table2_name} - (id INT, category STRING) - USING DELTA - TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') - """ + def test_show_schemas_blocked(self, mst_conn_params, mst_table, mst_catalog): + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, fq_table, f"SHOW SCHEMAS IN {mst_catalog}" ) - cursor.close() - try: - # Start transaction and insert into both tables - self.connection.autocommit = False - - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table1_name} (id, value) VALUES (10, 'table1_data')" - ) - cursor.execute( - f"INSERT INTO {fq_table2_name} (id, category) VALUES (10, 'table2_data')" - ) - cursor.close() + def test_show_catalogs_blocked(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, fq_table, "SHOW CATALOGS" + ) - # Commit both atomically - self.connection.commit() + def test_show_functions_blocked(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, fq_table, "SHOW FUNCTIONS" + ) - # Verify both inserts persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() + def test_describe_table_extended_blocked(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, fq_table, f"DESCRIBE TABLE EXTENDED {fq_table}" + ) - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 10" - ) - result = verify_cursor.fetchone() - assert result[0] == 1, "Table1 insert should be committed" + def test_information_schema_blocked(self, mst_conn_params, mst_table, mst_catalog): + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, + fq_table, + f"SELECT * FROM {mst_catalog}.information_schema.columns LIMIT 1", + ) - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 10" - ) - result = verify_cursor.fetchone() - assert result[0] == 1, "Table2 insert should be committed" + def test_show_columns_blocked(self, mst_conn_params, mst_table): + """SHOW COLUMNS is blocked in MST (ShowDeltaTableColumnsCommand).""" + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, fq_table, f"SHOW COLUMNS IN {fq_table}" + ) - verify_cursor.close() - finally: - verify_conn.close() + def test_describe_query_blocked(self, mst_conn_params, mst_table): + """DESCRIBE QUERY is blocked in MST (DescribeQueryCommand).""" + fq_table, _ = mst_table + self._assert_blocked_and_txn_aborted( + mst_conn_params, + fq_table, + f"DESCRIBE QUERY SELECT * FROM {fq_table}", + ) - finally: - # Cleanup second table - self.connection.autocommit = True - cursor = self.connection.cursor() - cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") - cursor.close() - - def test_multi_table_transaction_rollback(self): - """Test atomic rollback across multiple tables.""" - fq_table1_name = self._get_fully_qualified_table_name() - table2_name = self.TEST_TABLE_NAME + "_2" - fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" - - # Create second table - cursor = self.connection.cursor() - cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") - cursor.execute( - f""" - CREATE TABLE IF NOT EXISTS {fq_table2_name} - (id INT, category STRING) - USING DELTA - TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') - """ + # DESCRIBE TABLE is explicitly listed as an allowed command in the server's + # TRANSACTION_NOT_SUPPORTED.COMMAND error message: + # "Only SELECT / INSERT / MERGE / UPDATE / DELETE / DESCRIBE TABLE are supported." + def test_describe_table_not_blocked(self, mst_conn_params, mst_table): + """DESCRIBE TABLE succeeds in MST — explicitly allowed by the server.""" + fq_table, _ = mst_table + self._assert_not_blocked( + mst_conn_params, fq_table, f"DESCRIBE TABLE {fq_table}" ) - cursor.close() - try: - # Start transaction and insert into both tables - self.connection.autocommit = False - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table1_name} (id, value) VALUES (20, 'rollback1')" - ) - cursor.execute( - f"INSERT INTO {fq_table2_name} (id, category) VALUES (20, 'rollback2')" - ) - cursor.close() +# ==================== E. EXECUTE VARIANTS ==================== - # Rollback both atomically - self.connection.rollback() - # Verify both inserts were rolled back - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() +class TestMstExecuteVariants: + """Execute method variants (executemany) inside MST.""" - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 20" + def test_executemany_in_txn(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.executemany( + f"INSERT INTO {fq_table} VALUES (%(id)s, %(value)s)", + [ + {"id": 1, "value": "a"}, + {"id": 2, "value": "b"}, + {"id": 3, "value": "c"}, + ], ) - result = verify_cursor.fetchone() - assert result[0] == 0, "Table1 insert should be rolled back" - - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 20" + conn.commit() + + assert _get_row_count(mst_conn_params, fq_table) == 3 + + def test_executemany_rollback_in_txn(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + with sql.connect(**mst_conn_params) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + cursor.executemany( + f"INSERT INTO {fq_table} VALUES (%(id)s, %(value)s)", + [{"id": 1, "value": "a"}, {"id": 2, "value": "b"}], ) - result = verify_cursor.fetchone() - assert result[0] == 0, "Table2 insert should be rolled back" + conn.rollback() - verify_cursor.close() - finally: - verify_conn.close() - - finally: - # Cleanup second table - self.connection.autocommit = True - cursor = self.connection.cursor() - cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") - cursor.close() - - # ==================== ERROR HANDLING TESTS ==================== - - def test_set_autocommit_during_active_transaction(self): - """Test that setting autocommit during an active transaction throws error.""" - fq_table_name = self._get_fully_qualified_table_name() - - # Start transaction - self.connection.autocommit = False - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (99, 'test')") - cursor.close() - - # Try to set autocommit=True during active transaction - with pytest.raises(TransactionError) as exc_info: - self.connection.autocommit = True - - # Verify error message mentions autocommit or active transaction - error_msg = str(exc_info.value).lower() - assert ( - "autocommit" in error_msg or "active transaction" in error_msg - ), "Error should mention autocommit or active transaction" - - # Cleanup - rollback the transaction - self.connection.rollback() - - def test_commit_without_active_transaction_throws_error(self): - """Test that commit() throws error when autocommit=true (no active transaction).""" - # Ensure autocommit is true (default) - assert self.connection.autocommit is True - - # Attempt commit without active transaction should throw - with pytest.raises(TransactionError) as exc_info: - self.connection.commit() - - # Verify error message indicates no active transaction - error_message = str(exc_info.value) - assert ( - "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION" in error_message - or "no active transaction" in error_message.lower() - ), "Error should indicate no active transaction" - - def test_rollback_without_active_transaction_is_safe(self): - """Test that rollback() without active transaction is a safe no-op.""" - # With autocommit=true (no active transaction) - assert self.connection.autocommit is True - - # ROLLBACK should be safe (no exception) - self.connection.rollback() - - # Verify connection is still usable - assert self.connection.autocommit is True - assert self.connection.open is True - - # ==================== TRANSACTION ISOLATION TESTS ==================== - - def test_get_transaction_isolation_returns_repeatable_read(self): - """Test that get_transaction_isolation() returns REPEATABLE_READ.""" - isolation_level = self.connection.get_transaction_isolation() - assert ( - isolation_level == "REPEATABLE_READ" - ), "Databricks MST should use REPEATABLE_READ (Snapshot Isolation)" - - def test_set_transaction_isolation_accepts_repeatable_read(self): - """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" - # Should not raise - these are all valid formats - self.connection.set_transaction_isolation("REPEATABLE_READ") - self.connection.set_transaction_isolation("REPEATABLE READ") - self.connection.set_transaction_isolation("repeatable_read") - self.connection.set_transaction_isolation("repeatable read") - - def test_set_transaction_isolation_rejects_unsupported_level(self): - """Test that set_transaction_isolation() rejects unsupported levels.""" - with pytest.raises(NotSupportedError) as exc_info: - self.connection.set_transaction_isolation("READ_COMMITTED") - - error_message = str(exc_info.value) - assert "not supported" in error_message.lower() - assert "READ_COMMITTED" in error_message + assert _get_row_count(mst_conn_params, fq_table) == 0 From ff921ed2d48a35ac5027c845efc11c90a7c7b7b1 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Tue, 21 Apr 2026 21:19:39 +0530 Subject: [PATCH 37/39] Add SPOG routing support for account-level vanity URLs (#767) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add SPOG routing support for account-level vanity URLs SPOG replaces per-workspace hostnames with account-level URLs. When httpPath contains ?o=, the connector now extracts the workspace ID and injects x-databricks-org-id as an HTTP header on all non-OAuth endpoints (SEA, telemetry, feature flags). Changes: - Fix warehouse ID regex to stop at query params ([^?&]+ instead of .+) - Extract ?o= from httpPath once during session init, store as _spog_headers - Propagate org-id header to telemetry client via extra_headers param - Propagate org-id header to feature flags client - Do NOT propagate to OAuth endpoints (they reject it with 400) Signed-off-by: Madhavendra Rathore Co-authored-by: Isaac Signed-off-by: Madhavendra Rathore * Add debug logging for SPOG x-databricks-org-id header extraction Mirrors the JDBC driver's logging pattern. Emits at DEBUG level in three code paths of _extract_spog_headers: 1. http_path has a query string but no ?o= param — log and skip. 2. x-databricks-org-id already set by the caller (via http_headers) — log and skip (don't override explicit user header). 3. Injection happens — log the extracted workspace ID so customers diagnosing SPOG routing can confirm the header was added. Helps with customer support: when a customer reports "SPOG isn't routing correctly", they can enable DEBUG logging and immediately see whether the connector saw their ?o= value. Signed-off-by: Madhavendra Rathore Signed-off-by: Madhavendra Rathore --------- Signed-off-by: Madhavendra Rathore Signed-off-by: Madhavendra Rathore --- src/databricks/sql/backend/sea/backend.py | 5 +- src/databricks/sql/client.py | 1 + src/databricks/sql/common/feature_flag.py | 1 + src/databricks/sql/session.py | 46 +++++++++++++++++++ .../sql/telemetry/telemetry_client.py | 6 +++ tests/unit/test_sea_backend.py | 33 +++++++++++++ tests/unit/test_session.py | 44 ++++++++++++++++++ 7 files changed, 134 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index ff130cd39..04c79a18b 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -188,8 +188,9 @@ def _extract_warehouse_id(self, http_path: str) -> str: ValueError: If the warehouse ID cannot be extracted from the path """ - warehouse_pattern = re.compile(r".*/warehouses/(.+)") - endpoint_pattern = re.compile(r".*/endpoints/(.+)") + # [^?&]+ stops at query params (e.g. ?o= for SPOG routing) + warehouse_pattern = re.compile(r".*/warehouses/([^?&]+)") + endpoint_pattern = re.compile(r".*/endpoints/([^?&]+)") for pattern in [warehouse_pattern, endpoint_pattern]: match = pattern.match(http_path) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 2aeea175e..fe52f0c79 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -353,6 +353,7 @@ def read(self) -> Optional[OAuthToken]: host_url=self.session.host, batch_size=self.telemetry_batch_size, client_context=client_context, + extra_headers=self.session.get_spog_headers(), ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 36e4b8a02..0b2c7490b 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -113,6 +113,7 @@ def _refresh_flags(self): # Authenticate the request self._connection.session.auth_provider.add_headers(headers) headers["User-Agent"] = self._connection.session.useragent_header + headers.update(self._connection.session.get_spog_headers()) response = self._http_client.request( HttpMethod.GET, self._feature_flag_endpoint, headers=headers, timeout=30 diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 1588d9f79..65c0d6aca 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -72,6 +72,14 @@ def __init__( base_headers = [("User-Agent", self.useragent_header)] all_headers = (http_headers or []) + base_headers + # Extract ?o= from http_path for SPOG routing. + # On SPOG hosts, the httpPath contains ?o= which routes Thrift + # requests via the URL. For SEA, telemetry, and feature flags (which use + # separate endpoints), we inject x-databricks-org-id as an HTTP header. + self._spog_headers = self._extract_spog_headers(http_path, all_headers) + if self._spog_headers: + all_headers = all_headers + list(self._spog_headers.items()) + self.ssl_options = SSLOptions( # Double negation is generally a bad thing, but we have to keep backward compatibility tls_verify=not kwargs.get( @@ -136,6 +144,44 @@ def _create_backend( } return databricks_client_class(**common_args) + @staticmethod + def _extract_spog_headers(http_path, existing_headers): + """Extract ?o= from http_path and return as a header dict for SPOG routing.""" + if not http_path or "?" not in http_path: + return {} + + from urllib.parse import parse_qs + + query_string = http_path.split("?", 1)[1] + params = parse_qs(query_string) + org_id = params.get("o", [None])[0] + if not org_id: + logger.debug( + "SPOG header extraction: http_path has query string but no ?o= param, " + "skipping x-databricks-org-id injection" + ) + return {} + + # Don't override if explicitly set + if any(k == "x-databricks-org-id" for k, _ in existing_headers): + logger.debug( + "SPOG header extraction: x-databricks-org-id already set by caller, " + "not overriding with ?o=%s from http_path", + org_id, + ) + return {} + + logger.debug( + "SPOG header extraction: injecting x-databricks-org-id=%s " + "(extracted from ?o= in http_path)", + org_id, + ) + return {"x-databricks-org-id": org_id} + + def get_spog_headers(self): + """Returns SPOG routing headers (x-databricks-org-id) if ?o= was in http_path.""" + return dict(self._spog_headers) + def open(self): self._session_id = self.backend.open_session( session_configuration=self.session_configuration, diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 408162400..55d845e46 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -188,6 +188,7 @@ def __init__( executor, batch_size: int, client_context, + extra_headers: Optional[Dict[str, str]] = None, ) -> None: logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled @@ -195,6 +196,7 @@ def __init__( self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None + self._extra_headers = extra_headers or {} # OPTIMIZATION: Use lock-free Queue instead of list + lock # Queue is thread-safe internally and has better performance under concurrency @@ -287,6 +289,8 @@ def _send_telemetry(self, events): if self._auth_provider: self._auth_provider.add_headers(headers) + headers.update(self._extra_headers) + try: logger.debug("Submitting telemetry request to thread pool") @@ -587,6 +591,7 @@ def initialize_telemetry_client( host_url, batch_size, client_context, + extra_headers=None, ): """ Initialize a telemetry client for a specific connection if telemetry is enabled. @@ -627,6 +632,7 @@ def initialize_telemetry_client( executor=TelemetryClientFactory._executor, batch_size=batch_size, client_context=client_context, + extra_headers=extra_headers, ) TelemetryClientFactory._clients[ host_url diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f71e60943..24a5e8242 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -143,6 +143,39 @@ def test_initialization(self, mock_http_client): ) assert client2.warehouse_id == "def456" + # Test with SPOG query param ?o= in http_path + client_spog = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/1.0/warehouses/abc123?o=6051921418418893", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client_spog.warehouse_id == "abc123" + + # Test with SPOG query param on endpoints path + client_spog_ep = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/1.0/endpoints/def456?o=6051921418418893", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client_spog_ep.warehouse_id == "def456" + + # Test with multiple query params + client_spog_multi = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/1.0/warehouses/abc123?o=123&extra=val", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client_spog_multi.warehouse_id == "abc123" + # Test with custom max_download_threads client3 = SeaDatabricksClient( server_hostname="test-server.databricks.com", diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index aa7e7f02b..136c99e53 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -8,6 +8,7 @@ THandleIdentifier, ) from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.session import Session import databricks.sql @@ -226,3 +227,46 @@ def test_query_tags_dict_takes_precedence_over_session_config(self, mock_client_ call_kwargs = mock_client_class.return_value.open_session.call_args[1] assert call_kwargs["session_configuration"]["QUERY_TAGS"] == "team:new-team" + + +class TestSpogHeaders: + """Unit tests for SPOG header extraction from http_path.""" + + def test_extracts_org_id_from_query_param(self): + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123?o=6051921418418893", [] + ) + assert result == {"x-databricks-org-id": "6051921418418893"} + + def test_no_query_param_returns_empty(self): + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123", [] + ) + assert result == {} + + def test_no_o_param_returns_empty(self): + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123?other=value", [] + ) + assert result == {} + + def test_empty_http_path_returns_empty(self): + result = Session._extract_spog_headers("", []) + assert result == {} + + def test_none_http_path_returns_empty(self): + result = Session._extract_spog_headers(None, []) + assert result == {} + + def test_explicit_header_takes_precedence(self): + existing = [("x-databricks-org-id", "explicit-value")] + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123?o=6051921418418893", existing + ) + assert result == {} + + def test_multiple_query_params(self): + result = Session._extract_spog_headers( + "/sql/1.0/warehouses/abc123?o=12345&extra=val", [] + ) + assert result == {"x-databricks-org-id": "12345"} From 2926daab0803304c47c5a9a850a021d0fa67bd52 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Thu, 23 Apr 2026 11:55:39 +0530 Subject: [PATCH 38/39] Bump to version 4.2.6 (#777) Signed-off-by: Madhavendra Rathore --- CHANGELOG.md | 11 +++++++++++ pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ba3bb1a8..fc89750d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Release History +# 4.2.6 (2026-04-22) +- Add SPOG routing support for account-level vanity URLs (databricks/databricks-sql-python#767 by @msrathore-db) +- Fix dependency_manager: handle PEP 440 ~= compatible release syntax (databricks/databricks-sql-python#776 by @vikrantpuppala) +- Bump thrift to fix deprecation warning (databricks/databricks-sql-python#733 by @Korijn) +- Add AI coding agent detection to User-Agent header (databricks/databricks-sql-python#740 by @vikrantpuppala) +- Add statement-level query_tags support for SEA backend (databricks/databricks-sql-python#754 by @sreekanth-db) +- Update PyArrow concatenation of tables to use promote_options as default (databricks/databricks-sql-python#751 by @jprakash-db) +- Fix float inference to use DoubleParameter (64-bit) instead of FloatParameter (databricks/databricks-sql-python#742 by @Shubhambhusate) +- Allow specifying query_tags as a dict upon connection creation (databricks/databricks-sql-python#749 by @jiabin-hu) +- Add query_tags parameter support for execute methods (databricks/databricks-sql-python#736 by @jiabin-hu) + # 4.2.5 (2026-02-09) - Fix feature-flag endpoint retries in gov region (databricks/databricks-sql-python#735 by @samikshya-db) - Improve telemetry lifecycle management (databricks/databricks-sql-python#734 by @msrathore-db) diff --git a/pyproject.toml b/pyproject.toml index e1ce3b73f..5e9f7f0ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.2.5" +version = "4.2.6" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index c9195b89f..493ffe3a2 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -71,7 +71,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.2.5" +__version__ = "4.2.6" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From ee63b811a3c71961495450ffee0a3bd5ad856466 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Thu, 23 Apr 2026 15:37:14 +0530 Subject: [PATCH 39/39] [PECOBLR-2461] Fix test_show_columns_blocked: SHOW COLUMNS now allowed in MST (#778) The server's MSTCheckRule allowlist has been broadened to include SHOW COLUMNS (ShowDeltaTableColumnsCommand). Flip the test to assert SHOW COLUMNS succeeds inside an MST transaction, matching the pattern already used by test_describe_table_not_blocked. Other SHOW variants (SHOW SCHEMAS/TABLES/CATALOGS/FUNCTIONS), DESCRIBE QUERY, DESCRIBE TABLE EXTENDED, and information_schema remain blocked as expected. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- tests/e2e/test_transactions.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/e2e/test_transactions.py b/tests/e2e/test_transactions.py index 9bf74baad..4fb7918b9 100644 --- a/tests/e2e/test_transactions.py +++ b/tests/e2e/test_transactions.py @@ -624,17 +624,21 @@ def test_cursor_tables_non_transactional_after_concurrent_create( class TestMstBlockedSql: """SQL introspection statements inside active transactions. - The server restricts MST to a specific allowlist of commands. The error - message from TRANSACTION_NOT_SUPPORTED.COMMAND is explicit: + The server restricts MST to an allowlist enforced by MSTCheckRule. The + TRANSACTION_NOT_SUPPORTED.COMMAND error originally advertised only: "Only SELECT / INSERT / MERGE / UPDATE / DELETE / DESCRIBE TABLE are supported." + The server has since broadened the allowlist to include SHOW COLUMNS + (ShowDeltaTableColumnsCommand), observed on current DBSQL warehouses. + Blocked (throw + abort txn): - - SHOW COLUMNS, SHOW TABLES, SHOW SCHEMAS, SHOW CATALOGS, SHOW FUNCTIONS + - SHOW TABLES, SHOW SCHEMAS, SHOW CATALOGS, SHOW FUNCTIONS - DESCRIBE QUERY, DESCRIBE TABLE EXTENDED - SELECT FROM information_schema Allowed: - - DESCRIBE TABLE (basic form — explicitly listed in server's allowlist) + - DESCRIBE TABLE (basic form) + - SHOW COLUMNS """ def _assert_blocked_and_txn_aborted(self, mst_conn_params, fq_table, blocked_sql): @@ -704,10 +708,10 @@ def test_information_schema_blocked(self, mst_conn_params, mst_table, mst_catalo f"SELECT * FROM {mst_catalog}.information_schema.columns LIMIT 1", ) - def test_show_columns_blocked(self, mst_conn_params, mst_table): - """SHOW COLUMNS is blocked in MST (ShowDeltaTableColumnsCommand).""" + def test_show_columns_not_blocked(self, mst_conn_params, mst_table): + """SHOW COLUMNS succeeds in MST — allowed by the server's MSTCheckRule allowlist.""" fq_table, _ = mst_table - self._assert_blocked_and_txn_aborted( + self._assert_not_blocked( mst_conn_params, fq_table, f"SHOW COLUMNS IN {fq_table}" )