@@ -11,8 +11,9 @@ import com.linkedin.feathr.offline.job.FeatureTransformation.FEATURE_NAME_PREFIX
1111import com .linkedin .feathr .offline .join .algorithms .{SeqJoinExplodedJoinKeyColumnAppender , SequentialJoinConditionBuilder , SparkJoinWithJoinCondition }
1212import com .linkedin .feathr .offline .logical .FeatureGroups
1313import com .linkedin .feathr .offline .mvel .plugins .FeathrExpressionExecutionContext
14+ import com .linkedin .feathr .offline .util .FeathrUtils
1415import com .linkedin .feathr .offline .{TestFeathr , TestUtils }
15- import org .apache .spark .SparkException
16+ import org .apache .spark .{ SparkConf , SparkContext , SparkException }
1617import org .apache .spark .sql .functions .{when => _ , _ }
1718import org .apache .spark .sql .types ._
1819import org .apache .spark .sql .{AnalysisException , DataFrame , Row , SparkSession }
@@ -985,6 +986,12 @@ class TestSequentialJoinAsDerivation extends TestFeathr with MockitoSugar {
985986 val mockDerivationFunction = mock[SeqJoinDerivationFunction ]
986987 val mockBaseTaggedDependency = mock[BaseTaggedDependency ]
987988 val mockTaggedDependency = mock[TaggedDependency ]
989+ val mockSparkContext = mock[SparkContext ]
990+ val mockSparkConf = mock[SparkConf ]
991+ when(mockSparkContext.getConf).thenReturn(mockSparkConf)
992+ when(mockSparkSession.sparkContext).thenReturn(mockSparkContext)
993+ when(mockSparkConf.get(s " ${FeathrUtils .FEATHR_PARAMS_PREFIX }${FeathrUtils .ADD_DEFAULT_COL_FOR_MISSING_DATA }" , " false" ))
994+ .thenReturn(" false" )
988995 // mock derivation function
989996 when(mockDerivedFeature.derivation.asInstanceOf [SeqJoinDerivationFunction ]).thenReturn(mockDerivationFunction)
990997 when(mockDerivedFeature.producedFeatureNames).thenReturn(Seq (" seqJoinFeature" ))
@@ -1059,6 +1066,13 @@ class TestSequentialJoinAsDerivation extends TestFeathr with MockitoSugar {
10591066 val mockDerivationFunction = mock[SeqJoinDerivationFunction ]
10601067 val mockBaseTaggedDependency = mock[BaseTaggedDependency ]
10611068 val mockTaggedDependency = mock[TaggedDependency ]
1069+ val mockSparkConf = mock[SparkConf ]
1070+ val mockSparkContext = mock[SparkContext ]
1071+ when(mockSparkSession.sparkContext).thenReturn(mockSparkContext)
1072+ when(mockSparkContext.getConf).thenReturn(mockSparkConf)
1073+ when(mockSparkConf.get(s " ${FeathrUtils .FEATHR_PARAMS_PREFIX }${FeathrUtils .ADD_DEFAULT_COL_FOR_MISSING_DATA }" ,
1074+ " false" ))
1075+ .thenReturn(" false" )
10621076 // mock derivation function
10631077 when(mockDerivedFeature.derivation.asInstanceOf [SeqJoinDerivationFunction ]).thenReturn(mockDerivationFunction)
10641078 when(mockDerivedFeature.producedFeatureNames).thenReturn(Seq (" seqJoinFeature" ))
0 commit comments