diff --git a/CHANGELOG.md b/CHANGELOG.md index 06c12bdc6..f113127e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +# 4.1.3 (2025-09-17) +- Query tags integration (databricks/databricks-sql-python#663 by @sreekanth-db) +- Add variant support (databricks/databricks-sql-python#560 by @shivam2680) + # 4.1.2 (2025-08-22) - Streaming ingestion support for PUT operation (databricks/databricks-sql-python#643 by @sreekanth-db) - Removed use_threads argument on concat_tables for compatibility with pyarrow<14 (databricks/databricks-sql-python#684 by @jprakash-db) diff --git a/examples/query_tags_example.py b/examples/query_tags_example.py new file mode 100644 index 000000000..f615d082c --- /dev/null +++ b/examples/query_tags_example.py @@ -0,0 +1,30 @@ +import os +import databricks.sql as sql + +""" +This example demonstrates how to use Query Tags. + +Query Tags are key-value pairs that can be attached to SQL executions and will appear +in the system.query.history table for analytical purposes. + +Format: "key1:value1,key2:value2,key3:value3" +""" + +print("=== Query Tags Example ===\n") + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), + session_configuration={ + 'QUERY_TAGS': 'team:engineering,test:query-tags', + 'ansi_mode': False + } +) as connection: + + with connection.cursor() as cursor: + cursor.execute("SELECT 1") + result = cursor.fetchone() + print(f" Result: {result[0]}") + +print("\n=== Query Tags Example Complete ===") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6f0f74710..a1f43bc70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.1.2" +version = "4.1.3" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 31b5cbb7f..a2d444f39 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -68,7 +68,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.1.2" +__version__ = "4.1.3" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 46ce8c98a..61ecf969e 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -15,6 +15,7 @@ "STATEMENT_TIMEOUT": "0", "TIMEZONE": "UTC", "USE_CACHED_RESULT": "true", + "QUERY_TAGS": "", } diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 02c88aa63..d2b10e718 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -735,7 +735,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col, session_id_hex=None): + def _col_to_description(col, field=None, session_id_hex=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -764,12 +764,39 @@ def _col_to_description(col, session_id_hex=None): else: precision, scale = None, None + # Extract variant type from field if available + if field is not None: + try: + # Check for variant type in metadata + if field.metadata and b"Spark:DataType:SqlName" in field.metadata: + sql_type = field.metadata.get(b"Spark:DataType:SqlName") + if sql_type == b"VARIANT": + cleaned_type = "variant" + except Exception as e: + logger.debug(f"Could not extract variant type from field: {e}") + return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema, session_id_hex=None): + def _hive_schema_to_description( + t_table_schema, schema_bytes=None, session_id_hex=None + ): + field_dict = {} + if pyarrow and schema_bytes: + try: + arrow_schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes)) + # Build a dictionary mapping column names to fields + for field in arrow_schema: + field_dict[field.name] = field + except Exception as e: + logger.debug(f"Could not parse arrow schema: {e}") + return [ - ThriftDatabricksClient._col_to_description(col, session_id_hex) + ThriftDatabricksClient._col_to_description( + col, + field_dict.get(col.columnName) if field_dict else None, + session_id_hex, + ) for col in t_table_schema.columns ] @@ -802,11 +829,6 @@ def _results_message_to_execute_response(self, resp, operation_state): or direct_results.resultSet.hasMoreRows ) - description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema, - self._session_id_hex, - ) - if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema @@ -819,6 +841,12 @@ def _results_message_to_execute_response(self, resp, operation_state): else: schema_bytes = None + description = self._hive_schema_to_description( + t_result_set_metadata_resp.schema, + schema_bytes, + self._session_id_hex, + ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -863,11 +891,6 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata - description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema, - self._session_id_hex, - ) - if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema @@ -880,6 +903,12 @@ def get_execution_result( else: schema_bytes = None + description = self._hive_schema_to_description( + t_result_set_metadata_resp.schema, + schema_bytes, + self._session_id_hex, + ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation has_more_rows = resp.hasMoreRows diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 52f6e4a2e..e04e348c9 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -901,7 +901,7 @@ def test_timestamps_arrow(self): ) def test_multi_timestamps_arrow(self, extra_params): with self.cursor( - {"session_configuration": {"ansi_mode": False}, **extra_params} + {"session_configuration": {"ansi_mode": False, "query_tags": "test:multi-timestamps,driver:python"}, **extra_params} ) as cursor: query, expected = self.multi_query() expected = [ diff --git a/tests/e2e/test_variant_types.py b/tests/e2e/test_variant_types.py new file mode 100644 index 000000000..b5dc1f421 --- /dev/null +++ b/tests/e2e/test_variant_types.py @@ -0,0 +1,91 @@ +import pytest +from datetime import datetime +import json + +try: + import pyarrow +except ImportError: + pyarrow = None + +from tests.e2e.test_driver import PySQLPytestTestCase +from tests.e2e.common.predicates import pysql_supports_arrow + + +@pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support") +class TestVariantTypes(PySQLPytestTestCase): + """Tests for the proper detection and handling of VARIANT type columns""" + + @pytest.fixture(scope="class") + def variant_table(self, connection_details): + """A pytest fixture that creates a test table and cleans up after tests""" + self.arguments = connection_details.copy() + table_name = "pysql_test_variant_types_table" + + with self.cursor() as cursor: + try: + # Create the table with variant columns + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS pysql_test_variant_types_table ( + id INTEGER, + variant_col VARIANT, + regular_string_col STRING + ) + """ + ) + + # Insert test records with different variant values + cursor.execute( + """ + INSERT INTO pysql_test_variant_types_table + VALUES + (1, PARSE_JSON('{"name": "John", "age": 30}'), 'regular string'), + (2, PARSE_JSON('[1, 2, 3, 4]'), 'another string') + """ + ) + yield table_name + finally: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + + def test_variant_type_detection(self, variant_table): + """Test that VARIANT type columns are properly detected in schema""" + with self.cursor() as cursor: + cursor.execute(f"SELECT * FROM {variant_table} LIMIT 0") + + # Verify column types in description + assert ( + cursor.description[0][1] == "int" + ), "Integer column type not correctly identified" + assert ( + cursor.description[1][1] == "variant" + ), "VARIANT column type not correctly identified" + assert ( + cursor.description[2][1] == "string" + ), "String column type not correctly identified" + + def test_variant_data_retrieval(self, variant_table): + """Test that VARIANT data is properly retrieved and can be accessed as JSON""" + with self.cursor() as cursor: + cursor.execute(f"SELECT * FROM {variant_table} ORDER BY id") + rows = cursor.fetchall() + + # First row should have a JSON object + json_obj = rows[0][1] + assert isinstance( + json_obj, str + ), "VARIANT column should be returned as string" + + parsed = json.loads(json_obj) + assert parsed.get("name") == "John" + assert parsed.get("age") == 30 + + # Second row should have a JSON array + json_array = rows[1][1] + assert isinstance( + json_array, str + ), "VARIANT array should be returned as string" + + # Parsing to verify it's valid JSON array + parsed_array = json.loads(json_array) + assert isinstance(parsed_array, list) + assert parsed_array == [1, 2, 3, 4] diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f604f2874..26a898cb8 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -185,6 +185,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i session_config = { "ANSI_MODE": "FALSE", # Supported parameter "STATEMENT_TIMEOUT": "3600", # Supported parameter + "QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter "unsupported_param": "value", # Unsupported parameter } catalog = "test_catalog" @@ -196,6 +197,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i "session_confs": { "ansi_mode": "FALSE", "statement_timeout": "3600", + "query_tags": "team:marketing,dashboard:abc123", }, "catalog": catalog, "schema": schema, @@ -641,6 +643,7 @@ def test_filter_session_configuration(self): "TIMEZONE": "UTC", "enable_photon": False, "MAX_FILE_PARTITION_BYTES": 128.5, + "QUERY_TAGS": "team:engineering,project:data-pipeline", "unsupported_param": "value", "ANOTHER_UNSUPPORTED": 42, } @@ -663,6 +666,7 @@ def test_filter_session_configuration(self): "timezone": "UTC", # string -> "UTC", key lowercased "enable_photon": "False", # boolean False -> "False", key lowercased "max_file_partition_bytes": "128.5", # float -> "128.5", key lowercased + "query_tags": "team:engineering,project:data-pipeline", } assert result == expected_result @@ -683,12 +687,14 @@ def test_filter_session_configuration(self): "ansi_mode": "false", # lowercase key "STATEMENT_TIMEOUT": 7200, # uppercase key "TiMeZoNe": "America/New_York", # mixed case key + "QueRy_TaGs": "team:marketing,test:case-insensitive", } result = _filter_session_configuration(case_insensitive_config) expected_case_result = { "ansi_mode": "false", "statement_timeout": "7200", "timezone": "America/New_York", + "query_tags": "team:marketing,test:case-insensitive", } assert result == expected_case_result diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index e019e05a2..c135a846b 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -155,7 +155,7 @@ def test_socket_timeout_passthrough(self, mock_client_class): @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): - mock_session_config = Mock() + mock_session_config = {"ANSI_MODE": "FALSE", "QUERY_TAGS": "team:engineering,project:data-pipeline"} databricks.sql.connect( session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS ) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 0445ace3e..7254b66cb 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -2330,7 +2330,7 @@ def test_execute_command_sets_complex_type_fields_correctly( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **complex_arg_types, ) thrift_backend.execute_command( @@ -2356,6 +2356,86 @@ def test_execute_command_sets_complex_type_fields_correctly( t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow ) + @unittest.skipIf(pyarrow is None, "Requires pyarrow") + def test_col_to_description(self): + test_cases = [ + ("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}, "variant"), + ("normal_col", {}, "string"), + ("weird_field", {b"Spark:DataType:SqlName": b"Some unexpected value"}, "string"), + ("missing_field", None, "string"), # None field case + ] + + for column_name, field_metadata, expected_type in test_cases: + with self.subTest(column_name=column_name, expected_type=expected_type): + col = ttypes.TColumnDesc( + columnName=column_name, + typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE), + ) + + field = ( + None + if field_metadata is None + else pyarrow.field(column_name, pyarrow.string(), metadata=field_metadata) + ) + + result = ThriftDatabricksClient._col_to_description(col, field) + + self.assertEqual(result[0], column_name) + self.assertEqual(result[1], expected_type) + self.assertIsNone(result[2]) + self.assertIsNone(result[3]) + self.assertIsNone(result[4]) + self.assertIsNone(result[5]) + self.assertIsNone(result[6]) + + @unittest.skipIf(pyarrow is None, "Requires pyarrow") + def test_hive_schema_to_description(self): + test_cases = [ + ( + [ + ("regular_col", ttypes.TTypeId.STRING_TYPE), + ("variant_col", ttypes.TTypeId.STRING_TYPE), + ], + [ + ("regular_col", {}), + ("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}), + ], + [("regular_col", "string"), ("variant_col", "variant")], + ), + ( + [("regular_col", ttypes.TTypeId.STRING_TYPE)], + None, # No arrow schema + [("regular_col", "string")], + ), + ] + + for columns, arrow_fields, expected_types in test_cases: + with self.subTest(arrow_fields=arrow_fields is not None): + t_table_schema = ttypes.TTableSchema( + columns=[ + ttypes.TColumnDesc( + columnName=name, typeDesc=self._make_type_desc(col_type) + ) + for name, col_type in columns + ] + ) + + schema_bytes = None + if arrow_fields: + fields = [ + pyarrow.field(name, pyarrow.string(), metadata=metadata) + for name, metadata in arrow_fields + ] + schema_bytes = pyarrow.schema(fields).serialize().to_pybytes() + + description = ThriftDatabricksClient._hive_schema_to_description( + t_table_schema, schema_bytes + ) + + for i, (expected_name, expected_type) in enumerate(expected_types): + self.assertEqual(description[i][0], expected_name) + self.assertEqual(description[i][1], expected_type) + if __name__ == "__main__": unittest.main()