Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
adding unit test
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
  • Loading branch information
franciscojavierarceo committed Sep 21, 2024
commit 54ca3768548c0b632ef3b35b1228412f2afe25f7
122 changes: 116 additions & 6 deletions sdk/python/feast/on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import warnings
from types import FunctionType
from typing import Any, Optional, Union, get_type_hints
from typing import Any, List, Optional, Union, get_type_hints

import dill
import pandas as pd
Expand All @@ -12,8 +12,9 @@

from feast.base_feature_view import BaseFeatureView
from feast.data_source import RequestSource
from feast.entity import Entity
from feast.errors import RegistryInferenceFailure, SpecifiedFeaturesNotPresentError
from feast.feature_view import FeatureView
from feast.feature_view import DUMMY_ENTITY_NAME, FeatureView
from feast.feature_view_projection import FeatureViewProjection
from feast.field import Field, from_value_type
from feast.protos.feast.core.OnDemandFeatureView_pb2 import (
Expand Down Expand Up @@ -61,7 +62,8 @@ class OnDemandFeatureView(BaseFeatureView):
"""

name: str
features: list[Field]
entities: Optional[List[str]]
features: List[Field]
source_feature_view_projections: dict[str, FeatureViewProjection]
source_request_sources: dict[str, RequestSource]
feature_transformation: Union[
Expand All @@ -71,13 +73,15 @@ class OnDemandFeatureView(BaseFeatureView):
description: str
tags: dict[str, str]
owner: str
write_to_online_store: bool

def __init__( # noqa: C901
self,
*,
name: str,
schema: list[Field],
sources: list[
entities: Optional[List[Entity]] = None,
schema: Optional[List[Field]] = None,
sources: List[
Union[
FeatureView,
RequestSource,
Expand All @@ -93,12 +97,14 @@ def __init__( # noqa: C901
description: str = "",
tags: Optional[dict[str, str]] = None,
owner: str = "",
write_to_online_store: bool = False,
):
"""
Creates an OnDemandFeatureView object.

Args:
name: The unique name of the on demand feature view.
entities (optional): The list of names of entities that this feature view is associated with.
schema: The list of features in the output of the on demand feature view, after
the transformation has been applied.
sources: A map from input source names to the actual input sources, which may be
Expand All @@ -113,6 +119,8 @@ def __init__( # noqa: C901
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
owner (optional): The owner of the on demand feature view, typically the email
of the primary maintainer.
write_to_online_store (optional): A boolean that indicates whether to write the on demand feature view to
the online store for faster retrieval.
"""
super().__init__(
name=name,
Expand All @@ -122,6 +130,8 @@ def __init__( # noqa: C901
owner=owner,
)

schema = schema or []
self.entities = [e.name for e in entities] if entities else [DUMMY_ENTITY_NAME]
self.mode = mode.lower()

if self.mode not in {"python", "pandas", "substrait"}:
Expand Down Expand Up @@ -152,12 +162,48 @@ def __init__( # noqa: C901
self.source_request_sources[odfv_source.name] = odfv_source
elif isinstance(odfv_source, FeatureViewProjection):
self.source_feature_view_projections[odfv_source.name] = odfv_source

else:
self.source_feature_view_projections[odfv_source.name] = (
odfv_source.projection
)

features: List[Field] = []
self.entity_columns = []

join_keys: List[str] = []
if entities:
for entity in entities:
join_keys.append(entity.join_key)
# Ensure that entities have unique join keys.
if len(set(join_keys)) < len(join_keys):
raise ValueError(
"A feature view should not have entities that share a join key."
)

for field in schema:
if field.name in join_keys:
self.entity_columns.append(field)

# Confirm that the inferred type matches the specified entity type, if it exists.
matching_entities = (
[e for e in entities if e.join_key == field.name]
if entities
else []
)
assert len(matching_entities) == 1
entity = matching_entities[0]
if entity.value_type != ValueType.UNKNOWN:
if from_value_type(entity.value_type) != field.dtype:
raise ValueError(
f"Entity {entity.name} has type {entity.value_type}, which does not match the inferred type {field.dtype}."
)
else:
features.append(field)

self.features = features
self.feature_transformation = feature_transformation
self.write_to_online_store = write_to_online_store

@property
def proto_class(self) -> type[OnDemandFeatureViewProto]:
Expand All @@ -174,8 +220,13 @@ def __copy__(self):
description=self.description,
tags=self.tags,
owner=self.owner,
write_to_online_store=self.write_to_online_store,
)
fv.entities = self.entities
fv.features = self.features
fv.projection = copy.copy(self.projection)
fv.entity_columns = copy.copy(self.entity_columns)

return fv

def __eq__(self, other):
Expand All @@ -193,11 +244,36 @@ def __eq__(self, other):
or self.source_request_sources != other.source_request_sources
or self.mode != other.mode
or self.feature_transformation != other.feature_transformation
or self.write_to_online_store != other.write_to_online_store
or sorted(self.entity_columns) != sorted(other.entity_columns)
):
return False

return True

@property
def join_keys(self) -> List[str]:
"""Returns a list of all the join keys."""
return [entity.name for entity in self.entity_columns]

@property
def schema(self) -> List[Field]:
return list(set(self.entity_columns + self.features))

def ensure_valid(self):
"""
Validates the state of this feature view locally.

Raises:
ValueError: The On Demand feature view does not have an entity when trying to use write_to_online_store.
"""
super().ensure_valid()

if self.write_to_online_store and not self.entities:
raise ValueError(
"On Demand Feature views require an entity if write_to_online_store=True"
)

def __hash__(self):
return super().__hash__()

Expand All @@ -216,7 +292,7 @@ def to_proto(self) -> OnDemandFeatureViewProto:
sources = {}
for source_name, fv_projection in self.source_feature_view_projections.items():
sources[source_name] = OnDemandSource(
feature_view_projection=fv_projection.to_proto()
feature_view_projection=fv_projection.to_proto(),
)
for (
source_name,
Expand All @@ -239,13 +315,18 @@ def to_proto(self) -> OnDemandFeatureViewProto:
)
spec = OnDemandFeatureViewSpec(
name=self.name,
entities=self.entities if self.entities else None,
entity_columns=[
field.to_proto() for field in self.entity_columns if self.entity_columns
],
features=[feature.to_proto() for feature in self.features],
sources=sources,
feature_transformation=feature_transformation,
mode=self.mode,
description=self.description,
tags=self.tags,
owner=self.owner,
write_to_online_store=self.write_to_online_store,
)

return OnDemandFeatureViewProto(spec=spec, meta=meta)
Expand Down Expand Up @@ -335,6 +416,24 @@ def from_proto(
else:
raise ValueError("At least one transformation type needs to be provided")

if hasattr(on_demand_feature_view_proto.spec, "write_to_online_store"):
write_to_online_store = (
on_demand_feature_view_proto.spec.write_to_online_store
)
else:
write_to_online_store = False
if hasattr(on_demand_feature_view_proto.spec, "entities"):
entities = on_demand_feature_view_proto.spec.entities
else:
entities = None
if hasattr(on_demand_feature_view_proto.spec, "entity_columns"):
entity_columns = [
Field.from_proto(field_proto)
for field_proto in on_demand_feature_view_proto.spec.entity_columns
]
else:
entity_columns = []

on_demand_feature_view_obj = cls(
name=on_demand_feature_view_proto.spec.name,
schema=[
Expand All @@ -350,8 +449,12 @@ def from_proto(
description=on_demand_feature_view_proto.spec.description,
tags=dict(on_demand_feature_view_proto.spec.tags),
owner=on_demand_feature_view_proto.spec.owner,
write_to_online_store=write_to_online_store,
)

on_demand_feature_view_obj.entities = list(entities)
on_demand_feature_view_obj.entity_columns = entity_columns

# FeatureViewProjections are not saved in the OnDemandFeatureView proto.
# Create the default projection.
on_demand_feature_view_obj.projection = FeatureViewProjection.from_definition(
Expand Down Expand Up @@ -595,6 +698,7 @@ def get_requested_odfvs(

def on_demand_feature_view(
*,
entities: Optional[List[Entity]] = None,
schema: list[Field],
sources: list[
Union[
Expand All @@ -607,11 +711,13 @@ def on_demand_feature_view(
description: str = "",
tags: Optional[dict[str, str]] = None,
owner: str = "",
write_to_online_store: bool = False,
):
"""
Creates an OnDemandFeatureView object with the given user function as udf.

Args:
entities (Optional): The list of names of entities that this feature view is associated with.
schema: The list of features in the output of the on demand feature view, after
the transformation has been applied.
sources: A map from input source names to the actual input sources, which may be
Expand All @@ -622,6 +728,8 @@ def on_demand_feature_view(
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
owner (optional): The owner of the on demand feature view, typically the email
of the primary maintainer.
write_to_online_store (optional): A boolean that indicates whether to write the on demand feature view to
the online store for faster retrieval.
"""

def mainify(obj) -> None:
Expand Down Expand Up @@ -664,6 +772,8 @@ def decorator(user_function):
description=description,
tags=tags,
owner=owner,
write_to_online_store=write_to_online_store,
entities=entities,
)
functools.update_wrapper(
wrapper=on_demand_feature_view_obj, wrapped=user_function
Expand Down
19 changes: 19 additions & 0 deletions sdk/python/tests/unit/test_feature_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,25 @@ def test_hash():


# TODO(felixwang9817): Add tests for proto conversion.
def test_proto_conversion():
file_source = FileSource(name="my-file-source", path="test.parquet")
feature_view_1 = FeatureView(
name="my-feature-view",
entities=[],
schema=[
Field(name="feature1", dtype=Float32),
Field(name="feature2", dtype=Float32),
],
source=file_source,
)

feature_view_proto = feature_view_1.to_proto()
assert (
feature_view_proto.spec.name == "my-feature-view" and
feature_view_proto.spec.batch_source.file_options.uri == "test.parquet" and
feature_view_proto.spec.batch_source.name == "my-file-source" and
feature_view_proto.spec.batch_source.type == 1
)
# TODO(felixwang9817): Add tests for field mapping logic.


Expand Down