88from logging .config import dictConfig
99from typing import Any , Dict , List , NamedTuple , Optional
1010
11+ import numpy as np
12+ import pandas as pd
1113from pyspark .sql import DataFrame , SparkSession , Window
12- from pyspark .sql .functions import col , expr , monotonically_increasing_id , row_number
13- from pyspark .sql .types import LongType
14+ from pyspark .sql .functions import (
15+ col ,
16+ expr ,
17+ monotonically_increasing_id ,
18+ row_number ,
19+ struct ,
20+ )
21+ from pyspark .sql .pandas .functions import PandasUDFType , pandas_udf
22+ from pyspark .sql .types import BooleanType , LongType
1423
1524EVENT_TIMESTAMP_ALIAS = "event_timestamp"
1625CREATED_TIMESTAMP_ALIAS = "created_timestamp"
@@ -548,19 +557,50 @@ class SchemaError(Exception):
548557 pass
549558
550559
560+ def _make_time_filter_pandas_udf (
561+ spark : SparkSession ,
562+ entity_pandas : pd .DataFrame ,
563+ feature_table : FeatureTable ,
564+ entity_event_timestamp_column : str ,
565+ ):
566+ entity_br = spark .sparkContext .broadcast (
567+ entity_pandas .rename (
568+ columns = {entity_event_timestamp_column : EVENT_TIMESTAMP_ALIAS }
569+ )
570+ )
571+ entity_names = feature_table .entity_names
572+ max_age = feature_table .max_age
573+
574+ @pandas_udf (BooleanType (), PandasUDFType .SCALAR )
575+ def within_time_boundaries (features : pd .DataFrame ) -> pd .Series :
576+ features ["_row_id" ] = np .arange (len (features ))
577+ merged = features .merge (
578+ entity_br .value ,
579+ how = "left" ,
580+ on = entity_names ,
581+ suffixes = ("_feature" , "_entity" ),
582+ )
583+ merged ["distance" ] = (
584+ merged [f"{ EVENT_TIMESTAMP_ALIAS } _entity" ]
585+ - merged [f"{ EVENT_TIMESTAMP_ALIAS } _feature" ]
586+ )
587+ merged ["within" ] = merged ["distance" ].dt .total_seconds ().between (0 , max_age )
588+
589+ return merged .groupby (["_row_id" ]).max ()["within" ]
590+
591+ return within_time_boundaries
592+
593+
551594def _filter_feature_table_by_time_range (
595+ spark : SparkSession ,
552596 feature_table_df : DataFrame ,
553597 feature_table : FeatureTable ,
554598 feature_event_timestamp_column : str ,
555- entity_df : DataFrame ,
599+ entity_pandas : pd . DataFrame ,
556600 entity_event_timestamp_column : str ,
557601):
558- entity_max_timestamp = entity_df .agg (
559- {entity_event_timestamp_column : "max" }
560- ).collect ()[0 ][0 ]
561- entity_min_timestamp = entity_df .agg (
562- {entity_event_timestamp_column : "min" }
563- ).collect ()[0 ][0 ]
602+ entity_max_timestamp = entity_pandas [entity_event_timestamp_column ].max ()
603+ entity_min_timestamp = entity_pandas [entity_event_timestamp_column ].min ()
564604
565605 feature_table_timestamp_filter = (
566606 col (feature_event_timestamp_column ).between (
@@ -573,6 +613,18 @@ def _filter_feature_table_by_time_range(
573613
574614 time_range_filtered_df = feature_table_df .filter (feature_table_timestamp_filter )
575615
616+ if feature_table .max_age :
617+ within_time_boundaries_udf = _make_time_filter_pandas_udf (
618+ spark , entity_pandas , feature_table , entity_event_timestamp_column
619+ )
620+
621+ time_range_filtered_df = time_range_filtered_df .withColumn (
622+ "within_time_boundaries" ,
623+ within_time_boundaries_udf (
624+ struct (feature_table .entity_names + [feature_event_timestamp_column ])
625+ ),
626+ ).filter ("within_time_boundaries = true" )
627+
576628 return time_range_filtered_df
577629
578630
@@ -755,12 +807,15 @@ def retrieve_historical_features(
755807 f"{ expected_entity .name } ({ expected_entity .spark_type } ) is not present in the entity dataframe."
756808 )
757809
810+ entity_pandas = entity_df .toPandas ()
811+
758812 feature_table_dfs = [
759813 _filter_feature_table_by_time_range (
814+ spark ,
760815 feature_table_df ,
761816 feature_table ,
762817 feature_table_source .event_timestamp_column ,
763- entity_df ,
818+ entity_pandas ,
764819 entity_source .event_timestamp_column ,
765820 )
766821 for feature_table_df , feature_table , feature_table_source in zip (
0 commit comments