From b8072ed09a4ef339ce7b5141245a07b476fcdd7c Mon Sep 17 00:00:00 2001 From: TomSteenbergen Date: Tue, 2 Jul 2024 11:28:07 +0200 Subject: [PATCH 1/4] Add async retrieval for postgres Signed-off-by: TomSteenbergen --- .../infra/online_stores/contrib/postgres.py | 260 +++++++++++------- .../infra/utils/postgres/connection_utils.py | 27 +- .../online_store/test_universal_online.py | 2 +- 3 files changed, 188 insertions(+), 101 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 8715f0f65bb..3f18d26ca6e 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -4,6 +4,7 @@ from datetime import datetime from typing import ( Any, + AsyncGenerator, Callable, Dict, Generator, @@ -12,18 +13,24 @@ Optional, Sequence, Tuple, + Union, ) import pytz -from psycopg import sql +from psycopg import AsyncConnection, sql from psycopg.connection import Connection -from psycopg_pool import ConnectionPool +from psycopg_pool import AsyncConnectionPool, ConnectionPool from feast import Entity from feast.feature_view import FeatureView from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key from feast.infra.online_stores.online_store import OnlineStore -from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool +from feast.infra.utils.postgres.connection_utils import ( + _get_conn, + _get_conn_async, + _get_connection_pool, + _get_connection_pool_async, +) from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto @@ -51,6 +58,9 @@ class PostgreSQLOnlineStore(OnlineStore): _conn: Optional[Connection] = None _conn_pool: Optional[ConnectionPool] = None + _conn_async: Optional[AsyncConnection] = None + _conn_pool_async: Optional[AsyncConnectionPool] = None + @contextlib.contextmanager def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]: assert config.online_store.type == "postgres" @@ -67,15 +77,33 @@ def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]: self._conn = _get_conn(config.online_store) yield self._conn + @contextlib.asynccontextmanager + async def _get_conn_async( + self, config: RepoConfig + ) -> AsyncGenerator[AsyncConnection, Any]: + if config.online_store.conn_type == ConnectionType.pool: + if not self._conn_pool_async: + self._conn_pool_async = await _get_connection_pool_async( + config.online_store + ) + await self._conn_pool_async.open() + connection = await self._conn_pool_async.getconn() + yield connection + await self._conn_pool_async.putconn(connection) + else: + if not self._conn_async: + self._conn_async = await _get_conn_async(config.online_store) + yield self._conn_async + def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], - batch_size: int = 5000, + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + batch_size: int = 5000, ) -> None: # Format insert values insert_values = [] @@ -129,85 +157,123 @@ def online_write_batch( progress(len(cur_batch)) def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] + keys = self._prepare_keys(config, entity_keys) + query, params = self._construct_query_and_params( + config, table, keys, requested_features + ) - project = config.project with self._get_conn(config) as conn, conn.cursor() as cur: - # Collecting all the keys to a list allows us to make fewer round trips - # to PostgreSQL - keys = [] - for entity_key in entity_keys: - keys.append( - serialize_entity_key( - entity_key, - entity_key_serialization_version=config.entity_key_serialization_version, - ) - ) + cur.execute(query, params) + rows = cur.fetchall() - if not requested_features: - cur.execute( - sql.SQL( - """ - SELECT entity_key, feature_name, value, event_ts - FROM {} WHERE entity_key = ANY(%s); - """ - ).format( - sql.Identifier(_table_id(project, table)), - ), - (keys,), - ) - else: - cur.execute( - sql.SQL( - """ - SELECT entity_key, feature_name, value, event_ts - FROM {} WHERE entity_key = ANY(%s) and feature_name = ANY(%s); - """ - ).format( - sql.Identifier(_table_id(project, table)), - ), - (keys, requested_features), - ) + return self._process_rows(keys, rows) - rows = cur.fetchall() + async def online_read_async( + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + keys = self._prepare_keys(config, entity_keys) + query, params = self._construct_query_and_params( + config, table, keys, requested_features + ) + + async with self._get_conn_async(config) as conn: + async with conn.cursor() as cur: + await cur.execute(query, params) + rows = await cur.fetchall() + + return self._process_rows(keys, rows) + + @staticmethod + def _construct_query_and_params( + config: RepoConfig, + table: FeatureView, + keys: List[bytes], + requested_features: Optional[List[str]] = None, + ) -> Tuple[sql.Composed, Union[Tuple[List[bytes], List[str]], Tuple[List[bytes]]]]: + """Construct the SQL query based on the given parameters.""" + if requested_features: + query = sql.SQL( + """ + SELECT entity_key, feature_name, value, event_ts + FROM {} WHERE entity_key = ANY(%s) AND feature_name = ANY(%s); + """ + ).format( + sql.Identifier(_table_id(config.project, table)), + ) + params = (keys, requested_features) + else: + query = sql.SQL( + """ + SELECT entity_key, feature_name, value, event_ts + FROM {} WHERE entity_key = ANY(%s); + """ + ).format( + sql.Identifier(_table_id(config.project, table)), + ) + params = (keys, []) + return query, params + + @staticmethod + def _prepare_keys( + config: RepoConfig, entity_keys: List[EntityKeyProto] + ) -> List[bytes]: + """Prepare all keys in a list to make fewer round trips to the database.""" + return [ + serialize_entity_key( + entity_key, + entity_key_serialization_version=config.entity_key_serialization_version, + ) + for entity_key in entity_keys + ] - # Since we don't know the order returned from PostgreSQL we'll need - # to construct a dict to be able to quickly look up the correct row - # when we iterate through the keys since they are in the correct order - values_dict = defaultdict(list) - for row in rows if rows is not None else []: - values_dict[ - row[0] if isinstance(row[0], bytes) else row[0].tobytes() - ].append(row[1:]) - - for key in keys: - if key in values_dict: - value = values_dict[key] - res = {} - for feature_name, value_bin, event_ts in value: - val = ValueProto() - val.ParseFromString(bytes(value_bin)) - res[feature_name] = val - result.append((event_ts, res)) - else: - result.append((None, None)) + @staticmethod + def _process_rows( + keys: List[bytes], rows: List[Tuple] + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + """Transform the retrieved rows in the desired output. + + PostgreSQL may return rows in an unpredictable order. Therefore, `values_dict` + is created to quickly look up the correct row using the keys, since these are + actually in the correct order. + """ + values_dict = defaultdict(list) + for row in rows if rows is not None else []: + values_dict[ + row[0] if isinstance(row[0], bytes) else row[0].tobytes() + ].append(row[1:]) + result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] + for key in keys: + if key in values_dict: + value = values_dict[key] + res = {} + for feature_name, value_bin, event_ts in value: + val = ValueProto() + val.ParseFromString(bytes(value_bin)) + res[feature_name] = val + result.append((event_ts, res)) + else: + result.append((None, None)) return result def update( - self, - config: RepoConfig, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, ): project = config.project schema_name = config.online_store.db_schema or config.online_store.user @@ -269,10 +335,10 @@ def update( conn.commit() def teardown( - self, - config: RepoConfig, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ): project = config.project try: @@ -285,13 +351,13 @@ def teardown( raise def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_feature: str, - embedding: List[float], - top_k: int, - distance_metric: Optional[str] = "L2", + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + distance_metric: Optional[str] = "L2", ) -> List[ Tuple[ Optional[datetime], @@ -367,12 +433,12 @@ def retrieve_online_documents( rows = cur.fetchall() for ( - entity_key, - feature_name, - value, - vector_value, - distance, - event_ts, + entity_key, + feature_name, + value, + vector_value, + distance, + event_ts, ) in rows: # TODO Deserialize entity_key to return the entity in response # entity_key_proto = EntityKeyProto() diff --git a/sdk/python/feast/infra/utils/postgres/connection_utils.py b/sdk/python/feast/infra/utils/postgres/connection_utils.py index e0599019b96..8fdd926100a 100644 --- a/sdk/python/feast/infra/utils/postgres/connection_utils.py +++ b/sdk/python/feast/infra/utils/postgres/connection_utils.py @@ -4,8 +4,8 @@ import pandas as pd import psycopg import pyarrow as pa -from psycopg.connection import Connection -from psycopg_pool import ConnectionPool +from psycopg import AsyncConnection, Connection +from psycopg_pool import AsyncConnectionPool, ConnectionPool from feast.infra.utils.postgres.postgres_config import PostgreSQLConfig from feast.type_map import arrow_to_pg_type @@ -21,6 +21,16 @@ def _get_conn(config: PostgreSQLConfig) -> Connection: return conn +async def _get_conn_async(config: PostgreSQLConfig) -> AsyncConnection: + """Get a psycopg `AsyncConnection`.""" + conn = await psycopg.AsyncConnection.connect( + conninfo=_get_conninfo(config), + keepalives_idle=config.keepalives_idle, + **_get_conn_kwargs(config), + ) + return conn + + def _get_connection_pool(config: PostgreSQLConfig) -> ConnectionPool: """Get a psycopg `ConnectionPool`.""" return ConnectionPool( @@ -32,6 +42,17 @@ def _get_connection_pool(config: PostgreSQLConfig) -> ConnectionPool: ) +async def _get_connection_pool_async(config: PostgreSQLConfig) -> AsyncConnectionPool: + """Get a psycopg `AsyncConnectionPool`.""" + return AsyncConnectionPool( + conninfo=_get_conninfo(config), + min_size=config.min_conn, + max_size=config.max_conn, + open=False, + kwargs=_get_conn_kwargs(config), + ) + + def _get_conninfo(config: PostgreSQLConfig) -> str: """Get the `conninfo` argument required for connection objects.""" return ( @@ -67,7 +88,7 @@ def _df_to_create_table_sql(entity_df, table_name) -> str: def df_to_postgres_table( - config: PostgreSQLConfig, df: pd.DataFrame, table_name: str + config: PostgreSQLConfig, df: pd.DataFrame, table_name: str ) -> Dict[str, np.dtype]: """ Create a table for the data frame, insert all the values, and return the table schema diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 38656b90a9c..2ffe869ef50 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -488,7 +488,7 @@ def test_online_retrieval_with_event_timestamps(environment, universal_data_sour @pytest.mark.integration -@pytest.mark.universal_online_stores(only=["redis", "dynamodb"]) +@pytest.mark.universal_online_stores(only=["redis", "dynamodb", "postgres"]) def test_async_online_retrieval_with_event_timestamps( environment, universal_data_sources ): From 1d5719a1edf05bbcfa96ab6d8170e566c61cacf1 Mon Sep 17 00:00:00 2001 From: TomSteenbergen Date: Tue, 2 Jul 2024 11:29:18 +0200 Subject: [PATCH 2/4] Format Signed-off-by: TomSteenbergen --- .../infra/online_stores/contrib/postgres.py | 98 +++++++++---------- .../infra/utils/postgres/connection_utils.py | 2 +- 2 files changed, 50 insertions(+), 50 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 3f18d26ca6e..a9c4f3ee05f 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -79,7 +79,7 @@ def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]: @contextlib.asynccontextmanager async def _get_conn_async( - self, config: RepoConfig + self, config: RepoConfig ) -> AsyncGenerator[AsyncConnection, Any]: if config.online_store.conn_type == ConnectionType.pool: if not self._conn_pool_async: @@ -96,14 +96,14 @@ async def _get_conn_async( yield self._conn_async def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], - batch_size: int = 5000, + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + batch_size: int = 5000, ) -> None: # Format insert values insert_values = [] @@ -157,11 +157,11 @@ def online_write_batch( progress(len(cur_batch)) def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: keys = self._prepare_keys(config, entity_keys) query, params = self._construct_query_and_params( @@ -175,11 +175,11 @@ def online_read( return self._process_rows(keys, rows) async def online_read_async( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: keys = self._prepare_keys(config, entity_keys) query, params = self._construct_query_and_params( @@ -195,10 +195,10 @@ async def online_read_async( @staticmethod def _construct_query_and_params( - config: RepoConfig, - table: FeatureView, - keys: List[bytes], - requested_features: Optional[List[str]] = None, + config: RepoConfig, + table: FeatureView, + keys: List[bytes], + requested_features: Optional[List[str]] = None, ) -> Tuple[sql.Composed, Union[Tuple[List[bytes], List[str]], Tuple[List[bytes]]]]: """Construct the SQL query based on the given parameters.""" if requested_features: @@ -225,7 +225,7 @@ def _construct_query_and_params( @staticmethod def _prepare_keys( - config: RepoConfig, entity_keys: List[EntityKeyProto] + config: RepoConfig, entity_keys: List[EntityKeyProto] ) -> List[bytes]: """Prepare all keys in a list to make fewer round trips to the database.""" return [ @@ -238,7 +238,7 @@ def _prepare_keys( @staticmethod def _process_rows( - keys: List[bytes], rows: List[Tuple] + keys: List[bytes], rows: List[Tuple] ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: """Transform the retrieved rows in the desired output. @@ -267,13 +267,13 @@ def _process_rows( return result def update( - self, - config: RepoConfig, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, ): project = config.project schema_name = config.online_store.db_schema or config.online_store.user @@ -335,10 +335,10 @@ def update( conn.commit() def teardown( - self, - config: RepoConfig, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ): project = config.project try: @@ -351,13 +351,13 @@ def teardown( raise def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_feature: str, - embedding: List[float], - top_k: int, - distance_metric: Optional[str] = "L2", + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + distance_metric: Optional[str] = "L2", ) -> List[ Tuple[ Optional[datetime], @@ -433,12 +433,12 @@ def retrieve_online_documents( rows = cur.fetchall() for ( - entity_key, - feature_name, - value, - vector_value, - distance, - event_ts, + entity_key, + feature_name, + value, + vector_value, + distance, + event_ts, ) in rows: # TODO Deserialize entity_key to return the entity in response # entity_key_proto = EntityKeyProto() diff --git a/sdk/python/feast/infra/utils/postgres/connection_utils.py b/sdk/python/feast/infra/utils/postgres/connection_utils.py index 8fdd926100a..7b37ea981f4 100644 --- a/sdk/python/feast/infra/utils/postgres/connection_utils.py +++ b/sdk/python/feast/infra/utils/postgres/connection_utils.py @@ -88,7 +88,7 @@ def _df_to_create_table_sql(entity_df, table_name) -> str: def df_to_postgres_table( - config: PostgreSQLConfig, df: pd.DataFrame, table_name: str + config: PostgreSQLConfig, df: pd.DataFrame, table_name: str ) -> Dict[str, np.dtype]: """ Create a table for the data frame, insert all the values, and return the table schema From 84594f77e97f8eede2dbf0ccd0f965a3ee1aeb00 Mon Sep 17 00:00:00 2001 From: TomSteenbergen Date: Tue, 2 Jul 2024 11:44:26 +0200 Subject: [PATCH 3/4] Update _prepare_keys method Signed-off-by: TomSteenbergen --- sdk/python/feast/infra/online_stores/contrib/postgres.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index a9c4f3ee05f..640cdd7a290 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -163,7 +163,7 @@ def online_read( entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - keys = self._prepare_keys(config, entity_keys) + keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version) query, params = self._construct_query_and_params( config, table, keys, requested_features ) @@ -181,7 +181,7 @@ async def online_read_async( entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - keys = self._prepare_keys(config, entity_keys) + keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version) query, params = self._construct_query_and_params( config, table, keys, requested_features ) @@ -225,13 +225,13 @@ def _construct_query_and_params( @staticmethod def _prepare_keys( - config: RepoConfig, entity_keys: List[EntityKeyProto] + entity_keys: List[EntityKeyProto], entity_key_seriaization_version: int ) -> List[bytes]: """Prepare all keys in a list to make fewer round trips to the database.""" return [ serialize_entity_key( entity_key, - entity_key_serialization_version=config.entity_key_serialization_version, + entity_key_serialization_version=entity_key_serialization_version, ) for entity_key in entity_keys ] From ac297ca827ae7d8463d0c837258e81728f38f2ef Mon Sep 17 00:00:00 2001 From: TomSteenbergen Date: Tue, 2 Jul 2024 11:48:00 +0200 Subject: [PATCH 4/4] Fix typo Signed-off-by: TomSteenbergen --- sdk/python/feast/infra/online_stores/contrib/postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 640cdd7a290..48499840e0b 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -225,7 +225,7 @@ def _construct_query_and_params( @staticmethod def _prepare_keys( - entity_keys: List[EntityKeyProto], entity_key_seriaization_version: int + entity_keys: List[EntityKeyProto], entity_key_serialization_version: int ) -> List[bytes]: """Prepare all keys in a list to make fewer round trips to the database.""" return [