diff --git a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py index b9035b40dbf..b9254e72699 100644 --- a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py +++ b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py @@ -513,7 +513,7 @@ def chunk_helper(lst: pd.DataFrame, n: int) -> Iterator[Tuple[int, pd.DataFrame] def parse_private_key_path( - private_key_passphrase: str, + private_key_passphrase: Optional[str] = None, key_path: Optional[str] = None, private_key_content: Optional[bytes] = None, ) -> bytes: @@ -521,14 +521,18 @@ def parse_private_key_path( if private_key_content: p_key = serialization.load_pem_private_key( private_key_content, - password=private_key_passphrase.encode(), + password=private_key_passphrase.encode() + if private_key_passphrase is not None + else None, backend=default_backend(), ) elif key_path: with open(key_path, "rb") as key: p_key = serialization.load_pem_private_key( key.read(), - password=private_key_passphrase.encode(), + password=private_key_passphrase.encode() + if private_key_passphrase is not None + else None, backend=default_backend(), ) else: diff --git a/sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py b/sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py new file mode 100644 index 00000000000..8ae6ec63ba5 --- /dev/null +++ b/sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py @@ -0,0 +1,71 @@ +import tempfile +from typing import Optional + +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + +from feast.infra.utils.snowflake.snowflake_utils import parse_private_key_path + +PRIVATE_KEY_PASSPHRASE = "test" + + +def _pem_private_key(passphrase: Optional[str]): + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + return private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=( + serialization.BestAvailableEncryption(passphrase.encode()) + if passphrase + else serialization.NoEncryption() + ), + ) + + +@pytest.fixture +def unencrypted_private_key(): + return _pem_private_key(None) + + +@pytest.fixture +def encrypted_private_key(): + return _pem_private_key(PRIVATE_KEY_PASSPHRASE) + + +def test_parse_private_key_path_key_content_unencrypted(unencrypted_private_key): + parse_private_key_path( + None, + None, + unencrypted_private_key, + ) + + +def test_parse_private_key_path_key_content_encrypted(encrypted_private_key): + parse_private_key_path( + PRIVATE_KEY_PASSPHRASE, + None, + encrypted_private_key, + ) + + +def test_parse_private_key_path_key_path_unencrypted(unencrypted_private_key): + with tempfile.NamedTemporaryFile(mode="wb") as f: + f.write(unencrypted_private_key) + f.flush() + parse_private_key_path( + None, + f.name, + None, + ) + + +def test_parse_private_key_path_key_path_encrypted(encrypted_private_key): + with tempfile.NamedTemporaryFile(mode="wb") as f: + f.write(encrypted_private_key) + f.flush() + parse_private_key_path( + PRIVATE_KEY_PASSPHRASE, + f.name, + None, + )