Skip to content

Commit d666263

Browse files
EXPEbdodlashuchu
authored andcommitted
feat: Support for nested timestamp fields in Spark Offline store (feast-dev#4740)
1 parent 716de2e commit d666263

7 files changed

Lines changed: 236 additions & 16 deletions

File tree

sdk/python/feast/data_source.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class DataSource(ABC):
176176
was created, used for deduplicating rows.
177177
field_mapping (optional): A dictionary mapping of column names in this data
178178
source to feature names in a feature table or view. Only used for feature
179-
columns, not entity or timestamp columns.
179+
columns and timestamp columns, not entity columns.
180180
description (optional) A human-readable description.
181181
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
182182
owner (optional): The owner of the data source, typically the email of the primary
@@ -463,9 +463,11 @@ def from_proto(data_source: DataSourceProto):
463463
description=data_source.description,
464464
tags=dict(data_source.tags),
465465
owner=data_source.owner,
466-
batch_source=DataSource.from_proto(data_source.batch_source)
467-
if data_source.batch_source
468-
else None,
466+
batch_source=(
467+
DataSource.from_proto(data_source.batch_source)
468+
if data_source.batch_source
469+
else None
470+
),
469471
)
470472

471473
def to_proto(self) -> DataSourceProto:
@@ -643,9 +645,11 @@ def from_proto(data_source: DataSourceProto):
643645
description=data_source.description,
644646
tags=dict(data_source.tags),
645647
owner=data_source.owner,
646-
batch_source=DataSource.from_proto(data_source.batch_source)
647-
if data_source.batch_source
648-
else None,
648+
batch_source=(
649+
DataSource.from_proto(data_source.batch_source)
650+
if data_source.batch_source
651+
else None
652+
),
649653
)
650654

651655
@staticmethod

sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArti
240240
) = spark_serialized_artifacts.unserialize()
241241

242242
if feature_view.batch_source.field_mapping is not None:
243+
# Spark offline store does the field mapping in pull_latest_from_table_or_query() call
244+
# This may be needed in future if this materialization engine supports other offline stores
243245
table = _run_pyarrow_field_mapping(
244246
table, feature_view.batch_source.field_mapping
245247
)

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from feast.repo_config import FeastConfigBaseModel, RepoConfig
3434
from feast.saved_dataset import SavedDatasetStorage
3535
from feast.type_map import spark_schema_to_np_dtypes
36+
from feast.utils import _get_fields_with_aliases
3637

3738
# Make sure spark warning doesn't raise more than once.
3839
warnings.simplefilter("once", RuntimeWarning)
@@ -90,16 +91,22 @@ def pull_latest_from_table_or_query(
9091
if created_timestamp_column:
9192
timestamps.append(created_timestamp_column)
9293
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
93-
field_string = ", ".join(join_key_columns + feature_name_columns + timestamps)
94+
(fields_with_aliases, aliases) = _get_fields_with_aliases(
95+
fields=join_key_columns + feature_name_columns + timestamps,
96+
field_mappings=data_source.field_mapping,
97+
)
98+
99+
fields_as_string = ", ".join(fields_with_aliases)
100+
aliases_as_string = ", ".join(aliases)
94101

95102
start_date_str = _format_datetime(start_date)
96103
end_date_str = _format_datetime(end_date)
97104
query = f"""
98105
SELECT
99-
{field_string}
106+
{aliases_as_string}
100107
{f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""}
101108
FROM (
102-
SELECT {field_string},
109+
SELECT {fields_as_string},
103110
ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS feast_row_
104111
FROM {from_expression} t1
105112
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}')
@@ -279,14 +286,19 @@ def pull_all_from_table_or_query(
279286
spark_session = get_spark_session_or_start_new_with_repoconfig(
280287
store_config=config.offline_store
281288
)
289+
(fields_with_aliases, aliases) = _get_fields_with_aliases(
290+
fields=join_key_columns + feature_name_columns + [timestamp_field],
291+
field_mappings=data_source.field_mapping,
292+
)
293+
294+
fields_with_alias_string = ", ".join(fields_with_aliases)
282295

283-
fields = ", ".join(join_key_columns + feature_name_columns + [timestamp_field])
284296
from_expression = data_source.get_table_query_string()
285297
start_date = start_date.astimezone(tz=timezone.utc)
286298
end_date = end_date.astimezone(tz=timezone.utc)
287299

288300
query = f"""
289-
SELECT {fields}
301+
SELECT {fields_with_alias_string}
290302
FROM {from_expression}
291303
WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}'
292304
"""

sdk/python/feast/utils.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def _get_requested_feature_views_to_features_dict(
106106
on_demand_feature_views: List["OnDemandFeatureView"],
107107
) -> Tuple[Dict["FeatureView", List[str]], Dict["OnDemandFeatureView", List[str]]]:
108108
"""Create a dict of FeatureView -> List[Feature] for all requested features.
109-
Set full_feature_names to True to have feature names prefixed by their feature view name."""
109+
Set full_feature_names to True to have feature names prefixed by their feature view name.
110+
"""
110111

111112
feature_views_to_feature_map: Dict["FeatureView", List[str]] = defaultdict(list)
112113
on_demand_feature_views_to_feature_map: Dict["OnDemandFeatureView", List[str]] = (
@@ -212,6 +213,28 @@ def _run_pyarrow_field_mapping(
212213
return table
213214

214215

216+
def _get_fields_with_aliases(
217+
fields: List[str],
218+
field_mappings: Dict[str, str],
219+
) -> Tuple[List[str], List[str]]:
220+
"""
221+
Get a list of fields with aliases based on the field mappings.
222+
"""
223+
for field in fields:
224+
if "." in field and field not in field_mappings:
225+
raise ValueError(
226+
f"Feature {field} contains a '.' character, which is not allowed in field names. Use field mappings to rename fields."
227+
)
228+
fields_with_aliases = [
229+
f"{field} AS {field_mappings[field]}" if field in field_mappings else field
230+
for field in fields
231+
]
232+
aliases = [
233+
field_mappings[field] if field in field_mappings else field for field in fields
234+
]
235+
return (fields_with_aliases, aliases)
236+
237+
215238
def _coerce_datetime(ts):
216239
"""
217240
Depending on underlying time resolution, arrow to_pydict() sometimes returns pd
@@ -781,9 +804,11 @@ def _populate_response_from_feature_data(
781804
"""
782805
# Add the feature names to the response.
783806
requested_feature_refs = [
784-
f"{table.projection.name_to_use()}__{feature_name}"
785-
if full_feature_names
786-
else feature_name
807+
(
808+
f"{table.projection.name_to_use()}__{feature_name}"
809+
if full_feature_names
810+
else feature_name
811+
)
787812
for feature_name in requested_features
788813
]
789814
online_features_response.metadata.feature_names.val.extend(requested_feature_refs)

sdk/python/tests/integration/materialization/contrib/spark/test_spark.py renamed to sdk/python/tests/integration/materialization/contrib/spark/test_spark_materialization_engine.py

File renamed without changes.

sdk/python/tests/integration/registration/test_universal_registry.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,3 +1828,51 @@ def test_apply_entity_to_sql_registry_and_reinitialize_sql_registry(test_registr
18281828

18291829
updated_test_registry.teardown()
18301830
test_registry.teardown()
1831+
1832+
1833+
@pytest.mark.integration
1834+
def test_commit_for_read_only_user():
1835+
fd, registry_path = mkstemp()
1836+
registry_config = RegistryConfig(path=registry_path, cache_ttl_seconds=600)
1837+
write_registry = Registry("project", registry_config, None)
1838+
1839+
entity = Entity(
1840+
name="driver_car_id",
1841+
description="Car driver id",
1842+
tags={"team": "matchmaking"},
1843+
)
1844+
1845+
project = "project"
1846+
1847+
# Register Entity without commiting
1848+
write_registry.apply_entity(entity, project, commit=False)
1849+
assert write_registry.cached_registry_proto
1850+
project_obj = write_registry.cached_registry_proto.projects[0]
1851+
assert project == Project.from_proto(project_obj).name
1852+
assert_project(project, write_registry, True)
1853+
1854+
# Retrieving the entity should still succeed
1855+
entities = write_registry.list_entities(project, allow_cache=True, tags=entity.tags)
1856+
entity = entities[0]
1857+
assert (
1858+
len(entities) == 1
1859+
and entity.name == "driver_car_id"
1860+
and entity.description == "Car driver id"
1861+
and "team" in entity.tags
1862+
and entity.tags["team"] == "matchmaking"
1863+
)
1864+
1865+
# commit from the original registry
1866+
write_registry.commit()
1867+
1868+
# Reconstruct the new registry in order to read the newly written store
1869+
with mock.patch.object(
1870+
Registry,
1871+
"commit",
1872+
side_effect=Exception("Read only users are not allowed to commit"),
1873+
):
1874+
read_registry = Registry("project", registry_config, None)
1875+
entities = read_registry.list_entities(project, tags=entity.tags)
1876+
assert len(entities) == 1
1877+
1878+
write_registry.teardown()
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from datetime import datetime
2+
from unittest.mock import MagicMock, patch
3+
4+
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
5+
SparkOfflineStore,
6+
SparkOfflineStoreConfig,
7+
)
8+
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
9+
SparkSource,
10+
)
11+
from feast.infra.offline_stores.offline_store import RetrievalJob
12+
from feast.repo_config import RepoConfig
13+
14+
15+
@patch(
16+
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
17+
)
18+
def test_pull_latest_from_table_with_nested_timestamp_or_query(mock_get_spark_session):
19+
mock_spark_session = MagicMock()
20+
mock_get_spark_session.return_value = mock_spark_session
21+
22+
test_repo_config = RepoConfig(
23+
project="test_project",
24+
registry="test_registry",
25+
provider="local",
26+
offline_store=SparkOfflineStoreConfig(type="spark"),
27+
)
28+
29+
test_data_source = SparkSource(
30+
name="test_nested_batch_source",
31+
description="test_nested_batch_source",
32+
table="offline_store_database_name.offline_store_table_name",
33+
timestamp_field="nested_timestamp",
34+
field_mapping={
35+
"event_header.event_published_datetime_utc": "nested_timestamp",
36+
},
37+
)
38+
39+
# Define the parameters for the method
40+
join_key_columns = ["key1", "key2"]
41+
feature_name_columns = ["feature1", "feature2"]
42+
timestamp_field = "event_header.event_published_datetime_utc"
43+
created_timestamp_column = "created_timestamp"
44+
start_date = datetime(2021, 1, 1)
45+
end_date = datetime(2021, 1, 2)
46+
47+
# Call the method
48+
retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query(
49+
config=test_repo_config,
50+
data_source=test_data_source,
51+
join_key_columns=join_key_columns,
52+
feature_name_columns=feature_name_columns,
53+
timestamp_field=timestamp_field,
54+
created_timestamp_column=created_timestamp_column,
55+
start_date=start_date,
56+
end_date=end_date,
57+
)
58+
59+
expected_query = """SELECT
60+
key1, key2, feature1, feature2, nested_timestamp, created_timestamp
61+
62+
FROM (
63+
SELECT key1, key2, feature1, feature2, event_header.event_published_datetime_utc AS nested_timestamp, created_timestamp,
64+
ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_header.event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_
65+
FROM `offline_store_database_name`.`offline_store_table_name` t1
66+
WHERE event_header.event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000')
67+
) t2
68+
WHERE feast_row_ = 1""" # noqa: W293
69+
70+
assert isinstance(retrieval_job, RetrievalJob)
71+
assert retrieval_job.query.strip() == expected_query.strip()
72+
73+
74+
@patch(
75+
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
76+
)
77+
def test_pull_latest_from_table_without_nested_timestamp_or_query(
78+
mock_get_spark_session,
79+
):
80+
mock_spark_session = MagicMock()
81+
mock_get_spark_session.return_value = mock_spark_session
82+
83+
test_repo_config = RepoConfig(
84+
project="test_project",
85+
registry="test_registry",
86+
provider="local",
87+
offline_store=SparkOfflineStoreConfig(type="spark"),
88+
)
89+
90+
test_data_source = SparkSource(
91+
name="test_batch_source",
92+
description="test_nested_batch_source",
93+
table="offline_store_database_name.offline_store_table_name",
94+
timestamp_field="event_published_datetime_utc",
95+
)
96+
97+
# Define the parameters for the method
98+
join_key_columns = ["key1", "key2"]
99+
feature_name_columns = ["feature1", "feature2"]
100+
timestamp_field = "event_published_datetime_utc"
101+
created_timestamp_column = "created_timestamp"
102+
start_date = datetime(2021, 1, 1)
103+
end_date = datetime(2021, 1, 2)
104+
105+
# Call the method
106+
retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query(
107+
config=test_repo_config,
108+
data_source=test_data_source,
109+
join_key_columns=join_key_columns,
110+
feature_name_columns=feature_name_columns,
111+
timestamp_field=timestamp_field,
112+
created_timestamp_column=created_timestamp_column,
113+
start_date=start_date,
114+
end_date=end_date,
115+
)
116+
117+
expected_query = """SELECT
118+
key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp
119+
120+
FROM (
121+
SELECT key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp,
122+
ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_
123+
FROM `offline_store_database_name`.`offline_store_table_name` t1
124+
WHERE event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000')
125+
) t2
126+
WHERE feast_row_ = 1""" # noqa: W293
127+
128+
assert isinstance(retrieval_job, RetrievalJob)
129+
assert retrieval_job.query.strip() == expected_query.strip()

0 commit comments

Comments
 (0)