99from psycopg2 import sql
1010from psycopg2 .extras import execute_values
1111from psycopg2 .pool import SimpleConnectionPool
12-
1312from feast import Entity
1413from feast .feature_view import FeatureView
1514from feast .infra .key_encoding_utils import serialize_entity_key
2120from feast .repo_config import RepoConfig
2221from feast .usage import log_exceptions_and_usage
2322
24- # Search query template to find the top k items that are closest to the given embedding
25- # SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5;
26- SEARCH_QUERY_TEMPLATE = """
27- SELECT feature_name, value, event_ts FROM {table_name}
28- WHERE feature_name = '{feature_name}'
29- ORDER BY value <-> %s
30- LIMIT %s;
31- """
32-
3323
3424class PostgreSQLOnlineStoreConfig (PostgreSQLConfig ):
3525 type : Literal ["postgres" ] = "postgres"
3626
27+ # Whether to enable the pgvector extension for vector similarity search
28+ pgvector_enabled : Optional [bool ] = False
29+
30+ # If pgvector is enabled, the length of the vector field
31+ vector_len : Optional [int ] = 512
32+
3733
3834class PostgreSQLOnlineStore (OnlineStore ):
3935 _conn : Optional [psycopg2 ._psycopg .connection ] = None
@@ -77,11 +73,15 @@ def online_write_batch(
7773 created_ts = _to_naive_utc (created_ts )
7874
7975 for feature_name , val in values .items ():
76+ if config .online_config ["pgvector_enabled" ]:
77+ val = str (val .float_list_val .val )
78+ else :
79+ val = val .SerializeToString ()
8080 insert_values .append (
8181 (
8282 entity_key_bin ,
8383 feature_name ,
84- val . SerializeToString () ,
84+ val ,
8585 timestamp ,
8686 created_ts ,
8787 )
@@ -221,14 +221,17 @@ def update(
221221
222222 for table in tables_to_keep :
223223 table_name = _table_id (project , table )
224+ value_type = "BYTEA"
225+ if config .online_config ["pgvector_enabled" ]:
226+ value_type = f'vector({ config .online_config ["vector_len" ]} )'
224227 cur .execute (
225228 sql .SQL (
226229 """
227230 CREATE TABLE IF NOT EXISTS {}
228231 (
229232 entity_key BYTEA,
230233 feature_name TEXT,
231- value BYTEA ,
234+ value {} ,
232235 event_ts TIMESTAMPTZ,
233236 created_ts TIMESTAMPTZ,
234237 PRIMARY KEY(entity_key, feature_name)
@@ -237,6 +240,7 @@ def update(
237240 """
238241 ).format (
239242 sql .Identifier (table_name ),
243+ sql .SQL (value_type ),
240244 sql .Identifier (f"{ table_name } _ek" ),
241245 sql .Identifier (table_name ),
242246 )
@@ -267,7 +271,7 @@ def retrieve_online_documents(
267271 requested_feature : str ,
268272 embedding : List [float ],
269273 top_k : int ,
270- ) -> List [Tuple [Optional [datetime ], Optional [ValueProto ]]]:
274+ ) -> List [Tuple [Optional [datetime ], Optional [ ValueProto ], Optional [ValueProto ]]]:
271275 """
272276
273277 Args:
@@ -280,25 +284,50 @@ def retrieve_online_documents(
280284 List of tuples containing the event timestamp and the document feature
281285
282286 """
287+ project = config .project
283288
284289 # Convert the embedding to a string to be used in postgres vector search
285- query_embedding_str = f"' [{ ',' .join (str (el ) for el in embedding )} ]' "
290+ query_embedding_str = f"[{ ',' .join (str (el ) for el in embedding )} ]"
286291
287- result : List [Tuple [Optional [datetime ], Optional [ValueProto ]]] = []
292+ result : List [Tuple [Optional [datetime ], Optional [ValueProto ], Optional [ ValueProto ] ]] = []
288293 with self ._get_conn (config ) as conn , conn .cursor () as cur :
294+ table_name = _table_id (project , table )
295+
296+ # Search query template to find the top k items that are closest to the given embedding
297+ # SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5;
289298 cur .execute (
290- SEARCH_QUERY_TEMPLATE .format (
291- table_name = table , feature_name = requested_feature
299+ sql .SQL (
300+ """
301+ SELECT
302+ entity_key,
303+ feature_name,
304+ value,
305+ value <-> %s as distance,
306+ event_ts FROM {table_name}
307+ WHERE feature_name = {feature_name}
308+ ORDER BY distance
309+ LIMIT {top_k};
310+ """
311+ ).format (
312+ table_name = sql .Identifier (table_name ),
313+ feature_name = sql .Literal (requested_feature ),
314+ top_k = sql .Literal (top_k )
292315 ),
293- (query_embedding_str , top_k ),
316+ (query_embedding_str ,),
294317 )
295318 rows = cur .fetchall ()
296319
297- for feature_name , value , event_ts in rows :
298- val = ValueProto ()
299- val .ParseFromString (value )
320+ for entity_key , feature_name , value , distance , event_ts in rows :
321+
322+ # TODO Deserialize entity_key to return the entity in response
323+ entity_key_proto = EntityKeyProto ()
324+ entity_key_proto_bin = bytes (entity_key )
325+
326+ # TODO Convert to List[float] for value type proto
327+ feature_value_proto = ValueProto (string_val = value )
300328
301- result .append ((event_ts , val ))
329+ distance_value_proto = ValueProto (float_val = distance )
330+ result .append ((event_ts , feature_value_proto , distance_value_proto ))
302331
303332 return result
304333
0 commit comments