From a390a0b2a63c8a78d72c8e2aa3d86d7a41644836 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Wed, 17 Sep 2025 09:16:43 +0200 Subject: [PATCH 1/3] improvements in GF-based algorithms performance --- .../main/scala/org/graphframes/Logging.scala | 4 + .../graphframes/lib/ConnectedComponents.scala | 164 +++--- .../graphframes/lib/LabelPropagation.scala | 9 +- .../scala/org/graphframes/lib/PageRank.scala | 17 +- .../lib/ParallelPersonalizedPageRank.scala | 16 +- .../scala/org/graphframes/lib/Pregel.scala | 64 ++- .../org/graphframes/lib/SVDPlusPlus.scala | 10 +- .../org/graphframes/lib/ShortestPaths.scala | 4 +- .../lib/StronglyConnectedComponents.scala | 8 +- .../main/scala/org/graphframes/mixins.scala | 21 +- .../lib/ConnectedComponentsSuite.scala | 483 +++++++++--------- 11 files changed, 432 insertions(+), 368 deletions(-) diff --git a/core/src/main/scala/org/graphframes/Logging.scala b/core/src/main/scala/org/graphframes/Logging.scala index 594178ba3..63054f9d6 100644 --- a/core/src/main/scala/org/graphframes/Logging.scala +++ b/core/src/main/scala/org/graphframes/Logging.scala @@ -40,4 +40,8 @@ private[org] trait Logging { protected def logTrace(s: => String): Unit = { if (logger.isTraceEnabled) logger.trace(s) } + + protected def resultIsPersistent(): Unit = { + logWarn("Returned DataFrame is persistent and materialized!") + } } diff --git a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 6d90ed346..6f81904b6 100644 --- a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -17,7 +17,6 @@ package org.graphframes.lib -import org.apache.hadoop.fs.Path import org.apache.spark.graphframes.graphx import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrame @@ -94,7 +93,6 @@ object ConnectedComponents extends Logging { private val ORIG_ID = "orig_id" private val MIN_NBR = "min_nbr" private val CNT = "cnt" - private val CHECKPOINT_NAME_PREFIX = "connected-components" /** * Returns the symmetric directed graph of the graph specified by input edges. @@ -141,20 +139,36 @@ object ConnectedComponents extends Logging { * - `min_nbr`, the min vertex ID among itself and its neighbors * - `cnt`, the total number of neighbors */ - private def minNbrs(ee: DataFrame): DataFrame = { - symmetrize(ee) - .groupBy(SRC) - .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) - .withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR))) + private def minNbrs( + ee: DataFrame, + computeCount: Boolean, + includeSelf: Boolean, + doSymmetrize: Boolean): DataFrame = { + val ee2 = if (doSymmetrize) { + symmetrize(ee) + } else { + ee + } + val res = if (computeCount) { + ee2 + .groupBy(SRC) + .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) + } else { + ee2 + .groupBy(SRC) + .agg(min(col(DST)).as(MIN_NBR)) + } + if (includeSelf) { + res.withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR))) + } else { res } } private def minValue(x: Column, y: Column): Column = { when(x < y, x).otherwise(y) } - private def maxValue(x: Column, y: Column): Column = { + private def maxValue(x: Column, y: Column): Column = when(x > y, x).otherwise(y) - } /** * Performs a possibly skewed join between edges and current component assignments. The skew @@ -207,40 +221,30 @@ object ConnectedComponents extends Logging { } val spark = graph.spark - val sc = spark.sparkContext // Store original AQE setting val originalAQE = spark.conf.get("spark.sql.adaptive.enabled") try { - spark.conf.set("spark.sql.adaptive.enabled", "false") + val shouldDoSkewedJoin = broadcastThreshold != -1 + + if (shouldDoSkewedJoin) { + spark.conf.set("spark.sql.adaptive.enabled", "false") + } val runId = UUID.randomUUID().toString.takeRight(8) val logPrefix = s"[CC $runId]" logInfo(s"$logPrefix Start connected components with run ID $runId.") val shouldCheckpoint = checkpointInterval > 0 - val checkpointDir: Option[String] = if (useLocalCheckpoints) { None } - else if (shouldCheckpoint) { - val dir = sc.getCheckpointDir - .map { d => - new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString - } - .getOrElse { - // Spark-Connect workaround - spark.conf.getOption("spark.checkpoint.dir") match { - case Some(d) => new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString - case None => - throw new IOException( - "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir()" + - "or by specifying the conf 'spark.checkpoint.dir'.") - } - } - logInfo(s"$logPrefix Using $dir for checkpointing with interval $checkpointInterval.") - Some(dir) - } else { - logInfo( - s"$logPrefix Checkpointing is disabled because checkpointInterval=$checkpointInterval.") - None + if (shouldCheckpoint && !useLocalCheckpoints && spark.sparkContext.getCheckpointDir.isEmpty) { + // Spark-Connect workaround + spark.sparkContext.setCheckpointDir(spark.conf.getOption("spark.checkpoint.dir") match { + case Some(d) => d + case None => + throw new IOException( + "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir()" + + "or by specifying the conf 'spark.checkpoint.dir'.") + }) } logInfo(s"$logPrefix Preparing the graph for connected component computation ...") @@ -253,29 +257,24 @@ object ConnectedComponents extends Logging { var iteration = 1 def _calcMinNbrSum(minNbrsDF: DataFrame): BigDecimal = { - // Taking the sum in DecimalType to preserve precision. - // We use 20 digits for long values and Spark SQL will add 10 digits for the sum. - // It should be able to handle 200 billion edges without overflow. - val (minNbrSum, cnt) = minNbrsDF - .select(sum(col(MIN_NBR).cast(DecimalType(20, 0))), count("*")) - .rdd - .map { r => - (r.getAs[BigDecimal](0), r.getLong(1)) - } - .first() - if (cnt != 0L && minNbrSum == null) { - throw new ArithmeticException(s""" - |The total sum of edge src IDs is used to determine convergence during iterations. - |However, the total sum at iteration $iteration exceeded 30 digits (1e30), - |which should happen only if the graph contains more than 200 billion edges. - |If not, please file a bug report at https://github.com/graphframes/graphframes/issues. - """.stripMargin) + if (shouldDoSkewedJoin) { + minNbrsDF + .select(sum(col(MIN_NBR).cast(DecimalType(38, 0))), count("*")) + .first() + .getDecimal(0) + } else { + minNbrsDF.select(sum(col(MIN_NBR).cast(DecimalType(38, 0)))).first().getDecimal(0) } - minNbrSum } // compute min neighbors (including self-min) - var minNbrs1: DataFrame = minNbrs(ee) // src >= min_nbr - .persist(intermediateStorageLevel) + var minNbrs1: DataFrame = + minNbrs( + ee, + computeCount = shouldDoSkewedJoin, + includeSelf = true, + doSymmetrize = true + ) // src >= min_nbr + .persist(intermediateStorageLevel) var prevSum: BigDecimal = _calcMinNbrSum(minNbrs1) @@ -284,23 +283,37 @@ object ConnectedComponents extends Logging { var currRoundPersistedDFs = Seq[DataFrame]() // large-star step // connect all strictly larger neighbors to the min neighbor (including self) - ee = skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix) - .select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst + ee = { + if (shouldDoSkewedJoin) { + skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix) + + } else { + ee.join(minNbrs1, SRC) + } + }.select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst .distinct() .persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ ee // small-star step // compute min neighbors (excluding self-min) - val minNbrs2 = ee - .groupBy(col(SRC)) - .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr + val minNbrs2 = minNbrs( + ee, + computeCount = shouldDoSkewedJoin, + includeSelf = false, + doSymmetrize = false) .persist(intermediateStorageLevel) currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs2 // connect all smaller neighbors to the min neighbor - ee = skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix) - .select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst + ee = { + if (shouldDoSkewedJoin) { + skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix) + } else { + ee.join(minNbrs2, SRC) + } + }.select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst .filter(col(SRC) =!= col(DST)) // src < dst // connect self to the min neighbor ee = ee @@ -312,26 +325,21 @@ object ConnectedComponents extends Logging { if (useLocalCheckpoints) { ee = ee.localCheckpoint(eager = true) } else { - // TODO: remove this after DataFrame.checkpoint is implemented - val out = s"${checkpointDir.get}/$iteration" - ee.write.parquet(out) - // may hit S3 eventually consistent issue - ee = spark.read.parquet(out) - - // remove previous checkpoint - if (iteration > checkpointInterval) { - val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}") - path.getFileSystem(sc.hadoopConfiguration).delete(path, true) - } - - System.gc() // hint Spark to clean shuffle directories + ee = ee.checkpoint(eager = true) } + currRoundPersistedDFs = currRoundPersistedDFs :+ ee + } else { + // Checkpointing includes persist under the hood, no needs to do it again + ee = ee.persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ ee } - ee.persist(intermediateStorageLevel) - currRoundPersistedDFs = currRoundPersistedDFs :+ ee - - minNbrs1 = minNbrs(ee) // src >= min_nbr + minNbrs1 = minNbrs( + ee, + computeCount = shouldDoSkewedJoin, + includeSelf = true, + doSymmetrize = true + ) // src >= min_nbr .persist(intermediateStorageLevel) currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1 @@ -386,7 +394,7 @@ object ConnectedComponents extends Logging { persisted_df.unpersist() } - logWarn("The DataFrame returned by ConnectedComponents is persisted and loaded.") + resultIsPersistent() output } finally { diff --git a/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala b/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala index ef1e123d8..2b6efd7da 100644 --- a/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala +++ b/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.MapType import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame +import org.graphframes.Logging import org.graphframes.WithAlgorithmChoice import org.graphframes.WithCheckpointInterval import org.graphframes.WithLocalCheckpoints @@ -49,11 +50,12 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame) with WithAlgorithmChoice with WithCheckpointInterval with WithMaxIter - with WithLocalCheckpoints { + with WithLocalCheckpoints + with Logging { def run(): DataFrame = { val maxIterChecked = check(maxIter, "maxIter") - algorithm match { + val res = algorithm match { case "graphx" => LabelPropagation.runInGraphX(graph, maxIterChecked) case "graphframes" => LabelPropagation.runInGraphFrames( @@ -62,6 +64,8 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame) checkpointInterval, useLocalCheckpoints = useLocalCheckpoints) } + resultIsPersistent() + res } } @@ -127,5 +131,4 @@ private object LabelPropagation { } private val LABEL_ID = "label" - } diff --git a/core/src/main/scala/org/graphframes/lib/PageRank.scala b/core/src/main/scala/org/graphframes/lib/PageRank.scala index b0c851264..22ea38ce9 100644 --- a/core/src/main/scala/org/graphframes/lib/PageRank.scala +++ b/core/src/main/scala/org/graphframes/lib/PageRank.scala @@ -19,6 +19,7 @@ package org.graphframes.lib import org.apache.spark.graphframes.graphx.{lib => graphxlib} import org.graphframes.GraphFrame +import org.graphframes.Logging /** * PageRank algorithm implementation. There are two implementations of PageRank. @@ -63,7 +64,9 @@ import org.graphframes.GraphFrame * The resulting edges DataFrame contains one additional column: * - weight (`DoubleType`): the normalized weight of this edge after running PageRank */ -class PageRank private[graphframes] (private val graph: GraphFrame) extends Arguments { +class PageRank private[graphframes] (private val graph: GraphFrame) + extends Arguments + with Logging { private var tol: Option[Double] = None private var resetProb: Option[Double] = Some(0.15) @@ -93,13 +96,15 @@ class PageRank private[graphframes] (private val graph: GraphFrame) extends Argu } def run(): GraphFrame = { - tol match { + val res = tol match { case Some(t) => assert(maxIter.isEmpty, "You cannot specify maxIter() and tol() at the same time.") PageRank.runUntilConvergence(graph, t, resetProb.get, srcId) case None => PageRank.run(graph, check(maxIter, "maxIter"), resetProb.get, srcId) } + resultIsPersistent() + res } } @@ -129,7 +134,13 @@ private object PageRank { val longSrcId = srcId.map(GraphXConversions.integralId(graph, _)) val gx = graphxlib.PageRank.runWithOptions(graph.cachedTopologyGraphX, maxIter, resetProb, longSrcId) - GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANK), edgeNames = Seq(WEIGHT)) + val res = GraphXConversions + .fromGraphX(graph, gx, vertexNames = Seq(PAGERANK), edgeNames = Seq(WEIGHT)) + .persist() + res.vertices.count() + res.edges.count() + gx.unpersist() + res } /** diff --git a/core/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala b/core/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala index 7c8fea43a..2b5db23cd 100644 --- a/core/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala +++ b/core/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala @@ -19,6 +19,7 @@ package org.graphframes.lib import org.apache.spark.graphframes.graphx.{lib => graphxlib} import org.graphframes.GraphFrame +import org.graphframes.Logging import org.graphframes.WithMaxIter /** @@ -54,7 +55,8 @@ import org.graphframes.WithMaxIter */ class ParallelPersonalizedPageRank private[graphframes] (private val graph: GraphFrame) extends Arguments - with WithMaxIter { + with WithMaxIter + with Logging { private var resetProb: Option[Double] = Some(0.15) private var srcIds: Array[Any] = Array() @@ -74,7 +76,9 @@ class ParallelPersonalizedPageRank private[graphframes] (private val graph: Grap def run(): GraphFrame = { require(maxIter != None, "Max number of iterations maxIter() must be provided") require(srcIds.nonEmpty, "Source vertices Ids sourceIds() must be provided") - ParallelPersonalizedPageRank.run(graph, maxIter.get, resetProb.get, srcIds) + val res = ParallelPersonalizedPageRank.run(graph, maxIter.get, resetProb.get, srcIds) + resultIsPersistent() + res } } @@ -114,6 +118,12 @@ private object ParallelPersonalizedPageRank { maxIter, resetProb, longSrcIds) - GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANKS), edgeNames = Seq(WEIGHT)) + val gf = GraphXConversions + .fromGraphX(graph, gx, vertexNames = Seq(PAGERANKS), edgeNames = Seq(WEIGHT)) + .persist() + gf.vertices.count() + gf.edges.count() + gx.unpersist() + gf } } diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index e5926d53f..fed8b1687 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -331,13 +331,18 @@ class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints { updateExpr.as(colName) } + var lastRoundPersistent: scala.collection.mutable.Queue[DataFrame] = + scala.collection.mutable.Queue[DataFrame]() + var currentVertices = graph.vertices.select( (Seq( col("*"), initialActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ initVertexCols): _*) - var vertexUpdateColDF: DataFrame = null val edges = graph.edges + .select(col(SRC).alias("edge_src"), col(DST).alias("edge_dst"), struct(col("*")).as(EDGE)) + .repartition(col("edge_src"), col("edge_dst")) + .persist() var iteration = 1 @@ -357,12 +362,15 @@ class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints { breakable { while (iteration <= maxIter) { logInfo(s"start Pregel iteration $iteration / $maxIter") + val currRoundPersistent = scala.collection.mutable.Queue[DataFrame]() + currRoundPersistent.enqueue(currentVertices.persist()) var tripletsDF = currentVertices .select(struct(col("*")).as(SRC)) - .join(edges.select(struct(col("*")).as(EDGE)), Pregel.src(ID) === Pregel.edge(SRC)) + .join(edges, Pregel.src(ID) === col("edge_src")) .join( currentVertices.select(struct(col("*")).as(DST)), - Pregel.edge(DST) === Pregel.dst(ID)) + col("edge_dst") === Pregel.dst(ID)) + .drop(col("edge_src"), col("edge_dst")) if (skipMessagesFromNonActiveVertices) { tripletsDF = tripletsDF.filter( @@ -377,9 +385,10 @@ class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints { if (earlyStopping && msgDF.isEmpty) { logInfo( s"there are no more non-null messages; Pregel stops earlier at iteration $iteration") - if (vertexUpdateColDF != null) { - vertexUpdateColDF.unpersist() + while (lastRoundPersistent.nonEmpty) { + lastRoundPersistent.dequeue().unpersist() } + lastRoundPersistent = currRoundPersistent break() } @@ -389,43 +398,58 @@ class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints { val verticesWithMsg = currentVertices.join(newAggMsgDF, Seq(ID), "left_outer") - var newVertexUpdateColDF = verticesWithMsg.select( + currentVertices = verticesWithMsg.select( (Seq( col(ID), updateActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ updateVertexCols): _*) if (shouldCheckpoint && iteration % checkpointInterval == 0) { if (useLocalCheckpoints) { - newVertexUpdateColDF = newVertexUpdateColDF.localCheckpoint(eager = false) + currentVertices = currentVertices.localCheckpoint(eager = false) + currRoundPersistent.enqueue(currentVertices) } else { - // do checkpoint, use lazy checkpoint because later we will materialize this DF. - newVertexUpdateColDF = newVertexUpdateColDF.checkpoint(eager = false) - // TODO: remove last checkpoint file. + currentVertices = currentVertices.checkpoint(eager = false) + currRoundPersistent.enqueue(currentVertices) } + } else { + // checkpointing do persistence and we do not need to do it again + currRoundPersistent.enqueue(currentVertices.persist()) } - newVertexUpdateColDF.cache() - newVertexUpdateColDF.count() // materialize it - - if (vertexUpdateColDF != null) { - vertexUpdateColDF.unpersist() - } - vertexUpdateColDF = newVertexUpdateColDF - - currentVertices = graph.vertices.join(vertexUpdateColDF, ID) if (stopIfAllNonActiveVertices) { if (currentVertices.filter(col(Pregel.ACTIVE_FLAG_COL)).isEmpty) { logInfo( s"all the verties are non-active; Pregel stops earlier at iteration $iteration") + while (lastRoundPersistent.nonEmpty) { + lastRoundPersistent.dequeue().unpersist() + } + lastRoundPersistent = currRoundPersistent break() } } + if (!earlyStopping && !stopIfAllNonActiveVertices) { + // we need to call materialize + currentVertices.count() + } + + while (lastRoundPersistent.nonEmpty) { + lastRoundPersistent.dequeue().unpersist() + } + lastRoundPersistent = currRoundPersistent + iteration += 1 } } - currentVertices + val res = currentVertices.persist() + res.count() + while (lastRoundPersistent.nonEmpty) { + lastRoundPersistent.dequeue().unpersist() + } + edges.unpersist() + System.gc() + res } } diff --git a/core/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala b/core/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala index 093f8f0d8..97b445678 100644 --- a/core/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala +++ b/core/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.graphframes.GraphFrame import org.graphframes.GraphFramesUnreachableException +import org.graphframes.Logging import org.graphframes.WithMaxIter /** @@ -42,7 +43,8 @@ import org.graphframes.WithMaxIter */ class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) extends Arguments - with WithMaxIter { + with WithMaxIter + with Logging { private var _rank: Int = 10 private var _minVal: Double = 0.0 private var _maxVal: Double = 5.0 @@ -101,6 +103,7 @@ class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) val (df, l) = SVDPlusPlus.run(graph, conf) _loss = Some(l) + resultIsPersistent() df } @@ -122,7 +125,10 @@ object SVDPlusPlus { graph, gx, vertexNames = Seq(COLUMN1, COLUMN2, COLUMN3, COLUMN4)) - (gf.vertices, res) + val vertices = gf.vertices.persist() + vertices.count() + gx.unpersist() + (vertices, res) } /** diff --git a/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala index bd4df94e4..189a1c516 100644 --- a/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -80,7 +80,7 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame) def run(): DataFrame = { val lmarksChecked = check(lmarks, "landmarks") - algorithm match { + val res = algorithm match { case ALGO_GRAPHX => runInGraphX(graph, lmarksChecked) case ALGO_GRAPHFRAMES => runInGraphFrames( @@ -90,6 +90,8 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame) useLocalCheckpoints = useLocalCheckpoints) case _ => throw new GraphFramesUnreachableException() } + resultIsPersistent() + res } } diff --git a/core/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala b/core/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala index 2ee678f74..fbcd6242a 100644 --- a/core/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala +++ b/core/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala @@ -21,6 +21,7 @@ import org.apache.spark.graphframes.graphx.{lib => graphxlib} import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame +import org.graphframes.Logging import org.graphframes.WithMaxIter /** @@ -32,10 +33,13 @@ import org.graphframes.WithMaxIter */ class StronglyConnectedComponents private[graphframes] (private val graph: GraphFrame) extends Arguments - with WithMaxIter { + with WithMaxIter + with Logging { def run(): DataFrame = { - StronglyConnectedComponents.run(graph, check(maxIter, "maxIter")) + val res = StronglyConnectedComponents.run(graph, check(maxIter, "maxIter")) + resultIsPersistent() + res } } diff --git a/core/src/main/scala/org/graphframes/mixins.scala b/core/src/main/scala/org/graphframes/mixins.scala index d7c2f0e07..7fefc25df 100644 --- a/core/src/main/scala/org/graphframes/mixins.scala +++ b/core/src/main/scala/org/graphframes/mixins.scala @@ -66,14 +66,23 @@ private[graphframes] trait WithBroadcastThreshold extends Logging { protected var broadcastThreshold: Int = 1000000 /** - * Sets broadcast threshold in propagating component assignments (default: 1000000). If a node - * degree is greater than this threshold at some iteration, its component assignment will be - * collected and then broadcasted back to propagate the assignment to its neighbors. Otherwise, - * the assignment propagation is done by a normal Spark join. This parameter is only used when - * the algorithm is set to "graphframes". + * Sets a broadcast threshold in propagating component assignments (default: 1,000,000). If a + * node degree is greater than this threshold at some iteration, its component assignment will + * be collected and then broadcasted back to propagate the assignment to its neighbors. + * Otherwise, the assignment propagation is done by a normal Spark join. This parameter is only + * used when the algorithm is set to "graphframes". If the value is -1, then the skewness + * problem is left to the Apache Spark AQE optimizer. + * + * **WARNING** using a broadcast threshold is non-free! Under the hood it is calling an action, + * and if a broadcast threshold is set, then AQE is disabled to avoid wrong results! If your + * graph does not contain gigantic components, it is strongly recommended to set this value to + * -1. On benchmarks setting it to -1 gains about x5 better results in performance. + * + * **WARNING** the current default value is 1,000,000. It is left for backward compatibility + * only. In the future versions it may be set to -1 as more reasonable for the most real-world + * cases (e.g., the data deduplication problem). */ def setBroadcastThreshold(value: Int): this.type = { - require(value >= 0, s"Broadcast threshold must be non-negative but got $value.") broadcastThreshold = value this } diff --git a/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index e91d8bf1c..4e22c70e6 100644 --- a/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -28,210 +28,232 @@ import org.graphframes.GraphFrame._ import org.graphframes._ import org.graphframes.examples.Graphs -import java.io.IOException import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkContext { + test("default params") { + val g = Graphs.empty[Int] + val cc = g.connectedComponents + assert(cc.getAlgorithm === "graphframes") + assert(cc.getBroadcastThreshold === 1000000) + assert(cc.getCheckpointInterval === 2) + assert(!cc.getUseLabelsAsComponents) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } - Seq(true, false).foreach(useLocalCheckpoint => { - test(s"default params${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val g = Graphs.empty[Int] - val cc = g.connectedComponents - assert(cc.getAlgorithm === "graphframes") - assert(cc.getBroadcastThreshold === 1000000) - assert(cc.getCheckpointInterval === 2) - assert(!cc.getUseLabelsAsComponents) - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } - - test(s"empty graph${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - for (empty <- Seq(Graphs.empty[Int], Graphs.empty[Long], Graphs.empty[String])) { - val components = empty.connectedComponents.run() - assert(components.count() === 0L) - } - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } + test("using labels as components") { + spark.conf.set("spark.graphframes.useLabelsAsComponents", "true") + val vertices = + spark.createDataFrame(Seq("a", "b", "c", "d", "e").map(Tuple1.apply)).toDF(ID) + val edges = spark.createDataFrame(Seq.empty[(String, String)]).toDF(SRC, DST) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents.run() + val expected = Seq("a", "b", "c", "d", "e").map(Set(_)).toSet + assertComponents(components, expected) + components.unpersist() + spark.conf.set("spark.graphframes.useLabelsAsComponents", "false") + } - test(s"single vertex${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val v = spark.createDataFrame(List((0L, "a", "b"))).toDF("id", "vattr", "gender") - // Create an empty dataframe with the proper columns. - val e = spark - .createDataFrame(List((0L, 0L, 1L))) - .toDF("src", "dst", "test") - .filter("src > 10") - val g = GraphFrame(v, e) - val comps = ConnectedComponents.run(g) - TestUtils.testSchemaInvariants(g, comps) - TestUtils.checkColumnType(comps.schema, "component", DataTypes.LongType) - assert(comps.count() === 1) - assert( - comps.select("id", "component", "vattr", "gender").collect() - === Seq(Row(0L, 0L, "a", "b"))) - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } + test("don't using labels as components") { + val vertices = + spark.createDataFrame(Seq("a", "b", "c", "d", "e").map(Tuple1.apply)).toDF(ID) + val edges = spark.createDataFrame(Seq.empty[(String, String)]).toDF(SRC, DST) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents.run() + assert(components.schema("component").dataType == LongType) + components.unpersist() + } - test(s"disconnected vertices${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val n = 5L - val vertices = spark.range(n).toDF(ID) - val edges = spark.createDataFrame(Seq.empty[(Long, Long)]).toDF(SRC, DST) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - val expected = (0L until n).map(Set(_)).toSet + test("friends graph with different broadcast thresholds") { + val friends = Graphs.friends + val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) + for ((algorithm, broadcastThreshold) <- + Seq( + ("graphx", 1000000), + ("graphframes", 100000), + ("graphframes", 1), + ("graphframes", -1))) { + val components = friends.connectedComponents + .setAlgorithm(algorithm) + .setBroadcastThreshold(broadcastThreshold) + .run() assertComponents(components, expected) - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + components.unpersist() } + } - test( - s"using labels as components${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - spark.conf.set("spark.graphframes.useLabelsAsComponents", "true") - val vertices = - spark.createDataFrame(Seq("a", "b", "c", "d", "e").map(Tuple1.apply)).toDF(ID) - val edges = spark.createDataFrame(Seq.empty[(String, String)]).toDF(SRC, DST) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - val expected = Seq("a", "b", "c", "d", "e").map(Set(_)).toSet - assertComponents(components, expected) - spark.conf.set("spark.graphframes.useLabelsAsComponents", "false") - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } + Seq(true, false).foreach(useSkewedJoin => { + Seq(true, false).foreach(useLocalCheckpoint => { + val testPostfixName = s"${if (useLocalCheckpoint) " with local checkpoint" + else ""}${if (useSkewedJoin) ", skewed join" else ", AQE join"}" + val broadcastThreshold = if (useSkewedJoin) 1000000 else -1 + + test(s"empty graph$testPostfixName") { + for (empty <- Seq(Graphs.empty[Int], Graphs.empty[Long], Graphs.empty[String])) { + val components = empty.connectedComponents + .setBroadcastThreshold(broadcastThreshold) + .setUseLocalCheckpoints(useLocalCheckpoint) + .run() + assert(components.count() === 0L) + components.unpersist() + } + } - test( - s"don't using labels as components${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val vertices = - spark.createDataFrame(Seq("a", "b", "c", "d", "e").map(Tuple1.apply)).toDF(ID) - val edges = spark.createDataFrame(Seq.empty[(String, String)]).toDF(SRC, DST) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - assert(components.schema("component").dataType == LongType) - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } + test(s"single vertex$testPostfixName") { + val v = spark.createDataFrame(List((0L, "a", "b"))).toDF("id", "vattr", "gender") + // Create an empty dataframe with the proper columns. + val e = spark + .createDataFrame(List((0L, 0L, 1L))) + .toDF("src", "dst", "test") + .filter("src > 10") + val g = GraphFrame(v, e) + val comps = g.connectedComponents + .setBroadcastThreshold(broadcastThreshold) + .setUseLocalCheckpoints(useLocalCheckpoint) + .run() + TestUtils.testSchemaInvariants(g, comps) + TestUtils.checkColumnType(comps.schema, "component", DataTypes.LongType) + assert(comps.count() === 1) + assert( + comps.select("id", "component", "vattr", "gender").collect() + === Seq(Row(0L, 0L, "a", "b"))) + comps.unpersist() + } - test(s"two connected vertices${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val v = spark.createDataFrame(List((0L, "a0", "b0"), (1L, "a1", "b1"))).toDF("id", "A", "B") - val e = spark.createDataFrame(List((0L, 1L, "a01", "b01"))).toDF("src", "dst", "A", "B") - val g = GraphFrame(v, e) - val comps = g.connectedComponents.run() - TestUtils.testSchemaInvariants(g, comps) - assert(comps.count() === 2) - val vxs = comps.sort("id").select("id", "component", "A", "B").collect() - assert(List(Row(0L, 0L, "a0", "b0"), Row(1L, 0L, "a1", "b1")) === vxs) - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } + test(s"disconnected vertices$testPostfixName") { + val n = 5L + val vertices = spark.range(n).toDF(ID) + val edges = spark.createDataFrame(Seq.empty[(Long, Long)]).toDF(SRC, DST) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents + .setUseLocalCheckpoints(useLocalCheckpoint) + .setBroadcastThreshold(broadcastThreshold) + .run() + val expected = (0L until n).map(Set(_)).toSet + assertComponents(components, expected) + components.unpersist() + } - test(s"chain graph${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val n = 5L - val g = Graphs.chain(5L) - val components = g.connectedComponents.run() - val expected = Set((0L until n).toSet) - assertComponents(components, expected) - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } + test(s"two connected vertices$testPostfixName") { + val v = + spark.createDataFrame(List((0L, "a0", "b0"), (1L, "a1", "b1"))).toDF("id", "A", "B") + val e = spark.createDataFrame(List((0L, 1L, "a01", "b01"))).toDF("src", "dst", "A", "B") + val g = GraphFrame(v, e) + val comps = g.connectedComponents + .setUseLocalCheckpoints(useLocalCheckpoint) + .setBroadcastThreshold(broadcastThreshold) + .run() + TestUtils.testSchemaInvariants(g, comps) + assert(comps.count() === 2) + val vxs = comps.sort("id").select("id", "component", "A", "B").collect() + assert(List(Row(0L, 0L, "a0", "b0"), Row(1L, 0L, "a1", "b1")) === vxs) + comps.unpersist() + } - test(s"star graph${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val n = 5L - val g = Graphs.star(5L) - val components = g.connectedComponents.run() - val expected = Set((0L to n).toSet) - assertComponents(components, expected) - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } + test(s"chain graph$testPostfixName") { + val n = 5L + val g = Graphs.chain(5L) + val components = g.connectedComponents + .setUseLocalCheckpoints(useLocalCheckpoint) + .setBroadcastThreshold(broadcastThreshold) + .run() + val expected = Set((0L until n).toSet) + assertComponents(components, expected) + components.unpersist() + } - test(s"two blobs${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val n = 5L - val g = Graphs.twoBlobs(n.toInt) - val components = g.connectedComponents.run() - val expected = Set((0L until 2 * n).toSet) - assertComponents(components, expected) - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } + test(s"star graph$testPostfixName") { + val n = 5L + val g = Graphs.star(5L) + val components = g.connectedComponents + .setUseLocalCheckpoints(useLocalCheckpoint) + .setBroadcastThreshold(broadcastThreshold) + .run() + val expected = Set((0L to n).toSet) + assertComponents(components, expected) + components.unpersist() + } - test(s"two components${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val vertices = spark.range(6L).toDF(ID) - val edges = spark - .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L))) - .toDF(SRC, DST) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L)) - assertComponents(components, expected) - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } + test(s"two blobs$testPostfixName") { + val n = 5L + val g = Graphs.twoBlobs(n.toInt) + val components = g.connectedComponents + .setUseLocalCheckpoints(useLocalCheckpoint) + .setBroadcastThreshold(broadcastThreshold) + .run() + val expected = Set((0L until 2 * n).toSet) + assertComponents(components, expected) + components.unpersist() + } - test( - s"one component, differing edge directions${if (useLocalCheckpoint) " with local checkpoint" - else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val vertices = spark.range(5L).toDF(ID) - val edges = spark - .createDataFrame( - Seq( - // 0 -> 4 -> 3 <- 2 -> 1 - (0L, 4L), - (4L, 3L), - (2L, 3L), - (2L, 1L))) - .toDF(SRC, DST) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - val expected = Set((0L to 4L).toSet) - assertComponents(components, expected) - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } + test(s"two components$testPostfixName") { + val vertices = spark.range(6L).toDF(ID) + val edges = spark + .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L))) + .toDF(SRC, DST) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents + .setUseLocalCheckpoints(useLocalCheckpoint) + .setBroadcastThreshold(broadcastThreshold) + .run() + val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L)) + assertComponents(components, expected) + components.unpersist() + } - test( - s"two components and two dangling vertices${if (useLocalCheckpoint) " with local checkpoint" - else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val vertices = spark.range(8L).toDF(ID) - val edges = spark - .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L))) - .toDF(SRC, DST) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L), Set(6L), Set(7L)) - assertComponents(components, expected) - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } + test(s"one component, differing edge directions$testPostfixName") { + val vertices = spark.range(5L).toDF(ID) + val edges = spark + .createDataFrame( + Seq( + // 0 -> 4 -> 3 <- 2 -> 1 + (0L, 4L), + (4L, 3L), + (2L, 3L), + (2L, 1L))) + .toDF(SRC, DST) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents + .setUseLocalCheckpoints(useLocalCheckpoint) + .setBroadcastThreshold(broadcastThreshold) + .run() + val expected = Set((0L to 4L).toSet) + assertComponents(components, expected) + components.unpersist() + } - test(s"friends graph${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val friends = Graphs.friends - val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) - for ((algorithm, broadcastThreshold) <- - Seq(("graphx", 1000000), ("graphframes", 100000), ("graphframes", 1))) { - val components = friends.connectedComponents - .setAlgorithm(algorithm) + test(s"two components and two dangling vertices$testPostfixName") { + val vertices = spark.range(8L).toDF(ID) + val edges = spark + .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L))) + .toDF(SRC, DST) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents + .setUseLocalCheckpoints(useLocalCheckpoint) .setBroadcastThreshold(broadcastThreshold) .run() + val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L), Set(6L), Set(7L)) assertComponents(components, expected) + components.unpersist() } - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } - test(s"really large long IDs${if (useLocalCheckpoint) " with local checkpoint" else ""}") { - spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) - val max = Long.MaxValue - val chain = examples.Graphs.chain(10L) - val vertices = chain.vertices.select((lit(max) - col(ID)).as(ID)) - val edges = chain.edges.select((lit(max) - col(SRC)).as(SRC), (lit(max) - col(DST)).as(DST)) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - assert(components.count() === 10L) - assert(components.groupBy("component").count().count() === 1L) - spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") - } + test(s"really large long IDs$testPostfixName") { + val max = Long.MaxValue + val chain = examples.Graphs.chain(10L) + val vertices = chain.vertices.select((lit(max) - col(ID)).as(ID)) + val edges = + chain.edges.select((lit(max) - col(SRC)).as(SRC), (lit(max) - col(DST)).as(DST)) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents + .setUseLocalCheckpoints(useLocalCheckpoint) + .setBroadcastThreshold(broadcastThreshold) + .run() + assert(components.count() === 10L) + assert(components.groupBy("component").count().count() === 1L) + components.unpersist() + } + }) }) test("set configuration from spark conf") { @@ -255,83 +277,44 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon spark.conf.unset("spark.graphframes.connectedComponents.intermediatestoragelevel") } - test("checkpoint interval") { + test("intermediate storage level") { val friends = Graphs.friends val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) - val cc = new ConnectedComponents(friends) - assert( - cc.getCheckpointInterval === 2, - s"Default checkpoint interval should be 2, but got ${cc.getCheckpointInterval}.") - - val checkpointDir = sc.getCheckpointDir - assert(checkpointDir.nonEmpty) - - sc.setCheckpointDir(null) - withClue( - "Should throw an IOException if sc.getCheckpointDir is empty " + - "and checkpointInterval is positive.") { - intercept[IOException] { - cc.run() - } - } + val cc = friends.connectedComponents + assert(cc.getIntermediateStorageLevel === StorageLevel.MEMORY_AND_DISK) - // Checks whether the input DataFrame is from some checkpoint data. - // TODO: The implemetnation is a little hacky. - def isFromCheckpoint(df: DataFrame): Boolean = { - df.queryExecution.logical.toString().toLowerCase.contains("parquet") + for (storageLevel <- Seq( + StorageLevel.DISK_ONLY, + StorageLevel.MEMORY_ONLY, + StorageLevel.NONE)) { + val components = cc + .setIntermediateStorageLevel(storageLevel) + .run() + assertComponents(components, expected) + components.unpersist() + () } - - val components0 = cc.setCheckpointInterval(0).run() - assertComponents(components0, expected) - assert( - !isFromCheckpoint(components0), - "The result shouldn't depend on checkpoint data if checkpointing is disabled.") - - sc.setCheckpointDir(checkpointDir.get) - - val components1 = cc.setCheckpointInterval(1).run() - assertComponents(components1, expected) - assert( - isFromCheckpoint(components1), - "The result should depend on checkpoint data if checkpoint interval is 1.") - - val components10 = cc.setCheckpointInterval(10).run() - assertComponents(components10, expected) - assert( - !isFromCheckpoint(components10), - "The result shouldn't depend on checkpoint data if converged before first checkpoint.") } - test("intermediate storage level") { - // disabling adaptive query execution helps assertComponents - val enabled = spark.conf.getOption("spark.sql.adaptive.enabled") - try { - spark.conf.set("spark.sql.adaptive.enabled", value = false) - - val friends = Graphs.friends - val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) - - val cc = friends.connectedComponents - assert(cc.getIntermediateStorageLevel === StorageLevel.MEMORY_AND_DISK) - - for (storageLevel <- Seq( - StorageLevel.DISK_ONLY, - StorageLevel.MEMORY_ONLY, - StorageLevel.NONE)) { - // TODO: it is not trivial to confirm the actual storage level used - val components = cc - .setIntermediateStorageLevel(storageLevel) - .run() - assertComponents(components, expected) - } - } finally { - // restoring earlier conf - if (enabled.isDefined) { - spark.conf.set("spark.sql.adaptive.enabled", value = enabled.get) - } else { - spark.conf.unset("spark.sql.adaptive.enabled") - } + test("intermediate storage level without skewedJoin") { + val friends = Graphs.friends + val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) + + val cc = friends.connectedComponents + assert(cc.getIntermediateStorageLevel === StorageLevel.MEMORY_AND_DISK) + + for (storageLevel <- Seq( + StorageLevel.DISK_ONLY, + StorageLevel.MEMORY_ONLY, + StorageLevel.NONE)) { + val components = cc + .setIntermediateStorageLevel(storageLevel) + .setBroadcastThreshold(-1) + .run() + assertComponents(components, expected) + components.unpersist() + () } } From cd7f60f26f4eb5531773e471c357558117cc744d Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Wed, 17 Sep 2025 10:02:46 +0200 Subject: [PATCH 2/3] fixes --- .../org/graphframes/lib/ConnectedComponents.scala | 11 +++-------- core/src/main/scala/org/graphframes/lib/Pregel.scala | 12 ++++++------ 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 6f81904b6..77db635ca 100644 --- a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -119,11 +119,8 @@ object ConnectedComponents extends Logging { * `dst`. */ private def prepare(graph: GraphFrame): GraphFrame = { - // TODO: This assignment job might fail if the graph is skewed. val vertices = graph.indexedVertices .select(col(LONG_ID).as(ID), col(ATTR)) - // TODO: confirm the contract for a graph and decide whether we need distinct here - // .distinct() val edges = graph.indexedEdges .select(col(LONG_SRC).as(SRC), col(LONG_DST).as(DST)) val orderedEdges = edges @@ -327,13 +324,11 @@ object ConnectedComponents extends Logging { } else { ee = ee.checkpoint(eager = true) } - currRoundPersistedDFs = currRoundPersistedDFs :+ ee - } else { - // Checkpointing includes persist under the hood, no needs to do it again - ee = ee.persist(intermediateStorageLevel) - currRoundPersistedDFs = currRoundPersistedDFs :+ ee } + ee = ee.persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ ee + minNbrs1 = minNbrs( ee, computeCount = shouldDoSkewedJoin, diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index fed8b1687..db6b00665 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -334,10 +334,11 @@ class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints { var lastRoundPersistent: scala.collection.mutable.Queue[DataFrame] = scala.collection.mutable.Queue[DataFrame]() + val initialAttributes = graph.vertices.columns.map(col).toSeq + var currentVertices = graph.vertices.select( - (Seq( - col("*"), - initialActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ initVertexCols): _*) + ((initialAttributes :+ initialActiveVertexExpression.alias( + Pregel.ACTIVE_FLAG_COL)) ++ initVertexCols): _*) val edges = graph.edges .select(col(SRC).alias("edge_src"), col(DST).alias("edge_dst"), struct(col("*")).as(EDGE)) @@ -399,9 +400,8 @@ class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints { val verticesWithMsg = currentVertices.join(newAggMsgDF, Seq(ID), "left_outer") currentVertices = verticesWithMsg.select( - (Seq( - col(ID), - updateActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ updateVertexCols): _*) + ((initialAttributes :+ updateActiveVertexExpression.alias( + Pregel.ACTIVE_FLAG_COL)) ++ updateVertexCols): _*) if (shouldCheckpoint && iteration % checkpointInterval == 0) { if (useLocalCheckpoints) { From 807f286ee93ae07d992322eefca37ff956fc43fd Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Wed, 17 Sep 2025 14:17:00 +0200 Subject: [PATCH 3/3] StorageLevel.NONE is not working for AQE; most probably the reason is there is no checkpointing... But I'm not sure. Anyway, I don't see a blocker even if it is not working. The defaults are MEMORY_ONLY, so why does it matter? --- .../lib/ConnectedComponentsSuite.scala | 57 ++++++++----------- 1 file changed, 23 insertions(+), 34 deletions(-) diff --git a/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index 4e22c70e6..dbe097cab 100644 --- a/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -277,46 +277,35 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon spark.conf.unset("spark.graphframes.connectedComponents.intermediatestoragelevel") } - test("intermediate storage level") { - val friends = Graphs.friends - val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) - - val cc = friends.connectedComponents - assert(cc.getIntermediateStorageLevel === StorageLevel.MEMORY_AND_DISK) - - for (storageLevel <- Seq( - StorageLevel.DISK_ONLY, - StorageLevel.MEMORY_ONLY, - StorageLevel.NONE)) { - val components = cc - .setIntermediateStorageLevel(storageLevel) - .run() - assertComponents(components, expected) - components.unpersist() - () - } - } - - test("intermediate storage level without skewedJoin") { - val friends = Graphs.friends - val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) + Seq(StorageLevel.DISK_ONLY, StorageLevel.MEMORY_ONLY, StorageLevel.NONE).foreach( + storageLevel => { + test(s"intermediate storage level $storageLevel") { + val friends = Graphs.friends + val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) + + val components = + friends.connectedComponents.setIntermediateStorageLevel(storageLevel).run() + assertComponents(components, expected) + components.unpersist() + () + } + }) - val cc = friends.connectedComponents - assert(cc.getIntermediateStorageLevel === StorageLevel.MEMORY_AND_DISK) + Seq(StorageLevel.DISK_ONLY, StorageLevel.MEMORY_ONLY).foreach(storageLevel => { + test(s"intermediate storage level without skewedJoin $storageLevel") { + val friends = Graphs.friends + val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) - for (storageLevel <- Seq( - StorageLevel.DISK_ONLY, - StorageLevel.MEMORY_ONLY, - StorageLevel.NONE)) { - val components = cc - .setIntermediateStorageLevel(storageLevel) - .setBroadcastThreshold(-1) - .run() + val components = + friends.connectedComponents + .setIntermediateStorageLevel(storageLevel) + .setBroadcastThreshold(-1) + .run() assertComponents(components, expected) components.unpersist() () } - } + }) test("not leaking cached data") { val priorCachedDFsSize = spark.sparkContext.getPersistentRDDs.size