|
1 | 1 | import tempfile |
2 | 2 | from typing import Optional |
| 3 | +from unittest.mock import MagicMock |
3 | 4 |
|
4 | 5 | import pytest |
5 | 6 | from cryptography.hazmat.primitives import serialization |
6 | 7 | from cryptography.hazmat.primitives.asymmetric import rsa |
7 | 8 |
|
8 | | -from feast.infra.utils.snowflake.snowflake_utils import parse_private_key_path |
| 9 | +from feast.infra.utils.snowflake.snowflake_utils import ( |
| 10 | + execute_snowflake_statement, |
| 11 | + parse_private_key_path, |
| 12 | +) |
9 | 13 |
|
10 | 14 | PRIVATE_KEY_PASSPHRASE = "test" |
11 | 15 |
|
@@ -69,3 +73,48 @@ def test_parse_private_key_path_key_path_encrypted(encrypted_private_key): |
69 | 73 | f.name, |
70 | 74 | None, |
71 | 75 | ) |
| 76 | + |
| 77 | + |
| 78 | +class TestExecuteSnowflakeStatement: |
| 79 | + def test_empty_query_returns_cursor_without_executing(self): |
| 80 | + mock_conn = MagicMock() |
| 81 | + mock_cursor = MagicMock() |
| 82 | + mock_conn.cursor.return_value = mock_cursor |
| 83 | + |
| 84 | + result = execute_snowflake_statement(mock_conn, "") |
| 85 | + |
| 86 | + assert result is mock_cursor |
| 87 | + mock_conn.cursor.assert_called_once() |
| 88 | + mock_cursor.execute.assert_not_called() |
| 89 | + |
| 90 | + def test_whitespace_only_query_returns_cursor_without_executing(self): |
| 91 | + mock_conn = MagicMock() |
| 92 | + mock_cursor = MagicMock() |
| 93 | + mock_conn.cursor.return_value = mock_cursor |
| 94 | + |
| 95 | + result = execute_snowflake_statement(mock_conn, " \t\n ") |
| 96 | + |
| 97 | + assert result is mock_cursor |
| 98 | + mock_conn.cursor.assert_called_once() |
| 99 | + mock_cursor.execute.assert_not_called() |
| 100 | + |
| 101 | + def test_valid_query_executes_and_returns_cursor(self): |
| 102 | + mock_conn = MagicMock() |
| 103 | + mock_cursor = MagicMock() |
| 104 | + mock_executed_cursor = MagicMock() |
| 105 | + mock_conn.cursor.return_value = mock_cursor |
| 106 | + mock_cursor.execute.return_value = mock_executed_cursor |
| 107 | + |
| 108 | + result = execute_snowflake_statement(mock_conn, "SELECT 1") |
| 109 | + |
| 110 | + assert result is mock_executed_cursor |
| 111 | + mock_cursor.execute.assert_called_once_with("SELECT 1") |
| 112 | + |
| 113 | + def test_valid_query_raises_on_none_cursor(self): |
| 114 | + mock_conn = MagicMock() |
| 115 | + mock_cursor = MagicMock() |
| 116 | + mock_conn.cursor.return_value = mock_cursor |
| 117 | + mock_cursor.execute.return_value = None |
| 118 | + |
| 119 | + with pytest.raises(Exception, match="Snowflake query failed"): |
| 120 | + execute_snowflake_statement(mock_conn, "SELECT 1") |
0 commit comments