Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added more type hints
Signed-off-by: Chester Ong <chester.ong.ch@gmail.com>
  • Loading branch information
bushwhackr committed Feb 8, 2024
commit d9ddf1351bbf08909138f036247443766949d1fe
9 changes: 7 additions & 2 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.avro.functions import from_avro
from pyspark.sql.functions import col, from_json
from pyspark.sql.streaming import StreamingQuery

from feast.data_format import AvroFormat, JsonFormat
from feast.data_source import KafkaSource, PushMode
Expand Down Expand Up @@ -63,7 +64,11 @@ def __init__(
self.join_keys = [fs.get_entity(entity).join_key for entity in sfv.entities]
super().__init__(fs=fs, sfv=sfv, data_source=sfv.stream_source)

def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None:
# Type hinting for data_source type.
# data_source type has been checked to be an instance of KafkaSource.
self.data_source: KafkaSource = self.data_source # type: ignore

def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> StreamingQuery:
ingested_stream_df = self._ingest_stream_data()
transformed_df = self._construct_transformation_plan(ingested_stream_df)
online_store_query = self._write_stream_data(transformed_df, to)
Expand Down Expand Up @@ -122,7 +127,7 @@ def _ingest_stream_data(self) -> StreamTable:
def _construct_transformation_plan(self, df: StreamTable) -> StreamTable:
return self.sfv.udf.__call__(df) if self.sfv.udf else df

def _write_stream_data(self, df: StreamTable, to: PushMode):
def _write_stream_data(self, df: StreamTable, to: PushMode) -> StreamingQuery:
# Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema.
def batch_write(row: DataFrame, batch_id: int):
rows: pd.DataFrame = row.toPandas()
Expand Down
6 changes: 3 additions & 3 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel):
type: Literal["redshift"] = "redshift"
""" Offline store type selector"""

cluster_id: Optional[StrictStr]
cluster_id: Optional[StrictStr] = None
""" Redshift cluster identifier, for provisioned clusters """

user: Optional[StrictStr]
user: Optional[StrictStr] = None
""" Redshift user name, only required for provisioned clusters """

workgroup: Optional[StrictStr]
workgroup: Optional[StrictStr] = None
""" Redshift workgroup identifier, for serverless """

region: StrictStr
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/infra/offline_stores/snowflake_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Callable, Dict, Iterable, Optional, Tuple
from typing import Callable, Dict, Iterable, Optional, Tuple, Any

from typeguard import typechecked

Expand Down Expand Up @@ -223,7 +223,7 @@ def get_table_column_names_and_types(
query = f"SELECT * FROM {self.get_table_query_string()} LIMIT 5"
cursor = execute_snowflake_statement(conn, query)

metadata = [
metadata: list[dict[str, Any]] = [
{
"column_name": column.name,
"type_code": column.type_code,
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/registry/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def _delete_object(
"""
cursor = execute_snowflake_statement(conn, query)

if cursor.rowcount < 1 and not_found_exception:
if cursor.rowcount < 1 and not_found_exception: # type: ignore
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this and can pass the mypy check:
if cursor.rowcount and (cursor.rowcount < 1) and not_found_exception:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, this is due to another change I did to the mypy config: --follow-imports=skip.

run the below command to see the error:
python -m mypy --exclude=/tests/ sdk/python/feast/infra/registry/snowflake.py you will see an error.

`sdk/python/feast/infra/registry/snowflake.py:421: error: Unsupported operand types for > ("int" and "None")  [operator]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

essentially cursor.rowcount could return none from snowflake's signature

    @property
    def rowcount(self) -> int | None:
        return self._total_rowcount if self._total_rowcount >= 0 else None

raise not_found_exception(name, project)
self._set_last_updated_metadata(datetime.utcnow(), project)

Expand Down
4 changes: 3 additions & 1 deletion sdk/python/feast/infra/utils/snowflake/snowflake_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import feast
from feast.errors import SnowflakeIncompleteConfig, SnowflakeQueryUnknownError
from feast.feature_view import FeatureView
from feast.infra.offline_stores.snowflake import SnowflakeOfflineStoreConfig
from feast.infra.online_stores.snowflake import SnowflakeOnlineStoreConfig
from feast.repo_config import RepoConfig

try:
Expand All @@ -43,7 +45,7 @@


class GetSnowflakeConnection:
def __init__(self, config: str, autocommit=True):
def __init__(self, config: SnowflakeOfflineStoreConfig | SnowflakeOnlineStoreConfig, autocommit=True):
self.config = config
self.autocommit = autocommit

Expand Down