Skip to content

Commit 431e7ac

Browse files
feat: localCheckpoints + docs (#662)
* LocalCheckpoints + docs * useLocalCheckpoints as API argument
1 parent 102db5e commit 431e7ac

11 files changed

Lines changed: 568 additions & 302 deletions

File tree

core/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,18 @@ import org.apache.spark.sql.internal.SQLConf
66
import org.apache.spark.storage.StorageLevel
77

88
object GraphFramesConf {
9+
private val USE_LOCAL_CHECKPOINTS =
10+
SQLConf
11+
.buildConf("spark.graphframes.useLocalCheckpoints")
12+
.doc(""" Tells the connected components algorithm to use local checkpoints (default: "false").
13+
| If set to "true", iterative algorithm will use the checkpointing mechanism to the persistent storage.
14+
| Local checkpoints are faster but can make the whole job less prone to errors.
15+
| @note This option may become default "true" in the future.
16+
|""".stripMargin)
17+
.version("0.9.3")
18+
.booleanConf
19+
.createOptional
20+
921
private val USE_LABELS_AS_COMPONENTS =
1022
SQLConf
1123
.buildConf("spark.graphframes.useLabelsAsComponents")
@@ -108,4 +120,6 @@ object GraphFramesConf {
108120
case Some(use) => Some(use.toBoolean)
109121
case _ => None
110122
}
123+
124+
def getUseLocalCheckpoints: Option[Boolean] = get(USE_LOCAL_CHECKPOINTS).map(_.toBoolean)
111125
}

core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.graphframes.WithAlgorithmChoice
3030
import org.graphframes.WithBroadcastThreshold
3131
import org.graphframes.WithCheckpointInterval
3232
import org.graphframes.WithIntermediateStorageLevel
33+
import org.graphframes.WithLocalCheckpoints
3334
import org.graphframes.WithMaxIter
3435
import org.graphframes.WithUseLabelsAsComponents
3536

@@ -54,7 +55,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
5455
with WithBroadcastThreshold
5556
with WithIntermediateStorageLevel
5657
with WithUseLabelsAsComponents
57-
with WithMaxIter {
58+
with WithMaxIter
59+
with WithLocalCheckpoints {
5860

5961
setAlgorithm(GraphFramesConf.getConnectedComponentsAlgorithm.getOrElse(ALGO_GRAPHFRAMES))
6062
setCheckpointInterval(
@@ -65,6 +67,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
6567
GraphFramesConf.getConnectedComponentsStorageLevel.getOrElse(intermediateStorageLevel))
6668
setUseLabelsAsComponents(
6769
GraphFramesConf.getUseLabelsAsComponents.getOrElse(useLabelsAsComponents))
70+
setUseLocalCheckpoints(GraphFramesConf.getUseLocalCheckpoints.getOrElse(useLocalCheckpoints))
6871

6972
/**
7073
* Runs the algorithm.
@@ -77,7 +80,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
7780
checkpointInterval = checkpointInterval,
7881
intermediateStorageLevel = intermediateStorageLevel,
7982
useLabelsAsComponents = useLabelsAsComponents,
80-
maxIter = maxIter)
83+
maxIter = maxIter,
84+
useLocalCheckpoints = useLocalCheckpoints)
8185
}
8286
}
8387

@@ -190,7 +194,8 @@ object ConnectedComponents extends Logging {
190194
checkpointInterval: Int,
191195
intermediateStorageLevel: StorageLevel,
192196
useLabelsAsComponents: Boolean,
193-
maxIter: Option[Int]): DataFrame = {
197+
maxIter: Option[Int],
198+
useLocalCheckpoints: Boolean): DataFrame = {
194199
if (runInGraphX) {
195200
return runGraphX(graph, maxIter.getOrElse(Int.MaxValue))
196201
}
@@ -208,7 +213,8 @@ object ConnectedComponents extends Logging {
208213
logInfo(s"$logPrefix Start connected components with run ID $runId.")
209214

210215
val shouldCheckpoint = checkpointInterval > 0
211-
val checkpointDir: Option[String] = if (shouldCheckpoint) {
216+
val checkpointDir: Option[String] = if (useLocalCheckpoints) { None }
217+
else if (shouldCheckpoint) {
212218
val dir = sc.getCheckpointDir
213219
.map { d =>
214220
new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString
@@ -297,19 +303,23 @@ object ConnectedComponents extends Logging {
297303

298304
// checkpointing
299305
if (shouldCheckpoint && (iteration % checkpointInterval == 0)) {
300-
// TODO: remove this after DataFrame.checkpoint is implemented
301-
val out = s"${checkpointDir.get}/$iteration"
302-
ee.write.parquet(out)
303-
// may hit S3 eventually consistent issue
304-
ee = spark.read.parquet(out)
305-
306-
// remove previous checkpoint
307-
if (iteration > checkpointInterval) {
308-
val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}")
309-
path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
310-
}
306+
if (useLocalCheckpoints) {
307+
ee = ee.localCheckpoint(eager = true)
308+
} else {
309+
// TODO: remove this after DataFrame.checkpoint is implemented
310+
val out = s"${checkpointDir.get}/$iteration"
311+
ee.write.parquet(out)
312+
// may hit S3 eventually consistent issue
313+
ee = spark.read.parquet(out)
314+
315+
// remove previous checkpoint
316+
if (iteration > checkpointInterval) {
317+
val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}")
318+
path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
319+
}
311320

312-
System.gc() // hint Spark to clean shuffle directories
321+
System.gc() // hint Spark to clean shuffle directories
322+
}
313323
}
314324

315325
ee.persist(intermediateStorageLevel)

core/src/main/scala/org/graphframes/lib/LabelPropagation.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.functions._
2424
import org.graphframes.GraphFrame
2525
import org.graphframes.WithAlgorithmChoice
2626
import org.graphframes.WithCheckpointInterval
27+
import org.graphframes.WithLocalCheckpoints
2728
import org.graphframes.WithMaxIter
2829

2930
/**
@@ -44,14 +45,19 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame)
4445
extends Arguments
4546
with WithAlgorithmChoice
4647
with WithCheckpointInterval
47-
with WithMaxIter {
48+
with WithMaxIter
49+
with WithLocalCheckpoints {
4850

4951
def run(): DataFrame = {
5052
val maxIterChecked = check(maxIter, "maxIter")
5153
algorithm match {
5254
case "graphx" => LabelPropagation.runInGraphX(graph, maxIterChecked)
5355
case "graphframes" =>
54-
LabelPropagation.runInGraphFrames(graph, maxIterChecked, checkpointInterval)
56+
LabelPropagation.runInGraphFrames(
57+
graph,
58+
maxIterChecked,
59+
checkpointInterval,
60+
useLocalCheckpoints = useLocalCheckpoints)
5561
}
5662
}
5763
}
@@ -74,7 +80,8 @@ private object LabelPropagation {
7480
graph: GraphFrame,
7581
maxIter: Int,
7682
checkpointInterval: Int,
77-
isDirected: Boolean = true): DataFrame = {
83+
isDirected: Boolean = true,
84+
useLocalCheckpoints: Boolean): DataFrame = {
7885
// Overall:
7986
// - Initial labels - IDs
8087
// - Active vertex col (halt voting) - did the label changed?
@@ -88,6 +95,7 @@ private object LabelPropagation {
8895
.setCheckpointInterval(checkpointInterval)
8996
.setSkipMessagesFromNonActiveVertices(false)
9097
.setUpdateActiveVertexExpression(col(LABEL_ID) =!= keyWithMaxValue(Pregel.msg))
98+
.setUseLocalCheckpoints(useLocalCheckpoints)
9199

92100
if (isDirected) {
93101
pregel = pregel.sendMsgToDst(col(LABEL_ID))

core/src/main/scala/org/graphframes/lib/Pregel.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.functions.struct
2727
import org.graphframes.GraphFrame
2828
import org.graphframes.GraphFrame._
2929
import org.graphframes.Logging
30+
import org.graphframes.WithLocalCheckpoints
3031

3132
import java.io.IOException
3233
import scala.util.control.Breaks.break
@@ -80,7 +81,7 @@ import scala.util.control.Breaks.breakable
8081
* <a href="https://doi.org/10.1145/1807167.1807184"> Malewicz et al., Pregel: a system for
8182
* large-scale graph processing. </a>
8283
*/
83-
class Pregel(val graph: GraphFrame) extends Logging {
84+
class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints {
8485

8586
private val withVertexColumnList = collection.mutable.ListBuffer.empty[(String, Column, Column)]
8687

@@ -342,7 +343,7 @@ class Pregel(val graph: GraphFrame) extends Logging {
342343

343344
val shouldCheckpoint = checkpointInterval > 0
344345

345-
if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty) {
346+
if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty && !useLocalCheckpoints) {
346347
// Spark Connect workaround
347348
graph.spark.conf.getOption("spark.checkpoint.dir") match {
348349
case Some(d) => graph.spark.sparkContext.setCheckpointDir(d)
@@ -394,9 +395,13 @@ class Pregel(val graph: GraphFrame) extends Logging {
394395
updateActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ updateVertexCols): _*)
395396

396397
if (shouldCheckpoint && iteration % checkpointInterval == 0) {
397-
// do checkpoint, use lazy checkpoint because later we will materialize this DF.
398-
newVertexUpdateColDF = newVertexUpdateColDF.checkpoint(eager = false)
399-
// TODO: remove last checkpoint file.
398+
if (useLocalCheckpoints) {
399+
newVertexUpdateColDF = newVertexUpdateColDF.localCheckpoint(eager = false)
400+
} else {
401+
// do checkpoint, use lazy checkpoint because later we will materialize this DF.
402+
newVertexUpdateColDF = newVertexUpdateColDF.checkpoint(eager = false)
403+
// TODO: remove last checkpoint file.
404+
}
400405
}
401406
newVertexUpdateColDF.cache()
402407
newVertexUpdateColDF.count() // materialize it

core/src/main/scala/org/graphframes/lib/ShortestPaths.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.graphframes.GraphFramesUnreachableException
3838
import org.graphframes.Logging
3939
import org.graphframes.WithAlgorithmChoice
4040
import org.graphframes.WithCheckpointInterval
41+
import org.graphframes.WithLocalCheckpoints
4142

4243
import java.util
4344
import scala.jdk.CollectionConverters._
@@ -54,7 +55,8 @@ import scala.jdk.CollectionConverters._
5455
class ShortestPaths private[graphframes] (private val graph: GraphFrame)
5556
extends Arguments
5657
with WithAlgorithmChoice
57-
with WithCheckpointInterval {
58+
with WithCheckpointInterval
59+
with WithLocalCheckpoints {
5860
import org.graphframes.lib.ShortestPaths._
5961

6062
private var lmarks: Option[Seq[Any]] = None
@@ -79,7 +81,12 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame)
7981
val lmarksChecked = check(lmarks, "landmarks")
8082
algorithm match {
8183
case ALGO_GRAPHX => runInGraphX(graph, lmarksChecked)
82-
case ALGO_GRAPHFRAMES => runInGraphFrames(graph, lmarksChecked, checkpointInterval)
84+
case ALGO_GRAPHFRAMES =>
85+
runInGraphFrames(
86+
graph,
87+
lmarksChecked,
88+
checkpointInterval,
89+
useLocalCheckpoints = useLocalCheckpoints)
8390
case _ => throw new GraphFramesUnreachableException()
8491
}
8592
}
@@ -109,7 +116,8 @@ private object ShortestPaths extends Logging {
109116
graph: GraphFrame,
110117
landmarks: Seq[Any],
111118
checkpointInterval: Int,
112-
isDirected: Boolean = true): DataFrame = {
119+
isDirected: Boolean = true,
120+
useLocalCheckpoints: Boolean): DataFrame = {
113121
logWarn("The GraphFrames based implementation is slow and considered experimental!")
114122
val vertexType = graph.vertices.schema(GraphFrame.ID).dataType
115123

@@ -202,6 +210,8 @@ private object ShortestPaths extends Logging {
202210
.setUpdateActiveVertexExpression(updateActiveVierticesExpr)
203211
.setStopIfAllNonActiveVertices(true)
204212
.setSkipMessagesFromNonActiveVertices(true)
213+
.setCheckpointInterval(checkpointInterval)
214+
.setUseLocalCheckpoints(useLocalCheckpoints)
205215

206216
// Experimental feature
207217
if (isDirected) {

core/src/main/scala/org/graphframes/mixins.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,37 @@ private[graphframes] trait WithUseLabelsAsComponents {
141141
*/
142142
def getUseLabelsAsComponents: Boolean = useLabelsAsComponents
143143
}
144+
145+
/**
146+
* Provides support for local checkpoints in Spark computations.
147+
*
148+
* Local checkpoints offer a faster alternative to regular checkpoints as they don't require
149+
* configuration of checkpointDir in persistent storage (like HDFS or S3). While being more
150+
* performant, local checkpoints are less reliable since they don't survive node failures and the
151+
* data is not persisted across multiple nodes.
152+
*/
153+
private[graphframes] trait WithLocalCheckpoints {
154+
protected var useLocalCheckpoints: Boolean = false
155+
156+
/**
157+
* Sets whether to use local checkpoints instead of regular checkpoints (default: false). Local
158+
* checkpoints are faster but less reliable as they don't survive node failures.
159+
*
160+
* @param value
161+
* true to use local checkpoints, false for regular checkpoints
162+
* @return
163+
* this instance
164+
*/
165+
def setUseLocalCheckpoints(value: Boolean): this.type = {
166+
useLocalCheckpoints = value
167+
this
168+
}
169+
170+
/**
171+
* Gets whether local checkpoints are being used instead of regular checkpoints.
172+
*
173+
* @return
174+
* true if local checkpoints are enabled, false otherwise
175+
*/
176+
def getUseLocalCheckpoints: Boolean = useLocalCheckpoints
177+
}

0 commit comments

Comments
 (0)