Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 3 additions & 34 deletions src/main/scala/org/graphframes/lib/ConnectedComponents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.storage.StorageLevel
import org.graphframes.GraphFrame
import org.graphframes.Logging
import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithCheckpointInterval

import java.io.IOException
import java.math.BigDecimal
Expand All @@ -43,7 +44,8 @@ import java.util.UUID
class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
extends Arguments
with Logging
with WithAlgorithmChoice {
with WithAlgorithmChoice
with WithCheckpointInterval {

private var broadcastThreshold: Int = 1000000
setAlgorithm(ALGO_GRAPHFRAMES)
Expand Down Expand Up @@ -73,44 +75,11 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
*/
def getBroadcastThreshold: Int = broadcastThreshold

private var checkpointInterval: Int = 2

/**
* Sets checkpoint interval in terms of number of iterations (default: 2). Checkpointing
* regularly helps recover from failures, clean shuffle files, shorten the lineage of the
* computation graph, and reduce the complexity of plan optimization. As of Spark 2.0, the
* complexity of plan optimization would grow exponentially without checkpointing. Hence,
* disabling or setting longer-than-default checkpoint intervals are not recommended. Checkpoint
* data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with prefix
* "connected-components". If the checkpoint directory is not set, this throws a
* `java.io.IOException`. Set a nonpositive value to disable checkpointing. This parameter is
* only used when the algorithm is set to "graphframes". Its default value might change in the
* future.
* @see
* `org.apache.spark.SparkContext.setCheckpointDir` in Spark API doc
*/
def setCheckpointInterval(value: Int): this.type = {
if (value <= 0 || value > 2) {
logWarn(
s"Set checkpointInterval to $value. This would blow up the query plan and hang the " +
"driver for large graphs.")
}
checkpointInterval = value
this
}

// python-friendly setter
private[graphframes] def setCheckpointInterval(value: java.lang.Integer): this.type = {
setCheckpointInterval(value.toInt)
}

/**
* Gets checkpoint interval.
* @see
* [[org.graphframes.lib.ConnectedComponents.setCheckpointInterval]]
*/
def getCheckpointInterval: Int = checkpointInterval

private var intermediateStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK

/**
Expand Down
115 changes: 111 additions & 4 deletions src/main/scala/org/graphframes/lib/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.array
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.explode
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.struct
import org.graphframes.GraphFrame
import org.graphframes.GraphFrame._
import org.graphframes.Logging

import java.io.IOException
import scala.util.control.Breaks.break
Expand Down Expand Up @@ -78,13 +80,17 @@ import scala.util.control.Breaks.breakable
* <a href="https://doi.org/10.1145/1807167.1807184"> Malewicz et al., Pregel: a system for
* large-scale graph processing. </a>
*/
class Pregel(val graph: GraphFrame) {
class Pregel(val graph: GraphFrame) extends Logging {

private val withVertexColumnList = collection.mutable.ListBuffer.empty[(String, Column, Column)]

private var maxIter: Int = 10
private var checkpointInterval = 2
private var earlyStopping = false
private var stopIfAllNonActiveVertices = false
private var skipMessagesFromNonActiveVertices = false
private var initialActiveVertexExpression = lit(true)
private var updateActiveVertexExpression = lit(true)

private val sendMsgs = collection.mutable.ListBuffer.empty[(Column, Column)]
private var aggMsgsCol: Column = null
Expand Down Expand Up @@ -132,6 +138,80 @@ class Pregel(val graph: GraphFrame) {
this
}

/**
* Should Pregel stop earlier in case all the vertices are marked as non active.
*
* This feature allows to terminate Pregel before reaching maxIter by checking are there active
* vertex left. A good example of activity check is PageRank: (see Malewicz, Grzegorz, et al.
* "Pregel: a system for large-scale graph processing." Proceedings of the 2010 ACM SIGMOD
* International Conference on Management of data. 2010., a part about voting to halt)
* - after each iteration we are checking is the change in rank less than tolerance and if so,
* we can mark vertex as non active
* - if all the vertices are non active, we stop iterations
* @param value
* should Pregel stop earlier by vertices voting
* @return
*/
def setStopIfAllNonActiveVertices(value: Boolean): this.type = {
stopIfAllNonActiveVertices = value
this
}

/**
* Set the initial expression for the active/non-active flag per vertex.
*
* In most of the cases the default expression (true for all the vertices) should works fine.
* For some cases it makes sense to set a custom expression. A good example is
* multiple-landmarks shortest-paths algorithm:
* - the only initially active vertices in that case should be landmarkds, because only this
* vertices initially have non-null distances but all the other vertices have null distances
* and there is no reason to mark them active initially.
* @param expression
* an initial expression that will be used to create an active-flag vertex column
* @return
*/
def setInitialActiveVertexExpression(expression: Column): this.type = {
initialActiveVertexExpression = expression
this
}

/**
* Set an expression that will be used after each superstep to update the active-flag vertex
* column.
*
* An example is PageRank algorithm: in that case such an expression may looks like abs(old_rank -
* new_rank) >= tolerance
*
* @param expression
* an expression, that will be used after each superstep to update the active-flag vertex
* column
* @return
*/
def setUpdateActiveVertexExpression(expression: Column): this.type = {
updateActiveVertexExpression = expression
this
}

/**
* With a true value, Pregel will not generate messages from vertices, marked as non active.
*
* For example, for Shortest Paths, there is no reason to pass distances from vertices, for that
* these distances did not change at the latest iteration. It allows significantly reduce an
* amount of generated messages.
*
* Be careful, for algorithms like Label Propagation or Pregel, even if the vertex is not
* active, we still need to generate messages, otherwise algorithm will return an incorrect
* result!
*
* @param value
* should Pregel skip generation of messages for non active vertices.
* @return
*/
def setSkipMessagesFromNonActiveVertices(value: Boolean): this.type = {
skipMessagesFromNonActiveVertices = value
this
}

/**
* Defines an additional vertex column at the start of run and how to update it in each
* iteration.
Expand Down Expand Up @@ -250,7 +330,10 @@ class Pregel(val graph: GraphFrame) {
updateExpr.as(colName)
}

var currentVertices = graph.vertices.select((col("*") :: initVertexCols): _*)
var currentVertices = graph.vertices.select(
(Seq(
col("*"),
initialActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ initVertexCols): _*)
var vertexUpdateColDF: DataFrame = null

val edges = graph.edges
Expand All @@ -272,19 +355,27 @@ class Pregel(val graph: GraphFrame) {

breakable {
while (iteration <= maxIter) {
val tripletsDF = currentVertices
logInfo(s"start Pregel iteration $iteration / $maxIter")
var tripletsDF = currentVertices
.select(struct(col("*")).as(SRC))
.join(edges.select(struct(col("*")).as(EDGE)), Pregel.src(ID) === Pregel.edge(SRC))
.join(
currentVertices.select(struct(col("*")).as(DST)),
Pregel.edge(DST) === Pregel.dst(ID))

if (skipMessagesFromNonActiveVertices) {
tripletsDF = tripletsDF.filter(
Pregel.src(Pregel.ACTIVE_FLAG_COL) || Pregel.dst(Pregel.ACTIVE_FLAG_COL))
}

val msgDF: DataFrame = tripletsDF
.select(explode(array(sendMsgsColList: _*)).as("msg"))
.select(col("msg.id"), col("msg.msg").as(Pregel.MSG_COL_NAME))
.filter(Pregel.msg.isNotNull)

if (earlyStopping && msgDF.isEmpty) {
logInfo(
s"there are no more non-null messages; Pregel stops earlier at iteration $iteration")
if (vertexUpdateColDF != null) {
vertexUpdateColDF.unpersist()
}
Expand All @@ -297,7 +388,10 @@ class Pregel(val graph: GraphFrame) {

val verticesWithMsg = currentVertices.join(newAggMsgDF, Seq(ID), "left_outer")

var newVertexUpdateColDF = verticesWithMsg.select((col(ID) :: updateVertexCols): _*)
var newVertexUpdateColDF = verticesWithMsg.select(
(Seq(
col(ID),
updateActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ updateVertexCols): _*)

if (shouldCheckpoint && iteration % checkpointInterval == 0) {
// do checkpoint, use lazy checkpoint because later we will materialize this DF.
Expand All @@ -314,6 +408,14 @@ class Pregel(val graph: GraphFrame) {

currentVertices = graph.vertices.join(vertexUpdateColDF, ID)

if (stopIfAllNonActiveVertices) {
if (currentVertices.filter(col(Pregel.ACTIVE_FLAG_COL)).isEmpty) {
logInfo(
s"all the verties are non-active; Pregel stops earlier at iteration $iteration")
break()
}
}

iteration += 1
}
}
Expand All @@ -335,6 +437,11 @@ object Pregel extends Serializable {
*/
val MSG_COL_NAME = "_pregel_msg_"

/**
* A constant column name for active vertex flag.
*/
val ACTIVE_FLAG_COL = "_pregel_is_active"

/**
* References the message column in aggregating messages and updating additional vertex columns.
*
Expand Down
17 changes: 15 additions & 2 deletions src/main/scala/org/graphframes/lib/ShortestPaths.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.graphframes.GraphFrame.quote
import org.graphframes.GraphFramesUnreachableException
import org.graphframes.Logging
import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithCheckpointInterval

import java.util
import scala.jdk.CollectionConverters._
Expand All @@ -54,7 +55,8 @@ import scala.jdk.CollectionConverters._
*/
class ShortestPaths private[graphframes] (private val graph: GraphFrame)
extends Arguments
with WithAlgorithmChoice {
with WithAlgorithmChoice
with WithCheckpointInterval {
import org.graphframes.lib.ShortestPaths._

private var lmarks: Option[Seq[Any]] = None
Expand All @@ -79,7 +81,7 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame)
val lmarksChecked = check(lmarks, "landmarks")
algorithm match {
case ALGO_GRAPHX => runInGraphX(graph, lmarksChecked)
case ALGO_GRAPHFRAMES => runInGraphFrames(graph, lmarksChecked)
case ALGO_GRAPHFRAMES => runInGraphFrames(graph, lmarksChecked, checkpointInterval)
case _ => throw new GraphFramesUnreachableException()
}
}
Expand Down Expand Up @@ -125,6 +127,7 @@ private object ShortestPaths extends Logging {
private def runInGraphFrames(
graph: GraphFrame,
landmarks: Seq[Any],
checkpointInterval: Int,
isDirected: Boolean = true): DataFrame = {
logWarn("The GraphFrames based implementation is slow and considered experimental!")
val vertexType = graph.vertices.schema(GraphFrame.ID).dataType
Expand Down Expand Up @@ -192,6 +195,12 @@ private object ShortestPaths extends Logging {
val srcDistanceCol = Pregel.src(DISTANCE_ID)
val dstDistanceCol = Pregel.dst(DISTANCE_ID)

// Initial active-vertex col expression: only landmarks
val initialActiveVerticesExpr = col(GraphFrame.ID).isInCollection(landmarks)

// Mark vertex as active only in the case idstance changed
val updateActiveVierticesExpr = isDistanceImprovedWithMessage(Pregel.msg, col(DISTANCE_ID))

// Overall:
// 1. Initialize distances
// 2. If new message can improve distances send it
Expand All @@ -208,6 +217,10 @@ private object ShortestPaths extends Logging {
incrementDistances(dstDistanceCol)))
.aggMsgs(aggregateArrayOfDistanceMaps(collect_list(Pregel.msg)))
.setEarlyStopping(true)
.setInitialActiveVertexExpression(initialActiveVerticesExpr)
.setUpdateActiveVertexExpression(updateActiveVierticesExpr)
.setStopIfAllNonActiveVertices(true)
.setSkipMessagesFromNonActiveVertices(true)

// Experimental feature
if (isDirected) {
Expand Down
38 changes: 38 additions & 0 deletions src/main/scala/org/graphframes/mixins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ private[graphframes] trait WithAlgorithmChoice {
protected var algorithm: String = ALGO_GRAPHX
val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES)

/**
* Set an algorithm to use. Supported algorithms are "graphx" and "graphframes".
*
* @param value
* @return
*/
def setAlgorithm(value: String): this.type = {
require(
supportedAlgorithms.contains(value),
Expand All @@ -16,3 +22,35 @@ private[graphframes] trait WithAlgorithmChoice {

def getAlgorithm: String = algorithm
}

private[graphframes] trait WithCheckpointInterval extends Logging {
protected var checkpointInterval: Int = 2

/**
* Sets checkpoint interval in terms of number of iterations (default: 2). Checkpointing
* regularly helps recover from failures, clean shuffle files, shorten the lineage of the
* computation graph, and reduce the complexity of plan optimization. As of Spark 2.0, the
* complexity of plan optimization would grow exponentially without checkpointing. Hence,
* disabling or setting longer-than-default checkpoint intervals are not recommended. Checkpoint
* data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with prefix of the
* algorithm name. If the checkpoint directory is not set, this throws a `java.io.IOException`.
* Set a nonpositive value to disable checkpointing. This parameter is only used when the
* algorithm is set to "graphframes". Its default value might change in the future.
* @see
* `org.apache.spark.SparkContext.setCheckpointDir` in Spark API doc
*/
def setCheckpointInterval(value: Int): this.type = {
if (value <= 0 || value > 2) {
logWarn(
s"Set checkpointInterval to $value. This would blow up the query plan and hang the " +
"driver for large graphs.")
}
checkpointInterval = value
this
}

/**
* Gets checkpoint interval.
*/
def getCheckpointInterval: Int = checkpointInterval
}