Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add apply test
Signed-off-by: Danny Chiao <danny@tecton.ai>
  • Loading branch information
adchia committed Mar 4, 2022
commit b4a73dfc1242ea2dd1251349692d2bd36f394731
2 changes: 1 addition & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def list_data_sources(self, allow_cache: bool = False) -> List[DataSource]:
Returns:
A list of data sources.
"""
return self._registry.list_data_sources(allow_cache=allow_cache)
return self._registry.list_data_sources(self.project, allow_cache=allow_cache)

@log_exceptions_and_usage
def get_entity(self, name: str) -> Entity:
Expand Down
3 changes: 2 additions & 1 deletion sdk/python/feast/infra/offline_stores/bigquery_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def __eq__(self, other):
)

return (
self.bigquery_options.table_ref == other.bigquery_options.table_ref
self.name == other.name
and self.bigquery_options.table_ref == other.bigquery_options.table_ref
and self.bigquery_options.query == other.bigquery_options.query
and self.event_timestamp_column == other.event_timestamp_column
and self.created_timestamp_column == other.created_timestamp_column
Expand Down
3 changes: 2 additions & 1 deletion sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def __eq__(self, other):
raise TypeError("Comparisons should only involve FileSource class objects.")

return (
self.file_options.file_url == other.file_options.file_url
self.name == other.name
and self.file_options.file_url == other.file_options.file_url
and self.file_options.file_format == other.file_options.file_format
and self.event_timestamp_column == other.event_timestamp_column
and self.created_timestamp_column == other.created_timestamp_column
Expand Down
3 changes: 2 additions & 1 deletion sdk/python/feast/infra/offline_stores/redshift_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def __eq__(self, other):
)

return (
self.redshift_options.table == other.redshift_options.table
self.name == other.name
and self.redshift_options.table == other.redshift_options.table
and self.redshift_options.schema == other.redshift_options.schema
and self.redshift_options.query == other.redshift_options.query
and self.event_timestamp_column == other.event_timestamp_column
Expand Down
3 changes: 2 additions & 1 deletion sdk/python/feast/infra/offline_stores/snowflake_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def __eq__(self, other):
)

return (
self.snowflake_options.database == other.snowflake_options.database
self.name == other.name
and self.snowflake_options.database == other.snowflake_options.database
and self.snowflake_options.schema == other.snowflake_options.schema
and self.snowflake_options.table == other.snowflake_options.table
and self.snowflake_options.query == other.snowflake_options.query
Expand Down
8 changes: 6 additions & 2 deletions sdk/python/feast/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,14 @@ def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]
entities.append(Entity.from_proto(entity_proto))
return entities

def list_data_sources(self, allow_cache: bool = False) -> List[DataSource]:
def list_data_sources(
self, project: str, allow_cache: bool = False
) -> List[DataSource]:
"""
Retrieve a list of data sources from the registry

Args:
project: Filter data source based on project name
allow_cache: Whether to allow returning data sources from a cached registry

Returns:
Expand All @@ -295,7 +298,8 @@ def list_data_sources(self, allow_cache: bool = False) -> List[DataSource]:
registry_proto = self._get_registry_proto(allow_cache=allow_cache)
data_sources = []
for data_source_proto in registry_proto.data_sources:
data_sources.append(DataSource.from_proto(data_source_proto))
if data_source_proto.project == project:
data_sources.append(DataSource.from_proto(data_source_proto))
return data_sources

def apply_data_source(
Expand Down
58 changes: 55 additions & 3 deletions sdk/python/tests/integration/registration/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@


@pytest.fixture
def local_registry():
def local_registry() -> Registry:
fd, registry_path = mkstemp()
registry_config = RegistryConfig(path=registry_path, cache_ttl_seconds=600)
return Registry(registry_config, None)


@pytest.fixture
def gcs_registry():
def gcs_registry() -> Registry:
from google.cloud import storage

storage_client = storage.Client()
Expand All @@ -58,7 +58,7 @@ def gcs_registry():


@pytest.fixture
def s3_registry():
def s3_registry() -> Registry:
registry_config = RegistryConfig(
path=f"s3://feast-integration-tests/registries/{int(time.time() * 1000)}/registry.db",
cache_ttl_seconds=600,
Expand Down Expand Up @@ -428,6 +428,58 @@ def test_apply_feature_view_integration(test_registry):
test_registry._get_registry_proto()


@pytest.mark.integration
@pytest.mark.parametrize(
"test_registry", [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")],
)
def test_apply_data_source(test_registry: Registry):
# Create Feature Views
batch_source = FileSource(
name="test_source",
file_format=ParquetFormat(),
path="file://feast/*",
event_timestamp_column="ts_col",
created_timestamp_column="timestamp",
date_partition_column="date_partition_col",
)

fv1 = FeatureView(
name="my_feature_view_1",
features=[
Feature(name="fs1_my_feature_1", dtype=ValueType.INT64),
Feature(name="fs1_my_feature_2", dtype=ValueType.STRING),
Feature(name="fs1_my_feature_3", dtype=ValueType.STRING_LIST),
Feature(name="fs1_my_feature_4", dtype=ValueType.BYTES_LIST),
],
entities=["fs1_my_entity_1"],
tags={"team": "matchmaking"},
batch_source=batch_source,
ttl=timedelta(minutes=5),
)

project = "project"

# Register data source and feature view
test_registry.apply_data_source(batch_source, project, commit=False)
test_registry.apply_feature_view(fv1, project, commit=True)

registry_feature_view = test_registry.list_feature_views(project)[0]
assert registry_feature_view.batch_source == batch_source

batch_source.event_timestamp_column = "new_ts_col"
test_registry.apply_data_source(batch_source, project)
registry_feature_view = test_registry.list_feature_views(project)[0]
assert registry_feature_view.batch_source == batch_source
registry_batch_source = test_registry.list_data_sources(project)[0]
assert registry_batch_source == batch_source

test_registry.teardown()

# Will try to reload registry, which will fail because the file has been deleted
with pytest.raises(FileNotFoundError):
test_registry._get_registry_proto()


def test_commit():
fd, registry_path = mkstemp()
registry_config = RegistryConfig(path=registry_path, cache_ttl_seconds=600)
Expand Down