@@ -167,21 +167,15 @@ def _materialize_one(
167167 )
168168 )
169169
170- # serialize feature view using proto
171- feature_view_proto = feature_view .to_proto ().SerializeToString ()
172-
173- # serialize repo_config to disk. Will be used to instantiate the online store
174- repo_config_file = tempfile .NamedTemporaryFile (delete = False ).name
175- with open (repo_config_file , "wb" ) as f :
176- dill .dump (self .repo_config , f )
170+ spark_serialized_artifacts = _SparkSerializedArtifacts .serialize (
171+ feature_view = feature_view , repo_config = self .repo_config
172+ )
177173
178174 # split data into batches
179175 spark_df = offline_job .to_spark_df ()
180176 batch_size = self .repo_config .batch_engine .batch_size
181177 batched_spark_df , batch_column_alias = _add_batch_column (
182178 spark_df ,
183- join_key_columns = join_key_columns ,
184- timestamp_field = timestamp_field ,
185179 batch_size = batch_size ,
186180 )
187181
@@ -191,11 +185,7 @@ def _materialize_one(
191185 ]
192186 schema_ddl = ", " .join (schema )
193187 result = batched_spark_df .groupBy (batch_column_alias ).applyInPandas (
194- lambda x : _process_by_pandas_batch (
195- x ,
196- feature_view_proto = feature_view_proto ,
197- repo_config_file = repo_config_file ,
198- ),
188+ lambda x : _process_by_pandas_batch (x , spark_serialized_artifacts ),
199189 schema = schema_ddl ,
200190 )
201191 result .collect ()
@@ -209,9 +199,7 @@ def _materialize_one(
209199 )
210200
211201
212- def _add_batch_column (
213- spark_df : DataFrame , join_key_columns , timestamp_field , batch_size
214- ):
202+ def _add_batch_column (spark_df : DataFrame , batch_size ):
215203 """
216204 Generates a batch column for a data frame
217205 """
@@ -244,19 +232,48 @@ def _add_batch_column(
244232 return batched_spark_df , batch_column_alias
245233
246234
247- def _process_by_pandas_batch (pdf , feature_view_proto , repo_config_file ):
235+ @dataclass
236+ class _SparkSerializedArtifacts :
237+ """Class to assist with serializing unpicklable artifacts to the spark workers"""
248238
249- # unserialize
250- proto = FeatureViewProto ()
251- proto .ParseFromString (feature_view_proto )
252- feature_view = FeatureView .from_proto (proto )
239+ feature_view_proto : str
240+ repo_config_file : str
253241
254- # load
255- with open (repo_config_file , "rb" ) as f :
256- repo_config = dill .load (f )
242+ @classmethod
243+ def serialize (cls , feature_view , repo_config ):
257244
258- provider = PassthroughProvider (repo_config )
259- online_store = provider .online_store
245+ # serialize to proto
246+ feature_view_proto = feature_view .to_proto ().SerializeToString ()
247+
248+ # serialize repo_config to disk. Will be used to instantiate the online store
249+ repo_config_file = tempfile .NamedTemporaryFile (delete = False ).name
250+ with open (repo_config_file , "wb" ) as f :
251+ dill .dump (repo_config , f )
252+
253+ return _SparkSerializedArtifacts (
254+ feature_view_proto = feature_view_proto , repo_config_file = repo_config_file
255+ )
256+
257+ def unserialize (self ):
258+ # unserialize
259+ proto = FeatureViewProto ()
260+ proto .ParseFromString (self .feature_view_proto )
261+ feature_view = FeatureView .from_proto (proto )
262+
263+ # load
264+ with open (self .repo_config_file , "rb" ) as f :
265+ repo_config = dill .load (f )
266+
267+ provider = PassthroughProvider (repo_config )
268+ online_store = provider .online_store
269+ return feature_view , online_store , repo_config
270+
271+
272+ def _process_by_pandas_batch (
273+ pdf , spark_serialized_artifacts : _SparkSerializedArtifacts
274+ ):
275+ """Load pandas df to online store"""
276+ feature_view , online_store , repo_config = spark_serialized_artifacts .unserialize ()
260277
261278 table = pyarrow .Table .from_pandas (pdf )
262279
0 commit comments