diff --git a/docs/reference/beta-on-demand-feature-view.md b/docs/reference/beta-on-demand-feature-view.md index efb7023d567..684bc0ac4b2 100644 --- a/docs/reference/beta-on-demand-feature-view.md +++ b/docs/reference/beta-on-demand-feature-view.md @@ -69,6 +69,42 @@ def driver_aggregated_stats(inputs): Aggregated columns are automatically named using the pattern `{function}_{column}` (e.g., `sum_trips`, `mean_rating`). +### Using `input_schema` with Aggregations + +When the input data is not already stored as a feature view, use `input_schema` instead of `sources` to describe the fields that will be passed at request time. Feast will create an internal `RequestSource` automatically. + +```python +from datetime import timedelta +from feast import Field, on_demand_feature_view +from feast.aggregation import Aggregation +from feast.types import Float64, Int64 + +@on_demand_feature_view( + input_schema=[ + Field(name="txn_amount", dtype=Float64), + ], + schema=[ + Field(name="txn_count", dtype=Int64), + Field(name="total_txn_amount", dtype=Float64), + Field(name="avg_txn_amount", dtype=Float64), + ], + aggregations=[ + Aggregation(column="txn_amount", function="count", name="txn_count", + time_window=timedelta(days=30)), + Aggregation(column="txn_amount", function="sum", name="total_txn_amount", + time_window=timedelta(days=30)), + Aggregation(column="txn_amount", function="mean", name="avg_txn_amount", + time_window=timedelta(days=30)), + ], + entities=[user], +) +def user_transaction_stats(inputs): + # Aggregations replace the transformation function — no body needed. + pass +``` + +`input_schema` also accepts fields that are not aggregation columns — for example, thresholds, currency codes, or other contextual values passed at request time that your UDF needs but that are not stored as features. + ## Example See [https://github.com/feast-dev/on-demand-feature-views-demo](https://github.com/feast-dev/on-demand-feature-views-demo) for an example on how to use on demand feature views. diff --git a/sdk/python/feast/on_demand_feature_view.py b/sdk/python/feast/on_demand_feature_view.py index 6b8009f16cd..198c33675fa 100644 --- a/sdk/python/feast/on_demand_feature_view.py +++ b/sdk/python/feast/on_demand_feature_view.py @@ -134,6 +134,7 @@ class OnDemandFeatureView(BaseFeatureView): """ _TRACK_METRICS_TAG = "feast:track_metrics" + _INPUT_SCHEMA_SOURCE_PREFIX = "__input_schema__" name: str entities: Optional[List[str]] @@ -158,7 +159,8 @@ def __init__( # noqa: C901 name: str, entities: Optional[List[Entity]] = None, schema: Optional[List[Field]] = None, - sources: List[OnDemandSourceType], + sources: Optional[List[OnDemandSourceType]] = None, + input_schema: Optional[List[Field]] = None, udf: Optional[FunctionType] = None, udf_string: Optional[str] = "", feature_transformation: Optional[Transformation] = None, @@ -183,6 +185,11 @@ def __init__( # noqa: C901 sources: A map from input source names to the actual input sources, which may be feature views, or request data sources. These sources serve as inputs to the udf, which will refer to them by name. + input_schema (optional): A list of Fields describing data that is accepted as input + but not stored directly as features — e.g. aggregation columns, normalization + parameters, thresholds, or other contextual values passed at request time. + When provided, sources is not required — an internal RequestSource will be + created automatically. udf: The user defined transformation function, which must take pandas dataframes as inputs. udf_string: The source code version of the udf (for diffing and displaying in Web UI) @@ -214,15 +221,44 @@ def __init__( # noqa: C901 self.version = version schema = schema or [] self.entities = [e.name for e in entities] if entities else [DUMMY_ENTITY_NAME] - self.sources = sources + self.input_schema = input_schema self.mode = mode.lower() self.udf = udf self.udf_string = udf_string self.source_feature_view_projections: dict[str, FeatureViewProjection] = {} self.source_request_sources: dict[str, RequestSource] = {} + self._input_schema_sentinel: Optional[RequestSource] = None + + # Strip any existing sentinel from sources (handles __copy__ round-trip) + effective_sources: List[OnDemandSourceType] = [ + s + for s in (sources or []) + if not ( + isinstance(s, RequestSource) + and s.name.startswith(self._INPUT_SCHEMA_SOURCE_PREFIX) + ) + ] + + if input_schema is not None: + # Automatically create an internal RequestSource from input_schema. + # Stored privately so it does not appear in source_request_sources for + # external consumers (e.g. the feature server, apply(), utils.py). + self._input_schema_sentinel = RequestSource( + name=f"{self._INPUT_SCHEMA_SOURCE_PREFIX}{name}", + schema=input_schema, + ) + self.source_request_sources[self._input_schema_sentinel.name] = ( + self._input_schema_sentinel + ) + elif not effective_sources: + raise ValueError( + "Either 'sources' or 'input_schema' must be provided for OnDemandFeatureView." + ) + + self.sources = effective_sources # Process each source with explicit type handling - for odfv_source in sources: + for odfv_source in effective_sources: self._add_source_to_collections(odfv_source) features: List[Field] = [] @@ -274,6 +310,20 @@ def __init__( # noqa: C901 self.track_metrics = track_metrics self.aggregations = aggregations or [] + if input_schema is not None and self.aggregations: + input_field_names = {f.name for f in input_schema} + unknown = [ + agg.column + for agg in self.aggregations + if agg.column and agg.column not in input_field_names + ] + if unknown: + raise ValueError( + f"Aggregation column(s) {unknown} not found in input_schema " + f"for OnDemandFeatureView '{name}'. " + f"Available fields: {sorted(input_field_names)}" + ) + def _add_source_to_collections(self, odfv_source: OnDemandSourceType) -> None: """ Add a source to the appropriate collection with explicit type checking. @@ -328,6 +378,7 @@ def __copy__(self): schema=self.features, sources=list(self.source_feature_view_projections.values()) + list(self.source_request_sources.values()), + input_schema=self.input_schema, feature_transformation=self.feature_transformation, mode=self.mode, description=self.description, @@ -337,6 +388,7 @@ def __copy__(self): singleton=self.singleton, version=self.version, track_metrics=self.track_metrics, + aggregations=self.aggregations, ) fv.entities = self.entities fv.features = self.features @@ -536,6 +588,14 @@ def to_proto(self) -> OnDemandFeatureViewProto: request_data_source=request_sources.to_proto() ) + # Serialize the input_schema sentinel so that from_proto() can reconstruct + # input_schema correctly; it is excluded from source_request_sources so that + # external consumers never see it as a real data source. + if self._input_schema_sentinel is not None: + sources[self._input_schema_sentinel.name] = OnDemandSource( + request_data_source=self._input_schema_sentinel.to_proto() + ) + feature_transformation = transformation_to_proto(self.feature_transformation) tags = dict(self.tags) if self.tags else {} @@ -559,7 +619,7 @@ def to_proto(self) -> OnDemandFeatureViewProto: owner=self.owner, write_to_online_store=self.write_to_online_store, singleton=self.singleton or False, - aggregations=self.aggregations, + aggregations=[agg.to_proto() for agg in self.aggregations], version=self.version, ) return OnDemandFeatureViewProto(spec=spec, meta=meta) @@ -585,6 +645,18 @@ def from_proto( on_demand_feature_view_proto, skip_udf=skip_udf ) + # Detect and strip input_schema sentinel from sources + input_schema: Optional[List[Field]] = None + sources_without_sentinel: List[OnDemandSourceType] = [] + for source in sources: + if isinstance(source, RequestSource) and source.name.startswith( + cls._INPUT_SCHEMA_SOURCE_PREFIX + ): + input_schema = source.schema + else: + sources_without_sentinel.append(source) + sources = sources_without_sentinel + # Parse transformation from proto (skip UDF deserialization if requested) transformation = cls._parse_transformation_from_proto( on_demand_feature_view_proto, skip_udf=skip_udf @@ -607,6 +679,7 @@ def from_proto( name=on_demand_feature_view_proto.spec.name, schema=cls._parse_features_from_proto(on_demand_feature_view_proto), sources=cast(List[OnDemandSourceType], sources), + input_schema=input_schema, feature_transformation=transformation, mode=on_demand_feature_view_proto.spec.mode or "pandas", description=on_demand_feature_view_proto.spec.description, @@ -817,6 +890,10 @@ def get_request_data_schema(self) -> dict[str, ValueType]: raise TypeError( f"Request source schema is not correct type: ${str(type(request_source.schema))}" ) + # Include fields from the input_schema sentinel (stored privately) + if self._input_schema_sentinel is not None: + for field in self._input_schema_sentinel.schema: + schema[field.name] = field.dtype.to_value_type() return schema def _get_projected_feature_name(self, feature: str) -> str: @@ -1092,7 +1169,7 @@ def _is_array_type(self, dtype) -> bool: """Check if the dtype represents an array type.""" # Use proper type checking instead of string comparison dtype_str = str(dtype) - return "Array" in dtype_str or "List" in dtype_str + return "Array" in dtype_str or "List" in dtype_str or "Set" in dtype_str def _construct_random_input( self, singleton: bool = False @@ -1137,6 +1214,13 @@ def _construct_random_input( sample_value = sample_values.get(value_type, default_value) feature_dict[field.name] = sample_value + # Add input_schema fields (stored privately outside source_request_sources) + if self._input_schema_sentinel is not None: + for field in self._input_schema_sentinel.schema: + value_type = field.dtype.to_value_type() + sample_value = sample_values.get(value_type, default_value) + feature_dict[field.name] = sample_value + return feature_dict def _get_sample_values_by_type(self) -> dict[ValueType, list[Any]]: @@ -1224,13 +1308,17 @@ def on_demand_feature_view( name: Optional[str] = None, entities: Optional[List[Entity]] = None, schema: list[Field], - sources: list[ - Union[ - FeatureView, - RequestSource, - FeatureViewProjection, + sources: Optional[ + list[ + Union[ + FeatureView, + RequestSource, + FeatureViewProjection, + ] ] - ], + ] = None, + input_schema: Optional[list[Field]] = None, + aggregations: Optional[List[Aggregation]] = None, mode: str = "pandas", description: str = "", tags: Optional[dict[str, str]] = None, @@ -1252,6 +1340,10 @@ def on_demand_feature_view( sources: A map from input source names to the actual input sources, which may be feature views, or request data sources. These sources serve as inputs to the udf, which will refer to them by name. + input_schema (optional): A list of Fields describing data that is accepted as input + but not stored directly as features — e.g. aggregation columns, normalization + parameters, thresholds, or other contextual values passed at request time. + When provided, sources is not required. mode: The mode of execution (e.g,. Pandas or Python Native) description (optional): A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. @@ -1279,6 +1371,7 @@ def decorator(user_function): on_demand_feature_view_obj = OnDemandFeatureView( name=name if name is not None else user_function.__name__, sources=sources, + input_schema=input_schema, schema=schema, mode=mode, description=description, @@ -1288,6 +1381,7 @@ def decorator(user_function): entities=entities, singleton=singleton, track_metrics=track_metrics, + aggregations=aggregations, udf=user_function, udf_string=udf_string, version=version, diff --git a/sdk/python/tests/unit/test_on_demand_feature_view_input_schema.py b/sdk/python/tests/unit/test_on_demand_feature_view_input_schema.py new file mode 100644 index 00000000000..fd69762ef08 --- /dev/null +++ b/sdk/python/tests/unit/test_on_demand_feature_view_input_schema.py @@ -0,0 +1,167 @@ +# Copyright 2025 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for OnDemandFeatureView input_schema support.""" + +import copy +from datetime import timedelta + +import pandas as pd +import pytest + +from feast import Entity, Field +from feast.aggregation import Aggregation +from feast.on_demand_feature_view import OnDemandFeatureView, on_demand_feature_view +from feast.types import Float64, Int64 +from feast.value_type import ValueType + +user = Entity(name="user", join_keys=["user_id"], value_type=ValueType.INT64) + + +def test_decorator_with_input_schema(): + """The @on_demand_feature_view decorator supports input_schema without sources.""" + + @on_demand_feature_view( + input_schema=[ + Field(name="txn_amount", dtype=Float64), + ], + schema=[ + Field(name="txn_count", dtype=Int64), + Field(name="total_txn_amount", dtype=Float64), + Field(name="avg_txn_amount", dtype=Float64), + ], + aggregations=[ + Aggregation( + column="txn_amount", + function="count", + name="txn_count", + time_window=timedelta(days=30), + ), + Aggregation( + column="txn_amount", + function="sum", + name="total_txn_amount", + time_window=timedelta(days=30), + ), + Aggregation( + column="txn_amount", + function="mean", + name="avg_txn_amount", + time_window=timedelta(days=30), + ), + ], + entities=[user], + ) + def compute_txn_stats(df: pd.DataFrame) -> pd.DataFrame: + return df + + assert isinstance(compute_txn_stats, OnDemandFeatureView) + assert compute_txn_stats.name == "compute_txn_stats" + assert compute_txn_stats.input_schema == [Field(name="txn_amount", dtype=Float64)] + assert len(compute_txn_stats.aggregations) == 3 + assert len(compute_txn_stats.features) == 3 + + # The internal sentinel RequestSource should be present + sentinel_name = ( + f"{OnDemandFeatureView._INPUT_SCHEMA_SOURCE_PREFIX}compute_txn_stats" + ) + assert sentinel_name in compute_txn_stats.source_request_sources + + # sources (user-visible) should be empty + assert compute_txn_stats.sources == [] + + +def test_aggregation_aliases(): + """Aggregation name and time_window params work correctly.""" + agg = Aggregation( + column="txn_amount", + function="sum", + name="total_txn_amount", + time_window=timedelta(days=30), + ) + assert agg.name == "total_txn_amount" + assert agg.time_window == timedelta(days=30) + + +def test_input_schema_proto_roundtrip(): + """An ODFV with input_schema survives a to_proto / from_proto round-trip.""" + + @on_demand_feature_view( + input_schema=[ + Field(name="txn_amount", dtype=Float64), + ], + schema=[ + Field(name="total_txn_amount", dtype=Float64), + ], + aggregations=[ + Aggregation( + column="txn_amount", + function="sum", + name="total_txn_amount", + time_window=timedelta(days=30), + ), + ], + entities=[user], + ) + def txn_view(df: pd.DataFrame) -> pd.DataFrame: + return df + + proto = txn_view.to_proto() + restored = OnDemandFeatureView.from_proto(proto) + + assert restored.input_schema == txn_view.input_schema + assert restored.aggregations == txn_view.aggregations + sentinel_name = f"{OnDemandFeatureView._INPUT_SCHEMA_SOURCE_PREFIX}txn_view" + assert sentinel_name in restored.source_request_sources + + +def test_input_schema_copy(): + """__copy__ preserves input_schema and aggregations.""" + + @on_demand_feature_view( + input_schema=[ + Field(name="txn_amount", dtype=Float64), + ], + schema=[ + Field(name="total_txn_amount", dtype=Float64), + ], + aggregations=[ + Aggregation(column="txn_amount", function="sum", name="total_txn_amount"), + ], + entities=[user], + ) + def copy_view(df: pd.DataFrame) -> pd.DataFrame: + return df + + cloned = copy.copy(copy_view) + assert cloned.input_schema == copy_view.input_schema + assert cloned.aggregations == copy_view.aggregations + sentinel_name = f"{OnDemandFeatureView._INPUT_SCHEMA_SOURCE_PREFIX}copy_view" + assert sentinel_name in cloned.source_request_sources + + +def test_sources_required_without_input_schema(): + """Constructor raises if neither sources nor input_schema is provided.""" + with pytest.raises( + (ValueError), + ): + + def dummy(df): + return df + + OnDemandFeatureView( + name="bad_view", + schema=[Field(name="out", dtype=Float64)], + udf=dummy, + )