From 90e6d34d0e8a8a24b7d59d83d30127af5de35ba5 Mon Sep 17 00:00:00 2001 From: Jinghui Mo Date: Tue, 1 Nov 2022 15:51:41 -0400 Subject: [PATCH] Fix sql-based derived feature --- .../derived/DerivedFeatureEvaluator.scala | 19 ++- .../strategies/DerivationStrategies.scala | 11 +- .../strategies/SqlDerivationSpark.scala | 107 +++++++++++++ .../DataFrameFeatureGenerator.scala | 6 +- .../feathr/offline/DerivationsIntegTest.scala | 146 ++++++++++++++++++ 5 files changed, 277 insertions(+), 12 deletions(-) create mode 100644 src/main/scala/com/linkedin/feathr/offline/derived/strategies/SqlDerivationSpark.scala create mode 100644 src/test/scala/com/linkedin/feathr/offline/DerivationsIntegTest.scala diff --git a/src/main/scala/com/linkedin/feathr/offline/derived/DerivedFeatureEvaluator.scala b/src/main/scala/com/linkedin/feathr/offline/derived/DerivedFeatureEvaluator.scala index ff16ebe18..59dd8ea8e 100644 --- a/src/main/scala/com/linkedin/feathr/offline/derived/DerivedFeatureEvaluator.scala +++ b/src/main/scala/com/linkedin/feathr/offline/derived/DerivedFeatureEvaluator.scala @@ -1,19 +1,19 @@ package com.linkedin.feathr.offline.derived -import com.linkedin.feathr.{common, offline} -import com.linkedin.feathr.common.{FeatureDerivationFunction, FeatureTypeConfig} import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrException} -import com.linkedin.feathr.offline.{ErasedEntityTaggedFeature, FeatureDataFrame} +import com.linkedin.feathr.common.{FeatureDerivationFunction, FeatureTypeConfig} import com.linkedin.feathr.offline.client.DataFrameColName import com.linkedin.feathr.offline.client.plugins.{FeathrUdfPluginContext, FeatureDerivationFunctionAdaptor} -import com.linkedin.feathr.offline.derived.functions.{MvelFeatureDerivationFunction, SeqJoinDerivationFunction} -import com.linkedin.feathr.offline.derived.strategies.{DerivationStrategies, RowBasedDerivation, SequentialJoinAsDerivation, SparkUdfDerivation} +import com.linkedin.feathr.offline.derived.functions.{MvelFeatureDerivationFunction, SQLFeatureDerivationFunction, SeqJoinDerivationFunction} +import com.linkedin.feathr.offline.derived.strategies._ import com.linkedin.feathr.offline.join.algorithms.{SequentialJoinConditionBuilder, SparkJoinWithJoinCondition} import com.linkedin.feathr.offline.logical.FeatureGroups import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext -import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils import com.linkedin.feathr.offline.source.accessor.DataPathHandler +import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils +import com.linkedin.feathr.offline.{ErasedEntityTaggedFeature, FeatureDataFrame} import com.linkedin.feathr.sparkcommon.FeatureDerivationFunctionSpark +import com.linkedin.feathr.{common, offline} import org.apache.log4j.Logger import org.apache.spark.sql.{DataFrame, SparkSession} @@ -45,6 +45,9 @@ private[offline] class DerivedFeatureEvaluator(derivationStrategies: DerivationS case h: FeatureDerivationFunctionSpark => val resultDF = derivationStrategies.customDerivationSparkStrategy(keyTag, keyTagList, contextDF, derivedFeature, h, mvelContext) convertFeatureColumnToQuinceFds(producedFeatureColName, derivedFeature, resultDF) + case s: SQLFeatureDerivationFunction => + val resultDF = derivationStrategies.sqlDerivationSparkStrategy(keyTag, keyTagList, contextDF, derivedFeature, s, mvelContext) + convertFeatureColumnToQuinceFds(producedFeatureColName, derivedFeature, resultDF) case x: FeatureDerivationFunction => // We should do the FDS conversion inside the rowBasedDerivationStrategy here. The result of rowBasedDerivationStrategy // can be NTV FeatureValue or TensorData-based Feature. NTV FeatureValue has fixed FDS schema. However, TensorData @@ -118,8 +121,8 @@ private[offline] object DerivedFeatureEvaluator { val defaultStrategies = strategies.DerivationStrategies( new SparkUdfDerivation(), new RowBasedDerivation(featureGroups.allTypeConfigs, mvelContext), - new SequentialJoinAsDerivation(ss, featureGroups, SparkJoinWithJoinCondition(SequentialJoinConditionBuilder), dataPathHandlers) - ) + new SequentialJoinAsDerivation(ss, featureGroups, SparkJoinWithJoinCondition(SequentialJoinConditionBuilder), dataPathHandlers), + new SqlDerivationSpark()) new DerivedFeatureEvaluator(defaultStrategies, mvelContext) } diff --git a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/DerivationStrategies.scala b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/DerivationStrategies.scala index e54d68f59..13fbec9c7 100644 --- a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/DerivationStrategies.scala +++ b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/DerivationStrategies.scala @@ -1,8 +1,8 @@ package com.linkedin.feathr.offline.derived.strategies import com.linkedin.feathr.common.{FeatureDerivationFunction, FeatureDerivationFunctionBase} -import com.linkedin.feathr.offline.derived.functions.SeqJoinDerivationFunction import com.linkedin.feathr.offline.derived.DerivedFeature +import com.linkedin.feathr.offline.derived.functions.{SQLFeatureDerivationFunction, SeqJoinDerivationFunction} import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.sparkcommon.FeatureDerivationFunctionSpark import org.apache.spark.sql.DataFrame @@ -41,10 +41,17 @@ private[offline] trait RowBasedDerivationStrategy extends DerivationStrategy[Fea */ private[offline] trait SequentialJoinDerivationStrategy extends DerivationStrategy[SeqJoinDerivationFunction] +/** + * Implementation should define how a SQL-expression based derivation is evaluated. + */ +private[offline] trait SqlDerivationSparkStrategy extends DerivationStrategy[SQLFeatureDerivationFunction] + /** * This case class holds the implementations of supported strategies. */ private[offline] case class DerivationStrategies( customDerivationSparkStrategy: SparkUdfDerivationStrategy, rowBasedDerivationStrategy: RowBasedDerivationStrategy, - sequentialJoinDerivationStrategy: SequentialJoinDerivationStrategy) + sequentialJoinDerivationStrategy: SequentialJoinDerivationStrategy, + sqlDerivationSparkStrategy: SqlDerivationSparkStrategy) { +} diff --git a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SqlDerivationSpark.scala b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SqlDerivationSpark.scala new file mode 100644 index 000000000..3afa0a6af --- /dev/null +++ b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SqlDerivationSpark.scala @@ -0,0 +1,107 @@ +package com.linkedin.feathr.offline.derived.strategies + +import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrFeatureTransformationException} +import com.linkedin.feathr.offline.client.DataFrameColName +import com.linkedin.feathr.offline.derived.DerivedFeature +import com.linkedin.feathr.offline.derived.functions.SQLFeatureDerivationFunction +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.{DataFrame, SparkSession} + +import scala.collection.JavaConverters._ + +/** + * This class executes SQL-expression based derived feature. + */ +class SqlDerivationSpark extends SqlDerivationSparkStrategy { + + + /** + * Rewrite sqlExpression for a derived feature, e.g, replace the feature name/argument name with Frame internal dataframe column name + * @param deriveFeature derived feature definition + * @param keyTag list of tags represented by integer + * @param keyTagId2StringMap Map from the tag integer id to the string tag + * @return Rewritten SQL expression + */ + private[offline] def rewriteDerivedFeatureExpression( + deriveFeature: DerivedFeature, + keyTag: Seq[Int], + keyTagId2StringMap: Seq[String]): String = { + if (!deriveFeature.derivation.isInstanceOf[SQLFeatureDerivationFunction]) { + throw new FeathrFeatureTransformationException(ErrorLabel.FEATHR_ERROR, "Should not rewrite derived feature expression for non-SQLDerivedFeatures") + } + val sqlDerivation = deriveFeature.derivation.asInstanceOf[SQLFeatureDerivationFunction] + val deriveExpr = sqlDerivation.getExpression() + val parameterNames: Seq[String] = sqlDerivation.getParameterNames().getOrElse(Seq[String]()) + val consumedFeatureNames = deriveFeature.consumedFeatureNames.zipWithIndex.map { + case (consumeFeatureName, index) => + // begin of string, or other char except number and alphabet + // val featureStartPattern = """(^|[^a-zA-Z0-9])""" + // end of string, or other char except number and alphabet + // val featureEndPattern = """($|[^a-zA-Z0-9])""" + val namePattern = if (parameterNames.isEmpty) consumeFeatureName.getFeatureName else parameterNames(index) + // getBinding.map(keyTag.get) resolves the call tags + val newName = + if (!consumeFeatureName.getBinding.isEmpty // Passthrough features do not have keyTag + // Feature generation code path does not create columns with tags. + // The check ensures we do not run into IndexOutOfBoundsException when keyTag & keyTagId2StringMap are empty. + && keyTag.nonEmpty + && keyTagId2StringMap.nonEmpty) { + DataFrameColName.genFeatureColumnName( + consumeFeatureName.getFeatureName, + Some(consumeFeatureName.getBinding.asScala.map(keyTag(_)).map(keyTagId2StringMap))) + } else { + DataFrameColName.genFeatureColumnName(consumeFeatureName.getFeatureName) + } + (namePattern, newName) + }.toMap + + // replace all feature name to column names + // featureName is consist of numAlphabetic + val ss: SparkSession = SparkSession.builder().getOrCreate() + val dependencyFeatures = ss.sessionState.sqlParser.parseExpression(deriveExpr).references.map(_.name).toSeq + // \w is [a-zA-Z0-9_], not inclusion of _ and exclusion of -, as - is ambiguous, e.g, a-b could be a feature name or feature a minus feature b + val rewrittenExpr = dependencyFeatures.foldLeft(deriveExpr)((acc, ca) => { + // in scala \W does not work as ^\w + // "a+B+1".replaceAll("([^\w])B([^\w])", "$1abc$2" = A+abc+1 + // "a+B".replaceAll("([^\w])B$", "$1abc" = a+abc + // "B+1".replaceAll("^B([^\w])", "abc$1" = abc+1 + // "B".replaceAll("^B$", "abc" = abc + val newVal = consumedFeatureNames.getOrElse(ca, ca) + val patterns = Seq("([^\\w])" + ca + "([^\\w])", "([^\\w])" + ca + "$", "^" + ca + "([^\\w])", "^" + ca + "$") + val replacements = Seq("$1" + newVal + "$2", "$1" + newVal, newVal + "$1", newVal) + val replacedExpr = patterns + .zip(replacements) + .toMap + .foldLeft(acc)((orig, pairs) => { + orig.replaceAll(pairs._1, pairs._2) + }) + replacedExpr + }) + rewrittenExpr + } + + /** + * Apply the derivation strategy. + * + * @param keyTags keyTags for the derived feature. + * @param keyTagList integer keyTag to string keyTag map. + * @param df input DataFrame. + * @param derivedFeature Derived feature metadata. + * @param derivationFunction Derivation function to evaluate the derived feature + * @return output DataFrame with derived feature. + */ + override def apply(keyTags: Seq[Int], + keyTagList: Seq[String], + df: DataFrame, + derivedFeature: DerivedFeature, + derivationFunction: SQLFeatureDerivationFunction, + mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame = { + // sql expression based derived feature needs rewrite, e.g, replace the feature names with feature column names in the dataframe + val rewrittenExpr = rewriteDerivedFeatureExpression(derivedFeature, keyTags, keyTagList) + val tags = Some(keyTags.map(keyTagList).toList) + val featureColumnName = DataFrameColName.genFeatureColumnName(derivedFeature.producedFeatureNames.head, tags) + df.withColumn(featureColumnName, expr(rewrittenExpr)) + } + +} diff --git a/src/main/scala/com/linkedin/feathr/offline/generation/DataFrameFeatureGenerator.scala b/src/main/scala/com/linkedin/feathr/offline/generation/DataFrameFeatureGenerator.scala index 310c3931e..57f4def55 100644 --- a/src/main/scala/com/linkedin/feathr/offline/generation/DataFrameFeatureGenerator.scala +++ b/src/main/scala/com/linkedin/feathr/offline/generation/DataFrameFeatureGenerator.scala @@ -5,7 +5,7 @@ import com.linkedin.feathr.common.{Header, JoiningFeatureParams, TaggedFeatureNa import com.linkedin.feathr.offline import com.linkedin.feathr.offline.anchored.feature.FeatureAnchorWithSource.{getDefaultValues, getFeatureTypes} import com.linkedin.feathr.offline.derived.functions.SeqJoinDerivationFunction -import com.linkedin.feathr.offline.derived.strategies.{DerivationStrategies, RowBasedDerivation, SequentialJoinDerivationStrategy, SparkUdfDerivation} +import com.linkedin.feathr.offline.derived.strategies.{DerivationStrategies, RowBasedDerivation, SequentialJoinDerivationStrategy, SparkUdfDerivation, SqlDerivationSpark} import com.linkedin.feathr.offline.derived.{DerivedFeature, DerivedFeatureEvaluator} import com.linkedin.feathr.offline.evaluator.DerivedFeatureGenStage import com.linkedin.feathr.offline.job.{FeatureGenSpec, FeatureTransformation} @@ -133,5 +133,7 @@ private[offline] class DataFrameFeatureGenerator(logicalPlan: MultiStageJoinPlan ErrorLabel.FEATHR_ERROR, s"Feature Generation does not support Sequential Join features : ${derivedFeature.producedFeatureNames.head}") } - }), mvelContext) + }, + new SqlDerivationSpark() + ), mvelContext) } diff --git a/src/test/scala/com/linkedin/feathr/offline/DerivationsIntegTest.scala b/src/test/scala/com/linkedin/feathr/offline/DerivationsIntegTest.scala new file mode 100644 index 000000000..94e92e06d --- /dev/null +++ b/src/test/scala/com/linkedin/feathr/offline/DerivationsIntegTest.scala @@ -0,0 +1,146 @@ +package com.linkedin.feathr.offline + +import com.linkedin.feathr.offline.util.FeathrTestUtils.assertDataFrameApproximatelyEquals +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.testng.annotations.Test + +class DerivationsIntegTest extends FeathrIntegTest { + + /** + * Test multi-key derived feature and multi-tagged feature. + * This test covers the following:- + * -> sql based custom extractor + */ + @Test + def testMultiKeyDerivedFeatureDFWithSQL: Unit = { + val df = runLocalFeatureJoinForTest( + joinConfigAsString = """ + | features: [ { + | key: ["concat('',viewer)", viewee] + | featureList: [ "foo_square_distance_sql"] + | } , + | { + | key: [viewee, viewer] + | featureList: [ "foo_square_distance_sql"] + | }, + | { + | key: [viewee, viewer] + | featureList: [ "square_fooFeature_sql"] + | } + | ] + """.stripMargin, + featureDefAsString = """ + | anchors: { + | anchor1: { + | source: anchorAndDerivations/derivations/anchor6-source.csv + | key.sqlExpr: [sourceId, destId] + | features: { + | fooFeature: { + | def.sqlExpr: cast(source as int) + | type: NUMERIC + | } + | } + | } + | } + | derivations: { + | + | square_fooFeature_sql: { + | key: [m1, m2] + | inputs: { + | a: { key: [m1, m2], feature: fooFeature } + | } + | definition.sqlExpr: "a * a" + | } + | foo_square_distance_sql: { + | key: [m1, m2] + | inputs: { + | a1: { key: [m1, m2], feature: square_fooFeature_sql } + | a2: { key: [m2, m1], feature: square_fooFeature_sql } + | } + | definition.sqlExpr: "a1 - a2" + | } + | } + """.stripMargin, + observationDataPath = "anchorAndDerivations/derivations/test2-observations.csv") + + val expectedDf = ss.createDataFrame( + ss.sparkContext.parallelize( + Seq( + Row( + // viewer + "1", + // viewee + "3", + // label + "1.0", + // square_fooFeature_sql + 4.0f, + // viewee_viewer__foo_square_distance_sql + -21.0f, + // concat____viewer__viewee__foo_square_distance_sql + 21.0f), + Row( + // viewer + "2", + // viewee + "1", + // label + "-1.0", + // square_fooFeature_sql + 9.0f, + // viewee_viewer__foo_square_distance_sql + -27.0f, + // concat____viewer__viewee__foo_square_distance_sql + 27.0f), + Row( + // viewer + "3", + // viewee + "6", + // label + "1.0", + // square_fooFeature_sql + null, + // viewee_viewer__foo_square_distance_sql + null, + // concat____viewer__viewee__foo_square_distance_sql + null), + Row( + // viewer + "3", + // viewee + "5", + // label + "-1.0", + // square_fooFeature_sql + null, + // viewee_viewer__foo_square_distance_sql + null, + // concat____viewer__viewee__foo_square_distance_sql + null), + Row( + // viewer + "5", + // viewee + "10", + // label + "1.0", + // square_fooFeature_sql + null, + // viewee_viewer__foo_square_distance_sql + null, + // concat____viewer__viewee__foo_square_distance_sql + null))), + StructType( + List( + StructField("viewer", StringType, true), + StructField("viewee", StringType, true), + StructField("label", StringType, true), + StructField("square_fooFeature_sql", FloatType, true), + StructField("viewee_viewer__foo_square_distance_sql", FloatType, true), + StructField("concat____viewer__viewee__foo_square_distance_sql", FloatType, true)))) + def cmpFunc(row: Row): String = if (row.get(0) != null) row.get(0).toString else "null" + assertDataFrameApproximatelyEquals(df.data, expectedDf, cmpFunc) + } +}