22import logging
33from collections import defaultdict
44from datetime import datetime
5- from typing import Any , Callable , Dict , List , Literal , Optional , Sequence , Tuple , Union
5+ from typing import Any , Callable , Dict , List , Literal , Optional , Sequence , Tuple
66
77import psycopg2
88import pytz
1212
1313from feast import Entity
1414from feast .feature_view import FeatureView
15- from feast .infra .key_encoding_utils import get_val_str , serialize_entity_key
15+ from feast .infra .key_encoding_utils import get_list_val_str , serialize_entity_key
1616from feast .infra .online_stores .online_store import OnlineStore
1717from feast .infra .utils .postgres .connection_utils import _get_conn , _get_connection_pool
1818from feast .infra .utils .postgres .postgres_config import ConnectionType , PostgreSQLConfig
@@ -74,19 +74,18 @@ def online_write_batch(
7474 created_ts = _to_naive_utc (created_ts )
7575
7676 for feature_name , val in values .items ():
77- val_str : Union [ str , bytes ]
77+ vector_val = None
7878 if (
79- "pgvector_enabled" in config .online_config
80- and config .online_config [ " pgvector_enabled" ]
79+ "pgvector_enabled" in config .online_store
80+ and config .online_store . pgvector_enabled
8181 ):
82- val_str = get_val_str (val )
83- else :
84- val_str = val .SerializeToString ()
82+ vector_val = get_list_val_str (val )
8583 insert_values .append (
8684 (
8785 entity_key_bin ,
8886 feature_name ,
89- val_str ,
87+ val .SerializeToString (),
88+ vector_val ,
9089 timestamp ,
9190 created_ts ,
9291 )
@@ -100,11 +99,12 @@ def online_write_batch(
10099 sql .SQL (
101100 """
102101 INSERT INTO {}
103- (entity_key, feature_name, value, event_ts, created_ts)
102+ (entity_key, feature_name, value, vector_value, event_ts, created_ts)
104103 VALUES %s
105104 ON CONFLICT (entity_key, feature_name) DO
106105 UPDATE SET
107106 value = EXCLUDED.value,
107+ vector_value = EXCLUDED.vector_value,
108108 event_ts = EXCLUDED.event_ts,
109109 created_ts = EXCLUDED.created_ts;
110110 """ ,
@@ -226,20 +226,23 @@ def update(
226226
227227 for table in tables_to_keep :
228228 table_name = _table_id (project , table )
229- value_type = "BYTEA"
230229 if (
231- "pgvector_enabled" in config .online_config
232- and config .online_config [ " pgvector_enabled" ]
230+ "pgvector_enabled" in config .online_store
231+ and config .online_store . pgvector_enabled
233232 ):
234- value_type = f'vector({ config .online_config ["vector_len" ]} )'
233+ vector_value_type = f"vector({ config .online_store .vector_len } )"
234+ else :
235+ # keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility
236+ vector_value_type = "BYTEA"
235237 cur .execute (
236238 sql .SQL (
237239 """
238240 CREATE TABLE IF NOT EXISTS {}
239241 (
240242 entity_key BYTEA,
241243 feature_name TEXT,
242- value {},
244+ value BYTEA,
245+ vector_value {} NULL,
243246 event_ts TIMESTAMPTZ,
244247 created_ts TIMESTAMPTZ,
245248 PRIMARY KEY(entity_key, feature_name)
@@ -248,7 +251,7 @@ def update(
248251 """
249252 ).format (
250253 sql .Identifier (table_name ),
251- sql .SQL (value_type ),
254+ sql .SQL (vector_value_type ),
252255 sql .Identifier (f"{ table_name } _ek" ),
253256 sql .Identifier (table_name ),
254257 )
@@ -294,6 +297,14 @@ def retrieve_online_documents(
294297 """
295298 project = config .project
296299
300+ if (
301+ "pgvector_enabled" not in config .online_store
302+ or not config .online_store .pgvector_enabled
303+ ):
304+ raise ValueError (
305+ "pgvector is not enabled in the online store configuration"
306+ )
307+
297308 # Convert the embedding to a string to be used in postgres vector search
298309 query_embedding_str = f"[{ ',' .join (str (el ) for el in embedding )} ]"
299310
@@ -311,8 +322,8 @@ def retrieve_online_documents(
311322 SELECT
312323 entity_key,
313324 feature_name,
314- value ,
315- value <-> %s as distance,
325+ vector_value ,
326+ vector_value <-> %s as distance,
316327 event_ts FROM {table_name}
317328 WHERE feature_name = {feature_name}
318329 ORDER BY distance
@@ -327,13 +338,13 @@ def retrieve_online_documents(
327338 )
328339 rows = cur .fetchall ()
329340
330- for entity_key , feature_name , value , distance , event_ts in rows :
341+ for entity_key , feature_name , vector_value , distance , event_ts in rows :
331342 # TODO Deserialize entity_key to return the entity in response
332343 # entity_key_proto = EntityKeyProto()
333344 # entity_key_proto_bin = bytes(entity_key)
334345
335346 # TODO Convert to List[float] for value type proto
336- feature_value_proto = ValueProto (string_val = value )
347+ feature_value_proto = ValueProto (string_val = vector_value )
337348
338349 distance_value_proto = ValueProto (float_val = distance )
339350 result .append ((event_ts , feature_value_proto , distance_value_proto ))
0 commit comments