From 2bca985678fba457b7cee31da4d22fddb5b33126 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 22 Mar 2025 20:48:35 +0100 Subject: [PATCH 1/6] Add an experimental implementation of SP in GF --- .gitignore | 5 + .../org/graphframes/lib/ShortestPaths.scala | 121 +++++++++++++++++- .../graphframes/lib/ShortestPathsSuite.scala | 53 ++++++++ 3 files changed, 175 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 4d0a174e7..8554743bd 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,8 @@ project/boot/ project/plugins/project/ .bsp .metals +metals.sbt +.bloop # intellij .idea/ @@ -25,6 +27,9 @@ project/plugins/project/ # VSCode .vscode +# Helix +.helix + # Mac *.DS_Store diff --git a/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/src/main/scala/org/graphframes/lib/ShortestPaths.scala index 0d0b52407..d8cde196a 100644 --- a/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -24,10 +24,12 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.graphx.{lib => graphxlib} import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.api.java.UDF1 -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.{col, udf, map, lit, when, map_zip_with, reduce, map_values, transform_values, collect_list} import org.apache.spark.sql.types.{IntegerType, MapType} import org.graphframes.GraphFrame +import org.apache.spark.annotation.Experimental +import org.graphframes.Logging /** * Computes shortest paths from every vertex to the given set of landmark vertices. Note that this @@ -39,7 +41,10 @@ import org.graphframes.GraphFrame * shortest-path distance to each reachable landmark vertex. */ class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends Arguments { + import org.graphframes.lib.ShortestPaths._ + private var lmarks: Option[Seq[Any]] = None + private var algorithm: String = ALGO_GRAPHX /** * The list of landmark vertex ids. Shortest paths will be computed to each landmark. @@ -57,14 +62,25 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends landmarks(value.asScala.toSeq) } + def setAlgorithm(value: String): this.type = { + require( + supportedAlgorithms.contains(value), + s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $value.)") + algorithm = value + this + } + def run(): DataFrame = { - ShortestPaths.run(graph, check(lmarks, "landmarks")) + algorithm match { + case ALGO_GRAPHX => runInGraphX(graph, check(lmarks, "landmarks")) + case ALGO_GRAPHFRAMES => runInGraphFrames(graph, check(lmarks, "landmarks")) + } } } -private object ShortestPaths { +private object ShortestPaths extends Logging { - private def run(graph: GraphFrame, landmarks: Seq[Any]): DataFrame = { + private def runInGraphX(graph: GraphFrame, landmarks: Seq[Any]): DataFrame = { val idType = graph.vertices.schema(GraphFrame.ID).dataType val longIdToLandmark = landmarks.map(l => GraphXConversions.integralId(graph, l) -> l).toMap val gx = graphxlib.ShortestPaths @@ -94,6 +110,103 @@ private object ShortestPaths { g.vertices.select(cols.toSeq: _*) } + @Experimental + private def runInGraphFrames( + graph: GraphFrame, + landmarks: Seq[Any], + isDirected: Boolean = true): DataFrame = { + logWarn("The GraphFrames based implementation is slow and considered experimental!") + val vertexType = graph.vertices.schema(GraphFrame.ID).dataType + + // For landmark vertices the initial distance to itself is set to 0 + def initDistancesMap(vertexId: Column): Column = { + var initCol = when(vertexId === lit(landmarks.head), map(lit(landmarks.head), lit(0))) + for (lmark <- landmarks.tail) { + initCol = initCol.when(vertexId === lit(lmark), map(lit(lmark), lit(0))) + } + initCol + } + + // Concatenations of two distance maps: + // If one map is null just take another. + // In case both maps are not null: + // - iterate over keys + // - if value in the left map is null or greater than value from the right map take right one + // else take left one + def concatMaps(distancesLeft: Column, distancesRight: Column): Column = + when(distancesLeft.isNull, distancesRight) + .when(distancesRight.isNull, distancesLeft) + .otherwise(map_zip_with( + distancesLeft, + distancesRight, + (_, leftDistance, rightDistance) => { + when(leftDistance.isNull || (leftDistance > rightDistance), rightDistance) + .otherwise(leftDistance) + })) + + // If distance is null, result of d + 1 will be null too + def incrementDistances(distancesMap: Column): Column = + transform_values(distancesMap, (_, distance) => distance + lit(1)) + + // Takes an array of distance maps and reduce them with concatMaps + def aggregateArrayOfDistanceMaps(arrayCol: Column): Column = + reduce(arrayCol, lit(null).cast(MapType(vertexType, IntegerType)), concatMaps) + + // Checks that a sended distances map can change the destination distances. + // Evaluation would be "true" in case in the new distances map + // for one of keys present a non null value but in the old distances map it is null + // or new distance is less than old one. + def isDistanceImprovedWithMessage(newMap: Column, oldMap: Column): Column = reduce( + map_values( + map_zip_with( + newMap, + oldMap, + (_, newDistance, rightDistance) => + (newDistance.isNotNull && rightDistance.isNull) || (newDistance < rightDistance))), + lit(false), + (left, right) => left || right) + + // Overall: + // 1. Initialize distances + // 2. If new message can improve distances send it + // 3. Collect and aggregate messages + val pregel = graph.pregel + // TODO: set maxIter to Int.MaxValue and earlyStopping = true after merging #550 + .setMaxIter(15) + .withVertexColumn( + DISTANCE_ID, + when(col(GraphFrame.ID).isInCollection(landmarks), initDistancesMap(col(GraphFrame.ID))) + .otherwise(map().cast(MapType(vertexType, IntegerType))), + concatMaps(col(DISTANCE_ID), Pregel.msg)) + .sendMsgToSrc( + when( + isDistanceImprovedWithMessage( + incrementDistances(Pregel.dst(DISTANCE_ID)), + Pregel.src(DISTANCE_ID)), + incrementDistances(Pregel.dst(DISTANCE_ID)))) + .aggMsgs(aggregateArrayOfDistanceMaps(collect_list(Pregel.msg))) + + // Experimental feature + if (isDirected) { + pregel.run() + } else { + // For consider edges as undireceted, + // it is enough to send messages in both directions + pregel + .sendMsgToDst( + when( + isDistanceImprovedWithMessage( + incrementDistances(Pregel.src(DISTANCE_ID)), + Pregel.dst(DISTANCE_ID)), + incrementDistances(Pregel.src(DISTANCE_ID)))) + .run() + } + + } + private val DISTANCE_ID = "distances" + private val ALGO_GRAPHX = "graphx" + private val ALGO_GRAPHFRAMES = "graphframes" + val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES) } diff --git a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala index 69899ed9a..c82c24ba6 100644 --- a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala @@ -57,6 +57,39 @@ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(results.toSet === shortestPaths) } + test("Simple test with GraphFrames") { + val edgeSeq = Seq((1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)) + .flatMap { case e => + Seq(e, e.swap) + } + .map { case (src, dst) => (src.toLong, dst.toLong) } + val edges = spark.createDataFrame(edgeSeq).toDF("src", "dst") + val graph = GraphFrame.fromEdges(edges) + + // Ground truth + val shortestPaths = Set( + (1, Map(1 -> 0, 4 -> 2)), + (2, Map(1 -> 1, 4 -> 2)), + (3, Map(1 -> 2, 4 -> 1)), + (4, Map(1 -> 2, 4 -> 0)), + (5, Map(1 -> 1, 4 -> 1)), + (6, Map(1 -> 3, 4 -> 1))) + + val landmarks = Seq(1, 4).map(_.toLong) + val v2 = graph.shortestPaths.landmarks(landmarks).setAlgorithm("graphframes").run() + + TestUtils.testSchemaInvariants(graph, v2) + TestUtils.checkColumnType( + v2.schema, + "distances", + DataTypes.createMapType(v2.schema("id").dataType, DataTypes.IntegerType, true)) + val newVs = v2.select("id", "distances").collect().toSeq + val results = newVs.map { case Row(id: Long, spMap: Map[Long, Int] @unchecked) => + (id, spMap) + } + assert(results.toSet === shortestPaths) + } + test("friends graph") { val friends = examples.Graphs.friends val v = friends.shortestPaths.landmarks(Seq("a", "d")).run() @@ -78,4 +111,24 @@ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(results === expected) } + test("friends graph with GraphFrames") { + val friends = examples.Graphs.friends + val v = friends.shortestPaths.landmarks(Seq("a", "d")).setAlgorithm("graphframes").run() + val expected = Set[(String, Map[String, Int])]( + ("a", Map("a" -> 0, "d" -> 2)), + ("b", Map.empty), + ("c", Map.empty), + ("d", Map("a" -> 1, "d" -> 0)), + ("e", Map("a" -> 2, "d" -> 1)), + ("f", Map.empty), + ("g", Map.empty)) + val results = v + .select("id", "distances") + .collect() + .map { case Row(id: String, spMap: Map[String, Int] @unchecked) => + (id, spMap) + } + .toSet + assert(results === expected) + } } From 0c5c5397ee87e5e10c231fa6bd7cf40cfae688b1 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 23 Mar 2025 17:19:35 +0100 Subject: [PATCH 2/6] From comments --- .../graphframes/lib/ConnectedComponents.scala | 51 ++------------- .../org/graphframes/lib/ShortestPaths.scala | 62 +++++++++---------- src/main/scala/org/graphframes/mixins.scala | 18 ++++++ 3 files changed, 54 insertions(+), 77 deletions(-) create mode 100644 src/main/scala/org/graphframes/mixins.scala diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 1713c58da..486000b41 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DecimalType import org.apache.spark.storage.StorageLevel +import org.graphframes.WithAlgorithmChoice /** * Connected components algorithm. @@ -40,7 +41,8 @@ import org.apache.spark.storage.StorageLevel */ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) extends Arguments - with Logging { + with Logging + with WithAlgorithmChoice { import org.graphframes.lib.ConnectedComponents._ @@ -71,34 +73,6 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) */ def getBroadcastThreshold: Int = broadcastThreshold - private var algorithm: String = ALGO_GRAPHFRAMES - - /** - * Sets the connected components algorithm to use (default: "graphframes"). Supported algorithms - * are: - * - "graphframes": Uses alternating large star and small star iterations proposed in - * [[http://dx.doi.org/10.1145/2670979.2670997 Connected Components in MapReduce and Beyond]] - * with skewed join optimization. - * - "graphx": Converts the graph to a GraphX graph and then uses the connected components - * implementation in GraphX. - * @see - * [[org.graphframes.lib.ConnectedComponents.supportedAlgorithms]] - */ - def setAlgorithm(value: String): this.type = { - require( - supportedAlgorithms.contains(value), - s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $value.") - algorithm = value - this - } - - /** - * Gets the connected component algorithm to use. - * @see - * [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]. - */ - def getAlgorithm: String = algorithm - private var checkpointInterval: Int = 2 /** @@ -159,7 +133,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) def run(): DataFrame = { ConnectedComponents.run( graph, - algorithm = algorithm, + runInGraphX = algorithm == ALGO_GRAPHX, broadcastThreshold = broadcastThreshold, checkpointInterval = checkpointInterval, intermediateStorageLevel = intermediateStorageLevel) @@ -176,15 +150,6 @@ object ConnectedComponents extends Logging { private val CNT = "cnt" private val CHECKPOINT_NAME_PREFIX = "connected-components" - private val ALGO_GRAPHX = "graphx" - private val ALGO_GRAPHFRAMES = "graphframes" - - /** - * Supported algorithms in [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]: - * "graphframes" and "graphx". - */ - val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES) - /** * Returns the symmetric directed graph of the graph specified by input edges. * @param ee @@ -279,15 +244,11 @@ object ConnectedComponents extends Logging { private def run( graph: GraphFrame, - algorithm: String, + runInGraphX: Boolean, broadcastThreshold: Int, checkpointInterval: Int, intermediateStorageLevel: StorageLevel): DataFrame = { - require( - supportedAlgorithms.contains(algorithm), - s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $algorithm.") - - if (algorithm == ALGO_GRAPHX) { + if (runInGraphX) { return runGraphX(graph) } diff --git a/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/src/main/scala/org/graphframes/lib/ShortestPaths.scala index d8cde196a..58dec7f4e 100644 --- a/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.functions.{col, udf, map, lit, when, map_zip_with, r import org.apache.spark.sql.types.{IntegerType, MapType} import org.graphframes.GraphFrame -import org.apache.spark.annotation.Experimental import org.graphframes.Logging +import org.graphframes.WithAlgorithmChoice /** * Computes shortest paths from every vertex to the given set of landmark vertices. Note that this @@ -40,11 +40,12 @@ import org.graphframes.Logging * - distances (`MapType[vertex ID type, IntegerType]`): For each vertex v, a map containing the * shortest-path distance to each reachable landmark vertex. */ -class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends Arguments { +class ShortestPaths private[graphframes] (private val graph: GraphFrame) + extends Arguments + with WithAlgorithmChoice { import org.graphframes.lib.ShortestPaths._ private var lmarks: Option[Seq[Any]] = None - private var algorithm: String = ALGO_GRAPHX /** * The list of landmark vertex ids. Shortest paths will be computed to each landmark. @@ -59,21 +60,14 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends * The list of landmark vertex ids. Shortest paths will be computed to each landmark. */ def landmarks(value: util.ArrayList[Any]): this.type = { - landmarks(value.asScala.toSeq) - } - - def setAlgorithm(value: String): this.type = { - require( - supportedAlgorithms.contains(value), - s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $value.)") - algorithm = value - this + landmarks(value.asScala) } def run(): DataFrame = { + val lmarksChecked = check(lmarks, "landmarks") algorithm match { - case ALGO_GRAPHX => runInGraphX(graph, check(lmarks, "landmarks")) - case ALGO_GRAPHFRAMES => runInGraphFrames(graph, check(lmarks, "landmarks")) + case ALGO_GRAPHX => runInGraphX(graph, lmarksChecked) + case ALGO_GRAPHFRAMES => runInGraphFrames(graph, lmarksChecked) } } } @@ -110,7 +104,6 @@ private object ShortestPaths extends Logging { g.vertices.select(cols.toSeq: _*) } - @Experimental private def runInGraphFrames( graph: GraphFrame, landmarks: Seq[Any], @@ -119,6 +112,17 @@ private object ShortestPaths extends Logging { val vertexType = graph.vertices.schema(GraphFrame.ID).dataType // For landmark vertices the initial distance to itself is set to 0 + // Example: graph with vertices a, b, c, d; landmarks = (c, d) + // we shoudl init the following: + // (a, Map()), (b, Map()), (c, Map(c -> 0)), (d, Map(d -> 0)) + // + // Inside the following function it is done by applying multiple case-when + // because we know exactly that only one landmark could be equal to the nodeId. + // For example, for vertex c it will be: + // when(id == "a", Map(a -> 0)) + // .when(id == "b", Map(b -> 0)) + // .when(id == "c", Map(c -> 0)) --> this one is the only true + // .when(id == "d", Map(d -> 0)) def initDistancesMap(vertexId: Column): Column = { var initCol = when(vertexId === lit(landmarks.head), map(lit(landmarks.head), lit(0))) for (lmark <- landmarks.tail) { @@ -152,9 +156,9 @@ private object ShortestPaths extends Logging { def aggregateArrayOfDistanceMaps(arrayCol: Column): Column = reduce(arrayCol, lit(null).cast(MapType(vertexType, IntegerType)), concatMaps) - // Checks that a sended distances map can change the destination distances. + // Checks that a sent distances map can change the destination distances. // Evaluation would be "true" in case in the new distances map - // for one of keys present a non null value but in the old distances map it is null + // for one of keys present a non-null value but in the old distances map it is null // or new distance is less than old one. def isDistanceImprovedWithMessage(newMap: Column, oldMap: Column): Column = reduce( map_values( @@ -166,6 +170,9 @@ private object ShortestPaths extends Logging { lit(false), (left, right) => left || right) + val srcDistanceCol = Pregel.src(DISTANCE_ID) + val dstDistanceCol = Pregel.dst(DISTANCE_ID) + // Overall: // 1. Initialize distances // 2. If new message can improve distances send it @@ -178,35 +185,26 @@ private object ShortestPaths extends Logging { when(col(GraphFrame.ID).isInCollection(landmarks), initDistancesMap(col(GraphFrame.ID))) .otherwise(map().cast(MapType(vertexType, IntegerType))), concatMaps(col(DISTANCE_ID), Pregel.msg)) - .sendMsgToSrc( - when( - isDistanceImprovedWithMessage( - incrementDistances(Pregel.dst(DISTANCE_ID)), - Pregel.src(DISTANCE_ID)), - incrementDistances(Pregel.dst(DISTANCE_ID)))) + .sendMsgToSrc(when( + isDistanceImprovedWithMessage(incrementDistances(dstDistanceCol), srcDistanceCol), + incrementDistances(dstDistanceCol))) .aggMsgs(aggregateArrayOfDistanceMaps(collect_list(Pregel.msg))) // Experimental feature if (isDirected) { pregel.run() } else { - // For consider edges as undireceted, + // For consider edges as undirected, // it is enough to send messages in both directions pregel .sendMsgToDst( when( - isDistanceImprovedWithMessage( - incrementDistances(Pregel.src(DISTANCE_ID)), - Pregel.dst(DISTANCE_ID)), - incrementDistances(Pregel.src(DISTANCE_ID)))) + isDistanceImprovedWithMessage(incrementDistances(srcDistanceCol), dstDistanceCol), + incrementDistances(srcDistanceCol))) .run() } } private val DISTANCE_ID = "distances" - private val ALGO_GRAPHX = "graphx" - private val ALGO_GRAPHFRAMES = "graphframes" - - val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES) } diff --git a/src/main/scala/org/graphframes/mixins.scala b/src/main/scala/org/graphframes/mixins.scala new file mode 100644 index 000000000..88a348bb0 --- /dev/null +++ b/src/main/scala/org/graphframes/mixins.scala @@ -0,0 +1,18 @@ +package org.graphframes + +private[graphframes] trait WithAlgorithmChoice { + protected val ALGO_GRAPHX = "graphx" + protected val ALGO_GRAPHFRAMES = "graphframes" + protected var algorithm: String = ALGO_GRAPHX + val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES) + + def setAlgorithm(value: String): this.type = { + require( + supportedAlgorithms.contains(value), + s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $value.") + algorithm = value + this + } + + def getAlgorithm: String = algorithm +} From 106d450e62b9adc5128ddedf89072b77d30d7101 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 23 Mar 2025 17:31:13 +0100 Subject: [PATCH 3/6] Revert removing toSeq and fix head --- src/main/scala/org/graphframes/lib/ShortestPaths.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/src/main/scala/org/graphframes/lib/ShortestPaths.scala index 58dec7f4e..04290c582 100644 --- a/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -60,7 +60,7 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame) * The list of landmark vertex ids. Shortest paths will be computed to each landmark. */ def landmarks(value: util.ArrayList[Any]): this.type = { - landmarks(value.asScala) + landmarks(value.asScala.toSeq) } def run(): DataFrame = { @@ -124,7 +124,8 @@ private object ShortestPaths extends Logging { // .when(id == "c", Map(c -> 0)) --> this one is the only true // .when(id == "d", Map(d -> 0)) def initDistancesMap(vertexId: Column): Column = { - var initCol = when(vertexId === lit(landmarks.head), map(lit(landmarks.head), lit(0))) + val firstLmarkCol = lit(landmarks.head) + var initCol = when(vertexId === firstLmarkCol, map(firstLmarkCol, lit(0))) for (lmark <- landmarks.tail) { initCol = initCol.when(vertexId === lit(lmark), map(lit(lmark), lit(0))) } From f1a3d1c8726080c7b760c8cb109aeae52e9478da Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 23 Mar 2025 17:41:06 +0100 Subject: [PATCH 4/6] Fix failing tests --- src/main/scala/org/graphframes/lib/ConnectedComponents.scala | 1 + .../scala/org/graphframes/lib/ConnectedComponentsSuite.scala | 3 +++ 2 files changed, 4 insertions(+) diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 486000b41..468832bd3 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -47,6 +47,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) import org.graphframes.lib.ConnectedComponents._ private var broadcastThreshold: Int = 1000000 + setAlgorithm(ALGO_GRAPHFRAMES) /** * Sets broadcast threshold in propagating component assignments (default: 1000000). If a node diff --git a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index f55ae4edd..5cd5074fc 100644 --- a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -186,6 +186,8 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon assert(checkpointDir.nonEmpty) sc.setCheckpointDir(null) + val oldCheckpointDir = spark.conf.get("spark.checkpoint.dir") + spark.conf.unset("spark.checkpoint.dir") withClue( "Should throw an IOException if sc.getCheckpointDir is empty " + "and checkpointInterval is positive.") { @@ -193,6 +195,7 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon cc.run() } } + spark.conf.set("spark.checkpoint.dir", oldCheckpointDir) // Checks whether the input DataFrame is from some checkpoint data. // TODO: The implemetnation is a little hacky. From f4eebe8ddce3607565ae364dc4566c970507aafb Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 23 Mar 2025 17:59:38 +0100 Subject: [PATCH 5/6] Fix failing tests --- .../scala/org/graphframes/lib/ConnectedComponentsSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index 5cd5074fc..f55ae4edd 100644 --- a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -186,8 +186,6 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon assert(checkpointDir.nonEmpty) sc.setCheckpointDir(null) - val oldCheckpointDir = spark.conf.get("spark.checkpoint.dir") - spark.conf.unset("spark.checkpoint.dir") withClue( "Should throw an IOException if sc.getCheckpointDir is empty " + "and checkpointInterval is positive.") { @@ -195,7 +193,6 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon cc.run() } } - spark.conf.set("spark.checkpoint.dir", oldCheckpointDir) // Checks whether the input DataFrame is from some checkpoint data. // TODO: The implemetnation is a little hacky. From e91d14f7a4d764189ab33e3aade247399559ab0a Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Tue, 1 Apr 2025 18:10:26 +0200 Subject: [PATCH 6/6] Add early stopping --- src/main/scala/org/graphframes/lib/ShortestPaths.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/src/main/scala/org/graphframes/lib/ShortestPaths.scala index 04290c582..687236f7e 100644 --- a/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -179,8 +179,7 @@ private object ShortestPaths extends Logging { // 2. If new message can improve distances send it // 3. Collect and aggregate messages val pregel = graph.pregel - // TODO: set maxIter to Int.MaxValue and earlyStopping = true after merging #550 - .setMaxIter(15) + .setMaxIter(Int.MaxValue) // That is how the GraphX implementation works .withVertexColumn( DISTANCE_ID, when(col(GraphFrame.ID).isInCollection(landmarks), initDistancesMap(col(GraphFrame.ID))) @@ -190,6 +189,7 @@ private object ShortestPaths extends Logging { isDistanceImprovedWithMessage(incrementDistances(dstDistanceCol), srcDistanceCol), incrementDistances(dstDistanceCol))) .aggMsgs(aggregateArrayOfDistanceMaps(collect_list(Pregel.msg))) + .setEarlyStopping(true) // Experimental feature if (isDirected) {