-
Notifications
You must be signed in to change notification settings - Fork 268
Fix connected component #454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -313,22 +313,49 @@ object ConnectedComponents extends Logging { | |
|
|
||
| var converged = false | ||
| var iteration = 1 | ||
| var prevSum: BigDecimal = null | ||
|
|
||
| def _calcMinNbrSum(minNbrsDF: DataFrame): BigDecimal = { | ||
| // Taking the sum in DecimalType to preserve precision. | ||
| // We use 20 digits for long values and Spark SQL will add 10 digits for the sum. | ||
| // It should be able to handle 200 billion edges without overflow. | ||
| val (minNbrSum, cnt) = minNbrsDF.select(sum(col(MIN_NBR).cast(DecimalType(20, 0))), count("*")).rdd | ||
| .map { r => | ||
| (r.getAs[BigDecimal](0), r.getLong(1)) | ||
| }.first() | ||
| if (cnt != 0L && minNbrSum == null) { | ||
| throw new ArithmeticException( | ||
| s""" | ||
| |The total sum of edge src IDs is used to determine convergence during iterations. | ||
| |However, the total sum at iteration $iteration exceeded 30 digits (1e30), | ||
| |which should happen only if the graph contains more than 200 billion edges. | ||
| |If not, please file a bug report at https://github.com/graphframes/graphframes/issues. | ||
| """.stripMargin) | ||
| } | ||
| minNbrSum | ||
| } | ||
| // compute min neighbors (including self-min) | ||
| var minNbrs1: DataFrame = minNbrs(ee) // src >= min_nbr | ||
| .persist(intermediateStorageLevel) | ||
|
|
||
| var prevSum: BigDecimal = _calcMinNbrSum(minNbrs1) | ||
|
|
||
| var lastRoundPersistedDFs = Seq[DataFrame](ee, minNbrs1) | ||
| while (!converged) { | ||
| var currRoundPersistedDFs = Seq[DataFrame]() | ||
| // large-star step | ||
| // compute min neighbors (including self-min) | ||
| val minNbrs1 = minNbrs(ee) // src >= min_nbr | ||
| .persist(intermediateStorageLevel) | ||
| // connect all strictly larger neighbors to the min neighbor (including self) | ||
| ee = skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix) | ||
| .select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst | ||
| .distinct() | ||
| .persist(intermediateStorageLevel) | ||
| currRoundPersistedDFs = currRoundPersistedDFs :+ ee | ||
|
|
||
| // small-star step | ||
| // compute min neighbors (excluding self-min) | ||
| val minNbrs2 = ee.groupBy(col(SRC)).agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr | ||
| .persist(intermediateStorageLevel) | ||
| currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs2 | ||
|
|
||
| // connect all smaller neighbors to the min neighbor | ||
| ee = skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix) | ||
| .select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst | ||
|
|
@@ -355,25 +382,14 @@ object ConnectedComponents extends Logging { | |
| } | ||
|
|
||
| ee.persist(intermediateStorageLevel) | ||
| currRoundPersistedDFs = currRoundPersistedDFs :+ ee | ||
|
|
||
| // test convergence | ||
| minNbrs1 = minNbrs(ee) // src >= min_nbr | ||
| .persist(intermediateStorageLevel) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Within the termination of the while loop, should we also explicitly call unpersist() in order to remove the cache references to these pseudo-mutable vars that represent an immutable cache state?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will revisit and fix the persist/unpersist operations |
||
| currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1 | ||
|
|
||
| // Taking the sum in DecimalType to preserve precision. | ||
| // We use 20 digits for long values and Spark SQL will add 10 digits for the sum. | ||
| // It should be able to handle 200 billion edges without overflow. | ||
| val (currSum, cnt) = ee.select(sum(col(SRC).cast(DecimalType(20, 0))), count("*")).rdd | ||
| .map { r => | ||
| (r.getAs[BigDecimal](0), r.getLong(1)) | ||
| }.first() | ||
| if (cnt != 0L && currSum == null) { | ||
| throw new ArithmeticException( | ||
| s""" | ||
| |The total sum of edge src IDs is used to determine convergence during iterations. | ||
| |However, the total sum at iteration $iteration exceeded 30 digits (1e30), | ||
| |which should happen only if the graph contains more than 200 billion edges. | ||
| |If not, please file a bug report at https://github.com/graphframes/graphframes/issues. | ||
| """.stripMargin) | ||
| } | ||
| // test convergence | ||
| val currSum = _calcMinNbrSum(minNbrs1) | ||
| logInfo(s"$logPrefix Sum of assigned components in iteration $iteration: $currSum.") | ||
| if (currSum == prevSum) { | ||
| // This also covers the case when cnt = 0 and currSum is null, which means no edges. | ||
|
|
@@ -382,6 +398,15 @@ object ConnectedComponents extends Logging { | |
| prevSum = currSum | ||
| } | ||
|
|
||
| // materialize all persisted DataFrames in current round, | ||
| // then we can unpersist last round persisted DataFrames. | ||
| for (persisted_df <- currRoundPersistedDFs) { | ||
| persisted_df.count() // materialize it. | ||
| } | ||
| for (persisted_df <- lastRoundPersistedDFs) { | ||
| persisted_df.unpersist() | ||
| } | ||
| lastRoundPersistedDFs = currRoundPersistedDFs | ||
| iteration += 1 | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to do the first minimization vector connection reduction outside of the while loop. +1