diff --git a/benchmarks/src/main/scala/org/graphframes/benchmarks/ConnectedComponentsBenchmark.scala b/benchmarks/src/main/scala/org/graphframes/benchmarks/ConnectedComponentsBenchmark.scala index 9d1946b9b..1eebfa72e 100644 --- a/benchmarks/src/main/scala/org/graphframes/benchmarks/ConnectedComponentsBenchmark.scala +++ b/benchmarks/src/main/scala/org/graphframes/benchmarks/ConnectedComponentsBenchmark.scala @@ -35,7 +35,7 @@ class ConnectedComponentsBenchmark extends LDBCBenchmarkBase { } else { graph.connectedComponents .setUseLocalCheckpoints(true) - .setAlgorithm("graphframes") + .setAlgorithm(algorithm) .setBroadcastThreshold(broadcastThreshold.toInt) .setUseLocalCheckpoints(useLocalCheckpoints.toBoolean) .run() diff --git a/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala b/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala index 70f2f0ca3..a53eeeaec 100644 --- a/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala +++ b/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala @@ -32,9 +32,11 @@ object GraphFramesConf { SQLConf .buildConf("spark.graphframes.connectedComponents.algorithm") .doc(""" Sets the connected components algorithm to use (default: "graphframes"). Supported algorithms - | - "graphframes": Uses alternating large star and small star iterations proposed in + | - "two_phase": Uses alternating large star and small star iterations proposed in | [[http://dx.doi.org/10.1145/2670979.2670997 Connected Components in MapReduce and Beyond]] - | with skewed join optimization. + | - "randomized_contraction": Uses randomized algorithm proposed in + | [[https://arxiv.org/pdf/1802.09478 In-database connected component analysis]] + | - "graphframes": Deprecated alias for "two_phase" | - "graphx": Converts the graph to a GraphX graph and then uses the connected components | implementation in GraphX. | @see org.graphframes.lib.ConnectedComponents.supportedAlgorithms""".stripMargin) diff --git a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index ebedf8fbf..458542758 100644 --- a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -17,17 +17,13 @@ package org.graphframes.lib -import org.apache.hadoop.fs.Path import org.apache.spark.graphframes.graphx -import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.* import org.apache.spark.sql.graphframes.GraphFramesConf -import org.apache.spark.sql.types.DecimalType import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame +import org.graphframes.GraphFramesUnreachableException import org.graphframes.Logging -import org.graphframes.WithAlgorithmChoice import org.graphframes.WithBroadcastThreshold import org.graphframes.WithCheckpointInterval import org.graphframes.WithIntermediateStorageLevel @@ -35,10 +31,6 @@ import org.graphframes.WithLocalCheckpoints import org.graphframes.WithMaxIter import org.graphframes.WithUseLabelsAsComponents -import java.io.IOException -import java.math.BigDecimal -import java.util.UUID - /** * Connected Components algorithm. * @@ -51,7 +43,6 @@ import java.util.UUID class ConnectedComponents private[graphframes] (private val graph: GraphFrame) extends Arguments with Logging - with WithAlgorithmChoice with WithCheckpointInterval with WithBroadcastThreshold with WithIntermediateStorageLevel @@ -59,7 +50,13 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) with WithMaxIter with WithLocalCheckpoints { - setAlgorithm(GraphFramesConf.getConnectedComponentsAlgorithm.getOrElse(ALGO_GRAPHFRAMES)) + import ConnectedComponents._ + + private var algorithm: String = GraphFramesConf.getConnectedComponentsAlgorithm + .getOrElse(ALGO_TWO_PHASE) + + private var isGraphPrepared: Boolean = false + setCheckpointInterval( GraphFramesConf.getConnectedComponentsCheckpointInterval.getOrElse(checkpointInterval)) setBroadcastThreshold( @@ -71,322 +68,155 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) setUseLocalCheckpoints(GraphFramesConf.getUseLocalCheckpoints.getOrElse(useLocalCheckpoints)) /** - * Runs the algorithm. - */ - def run(): DataFrame = { - ConnectedComponents.run( - graph, - runInGraphX = algorithm == ALGO_GRAPHX, - broadcastThreshold = broadcastThreshold, - checkpointInterval = checkpointInterval, - intermediateStorageLevel = intermediateStorageLevel, - useLabelsAsComponents = useLabelsAsComponents, - maxIter = maxIter, - useLocalCheckpoints = useLocalCheckpoints) - } -} - -object ConnectedComponents extends Logging { - - import org.graphframes.GraphFrame.* - - private[graphframes] val COMPONENT = "component" - private[graphframes] val ORIG_ID = "orig_id" - private val MIN_NBR = "min_nbr" - private val CNT = "cnt" - private val CHECKPOINT_NAME_PREFIX = "connected-components" - - /** - * Returns the symmetric directed graph of the graph specified by input edges. - * @param ee - * non-bidirectional edges + * Sets the algorithm to use for computing connected components. Supported values: + * - [[ConnectedComponents.ALGO_GRAPHX]]: use the GraphX implementation + * - [[ConnectedComponents.ALGO_GRAPHFRAMES]]: deprecated alias for + * [[ConnectedComponents.ALGO_TWO_PHASE]] + * - [[ConnectedComponents.ALGO_TWO_PHASE]]: use the two-phase label propagation + * implementation + * - [[ConnectedComponents.ALGO_RANDOMIZED_CONTRACTION]]: use the randomized contraction + * implementation */ - private[graphframes] def symmetrize(ee: DataFrame): DataFrame = { - val EDGE = "_edge" - ee.select(explode( - array(struct(col(SRC), col(DST)), struct(col(DST).as(SRC), col(SRC).as(DST)))).as(EDGE)) - .select(col(s"$EDGE.$SRC").as(SRC), col(s"$EDGE.$DST").as(DST)) + def setAlgorithm(value: String): this.type = { + val normalized = value.toLowerCase + normalized match { + case ALGO_GRAPHX | ALGO_TWO_PHASE | ALGO_RANDOMIZED_CONTRACTION => + algorithm = normalized + case ALGO_GRAPHFRAMES => + logWarn( + s"Algorithm '$ALGO_GRAPHFRAMES' is deprecated and will be removed in a future release. " + + s"Using '$ALGO_TWO_PHASE' instead.") + algorithm = ALGO_TWO_PHASE + case _ => + throw new IllegalArgumentException( + s"Unsupported algorithm: '$value'. " + + s"Supported values are: $ALGO_GRAPHX, $ALGO_TWO_PHASE, " + + s"$ALGO_RANDOMIZED_CONTRACTION, $ALGO_GRAPHFRAMES (deprecated).") + } + this } /** - * Prepares the input graph for computing connected components by: - * - de-duplicating vertices and assigning unique long IDs to each, - * - changing edge directions to have increasing long IDs from src to dst, - * - de-duplicating edges and removing self-loops. In the returned GraphFrame, the vertex - * DataFrame has two columns: - * - column `id` stores a long ID assigned to the vertex, - * - column `attr` stores the original vertex attributes. The edge DataFrame has two columns: - * - column `src` stores the long ID of the source vertex, - * - column `dst` stores the long ID of the destination vertex, where we always have `src` < - * `dst`. + * Gets the algorithm used for computing connected components. */ - private def prepare(graph: GraphFrame): GraphFrame = { - // TODO: This assignment job might fail if the graph is skewed. - val vertices = graph.indexedVertices - .select(col(LONG_ID).as(ID), col(ATTR)) - // TODO: confirm the contract for a graph and decide whether we need distinct here - // .distinct() - val edges = graph.indexedEdges - .select(col(LONG_SRC).as(SRC), col(LONG_DST).as(DST)) - val orderedEdges = edges - .filter(col(SRC) =!= col(DST)) - .select(minValue(col(SRC), col(DST)).as(SRC), maxValue(col(SRC), col(DST)).as(DST)) - .distinct() - GraphFrame(vertices, orderedEdges) - } + def getAlgorithm: String = algorithm /** - * Returns the min vertex among each vertex and its neighbors in a DataFrame with three columns: - * - `src`, the ID of the vertex - * - `min_nbr`, the min vertex ID among itself and its neighbors - * - `cnt`, the total number of neighbors + * !! WARNING: INTERNAL API — FOR VERY EXPERIENCED USERS ONLY !! + * + * Sets whether the graph has already been prepared before being passed to the algorithm, + * skipping the internal graph preparation step. The default is `false`, meaning the algorithm + * will always prepare the graph itself, which is the safe and recommended behaviour. + * + * Only set this to `true` if you have '''already performed all required preparation steps + * yourself''' and you fully understand what those steps are for the specific algorithm you are + * using. '''The preparation requirements differ significantly between algorithms:''' + * + * - `two_phase` and `randomized_contraction` each require their own distinct preparation + * steps. These are NOT interchangeable. You MUST study the internal source code of the + * algorithm you intend to use and replicate its exact preparation logic before enabling + * this flag. + * + * '''Incorrect use of this flag WILL produce silently wrong results with no error or warning at + * runtime.''' There is no validation that the graph has been correctly prepared. You are + * entirely responsible for ensuring correctness. + * + * @param value + * true if the graph is already prepared, false otherwise (default: false) */ - private def minNbrs(ee: DataFrame): DataFrame = { - symmetrize(ee) - .groupBy(SRC) - .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) - .withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR))) - } - - private def minValue(x: Column, y: Column): Column = { - when(x < y, x).otherwise(y) - } - - private def maxValue(x: Column, y: Column): Column = { - when(x > y, x).otherwise(y) + def setIsGraphPrepared(value: Boolean): this.type = { + logWarn( + "INTERNAL API ONLY WAS CALLED. This is an internal option for advanced users who fully " + + "understand graph preparation internals. Misuse will produce silently wrong results.") + isGraphPrepared = value + this } /** - * Performs a possibly skewed join between edges and current component assignments. The skew - * join is done by broadcast join for frequent keys and normal join for the rest. + * Runs the algorithm. */ - private def skewedJoin( - edges: DataFrame, - minNbrs: DataFrame, - broadcastThreshold: Int, - logPrefix: String): DataFrame = { - import edges.sparkSession.implicits.* - val hubs = minNbrs - .filter(col(CNT) > broadcastThreshold) - .select(SRC) - .as[Long] - .collect() - .toSet - GraphFrame.skewedJoin(edges, minNbrs, SRC, hubs, logPrefix) + def run(): DataFrame = { + algorithm match { + case ALGO_GRAPHX => + ConnectedComponents.runGraphX( + graph, + maxIter.getOrElse(Int.MaxValue), + intermediateStorageLevel) + case ALGO_TWO_PHASE => + if (broadcastThreshold == -1) { + TwoPhase.runAQE( + graph, + checkpointInterval = checkpointInterval, + intermediateStorageLevel = intermediateStorageLevel, + useLabelsAsComponents = useLabelsAsComponents, + useLocalCheckpoints = useLocalCheckpoints, + isGraphPrepared = isGraphPrepared) + } else { + TwoPhase.run( + graph, + broadcastThreshold = broadcastThreshold, + checkpointInterval = checkpointInterval, + intermediateStorageLevel = intermediateStorageLevel, + useLabelsAsComponents = useLabelsAsComponents, + useLocalCheckpoints = useLocalCheckpoints, + isGraphPrepared = isGraphPrepared) + } + case ALGO_RANDOMIZED_CONTRACTION => + RandomizedContraction.run( + graph, + useLabelsAsComponents = useLabelsAsComponents, + intermediateStorageLevel = intermediateStorageLevel, + useLocalCheckpoints = useLocalCheckpoints, + checkpointInterval = checkpointInterval, + isGraphPrepared = isGraphPrepared) + // the check is inside the setter + case _ => throw new GraphFramesUnreachableException() + } } - /** - * Runs connected components with default parameters. - */ + @deprecated("use graph.connectedComponents instead", "0.11.0") def run(graph: GraphFrame): DataFrame = { new ConnectedComponents(graph).run() } +} - private def runGraphX(graph: GraphFrame, maxIter: Int): DataFrame = { - val components = - graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX, maxIter) - GraphXConversions.fromGraphX(graph, components, vertexNames = Seq(COMPONENT)).vertices - } - - private def run( - graph: GraphFrame, - runInGraphX: Boolean, - broadcastThreshold: Int, - checkpointInterval: Int, - intermediateStorageLevel: StorageLevel, - useLabelsAsComponents: Boolean, - maxIter: Option[Int], - useLocalCheckpoints: Boolean): DataFrame = { - if (runInGraphX) { - return runGraphX(graph, maxIter.getOrElse(Int.MaxValue)) - } - - val spark = graph.spark - val sc = spark.sparkContext - // Store original AQE setting - val originalAQE = spark.conf.get("spark.sql.adaptive.enabled") - - try { - spark.conf.set("spark.sql.adaptive.enabled", "false") - - val runId = UUID.randomUUID().toString.takeRight(8) - val logPrefix = s"[CC $runId]" - logInfo(s"$logPrefix Start connected components with run ID $runId.") - - val shouldCheckpoint = checkpointInterval > 0 - val checkpointDir: Option[String] = if (useLocalCheckpoints) { None } - else if (shouldCheckpoint) { - val dir = sc.getCheckpointDir - .map { d => - new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString - } - .getOrElse { - // Spark-Connect workaround - spark.conf.getOption("spark.checkpoint.dir") match { - case Some(d) => new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString - case None => - throw new IOException( - "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir()" + - "or by specifying the conf 'spark.checkpoint.dir'.") - } - } - logInfo(s"$logPrefix Using $dir for checkpointing with interval $checkpointInterval.") - Some(dir) - } else { - logInfo( - s"$logPrefix Checkpointing is disabled because checkpointInterval=$checkpointInterval.") - None - } - - logInfo(s"$logPrefix Preparing the graph for connected component computation ...") - val g = prepare(graph) - val vv = g.vertices - var ee = g.edges.persist(intermediateStorageLevel) // src < dst - logInfo(s"$logPrefix Found ${ee.count()} edges after preparation.") - - var converged = false - var iteration = 1 - - def _calcMinNbrSum(minNbrsDF: DataFrame): BigDecimal = { - // Taking the sum in DecimalType to preserve precision. - // We use 20 digits for long values and Spark SQL will add 10 digits for the sum. - // It should be able to handle 200 billion edges without overflow. - val (minNbrSum, cnt) = minNbrsDF - .select(sum(col(MIN_NBR).cast(DecimalType(20, 0))), count("*")) - .rdd - .map { r => - (r.getAs[BigDecimal](0), r.getLong(1)) - } - .first() - if (cnt != 0L && minNbrSum == null) { - throw new ArithmeticException(s""" - |The total sum of edge src IDs is used to determine convergence during iterations. - |However, the total sum at iteration $iteration exceeded 30 digits (1e30), - |which should happen only if the graph contains more than 200 billion edges. - |If not, please file a bug report at https://github.com/graphframes/graphframes/issues. - """.stripMargin) - } - minNbrSum - } - // compute min neighbors (including self-min) - var minNbrs1: DataFrame = minNbrs(ee) // src >= min_nbr - .persist(intermediateStorageLevel) - - var prevSum: BigDecimal = _calcMinNbrSum(minNbrs1) - - var lastRoundPersistedDFs = Seq[DataFrame](ee, minNbrs1) - while (!converged) { - var currRoundPersistedDFs = Seq[DataFrame]() - // large-star step - // connect all strictly larger neighbors to the min neighbor (including self) - ee = skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix) - .select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst - .distinct() - .persist(intermediateStorageLevel) - currRoundPersistedDFs = currRoundPersistedDFs :+ ee - - // small-star step - // compute min neighbors (excluding self-min) - val minNbrs2 = ee - .groupBy(col(SRC)) - .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr - .persist(intermediateStorageLevel) - currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs2 - - // connect all smaller neighbors to the min neighbor - ee = skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix) - .select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst - .filter(col(SRC) =!= col(DST)) // src < dst - // connect self to the min neighbor - ee = ee - .union(minNbrs2.select(col(MIN_NBR).as(SRC), col(SRC).as(DST))) // src < dst - .distinct() - - // checkpointing - if (shouldCheckpoint && (iteration % checkpointInterval == 0)) { - if (useLocalCheckpoints) { - ee = ee.localCheckpoint(eager = true) - } else { - // TODO: remove this after DataFrame.checkpoint is implemented - val out = s"${checkpointDir.get}/$iteration" - ee.write.parquet(out) - // may hit S3 eventually consistent issue - ee = spark.read.parquet(out) - - // remove previous checkpoint - if (iteration > checkpointInterval) { - val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}") - path.getFileSystem(sc.hadoopConfiguration).delete(path, true) - } - - System.gc() // hint Spark to clean shuffle directories - } - } - - ee.persist(intermediateStorageLevel) - currRoundPersistedDFs = currRoundPersistedDFs :+ ee - - minNbrs1 = minNbrs(ee) // src >= min_nbr - .persist(intermediateStorageLevel) - currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1 +object ConnectedComponents extends Logging { - // test convergence - val currSum = _calcMinNbrSum(minNbrs1) - logInfo(s"$logPrefix Sum of assigned components in iteration $iteration: $currSum.") - if (currSum == prevSum) { - // This also covers the case when cnt = 0 and currSum is null, which means no edges. - converged = true - } else { - prevSum = currSum - } + private[graphframes] val COMPONENT = "component" + private[graphframes] val ORIG_ID = "orig_id" - // clean up persisted DFs - for (persisted_df <- lastRoundPersistedDFs) { - persisted_df.unpersist() - } - lastRoundPersistedDFs = currRoundPersistedDFs - iteration += 1 - } + val ALGO_GRAPHX = "graphx" - logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.") + /** + * @deprecated + * Use [[ALGO_TWO_PHASE]] instead. + */ + val ALGO_GRAPHFRAMES = "graphframes" - logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.") - val indexedLabel = vv - .join(ee, vv(ID) === ee(DST), "left_outer") - .select( - vv(ATTR), - when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(COMPONENT), - col(ATTR + "." + ID).as(ID)) - val output = if (graph.hasIntegralIdType || !useLabelsAsComponents) { - indexedLabel - .select(col(s"$ATTR.*"), col(COMPONENT)) - .persist(intermediateStorageLevel) - } else { - indexedLabel - .join( - indexedLabel - .groupBy(col(COMPONENT)) - .agg(min(col(ID)).as(ORIG_ID)) - .select(col(COMPONENT), col(ORIG_ID)), - COMPONENT) - .select(col(s"$ATTR.*"), col(ORIG_ID).as(COMPONENT)) - .persist(intermediateStorageLevel) - } + val ALGO_TWO_PHASE = "two_phase" + val ALGO_RANDOMIZED_CONTRACTION = "randomized_contraction" - // An action must be performed on the DataFrame for the cache to load - output.count() + /** + * Runs the GraphX connected components implementation. + */ + private[graphframes] def runGraphX( + graph: GraphFrame, + maxIter: Int, + intermediateStorageLevel: StorageLevel): DataFrame = { + val gx = graph.cachedTopologyGraphX + val components = + graphx.lib.ConnectedComponents.run(gx, maxIter) + val result = GraphXConversions + .fromGraphX(graph, components, vertexNames = Seq(ConnectedComponents.COMPONENT)) + .vertices + .persist(intermediateStorageLevel) - // clean up persisted DFs - for (persisted_df <- lastRoundPersistedDFs) { - persisted_df.unpersist() - } + val _ = result.count() + gx.unpersist() + components.unpersist() - resultIsPersistent() + resultIsPersistent() - output - } finally { - // Restore original AQE setting - spark.conf.set("spark.sql.adaptive.enabled", originalAQE) - } + result } } diff --git a/core/src/main/scala/org/graphframes/lib/RandomizedContraction.scala b/core/src/main/scala/org/graphframes/lib/RandomizedContraction.scala index d3ace0b03..90459ccf2 100644 --- a/core/src/main/scala/org/graphframes/lib/RandomizedContraction.scala +++ b/core/src/main/scala/org/graphframes/lib/RandomizedContraction.scala @@ -50,6 +50,8 @@ private[graphframes] object RandomizedContraction extends Logging with Serializa inputGraph: GraphFrame, useLabelsAsComponents: Boolean, intermediateStorageLevel: StorageLevel, + useLocalCheckpoints: Boolean, + checkpointInterval: Int, isGraphPrepared: Boolean): DataFrame = { val spark = inputGraph.vertices.sparkSession val sc = spark.sparkContext @@ -140,15 +142,26 @@ private[graphframes] object RandomizedContraction extends Logging with Serializa // save ref to unpersist val oldEdges = edges - edges = edges2 - .alias("e") - .join( - ccRepresentatives.alias("r2"), - col(s"e.$DST") === col("r2.v") && - col(s"e.$SRC") =!= col("r2.rep")) - .select(col(s"e.$SRC").alias(SRC), col("r2.rep").alias(DST)) - .distinct() - .persist(intermediateStorageLevel) + edges = { + val te = edges2 + .alias("e") + .join( + ccRepresentatives.alias("r2"), + col(s"e.$DST") === col("r2.v") && + col(s"e.$SRC") =!= col("r2.rep")) + .select(col(s"e.$SRC").alias(SRC), col("r2.rep").alias(DST)) + .distinct() + + if ((iter > 0) && (iter % checkpointInterval == 0)) { + if (useLocalCheckpoints) { + te.localCheckpoint() + } else { + te.checkpoint() + } + } else { + te.persist(intermediateStorageLevel) + } + } graphSize = edges.count() oldEdges.unpersist() @@ -160,6 +173,8 @@ private[graphframes] object RandomizedContraction extends Logging with Serializa var accA = 1L var accB = 0L + edges.unpersist(true) + while (iter > 1) { iter -= 1 val poppedA = stackA.pop() @@ -185,6 +200,8 @@ private[graphframes] object RandomizedContraction extends Logging with Serializa .persist(intermediateStorageLevel) result.write.mode("overwrite").parquet(ccRepsR) + + result.unpersist() val oldPath = new Path(ccRepsR1) val fs = oldPath.getFileSystem(sc.hadoopConfiguration) @@ -236,7 +253,6 @@ private[graphframes] object RandomizedContraction extends Logging with Serializa outputComponents.count() // clean-up - edges.unpersist() val chDirPath = new Path(checkpointDir) val fs = chDirPath.getFileSystem(sc.hadoopConfiguration) if (fs.exists(chDirPath)) { @@ -245,6 +261,8 @@ private[graphframes] object RandomizedContraction extends Logging with Serializa outputComponents } finally { + // to be 100% sure; + edges.unpersist() val dereg = functionRegistry.dropFunction(FunctionIdentifier("_axpb")) if (!dereg) { logWarn( diff --git a/core/src/main/scala/org/graphframes/lib/TwoPhase.scala b/core/src/main/scala/org/graphframes/lib/TwoPhase.scala new file mode 100644 index 000000000..7ce9e5872 --- /dev/null +++ b/core/src/main/scala/org/graphframes/lib/TwoPhase.scala @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.graphframes.lib + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.* +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.storage.StorageLevel +import org.graphframes.GraphFrame +import org.graphframes.GraphFrame.ATTR +import org.graphframes.GraphFrame.DST +import org.graphframes.GraphFrame.ID +import org.graphframes.GraphFrame.LONG_DST +import org.graphframes.GraphFrame.LONG_ID +import org.graphframes.GraphFrame.LONG_SRC +import org.graphframes.GraphFrame.SRC +import org.graphframes.Logging + +import java.io.IOException +import java.math.BigDecimal +import java.util.UUID + +/** + * Two-phase label propagation implementation of connected components. + * + * This is the primary GraphFrames-native implementation. It iteratively applies large-star and + * small-star steps until convergence, using checkpointing to manage query plan growth. + */ +private[graphframes] object TwoPhase extends Logging { + + private val CHECKPOINT_NAME_PREFIX = "connected-components" + private val MIN_NBR = "min_nbr" + private val CNT = "cnt" + + /** + * Returns the symmetric directed graph of the graph specified by input edges. + * @param ee + * non-bidirectional edges + */ + private def symmetrize(ee: DataFrame): DataFrame = { + val EDGE = "_edge" + ee.select(explode( + array(struct(col(SRC), col(DST)), struct(col(DST).as(SRC), col(SRC).as(DST)))).as(EDGE)) + .select(col(s"$EDGE.$SRC").as(SRC), col(s"$EDGE.$DST").as(DST)) + } + + /** + * Prepares the input graph for computing connected components by: + * - de-duplicating vertices and assigning unique long IDs to each, + * - changing edge directions to have increasing long IDs from src to dst, + * - de-duplicating edges and removing self-loops. + * + * In the returned GraphFrame, the vertex DataFrame has two columns: + * - column `id` stores a long ID assigned to the vertex, + * - column `attr` stores the original vertex attributes. + * + * The edge DataFrame has two columns: + * - column `src` stores the long ID of the source vertex, + * - column `dst` stores the long ID of the destination vertex, where we always have `src` < + * `dst`. + */ + private def prepare(graph: GraphFrame): GraphFrame = { + val vertices = graph.indexedVertices + .select(col(LONG_ID).as(ID), col(ATTR)) + val edges = graph.indexedEdges + .select(col(LONG_SRC).as(SRC), col(LONG_DST).as(DST)) + val orderedEdges = edges + .filter(col(SRC) =!= col(DST)) + .select(minValue(col(SRC), col(DST)).as(SRC), maxValue(col(SRC), col(DST)).as(DST)) + .distinct() + GraphFrame(vertices, orderedEdges) + } + + /** + * Returns the min vertex among each vertex and its neighbors in a DataFrame with three columns: + * - `src`, the ID of the vertex + * - `min_nbr`, the min vertex ID among itself and its neighbors + * - `cnt`, the total number of neighbors + */ + private def minNbrs(ee: DataFrame): DataFrame = { + symmetrize(ee) + .groupBy(SRC) + .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) + .withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR))) + } + + private def minValue(x: Column, y: Column): Column = { + when(x < y, x).otherwise(y) + } + + private def maxValue(x: Column, y: Column): Column = { + when(x > y, x).otherwise(y) + } + + /** + * Computes the sum of all `min_nbr` values in the given DataFrame, cast to DecimalType(38, 0) + * for high precision. Used to detect convergence between iterations. + */ + private def calcMinNbrSum(minNbrsDF: DataFrame): BigDecimal = { + minNbrsDF + .select(sum(col(MIN_NBR).cast(DecimalType(38, 0)))) + .first() + .getAs[BigDecimal](0) + } + + /** + * Builds the output DataFrame by joining the indexed vertices with the final edge assignments + * and resolving component labels. + */ + private def buildOutput( + graph: GraphFrame, + vv: DataFrame, + ee: DataFrame, + useLabelsAsComponents: Boolean): DataFrame = { + val indexedLabel = vv + .join(ee, vv(ID) === ee(DST), "left_outer") + .select( + vv(ATTR), + when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(ConnectedComponents.COMPONENT), + col(ATTR + "." + ID).as(ID)) + + if (graph.hasIntegralIdType || !useLabelsAsComponents) { + indexedLabel + .select(col(s"$ATTR.*"), col(ConnectedComponents.COMPONENT)) + } else { + indexedLabel + .join( + indexedLabel + .groupBy(col(ConnectedComponents.COMPONENT)) + .agg(min(col(ID)).as(ConnectedComponents.ORIG_ID)) + .select(col(ConnectedComponents.COMPONENT), col(ConnectedComponents.ORIG_ID)), + ConnectedComponents.COMPONENT) + .select( + col(s"$ATTR.*"), + col(ConnectedComponents.ORIG_ID).as(ConnectedComponents.COMPONENT)) + } + } + + /** + * Performs a possibly skewed join between edges and current component assignments. The skew + * join is done by broadcast join for frequent keys and normal join for the rest. + */ + private def skewedJoin( + edges: DataFrame, + minNbrsDF: DataFrame, + broadcastThreshold: Int, + logPrefix: String): DataFrame = { + import edges.sparkSession.implicits._ + val hubs = minNbrsDF + .filter(col(CNT) > broadcastThreshold) + .select(SRC) + .as[Long] + .collect() + .toSet + GraphFrame.skewedJoin(edges, minNbrsDF, SRC, hubs, logPrefix) + } + + /** + * Runs the two-phase label propagation connected components algorithm. + */ + private[graphframes] def run( + graph: GraphFrame, + broadcastThreshold: Int, + checkpointInterval: Int, + intermediateStorageLevel: StorageLevel, + useLabelsAsComponents: Boolean, + useLocalCheckpoints: Boolean, + isGraphPrepared: Boolean): DataFrame = { + + val spark = graph.spark + val sc = spark.sparkContext + val originalAQE = spark.conf.get("spark.sql.adaptive.enabled") + + try { + spark.conf.set("spark.sql.adaptive.enabled", "false") + + val runId = UUID.randomUUID().toString.takeRight(8) + val logPrefix = s"[CC $runId]" + logInfo(s"$logPrefix Start connected components with run ID $runId.") + + val shouldCheckpoint = checkpointInterval > 0 + val checkpointDir: Option[String] = if (useLocalCheckpoints) { None } + else if (shouldCheckpoint) { + val dir = sc.getCheckpointDir + .map { d => + new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString + } + .getOrElse { + spark.conf.getOption("spark.checkpoint.dir") match { + case Some(d) => new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString + case None => + throw new IOException( + "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir()" + + "or by specifying the conf 'spark.checkpoint.dir'.") + } + } + logInfo(s"$logPrefix Using $dir for checkpointing with interval $checkpointInterval.") + Some(dir) + } else { + logInfo( + s"$logPrefix Checkpointing is disabled because checkpointInterval=$checkpointInterval.") + None + } + + logInfo(s"$logPrefix Preparing the graph for connected component computation ...") + val g = if (isGraphPrepared) graph else prepare(graph) + val vv = g.vertices + var ee = g.edges.persist(intermediateStorageLevel) // src < dst + logInfo(s"$logPrefix Found ${ee.count()} edges after preparation.") + + var converged = false + var iteration = 1 + + var minNbrs1: DataFrame = minNbrs(ee) // src >= min_nbr + .persist(intermediateStorageLevel) + + var prevSum: BigDecimal = calcMinNbrSum(minNbrs1) + + var lastRoundPersistedDFs = Seq[DataFrame](ee, minNbrs1) + while (!converged) { + var currRoundPersistedDFs = Seq[DataFrame]() + + // large-star step + // connect all strictly larger neighbors to the min neighbor (including self) + ee = skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix) + .select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst + .distinct() + .persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ ee + + // small-star step + // compute min neighbors (excluding self-min) + val minNbrs2 = ee + .groupBy(col(SRC)) + .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr + .persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs2 + + // connect all smaller neighbors to the min neighbor + ee = skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix) + .select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst + .filter(col(SRC) =!= col(DST)) // src < dst + // connect self to the min neighbor + ee = ee + .union(minNbrs2.select(col(MIN_NBR).as(SRC), col(SRC).as(DST))) // src < dst + .distinct() + + // checkpointing + if (shouldCheckpoint && (iteration % checkpointInterval == 0)) { + if (useLocalCheckpoints) { + ee = ee.localCheckpoint(eager = true) + } else { + val out = s"${checkpointDir.get}/$iteration" + ee.write.parquet(out) + ee = spark.read.parquet(out) + + if (iteration > checkpointInterval) { + val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}") + path.getFileSystem(sc.hadoopConfiguration).delete(path, true) + } + + System.gc() + } + } + + ee.persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ ee + + minNbrs1 = minNbrs(ee) // src >= min_nbr + .persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1 + + // test convergence + val currSum = calcMinNbrSum(minNbrs1) + logInfo(s"$logPrefix Sum of assigned components in iteration $iteration: $currSum.") + if (currSum == prevSum) { + converged = true + } else { + prevSum = currSum + } + + for (persistedDF <- lastRoundPersistedDFs) { + persistedDF.unpersist() + } + lastRoundPersistedDFs = currRoundPersistedDFs + iteration += 1 + } + + logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.") + logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.") + + val output = buildOutput(graph, vv, ee, useLabelsAsComponents) + .persist(intermediateStorageLevel) + + output.count() + + for (persistedDF <- lastRoundPersistedDFs) { + persistedDF.unpersist() + } + + resultIsPersistent() + + output + } finally { + spark.conf.set("spark.sql.adaptive.enabled", originalAQE) + } + } + + /** + * Runs the two-phase label propagation connected components algorithm using Adaptive Query + * Execution (AQE). Unlike `run`, this method does not manipulate AQE settings, does not use + * skewed joins, and uses simpler checkpointing. + */ + private[graphframes] def runAQE( + graph: GraphFrame, + checkpointInterval: Int, + intermediateStorageLevel: StorageLevel, + useLabelsAsComponents: Boolean, + useLocalCheckpoints: Boolean, + isGraphPrepared: Boolean): DataFrame = { + + val runId = UUID.randomUUID().toString.takeRight(8) + val logPrefix = s"[CC $runId]" + logInfo(s"$logPrefix Start connected components with run ID $runId.") + + val shouldCheckpoint = checkpointInterval > 0 + + logInfo(s"$logPrefix Preparing the graph for connected component computation ...") + val g = if (isGraphPrepared) graph else prepare(graph) + val vv = g.vertices + var ee = g.edges.persist(intermediateStorageLevel) // src < dst + logInfo(s"$logPrefix Found ${ee.count()} edges after preparation.") + + var converged = false + var iteration = 1 + + var minNbrs1: DataFrame = symmetrize(ee) + .groupBy(SRC) + .agg(min(col(DST)).as(MIN_NBR)) + .withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR))) + .persist(intermediateStorageLevel) + + var prevSum: BigDecimal = calcMinNbrSum(minNbrs1) + + var lastRoundPersistedDFs = Seq[DataFrame](ee, minNbrs1) + while (!converged) { + var currRoundPersistedDFs = Seq[DataFrame]() + + // large-star step + // connect all strictly larger neighbors to the min neighbor (including self) + ee = ee + .join(minNbrs1, SRC) + .select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst + .distinct() + .persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ ee + + // small-star step + // compute min neighbors (excluding self-min) + val minNbrs2 = ee + .groupBy(col(SRC)) + .agg(min(col(DST)).as(MIN_NBR)) // src > min_nbr + .persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs2 + + // connect all smaller neighbors to the min neighbor + ee = ee + .join(minNbrs2, SRC) + .select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst + .filter(col(SRC) =!= col(DST)) // src < dst + // connect self to the min neighbor + ee = ee + .union(minNbrs2.select(col(MIN_NBR).as(SRC), col(SRC).as(DST))) // src < dst + .distinct() + + // checkpointing + if (shouldCheckpoint && (iteration % checkpointInterval == 0)) { + if (useLocalCheckpoints) { + ee = ee.localCheckpoint(eager = true) + } else { + ee = ee.checkpoint(eager = true) + } + } + + ee.persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ ee + + minNbrs1 = symmetrize(ee) + .groupBy(SRC) + .agg(min(col(DST)).as(MIN_NBR)) + .withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR))) + .persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1 + + // test convergence + val currSum = calcMinNbrSum(minNbrs1) + logInfo(s"$logPrefix Sum of assigned components in iteration $iteration: $currSum.") + if (currSum == prevSum) { + converged = true + } else { + prevSum = currSum + } + + for (persistedDF <- lastRoundPersistedDFs) { + persistedDF.unpersist() + } + lastRoundPersistedDFs = currRoundPersistedDFs + iteration += 1 + } + + logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.") + logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.") + + val output = buildOutput(graph, vv, ee, useLabelsAsComponents) + .persist(intermediateStorageLevel) + + output.count() + + for (persistedDF <- lastRoundPersistedDFs) { + persistedDF.unpersist() + } + + resultIsPersistent() + + output + } +} diff --git a/core/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala b/core/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala index c2d569c7f..89ab7e34c 100644 --- a/core/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala +++ b/core/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala @@ -211,10 +211,15 @@ class TestLDBCCases extends SparkFunSuite with GraphFrameTestSparkContext { expectedComponents) } - Seq("graphframes", "graphx").foreach { algo => + Seq("two_phase", "graphx", "randomized_contraction").foreach { algo => test(s"test undirected WCC with LDBC for impl ${algo}") { val testCase = ldbcTestWCCUndirected - val ccResults = testCase._1.connectedComponents.setAlgorithm(algo).run() + var cc = testCase._1.connectedComponents.setAlgorithm(algo) + if (algo == "randomized_contraction") { + // RC is randomized by it's nature; + cc = cc.setUseLabelsAsComponents(true) + } + val ccResults = cc.run() assert(ccResults.count() == testCase._1.vertices.count()) assert( ccResults diff --git a/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index 8bec4b061..b2db69d76 100644 --- a/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -35,7 +35,8 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon test("default params") { val g = Graphs.empty[Int] val cc = g.connectedComponents - assert(cc.getAlgorithm === "graphframes") + // That is OK! It is just a name-change + assert(cc.getAlgorithm === "two_phase") assert(cc.getBroadcastThreshold === 1000000) assert(cc.getCheckpointInterval === 2) assert(!cc.getUseLabelsAsComponents) diff --git a/core/src/test/scala/org/graphframes/lib/RandomizedContractionSuite.scala b/core/src/test/scala/org/graphframes/lib/RandomizedContractionSuite.scala index da7c14391..a54567988 100644 --- a/core/src/test/scala/org/graphframes/lib/RandomizedContractionSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/RandomizedContractionSuite.scala @@ -18,6 +18,8 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC inputGraph = graph, useLabelsAsComponents = false, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = true, + checkpointInterval = 1, isGraphPrepared = false) assert(components.count() === 0L) assertFunctionRegistryClean() @@ -32,6 +34,8 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC inputGraph = graph, useLabelsAsComponents = false, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = true, + checkpointInterval = 1, isGraphPrepared = false) assert(components.count() === 1L) assert(components.select("id", "component").collect().toSet === Set(Row(0L, 0L))) @@ -47,6 +51,8 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC inputGraph = graph, useLabelsAsComponents = false, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = true, + checkpointInterval = 1, isGraphPrepared = false) assert(components.count() === 2L) val compValues = components.select("id", "component").collect() @@ -62,6 +68,8 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC inputGraph = graph, useLabelsAsComponents = false, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = true, + checkpointInterval = 1, isGraphPrepared = false) assert(components.count() === n) assert(components.select("component").distinct().count() === 1L) @@ -78,6 +86,8 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC inputGraph = graph, useLabelsAsComponents = false, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = true, + checkpointInterval = 1, isGraphPrepared = false) assert(components.count() === n) assert(components.select("component").distinct().count() === n) @@ -94,6 +104,8 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC inputGraph = graph, useLabelsAsComponents = false, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = true, + checkpointInterval = 1, isGraphPrepared = false) assert(components.count() === 6L) val compGroups = @@ -112,6 +124,8 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC inputGraph = graph, useLabelsAsComponents = false, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = true, + checkpointInterval = 1, isGraphPrepared = false) assert(components.count() === 8L) val compCounts = @@ -130,6 +144,8 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC inputGraph = graph, useLabelsAsComponents = true, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = true, + checkpointInterval = 1, isGraphPrepared = false) assert(components.count() === 4L) val compIds = components.select("component").collect().map(_.getString(0)).toSet @@ -148,6 +164,8 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC inputGraph = graph, useLabelsAsComponents = true, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = true, + checkpointInterval = 1, isGraphPrepared = false) assert(components.count() === 4L) val compIds = components.select("component").collect().map(_.getLong(0)).toSet @@ -164,6 +182,8 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC inputGraph = graph, useLabelsAsComponents = false, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = true, + checkpointInterval = 1, isGraphPrepared = false) components.count() @@ -173,19 +193,21 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC } test("RandomizedContraction: no memory leaks") { - val priorCachedCount = spark.sparkContext.getPersistentRDDs.size + val priorCached = spark.sparkContext.getPersistentRDDs val graph = Graphs.chain(10L) val components = RandomizedContraction.run( inputGraph = graph, useLabelsAsComponents = false, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = false, + checkpointInterval = 1, isGraphPrepared = false) components.count() components.unpersist() - val postCachedCount = spark.sparkContext.getPersistentRDDs.size - assert(postCachedCount === priorCachedCount) + val postCached = spark.sparkContext.getPersistentRDDs + assert(postCached.size === priorCached.size) assertFunctionRegistryClean() } @@ -201,6 +223,8 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC inputGraph = graph, useLabelsAsComponents = false, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = true, + checkpointInterval = 1, isGraphPrepared = false) assert(components.count() === 10L) assert(components.select("component").distinct().count() === 1L) @@ -217,6 +241,8 @@ class RandomizedContractionSuite extends SparkFunSuite with GraphFrameTestSparkC inputGraph = graph, useLabelsAsComponents = false, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + useLocalCheckpoints = true, + checkpointInterval = 1, isGraphPrepared = false) assert(components.count() === 5L) assert(components.select("component").distinct().count() === 1L) diff --git a/docs/src/04-user-guide/05-traversals.md b/docs/src/04-user-guide/05-traversals.md index c831a5d81..e7f6a81cb 100644 --- a/docs/src/04-user-guide/05-traversals.md +++ b/docs/src/04-user-guide/05-traversals.md @@ -11,7 +11,7 @@ See [Wikipedia](https://en.wikipedia.org/wiki/Shortest_path_problem) for a backg **NOTE** -*Be aware, that returned `DataFrame` is persistent and should be unpersisted manually after processing to avoid memory leaks!* +_Be aware, that returned `DataFrame` is persistent and should be unpersisted manually after processing to avoid memory leaks!_ --- @@ -120,17 +120,17 @@ paths.show() Computes the connected component membership of each vertex and returns a graph with each vertex assigned a component ID. -See [Wikipedia](https://en.wikipedia.org/wiki/Connected_component_(graph_theory)) for the background. +See [Wikipedia]() for the background. --- **NOTE:** -*With GraphFrames 0.3.0 and later releases, the default Connected Components algorithm requires setting a Spark checkpoint directory. Users can revert to the old algorithm using `connectedComponents.setAlgorithm("graphx")`. Starting from GraphFrames 0.9.3 release, users can also use `localCheckpoints` that does not require setting a Spark checkpoint directory. To use `localCheckpoints` users can set the config `spark.graphframes.useLocalCheckpoints` to `true` or use the API `connectedComponents.setUseLocalCheckpoints(true)`. While `localCheckpoints` provides better performance they are not as reliable as the persistent checkpointing.* +_With GraphFrames 0.3.0 and later releases, the default Connected Components algorithm requires setting a Spark checkpoint directory. Users can revert to the old algorithm using `connectedComponents.setAlgorithm("graphx")`. Starting from GraphFrames 0.9.3 release, users can also use `localCheckpoints` that does not require setting a Spark checkpoint directory. To use `localCheckpoints` users can set the config `spark.graphframes.useLocalCheckpoints` to `true` or use the API `connectedComponents.setUseLocalCheckpoints(true)`. While `localCheckpoints` provides better performance they are not as reliable as the persistent checkpointing._ **NOTE** -*Be aware, that returned `DataFrame` is persistent and should be unpersisted manually after processing to avoid memory leaks!* +_Be aware, that returned `DataFrame` is persistent and should be unpersisted manually after processing to avoid memory leaks!_ --- @@ -162,43 +162,96 @@ val result = g.connectedComponents.setUseLocalCheckpoints(true).run() result.select("id", "component").orderBy("component").show() ``` +### Algorithms + +GraphFrames provides three algorithm implementations, selectable via the `algorithm` argument: + +#### `graphx` + +A naive Pregel-based implementation backed by Apache Spark GraphX. It propagates the minimum vertex ID along edges one hop per iteration, so convergence requires a number of iterations equal to the diameter of the graph. While it may be slightly faster on very small graphs, it has poor convergence complexity on large or wide graphs and requires significantly more memory due to less efficient RDD serialization and its triplets-based nature. Use this only as a fallback or for compatibility. + +#### `two_phase` (default) + +A DataFrame-native implementation based on the large-star / small-star label propagation approach described in: + +> Kiveris, Raimondas, et al. _"Connected components in MapReduce and beyond."_ Proceedings of the ACM Symposium on Cloud Computing. 2014. https://dl.acm.org/doi/abs/10.1145/2670979.2670997 + +This algorithm has much better convergence complexity than `graphx` and requires significantly less memory thanks to efficient Tungsten serialization. It is the recommended default for most workloads. + +Component IDs produced by `two_phase` are stable `Long` values. For graphs whose vertex IDs are already integral types (`Long`, `Int`, `Short`, `Byte`), the component ID will be the minimum original vertex ID within the component. For `String`-typed (or other non-integral) vertex IDs, the component ID will be a random `Long` unless `use_labels_as_components=True` is set (see below). + +This algorithm has two internal join modes — see [AQE-broadcast mode](#aqe-broadcast-mode) below for details. + +#### `randomized_contraction` + +A DataFrame-native implementation based on randomized graph contraction, described in: + +> Bögeholz, Harald, Michael Brand, and Radu-Alexandru Todor. _"In-database connected component analysis."_ 2020 IEEE 36th International Conference on Data Engineering (ICDE). IEEE, 2020. + +This algorithm iteratively contracts the graph using random linear functions until no edges remain, then reconstructs component identifiers in a reverse pass. It has similar convergence characteristics to `two_phase` (AQE mode) and performs comparably on benchmarks — slightly worse than `two_phase` with AQE, but significantly better than `two_phase` with manual skewed joins. + +Unlike `two_phase`, `randomized_contraction` **always** produces random `Long` component IDs regardless of the input vertex ID type, unless `use_labels_as_components=True` is set. + +#### Deprecation notice + +The algorithm name `graphframes` is a deprecated alias for `two_phase` and will be removed in a future release. Replace any usage of `setAlgorithm("graphframes")` with `setAlgorithm("two_phase")`. + +#### Performance summary + +| Algorithm | Convergence complexity | Memory usage | Component ID type | +| ------------------------- | ------------------------------------------ | ---------------- | -------------------------------------------------- | +| `graphx` | O(diameter) iterations | High (RDD-based) | Min vertex ID in component | +| `two_phase` (skewed join) | Fast | Low (DataFrame) | Min original ID (integral) or random Long (String) | +| `two_phase` (AQE, `-1`) | Fast, ~5x faster than skewed join | Low (DataFrame) | Min original ID (integral) or random Long (String) | +| `randomized_contraction` | Fast, slightly slower than `two_phase` AQE | Low (DataFrame) | Always random Long | + ### Arguments - `algorithm` -Possible values are `graphx` and `graphframes`. GraphX-based implementation is a naive Pregel one-by-one. While it may be slightly faster on small-medium sized graphs, it has a much bigger convergence complexity and requires much more memory due to less efficient RDD serialization. GraphFrame-based implementation is based on the ideas from the [Kiveris, Raimondas, et al. "Connected components in mapreduce and beyond." Proceedings of the ACM Symposium on Cloud Computing. 2014.](https://dl.acm.org/doi/abs/10.1145/2670979.2670997). This implementation has much better convergence complexity as well as requires less amount of memory. +Selects the algorithm. Supported values: `graphx`, `two_phase` (default), `randomized_contraction`. The value `graphframes` is a deprecated alias for `two_phase`. - `maxIter` -For `graphx` only. Limit the maximal amount of Pregel iterations. By default it is infinity (`Integer.maxValue`). It is recommended do not change this value. If the algorithm stucks, it is a problem of the graph, not algorithm. +For `graphx` only. Limits the maximum number of Pregel iterations. Default is `Integer.MAX_VALUE` (unlimited). It is generally not recommended to change this value. - `checkpoint_interval` -For `graphframes` only. To avoid exponential growing of the Spark' Logical Plan, DataFrame lineage and query optimization time, it is required to do checkpointing periodically. While checkpoint itself is not free, it is still recommended to set this value to something less than `5`. +For `two_phase` and `randomized_contraction`. To avoid exponential growth of the Spark logical plan, DataFrame lineage, and query optimization time, checkpointing is performed periodically. It is recommended to keep this value at `2` or below. - `broadcast_threshold` -For `graphframes` only. See [this section](05-traversals.md#aqe-broadcast-mode) for details. +For `two_phase` only. See [AQE-broadcast mode](#aqe-broadcast-mode) below for details. - `use_labels_as_components` -For `graphframes` only. In the case, when the type of the input graph vertices is not one of `Long`, `Int`, `Short`, `Byte`, output labels (components) are a random `Long` numbers by default. By providing `use_labels_as_components=True` user can ask GraphFrames to use original vertex labels for output components. In that case, the minimal value of all original IDs will be used for each of found components. This operation is not free and require an additional `groupBy` + `agg` + `join`. +For `two_phase` and `randomized_contraction`. By default, component IDs are `Long` values. For `two_phase` with integral vertex ID types, the component ID is the minimum original vertex ID in the component. For `String`-typed vertices (or any non-integral type), and always for `randomized_contraction`, the component ID is a random `Long`. By setting `use_labels_as_components=True`, GraphFrames will instead use the minimum original vertex label as the component ID. This requires an additional `groupBy` + `agg` + `join` and is not free. - `use_local_checkpoints` -For `graphframes` only. By default, GraphFrames uses persistent checkpoints. They are realiable and reduce the errors rate. The downside of the persistent checkpoints is that they are requiride to set up a `checkpointDir` in persistent storage like `S3` or `HDFS`. By providing `use_local_checkpoints=True`, user can say GraphFrames to use local disks of Spark' executurs for checkpointing. Local checkpoints are faster, but they are less reliable: if the executur lost, for example, is taking by the higher priority job, checkpoints will be lost and the whole job fails. +For `two_phase` and `randomized_contraction`. By default, GraphFrames uses persistent checkpoints, which are reliable but require a `checkpointDir` to be configured in persistent storage (e.g. S3 or HDFS). Setting `use_local_checkpoints=True` uses the local disks of Spark executors instead. Local checkpoints are faster but less reliable: if an executor is lost, the checkpoint is lost and the job will fail. - `storage_level` -The level of storage for intermediate results and the output `DataFrame` with components. By default it is memory and disk deserialized as a good balance between performance and reliability. For very big graphs and out-of-core scenarious, using `DISK_ONLY` may be faster. +The storage level for intermediate datasets and the output DataFrame. Default is `MEMORY_AND_DISK`. For very large graphs or out-of-core scenarios, `DISK_ONLY` may be preferable. ### AQE-broadcast mode -*Starting from 0.10.0* +_Starting from 0.10.0_ + +For `two_phase` only. During iterations, this algorithm can produce edges with highly skewed degree distributions, where some vertices have very high degree. In earlier versions of GraphFrames, this was handled by manually broadcasting high-degree nodes. However, this manual broadcasting is incompatible with Apache Spark Adaptive Query Execution (AQE), which is why AQE was previously disabled for Connected Components. + +From GraphFrames 0.10+, you can disable manual broadcasting and instead rely on AQE to handle skewness automatically. To enable this mode, pass `-1` as the `broadcast_threshold` (or call `setBroadcastThreshold(-1)`). Based on benchmarks, this mode provides approximately **5x speed-up** over the manual skewed-join mode. It is possible that in a future release, `-1` will become the default value for `broadcast_threshold`. + +### Advanced: skipping graph preparation -For `graphframes` algorithm only. During iterations, this algorithm can generate new edges that may tend to high skewness in joins and aggregates, because some vertices are having a very-high degree. In previous versions of GraphFrames this issue was addressed by manual broadcasting very high-degree nodes. Unfortunately, Apache Spark Adaptive Quey Execution optimization fails on such a case and that was the reason shy AQE was disabled for Connected Components. +_For advanced JVM users only._ -In the new versions of GraphFrames (0.10+) there is a way to disable manual broadcasting, enable AQE and allow it to handle skewnewss. To enable this mode, pass `-1` to the `setBroadcastThreshold`. Based on benchmarks, this mode provides about 5x speed-up. It is possible, that in the future releases, the default value of the `broadcastThreshold` will be changed to `-1`. +Internally, both `two_phase` and `randomized_contraction` perform a graph preparation step before running the algorithm (re-indexing vertices, symmetrizing edges, removing self-loops, etc.). The preparation steps differ between the two algorithms and are **not interchangeable**. + +If you have already performed the exact preparation steps yourself and fully understand what they entail for the specific algorithm you are using, you can skip the internal preparation by calling `setIsGraphPrepared(true)`. This is an internal API intended only for users who have studied the source code of the algorithm in detail. + +**Warning:** Incorrect use of this flag will produce silently wrong results with no error or warning at runtime. ### Strongly connected components @@ -211,7 +264,7 @@ See [Wikipedia](https://en.wikipedia.org/wiki/Strongly_connected_component) for **NOTE** -*Be aware, that returned `DataFrame` is persistent and should be unpersisted manually after processing to avoid memory leaks!* +_Be aware, that returned `DataFrame` is persistent and should be unpersisted manually after processing to avoid memory leaks!_ --- @@ -250,6 +303,7 @@ Triangle count computes the number of triangles passing through each vertex. A t ### Performance and Use Cases Counting triangles is a fundamental task in network analysis: + - **Clustering Coefficient**: It is used to compute the local and global clustering coefficients, which measure the degree to which nodes in a graph tend to cluster together. - **Community Detection**: A high density of triangles often indicates the presence of a tightly knit community or "clique." - **Spam and Fraud Detection**: In social networks and financial transactions, unusual triangle patterns can help identify botnets or money-laundering rings. @@ -263,11 +317,11 @@ The core logic of the algorithm is based on **neighborhood intersection**. For e GraphFrames provides two implementations with different performance characteristics: - **Exact**: This is the default algorithm. It computes the precise intersection of adjacency lists. - - **Pros**: 100% accuracy. - - **Cons**: Extremely memory-intensive. For high-degree nodes (hubs), collecting and intersecting large neighbor sets can lead to Out-of-Memory (OOM) errors or severe skew. + - **Pros**: 100% accuracy. + - **Cons**: Extremely memory-intensive. For high-degree nodes (hubs), collecting and intersecting large neighbor sets can lead to Out-of-Memory (OOM) errors or severe skew. - **Approximate** (Starting from Spark 4.1): This version uses **DataSketches (Theta sketches)** to estimate the size of the intersection. - - **Pros**: Highly scalable. It uses a fixed-size probabilistic structure to represent neighborhoods, dramatically reducing memory overhead and execution time. - - **Cons**: Provides an estimate rather than an exact count. + - **Pros**: Highly scalable. It uses a fixed-size probabilistic structure to represent neighborhoods, dramatically reducing memory overhead and execution time. + - **Cons**: Provides an estimate rather than an exact count. ### Selection Guide @@ -309,11 +363,11 @@ the [Rocha–Thatte cycle detection algorithm](https://en.wikipedia.org/wiki/Roc **NOTE** -*Be aware, that returned `DataFrame` is persistent and should be unpersisted manually after processing to avoid memory leaks!* +_Be aware, that returned `DataFrame` is persistent and should be unpersisted manually after processing to avoid memory leaks!_ **WARNING:** -- *This algorithm collects the full sequences and may require a lot of cluste memory for power-law graphs* +- _This algorithm collects the full sequences and may require a lot of cluste memory for power-law graphs_ --- diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index 5179eb2a3..d74ec29d3 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -525,7 +525,8 @@ def connectedComponents( See Scala documentation for more details. :param algorithm: connected components algorithm to use (default: "graphframes") - Supported algorithms are "graphframes" and "graphx". + Supported algorithms are "two_phase", "randomized_contraction", + "graphframes" (deprecated alias for "two_phase") and "graphx". :param checkpointInterval: checkpoint interval in terms of number of iterations (default: 2) :param broadcastThreshold: broadcast threshold in propagating component assignments (default: 1000000). Passing -1 disable manual broadcasting and