Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix
Signed-off-by: Kevin Zhang <kzhang@tecton.ai>
  • Loading branch information
kevjumba committed Jul 29, 2022
commit aa6fa79b8f64c396f59ed711f7e9f1271ea3cbdb
3 changes: 3 additions & 0 deletions docs/how-to-guides/adding-or-reusing-tests.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def test_historical_features(environment, universal_data_sources, full_feature_n

The key fixtures are the `environment` and `universal_data_sources` fixtures, which are defined in the `feature_repos` directories. This by default pulls in a standard dataset with driver and customer entities, certain feature views, and feature values. By including the environment as a parameter, the test automatically parametrizes across other offline / online store combinations.

## Debugging Test Failures


## Writing a new test or reusing existing tests

### To add a new test to an existing test file
Expand Down
19 changes: 9 additions & 10 deletions sdk/python/tests/unit/local_feast_tests/test_e2e_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from tests.utils.feature_store_test_functions import validate_online_features


@pytest.mark.integration
def test_e2e_local() -> None:
"""
Tests the end-to-end workflow of apply, materialize, and online retrieval.
Expand All @@ -40,15 +39,15 @@ def test_e2e_local() -> None:
global_stats_path = os.path.join(data_dir, "global_stats.parquet")
global_df.to_parquet(path=global_stats_path, allow_truncated_timestamps=True)

with runner.local_repo(
get_example_repo("example_feature_repo_2.py")
.replace("%PARQUET_PATH%", driver_stats_path)
.replace("%PARQUET_PATH_GLOBAL%", global_stats_path),
"file",
) as store:
_test_materialize_and_online_retrieval(
runner, store, start_date, end_date, driver_df
)
with runner.local_repo(
get_example_repo("example_feature_repo_2.py")
.replace("%PARQUET_PATH%", driver_stats_path)
.replace("%PARQUET_PATH_GLOBAL%", global_stats_path),
"file",
) as store:
_test_materialize_and_online_retrieval(
runner, store, start_date, end_date, driver_df
)

with runner.local_repo(
get_example_repo("example_feature_repo_version_0_19.py")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,68 +36,68 @@ def test_apply_stream_feature_view(simple_dataset_1) -> None:
global_stats_path = os.path.join(data_dir, "global_stats.parquet")
global_df.to_parquet(path=global_stats_path, allow_truncated_timestamps=True)

with runner.local_repo(
get_example_repo("example_feature_repo_2.py")
.replace("%PARQUET_PATH%", driver_stats_path)
.replace("%PARQUET_PATH_GLOBAL%", global_stats_path),
"file",
) as fs, prep_file_source(
df=simple_dataset_1, timestamp_field="ts_1"
) as file_source:
entity = Entity(name="driver_entity", join_keys=["test_key"])

stream_source = KafkaSource(
name="kafka",
timestamp_field="event_timestamp",
kafka_bootstrap_servers="",
message_format=AvroFormat(""),
topic="topic",
batch_source=file_source,
watermark_delay_threshold=timedelta(days=1),
)

@stream_feature_view(
entities=[entity],
ttl=timedelta(days=30),
owner="test@example.com",
online=True,
schema=[Field(name="dummy_field", dtype=Float32)],
description="desc",
aggregations=[
Aggregation(
column="dummy_field",
function="max",
time_window=timedelta(days=1),
),
Aggregation(
column="dummy_field2",
function="count",
time_window=timedelta(days=24),
),
],
timestamp_field="event_timestamp",
mode="spark",
source=stream_source,
tags={},
)
def simple_sfv(df):
return df

fs.apply([entity, simple_sfv])

stream_feature_views = fs.list_stream_feature_views()
assert len(stream_feature_views) == 1
assert stream_feature_views[0] == simple_sfv

features = fs.get_online_features(
features=["simple_sfv:dummy_field"],
entity_rows=[{"test_key": 1001}],
).to_dict(include_event_timestamps=True)

assert "test_key" in features
assert features["test_key"] == [1001]
assert "dummy_field" in features
assert features["dummy_field"] == [None]
with runner.local_repo(
get_example_repo("example_feature_repo_2.py")
.replace("%PARQUET_PATH%", driver_stats_path)
.replace("%PARQUET_PATH_GLOBAL%", global_stats_path),
"file",
) as fs, prep_file_source(
df=simple_dataset_1, timestamp_field="ts_1"
) as file_source:
entity = Entity(name="driver_entity", join_keys=["test_key"])

stream_source = KafkaSource(
name="kafka",
timestamp_field="event_timestamp",
kafka_bootstrap_servers="",
message_format=AvroFormat(""),
topic="topic",
batch_source=file_source,
watermark_delay_threshold=timedelta(days=1),
)

@stream_feature_view(
entities=[entity],
ttl=timedelta(days=30),
owner="test@example.com",
online=True,
schema=[Field(name="dummy_field", dtype=Float32)],
description="desc",
aggregations=[
Aggregation(
column="dummy_field",
function="max",
time_window=timedelta(days=1),
),
Aggregation(
column="dummy_field2",
function="count",
time_window=timedelta(days=24),
),
],
timestamp_field="event_timestamp",
mode="spark",
source=stream_source,
tags={},
)
def simple_sfv(df):
return df

fs.apply([entity, simple_sfv])

stream_feature_views = fs.list_stream_feature_views()
assert len(stream_feature_views) == 1
assert stream_feature_views[0] == simple_sfv

features = fs.get_online_features(
features=["simple_sfv:dummy_field"],
entity_rows=[{"test_key": 1001}],
).to_dict(include_event_timestamps=True)

assert "test_key" in features
assert features["test_key"] == [1001]
assert "dummy_field" in features
assert features["dummy_field"] == [None]


def test_stream_feature_view_udf(simple_dataset_1) -> None:
Expand All @@ -119,71 +119,71 @@ def test_stream_feature_view_udf(simple_dataset_1) -> None:
global_stats_path = os.path.join(data_dir, "global_stats.parquet")
global_df.to_parquet(path=global_stats_path, allow_truncated_timestamps=True)

with runner.local_repo(
get_example_repo("example_feature_repo_2.py")
.replace("%PARQUET_PATH%", driver_stats_path)
.replace("%PARQUET_PATH_GLOBAL%", global_stats_path),
"file",
) as fs, prep_file_source(
df=simple_dataset_1, timestamp_field="ts_1"
) as file_source:
entity = Entity(name="driver_entity", join_keys=["test_key"])

stream_source = KafkaSource(
name="kafka",
timestamp_field="event_timestamp",
kafka_bootstrap_servers="",
message_format=AvroFormat(""),
topic="topic",
batch_source=file_source,
watermark_delay_threshold=timedelta(days=1),
)

@stream_feature_view(
entities=[entity],
ttl=timedelta(days=30),
owner="test@example.com",
online=True,
schema=[Field(name="dummy_field", dtype=Float32)],
description="desc",
aggregations=[
Aggregation(
column="dummy_field",
function="max",
time_window=timedelta(days=1),
),
Aggregation(
column="dummy_field2",
function="count",
time_window=timedelta(days=24),
),
],
timestamp_field="event_timestamp",
mode="spark",
source=stream_source,
tags={},
)
def pandas_view(pandas_df):
import pandas as pd

assert type(pandas_df) == pd.DataFrame
df = pandas_df.transform(lambda x: x + 10, axis=1)
df.insert(2, "C", [20.2, 230.0, 34.0], True)
return df
with runner.local_repo(
get_example_repo("example_feature_repo_2.py")
.replace("%PARQUET_PATH%", driver_stats_path)
.replace("%PARQUET_PATH_GLOBAL%", global_stats_path),
"file",
) as fs, prep_file_source(
df=simple_dataset_1, timestamp_field="ts_1"
) as file_source:
entity = Entity(name="driver_entity", join_keys=["test_key"])

stream_source = KafkaSource(
name="kafka",
timestamp_field="event_timestamp",
kafka_bootstrap_servers="",
message_format=AvroFormat(""),
topic="topic",
batch_source=file_source,
watermark_delay_threshold=timedelta(days=1),
)

@stream_feature_view(
entities=[entity],
ttl=timedelta(days=30),
owner="test@example.com",
online=True,
schema=[Field(name="dummy_field", dtype=Float32)],
description="desc",
aggregations=[
Aggregation(
column="dummy_field",
function="max",
time_window=timedelta(days=1),
),
Aggregation(
column="dummy_field2",
function="count",
time_window=timedelta(days=24),
),
],
timestamp_field="event_timestamp",
mode="spark",
source=stream_source,
tags={},
)
def pandas_view(pandas_df):
import pandas as pd

assert type(pandas_df) == pd.DataFrame
df = pandas_df.transform(lambda x: x + 10, axis=1)
df.insert(2, "C", [20.2, 230.0, 34.0], True)
return df

import pandas as pd
import pandas as pd

fs.apply([entity, pandas_view])
fs.apply([entity, pandas_view])

stream_feature_views = fs.list_stream_feature_views()
assert len(stream_feature_views) == 1
assert stream_feature_views[0] == pandas_view
stream_feature_views = fs.list_stream_feature_views()
assert len(stream_feature_views) == 1
assert stream_feature_views[0] == pandas_view

sfv = stream_feature_views[0]
sfv = stream_feature_views[0]

df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
new_df = sfv.udf(df)
expected_df = pd.DataFrame(
{"A": [11, 12, 13], "B": [20, 30, 40], "C": [20.2, 230.0, 34.0]}
)
assert new_df.equals(expected_df)
df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
new_df = sfv.udf(df)
expected_df = pd.DataFrame(
{"A": [11, 12, 13], "B": [20, 30, 40], "C": [20.2, 230.0, 34.0]}
)
assert new_df.equals(expected_df)