Skip to content

Commit f3b3b20

Browse files
authored
allow useLabelsAsComponents to be set locally on the ConnectedComponents instance (#632)
* allow useLabelsAsComponents to be set locally on the ConnectedComponents instance * allow useLabelsAsComponents to be set locally on the ConnectedComponents instance * typo
1 parent 191c355 commit f3b3b20

13 files changed

Lines changed: 86 additions & 35 deletions

File tree

graphframes-connect/src/main/protobuf/graphframes.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ message ConnectedComponents {
6363
string algorithm = 1;
6464
int32 checkpoint_interval = 2;
6565
int32 broadcast_threshold = 3;
66+
bool use_labels_as_components = 4;
6667
}
6768

6869
message DropIsolatedVertices {}

graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ object GraphFramesConnectUtils {
102102
.setAlgorithm(cc.getAlgorithm)
103103
.setCheckpointInterval(cc.getCheckpointInterval)
104104
.setBroadcastThreshold(cc.getBroadcastThreshold)
105+
.setUseLabelsAsComponents(cc.getUseLabelsAsComponents)
105106
.run()
106107
}
107108
case MethodCase.DROP_ISOLATED_VERTICES => {

python/graphframes/classic/graphframe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,14 @@ def connectedComponents(
206206
algorithm: str = "graphframes",
207207
checkpointInterval: int = 2,
208208
broadcastThreshold: int = 1000000,
209+
useLabelsAsComponents: bool = False,
209210
) -> DataFrame:
210211
jdf = (
211212
self._jvm_graph.connectedComponents()
212213
.setAlgorithm(algorithm)
213214
.setCheckpointInterval(checkpointInterval)
214215
.setBroadcastThreshold(broadcastThreshold)
216+
.setUseLabelsAsComponents(useLabelsAsComponents)
215217
.run()
216218
)
217219
return DataFrame(jdf, self._spark)

python/graphframes/connect/graphframe_client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ def connectedComponents(
504504
algorithm: str = "graphframes",
505505
checkpointInterval: int = 2,
506506
broadcastThreshold: int = 1000000,
507+
useLabelsAsComponents: bool = False,
507508
) -> DataFrame:
508509
class ConnectedComponents(LogicalPlan):
509510
def __init__(
@@ -513,13 +514,15 @@ def __init__(
513514
algorithm: str,
514515
checkpoint_interval: int,
515516
broadcast_threshold: int,
517+
use_labels_as_components: bool,
516518
) -> None:
517519
super().__init__(None)
518520
self.v = v
519521
self.e = e
520522
self.algorithm = algorithm
521523
self.checkpoint_interval = checkpoint_interval
522524
self.broadcast_threshold = broadcast_threshold
525+
self.use_labels_as_components = use_labels_as_components
523526

524527
def plan(self, session: SparkConnectClient) -> proto.Relation:
525528
graphframes_api_call = GraphFrameConnect._get_pb_api_message(
@@ -530,6 +533,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation:
530533
algorithm=self.algorithm,
531534
checkpoint_interval=self.checkpoint_interval,
532535
broadcast_threshold=self.broadcast_threshold,
536+
use_labels_as_components=self.use_labels_as_components,
533537
)
534538
)
535539
plan = self._create_proto_relation()
@@ -543,6 +547,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation:
543547
algorithm,
544548
checkpointInterval,
545549
broadcastThreshold,
550+
useLabelsAsComponents,
546551
),
547552
self._spark,
548553
)

python/graphframes/connect/proto/graphframes_pb2.py

Lines changed: 30 additions & 30 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/graphframes/connect/proto/graphframes_pb2.pyi

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,26 @@ class BFS(_message.Message):
147147
) -> None: ...
148148

149149
class ConnectedComponents(_message.Message):
150-
__slots__ = ("algorithm", "checkpoint_interval", "broadcast_threshold")
150+
__slots__ = (
151+
"algorithm",
152+
"checkpoint_interval",
153+
"broadcast_threshold",
154+
"use_labels_as_components",
155+
)
151156
ALGORITHM_FIELD_NUMBER: _ClassVar[int]
152157
CHECKPOINT_INTERVAL_FIELD_NUMBER: _ClassVar[int]
153158
BROADCAST_THRESHOLD_FIELD_NUMBER: _ClassVar[int]
159+
USE_LABELS_AS_COMPONENTS_FIELD_NUMBER: _ClassVar[int]
154160
algorithm: str
155161
checkpoint_interval: int
156162
broadcast_threshold: int
163+
use_labels_as_components: bool
157164
def __init__(
158165
self,
159166
algorithm: _Optional[str] = ...,
160167
checkpoint_interval: _Optional[int] = ...,
161168
broadcast_threshold: _Optional[int] = ...,
169+
use_labels_as_components: bool = ...,
162170
) -> None: ...
163171

164172
class DropIsolatedVertices(_message.Message):

python/graphframes/graphframe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def connectedComponents(
267267
algorithm: str = "graphframes",
268268
checkpointInterval: int = 2,
269269
broadcastThreshold: int = 1000000,
270+
useLabelsAsComponents: bool = False,
270271
) -> DataFrame:
271272
"""
272273
Computes the connected components of the graph.
@@ -278,13 +279,16 @@ def connectedComponents(
278279
:param checkpointInterval: checkpoint interval in terms of number of iterations (default: 2)
279280
:param broadcastThreshold: broadcast threshold in propagating component assignments
280281
(default: 1000000)
282+
:param useLabelsAsComponents: if True, uses the vertex labels as components, otherwise will
283+
use longs
281284
282285
:return: DataFrame with new vertices column "component"
283286
"""
284287
return self._impl.connectedComponents(
285288
algorithm=algorithm,
286289
checkpointInterval=checkpointInterval,
287290
broadcastThreshold=broadcastThreshold,
291+
useLabelsAsComponents=useLabelsAsComponents,
288292
)
289293

290294
def labelPropagation(self, maxIter: int) -> DataFrame:

python/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
def is_remote() -> bool:
1919
return False
20-
20+
2121
spark_major_version = __version__[:1]
2222
scala_version = os.environ.get("SCALA_VERSION", "2.12" if __version__ < "4" else "2.13")
2323

python/tests/test_graphframes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def test_connected_components_friends(examples, spark):
333333
g.connectedComponents(checkpointInterval=0),
334334
g.connectedComponents(checkpointInterval=10),
335335
g.connectedComponents(algorithm="graphx"),
336+
g.connectedComponents(useLabelsAsComponents=True),
336337
]
337338
for c in comps_tests:
338339
assert c.groupBy("component").count().count() == 2

src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConf.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ object GraphFramesConf {
1414
|""".stripMargin)
1515
.version("0.9.0")
1616
.booleanConf
17-
.createWithDefault(true)
17+
.createOptional
1818

1919
private val CONNECTED_COMPONENTS_ALGORITHM =
2020
SQLConf
@@ -104,5 +104,8 @@ object GraphFramesConf {
104104
}
105105
}
106106

107-
def getUseLabelsAsComponents: Boolean = get(USE_LABELS_AS_COMPONENTS).get.toBoolean
107+
def getUseLabelsAsComponents: Option[Boolean] = get(USE_LABELS_AS_COMPONENTS) match {
108+
case Some(use) => Some(use.toBoolean)
109+
case _ => None
110+
}
108111
}

0 commit comments

Comments
 (0)