Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix lint errors
Signed-off-by: Danny Chiao <danny@tecton.ai>
Signed-off-by: sfc-gh-madkins <miles.adkins@snowflake.com>
  • Loading branch information
adchia authored and sfc-gh-madkins committed Jan 31, 2022
commit 7c24117f032af8e8dfdc2859e87f2428b86df8ba
2 changes: 1 addition & 1 deletion sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,4 +329,4 @@ def __init__(self, e: KeyError):

class SnowflakeQueryUnknownError(Exception):
def __init__(self, query: str):
super().__init__(f"Snowflake query failed: {query}")
super().__init__(f"Snowflake query failed: {query}")
8 changes: 4 additions & 4 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Optional,
Tuple,
Union,
cast,
)

import numpy as np
Expand Down Expand Up @@ -212,7 +213,7 @@ def get_historical_features(
)

entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
entity_df, entity_df_event_timestamp_col, snowflake_conn, config,
entity_df, entity_df_event_timestamp_col, snowflake_conn,
)

@contextlib.contextmanager
Expand Down Expand Up @@ -357,7 +358,7 @@ def to_sql(self) -> str:
with self._query_generator() as query:
return query

def to_arrow_chunks(self, arrow_options: Optional[Dict] = None) -> list:
def to_arrow_chunks(self, arrow_options: Optional[Dict] = None) -> Optional[List]:
with self._query_generator() as query:

arrow_batches = execute_snowflake_statement(
Expand Down Expand Up @@ -436,7 +437,6 @@ def _get_entity_df_event_timestamp_range(
entity_df: Union[pd.DataFrame, str],
entity_df_event_timestamp_col: str,
snowflake_conn: SnowflakeConnection,
config: RepoConfig,
) -> Tuple[datetime, datetime]:
if isinstance(entity_df, pd.DataFrame):
entity_df_event_timestamp = entity_df.loc[
Expand All @@ -456,7 +456,7 @@ def _get_entity_df_event_timestamp_range(
query = f'SELECT MIN("{entity_df_event_timestamp_col}") AS "min_value", MAX("{entity_df_event_timestamp_col}") AS "max_value" FROM ({entity_df})'
results = execute_snowflake_statement(snowflake_conn, query).fetchall()

entity_df_event_timestamp_range = results[0]
entity_df_event_timestamp_range = cast(Tuple[datetime, datetime], results[0])
else:
raise InvalidEntityType(type(entity_df))

Expand Down
80 changes: 32 additions & 48 deletions sdk/python/feast/infra/utils/snowflake_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,33 @@
import string
from logging import getLogger
from tempfile import TemporaryDirectory
from typing import Iterator, Optional, Sequence, Tuple, TypeVar, Union, Dict
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)

import pandas as pd
import snowflake.connector
from sdk.python.feast.errors import (
SnowflakeIncompleteConfig,
SnowflakeQueryUnknownError,
)
from snowflake.connector import ProgrammingError, SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor
from tenacity import retry, wait_exponential, retry_if_exception_type, stop_after_attempt

from sdk.python.feast.errors import SnowflakeIncompleteConfig, SnowflakeQueryUnknownError
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

getLogger("snowflake.connector.cursor").disabled = True
getLogger("snowflake.connector.connection").disabled = True
Expand Down Expand Up @@ -74,25 +92,7 @@ def write_pandas(
quote_identifiers: bool = True,
auto_create_table: bool = False,
create_temp_table: bool = False,
) -> Tuple[
bool,
int,
int,
Sequence[
Tuple[
str,
str,
int,
int,
int,
int,
Optional[str],
Optional[int],
Optional[int],
Optional[str],
]
],
]:
):
"""Allows users to most efficiently write back a pandas DataFrame to Snowflake.

It works by dumping the DataFrame into Parquet files, uploading them and finally copying their data into the table.
Expand Down Expand Up @@ -128,10 +128,6 @@ def write_pandas(
auto_create_table: When true, will automatically create a table with corresponding columns for each column in
the passed in DataFrame. The table will not be created if it already exists
create_temp_table: Will make the auto-created table as a temporary table

Returns:
Returns the COPY INTO command's results to verify ingestion in the form of a tuple of whether all chunks were
ingested correctly, # of chunks, # of ingested rows, and ingest's output.
"""
if database is not None and schema is None:
raise ProgrammingError(
Expand Down Expand Up @@ -197,9 +193,8 @@ def write_pandas(
result_cursor = cursor.execute(infer_schema_sql, _is_internal=True)
if result_cursor is None:
raise SnowflakeQueryUnknownError(infer_schema_sql)
column_type_mapping: Dict[str, str] = dict(
result_cursor.fetchall()
)
result = cast(List[Tuple[str, str]], result_cursor.fetchall())
column_type_mapping: Dict[str, str] = dict(result)
# Infer schema can return the columns out of order depending on the chunking we do when uploading
# so we have to iterate through the dataframe columns to make sure we create the table with its
# columns in order
Expand Down Expand Up @@ -243,14 +238,8 @@ def write_pandas(
result_cursor = cursor.execute(copy_into_sql, _is_internal=True)
if result_cursor is None:
raise SnowflakeQueryUnknownError(copy_into_sql)
copy_results = result_cursor.fetchall()
copy_results = cast(List[Tuple], result_cursor.fetchall())
result_cursor.close()
return (
all(e[1] == "LOADED" for e in copy_results),
len(copy_results),
sum(int(e[3]) for e in copy_results),
copy_results,
)


@retry(
Expand All @@ -259,11 +248,11 @@ def write_pandas(
stop=stop_after_attempt(5),
reraise=True,
)
def create_file_format(compression: str, compression_map: Dict[str, str], cursor: SnowflakeCursor) -> str:
def create_file_format(
compression: str, compression_map: Dict[str, str], cursor: SnowflakeCursor
) -> str:
file_format_name = (
'"'
+ "".join(random.choice(string.ascii_lowercase) for _ in range(5))
+ '"'
'"' + "".join(random.choice(string.ascii_lowercase) for _ in range(5)) + '"'
)
file_format_sql = (
f"CREATE FILE FORMAT {file_format_name} "
Expand All @@ -282,9 +271,7 @@ def create_file_format(compression: str, compression_map: Dict[str, str], cursor
reraise=True,
)
def create_temporary_sfc_stage(cursor: SnowflakeCursor) -> str:
stage_name = "".join(
random.choice(string.ascii_lowercase) for _ in range(5)
)
stage_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
create_stage_sql = (
"create temporary stage /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
'"{stage_name}"'
Expand All @@ -297,10 +284,7 @@ def create_temporary_sfc_stage(cursor: SnowflakeCursor) -> str:
return stage_name


T = TypeVar("T", bound=Sequence)


def chunk_helper(lst: T, n: int) -> Iterator[Tuple[int, T]]:
def chunk_helper(lst: pd.DataFrame, n: int) -> Iterator[Tuple[int, pd.DataFrame]]:
"""Helper generator to chunk a sequence efficiently with current index like if enumerate was called on sequence."""
for i in range(0, len(lst), n):
yield int(i / n), lst[i : i + n]