Skip to content
Merged
Prev Previous commit
Next Next commit
fix test
Signed-off-by: HaoXuAI <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed Apr 18, 2025
commit 8e02747c4c60db1a68312cef540d3d17c0986585
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/compute_engines/spark/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def execute(self, context: ExecutionContext) -> DAGValue:
feature_df: DataFrame = feature_value.data

entity_df = context.entity_df
if not entity_df:
if entity_df is None:
return DAGValue(
data=feature_df,
format=DAGFormat.SPARK,
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def pull_all_from_table_or_query(
+ [timestamp_field]
)
timestamp_filter = get_timestamp_filter_sql(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

major change: get the timestamp filter from get_timestamp_filter_sql

start_date, end_date, timestamp_field
start_date, end_date, timestamp_field, quote_fields=False
)
query = f"""
SELECT {field_string}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,12 @@ def pull_all_from_table_or_query(
)[:-3]

timestamp_filter = get_timestamp_filter_sql(
start_date_str, end_date_str, timestamp_field, date_partition_column
start_date_str,
end_date_str,
timestamp_field,
date_partition_column,
cast_style="raw",
quote_fields=False,
)

query = f"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,16 @@ def pull_all_from_table_or_query(
field_string = ", ".join(
join_key_columns + feature_name_columns + [timestamp_field]
)

start_date_normalized = normalize_timestamp(start_date) if start_date else None
end_date_normalized = normalize_timestamp(end_date) if end_date else None
start_date_normalized = (
f"`{normalize_timestamp(start_date)}`" if start_date else None
)
end_date_normalized = f"`{normalize_timestamp(end_date)}`" if end_date else None
timestamp_filter = get_timestamp_filter_sql(
start_date_normalized, end_date_normalized, timestamp_field
start_date_normalized,
end_date_normalized,
timestamp_field,
cast_style="raw",
quote_fields=False,
)

query = f"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def pull_all_from_table_or_query(
timestamp_field,
tz=timezone.utc,
cast_style="timestamptz",
date_time_separator=" ", # backwards compatibility but inconsistent with other offline stores
)

query = f"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def pull_all_from_table_or_query(

from_expression = data_source.get_table_query_string()
timestamp_filter = get_timestamp_filter_sql(
start_date, end_date, timestamp_field, tz=timezone.utc
start_date, end_date, timestamp_field, tz=timezone.utc, quote_fields=False
)

query = f"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def pull_all_from_table_or_query(
)

timestamp_filter = get_timestamp_filter_sql(
start_date, end_date, timestamp_field
start_date, end_date, timestamp_field, quote_fields=False
)
query = f"""
SELECT {field_string}
Expand Down
33 changes: 25 additions & 8 deletions sdk/python/feast/infra/offline_stores/offline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ def get_timestamp_filter_sql(
date_partition_column: Optional[str] = None,
tz: Optional[timezone] = None,
cast_style: Literal["timestamp_func", "timestamptz", "raw"] = "timestamp_func",
date_time_separator: str = "T",
quote_fields: bool = True,
) -> str:
"""
Returns SQL filter condition (no WHERE) with flexible timestamp casting.
Expand All @@ -289,21 +291,33 @@ def get_timestamp_filter_sql(
- "timestamp_func": TIMESTAMP('...') → Snowflake, BigQuery, Athena
- "timestamptz": '...'::timestamptz → PostgreSQL
- "raw": '...' → no cast, string only
date_time_separator: separator for datetime strings (default is "T")
(e.g. "2023-10-01T00:00:00" or "2023-10-01 00:00:00")
quote_fields: whether to quote the timestamp and partition column names

Returns:
SQL filter string without WHERE
"""

def quote_column_if_needed(column: Optional[str]) -> Optional[str]:
if not column or not quote_fields:
return column
return f'"{column}"'

def format_casted_ts(val: Union[str, datetime]) -> str:
if isinstance(val, datetime):
if tz:
val = val.astimezone(tz)
val_str = val.isoformat(sep=date_time_separator)
else:
val_str = val

if cast_style == "timestamp_func":
return f"TIMESTAMP('{val}')"
return f"TIMESTAMP('{val_str}')"
elif cast_style == "timestamptz":
return f"'{val}'::timestamptz"
return f"'{val_str}'::timestamptz"
else:
return f"'{val}'"
return f"'{val_str}'"

def format_date(val: Union[str, datetime]) -> str:
if isinstance(val, datetime):
Expand All @@ -312,23 +326,26 @@ def format_date(val: Union[str, datetime]) -> str:
return val.strftime("%Y-%m-%d")
return val

ts_field = quote_column_if_needed(timestamp_field)
dp_field = quote_column_if_needed(date_partition_column)

filters = []

# Timestamp filters
if start_date and end_date:
filters.append(
f'"{timestamp_field}" BETWEEN {format_casted_ts(start_date)} AND {format_casted_ts(end_date)}'
f"{ts_field} BETWEEN {format_casted_ts(start_date)} AND {format_casted_ts(end_date)}"
)
elif start_date:
filters.append(f"{timestamp_field} >= {format_casted_ts(start_date)}")
filters.append(f"{ts_field} >= {format_casted_ts(start_date)}")
elif end_date:
filters.append(f"{timestamp_field} <= {format_casted_ts(end_date)}")
filters.append(f"{ts_field} <= {format_casted_ts(end_date)}")

# Partition pruning
if date_partition_column:
if start_date:
filters.append(f"{date_partition_column} >= '{format_date(start_date)}'")
filters.append(f"{dp_field} >= '{format_date(start_date)}'")
if end_date:
filters.append(f"{date_partition_column} <= '{format_date(end_date)}'")
filters.append(f"{dp_field} <= '{format_date(end_date)}'")

return " AND ".join(filters) if filters else ""
Loading