Skip to content

Commit c239d22

Browse files
rakeshkashyap123Rakesh Kashyap Hanasoge Padmanabhajaymo001
authored
Add default column for missing features (#1158)
* Add default column for missing features * Fix failing test * Fix SWA sparksession issue * address comments * Add comment * bump version --------- Co-authored-by: Rakesh Kashyap Hanasoge Padmanabha <rkashyap@rkashyap-mn3.linkedin.biz> Co-authored-by: Jinghui Mo <jmo@linkedin.com>
1 parent f0cb5d2 commit c239d22

7 files changed

Lines changed: 244 additions & 37 deletions

File tree

feathr-impl/src/main/scala/com/linkedin/feathr/offline/client/FeathrClient.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package com.linkedin.feathr.offline.client
22

33
import com.linkedin.feathr.common.exception._
4-
import com.linkedin.feathr.common.{FeatureInfo, Header, InternalApi, JoiningFeatureParams, RichConfig, TaggedFeatureName}
4+
import com.linkedin.feathr.common.{FeatureInfo, FeatureTypeConfig, Header, InternalApi, JoiningFeatureParams, RichConfig, TaggedFeatureName}
55
import com.linkedin.feathr.offline.config.sources.FeatureGroupsUpdater
66
import com.linkedin.feathr.offline.config.{FeathrConfig, FeathrConfigLoader, FeatureGroupsGenerator, FeatureJoinConfig}
77
import com.linkedin.feathr.offline.generation.{DataFrameFeatureGenerator, FeatureGenKeyTagAnalyzer, StreamingFeatureGenerator}
@@ -12,8 +12,10 @@ import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
1212
import com.linkedin.feathr.offline.source.DataSource
1313
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
1414
import com.linkedin.feathr.offline.swa.SWAHandler
15+
import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstituter.substituteDefaults
1516
import com.linkedin.feathr.offline.util._
1617
import org.apache.logging.log4j.LogManager
18+
import org.apache.spark.sql.functions.lit
1719
import org.apache.spark.sql.internal.SQLConf
1820
import org.apache.spark.sql.{DataFrame, SparkSession}
1921

@@ -310,6 +312,8 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:
310312

311313
var logicalPlan = logicalPlanner.getLogicalPlan(updatedFeatureGroups, keyTaggedFeatures)
312314
val shouldSkipFeature = FeathrUtils.getFeathrJobParam(sparkSession.sparkContext.getConf, FeathrUtils.SKIP_MISSING_FEATURE).toBoolean
315+
val shouldAddDefault = FeathrUtils.getFeathrJobParam(sparkSession.sparkContext.getConf, FeathrUtils.ADD_DEFAULT_COL_FOR_MISSING_DATA).toBoolean
316+
var leftRenamed = left
313317
val featureToPathsMap = (for {
314318
requiredFeature <- logicalPlan.allRequiredFeatures
315319
featureAnchorWithSource <- allAnchoredFeatures.get(requiredFeature.getFeatureName)
@@ -323,6 +327,8 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:
323327
val featureGroupsWithoutInvalidFeatures = FeatureGroupsUpdater()
324328
.getUpdatedFeatureGroupsWithoutInvalidPaths(featureToPathsMap, updatedFeatureGroups, featurePathsTest._2)
325329
logicalPlanner.getLogicalPlan(featureGroupsWithoutInvalidFeatures, keyTaggedFeatures)
330+
} else if (shouldAddDefault) {
331+
// dont throw error if this flag is set, the missing data will be handled at a later step.
326332
} else {
327333
throw new FeathrInputDataException(
328334
ErrorLabel.FEATHR_USER_ERROR,
@@ -358,7 +364,6 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:
358364
val renameFeatures = conflictsAutoCorrectionSetting.get.renameFeatureList
359365
val suffix = conflictsAutoCorrectionSetting.get.suffix
360366
log.warn(s"Found conflicted field names: ${conflictFeatureNames}. Will auto correct them by applying suffix: ${suffix}")
361-
var leftRenamed = left
362367
conflictFeatureNames.foreach(name => {
363368
leftRenamed = leftRenamed.withColumnRenamed(name, name+'_'+suffix)
364369
})
@@ -369,7 +374,8 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:
369374
s"Failed to apply auto correction to solve conflicts. Still got conflicts after applying provided suffix ${suffix} for fields: " +
370375
s"${conflictFeatureNames}. Please provide another suffix or solve conflicts manually.")
371376
}
372-
val (df, header) = joiner.joinFeaturesAsDF(sparkSession, joinConfig, updatedFeatureGroups, keyTaggedFeatures, leftRenamed, rowBloomFilterThreshold)
377+
val (df, header) = joiner.joinFeaturesAsDF(sparkSession, joinConfig, updatedFeatureGroups, keyTaggedFeatures, leftRenamed,
378+
rowBloomFilterThreshold)
373379
if(renameFeatures) {
374380
log.warn(s"Suffix :${suffix} is applied into feature names: ${conflictFeatureNames} to avoid conflicts in outputs")
375381
renameFeatureNames(df, header, conflictFeatureNames, suffix)

feathr-impl/src/main/scala/com/linkedin/feathr/offline/join/workflow/AnchoredFeatureJoinStep.scala

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package com.linkedin.feathr.offline.join.workflow
22

33
import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrFeatureJoinException}
4-
import com.linkedin.feathr.common.{ErasedEntityTaggedFeature, FeatureTypeConfig}
4+
import com.linkedin.feathr.common.{ErasedEntityTaggedFeature, FeatureTypeConfig, FeatureTypes}
55
import com.linkedin.feathr.offline
66
import com.linkedin.feathr.offline.FeatureDataFrame
77
import com.linkedin.feathr.offline.anchored.feature.FeatureAnchorWithSource
@@ -12,15 +12,16 @@ import com.linkedin.feathr.offline.job.KeyedTransformedResult
1212
import com.linkedin.feathr.offline.join._
1313
import com.linkedin.feathr.offline.join.algorithms._
1414
import com.linkedin.feathr.offline.join.util.FrequentItemEstimatorFactory
15+
import com.linkedin.feathr.offline.logical.{LogicalPlan, MultiStageJoinPlan}
1516
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
1617
import com.linkedin.feathr.offline.source.accessor.DataSourceAccessor
1718
import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstituter.substituteDefaults
1819
import com.linkedin.feathr.offline.transformation.DataFrameExt._
19-
import com.linkedin.feathr.offline.util.{DataFrameUtils, FeathrUtils}
20+
import com.linkedin.feathr.offline.util.{DataFrameUtils, FeathrUtils, FeaturizedDatasetUtils}
2021
import com.linkedin.feathr.offline.util.FeathrUtils.shouldCheckPoint
2122
import org.apache.logging.log4j.LogManager
2223
import org.apache.spark.sql.{DataFrame, SparkSession}
23-
import org.apache.spark.sql.functions.lit
24+
import org.apache.spark.sql.functions.{col, lit}
2425

2526
/**
2627
* An abstract class provides default implementation of anchored feature join step
@@ -39,8 +40,64 @@ private[offline] class AnchoredFeatureJoinStep(
3940
extends FeatureJoinStep[AnchorJoinStepInput, DataFrameJoinStepOutput] {
4041
@transient lazy val log = LogManager.getLogger(getClass.getName)
4142

43+
/**
44+
* When the add.default.col.for.missing.data flag is turned, some features could be skipped because of missing data.
45+
* For such anchored features, we will add a feature column with a configured default value (if present in the feature anchor) or
46+
* a null value column.
47+
* @param sparkSession spark session
48+
* @param dataframe the original observation dataframe
49+
* @param logicalPlan logical plan generated using the join config
50+
* @param missingFeatures Map of missing feature names to the corresponding featureAnchorWithSource object.
51+
* @return Dataframe with the missing feature columns added
52+
*/
53+
def substituteDefaultsForDataMissingFeatures(sparkSession: SparkSession, dataframe: DataFrame, logicalPlan: MultiStageJoinPlan,
54+
missingFeatures: Map[String, FeatureAnchorWithSource]): DataFrame = {
55+
// Create a map of feature name to corresponding defaults. If a feature does not have default value, it would be missing
56+
// from this map and we would add a default column of nulls for those features.
57+
val defaults = missingFeatures.flatMap(s => s._2.featureAnchor.defaults)
58+
59+
// Create a map of feature to their feature type if configured.
60+
val featureTypes = missingFeatures
61+
.map(x => Some(x._2.featureAnchor.featureTypeConfigs))
62+
.foldLeft(Map.empty[String, FeatureTypeConfig])((a, b) => a ++ b.getOrElse(Map.empty[String, FeatureTypeConfig]))
63+
64+
// We try to guess the column data type from the configured feature type. If feature type is not present, we will default to
65+
// default feathr behavior of returning a map column of string to float.
66+
val obsDfWithDefaultNullColumn = missingFeatures.keys.foldLeft(dataframe) { (observationDF, featureName) =>
67+
val featureColumnType = if (featureTypes.contains(featureName)) {
68+
featureTypes(featureName).getFeatureType match {
69+
case FeatureTypes.NUMERIC => "float"
70+
case FeatureTypes.BOOLEAN => "boolean"
71+
case FeatureTypes.DENSE_VECTOR => "array<float>"
72+
case FeatureTypes.CATEGORICAL => "string"
73+
case FeatureTypes.CATEGORICAL_SET => "array<string>"
74+
case FeatureTypes.TERM_VECTOR => "map<string,float>"
75+
case FeatureTypes.UNSPECIFIED => "map<string,float>"
76+
case _ => "map<string,float>"
77+
}
78+
} else { // feature type is not configured
79+
"map<string,float>"
80+
}
81+
observationDF.withColumn(DataFrameColName.genFeatureColumnName(FEATURE_NAME_PREFIX + featureName), lit(null).cast(featureColumnType))
82+
}
83+
84+
val dataframeWithDefaults = substituteDefaults(obsDfWithDefaultNullColumn, missingFeatures.keys.toSeq, defaults, featureTypes,
85+
sparkSession, (s: String) => s"${FEATURE_NAME_PREFIX}$s")
86+
87+
// We want to duplicate this column with the correct feathr supported feature name which is required for further processing.
88+
// For example, if feature name is abc and the corresponding key is x, the column name would be __feathr_feature_abc_x.
89+
// This column will be dropped after all the joins are done.
90+
missingFeatures.keys.foldLeft(dataframeWithDefaults) { (dataframeWithDefaults, featureName) =>
91+
val keyTags = logicalPlan.joinStages.filter(kv => kv._2.contains(featureName)).head._1
92+
val keyStr = keyTags.map(logicalPlan.keyTagIntsToStrings).toList
93+
dataframeWithDefaults.withColumn(DataFrameColName.genFeatureColumnName(FEATURE_NAME_PREFIX + featureName, Some(keyStr)),
94+
col(DataFrameColName.genFeatureColumnName(FEATURE_NAME_PREFIX + featureName)))
95+
}
96+
}
97+
4298
/**
4399
* Join anchored features to the observation passed as part of the input context.
100+
*
44101
* @param features Non-window aggregation, basic anchored features.
45102
* @param input input context for this step.
46103
* @param ctx environment variable that contains join job execution context.
@@ -49,10 +106,22 @@ private[offline] class AnchoredFeatureJoinStep(
49106
override def joinFeatures(features: Seq[ErasedEntityTaggedFeature], input: AnchorJoinStepInput)(
50107
implicit ctx: JoinExecutionContext): FeatureDataFrameOutput = {
51108
val AnchorJoinStepInput(observationDF, anchorDFMap) = input
109+
val shouldAddDefault = FeathrUtils.getFeathrJobParam(ctx.sparkSession.sparkContext.getConf,
110+
FeathrUtils.ADD_DEFAULT_COL_FOR_MISSING_DATA).toBoolean
111+
val withMissingFeaturesSubstituted = if (shouldAddDefault) {
112+
val missingFeatures = features.map(x => x.getFeatureName).filter(x => {
113+
val containsFeature: Seq[Boolean] = anchorDFMap.map(y => y._1.selectedFeatures.contains(x)).toSeq
114+
containsFeature.contains(false)
115+
})
116+
val missingAnchoredFeatures = ctx.featureGroups.allAnchoredFeatures.filter(featureName => missingFeatures.contains(featureName._1))
117+
substituteDefaultsForDataMissingFeatures(ctx.sparkSession, observationDF, ctx.logicalPlan,
118+
missingAnchoredFeatures)
119+
}else observationDF
120+
52121
val allAnchoredFeatures: Map[String, FeatureAnchorWithSource] = ctx.featureGroups.allAnchoredFeatures
53122
val joinStages = ctx.logicalPlan.joinStages
54123
val joinOutput = joinStages
55-
.foldLeft(FeatureDataFrame(observationDF, Map.empty[String, FeatureTypeConfig]))((accFeatureDataFrame, joinStage) => {
124+
.foldLeft(FeatureDataFrame(withMissingFeaturesSubstituted, Map.empty[String, FeatureTypeConfig]))((accFeatureDataFrame, joinStage) => {
56125
val (keyTags: Seq[Int], featureNames: Seq[String]) = joinStage
57126
val FeatureDataFrame(contextDF, inferredFeatureTypeMap) = accFeatureDataFrame
58127
// map feature name to its transformed dataframe and the join key of the dataframe

feathr-impl/src/main/scala/com/linkedin/feathr/offline/transformation/AnchorToDataSourceMapper.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ private[offline] class AnchorToDataSourceMapper(dataPathHandlers: List[DataPathH
3737
requiredFeatureAnchors: Seq[FeatureAnchorWithSource],
3838
failOnMissingPartition: Boolean): Map[FeatureAnchorWithSource, Option[DataSourceAccessor]] = {
3939
val shouldSkipFeature = FeathrUtils.getFeathrJobParam(ss.sparkContext.getConf, FeathrUtils.SKIP_MISSING_FEATURE).toBoolean
40+
val shouldAddDefaultCol = FeathrUtils.getFeathrJobParam(ss.sparkContext.getConf, FeathrUtils.ADD_DEFAULT_COL_FOR_MISSING_DATA).toBoolean
41+
4042
// get a Map from each source to a list of all anchors based on this source
4143
val sourceToAnchor = requiredFeatureAnchors
4244
.map(anchor => (anchor.source, anchor))
@@ -74,15 +76,15 @@ private[offline] class AnchorToDataSourceMapper(dataPathHandlers: List[DataPathH
7476

7577
// If it is a nonTime based source, we will load the dataframe at runtime, however this is too late to decide if the
7678
// feature should be skipped. So, we will try to take the first row here, and see if it succeeds.
77-
if (dataSource.isInstanceOf[NonTimeBasedDataSourceAccessor] && shouldSkipFeature) {
79+
if (dataSource.isInstanceOf[NonTimeBasedDataSourceAccessor] && (shouldSkipFeature || shouldAddDefaultCol)) {
7880
if (dataSource.get().take(1).isEmpty) None else {
7981
Some(dataSource)
8082
}
8183
} else {
8284
Some(dataSource)
8385
}
8486
} catch {
85-
case e: Exception => if (shouldSkipFeature) None else throw e
87+
case e: Exception => if (shouldSkipFeature || shouldAddDefaultCol) None else throw e
8688
}
8789

8890
anchorsWithDate.map(anchor => (anchor, timeSeriesSource))

feathr-impl/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala

Lines changed: 149 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ import com.linkedin.feathr.offline.generation.SparkIOUtils
77
import com.linkedin.feathr.offline.job.PreprocessedDataFrameManager
88
import com.linkedin.feathr.offline.source.dataloader.{AvroJsonDataLoader, CsvDataLoader}
99
import com.linkedin.feathr.offline.util.FeathrTestUtils
10-
import com.linkedin.feathr.offline.util.FeathrUtils.{SKIP_MISSING_FEATURE, setFeathrJobParam}
10+
import com.linkedin.feathr.offline.util.FeathrUtils.{ADD_DEFAULT_COL_FOR_MISSING_DATA, SKIP_MISSING_FEATURE, setFeathrJobParam}
1111
import org.apache.spark.sql.Row
1212
import org.apache.spark.sql.functions.col
1313
import org.apache.spark.sql.types._
14-
import org.testng.Assert.assertTrue
14+
import org.testng.Assert.{assertEquals, assertTrue}
1515
import org.testng.annotations.{BeforeClass, Test}
1616

1717
import scala.collection.mutable
@@ -386,6 +386,153 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
386386
setFeathrJobParam(SKIP_MISSING_FEATURE, "false")
387387
}
388388

389+
/*
390+
* Test skipping combination of anchored, derived and swa features. Also, test it with different default value types.
391+
*/
392+
@Test
393+
def testAddDefaultForMissingAnchoredFeatures: Unit = {
394+
setFeathrJobParam(ADD_DEFAULT_COL_FOR_MISSING_DATA, "true")
395+
val df = runLocalFeatureJoinForTest(
396+
joinConfigAsString =
397+
"""
398+
|settings: {
399+
| joinTimeSettings: {
400+
| timestampColumn: {
401+
| def: "timestamp"
402+
| format: "yyyy-MM-dd"
403+
| }
404+
| simulateTimeDelay: 1d
405+
| }
406+
|}
407+
|
408+
| features: {
409+
| key: a_id
410+
| featureList: ["featureWithNull", "derived_featureWithNull", "featureWithNull2", "featureWithNull3", "featureWithNull4",
411+
| "featureWithNull5", "derived_featureWithNull2", "featureWithNull6", "featureWithNull7", "derived_featureWithNull7"
412+
| "aEmbedding", "memberEmbeddingAutoTZ"]
413+
| }
414+
""".stripMargin,
415+
featureDefAsString =
416+
"""
417+
| sources: {
418+
| swaSource: {
419+
| location: { path: "generaion/daily" }
420+
| timePartitionPattern: "yyyy/MM/dd"
421+
| timeWindowParameters: {
422+
| timestampColumn: "timestamp"
423+
| timestampColumnFormat: "yyyy-MM-dd"
424+
| }
425+
| }
426+
| swaSource1: {
427+
| location: { path: "generation/daily" }
428+
| timePartitionPattern: "yyyy/MM/dd"
429+
| timeWindowParameters: {
430+
| timestampColumn: "timestamp"
431+
| timestampColumnFormat: "yyyy-MM-dd"
432+
| }
433+
| }
434+
|}
435+
|
436+
| anchors: {
437+
| anchor1: {
438+
| source: "anchorAndDerivations/nullVaueSource.avro.json"
439+
| key: "toUpperCaseExt(mId)"
440+
| features: {
441+
| featureWithNull: {
442+
| def: "isPresent(value) ? toNumeric(value) : 0"
443+
| type: NUMERIC
444+
| default: -1
445+
| }
446+
| featureWithNull3: {
447+
| def: "isPresent(value) ? toNumeric(value) : 0"
448+
| type: CATEGORICAL
449+
| default: "null"
450+
| }
451+
| featureWithNull7: {
452+
| def: "isPresent(value) ? toNumeric(value) : 0"
453+
| }
454+
| featureWithNull4: {
455+
| def: "isPresent(value) ? toNumeric(value) : 0"
456+
| type: TERM_VECTOR
457+
| default: {}
458+
| }
459+
| featureWithNull6: {
460+
| def: "isPresent(value) ? toNumeric(value) : 0"
461+
| type: DENSE_VECTOR
462+
| default: [1, 2, 3]
463+
| }
464+
| featureWithNull5: {
465+
| def: "isPresent(value) ? toNumeric(value) : 0"
466+
| default: 1
467+
| }
468+
| }
469+
| }
470+
|
471+
| anchor2: {
472+
| source: "anchorAndDerivations/nullValueSource.avro.json"
473+
| key: "toUpperCaseExt(mId)"
474+
| features: {
475+
| featureWithNull2: "isPresent(value) ? toNumeric(value) : 0"
476+
| }
477+
| }
478+
| swaAnchor: {
479+
| source: "swaSource"
480+
| key: "x"
481+
| features: {
482+
| aEmbedding: {
483+
| def: "embedding"
484+
| aggregation: LATEST
485+
| window: 3d
486+
| default: 2
487+
| }
488+
| }
489+
| }
490+
| swaAnchor1: {
491+
| source: "swaSource1"
492+
| key: "x"
493+
| features: {
494+
| memberEmbeddingAutoTZ: {
495+
| def: "embedding"
496+
| aggregation: LATEST
497+
| window: 3d
498+
| type: {
499+
| type: TENSOR
500+
| tensorCategory: SPARSE
501+
| dimensionType: [INT]
502+
| valType: FLOAT
503+
| }
504+
| }
505+
| }
506+
| }
507+
|}
508+
|derivations: {
509+
|
510+
| derived_featureWithNull: "featureWithNull * 2"
511+
| derived_featureWithNull2: "featureWithNull2 * 2"
512+
| derived_featureWithNull7: "featureWithNull7 * 2"
513+
|}
514+
""".stripMargin,
515+
observationDataPath = "anchorAndDerivations/testMVELLoopExpFeature-observations.csv")
516+
517+
df.data.show()
518+
val featureList = df.data.collect().sortBy(row => if (row.get(0) != null) row.getAs[String]("a_id") else "null")
519+
assertEquals(featureList(0).getAs[Row]("aEmbedding"),
520+
Row(mutable.WrappedArray.make(Array("")), mutable.WrappedArray.make(Array(2.0f))))
521+
assertEquals(featureList(0).getAs[Row]("featureWithNull3"), "null")
522+
assertEquals(featureList(0).getAs[Row]("featureWithNull5"), mutable.Map("" -> 1.0f))
523+
assertEquals(featureList(0).getAs[Row]("featureWithNull7"), null)
524+
assertEquals(featureList(0).getAs[Row]("featureWithNull"),-1.0f)
525+
assertEquals(featureList(0).getAs[Row]("featureWithNull4"),Map())
526+
assertEquals(featureList(0).getAs[Row]("featureWithNull2"),1.0f)
527+
assertEquals(featureList(0).getAs[Row]("derived_featureWithNull"),
528+
Row(mutable.WrappedArray.make(Array("")), mutable.WrappedArray.make(Array(-2.0f))))
529+
assertEquals(featureList(0).getAs[Row]("derived_featureWithNull7"),
530+
Row(mutable.WrappedArray.make(Array()), mutable.WrappedArray.empty))
531+
assertEquals(featureList(0).getAs[Row]("derived_featureWithNull2"),
532+
Row(mutable.WrappedArray.make(Array("")), mutable.WrappedArray.make(Array(2.0f))))
533+
setFeathrJobParam(ADD_DEFAULT_COL_FOR_MISSING_DATA, "false")
534+
}
535+
389536
/*
390537
* Test features with fdsExtract.
391538
*/

0 commit comments

Comments
 (0)