Skip to content

Commit 8ca28d2

Browse files
committed
implement GraphFramesConf using ConfigBuilder
1 parent a863a3a commit 8ca28d2

4 files changed

Lines changed: 181 additions & 48 deletions

File tree

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package org.apache.spark.sql.graphframes
2+
3+
import org.apache.spark.internal.config.ConfigEntry
4+
import org.apache.spark.sql.SparkSession
5+
import org.apache.spark.sql.internal.SQLConf
6+
import org.apache.spark.storage.StorageLevel
7+
8+
object GraphFramesConf {
9+
private val CONNECTED_COMPONENTS_ALGORITHM =
10+
SQLConf
11+
.buildConf("spark.graphframes.connectedComponents.algorithm")
12+
.doc(""" Sets the connected components algorithm to use (default: "graphframes"). Supported algorithms
13+
| - "graphframes": Uses alternating large star and small star iterations proposed in
14+
| [[http://dx.doi.org/10.1145/2670979.2670997 Connected Components in MapReduce and Beyond]]
15+
| with skewed join optimization.
16+
| - "graphx": Converts the graph to a GraphX graph and then uses the connected components
17+
| implementation in GraphX.
18+
| @see org.graphframes.lib.ConnectedComponents.supportedAlgorithms""".stripMargin)
19+
.version("0.9.0")
20+
.stringConf
21+
.createOptional
22+
23+
private val CONNECTED_COMPONENTS_BROADCAST_THRESHOLD =
24+
SQLConf
25+
.buildConf("spark.graphframes.connectedComponents.broadcastthreshold")
26+
.doc(""" Sets broadcast threshold in propagating component assignments (default: 1000000). If a node
27+
| degree is greater than this threshold at some iteration, its component assignment will be
28+
| collected and then broadcasted back to propagate the assignment to its neighbors. Otherwise,
29+
| the assignment propagation is done by a normal Spark join. This parameter is only used when
30+
| the algorithm is set to "graphframes".""".stripMargin)
31+
.version("0.9.0")
32+
.intConf
33+
.createOptional
34+
35+
private val CONNECTED_COMPONENTS_CHECKPOINT_INTERVAL =
36+
SQLConf
37+
.buildConf("spark.graphframes.connectedComponents.checkpointinterval")
38+
.doc(""" Sets checkpoint interval in terms of number of iterations (default: 2). Checkpointing
39+
| regularly helps recover from failures, clean shuffle files, shorten the lineage of the
40+
| computation graph, and reduce the complexity of plan optimization. As of Spark 2.0, the
41+
| complexity of plan optimization would grow exponentially without checkpointing. Hence,
42+
| disabling or setting longer-than-default checkpoint intervals are not recommended. Checkpoint
43+
| data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with prefix
44+
| "connected-components". If the checkpoint directory is not set, this throws a
45+
| `java.io.IOException`. Set a nonpositive value to disable checkpointing. This parameter is
46+
| only used when the algorithm is set to "graphframes". Its default value might change in the
47+
| future.
48+
| @see `org.apache.spark.SparkContext.setCheckpointDir` in Spark API doc""".stripMargin)
49+
.version("0.9.0")
50+
.intConf
51+
.createOptional
52+
53+
private val CONNECTED_COMPONENTS_INTERMEDIATE_STORAGE_LEVEL =
54+
SQLConf
55+
.buildConf("spark.graphframes.connectedComponents.intermediatestoragelevel")
56+
.doc("Sets storage level for intermediate datasets that require multiple passes (default: ``MEMORY_AND_DISK``).")
57+
.version("0.9.0")
58+
.stringConf
59+
.createOptional
60+
61+
private def get(entry: ConfigEntry[_]): Option[String] = {
62+
try {
63+
Option(SparkSession.getActiveSession.get.conf.get(entry.key))
64+
} catch {
65+
case _: NoSuchElementException => None
66+
}
67+
}
68+
69+
def getConnectedComponentsAlgorithm: Option[String] = {
70+
get(CONNECTED_COMPONENTS_ALGORITHM) match {
71+
case Some(threshold) => Some(threshold.toLowerCase)
72+
case _ => None
73+
}
74+
}
75+
76+
def getConnectedComponentsBroadcastThreshold: Option[Int] = {
77+
get(CONNECTED_COMPONENTS_BROADCAST_THRESHOLD) match {
78+
case Some(threshold) => Some(threshold.toInt)
79+
case _ => None
80+
}
81+
}
82+
83+
def getConnectedComponentsCheckpointInterval: Option[Int] = {
84+
get(CONNECTED_COMPONENTS_CHECKPOINT_INTERVAL) match {
85+
case Some(interval) => Some(interval.toInt)
86+
case _ => None
87+
}
88+
}
89+
90+
def getConnectedComponentsStorageLevel: Option[StorageLevel] = {
91+
get(CONNECTED_COMPONENTS_INTERMEDIATE_STORAGE_LEVEL) match {
92+
case Some(level) => Some(StorageLevel.fromString(level.toUpperCase))
93+
case _ => None
94+
}
95+
}
96+
}

src/main/scala/org/graphframes/lib/ConnectedComponents.scala

Lines changed: 12 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ import org.apache.hadoop.fs.Path
2121
import org.apache.spark.sql.Column
2222
import org.apache.spark.sql.DataFrame
2323
import org.apache.spark.sql.functions._
24+
import org.apache.spark.sql.graphframes.GraphFramesConf
2425
import org.apache.spark.sql.types.DecimalType
2526
import org.apache.spark.storage.StorageLevel
2627
import org.graphframes.GraphFrame
2728
import org.graphframes.Logging
2829
import org.graphframes.WithAlgorithmChoice
30+
import org.graphframes.WithBroadcastThreshold
2931
import org.graphframes.WithCheckpointInterval
32+
import org.graphframes.WithIntermediateStorageLevel
3033
import org.graphframes.WithMaxIter
3134

3235
import java.io.IOException
@@ -47,56 +50,17 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
4750
with Logging
4851
with WithAlgorithmChoice
4952
with WithCheckpointInterval
53+
with WithBroadcastThreshold
54+
with WithIntermediateStorageLevel
5055
with WithMaxIter {
5156

52-
private var broadcastThreshold: Int = 1000000
53-
setAlgorithm(ALGO_GRAPHFRAMES)
54-
55-
/**
56-
* Sets broadcast threshold in propagating component assignments (default: 1000000). If a node
57-
* degree is greater than this threshold at some iteration, its component assignment will be
58-
* collected and then broadcasted back to propagate the assignment to its neighbors. Otherwise,
59-
* the assignment propagation is done by a normal Spark join. This parameter is only used when
60-
* the algorithm is set to "graphframes".
61-
*/
62-
def setBroadcastThreshold(value: Int): this.type = {
63-
require(value >= 0, s"Broadcast threshold must be non-negative but got $value.")
64-
broadcastThreshold = value
65-
this
66-
}
67-
68-
// python-friendly setter
69-
private[graphframes] def setBroadcastThreshold(value: java.lang.Integer): this.type = {
70-
setBroadcastThreshold(value.toInt)
71-
}
72-
73-
/**
74-
* Gets broadcast threshold in propagating component assignment.
75-
* @see
76-
* [[org.graphframes.lib.ConnectedComponents.setBroadcastThreshold]]
77-
*/
78-
def getBroadcastThreshold: Int = broadcastThreshold
79-
80-
// python-friendly setter
81-
private[graphframes] def setCheckpointInterval(value: java.lang.Integer): this.type = {
82-
setCheckpointInterval(value.toInt)
83-
}
84-
85-
private var intermediateStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
86-
87-
/**
88-
* Sets storage level for intermediate datasets that require multiple passes (default:
89-
* ``MEMORY_AND_DISK``).
90-
*/
91-
def setIntermediateStorageLevel(value: StorageLevel): this.type = {
92-
intermediateStorageLevel = value
93-
this
94-
}
95-
96-
/**
97-
* Gets storage level for intermediate datasets that require multiple passes.
98-
*/
99-
def getIntermediateStorageLevel: StorageLevel = intermediateStorageLevel
57+
setAlgorithm(GraphFramesConf.getConnectedComponentsAlgorithm.getOrElse(algorithm))
58+
setCheckpointInterval(
59+
GraphFramesConf.getConnectedComponentsCheckpointInterval.getOrElse(checkpointInterval))
60+
setBroadcastThreshold(
61+
GraphFramesConf.getConnectedComponentsBroadcastThreshold.getOrElse(broadcastThreshold))
62+
setIntermediateStorageLevel(
63+
GraphFramesConf.getConnectedComponentsStorageLevel.getOrElse(intermediateStorageLevel))
10064

10165
/**
10266
* Runs the algorithm.

src/main/scala/org/graphframes/mixins.scala

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package org.graphframes
22

3+
import org.apache.spark.storage.StorageLevel
4+
35
private[graphframes] trait WithAlgorithmChoice {
46
protected val ALGO_GRAPHX = "graphx"
57
protected val ALGO_GRAPHFRAMES = "graphframes"
@@ -49,12 +51,66 @@ private[graphframes] trait WithCheckpointInterval extends Logging {
4951
this
5052
}
5153

54+
// python-friendly setter
55+
private[graphframes] def setCheckpointInterval(value: java.lang.Integer): this.type = {
56+
setCheckpointInterval(value.toInt)
57+
}
58+
5259
/**
5360
* Gets checkpoint interval.
5461
*/
5562
def getCheckpointInterval: Int = checkpointInterval
5663
}
5764

65+
private[graphframes] trait WithBroadcastThreshold extends Logging {
66+
protected var broadcastThreshold: Int = 1000000
67+
68+
/**
69+
* Sets broadcast threshold in propagating component assignments (default: 1000000). If a node
70+
* degree is greater than this threshold at some iteration, its component assignment will be
71+
* collected and then broadcasted back to propagate the assignment to its neighbors. Otherwise,
72+
* the assignment propagation is done by a normal Spark join. This parameter is only used when
73+
* the algorithm is set to "graphframes".
74+
*/
75+
def setBroadcastThreshold(value: Int): this.type = {
76+
require(value >= 0, s"Broadcast threshold must be non-negative but got $value.")
77+
broadcastThreshold = value
78+
this
79+
}
80+
81+
// python-friendly setter
82+
private[graphframes] def setBroadcastThreshold(value: java.lang.Integer): this.type = {
83+
setBroadcastThreshold(value.toInt)
84+
}
85+
86+
/**
87+
* Gets broadcast threshold in propagating component assignment.
88+
* @see
89+
* [[org.graphframes.lib.ConnectedComponents.setBroadcastThreshold]]
90+
*/
91+
def getBroadcastThreshold: Int = broadcastThreshold
92+
}
93+
94+
private[graphframes] trait WithIntermediateStorageLevel extends Logging {
95+
96+
protected var intermediateStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
97+
98+
/**
99+
* Sets storage level for intermediate datasets that require multiple passes (default:
100+
* ``MEMORY_AND_DISK``).
101+
*/
102+
def setIntermediateStorageLevel(value: StorageLevel): this.type = {
103+
intermediateStorageLevel = value
104+
this
105+
}
106+
107+
/**
108+
* Gets storage level for intermediate datasets that require multiple passes.
109+
*/
110+
def getIntermediateStorageLevel: StorageLevel = intermediateStorageLevel
111+
112+
}
113+
58114
private[graphframes] trait WithMaxIter {
59115
protected var maxIter: Option[Int] = None
60116

src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,23 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
264264
assert(spark.sparkContext.getPersistentRDDs.size === priorCachedDFsSize)
265265
}
266266

267+
test("set configuration from spark conf") {
268+
spark.conf.set("spark.graphframes.connectedComponents.algorithm", "GRAPHX")
269+
assert(Graphs.friends.connectedComponents.getAlgorithm == "graphx")
270+
271+
spark.conf.set("spark.graphframes.connectedComponents.broadcastthreshold", "1000")
272+
assert(Graphs.friends.connectedComponents.getBroadcastThreshold == 1000)
273+
274+
spark.conf.set("spark.graphframes.connectedComponents.checkpointinterval", "5")
275+
assert(Graphs.friends.connectedComponents.getCheckpointInterval == 5)
276+
277+
spark.conf.set(
278+
"spark.graphframes.connectedComponents.intermediatestoragelevel",
279+
"memory_only")
280+
assert(
281+
Graphs.friends.connectedComponents.getIntermediateStorageLevel == StorageLevel.MEMORY_ONLY)
282+
}
283+
267284
private def assertComponents[T: ClassTag: TypeTag](
268285
actual: DataFrame,
269286
expected: Set[Set[T]]): Unit = {

0 commit comments

Comments
 (0)