diff --git a/src/main/scala/org/graphframes/catalyst/expressions.scala b/src/main/scala/org/graphframes/catalyst/expressions.scala new file mode 100644 index 000000000..53dfbb589 --- /dev/null +++ b/src/main/scala/org/graphframes/catalyst/expressions.scala @@ -0,0 +1,37 @@ +package org.graphframes.catalyst + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.UnaryExpression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.MapType +import org.graphframes.GraphFramesUnreachableException + +private[graphframes] object GraphFramesFunctions { + def keyWithMaxValue(mapCol: Column): Column = new Column(KeyWithMaxValue(mapCol.expr)) +} + +private[graphframes] case class KeyWithMaxValue(child: Expression) + extends UnaryExpression + with CodegenFallback { + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + + override def dataType: DataType = child.dataType match { + case t: MapType => t.valueType + case _: DataType => throw new GraphFramesUnreachableException() + } + + override protected def nullSafeEval(input: Any): Any = { + input match { + case map: Map[Long, Int] @unchecked => map.maxBy { case (key, value) => (value, key) }._1 + case _ => throw new GraphFramesUnreachableException() + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen(ctx, ev, eval => s"${eval}.maxBy{ case (key, value) => (value, key) }._1)") +} diff --git a/src/main/scala/org/graphframes/lib/LabelPropagation.scala b/src/main/scala/org/graphframes/lib/LabelPropagation.scala index ff9af4d48..3ef418db6 100644 --- a/src/main/scala/org/graphframes/lib/LabelPropagation.scala +++ b/src/main/scala/org/graphframes/lib/LabelPropagation.scala @@ -19,8 +19,12 @@ package org.graphframes.lib import org.apache.spark.graphx.{lib => graphxlib} import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ import org.graphframes.GraphFrame +import org.graphframes.WithAlgorithmChoice +import org.graphframes.WithCheckpointInterval import org.graphframes.WithMaxIter +import org.graphframes.catalyst.GraphFramesFunctions /** * Run static Label Propagation for detecting communities in networks. @@ -38,19 +42,64 @@ import org.graphframes.WithMaxIter */ class LabelPropagation private[graphframes] (private val graph: GraphFrame) extends Arguments + with WithAlgorithmChoice + with WithCheckpointInterval with WithMaxIter { def run(): DataFrame = { - LabelPropagation.run(graph, check(maxIter, "maxIter")) + val maxIterChecked = check(maxIter, "maxIter") + algorithm match { + case "graphx" => LabelPropagation.runInGraphX(graph, maxIterChecked) + case "graphframes" => + LabelPropagation.runInGraphFrames(graph, maxIterChecked, checkpointInterval) + } } } private object LabelPropagation { - private def run(graph: GraphFrame, maxIter: Int): DataFrame = { + private def runInGraphX(graph: GraphFrame, maxIter: Int): DataFrame = { val gx = graphxlib.LabelPropagation.run(graph.cachedTopologyGraphX, maxIter) GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(LABEL_ID)).vertices } + private def runInGraphFrames( + graph: GraphFrame, + maxIter: Int, + checkpointInterval: Int, + isDirected: Boolean = true): DataFrame = { + // Overall: + // - Initial labels - IDs + // - Active vertex col (halt voting) - did the label changed? + // - Choosing a new label - top across neighbours (tie-braking is determenistic) + + var pregel = graph.pregel + .withVertexColumn( + LABEL_ID, + col(GraphFrame.ID).alias(LABEL_ID), + GraphFramesFunctions.keyWithMaxValue(Pregel.msg)) + .setMaxIter(maxIter) + .setStopIfAllNonActiveVertices(true) + .setEarlyStopping(false) + .setCheckpointInterval(checkpointInterval) + .setSkipMessagesFromNonActiveVertices(false) + .setUpdateActiveVertexExpression(col(LABEL_ID) =!= GraphFramesFunctions + .keyWithMaxValue(Pregel.msg)) + + if (isDirected) { + pregel = pregel.sendMsgToDst(col(LABEL_ID)) + } else { + pregel = pregel.sendMsgToDst(col(LABEL_ID)).sendMsgToSrc(col(LABEL_ID)) + } + + pregel = pregel.aggMsgs( + reduce( + collect_list(Pregel.msg), + lit(Map.empty[Long, Int]), + (acc, x) => map_concat(acc, map(coalesce(acc.getItem(x) + lit(1), lit(1)))))) + + pregel.run() + } + private val LABEL_ID = "label" } diff --git a/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala b/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala index 969599edf..a89f0a1a7 100644 --- a/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala +++ b/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala @@ -107,12 +107,11 @@ class TestLDBCCases extends SparkFunSuite with GraphFrameTestSparkContext { props.getProperty(s"graph.${LDBCUtils.TEST_CDLP_UNDIRECTED}.cdlp.max-iterations").toInt) } - // TODO: add graphframes after finishing #564 - Seq("graphx").foreach { algo => + Seq("graphx", "graphframes").foreach { algo => test(s"test undirected CDLP with LDBC for algo ${algo}") { - // Remove it after #571 or after removing GraphX at all - if ((algo == "graphx") && (scala.util.Properties.versionNumberString.startsWith("2.12"))) { - cancel("GraphX based implementation is broken in 2.12, see #571") + // I have no idea how to write it so it will work + if (scala.util.Properties.versionNumberString.startsWith("2.12")) { + cancel("CDLP implementations are broken in 2.12, see #571") } val testCase = ldbcTestCDLPUndirected val cdlpResults = testCase._1.labelPropagation.maxIter(testCase._3).run()