diff --git a/sdk/python/feast/base_feature_view.py b/sdk/python/feast/base_feature_view.py index 67435fa44c8..80b3b0cec82 100644 --- a/sdk/python/feast/base_feature_view.py +++ b/sdk/python/feast/base_feature_view.py @@ -110,7 +110,7 @@ def __str__(self): return str(MessageToJson(self.to_proto())) def __hash__(self): - return hash((id(self), self.name)) + return hash((self.name)) def __getitem__(self, item): assert isinstance(item, list) @@ -134,6 +134,7 @@ def __eq__(self, other): if ( self.name != other.name or sorted(self.features) != sorted(other.features) + or self.projection != other.projection or self.description != other.description or self.tags != other.tags or self.owner != other.owner diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index 4a3762031e4..6040654784c 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -245,7 +245,7 @@ def __init__( self.owner = owner or "" def __hash__(self): - return hash((id(self), self.name)) + return hash((self.name, self.timestamp_field)) def __str__(self): return str(MessageToJson(self.to_proto())) @@ -263,9 +263,9 @@ def __eq__(self, other): or self.created_timestamp_column != other.created_timestamp_column or self.field_mapping != other.field_mapping or self.date_partition_column != other.date_partition_column + or self.description != other.description or self.tags != other.tags or self.owner != other.owner - or self.description != other.description ): return False @@ -392,6 +392,9 @@ def __eq__(self, other): "Comparisons should only involve KafkaSource class objects." ) + if not super().__eq__(other): + return False + if ( self.kafka_options.bootstrap_servers != other.kafka_options.bootstrap_servers @@ -402,6 +405,9 @@ def __eq__(self, other): return True + def __hash__(self): + return super().__hash__() + @staticmethod def from_proto(data_source: DataSourceProto): return KafkaSource( @@ -507,13 +513,10 @@ def __eq__(self, other): raise TypeError( "Comparisons should only involve RequestSource class objects." ) - if ( - self.name != other.name - or self.description != other.description - or self.owner != other.owner - or self.tags != other.tags - ): + + if not super().__eq__(other): return False + if isinstance(self.schema, List) and isinstance(other.schema, List): for field1, field2 in zip(self.schema, other.schema): if field1 != field2: @@ -671,17 +674,16 @@ def __init__( ) def __eq__(self, other): - if other is None: - return False - if not isinstance(other, KinesisSource): raise TypeError( "Comparisons should only involve KinesisSource class objects." ) + if not super().__eq__(other): + return False + if ( - self.name != other.name - or self.kinesis_options.record_format != other.kinesis_options.record_format + self.kinesis_options.record_format != other.kinesis_options.record_format or self.kinesis_options.region != other.kinesis_options.region or self.kinesis_options.stream_name != other.kinesis_options.stream_name ): @@ -689,6 +691,9 @@ def __eq__(self, other): return True + def __hash__(self): + return super().__hash__() + def to_proto(self) -> DataSourceProto: data_source_proto = DataSourceProto( name=self.name, @@ -744,6 +749,21 @@ def __init__( if not self.batch_source: raise ValueError(f"batch_source is needed for push source {self.name}") + def __eq__(self, other): + if not isinstance(other, PushSource): + raise TypeError("Comparisons should only involve PushSource class objects.") + + if not super().__eq__(other): + return False + + if self.batch_source != other.batch_source: + return False + + return True + + def __hash__(self): + return super().__hash__() + def validate(self, config: RepoConfig): pass diff --git a/sdk/python/feast/diff/registry_diff.py b/sdk/python/feast/diff/registry_diff.py index 10bd88c56f8..b2caec2b687 100644 --- a/sdk/python/feast/diff/registry_diff.py +++ b/sdk/python/feast/diff/registry_diff.py @@ -177,7 +177,7 @@ def extract_objects_for_keep_delete_update_add( FeastObjectType, List[Any] ] = FeastObjectType.get_objects_from_registry(registry, current_project) registry_object_type_to_repo_contents: Dict[ - FeastObjectType, Set[Any] + FeastObjectType, List[Any] ] = FeastObjectType.get_objects_from_repo_contents(desired_repo_contents) for object_type in FEAST_OBJECT_TYPES: diff --git a/sdk/python/feast/entity.py b/sdk/python/feast/entity.py index e504fc1822d..3aaf0f9b695 100644 --- a/sdk/python/feast/entity.py +++ b/sdk/python/feast/entity.py @@ -132,7 +132,7 @@ def __init__( self.last_updated_timestamp = None def __hash__(self) -> int: - return hash((id(self), self.name)) + return hash((self.name, self.join_key)) def __eq__(self, other): if not isinstance(other, Entity): diff --git a/sdk/python/feast/feature_service.py b/sdk/python/feast/feature_service.py index 40030b34ceb..2febad3b1b9 100644 --- a/sdk/python/feast/feature_service.py +++ b/sdk/python/feast/feature_service.py @@ -85,7 +85,7 @@ def __str__(self): return str(MessageToJson(self.to_proto())) def __hash__(self): - return hash((id(self), self.name)) + return hash((self.name)) def __eq__(self, other): if not isinstance(other, FeatureService): diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 2311f78e9ba..4f456be3846 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -533,25 +533,25 @@ def _plan( ... batch_source=driver_hourly_stats, ... ) >>> registry_diff, infra_diff, new_infra = fs._plan(RepoContents( - ... data_sources={driver_hourly_stats}, - ... feature_views={driver_hourly_stats_view}, - ... on_demand_feature_views=set(), - ... request_feature_views=set(), - ... entities={driver}, - ... feature_services=set())) # register entity and feature view + ... data_sources=[driver_hourly_stats], + ... feature_views=[driver_hourly_stats_view], + ... on_demand_feature_views=list(), + ... request_feature_views=list(), + ... entities=[driver], + ... feature_services=list())) # register entity and feature view """ # Validate and run inference on all the objects to be registered. self._validate_all_feature_views( - list(desired_repo_contents.feature_views), - list(desired_repo_contents.on_demand_feature_views), - list(desired_repo_contents.request_feature_views), + desired_repo_contents.feature_views, + desired_repo_contents.on_demand_feature_views, + desired_repo_contents.request_feature_views, ) - _validate_data_sources(list(desired_repo_contents.data_sources)) + _validate_data_sources(desired_repo_contents.data_sources) self._make_inferences( - list(desired_repo_contents.data_sources), - list(desired_repo_contents.entities), - list(desired_repo_contents.feature_views), - list(desired_repo_contents.on_demand_feature_views), + desired_repo_contents.data_sources, + desired_repo_contents.entities, + desired_repo_contents.feature_views, + desired_repo_contents.on_demand_feature_views, ) # Compute the desired difference between the current objects in the registry and diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 7d29a4b69bb..cea8f619cb1 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -270,7 +270,6 @@ def _initialize_sources(self, name, batch_source, stream_source, source): self.batch_source = batch_source self.source = source - # Note: Python requires redefining hash in child classes that override __eq__ def __hash__(self): return super().__hash__() @@ -298,19 +297,15 @@ def __eq__(self, other): return False if ( - self.tags != other.tags + sorted(self.entities) != sorted(other.entities) or self.ttl != other.ttl or self.online != other.online + or self.batch_source != other.batch_source + or self.stream_source != other.stream_source + or self.schema != other.schema ): return False - if sorted(self.entities) != sorted(other.entities): - return False - if self.batch_source != other.batch_source: - return False - if self.stream_source != other.stream_source: - return False - return True def ensure_valid(self): diff --git a/sdk/python/feast/on_demand_feature_view.py b/sdk/python/feast/on_demand_feature_view.py index 790891b0781..a807f3b4a40 100644 --- a/sdk/python/feast/on_demand_feature_view.py +++ b/sdk/python/feast/on_demand_feature_view.py @@ -234,14 +234,19 @@ def __copy__(self): return fv def __eq__(self, other): + if not isinstance(other, OnDemandFeatureView): + raise TypeError( + "Comparisons should only involve OnDemandFeatureView class objects." + ) + if not super().__eq__(other): return False if ( - not self.source_feature_view_projections - == other.source_feature_view_projections - or not self.source_request_sources == other.source_request_sources - or not self.udf.__code__.co_code == other.udf.__code__.co_code + self.source_feature_view_projections + != other.source_feature_view_projections + or self.source_request_sources != other.source_request_sources + or self.udf.__code__.co_code != other.udf.__code__.co_code ): return False diff --git a/sdk/python/feast/registry.py b/sdk/python/feast/registry.py index da9c6c6b217..5f5d27318a9 100644 --- a/sdk/python/feast/registry.py +++ b/sdk/python/feast/registry.py @@ -18,7 +18,7 @@ from enum import Enum from pathlib import Path from threading import Lock -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Optional from urllib.parse import urlparse import dill @@ -98,7 +98,7 @@ def get_objects_from_registry( @staticmethod def get_objects_from_repo_contents( repo_contents: RepoContents, - ) -> Dict["FeastObjectType", Set[Any]]: + ) -> Dict["FeastObjectType", List[Any]]: return { FeastObjectType.DATA_SOURCE: repo_contents.data_sources, FeastObjectType.ENTITY: repo_contents.entities, diff --git a/sdk/python/feast/repo_contents.py b/sdk/python/feast/repo_contents.py index b59adc34db4..4d7c92f2a6d 100644 --- a/sdk/python/feast/repo_contents.py +++ b/sdk/python/feast/repo_contents.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import NamedTuple, Set +from typing import List, NamedTuple from feast.data_source import DataSource from feast.entity import Entity @@ -27,12 +27,12 @@ class RepoContents(NamedTuple): Represents the objects in a Feast feature repo. """ - data_sources: Set[DataSource] - feature_views: Set[FeatureView] - on_demand_feature_views: Set[OnDemandFeatureView] - request_feature_views: Set[RequestFeatureView] - entities: Set[Entity] - feature_services: Set[FeatureService] + data_sources: List[DataSource] + feature_views: List[FeatureView] + on_demand_feature_views: List[OnDemandFeatureView] + request_feature_views: List[RequestFeatureView] + entities: List[Entity] + feature_services: List[FeatureService] def to_registry_proto(self) -> RegistryProto: registry_proto = RegistryProto() diff --git a/sdk/python/feast/repo_operations.py b/sdk/python/feast/repo_operations.py index 5e223aac8af..8a5e6b39f97 100644 --- a/sdk/python/feast/repo_operations.py +++ b/sdk/python/feast/repo_operations.py @@ -94,14 +94,20 @@ def get_repo_files(repo_root: Path) -> List[Path]: def parse_repo(repo_root: Path) -> RepoContents: - """Collect feature table definitions from feature repo""" + """ + Collects unique Feast object definitions from the given feature repo. + + Specifically, if an object foo has already been added, bar will still be added if + (bar == foo), but not if (bar is foo). This ensures that import statements will + not result in duplicates, but defining two equal objects will. + """ res = RepoContents( - data_sources=set(), - entities=set(), - feature_views=set(), - feature_services=set(), - on_demand_feature_views=set(), - request_feature_views=set(), + data_sources=[], + entities=[], + feature_views=[], + feature_services=[], + on_demand_feature_views=[], + request_feature_views=[], ) for repo_file in get_repo_files(repo_root): @@ -109,21 +115,35 @@ def parse_repo(repo_root: Path) -> RepoContents: module = importlib.import_module(module_path) for attr_name in dir(module): obj = getattr(module, attr_name) - if isinstance(obj, DataSource): - res.data_sources.add(obj) - if isinstance(obj, FeatureView): - res.feature_views.add(obj) - if isinstance(obj.stream_source, PushSource): - res.data_sources.add(obj.stream_source.batch_source) - elif isinstance(obj, Entity): - res.entities.add(obj) - elif isinstance(obj, FeatureService): - res.feature_services.add(obj) - elif isinstance(obj, OnDemandFeatureView): - res.on_demand_feature_views.add(obj) - elif isinstance(obj, RequestFeatureView): - res.request_feature_views.add(obj) - res.entities.add(DUMMY_ENTITY) + if isinstance(obj, DataSource) and not any( + (obj is ds) for ds in res.data_sources + ): + res.data_sources.append(obj) + if isinstance(obj, FeatureView) and not any( + (obj is fv) for fv in res.feature_views + ): + res.feature_views.append(obj) + if isinstance(obj.stream_source, PushSource) and not any( + (obj is ds) for ds in res.data_sources + ): + res.data_sources.append(obj.stream_source.batch_source) + elif isinstance(obj, Entity) and not any( + (obj is entity) for entity in res.entities + ): + res.entities.append(obj) + elif isinstance(obj, FeatureService) and not any( + (obj is fs) for fs in res.feature_services + ): + res.feature_services.append(obj) + elif isinstance(obj, OnDemandFeatureView) and not any( + (obj is odfv) for odfv in res.on_demand_feature_views + ): + res.on_demand_feature_views.append(obj) + elif isinstance(obj, RequestFeatureView) and not any( + (obj is rfv) for rfv in res.request_feature_views + ): + res.request_feature_views.append(obj) + res.entities.append(DUMMY_ENTITY) return res diff --git a/sdk/python/feast/saved_dataset.py b/sdk/python/feast/saved_dataset.py index 7a05a9ca221..aead7fe8eff 100644 --- a/sdk/python/feast/saved_dataset.py +++ b/sdk/python/feast/saved_dataset.py @@ -92,17 +92,23 @@ def __str__(self): return str(MessageToJson(self.to_proto())) def __hash__(self): - return hash((id(self), self.name)) + return hash((self.name)) def __eq__(self, other): if not isinstance(other, SavedDataset): raise TypeError( - "Comparisons should only involve FeatureService class objects." + "Comparisons should only involve SavedDataset class objects." ) - if self.name != other.name: - return False - if sorted(self.features) != sorted(other.features): + if ( + self.name != other.name + or sorted(self.features) != sorted(other.features) + or sorted(self.join_keys) != sorted(other.join_keys) + or self.storage != other.storage + or self.full_feature_names != other.full_feature_names + or self.tags != other.tags + or self.feature_service_name != other.feature_service_name + ): return False return True diff --git a/sdk/python/tests/example_repos/example_feature_repo_with_duplicated_featureview_names.py b/sdk/python/tests/example_repos/example_feature_repo_with_duplicated_featureview_names.py index 20ff666bd9c..cbcc3ad172b 100644 --- a/sdk/python/tests/example_repos/example_feature_repo_with_duplicated_featureview_names.py +++ b/sdk/python/tests/example_repos/example_feature_repo_with_duplicated_featureview_names.py @@ -10,7 +10,7 @@ name="driver_hourly_stats", # Intentionally use the same FeatureView name entities=["driver_id"], online=False, - batch_source=driver_hourly_stats, + source=driver_hourly_stats, ttl=timedelta(days=1), tags={}, ) @@ -19,7 +19,7 @@ name="driver_hourly_stats", # Intentionally use the same FeatureView name entities=["driver_id"], online=False, - batch_source=driver_hourly_stats, + source=driver_hourly_stats, ttl=timedelta(days=1), tags={}, ) diff --git a/sdk/python/tests/integration/registration/test_inference.py b/sdk/python/tests/integration/registration/test_inference.py index ae1ce55da7f..558700dc9c8 100644 --- a/sdk/python/tests/integration/registration/test_inference.py +++ b/sdk/python/tests/integration/registration/test_inference.py @@ -30,7 +30,7 @@ SparkSource, ) from feast.on_demand_feature_view import on_demand_feature_view -from feast.types import Float32, PrimitiveFeastType, String, UnixTimestamp +from feast.types import Float32, String, UnixTimestamp from tests.utils.data_source_utils import ( prep_file_source, simple_bq_source_using_query_arg, @@ -229,7 +229,7 @@ def test_view_with_missing_feature(features_df: pd.DataFrame) -> pd.DataFrame: @pytest.mark.parametrize( "request_source_schema", [ - [Field(name="some_date", dtype=PrimitiveFeastType.UNIX_TIMESTAMP)], + [Field(name="some_date", dtype=UnixTimestamp)], {"some_date": ValueType.UNIX_TIMESTAMP}, ], ) diff --git a/sdk/python/tests/integration/registration/test_registry.py b/sdk/python/tests/integration/registration/test_registry.py index 072be15bfee..5f72fb7125b 100644 --- a/sdk/python/tests/integration/registration/test_registry.py +++ b/sdk/python/tests/integration/registration/test_registry.py @@ -29,7 +29,7 @@ from feast.protos.feast.types import Value_pb2 as ValueProto from feast.registry import Registry from feast.repo_config import RegistryConfig -from feast.types import Array, Bytes, Float32, Int32, Int64, PrimitiveFeastType, String +from feast.types import Array, Bytes, Float32, Int32, Int64, String from feast.value_type import ValueType @@ -240,10 +240,7 @@ def test_apply_feature_view_success(test_registry): # TODO(kevjumba): remove this in feast 0.23 when deprecating @pytest.mark.parametrize( "request_source_schema", - [ - [Field(name="my_input_1", dtype=PrimitiveFeastType.INT32)], - {"my_input_1": ValueType.INT32}, - ], + [[Field(name="my_input_1", dtype=Int32)], {"my_input_1": ValueType.INT32}], ) def test_modify_feature_views_success(test_registry, request_source_schema): # Create Feature Views diff --git a/sdk/python/tests/unit/test_data_sources.py b/sdk/python/tests/unit/test_data_sources.py index a0de42e1e22..883ab7ddc09 100644 --- a/sdk/python/tests/unit/test_data_sources.py +++ b/sdk/python/tests/unit/test_data_sources.py @@ -4,7 +4,7 @@ from feast.data_source import PushSource, RequestDataSource, RequestSource from feast.field import Field from feast.infra.offline_stores.bigquery_source import BigQuerySource -from feast.types import PrimitiveFeastType +from feast.types import Bool, Float32 def test_push_with_batch(): @@ -13,8 +13,6 @@ def test_push_with_batch(): ) push_source_proto = push_source.to_proto() assert push_source_proto.HasField("batch_source") - assert push_source_proto.timestamp_field is not None - assert push_source_proto.push_options is not None push_source_unproto = PushSource.from_proto(push_source_proto) @@ -35,8 +33,8 @@ def test_request_data_source_deprecation(): def test_request_source_primitive_type_to_proto(): schema = [ - Field(name="f1", dtype=PrimitiveFeastType.FLOAT32), - Field(name="f2", dtype=PrimitiveFeastType.BOOL), + Field(name="f1", dtype=Float32), + Field(name="f2", dtype=Bool), ] request_source = RequestSource( name="source", schema=schema, description="desc", tags={}, owner="feast", @@ -44,3 +42,32 @@ def test_request_source_primitive_type_to_proto(): request_proto = request_source.to_proto() deserialized_request_source = RequestSource.from_proto(request_proto) assert deserialized_request_source == request_source + + +def test_hash(): + push_source_1 = PushSource( + name="test", batch_source=BigQuerySource(table="test.test"), + ) + push_source_2 = PushSource( + name="test", batch_source=BigQuerySource(table="test.test"), + ) + push_source_3 = PushSource( + name="test", batch_source=BigQuerySource(table="test.test2"), + ) + push_source_4 = PushSource( + name="test", + batch_source=BigQuerySource(table="test.test2"), + description="test", + ) + + s1 = {push_source_1, push_source_2} + assert len(s1) == 1 + + s2 = {push_source_1, push_source_3} + assert len(s2) == 2 + + s3 = {push_source_3, push_source_4} + assert len(s3) == 2 + + s4 = {push_source_1, push_source_2, push_source_3, push_source_4} + assert len(s4) == 3 diff --git a/sdk/python/tests/unit/test_entity.py b/sdk/python/tests/unit/test_entity.py index fee8bd9f009..254a975f678 100644 --- a/sdk/python/tests/unit/test_entity.py +++ b/sdk/python/tests/unit/test_entity.py @@ -63,3 +63,22 @@ def test_multiple_args(): def test_name_keyword(recwarn): Entity(name="my-entity", value_type=ValueType.STRING) assert len(recwarn) == 0 + + +def test_hash(): + entity1 = Entity(name="my-entity", value_type=ValueType.STRING) + entity2 = Entity(name="my-entity", value_type=ValueType.STRING) + entity3 = Entity(name="my-entity", value_type=ValueType.FLOAT) + entity4 = Entity(name="my-entity", value_type=ValueType.FLOAT, description="test") + + s1 = {entity1, entity2} + assert len(s1) == 1 + + s2 = {entity1, entity3} + assert len(s2) == 2 + + s3 = {entity3, entity4} + assert len(s3) == 2 + + s4 = {entity1, entity2, entity3, entity4} + assert len(s4) == 3 diff --git a/sdk/python/tests/unit/test_feature_service.py b/sdk/python/tests/unit/test_feature_service.py index 522ac49de13..80445299f23 100644 --- a/sdk/python/tests/unit/test_feature_service.py +++ b/sdk/python/tests/unit/test_feature_service.py @@ -1,4 +1,8 @@ -from feast import FeatureService +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.field import Field +from feast.infra.offline_stores.file_source import FileSource +from feast.types import Float32 def test_feature_service_with_description(): @@ -12,3 +16,42 @@ def test_feature_service_without_description(): feature_service = FeatureService(name="my-feature-service", features=[]) # assert feature_service.to_proto().spec.description == "" + + +def test_hash(): + file_source = FileSource(name="my-file-source", path="test.parquet") + feature_view = FeatureView( + name="my-feature-view", + entities=[], + schema=[ + Field(name="feature1", dtype=Float32), + Field(name="feature2", dtype=Float32), + ], + source=file_source, + ) + feature_service_1 = FeatureService( + name="my-feature-service", features=[feature_view[["feature1", "feature2"]]] + ) + feature_service_2 = FeatureService( + name="my-feature-service", features=[feature_view[["feature1", "feature2"]]] + ) + feature_service_3 = FeatureService( + name="my-feature-service", features=[feature_view[["feature1"]]] + ) + feature_service_4 = FeatureService( + name="my-feature-service", + features=[feature_view[["feature1"]]], + description="test", + ) + + s1 = {feature_service_1, feature_service_2} + assert len(s1) == 1 + + s2 = {feature_service_1, feature_service_3} + assert len(s2) == 2 + + s3 = {feature_service_3, feature_service_4} + assert len(s3) == 2 + + s4 = {feature_service_1, feature_service_2, feature_service_3, feature_service_4} + assert len(s4) == 3 diff --git a/sdk/python/tests/unit/test_feature_view.py b/sdk/python/tests/unit/test_feature_view.py new file mode 100644 index 00000000000..80a583806e7 --- /dev/null +++ b/sdk/python/tests/unit/test_feature_view.py @@ -0,0 +1,64 @@ +# Copyright 2022 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from feast.feature_view import FeatureView +from feast.field import Field +from feast.infra.offline_stores.file_source import FileSource +from feast.types import Float32 + + +def test_hash(): + file_source = FileSource(name="my-file-source", path="test.parquet") + feature_view_1 = FeatureView( + name="my-feature-view", + entities=[], + schema=[ + Field(name="feature1", dtype=Float32), + Field(name="feature2", dtype=Float32), + ], + source=file_source, + ) + feature_view_2 = FeatureView( + name="my-feature-view", + entities=[], + schema=[ + Field(name="feature1", dtype=Float32), + Field(name="feature2", dtype=Float32), + ], + source=file_source, + ) + feature_view_3 = FeatureView( + name="my-feature-view", + entities=[], + schema=[Field(name="feature1", dtype=Float32)], + source=file_source, + ) + feature_view_4 = FeatureView( + name="my-feature-view", + entities=[], + schema=[Field(name="feature1", dtype=Float32)], + source=file_source, + description="test", + ) + + s1 = {feature_view_1, feature_view_2} + assert len(s1) == 1 + + s2 = {feature_view_1, feature_view_3} + assert len(s2) == 2 + + s3 = {feature_view_3, feature_view_4} + assert len(s3) == 2 + + s4 = {feature_view_1, feature_view_2, feature_view_3, feature_view_4} + assert len(s4) == 3 diff --git a/sdk/python/tests/unit/test_on_demand_feature_view.py b/sdk/python/tests/unit/test_on_demand_feature_view.py new file mode 100644 index 00000000000..9d45cfbb0b7 --- /dev/null +++ b/sdk/python/tests/unit/test_on_demand_feature_view.py @@ -0,0 +1,102 @@ +# Copyright 2022 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pandas as pd + +from feast.feature_view import FeatureView +from feast.field import Field +from feast.infra.offline_stores.file_source import FileSource +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.types import Float32 + + +def udf1(features_df: pd.DataFrame) -> pd.DataFrame: + df = pd.DataFrame() + df["output1"] = features_df["feature1"] + df["output2"] = features_df["feature2"] + return df + + +def udf2(features_df: pd.DataFrame) -> pd.DataFrame: + df = pd.DataFrame() + df["output1"] = features_df["feature1"] + 100 + df["output2"] = features_df["feature2"] + 100 + return df + + +def test_hash(): + file_source = FileSource(name="my-file-source", path="test.parquet") + feature_view = FeatureView( + name="my-feature-view", + entities=[], + schema=[ + Field(name="feature1", dtype=Float32), + Field(name="feature2", dtype=Float32), + ], + source=file_source, + ) + sources = {"my-feature-view": feature_view} + on_demand_feature_view_1 = OnDemandFeatureView( + name="my-on-demand-feature-view", + sources=sources, + schema=[ + Field(name="output1", dtype=Float32), + Field(name="output2", dtype=Float32), + ], + udf=udf1, + ) + on_demand_feature_view_2 = OnDemandFeatureView( + name="my-on-demand-feature-view", + sources=sources, + schema=[ + Field(name="output1", dtype=Float32), + Field(name="output2", dtype=Float32), + ], + udf=udf1, + ) + on_demand_feature_view_3 = OnDemandFeatureView( + name="my-on-demand-feature-view", + sources=sources, + schema=[ + Field(name="output1", dtype=Float32), + Field(name="output2", dtype=Float32), + ], + udf=udf2, + ) + on_demand_feature_view_4 = OnDemandFeatureView( + name="my-on-demand-feature-view", + sources=sources, + schema=[ + Field(name="output1", dtype=Float32), + Field(name="output2", dtype=Float32), + ], + udf=udf2, + description="test", + ) + + s1 = {on_demand_feature_view_1, on_demand_feature_view_2} + assert len(s1) == 1 + + s2 = {on_demand_feature_view_1, on_demand_feature_view_3} + assert len(s2) == 2 + + s3 = {on_demand_feature_view_3, on_demand_feature_view_4} + assert len(s3) == 2 + + s4 = { + on_demand_feature_view_1, + on_demand_feature_view_2, + on_demand_feature_view_3, + on_demand_feature_view_4, + } + assert len(s4) == 3