diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 1713c58da..1b1148915 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -21,16 +21,16 @@ import java.io.IOException import java.math.BigDecimal import java.util.UUID -import org.graphframes.{GraphFrame, Logging} +import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.{FileSystem, Path} +import org.graphframes.{GraphFrame, Logging} import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DecimalType import org.apache.spark.storage.StorageLevel /** - * Connected components algorithm. + * Connected Components algorithm. * * Computes the connected component membership of each vertex and returns a DataFrame of vertex * information with each vertex assigned a component ID. @@ -105,7 +105,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) * Sets checkpoint interval in terms of number of iterations (default: 2). Checkpointing * regularly helps recover from failures, clean shuffle files, shorten the lineage of the * computation graph, and reduce the complexity of plan optimization. As of Spark 2.0, the - * complexity of plan optimization would grow exponentially without checkpointing. Hence + * complexity of plan optimization would grow exponentially without checkpointing. Hence, * disabling or setting longer-than-default checkpoint intervals are not recommended. Checkpoint * data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with prefix * "connected-components". If the checkpoint directory is not set, this throws a @@ -171,7 +171,6 @@ object ConnectedComponents extends Logging { import org.graphframes.GraphFrame._ private val COMPONENT = "component" - private val ORIG_ID = "orig_id" private val MIN_NBR = "min_nbr" private val CNT = "cnt" private val CHECKPOINT_NAME_PREFIX = "connected-components" @@ -183,7 +182,7 @@ object ConnectedComponents extends Logging { * Supported algorithms in [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]: * "graphframes" and "graphx". */ - val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES) + private val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES) /** * Returns the symmetric directed graph of the graph specified by input edges. @@ -331,8 +330,7 @@ object ConnectedComponents extends Logging { val g = prepare(graph) val vv = g.vertices var ee = g.edges.persist(intermediateStorageLevel) // src < dst - val numEdges = ee.count() - logInfo(s"$logPrefix Found $numEdges edges after preparation.") + logInfo(s"$logPrefix Found ${ee.count()} edges after preparation.") var converged = false var iteration = 1 @@ -426,11 +424,7 @@ object ConnectedComponents extends Logging { prevSum = currSum } - // materialize all persisted DataFrames in current round, - // then we can unpersist last round persisted DataFrames. - for (persisted_df <- currRoundPersistedDFs) { - persisted_df.count() // materialize it. - } + // clean up persisted DFs for (persisted_df <- lastRoundPersistedDFs) { persisted_df.unpersist() } @@ -441,9 +435,23 @@ object ConnectedComponents extends Logging { logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.") logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.") - vv.join(ee, vv(ID) === ee(DST), "left_outer") + val output = vv + .join(ee, vv(ID) === ee(DST), "left_outer") .select(vv(ATTR), when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(COMPONENT)) .select(col(s"$ATTR.*"), col(COMPONENT)) + .persist(intermediateStorageLevel) + + // An action must be performed on the DataFrame for the cache to load + output.count() + + // clean up persisted DFs + for (persisted_df <- lastRoundPersistedDFs) { + persisted_df.unpersist() + } + + logWarn("The DataFrame returned by ConnectedComponents is persisted and loaded.") + + output } finally { // Restore original AQE setting spark.conf.set("spark.sql.adaptive.enabled", originalAQE) diff --git a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index f55ae4edd..32221311b 100644 --- a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -253,6 +253,17 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon } } + test("not leaking cached data") { + val priorCachedDFsSize = spark.sparkContext.getPersistentRDDs.size + + val cc = Graphs.friends.connectedComponents + val components = cc.run() + + components.unpersist(blocking = true) + + assert(spark.sparkContext.getPersistentRDDs.size === priorCachedDFsSize) + } + private def assertComponents[T: ClassTag: TypeTag]( actual: DataFrame, expected: Set[Set[T]]): Unit = {