Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
351 changes: 175 additions & 176 deletions sdk/python/tests/unit/test_on_demand_python_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,202 +45,201 @@

class TestOnDemandPythonTransformation(unittest.TestCase):
def setUp(self):
with tempfile.TemporaryDirectory() as data_dir:
self.store = FeatureStore(
config=RepoConfig(
project="test_on_demand_python_transformation",
registry=os.path.join(data_dir, "registry.db"),
provider="local",
entity_key_serialization_version=3,
online_store=SqliteOnlineStoreConfig(
path=os.path.join(data_dir, "online.db")
),
)
self.data_dir = tempfile.mkdtemp()
data_dir = self.data_dir
self.store = FeatureStore(
config=RepoConfig(
project="test_on_demand_python_transformation",
registry=os.path.join(data_dir, "registry.db"),
provider="local",
entity_key_serialization_version=3,
online_store=SqliteOnlineStoreConfig(
path=os.path.join(data_dir, "online.db")
),
)
)

# Generate test data.
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
start_date = end_date - timedelta(days=15)
# Generate test data.
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
start_date = end_date - timedelta(days=15)

driver_entities = [1001, 1002, 1003, 1004, 1005]
driver_df = create_driver_hourly_stats_df(
driver_entities, start_date, end_date
)
driver_stats_path = os.path.join(data_dir, "driver_stats.parquet")
driver_df.to_parquet(
path=driver_stats_path, allow_truncated_timestamps=True
)
driver_entities = [1001, 1002, 1003, 1004, 1005]
driver_df = create_driver_hourly_stats_df(driver_entities, start_date, end_date)
driver_stats_path = os.path.join(data_dir, "driver_stats.parquet")
driver_df.to_parquet(path=driver_stats_path, allow_truncated_timestamps=True)

driver = Entity(
name="driver", join_keys=["driver_id"], value_type=ValueType.INT64
)
driver = Entity(
name="driver", join_keys=["driver_id"], value_type=ValueType.INT64
)

driver_stats_source = FileSource(
name="driver_hourly_stats_source",
path=driver_stats_path,
timestamp_field="event_timestamp",
created_timestamp_column="created",
)
input_request_source = RequestSource(
name="counter_source",
schema=[
Field(name="counter", dtype=Int64),
Field(name="input_datetime", dtype=UnixTimestamp),
],
)
driver_stats_source = FileSource(
name="driver_hourly_stats_source",
path=driver_stats_path,
timestamp_field="event_timestamp",
created_timestamp_column="created",
)
input_request_source = RequestSource(
name="counter_source",
schema=[
Field(name="counter", dtype=Int64),
Field(name="input_datetime", dtype=UnixTimestamp),
],
)

driver_stats_fv = FeatureView(
name="driver_hourly_stats",
entities=[driver],
ttl=timedelta(days=0),
schema=[
Field(name="conv_rate", dtype=Float32),
Field(name="acc_rate", dtype=Float32),
Field(name="avg_daily_trips", dtype=Int64),
],
online=True,
source=driver_stats_source,
)
driver_stats_fv = FeatureView(
name="driver_hourly_stats",
entities=[driver],
ttl=timedelta(days=0),
schema=[
Field(name="conv_rate", dtype=Float32),
Field(name="acc_rate", dtype=Float32),
Field(name="avg_daily_trips", dtype=Int64),
],
online=True,
source=driver_stats_source,
)

driver_stats_entity_less_fv = FeatureView(
name="driver_hourly_stats_no_entity",
entities=[],
ttl=timedelta(days=0),
schema=[
Field(name="conv_rate", dtype=Float32),
Field(name="acc_rate", dtype=Float32),
Field(name="avg_daily_trips", dtype=Int64),
],
online=True,
source=driver_stats_source,
)
driver_stats_entity_less_fv = FeatureView(
name="driver_hourly_stats_no_entity",
entities=[],
ttl=timedelta(days=0),
schema=[
Field(name="conv_rate", dtype=Float32),
Field(name="acc_rate", dtype=Float32),
Field(name="avg_daily_trips", dtype=Int64),
],
online=True,
source=driver_stats_source,
)

@on_demand_feature_view(
sources=[driver_stats_fv],
schema=[Field(name="conv_rate_plus_acc_pandas", dtype=Float64)],
mode="pandas",
)
def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame()
df["conv_rate_plus_acc_pandas"] = (
inputs["conv_rate"] + inputs["acc_rate"]
)
return df
@on_demand_feature_view(
sources=[driver_stats_fv],
schema=[Field(name="conv_rate_plus_acc_pandas", dtype=Float64)],
mode="pandas",
)
def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame()
df["conv_rate_plus_acc_pandas"] = inputs["conv_rate"] + inputs["acc_rate"]
return df

@on_demand_feature_view(
sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
schema=[Field(name="conv_rate_plus_acc_python", dtype=Float64)],
mode="python",
)
def python_view(inputs: dict[str, Any]) -> dict[str, Any]:
output: dict[str, Any] = {
"conv_rate_plus_acc_python": conv_rate + acc_rate
@on_demand_feature_view(
sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
schema=[Field(name="conv_rate_plus_acc_python", dtype=Float64)],
mode="python",
)
def python_view(inputs: dict[str, Any]) -> dict[str, Any]:
output: dict[str, Any] = {
"conv_rate_plus_acc_python": conv_rate + acc_rate
for conv_rate, acc_rate in zip(inputs["conv_rate"], inputs["acc_rate"])
}
return output

@on_demand_feature_view(
sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
schema=[
Field(name="conv_rate_plus_val1_python", dtype=Float64),
Field(name="conv_rate_plus_val2_python", dtype=Float64),
],
mode="python",
)
def python_demo_view(inputs: dict[str, Any]) -> dict[str, Any]:
output: dict[str, Any] = {
"conv_rate_plus_val1_python": [
conv_rate + acc_rate
for conv_rate, acc_rate in zip(
inputs["conv_rate"], inputs["acc_rate"]
)
}
return output

@on_demand_feature_view(
sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
schema=[
Field(name="conv_rate_plus_val1_python", dtype=Float64),
Field(name="conv_rate_plus_val2_python", dtype=Float64),
],
mode="python",
)
def python_demo_view(inputs: dict[str, Any]) -> dict[str, Any]:
output: dict[str, Any] = {
"conv_rate_plus_val1_python": [
conv_rate + acc_rate
for conv_rate, acc_rate in zip(
inputs["conv_rate"], inputs["acc_rate"]
)
],
"conv_rate_plus_val2_python": [
conv_rate + acc_rate
for conv_rate, acc_rate in zip(
inputs["conv_rate"], inputs["acc_rate"]
)
],
}
return output

@on_demand_feature_view(
sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
schema=[
Field(name="conv_rate_plus_acc_python_singleton", dtype=Float64),
Field(
name="conv_rate_plus_acc_python_singleton_array",
dtype=Array(Float64),
),
"conv_rate_plus_val2_python": [
conv_rate + acc_rate
for conv_rate, acc_rate in zip(
inputs["conv_rate"], inputs["acc_rate"]
)
],
mode="python",
singleton=True,
}
return output

@on_demand_feature_view(
sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
schema=[
Field(name="conv_rate_plus_acc_python_singleton", dtype=Float64),
Field(
name="conv_rate_plus_acc_python_singleton_array",
dtype=Array(Float64),
),
],
mode="python",
singleton=True,
)
def python_singleton_view(inputs: dict[str, Any]) -> dict[str, Any]:
output: dict[str, Any] = dict(conv_rate_plus_acc_python=float("-inf"))
output["conv_rate_plus_acc_python_singleton"] = (
inputs["conv_rate"] + inputs["acc_rate"]
)
def python_singleton_view(inputs: dict[str, Any]) -> dict[str, Any]:
output: dict[str, Any] = dict(conv_rate_plus_acc_python=float("-inf"))
output["conv_rate_plus_acc_python_singleton"] = (
inputs["conv_rate"] + inputs["acc_rate"]
)
output["conv_rate_plus_acc_python_singleton_array"] = [0.1, 0.2, 0.3]
return output
output["conv_rate_plus_acc_python_singleton_array"] = [0.1, 0.2, 0.3]
return output

@on_demand_feature_view(
sources=[
driver_stats_fv[["conv_rate", "acc_rate"]],
input_request_source,
],
schema=[
Field(name="conv_rate_plus_acc", dtype=Float64),
Field(name="current_datetime", dtype=UnixTimestamp),
Field(name="counter", dtype=Int64),
Field(name="input_datetime", dtype=UnixTimestamp),
@on_demand_feature_view(
sources=[
driver_stats_fv[["conv_rate", "acc_rate"]],
input_request_source,
],
schema=[
Field(name="conv_rate_plus_acc", dtype=Float64),
Field(name="current_datetime", dtype=UnixTimestamp),
Field(name="counter", dtype=Int64),
Field(name="input_datetime", dtype=UnixTimestamp),
],
mode="python",
write_to_online_store=True,
)
def python_stored_writes_feature_view(
inputs: dict[str, Any],
) -> dict[str, Any]:
output: dict[str, Any] = {
"conv_rate_plus_acc": [
conv_rate + acc_rate
for conv_rate, acc_rate in zip(
inputs["conv_rate"], inputs["acc_rate"]
)
],
mode="python",
write_to_online_store=True,
)
def python_stored_writes_feature_view(
inputs: dict[str, Any],
) -> dict[str, Any]:
output: dict[str, Any] = {
"conv_rate_plus_acc": [
conv_rate + acc_rate
for conv_rate, acc_rate in zip(
inputs["conv_rate"], inputs["acc_rate"]
)
],
"current_datetime": [datetime.now() for _ in inputs["conv_rate"]],
"counter": [c + 1 for c in inputs["counter"]],
"input_datetime": [d for d in inputs["input_datetime"]],
}
return output
"current_datetime": [datetime.now() for _ in inputs["conv_rate"]],
"counter": [c + 1 for c in inputs["counter"]],
"input_datetime": [d for d in inputs["input_datetime"]],
}
return output

self.store.apply(
[
driver,
driver_stats_source,
driver_stats_fv,
pandas_view,
python_view,
python_singleton_view,
python_demo_view,
driver_stats_entity_less_fv,
python_stored_writes_feature_view,
]
)
self.store.write_to_online_store(
feature_view_name="driver_hourly_stats", df=driver_df
)
assert driver_stats_fv.entity_columns == [
Field(name=driver.join_key, dtype=from_value_type(driver.value_type))
self.store.apply(
[
driver,
driver_stats_source,
driver_stats_fv,
pandas_view,
python_view,
python_singleton_view,
python_demo_view,
driver_stats_entity_less_fv,
python_stored_writes_feature_view,
]
assert driver_stats_entity_less_fv.entity_columns == [DUMMY_ENTITY_FIELD]
)
self.store.write_to_online_store(
feature_view_name="driver_hourly_stats", df=driver_df
)
assert driver_stats_fv.entity_columns == [
Field(name=driver.join_key, dtype=from_value_type(driver.value_type))
]
assert driver_stats_entity_less_fv.entity_columns == [DUMMY_ENTITY_FIELD]

assert len(self.store.list_all_feature_views()) == 7
assert len(self.store.list_feature_views()) == 2
assert len(self.store.list_on_demand_feature_views()) == 5
assert len(self.store.list_stream_feature_views()) == 0
assert len(self.store.list_all_feature_views()) == 7
assert len(self.store.list_feature_views()) == 2
assert len(self.store.list_on_demand_feature_views()) == 5
assert len(self.store.list_stream_feature_views()) == 0

def tearDown(self):
import shutil

if hasattr(self, "data_dir"):
shutil.rmtree(self.data_dir, ignore_errors=True)

def test_setup(self):
pass
Expand Down
Loading