Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
checking in progress...trying to fix tests
Signed-off-by: Francisco Javier Arceo <franciscojavierarceo@users.noreply.github.com>
  • Loading branch information
franciscojavierarceo committed Apr 3, 2024
commit 066880f5fb4db64afe2198ef708401ce4fda0471
15 changes: 15 additions & 0 deletions sdk/python/tests/example_repos/example_feature_repo_1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from datetime import timedelta

import pandas as pd

from feast import Entity, FeatureService, FeatureView, Field, FileSource, PushSource
from feast.types import Float32, Int64, String
from feast.on_demand_feature_view import on_demand_feature_view

# Note that file source paths are not validated, so there doesn't actually need to be any data
# at the paths for these file sources. Since these paths are effectively fake, this example
Expand Down Expand Up @@ -98,6 +101,18 @@
tags={},
)

@on_demand_feature_view(
sources=[customer_driver_combined_source],
schema=[
Field(name='on_demand_feature', dtype=Int64)
],
mode="pandas",
)
def customer_driver_combined_pandas_odfv(inputs: pd.DataFrame) -> pd.DataFrame:
outputs = pd.DataFrame()
outputs['on_demand_feature'] = inputs['trips'] + 1
return outputs


all_drivers_feature_service = FeatureService(
name="driver_locations_service",
Expand Down
7 changes: 7 additions & 0 deletions sdk/python/tests/unit/online_store/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ def test_online() -> None:

assert "trips" in result

result = store.get_online_features(
features=["customer_driver_combined_pandas_odfv:on_demand_feature"],
entity_rows=[{"driver_id": 0, "customer_id": 0}],
full_feature_names=False,
).to_dict()
print(result)
assert 1 == 2
# invalid table reference
with pytest.raises(FeatureViewNotFoundException):
store.get_online_features(
Expand Down
2 changes: 2 additions & 0 deletions sdk/python/tests/unit/test_on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def test_python_native_transformation_mode():
}
) == {"feature1": 0, "feature2": 1, "output1": 100, "output2": 102}

# def test_get_online_features_on_demand():


@pytest.mark.filterwarnings("ignore:udf and udf_string parameters are deprecated")
def test_from_proto_backwards_compatible_udf():
Expand Down
117 changes: 117 additions & 0 deletions sdk/python/tests/unit/test_on_demand_python_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
import tempfile
from datetime import datetime, timedelta

import pandas as pd

from feast import Entity, FeatureStore, FeatureView, FileSource, RepoConfig
from feast.driver_test_data import create_driver_hourly_stats_df
from feast.field import Field
from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig
from feast.on_demand_feature_view import on_demand_feature_view
from feast.types import Float32, Float64, Int64
from typing import Dict, Any


def test_python_pandas_parity():
with tempfile.TemporaryDirectory() as data_dir:
store = FeatureStore(
config=RepoConfig(
project="test_on_demand_python_transformation",
registry=os.path.join(data_dir, "registry.db"),
provider="local",
entity_key_serialization_version=2,
online_store=SqliteOnlineStoreConfig(
path=os.path.join(data_dir, "online.db")
),
)
)

# Generate test data.
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
start_date = end_date - timedelta(days=15)

driver_entities = [1001, 1002, 1003, 1004, 1005]
driver_df = create_driver_hourly_stats_df(driver_entities, start_date, end_date)
driver_stats_path = os.path.join(data_dir, "driver_stats.parquet")
driver_df.to_parquet(path=driver_stats_path, allow_truncated_timestamps=True)

driver = Entity(name="driver", join_keys=["driver_id"])

driver_stats_source = FileSource(
name="driver_hourly_stats_source",
path=driver_stats_path,
timestamp_field="event_timestamp",
created_timestamp_column="created",
)

driver_stats_fv = FeatureView(
name="driver_hourly_stats",
entities=[driver],
ttl=timedelta(days=1),
schema=[
Field(name="conv_rate", dtype=Float32),
Field(name="acc_rate", dtype=Float32),
Field(name="avg_daily_trips", dtype=Int64),
],
online=True,
source=driver_stats_source,
)

@on_demand_feature_view(
sources=[driver_stats_fv],
schema=[Field(name="conv_rate_plus_acc", dtype=Float64)],
mode="pandas",
)
def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame()
df["conv_rate_plus_acc"] = inputs["conv_rate"] + inputs["acc_rate"]
return df

# @on_demand_feature_view(
# sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
# schema=[Field(name="conv_rate_plus_acc_python", dtype=Float64)],
# mode="python",
# )
# def python_view(inputs: Dict[str, Any]) -> Dict[str, Any]:
# output: Dict[str, Any] = {'conv_rate_plus_acc_python': inputs['conv_rate'] + inputs['acc_rate']}
# return output

store.apply(
[driver, driver_stats_source, driver_stats_fv, pandas_view]
)

entity_rows = [
{
# entity's join key -> entity values
"driver_id": 1001,
# "event_timestamp" (reserved key) -> timestamps
"event_timestamp": datetime(2021, 4, 12, 10, 59, 42),
}
]
entity_df = pd.DataFrame.from_dict(
{
# entity's join key -> entity values
"driver_id": [1001],
# "event_timestamp" (reserved key) -> timestamps
"event_timestamp": [
datetime(2021, 4, 12, 10, 59, 42),
],
}
)

training_df = store.get_online_features(
# entity_rows=entity_rows,
entity_rows=entity_df,
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
# "python_view:conv_rate_plus_acc_python",
"pandas_view:conv_rate_plus_acc",
],
).to_df()

assert training_df["conv_rate_plus_acc"].equals(
training_df["conv_rate_plus_acc_python"]
)