From 0800f135332adf57b68a7fb95ac2faee9125f77f Mon Sep 17 00:00:00 2001 From: Abhishek Shinde Date: Sun, 14 Jun 2026 15:38:45 +0530 Subject: [PATCH 1/2] feat: Implement Databricks Unity Catalog offline store integration Signed-off-by: Abhishek Shinde --- .../spark_offline_store/databricks_uc.py | 300 ++++++++++++++++++ sdk/python/feast/repo_config.py | 1 + .../spark_offline_store/test_databricks_uc.py | 193 +++++++++++ 3 files changed, 494 insertions(+) create mode 100644 sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/databricks_uc.py create mode 100644 sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_databricks_uc.py diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/databricks_uc.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/databricks_uc.py new file mode 100644 index 00000000000..e6eb3ea47e6 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/databricks_uc.py @@ -0,0 +1,300 @@ +import logging +from datetime import date, datetime +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import pandas as pd +import pyarrow +import pyspark +from pydantic import StrictStr +from pyspark import SparkConf +from pyspark.sql import SparkSession + +from feast import FeatureView +from feast.data_source import DataSource +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkOfflineStore, + SparkOfflineStoreConfig, +) +from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.infra.registry.base_registry import BaseRegistry +from feast.repo_config import RepoConfig + +logger = logging.getLogger(__name__) + + +class DatabricksUCOfflineStoreConfig(SparkOfflineStoreConfig): + type: StrictStr = "databricks_uc" + """Offline store type selector""" + + workspace_host: Optional[StrictStr] = None + """Databricks workspace host (e.g. adb-xxxx.azuredatabricks.net)""" + + token: Optional[StrictStr] = None + """Databricks Personal Access Token (PAT)""" + + cluster_id: Optional[StrictStr] = None + """Databricks Cluster ID to connect to for Databricks Connect""" + + default_catalog: Optional[StrictStr] = None + """Default catalog name to use in Unity Catalog""" + + default_schema: Optional[StrictStr] = None + """Default schema name to use in Unity Catalog""" + + +def get_databricks_session( + store_config: DatabricksUCOfflineStoreConfig, +) -> SparkSession: + # Check if there is already an active session + spark_session = SparkSession.getActiveSession() + if not spark_session: + workspace_host = store_config.workspace_host + token = store_config.token + cluster_id = store_config.cluster_id + + # Clean host URL if it starts with https:// + if workspace_host: + if workspace_host.startswith("https://"): + workspace_host = workspace_host[8:] + elif workspace_host.startswith("http://"): + workspace_host = workspace_host[7:] + + if workspace_host and cluster_id: + # Databricks Connect V2 initialization (Spark Connect URI format) + conn_str = f"sc://{workspace_host}:443/" + params = [] + if token: + params.append(f"token={token}") + params.append(f"x-databricks-cluster-id={cluster_id}") + if params: + conn_str = f"{conn_str};{';'.join(params)}" + + try: + from databricks.connect import DatabricksSession + + builder = DatabricksSession.builder.remote(conn_str) + except ImportError: + # Fallback to standard PySpark remote connect if databricks-connect not installed + builder = SparkSession.builder.remote(conn_str) + else: + try: + from databricks.connect import DatabricksSession + + builder = DatabricksSession.builder + except ImportError: + builder = SparkSession.builder + + spark_conf = store_config.spark_conf + if spark_conf: + builder = builder.config( + conf=SparkConf().setAll([(k, v) for k, v in spark_conf.items()]) + ) + + spark_session = builder.getOrCreate() + + # Apply configuration defaults + spark_session.conf.set("spark.sql.parser.quotedRegexColumnNames", "true") + + if store_config.default_catalog: + spark_session.sql(f"USE CATALOG `{store_config.default_catalog}`") + if store_config.default_schema: + spark_session.sql(f"USE SCHEMA `{store_config.default_schema}`") + + return spark_session + + +class DatabricksUCOfflineStore(SparkOfflineStore): + @staticmethod + def pull_latest_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + start_date: datetime, + end_date: datetime, + ) -> RetrievalJob: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + # Initialize/Retrieve the Databricks Spark Session so it's registered as active + get_databricks_session(config.offline_store) + + return SparkOfflineStore.pull_latest_from_table_or_query( + config=config, + data_source=data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + @staticmethod + def get_historical_features( + config: RepoConfig, + feature_views: List[FeatureView], + feature_refs: List[str], + entity_df: Optional[Union[pd.DataFrame, str, pyspark.sql.DataFrame]], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, + **kwargs, + ) -> RetrievalJob: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.get_historical_features( + config=config, + feature_views=feature_views, + feature_refs=feature_refs, + entity_df=entity_df, + registry=registry, + project=project, + full_feature_names=full_feature_names, + **kwargs, + ) + + @staticmethod + def pull_all_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> RetrievalJob: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.pull_all_from_table_or_query( + config=config, + data_source=data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + @staticmethod + def offline_write_batch( + config: RepoConfig, + feature_view: FeatureView, + table: pyarrow.Table, + progress: Optional[Callable[[int], Any]], + ): + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.offline_write_batch( + config=config, + feature_view=feature_view, + table=table, + progress=progress, + ) + + @staticmethod + def compute_monitoring_metrics( + config: RepoConfig, + data_source: DataSource, + feature_columns: List[Tuple[str, str]], + timestamp_field: str, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + histogram_bins: int = 20, + top_n: int = 10, + ) -> List[Dict[str, Any]]: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.compute_monitoring_metrics( + config=config, + data_source=data_source, + feature_columns=feature_columns, + timestamp_field=timestamp_field, + start_date=start_date, + end_date=end_date, + histogram_bins=histogram_bins, + top_n=top_n, + ) + + @staticmethod + def get_monitoring_max_timestamp( + config: RepoConfig, + data_source: DataSource, + timestamp_field: str, + ) -> Optional[datetime]: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.get_monitoring_max_timestamp( + config=config, + data_source=data_source, + timestamp_field=timestamp_field, + ) + + @staticmethod + def ensure_monitoring_tables(config: RepoConfig) -> None: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.ensure_monitoring_tables(config=config) + + @staticmethod + def save_monitoring_metrics( + config: RepoConfig, + metric_type: str, + metrics: List[Dict[str, Any]], + ) -> None: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.save_monitoring_metrics( + config=config, + metric_type=metric_type, + metrics=metrics, + ) + + @staticmethod + def query_monitoring_metrics( + config: RepoConfig, + project: str, + metric_type: str, + filters: Optional[Dict[str, Any]] = None, + start_date: Optional[date] = None, + end_date: Optional[date] = None, + ) -> List[Dict[str, Any]]: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.query_monitoring_metrics( + config=config, + project=project, + metric_type=metric_type, + filters=filters, + start_date=start_date, + end_date=end_date, + ) + + @staticmethod + def clear_monitoring_baseline( + config: RepoConfig, + project: str, + feature_view_name: Optional[str] = None, + feature_name: Optional[str] = None, + data_source_type: Optional[str] = None, + ) -> None: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.clear_monitoring_baseline( + config=config, + project=project, + feature_view_name=feature_view_name, + feature_name=feature_name, + data_source_type=data_source_type, + ) diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 06529cea0f2..8527a3591aa 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -95,6 +95,7 @@ "redshift": "feast.infra.offline_stores.redshift.RedshiftOfflineStore", "snowflake.offline": "feast.infra.offline_stores.snowflake.SnowflakeOfflineStore", "spark": "feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStore", + "databricks_uc": "feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc.DatabricksUCOfflineStore", "trino": "feast.infra.offline_stores.contrib.trino_offline_store.trino.TrinoOfflineStore", "postgres": "feast.infra.offline_stores.contrib.postgres_offline_store.postgres.PostgreSQLOfflineStore", "athena": "feast.infra.offline_stores.contrib.athena_offline_store.athena.AthenaOfflineStore", diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_databricks_uc.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_databricks_uc.py new file mode 100644 index 00000000000..07bfd1d5416 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_databricks_uc.py @@ -0,0 +1,193 @@ +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc import ( + DatabricksUCOfflineStore, + DatabricksUCOfflineStoreConfig, + get_databricks_session, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( + SparkSource, +) +from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig +from feast.repo_config import RepoConfig + + +def test_config_parsing(): + config_dict = { + "type": "databricks_uc", + "workspace_host": "adb-12345.azuredatabricks.net", + "token": "dapi123456", + "cluster_id": "0123-4567-abcde", + "default_catalog": "main", + "default_schema": "default", + "spark_conf": {"spark.sql.shuffle.partitions": "10"}, + } + config = DatabricksUCOfflineStoreConfig(**config_dict) + assert config.type == "databricks_uc" + assert config.workspace_host == "adb-12345.azuredatabricks.net" + assert config.token == "dapi123456" + assert config.cluster_id == "0123-4567-abcde" + assert config.default_catalog == "main" + assert config.default_schema == "default" + assert config.spark_conf == {"spark.sql.shuffle.partitions": "10"} + + +def test_config_forbidden_extra(): + with pytest.raises(ValidationError): + DatabricksUCOfflineStoreConfig(type="databricks_uc", invalid_key="some_val") + + +@patch("pyspark.sql.SparkSession.getActiveSession") +def test_get_databricks_session_active(mock_get_active): + mock_session = MagicMock() + mock_get_active.return_value = mock_session + + config = DatabricksUCOfflineStoreConfig( + type="databricks_uc", + default_catalog="my_catalog", + default_schema="my_schema", + ) + + session = get_databricks_session(config) + + assert session == mock_session + mock_session.conf.set.assert_called_once_with( + "spark.sql.parser.quotedRegexColumnNames", "true" + ) + mock_session.sql.assert_any_call("USE CATALOG `my_catalog`") + mock_session.sql.assert_any_call("USE SCHEMA `my_schema`") + + +@patch("pyspark.sql.SparkSession.getActiveSession") +@patch("pyspark.sql.SparkSession.builder") +def test_get_databricks_session_new_remote(mock_builder, mock_get_active): + mock_get_active.return_value = None + mock_session = MagicMock() + mock_builder.remote.return_value.config.return_value.getOrCreate.return_value = ( + mock_session + ) + + config = DatabricksUCOfflineStoreConfig( + type="databricks_uc", + workspace_host="https://adb-12345.azuredatabricks.net", + token="dapi123", + cluster_id="0123-4567-abcde", + spark_conf={"spark.some.option": "value"}, + ) + + session = get_databricks_session(config) + + assert session == mock_session + mock_builder.remote.assert_called_once_with( + "sc://adb-12345.azuredatabricks.net:443/;token=dapi123;x-databricks-cluster-id=0123-4567-abcde" + ) + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc.get_databricks_session" +) +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStore.get_historical_features" +) +def test_get_historical_features_delegation(mock_parent_features, mock_get_session): + mock_session = MagicMock() + mock_get_session.return_value = mock_session + + repo_config = RepoConfig( + registry="file:///tmp/registry.db", + project="test", + provider="local", + online_store=SqliteOnlineStoreConfig(type="sqlite"), + offline_store=DatabricksUCOfflineStoreConfig( + type="databricks_uc", + workspace_host="adb-123.databricks.com", + cluster_id="123", + ), + ) + + feature_views = [] + feature_refs = ["fv:f1"] + entity_df = MagicMock() + registry = MagicMock() + + DatabricksUCOfflineStore.get_historical_features( + config=repo_config, + feature_views=feature_views, + feature_refs=feature_refs, + entity_df=entity_df, + registry=registry, + project="test", + ) + + mock_get_session.assert_called_once_with(repo_config.offline_store) + mock_parent_features.assert_called_once_with( + config=repo_config, + feature_views=feature_views, + feature_refs=feature_refs, + entity_df=entity_df, + registry=registry, + project="test", + full_feature_names=False, + ) + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc.get_databricks_session" +) +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStore.pull_latest_from_table_or_query" +) +def test_pull_latest_from_table_or_query_delegation( + mock_parent_pull_latest, mock_get_session +): + mock_session = MagicMock() + mock_get_session.return_value = mock_session + + repo_config = RepoConfig( + registry="file:///tmp/registry.db", + project="test", + provider="local", + online_store=SqliteOnlineStoreConfig(type="sqlite"), + offline_store=DatabricksUCOfflineStoreConfig( + type="databricks_uc", + workspace_host="adb-123.databricks.com", + cluster_id="123", + ), + ) + + data_source = SparkSource( + name="test_source", + path="catalog.schema.table", + file_format="parquet", + timestamp_field="ts", + ) + + start_date = datetime(2023, 1, 1, tzinfo=timezone.utc) + end_date = datetime(2023, 1, 2, tzinfo=timezone.utc) + + DatabricksUCOfflineStore.pull_latest_from_table_or_query( + config=repo_config, + data_source=data_source, + join_key_columns=["id"], + feature_name_columns=["val"], + timestamp_field="ts", + created_timestamp_column=None, + start_date=start_date, + end_date=end_date, + ) + + mock_get_session.assert_called_once_with(repo_config.offline_store) + mock_parent_pull_latest.assert_called_once_with( + config=repo_config, + data_source=data_source, + join_key_columns=["id"], + feature_name_columns=["val"], + timestamp_field="ts", + created_timestamp_column=None, + start_date=start_date, + end_date=end_date, + ) From 664107f0a27da3c2adf0aadeeacd9059c2a014d4 Mon Sep 17 00:00:00 2001 From: Abhishek Shinde Date: Tue, 16 Jun 2026 09:27:15 +0530 Subject: [PATCH 2/2] fix: Initialize Databricks session in DatabricksUCOfflineStore validation methods Signed-off-by: Abhishek Shinde --- .../spark_offline_store/databricks_uc.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/databricks_uc.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/databricks_uc.py index e6eb3ea47e6..9c8ec25e5d3 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/databricks_uc.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/databricks_uc.py @@ -1,6 +1,6 @@ import logging from datetime import date, datetime -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import pandas as pd import pyarrow @@ -298,3 +298,21 @@ def clear_monitoring_baseline( feature_name=feature_name, data_source_type=data_source_type, ) + + @staticmethod + def validate_data_source( + config: RepoConfig, + data_source: DataSource, + ): + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + data_source.validate(config=config) + + @staticmethod + def get_table_column_names_and_types_from_data_source( + config: RepoConfig, + data_source: DataSource, + ) -> Iterable[Tuple[str, str]]: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + return data_source.get_table_column_names_and_types(config=config)