Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix formatting, linting
Signed-off-by: tokoko <togurg14@freeuni.edu.ge>
  • Loading branch information
tokoko committed Mar 21, 2024
commit fcfe3053b4cb45f1e2d61f4122ac36e84b2a3832
Original file line number Diff line number Diff line change
Expand Up @@ -162,26 +162,32 @@ def read_fv(feature_view, feature_refs, full_feature_names):

if full_feature_names:
fv_table = fv_table.rename(
{f"{full_name_prefix}__{feature}": feature for feature in feature_refs}
{
f"{full_name_prefix}__{feature}": feature
for feature in feature_refs
}
)

feature_refs = [f"{full_name_prefix}__{feature}" for feature in feature_refs]
feature_refs = [
f"{full_name_prefix}__{feature}" for feature in feature_refs
]

return (
fv_table,
feature_view.batch_source.timestamp_field,
feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns},
feature_view.projection.join_key_map
or {e.name: e.name for e in feature_view.entity_columns},
feature_refs,
feature_view.ttl
feature_view.ttl,
)

res = point_in_time_join(
entity_table=entity_table,
feature_tables=[
feature_tables=[
read_fv(feature_view, feature_refs, full_feature_names)
for feature_view in feature_views
],
event_timestamp_col=event_timestamp_col
event_timestamp_col=event_timestamp_col,
)

return IbisRetrievalJob(
Expand Down Expand Up @@ -217,8 +223,8 @@ def pull_all_from_table_or_query(
table = table.select(*fields)

# TODO get rid of this fix
if '__log_date' in table.columns:
table = table.drop('__log_date')
if "__log_date" in table.columns:
table = table.drop("__log_date")

table = table.filter(
ibis.and_(
Expand Down Expand Up @@ -255,7 +261,7 @@ def write_logged_features(
else:
kwargs = {}

#TODO always write to directory
# TODO always write to directory
table.to_parquet(
f"{destination.path}/{uuid.uuid4().hex}-{{i}}.parquet", **kwargs
)
Expand Down Expand Up @@ -346,9 +352,9 @@ def metadata(self) -> Optional[RetrievalMetadata]:
def point_in_time_join(
entity_table: Table,
feature_tables: List[Tuple[Table, str, Dict[str, str], List[str], timedelta]],
event_timestamp_col = 'event_timestamp'
event_timestamp_col="event_timestamp",
):
#TODO handle ttl
# TODO handle ttl
all_entities = [event_timestamp_col]
for feature_table, timestamp_field, join_key_map, _, _ in feature_tables:
all_entities.extend(join_key_map.values())
Expand All @@ -362,16 +368,25 @@ def point_in_time_join(

acc_table = entity_table

for feature_table, timestamp_field, join_key_map, feature_refs, ttl in feature_tables:
predicates = [feature_table[k] == entity_table[v] for k, v in join_key_map.items()]
for (
feature_table,
timestamp_field,
join_key_map,
feature_refs,
ttl,
) in feature_tables:
predicates = [
feature_table[k] == entity_table[v] for k, v in join_key_map.items()
]

predicates.append(
feature_table[timestamp_field] <= entity_table[event_timestamp_col],
)

if ttl:
predicates.append(
feature_table[timestamp_field] >= entity_table[event_timestamp_col] - ibis.literal(ttl)
feature_table[timestamp_field]
>= entity_table[event_timestamp_col] - ibis.literal(ttl)
)

feature_table = feature_table.inner_join(
Expand All @@ -386,7 +401,9 @@ def point_in_time_join(
.mutate(rn=ibis.row_number())
)

feature_table = feature_table.filter(feature_table["rn"] == ibis.literal(0)).drop("rn")
feature_table = feature_table.filter(
feature_table["rn"] == ibis.literal(0)
).drop("rn")

select_cols = ["entity_row_id"]
select_cols.extend(feature_refs)
Expand All @@ -401,6 +418,6 @@ def point_in_time_join(

acc_table = acc_table.drop(s.endswith("_yyyy"))

acc_table = acc_table.drop('entity_row_id')
acc_table = acc_table.drop("entity_row_id")

return acc_table
return acc_table
98 changes: 74 additions & 24 deletions sdk/python/tests/unit/infra/offline_stores/test_ibis.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,105 @@
from datetime import datetime, timedelta
from typing import Dict, List, Tuple

import ibis
import pyarrow as pa
from typing import List, Tuple, Dict
from feast.infra.offline_stores.contrib.ibis_offline_store.ibis import point_in_time_join
from pprint import pprint

from feast.infra.offline_stores.contrib.ibis_offline_store.ibis import (
point_in_time_join,
)


def pa_datetime(year, month, day):
return pa.scalar(datetime(year, month, day), type=pa.timestamp('s', tz='UTC'))
return pa.scalar(datetime(year, month, day), type=pa.timestamp("s", tz="UTC"))


def customer_table():
return pa.Table.from_arrays(
arrays=[
pa.array([1, 1, 2]),
pa.array([pa_datetime(2024, 1, 1),pa_datetime(2024, 1, 2),pa_datetime(2024, 1, 1)])
pa.array(
[
pa_datetime(2024, 1, 1),
pa_datetime(2024, 1, 2),
pa_datetime(2024, 1, 1),
]
),
],
names=['customer_id', 'event_timestamp']
names=["customer_id", "event_timestamp"],
)


def features_table_1():
return pa.Table.from_arrays(
arrays=[
pa.array([1, 1, 1, 2]),
pa.array([pa_datetime(2023, 12, 31), pa_datetime(2024, 1, 2), pa_datetime(2024, 1, 3), pa_datetime(2023, 1, 3)]),
pa.array([11, 22, 33, 22])
],
names=['customer_id', 'event_timestamp', 'feature1']
pa.array(
[
pa_datetime(2023, 12, 31),
pa_datetime(2024, 1, 2),
pa_datetime(2024, 1, 3),
pa_datetime(2023, 1, 3),
]
),
pa.array([11, 22, 33, 22]),
],
names=["customer_id", "event_timestamp", "feature1"],
)


def point_in_time_join_brute(
entity_table: pa.Table,
feature_tables: List[Tuple[pa.Table, str, Dict[str, str], List[str], timedelta]],
event_timestamp_col = 'event_timestamp'
event_timestamp_col="event_timestamp",
):
ret_fields = [entity_table.schema.field(n) for n in entity_table.schema.names]

from operator import itemgetter

ret = entity_table.to_pydict()
batch_dict = entity_table.to_pydict()

for i, row_timestmap in enumerate(batch_dict[event_timestamp_col]):
for feature_table, timestamp_key, join_key_map, feature_refs, ttl in feature_tables:
for (
feature_table,
timestamp_key,
join_key_map,
feature_refs,
ttl,
) in feature_tables:
if i == 0:
ret_fields.extend([feature_table.schema.field(f) for f in feature_table.schema.names if f not in join_key_map.values() and f != timestamp_key])
ret_fields.extend(
[
feature_table.schema.field(f)
for f in feature_table.schema.names
if f not in join_key_map.values() and f != timestamp_key
]
)

def check_equality(ft_dict, batch_dict, x, y):
return all([ft_dict[k][x] == batch_dict[v][y] for k, v in join_key_map.items()])
return all(
[ft_dict[k][x] == batch_dict[v][y] for k, v in join_key_map.items()]
)

ft_dict = feature_table.to_pydict()
found_matches = [
(j, ft_dict[timestamp_key][j]) for j in range(entity_table.num_rows)
if check_equality(ft_dict, batch_dict, j, i) and
ft_dict[timestamp_key][j] <= row_timestmap and
ft_dict[timestamp_key][j] >= row_timestmap - ttl
(j, ft_dict[timestamp_key][j])
for j in range(entity_table.num_rows)
if check_equality(ft_dict, batch_dict, j, i)
and ft_dict[timestamp_key][j] <= row_timestmap
and ft_dict[timestamp_key][j] >= row_timestmap - ttl
]

index_found = max(found_matches, key=itemgetter(1))[0] if found_matches else None
index_found = (
max(found_matches, key=itemgetter(1))[0] if found_matches else None
)
for col in ft_dict.keys():
if col not in feature_refs:
continue

if col not in ret:
ret[col] = []

if index_found is not None:
ret[col].append(ft_dict[col][index_found])
else:
Expand All @@ -74,15 +112,27 @@ def test_point_in_time_join():
expected = point_in_time_join_brute(
customer_table(),
feature_tables=[
(features_table_1(), 'event_timestamp', {'customer_id': 'customer_id'}, ['feature1'], timedelta(days=10))
]
(
features_table_1(),
"event_timestamp",
{"customer_id": "customer_id"},
["feature1"],
timedelta(days=10),
)
],
)

actual = point_in_time_join(
ibis.memtable(customer_table()),
feature_tables=[
(ibis.memtable(features_table_1()), 'event_timestamp', {'customer_id': 'customer_id'}, ['feature1'], timedelta(days=10))
]
(
ibis.memtable(features_table_1()),
"event_timestamp",
{"customer_id": "customer_id"},
["feature1"],
timedelta(days=10),
)
],
).to_pyarrow()

assert actual.equals(expected)