Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions sdk/python/feast/on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,23 @@ def from_proto(
== "user_defined_function"
and on_demand_feature_view_proto.spec.feature_transformation.user_defined_function.body_text
!= ""
and on_demand_feature_view_proto.spec.mode == "pandas"
):
transformation = PandasTransformation.from_proto(
on_demand_feature_view_proto.spec.feature_transformation.user_defined_function
)
elif (
on_demand_feature_view_proto.spec.feature_transformation.WhichOneof(
"transformation"
)
== "user_defined_function"
and on_demand_feature_view_proto.spec.feature_transformation.user_defined_function.body_text
!= ""
and on_demand_feature_view_proto.spec.mode == "python"
):
transformation = PythonTransformation.from_proto(
on_demand_feature_view_proto.spec.feature_transformation.user_defined_function
)
elif (
on_demand_feature_view_proto.spec.feature_transformation.WhichOneof(
"transformation"
Expand Down
Binary file added sdk/python/tests/data/driver_hourly_stats.parquet
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this file necessary? looks like it should be auto-generated during tests in a temp folder

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh good catch I didn't mean to check this in

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

binary files in a tests folder are scary these days lol 😄

Binary file not shown.
14 changes: 14 additions & 0 deletions sdk/python/tests/example_repos/example_feature_repo_1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from datetime import timedelta

import pandas as pd

from feast import Entity, FeatureService, FeatureView, Field, FileSource, PushSource
from feast.on_demand_feature_view import on_demand_feature_view
from feast.types import Float32, Int64, String

# Note that file source paths are not validated, so there doesn't actually need to be any data
Expand Down Expand Up @@ -99,6 +102,17 @@
)


@on_demand_feature_view(
sources=[customer_profile],
schema=[Field(name="on_demand_age", dtype=Int64)],
mode="pandas",
)
def customer_profile_pandas_odfv(inputs: pd.DataFrame) -> pd.DataFrame:
outputs = pd.DataFrame()
outputs["on_demand_age"] = inputs["age"] + 1
return outputs


all_drivers_feature_service = FeatureService(
name="driver_locations_service",
features=[driver_locations],
Expand Down
11 changes: 11 additions & 0 deletions sdk/python/tests/unit/online_store/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ def test_online() -> None:

assert "trips" in result

result = store.get_online_features(
features=["customer_profile_pandas_odfv:on_demand_age"],
entity_rows=[{"driver_id": 1, "customer_id": "5"}],
full_feature_names=False,
).to_dict()

assert "on_demand_age" in result
assert result["driver_id"] == [1]
assert result["customer_id"] == ["5"]
assert result["on_demand_age"] == [4]

# invalid table reference
with pytest.raises(FeatureViewNotFoundException):
store.get_online_features(
Expand Down
93 changes: 93 additions & 0 deletions sdk/python/tests/unit/test_on_demand_pandas_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
import tempfile
from datetime import datetime, timedelta

import pandas as pd

from feast import Entity, FeatureStore, FeatureView, FileSource, RepoConfig
from feast.driver_test_data import create_driver_hourly_stats_df
from feast.field import Field
from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig
from feast.on_demand_feature_view import on_demand_feature_view
from feast.types import Float32, Float64, Int64


def test_pandas_transformation():
with tempfile.TemporaryDirectory() as data_dir:
store = FeatureStore(
config=RepoConfig(
project="test_on_demand_python_transformation",
registry=os.path.join(data_dir, "registry.db"),
provider="local",
entity_key_serialization_version=2,
online_store=SqliteOnlineStoreConfig(
path=os.path.join(data_dir, "online.db")
),
)
)

# Generate test data.
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
start_date = end_date - timedelta(days=15)

driver_entities = [1001, 1002, 1003, 1004, 1005]
driver_df = create_driver_hourly_stats_df(driver_entities, start_date, end_date)
driver_stats_path = os.path.join(data_dir, "driver_stats.parquet")
driver_df.to_parquet(path=driver_stats_path, allow_truncated_timestamps=True)

driver = Entity(name="driver", join_keys=["driver_id"])

driver_stats_source = FileSource(
name="driver_hourly_stats_source",
path=driver_stats_path,
timestamp_field="event_timestamp",
created_timestamp_column="created",
)

driver_stats_fv = FeatureView(
name="driver_hourly_stats",
entities=[driver],
ttl=timedelta(days=0),
schema=[
Field(name="conv_rate", dtype=Float32),
Field(name="acc_rate", dtype=Float32),
Field(name="avg_daily_trips", dtype=Int64),
],
online=True,
source=driver_stats_source,
)

@on_demand_feature_view(
sources=[driver_stats_fv],
schema=[Field(name="conv_rate_plus_acc", dtype=Float64)],
mode="pandas",
)
def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame()
df["conv_rate_plus_acc"] = inputs["conv_rate"] + inputs["acc_rate"]
return df

store.apply([driver, driver_stats_source, driver_stats_fv, pandas_view])

entity_rows = [
{
"driver_id": 1001,
}
]
store.write_to_online_store(
feature_view_name="driver_hourly_stats", df=driver_df
)

online_response = store.get_online_features(
entity_rows=entity_rows,
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
"pandas_view:conv_rate_plus_acc",
],
).to_df()

assert online_response["conv_rate_plus_acc"].equals(
online_response["conv_rate"] + online_response["acc_rate"]
)
139 changes: 139 additions & 0 deletions sdk/python/tests/unit/test_on_demand_python_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os
import tempfile
import unittest
from datetime import datetime, timedelta
from typing import Any, Dict

import pandas as pd

from feast import Entity, FeatureStore, FeatureView, FileSource, RepoConfig
from feast.driver_test_data import create_driver_hourly_stats_df
from feast.field import Field
from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig
from feast.on_demand_feature_view import on_demand_feature_view
from feast.types import Float32, Float64, Int64


class TestOnDemandPythonTransformation(unittest.TestCase):
def setUp(self):
with tempfile.TemporaryDirectory() as data_dir:
self.store = FeatureStore(
config=RepoConfig(
project="test_on_demand_python_transformation",
registry=os.path.join(data_dir, "registry.db"),
provider="local",
entity_key_serialization_version=2,
online_store=SqliteOnlineStoreConfig(
path=os.path.join(data_dir, "online.db")
),
)
)

# Generate test data.
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
start_date = end_date - timedelta(days=15)

driver_entities = [1001, 1002, 1003, 1004, 1005]
driver_df = create_driver_hourly_stats_df(
driver_entities, start_date, end_date
)
driver_stats_path = os.path.join(data_dir, "driver_stats.parquet")
driver_df.to_parquet(
path=driver_stats_path, allow_truncated_timestamps=True
)

driver = Entity(name="driver", join_keys=["driver_id"])

driver_stats_source = FileSource(
name="driver_hourly_stats_source",
path=driver_stats_path,
timestamp_field="event_timestamp",
created_timestamp_column="created",
)

driver_stats_fv = FeatureView(
name="driver_hourly_stats",
entities=[driver],
ttl=timedelta(days=0),
schema=[
Field(name="conv_rate", dtype=Float32),
Field(name="acc_rate", dtype=Float32),
Field(name="avg_daily_trips", dtype=Int64),
],
online=True,
source=driver_stats_source,
)

@on_demand_feature_view(
sources=[driver_stats_fv],
schema=[Field(name="conv_rate_plus_acc_pandas", dtype=Float64)],
mode="pandas",
)
def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame()
df["conv_rate_plus_acc_pandas"] = (
inputs["conv_rate"] + inputs["acc_rate"]
)
return df

@on_demand_feature_view(
sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
schema=[Field(name="conv_rate_plus_acc_python", dtype=Float64)],
mode="python",
)
def python_view(inputs: Dict[str, Any]) -> Dict[str, Any]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still a little confused about the required signature here. Are these functions supposed to accept a dict of lists (looks like that in this test) and apply the udf for all entities at once? I thought from the previous PR that the goal was to have a udf that would be applied to individual entities...

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also alter the tests so that more than one entity is passed? this will probably fail in such a case as only first entity is processed. If we are sticking with this signature, udf should look something like this:

return {
                'conv_rate_plus_acc_python': [
                    conv_rate + acc_rate
                    for conv_rate, acc_rate in zip(inputs['conv_rate'], inputs['acc_rate'])
                ]
            }

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you look at _infer_features_dict you'll see it expects a dict of lists. I added an explicit test that shows this will result in a type failure when running the apply operations. We can add singleton execution as a follow up but this is sufficient to highlight the currently supported behavior and then we can cut a release.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@franciscojavierarceo got it, good... that's probably more efficient anyway. no rush, but in that case it will probably be a good idea to change type annotations for relevant functions to Dict[str, List[Any]].

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually originally had that setup but I received a ton of type failures from that which is why I did it this way.

Let me address both of those as folllowups. I want to merge this and cut a release.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made an issue here #4075, will close it later.

output: Dict[str, Any] = {"conv_rate_plus_acc_python": []}
output["conv_rate_plus_acc_python"].append(
inputs["conv_rate"][0] + inputs["acc_rate"][0]
)
return output

self.store.apply(
[driver, driver_stats_source, driver_stats_fv, pandas_view, python_view]
)
self.store.write_to_online_store(
feature_view_name="driver_hourly_stats", df=driver_df
)

def test_python_pandas_parity(self):
entity_rows = [
{
"driver_id": 1001,
}
]

online_python_response = self.store.get_online_features(
entity_rows=entity_rows,
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"python_view:conv_rate_plus_acc_python",
],
).to_dict()

online_pandas_response = self.store.get_online_features(
entity_rows=entity_rows,
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"pandas_view:conv_rate_plus_acc_pandas",
],
).to_df()

assert len(online_python_response) == 4
assert all(
key in online_python_response.keys()
for key in [
"driver_id",
"acc_rate",
"conv_rate",
"conv_rate_plus_acc_python",
]
)
assert len(online_python_response["conv_rate_plus_acc_python"]) == 1
assert (
online_python_response["conv_rate_plus_acc_python"][0]
== online_pandas_response["conv_rate_plus_acc_pandas"][0]
== online_python_response["conv_rate"][0]
+ online_python_response["acc_rate"][0]
)