Skip to content

Commit 42af82e

Browse files
committed
allow useLabelsAsComponents to be set locally on the ConnectedComponents instance
1 parent 191c355 commit 42af82e

11 files changed

Lines changed: 95 additions & 47 deletions

File tree

graphframes-connect/src/main/protobuf/graphframes.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ message ConnectedComponents {
6363
string algorithm = 1;
6464
int32 checkpoint_interval = 2;
6565
int32 broadcast_threshold = 3;
66+
bool use_labels_as_components = 4;
6667
}
6768

6869
message DropIsolatedVertices {}

python/graphframes/classic/graphframe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,14 @@ def connectedComponents(
206206
algorithm: str = "graphframes",
207207
checkpointInterval: int = 2,
208208
broadcastThreshold: int = 1000000,
209+
useLabelsAsComponents: bool = False,
209210
) -> DataFrame:
210211
jdf = (
211212
self._jvm_graph.connectedComponents()
212213
.setAlgorithm(algorithm)
213214
.setCheckpointInterval(checkpointInterval)
214215
.setBroadcastThreshold(broadcastThreshold)
216+
.setUseLabelsAsComponents(useLabelsAsComponents)
215217
.run()
216218
)
217219
return DataFrame(jdf, self._spark)

python/graphframes/connect/graphframe_client.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ def connectedComponents(
504504
algorithm: str = "graphframes",
505505
checkpointInterval: int = 2,
506506
broadcastThreshold: int = 1000000,
507+
useLabelsAsComponents: bool = False,
507508
) -> DataFrame:
508509
class ConnectedComponents(LogicalPlan):
509510
def __init__(
@@ -513,28 +514,29 @@ def __init__(
513514
algorithm: str,
514515
checkpoint_interval: int,
515516
broadcast_threshold: int,
517+
use_labels_as_components: bool,
516518
) -> None:
517519
super().__init__(None)
518520
self.v = v
519521
self.e = e
520522
self.algorithm = algorithm
521523
self.checkpoint_interval = checkpoint_interval
522524
self.broadcast_threshold = broadcast_threshold
525+
self.useLabelsAsComponents = use_labels_as_components
523526

524-
def plan(self, session: SparkConnectClient) -> proto.Relation:
525-
graphframes_api_call = GraphFrameConnect._get_pb_api_message(
526-
self.v, self.e, session
527-
)
528-
graphframes_api_call.connected_components.CopyFrom(
529-
pb.ConnectedComponents(
530-
algorithm=self.algorithm,
531-
checkpoint_interval=self.checkpoint_interval,
532-
broadcast_threshold=self.broadcast_threshold,
533-
)
527+
def plan(self, session: SparkConnectClient) -> proto.Relation:
528+
graphframes_api_call = GraphFrameConnect._get_pb_api_message(self.v, self.e, session)
529+
graphframes_api_call.connected_components.CopyFrom(
530+
pb.ConnectedComponents(
531+
algorithm=self.algorithm,
532+
checkpoint_interval=self.checkpoint_interval,
533+
broadcast_threshold=self.broadcast_threshold,
534+
use_labels_as_components=self.use_labels_as_components,
534535
)
535-
plan = self._create_proto_relation()
536-
plan.extension.Pack(graphframes_api_call)
537-
return plan
536+
)
537+
plan = self._create_proto_relation()
538+
plan.extension.Pack(graphframes_api_call)
539+
return plan
538540

539541
return _dataframe_from_plan(
540542
ConnectedComponents(
@@ -543,6 +545,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation:
543545
algorithm,
544546
checkpointInterval,
545547
broadcastThreshold,
548+
useLabelsAsComponents,
546549
),
547550
self._spark,
548551
)

python/graphframes/connect/proto/graphframes_pb2.py

Lines changed: 30 additions & 30 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/graphframes/connect/proto/graphframes_pb2.pyi

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,26 @@ class BFS(_message.Message):
147147
) -> None: ...
148148

149149
class ConnectedComponents(_message.Message):
150-
__slots__ = ("algorithm", "checkpoint_interval", "broadcast_threshold")
150+
__slots__ = (
151+
"algorithm",
152+
"checkpoint_interval",
153+
"broadcast_threshold",
154+
"use_labels_as_components",
155+
)
151156
ALGORITHM_FIELD_NUMBER: _ClassVar[int]
152157
CHECKPOINT_INTERVAL_FIELD_NUMBER: _ClassVar[int]
153158
BROADCAST_THRESHOLD_FIELD_NUMBER: _ClassVar[int]
159+
USE_LABELS_AS_COMPONENTS_FIELD_NUMBER: _ClassVar[int]
154160
algorithm: str
155161
checkpoint_interval: int
156162
broadcast_threshold: int
163+
use_labels_as_components: bool
157164
def __init__(
158165
self,
159166
algorithm: _Optional[str] = ...,
160167
checkpoint_interval: _Optional[int] = ...,
161168
broadcast_threshold: _Optional[int] = ...,
169+
use_labels_as_components: bool = ...,
162170
) -> None: ...
163171

164172
class DropIsolatedVertices(_message.Message):

python/graphframes/graphframe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def connectedComponents(
267267
algorithm: str = "graphframes",
268268
checkpointInterval: int = 2,
269269
broadcastThreshold: int = 1000000,
270+
useLabelsAsComponents: bool = False,
270271
) -> DataFrame:
271272
"""
272273
Computes the connected components of the graph.
@@ -278,13 +279,16 @@ def connectedComponents(
278279
:param checkpointInterval: checkpoint interval in terms of number of iterations (default: 2)
279280
:param broadcastThreshold: broadcast threshold in propagating component assignments
280281
(default: 1000000)
282+
:param useLabelsAsComponents: if True, uses the vertex labels as components, otherwise will
283+
use longs
281284
282285
:return: DataFrame with new vertices column "component"
283286
"""
284287
return self._impl.connectedComponents(
285288
algorithm=algorithm,
286289
checkpointInterval=checkpointInterval,
287290
broadcastThreshold=broadcastThreshold,
291+
useLabelsAsComponents=useLabelsAsComponents,
288292
)
289293

290294
def labelPropagation(self, maxIter: int) -> DataFrame:

python/tests/test_graphframes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def test_connected_components_friends(examples, spark):
333333
g.connectedComponents(checkpointInterval=0),
334334
g.connectedComponents(checkpointInterval=10),
335335
g.connectedComponents(algorithm="graphx"),
336+
g.connectedComponents(useLabelsAsComponents=True),
336337
]
337338
for c in comps_tests:
338339
assert c.groupBy("component").count().count() == 2

src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ object GraphFramesConf {
1414
|""".stripMargin)
1515
.version("0.9.0")
1616
.booleanConf
17-
.createWithDefault(true)
17+
.createOptional
1818

1919
private val CONNECTED_COMPONENTS_ALGORITHM =
2020
SQLConf
@@ -104,5 +104,8 @@ object GraphFramesConf {
104104
}
105105
}
106106

107-
def getUseLabelsAsComponents: Boolean = get(USE_LABELS_AS_COMPONENTS).get.toBoolean
107+
def getUseLabelsAsComponents: Option[Boolean] = get(USE_LABELS_AS_COMPONENTS) match {
108+
case Some(use) => Some(use.toBoolean)
109+
case _ => None
110+
}
108111
}

src/main/scala/org/graphframes/lib/ConnectedComponents.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.graphframes.WithBroadcastThreshold
3131
import org.graphframes.WithCheckpointInterval
3232
import org.graphframes.WithIntermediateStorageLevel
3333
import org.graphframes.WithMaxIter
34+
import org.graphframes.WithUseLabelsAsComponents
3435

3536
import java.io.IOException
3637
import java.math.BigDecimal
@@ -52,6 +53,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
5253
with WithCheckpointInterval
5354
with WithBroadcastThreshold
5455
with WithIntermediateStorageLevel
56+
with WithUseLabelsAsComponents
5557
with WithMaxIter {
5658

5759
setAlgorithm(GraphFramesConf.getConnectedComponentsAlgorithm.getOrElse(ALGO_GRAPHFRAMES))
@@ -61,6 +63,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
6163
GraphFramesConf.getConnectedComponentsBroadcastThreshold.getOrElse(broadcastThreshold))
6264
setIntermediateStorageLevel(
6365
GraphFramesConf.getConnectedComponentsStorageLevel.getOrElse(intermediateStorageLevel))
66+
setUseLabelsAsComponents(
67+
GraphFramesConf.getUseLabelsAsComponents.getOrElse(useLabelsAsComponents))
6468

6569
/**
6670
* Runs the algorithm.
@@ -72,6 +76,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
7276
broadcastThreshold = broadcastThreshold,
7377
checkpointInterval = checkpointInterval,
7478
intermediateStorageLevel = intermediateStorageLevel,
79+
useLabelsAsComponents = useLabelsAsComponents,
7580
maxIter = maxIter)
7681
}
7782
}
@@ -184,6 +189,7 @@ object ConnectedComponents extends Logging {
184189
broadcastThreshold: Int,
185190
checkpointInterval: Int,
186191
intermediateStorageLevel: StorageLevel,
192+
useLabelsAsComponents: Boolean,
187193
maxIter: Option[Int]): DataFrame = {
188194
if (runInGraphX) {
189195
return runGraphX(graph, maxIter.getOrElse(Int.MaxValue))
@@ -340,7 +346,7 @@ object ConnectedComponents extends Logging {
340346
vv(ATTR),
341347
when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(COMPONENT),
342348
col(ATTR + "." + ID).as(ID))
343-
val output = if (graph.hasIntegralIdType || !GraphFramesConf.getUseLabelsAsComponents) {
349+
val output = if (graph.hasIntegralIdType || !useLabelsAsComponents) {
344350
indexedLabel
345351
.select(col(s"$ATTR.*"), col(COMPONENT))
346352
.persist(intermediateStorageLevel)

src/main/scala/org/graphframes/mixins.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,22 @@ private[graphframes] trait WithMaxIter {
122122
this
123123
}
124124
}
125+
126+
private[graphframes] trait WithUseLabelsAsComponents {
127+
protected var useLabelsAsComponents: Boolean = false
128+
129+
/**
130+
* Sets whether to use vertex labels as component identifiers (default: false). When true,
131+
* vertex labels will be used as component identifiers instead of computing connected
132+
* components.
133+
*/
134+
def setUseLabelsAsComponents(value: Boolean): this.type = {
135+
useLabelsAsComponents = value
136+
this
137+
}
138+
139+
/**
140+
* Gets whether to use vertex labels as component identifiers.
141+
*/
142+
def getUseLabelsAsComponents: Boolean = useLabelsAsComponents
143+
}

0 commit comments

Comments
 (0)