Skip to content

Commit 04e89fd

Browse files
rakeshkashyap123Rakesh Kashyap Hanasoge Padmanabha
andauthored
Fix bug in SWA with missing feature data (#1171)
* Fix bug in SWA with missing feature data * remove unwanted code * Address feedback and version bump --------- Co-authored-by: Rakesh Kashyap Hanasoge Padmanabha <rkashyap@rkashyap-mn3.linkedin.biz>
1 parent c06699a commit 04e89fd

File tree

3 files changed

+161
-8
lines changed

3 files changed

+161
-8
lines changed

feathr-impl/src/main/scala/com/linkedin/feathr/offline/client/FeathrClient.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:
231231

232232
val (joinedDF, header) = doJoinObsAndFeatures(joinConfig, jobContext, obsData)
233233
(joinedDF, header, Map(SuppressedExceptionHandlerUtils.MISSING_DATA_EXCEPTION
234-
-> SuppressedExceptionHandlerUtils.missingFeatures.mkString))
234+
-> SuppressedExceptionHandlerUtils.missingFeatures.mkString(", ")))
235235
}
236236

237237
/**

feathr-impl/src/main/scala/com/linkedin/feathr/offline/swa/SlidingWindowAggregationJoiner.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ private[offline] class SlidingWindowAggregationJoiner(
177177
} else if (originalSourceDf.isEmpty && shouldAddDefaultColForMissingData) { // If add default col for missing data flag features
178178
// flag is set and there is a data related error, an empty dataframe will be returned.
179179
res.map(emptyFeatures.add)
180-
val exceptionMsg = emptyFeatures.mkString
181180
log.warn(s"Missing data for features ${emptyFeatures}. Default values will be populated for this column.")
182181
SuppressedExceptionHandlerUtils.missingFeatures ++= emptyFeatures
183182
anchors.map(anchor => (anchor, originalSourceDf))
@@ -299,7 +298,7 @@ private[offline] class SlidingWindowAggregationJoiner(
299298
substituteDefaults(withFDSFeatureDF, defaults.keys.filter(joinedFeatures.contains).toSeq, defaults, userSpecifiedTypesConfig, ss)
300299

301300
allInferredFeatureTypes ++= inferredTypes
302-
contextDF = standardizeFeatureColumnNames(origContextObsColumns, withFeatureContextDF, joinedFeatures, keyTags.map(keyTagList))
301+
contextDF = standardizeFeatureColumnNames(ss, origContextObsColumns, withFeatureContextDF, joinedFeatures, keyTags.map(keyTagList))
303302
if (shouldCheckPoint(ss)) {
304303
// checkpoint complicated dataframe for each stage to avoid Spark failure
305304
contextDF = contextDF.checkpoint(true)
@@ -325,13 +324,19 @@ private[offline] class SlidingWindowAggregationJoiner(
325324
* @return
326325
*/
327326
def standardizeFeatureColumnNames(
327+
ss: SparkSession,
328328
origContextObsColumns: Seq[String],
329329
withSWAFeatureDF: DataFrame,
330330
featureNames: Seq[String],
331331
keyTags: Seq[String]): DataFrame = {
332332
val inputColumnSize = origContextObsColumns.size
333333
val outputColumnNum = withSWAFeatureDF.columns.size
334-
if (outputColumnNum != inputColumnSize + featureNames.size) {
334+
val shouldAddDefaultColForMissingData = FeathrUtils.getFeathrJobParam(ss.sparkContext.getConf,
335+
FeathrUtils.ADD_DEFAULT_COL_FOR_MISSING_DATA).toBoolean
336+
337+
// Do not perform this check if shouldAddDefaultColForMissingData is true as we add the null values to all SWA features at once,
338+
// and do not care for the SWA groupings.
339+
if (!shouldAddDefaultColForMissingData && (outputColumnNum != inputColumnSize + featureNames.size)) {
335340
throw new FeathrIllegalStateException(
336341
s"Number of columns (${outputColumnNum}) in the dataframe returned by " +
337342
s"sliding window aggregation does not equal to number of columns in the observation data (${inputColumnSize}) " +

feathr-impl/src/test/scala/com/linkedin/feathr/offline/SlidingWindowAggIntegTest.scala

Lines changed: 152 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import com.linkedin.feathr.offline.transformation.MultiLevelAggregationTransform
55
import com.linkedin.feathr.offline.util.FeathrUtils
66
import com.linkedin.feathr.offline.util.FeathrUtils.{FILTER_NULLS, SKIP_MISSING_FEATURE, setFeathrJobParam}
77
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
8-
import org.apache.spark.sql.functions._
98
import org.apache.spark.sql.types.{LongType, StructField, StructType}
109
import org.apache.spark.sql.{DataFrame, Row}
1110
import org.testng.Assert._
@@ -19,10 +18,7 @@ import scala.collection.mutable
1918

2019

2120
import org.apache.spark.sql.SparkSession
22-
import org.apache.spark.sql.functions._
2321
import org.apache.spark.sql.types.{StringType, TimestampType}
24-
25-
import scala.concurrent.duration._
2622
class SlidingWindowAggIntegTest extends FeathrIntegTest {
2723

2824
def getDf(): DataFrame = {
@@ -259,6 +255,157 @@ class SlidingWindowAggIntegTest extends FeathrIntegTest {
259255
assertEquals(row1f1f1, TestUtils.build1dSparseTensorFDSRow(Array("f1t1"), Array(12.0f)))
260256
}
261257

258+
/**
259+
* test SWA with lateralview parameters and ADD_DEFAULT_COL_FOR_MISSING_DATA flag set
260+
*/
261+
@Test
262+
def testLocalAnchorSWATestWithDataMissingFlagSet: Unit = {
263+
setFeathrJobParam(FeathrUtils.ADD_DEFAULT_COL_FOR_MISSING_DATA, "true")
264+
val df = runLocalFeatureJoinForTest(
265+
joinConfigAsString =
266+
"""
267+
| settings: {
268+
| observationDataTimeSettings: {
269+
| absoluteTimeRange: {
270+
| startTime: "2018-05-01"
271+
| endTime: "2018-05-03"
272+
| timeFormat: "yyyy-MM-dd"
273+
| }
274+
| }
275+
| joinTimeSettings: {
276+
| timestampColumn: {
277+
| def: timestamp
278+
| format: "yyyy-MM-dd"
279+
| }
280+
| }
281+
|}
282+
|
283+
|features: [
284+
| {
285+
| key: [x],
286+
| featureList: ["f1", "f1Sum", "f2", "f1f1"]
287+
| },
288+
| {
289+
| key: [x, y]
290+
| featureList: ["f3", "f4"]
291+
| }
292+
|]
293+
""".stripMargin,
294+
featureDefAsString =
295+
"""
296+
|sources: {
297+
| ptSource: {
298+
| type: "PASSTHROUGH"
299+
| }
300+
| swaSource: {
301+
| location: { path: "missingData/localSWAAnchorTestFeatureData/daily" }
302+
| timePartitionPattern: "yyyy/MM/dd"
303+
| timeWindowParameters: {
304+
| timestampColumn: "timestamp"
305+
| timestampColumnFormat: "yyyy-MM-dd"
306+
| }
307+
| }
308+
|}
309+
|
310+
|anchors: {
311+
| ptAnchor: {
312+
| source: "ptSource"
313+
| key: "x"
314+
| features: {
315+
| f1f1: {
316+
| def: "([$.term:$.value] in passthroughFeatures if $.name == 'f1f1')"
317+
| }
318+
| }
319+
| }
320+
| swaAnchor: {
321+
| source: "swaSource"
322+
| key: "substring(x, 0)"
323+
| lateralViewParameters: {
324+
| lateralViewDef: explode(features)
325+
| lateralViewItemAlias: feature
326+
| }
327+
| features: {
328+
| f1: {
329+
| def: "feature.col.value"
330+
| filter: "feature.col.name = 'f1'"
331+
| aggregation: SUM
332+
| groupBy: "feature.col.term"
333+
| window: 3d
334+
| }
335+
| }
336+
| }
337+
|
338+
| swaAnchor2: {
339+
| source: "swaSource"
340+
| key: "x"
341+
| lateralViewParameters: {
342+
| lateralViewDef: explode(features)
343+
| lateralViewItemAlias: feature
344+
| }
345+
| features: {
346+
| f1Sum: {
347+
| def: "feature.col.value"
348+
| filter: "feature.col.name = 'f1'"
349+
| aggregation: SUM
350+
| groupBy: "feature.col.term"
351+
| window: 3d
352+
| }
353+
| }
354+
| }
355+
| swaAnchorWithKeyExtractor: {
356+
| source: "swaSource"
357+
| keyExtractor: "com.linkedin.feathr.offline.anchored.keyExtractor.SimpleSampleKeyExtractor"
358+
| features: {
359+
| f3: {
360+
| def: "aggregationWindow"
361+
| aggregation: SUM
362+
| window: 3d
363+
| }
364+
| }
365+
| }
366+
| swaAnchorWithKeyExtractor2: {
367+
| source: "swaSource"
368+
| keyExtractor: "com.linkedin.feathr.offline.anchored.keyExtractor.SimpleSampleKeyExtractor"
369+
| features: {
370+
| f4: {
371+
| def: "aggregationWindow"
372+
| aggregation: SUM
373+
| window: 3d
374+
| }
375+
| }
376+
| }
377+
| swaAnchorWithKeyExtractor3: {
378+
| source: "swaSource"
379+
| keyExtractor: "com.linkedin.feathr.offline.anchored.keyExtractor.SimpleSampleKeyExtractor2"
380+
| lateralViewParameters: {
381+
| lateralViewDef: explode(features)
382+
| lateralViewItemAlias: feature
383+
| }
384+
| features: {
385+
| f2: {
386+
| def: "feature.col.value"
387+
| filter: "feature.col.name = 'f2'"
388+
| aggregation: SUM
389+
| groupBy: "feature.col.term"
390+
| window: 3d
391+
| }
392+
| }
393+
| }
394+
|}
395+
""".stripMargin,
396+
"slidingWindowAgg/localAnchorTestObsData.avro.json").data
397+
df.show()
398+
399+
// validate output in name term value format
400+
val featureList = df.collect().sortBy(row => if (row.get(0) != null) row.getAs[String]("x") else "null")
401+
val row0 = featureList(0)
402+
val row0f1 = row0.getAs[Row]("f1")
403+
assertEquals(row0f1, null)
404+
val row0f2 = row0.getAs[Row]("f2")
405+
assertEquals(row0f2, null)
406+
setFeathrJobParam(FeathrUtils.ADD_DEFAULT_COL_FOR_MISSING_DATA, "false")
407+
}
408+
262409
/**
263410
* test SWA with lateralview parameters
264411
*/
@@ -869,6 +1016,7 @@ class SlidingWindowAggIntegTest extends FeathrIntegTest {
8691016
| timestampColumnFormat: "yyyy-MM-dd"
8701017
| }
8711018
| }
1019+
|
8721020
|}
8731021
|
8741022
|anchors: {

0 commit comments

Comments
 (0)