2323 Dict ,
2424 Iterable ,
2525 List ,
26+ Mapping ,
2627 NamedTuple ,
2728 Optional ,
29+ Sequence ,
2830 Set ,
2931 Tuple ,
3032 Union ,
7274 GetOnlineFeaturesResponse ,
7375)
7476from feast .protos .feast .types .EntityKey_pb2 import EntityKey as EntityKeyProto
75- from feast .protos .feast .types .Value_pb2 import Value
77+ from feast .protos .feast .types .Value_pb2 import RepeatedValue , Value
7678from feast .registry import Registry
7779from feast .repo_config import RepoConfig , load_repo_config
7880from feast .request_feature_view import RequestFeatureView
@@ -267,14 +269,18 @@ def _list_feature_views(
267269 return feature_views
268270
269271 @log_exceptions_and_usage
270- def list_on_demand_feature_views (self ) -> List [OnDemandFeatureView ]:
272+ def list_on_demand_feature_views (
273+ self , allow_cache : bool = False
274+ ) -> List [OnDemandFeatureView ]:
271275 """
272276 Retrieves the list of on demand feature views from the registry.
273277
274278 Returns:
275279 A list of on demand feature views.
276280 """
277- return self ._registry .list_on_demand_feature_views (self .project )
281+ return self ._registry .list_on_demand_feature_views (
282+ self .project , allow_cache = allow_cache
283+ )
278284
279285 @log_exceptions_and_usage
280286 def get_entity (self , name : str ) -> Entity :
@@ -1067,6 +1073,30 @@ def get_online_features(
10671073 ... )
10681074 >>> online_response_dict = online_response.to_dict()
10691075 """
1076+ columnar : Dict [str , List [Any ]] = {k : [] for k in entity_rows [0 ].keys ()}
1077+ for entity_row in entity_rows :
1078+ for key , value in entity_row .items ():
1079+ try :
1080+ columnar [key ].append (value )
1081+ except KeyError as e :
1082+ raise ValueError ("All entity_rows must have the same keys." ) from e
1083+
1084+ return self ._get_online_features (
1085+ features = features ,
1086+ entity_values = columnar ,
1087+ full_feature_names = full_feature_names ,
1088+ native_entity_values = True ,
1089+ )
1090+
1091+ def _get_online_features (
1092+ self ,
1093+ features : Union [List [str ], FeatureService ],
1094+ entity_values : Mapping [
1095+ str , Union [Sequence [Any ], Sequence [Value ], RepeatedValue ]
1096+ ],
1097+ full_feature_names : bool = False ,
1098+ native_entity_values : bool = True ,
1099+ ):
10701100 _feature_refs = self ._get_features (features , allow_cache = True )
10711101 (
10721102 requested_feature_views ,
@@ -1076,6 +1106,29 @@ def get_online_features(
10761106 features = features , allow_cache = True , hide_dummy_entity = False
10771107 )
10781108
1109+ entity_name_to_join_key_map , entity_type_map = self ._get_entity_maps (
1110+ requested_feature_views
1111+ )
1112+
1113+ # Extract Sequence from RepeatedValue Protobuf.
1114+ entity_value_lists : Dict [str , Union [List [Any ], List [Value ]]] = {
1115+ k : list (v ) if isinstance (v , Sequence ) else list (v .val )
1116+ for k , v in entity_values .items ()
1117+ }
1118+
1119+ entity_proto_values : Dict [str , List [Value ]]
1120+ if native_entity_values :
1121+ # Convert values to Protobuf once.
1122+ entity_proto_values = {
1123+ k : python_values_to_proto_values (
1124+ v , entity_type_map .get (k , ValueType .UNKNOWN )
1125+ )
1126+ for k , v in entity_value_lists .items ()
1127+ }
1128+ else :
1129+ entity_proto_values = entity_value_lists
1130+
1131+ num_rows = _validate_entity_values (entity_proto_values )
10791132 _validate_feature_refs (_feature_refs , full_feature_names )
10801133 (
10811134 grouped_refs ,
@@ -1101,111 +1154,72 @@ def get_online_features(
11011154 }
11021155
11031156 feature_views = list (view for view , _ in grouped_refs )
1104- entityless_case = DUMMY_ENTITY_NAME in [
1105- entity_name
1106- for feature_view in feature_views
1107- for entity_name in feature_view .entities
1108- ]
1109-
1110- provider = self ._get_provider ()
1111- entities = self ._list_entities (allow_cache = True , hide_dummy_entity = False )
1112- entity_name_to_join_key_map : Dict [str , str ] = {}
1113- join_key_to_entity_type_map : Dict [str , ValueType ] = {}
1114- for entity in entities :
1115- entity_name_to_join_key_map [entity .name ] = entity .join_key
1116- join_key_to_entity_type_map [entity .join_key ] = entity .value_type
1117- for feature_view in requested_feature_views :
1118- for entity_name in feature_view .entities :
1119- entity = self ._registry .get_entity (
1120- entity_name , self .project , allow_cache = True
1121- )
1122- # User directly uses join_key as the entity reference in the entity_rows for the
1123- # entity mapping case.
1124- entity_name = feature_view .projection .join_key_map .get (
1125- entity .join_key , entity .name
1126- )
1127- join_key = feature_view .projection .join_key_map .get (
1128- entity .join_key , entity .join_key
1129- )
1130- entity_name_to_join_key_map [entity_name ] = join_key
1131- join_key_to_entity_type_map [join_key ] = entity .value_type
11321157
11331158 needed_request_data , needed_request_fv_features = self .get_needed_request_data (
11341159 grouped_odfv_refs , grouped_request_fv_refs
11351160 )
11361161
1137- join_key_rows = []
1138- request_data_features : Dict [str , List [Any ]] = defaultdict ( list )
1162+ join_key_values : Dict [ str , List [ Value ]] = {}
1163+ request_data_features : Dict [str , List [Value ]] = {}
11391164 # Entity rows may be either entities or request data.
1140- for row in entity_rows :
1141- join_key_row = {}
1142- for entity_name , entity_value in row .items ():
1143- # Found request data
1144- if (
1145- entity_name in needed_request_data
1146- or entity_name in needed_request_fv_features
1147- ):
1148- if entity_name in needed_request_fv_features :
1149- # If the data was requested as a feature then
1150- # make sure it appears in the result.
1151- requested_result_row_names .add (entity_name )
1152- request_data_features [entity_name ].append (entity_value )
1153- else :
1154- try :
1155- join_key = entity_name_to_join_key_map [entity_name ]
1156- except KeyError :
1157- raise EntityNotFoundException (entity_name , self .project )
1158- # All join keys should be returned in the result.
1159- requested_result_row_names .add (join_key )
1160- join_key_row [join_key ] = entity_value
1161- if entityless_case :
1162- join_key_row [DUMMY_ENTITY_ID ] = DUMMY_ENTITY_VAL
1163- if len (join_key_row ) > 0 :
1164- # May be empty if this entity row was request data
1165- join_key_rows .append (join_key_row )
1165+ for entity_name , values in entity_proto_values .items ():
1166+ # Found request data
1167+ if (
1168+ entity_name in needed_request_data
1169+ or entity_name in needed_request_fv_features
1170+ ):
1171+ if entity_name in needed_request_fv_features :
1172+ # If the data was requested as a feature then
1173+ # make sure it appears in the result.
1174+ requested_result_row_names .add (entity_name )
1175+ request_data_features [entity_name ] = values
1176+ else :
1177+ try :
1178+ join_key = entity_name_to_join_key_map [entity_name ]
1179+ except KeyError :
1180+ raise EntityNotFoundException (entity_name , self .project )
1181+ # All join keys should be returned in the result.
1182+ requested_result_row_names .add (join_key )
1183+ join_key_values [join_key ] = values
11661184
11671185 self .ensure_request_data_values_exist (
11681186 needed_request_data , needed_request_fv_features , request_data_features
11691187 )
11701188
1171- # Convert join_key_rows from rowise to columnar.
1172- join_key_python_values : Dict [str , List [Value ]] = defaultdict (list )
1173- for join_key_row in join_key_rows :
1174- for join_key , value in join_key_row .items ():
1175- join_key_python_values [join_key ].append (value )
1176-
1177- # Convert all join key values to Protobuf Values
1178- join_key_proto_values = {
1179- k : python_values_to_proto_values (v , join_key_to_entity_type_map [k ])
1180- for k , v in join_key_python_values .items ()
1181- }
1182-
1183- # Populate online features response proto with join keys
1189+ # Populate online features response proto with join keys and request data features
11841190 online_features_response = GetOnlineFeaturesResponse (
1185- results = [
1186- GetOnlineFeaturesResponse .FeatureVector ()
1187- for _ in range (len (entity_rows ))
1188- ]
1191+ results = [GetOnlineFeaturesResponse .FeatureVector () for _ in range (num_rows )]
11891192 )
1190- for key , values in join_key_proto_values .items ():
1191- online_features_response .metadata .feature_names .val .append (key )
1192- for row_idx , result_row in enumerate (online_features_response .results ):
1193- result_row .values .append (values [row_idx ])
1194- result_row .statuses .append (FieldStatus .PRESENT )
1195- result_row .event_timestamps .append (Timestamp ())
1193+ self ._populate_result_rows_from_columnar (
1194+ online_features_response = online_features_response ,
1195+ data = dict (** join_key_values , ** request_data_features ),
1196+ )
1197+
1198+ # Add the Entityless case after populating result rows to avoid having to remove
1199+ # it later.
1200+ entityless_case = DUMMY_ENTITY_NAME in [
1201+ entity_name
1202+ for feature_view in feature_views
1203+ for entity_name in feature_view .entities
1204+ ]
1205+ if entityless_case :
1206+ join_key_values [DUMMY_ENTITY_ID ] = python_values_to_proto_values (
1207+ [DUMMY_ENTITY_VAL ] * num_rows , DUMMY_ENTITY .value_type
1208+ )
11961209
11971210 # Initialize the set of EntityKeyProtos once and reuse them for each FeatureView
11981211 # to avoid initialization overhead.
1199- entity_keys = [EntityKeyProto () for _ in range (len (join_key_rows ))]
1212+ entity_keys = [EntityKeyProto () for _ in range (num_rows )]
1213+ provider = self ._get_provider ()
12001214 for table , requested_features in grouped_refs :
12011215 # Get the correct set of entity values with the correct join keys.
1202- entity_values = self ._get_table_entity_values (
1203- table , entity_name_to_join_key_map , join_key_proto_values ,
1216+ table_entity_values = self ._get_table_entity_values (
1217+ table , entity_name_to_join_key_map , join_key_values ,
12041218 )
12051219
12061220 # Set the EntityKeyProtos inplace.
12071221 self ._set_table_entity_keys (
1208- entity_values , entity_keys ,
1222+ table_entity_values , entity_keys ,
12091223 )
12101224
12111225 # Populate the result_rows with the Features from the OnlineStore inplace.
@@ -1218,10 +1232,6 @@ def get_online_features(
12181232 table ,
12191233 )
12201234
1221- self ._populate_request_data_features (
1222- online_features_response , request_data_features
1223- )
1224-
12251235 if grouped_odfv_refs :
12261236 self ._augment_response_with_on_demand_transforms (
12271237 online_features_response ,
@@ -1235,6 +1245,50 @@ def get_online_features(
12351245 )
12361246 return OnlineResponse (online_features_response )
12371247
1248+ @staticmethod
1249+ def _get_columnar_entity_values (
1250+ rowise : Optional [List [Dict [str , Any ]]], columnar : Optional [Dict [str , List [Any ]]]
1251+ ) -> Dict [str , List [Any ]]:
1252+ if (rowise is None and columnar is None ) or (
1253+ rowise is not None and columnar is not None
1254+ ):
1255+ raise ValueError (
1256+ "Exactly one of `columnar_entity_values` and `rowise_entity_values` must be set."
1257+ )
1258+
1259+ if rowise is not None :
1260+ # Convert entity_rows from rowise to columnar.
1261+ res = defaultdict (list )
1262+ for entity_row in rowise :
1263+ for key , value in entity_row .items ():
1264+ res [key ].append (value )
1265+ return res
1266+ return cast (Dict [str , List [Any ]], columnar )
1267+
1268+ def _get_entity_maps (self , feature_views ):
1269+ entities = self ._list_entities (allow_cache = True , hide_dummy_entity = False )
1270+ entity_name_to_join_key_map : Dict [str , str ] = {}
1271+ entity_type_map : Dict [str , ValueType ] = {}
1272+ for entity in entities :
1273+ entity_name_to_join_key_map [entity .name ] = entity .join_key
1274+ entity_type_map [entity .name ] = entity .value_type
1275+ for feature_view in feature_views :
1276+ for entity_name in feature_view .entities :
1277+ entity = self ._registry .get_entity (
1278+ entity_name , self .project , allow_cache = True
1279+ )
1280+ # User directly uses join_key as the entity reference in the entity_rows for the
1281+ # entity mapping case.
1282+ entity_name = feature_view .projection .join_key_map .get (
1283+ entity .join_key , entity .name
1284+ )
1285+ join_key = feature_view .projection .join_key_map .get (
1286+ entity .join_key , entity .join_key
1287+ )
1288+ entity_name_to_join_key_map [entity_name ] = join_key
1289+ entity_type_map [join_key ] = entity .value_type
1290+ return entity_name_to_join_key_map , entity_type_map
1291+
12381292 @staticmethod
12391293 def _get_table_entity_values (
12401294 table : FeatureView ,
@@ -1275,23 +1329,21 @@ def _set_table_entity_keys(
12751329 entity_key .entity_values .extend (next (rowise_values ))
12761330
12771331 @staticmethod
1278- def _populate_request_data_features (
1332+ def _populate_result_rows_from_columnar (
12791333 online_features_response : GetOnlineFeaturesResponse ,
1280- request_data_features : Dict [str , List [Any ]],
1334+ data : Dict [str , List [Value ]],
12811335 ):
1282- # Add more feature values to the existing result rows for the request data features
1283- for feature_name , feature_values in request_data_features .items ():
1284- proto_values = python_values_to_proto_values (
1285- feature_values , ValueType .UNKNOWN
1286- )
1336+ timestamp = Timestamp () # Only initialize this timestamp once.
1337+ # Add more values to the existing result rows
1338+ for feature_name , feature_values in data .items ():
12871339
12881340 online_features_response .metadata .feature_names .val .append (feature_name )
12891341
1290- for row_idx , proto_value in enumerate (proto_values ):
1342+ for row_idx , proto_value in enumerate (feature_values ):
12911343 result_row = online_features_response .results [row_idx ]
12921344 result_row .values .append (proto_value )
12931345 result_row .statuses .append (FieldStatus .PRESENT )
1294- result_row .event_timestamps .append (Timestamp () )
1346+ result_row .event_timestamps .append (timestamp )
12951347
12961348 @staticmethod
12971349 def get_needed_request_data (
@@ -1567,6 +1619,13 @@ def serve_transformations(self, port: int) -> None:
15671619 transformation_server .start_server (self , port )
15681620
15691621
1622+ def _validate_entity_values (join_key_values : Dict [str , List [Value ]]):
1623+ set_of_row_lengths = {len (v ) for v in join_key_values .values ()}
1624+ if len (set_of_row_lengths ) > 1 :
1625+ raise ValueError ("All entity rows must have the same columns." )
1626+ return set_of_row_lengths .pop ()
1627+
1628+
15701629def _validate_feature_refs (feature_refs : List [str ], full_feature_names : bool = False ):
15711630 collided_feature_refs = []
15721631
0 commit comments