From 184f45a28ee7a92414894d2088da668f4da40fe1 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Fri, 12 Sep 2025 06:58:01 +0200 Subject: [PATCH] Performance improvement in GraphX CDLP --- .../org/graphframes/ldbc/TestLDBCCases.scala | 4 --- .../graphx/lib/LabelPropagation.scala | 27 ++++++------------- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/core/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala b/core/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala index 00d23cb1f..adaa014e5 100644 --- a/core/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala +++ b/core/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala @@ -117,10 +117,6 @@ class TestLDBCCases extends SparkFunSuite with GraphFrameTestSparkContext { Seq("graphx", "graphframes").foreach { algo => test(s"test undirected CDLP with LDBC for algo ${algo}") { - // I have no idea how to write it so it will work - if (scala.util.Properties.versionNumberString.startsWith("2.12") && algo == "graphx") { - cancel("CDLP implementations are broken in 2.12, see #571") - } val testCase = ldbcTestCDLPUndirected val cdlpResults = testCase._1.labelPropagation.setAlgorithm(algo).maxIter(testCase._3).run() assert(cdlpResults.count() == testCase._1.vertices.count()) diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/LabelPropagation.scala index 327fc6f64..351a8a068 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/LabelPropagation.scala @@ -19,8 +19,6 @@ package org.apache.spark.graphframes.graphx.lib import org.apache.spark.graphframes.graphx._ import scala.annotation.nowarn -import scala.collection.Map -import scala.collection.mutable import scala.reflect.ClassTag /** Label Propagation algorithm. */ @@ -53,26 +51,17 @@ object LabelPropagation { require(maxSteps > 0, s"Maximum of steps must be greater than 0, but got ${maxSteps}") val lpaGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, Long])] = { - Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) - } - def mergeMessage( - count1: Map[VertexId, Long], - count2: Map[VertexId, Long]): Map[VertexId, Long] = { - // Mimics the optimization of breakOut, not present in Scala 2.13, while working in 2.12 - val map = mutable.Map[VertexId, Long]() - (count1.keySet ++ count2.keySet).foreach { i => - val count1Val = count1.getOrElse(i, 0L) - val count2Val = count2.getOrElse(i, 0L) - map.put(i, count1Val + count2Val) - } - map + def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Vector[Long])] = { + Iterator((e.srcId, Vector(e.dstAttr)), (e.dstId, Vector(e.srcAttr))) } + def mergeMessage(left: Vector[Long], right: Vector[Long]): Vector[Long] = left ++ right + @nowarn - def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]): VertexId = { - if (message.isEmpty) attr else message.maxBy(_._2)._1 + def vertexProgram(vid: VertexId, attr: Long, message: Vector[Long]): Long = { + if (message.isEmpty) attr + else message.groupBy(f => f).map(f => (f._1, f._2.size)).maxBy(f => (f._2, -f._1))._1 } - val initialMessage = Map[VertexId, Long]() + val initialMessage = Vector[Long]() Pregel(lpaGraph, initialMessage, maxIterations = maxSteps)( vprog = vertexProgram, sendMsg = sendMessage,