diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 7b1cb6f1d..f94274082 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -26,6 +26,7 @@ import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame import org.graphframes.Logging import org.graphframes.WithAlgorithmChoice +import org.graphframes.WithCheckpointInterval import java.io.IOException import java.math.BigDecimal @@ -43,7 +44,8 @@ import java.util.UUID class ConnectedComponents private[graphframes] (private val graph: GraphFrame) extends Arguments with Logging - with WithAlgorithmChoice { + with WithAlgorithmChoice + with WithCheckpointInterval { private var broadcastThreshold: Int = 1000000 setAlgorithm(ALGO_GRAPHFRAMES) @@ -73,44 +75,11 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) */ def getBroadcastThreshold: Int = broadcastThreshold - private var checkpointInterval: Int = 2 - - /** - * Sets checkpoint interval in terms of number of iterations (default: 2). Checkpointing - * regularly helps recover from failures, clean shuffle files, shorten the lineage of the - * computation graph, and reduce the complexity of plan optimization. As of Spark 2.0, the - * complexity of plan optimization would grow exponentially without checkpointing. Hence, - * disabling or setting longer-than-default checkpoint intervals are not recommended. Checkpoint - * data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with prefix - * "connected-components". If the checkpoint directory is not set, this throws a - * `java.io.IOException`. Set a nonpositive value to disable checkpointing. This parameter is - * only used when the algorithm is set to "graphframes". Its default value might change in the - * future. - * @see - * `org.apache.spark.SparkContext.setCheckpointDir` in Spark API doc - */ - def setCheckpointInterval(value: Int): this.type = { - if (value <= 0 || value > 2) { - logWarn( - s"Set checkpointInterval to $value. This would blow up the query plan and hang the " + - "driver for large graphs.") - } - checkpointInterval = value - this - } - // python-friendly setter private[graphframes] def setCheckpointInterval(value: java.lang.Integer): this.type = { setCheckpointInterval(value.toInt) } - /** - * Gets checkpoint interval. - * @see - * [[org.graphframes.lib.ConnectedComponents.setCheckpointInterval]] - */ - def getCheckpointInterval: Int = checkpointInterval - private var intermediateStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK /** diff --git a/src/main/scala/org/graphframes/lib/Pregel.scala b/src/main/scala/org/graphframes/lib/Pregel.scala index ab34494d2..fab3699ba 100644 --- a/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/src/main/scala/org/graphframes/lib/Pregel.scala @@ -22,9 +22,11 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.array import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.explode +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.struct import org.graphframes.GraphFrame import org.graphframes.GraphFrame._ +import org.graphframes.Logging import java.io.IOException import scala.util.control.Breaks.break @@ -78,13 +80,17 @@ import scala.util.control.Breaks.breakable * Malewicz et al., Pregel: a system for * large-scale graph processing. */ -class Pregel(val graph: GraphFrame) { +class Pregel(val graph: GraphFrame) extends Logging { private val withVertexColumnList = collection.mutable.ListBuffer.empty[(String, Column, Column)] private var maxIter: Int = 10 private var checkpointInterval = 2 private var earlyStopping = false + private var stopIfAllNonActiveVertices = false + private var skipMessagesFromNonActiveVertices = false + private var initialActiveVertexExpression = lit(true) + private var updateActiveVertexExpression = lit(true) private val sendMsgs = collection.mutable.ListBuffer.empty[(Column, Column)] private var aggMsgsCol: Column = null @@ -132,6 +138,80 @@ class Pregel(val graph: GraphFrame) { this } + /** + * Should Pregel stop earlier in case all the vertices are marked as non active. + * + * This feature allows to terminate Pregel before reaching maxIter by checking are there active + * vertex left. A good example of activity check is PageRank: (see Malewicz, Grzegorz, et al. + * "Pregel: a system for large-scale graph processing." Proceedings of the 2010 ACM SIGMOD + * International Conference on Management of data. 2010., a part about voting to halt) + * - after each iteration we are checking is the change in rank less than tolerance and if so, + * we can mark vertex as non active + * - if all the vertices are non active, we stop iterations + * @param value + * should Pregel stop earlier by vertices voting + * @return + */ + def setStopIfAllNonActiveVertices(value: Boolean): this.type = { + stopIfAllNonActiveVertices = value + this + } + + /** + * Set the initial expression for the active/non-active flag per vertex. + * + * In most of the cases the default expression (true for all the vertices) should works fine. + * For some cases it makes sense to set a custom expression. A good example is + * multiple-landmarks shortest-paths algorithm: + * - the only initially active vertices in that case should be landmarkds, because only this + * vertices initially have non-null distances but all the other vertices have null distances + * and there is no reason to mark them active initially. + * @param expression + * an initial expression that will be used to create an active-flag vertex column + * @return + */ + def setInitialActiveVertexExpression(expression: Column): this.type = { + initialActiveVertexExpression = expression + this + } + + /** + * Set an expression that will be used after each superstep to update the active-flag vertex + * column. + * + * An example is PageRank algorithm: in that case such an expression may looks like abs(old_rank - + * new_rank) >= tolerance + * + * @param expression + * an expression, that will be used after each superstep to update the active-flag vertex + * column + * @return + */ + def setUpdateActiveVertexExpression(expression: Column): this.type = { + updateActiveVertexExpression = expression + this + } + + /** + * With a true value, Pregel will not generate messages from vertices, marked as non active. + * + * For example, for Shortest Paths, there is no reason to pass distances from vertices, for that + * these distances did not change at the latest iteration. It allows significantly reduce an + * amount of generated messages. + * + * Be careful, for algorithms like Label Propagation or Pregel, even if the vertex is not + * active, we still need to generate messages, otherwise algorithm will return an incorrect + * result! + * + * @param value + * should Pregel skip generation of messages for non active vertices. + * @return + */ + def setSkipMessagesFromNonActiveVertices(value: Boolean): this.type = { + skipMessagesFromNonActiveVertices = value + this + } + /** * Defines an additional vertex column at the start of run and how to update it in each * iteration. @@ -250,7 +330,10 @@ class Pregel(val graph: GraphFrame) { updateExpr.as(colName) } - var currentVertices = graph.vertices.select((col("*") :: initVertexCols): _*) + var currentVertices = graph.vertices.select( + (Seq( + col("*"), + initialActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ initVertexCols): _*) var vertexUpdateColDF: DataFrame = null val edges = graph.edges @@ -272,19 +355,27 @@ class Pregel(val graph: GraphFrame) { breakable { while (iteration <= maxIter) { - val tripletsDF = currentVertices + logInfo(s"start Pregel iteration $iteration / $maxIter") + var tripletsDF = currentVertices .select(struct(col("*")).as(SRC)) .join(edges.select(struct(col("*")).as(EDGE)), Pregel.src(ID) === Pregel.edge(SRC)) .join( currentVertices.select(struct(col("*")).as(DST)), Pregel.edge(DST) === Pregel.dst(ID)) + if (skipMessagesFromNonActiveVertices) { + tripletsDF = tripletsDF.filter( + Pregel.src(Pregel.ACTIVE_FLAG_COL) || Pregel.dst(Pregel.ACTIVE_FLAG_COL)) + } + val msgDF: DataFrame = tripletsDF .select(explode(array(sendMsgsColList: _*)).as("msg")) .select(col("msg.id"), col("msg.msg").as(Pregel.MSG_COL_NAME)) .filter(Pregel.msg.isNotNull) if (earlyStopping && msgDF.isEmpty) { + logInfo( + s"there are no more non-null messages; Pregel stops earlier at iteration $iteration") if (vertexUpdateColDF != null) { vertexUpdateColDF.unpersist() } @@ -297,7 +388,10 @@ class Pregel(val graph: GraphFrame) { val verticesWithMsg = currentVertices.join(newAggMsgDF, Seq(ID), "left_outer") - var newVertexUpdateColDF = verticesWithMsg.select((col(ID) :: updateVertexCols): _*) + var newVertexUpdateColDF = verticesWithMsg.select( + (Seq( + col(ID), + updateActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ updateVertexCols): _*) if (shouldCheckpoint && iteration % checkpointInterval == 0) { // do checkpoint, use lazy checkpoint because later we will materialize this DF. @@ -314,6 +408,14 @@ class Pregel(val graph: GraphFrame) { currentVertices = graph.vertices.join(vertexUpdateColDF, ID) + if (stopIfAllNonActiveVertices) { + if (currentVertices.filter(col(Pregel.ACTIVE_FLAG_COL)).isEmpty) { + logInfo( + s"all the verties are non-active; Pregel stops earlier at iteration $iteration") + break() + } + } + iteration += 1 } } @@ -335,6 +437,11 @@ object Pregel extends Serializable { */ val MSG_COL_NAME = "_pregel_msg_" + /** + * A constant column name for active vertex flag. + */ + val ACTIVE_FLAG_COL = "_pregel_is_active" + /** * References the message column in aggregating messages and updating additional vertex columns. * diff --git a/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/src/main/scala/org/graphframes/lib/ShortestPaths.scala index 302be36eb..c8e0d8169 100644 --- a/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -39,6 +39,7 @@ import org.graphframes.GraphFrame.quote import org.graphframes.GraphFramesUnreachableException import org.graphframes.Logging import org.graphframes.WithAlgorithmChoice +import org.graphframes.WithCheckpointInterval import java.util import scala.jdk.CollectionConverters._ @@ -54,7 +55,8 @@ import scala.jdk.CollectionConverters._ */ class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends Arguments - with WithAlgorithmChoice { + with WithAlgorithmChoice + with WithCheckpointInterval { import org.graphframes.lib.ShortestPaths._ private var lmarks: Option[Seq[Any]] = None @@ -79,7 +81,7 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame) val lmarksChecked = check(lmarks, "landmarks") algorithm match { case ALGO_GRAPHX => runInGraphX(graph, lmarksChecked) - case ALGO_GRAPHFRAMES => runInGraphFrames(graph, lmarksChecked) + case ALGO_GRAPHFRAMES => runInGraphFrames(graph, lmarksChecked, checkpointInterval) case _ => throw new GraphFramesUnreachableException() } } @@ -125,6 +127,7 @@ private object ShortestPaths extends Logging { private def runInGraphFrames( graph: GraphFrame, landmarks: Seq[Any], + checkpointInterval: Int, isDirected: Boolean = true): DataFrame = { logWarn("The GraphFrames based implementation is slow and considered experimental!") val vertexType = graph.vertices.schema(GraphFrame.ID).dataType @@ -192,6 +195,12 @@ private object ShortestPaths extends Logging { val srcDistanceCol = Pregel.src(DISTANCE_ID) val dstDistanceCol = Pregel.dst(DISTANCE_ID) + // Initial active-vertex col expression: only landmarks + val initialActiveVerticesExpr = col(GraphFrame.ID).isInCollection(landmarks) + + // Mark vertex as active only in the case idstance changed + val updateActiveVierticesExpr = isDistanceImprovedWithMessage(Pregel.msg, col(DISTANCE_ID)) + // Overall: // 1. Initialize distances // 2. If new message can improve distances send it @@ -208,6 +217,10 @@ private object ShortestPaths extends Logging { incrementDistances(dstDistanceCol))) .aggMsgs(aggregateArrayOfDistanceMaps(collect_list(Pregel.msg))) .setEarlyStopping(true) + .setInitialActiveVertexExpression(initialActiveVerticesExpr) + .setUpdateActiveVertexExpression(updateActiveVierticesExpr) + .setStopIfAllNonActiveVertices(true) + .setSkipMessagesFromNonActiveVertices(true) // Experimental feature if (isDirected) { diff --git a/src/main/scala/org/graphframes/mixins.scala b/src/main/scala/org/graphframes/mixins.scala index 88a348bb0..aa8b22afc 100644 --- a/src/main/scala/org/graphframes/mixins.scala +++ b/src/main/scala/org/graphframes/mixins.scala @@ -6,6 +6,12 @@ private[graphframes] trait WithAlgorithmChoice { protected var algorithm: String = ALGO_GRAPHX val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES) + /** + * Set an algorithm to use. Supported algorithms are "graphx" and "graphframes". + * + * @param value + * @return + */ def setAlgorithm(value: String): this.type = { require( supportedAlgorithms.contains(value), @@ -16,3 +22,35 @@ private[graphframes] trait WithAlgorithmChoice { def getAlgorithm: String = algorithm } + +private[graphframes] trait WithCheckpointInterval extends Logging { + protected var checkpointInterval: Int = 2 + + /** + * Sets checkpoint interval in terms of number of iterations (default: 2). Checkpointing + * regularly helps recover from failures, clean shuffle files, shorten the lineage of the + * computation graph, and reduce the complexity of plan optimization. As of Spark 2.0, the + * complexity of plan optimization would grow exponentially without checkpointing. Hence, + * disabling or setting longer-than-default checkpoint intervals are not recommended. Checkpoint + * data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with prefix of the + * algorithm name. If the checkpoint directory is not set, this throws a `java.io.IOException`. + * Set a nonpositive value to disable checkpointing. This parameter is only used when the + * algorithm is set to "graphframes". Its default value might change in the future. + * @see + * `org.apache.spark.SparkContext.setCheckpointDir` in Spark API doc + */ + def setCheckpointInterval(value: Int): this.type = { + if (value <= 0 || value > 2) { + logWarn( + s"Set checkpointInterval to $value. This would blow up the query plan and hang the " + + "driver for large graphs.") + } + checkpointInterval = value + this + } + + /** + * Gets checkpoint interval. + */ + def getCheckpointInterval: Int = checkpointInterval +}