44from datetime import datetime
55from typing import Any , Callable , Dict , List , Literal , Optional , Sequence , Tuple
66
7+ import numpy as np
78import psycopg2
89import pytz
910from psycopg2 import sql
1213
1314from feast import Entity
1415from feast .feature_view import FeatureView
16+ from feast .feature import Feature
1517from feast .infra .key_encoding_utils import serialize_entity_key
1618from feast .infra .online_stores .online_store import OnlineStore
19+ from feast .infra .online_stores .document_store import DocumentStore , DocumentStoreIndexConfig
1720from feast .infra .utils .postgres .connection_utils import _get_conn , _get_connection_pool
1821from feast .infra .utils .postgres .postgres_config import ConnectionType , PostgreSQLConfig
1922from feast .protos .feast .types .EntityKey_pb2 import EntityKey as EntityKeyProto
@@ -46,13 +49,13 @@ def _get_conn(self, config: RepoConfig):
4649
4750 @log_exceptions_and_usage (online_store = "postgres" )
4851 def online_write_batch (
49- self ,
50- config : RepoConfig ,
51- table : FeatureView ,
52- data : List [
53- Tuple [EntityKeyProto , Dict [str , ValueProto ], datetime , Optional [datetime ]]
54- ],
55- progress : Optional [Callable [[int ], Any ]],
52+ self ,
53+ config : RepoConfig ,
54+ table : FeatureView ,
55+ data : List [
56+ Tuple [EntityKeyProto , Dict [str , ValueProto ], datetime , Optional [datetime ]]
57+ ],
58+ progress : Optional [Callable [[int ], Any ]],
5659 ) -> None :
5760 project = config .project
5861
@@ -80,7 +83,7 @@ def online_write_batch(
8083 # Control the batch so that we can update the progress
8184 batch_size = 5000
8285 for i in range (0 , len (insert_values ), batch_size ):
83- cur_batch = insert_values [i : i + batch_size ]
86+ cur_batch = insert_values [i : i + batch_size ]
8487 execute_values (
8588 cur ,
8689 sql .SQL (
@@ -104,11 +107,11 @@ def online_write_batch(
104107
105108 @log_exceptions_and_usage (online_store = "postgres" )
106109 def online_read (
107- self ,
108- config : RepoConfig ,
109- table : FeatureView ,
110- entity_keys : List [EntityKeyProto ],
111- requested_features : Optional [List [str ]] = None ,
110+ self ,
111+ config : RepoConfig ,
112+ table : FeatureView ,
113+ entity_keys : List [EntityKeyProto ],
114+ requested_features : Optional [List [str ]] = None ,
112115 ) -> List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]]:
113116 result : List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]] = []
114117
@@ -175,13 +178,13 @@ def online_read(
175178
176179 @log_exceptions_and_usage (online_store = "postgres" )
177180 def update (
178- self ,
179- config : RepoConfig ,
180- tables_to_delete : Sequence [FeatureView ],
181- tables_to_keep : Sequence [FeatureView ],
182- entities_to_delete : Sequence [Entity ],
183- entities_to_keep : Sequence [Entity ],
184- partial : bool ,
181+ self ,
182+ config : RepoConfig ,
183+ tables_to_delete : Sequence [FeatureView ],
184+ tables_to_keep : Sequence [FeatureView ],
185+ entities_to_delete : Sequence [Entity ],
186+ entities_to_keep : Sequence [Entity ],
187+ partial : bool ,
185188 ):
186189 project = config .project
187190 schema_name = config .online_store .db_schema or config .online_store .user
@@ -236,10 +239,10 @@ def update(
236239 conn .commit ()
237240
238241 def teardown (
239- self ,
240- config : RepoConfig ,
241- tables : Sequence [FeatureView ],
242- entities : Sequence [Entity ],
242+ self ,
243+ config : RepoConfig ,
244+ tables : Sequence [FeatureView ],
245+ entities : Sequence [Entity ],
243246 ):
244247 project = config .project
245248 try :
@@ -273,3 +276,75 @@ def _to_naive_utc(ts: datetime):
273276 return ts
274277 else :
275278 return ts .astimezone (pytz .utc ).replace (tzinfo = None )
279+
280+
281+ # Search query template to find the top k items that are closest to the given embedding
282+ # SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5;
283+ SEARCH_QUERY_TEMPLATE = """
284+ SELECT entity_key, feature_name, value, event_ts FROM {table_name}
285+ WHERE feature_name = '{feature_name}'
286+ ORDER BY value <-> %s
287+ LIMIT %s;
288+ """
289+
290+ # Create index query template to create a index based on the index type
291+ CREATE_INDEX_QUERY_TEMPLATE = """
292+ CREATE INDEX ON {table_name} USING {index_type} (embedding {embeding_type});
293+ """
294+
295+
296+ class PostgresDocumentStoreConfig (DocumentStoreIndexConfig ):
297+ type : Literal ["postgres" ] = "postgres"
298+
299+
300+ class PostgresDocumentStore (PostgreSQLOnlineStore , DocumentStore ):
301+
302+ def online_search (self ,
303+ config : RepoConfig ,
304+ table : FeatureView ,
305+ requested_feature : str ,
306+ embedding : np .ndarray ,
307+ top_k : int ,
308+ ):
309+ result : List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]] = []
310+
311+ with self ._get_conn (config ) as conn , conn .cursor () as cur :
312+ cur .execute (SEARCH_QUERY_TEMPLATE .format (
313+ table_name = table ,
314+ feature_name = requested_feature
315+ ), (embedding , top_k ))
316+ rows = cur .fetchall ()
317+
318+ for row in rows :
319+ # The first column is the entity key
320+ entity_key = EntityKeyProto ()
321+ entity_key .ParseFromString (row [0 ])
322+
323+ # The second column is the feature name
324+ feature_name = row [1 ]
325+
326+ # The third column is the embedding value
327+ val = ValueProto ()
328+ val .ParseFromString (row [2 ])
329+
330+ # The fourth column is the event timestamp
331+ event_ts = row [3 ]
332+
333+ res = {}
334+ res [feature_name ] = val
335+ result .append ((event_ts , res ))
336+
337+
338+ return result
339+
340+ def create_index (self ,
341+ config : RepoConfig ,
342+ index : str ,
343+ index_config : DocumentStoreIndexConfig
344+ ):
345+ with self ._get_conn (config ) as conn , conn .cursor () as cur :
346+ cur .execute (CREATE_INDEX_QUERY_TEMPLATE .format (
347+ table_name = config .project ,
348+ index_type = index ,
349+ embeding_type = index_config .embedding_type
350+ ))
0 commit comments