Skip to content
Prev Previous commit
Next Next commit
Fix
Signed-off-by: Kevin Zhang <kzhang@tecton.ai>
  • Loading branch information
kevjumba committed Apr 21, 2022
commit 6f01fd3c963ace4dc5934fd5103db1df97fd1d1a
7 changes: 3 additions & 4 deletions sdk/python/feast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from feast.infra.offline_stores.redshift_source import RedshiftSource
from feast.infra.offline_stores.snowflake_source import SnowflakeSource

from .base_feature_view import BaseFeatureView
from .data_source import (
KafkaSource,
KinesisSource,
Expand All @@ -21,10 +22,9 @@
from .feature_view import FeatureView
from .field import Field
from .on_demand_feature_view import OnDemandFeatureView
from .base_feature_view import BaseFeatureView
from .stream_feature_view import StreamFeatureView
from .repo_config import RepoConfig
from .request_feature_view import RequestFeatureView
from .stream_feature_view import StreamFeatureView
from .value_type import ValueType

logging.basicConfig(
Expand All @@ -40,8 +40,7 @@
pass

__all__ = [
"BaseFeatureView"
"Entity",
"BaseFeatureView" "Entity",
"KafkaSource",
"KinesisSource",
"Feature",
Expand Down
103 changes: 76 additions & 27 deletions sdk/python/feast/on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import dill
import pandas as pd

from feast.batch_feature_view import BatchFeatureView
from feast.stream_feature_view import StreamFeatureView
from feast.base_feature_view import BaseFeatureView
from feast.batch_feature_view import BatchFeatureView
from feast.data_source import RequestSource
from feast.errors import RegistryInferenceFailure, SpecifiedFeaturesNotPresentError
from feast.feature import Feature
Expand All @@ -27,6 +26,7 @@
from feast.protos.feast.core.OnDemandFeatureView_pb2 import (
UserDefinedFunction as UserDefinedFunctionProto,
)
from feast.stream_feature_view import StreamFeatureView
from feast.type_map import (
feast_value_type_to_pandas_type,
python_type_to_feast_value_type,
Expand Down Expand Up @@ -136,8 +136,18 @@ def __init__(
),
DeprecationWarning,
)
for source in inputs.values():
_sources.append(source)
for _, source in inputs.items():
if isinstance(source, FeatureView):
_sources.append(feature_view_to_batch_feature_view(source))
elif isinstance(source, FeatureViewProjection):
_sources.append(BatchFeatureView(
name=source.name,
schema=source.features,
))
elif isinstance(source, RequestSource):
_sources.append(source)
else:
raise ValueError("input can only accept FeatureView, FeatureViewProjection, or RequestSource")
_udf = udf

if args:
Expand Down Expand Up @@ -172,8 +182,18 @@ def __init__(
)
if len(args) >= 3:
_inputs = args[2]
for source in _inputs.values():
_sources.append(source)
for _, source in _inputs.items():
if isinstance(source, FeatureView):
_sources.append(feature_view_to_batch_feature_view(source))
elif isinstance(source, FeatureViewProjection):
_sources.append(BatchFeatureView(
name=source.name,
schema=source.features,
))
elif isinstance(source, RequestSource):
_sources.append(source)
else:
raise ValueError("input can only accept FeatureView, FeatureViewProjection, or RequestSource")
warnings.warn(
(
"The `inputs` parameter is being deprecated. Please use `sources` instead. "
Expand All @@ -199,8 +219,6 @@ def __init__(
tags=tags,
owner=owner,
)
print("Asdf")
print(_sources)
assert _sources is not None
self.source_feature_view_projections: Dict[str, FeatureViewProjection] = {}
self.source_request_sources: Dict[str, RequestSource] = {}
Expand Down Expand Up @@ -228,7 +246,8 @@ def __copy__(self):
name=self.name,
schema=self.features,
sources=list(
**self.source_feature_view_projections.values(), **self.source_request_sources.values(),
**self.source_feature_view_projections.values(),
**self.source_request_sources.values(),
),
udf=self.udf,
description=self.description,
Expand Down Expand Up @@ -308,25 +327,21 @@ def from_proto(cls, on_demand_feature_view_proto: OnDemandFeatureViewProto):
A OnDemandFeatureView object based on the on-demand feature view protobuf.
"""
sources = []
for (
_,
on_demand_source,
) in on_demand_feature_view_proto.spec.sources.items():
for (_, on_demand_source,) in on_demand_feature_view_proto.spec.sources.items():
if on_demand_source.WhichOneof("source") == "feature_view":
sources.append(
FeatureView.from_proto(
on_demand_source.feature_view
).projection)
FeatureView.from_proto(on_demand_source.feature_view).projection
)
elif on_demand_source.WhichOneof("source") == "feature_view_projection":
sources.append(
FeatureViewProjection.from_proto(
on_demand_source.feature_view_projection
))
on_demand_source.feature_view_projection
)
)
else:
sources.append(
RequestSource.from_proto(
on_demand_source.request_data_source
))
RequestSource.from_proto(on_demand_source.request_data_source)
)
on_demand_feature_view_obj = cls(
name=on_demand_feature_view_proto.spec.name,
schema=[
Expand Down Expand Up @@ -484,7 +499,9 @@ def get_requested_odfvs(feature_refs, project, registry):
def on_demand_feature_view(
*args,
features: Optional[List[Feature]] = None,
sources: Optional[List[Union[BatchFeatureView, StreamFeatureView, RequestSource]]] = None,
sources: Optional[
List[Union[BatchFeatureView, StreamFeatureView, RequestSource]]
] = None,
inputs: Optional[Dict[str, Union[FeatureView, RequestSource]]] = None,
schema: Optional[List[Field]] = None,
description: str = "",
Expand Down Expand Up @@ -537,8 +554,18 @@ def on_demand_feature_view(
),
DeprecationWarning,
)
for source in inputs.values():
_sources.append(source)
for _, source in inputs.items():
if isinstance(source, FeatureView):
_sources.append(feature_view_to_batch_feature_view(source))
elif isinstance(source, FeatureViewProjection):
_sources.append(BatchFeatureView(
name=source.name,
schema=source.features,
))
elif isinstance(source, RequestSource):
_sources.append(source)
else:
raise ValueError("input can only accept FeatureView, FeatureViewProjection, or RequestSource")

if args:
warnings.warn(
Expand Down Expand Up @@ -570,9 +597,19 @@ def on_demand_feature_view(
)
if len(args) >= 2:
_inputs = args[1]
for source in _inputs.values():
_sources.append(source)
warnings.warn(
for _, source in _inputs.items():
if isinstance(source, FeatureView):
_sources.append(feature_view_to_batch_feature_view(source))
elif isinstance(source, FeatureViewProjection):
_sources.append(BatchFeatureView(
name=source.name,
schema=source.features,
))
elif isinstance(source, RequestSource):
_sources.append(source)
else:
raise ValueError("input can only accept FeatureView, FeatureViewProjection, or RequestSource")
warnings.warn(
(
"The `inputs` parameter is being deprecated. Please use `sources` instead. "
"Feast 0.21 and onwards will not support the `inputs` parameter."
Expand All @@ -599,3 +636,15 @@ def decorator(user_function):
return on_demand_feature_view_obj

return decorator

def feature_view_to_batch_feature_view(fv: FeatureView) -> BatchFeatureView:
return BatchFeatureView(
name=fv.name,
entities=fv.entities,
ttl=fv.ttl,
tags=fv.tags,
online=fv.online,
owner=fv.owner,
schema=fv.schema,
source=fv.source,
)
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
conv_rate_plus_100_feature_view,
create_conv_rate_request_source,
create_customer_daily_profile_feature_view,
create_driver_hourly_stats_batch_feature_view,
create_driver_hourly_stats_feature_view,
create_driver_hourly_stats_base_feature_view,
create_field_mapping_feature_view,
create_global_stats_feature_view,
create_location_stats_feature_view,
Expand Down Expand Up @@ -312,7 +312,9 @@ def construct_universal_feature_views(
data_sources: UniversalDataSources, with_odfv: bool = True,
) -> UniversalFeatureViews:
driver_hourly_stats = create_driver_hourly_stats_feature_view(data_sources.driver)
driver_hourly_stats_base_feature_view = create_driver_hourly_stats_base_feature_view(data_sources.driver)
driver_hourly_stats_base_feature_view = create_driver_hourly_stats_batch_feature_view(
data_sources.driver
)
return UniversalFeatureViews(
customer=create_customer_daily_profile_feature_view(data_sources.customer),
global_fv=create_global_stats_feature_view(data_sources.global_ds),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from feast.data_source import DataSource, RequestSource
from feast.types import Array, FeastType, Float32, Float64, Int32
from sdk.python.feast.batch_feature_view import BatchFeatureView
from tests.integration.feature_repos.universal.entities import location


Expand Down Expand Up @@ -150,8 +151,9 @@ def create_item_embeddings_feature_view(source, infer_features: bool = False):
)
return item_embeddings_feature_view

def create_item_embeddings_base_feature_view(source, infer_features: bool = False):
item_embeddings_feature_view = BaseFeatureView(

def create_item_embeddings_batch_feature_view(source, infer_features: bool = False):
item_embeddings_feature_view = BatchFeatureView(
name="item_embeddings",
entities=["item"],
schema=None
Expand All @@ -165,6 +167,7 @@ def create_item_embeddings_base_feature_view(source, infer_features: bool = Fals
)
return item_embeddings_feature_view


def create_driver_hourly_stats_feature_view(source, infer_features: bool = False):
driver_stats_feature_view = FeatureView(
name="driver_stats",
Expand All @@ -181,8 +184,9 @@ def create_driver_hourly_stats_feature_view(source, infer_features: bool = False
)
return driver_stats_feature_view

def create_driver_hourly_stats_base_feature_view(source, infer_features: bool = False):
driver_stats_feature_view = BaseFeatureView(

def create_driver_hourly_stats_batch_feature_view(source, infer_features: bool = False):
driver_stats_feature_view = BatchFeatureView(
name="driver_stats",
entities=["driver"],
schema=None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from tests.integration.feature_repos.universal.feature_views import (
conv_rate_plus_100_feature_view,
create_conv_rate_request_source,
create_driver_hourly_stats_base_feature_view,
create_item_embeddings_base_feature_view,
create_driver_hourly_stats_batch_feature_view,
create_item_embeddings_batch_feature_view,
create_similarity_request_source,
similarity_feature_view,
)
Expand All @@ -26,7 +26,9 @@ def test_infer_odfv_features(environment, universal_data_sources, infer_features

(entities, datasets, data_sources) = universal_data_sources

driver_hourly_stats = create_driver_hourly_stats_base_feature_view(data_sources.driver)
driver_hourly_stats = create_driver_hourly_stats_batch_feature_view(
data_sources.driver
)
request_source = create_conv_rate_request_source()
driver_odfv = conv_rate_plus_100_feature_view(
{"driver": driver_hourly_stats, "input_request": request_source},
Expand Down Expand Up @@ -59,7 +61,7 @@ def test_infer_odfv_list_features(environment, infer_features, tmp_path):
timestamp_field="event_timestamp",
created_timestamp_column="created",
)
item_feature_view = create_item_embeddings_base_feature_view(fake_items_src)
item_feature_view = create_item_embeddings_batch_feature_view(fake_items_src)
sim_odfv = similarity_feature_view(
[item_feature_view, create_similarity_request_source()],
infer_features=infer_features,
Expand All @@ -78,7 +80,9 @@ def test_infer_odfv_features_with_error(environment, universal_data_sources):
(entities, datasets, data_sources) = universal_data_sources

features = [Field(name="conv_rate_plus_200", dtype=Float64)]
driver_hourly_stats = create_driver_hourly_stats_base_feature_view(data_sources.driver)
driver_hourly_stats = create_driver_hourly_stats_batch_feature_view(
data_sources.driver
)
request_source = create_conv_rate_request_source()
driver_odfv = conv_rate_plus_100_feature_view(
{"driver": driver_hourly_stats, "input_request": request_source},
Expand Down