Skip to content

Commit 3980e0c

Browse files
authored
feat: Rewrite ibis point-in-time-join w/o feast abstractions (feast-dev#4023)
* feat: refactor ibis point-in-time-join Signed-off-by: tokoko <togurg14@freeuni.edu.ge> * fix formatting, linting Signed-off-by: tokoko <togurg14@freeuni.edu.ge> --------- Signed-off-by: tokoko <togurg14@freeuni.edu.ge>
1 parent afd52b8 commit 3980e0c

5 files changed

Lines changed: 295 additions & 118 deletions

File tree

sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py

Lines changed: 132 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -72,112 +72,6 @@ def _get_entity_df_event_timestamp_range(
7272

7373
return entity_df_event_timestamp_range
7474

75-
@staticmethod
76-
def _get_historical_features_one(
77-
feature_view: FeatureView,
78-
entity_table: Table,
79-
feature_refs: List[str],
80-
full_feature_names: bool,
81-
timestamp_range: Tuple,
82-
acc_table: Table,
83-
event_timestamp_col: str,
84-
) -> Table:
85-
fv_table: Table = ibis.read_parquet(feature_view.batch_source.name)
86-
87-
for old_name, new_name in feature_view.batch_source.field_mapping.items():
88-
if old_name in fv_table.columns:
89-
fv_table = fv_table.rename({new_name: old_name})
90-
91-
timestamp_field = feature_view.batch_source.timestamp_field
92-
93-
# TODO mutate only if tz-naive
94-
fv_table = fv_table.mutate(
95-
**{
96-
timestamp_field: fv_table[timestamp_field].cast(
97-
dt.Timestamp(timezone="UTC")
98-
)
99-
}
100-
)
101-
102-
full_name_prefix = feature_view.projection.name_alias or feature_view.name
103-
104-
feature_refs = [
105-
fr.split(":")[1]
106-
for fr in feature_refs
107-
if fr.startswith(f"{full_name_prefix}:")
108-
]
109-
110-
timestamp_range_start_minus_ttl = (
111-
timestamp_range[0] - feature_view.ttl
112-
if feature_view.ttl and feature_view.ttl > timedelta(0, 0, 0, 0, 0, 0, 0)
113-
else timestamp_range[0]
114-
)
115-
116-
timestamp_range_start_minus_ttl = ibis.literal(
117-
timestamp_range_start_minus_ttl.strftime("%Y-%m-%d %H:%M:%S.%f")
118-
).cast(dt.Timestamp(timezone="UTC"))
119-
120-
timestamp_range_end = ibis.literal(
121-
timestamp_range[1].strftime("%Y-%m-%d %H:%M:%S.%f")
122-
).cast(dt.Timestamp(timezone="UTC"))
123-
124-
fv_table = fv_table.filter(
125-
ibis.and_(
126-
fv_table[timestamp_field] <= timestamp_range_end,
127-
fv_table[timestamp_field] >= timestamp_range_start_minus_ttl,
128-
)
129-
)
130-
131-
# join_key_map = feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns}
132-
# predicates = [fv_table[k] == entity_table[v] for k, v in join_key_map.items()]
133-
134-
if feature_view.projection.join_key_map:
135-
predicates = [
136-
fv_table[k] == entity_table[v]
137-
for k, v in feature_view.projection.join_key_map.items()
138-
]
139-
else:
140-
predicates = [
141-
fv_table[e.name] == entity_table[e.name]
142-
for e in feature_view.entity_columns
143-
]
144-
145-
predicates.append(
146-
fv_table[timestamp_field] <= entity_table[event_timestamp_col]
147-
)
148-
149-
fv_table = fv_table.inner_join(
150-
entity_table, predicates, lname="", rname="{name}_y"
151-
)
152-
153-
fv_table = (
154-
fv_table.group_by(by="entity_row_id")
155-
.order_by(ibis.desc(fv_table[timestamp_field]))
156-
.mutate(rn=ibis.row_number())
157-
)
158-
159-
fv_table = fv_table.filter(fv_table["rn"] == ibis.literal(0))
160-
161-
select_cols = ["entity_row_id"]
162-
select_cols.extend(feature_refs)
163-
fv_table = fv_table.select(select_cols)
164-
165-
if full_feature_names:
166-
fv_table = fv_table.rename(
167-
{f"{full_name_prefix}__{feature}": feature for feature in feature_refs}
168-
)
169-
170-
acc_table = acc_table.left_join(
171-
fv_table,
172-
predicates=[fv_table.entity_row_id == acc_table.entity_row_id],
173-
lname="",
174-
rname="{name}_yyyy",
175-
)
176-
177-
acc_table = acc_table.drop(s.endswith("_yyyy"))
178-
179-
return acc_table
180-
18175
@staticmethod
18276
def _to_utc(entity_df: pd.DataFrame, event_timestamp_col):
18377
entity_df_event_timestamp = entity_df.loc[
@@ -228,30 +122,73 @@ def get_historical_features(
228122
entity_schema=entity_schema,
229123
)
230124

125+
# TODO get range with ibis
231126
timestamp_range = IbisOfflineStore._get_entity_df_event_timestamp_range(
232127
entity_df, event_timestamp_col
233128
)
129+
234130
entity_df = IbisOfflineStore._to_utc(entity_df, event_timestamp_col)
235131

236132
entity_table = ibis.memtable(entity_df)
237133
entity_table = IbisOfflineStore._generate_row_id(
238134
entity_table, feature_views, event_timestamp_col
239135
)
240136

241-
res: Table = entity_table
137+
def read_fv(feature_view, feature_refs, full_feature_names):
138+
fv_table: Table = ibis.read_parquet(feature_view.batch_source.name)
242139

243-
for fv in feature_views:
244-
res = IbisOfflineStore._get_historical_features_one(
245-
fv,
246-
entity_table,
140+
for old_name, new_name in feature_view.batch_source.field_mapping.items():
141+
if old_name in fv_table.columns:
142+
fv_table = fv_table.rename({new_name: old_name})
143+
144+
timestamp_field = feature_view.batch_source.timestamp_field
145+
146+
# TODO mutate only if tz-naive
147+
fv_table = fv_table.mutate(
148+
**{
149+
timestamp_field: fv_table[timestamp_field].cast(
150+
dt.Timestamp(timezone="UTC")
151+
)
152+
}
153+
)
154+
155+
full_name_prefix = feature_view.projection.name_alias or feature_view.name
156+
157+
feature_refs = [
158+
fr.split(":")[1]
159+
for fr in feature_refs
160+
if fr.startswith(f"{full_name_prefix}:")
161+
]
162+
163+
if full_feature_names:
164+
fv_table = fv_table.rename(
165+
{
166+
f"{full_name_prefix}__{feature}": feature
167+
for feature in feature_refs
168+
}
169+
)
170+
171+
feature_refs = [
172+
f"{full_name_prefix}__{feature}" for feature in feature_refs
173+
]
174+
175+
return (
176+
fv_table,
177+
feature_view.batch_source.timestamp_field,
178+
feature_view.projection.join_key_map
179+
or {e.name: e.name for e in feature_view.entity_columns},
247180
feature_refs,
248-
full_feature_names,
249-
timestamp_range,
250-
res,
251-
event_timestamp_col,
181+
feature_view.ttl,
252182
)
253183

254-
res = res.drop("entity_row_id")
184+
res = point_in_time_join(
185+
entity_table=entity_table,
186+
feature_tables=[
187+
read_fv(feature_view, feature_refs, full_feature_names)
188+
for feature_view in feature_views
189+
],
190+
event_timestamp_col=event_timestamp_col,
191+
)
255192

256193
return IbisRetrievalJob(
257194
res,
@@ -285,6 +222,10 @@ def pull_all_from_table_or_query(
285222

286223
table = table.select(*fields)
287224

225+
# TODO get rid of this fix
226+
if "__log_date" in table.columns:
227+
table = table.drop("__log_date")
228+
288229
table = table.filter(
289230
ibis.and_(
290231
table[timestamp_field] >= ibis.literal(start_date),
@@ -320,6 +261,7 @@ def write_logged_features(
320261
else:
321262
kwargs = {}
322263

264+
# TODO always write to directory
323265
table.to_parquet(
324266
f"{destination.path}/{uuid.uuid4().hex}-{{i}}.parquet", **kwargs
325267
)
@@ -405,3 +347,77 @@ def persist(
405347
@property
406348
def metadata(self) -> Optional[RetrievalMetadata]:
407349
return self._metadata
350+
351+
352+
def point_in_time_join(
353+
entity_table: Table,
354+
feature_tables: List[Tuple[Table, str, Dict[str, str], List[str], timedelta]],
355+
event_timestamp_col="event_timestamp",
356+
):
357+
# TODO handle ttl
358+
all_entities = [event_timestamp_col]
359+
for feature_table, timestamp_field, join_key_map, _, _ in feature_tables:
360+
all_entities.extend(join_key_map.values())
361+
362+
r = ibis.literal("")
363+
364+
for e in set(all_entities):
365+
r = r.concat(entity_table[e].cast("string")) # type: ignore
366+
367+
entity_table = entity_table.mutate(entity_row_id=r)
368+
369+
acc_table = entity_table
370+
371+
for (
372+
feature_table,
373+
timestamp_field,
374+
join_key_map,
375+
feature_refs,
376+
ttl,
377+
) in feature_tables:
378+
predicates = [
379+
feature_table[k] == entity_table[v] for k, v in join_key_map.items()
380+
]
381+
382+
predicates.append(
383+
feature_table[timestamp_field] <= entity_table[event_timestamp_col],
384+
)
385+
386+
if ttl:
387+
predicates.append(
388+
feature_table[timestamp_field]
389+
>= entity_table[event_timestamp_col] - ibis.literal(ttl)
390+
)
391+
392+
feature_table = feature_table.inner_join(
393+
entity_table, predicates, lname="", rname="{name}_y"
394+
)
395+
396+
feature_table = feature_table.drop(s.endswith("_y"))
397+
398+
feature_table = (
399+
feature_table.group_by(by="entity_row_id")
400+
.order_by(ibis.desc(feature_table[timestamp_field]))
401+
.mutate(rn=ibis.row_number())
402+
)
403+
404+
feature_table = feature_table.filter(
405+
feature_table["rn"] == ibis.literal(0)
406+
).drop("rn")
407+
408+
select_cols = ["entity_row_id"]
409+
select_cols.extend(feature_refs)
410+
feature_table = feature_table.select(select_cols)
411+
412+
acc_table = acc_table.left_join(
413+
feature_table,
414+
predicates=[feature_table.entity_row_id == acc_table.entity_row_id],
415+
lname="",
416+
rname="{name}_yyyy",
417+
)
418+
419+
acc_table = acc_table.drop(s.endswith("_yyyy"))
420+
421+
acc_table = acc_table.drop("entity_row_id")
422+
423+
return acc_table

sdk/python/requirements/py3.10-ci-requirements.txt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,12 @@ docker==7.0.0
164164
# testcontainers
165165
docutils==0.19
166166
# via sphinx
167+
duckdb==0.10.1
168+
# via
169+
# duckdb-engine
170+
# ibis-framework
171+
duckdb-engine==0.11.2
172+
# via ibis-framework
167173
entrypoints==0.4
168174
# via altair
169175
exceptiongroup==1.2.0
@@ -310,7 +316,7 @@ httpx==0.27.0
310316
# via
311317
# feast (setup.py)
312318
# jupyterlab
313-
ibis-framework==8.0.0
319+
ibis-framework[duckdb]==8.0.0
314320
# via
315321
# feast (setup.py)
316322
# ibis-substrait
@@ -848,8 +854,13 @@ sphinxcontrib-serializinghtml==1.1.10
848854
# via sphinx
849855
sqlalchemy[mypy]==1.4.52
850856
# via
857+
# duckdb-engine
851858
# feast (setup.py)
859+
# ibis-framework
852860
# sqlalchemy
861+
# sqlalchemy-views
862+
sqlalchemy-views==0.3.2
863+
# via ibis-framework
853864
sqlalchemy2-stubs==0.0.2a38
854865
# via sqlalchemy
855866
sqlglot==20.11.0

sdk/python/requirements/py3.9-ci-requirements.txt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,12 @@ docker==7.0.0
164164
# testcontainers
165165
docutils==0.19
166166
# via sphinx
167+
duckdb==0.10.1
168+
# via
169+
# duckdb-engine
170+
# ibis-framework
171+
duckdb-engine==0.11.2
172+
# via ibis-framework
167173
entrypoints==0.4
168174
# via altair
169175
exceptiongroup==1.2.0
@@ -310,7 +316,7 @@ httpx==0.27.0
310316
# via
311317
# feast (setup.py)
312318
# jupyterlab
313-
ibis-framework==8.0.0
319+
ibis-framework[duckdb]==8.0.0
314320
# via
315321
# feast (setup.py)
316322
# ibis-substrait
@@ -858,8 +864,13 @@ sphinxcontrib-serializinghtml==1.1.10
858864
# via sphinx
859865
sqlalchemy[mypy]==1.4.52
860866
# via
867+
# duckdb-engine
861868
# feast (setup.py)
869+
# ibis-framework
862870
# sqlalchemy
871+
# sqlalchemy-views
872+
sqlalchemy-views==0.3.2
873+
# via ibis-framework
863874
sqlalchemy2-stubs==0.0.2a38
864875
# via sqlalchemy
865876
sqlglot==20.11.0

0 commit comments

Comments
 (0)