Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix Snowflake proto conversion and add test
Signed-off-by: Felix Wang <wangfelix98@gmail.com>
  • Loading branch information
felixwang9817 committed Apr 23, 2022
commit da33eb4079c7687e8486fc15f60c03037fb85df7
81 changes: 14 additions & 67 deletions sdk/python/feast/infra/offline_stores/snowflake_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
"""
if table is None and query is None:
raise ValueError('No "table" argument provided.')

# The default Snowflake schema is named "PUBLIC".
_schema = "PUBLIC" if (database and table and not schema) else schema

Expand Down Expand Up @@ -112,6 +113,7 @@ def from_proto(data_source: DataSourceProto):
A SnowflakeSource object based on the data_source protobuf.
"""
return SnowflakeSource(
name=data_source.name,
field_mapping=dict(data_source.field_mapping),
database=data_source.snowflake_options.database,
schema=data_source.snowflake_options.schema,
Expand All @@ -136,18 +138,12 @@ def __eq__(self, other):
)

return (
self.name == other.name
and self.snowflake_options.database == other.snowflake_options.database
and self.snowflake_options.schema == other.snowflake_options.schema
and self.snowflake_options.table == other.snowflake_options.table
and self.snowflake_options.query == other.snowflake_options.query
and self.snowflake_options.warehouse == other.snowflake_options.warehouse
and self.timestamp_field == other.timestamp_field
and self.created_timestamp_column == other.created_timestamp_column
and self.field_mapping == other.field_mapping
and self.description == other.description
and self.tags == other.tags
and self.owner == other.owner
super().__eq__(other)
and self.database == other.database
and self.schema == other.schema
and self.table == other.table
and self.query == other.query
and self.warehouse == other.warehouse
)

@property
Expand Down Expand Up @@ -183,6 +179,7 @@ def to_proto(self) -> DataSourceProto:
A DataSourceProto object.
"""
data_source_proto = DataSourceProto(
name=self.name,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this the main bug leading to duplicate data sources?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this was the root cause, see #2581 (comment)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch.

type=DataSourceProto.BATCH_SNOWFLAKE,
field_mapping=self.field_mapping,
snowflake_options=self.snowflake_options.to_proto(),
Expand Down Expand Up @@ -263,61 +260,11 @@ def __init__(
query: Optional[str],
warehouse: Optional[str],
):
self._database = database
self._schema = schema
self._table = table
self._query = query
self._warehouse = warehouse

@property
def query(self):
"""Returns the snowflake SQL query referenced by this source."""
return self._query

@query.setter
def query(self, query):
"""Sets the snowflake SQL query referenced by this source."""
self._query = query

@property
def database(self):
"""Returns the database name of this snowflake table."""
return self._database

@database.setter
def database(self, database):
"""Sets the database ref of this snowflake table."""
self._database = database

@property
def schema(self):
"""Returns the schema name of this snowflake table."""
return self._schema

@schema.setter
def schema(self, schema):
"""Sets the schema of this snowflake table."""
self._schema = schema

@property
def table(self):
"""Returns the table name of this snowflake table."""
return self._table

@table.setter
def table(self, table):
"""Sets the table ref of this snowflake table."""
self._table = table

@property
def warehouse(self):
"""Returns the warehouse name of this snowflake table."""
return self._warehouse

@warehouse.setter
def warehouse(self, warehouse):
"""Sets the warehouse name of this snowflake table."""
self._warehouse = warehouse
self.database = database or ""
self.schema = schema or ""
self.table = table or ""
self.query = query or ""
self.warehouse = warehouse or ""

@classmethod
def from_proto(cls, snowflake_options_proto: DataSourceProto.SnowflakeOptions):
Expand Down
18 changes: 18 additions & 0 deletions sdk/python/tests/unit/test_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from feast.field import Field
from feast.infra.offline_stores.bigquery_source import BigQuerySource
from feast.infra.offline_stores.snowflake_source import SnowflakeSource
from feast.types import Bool, Float32, Int64


Expand Down Expand Up @@ -145,3 +146,20 @@ def test_default_data_source_kw_arg_warning():
message_format=ProtoFormat("class_path"),
topic="topic",
)


def test_proto_conversion():
snowflake_source = SnowflakeSource(
name="test_source",
database="test_database",
warehouse="test_warehouse",
schema="test_schema",
table="test_table",
timestamp_field="event_timestamp",
created_timestamp_column="created_timestamp",
field_mapping={"foo": "bar"},
description="test description",
owner="test@gmail.com",
)

assert SnowflakeSource.from_proto(snowflake_source.to_proto()) == snowflake_source