diff --git a/.github/fork_workflows/fork_pr_integration_tests_snowflake.yml b/.github/fork_workflows/fork_pr_integration_tests_snowflake.yml index fe97341a5a2..0db580ce7db 100644 --- a/.github/fork_workflows/fork_pr_integration_tests_snowflake.yml +++ b/.github/fork_workflows/fork_pr_integration_tests_snowflake.yml @@ -72,7 +72,7 @@ jobs: SNOWFLAKE_CI_WAREHOUSE: ${{ secrets.SNOWFLAKE_CI_WAREHOUSE }} # Run only Snowflake BigQuery and File tests without dynamo and redshift tests. run: | - pytest -n 8 --cov=./ --cov-report=xml --color=yes sdk/python/tests --integration --durations=5 --timeout=1200 --timeout_method=thread -k "Snowflake and not dynamo and not Redshift and not Bigquery and not gcp and not minio_registry" + pytest -n 8 --cov=./ --cov-report=xml --color=yes sdk/python/tests --integration --durations=5 --timeout=1200 --timeout_method=thread -k "(Snowflake or snowflake_registry) and not dynamo and not Redshift and not Bigquery and not gcp and not minio_registry" pytest -n 8 --cov=./ --cov-report=xml --color=yes sdk/python/tests --integration --durations=5 --timeout=1200 --timeout_method=thread -k "File and not dynamo and not Redshift and not Bigquery and not gcp and not minio_registry" - name: Minimize uv cache run: uv cache prune --ci diff --git a/.secrets.baseline b/.secrets.baseline index 44c3ddc52b1..610429b78cf 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -1340,7 +1340,7 @@ "filename": "sdk/python/tests/integration/registration/test_universal_registry.py", "hashed_secret": "53e9042a36213bf85ef29a4371896aef8ba9196a", "is_verified": false, - "line_number": 126 + "line_number": 127 } ], "sdk/python/tests/integration/rest_api/resource/feast_config_rhoai.yaml": [ diff --git a/Makefile b/Makefile index c9730be8131..d59d730c114 100644 --- a/Makefile +++ b/Makefile @@ -282,6 +282,7 @@ test-python-universal-spark: ## Run Python Spark integration tests not test_push_features_to_offline_store.py and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_universal_types and \ not test_snowflake" \ sdk/python/tests @@ -316,6 +317,7 @@ test-python-universal-trino: ## Run Python Trino integration tests not test_push_features_to_offline_store.py and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_universal_types and \ not test_snowflake" \ sdk/python/tests @@ -331,6 +333,7 @@ test-python-universal-mssql: ## Run Python MSSQL integration tests python -m pytest -n 8 --integration \ -k "not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_lambda_materialization and \ not test_snowflake and \ not test_historical_features_persisting and \ @@ -364,6 +367,7 @@ test-python-universal-athena: ## Run Python Athena integration tests not test_historical_retrieval_fails_on_validation and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_snowflake" \ sdk/python/tests @@ -384,6 +388,7 @@ test-python-universal-postgres-offline: ## Run Python Postgres integration tests not test_push_features_to_offline_store and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_snowflake and \ not test_spark" \ sdk/python/tests @@ -405,6 +410,7 @@ test-python-universal-postgres-offline: ## Run Python Postgres integration tests not test_push_features_to_offline_store and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_snowflake and \ not test_spark" \ sdk/python/tests @@ -427,6 +433,7 @@ test-python-universal-postgres-online: ## Run Python Postgres integration tests not test_push_features_to_offline_store and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_universal_types and \ not test_snowflake" \ sdk/python/tests @@ -446,6 +453,7 @@ test-python-universal-postgres-online: ## Run Python Postgres integration tests not test_push_features_to_offline_store and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_universal_types and \ not test_validation and \ not test_spark_materialization_consistency and \ @@ -468,6 +476,7 @@ test-python-universal-mysql-online: ## Run Python MySQL integration tests not test_push_features_to_offline_store and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_universal_types and \ not test_snowflake" \ sdk/python/tests @@ -498,6 +507,7 @@ test-python-universal-hazelcast: ## Run Python Hazelcast integration tests not test_push_features_to_offline_store and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_universal_types and \ not test_snowflake" \ sdk/python/tests @@ -516,6 +526,7 @@ test-python-universal-cassandra-no-cloud-providers: ## Run Python Cassandra inte not test_nullable_online_store and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_snowflake" \ sdk/python/tests @@ -534,6 +545,7 @@ test-python-universal-elasticsearch-online: ## Run Python Elasticsearch online s not test_push_features_to_offline_store and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_universal_types and \ not test_snowflake" \ sdk/python/tests @@ -553,6 +565,7 @@ test-python-universal-mongodb-online: ## Run Python MongoDB online store integra not test_push_features_to_offline_store and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_universal_types and \ not test_snowflake" \ sdk/python/tests @@ -573,6 +586,7 @@ test-python-universal-singlestore-online: ## Run Python Singlestore online store -k "not test_universal_cli and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_snowflake" \ sdk/python/tests @@ -607,6 +621,7 @@ test-python-universal-couchbase-offline: ## Run Python Couchbase offline store i not test_push_features_to_offline_store and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_snowflake and \ not test_universal_types" \ sdk/python/tests @@ -626,6 +641,7 @@ test-python-universal-couchbase-online: ## Run Python Couchbase online store int not test_push_features_to_offline_store and \ not gcs_registry and \ not s3_registry and \ + not snowflake_registry and \ not test_universal_types and \ not test_snowflake" \ sdk/python/tests diff --git a/sdk/python/feast/infra/registry/snowflake.py b/sdk/python/feast/infra/registry/snowflake.py index 1a48e2962ef..6382fa1c010 100644 --- a/sdk/python/feast/infra/registry/snowflake.py +++ b/sdk/python/feast/infra/registry/snowflake.py @@ -147,13 +147,13 @@ def __init__( execute_snowflake_statement(conn, query) self.purge_feast_metadata = registry_config.purge_feast_metadata - self._sync_feast_metadata_to_projects_table() - if not self.purge_feast_metadata: - self._maybe_init_project_metadata(project) + self.project = project - self.cached_registry_proto = self.proto() - self.cached_registry_proto_created = _utc_now() + # Initialize cache state before any method that may trigger + # _refresh_cached_registry_if_necessary (e.g. proto(), get_project()). self._refresh_lock = Lock() + self.cached_registry_proto = None + self.cached_registry_proto_created = None self.cached_registry_proto_ttl = timedelta( seconds=( registry_config.cache_ttl_seconds @@ -161,11 +161,17 @@ def __init__( else 0 ) ) - self.project = project + + self._sync_feast_metadata_to_projects_table() + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) + + self.cached_registry_proto = self.proto() + self.cached_registry_proto_created = _utc_now() def _sync_feast_metadata_to_projects_table(self): - feast_metadata_projects: set = [] - projects_set: set = [] + feast_metadata_projects: set[str] = set() + projects_set: set[str] = set() with GetSnowflakeConnection(self.registry_config) as conn: query = ( @@ -185,7 +191,7 @@ def _sync_feast_metadata_to_projects_table(self): projects_set.add(row[1]["PROJECT_ID"]) # Find object in feast_metadata_projects but not in projects - projects_to_sync = set(feast_metadata_projects) - set(projects_set) + projects_to_sync = feast_metadata_projects - projects_set for project_name in projects_to_sync: self.apply_project(Project(name=project_name), commit=True) @@ -200,7 +206,7 @@ def refresh(self, project: Optional[str] = None): self.cached_registry_proto = self.proto() self.cached_registry_proto_created = _utc_now() - def _refresh_cached_registry_if_necessary(self): + def _refresh_cached_registry_if_necessary(self) -> RegistryProto: with self._refresh_lock: expired = ( self.cached_registry_proto is None @@ -221,6 +227,10 @@ def _refresh_cached_registry_if_necessary(self): logger.info("Registry cache expired, so refreshing") self.refresh() + if self.cached_registry_proto is None: + raise RuntimeError("Registry cache is unexpectedly empty after refresh") + return self.cached_registry_proto + def teardown(self): with GetSnowflakeConnection(self.registry_config) as conn: sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_deletion.sql" @@ -521,10 +531,8 @@ def get_data_source( self, name: str, project: str, allow_cache: bool = False ) -> DataSource: if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_data_source( - self.cached_registry_proto, name, project - ) + registry_proto = self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_data_source(registry_proto, name, project) return self._get_object( "DATA_SOURCES", name, @@ -538,10 +546,8 @@ def get_data_source( def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_entity( - self.cached_registry_proto, name, project - ) + registry_proto = self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_entity(registry_proto, name, project) return self._get_object( "ENTITIES", name, @@ -557,9 +563,9 @@ def get_feature_service( self, name: str, project: str, allow_cache: bool = False ) -> FeatureService: if allow_cache: - self._refresh_cached_registry_if_necessary() + registry_proto = self._refresh_cached_registry_if_necessary() return proto_registry_utils.get_feature_service( - self.cached_registry_proto, name, project + registry_proto, name, project ) return self._get_object( "FEATURE_SERVICES", @@ -576,10 +582,8 @@ def get_feature_view( self, name: str, project: str, allow_cache: bool = False ) -> FeatureView: if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_feature_view( - self.cached_registry_proto, name, project - ) + registry_proto = self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_feature_view(registry_proto, name, project) return self._get_object( "FEATURE_VIEWS", name, @@ -595,9 +599,9 @@ def get_any_feature_view( self, name: str, project: str, allow_cache: bool = False ) -> BaseFeatureView: if allow_cache: - self._refresh_cached_registry_if_necessary() + registry_proto = self._refresh_cached_registry_if_necessary() return proto_registry_utils.get_any_feature_view( - self.cached_registry_proto, name, project + registry_proto, name, project ) fv = self._get_object( "FEATURE_VIEWS", @@ -641,9 +645,9 @@ def list_all_feature_views( tags: Optional[dict[str, str]] = None, ) -> List[BaseFeatureView]: if allow_cache: - self._refresh_cached_registry_if_necessary() + registry_proto = self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_all_feature_views( - self.cached_registry_proto, project, tags + registry_proto, project, tags ) return ( @@ -679,9 +683,9 @@ def get_on_demand_feature_view( self, name: str, project: str, allow_cache: bool = False ) -> OnDemandFeatureView: if allow_cache: - self._refresh_cached_registry_if_necessary() + registry_proto = self._refresh_cached_registry_if_necessary() return proto_registry_utils.get_on_demand_feature_view( - self.cached_registry_proto, name, project + registry_proto, name, project ) return self._get_object( "ON_DEMAND_FEATURE_VIEWS", @@ -698,10 +702,8 @@ def get_saved_dataset( self, name: str, project: str, allow_cache: bool = False ) -> SavedDataset: if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_saved_dataset( - self.cached_registry_proto, name, project - ) + registry_proto = self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_saved_dataset(registry_proto, name, project) return self._get_object( "SAVED_DATASETS", name, @@ -717,9 +719,9 @@ def get_stream_feature_view( self, name: str, project: str, allow_cache: bool = False ): if allow_cache: - self._refresh_cached_registry_if_necessary() + registry_proto = self._refresh_cached_registry_if_necessary() return proto_registry_utils.get_stream_feature_view( - self.cached_registry_proto, name, project + registry_proto, name, project ) return self._get_object( "STREAM_FEATURE_VIEWS", @@ -736,9 +738,9 @@ def get_validation_reference( self, name: str, project: str, allow_cache: bool = False ) -> ValidationReference: if allow_cache: - self._refresh_cached_registry_if_necessary() + registry_proto = self._refresh_cached_registry_if_necessary() return proto_registry_utils.get_validation_reference( - self.cached_registry_proto, name, project + registry_proto, name, project ) return self._get_object( "VALIDATION_REFERENCES", @@ -787,10 +789,8 @@ def get_permission( self, name: str, project: str, allow_cache: bool = False ) -> Permission: if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_permission( - self.cached_registry_proto, name, project - ) + registry_proto = self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_permission(registry_proto, name, project) return self._get_object( "PERMISSIONS", name, @@ -810,10 +810,8 @@ def list_data_sources( tags: Optional[dict[str, str]] = None, ) -> List[DataSource]: if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_data_sources( - self.cached_registry_proto, project, tags - ) + registry_proto = self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_data_sources(registry_proto, project, tags) return self._list_objects( "DATA_SOURCES", project, @@ -830,10 +828,8 @@ def list_entities( tags: Optional[dict[str, str]] = None, ) -> List[Entity]: if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_entities( - self.cached_registry_proto, project, tags - ) + registry_proto = self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_entities(registry_proto, project, tags) return self._list_objects( "ENTITIES", project, EntityProto, Entity, "ENTITY_PROTO", tags=tags ) @@ -845,9 +841,9 @@ def list_feature_services( tags: Optional[dict[str, str]] = None, ) -> List[FeatureService]: if allow_cache: - self._refresh_cached_registry_if_necessary() + registry_proto = self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_feature_services( - self.cached_registry_proto, project, tags + registry_proto, project, tags ) return self._list_objects( "FEATURE_SERVICES", @@ -865,9 +861,9 @@ def list_feature_views( tags: Optional[dict[str, str]] = None, ) -> List[FeatureView]: if allow_cache: - self._refresh_cached_registry_if_necessary() + registry_proto = self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_feature_views( - self.cached_registry_proto, project, tags + registry_proto, project, tags ) return self._list_objects( "FEATURE_VIEWS", @@ -885,9 +881,9 @@ def list_on_demand_feature_views( tags: Optional[dict[str, str]] = None, ) -> List[OnDemandFeatureView]: if allow_cache: - self._refresh_cached_registry_if_necessary() + registry_proto = self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_on_demand_feature_views( - self.cached_registry_proto, project, tags + registry_proto, project, tags ) return self._list_objects( "ON_DEMAND_FEATURE_VIEWS", @@ -905,9 +901,9 @@ def list_saved_datasets( tags: Optional[dict[str, str]] = None, ) -> List[SavedDataset]: if allow_cache: - self._refresh_cached_registry_if_necessary() + registry_proto = self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_saved_datasets( - self.cached_registry_proto, project, tags + registry_proto, project, tags ) return self._list_objects( "SAVED_DATASETS", @@ -925,9 +921,9 @@ def list_stream_feature_views( tags: Optional[dict[str, str]] = None, ) -> List[StreamFeatureView]: if allow_cache: - self._refresh_cached_registry_if_necessary() + registry_proto = self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_stream_feature_views( - self.cached_registry_proto, project, tags + registry_proto, project, tags ) return self._list_objects( "STREAM_FEATURE_VIEWS", @@ -990,10 +986,8 @@ def list_permissions( tags: Optional[dict[str, str]] = None, ) -> List[Permission]: if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_permissions( - self.cached_registry_proto, project - ) + registry_proto = self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_permissions(registry_proto, project) return self._list_objects( "PERMISSIONS", project, @@ -1042,10 +1036,8 @@ def list_project_metadata( self, project: str, allow_cache: bool = False ) -> List[ProjectMetadata]: if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_project_metadata( - self.cached_registry_proto, project - ) + registry_proto = self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_project_metadata(registry_proto, project) with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT @@ -1139,18 +1131,6 @@ def process_project(project: Project): project_name = project.name last_updated_timestamp = project.last_updated_timestamp - try: - cached_project = self.get_project(project_name, True) - except ProjectObjectNotFoundException: - cached_project = None - - allow_cache = False - - if cached_project is not None: - allow_cache = ( - last_updated_timestamp <= cached_project.last_updated_timestamp - ) - r.projects.extend([project.to_proto()]) last_updated_timestamps.append(last_updated_timestamp) @@ -1165,7 +1145,9 @@ def process_project(project: Project): (self.list_validation_references, r.validation_references), (self.list_permissions, r.permissions), ]: - objs: List[Any] = lister(project_name, allow_cache) # type: ignore + # Always bypass cache here: proto() builds the cache, so using + # allow_cache=True would cause infinite recursion via refresh(). + objs: List[Any] = lister(project_name, False) # type: ignore if objs: obj_protos = [obj.to_proto() for obj in objs] for obj_proto in obj_protos: @@ -1354,8 +1336,8 @@ def get_project( allow_cache: bool = False, ) -> Project: if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_project(self.cached_registry_proto, name) + registry_proto = self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_project(registry_proto, name) return self._get_project(name) def _list_projects( @@ -1371,7 +1353,7 @@ def _list_projects( objects = [] for row in df.iterrows(): obj = Project.from_proto( - ProjectProto.FromString(row[1]["project_proto"]) + ProjectProto.FromString(row[1]["PROJECT_PROTO"]) ) if has_all_tags(obj.tags, tags): objects.append(obj) @@ -1384,8 +1366,8 @@ def list_projects( tags: Optional[dict[str, str]] = None, ) -> List[Project]: if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_projects(self.cached_registry_proto, tags) + registry_proto = self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_projects(registry_proto, tags) return self._list_projects(tags) def set_project_metadata(self, project: str, key: str, value: str): diff --git a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py index b9254e72699..1890cb6a087 100644 --- a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py +++ b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py @@ -86,9 +86,9 @@ def __enter__(self): # https://docs.snowflake.com/en/user-guide/key-pair-auth.html#configuring-key-pair-authentication if "private_key" in kwargs or "private_key_content" in kwargs: kwargs["private_key"] = parse_private_key_path( - kwargs.get("private_key_passphrase"), - kwargs.get("private_key"), - kwargs.get("private_key_content"), + kwargs.pop("private_key_passphrase", None), + kwargs.pop("private_key", None), + kwargs.pop("private_key_content", None), ) try: diff --git a/sdk/python/tests/integration/registration/test_universal_registry.py b/sdk/python/tests/integration/registration/test_universal_registry.py index a08353921c5..d8322e55636 100644 --- a/sdk/python/tests/integration/registration/test_universal_registry.py +++ b/sdk/python/tests/integration/registration/test_universal_registry.py @@ -43,6 +43,7 @@ from feast.infra.registry.base_registry import BaseRegistry from feast.infra.registry.registry import Registry from feast.infra.registry.remote import RemoteRegistry, RemoteRegistryConfig +from feast.infra.registry.snowflake import SnowflakeRegistry, SnowflakeRegistryConfig from feast.infra.registry.sql import SqlRegistry, SqlRegistryConfig from feast.on_demand_feature_view import on_demand_feature_view from feast.permissions.action import AuthzedAction @@ -284,6 +285,39 @@ def sqlite_registry(): yield SqlRegistry(registry_config, "project", None) +@pytest.fixture(scope="function") +def snowflake_registry(): + account = os.getenv("SNOWFLAKE_CI_DEPLOYMENT", "") + if not account: + pytest.skip("SNOWFLAKE_CI_DEPLOYMENT not set") + + config_kwargs = dict( + registry_type="snowflake.registry", + account=account, + user=os.getenv("SNOWFLAKE_CI_USER", ""), + role=os.getenv("SNOWFLAKE_CI_ROLE", ""), + warehouse=os.getenv("SNOWFLAKE_CI_WAREHOUSE", ""), + database=os.getenv("SNOWFLAKE_CI_DATABASE", "FEAST"), + schema=os.getenv("SNOWFLAKE_CI_SCHEMA", "REGISTRY_TEST"), + cache_ttl_seconds=2, + purge_feast_metadata=False, + ) + + private_key = os.getenv("SNOWFLAKE_CI_PRIVATE_KEY_PATH", "") + if private_key: + config_kwargs["private_key"] = private_key + passphrase = os.getenv("SNOWFLAKE_CI_PRIVATE_KEY_PASSPHRASE", "") + if passphrase: + config_kwargs["private_key_passphrase"] = passphrase + else: + config_kwargs["password"] = os.getenv("SNOWFLAKE_CI_PASSWORD", "") + + registry_config = SnowflakeRegistryConfig(**config_kwargs) + registry = SnowflakeRegistry(registry_config, "project", None) + yield registry + registry.teardown() + + @pytest.fixture(scope="function") def hdfs_registry(): HADOOP_NAMENODE_IMAGE = "bde2020/hadoop-namenode:2.0.0-hadoop3.2.1-java8" @@ -412,6 +446,10 @@ def mock_remote_registry(): lazy_fixture("hdfs_registry"), marks=pytest.mark.xdist_group(name="hdfs_registry"), ), + pytest.param( + lazy_fixture("snowflake_registry"), + marks=pytest.mark.xdist_group(name="snowflake_registry"), + ), ] ) else: diff --git a/sdk/python/tests/unit/infra/registry/test_snowflake_registry.py b/sdk/python/tests/unit/infra/registry/test_snowflake_registry.py index 526b2b4d35d..f1935f329b0 100644 --- a/sdk/python/tests/unit/infra/registry/test_snowflake_registry.py +++ b/sdk/python/tests/unit/infra/registry/test_snowflake_registry.py @@ -19,6 +19,7 @@ from feast.entity import Entity from feast.infra.registry.snowflake import SnowflakeRegistry, SnowflakeRegistryConfig +from feast.infra.utils.snowflake.snowflake_utils import GetSnowflakeConnection @pytest.fixture @@ -186,3 +187,144 @@ def simulated_snowflake(conn, query): f"feast#6208: UPDATE WHERE clause references {project_a!r} — unintended cross-project write.\n" f"Query: {update_query}" ) + + +class TestSyncFeastMetadataToProjectsTable: + def _make_registry(self): + """Create a SnowflakeRegistry with mocked __init__.""" + with patch.object(SnowflakeRegistry, "__init__", lambda self: None): + registry = SnowflakeRegistry() + registry.registry_config = MagicMock() + registry.registry_path = "test_db.test_schema" + registry.purge_feast_metadata = False + return registry + + @patch( + "feast.infra.registry.snowflake.GetSnowflakeConnection", + ) + @patch("feast.infra.registry.snowflake.execute_snowflake_statement") + def test_sync_with_feast_metadata_projects(self, mock_execute, mock_get_conn): + registry = self._make_registry() + + metadata_df = pd.DataFrame({"PROJECT_ID": ["project_a", "project_b"]}) + projects_df = pd.DataFrame({"PROJECT_ID": ["project_a"]}) + + mock_cursor = MagicMock() + mock_cursor.fetch_pandas_all.side_effect = [metadata_df, projects_df] + mock_execute.return_value = mock_cursor + + mock_conn = MagicMock() + mock_get_conn.return_value.__enter__ = MagicMock(return_value=mock_conn) + mock_get_conn.return_value.__exit__ = MagicMock(return_value=False) + + with patch.object(registry, "apply_project") as mock_apply: + registry._sync_feast_metadata_to_projects_table() + + mock_apply.assert_called_once() + applied_project = mock_apply.call_args[0][0] + assert applied_project.name == "project_b" + + @patch( + "feast.infra.registry.snowflake.GetSnowflakeConnection", + ) + @patch("feast.infra.registry.snowflake.execute_snowflake_statement") + def test_sync_with_no_feast_metadata(self, mock_execute, mock_get_conn): + registry = self._make_registry() + + empty_df = pd.DataFrame({"PROJECT_ID": []}) + mock_cursor = MagicMock() + mock_cursor.fetch_pandas_all.return_value = empty_df + mock_execute.return_value = mock_cursor + + mock_conn = MagicMock() + mock_get_conn.return_value.__enter__ = MagicMock(return_value=mock_conn) + mock_get_conn.return_value.__exit__ = MagicMock(return_value=False) + + with patch.object(registry, "apply_project") as mock_apply: + registry._sync_feast_metadata_to_projects_table() + + mock_apply.assert_not_called() + + @patch( + "feast.infra.registry.snowflake.GetSnowflakeConnection", + ) + @patch("feast.infra.registry.snowflake.execute_snowflake_statement") + def test_sync_deduplicates_project_ids(self, mock_execute, mock_get_conn): + """Sets should deduplicate project IDs; lists would not.""" + registry = self._make_registry() + + metadata_df = pd.DataFrame( + {"PROJECT_ID": ["project_a", "project_a", "project_b"]} + ) + projects_df = pd.DataFrame({"PROJECT_ID": []}) + + mock_cursor = MagicMock() + mock_cursor.fetch_pandas_all.side_effect = [metadata_df, projects_df] + mock_execute.return_value = mock_cursor + + mock_conn = MagicMock() + mock_get_conn.return_value.__enter__ = MagicMock(return_value=mock_conn) + mock_get_conn.return_value.__exit__ = MagicMock(return_value=False) + + with patch.object(registry, "apply_project") as mock_apply: + registry._sync_feast_metadata_to_projects_table() + + assert mock_apply.call_count == 2 + applied_names = {call[0][0].name for call in mock_apply.call_args_list} + assert applied_names == {"project_a", "project_b"} + + +class _DictableConfig: + """A config object that supports dict() conversion and attribute access.""" + + def __init__(self, data): + self._data = data + for k, v in data.items(): + setattr(self, k, v) + + def __iter__(self): + return iter(self._data) + + def keys(self): + return self._data.keys() + + def __getitem__(self, key): + return self._data[key] + + +class TestGetSnowflakeConnection: + @patch("feast.infra.utils.snowflake.snowflake_utils.parse_private_key_path") + @patch("feast.infra.utils.snowflake.snowflake_utils.snowflake.connector") + @patch("feast.infra.utils.snowflake.snowflake_utils._cache", {}) + def test_private_key_kwargs_not_leaked_to_connect( + self, mock_connector, mock_parse_key + ): + """private_key_passphrase and private_key_content must not be passed to connect().""" + mock_parse_key.return_value = b"parsed_key_bytes" + mock_conn = MagicMock() + mock_connector.connect.return_value = mock_conn + + config = _DictableConfig( + { + "type": "snowflake.registry", + "account": "test_account", + "user": "test_user", + "password": None, + "role": "test_role", + "warehouse": "test_wh", + "database": "test_db", + "schema_": "test_schema", + "config_path": "", + "private_key": "/path/to/key.p8", + "private_key_passphrase": "my_secret", # pragma: allowlist secret + "private_key_content": None, + } + ) + + with GetSnowflakeConnection(config): + pass + + connect_kwargs = mock_connector.connect.call_args[1] + assert "private_key_passphrase" not in connect_kwargs + assert "private_key_content" not in connect_kwargs + assert connect_kwargs["private_key"] == b"parsed_key_bytes"