Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions graphframes-connect/src/main/protobuf/graphframes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ message ConnectedComponents {
string algorithm = 1;
int32 checkpoint_interval = 2;
int32 broadcast_threshold = 3;
bool use_labels_as_components = 4;
}

message DropIsolatedVertices {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ object GraphFramesConnectUtils {
.setAlgorithm(cc.getAlgorithm)
.setCheckpointInterval(cc.getCheckpointInterval)
.setBroadcastThreshold(cc.getBroadcastThreshold)
.setUseLabelsAsComponents(cc.getUseLabelsAsComponents)
.run()
}
case MethodCase.DROP_ISOLATED_VERTICES => {
Expand Down
2 changes: 2 additions & 0 deletions python/graphframes/classic/graphframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,14 @@ def connectedComponents(
algorithm: str = "graphframes",
checkpointInterval: int = 2,
broadcastThreshold: int = 1000000,
useLabelsAsComponents: bool = False,
) -> DataFrame:
jdf = (
self._jvm_graph.connectedComponents()
.setAlgorithm(algorithm)
.setCheckpointInterval(checkpointInterval)
.setBroadcastThreshold(broadcastThreshold)
.setUseLabelsAsComponents(useLabelsAsComponents)
.run()
)
return DataFrame(jdf, self._spark)
Expand Down
5 changes: 5 additions & 0 deletions python/graphframes/connect/graphframe_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def connectedComponents(
algorithm: str = "graphframes",
checkpointInterval: int = 2,
broadcastThreshold: int = 1000000,
useLabelsAsComponents: bool = False,
) -> DataFrame:
class ConnectedComponents(LogicalPlan):
def __init__(
Expand All @@ -513,13 +514,15 @@ def __init__(
algorithm: str,
checkpoint_interval: int,
broadcast_threshold: int,
use_labels_as_components: bool,
) -> None:
super().__init__(None)
self.v = v
self.e = e
self.algorithm = algorithm
self.checkpoint_interval = checkpoint_interval
self.broadcast_threshold = broadcast_threshold
self.use_labels_as_components = use_labels_as_components

def plan(self, session: SparkConnectClient) -> proto.Relation:
graphframes_api_call = GraphFrameConnect._get_pb_api_message(
Expand All @@ -530,6 +533,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation:
algorithm=self.algorithm,
checkpoint_interval=self.checkpoint_interval,
broadcast_threshold=self.broadcast_threshold,
use_labels_as_components=self.use_labels_as_components,
)
)
plan = self._create_proto_relation()
Expand All @@ -543,6 +547,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation:
algorithm,
checkpointInterval,
broadcastThreshold,
useLabelsAsComponents,
),
self._spark,
)
Expand Down
60 changes: 30 additions & 30 deletions python/graphframes/connect/proto/graphframes_pb2.py

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion python/graphframes/connect/proto/graphframes_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,26 @@ class BFS(_message.Message):
) -> None: ...

class ConnectedComponents(_message.Message):
__slots__ = ("algorithm", "checkpoint_interval", "broadcast_threshold")
__slots__ = (
"algorithm",
"checkpoint_interval",
"broadcast_threshold",
"use_labels_as_components",
)
ALGORITHM_FIELD_NUMBER: _ClassVar[int]
CHECKPOINT_INTERVAL_FIELD_NUMBER: _ClassVar[int]
BROADCAST_THRESHOLD_FIELD_NUMBER: _ClassVar[int]
USE_LABELS_AS_COMPONENTS_FIELD_NUMBER: _ClassVar[int]
algorithm: str
checkpoint_interval: int
broadcast_threshold: int
use_labels_as_components: bool
def __init__(
self,
algorithm: _Optional[str] = ...,
checkpoint_interval: _Optional[int] = ...,
broadcast_threshold: _Optional[int] = ...,
use_labels_as_components: bool = ...,
) -> None: ...

class DropIsolatedVertices(_message.Message):
Expand Down
4 changes: 4 additions & 0 deletions python/graphframes/graphframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def connectedComponents(
algorithm: str = "graphframes",
checkpointInterval: int = 2,
broadcastThreshold: int = 1000000,
useLabelsAsComponents: bool = False,
) -> DataFrame:
"""
Computes the connected components of the graph.
Expand All @@ -278,13 +279,16 @@ def connectedComponents(
:param checkpointInterval: checkpoint interval in terms of number of iterations (default: 2)
:param broadcastThreshold: broadcast threshold in propagating component assignments
(default: 1000000)
:param useLabelsAsComponents: if True, uses the vertex labels as components, otherwise will
use longs

:return: DataFrame with new vertices column "component"
"""
return self._impl.connectedComponents(
algorithm=algorithm,
checkpointInterval=checkpointInterval,
broadcastThreshold=broadcastThreshold,
useLabelsAsComponents=useLabelsAsComponents,
)

def labelPropagation(self, maxIter: int) -> DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def is_remote() -> bool:
return False

spark_major_version = __version__[:1]
scala_version = os.environ.get("SCALA_VERSION", "2.12" if __version__ < "4" else "2.13")

Expand Down
1 change: 1 addition & 0 deletions python/tests/test_graphframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def test_connected_components_friends(examples, spark):
g.connectedComponents(checkpointInterval=0),
g.connectedComponents(checkpointInterval=10),
g.connectedComponents(algorithm="graphx"),
g.connectedComponents(useLabelsAsComponents=True),
]
for c in comps_tests:
assert c.groupBy("component").count().count() == 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ object GraphFramesConf {
|""".stripMargin)
.version("0.9.0")
.booleanConf
.createWithDefault(true)
.createOptional

private val CONNECTED_COMPONENTS_ALGORITHM =
SQLConf
Expand Down Expand Up @@ -104,5 +104,8 @@ object GraphFramesConf {
}
}

def getUseLabelsAsComponents: Boolean = get(USE_LABELS_AS_COMPONENTS).get.toBoolean
def getUseLabelsAsComponents: Option[Boolean] = get(USE_LABELS_AS_COMPONENTS) match {
case Some(use) => Some(use.toBoolean)
case _ => None
}
}
8 changes: 7 additions & 1 deletion src/main/scala/org/graphframes/lib/ConnectedComponents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.graphframes.WithBroadcastThreshold
import org.graphframes.WithCheckpointInterval
import org.graphframes.WithIntermediateStorageLevel
import org.graphframes.WithMaxIter
import org.graphframes.WithUseLabelsAsComponents

import java.io.IOException
import java.math.BigDecimal
Expand All @@ -52,6 +53,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
with WithCheckpointInterval
with WithBroadcastThreshold
with WithIntermediateStorageLevel
with WithUseLabelsAsComponents
with WithMaxIter {

setAlgorithm(GraphFramesConf.getConnectedComponentsAlgorithm.getOrElse(ALGO_GRAPHFRAMES))
Expand All @@ -61,6 +63,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
GraphFramesConf.getConnectedComponentsBroadcastThreshold.getOrElse(broadcastThreshold))
setIntermediateStorageLevel(
GraphFramesConf.getConnectedComponentsStorageLevel.getOrElse(intermediateStorageLevel))
setUseLabelsAsComponents(
GraphFramesConf.getUseLabelsAsComponents.getOrElse(useLabelsAsComponents))

/**
* Runs the algorithm.
Expand All @@ -72,6 +76,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
broadcastThreshold = broadcastThreshold,
checkpointInterval = checkpointInterval,
intermediateStorageLevel = intermediateStorageLevel,
useLabelsAsComponents = useLabelsAsComponents,
maxIter = maxIter)
}
}
Expand Down Expand Up @@ -184,6 +189,7 @@ object ConnectedComponents extends Logging {
broadcastThreshold: Int,
checkpointInterval: Int,
intermediateStorageLevel: StorageLevel,
useLabelsAsComponents: Boolean,
maxIter: Option[Int]): DataFrame = {
if (runInGraphX) {
return runGraphX(graph, maxIter.getOrElse(Int.MaxValue))
Expand Down Expand Up @@ -340,7 +346,7 @@ object ConnectedComponents extends Logging {
vv(ATTR),
when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(COMPONENT),
col(ATTR + "." + ID).as(ID))
val output = if (graph.hasIntegralIdType || !GraphFramesConf.getUseLabelsAsComponents) {
val output = if (graph.hasIntegralIdType || !useLabelsAsComponents) {
indexedLabel
.select(col(s"$ATTR.*"), col(COMPONENT))
.persist(intermediateStorageLevel)
Expand Down
19 changes: 19 additions & 0 deletions src/main/scala/org/graphframes/mixins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,22 @@ private[graphframes] trait WithMaxIter {
this
}
}

private[graphframes] trait WithUseLabelsAsComponents {
protected var useLabelsAsComponents: Boolean = false

/**
* Sets whether to use vertex labels as component identifiers (default: false). When true,
* vertex labels will be used as component identifiers instead of computing connected
* components.
*/
def setUseLabelsAsComponents(value: Boolean): this.type = {
useLabelsAsComponents = value
this
}

/**
* Gets whether to use vertex labels as component identifiers.
*/
def getUseLabelsAsComponents: Boolean = useLabelsAsComponents
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
assert(cc.getAlgorithm === "graphframes")
assert(cc.getBroadcastThreshold === 1000000)
assert(cc.getCheckpointInterval === 2)
assert(!cc.getUseLabelsAsComponents)
}

test("empty graph") {
Expand Down