Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
package com.linkedin.feathr.offline.derived

import com.linkedin.feathr.{common, offline}
import com.linkedin.feathr.common.{FeatureDerivationFunction, FeatureTypeConfig}
import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrException}
import com.linkedin.feathr.offline.{ErasedEntityTaggedFeature, FeatureDataFrame}
import com.linkedin.feathr.common.{FeatureDerivationFunction, FeatureTypeConfig}
import com.linkedin.feathr.offline.client.DataFrameColName
import com.linkedin.feathr.offline.client.plugins.{FeathrUdfPluginContext, FeatureDerivationFunctionAdaptor}
import com.linkedin.feathr.offline.derived.functions.{MvelFeatureDerivationFunction, SeqJoinDerivationFunction}
import com.linkedin.feathr.offline.derived.strategies.{DerivationStrategies, RowBasedDerivation, SequentialJoinAsDerivation, SparkUdfDerivation}
import com.linkedin.feathr.offline.derived.functions.{MvelFeatureDerivationFunction, SQLFeatureDerivationFunction, SeqJoinDerivationFunction}
import com.linkedin.feathr.offline.derived.strategies._
import com.linkedin.feathr.offline.join.algorithms.{SequentialJoinConditionBuilder, SparkJoinWithJoinCondition}
import com.linkedin.feathr.offline.logical.FeatureGroups
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils
import com.linkedin.feathr.offline.{ErasedEntityTaggedFeature, FeatureDataFrame}
import com.linkedin.feathr.sparkcommon.FeatureDerivationFunctionSpark
import com.linkedin.feathr.{common, offline}
import org.apache.log4j.Logger
import org.apache.spark.sql.{DataFrame, SparkSession}

Expand Down Expand Up @@ -45,6 +45,9 @@ private[offline] class DerivedFeatureEvaluator(derivationStrategies: DerivationS
case h: FeatureDerivationFunctionSpark =>
val resultDF = derivationStrategies.customDerivationSparkStrategy(keyTag, keyTagList, contextDF, derivedFeature, h, mvelContext)
convertFeatureColumnToQuinceFds(producedFeatureColName, derivedFeature, resultDF)
case s: SQLFeatureDerivationFunction =>
val resultDF = derivationStrategies.sqlDerivationSparkStrategy(keyTag, keyTagList, contextDF, derivedFeature, s, mvelContext)
convertFeatureColumnToQuinceFds(producedFeatureColName, derivedFeature, resultDF)
case x: FeatureDerivationFunction =>
// We should do the FDS conversion inside the rowBasedDerivationStrategy here. The result of rowBasedDerivationStrategy
// can be NTV FeatureValue or TensorData-based Feature. NTV FeatureValue has fixed FDS schema. However, TensorData
Expand Down Expand Up @@ -118,8 +121,8 @@ private[offline] object DerivedFeatureEvaluator {
val defaultStrategies = strategies.DerivationStrategies(
new SparkUdfDerivation(),
new RowBasedDerivation(featureGroups.allTypeConfigs, mvelContext),
new SequentialJoinAsDerivation(ss, featureGroups, SparkJoinWithJoinCondition(SequentialJoinConditionBuilder), dataPathHandlers)
)
new SequentialJoinAsDerivation(ss, featureGroups, SparkJoinWithJoinCondition(SequentialJoinConditionBuilder), dataPathHandlers),
new SqlDerivationSpark())
new DerivedFeatureEvaluator(defaultStrategies, mvelContext)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.linkedin.feathr.offline.derived.strategies

import com.linkedin.feathr.common.{FeatureDerivationFunction, FeatureDerivationFunctionBase}
import com.linkedin.feathr.offline.derived.functions.SeqJoinDerivationFunction
import com.linkedin.feathr.offline.derived.DerivedFeature
import com.linkedin.feathr.offline.derived.functions.{SQLFeatureDerivationFunction, SeqJoinDerivationFunction}
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.sparkcommon.FeatureDerivationFunctionSpark
import org.apache.spark.sql.DataFrame
Expand Down Expand Up @@ -41,10 +41,17 @@ private[offline] trait RowBasedDerivationStrategy extends DerivationStrategy[Fea
*/
private[offline] trait SequentialJoinDerivationStrategy extends DerivationStrategy[SeqJoinDerivationFunction]

/**
* Implementation should define how a SQL-expression based derivation is evaluated.
*/
private[offline] trait SqlDerivationSparkStrategy extends DerivationStrategy[SQLFeatureDerivationFunction]

/**
* This case class holds the implementations of supported strategies.
*/
private[offline] case class DerivationStrategies(
customDerivationSparkStrategy: SparkUdfDerivationStrategy,
rowBasedDerivationStrategy: RowBasedDerivationStrategy,
sequentialJoinDerivationStrategy: SequentialJoinDerivationStrategy)
sequentialJoinDerivationStrategy: SequentialJoinDerivationStrategy,
sqlDerivationSparkStrategy: SqlDerivationSparkStrategy) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package com.linkedin.feathr.offline.derived.strategies

import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrFeatureTransformationException}
import com.linkedin.feathr.offline.client.DataFrameColName
import com.linkedin.feathr.offline.derived.DerivedFeature
import com.linkedin.feathr.offline.derived.functions.SQLFeatureDerivationFunction
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.collection.JavaConverters._

/**
* This class executes SQL-expression based derived feature.
*/
class SqlDerivationSpark extends SqlDerivationSparkStrategy {


/**
* Rewrite sqlExpression for a derived feature, e.g, replace the feature name/argument name with Frame internal dataframe column name
* @param deriveFeature derived feature definition
* @param keyTag list of tags represented by integer
* @param keyTagId2StringMap Map from the tag integer id to the string tag
* @return Rewritten SQL expression
*/
private[offline] def rewriteDerivedFeatureExpression(
deriveFeature: DerivedFeature,
keyTag: Seq[Int],
keyTagId2StringMap: Seq[String]): String = {
if (!deriveFeature.derivation.isInstanceOf[SQLFeatureDerivationFunction]) {
throw new FeathrFeatureTransformationException(ErrorLabel.FEATHR_ERROR, "Should not rewrite derived feature expression for non-SQLDerivedFeatures")
}
val sqlDerivation = deriveFeature.derivation.asInstanceOf[SQLFeatureDerivationFunction]
val deriveExpr = sqlDerivation.getExpression()
val parameterNames: Seq[String] = sqlDerivation.getParameterNames().getOrElse(Seq[String]())
val consumedFeatureNames = deriveFeature.consumedFeatureNames.zipWithIndex.map {
case (consumeFeatureName, index) =>
// begin of string, or other char except number and alphabet
// val featureStartPattern = """(^|[^a-zA-Z0-9])"""
// end of string, or other char except number and alphabet
// val featureEndPattern = """($|[^a-zA-Z0-9])"""
val namePattern = if (parameterNames.isEmpty) consumeFeatureName.getFeatureName else parameterNames(index)
// getBinding.map(keyTag.get) resolves the call tags
val newName =
if (!consumeFeatureName.getBinding.isEmpty // Passthrough features do not have keyTag
// Feature generation code path does not create columns with tags.
// The check ensures we do not run into IndexOutOfBoundsException when keyTag & keyTagId2StringMap are empty.
&& keyTag.nonEmpty
&& keyTagId2StringMap.nonEmpty) {
DataFrameColName.genFeatureColumnName(
consumeFeatureName.getFeatureName,
Some(consumeFeatureName.getBinding.asScala.map(keyTag(_)).map(keyTagId2StringMap)))
} else {
DataFrameColName.genFeatureColumnName(consumeFeatureName.getFeatureName)
}
(namePattern, newName)
}.toMap

// replace all feature name to column names
// featureName is consist of numAlphabetic
val ss: SparkSession = SparkSession.builder().getOrCreate()
val dependencyFeatures = ss.sessionState.sqlParser.parseExpression(deriveExpr).references.map(_.name).toSeq
// \w is [a-zA-Z0-9_], not inclusion of _ and exclusion of -, as - is ambiguous, e.g, a-b could be a feature name or feature a minus feature b
val rewrittenExpr = dependencyFeatures.foldLeft(deriveExpr)((acc, ca) => {
// in scala \W does not work as ^\w
// "a+B+1".replaceAll("([^\w])B([^\w])", "$1abc$2" = A+abc+1
// "a+B".replaceAll("([^\w])B$", "$1abc" = a+abc
// "B+1".replaceAll("^B([^\w])", "abc$1" = abc+1
// "B".replaceAll("^B$", "abc" = abc
val newVal = consumedFeatureNames.getOrElse(ca, ca)
val patterns = Seq("([^\\w])" + ca + "([^\\w])", "([^\\w])" + ca + "$", "^" + ca + "([^\\w])", "^" + ca + "$")
val replacements = Seq("$1" + newVal + "$2", "$1" + newVal, newVal + "$1", newVal)
val replacedExpr = patterns
.zip(replacements)
.toMap
.foldLeft(acc)((orig, pairs) => {
orig.replaceAll(pairs._1, pairs._2)
})
replacedExpr
})
rewrittenExpr
}

/**
* Apply the derivation strategy.
*
* @param keyTags keyTags for the derived feature.
* @param keyTagList integer keyTag to string keyTag map.
* @param df input DataFrame.
* @param derivedFeature Derived feature metadata.
* @param derivationFunction Derivation function to evaluate the derived feature
* @return output DataFrame with derived feature.
*/
override def apply(keyTags: Seq[Int],
keyTagList: Seq[String],
df: DataFrame,
derivedFeature: DerivedFeature,
derivationFunction: SQLFeatureDerivationFunction,
mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame = {
// sql expression based derived feature needs rewrite, e.g, replace the feature names with feature column names in the dataframe
val rewrittenExpr = rewriteDerivedFeatureExpression(derivedFeature, keyTags, keyTagList)
val tags = Some(keyTags.map(keyTagList).toList)
val featureColumnName = DataFrameColName.genFeatureColumnName(derivedFeature.producedFeatureNames.head, tags)
df.withColumn(featureColumnName, expr(rewrittenExpr))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import com.linkedin.feathr.common.{Header, JoiningFeatureParams, TaggedFeatureNa
import com.linkedin.feathr.offline
import com.linkedin.feathr.offline.anchored.feature.FeatureAnchorWithSource.{getDefaultValues, getFeatureTypes}
import com.linkedin.feathr.offline.derived.functions.SeqJoinDerivationFunction
import com.linkedin.feathr.offline.derived.strategies.{DerivationStrategies, RowBasedDerivation, SequentialJoinDerivationStrategy, SparkUdfDerivation}
import com.linkedin.feathr.offline.derived.strategies.{DerivationStrategies, RowBasedDerivation, SequentialJoinDerivationStrategy, SparkUdfDerivation, SqlDerivationSpark}
import com.linkedin.feathr.offline.derived.{DerivedFeature, DerivedFeatureEvaluator}
import com.linkedin.feathr.offline.evaluator.DerivedFeatureGenStage
import com.linkedin.feathr.offline.job.{FeatureGenSpec, FeatureTransformation}
Expand Down Expand Up @@ -133,5 +133,7 @@ private[offline] class DataFrameFeatureGenerator(logicalPlan: MultiStageJoinPlan
ErrorLabel.FEATHR_ERROR,
s"Feature Generation does not support Sequential Join features : ${derivedFeature.producedFeatureNames.head}")
}
}), mvelContext)
},
new SqlDerivationSpark()
), mvelContext)
}
146 changes: 146 additions & 0 deletions src/test/scala/com/linkedin/feathr/offline/DerivationsIntegTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package com.linkedin.feathr.offline

import com.linkedin.feathr.offline.util.FeathrTestUtils.assertDataFrameApproximatelyEquals
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.testng.annotations.Test

class DerivationsIntegTest extends FeathrIntegTest {

/**
* Test multi-key derived feature and multi-tagged feature.
* This test covers the following:-
* -> sql based custom extractor
*/
@Test
def testMultiKeyDerivedFeatureDFWithSQL: Unit = {
val df = runLocalFeatureJoinForTest(
joinConfigAsString = """
| features: [ {
| key: ["concat('',viewer)", viewee]
| featureList: [ "foo_square_distance_sql"]
| } ,
| {
| key: [viewee, viewer]
| featureList: [ "foo_square_distance_sql"]
| },
| {
| key: [viewee, viewer]
| featureList: [ "square_fooFeature_sql"]
| }
| ]
""".stripMargin,
featureDefAsString = """
| anchors: {
| anchor1: {
| source: anchorAndDerivations/derivations/anchor6-source.csv
| key.sqlExpr: [sourceId, destId]
| features: {
| fooFeature: {
| def.sqlExpr: cast(source as int)
| type: NUMERIC
| }
| }
| }
| }
| derivations: {
|
| square_fooFeature_sql: {
| key: [m1, m2]
| inputs: {
| a: { key: [m1, m2], feature: fooFeature }
| }
| definition.sqlExpr: "a * a"
| }
| foo_square_distance_sql: {
| key: [m1, m2]
| inputs: {
| a1: { key: [m1, m2], feature: square_fooFeature_sql }
| a2: { key: [m2, m1], feature: square_fooFeature_sql }
| }
| definition.sqlExpr: "a1 - a2"
| }
| }
""".stripMargin,
observationDataPath = "anchorAndDerivations/derivations/test2-observations.csv")

val expectedDf = ss.createDataFrame(
ss.sparkContext.parallelize(
Seq(
Row(
// viewer
"1",
// viewee
"3",
// label
"1.0",
// square_fooFeature_sql
4.0f,
// viewee_viewer__foo_square_distance_sql
-21.0f,
// concat____viewer__viewee__foo_square_distance_sql
21.0f),
Row(
// viewer
"2",
// viewee
"1",
// label
"-1.0",
// square_fooFeature_sql
9.0f,
// viewee_viewer__foo_square_distance_sql
-27.0f,
// concat____viewer__viewee__foo_square_distance_sql
27.0f),
Row(
// viewer
"3",
// viewee
"6",
// label
"1.0",
// square_fooFeature_sql
null,
// viewee_viewer__foo_square_distance_sql
null,
// concat____viewer__viewee__foo_square_distance_sql
null),
Row(
// viewer
"3",
// viewee
"5",
// label
"-1.0",
// square_fooFeature_sql
null,
// viewee_viewer__foo_square_distance_sql
null,
// concat____viewer__viewee__foo_square_distance_sql
null),
Row(
// viewer
"5",
// viewee
"10",
// label
"1.0",
// square_fooFeature_sql
null,
// viewee_viewer__foo_square_distance_sql
null,
// concat____viewer__viewee__foo_square_distance_sql
null))),
StructType(
List(
StructField("viewer", StringType, true),
StructField("viewee", StringType, true),
StructField("label", StringType, true),
StructField("square_fooFeature_sql", FloatType, true),
StructField("viewee_viewer__foo_square_distance_sql", FloatType, true),
StructField("concat____viewer__viewee__foo_square_distance_sql", FloatType, true))))
def cmpFunc(row: Row): String = if (row.get(0) != null) row.get(0).toString else "null"
assertDataFrameApproximatelyEquals(df.data, expectedDf, cmpFunc)
}
}