@@ -75,10 +75,7 @@ def online_write_batch(
7575
7676 for feature_name , val in values .items ():
7777 vector_val = None
78- if (
79- "pgvector_enabled" in config .online_store
80- and config .online_store .pgvector_enabled
81- ):
78+ if config .online_store .pgvector_enabled :
8279 vector_val = get_list_val_str (val )
8380 insert_values .append (
8481 (
@@ -226,10 +223,7 @@ def update(
226223
227224 for table in tables_to_keep :
228225 table_name = _table_id (project , table )
229- if (
230- "pgvector_enabled" in config .online_store
231- and config .online_store .pgvector_enabled
232- ):
226+ if config .online_store .pgvector_enabled :
233227 vector_value_type = f"vector({ config .online_store .vector_len } )"
234228 else :
235229 # keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility
@@ -282,7 +276,14 @@ def retrieve_online_documents(
282276 requested_feature : str ,
283277 embedding : List [float ],
284278 top_k : int ,
285- ) -> List [Tuple [Optional [datetime ], Optional [ValueProto ], Optional [ValueProto ]]]:
279+ ) -> List [
280+ Tuple [
281+ Optional [datetime ],
282+ Optional [ValueProto ],
283+ Optional [ValueProto ],
284+ Optional [ValueProto ],
285+ ]
286+ ]:
286287 """
287288
288289 Args:
@@ -297,10 +298,7 @@ def retrieve_online_documents(
297298 """
298299 project = config .project
299300
300- if (
301- "pgvector_enabled" not in config .online_store
302- or not config .online_store .pgvector_enabled
303- ):
301+ if not config .online_store .pgvector_enabled :
304302 raise ValueError (
305303 "pgvector is not enabled in the online store configuration"
306304 )
@@ -309,7 +307,12 @@ def retrieve_online_documents(
309307 query_embedding_str = f"[{ ',' .join (str (el ) for el in embedding )} ]"
310308
311309 result : List [
312- Tuple [Optional [datetime ], Optional [ValueProto ], Optional [ValueProto ]]
310+ Tuple [
311+ Optional [datetime ],
312+ Optional [ValueProto ],
313+ Optional [ValueProto ],
314+ Optional [ValueProto ],
315+ ]
313316 ] = []
314317 with self ._get_conn (config ) as conn , conn .cursor () as cur :
315318 table_name = _table_id (project , table )
@@ -322,6 +325,7 @@ def retrieve_online_documents(
322325 SELECT
323326 entity_key,
324327 feature_name,
328+ value,
325329 vector_value,
326330 vector_value <-> %s as distance,
327331 event_ts FROM {table_name}
@@ -338,16 +342,31 @@ def retrieve_online_documents(
338342 )
339343 rows = cur .fetchall ()
340344
341- for entity_key , feature_name , vector_value , distance , event_ts in rows :
345+ for (
346+ entity_key ,
347+ feature_name ,
348+ value ,
349+ vector_value ,
350+ distance ,
351+ event_ts ,
352+ ) in rows :
342353 # TODO Deserialize entity_key to return the entity in response
343354 # entity_key_proto = EntityKeyProto()
344355 # entity_key_proto_bin = bytes(entity_key)
345356
346- # TODO Convert to List[float] for value type proto
347- feature_value_proto = ValueProto ( string_val = vector_value )
357+ feature_value_proto = ValueProto ()
358+ feature_value_proto . ParseFromString ( bytes ( value ) )
348359
360+ vector_value_proto = ValueProto (string_val = vector_value )
349361 distance_value_proto = ValueProto (float_val = distance )
350- result .append ((event_ts , feature_value_proto , distance_value_proto ))
362+ result .append (
363+ (
364+ event_ts ,
365+ feature_value_proto ,
366+ vector_value_proto ,
367+ distance_value_proto ,
368+ )
369+ )
351370
352371 return result
353372
0 commit comments