diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 52556eda153..063dda4f53e 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -608,7 +608,12 @@ def _make_inferences( update_feature_views_with_inferred_features_and_entities( sfvs_to_update, entities + entities_to_update, self.config ) - # TODO(kevjumba): Update schema inferrence + # We need to attach the time stamp fields to the underlying data sources + # and cascade the dependencies + update_feature_views_with_inferred_features_and_entities( + odfvs_to_update, entities + entities_to_update, self.config + ) + # TODO(kevjumba): Update schema inference for sfv in sfvs_to_update: if not sfv.schema: raise ValueError( @@ -618,8 +623,13 @@ def _make_inferences( for odfv in odfvs_to_update: odfv.infer_features() + odfvs_to_write = [ + odfv for odfv in odfvs_to_update if odfv.write_to_online_store + ] + # Update to include ODFVs with write to online store fvs_to_update_map = { - view.name: view for view in [*views_to_update, *sfvs_to_update] + view.name: view + for view in [*views_to_update, *sfvs_to_update, *odfvs_to_write] } for feature_service in feature_services_to_update: feature_service.infer_features(fvs_to_update=fvs_to_update_map) @@ -847,6 +857,11 @@ def apply( ] sfvs_to_update = [ob for ob in objects if isinstance(ob, StreamFeatureView)] odfvs_to_update = [ob for ob in objects if isinstance(ob, OnDemandFeatureView)] + odfvs_with_writes_to_update = [ + ob + for ob in objects + if isinstance(ob, OnDemandFeatureView) and ob.write_to_online_store + ] services_to_update = [ob for ob in objects if isinstance(ob, FeatureService)] data_sources_set_to_update = { ob for ob in objects if isinstance(ob, DataSource) @@ -868,10 +883,22 @@ def apply( for batch_source in batch_sources_to_add: data_sources_set_to_update.add(batch_source) - for fv in itertools.chain(views_to_update, sfvs_to_update): - data_sources_set_to_update.add(fv.batch_source) - if fv.stream_source: - data_sources_set_to_update.add(fv.stream_source) + for fv in itertools.chain( + views_to_update, sfvs_to_update, odfvs_with_writes_to_update + ): + if isinstance(fv, FeatureView): + data_sources_set_to_update.add(fv.batch_source) + if isinstance(fv, StreamFeatureView): + if fv.stream_source: + data_sources_set_to_update.add(fv.stream_source) + if isinstance(fv, OnDemandFeatureView): + for source_fvp in fv.source_feature_view_projections: + if fv.source_feature_view_projections[source_fvp].batch_source: + data_sources_set_to_update.add( + fv.source_feature_view_projections[source_fvp].batch_source + ) + else: + pass for odfv in odfvs_to_update: for v in odfv.source_request_sources.values(): @@ -884,7 +911,9 @@ def apply( # Validate all feature views and make inferences. self._validate_all_feature_views( - views_to_update, odfvs_to_update, sfvs_to_update + views_to_update, + odfvs_to_update, + sfvs_to_update, ) self._make_inferences( data_sources_to_update, @@ -989,7 +1018,9 @@ def apply( tables_to_delete: List[FeatureView] = ( views_to_delete + sfvs_to_delete if not partial else [] # type: ignore ) - tables_to_keep: List[FeatureView] = views_to_update + sfvs_to_update # type: ignore + tables_to_keep: List[FeatureView] = ( + views_to_update + sfvs_to_update + odfvs_with_writes_to_update + ) # type: ignore self._get_provider().update_infra( project=self.project, @@ -1444,19 +1475,18 @@ def write_to_online_store( inputs: Optional the dictionary object to be written allow_registry_cache (optional): Whether to allow retrieving feature views from a cached registry. """ - # TODO: restrict this to work with online StreamFeatureViews and validate the FeatureView type + feature_view_dict = { + fv_proto.name: fv_proto + for fv_proto in self.list_all_feature_views(allow_registry_cache) + } try: - feature_view: FeatureView = self.get_stream_feature_view( - feature_view_name, allow_registry_cache=allow_registry_cache - ) + feature_view = feature_view_dict[feature_view_name] except FeatureViewNotFoundException: - feature_view = self.get_feature_view( - feature_view_name, allow_registry_cache=allow_registry_cache - ) + raise FeatureViewNotFoundException(feature_view_name, self.project) if df is not None and inputs is not None: raise ValueError("Both df and inputs cannot be provided at the same time.") if df is None and inputs is not None: - if isinstance(inputs, dict): + if isinstance(inputs, dict) or isinstance(inputs, List): try: df = pd.DataFrame(inputs) except Exception as _: @@ -1465,8 +1495,20 @@ def write_to_online_store( pass else: raise ValueError("inputs must be a dictionary or a pandas DataFrame.") + if df is not None and inputs is None: + if isinstance(df, dict) or isinstance(df, List): + try: + df = pd.DataFrame(df) + except Exception as _: + raise DataFrameSerializationError + provider = self._get_provider() - provider.ingest_df(feature_view, df) + if isinstance(feature_view, OnDemandFeatureView): + # TODO: add projection mapping + projection_mapping = {} + provider.ingest_df(feature_view, df, projection_mapping) + else: + provider.ingest_df(feature_view, df) def write_to_offline_store( self, diff --git a/sdk/python/feast/inference.py b/sdk/python/feast/inference.py index b9fb9b694d2..38586913767 100644 --- a/sdk/python/feast/inference.py +++ b/sdk/python/feast/inference.py @@ -209,6 +209,7 @@ def _infer_features_and_entities( fv, join_keys, run_inference_for_features, config ) + entity_columns = [] columns_to_exclude = { fv.batch_source.timestamp_field, fv.batch_source.created_timestamp_column, @@ -218,6 +219,7 @@ def _infer_features_and_entities( columns_to_exclude.remove(mapped_col) columns_to_exclude.add(original_col) + # this is what gets the right stuff table_column_names_and_types = fv.batch_source.get_table_column_names_and_types( config ) @@ -233,9 +235,9 @@ def _infer_features_and_entities( ), ) if field.name not in [ - entity_column.name for entity_column in fv.entity_columns + entity_column.name for entity_column in entity_columns ]: - fv.entity_columns.append(field) + entity_columns.append(field) elif not re.match( "^__|__$", col_name ): # double underscores often signal an internal-use column @@ -256,6 +258,8 @@ def _infer_features_and_entities( if field.name not in [feature.name for feature in fv.features]: fv.features.append(field) + fv.entity_columns = entity_columns + def _infer_on_demand_features_and_entities( fv: OnDemandFeatureView, @@ -282,18 +286,19 @@ def _infer_on_demand_features_and_entities( batch_source = getattr(source_feature_view, "batch_source") batch_field_mapping = getattr(batch_source or None, "field_mapping") - if batch_field_mapping: - for ( - original_col, - mapped_col, - ) in batch_field_mapping.items(): - if mapped_col in columns_to_exclude: - columns_to_exclude.remove(mapped_col) - columns_to_exclude.add(original_col) + for ( + original_col, + mapped_col, + ) in batch_field_mapping.items(): + if mapped_col in columns_to_exclude: + columns_to_exclude.remove(mapped_col) + columns_to_exclude.add(original_col) + + table_column_names_and_types = batch_source.get_table_column_names_and_types( + config + ) + batch_field_mapping = getattr(batch_source, "field_mapping", {}) - table_column_names_and_types = ( - batch_source.get_table_column_names_and_types(config) - ) for col_name, col_datatype in table_column_names_and_types: if col_name in columns_to_exclude: continue @@ -301,7 +306,9 @@ def _infer_on_demand_features_and_entities( field = Field( name=col_name, dtype=from_value_type( - batch_source.source_datatype_to_feast_value_type()(col_datatype) + batch_source.source_datatype_to_feast_value_type()( + col_datatype + ) ), ) if field.name not in [ diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index c3c3048a896..b0c67bcf15e 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -5,7 +5,7 @@ import pyarrow as pa from tqdm import tqdm -from feast import importer +from feast import OnDemandFeatureView, importer from feast.batch_feature_view import BatchFeatureView from feast.data_source import DataSource from feast.entity import Entity @@ -276,23 +276,38 @@ def ingest_df( self, feature_view: FeatureView, df: pd.DataFrame, + field_mapping: Optional[Dict] = None, ): table = pa.Table.from_pandas(df) - - if feature_view.batch_source.field_mapping is not None: - table = _run_pyarrow_field_mapping( - table, feature_view.batch_source.field_mapping + if isinstance(feature_view, OnDemandFeatureView): + table = _run_pyarrow_field_mapping(table, field_mapping) + join_keys = { + entity.name: entity.dtype.to_value_type() + for entity in feature_view.entity_columns + } + rows_to_write = _convert_arrow_to_proto(table, feature_view, join_keys) + + self.online_write_batch( + self.repo_config, feature_view, rows_to_write, progress=None ) + else: + # Note: A dictionary mapping of column names in this data + # source to feature names in a feature table or view. Only used for feature + # columns, not entity or timestamp columns. + if feature_view.batch_source.field_mapping is not None: + table = _run_pyarrow_field_mapping( + table, feature_view.batch_source.field_mapping + ) - join_keys = { - entity.name: entity.dtype.to_value_type() - for entity in feature_view.entity_columns - } - rows_to_write = _convert_arrow_to_proto(table, feature_view, join_keys) + join_keys = { + entity.name: entity.dtype.to_value_type() + for entity in feature_view.entity_columns + } + rows_to_write = _convert_arrow_to_proto(table, feature_view, join_keys) - self.online_write_batch( - self.repo_config, feature_view, rows_to_write, progress=None - ) + self.online_write_batch( + self.repo_config, feature_view, rows_to_write, progress=None + ) def ingest_df_to_offline_store(self, feature_view: FeatureView, table: pa.Table): if feature_view.batch_source.field_mapping is not None: diff --git a/sdk/python/feast/types.py b/sdk/python/feast/types.py index 4b07c58d19e..b934f12a868 100644 --- a/sdk/python/feast/types.py +++ b/sdk/python/feast/types.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod +from datetime import datetime, timezone from enum import Enum from typing import Dict, Union +import pyarrow + from feast.value_type import ValueType PRIMITIVE_FEAST_TYPES_TO_VALUE_TYPES = { @@ -30,6 +33,10 @@ } +def _utc_now() -> datetime: + return datetime.now(tz=timezone.utc) + + class ComplexFeastType(ABC): """ A ComplexFeastType represents a structured type that is recognized by Feast. @@ -103,7 +110,6 @@ def __hash__(self): Float64 = PrimitiveFeastType.FLOAT64 UnixTimestamp = PrimitiveFeastType.UNIX_TIMESTAMP - SUPPORTED_BASE_TYPES = [ Invalid, String, @@ -159,7 +165,6 @@ def __str__(self): FeastType = Union[ComplexFeastType, PrimitiveFeastType] - VALUE_TYPES_TO_FEAST_TYPES: Dict["ValueType", FeastType] = { ValueType.UNKNOWN: Invalid, ValueType.BYTES: Bytes, @@ -180,6 +185,33 @@ def __str__(self): ValueType.UNIX_TIMESTAMP_LIST: Array(UnixTimestamp), } +FEAST_TYPES_TO_PYARROW_TYPES = { + String: pyarrow.string(), + Bool: pyarrow.bool_(), + Int32: pyarrow.int32(), + Int64: pyarrow.int64(), + Float32: pyarrow.float32(), + Float64: pyarrow.float64(), + # Note: datetime only supports microseconds https://github.com/python/cpython/blob/3.8/Lib/datetime.py#L1559 + UnixTimestamp: pyarrow.timestamp("us", tz=_utc_now().tzname()), +} + + +def from_feast_to_pyarrow_type(feast_type: FeastType) -> pyarrow.DataType: + """ + Converts a Feast type to a PyArrow type. + + Args: + feast_type: The Feast type to be converted. + + Raises: + ValueError: The conversion could not be performed. + """ + if feast_type in FEAST_TYPES_TO_PYARROW_TYPES: + return FEAST_TYPES_TO_PYARROW_TYPES[feast_type] + + raise ValueError(f"Could not convert Feast type {feast_type} to PyArrow type.") + def from_value_type( value_type: ValueType, diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 8a9f1fadae8..dd375597627 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -43,6 +43,7 @@ from feast.protos.feast.types.Value_pb2 import RepeatedValue as RepeatedValueProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.type_map import python_values_to_proto_values +from feast.types import from_feast_to_pyarrow_type from feast.value_type import ValueType from feast.version import get_version @@ -230,6 +231,18 @@ def _convert_arrow_to_proto( table: Union[pyarrow.Table, pyarrow.RecordBatch], feature_view: "FeatureView", join_keys: Dict[str, ValueType], +) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: + # This is a workaround for isinstance(feature_view, OnDemandFeatureView), which triggers a circular import + if getattr(feature_view, "source_request_sources", None): + return _convert_arrow_odfv_to_proto(table, feature_view, join_keys) + else: + return _convert_arrow_fv_to_proto(table, feature_view, join_keys) + + +def _convert_arrow_fv_to_proto( + table: Union[pyarrow.Table, pyarrow.RecordBatch], + feature_view: "FeatureView", + join_keys: Dict[str, ValueType], ) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: # Avoid ChunkedArrays which guarantees `zero_copy_only` available. if isinstance(table, pyarrow.Table): @@ -287,6 +300,76 @@ def _convert_arrow_to_proto( return list(zip(entity_keys, features, event_timestamps, created_timestamps)) +def _convert_arrow_odfv_to_proto( + table: Union[pyarrow.Table, pyarrow.RecordBatch], + feature_view: "FeatureView", + join_keys: Dict[str, ValueType], +) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: + # Avoid ChunkedArrays which guarantees `zero_copy_only` available. + if isinstance(table, pyarrow.Table): + table = table.to_batches()[0] + + columns = [ + (field.name, field.dtype.to_value_type()) for field in feature_view.features + ] + list(join_keys.items()) + + proto_values_by_column = { + column: python_values_to_proto_values( + table.column(column).to_numpy(zero_copy_only=False), value_type + ) + for column, value_type in columns + if column in table.column_names + } + # Adding On Demand Features + for feature in feature_view.features: + if ( + feature.name in [c[0] for c in columns] + and feature.name not in proto_values_by_column + ): + # initializing the column as null + null_column = pyarrow.array( + [None] * table.num_rows, + type=from_feast_to_pyarrow_type(feature.dtype), + ) + updated_table = pyarrow.RecordBatch.from_arrays( + table.columns + [null_column], + schema=table.schema.append( + pyarrow.field(feature.name, null_column.type) + ), + ) + proto_values_by_column[feature.name] = python_values_to_proto_values( + updated_table.column(feature.name).to_numpy(zero_copy_only=False), + feature.dtype.to_value_type(), + ) + + entity_keys = [ + EntityKeyProto( + join_keys=join_keys, + entity_values=[proto_values_by_column[k][idx] for k in join_keys], + ) + for idx in range(table.num_rows) + ] + + # Serialize the features per row + feature_dict = { + feature.name: proto_values_by_column[feature.name] + for feature in feature_view.features + } + features = [dict(zip(feature_dict, vars)) for vars in zip(*feature_dict.values())] + + # We need to artificially add event_timestamps and created_timestamps + event_timestamps = [] + timestamp_values = pd.to_datetime([_utc_now() for i in range(table.num_rows)]) + + for val in timestamp_values: + event_timestamps.append(_coerce_datetime(val)) + + # setting them equivalent + created_timestamps = event_timestamps + + return list(zip(entity_keys, features, event_timestamps, created_timestamps)) + + def _validate_entity_values(join_key_values: Dict[str, List[ValueProto]]): set_of_row_lengths = {len(v) for v in join_key_values.values()} if len(set_of_row_lengths) > 1: @@ -931,6 +1014,17 @@ def _prepare_entities_to_read_from_online_store( num_rows = _validate_entity_values(entity_proto_values) + odfv_entities = [] + request_source_keys = [] + for on_demand_feature_view in requested_on_demand_feature_views: + odfv_entities.append(*getattr(on_demand_feature_view, "entities", None)) + for source in on_demand_feature_view.source_request_sources: + source_schema = on_demand_feature_view.source_request_sources[source].schema + for column in source_schema: + request_source_keys.append(column.name) + + join_keys_set.update(set(odfv_entities)) + join_key_values: Dict[str, List[ValueProto]] = {} request_data_features: Dict[str, List[ValueProto]] = {} # Entity rows may be either entities or request data. @@ -943,7 +1037,8 @@ def _prepare_entities_to_read_from_online_store( join_key = join_key_or_entity_name else: try: - join_key = entity_name_to_join_key_map[join_key_or_entity_name] + if join_key_or_entity_name in request_source_keys: + join_key = entity_name_to_join_key_map[join_key_or_entity_name] except KeyError: raise EntityNotFoundException(join_key_or_entity_name, project) else: diff --git a/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py b/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py index cc48295b206..0d48a4aa248 100644 --- a/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py +++ b/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py @@ -11,7 +11,7 @@ from feast.entity import Entity from feast.feast_object import ALL_RESOURCE_TYPES from feast.feature_store import FeatureStore -from feast.feature_view import DUMMY_ENTITY_ID, FeatureView +from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_NAME, FeatureView from feast.field import Field from feast.infra.offline_stores.file_source import FileSource from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig @@ -347,7 +347,7 @@ def test_apply_entities_and_feature_views(test_feature_store): "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], ) -def test_apply_dummuy_entity_and_feature_view_columns(test_feature_store): +def test_apply_dummy_entity_and_feature_view_columns(test_feature_store): assert isinstance(test_feature_store, FeatureStore) # Create Feature Views batch_source = FileSource( @@ -359,14 +359,25 @@ def test_apply_dummuy_entity_and_feature_view_columns(test_feature_store): e1 = Entity(name="fs1_my_entity_1", description="something") - fv = FeatureView( - name="my_feature_view_no_entity", + fv_with_entity = FeatureView( + name="my_feature_view_with_entity", schema=[ Field(name="fs1_my_feature_1", dtype=Int64), Field(name="fs1_my_feature_2", dtype=String), Field(name="fs1_my_feature_3", dtype=Array(String)), Field(name="fs1_my_feature_4", dtype=Array(Bytes)), - Field(name="fs1_my_entity_2", dtype=Int64), + Field(name="fs1_my_entity_1", dtype=Int64), + ], + entities=[e1], + tags={"team": "matchmaking"}, + source=batch_source, + ttl=timedelta(minutes=5), + ) + + fv_no_entity = FeatureView( + name="my_feature_view_no_entity", + schema=[ + Field(name="fs1_my_feature_1", dtype=Int64), ], entities=[], tags={"team": "matchmaking"}, @@ -375,16 +386,24 @@ def test_apply_dummuy_entity_and_feature_view_columns(test_feature_store): ) # Check that the entity_columns are empty before applying - assert fv.entity_columns == [] + assert fv_no_entity.entities == [DUMMY_ENTITY_NAME] + assert fv_no_entity.entity_columns == [] + assert fv_with_entity.entity_columns[0].name == e1.name # Register Feature View - test_feature_store.apply([fv, e1]) - fv_actual = test_feature_store.get_feature_view("my_feature_view_no_entity") + test_feature_store.apply([e1, fv_no_entity, fv_with_entity]) + fv_from_online_store = test_feature_store.get_feature_view( + "my_feature_view_no_entity" + ) # Note that after the apply() the feature_view serializes the Dummy Entity ID - assert fv.entity_columns[0].name == DUMMY_ENTITY_ID - assert fv_actual.entity_columns[0].name == DUMMY_ENTITY_ID + assert fv_no_entity.entity_columns[0].name == DUMMY_ENTITY_ID + assert fv_from_online_store.entity_columns[0].name == DUMMY_ENTITY_ID + assert fv_from_online_store.entities == [] + assert fv_no_entity.entities == [DUMMY_ENTITY_NAME] + assert fv_with_entity.entity_columns[0].name == e1.name + assert fv_with_entity.entities == [e1.name] test_feature_store.teardown() diff --git a/sdk/python/tests/unit/test_on_demand_feature_view.py b/sdk/python/tests/unit/test_on_demand_feature_view.py index 6073891aba3..4b30bd6be99 100644 --- a/sdk/python/tests/unit/test_on_demand_feature_view.py +++ b/sdk/python/tests/unit/test_on_demand_feature_view.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import datetime from typing import Any, Dict, List import pandas as pd @@ -50,6 +50,15 @@ def python_native_udf(features_dict: Dict[str, Any]) -> Dict[str, Any]: return output_dict +def python_writes_test_udf(features_dict: Dict[str, Any]) -> Dict[str, Any]: + output_dict: Dict[str, List[Any]] = { + "output1": features_dict["feature1"] + 100, + "output2": features_dict["feature2"] + 101, + "output3": datetime.datetime.now(), + } + return output_dict + + @pytest.mark.filterwarnings("ignore:udf and udf_string parameters are deprecated") def test_hash(): file_source = FileSource(name="my-file-source", path="test.parquet") @@ -261,3 +270,89 @@ def test_from_proto_backwards_compatible_udf(): reserialized_proto.feature_transformation.udf_string == on_demand_feature_view.feature_transformation.udf_string ) + + +def test_on_demand_feature_view_writes_protos(): + file_source = FileSource(name="my-file-source", path="test.parquet") + feature_view = FeatureView( + name="my-feature-view", + entities=[], + schema=[ + Field(name="feature1", dtype=Float32), + Field(name="feature2", dtype=Float32), + ], + source=file_source, + ) + sources = [feature_view] + on_demand_feature_view = OnDemandFeatureView( + name="my-on-demand-feature-view", + sources=sources, + schema=[ + Field(name="output1", dtype=Float32), + Field(name="output2", dtype=Float32), + ], + feature_transformation=PandasTransformation( + udf=udf1, udf_string="udf1 source code" + ), + write_to_online_store=True, + ) + + proto = on_demand_feature_view.to_proto() + reserialized_proto = OnDemandFeatureView.from_proto(proto) + + assert on_demand_feature_view.write_to_online_store + assert proto.spec.write_to_online_store + assert reserialized_proto.write_to_online_store + + proto.spec.write_to_online_store = False + reserialized_proto = OnDemandFeatureView.from_proto(proto) + assert not reserialized_proto.write_to_online_store + + +def test_on_demand_feature_view_stored_writes(): + file_source = FileSource(name="my-file-source", path="test.parquet") + feature_view = FeatureView( + name="my-feature-view", + entities=[], + schema=[ + Field(name="feature1", dtype=Float32), + Field(name="feature2", dtype=Float32), + ], + source=file_source, + ) + sources = [feature_view] + + on_demand_feature_view = OnDemandFeatureView( + name="my-on-demand-feature-view", + sources=sources, + schema=[ + Field(name="output1", dtype=Float32), + Field(name="output2", dtype=Float32), + ], + feature_transformation=PythonTransformation( + udf=python_writes_test_udf, udf_string="python native udf source code" + ), + description="testing on demand feature view stored writes", + mode="python", + write_to_online_store=True, + ) + + transformed_output = on_demand_feature_view.transform_dict( + { + "feature1": 0, + "feature2": 1, + } + ) + expected_output = {"feature1": 0, "feature2": 1, "output1": 100, "output2": 102} + keys_to_validate = [ + "feature1", + "feature2", + "output1", + "output2", + ] + for k in keys_to_validate: + assert transformed_output[k] == expected_output[k] + + assert transformed_output["output3"] is not None and isinstance( + transformed_output["output3"], datetime.datetime + ) diff --git a/sdk/python/tests/unit/test_on_demand_python_transformation.py b/sdk/python/tests/unit/test_on_demand_python_transformation.py index ff7ad494caf..635b23b80cd 100644 --- a/sdk/python/tests/unit/test_on_demand_python_transformation.py +++ b/sdk/python/tests/unit/test_on_demand_python_transformation.py @@ -20,7 +20,16 @@ from feast.field import Field from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig from feast.on_demand_feature_view import on_demand_feature_view -from feast.types import Array, Bool, Float32, Float64, Int64, String +from feast.types import ( + Array, + Bool, + Float32, + Float64, + Int64, + String, + UnixTimestamp, + _utc_now, +) class TestOnDemandPythonTransformation(unittest.TestCase): @@ -371,15 +380,15 @@ def python_view(inputs: dict[str, Any]) -> dict[str, Any]: self.store.apply( [driver, driver_stats_source, driver_stats_fv, python_view] ) - self.store.write_to_online_store( - feature_view_name="driver_hourly_stats", df=driver_df - ) - fv_applied = self.store.get_feature_view("driver_hourly_stats") assert fv_applied.entities == [driver.name] # Note here that after apply() is called, the entity_columns are populated with the join_key assert fv_applied.entity_columns[0].name == driver.join_key + self.store.write_to_online_store( + feature_view_name="driver_hourly_stats", df=driver_df + ) + def test_python_transformation_returning_all_data_types(self): entity_rows = [ { @@ -484,8 +493,7 @@ def test_invalid_python_transformation_raises_type_error_on_apply(): schema=[Field(name="driver_name_lower", dtype=String)], mode="python", ) - def python_view(inputs: dict[str, Any]) -> dict[str, Any]: - return {"driver_name_lower": []} + def python_view(inputs: dict[str, Any]) -> dict[str, Any]: return {"driver_name_lower": []} with pytest.raises( TypeError, @@ -494,3 +502,206 @@ def python_view(inputs: dict[str, Any]) -> dict[str, Any]: ), ): store.apply([request_source, python_view]) + + +class TestOnDemandTransformationsWithWrites(unittest.TestCase): + def test_stored_writes(self): + with tempfile.TemporaryDirectory() as data_dir: + self.store = FeatureStore( + config=RepoConfig( + project="test_on_demand_python_transformation", + registry=os.path.join(data_dir, "registry.db"), + provider="local", + entity_key_serialization_version=2, + online_store=SqliteOnlineStoreConfig( + path=os.path.join(data_dir, "online.db") + ), + ) + ) + + # Generate test data. + end_date = datetime.now().replace(microsecond=0, second=0, minute=0) + start_date = end_date - timedelta(days=15) + + driver_entities = [1001, 1002, 1003, 1004, 1005] + driver_df = create_driver_hourly_stats_df( + driver_entities, start_date, end_date + ) + driver_stats_path = os.path.join(data_dir, "driver_stats.parquet") + driver_df.to_parquet( + path=driver_stats_path, allow_truncated_timestamps=True + ) + + driver = Entity(name="driver", join_keys=["driver_id"]) + + driver_stats_source = FileSource( + name="driver_hourly_stats_source", + path=driver_stats_path, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + input_request_source = RequestSource( + name="counter_source", + schema=[ + Field(name="counter", dtype=Int64), + Field(name="input_datetime", dtype=UnixTimestamp), + ], + ) + + driver_stats_fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(days=0), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + ], + online=True, + source=driver_stats_source, + ) + assert driver_stats_fv.entities == [driver.name] + assert driver_stats_fv.entity_columns == [] + + @on_demand_feature_view( + entities=[driver], + sources=[ + driver_stats_fv[["conv_rate", "acc_rate"]], + input_request_source, + ], + schema=[ + Field(name="conv_rate_plus_acc", dtype=Float64), + Field(name="current_datetime", dtype=UnixTimestamp), + Field(name="counter", dtype=Int64), + Field(name="input_datetime", dtype=UnixTimestamp), + ], + mode="python", + write_to_online_store=True, + ) + def python_stored_writes_feature_view( + inputs: dict[str, Any], + ) -> dict[str, Any]: + output: dict[str, Any] = { + "conv_rate_plus_acc": [ + conv_rate + acc_rate + for conv_rate, acc_rate in zip( + inputs["conv_rate"], inputs["acc_rate"] + ) + ], + "current_datetime": [datetime.now() for _ in inputs["conv_rate"]], + "counter": [c + 1 for c in inputs["counter"]], + "input_datetime": [d for d in inputs["input_datetime"]], + } + return output + + assert python_stored_writes_feature_view.entities == [driver.name] + assert python_stored_writes_feature_view.entity_columns == [] + + self.store.apply( + [ + driver, + driver_stats_source, + driver_stats_fv, + python_stored_writes_feature_view, + ] + ) + fv_applied = self.store.get_feature_view("driver_hourly_stats") + odfv_applied = self.store.get_on_demand_feature_view( + "python_stored_writes_feature_view" + ) + + assert fv_applied.entities == [driver.name] + assert odfv_applied.entities == [driver.name] + + # Note here that after apply() is called, the entity_columns are populated with the join_key + # assert fv_applied.entity_columns[0].name == driver.join_key + assert fv_applied.entity_columns == [] + assert odfv_applied.entity_columns[0].name == driver.join_key + + assert len(self.store.list_all_feature_views()) == 2 + assert len(self.store.list_feature_views()) == 1 + assert len(self.store.list_on_demand_feature_views()) == 1 + assert len(self.store.list_stream_feature_views()) == 0 + assert ( + driver_stats_fv.entity_columns + == self.store.get_feature_view("driver_hourly_stats").entity_columns + ) + assert ( + python_stored_writes_feature_view.entity_columns + == self.store.get_on_demand_feature_view("python_stored_writes_feature_view").entity_columns + ) + + current_datetime = _utc_now() + fv_entity_rows_to_write = [ + { + "driver_id": 1001, + "conv_rate": 0.25, + "acc_rate": 0.25, + "avg_daily_trips": 2, + "event_timestamp": current_datetime, + "created": current_datetime, + } + ] + odfv_entity_rows_to_write = [ + { + "driver_id": 1001, + "counter": 0, + "input_datetime": current_datetime, + } + ] + fv_entity_rows_to_read = [ + { + "driver_id": 1001, + } + ] + # Note that here we shouldn't have to pass the request source features for reading + # because they should have already been written to the online store + odfv_entity_rows_to_read = [ + { + "driver_id": 1001, + "conv_rate": 0.25, + "acc_rate": 0.25, + "counter": 0, + "input_datetime": current_datetime, + } + ] + print("storing fv features") + self.store.write_to_online_store( + feature_view_name="driver_hourly_stats", + df=fv_entity_rows_to_write, + ) + print("reading fv features") + online_python_response = self.store.get_online_features( + entity_rows=fv_entity_rows_to_read, + features=[ + "driver_hourly_stats:conv_rate", + "driver_hourly_stats:acc_rate", + "driver_hourly_stats:avg_daily_trips", + ], + ).to_dict() + print(online_python_response) + print("storing odfv features") + self.store.write_to_online_store( + feature_view_name="python_stored_writes_feature_view", + df=odfv_entity_rows_to_write, + ) + print("reading odfv features") + online_python_response = self.store.get_online_features( + entity_rows=odfv_entity_rows_to_read, + features=[ + "python_stored_writes_feature_view:conv_rate_plus_acc", + "python_stored_writes_feature_view:current_datetime", + "python_stored_writes_feature_view:counter", + "python_stored_writes_feature_view:input_datetime", + ], + ).to_dict() + print(online_python_response) + assert sorted(list(online_python_response.keys())) == sorted( + [ + "driver_id", + "conv_rate_plus_acc", + "counter", + "current_datetime", + "input_datetime", + ] + )