1- import json
21import os
32import time
43import uuid
@@ -265,7 +264,7 @@ def _stage_file(self, file_path: str, job_id: str) -> str:
265264 return blob_uri_str
266265
267266 def dataproc_submit (
268- self , job_params : SparkJobParameters
267+ self , job_params : SparkJobParameters , extra_properties : Dict [ str , str ]
269268 ) -> Tuple [Job , Callable [[], Job ], Callable [[], None ]]:
270269 local_job_id = str (uuid .uuid4 ())
271270 main_file_uri = self ._stage_file (job_params .get_main_file_path (), local_job_id )
@@ -280,18 +279,22 @@ def dataproc_submit(
280279 job_config ["labels" ][self .JOB_HASH_LABEL_KEY ] = job_params .get_job_hash ()
281280
282281 if job_params .get_class_name ():
282+ properties = {
283+ "spark.yarn.user.classpath.first" : "true" ,
284+ "spark.executor.instances" : self .executor_instances ,
285+ "spark.executor.cores" : self .executor_cores ,
286+ "spark.executor.memory" : self .executor_memory ,
287+ }
288+
289+ properties .update (extra_properties )
290+
283291 job_config .update (
284292 {
285293 "spark_job" : {
286294 "jar_file_uris" : [main_file_uri ] + self .EXTERNAL_JARS ,
287295 "main_class" : job_params .get_class_name (),
288296 "args" : job_params .get_arguments (),
289- "properties" : {
290- "spark.yarn.user.classpath.first" : "true" ,
291- "spark.executor.instances" : self .executor_instances ,
292- "spark.executor.cores" : self .executor_cores ,
293- "spark.executor.memory" : self .executor_memory ,
294- },
297+ "properties" : properties ,
295298 }
296299 }
297300 )
@@ -302,6 +305,7 @@ def dataproc_submit(
302305 "main_python_file_uri" : main_file_uri ,
303306 "jar_file_uris" : self .EXTERNAL_JARS ,
304307 "args" : job_params .get_arguments (),
308+ "properties" : extra_properties if extra_properties else {},
305309 }
306310 }
307311 )
@@ -332,21 +336,23 @@ def dataproc_cancel(self, job_id):
332336 def historical_feature_retrieval (
333337 self , job_params : RetrievalJobParameters
334338 ) -> RetrievalJob :
335- job , refresh_fn , cancel_fn = self .dataproc_submit (job_params )
339+ job , refresh_fn , cancel_fn = self .dataproc_submit (
340+ job_params , {"dev.feast.outputuri" : job_params .get_destination_path ()}
341+ )
336342 return DataprocRetrievalJob (
337343 job , refresh_fn , cancel_fn , job_params .get_destination_path ()
338344 )
339345
340346 def offline_to_online_ingestion (
341347 self , ingestion_job_params : BatchIngestionJobParameters
342348 ) -> BatchIngestionJob :
343- job , refresh_fn , cancel_fn = self .dataproc_submit (ingestion_job_params )
349+ job , refresh_fn , cancel_fn = self .dataproc_submit (ingestion_job_params , {} )
344350 return DataprocBatchIngestionJob (job , refresh_fn , cancel_fn )
345351
346352 def start_stream_to_online_ingestion (
347353 self , ingestion_job_params : StreamIngestionJobParameters
348354 ) -> StreamIngestionJob :
349- job , refresh_fn , cancel_fn = self .dataproc_submit (ingestion_job_params )
355+ job , refresh_fn , cancel_fn = self .dataproc_submit (ingestion_job_params , {} )
350356 job_hash = ingestion_job_params .get_job_hash ()
351357 return DataprocStreamingIngestionJob (job , refresh_fn , cancel_fn , job_hash )
352358
@@ -368,7 +374,7 @@ def _dataproc_job_to_spark_job(self, job: Job) -> SparkJob:
368374 cancel_fn = partial (self .dataproc_cancel , job_id )
369375
370376 if job_type == SparkJobType .HISTORICAL_RETRIEVAL .name .lower ():
371- output_path = json . loads ( job .pyspark_job .args [ - 1 ])[ "path" ]
377+ output_path = job .pyspark_job .properties . get ( "dev.feast.outputuri" )
372378 return DataprocRetrievalJob (job , refresh_fn , cancel_fn , output_path )
373379
374380 if job_type == SparkJobType .BATCH_INGESTION .name .lower ():
0 commit comments