From ec21e1ef5d31b8c65b9821033da88bd44f5105cc Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 20 Jul 2025 15:53:32 +0200 Subject: [PATCH 1/2] LocalCheckpoints + docs --- .../sql/graphframes/GraphFramesConf.scala | 14 + .../graphframes/lib/ConnectedComponents.scala | 32 +- .../scala/org/graphframes/lib/Pregel.scala | 14 +- .../lib/ConnectedComponentsSuite.scala | 362 ++++++++++-------- .../org/graphframes/lib/PregelSuite.scala | 222 ++++++----- docs/configurations.md | 137 +++++++ docs/index.md | 7 +- docs/quick-start.md | 13 +- 8 files changed, 506 insertions(+), 295 deletions(-) create mode 100644 docs/configurations.md diff --git a/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala b/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala index 4b016fe99..cc8f6764f 100644 --- a/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala +++ b/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala @@ -6,6 +6,18 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel object GraphFramesConf { + private val USE_LOCAL_CHECKPOINTS = + SQLConf + .buildConf("spark.graphframes.useLocalCheckpoints") + .doc(""" Tells the connected components algorithm to use local checkpoints (default: "false"). + | If set to "true", iterative algorithm will use the checkpointing mechanism to the persistent storage. + | Local checkpoints are faster but can make the whole job less prone to errors. + | @note This option may become default "true" in the future. + |""".stripMargin) + .version("0.9.3") + .booleanConf + .createWithDefault(false) + private val USE_LABELS_AS_COMPONENTS = SQLConf .buildConf("spark.graphframes.useLabelsAsComponents") @@ -108,4 +120,6 @@ object GraphFramesConf { case Some(use) => Some(use.toBoolean) case _ => None } + + def getUseLocalCheckpoints: Boolean = get(USE_LOCAL_CHECKPOINTS).get.toBoolean } diff --git a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 32823f57e..ca9239e68 100644 --- a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -207,8 +207,10 @@ object ConnectedComponents extends Logging { val logPrefix = s"[CC $runId]" logInfo(s"$logPrefix Start connected components with run ID $runId.") + val shouldUseLocalCheckpoints = GraphFramesConf.getUseLocalCheckpoints val shouldCheckpoint = checkpointInterval > 0 - val checkpointDir: Option[String] = if (shouldCheckpoint) { + val checkpointDir: Option[String] = if (shouldUseLocalCheckpoints) { None } + else if (shouldCheckpoint) { val dir = sc.getCheckpointDir .map { d => new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString @@ -297,19 +299,23 @@ object ConnectedComponents extends Logging { // checkpointing if (shouldCheckpoint && (iteration % checkpointInterval == 0)) { - // TODO: remove this after DataFrame.checkpoint is implemented - val out = s"${checkpointDir.get}/$iteration" - ee.write.parquet(out) - // may hit S3 eventually consistent issue - ee = spark.read.parquet(out) - - // remove previous checkpoint - if (iteration > checkpointInterval) { - val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}") - path.getFileSystem(sc.hadoopConfiguration).delete(path, true) - } + if (shouldUseLocalCheckpoints) { + ee = ee.localCheckpoint(eager = true) + } else { + // TODO: remove this after DataFrame.checkpoint is implemented + val out = s"${checkpointDir.get}/$iteration" + ee.write.parquet(out) + // may hit S3 eventually consistent issue + ee = spark.read.parquet(out) + + // remove previous checkpoint + if (iteration > checkpointInterval) { + val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}") + path.getFileSystem(sc.hadoopConfiguration).delete(path, true) + } - System.gc() // hint Spark to clean shuffle directories + System.gc() // hint Spark to clean shuffle directories + } } ee.persist(intermediateStorageLevel) diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index fab3699ba..23c82dc8a 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.explode import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.struct +import org.apache.spark.sql.graphframes.GraphFramesConf import org.graphframes.GraphFrame import org.graphframes.GraphFrame._ import org.graphframes.Logging @@ -341,8 +342,9 @@ class Pregel(val graph: GraphFrame) extends Logging { var iteration = 1 val shouldCheckpoint = checkpointInterval > 0 + val shouldUseLocalCheckpoint = GraphFramesConf.getUseLocalCheckpoints - if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty) { + if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty && !shouldUseLocalCheckpoint) { // Spark Connect workaround graph.spark.conf.getOption("spark.checkpoint.dir") match { case Some(d) => graph.spark.sparkContext.setCheckpointDir(d) @@ -394,9 +396,13 @@ class Pregel(val graph: GraphFrame) extends Logging { updateActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ updateVertexCols): _*) if (shouldCheckpoint && iteration % checkpointInterval == 0) { - // do checkpoint, use lazy checkpoint because later we will materialize this DF. - newVertexUpdateColDF = newVertexUpdateColDF.checkpoint(eager = false) - // TODO: remove last checkpoint file. + if (shouldUseLocalCheckpoint) { + newVertexUpdateColDF = newVertexUpdateColDF.localCheckpoint(eager = false) + } else { + // do checkpoint, use lazy checkpoint because later we will materialize this DF. + newVertexUpdateColDF = newVertexUpdateColDF.checkpoint(eager = false) + // TODO: remove last checkpoint file. + } } newVertexUpdateColDF.cache() newVertexUpdateColDF.count() // materialize it diff --git a/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index 12a4327e4..847a22ee8 100644 --- a/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -34,164 +34,225 @@ import scala.reflect.runtime.universe.TypeTag class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkContext { - test("default params") { - val g = Graphs.empty[Int] - val cc = g.connectedComponents - assert(cc.getAlgorithm === "graphframes") - assert(cc.getBroadcastThreshold === 1000000) - assert(cc.getCheckpointInterval === 2) - assert(!cc.getUseLabelsAsComponents) - } - - test("empty graph") { - for (empty <- Seq(Graphs.empty[Int], Graphs.empty[Long], Graphs.empty[String])) { - val components = empty.connectedComponents.run() - assert(components.count() === 0L) + Seq(true, false).foreach(useLocalCheckpoint => { + test(s"default params${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val g = Graphs.empty[Int] + val cc = g.connectedComponents + assert(cc.getAlgorithm === "graphframes") + assert(cc.getBroadcastThreshold === 1000000) + assert(cc.getCheckpointInterval === 2) + assert(!cc.getUseLabelsAsComponents) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") } - } - test("single vertex") { - val v = spark.createDataFrame(List((0L, "a", "b"))).toDF("id", "vattr", "gender") - // Create an empty dataframe with the proper columns. - val e = spark - .createDataFrame(List((0L, 0L, 1L))) - .toDF("src", "dst", "test") - .filter("src > 10") - val g = GraphFrame(v, e) - val comps = ConnectedComponents.run(g) - TestUtils.testSchemaInvariants(g, comps) - TestUtils.checkColumnType(comps.schema, "component", DataTypes.LongType) - assert(comps.count() === 1) - assert( - comps.select("id", "component", "vattr", "gender").collect() - === Seq(Row(0L, 0L, "a", "b"))) - } + test(s"empty graph${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + for (empty <- Seq(Graphs.empty[Int], Graphs.empty[Long], Graphs.empty[String])) { + val components = empty.connectedComponents.run() + assert(components.count() === 0L) + } + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } - test("disconnected vertices") { - val n = 5L - val vertices = spark.range(n).toDF(ID) - val edges = spark.createDataFrame(Seq.empty[(Long, Long)]).toDF(SRC, DST) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - val expected = (0L until n).map(Set(_)).toSet - assertComponents(components, expected) - } + test(s"single vertex${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val v = spark.createDataFrame(List((0L, "a", "b"))).toDF("id", "vattr", "gender") + // Create an empty dataframe with the proper columns. + val e = spark + .createDataFrame(List((0L, 0L, 1L))) + .toDF("src", "dst", "test") + .filter("src > 10") + val g = GraphFrame(v, e) + val comps = ConnectedComponents.run(g) + TestUtils.testSchemaInvariants(g, comps) + TestUtils.checkColumnType(comps.schema, "component", DataTypes.LongType) + assert(comps.count() === 1) + assert( + comps.select("id", "component", "vattr", "gender").collect() + === Seq(Row(0L, 0L, "a", "b"))) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } - test("using labels as components") { - val vertices = spark.createDataFrame(Seq("a", "b", "c", "d", "e").map(Tuple1.apply)).toDF(ID) - val edges = spark.createDataFrame(Seq.empty[(String, String)]).toDF(SRC, DST) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - val expected = Seq("a", "b", "c", "d", "e").map(Set(_)).toSet - assertComponents(components, expected) - } + test(s"disconnected vertices${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val n = 5L + val vertices = spark.range(n).toDF(ID) + val edges = spark.createDataFrame(Seq.empty[(Long, Long)]).toDF(SRC, DST) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents.run() + val expected = (0L until n).map(Set(_)).toSet + assertComponents(components, expected) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } - test("don't using labels as components") { - spark.conf.set("spark.graphframes.useLabelsAsComponents", "false") - val vertices = spark.createDataFrame(Seq("a", "b", "c", "d", "e").map(Tuple1.apply)).toDF(ID) - val edges = spark.createDataFrame(Seq.empty[(String, String)]).toDF(SRC, DST) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - assert(components.schema("component").dataType == LongType) - spark.conf.set("spark.graphframes.useLabelsAsComponents", "true") - } + test( + s"using labels as components${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + spark.conf.set("spark.graphframes.useLabelsAsComponents", "true") + val vertices = + spark.createDataFrame(Seq("a", "b", "c", "d", "e").map(Tuple1.apply)).toDF(ID) + val edges = spark.createDataFrame(Seq.empty[(String, String)]).toDF(SRC, DST) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents.run() + val expected = Seq("a", "b", "c", "d", "e").map(Set(_)).toSet + assertComponents(components, expected) + spark.conf.set("spark.graphframes.useLabelsAsComponents", "false") + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } - test("two connected vertices") { - val v = spark.createDataFrame(List((0L, "a0", "b0"), (1L, "a1", "b1"))).toDF("id", "A", "B") - val e = spark.createDataFrame(List((0L, 1L, "a01", "b01"))).toDF("src", "dst", "A", "B") - val g = GraphFrame(v, e) - val comps = g.connectedComponents.run() - TestUtils.testSchemaInvariants(g, comps) - assert(comps.count() === 2) - val vxs = comps.sort("id").select("id", "component", "A", "B").collect() - assert(List(Row(0L, 0L, "a0", "b0"), Row(1L, 0L, "a1", "b1")) === vxs) - } + test( + s"don't using labels as components${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val vertices = + spark.createDataFrame(Seq("a", "b", "c", "d", "e").map(Tuple1.apply)).toDF(ID) + val edges = spark.createDataFrame(Seq.empty[(String, String)]).toDF(SRC, DST) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents.run() + assert(components.schema("component").dataType == LongType) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } - test("chain graph") { - val n = 5L - val g = Graphs.chain(5L) - val components = g.connectedComponents.run() - val expected = Set((0L until n).toSet) - assertComponents(components, expected) - } + test(s"two connected vertices${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val v = spark.createDataFrame(List((0L, "a0", "b0"), (1L, "a1", "b1"))).toDF("id", "A", "B") + val e = spark.createDataFrame(List((0L, 1L, "a01", "b01"))).toDF("src", "dst", "A", "B") + val g = GraphFrame(v, e) + val comps = g.connectedComponents.run() + TestUtils.testSchemaInvariants(g, comps) + assert(comps.count() === 2) + val vxs = comps.sort("id").select("id", "component", "A", "B").collect() + assert(List(Row(0L, 0L, "a0", "b0"), Row(1L, 0L, "a1", "b1")) === vxs) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } - test("star graph") { - val n = 5L - val g = Graphs.star(5L) - val components = g.connectedComponents.run() - val expected = Set((0L to n).toSet) - assertComponents(components, expected) - } + test(s"chain graph${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val n = 5L + val g = Graphs.chain(5L) + val components = g.connectedComponents.run() + val expected = Set((0L until n).toSet) + assertComponents(components, expected) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } - test("two blobs") { - val n = 5L - val g = Graphs.twoBlobs(n.toInt) - val components = g.connectedComponents.run() - val expected = Set((0L until 2 * n).toSet) - assertComponents(components, expected) - } + test(s"star graph${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val n = 5L + val g = Graphs.star(5L) + val components = g.connectedComponents.run() + val expected = Set((0L to n).toSet) + assertComponents(components, expected) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } - test("two components") { - val vertices = spark.range(6L).toDF(ID) - val edges = spark - .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L))) - .toDF(SRC, DST) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L)) - assertComponents(components, expected) - } + test(s"two blobs${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val n = 5L + val g = Graphs.twoBlobs(n.toInt) + val components = g.connectedComponents.run() + val expected = Set((0L until 2 * n).toSet) + assertComponents(components, expected) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } - test("one component, differing edge directions") { - val vertices = spark.range(5L).toDF(ID) - val edges = spark - .createDataFrame( - Seq( - // 0 -> 4 -> 3 <- 2 -> 1 - (0L, 4L), - (4L, 3L), - (2L, 3L), - (2L, 1L))) - .toDF(SRC, DST) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - val expected = Set((0L to 4L).toSet) - assertComponents(components, expected) - } + test(s"two components${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val vertices = spark.range(6L).toDF(ID) + val edges = spark + .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L))) + .toDF(SRC, DST) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents.run() + val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L)) + assertComponents(components, expected) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } - test("two components and two dangling vertices") { - val vertices = spark.range(8L).toDF(ID) - val edges = spark - .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L))) - .toDF(SRC, DST) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L), Set(6L), Set(7L)) - assertComponents(components, expected) - } + test( + s"one component, differing edge directions${if (useLocalCheckpoint) " with local checkpoint" + else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val vertices = spark.range(5L).toDF(ID) + val edges = spark + .createDataFrame( + Seq( + // 0 -> 4 -> 3 <- 2 -> 1 + (0L, 4L), + (4L, 3L), + (2L, 3L), + (2L, 1L))) + .toDF(SRC, DST) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents.run() + val expected = Set((0L to 4L).toSet) + assertComponents(components, expected) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } - test("friends graph") { - val friends = Graphs.friends - val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) - for ((algorithm, broadcastThreshold) <- - Seq(("graphx", 1000000), ("graphframes", 100000), ("graphframes", 1))) { - val components = friends.connectedComponents - .setAlgorithm(algorithm) - .setBroadcastThreshold(broadcastThreshold) - .run() + test( + s"two components and two dangling vertices${if (useLocalCheckpoint) " with local checkpoint" + else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val vertices = spark.range(8L).toDF(ID) + val edges = spark + .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L))) + .toDF(SRC, DST) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents.run() + val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L), Set(6L), Set(7L)) assertComponents(components, expected) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") } - } - test("really large long IDs") { - val max = Long.MaxValue - val chain = examples.Graphs.chain(10L) - val vertices = chain.vertices.select((lit(max) - col(ID)).as(ID)) - val edges = chain.edges.select((lit(max) - col(SRC)).as(SRC), (lit(max) - col(DST)).as(DST)) - val g = GraphFrame(vertices, edges) - val components = g.connectedComponents.run() - assert(components.count() === 10L) - assert(components.groupBy("component").count().count() === 1L) + test(s"friends graph${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val friends = Graphs.friends + val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) + for ((algorithm, broadcastThreshold) <- + Seq(("graphx", 1000000), ("graphframes", 100000), ("graphframes", 1))) { + val components = friends.connectedComponents + .setAlgorithm(algorithm) + .setBroadcastThreshold(broadcastThreshold) + .run() + assertComponents(components, expected) + } + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } + + test(s"really large long IDs${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val max = Long.MaxValue + val chain = examples.Graphs.chain(10L) + val vertices = chain.vertices.select((lit(max) - col(ID)).as(ID)) + val edges = chain.edges.select((lit(max) - col(SRC)).as(SRC), (lit(max) - col(DST)).as(DST)) + val g = GraphFrame(vertices, edges) + val components = g.connectedComponents.run() + assert(components.count() === 10L) + assert(components.groupBy("component").count().count() === 1L) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } + }) + + test("set configuration from spark conf") { + spark.conf.set("spark.graphframes.connectedComponents.algorithm", "GRAPHX") + assert(Graphs.friends.connectedComponents.getAlgorithm == "graphx") + + spark.conf.set("spark.graphframes.connectedComponents.broadcastthreshold", "1000") + assert(Graphs.friends.connectedComponents.getBroadcastThreshold == 1000) + + spark.conf.set("spark.graphframes.connectedComponents.checkpointinterval", "5") + assert(Graphs.friends.connectedComponents.getCheckpointInterval == 5) + + spark.conf + .set("spark.graphframes.connectedComponents.intermediatestoragelevel", "memory_only") + assert( + Graphs.friends.connectedComponents.getIntermediateStorageLevel == StorageLevel.MEMORY_ONLY) + + spark.conf.unset("spark.graphframes.connectedComponents.algorithm") + spark.conf.unset("spark.graphframes.connectedComponents.broadcastthreshold") + spark.conf.unset("spark.graphframes.connectedComponents.checkpointinterval") + spark.conf.unset("spark.graphframes.connectedComponents.intermediatestoragelevel") } test("checkpoint interval") { @@ -285,23 +346,6 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon assert(spark.sparkContext.getPersistentRDDs.size === priorCachedDFsSize) } - test("set configuration from spark conf") { - spark.conf.set("spark.graphframes.connectedComponents.algorithm", "GRAPHX") - assert(Graphs.friends.connectedComponents.getAlgorithm == "graphx") - - spark.conf.set("spark.graphframes.connectedComponents.broadcastthreshold", "1000") - assert(Graphs.friends.connectedComponents.getBroadcastThreshold == 1000) - - spark.conf.set("spark.graphframes.connectedComponents.checkpointinterval", "5") - assert(Graphs.friends.connectedComponents.getCheckpointInterval == 5) - - spark.conf.set( - "spark.graphframes.connectedComponents.intermediatestoragelevel", - "memory_only") - assert( - Graphs.friends.connectedComponents.getIntermediateStorageLevel == StorageLevel.MEMORY_ONLY) - } - private def assertComponents[T: ClassTag: TypeTag]( actual: DataFrame, expected: Set[Set[T]]): Unit = { diff --git a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala index f099a9443..6efc0d247 100644 --- a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala @@ -25,110 +25,122 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { import sqlImplicits._ - test("page rank") { - val edges = Seq( - (0L, 1L), - (1L, 2L), - (2L, 4L), - (2L, 0L), - (3L, 4L), // 3 has no in-links - (4L, 0L), - (4L, 2L)).toDF("src", "dst").cache() - val vertices = GraphFrame.fromEdges(edges).outDegrees.cache() - val numVertices = vertices.count() - val graph = GraphFrame(vertices, edges) - - val alpha = 0.15 - // NOTE: This version doesn't handle nodes with no out-links. - val ranks = graph.pregel - .setMaxIter(5) - .withVertexColumn( - "rank", - lit(1.0 / numVertices), - coalesce(Pregel.msg, lit(0.0)) * (1.0 - alpha) + alpha / numVertices) - .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree")) - .aggMsgs(sum(Pregel.msg)) - .run() - - val result = ranks - .sort(col("id")) - .select("rank") - .as[Double] - .collect() - assert(result.sum === 1.0 +- 1e-6) - val expected = Seq(0.245, 0.224, 0.303, 0.03, 0.197) - result.zip(expected).foreach { case (r, e) => - assert(r === e +- 1e-3) + Seq(true, false).foreach(useLocalCheckpoint => { + test(s"page rank${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val edges = Seq( + (0L, 1L), + (1L, 2L), + (2L, 4L), + (2L, 0L), + (3L, 4L), // 3 has no in-links + (4L, 0L), + (4L, 2L)).toDF("src", "dst").cache() + val vertices = GraphFrame.fromEdges(edges).outDegrees.cache() + val numVertices = vertices.count() + val graph = GraphFrame(vertices, edges) + + val alpha = 0.15 + // NOTE: This version doesn't handle nodes with no out-links. + val ranks = graph.pregel + .setMaxIter(5) + .withVertexColumn( + "rank", + lit(1.0 / numVertices), + coalesce(Pregel.msg, lit(0.0)) * (1.0 - alpha) + alpha / numVertices) + .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree")) + .aggMsgs(sum(Pregel.msg)) + .run() + + val result = ranks + .sort(col("id")) + .select("rank") + .as[Double] + .collect() + assert(result.sum === 1.0 +- 1e-6) + val expected = Seq(0.245, 0.224, 0.303, 0.03, 0.197) + result.zip(expected).foreach { case (r, e) => + assert(r === e +- 1e-3) + } + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") } - } - - test("chain propagation") { - val n = 5 - val verDF = (1 to n).toDF("id").repartition(3) - val edgeDF = (1 until n) - .map(x => (x, x + 1)) - .toDF("src", "dst") - .repartition(3) - - val graph = GraphFrame(verDF, edgeDF) - - val resultDF = graph.pregel - .setMaxIter(n - 1) - .withVertexColumn( - "value", - when(col("id") === lit(1), lit(1)).otherwise(lit(0)), - when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))) - .sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value"))) - .aggMsgs(max(Pregel.msg)) - .run() - - assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1)) - } - - test("reverse chain propagation") { - val n = 5 - val verDF = (1 to n).toDF("id").repartition(3) - val edgeDF = (1 until n) - .map(x => (x + 1, x)) - .toDF("src", "dst") - .repartition(3) - - val graph = GraphFrame(verDF, edgeDF) - - val resultDF = graph.pregel - .setMaxIter(n - 1) - .withVertexColumn( - "value", - when(col("id") === lit(1), lit(1)).otherwise(lit(0)), - when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))) - .sendMsgToSrc(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.dst("value"))) - .aggMsgs(max(Pregel.msg)) - .run() - - assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1)) - } - - test("chain propagation with termination") { - val n = 5 - val verDF = (1 to n).toDF("id").repartition(3) - val edgeDF = (1 until n) - .map(x => (x, x + 1)) - .toDF("src", "dst") - .repartition(3) - - val graph = GraphFrame(verDF, edgeDF) - - val resultDF = graph.pregel - .setMaxIter(1000) - .setEarlyStopping(true) - .withVertexColumn( - "value", - when(col("id") === lit(1), lit(1)).otherwise(lit(0)), - when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))) - .sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value"))) - .aggMsgs(max(Pregel.msg)) - .run() - - assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1)) - } + + test(s"chain propagation${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val n = 5 + val verDF = (1 to n).toDF("id").repartition(3) + val edgeDF = (1 until n) + .map(x => (x, x + 1)) + .toDF("src", "dst") + .repartition(3) + + val graph = GraphFrame(verDF, edgeDF) + + val resultDF = graph.pregel + .setMaxIter(n - 1) + .withVertexColumn( + "value", + when(col("id") === lit(1), lit(1)).otherwise(lit(0)), + when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))) + .sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value"))) + .aggMsgs(max(Pregel.msg)) + .run() + + assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1)) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } + + test( + s"reverse chain propagation${if (useLocalCheckpoint) " with local checkpoint" else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val n = 5 + val verDF = (1 to n).toDF("id").repartition(3) + val edgeDF = (1 until n) + .map(x => (x + 1, x)) + .toDF("src", "dst") + .repartition(3) + + val graph = GraphFrame(verDF, edgeDF) + + val resultDF = graph.pregel + .setMaxIter(n - 1) + .withVertexColumn( + "value", + when(col("id") === lit(1), lit(1)).otherwise(lit(0)), + when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))) + .sendMsgToSrc(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.dst("value"))) + .aggMsgs(max(Pregel.msg)) + .run() + + assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1)) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } + + test(s"chain propagation with termination${if (useLocalCheckpoint) " with local checkpoint" + else ""}") { + spark.conf.set("spark.graphframes.useLocalCheckpoints", useLocalCheckpoint.toString) + val n = 5 + val verDF = (1 to n).toDF("id").repartition(3) + val edgeDF = (1 until n) + .map(x => (x, x + 1)) + .toDF("src", "dst") + .repartition(3) + + val graph = GraphFrame(verDF, edgeDF) + + val resultDF = graph.pregel + .setMaxIter(1000) + .setEarlyStopping(true) + .withVertexColumn( + "value", + when(col("id") === lit(1), lit(1)).otherwise(lit(0)), + when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))) + .sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value"))) + .aggMsgs(max(Pregel.msg)) + .run() + + assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1)) + spark.conf.set("spark.graphframes.useLocalCheckpoints", "false") + } + }) } diff --git a/docs/configurations.md b/docs/configurations.md new file mode 100644 index 000000000..2ae4ea67f --- /dev/null +++ b/docs/configurations.md @@ -0,0 +1,137 @@ +--- +layout: global +displayTitle: GraphFrames Configurations +title: Configurations +description: GraphFrames GRAPHFRAMES_VERSION configurations documentation +--- + +* Table of contents +{:toc} + +# GraphFrames Configurations + +GraphFrames provides several configuration options that can be used to tune the behavior of algorithms and operations. This page documents all available configurations, their descriptions, default values, and usage examples. + +## Configuration Table + +The following table lists all available GraphFrames configurations: + +| Configuration | Description | Default Value | Since Version | +|---------------|-------------|---------------|---------------| +| `spark.graphframes.useLocalCheckpoints` | Tells the connected components algorithm to use local checkpoints. If set to "true", iterative algorithm will use the checkpointing mechanism to the persistent storage. Local checkpoints are faster but can make the whole job less prone to errors. | `false` | 0.9.3 | +| `spark.graphframes.useLabelsAsComponents` | Tells the connected components algorithm to use labels as components in the output DataFrame. If set to "false", randomly generated labels with the data type LONG will returned. | Optional (default: `true`) | 0.9.0 | +| `spark.graphframes.connectedComponents.algorithm` | Sets the connected components algorithm to use. Supported algorithms:
- "graphframes": Uses alternating large star and small star iterations proposed in [Connected Components in MapReduce and Beyond](http://dx.doi.org/10.1145/2670979.2670997) with skewed join optimization.
- "graphx": Converts the graph to a GraphX graph and then uses the connected components implementation in GraphX. | Optional (default: `graphframes`) | 0.9.0 | +| `spark.graphframes.connectedComponents.broadcastthreshold` | Sets broadcast threshold in propagating component assignments. If a node degree is greater than this threshold at some iteration, its component assignment will be collected and then broadcasted back to propagate the assignment to its neighbors. Otherwise, the assignment propagation is done by a normal Spark join. This parameter is only used when the algorithm is set to "graphframes". | Optional (default: `1000000`) | 0.9.0 | +| `spark.graphframes.connectedComponents.checkpointinterval` | Sets checkpoint interval in terms of number of iterations. Checkpointing regularly helps recover from failures, clean shuffle files, shorten the lineage of the computation graph, and reduce the complexity of plan optimization. As of Spark 2.0, the complexity of plan optimization would grow exponentially without checkpointing. Hence, disabling or setting longer-than-default checkpoint intervals are not recommended. Checkpoint data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with prefix "connected-components". If the checkpoint directory is not set, this throws a `java.io.IOException`. Set a nonpositive value to disable checkpointing. This parameter is only used when the algorithm is set to "graphframes". | Optional (default: `2`) | 0.9.0 | +| `spark.graphframes.connectedComponents.intermediatestoragelevel` | Sets storage level for intermediate datasets that require multiple passes. | Optional (default: `MEMORY_AND_DISK`) | 0.9.0 | + +## Setting Configurations + +GraphFrames configurations can be set in several ways: + +### Spark Configuration + +You can set configurations when creating a SparkSession: + +
+ +
+{% highlight scala %} +import org.apache.spark.sql.SparkSession + +val spark = SparkSession.builder() + .appName("GraphFrames Example") + .config("spark.graphframes.connectedComponents.algorithm", "graphframes") + .config("spark.graphframes.connectedComponents.checkpointinterval", 3) + .getOrCreate() +{% endhighlight %} +
+ +
+{% highlight python %} +from pyspark.sql import SparkSession + +spark = SparkSession.builder \ + .appName("GraphFrames Example") \ + .config("spark.graphframes.connectedComponents.algorithm", "graphframes") \ + .config("spark.graphframes.connectedComponents.checkpointinterval", 3) \ + .getOrCreate() +{% endhighlight %} +
+ +
+ +### Runtime Configuration + +You can also set configurations at runtime: + +
+ +
+{% highlight scala %} +spark.conf.set("spark.graphframes.connectedComponents.algorithm", "graphframes") +spark.conf.set("spark.graphframes.connectedComponents.checkpointinterval", 3) +{% endhighlight %} +
+ +
+{% highlight python %} +spark.conf.set("spark.graphframes.connectedComponents.algorithm", "graphframes") +spark.conf.set("spark.graphframes.connectedComponents.checkpointinterval", 3) +{% endhighlight %} +
+ +
+ +## Example: Connected Components with Custom Configurations + +This example shows how to run the Connected Components algorithm with custom configurations: + +
+ +
+{% highlight scala %} +import org.graphframes.GraphFrame +import org.graphframes.examples + +// Get example graph +val g = examples.Graphs.friends + +// Set configurations +spark.conf.set("spark.graphframes.connectedComponents.algorithm", "graphframes") +spark.conf.set("spark.graphframes.connectedComponents.checkpointinterval", 3) +spark.conf.set("spark.graphframes.useLocalCheckpoints", true) + +// Run connected components with custom configurations +val result = g.connectedComponents.run() +result.show() +{% endhighlight %} +
+ +
+{% highlight python %} +from graphframes.examples import Graphs + +# Get example graph +g = Graphs(spark).friends() + +# Set configurations +spark.conf.set("spark.graphframes.connectedComponents.algorithm", "graphframes") +spark.conf.set("spark.graphframes.connectedComponents.checkpointinterval", 3) +spark.conf.set("spark.graphframes.useLocalCheckpoints", "true") + +# Run connected components with custom configurations +result = g.connectedComponents() +result.show() +{% endhighlight %} +
+ +
+ +## Notes on Configuration Usage + +- **Checkpoint Directory**: For configurations related to checkpointing, make sure to set a checkpoint directory using `spark.sparkContext.setCheckpointDir("path/to/checkpoint/dir")` before running algorithms that use checkpointing. + +- **Storage Levels**: When setting the `spark.graphframes.connectedComponents.intermediatestoragelevel` configuration, use one of the following values: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc. + +- **Algorithm Selection**: The choice of algorithm for connected components can significantly impact performance. The "graphframes" algorithm is generally more scalable for large graphs, while the "graphx" algorithm may be faster for smaller graphs. \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index b9d8917bc..f0784f665 100644 --- a/docs/index.md +++ b/docs/index.md @@ -36,7 +36,7 @@ is willing to support this effort by reviewing the relevant pull requests.** # Downloading -Get GraphFrames from the [Spark Packages website](http://spark-packages.org/package/graphframes/graphframes). +Get GraphFrames from the [Maven Central](https://central.sonatype.com/namespace/io.graphframes). This documentation is for GraphFrames version {{site.GRAPHFRAMES_VERSION}}. GraphFrames depends on Apache Spark, which is available for download from the [Apache Spark website](http://spark.apache.org). @@ -44,10 +44,10 @@ GraphFrames depends on Apache Spark, which is available for download from the GraphFrames should be compatible with any platform which runs Spark. Refer to the [Apache Spark documentation](http://spark.apache.org/docs/latest) for more information. -GraphFrames is compatible with Spark 1.6+. However, later versions of Spark include major improvements +GraphFrames is compatible with Spark 3.4+. However, later versions of Spark include major improvements to DataFrames, so GraphFrames may be more efficient when running on more recent Spark versions. -GraphFrames is tested with Java 8, Python 2 and 3, and running against Spark 2.2+ (Scala 2.11). +GraphFrames is tested with Java 8, 11 and 17, Python 3, Spark 3.5 and Spark 4.0 (Scala 2.12 / Scala 2.13). # Applications, the Apache Spark shell, and clusters @@ -64,6 +64,7 @@ GraphFrames supplied as a package. * [GraphFrames User Guide](user-guide.html): detailed overview of GraphFrames in all supported languages (Scala, Java, Python) * [Motif Finding Tutorial](motif-tutorial.html): learn to perform pattern recognition with GraphFrames using a technique called network motif finding over the knowledge graph for the `stackexchange.com` subdomain [data dump](https://archive.org/details/stackexchange) +* [GraphFrames Configurations](configurations.html): detailed information about GraphFrames configurations, their descriptions, and usage examples **API Docs:** diff --git a/docs/quick-start.md b/docs/quick-start.md index ae28d9c94..c7562b1ab 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -18,10 +18,6 @@ If you are new to using Apache Spark, refer to the [Apache Spark Documentation](http://spark.apache.org/docs/latest/index.html) and its [Quick-Start Guide](http://spark.apache.org/docs/latest/quick-start.html) for more information. -If you are new to using [Spark packages](http://spark-packages.org/package/graphframes/graphframes), you can find more information -in the [Spark User Guide on using the interactive shell](http://spark.apache.org/docs/latest/programming-guide.html#using-the-shell). -You just need to make sure your Spark shell session has the package as a dependency. - The following example shows how to run the Spark shell with the GraphFrames package. We use the `--packages` argument to download the graphframes package and any dependencies automatically. @@ -30,7 +26,7 @@ We use the `--packages` argument to download the graphframes package and any dep
{% highlight bash %} -$ ./bin/spark-shell --packages graphframes:graphframes:0.8.4-spark3.5-s_2.12 +$ ./bin/spark-shell --packages io.graphframes:graphframes-spark3_2.12:0.9.2 {% endhighlight %}
@@ -38,18 +34,13 @@ $ ./bin/spark-shell --packages graphframes:graphframes:0.8.4-spark3.5-s_2.12
{% highlight bash %} -$ ./bin/pyspark --packages graphframes:graphframes:0.8.4-spark3.5-s_2.12 +$ ./bin/pyspark --packages io.graphframes:graphframes-spark3_2.12:0.9.2 {% endhighlight %}
-The above examples of running the Spark shell with GraphFrames use a specific version of the GraphFrames -package. To use a different version, just change the last part of the `--packages` argument; -for example, to run with version `0.1.0-spark1.6`, pass the argument -`--packages graphframes:graphframes:0.1.0-spark1.6`. - # Start using GraphFrames The following example shows how to create a GraphFrame, query it, and run the PageRank algorithm. From 0c40979898b285ace02bedcc1aae14bf2b452c73 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 21 Jul 2025 12:27:29 +0200 Subject: [PATCH 2/2] useLocalCheckpoints as API argument --- .../sql/graphframes/GraphFramesConf.scala | 4 +-- .../graphframes/lib/ConnectedComponents.scala | 16 +++++---- .../graphframes/lib/LabelPropagation.scala | 14 ++++++-- .../scala/org/graphframes/lib/Pregel.scala | 9 +++-- .../org/graphframes/lib/ShortestPaths.scala | 16 +++++++-- .../main/scala/org/graphframes/mixins.scala | 34 +++++++++++++++++++ 6 files changed, 74 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala b/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala index cc8f6764f..70f2f0ca3 100644 --- a/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala +++ b/core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala @@ -16,7 +16,7 @@ object GraphFramesConf { |""".stripMargin) .version("0.9.3") .booleanConf - .createWithDefault(false) + .createOptional private val USE_LABELS_AS_COMPONENTS = SQLConf @@ -121,5 +121,5 @@ object GraphFramesConf { case _ => None } - def getUseLocalCheckpoints: Boolean = get(USE_LOCAL_CHECKPOINTS).get.toBoolean + def getUseLocalCheckpoints: Option[Boolean] = get(USE_LOCAL_CHECKPOINTS).map(_.toBoolean) } diff --git a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index ca9239e68..fa7c7c335 100644 --- a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -30,6 +30,7 @@ import org.graphframes.WithAlgorithmChoice import org.graphframes.WithBroadcastThreshold import org.graphframes.WithCheckpointInterval import org.graphframes.WithIntermediateStorageLevel +import org.graphframes.WithLocalCheckpoints import org.graphframes.WithMaxIter import org.graphframes.WithUseLabelsAsComponents @@ -54,7 +55,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) with WithBroadcastThreshold with WithIntermediateStorageLevel with WithUseLabelsAsComponents - with WithMaxIter { + with WithMaxIter + with WithLocalCheckpoints { setAlgorithm(GraphFramesConf.getConnectedComponentsAlgorithm.getOrElse(ALGO_GRAPHFRAMES)) setCheckpointInterval( @@ -65,6 +67,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) GraphFramesConf.getConnectedComponentsStorageLevel.getOrElse(intermediateStorageLevel)) setUseLabelsAsComponents( GraphFramesConf.getUseLabelsAsComponents.getOrElse(useLabelsAsComponents)) + setUseLocalCheckpoints(GraphFramesConf.getUseLocalCheckpoints.getOrElse(useLocalCheckpoints)) /** * Runs the algorithm. @@ -77,7 +80,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) checkpointInterval = checkpointInterval, intermediateStorageLevel = intermediateStorageLevel, useLabelsAsComponents = useLabelsAsComponents, - maxIter = maxIter) + maxIter = maxIter, + useLocalCheckpoints = useLocalCheckpoints) } } @@ -190,7 +194,8 @@ object ConnectedComponents extends Logging { checkpointInterval: Int, intermediateStorageLevel: StorageLevel, useLabelsAsComponents: Boolean, - maxIter: Option[Int]): DataFrame = { + maxIter: Option[Int], + useLocalCheckpoints: Boolean): DataFrame = { if (runInGraphX) { return runGraphX(graph, maxIter.getOrElse(Int.MaxValue)) } @@ -207,9 +212,8 @@ object ConnectedComponents extends Logging { val logPrefix = s"[CC $runId]" logInfo(s"$logPrefix Start connected components with run ID $runId.") - val shouldUseLocalCheckpoints = GraphFramesConf.getUseLocalCheckpoints val shouldCheckpoint = checkpointInterval > 0 - val checkpointDir: Option[String] = if (shouldUseLocalCheckpoints) { None } + val checkpointDir: Option[String] = if (useLocalCheckpoints) { None } else if (shouldCheckpoint) { val dir = sc.getCheckpointDir .map { d => @@ -299,7 +303,7 @@ object ConnectedComponents extends Logging { // checkpointing if (shouldCheckpoint && (iteration % checkpointInterval == 0)) { - if (shouldUseLocalCheckpoints) { + if (useLocalCheckpoints) { ee = ee.localCheckpoint(eager = true) } else { // TODO: remove this after DataFrame.checkpoint is implemented diff --git a/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala b/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala index 3b083c39a..aa8877e00 100644 --- a/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala +++ b/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.functions._ import org.graphframes.GraphFrame import org.graphframes.WithAlgorithmChoice import org.graphframes.WithCheckpointInterval +import org.graphframes.WithLocalCheckpoints import org.graphframes.WithMaxIter /** @@ -44,14 +45,19 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame) extends Arguments with WithAlgorithmChoice with WithCheckpointInterval - with WithMaxIter { + with WithMaxIter + with WithLocalCheckpoints { def run(): DataFrame = { val maxIterChecked = check(maxIter, "maxIter") algorithm match { case "graphx" => LabelPropagation.runInGraphX(graph, maxIterChecked) case "graphframes" => - LabelPropagation.runInGraphFrames(graph, maxIterChecked, checkpointInterval) + LabelPropagation.runInGraphFrames( + graph, + maxIterChecked, + checkpointInterval, + useLocalCheckpoints = useLocalCheckpoints) } } } @@ -74,7 +80,8 @@ private object LabelPropagation { graph: GraphFrame, maxIter: Int, checkpointInterval: Int, - isDirected: Boolean = true): DataFrame = { + isDirected: Boolean = true, + useLocalCheckpoints: Boolean): DataFrame = { // Overall: // - Initial labels - IDs // - Active vertex col (halt voting) - did the label changed? @@ -88,6 +95,7 @@ private object LabelPropagation { .setCheckpointInterval(checkpointInterval) .setSkipMessagesFromNonActiveVertices(false) .setUpdateActiveVertexExpression(col(LABEL_ID) =!= keyWithMaxValue(Pregel.msg)) + .setUseLocalCheckpoints(useLocalCheckpoints) if (isDirected) { pregel = pregel.sendMsgToDst(col(LABEL_ID)) diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index 23c82dc8a..e5926d53f 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -24,10 +24,10 @@ import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.explode import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.struct -import org.apache.spark.sql.graphframes.GraphFramesConf import org.graphframes.GraphFrame import org.graphframes.GraphFrame._ import org.graphframes.Logging +import org.graphframes.WithLocalCheckpoints import java.io.IOException import scala.util.control.Breaks.break @@ -81,7 +81,7 @@ import scala.util.control.Breaks.breakable * Malewicz et al., Pregel: a system for * large-scale graph processing. */ -class Pregel(val graph: GraphFrame) extends Logging { +class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints { private val withVertexColumnList = collection.mutable.ListBuffer.empty[(String, Column, Column)] @@ -342,9 +342,8 @@ class Pregel(val graph: GraphFrame) extends Logging { var iteration = 1 val shouldCheckpoint = checkpointInterval > 0 - val shouldUseLocalCheckpoint = GraphFramesConf.getUseLocalCheckpoints - if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty && !shouldUseLocalCheckpoint) { + if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty && !useLocalCheckpoints) { // Spark Connect workaround graph.spark.conf.getOption("spark.checkpoint.dir") match { case Some(d) => graph.spark.sparkContext.setCheckpointDir(d) @@ -396,7 +395,7 @@ class Pregel(val graph: GraphFrame) extends Logging { updateActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ updateVertexCols): _*) if (shouldCheckpoint && iteration % checkpointInterval == 0) { - if (shouldUseLocalCheckpoint) { + if (useLocalCheckpoints) { newVertexUpdateColDF = newVertexUpdateColDF.localCheckpoint(eager = false) } else { // do checkpoint, use lazy checkpoint because later we will materialize this DF. diff --git a/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala index 1e1b4f18d..50a543cc8 100644 --- a/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -38,6 +38,7 @@ import org.graphframes.GraphFramesUnreachableException import org.graphframes.Logging import org.graphframes.WithAlgorithmChoice import org.graphframes.WithCheckpointInterval +import org.graphframes.WithLocalCheckpoints import java.util import scala.jdk.CollectionConverters._ @@ -54,7 +55,8 @@ import scala.jdk.CollectionConverters._ class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends Arguments with WithAlgorithmChoice - with WithCheckpointInterval { + with WithCheckpointInterval + with WithLocalCheckpoints { import org.graphframes.lib.ShortestPaths._ private var lmarks: Option[Seq[Any]] = None @@ -79,7 +81,12 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame) val lmarksChecked = check(lmarks, "landmarks") algorithm match { case ALGO_GRAPHX => runInGraphX(graph, lmarksChecked) - case ALGO_GRAPHFRAMES => runInGraphFrames(graph, lmarksChecked, checkpointInterval) + case ALGO_GRAPHFRAMES => + runInGraphFrames( + graph, + lmarksChecked, + checkpointInterval, + useLocalCheckpoints = useLocalCheckpoints) case _ => throw new GraphFramesUnreachableException() } } @@ -109,7 +116,8 @@ private object ShortestPaths extends Logging { graph: GraphFrame, landmarks: Seq[Any], checkpointInterval: Int, - isDirected: Boolean = true): DataFrame = { + isDirected: Boolean = true, + useLocalCheckpoints: Boolean): DataFrame = { logWarn("The GraphFrames based implementation is slow and considered experimental!") val vertexType = graph.vertices.schema(GraphFrame.ID).dataType @@ -202,6 +210,8 @@ private object ShortestPaths extends Logging { .setUpdateActiveVertexExpression(updateActiveVierticesExpr) .setStopIfAllNonActiveVertices(true) .setSkipMessagesFromNonActiveVertices(true) + .setCheckpointInterval(checkpointInterval) + .setUseLocalCheckpoints(useLocalCheckpoints) // Experimental feature if (isDirected) { diff --git a/core/src/main/scala/org/graphframes/mixins.scala b/core/src/main/scala/org/graphframes/mixins.scala index 56c8e3531..d7c2f0e07 100644 --- a/core/src/main/scala/org/graphframes/mixins.scala +++ b/core/src/main/scala/org/graphframes/mixins.scala @@ -141,3 +141,37 @@ private[graphframes] trait WithUseLabelsAsComponents { */ def getUseLabelsAsComponents: Boolean = useLabelsAsComponents } + +/** + * Provides support for local checkpoints in Spark computations. + * + * Local checkpoints offer a faster alternative to regular checkpoints as they don't require + * configuration of checkpointDir in persistent storage (like HDFS or S3). While being more + * performant, local checkpoints are less reliable since they don't survive node failures and the + * data is not persisted across multiple nodes. + */ +private[graphframes] trait WithLocalCheckpoints { + protected var useLocalCheckpoints: Boolean = false + + /** + * Sets whether to use local checkpoints instead of regular checkpoints (default: false). Local + * checkpoints are faster but less reliable as they don't survive node failures. + * + * @param value + * true to use local checkpoints, false for regular checkpoints + * @return + * this instance + */ + def setUseLocalCheckpoints(value: Boolean): this.type = { + useLocalCheckpoints = value + this + } + + /** + * Gets whether local checkpoints are being used instead of regular checkpoints. + * + * @return + * true if local checkpoints are enabled, false otherwise + */ + def getUseLocalCheckpoints: Boolean = useLocalCheckpoints +}