Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/src/main/scala/org/graphframes/Logging.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,8 @@ private[org] trait Logging {
protected def logTrace(s: => String): Unit = {
if (logger.isTraceEnabled) logger.trace(s)
}

protected def resultIsPersistent(): Unit = {
logWarn("Returned DataFrame is persistent and materialized!")
}
}
161 changes: 82 additions & 79 deletions core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.graphframes.lib

import org.apache.hadoop.fs.Path
import org.apache.spark.graphframes.graphx
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
Expand Down Expand Up @@ -94,7 +93,6 @@ object ConnectedComponents extends Logging {
private val ORIG_ID = "orig_id"
private val MIN_NBR = "min_nbr"
private val CNT = "cnt"
private val CHECKPOINT_NAME_PREFIX = "connected-components"

/**
* Returns the symmetric directed graph of the graph specified by input edges.
Expand All @@ -121,11 +119,8 @@ object ConnectedComponents extends Logging {
* `dst`.
*/
private def prepare(graph: GraphFrame): GraphFrame = {
// TODO: This assignment job might fail if the graph is skewed.
val vertices = graph.indexedVertices
.select(col(LONG_ID).as(ID), col(ATTR))
// TODO: confirm the contract for a graph and decide whether we need distinct here
// .distinct()
val edges = graph.indexedEdges
.select(col(LONG_SRC).as(SRC), col(LONG_DST).as(DST))
val orderedEdges = edges
Expand All @@ -141,20 +136,36 @@ object ConnectedComponents extends Logging {
* - `min_nbr`, the min vertex ID among itself and its neighbors
* - `cnt`, the total number of neighbors
*/
private def minNbrs(ee: DataFrame): DataFrame = {
symmetrize(ee)
.groupBy(SRC)
.agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT))
.withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR)))
private def minNbrs(
ee: DataFrame,
computeCount: Boolean,
includeSelf: Boolean,
doSymmetrize: Boolean): DataFrame = {
val ee2 = if (doSymmetrize) {
symmetrize(ee)
} else {
ee
}
val res = if (computeCount) {
ee2
.groupBy(SRC)
.agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT))
} else {
ee2
.groupBy(SRC)
.agg(min(col(DST)).as(MIN_NBR))
}
if (includeSelf) {
res.withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR)))
} else { res }
}

private def minValue(x: Column, y: Column): Column = {
when(x < y, x).otherwise(y)
}

private def maxValue(x: Column, y: Column): Column = {
private def maxValue(x: Column, y: Column): Column =
when(x > y, x).otherwise(y)
}

/**
* Performs a possibly skewed join between edges and current component assignments. The skew
Expand Down Expand Up @@ -207,40 +218,30 @@ object ConnectedComponents extends Logging {
}

val spark = graph.spark
val sc = spark.sparkContext
// Store original AQE setting
val originalAQE = spark.conf.get("spark.sql.adaptive.enabled")

try {
spark.conf.set("spark.sql.adaptive.enabled", "false")
val shouldDoSkewedJoin = broadcastThreshold != -1

if (shouldDoSkewedJoin) {
spark.conf.set("spark.sql.adaptive.enabled", "false")
}

val runId = UUID.randomUUID().toString.takeRight(8)
val logPrefix = s"[CC $runId]"
logInfo(s"$logPrefix Start connected components with run ID $runId.")

val shouldCheckpoint = checkpointInterval > 0
val checkpointDir: Option[String] = if (useLocalCheckpoints) { None }
else if (shouldCheckpoint) {
val dir = sc.getCheckpointDir
.map { d =>
new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString
}
.getOrElse {
// Spark-Connect workaround
spark.conf.getOption("spark.checkpoint.dir") match {
case Some(d) => new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString
case None =>
throw new IOException(
"Checkpoint directory is not set. Please set it first using sc.setCheckpointDir()" +
"or by specifying the conf 'spark.checkpoint.dir'.")
}
}
logInfo(s"$logPrefix Using $dir for checkpointing with interval $checkpointInterval.")
Some(dir)
} else {
logInfo(
s"$logPrefix Checkpointing is disabled because checkpointInterval=$checkpointInterval.")
None
if (shouldCheckpoint && !useLocalCheckpoints && spark.sparkContext.getCheckpointDir.isEmpty) {
// Spark-Connect workaround
spark.sparkContext.setCheckpointDir(spark.conf.getOption("spark.checkpoint.dir") match {
case Some(d) => d
case None =>
throw new IOException(
"Checkpoint directory is not set. Please set it first using sc.setCheckpointDir()" +
"or by specifying the conf 'spark.checkpoint.dir'.")
})
}

logInfo(s"$logPrefix Preparing the graph for connected component computation ...")
Expand All @@ -253,29 +254,24 @@ object ConnectedComponents extends Logging {
var iteration = 1

def _calcMinNbrSum(minNbrsDF: DataFrame): BigDecimal = {
// Taking the sum in DecimalType to preserve precision.
// We use 20 digits for long values and Spark SQL will add 10 digits for the sum.
// It should be able to handle 200 billion edges without overflow.
val (minNbrSum, cnt) = minNbrsDF
.select(sum(col(MIN_NBR).cast(DecimalType(20, 0))), count("*"))
.rdd
.map { r =>
(r.getAs[BigDecimal](0), r.getLong(1))
}
.first()
if (cnt != 0L && minNbrSum == null) {
throw new ArithmeticException(s"""
|The total sum of edge src IDs is used to determine convergence during iterations.
|However, the total sum at iteration $iteration exceeded 30 digits (1e30),
|which should happen only if the graph contains more than 200 billion edges.
|If not, please file a bug report at https://github.com/graphframes/graphframes/issues.
""".stripMargin)
if (shouldDoSkewedJoin) {
minNbrsDF
.select(sum(col(MIN_NBR).cast(DecimalType(38, 0))), count("*"))
.first()
.getDecimal(0)
} else {
minNbrsDF.select(sum(col(MIN_NBR).cast(DecimalType(38, 0)))).first().getDecimal(0)
}
minNbrSum
}
// compute min neighbors (including self-min)
var minNbrs1: DataFrame = minNbrs(ee) // src >= min_nbr
.persist(intermediateStorageLevel)
var minNbrs1: DataFrame =
minNbrs(
ee,
computeCount = shouldDoSkewedJoin,
includeSelf = true,
doSymmetrize = true
) // src >= min_nbr
.persist(intermediateStorageLevel)

var prevSum: BigDecimal = _calcMinNbrSum(minNbrs1)

Expand All @@ -284,23 +280,37 @@ object ConnectedComponents extends Logging {
var currRoundPersistedDFs = Seq[DataFrame]()
// large-star step
// connect all strictly larger neighbors to the min neighbor (including self)
ee = skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix)
.select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst
ee = {
if (shouldDoSkewedJoin) {
skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix)

} else {
ee.join(minNbrs1, SRC)
}
}.select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst
.distinct()
.persist(intermediateStorageLevel)

currRoundPersistedDFs = currRoundPersistedDFs :+ ee

// small-star step
// compute min neighbors (excluding self-min)
val minNbrs2 = ee
.groupBy(col(SRC))
.agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr
val minNbrs2 = minNbrs(
ee,
computeCount = shouldDoSkewedJoin,
includeSelf = false,
doSymmetrize = false)
.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs2

// connect all smaller neighbors to the min neighbor
ee = skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix)
.select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst
ee = {
if (shouldDoSkewedJoin) {
skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix)
} else {
ee.join(minNbrs2, SRC)
}
}.select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst
.filter(col(SRC) =!= col(DST)) // src < dst
// connect self to the min neighbor
ee = ee
Expand All @@ -312,26 +322,19 @@ object ConnectedComponents extends Logging {
if (useLocalCheckpoints) {
ee = ee.localCheckpoint(eager = true)
} else {
// TODO: remove this after DataFrame.checkpoint is implemented
val out = s"${checkpointDir.get}/$iteration"
ee.write.parquet(out)
// may hit S3 eventually consistent issue
ee = spark.read.parquet(out)

// remove previous checkpoint
if (iteration > checkpointInterval) {
val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}")
path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
}

System.gc() // hint Spark to clean shuffle directories
ee = ee.checkpoint(eager = true)
}
}

ee.persist(intermediateStorageLevel)
ee = ee.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ ee

minNbrs1 = minNbrs(ee) // src >= min_nbr
minNbrs1 = minNbrs(
ee,
computeCount = shouldDoSkewedJoin,
includeSelf = true,
doSymmetrize = true
) // src >= min_nbr
.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1

Expand Down Expand Up @@ -386,7 +389,7 @@ object ConnectedComponents extends Logging {
persisted_df.unpersist()
}

logWarn("The DataFrame returned by ConnectedComponents is persisted and loaded.")
resultIsPersistent()

output
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.MapType
import org.apache.spark.storage.StorageLevel
import org.graphframes.GraphFrame
import org.graphframes.Logging
import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithCheckpointInterval
import org.graphframes.WithLocalCheckpoints
Expand All @@ -49,11 +50,12 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame)
with WithAlgorithmChoice
with WithCheckpointInterval
with WithMaxIter
with WithLocalCheckpoints {
with WithLocalCheckpoints
with Logging {

def run(): DataFrame = {
val maxIterChecked = check(maxIter, "maxIter")
algorithm match {
val res = algorithm match {
case "graphx" => LabelPropagation.runInGraphX(graph, maxIterChecked)
case "graphframes" =>
LabelPropagation.runInGraphFrames(
Expand All @@ -62,6 +64,8 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame)
checkpointInterval,
useLocalCheckpoints = useLocalCheckpoints)
}
resultIsPersistent()
res
}
}

Expand Down Expand Up @@ -127,5 +131,4 @@ private object LabelPropagation {
}

private val LABEL_ID = "label"

}
17 changes: 14 additions & 3 deletions core/src/main/scala/org/graphframes/lib/PageRank.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.graphframes.lib

import org.apache.spark.graphframes.graphx.{lib => graphxlib}
import org.graphframes.GraphFrame
import org.graphframes.Logging

/**
* PageRank algorithm implementation. There are two implementations of PageRank.
Expand Down Expand Up @@ -63,7 +64,9 @@ import org.graphframes.GraphFrame
* The resulting edges DataFrame contains one additional column:
* - weight (`DoubleType`): the normalized weight of this edge after running PageRank
*/
class PageRank private[graphframes] (private val graph: GraphFrame) extends Arguments {
class PageRank private[graphframes] (private val graph: GraphFrame)
extends Arguments
with Logging {

private var tol: Option[Double] = None
private var resetProb: Option[Double] = Some(0.15)
Expand Down Expand Up @@ -93,13 +96,15 @@ class PageRank private[graphframes] (private val graph: GraphFrame) extends Argu
}

def run(): GraphFrame = {
tol match {
val res = tol match {
case Some(t) =>
assert(maxIter.isEmpty, "You cannot specify maxIter() and tol() at the same time.")
PageRank.runUntilConvergence(graph, t, resetProb.get, srcId)
case None =>
PageRank.run(graph, check(maxIter, "maxIter"), resetProb.get, srcId)
}
resultIsPersistent()
res
}
}

Expand Down Expand Up @@ -129,7 +134,13 @@ private object PageRank {
val longSrcId = srcId.map(GraphXConversions.integralId(graph, _))
val gx =
graphxlib.PageRank.runWithOptions(graph.cachedTopologyGraphX, maxIter, resetProb, longSrcId)
GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANK), edgeNames = Seq(WEIGHT))
val res = GraphXConversions
.fromGraphX(graph, gx, vertexNames = Seq(PAGERANK), edgeNames = Seq(WEIGHT))
.persist()
res.vertices.count()
res.edges.count()
gx.unpersist()
res
}

/**
Expand Down
Loading