-
Notifications
You must be signed in to change notification settings - Fork 268
Optimization: Add pruning for connected component 2 phase algorithm #846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,14 +110,19 @@ private[graphframes] object TwoPhase extends Logging { | |
| } | ||
|
|
||
| /** | ||
| * Computes the sum of all `min_nbr` values in the given DataFrame, cast to DecimalType(38, 0) | ||
| * for high precision. Used to detect convergence between iterations. | ||
| * Computes the sum of all `min_nbr` values and the undirected edge count of the given DataFrame | ||
| * in a single Spark job. The sum is cast to DecimalType(38, 0) for high precision. Used to | ||
| * detect convergence between iterations and to check graph sparsity for the pruning | ||
| * optimization. The edge count is derived as `sum(cnt) / 2` since each undirected edge appears | ||
| * once in each direction in the symmetrized graph. | ||
| */ | ||
| private def calcMinNbrSum(minNbrsDF: DataFrame): BigDecimal = { | ||
| minNbrsDF | ||
| .select(sum(col(MIN_NBR).cast(DecimalType(38, 0)))) | ||
| private def calcMinNbrSum(minNbrsDF: DataFrame): (BigDecimal, Long) = { | ||
| val row = minNbrsDF | ||
| .select( | ||
| sum(col(MIN_NBR).cast(DecimalType(38, 0))), | ||
| coalesce((sum(col(CNT)) / 2).cast("long"), lit(0L))) | ||
| .first() | ||
| .getAs[BigDecimal](0) | ||
| (row.getAs[BigDecimal](0), row.getLong(1)) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -172,6 +177,70 @@ private[graphframes] object TwoPhase extends Logging { | |
| GraphFrame.skewedJoin(edges, minNbrsDF, SRC, hubs, logPrefix) | ||
| } | ||
|
|
||
| /** | ||
| * Prunes leaf nodes (vertices with out-degree 0 and in-degree 1) from the graph to create a | ||
| * smaller shrunken graph. Returns Some((vertices, edges, nodeCount)) if the shrunken graph is | ||
| * significantly smaller than the original (by shrinkageThreshold), None otherwise. | ||
| */ | ||
| private[graphframes] def pruneLeafNodes( | ||
| edges: DataFrame, | ||
| intermediateStorageLevel: StorageLevel, | ||
| numNodes: Long, | ||
| shrinkageThreshold: Double): Option[(DataFrame, DataFrame, Long)] = { | ||
|
|
||
| // vertices whose indegree > 1 | ||
| val v1 = edges | ||
| .groupBy(DST) | ||
| .agg(count("*").as(CNT)) | ||
| .filter(col(CNT) > 1) | ||
| .select(col(DST).as(ID)) | ||
|
|
||
| // vertices whose outdegree > 0 or indegree > 1 | ||
| val newVV = edges | ||
| .select(col(SRC).as(ID)) | ||
| .union(v1) | ||
| .distinct() | ||
| .persist(intermediateStorageLevel) | ||
| val newVVCnt = newVV.count() | ||
|
|
||
| if (newVVCnt * shrinkageThreshold < numNodes) { | ||
| val newEE = edges | ||
| .join(newVV.withColumnRenamed(ID, DST), DST) | ||
| .persist(intermediateStorageLevel) | ||
| Some((newVV, newEE, newVVCnt)) | ||
| } else { | ||
| newVV.unpersist(blocking = false) | ||
| None | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Given the vertices and converged edges of the shrunken graph, joins back to reconstruct the | ||
| * converged edges of the original graph. | ||
| */ | ||
| private[graphframes] def joinBack( | ||
| vertices: DataFrame, | ||
| edges: DataFrame, | ||
| edgesBeforePruning: DataFrame): DataFrame = { | ||
|
|
||
| val cc = vertices | ||
| .as("vertices") | ||
| .join(edges.as("edges"), col(s"vertices.$ID") === col(s"edges.$DST"), "left_outer") | ||
| .select( | ||
| when(col(s"edges.$SRC").isNull, col(s"vertices.$ID")) | ||
| .otherwise(col(s"edges.$SRC")) | ||
| .as(SRC), | ||
| col(s"vertices.$ID").as(DST)) | ||
|
|
||
| cc.as("cc") | ||
| .join( | ||
| edgesBeforePruning.as("edgesBeforePruning"), | ||
| col(s"cc.$DST") === col(s"edgesBeforePruning.$SRC")) | ||
| .select(col(s"cc.$SRC"), col(s"edgesBeforePruning.$DST")) | ||
| .union(cc) | ||
| .distinct() | ||
| } | ||
|
|
||
| /** | ||
| * Runs the two-phase label propagation connected components algorithm. | ||
| */ | ||
|
|
@@ -182,7 +251,10 @@ private[graphframes] object TwoPhase extends Logging { | |
| intermediateStorageLevel: StorageLevel, | ||
| useLabelsAsComponents: Boolean, | ||
| useLocalCheckpoints: Boolean, | ||
| isGraphPrepared: Boolean): DataFrame = { | ||
| isGraphPrepared: Boolean, | ||
| optStartIter: Int = 2, | ||
| sparsityThreshold: Double = 2.0, | ||
| shrinkageThreshold: Double = 2.0): DataFrame = { | ||
|
|
||
| val spark = graph.spark | ||
| val sc = spark.sparkContext | ||
|
|
@@ -224,14 +296,21 @@ private[graphframes] object TwoPhase extends Logging { | |
| val vv = g.vertices | ||
| var ee = g.edges.persist(intermediateStorageLevel) // src < dst | ||
| logInfo(s"$logPrefix Found ${ee.count()} edges after preparation.") | ||
| var numNodes = vv.count() | ||
| logInfo(s"$logPrefix Found $numNodes nodes after preparation.") | ||
|
|
||
| var converged = false | ||
| var iteration = 1 | ||
| var isOptimized = false | ||
| var triedToOptimize = false | ||
| var shouldKeepCheckpoint = false | ||
| var edgesBeforePruning: DataFrame = null | ||
| var shrunkenGraphNodes: DataFrame = null | ||
|
|
||
| var minNbrs1: DataFrame = minNbrs(ee) // src >= min_nbr | ||
| .persist(intermediateStorageLevel) | ||
|
|
||
| var prevSum: BigDecimal = calcMinNbrSum(minNbrs1) | ||
| var (prevSum, _) = calcMinNbrSum(minNbrs1) | ||
|
|
||
| var lastRoundPersistedDFs = Seq[DataFrame](ee, minNbrs1) | ||
| while (!converged) { | ||
|
|
@@ -273,7 +352,12 @@ private[graphframes] object TwoPhase extends Logging { | |
|
|
||
| if (iteration > checkpointInterval) { | ||
| val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}") | ||
| path.getFileSystem(sc.hadoopConfiguration).delete(path, true) | ||
| // keep the checkpoint when edgesBeforePruning points to it | ||
| if (!shouldKeepCheckpoint) { | ||
| path.getFileSystem(sc.hadoopConfiguration).delete(path, true) | ||
| } else { | ||
| shouldKeepCheckpoint = false | ||
| } | ||
| } | ||
|
|
||
| System.gc() | ||
|
|
@@ -288,8 +372,44 @@ private[graphframes] object TwoPhase extends Logging { | |
| currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1 | ||
|
|
||
| // test convergence | ||
| val currSum = calcMinNbrSum(minNbrs1) | ||
| val (currSum, edgeCnt) = calcMinNbrSum(minNbrs1) | ||
| logInfo(s"$logPrefix Sum of assigned components in iteration $iteration: $currSum.") | ||
|
|
||
| // Pruning Node Optimization: construct a new small graph with fewer nodes, | ||
| // and find connected components of the shrunken graph, then join back to get the | ||
| // connected components of the original graph. | ||
|
|
||
| // If the graph becomes sparse and current iteration >= $optStartIter, we start to | ||
| // try such optimization. However, the optimization is only performed if the shrunken | ||
| // graph is much smaller than the original graph, otherwise we do not perform it ( | ||
| // in this case, the only additional cost is to determine the size of shrunken graph). | ||
| // In current implementation, we only try such optimization one time and it is | ||
| // performed at most one time. So the additional cost is bounded. | ||
|
|
||
| // According to such heuristic rule, we can determine when and whether we should | ||
| // perform the optimization. For the sparse graphs (defined by sparsityThreshold), | ||
| // we will try such optimization at the end of $optStartIter iteration (default is 2). | ||
| // For the dense graph, its edges will be pruned at each large/small star join iteration, | ||
| // and we will try the optimization once the graph becomes sparse. | ||
| if ((edgeCnt < sparsityThreshold * numNodes) && (edgeCnt > 0) | ||
| && (iteration >= optStartIter) && (!triedToOptimize)) { | ||
| edgesBeforePruning = ee | ||
| pruneLeafNodes(ee, intermediateStorageLevel, numNodes, shrinkageThreshold) match { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we update the |
||
| case Some(r) => | ||
| shrunkenGraphNodes = r._1 | ||
| ee = r._2 | ||
| currRoundPersistedDFs = currRoundPersistedDFs :+ ee | ||
| numNodes = r._3 | ||
| isOptimized = true | ||
| shouldKeepCheckpoint = true | ||
| logInfo(s"$logPrefix Pruning node optimization performed in iteration $iteration.") | ||
| logInfo(s"$logPrefix Shrunken graph node count: $numNodes.") | ||
| case None => | ||
| logInfo(s"$logPrefix Pruning node optimization not performed.") | ||
| } | ||
| triedToOptimize = true | ||
| } | ||
|
|
||
| if (currSum == prevSum) { | ||
| converged = true | ||
| } else { | ||
|
|
@@ -303,6 +423,10 @@ private[graphframes] object TwoPhase extends Logging { | |
| iteration += 1 | ||
| } | ||
|
|
||
| if (isOptimized) { | ||
| ee = joinBack(shrunkenGraphNodes, ee, edgesBeforePruning) | ||
| } | ||
|
|
||
| logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.") | ||
| logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.") | ||
|
|
||
|
|
@@ -314,6 +438,9 @@ private[graphframes] object TwoPhase extends Logging { | |
| for (persistedDF <- lastRoundPersistedDFs) { | ||
| persistedDF.unpersist() | ||
| } | ||
| if (shrunkenGraphNodes != null) { | ||
| shrunkenGraphNodes.unpersist() | ||
| } | ||
|
|
||
| resultIsPersistent() | ||
|
|
||
|
|
@@ -353,11 +480,11 @@ private[graphframes] object TwoPhase extends Logging { | |
|
|
||
| var minNbrs1: DataFrame = symmetrize(ee) | ||
| .groupBy(SRC) | ||
| .agg(min(col(DST)).as(MIN_NBR)) | ||
| .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) | ||
| .withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR))) | ||
| .persist(intermediateStorageLevel) | ||
|
|
||
| var prevSum: BigDecimal = calcMinNbrSum(minNbrs1) | ||
| var (prevSum, _) = calcMinNbrSum(minNbrs1) | ||
|
|
||
| var lastRoundPersistedDFs = Seq[DataFrame](ee, minNbrs1) | ||
| while (!converged) { | ||
|
|
@@ -404,13 +531,13 @@ private[graphframes] object TwoPhase extends Logging { | |
|
|
||
| minNbrs1 = symmetrize(ee) | ||
| .groupBy(SRC) | ||
| .agg(min(col(DST)).as(MIN_NBR)) | ||
| .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) | ||
| .withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR))) | ||
| .persist(intermediateStorageLevel) | ||
| currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1 | ||
|
|
||
| // test convergence | ||
| val currSum = calcMinNbrSum(minNbrs1) | ||
| val (currSum, _) = calcMinNbrSum(minNbrs1) | ||
| logInfo(s"$logPrefix Sum of assigned components in iteration $iteration: $currSum.") | ||
| if (currSum == prevSum) { | ||
| converged = true | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.