diff --git a/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala b/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala
index 4b016fe99..70f2f0ca3 100644
--- a/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala
+++ b/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala
@@ -6,6 +6,18 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.StorageLevel
object GraphFramesConf {
+ private val USE_LOCAL_CHECKPOINTS =
+ SQLConf
+ .buildConf("spark.graphframes.useLocalCheckpoints")
+ .doc(""" Tells the connected components algorithm to use local checkpoints (default: "false").
+ | If set to "true", iterative algorithm will use the checkpointing mechanism to the persistent storage.
+ | Local checkpoints are faster but can make the whole job less prone to errors.
+ | @note This option may become default "true" in the future.
+ |""".stripMargin)
+ .version("0.9.3")
+ .booleanConf
+ .createOptional
+
private val USE_LABELS_AS_COMPONENTS =
SQLConf
.buildConf("spark.graphframes.useLabelsAsComponents")
@@ -108,4 +120,6 @@ object GraphFramesConf {
case Some(use) => Some(use.toBoolean)
case _ => None
}
+
+ def getUseLocalCheckpoints: Option[Boolean] = get(USE_LOCAL_CHECKPOINTS).map(_.toBoolean)
}
diff --git a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala
index 32823f57e..fa7c7c335 100644
--- a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala
+++ b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala
@@ -30,6 +30,7 @@ import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithBroadcastThreshold
import org.graphframes.WithCheckpointInterval
import org.graphframes.WithIntermediateStorageLevel
+import org.graphframes.WithLocalCheckpoints
import org.graphframes.WithMaxIter
import org.graphframes.WithUseLabelsAsComponents
@@ -54,7 +55,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
with WithBroadcastThreshold
with WithIntermediateStorageLevel
with WithUseLabelsAsComponents
- with WithMaxIter {
+ with WithMaxIter
+ with WithLocalCheckpoints {
setAlgorithm(GraphFramesConf.getConnectedComponentsAlgorithm.getOrElse(ALGO_GRAPHFRAMES))
setCheckpointInterval(
@@ -65,6 +67,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
GraphFramesConf.getConnectedComponentsStorageLevel.getOrElse(intermediateStorageLevel))
setUseLabelsAsComponents(
GraphFramesConf.getUseLabelsAsComponents.getOrElse(useLabelsAsComponents))
+ setUseLocalCheckpoints(GraphFramesConf.getUseLocalCheckpoints.getOrElse(useLocalCheckpoints))
/**
* Runs the algorithm.
@@ -77,7 +80,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
checkpointInterval = checkpointInterval,
intermediateStorageLevel = intermediateStorageLevel,
useLabelsAsComponents = useLabelsAsComponents,
- maxIter = maxIter)
+ maxIter = maxIter,
+ useLocalCheckpoints = useLocalCheckpoints)
}
}
@@ -190,7 +194,8 @@ object ConnectedComponents extends Logging {
checkpointInterval: Int,
intermediateStorageLevel: StorageLevel,
useLabelsAsComponents: Boolean,
- maxIter: Option[Int]): DataFrame = {
+ maxIter: Option[Int],
+ useLocalCheckpoints: Boolean): DataFrame = {
if (runInGraphX) {
return runGraphX(graph, maxIter.getOrElse(Int.MaxValue))
}
@@ -208,7 +213,8 @@ object ConnectedComponents extends Logging {
logInfo(s"$logPrefix Start connected components with run ID $runId.")
val shouldCheckpoint = checkpointInterval > 0
- val checkpointDir: Option[String] = if (shouldCheckpoint) {
+ val checkpointDir: Option[String] = if (useLocalCheckpoints) { None }
+ else if (shouldCheckpoint) {
val dir = sc.getCheckpointDir
.map { d =>
new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString
@@ -297,19 +303,23 @@ object ConnectedComponents extends Logging {
// checkpointing
if (shouldCheckpoint && (iteration % checkpointInterval == 0)) {
- // TODO: remove this after DataFrame.checkpoint is implemented
- val out = s"${checkpointDir.get}/$iteration"
- ee.write.parquet(out)
- // may hit S3 eventually consistent issue
- ee = spark.read.parquet(out)
-
- // remove previous checkpoint
- if (iteration > checkpointInterval) {
- val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}")
- path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
- }
+ if (useLocalCheckpoints) {
+ ee = ee.localCheckpoint(eager = true)
+ } else {
+ // TODO: remove this after DataFrame.checkpoint is implemented
+ val out = s"${checkpointDir.get}/$iteration"
+ ee.write.parquet(out)
+ // may hit S3 eventually consistent issue
+ ee = spark.read.parquet(out)
+
+ // remove previous checkpoint
+ if (iteration > checkpointInterval) {
+ val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}")
+ path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
+ }
- System.gc() // hint Spark to clean shuffle directories
+ System.gc() // hint Spark to clean shuffle directories
+ }
}
ee.persist(intermediateStorageLevel)
diff --git a/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala b/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala
index 3b083c39a..aa8877e00 100644
--- a/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala
+++ b/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.functions._
import org.graphframes.GraphFrame
import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithCheckpointInterval
+import org.graphframes.WithLocalCheckpoints
import org.graphframes.WithMaxIter
/**
@@ -44,14 +45,19 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame)
extends Arguments
with WithAlgorithmChoice
with WithCheckpointInterval
- with WithMaxIter {
+ with WithMaxIter
+ with WithLocalCheckpoints {
def run(): DataFrame = {
val maxIterChecked = check(maxIter, "maxIter")
algorithm match {
case "graphx" => LabelPropagation.runInGraphX(graph, maxIterChecked)
case "graphframes" =>
- LabelPropagation.runInGraphFrames(graph, maxIterChecked, checkpointInterval)
+ LabelPropagation.runInGraphFrames(
+ graph,
+ maxIterChecked,
+ checkpointInterval,
+ useLocalCheckpoints = useLocalCheckpoints)
}
}
}
@@ -74,7 +80,8 @@ private object LabelPropagation {
graph: GraphFrame,
maxIter: Int,
checkpointInterval: Int,
- isDirected: Boolean = true): DataFrame = {
+ isDirected: Boolean = true,
+ useLocalCheckpoints: Boolean): DataFrame = {
// Overall:
// - Initial labels - IDs
// - Active vertex col (halt voting) - did the label changed?
@@ -88,6 +95,7 @@ private object LabelPropagation {
.setCheckpointInterval(checkpointInterval)
.setSkipMessagesFromNonActiveVertices(false)
.setUpdateActiveVertexExpression(col(LABEL_ID) =!= keyWithMaxValue(Pregel.msg))
+ .setUseLocalCheckpoints(useLocalCheckpoints)
if (isDirected) {
pregel = pregel.sendMsgToDst(col(LABEL_ID))
diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala
index fab3699ba..e5926d53f 100644
--- a/core/src/main/scala/org/graphframes/lib/Pregel.scala
+++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.functions.struct
import org.graphframes.GraphFrame
import org.graphframes.GraphFrame._
import org.graphframes.Logging
+import org.graphframes.WithLocalCheckpoints
import java.io.IOException
import scala.util.control.Breaks.break
@@ -80,7 +81,7 @@ import scala.util.control.Breaks.breakable
* Malewicz et al., Pregel: a system for
* large-scale graph processing.
*/
-class Pregel(val graph: GraphFrame) extends Logging {
+class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints {
private val withVertexColumnList = collection.mutable.ListBuffer.empty[(String, Column, Column)]
@@ -342,7 +343,7 @@ class Pregel(val graph: GraphFrame) extends Logging {
val shouldCheckpoint = checkpointInterval > 0
- if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty) {
+ if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty && !useLocalCheckpoints) {
// Spark Connect workaround
graph.spark.conf.getOption("spark.checkpoint.dir") match {
case Some(d) => graph.spark.sparkContext.setCheckpointDir(d)
@@ -394,9 +395,13 @@ class Pregel(val graph: GraphFrame) extends Logging {
updateActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ updateVertexCols): _*)
if (shouldCheckpoint && iteration % checkpointInterval == 0) {
- // do checkpoint, use lazy checkpoint because later we will materialize this DF.
- newVertexUpdateColDF = newVertexUpdateColDF.checkpoint(eager = false)
- // TODO: remove last checkpoint file.
+ if (useLocalCheckpoints) {
+ newVertexUpdateColDF = newVertexUpdateColDF.localCheckpoint(eager = false)
+ } else {
+ // do checkpoint, use lazy checkpoint because later we will materialize this DF.
+ newVertexUpdateColDF = newVertexUpdateColDF.checkpoint(eager = false)
+ // TODO: remove last checkpoint file.
+ }
}
newVertexUpdateColDF.cache()
newVertexUpdateColDF.count() // materialize it
diff --git a/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala
index 1e1b4f18d..50a543cc8 100644
--- a/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala
+++ b/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala
@@ -38,6 +38,7 @@ import org.graphframes.GraphFramesUnreachableException
import org.graphframes.Logging
import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithCheckpointInterval
+import org.graphframes.WithLocalCheckpoints
import java.util
import scala.jdk.CollectionConverters._
@@ -54,7 +55,8 @@ import scala.jdk.CollectionConverters._
class ShortestPaths private[graphframes] (private val graph: GraphFrame)
extends Arguments
with WithAlgorithmChoice
- with WithCheckpointInterval {
+ with WithCheckpointInterval
+ with WithLocalCheckpoints {
import org.graphframes.lib.ShortestPaths._
private var lmarks: Option[Seq[Any]] = None
@@ -79,7 +81,12 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame)
val lmarksChecked = check(lmarks, "landmarks")
algorithm match {
case ALGO_GRAPHX => runInGraphX(graph, lmarksChecked)
- case ALGO_GRAPHFRAMES => runInGraphFrames(graph, lmarksChecked, checkpointInterval)
+ case ALGO_GRAPHFRAMES =>
+ runInGraphFrames(
+ graph,
+ lmarksChecked,
+ checkpointInterval,
+ useLocalCheckpoints = useLocalCheckpoints)
case _ => throw new GraphFramesUnreachableException()
}
}
@@ -109,7 +116,8 @@ private object ShortestPaths extends Logging {
graph: GraphFrame,
landmarks: Seq[Any],
checkpointInterval: Int,
- isDirected: Boolean = true): DataFrame = {
+ isDirected: Boolean = true,
+ useLocalCheckpoints: Boolean): DataFrame = {
logWarn("The GraphFrames based implementation is slow and considered experimental!")
val vertexType = graph.vertices.schema(GraphFrame.ID).dataType
@@ -202,6 +210,8 @@ private object ShortestPaths extends Logging {
.setUpdateActiveVertexExpression(updateActiveVierticesExpr)
.setStopIfAllNonActiveVertices(true)
.setSkipMessagesFromNonActiveVertices(true)
+ .setCheckpointInterval(checkpointInterval)
+ .setUseLocalCheckpoints(useLocalCheckpoints)
// Experimental feature
if (isDirected) {
diff --git a/core/src/main/scala/org/graphframes/mixins.scala b/core/src/main/scala/org/graphframes/mixins.scala
index 56c8e3531..d7c2f0e07 100644
--- a/core/src/main/scala/org/graphframes/mixins.scala
+++ b/core/src/main/scala/org/graphframes/mixins.scala
@@ -141,3 +141,37 @@ private[graphframes] trait WithUseLabelsAsComponents {
*/
def getUseLabelsAsComponents: Boolean = useLabelsAsComponents
}
+
+/**
+ * Provides support for local checkpoints in Spark computations.
+ *
+ * Local checkpoints offer a faster alternative to regular checkpoints as they don't require
+ * configuration of checkpointDir in persistent storage (like HDFS or S3). While being more
+ * performant, local checkpoints are less reliable since they don't survive node failures and the
+ * data is not persisted across multiple nodes.
+ */
+private[graphframes] trait WithLocalCheckpoints {
+ protected var useLocalCheckpoints: Boolean = false
+
+ /**
+ * Sets whether to use local checkpoints instead of regular checkpoints (default: false). Local
+ * checkpoints are faster but less reliable as they don't survive node failures.
+ *
+ * @param value
+ * true to use local checkpoints, false for regular checkpoints
+ * @return
+ * this instance
+ */
+ def setUseLocalCheckpoints(value: Boolean): this.type = {
+ useLocalCheckpoints = value
+ this
+ }
+
+ /**
+ * Gets whether local checkpoints are being used instead of regular checkpoints.
+ *
+ * @return
+ * true if local checkpoints are enabled, false otherwise
+ */
+ def getUseLocalCheckpoints: Boolean = useLocalCheckpoints
+}
diff --git a/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala
index 12a4327e4..847a22ee8 100644
--- a/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala
+++ b/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala
@@ -34,164 +34,225 @@ import scala.reflect.runtime.universe.TypeTag
class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkContext {
- test("default params") {
- val g = Graphs.empty[Int]
- val cc = g.connectedComponents
- assert(cc.getAlgorithm === "graphframes")
- assert(cc.getBroadcastThreshold === 1000000)
- assert(cc.getCheckpointInterval === 2)
- assert(!cc.getUseLabelsAsComponents)
- }
-
- test("empty graph") {
- for (empty <- Seq(Graphs.empty[Int], Graphs.empty[Long], Graphs.empty[String])) {
- val components = empty.connectedComponents.run()
- assert(components.count() === 0L)
+ Seq(true, false).foreach(useLocalCheckpoint => {
+ test(s"default params${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val g = Graphs.empty[Int]
+ val cc = g.connectedComponents
+ assert(cc.getAlgorithm === "graphframes")
+ assert(cc.getBroadcastThreshold === 1000000)
+ assert(cc.getCheckpointInterval === 2)
+ assert(!cc.getUseLabelsAsComponents)
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
}
- }
- test("single vertex") {
- val v = spark.createDataFrame(List((0L, "a", "b"))).toDF("id", "vattr", "gender")
- // Create an empty dataframe with the proper columns.
- val e = spark
- .createDataFrame(List((0L, 0L, 1L)))
- .toDF("src", "dst", "test")
- .filter("src > 10")
- val g = GraphFrame(v, e)
- val comps = ConnectedComponents.run(g)
- TestUtils.testSchemaInvariants(g, comps)
- TestUtils.checkColumnType(comps.schema, "component", DataTypes.LongType)
- assert(comps.count() === 1)
- assert(
- comps.select("id", "component", "vattr", "gender").collect()
- === Seq(Row(0L, 0L, "a", "b")))
- }
+ test(s"empty graph${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ for (empty <- Seq(Graphs.empty[Int], Graphs.empty[Long], Graphs.empty[String])) {
+ val components = empty.connectedComponents.run()
+ assert(components.count() === 0L)
+ }
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- test("disconnected vertices") {
- val n = 5L
- val vertices = spark.range(n).toDF(ID)
- val edges = spark.createDataFrame(Seq.empty[(Long, Long)]).toDF(SRC, DST)
- val g = GraphFrame(vertices, edges)
- val components = g.connectedComponents.run()
- val expected = (0L until n).map(Set(_)).toSet
- assertComponents(components, expected)
- }
+ test(s"single vertex${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val v = spark.createDataFrame(List((0L, "a", "b"))).toDF("id", "vattr", "gender")
+ // Create an empty dataframe with the proper columns.
+ val e = spark
+ .createDataFrame(List((0L, 0L, 1L)))
+ .toDF("src", "dst", "test")
+ .filter("src > 10")
+ val g = GraphFrame(v, e)
+ val comps = ConnectedComponents.run(g)
+ TestUtils.testSchemaInvariants(g, comps)
+ TestUtils.checkColumnType(comps.schema, "component", DataTypes.LongType)
+ assert(comps.count() === 1)
+ assert(
+ comps.select("id", "component", "vattr", "gender").collect()
+ === Seq(Row(0L, 0L, "a", "b")))
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- test("using labels as components") {
- val vertices = spark.createDataFrame(Seq("a", "b", "c", "d", "e").map(Tuple1.apply)).toDF(ID)
- val edges = spark.createDataFrame(Seq.empty[(String, String)]).toDF(SRC, DST)
- val g = GraphFrame(vertices, edges)
- val components = g.connectedComponents.run()
- val expected = Seq("a", "b", "c", "d", "e").map(Set(_)).toSet
- assertComponents(components, expected)
- }
+ test(s"disconnected vertices${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val n = 5L
+ val vertices = spark.range(n).toDF(ID)
+ val edges = spark.createDataFrame(Seq.empty[(Long, Long)]).toDF(SRC, DST)
+ val g = GraphFrame(vertices, edges)
+ val components = g.connectedComponents.run()
+ val expected = (0L until n).map(Set(_)).toSet
+ assertComponents(components, expected)
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- test("don't using labels as components") {
- spark.conf.set("spark.graphframes.useLabelsAsComponents", "false")
- val vertices = spark.createDataFrame(Seq("a", "b", "c", "d", "e").map(Tuple1.apply)).toDF(ID)
- val edges = spark.createDataFrame(Seq.empty[(String, String)]).toDF(SRC, DST)
- val g = GraphFrame(vertices, edges)
- val components = g.connectedComponents.run()
- assert(components.schema("component").dataType == LongType)
- spark.conf.set("spark.graphframes.useLabelsAsComponents", "true")
- }
+ test(
+ s"using labels as components${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ spark.conf.set("spark.graphframes.useLabelsAsComponents", "true")
+ val vertices =
+ spark.createDataFrame(Seq("a", "b", "c", "d", "e").map(Tuple1.apply)).toDF(ID)
+ val edges = spark.createDataFrame(Seq.empty[(String, String)]).toDF(SRC, DST)
+ val g = GraphFrame(vertices, edges)
+ val components = g.connectedComponents.run()
+ val expected = Seq("a", "b", "c", "d", "e").map(Set(_)).toSet
+ assertComponents(components, expected)
+ spark.conf.set("spark.graphframes.useLabelsAsComponents", "false")
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- test("two connected vertices") {
- val v = spark.createDataFrame(List((0L, "a0", "b0"), (1L, "a1", "b1"))).toDF("id", "A", "B")
- val e = spark.createDataFrame(List((0L, 1L, "a01", "b01"))).toDF("src", "dst", "A", "B")
- val g = GraphFrame(v, e)
- val comps = g.connectedComponents.run()
- TestUtils.testSchemaInvariants(g, comps)
- assert(comps.count() === 2)
- val vxs = comps.sort("id").select("id", "component", "A", "B").collect()
- assert(List(Row(0L, 0L, "a0", "b0"), Row(1L, 0L, "a1", "b1")) === vxs)
- }
+ test(
+ s"don't using labels as components${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val vertices =
+ spark.createDataFrame(Seq("a", "b", "c", "d", "e").map(Tuple1.apply)).toDF(ID)
+ val edges = spark.createDataFrame(Seq.empty[(String, String)]).toDF(SRC, DST)
+ val g = GraphFrame(vertices, edges)
+ val components = g.connectedComponents.run()
+ assert(components.schema("component").dataType == LongType)
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- test("chain graph") {
- val n = 5L
- val g = Graphs.chain(5L)
- val components = g.connectedComponents.run()
- val expected = Set((0L until n).toSet)
- assertComponents(components, expected)
- }
+ test(s"two connected vertices${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val v = spark.createDataFrame(List((0L, "a0", "b0"), (1L, "a1", "b1"))).toDF("id", "A", "B")
+ val e = spark.createDataFrame(List((0L, 1L, "a01", "b01"))).toDF("src", "dst", "A", "B")
+ val g = GraphFrame(v, e)
+ val comps = g.connectedComponents.run()
+ TestUtils.testSchemaInvariants(g, comps)
+ assert(comps.count() === 2)
+ val vxs = comps.sort("id").select("id", "component", "A", "B").collect()
+ assert(List(Row(0L, 0L, "a0", "b0"), Row(1L, 0L, "a1", "b1")) === vxs)
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- test("star graph") {
- val n = 5L
- val g = Graphs.star(5L)
- val components = g.connectedComponents.run()
- val expected = Set((0L to n).toSet)
- assertComponents(components, expected)
- }
+ test(s"chain graph${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val n = 5L
+ val g = Graphs.chain(5L)
+ val components = g.connectedComponents.run()
+ val expected = Set((0L until n).toSet)
+ assertComponents(components, expected)
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- test("two blobs") {
- val n = 5L
- val g = Graphs.twoBlobs(n.toInt)
- val components = g.connectedComponents.run()
- val expected = Set((0L until 2 * n).toSet)
- assertComponents(components, expected)
- }
+ test(s"star graph${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val n = 5L
+ val g = Graphs.star(5L)
+ val components = g.connectedComponents.run()
+ val expected = Set((0L to n).toSet)
+ assertComponents(components, expected)
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- test("two components") {
- val vertices = spark.range(6L).toDF(ID)
- val edges = spark
- .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L)))
- .toDF(SRC, DST)
- val g = GraphFrame(vertices, edges)
- val components = g.connectedComponents.run()
- val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L))
- assertComponents(components, expected)
- }
+ test(s"two blobs${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val n = 5L
+ val g = Graphs.twoBlobs(n.toInt)
+ val components = g.connectedComponents.run()
+ val expected = Set((0L until 2 * n).toSet)
+ assertComponents(components, expected)
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- test("one component, differing edge directions") {
- val vertices = spark.range(5L).toDF(ID)
- val edges = spark
- .createDataFrame(
- Seq(
- // 0 -> 4 -> 3 <- 2 -> 1
- (0L, 4L),
- (4L, 3L),
- (2L, 3L),
- (2L, 1L)))
- .toDF(SRC, DST)
- val g = GraphFrame(vertices, edges)
- val components = g.connectedComponents.run()
- val expected = Set((0L to 4L).toSet)
- assertComponents(components, expected)
- }
+ test(s"two components${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val vertices = spark.range(6L).toDF(ID)
+ val edges = spark
+ .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L)))
+ .toDF(SRC, DST)
+ val g = GraphFrame(vertices, edges)
+ val components = g.connectedComponents.run()
+ val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L))
+ assertComponents(components, expected)
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- test("two components and two dangling vertices") {
- val vertices = spark.range(8L).toDF(ID)
- val edges = spark
- .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L)))
- .toDF(SRC, DST)
- val g = GraphFrame(vertices, edges)
- val components = g.connectedComponents.run()
- val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L), Set(6L), Set(7L))
- assertComponents(components, expected)
- }
+ test(
+ s"one component, differing edge directions${if (useLocalCheckpoint) " with local checkpoint"
+ else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val vertices = spark.range(5L).toDF(ID)
+ val edges = spark
+ .createDataFrame(
+ Seq(
+ // 0 -> 4 -> 3 <- 2 -> 1
+ (0L, 4L),
+ (4L, 3L),
+ (2L, 3L),
+ (2L, 1L)))
+ .toDF(SRC, DST)
+ val g = GraphFrame(vertices, edges)
+ val components = g.connectedComponents.run()
+ val expected = Set((0L to 4L).toSet)
+ assertComponents(components, expected)
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- test("friends graph") {
- val friends = Graphs.friends
- val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g"))
- for ((algorithm, broadcastThreshold) <-
- Seq(("graphx", 1000000), ("graphframes", 100000), ("graphframes", 1))) {
- val components = friends.connectedComponents
- .setAlgorithm(algorithm)
- .setBroadcastThreshold(broadcastThreshold)
- .run()
+ test(
+ s"two components and two dangling vertices${if (useLocalCheckpoint) " with local checkpoint"
+ else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val vertices = spark.range(8L).toDF(ID)
+ val edges = spark
+ .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L)))
+ .toDF(SRC, DST)
+ val g = GraphFrame(vertices, edges)
+ val components = g.connectedComponents.run()
+ val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L), Set(6L), Set(7L))
assertComponents(components, expected)
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
}
- }
- test("really large long IDs") {
- val max = Long.MaxValue
- val chain = examples.Graphs.chain(10L)
- val vertices = chain.vertices.select((lit(max) - col(ID)).as(ID))
- val edges = chain.edges.select((lit(max) - col(SRC)).as(SRC), (lit(max) - col(DST)).as(DST))
- val g = GraphFrame(vertices, edges)
- val components = g.connectedComponents.run()
- assert(components.count() === 10L)
- assert(components.groupBy("component").count().count() === 1L)
+ test(s"friends graph${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val friends = Graphs.friends
+ val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g"))
+ for ((algorithm, broadcastThreshold) <-
+ Seq(("graphx", 1000000), ("graphframes", 100000), ("graphframes", 1))) {
+ val components = friends.connectedComponents
+ .setAlgorithm(algorithm)
+ .setBroadcastThreshold(broadcastThreshold)
+ .run()
+ assertComponents(components, expected)
+ }
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
+
+ test(s"really large long IDs${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val max = Long.MaxValue
+ val chain = examples.Graphs.chain(10L)
+ val vertices = chain.vertices.select((lit(max) - col(ID)).as(ID))
+ val edges = chain.edges.select((lit(max) - col(SRC)).as(SRC), (lit(max) - col(DST)).as(DST))
+ val g = GraphFrame(vertices, edges)
+ val components = g.connectedComponents.run()
+ assert(components.count() === 10L)
+ assert(components.groupBy("component").count().count() === 1L)
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
+ })
+
+ test("set configuration from spark conf") {
+ spark.conf.set("spark.graphframes.connectedComponents.algorithm", "GRAPHX")
+ assert(Graphs.friends.connectedComponents.getAlgorithm == "graphx")
+
+ spark.conf.set("spark.graphframes.connectedComponents.broadcastthreshold", "1000")
+ assert(Graphs.friends.connectedComponents.getBroadcastThreshold == 1000)
+
+ spark.conf.set("spark.graphframes.connectedComponents.checkpointinterval", "5")
+ assert(Graphs.friends.connectedComponents.getCheckpointInterval == 5)
+
+ spark.conf
+ .set("spark.graphframes.connectedComponents.intermediatestoragelevel", "memory_only")
+ assert(
+ Graphs.friends.connectedComponents.getIntermediateStorageLevel == StorageLevel.MEMORY_ONLY)
+
+ spark.conf.unset("spark.graphframes.connectedComponents.algorithm")
+ spark.conf.unset("spark.graphframes.connectedComponents.broadcastthreshold")
+ spark.conf.unset("spark.graphframes.connectedComponents.checkpointinterval")
+ spark.conf.unset("spark.graphframes.connectedComponents.intermediatestoragelevel")
}
test("checkpoint interval") {
@@ -285,23 +346,6 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
assert(spark.sparkContext.getPersistentRDDs.size === priorCachedDFsSize)
}
- test("set configuration from spark conf") {
- spark.conf.set("spark.graphframes.connectedComponents.algorithm", "GRAPHX")
- assert(Graphs.friends.connectedComponents.getAlgorithm == "graphx")
-
- spark.conf.set("spark.graphframes.connectedComponents.broadcastthreshold", "1000")
- assert(Graphs.friends.connectedComponents.getBroadcastThreshold == 1000)
-
- spark.conf.set("spark.graphframes.connectedComponents.checkpointinterval", "5")
- assert(Graphs.friends.connectedComponents.getCheckpointInterval == 5)
-
- spark.conf.set(
- "spark.graphframes.connectedComponents.intermediatestoragelevel",
- "memory_only")
- assert(
- Graphs.friends.connectedComponents.getIntermediateStorageLevel == StorageLevel.MEMORY_ONLY)
- }
-
private def assertComponents[T: ClassTag: TypeTag](
actual: DataFrame,
expected: Set[Set[T]]): Unit = {
diff --git a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala
index d471c0d3d..67b2c5a46 100644
--- a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala
+++ b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala
@@ -25,112 +25,124 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext {
import sqlImplicits._
- test("page rank") {
- val edges = Seq(
- (0L, 1L),
- (1L, 2L),
- (2L, 4L),
- (2L, 0L),
- (3L, 4L), // 3 has no in-links
- (4L, 0L),
- (4L, 2L)).toDF("src", "dst").cache()
- val vertices = GraphFrame.fromEdges(edges).outDegrees.cache()
- val numVertices = vertices.count()
- val graph = GraphFrame(vertices, edges)
-
- val alpha = 0.15
- // NOTE: This version doesn't handle nodes with no out-links.
- val ranks = graph.pregel
- .setMaxIter(5)
- .withVertexColumn(
- "rank",
- lit(1.0 / numVertices),
- coalesce(Pregel.msg, lit(0.0)) * (1.0 - alpha) + alpha / numVertices)
- .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree"))
- .aggMsgs(sum(Pregel.msg))
- .run()
-
- val result = ranks
- .sort(col("id"))
- .select("rank")
- .as[Double]
- .collect()
- assert(result.sum === 1.0 +- 1e-6)
- val expected = Seq(0.245, 0.224, 0.303, 0.03, 0.197)
- result.zip(expected).foreach { case (r, e) =>
- assert(r === e +- 1e-3)
+ Seq(true, false).foreach(useLocalCheckpoint => {
+ test(s"page rank${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val edges = Seq(
+ (0L, 1L),
+ (1L, 2L),
+ (2L, 4L),
+ (2L, 0L),
+ (3L, 4L), // 3 has no in-links
+ (4L, 0L),
+ (4L, 2L)).toDF("src", "dst").cache()
+ val vertices = GraphFrame.fromEdges(edges).outDegrees.cache()
+ val numVertices = vertices.count()
+ val graph = GraphFrame(vertices, edges)
+
+ val alpha = 0.15
+ // NOTE: This version doesn't handle nodes with no out-links.
+ val ranks = graph.pregel
+ .setMaxIter(5)
+ .withVertexColumn(
+ "rank",
+ lit(1.0 / numVertices),
+ coalesce(Pregel.msg, lit(0.0)) * (1.0 - alpha) + alpha / numVertices)
+ .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree"))
+ .aggMsgs(sum(Pregel.msg))
+ .run()
+
+ val result = ranks
+ .sort(col("id"))
+ .select("rank")
+ .as[Double]
+ .collect()
+ assert(result.sum === 1.0 +- 1e-6)
+ val expected = Seq(0.245, 0.224, 0.303, 0.03, 0.197)
+ result.zip(expected).foreach { case (r, e) =>
+ assert(r === e +- 1e-3)
+ }
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
}
- }
-
- test("chain propagation") {
- val n = 5
- val verDF = (1 to n).toDF("id").repartition(3)
- val edgeDF = (1 until n)
- .map(x => (x, x + 1))
- .toDF("src", "dst")
- .repartition(3)
-
- val graph = GraphFrame(verDF, edgeDF)
- val resultDF = graph.pregel
- .setMaxIter(n - 1)
- .withVertexColumn(
- "value",
- when(col("id") === lit(1), lit(1)).otherwise(lit(0)),
- when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value")))
- .sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value")))
- .aggMsgs(max(Pregel.msg))
- .run()
-
- assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1))
- }
-
- test("reverse chain propagation") {
- val n = 5
- val verDF = (1 to n).toDF("id").repartition(3)
- val edgeDF = (1 until n)
- .map(x => (x + 1, x))
- .toDF("src", "dst")
- .repartition(3)
-
- val graph = GraphFrame(verDF, edgeDF)
-
- val resultDF = graph.pregel
- .setMaxIter(n - 1)
- .withVertexColumn(
- "value",
- when(col("id") === lit(1), lit(1)).otherwise(lit(0)),
- when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value")))
- .sendMsgToSrc(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.dst("value")))
- .aggMsgs(max(Pregel.msg))
- .run()
-
- assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1))
- }
-
- test("chain propagation with termination") {
- val n = 5
- val verDF = (1 to n).toDF("id").repartition(3)
- val edgeDF = (1 until n)
- .map(x => (x, x + 1))
- .toDF("src", "dst")
- .repartition(3)
+ test(s"chain propagation${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val n = 5
+ val verDF = (1 to n).toDF("id").repartition(3)
+ val edgeDF = (1 until n)
+ .map(x => (x, x + 1))
+ .toDF("src", "dst")
+ .repartition(3)
+
+ val graph = GraphFrame(verDF, edgeDF)
+
+ val resultDF = graph.pregel
+ .setMaxIter(n - 1)
+ .withVertexColumn(
+ "value",
+ when(col("id") === lit(1), lit(1)).otherwise(lit(0)),
+ when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value")))
+ .sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value")))
+ .aggMsgs(max(Pregel.msg))
+ .run()
+
+ assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1))
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- val graph = GraphFrame(verDF, edgeDF)
+ test(
+ s"reverse chain propagation${if (useLocalCheckpoint) " with local checkpoint" else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val n = 5
+ val verDF = (1 to n).toDF("id").repartition(3)
+ val edgeDF = (1 until n)
+ .map(x => (x + 1, x))
+ .toDF("src", "dst")
+ .repartition(3)
+
+ val graph = GraphFrame(verDF, edgeDF)
+
+ val resultDF = graph.pregel
+ .setMaxIter(n - 1)
+ .withVertexColumn(
+ "value",
+ when(col("id") === lit(1), lit(1)).otherwise(lit(0)),
+ when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value")))
+ .sendMsgToSrc(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.dst("value")))
+ .aggMsgs(max(Pregel.msg))
+ .run()
+
+ assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1))
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
- val resultDF = graph.pregel
- .setMaxIter(1000)
- .setEarlyStopping(true)
- .withVertexColumn(
- "value",
- when(col("id") === lit(1), lit(1)).otherwise(lit(0)),
- when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value")))
- .sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value")))
- .aggMsgs(max(Pregel.msg))
- .run()
-
- assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1))
- }
+ test(s"chain propagation with termination${if (useLocalCheckpoint) " with local checkpoint"
+ else ""}") {
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString)
+ val n = 5
+ val verDF = (1 to n).toDF("id").repartition(3)
+ val edgeDF = (1 until n)
+ .map(x => (x, x + 1))
+ .toDF("src", "dst")
+ .repartition(3)
+
+ val graph = GraphFrame(verDF, edgeDF)
+
+ val resultDF = graph.pregel
+ .setMaxIter(1000)
+ .setEarlyStopping(true)
+ .withVertexColumn(
+ "value",
+ when(col("id") === lit(1), lit(1)).otherwise(lit(0)),
+ when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value")))
+ .sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value")))
+ .aggMsgs(max(Pregel.msg))
+ .run()
+
+ assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1))
+ spark.conf.set("spark.graphframes.useLocalCheckpoints", "false")
+ }
+ })
test("new vertex column is based on the nullable column") {
val verDF = Seq(1L, 2L, 3L, 4L)
diff --git a/docs/configurations.md b/docs/configurations.md
new file mode 100644
index 000000000..2ae4ea67f
--- /dev/null
+++ b/docs/configurations.md
@@ -0,0 +1,137 @@
+---
+layout: global
+displayTitle: GraphFrames Configurations
+title: Configurations
+description: GraphFrames GRAPHFRAMES_VERSION configurations documentation
+---
+
+* Table of contents
+{:toc}
+
+# GraphFrames Configurations
+
+GraphFrames provides several configuration options that can be used to tune the behavior of algorithms and operations. This page documents all available configurations, their descriptions, default values, and usage examples.
+
+## Configuration Table
+
+The following table lists all available GraphFrames configurations:
+
+| Configuration | Description | Default Value | Since Version |
+|---------------|-------------|---------------|---------------|
+| `spark.graphframes.useLocalCheckpoints` | Tells the connected components algorithm to use local checkpoints. If set to "true", iterative algorithm will use the checkpointing mechanism to the persistent storage. Local checkpoints are faster but can make the whole job less prone to errors. | `false` | 0.9.3 |
+| `spark.graphframes.useLabelsAsComponents` | Tells the connected components algorithm to use labels as components in the output DataFrame. If set to "false", randomly generated labels with the data type LONG will returned. | Optional (default: `true`) | 0.9.0 |
+| `spark.graphframes.connectedComponents.algorithm` | Sets the connected components algorithm to use. Supported algorithms:
- "graphframes": Uses alternating large star and small star iterations proposed in [Connected Components in MapReduce and Beyond](http://dx.doi.org/10.1145/2670979.2670997) with skewed join optimization.
- "graphx": Converts the graph to a GraphX graph and then uses the connected components implementation in GraphX. | Optional (default: `graphframes`) | 0.9.0 |
+| `spark.graphframes.connectedComponents.broadcastthreshold` | Sets broadcast threshold in propagating component assignments. If a node degree is greater than this threshold at some iteration, its component assignment will be collected and then broadcasted back to propagate the assignment to its neighbors. Otherwise, the assignment propagation is done by a normal Spark join. This parameter is only used when the algorithm is set to "graphframes". | Optional (default: `1000000`) | 0.9.0 |
+| `spark.graphframes.connectedComponents.checkpointinterval` | Sets checkpoint interval in terms of number of iterations. Checkpointing regularly helps recover from failures, clean shuffle files, shorten the lineage of the computation graph, and reduce the complexity of plan optimization. As of Spark 2.0, the complexity of plan optimization would grow exponentially without checkpointing. Hence, disabling or setting longer-than-default checkpoint intervals are not recommended. Checkpoint data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with prefix "connected-components". If the checkpoint directory is not set, this throws a `java.io.IOException`. Set a nonpositive value to disable checkpointing. This parameter is only used when the algorithm is set to "graphframes". | Optional (default: `2`) | 0.9.0 |
+| `spark.graphframes.connectedComponents.intermediatestoragelevel` | Sets storage level for intermediate datasets that require multiple passes. | Optional (default: `MEMORY_AND_DISK`) | 0.9.0 |
+
+## Setting Configurations
+
+GraphFrames configurations can be set in several ways:
+
+### Spark Configuration
+
+You can set configurations when creating a SparkSession:
+
+