Skip to content

Commit d3253c3

Browse files
AlexEijssenadchiaalex.eijssen
authored
feat: Adding saved dataset capabilities for Postgres (#3070)
* feat: Adding saved dataset capabilities for Postgres Signed-off-by: Danny Chiao <danny@tecton.ai> Signed-off-by: alex.eijssen <alex.eijssen@energyessentials.nl> Co-authored-by: Danny Chiao <danny@tecton.ai> Co-authored-by: alex.eijssen <alex.eijssen@energyessentials.nl>
1 parent 36747aa commit d3253c3

File tree

5 files changed

+177
-52
lines changed

5 files changed

+177
-52
lines changed

Makefile

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,24 +164,47 @@ test-python-universal-athena:
164164
not s3_registry" \
165165
sdk/python/tests
166166

167-
168-
169-
test-python-universal-postgres:
167+
test-python-universal-postgres-offline:
170168
PYTHONPATH='.' \
171169
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.postgres_repo_configuration \
172170
PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.postgres_offline_store.tests \
173171
FEAST_USAGE=False \
174172
IS_TEST=True \
175-
python -m pytest -x --integration \
176-
-k "not test_historical_retrieval_fails_on_validation and \
177-
not test_historical_retrieval_with_validation and \
173+
python -m pytest -n 8 --integration \
174+
-k "not test_historical_retrieval_with_validation and \
178175
not test_historical_features_persisting and \
179-
not test_historical_retrieval_fails_on_validation and \
180-
not test_universal_cli and \
181-
not test_go_feature_server and \
182-
not test_feature_logging and \
183-
not test_universal_types" \
184-
sdk/python/tests
176+
not test_universal_cli and \
177+
not test_go_feature_server and \
178+
not test_feature_logging and \
179+
not test_reorder_columns and \
180+
not test_logged_features_validation and \
181+
not test_lambda_materialization_consistency and \
182+
not test_offline_write and \
183+
not test_push_features_to_offline_store and \
184+
not gcs_registry and \
185+
not s3_registry and \
186+
not test_universal_types" \
187+
sdk/python/tests
188+
189+
test-python-universal-postgres-online:
190+
PYTHONPATH='.' \
191+
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.postgres_repo_configuration \
192+
PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.postgres_offline_store.tests \
193+
FEAST_USAGE=False \
194+
IS_TEST=True \
195+
python -m pytest -n 8 --integration \
196+
-k "not test_universal_cli and \
197+
not test_go_feature_server and \
198+
not test_feature_logging and \
199+
not test_reorder_columns and \
200+
not test_logged_features_validation and \
201+
not test_lambda_materialization_consistency and \
202+
not test_offline_write and \
203+
not test_push_features_to_offline_store and \
204+
not gcs_registry and \
205+
not s3_registry and \
206+
not test_universal_types" \
207+
sdk/python/tests
185208

186209
test-python-universal-cassandra:
187210
PYTHONPATH='.' \

sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Any,
66
Callable,
77
ContextManager,
8+
Dict,
89
Iterator,
910
KeysView,
1011
List,
@@ -13,6 +14,7 @@
1314
Union,
1415
)
1516

17+
import numpy as np
1618
import pandas as pd
1719
import pyarrow as pa
1820
from jinja2 import BaseLoader, Environment
@@ -24,6 +26,9 @@
2426
from feast.errors import InvalidEntityType
2527
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
2628
from feast.infra.offline_stores import offline_utils
29+
from feast.infra.offline_stores.contrib.postgres_offline_store.postgres_source import (
30+
SavedDatasetPostgreSQLStorage,
31+
)
2732
from feast.infra.offline_stores.offline_store import (
2833
OfflineStore,
2934
RetrievalJob,
@@ -112,24 +117,24 @@ def get_historical_features(
112117
project: str,
113118
full_feature_names: bool = False,
114119
) -> RetrievalJob:
120+
121+
entity_schema = _get_entity_schema(entity_df, config)
122+
123+
entity_df_event_timestamp_col = (
124+
offline_utils.infer_event_timestamp_from_entity_df(entity_schema)
125+
)
126+
127+
entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
128+
entity_df,
129+
entity_df_event_timestamp_col,
130+
config,
131+
)
132+
115133
@contextlib.contextmanager
116134
def query_generator() -> Iterator[str]:
117-
table_name = None
118-
if isinstance(entity_df, pd.DataFrame):
119-
table_name = offline_utils.get_temp_entity_table_name()
120-
entity_schema = df_to_postgres_table(
121-
config.offline_store, entity_df, table_name
122-
)
123-
df_query = table_name
124-
elif isinstance(entity_df, str):
125-
df_query = f"({entity_df}) AS sub"
126-
entity_schema = get_query_schema(config.offline_store, df_query)
127-
else:
128-
raise TypeError(entity_df)
129-
130-
entity_df_event_timestamp_col = (
131-
offline_utils.infer_event_timestamp_from_entity_df(entity_schema)
132-
)
135+
table_name = offline_utils.get_temp_entity_table_name()
136+
137+
_upload_entity_df(config, entity_df, table_name)
133138

134139
expected_join_keys = offline_utils.get_expected_join_keys(
135140
project, feature_views, registry
@@ -139,13 +144,6 @@ def query_generator() -> Iterator[str]:
139144
entity_schema, expected_join_keys, entity_df_event_timestamp_col
140145
)
141146

142-
entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
143-
entity_df,
144-
entity_df_event_timestamp_col,
145-
config,
146-
df_query,
147-
)
148-
149147
query_context = offline_utils.get_feature_view_query_context(
150148
feature_refs,
151149
feature_views,
@@ -165,7 +163,7 @@ def query_generator() -> Iterator[str]:
165163
try:
166164
yield build_point_in_time_query(
167165
query_context_dict,
168-
left_table_query_string=df_query,
166+
left_table_query_string=table_name,
169167
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
170168
entity_df_columns=entity_schema.keys(),
171169
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
@@ -189,6 +187,12 @@ def query_generator() -> Iterator[str]:
189187
on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs(
190188
feature_refs, project, registry
191189
),
190+
metadata=RetrievalMetadata(
191+
features=feature_refs,
192+
keys=list(entity_schema.keys() - {entity_df_event_timestamp_col}),
193+
min_event_timestamp=entity_df_event_timestamp_range[0],
194+
max_event_timestamp=entity_df_event_timestamp_range[1],
195+
),
192196
)
193197

194198
@staticmethod
@@ -294,14 +298,19 @@ def metadata(self) -> Optional[RetrievalMetadata]:
294298
return self._metadata
295299

296300
def persist(self, storage: SavedDatasetStorage):
297-
pass
301+
assert isinstance(storage, SavedDatasetPostgreSQLStorage)
302+
303+
df_to_postgres_table(
304+
config=self.config.offline_store,
305+
df=self.to_df(),
306+
table_name=storage.postgres_options._table,
307+
)
298308

299309

300310
def _get_entity_df_event_timestamp_range(
301311
entity_df: Union[pd.DataFrame, str],
302312
entity_df_event_timestamp_col: str,
303313
config: RepoConfig,
304-
table_name: str,
305314
) -> Tuple[datetime, datetime]:
306315
if isinstance(entity_df, pd.DataFrame):
307316
entity_df_event_timestamp = entity_df.loc[
@@ -312,15 +321,15 @@ def _get_entity_df_event_timestamp_range(
312321
entity_df_event_timestamp, utc=True
313322
)
314323
entity_df_event_timestamp_range = (
315-
entity_df_event_timestamp.min(),
316-
entity_df_event_timestamp.max(),
324+
entity_df_event_timestamp.min().to_pydatetime(),
325+
entity_df_event_timestamp.max().to_pydatetime(),
317326
)
318327
elif isinstance(entity_df, str):
319328
# If the entity_df is a string (SQL query), determine range
320329
# from table
321330
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
322331
cur.execute(
323-
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max FROM {table_name}"
332+
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max FROM ({entity_df}) as tmp_alias"
324333
),
325334
res = cur.fetchone()
326335
entity_df_event_timestamp_range = (res[0], res[1])
@@ -374,6 +383,34 @@ def build_point_in_time_query(
374383
return query
375384

376385

386+
def _upload_entity_df(
387+
config: RepoConfig, entity_df: Union[pd.DataFrame, str], table_name: str
388+
):
389+
if isinstance(entity_df, pd.DataFrame):
390+
# If the entity_df is a pandas dataframe, upload it to Postgres
391+
df_to_postgres_table(config.offline_store, entity_df, table_name)
392+
elif isinstance(entity_df, str):
393+
# If the entity_df is a string (SQL query), create a Postgres table out of it
394+
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
395+
cur.execute(f"CREATE TABLE {table_name} AS ({entity_df})")
396+
else:
397+
raise InvalidEntityType(type(entity_df))
398+
399+
400+
def _get_entity_schema(
401+
entity_df: Union[pd.DataFrame, str],
402+
config: RepoConfig,
403+
) -> Dict[str, np.dtype]:
404+
if isinstance(entity_df, pd.DataFrame):
405+
return dict(zip(entity_df.columns, entity_df.dtypes))
406+
407+
elif isinstance(entity_df, str):
408+
df_query = f"({entity_df}) AS sub"
409+
return get_query_schema(config.offline_store, df_query)
410+
else:
411+
raise InvalidEntityType(type(entity_df))
412+
413+
377414
# Copied from the Feast Redshift offline store implementation
378415
# Note: Keep this in sync with sdk/python/feast/infra/offline_stores/redshift.py:
379416
# MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN

sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres_source.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,42 @@
11
import json
22
from typing import Callable, Dict, Iterable, Optional, Tuple
33

4+
from typeguard import typechecked
5+
46
from feast.data_source import DataSource
7+
from feast.errors import DataSourceNoNameException
58
from feast.infra.utils.postgres.connection_utils import _get_conn
69
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
10+
from feast.protos.feast.core.SavedDataset_pb2 import (
11+
SavedDatasetStorage as SavedDatasetStorageProto,
12+
)
713
from feast.repo_config import RepoConfig
14+
from feast.saved_dataset import SavedDatasetStorage
815
from feast.type_map import pg_type_code_to_pg_type, pg_type_to_feast_value_type
916
from feast.value_type import ValueType
1017

1118

19+
@typechecked
1220
class PostgreSQLSource(DataSource):
1321
def __init__(
1422
self,
15-
name: str,
16-
query: str,
23+
name: Optional[str] = None,
24+
query: Optional[str] = None,
25+
table: Optional[str] = None,
1726
timestamp_field: Optional[str] = "",
1827
created_timestamp_column: Optional[str] = "",
1928
field_mapping: Optional[Dict[str, str]] = None,
2029
description: Optional[str] = "",
2130
tags: Optional[Dict[str, str]] = None,
2231
owner: Optional[str] = "",
2332
):
24-
self._postgres_options = PostgreSQLOptions(name=name, query=query)
33+
self._postgres_options = PostgreSQLOptions(name=name, query=query, table=table)
34+
35+
# If no name, use the table as the default name.
36+
if name is None and table is None:
37+
raise DataSourceNoNameException()
38+
name = name or table
39+
assert name
2540

2641
super().__init__(
2742
name=name,
@@ -55,9 +70,11 @@ def from_proto(data_source: DataSourceProto):
5570
assert data_source.HasField("custom_options")
5671

5772
postgres_options = json.loads(data_source.custom_options.configuration)
73+
5874
return PostgreSQLSource(
5975
name=postgres_options["name"],
6076
query=postgres_options["query"],
77+
table=postgres_options["table"],
6178
field_mapping=dict(data_source.field_mapping),
6279
timestamp_field=data_source.timestamp_field,
6380
created_timestamp_column=data_source.created_timestamp_column,
@@ -102,26 +119,60 @@ def get_table_column_names_and_types(
102119
)
103120

104121
def get_table_query_string(self) -> str:
105-
return f"({self._postgres_options._query})"
122+
123+
if self._postgres_options._table:
124+
return f"{self._postgres_options._table}"
125+
else:
126+
return f"({self._postgres_options._query})"
106127

107128

108129
class PostgreSQLOptions:
109-
def __init__(self, name: str, query: Optional[str]):
110-
self._name = name
111-
self._query = query
130+
def __init__(
131+
self,
132+
name: Optional[str],
133+
query: Optional[str],
134+
table: Optional[str],
135+
):
136+
self._name = name or ""
137+
self._query = query or ""
138+
self._table = table or ""
112139

113140
@classmethod
114141
def from_proto(cls, postgres_options_proto: DataSourceProto.CustomSourceOptions):
115142
config = json.loads(postgres_options_proto.configuration.decode("utf8"))
116-
postgres_options = cls(name=config["name"], query=config["query"])
143+
postgres_options = cls(
144+
name=config["name"], query=config["query"], table=config["table"]
145+
)
117146

118147
return postgres_options
119148

120149
def to_proto(self) -> DataSourceProto.CustomSourceOptions:
121150
postgres_options_proto = DataSourceProto.CustomSourceOptions(
122151
configuration=json.dumps(
123-
{"name": self._name, "query": self._query}
152+
{"name": self._name, "query": self._query, "table": self._table}
124153
).encode()
125154
)
126-
127155
return postgres_options_proto
156+
157+
158+
class SavedDatasetPostgreSQLStorage(SavedDatasetStorage):
159+
_proto_attr_name = "custom_storage"
160+
161+
postgres_options: PostgreSQLOptions
162+
163+
def __init__(self, table_ref: str):
164+
self.postgres_options = PostgreSQLOptions(
165+
table=table_ref, name=None, query=None
166+
)
167+
168+
@staticmethod
169+
def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage:
170+
return SavedDatasetPostgreSQLStorage(
171+
table_ref=PostgreSQLOptions.from_proto(storage_proto.custom_storage)._table
172+
)
173+
174+
def to_proto(self) -> SavedDatasetStorageProto:
175+
return SavedDatasetStorageProto(custom_storage=self.postgres_options.to_proto())
176+
177+
def to_data_source(self) -> DataSource:
178+
return PostgreSQLSource(table=self.postgres_options._table)
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from feast.infra.offline_stores.contrib.postgres_offline_store.tests.data_source import (
22
PostgreSQLDataSourceCreator,
33
)
4+
from tests.integration.feature_repos.repo_configuration import REDIS_CONFIG
5+
from tests.integration.feature_repos.universal.online_store.redis import (
6+
RedisOnlineStoreCreator,
7+
)
48

59
AVAILABLE_OFFLINE_STORES = [("local", PostgreSQLDataSourceCreator)]
610

7-
AVAILABLE_ONLINE_STORES = {"postgres": (None, PostgreSQLDataSourceCreator)}
11+
AVAILABLE_ONLINE_STORES = {"redis": (REDIS_CONFIG, RedisOnlineStoreCreator)}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from feast.infra.offline_stores.contrib.postgres_offline_store.tests.data_source import (
2+
PostgreSQLDataSourceCreator,
3+
)
4+
from tests.integration.feature_repos.integration_test_repo_config import (
5+
IntegrationTestRepoConfig,
6+
)
7+
8+
FULL_REPO_CONFIGS = [
9+
IntegrationTestRepoConfig(online_store_creator=PostgreSQLDataSourceCreator),
10+
]

0 commit comments

Comments
 (0)