Skip to content

Commit f609cfb

Browse files
committed
switch to using foreachPartition
Signed-off-by: niklasvm <niklasvm@gmail.com>
1 parent b42352e commit f609cfb

1 file changed

Lines changed: 18 additions & 59 deletions

File tree

sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py

Lines changed: 18 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import tempfile
2-
import uuid
32
from dataclasses import dataclass
43
from datetime import datetime
54
from typing import Callable, List, Literal, Optional, Sequence, Union
65

76
import dill
7+
import pandas as pd
88
import pyarrow
9-
from pyspark.sql import DataFrame
109
from tqdm import tqdm
1110

1211
from feast.batch_feature_view import BatchFeatureView
@@ -171,25 +170,11 @@ def _materialize_one(
171170
feature_view=feature_view, repo_config=self.repo_config
172171
)
173172

174-
# split data into batches
175173
spark_df = offline_job.to_spark_df()
176-
batch_size = self.repo_config.batch_engine.batch_size
177-
batched_spark_df, batch_column_alias = _add_batch_column(
178-
spark_df,
179-
batch_size=batch_size,
174+
spark_df.foreachPartition(
175+
lambda x: _process_by_partition(x, spark_serialized_artifacts)
180176
)
181177

182-
schema = [
183-
f"{x} {y}"
184-
for x, y in batched_spark_df.dtypes + [("success_flag", "string")]
185-
]
186-
schema_ddl = ", ".join(schema)
187-
result = batched_spark_df.groupBy(batch_column_alias).applyInPandas(
188-
lambda x: _process_by_pandas_batch(x, spark_serialized_artifacts),
189-
schema=schema_ddl,
190-
)
191-
result.collect()
192-
193178
return SparkMaterializationJob(
194179
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
195180
)
@@ -199,39 +184,6 @@ def _materialize_one(
199184
)
200185

201186

202-
def _add_batch_column(spark_df: DataFrame, batch_size):
203-
"""
204-
Generates a batch column for a data frame
205-
"""
206-
spark_session = spark_df.sparkSession
207-
208-
# generate a unique name for the view
209-
view_name = f"{uuid.uuid4()}".replace("-", "")
210-
211-
row_number_index_alias = f"{view_name}_row_index"
212-
batch_column_alias = f"{view_name}_batch"
213-
original_columns_snippet = ", ".join(spark_df.columns)
214-
215-
# generate batch
216-
spark_df.createOrReplaceTempView(view_name)
217-
batched_spark_df = spark_session.sql(
218-
f"""
219-
with add_index as (
220-
select
221-
{original_columns_snippet},
222-
monotonically_increasing_id() as {row_number_index_alias}
223-
from {view_name}
224-
)
225-
select
226-
{original_columns_snippet},
227-
floor({(row_number_index_alias)}/{batch_size}) as {batch_column_alias}
228-
from add_index
229-
"""
230-
)
231-
232-
return batched_spark_df, batch_column_alias
233-
234-
235187
@dataclass
236188
class _SparkSerializedArtifacts:
237189
"""Class to assist with serializing unpicklable artifacts to the spark workers"""
@@ -269,13 +221,23 @@ def unserialize(self):
269221
return feature_view, online_store, repo_config
270222

271223

272-
def _process_by_pandas_batch(
273-
pdf, spark_serialized_artifacts: _SparkSerializedArtifacts
274-
):
224+
def _process_by_partition(rows, spark_serialized_artifacts: _SparkSerializedArtifacts):
275225
"""Load pandas df to online store"""
276-
feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize()
277226

278-
table = pyarrow.Table.from_pandas(pdf)
227+
# convert to pyarrow table
228+
dicts = []
229+
for row in rows:
230+
dicts.append(row.asDict())
231+
232+
df = pd.DataFrame.from_records(dicts)
233+
if df.shape[0] == 0:
234+
print("Skipping")
235+
return
236+
237+
table = pyarrow.Table.from_pandas(df)
238+
239+
# unserialize artifacts
240+
feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize()
279241

280242
if feature_view.batch_source.field_mapping is not None:
281243
table = _run_pyarrow_field_mapping(
@@ -294,6 +256,3 @@ def _process_by_pandas_batch(
294256
rows_to_write,
295257
lambda x: None,
296258
)
297-
pdf["success_flag"] = "SUCCESS"
298-
299-
return pdf

0 commit comments

Comments
 (0)