Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 141 additions & 14 deletions core/src/main/scala/org/graphframes/lib/TwoPhase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,19 @@ private[graphframes] object TwoPhase extends Logging {
}

/**
* Computes the sum of all `min_nbr` values in the given DataFrame, cast to DecimalType(38, 0)
* for high precision. Used to detect convergence between iterations.
* Computes the sum of all `min_nbr` values and the undirected edge count of the given DataFrame
* in a single Spark job. The sum is cast to DecimalType(38, 0) for high precision. Used to
* detect convergence between iterations and to check graph sparsity for the pruning
* optimization. The edge count is derived as `sum(cnt) / 2` since each undirected edge appears
* once in each direction in the symmetrized graph.
*/
private def calcMinNbrSum(minNbrsDF: DataFrame): BigDecimal = {
minNbrsDF
.select(sum(col(MIN_NBR).cast(DecimalType(38, 0))))
private def calcMinNbrSum(minNbrsDF: DataFrame): (BigDecimal, Long) = {
val row = minNbrsDF
.select(
sum(col(MIN_NBR).cast(DecimalType(38, 0))),
coalesce((sum(col(CNT)) / 2).cast("long"), lit(0L)))
.first()
.getAs[BigDecimal](0)
(row.getAs[BigDecimal](0), row.getLong(1))
Comment thread
WeichenXu123 marked this conversation as resolved.
}

/**
Expand Down Expand Up @@ -172,6 +177,70 @@ private[graphframes] object TwoPhase extends Logging {
GraphFrame.skewedJoin(edges, minNbrsDF, SRC, hubs, logPrefix)
}

/**
* Prunes leaf nodes (vertices with out-degree 0 and in-degree 1) from the graph to create a
* smaller shrunken graph. Returns Some((vertices, edges, nodeCount)) if the shrunken graph is
* significantly smaller than the original (by shrinkageThreshold), None otherwise.
*/
private[graphframes] def pruneLeafNodes(
edges: DataFrame,
intermediateStorageLevel: StorageLevel,
numNodes: Long,
shrinkageThreshold: Double): Option[(DataFrame, DataFrame, Long)] = {

// vertices whose indegree > 1
val v1 = edges
.groupBy(DST)
.agg(count("*").as(CNT))
.filter(col(CNT) > 1)
.select(col(DST).as(ID))

// vertices whose outdegree > 0 or indegree > 1
val newVV = edges
.select(col(SRC).as(ID))
.union(v1)
.distinct()
.persist(intermediateStorageLevel)
val newVVCnt = newVV.count()

if (newVVCnt * shrinkageThreshold < numNodes) {
val newEE = edges
.join(newVV.withColumnRenamed(ID, DST), DST)
.persist(intermediateStorageLevel)
Some((newVV, newEE, newVVCnt))
} else {
newVV.unpersist(blocking = false)
None
}
}

/**
* Given the vertices and converged edges of the shrunken graph, joins back to reconstruct the
* converged edges of the original graph.
*/
private[graphframes] def joinBack(
vertices: DataFrame,
edges: DataFrame,
edgesBeforePruning: DataFrame): DataFrame = {

val cc = vertices
.as("vertices")
.join(edges.as("edges"), col(s"vertices.$ID") === col(s"edges.$DST"), "left_outer")
.select(
when(col(s"edges.$SRC").isNull, col(s"vertices.$ID"))
.otherwise(col(s"edges.$SRC"))
.as(SRC),
col(s"vertices.$ID").as(DST))

cc.as("cc")
.join(
edgesBeforePruning.as("edgesBeforePruning"),
col(s"cc.$DST") === col(s"edgesBeforePruning.$SRC"))
.select(col(s"cc.$SRC"), col(s"edgesBeforePruning.$DST"))
.union(cc)
.distinct()
}

/**
* Runs the two-phase label propagation connected components algorithm.
*/
Expand All @@ -182,7 +251,10 @@ private[graphframes] object TwoPhase extends Logging {
intermediateStorageLevel: StorageLevel,
useLabelsAsComponents: Boolean,
useLocalCheckpoints: Boolean,
isGraphPrepared: Boolean): DataFrame = {
isGraphPrepared: Boolean,
optStartIter: Int = 2,
sparsityThreshold: Double = 2.0,
shrinkageThreshold: Double = 2.0): DataFrame = {

val spark = graph.spark
val sc = spark.sparkContext
Expand Down Expand Up @@ -224,14 +296,21 @@ private[graphframes] object TwoPhase extends Logging {
val vv = g.vertices
var ee = g.edges.persist(intermediateStorageLevel) // src < dst
logInfo(s"$logPrefix Found ${ee.count()} edges after preparation.")
var numNodes = vv.count()
logInfo(s"$logPrefix Found $numNodes nodes after preparation.")

var converged = false
var iteration = 1
var isOptimized = false
var triedToOptimize = false
var shouldKeepCheckpoint = false
var edgesBeforePruning: DataFrame = null
var shrunkenGraphNodes: DataFrame = null

var minNbrs1: DataFrame = minNbrs(ee) // src >= min_nbr
.persist(intermediateStorageLevel)

var prevSum: BigDecimal = calcMinNbrSum(minNbrs1)
var (prevSum, _) = calcMinNbrSum(minNbrs1)

var lastRoundPersistedDFs = Seq[DataFrame](ee, minNbrs1)
while (!converged) {
Expand Down Expand Up @@ -273,7 +352,12 @@ private[graphframes] object TwoPhase extends Logging {

if (iteration > checkpointInterval) {
val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}")
path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
// keep the checkpoint when edgesBeforePruning points to it
if (!shouldKeepCheckpoint) {
path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
} else {
shouldKeepCheckpoint = false
}
}

System.gc()
Expand All @@ -288,8 +372,44 @@ private[graphframes] object TwoPhase extends Logging {
currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1

// test convergence
val currSum = calcMinNbrSum(minNbrs1)
val (currSum, edgeCnt) = calcMinNbrSum(minNbrs1)
logInfo(s"$logPrefix Sum of assigned components in iteration $iteration: $currSum.")

// Pruning Node Optimization: construct a new small graph with fewer nodes,
// and find connected components of the shrunken graph, then join back to get the
// connected components of the original graph.

// If the graph becomes sparse and current iteration >= $optStartIter, we start to
// try such optimization. However, the optimization is only performed if the shrunken
// graph is much smaller than the original graph, otherwise we do not perform it (
// in this case, the only additional cost is to determine the size of shrunken graph).
// In current implementation, we only try such optimization one time and it is
// performed at most one time. So the additional cost is bounded.

// According to such heuristic rule, we can determine when and whether we should
// perform the optimization. For the sparse graphs (defined by sparsityThreshold),
// we will try such optimization at the end of $optStartIter iteration (default is 2).
// For the dense graph, its edges will be pruned at each large/small star join iteration,
// and we will try the optimization once the graph becomes sparse.
if ((edgeCnt < sparsityThreshold * numNodes) && (edgeCnt > 0)
&& (iteration >= optStartIter) && (!triedToOptimize)) {
edgesBeforePruning = ee
pruneLeafNodes(ee, intermediateStorageLevel, numNodes, shrinkageThreshold) match {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update the currSum here? I mean if we shrink the graph and convergence should happen next iteration we won't catch that algorithm is converged because the preSum at the next iteration will be currSum from the iteration when shrinking happened but that sum was computed before shrinking and it is always be bigger.

case Some(r) =>
shrunkenGraphNodes = r._1
ee = r._2
currRoundPersistedDFs = currRoundPersistedDFs :+ ee
numNodes = r._3
isOptimized = true
shouldKeepCheckpoint = true
logInfo(s"$logPrefix Pruning node optimization performed in iteration $iteration.")
logInfo(s"$logPrefix Shrunken graph node count: $numNodes.")
case None =>
logInfo(s"$logPrefix Pruning node optimization not performed.")
}
triedToOptimize = true
}

if (currSum == prevSum) {
converged = true
} else {
Expand All @@ -303,6 +423,10 @@ private[graphframes] object TwoPhase extends Logging {
iteration += 1
}

if (isOptimized) {
ee = joinBack(shrunkenGraphNodes, ee, edgesBeforePruning)
}

logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.")
logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.")

Expand All @@ -314,6 +438,9 @@ private[graphframes] object TwoPhase extends Logging {
for (persistedDF <- lastRoundPersistedDFs) {
persistedDF.unpersist()
}
if (shrunkenGraphNodes != null) {
shrunkenGraphNodes.unpersist()
}

resultIsPersistent()

Expand Down Expand Up @@ -353,11 +480,11 @@ private[graphframes] object TwoPhase extends Logging {

var minNbrs1: DataFrame = symmetrize(ee)
.groupBy(SRC)
.agg(min(col(DST)).as(MIN_NBR))
.agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT))
.withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR)))
.persist(intermediateStorageLevel)

var prevSum: BigDecimal = calcMinNbrSum(minNbrs1)
var (prevSum, _) = calcMinNbrSum(minNbrs1)

var lastRoundPersistedDFs = Seq[DataFrame](ee, minNbrs1)
while (!converged) {
Expand Down Expand Up @@ -404,13 +531,13 @@ private[graphframes] object TwoPhase extends Logging {

minNbrs1 = symmetrize(ee)
.groupBy(SRC)
.agg(min(col(DST)).as(MIN_NBR))
.agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT))
.withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR)))
.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1

// test convergence
val currSum = calcMinNbrSum(minNbrs1)
val (currSum, _) = calcMinNbrSum(minNbrs1)
logInfo(s"$logPrefix Sum of assigned components in iteration $iteration: $currSum.")
if (currSum == prevSum) {
converged = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkContext {

// vertices and edges for pruning node optimization tests.
var verticesOpt: DataFrame = _
var edgesOpt: DataFrame = _

override def beforeAll(): Unit = {
super.beforeAll()
verticesOpt = spark.range(7L).toDF(ID)
edgesOpt = spark
.createDataFrame(Seq((0L, 1L), (0L, 2L), (0L, 3L), (0L, 4L), (1L, 2L), (1L, 5L)))
.toDF(SRC, DST)
}

test("default params") {
val g = Graphs.empty[Int]
val cc = g.connectedComponents
Expand Down Expand Up @@ -319,6 +332,47 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
assert(spark.sparkContext.getPersistentRDDs.size === priorCachedDFsSize)
}

test("prune process for pruning nodes optimization") {
val intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK
val shrinkageThreshold = 2.0
val Some(r1) = TwoPhase.pruneLeafNodes(
edgesOpt,
intermediateStorageLevel,
verticesOpt.count(),
shrinkageThreshold)

val expectedV = Set(Row(0L), Row(1L), Row(2L))
val expectedE = Set(Row(0L, 1L), Row(1L, 2L), Row(0L, 2L))

assert(r1._1.collect().toSet == expectedV)
assert(r1._2.select(SRC, DST).collect().toSet == expectedE)
assert(r1._3 == expectedV.size)
r1._1.unpersist()
r1._2.unpersist()
}

test("shrinkage condition for pruning nodes optimization") {
val intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK
val shrinkageThreshold = 4.0
// new_vv_cnt = 3, nodeNum = 7, shrinkageThreshold = 4
// new_vv_cnt * shrinkageThreshold > nodeNum. Do not perform the optimization.
val r1 = TwoPhase.pruneLeafNodes(
edgesOpt,
intermediateStorageLevel,
verticesOpt.count(),
shrinkageThreshold)
assert(r1 == None)
}

test("join back for pruning node optimization") {
val v1 = spark.range(3L).toDF(ID)
val e1 = spark.createDataFrame(Seq((0L, 1L), (0L, 2L))).toDF(SRC, DST)
val r = TwoPhase.joinBack(v1, e1, edgesOpt)
val expectedR =
Set(Row(0L, 0L), Row(0L, 1L), Row(0L, 2L), Row(0L, 3L), Row(0L, 4L), Row(0L, 5L))
assert(r.collect().toSet == expectedR)
}

private def assertComponents[T: ClassTag: TypeTag](
actual: DataFrame,
expected: Set[Set[T]]): Unit = {
Expand Down
Loading