Skip to content

Commit 26e0b3e

Browse files
committed
implement spark materialization engine
Signed-off-by: niklasvm <niklasvm@gmail.com>
1 parent b4ef834 commit 26e0b3e

3 files changed

Lines changed: 381 additions & 0 deletions

File tree

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
import tempfile
2+
import uuid
3+
from dataclasses import dataclass
4+
from datetime import datetime
5+
from typing import Callable, List, Literal, Optional, Sequence, Union
6+
7+
import dill
8+
import pyarrow
9+
from pyspark.sql import DataFrame
10+
from tqdm import tqdm
11+
12+
from feast.batch_feature_view import BatchFeatureView
13+
from feast.entity import Entity
14+
from feast.feature_view import FeatureView
15+
from feast.infra.materialization.batch_materialization_engine import (
16+
BatchMaterializationEngine,
17+
MaterializationJob,
18+
MaterializationJobStatus,
19+
MaterializationTask,
20+
)
21+
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
22+
SparkOfflineStore,
23+
SparkRetrievalJob,
24+
)
25+
from feast.infra.online_stores.online_store import OnlineStore
26+
from feast.infra.passthrough_provider import PassthroughProvider
27+
from feast.infra.registry.base_registry import BaseRegistry
28+
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto
29+
from feast.repo_config import FeastConfigBaseModel, RepoConfig
30+
from feast.stream_feature_view import StreamFeatureView
31+
from feast.utils import (
32+
_convert_arrow_to_proto,
33+
_get_column_names,
34+
_run_pyarrow_field_mapping,
35+
)
36+
37+
38+
class SparkMaterializationEngineConfig(FeastConfigBaseModel):
39+
"""Batch Materialization Engine config for spark engine"""
40+
41+
type: Literal["spark"] = "spark"
42+
""" Type selector"""
43+
batch_size: int
44+
45+
46+
@dataclass
47+
class SparkMaterializationJob(MaterializationJob):
48+
def __init__(
49+
self,
50+
job_id: str,
51+
status: MaterializationJobStatus,
52+
error: Optional[BaseException] = None,
53+
) -> None:
54+
super().__init__()
55+
self._job_id: str = job_id
56+
self._status: MaterializationJobStatus = status
57+
self._error: Optional[BaseException] = error
58+
59+
def status(self) -> MaterializationJobStatus:
60+
return self._status
61+
62+
def error(self) -> Optional[BaseException]:
63+
return self._error
64+
65+
def should_be_retried(self) -> bool:
66+
return False
67+
68+
def job_id(self) -> str:
69+
return self._job_id
70+
71+
def url(self) -> Optional[str]:
72+
return None
73+
74+
75+
class SparkMaterializationEngine(BatchMaterializationEngine):
76+
def update(
77+
self,
78+
project: str,
79+
views_to_delete: Sequence[
80+
Union[BatchFeatureView, StreamFeatureView, FeatureView]
81+
],
82+
views_to_keep: Sequence[
83+
Union[BatchFeatureView, StreamFeatureView, FeatureView]
84+
],
85+
entities_to_delete: Sequence[Entity],
86+
entities_to_keep: Sequence[Entity],
87+
):
88+
# Nothing to set up.
89+
pass
90+
91+
def teardown_infra(
92+
self,
93+
project: str,
94+
fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]],
95+
entities: Sequence[Entity],
96+
):
97+
# Nothing to tear down.
98+
pass
99+
100+
def __init__(
101+
self,
102+
*,
103+
repo_config: RepoConfig,
104+
offline_store: SparkOfflineStore,
105+
online_store: OnlineStore,
106+
**kwargs,
107+
):
108+
if not isinstance(offline_store, SparkOfflineStore):
109+
raise TypeError(
110+
"SparkMaterializationEngine is only compatible with the SparkOfflineStore"
111+
)
112+
super().__init__(
113+
repo_config=repo_config,
114+
offline_store=offline_store,
115+
online_store=online_store,
116+
**kwargs,
117+
)
118+
119+
def materialize(
120+
self, registry, tasks: List[MaterializationTask]
121+
) -> List[MaterializationJob]:
122+
return [
123+
self._materialize_one(
124+
registry,
125+
task.feature_view,
126+
task.start_time,
127+
task.end_time,
128+
task.project,
129+
task.tqdm_builder,
130+
)
131+
for task in tasks
132+
]
133+
134+
def _materialize_one(
135+
self,
136+
registry: BaseRegistry,
137+
feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView],
138+
start_date: datetime,
139+
end_date: datetime,
140+
project: str,
141+
tqdm_builder: Callable[[int], tqdm],
142+
):
143+
entities = []
144+
for entity_name in feature_view.entities:
145+
entities.append(registry.get_entity(entity_name, project))
146+
147+
(
148+
join_key_columns,
149+
feature_name_columns,
150+
timestamp_field,
151+
created_timestamp_column,
152+
) = _get_column_names(feature_view, entities)
153+
154+
job_id = f"{feature_view.name}-{start_date}-{end_date}"
155+
156+
try:
157+
offline_job: SparkRetrievalJob = (
158+
self.offline_store.pull_latest_from_table_or_query(
159+
config=self.repo_config,
160+
data_source=feature_view.batch_source,
161+
join_key_columns=join_key_columns,
162+
feature_name_columns=feature_name_columns,
163+
timestamp_field=timestamp_field,
164+
created_timestamp_column=created_timestamp_column,
165+
start_date=start_date,
166+
end_date=end_date,
167+
)
168+
)
169+
170+
# serialize feature view using proto
171+
feature_view_proto = feature_view.to_proto().SerializeToString()
172+
173+
# serialize repo_config to disk. Will be used to instantiate the online store
174+
repo_config_file = tempfile.NamedTemporaryFile(delete=False).name
175+
with open(repo_config_file, "wb") as f:
176+
dill.dump(self.repo_config, f)
177+
178+
# split data into batches
179+
spark_df = offline_job.to_spark_df()
180+
batch_size = self.repo_config.batch_engine.batch_size
181+
batched_spark_df, batch_column_alias = add_batch_column(
182+
spark_df,
183+
join_key_columns=join_key_columns,
184+
timestamp_field=timestamp_field,
185+
batch_size=batch_size,
186+
)
187+
188+
schema = [
189+
f"{x} {y}"
190+
for x, y in batched_spark_df.dtypes + [("success_flag", "string")]
191+
]
192+
schema_ddl = ", ".join(schema)
193+
result = batched_spark_df.groupBy(batch_column_alias).applyInPandas(
194+
lambda x: _process_by_pandas_batch(
195+
x,
196+
feature_view_proto=feature_view_proto,
197+
repo_config_file=repo_config_file,
198+
),
199+
schema=schema_ddl,
200+
)
201+
result.collect()
202+
203+
return SparkMaterializationJob(
204+
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
205+
)
206+
except BaseException as e:
207+
return SparkMaterializationJob(
208+
job_id=job_id, status=MaterializationJobStatus.ERROR, error=e
209+
)
210+
211+
212+
def add_batch_column(
213+
spark_df: DataFrame, join_key_columns, timestamp_field, batch_size
214+
):
215+
"""
216+
Generates a batch column for a data frame
217+
"""
218+
spark_session = spark_df.sparkSession
219+
220+
# generate a unique name for the view
221+
view_name = f"{uuid.uuid4()}".replace("-", "")
222+
223+
row_number_index_alias = f"{view_name}_row_index"
224+
batch_column_alias = f"{view_name}_batch"
225+
original_columns_snippet = ", ".join(spark_df.columns)
226+
227+
# generate batch
228+
spark_df.createOrReplaceTempView(view_name)
229+
batched_spark_df = spark_session.sql(
230+
f"""
231+
with add_index as (
232+
select
233+
{original_columns_snippet},
234+
monotonically_increasing_id() as {row_number_index_alias}
235+
from {view_name}
236+
)
237+
select
238+
{original_columns_snippet},
239+
floor({(row_number_index_alias)}/{batch_size}) as {batch_column_alias}
240+
from add_index
241+
"""
242+
)
243+
244+
return batched_spark_df, batch_column_alias
245+
246+
247+
def _process_by_pandas_batch(pdf, feature_view_proto, repo_config_file):
248+
249+
# unserialize
250+
proto = FeatureViewProto()
251+
proto.ParseFromString(feature_view_proto)
252+
feature_view = FeatureView.from_proto(proto)
253+
254+
# load
255+
with open(repo_config_file, "rb") as f:
256+
repo_config = dill.load(f)
257+
258+
provider = PassthroughProvider(repo_config)
259+
online_store = provider.online_store
260+
261+
table = pyarrow.Table.from_pandas(pdf)
262+
263+
if feature_view.batch_source.field_mapping is not None:
264+
table = _run_pyarrow_field_mapping(
265+
table, feature_view.batch_source.field_mapping
266+
)
267+
268+
join_key_to_value_type = {
269+
entity.name: entity.dtype.to_value_type()
270+
for entity in feature_view.entity_columns
271+
}
272+
273+
rows_to_write = _convert_arrow_to_proto(table, feature_view, join_key_to_value_type)
274+
online_store.online_write_batch(
275+
repo_config,
276+
feature_view,
277+
rows_to_write,
278+
lambda x: None,
279+
)
280+
pdf["success_flag"] = "SUCCESS"
281+
282+
return pdf

sdk/python/feast/repo_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"snowflake.engine": "feast.infra.materialization.snowflake_engine.SnowflakeMaterializationEngine",
4040
"lambda": "feast.infra.materialization.aws_lambda.lambda_engine.LambdaMaterializationEngine",
4141
"bytewax": "feast.infra.materialization.contrib.bytewax.bytewax_materialization_engine.BytewaxMaterializationEngine",
42+
"spark": "feast.infra.materialization.contrib.spark.spark_materialization_engine.SparkMaterializationEngine",
4243
}
4344

4445
ONLINE_STORE_CLASS_FOR_TYPE = {
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from datetime import timedelta
2+
3+
import pytest
4+
5+
from feast.entity import Entity
6+
from feast.feature_view import FeatureView
7+
from feast.field import Field
8+
from feast.infra.offline_stores.contrib.spark_offline_store.tests.data_source import (
9+
SparkDataSourceCreator,
10+
)
11+
from feast.types import Float32
12+
from tests.data.data_creator import create_basic_driver_dataset
13+
from tests.integration.feature_repos.integration_test_repo_config import (
14+
IntegrationTestRepoConfig,
15+
)
16+
from tests.integration.feature_repos.repo_configuration import (
17+
construct_test_environment,
18+
)
19+
from tests.utils.e2e_test_validation import validate_offline_online_store_consistency
20+
21+
22+
@pytest.mark.integration
23+
def test_spark_materialization_consistency():
24+
spark_config = IntegrationTestRepoConfig(
25+
provider="local",
26+
online_store={
27+
"type": "redis",
28+
# "path": "data/online_store.db"
29+
},
30+
offline_store_creator=SparkDataSourceCreator,
31+
batch_engine={"type": "spark", "batch_size": 10},
32+
)
33+
spark_environment = construct_test_environment(
34+
spark_config, None, entity_key_serialization_version=1
35+
)
36+
37+
df = create_basic_driver_dataset()
38+
39+
# # generate a large data set
40+
# now = datetime.utcnow().replace(microsecond=0, second=0, minute=0)
41+
42+
# ts = pd.Timestamp(now).round("ms")
43+
44+
# size = 10000
45+
# driver_id = np.array(list(range(size)))
46+
# value = np.array([round(np.random.uniform(size=1)[0],3) for x in list(range(size))])
47+
# ts_1 = [ts - timedelta(hours=4) for x in range(size)]
48+
# created_ts = np.repeat(ts,repeats=size)
49+
50+
# df = pd.DataFrame({
51+
# "driver_id": driver_id,
52+
# "value": value,
53+
# "ts_1": ts_1,
54+
# "created_ts": created_ts
55+
56+
# })
57+
# print(df)
58+
59+
ds = spark_environment.data_source_creator.create_data_source(
60+
df,
61+
spark_environment.feature_store.project,
62+
field_mapping={"ts_1": "ts"},
63+
)
64+
65+
fs = spark_environment.feature_store
66+
driver = Entity(
67+
name="driver_id",
68+
join_keys=["driver_id"],
69+
)
70+
71+
driver_stats_fv = FeatureView(
72+
name="driver_hourly_stats",
73+
entities=[driver],
74+
ttl=timedelta(weeks=52),
75+
schema=[Field(name="value", dtype=Float32)],
76+
source=ds,
77+
)
78+
79+
try:
80+
81+
fs.apply([driver, driver_stats_fv])
82+
83+
print(df)
84+
85+
# materialization is run in two steps and
86+
# we use timestamp from generated dataframe as a split point
87+
split_dt = df["ts_1"][4].to_pydatetime() - timedelta(seconds=1)
88+
89+
print(f"Split datetime: {split_dt}")
90+
91+
validate_offline_online_store_consistency(fs, driver_stats_fv, split_dt)
92+
finally:
93+
fs.teardown()
94+
95+
96+
if __name__ == "__main__":
97+
test_spark_materialization_consistency()
98+
print()

0 commit comments

Comments
 (0)