Skip to content

Commit 76451a3

Browse files
pyalexkhorshuheng
andauthored
Optimize historical retrieval by filtering rows only within timestamp boundaries from entity dataframe (#87)
* prefilter feature rows in historical retrieval Signed-off-by: Oleksii Moskalenko <moskalenko.alexey@gmail.com> * revert requirements split Signed-off-by: Oleksii Moskalenko <moskalenko.alexey@gmail.com> * revert spark update Signed-off-by: Oleksii Moskalenko <moskalenko.alexey@gmail.com> Co-authored-by: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com>
1 parent bbea9f1 commit 76451a3

File tree

4 files changed

+70
-10
lines changed

4 files changed

+70
-10
lines changed

.github/workflows/pr.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ jobs:
4949
runs-on: [ubuntu-latest]
5050
needs: lint-python
5151
container: gcr.io/kf-feast/feast-ci:latest
52+
env:
53+
PYSPARK_PYTHON: python3.7
5254
steps:
5355
- uses: actions/checkout@v2
5456
- name: Install python

infra/docker/spark/Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ ADD https://repo1.maven.org/maven2/com/google/cloud/spark/spark-bigquery-with-de
3131
# Fix arrow issue for jdk-11
3232
RUN mkdir -p /opt/spark/conf
3333
RUN echo 'spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true"' >> $SPARK_HOME/conf/spark-defaults.conf
34+
RUN echo 'spark.driver.extraJavaOptions="-Dcom.google.cloud.spark.bigquery.repackaged.io.netty.tryReflectionSetAccessible=true"' >> $SPARK_HOME/conf/spark-defaults.conf
3435
RUN echo 'spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true"' >> $SPARK_HOME/conf/spark-defaults.conf
36+
RUN echo 'spark.executor.extraJavaOptions="-Dcom.google.cloud.spark.bigquery.repackaged.io.netty.tryReflectionSetAccessible=true"' >> $SPARK_HOME/conf/spark-defaults.conf
3537

3638
# For logging to /dev/termination-log
3739
RUN mkdir -p /dev

python/feast_spark/pyspark/historical_feature_retrieval_job.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,18 @@
88
from logging.config import dictConfig
99
from typing import Any, Dict, List, NamedTuple, Optional
1010

11+
import numpy as np
12+
import pandas as pd
1113
from pyspark.sql import DataFrame, SparkSession, Window
12-
from pyspark.sql.functions import col, expr, monotonically_increasing_id, row_number
13-
from pyspark.sql.types import LongType
14+
from pyspark.sql.functions import (
15+
col,
16+
expr,
17+
monotonically_increasing_id,
18+
row_number,
19+
struct,
20+
)
21+
from pyspark.sql.pandas.functions import PandasUDFType, pandas_udf
22+
from pyspark.sql.types import BooleanType, LongType
1423

1524
EVENT_TIMESTAMP_ALIAS = "event_timestamp"
1625
CREATED_TIMESTAMP_ALIAS = "created_timestamp"
@@ -548,19 +557,50 @@ class SchemaError(Exception):
548557
pass
549558

550559

560+
def _make_time_filter_pandas_udf(
561+
spark: SparkSession,
562+
entity_pandas: pd.DataFrame,
563+
feature_table: FeatureTable,
564+
entity_event_timestamp_column: str,
565+
):
566+
entity_br = spark.sparkContext.broadcast(
567+
entity_pandas.rename(
568+
columns={entity_event_timestamp_column: EVENT_TIMESTAMP_ALIAS}
569+
)
570+
)
571+
entity_names = feature_table.entity_names
572+
max_age = feature_table.max_age
573+
574+
@pandas_udf(BooleanType(), PandasUDFType.SCALAR)
575+
def within_time_boundaries(features: pd.DataFrame) -> pd.Series:
576+
features["_row_id"] = np.arange(len(features))
577+
merged = features.merge(
578+
entity_br.value,
579+
how="left",
580+
on=entity_names,
581+
suffixes=("_feature", "_entity"),
582+
)
583+
merged["distance"] = (
584+
merged[f"{EVENT_TIMESTAMP_ALIAS}_entity"]
585+
- merged[f"{EVENT_TIMESTAMP_ALIAS}_feature"]
586+
)
587+
merged["within"] = merged["distance"].dt.total_seconds().between(0, max_age)
588+
589+
return merged.groupby(["_row_id"]).max()["within"]
590+
591+
return within_time_boundaries
592+
593+
551594
def _filter_feature_table_by_time_range(
595+
spark: SparkSession,
552596
feature_table_df: DataFrame,
553597
feature_table: FeatureTable,
554598
feature_event_timestamp_column: str,
555-
entity_df: DataFrame,
599+
entity_pandas: pd.DataFrame,
556600
entity_event_timestamp_column: str,
557601
):
558-
entity_max_timestamp = entity_df.agg(
559-
{entity_event_timestamp_column: "max"}
560-
).collect()[0][0]
561-
entity_min_timestamp = entity_df.agg(
562-
{entity_event_timestamp_column: "min"}
563-
).collect()[0][0]
602+
entity_max_timestamp = entity_pandas[entity_event_timestamp_column].max()
603+
entity_min_timestamp = entity_pandas[entity_event_timestamp_column].min()
564604

565605
feature_table_timestamp_filter = (
566606
col(feature_event_timestamp_column).between(
@@ -573,6 +613,18 @@ def _filter_feature_table_by_time_range(
573613

574614
time_range_filtered_df = feature_table_df.filter(feature_table_timestamp_filter)
575615

616+
if feature_table.max_age:
617+
within_time_boundaries_udf = _make_time_filter_pandas_udf(
618+
spark, entity_pandas, feature_table, entity_event_timestamp_column
619+
)
620+
621+
time_range_filtered_df = time_range_filtered_df.withColumn(
622+
"within_time_boundaries",
623+
within_time_boundaries_udf(
624+
struct(feature_table.entity_names + [feature_event_timestamp_column])
625+
),
626+
).filter("within_time_boundaries = true")
627+
576628
return time_range_filtered_df
577629

578630

@@ -755,12 +807,15 @@ def retrieve_historical_features(
755807
f"{expected_entity.name} ({expected_entity.spark_type}) is not present in the entity dataframe."
756808
)
757809

810+
entity_pandas = entity_df.toPandas()
811+
758812
feature_table_dfs = [
759813
_filter_feature_table_by_time_range(
814+
spark,
760815
feature_table_df,
761816
feature_table,
762817
feature_table_source.event_timestamp_column,
763-
entity_df,
818+
entity_pandas,
764819
entity_source.event_timestamp_column,
765820
)
766821
for feature_table_df, feature_table, feature_table_source in zip(

python/tests/test_historical_feature_retrieval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def test_historical_feature_retrieval(spark: SparkSession):
601601
"name": "bookings",
602602
"entities": [{"name": "driver_id", "type": "int32"}],
603603
"features": [{"name": "completed_bookings", "type": "int32"}],
604+
"max_age": 365 * 86400,
604605
}
605606
transaction_table = {
606607
"name": "transactions",

0 commit comments

Comments
 (0)