Skip to content

Commit c06699a

Browse files
authored
Support high-dimensional tensor in derivations (#1172)
1 parent 53780f8 commit c06699a

File tree

3 files changed

+93
-3
lines changed

3 files changed

+93
-3
lines changed

feathr-impl/src/main/scala/com/linkedin/feathr/offline/transformation/FDSConversionUtils.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ private[offline] object FDSConversionUtils {
4949
// 1D sparse tensor
5050
case targetType: StructType if targetType.fields.size == 2 =>
5151
convertRawValueTo1DFDSSparseTensorRow(rawFeatureValue, targetType)
52-
// 1D dense tensor
53-
case targetType: ArrayType if !targetType.elementType.isInstanceOf[ArrayType] =>
52+
// dense tensor
53+
case targetType: ArrayType =>
5454
convertRawValueTo1DFDSDenseTensorRow(rawFeatureValue, targetType)
5555
case otherType =>
5656
throw new FeathrException(ErrorLabel.FEATHR_ERROR, s"Converting ${rawFeatureValue} to FDS Tensor type " +
@@ -279,6 +279,8 @@ private[offline] object FDSConversionUtils {
279279
case _: FloatType =>
280280
// If it's FloatType, then we know it's autoTz rules.
281281
convertRawValueTo1DFDSDenseTensorRowAutoTz(rawFeatureValue)
282+
case _: ArrayType =>
283+
rawFeatureValue.asInstanceOf[Array[_]]
282284
case _ =>
283285
convertRawValueTo1DFDSDenseTensorRowTz(rawFeatureValue)
284286
}

feathr-impl/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,94 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
566566
setFeathrJobParam(ADD_DEFAULT_COL_FOR_MISSING_DATA, "false")
567567
}
568568

569+
/*
570+
* Test features with fdsExtract.
571+
*/
572+
@Test
573+
def testFeaturesWithFdsExtract: Unit = {
574+
val df = runLocalFeatureJoinForTest(
575+
joinConfigAsString =
576+
"""
577+
| features: {
578+
| key: a_id
579+
| featureList: ["featureWithNullDerived"]
580+
| }
581+
""".stripMargin,
582+
featureDefAsString =
583+
"""
584+
| anchors: {
585+
| anchor1: {
586+
| source: "anchorAndDerivations/nullValueSource.avro.json"
587+
| key.sqlExpr: mId
588+
| features: {
589+
| featureWithNull {
590+
| def.sqlExpr: "FDSExtract(denseValue)"
591+
| type:{
592+
| type: TENSOR
593+
| tensorCategory: DENSE
594+
| shape: [2,5]
595+
| dimensionType: [INT, INT]
596+
| valType: STRING
597+
| }
598+
| }
599+
| }
600+
| }
601+
|}
602+
|derivations: {
603+
|featureWithNullDerived:{
604+
| key: ["id"]
605+
| inputs:
606+
| {
607+
| fv: {key: ["id"], feature: featureWithNull}
608+
| }
609+
| definition.sqlExpr: "coalesce(fv, ARRAY(ARRAY(\"aa\", \"bb\", \"cc\", \"dd\", \"ee\"), ARRAY(\"UNK\", \"UNK\", \"UNK\", \"UNK\", \"UNK\")))"
610+
| type:
611+
| {
612+
| type: TENSOR
613+
| tensorCategory: DENSE
614+
| shape: [2,5]
615+
| dimensionType: [INT, INT]
616+
| valType: STRING
617+
| }
618+
|}
619+
|}
620+
""".stripMargin,
621+
observationDataPath = "anchorAndDerivations/testMVELLoopExpFeature-observations.csv")
622+
623+
val selectedColumns = Seq("a_id", "featureWithNullDerived")
624+
val filteredDf = df.data.select(selectedColumns.head, selectedColumns.tail: _*)
625+
626+
val expectedDf = ss.createDataFrame(
627+
ss.sparkContext.parallelize(
628+
Seq(
629+
Row(
630+
// a_id
631+
"1",
632+
// featureWithNull
633+
mutable.WrappedArray.make(Array(Array("aa", "bb", "cc", "dd", "ee"), Array("a", "a", "a", "a", "a"))),
634+
),
635+
Row(
636+
// a_id
637+
"2",
638+
// f3eatureWithNull
639+
mutable.WrappedArray.make(Array(Array("aa", "bb", "cc", "dd", "ee"), Array("UNK", "UNK", "UNK", "UNK", "UNK")))
640+
),
641+
Row(
642+
// a_id
643+
"3",
644+
// featureWithNull
645+
mutable.WrappedArray.make(Array(Array("aa", "bb", "cc", "dd", "ee"), Array("a", "a", "a", "a", "a")),
646+
)))),
647+
StructType(
648+
List(
649+
StructField("a_id", StringType, true),
650+
StructField("featureWithNull", ArrayType(ArrayType(StringType, true), true), true)
651+
)))
652+
653+
def cmpFunc(row: Row): String = row.get(0).toString
654+
655+
FeathrTestUtils.assertDataFrameApproximatelyEquals(filteredDf, expectedDf, cmpFunc)
656+
}
569657

570658
/*
571659
* Test features with null values.

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
version=1.0.4-rc5
1+
version=1.0.4-rc6
22
SONATYPE_AUTOMATIC_RELEASE=true
33
POM_ARTIFACT_ID=feathr_2.12

0 commit comments

Comments
 (0)