1616from collections import Counter , OrderedDict , defaultdict
1717from datetime import datetime , timedelta
1818from pathlib import Path
19- from typing import Any , Dict , Iterable , List , Optional , Tuple , Union
19+ from typing import Any , Dict , Iterable , List , Optional , Set , Tuple , Union , cast
2020
2121import pandas as pd
2222from colorama import Fore , Style
2323from tqdm import tqdm
2424
2525from feast import feature_server , utils
26+ from feast .data_source import RequestDataSource
2627from feast .entity import Entity
2728from feast .errors import (
2829 EntityNotFoundException ,
2930 FeatureNameCollisionError ,
3031 FeatureViewNotFoundException ,
32+ RequestDataNotFoundInEntityDfException ,
33+ RequestDataNotFoundInEntityRowsException ,
3134)
3235from feast .feature_service import FeatureService
3336from feast .feature_table import FeatureTable
@@ -402,7 +405,7 @@ def apply(
402405 view .infer_features_from_batch_source (self .config )
403406
404407 for odfv in odfvs_to_update :
405- odfv .infer_features_from_batch_source ( self . config )
408+ odfv .infer_features ( )
406409
407410 if len (views_to_update ) + len (entities_to_update ) + len (
408411 services_to_update
@@ -545,10 +548,26 @@ def get_historical_features(
545548 # TODO(achal): _group_feature_refs returns the on demand feature views, but it's no passed into the provider.
546549 # This is a weird interface quirk - we should revisit the `get_historical_features` to
547550 # pass in the on demand feature views as well.
548- fvs , _ = _group_feature_refs (
551+ fvs , odfvs = _group_feature_refs (
549552 _feature_refs , all_feature_views , all_on_demand_feature_views
550553 )
551554 feature_views = list (view for view , _ in fvs )
555+ on_demand_feature_views = list (view for view , _ in odfvs )
556+
557+ # Check that the right request data is present in the entity_df
558+ if type (entity_df ) == pd .DataFrame :
559+ entity_pd_df = cast (pd .DataFrame , entity_df )
560+ for odfv in on_demand_feature_views :
561+ odfv_inputs = odfv .inputs .values ()
562+ for odfv_input in odfv_inputs :
563+ if type (odfv_input ) == RequestDataSource :
564+ request_data_source = cast (RequestDataSource , odfv_input )
565+ for feature_name in request_data_source .schema .keys ():
566+ if feature_name not in entity_pd_df .columns :
567+ raise RequestDataNotFoundInEntityDfException (
568+ feature_name = feature_name ,
569+ feature_view_name = odfv .name ,
570+ )
552571
553572 _validate_feature_refs (_feature_refs , full_feature_names )
554573
@@ -789,7 +808,7 @@ def get_online_features(
789808 )
790809
791810 _validate_feature_refs (_feature_refs , full_feature_names )
792- grouped_refs , _ = _group_feature_refs (
811+ grouped_refs , grouped_odfv_refs = _group_feature_refs (
793812 _feature_refs , all_feature_views , all_on_demand_feature_views
794813 )
795814 feature_views = list (view for view , _ in grouped_refs )
@@ -805,28 +824,61 @@ def get_online_features(
805824 for entity in entities :
806825 entity_name_to_join_key_map [entity .name ] = entity .join_key
807826
827+ needed_request_data_features = self ._get_needed_request_data_features (
828+ grouped_odfv_refs
829+ )
830+
808831 join_key_rows = []
832+ request_data_features : Dict [str , List [Any ]] = {}
833+ # Entity rows may be either entities or request data.
809834 for row in entity_rows :
810835 join_key_row = {}
811836 for entity_name , entity_value in row .items ():
837+ # Found request data
838+ if entity_name in needed_request_data_features :
839+ if entity_name not in request_data_features :
840+ request_data_features [entity_name ] = []
841+ request_data_features [entity_name ].append (entity_value )
842+ continue
812843 try :
813844 join_key = entity_name_to_join_key_map [entity_name ]
814845 except KeyError :
815846 raise EntityNotFoundException (entity_name , self .project )
816847 join_key_row [join_key ] = entity_value
817848 if entityless_case :
818849 join_key_row [DUMMY_ENTITY_ID ] = DUMMY_ENTITY_VAL
819- join_key_rows .append (join_key_row )
850+ if len (join_key_row ) > 0 :
851+ # May be empty if this entity row was request data
852+ join_key_rows .append (join_key_row )
853+
854+ if len (needed_request_data_features ) != len (request_data_features .keys ()):
855+ raise RequestDataNotFoundInEntityRowsException (
856+ feature_names = needed_request_data_features
857+ )
820858
821859 entity_row_proto_list = _infer_online_entity_rows (join_key_rows )
822860
823- union_of_entity_keys = []
861+ union_of_entity_keys : List [ EntityKeyProto ] = []
824862 result_rows : List [GetOnlineFeaturesResponse .FieldValues ] = []
825863
826864 for entity_row_proto in entity_row_proto_list :
865+ # Create a list of entity keys to filter down for each feature view at lookup time.
827866 union_of_entity_keys .append (_entity_row_to_key (entity_row_proto ))
867+ # Also create entity values to append to the result
828868 result_rows .append (_entity_row_to_field_values (entity_row_proto ))
829869
870+ # Add more feature values to the existing result rows for the request data features
871+ for feature_name , feature_values in request_data_features .items ():
872+ for row_idx , feature_value in enumerate (feature_values ):
873+ result_row = result_rows [row_idx ]
874+ result_row .fields [feature_name ].CopyFrom (
875+ python_value_to_proto_value (feature_value )
876+ )
877+ result_row .statuses [
878+ feature_name
879+ ] = GetOnlineFeaturesResponse .FieldStatus .PRESENT
880+
881+ # Note: each "table" is a feature view
830882 for table , requested_features in grouped_refs :
831883 entity_keys = _get_table_entity_keys (
832884 table , union_of_entity_keys , entity_name_to_join_key_map
@@ -837,6 +889,7 @@ def get_online_features(
837889 entity_keys = entity_keys ,
838890 requested_features = requested_features ,
839891 )
892+ # Each row is a set of features for a given entity key
840893 for row_idx , read_row in enumerate (read_rows ):
841894 row_ts , feature_data = read_row
842895 result_row = result_rows [row_idx ]
@@ -873,6 +926,18 @@ def get_online_features(
873926 _feature_refs , full_feature_names , initial_response , result_rows
874927 )
875928
929+ def _get_needed_request_data_features (self , grouped_odfv_refs ) -> Set [str ]:
930+ needed_request_data_features = set ()
931+ for odfv_to_feature_names in grouped_odfv_refs :
932+ odfv , requested_feature_names = odfv_to_feature_names
933+ odfv_inputs = odfv .inputs .values ()
934+ for odfv_input in odfv_inputs :
935+ if type (odfv_input ) == RequestDataSource :
936+ request_data_source = cast (RequestDataSource , odfv_input )
937+ for feature_name in request_data_source .schema .keys ():
938+ needed_request_data_features .add (feature_name )
939+ return needed_request_data_features
940+
876941 def _augment_response_with_on_demand_transforms (
877942 self ,
878943 feature_refs : List [str ],
0 commit comments