11import tempfile
2- import uuid
32from dataclasses import dataclass
43from datetime import datetime
54from typing import Callable , List , Literal , Optional , Sequence , Union
65
76import dill
7+ import pandas as pd
88import pyarrow
9- from pyspark .sql import DataFrame
109from tqdm import tqdm
1110
1211from feast .batch_feature_view import BatchFeatureView
@@ -171,25 +170,11 @@ def _materialize_one(
171170 feature_view = feature_view , repo_config = self .repo_config
172171 )
173172
174- # split data into batches
175173 spark_df = offline_job .to_spark_df ()
176- batch_size = self .repo_config .batch_engine .batch_size
177- batched_spark_df , batch_column_alias = _add_batch_column (
178- spark_df ,
179- batch_size = batch_size ,
174+ spark_df .foreachPartition (
175+ lambda x : _process_by_partition (x , spark_serialized_artifacts )
180176 )
181177
182- schema = [
183- f"{ x } { y } "
184- for x , y in batched_spark_df .dtypes + [("success_flag" , "string" )]
185- ]
186- schema_ddl = ", " .join (schema )
187- result = batched_spark_df .groupBy (batch_column_alias ).applyInPandas (
188- lambda x : _process_by_pandas_batch (x , spark_serialized_artifacts ),
189- schema = schema_ddl ,
190- )
191- result .collect ()
192-
193178 return SparkMaterializationJob (
194179 job_id = job_id , status = MaterializationJobStatus .SUCCEEDED
195180 )
@@ -199,39 +184,6 @@ def _materialize_one(
199184 )
200185
201186
202- def _add_batch_column (spark_df : DataFrame , batch_size ):
203- """
204- Generates a batch column for a data frame
205- """
206- spark_session = spark_df .sparkSession
207-
208- # generate a unique name for the view
209- view_name = f"{ uuid .uuid4 ()} " .replace ("-" , "" )
210-
211- row_number_index_alias = f"{ view_name } _row_index"
212- batch_column_alias = f"{ view_name } _batch"
213- original_columns_snippet = ", " .join (spark_df .columns )
214-
215- # generate batch
216- spark_df .createOrReplaceTempView (view_name )
217- batched_spark_df = spark_session .sql (
218- f"""
219- with add_index as (
220- select
221- { original_columns_snippet } ,
222- monotonically_increasing_id() as { row_number_index_alias }
223- from { view_name }
224- )
225- select
226- { original_columns_snippet } ,
227- floor({ (row_number_index_alias )} /{ batch_size } ) as { batch_column_alias }
228- from add_index
229- """
230- )
231-
232- return batched_spark_df , batch_column_alias
233-
234-
235187@dataclass
236188class _SparkSerializedArtifacts :
237189 """Class to assist with serializing unpicklable artifacts to the spark workers"""
@@ -269,13 +221,23 @@ def unserialize(self):
269221 return feature_view , online_store , repo_config
270222
271223
272- def _process_by_pandas_batch (
273- pdf , spark_serialized_artifacts : _SparkSerializedArtifacts
274- ):
224+ def _process_by_partition (rows , spark_serialized_artifacts : _SparkSerializedArtifacts ):
275225 """Load pandas df to online store"""
276- feature_view , online_store , repo_config = spark_serialized_artifacts .unserialize ()
277226
278- table = pyarrow .Table .from_pandas (pdf )
227+ # convert to pyarrow table
228+ dicts = []
229+ for row in rows :
230+ dicts .append (row .asDict ())
231+
232+ df = pd .DataFrame .from_records (dicts )
233+ if df .shape [0 ] == 0 :
234+ print ("Skipping" )
235+ return
236+
237+ table = pyarrow .Table .from_pandas (df )
238+
239+ # unserialize artifacts
240+ feature_view , online_store , repo_config = spark_serialized_artifacts .unserialize ()
279241
280242 if feature_view .batch_source .field_mapping is not None :
281243 table = _run_pyarrow_field_mapping (
@@ -294,6 +256,3 @@ def _process_by_pandas_batch(
294256 rows_to_write ,
295257 lambda x : None ,
296258 )
297- pdf ["success_flag" ] = "SUCCESS"
298-
299- return pdf
0 commit comments