@@ -30,6 +30,7 @@ import org.graphframes.WithAlgorithmChoice
3030import org .graphframes .WithBroadcastThreshold
3131import org .graphframes .WithCheckpointInterval
3232import org .graphframes .WithIntermediateStorageLevel
33+ import org .graphframes .WithLocalCheckpoints
3334import org .graphframes .WithMaxIter
3435import org .graphframes .WithUseLabelsAsComponents
3536
@@ -54,7 +55,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
5455 with WithBroadcastThreshold
5556 with WithIntermediateStorageLevel
5657 with WithUseLabelsAsComponents
57- with WithMaxIter {
58+ with WithMaxIter
59+ with WithLocalCheckpoints {
5860
5961 setAlgorithm(GraphFramesConf .getConnectedComponentsAlgorithm.getOrElse(ALGO_GRAPHFRAMES ))
6062 setCheckpointInterval(
@@ -65,6 +67,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
6567 GraphFramesConf .getConnectedComponentsStorageLevel.getOrElse(intermediateStorageLevel))
6668 setUseLabelsAsComponents(
6769 GraphFramesConf .getUseLabelsAsComponents.getOrElse(useLabelsAsComponents))
70+ setUseLocalCheckpoints(GraphFramesConf .getUseLocalCheckpoints.getOrElse(useLocalCheckpoints))
6871
6972 /**
7073 * Runs the algorithm.
@@ -77,7 +80,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
7780 checkpointInterval = checkpointInterval,
7881 intermediateStorageLevel = intermediateStorageLevel,
7982 useLabelsAsComponents = useLabelsAsComponents,
80- maxIter = maxIter)
83+ maxIter = maxIter,
84+ useLocalCheckpoints = useLocalCheckpoints)
8185 }
8286}
8387
@@ -190,7 +194,8 @@ object ConnectedComponents extends Logging {
190194 checkpointInterval : Int ,
191195 intermediateStorageLevel : StorageLevel ,
192196 useLabelsAsComponents : Boolean ,
193- maxIter : Option [Int ]): DataFrame = {
197+ maxIter : Option [Int ],
198+ useLocalCheckpoints : Boolean ): DataFrame = {
194199 if (runInGraphX) {
195200 return runGraphX(graph, maxIter.getOrElse(Int .MaxValue ))
196201 }
@@ -208,7 +213,8 @@ object ConnectedComponents extends Logging {
208213 logInfo(s " $logPrefix Start connected components with run ID $runId. " )
209214
210215 val shouldCheckpoint = checkpointInterval > 0
211- val checkpointDir : Option [String ] = if (shouldCheckpoint) {
216+ val checkpointDir : Option [String ] = if (useLocalCheckpoints) { None }
217+ else if (shouldCheckpoint) {
212218 val dir = sc.getCheckpointDir
213219 .map { d =>
214220 new Path (d, s " $CHECKPOINT_NAME_PREFIX- $runId" ).toString
@@ -297,19 +303,23 @@ object ConnectedComponents extends Logging {
297303
298304 // checkpointing
299305 if (shouldCheckpoint && (iteration % checkpointInterval == 0 )) {
300- // TODO: remove this after DataFrame.checkpoint is implemented
301- val out = s " ${checkpointDir.get}/ $iteration"
302- ee.write.parquet(out)
303- // may hit S3 eventually consistent issue
304- ee = spark.read.parquet(out)
305-
306- // remove previous checkpoint
307- if (iteration > checkpointInterval) {
308- val path = new Path (s " ${checkpointDir.get}/ ${iteration - checkpointInterval}" )
309- path.getFileSystem(sc.hadoopConfiguration).delete(path, true )
310- }
306+ if (useLocalCheckpoints) {
307+ ee = ee.localCheckpoint(eager = true )
308+ } else {
309+ // TODO: remove this after DataFrame.checkpoint is implemented
310+ val out = s " ${checkpointDir.get}/ $iteration"
311+ ee.write.parquet(out)
312+ // may hit S3 eventually consistent issue
313+ ee = spark.read.parquet(out)
314+
315+ // remove previous checkpoint
316+ if (iteration > checkpointInterval) {
317+ val path = new Path (s " ${checkpointDir.get}/ ${iteration - checkpointInterval}" )
318+ path.getFileSystem(sc.hadoopConfiguration).delete(path, true )
319+ }
311320
312- System .gc() // hint Spark to clean shuffle directories
321+ System .gc() // hint Spark to clean shuffle directories
322+ }
313323 }
314324
315325 ee.persist(intermediateStorageLevel)
0 commit comments