diff --git a/.secrets.baseline b/.secrets.baseline index 74979b4f91b..e9e8acd786b 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -1426,7 +1426,7 @@ "filename": "sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py", "hashed_secret": "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3", "is_verified": false, - "line_number": 14 + "line_number": 15 } ], "sdk/python/tests/unit/local_feast_tests/test_init.py": [ diff --git a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py index 1890cb6a087..e22d4857a19 100644 --- a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py +++ b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py @@ -76,10 +76,6 @@ def __enter__(self): kwargs.update((k, v) for k, v in config_dict.items() if v is not None) - for k, v in kwargs.items(): - if k in ["role", "warehouse", "database", "schema_"]: - kwargs[k] = f'"{v}"' - kwargs["schema"] = kwargs.pop("schema_") # https://docs.snowflake.com/en/user-guide/python-connector-example.html#using-key-pair-authentication-key-pair-rotation diff --git a/sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py b/sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py index 14b42c9783b..4bb02076354 100644 --- a/sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py +++ b/sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py @@ -1,12 +1,13 @@ import tempfile from typing import Optional -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from feast.infra.utils.snowflake.snowflake_utils import ( + GetSnowflakeConnection, execute_snowflake_statement, parse_private_key_path, ) @@ -75,6 +76,64 @@ def test_parse_private_key_path_key_path_encrypted(encrypted_private_key): ) +class _AttrDict(dict): + __getattr__ = dict.__getitem__ + + +def _make_config(**overrides): + defaults = { + "type": "snowflake.offline", + "account": "test_account", + "user": "test_user", + "password": "test_password", # pragma: allowlist secret + "role": "test_role", + "warehouse": "test_wh", + "database": "test_db", + "schema_": "test_schema", + "config_path": "", + } + defaults.update(overrides) + return _AttrDict(defaults) + + +@patch("feast.infra.utils.snowflake.snowflake_utils.snowflake.connector") +class TestGetSnowflakeConnectionIdentifierQuoting: + @pytest.fixture(autouse=True) + def _clear_cache(self): + with patch("feast.infra.utils.snowflake.snowflake_utils._cache", {}): + yield + + @pytest.mark.parametrize( + "config_key,connect_key,value", + [ + ("warehouse", "warehouse", "MY_WH"), + ("role", "role", "ANALYST"), + ("database", "database", "PROD_DB"), + ("schema_", "schema", "PUBLIC"), + ], + ) + def test_identifier_passed_without_quoting( + self, mock_connector, config_key, connect_key, value + ): + mock_connector.connect.return_value = MagicMock() + + with GetSnowflakeConnection(_make_config(**{config_key: value})): + pass + + kwargs = mock_connector.connect.call_args[1] + assert kwargs[connect_key] == value + + def test_schema_key_renamed_from_schema_underscore(self, mock_connector): + mock_connector.connect.return_value = MagicMock() + + with GetSnowflakeConnection(_make_config(schema_="analytics")): + pass + + kwargs = mock_connector.connect.call_args[1] + assert "schema" in kwargs + assert "schema_" not in kwargs + + class TestExecuteSnowflakeStatement: def test_empty_query_is_passed_through_to_execute(self): mock_conn = MagicMock()