Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
entity query as cte
Signed-off-by: Blake <blaketastic2@gmail.com>
  • Loading branch information
blaketastic2 committed Mar 31, 2025
commit 41e3a5f5662e721f47971a224b8a2b00aca60987
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
from dataclasses import asdict
from datetime import datetime, timezone
from enum import Enum
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -48,8 +49,16 @@
from .postgres_source import PostgreSQLSource


class EntitySelectMode(Enum):
""" Use a temporary table to store the entity DataFrame or SQL query when querying feature data """
temp_table = "temp_table"
""" Use the entity SQL query directly when querying feature data """
embed_query = "embed_query"


class PostgreSQLOfflineStoreConfig(PostgreSQLConfig):
type: Literal["postgres"] = "postgres"
entity_select_mode: EntitySelectMode = EntitySelectMode.temp_table


class PostgreSQLOfflineStore(OfflineStore):
Expand Down Expand Up @@ -134,7 +143,17 @@ def get_historical_features(
def query_generator() -> Iterator[str]:
table_name = offline_utils.get_temp_entity_table_name()

_upload_entity_df(config, entity_df, table_name)
# If using CTE and entity_df is a SQL query, we don't need a table
if config.offline_store.entity_select_mode == EntitySelectMode.embed_query:
if isinstance(entity_df, str):
left_table_query_string = entity_df
else:
raise ValueError(
f"Invalid entity select mode: {config.offline_store.entity_select_mode} cannot be used with entity_df as a DataFrame"
)
else:
left_table_query_string = table_name
_upload_entity_df(config, entity_df, table_name)

expected_join_keys = offline_utils.get_expected_join_keys(
project, feature_views, registry
Expand Down Expand Up @@ -163,14 +182,19 @@ def query_generator() -> Iterator[str]:
try:
yield build_point_in_time_query(
query_context_dict,
left_table_query_string=table_name,
left_table_query_string=left_table_query_string,
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
entity_df_columns=entity_schema.keys(),
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
full_feature_names=full_feature_names,
entity_select_mode=config.offline_store.entity_select_mode,
)
finally:
if table_name:
# Only cleanup if we created a table
if (
config.offline_store.entity_select_mode
== EntitySelectMode.temp_table
):
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
cur.execute(
sql.SQL(
Expand Down Expand Up @@ -362,6 +386,7 @@ def build_point_in_time_query(
entity_df_columns: KeysView[str],
query_template: str,
full_feature_names: bool = False,
entity_select_mode: EntitySelectMode = EntitySelectMode.temp_table,
) -> str:
"""Build point-in-time query between each feature view table and the entity dataframe for PostgreSQL"""
template = Environment(loader=BaseLoader()).from_string(source=query_template)
Expand Down Expand Up @@ -389,6 +414,7 @@ def build_point_in_time_query(
"featureviews": feature_view_query_contexts,
"full_feature_names": full_feature_names,
"final_output_feature_names": final_output_feature_names,
"entity_select_mode": entity_select_mode,
}

query = template.render(template_context)
Expand Down Expand Up @@ -429,11 +455,19 @@ def _get_entity_schema(
# https://github.com/feast-dev/feast/blob/master/sdk/python/feast/infra/offline_stores/redshift.py

MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """
WITH entity_query AS (
{% if entity_select_mode == EntitySelectMode.embed_query %}
{{ left_table_query_string }}
{% else %}
SELECT * FROM {{ left_table_query_string }}
{% endif %}
),

/*
Compute a deterministic hash for the `left_table_query_string` that will be used throughout
all the logic as the field to GROUP BY the data
*/
WITH entity_dataframe AS (
entity_dataframe AS (
SELECT *,
{{entity_df_event_timestamp_col}} AS entity_timestamp
{% for featureview in featureviews %}
Expand All @@ -448,7 +482,7 @@ def _get_entity_schema(
,CAST("{{entity_df_event_timestamp_col}}" AS VARCHAR) AS "{{featureview.name}}__entity_row_unique_id"
{% endif %}
{% endfor %}
FROM {{ left_table_query_string }}
FROM entity_query
),

{% for featureview in featureviews %}
Expand Down