Skip to content

Commit 7281d09

Browse files
feat: add Rocha–Thatte cycle detection algorithm (#700)
* Add Rocha–Thatte cycle detection algorithm * fix 2.12 failing tests * merge main + fix docs generation * fix bad merge + remove unnecessary distinct * fix bad merge
1 parent 4e894e4 commit 7281d09

6 files changed

Lines changed: 249 additions & 14 deletions

File tree

build.sbt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ lazy val commonSetting = Seq(
121121
ScalacOptions.warnUnusedPrivates,
122122
ScalacOptions.warnNumericWiden,
123123
ScalacOptions.privateWarnNumericWiden,
124+
ScalacOptions.warnUnusedNoWarn,
125+
ScalacOptions.privateWarnUnusedNoWarn,
124126
))
125127

126128
lazy val graphx = (project in file("graphx"))

core/src/main/scala/org/graphframes/GraphFrame.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,25 @@ class GraphFrame private (
637637
}
638638
}
639639

640+
/**
641+
* Find all cycles in the graph. An implementation of the Rocha–Thatte cycle detection
642+
* algorithm.
643+
*
644+
* Rocha, Rodrigo Caetano, and Bhalchandra D. Thatte. "Distributed cycle detection in
645+
* large-scale sparse graphs." Proceedings of Simpósio Brasileiro de Pesquisa Operacional
646+
* (SBPO’15) (2015): 1-11.
647+
*
648+
* Returns a DataFrame with ID and cycles, ID are not unique if there are multiple cycles
649+
* starting from this ID. For the case of cycle 1 -> 2 -> 3 -> 1 all the vertices will have the
650+
* same cycle! E.g.: 1 -> [1, 2, 3, 1] 2 -> [2, 3, 1, 2] 3 -> [3, 1, 2, 3]
651+
*
652+
* Deduplication of cycles should be done by the user!
653+
*
654+
* @return
655+
* an instance of DetectingCycles initialized with the current context
656+
*/
657+
def detectingCycles: DetectingCycles = new DetectingCycles(this)
658+
640659
// ========= Motif finding (private) =========
641660

642661
/**

core/src/main/scala/org/graphframes/lib/AggregateMessages.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ import org.graphframes.WithIntermediateStorageLevel
4747
* - [[AggregateMessages.msg]]: message sent to vertex (for aggregation function)
4848
*
4949
* Note: If you use this operation to write an iterative algorithm, you may want to use
50-
* [[AggregateMessages$.getCachedDataFrame getCachedDataFrame()]] as a workaround for caching
51-
* issues.
50+
* `org.apache.spark.sql.Dataset.checkpoint` or `org.apache.spark.sql.Dataset.localCheckpoint`` as
51+
* a workaround for caching issues.
5252
*
5353
* @example
5454
* We can use this function to compute the in-degree of each vertex
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package org.graphframes.lib
2+
3+
import org.apache.spark.sql.DataFrame
4+
import org.apache.spark.sql.functions._
5+
import org.apache.spark.sql.types.ArrayType
6+
import org.apache.spark.storage.StorageLevel
7+
import org.graphframes.GraphFrame
8+
import org.graphframes.Logging
9+
import org.graphframes.WithCheckpointInterval
10+
import org.graphframes.WithIntermediateStorageLevel
11+
import org.graphframes.WithLocalCheckpoints
12+
13+
class DetectingCycles private[graphframes] (private val graph: GraphFrame)
14+
extends Arguments
15+
with Serializable
16+
with Logging
17+
with WithIntermediateStorageLevel
18+
with WithLocalCheckpoints
19+
with WithCheckpointInterval {
20+
import DetectingCycles._
21+
def run(): DataFrame = {
22+
val rawRes = DetectingCycles.run(
23+
graph,
24+
useLocalCheckpoints,
25+
checkpointInterval,
26+
intermediateStorageLevel)
27+
val explodedRes = rawRes
28+
.select(
29+
col(GraphFrame.ID),
30+
filter(col(foundSeqCol), x => size(x) > lit(0)).alias(foundSeqCol))
31+
.filter(size(col(foundSeqCol)) > lit(0))
32+
.select(
33+
col(GraphFrame.ID),
34+
// from vid -> [[cycle1, cycle2, ...]]
35+
// to vid -> [cycle1], vid -> [cycle2], ...
36+
explode(col(foundSeqCol)).alias(foundSeqCol))
37+
.persist(intermediateStorageLevel)
38+
explodedRes.count()
39+
resultIsPersistent()
40+
rawRes.unpersist()
41+
explodedRes
42+
}
43+
}
44+
45+
object DetectingCycles {
46+
private val storedSeqCol: String = "sequences"
47+
val foundSeqCol: String = "found_cycles"
48+
49+
def run(
50+
graph: GraphFrame,
51+
useLocalCheckpoints: Boolean,
52+
checkpointInterval: Int,
53+
intermediateStorageLevel: StorageLevel): DataFrame = {
54+
val preparedGraph = GraphFrame(
55+
graph.vertices.select(GraphFrame.ID),
56+
graph.edges.select(GraphFrame.SRC, GraphFrame.DST))
57+
58+
val vertexDT = preparedGraph.vertices.schema(GraphFrame.ID).dataType
59+
60+
// Each vertex stores sequences from the previous iteration, initial is just Array(Array(ID))
61+
val initSequences = array(array(col(GraphFrame.ID)))
62+
// Each vertex stores all the found cycles
63+
val foundSequences = array().cast(ArrayType(ArrayType(vertexDT)))
64+
// Message is simply stored sequences
65+
val sentMessages = when(size(Pregel.src(storedSeqCol)) =!= lit(0), Pregel.src(storedSeqCol))
66+
.otherwise(lit(null).cast(ArrayType(ArrayType(vertexDT))))
67+
// If the sequence contains the current vertex ID somewhere in the middle, it is
68+
// a previously detected cycle and a sequence should be discarded.
69+
val filterOutSequences = flatten(collect_list(Pregel.msg))
70+
when(Pregel.msg.isNull, array(array()).cast(ArrayType(ArrayType(vertexDT))))
71+
.otherwise(filter(Pregel.msg, x => !(array_position(x, col(GraphFrame.ID)) > lit(1))))
72+
// update found sequences by appending all from messages that start from the current vertex ID
73+
val updateFound = when(Pregel.msg.isNull, col(foundSeqCol)).otherwise(
74+
array_union(
75+
col(foundSeqCol),
76+
transform(
77+
filter(Pregel.msg, x => try_element_at(x, lit(1)) === col(GraphFrame.ID)),
78+
x => array_append(x, col(GraphFrame.ID)))))
79+
// update stored sequences by filtering out already added sequences
80+
val updateSequences = transform(
81+
filter(Pregel.msg, x => !array_contains(x, col(GraphFrame.ID))),
82+
x => array_append(x, col(GraphFrame.ID)))
83+
84+
preparedGraph.pregel
85+
.setCheckpointInterval(checkpointInterval)
86+
.setUseLocalCheckpoints(useLocalCheckpoints)
87+
.setIntermediateStorageLevel(intermediateStorageLevel)
88+
.setEarlyStopping(false)
89+
.setSkipMessagesFromNonActiveVertices(true)
90+
.setInitialActiveVertexExpression(lit(true))
91+
.sendMsgToDst(sentMessages)
92+
.setUpdateActiveVertexExpression(Pregel.msg.isNotNull && (size(updateSequences) > lit(0)))
93+
.withVertexColumn(storedSeqCol, initSequences, updateSequences)
94+
.withVertexColumn(foundSeqCol, foundSequences, updateFound)
95+
.aggMsgs(filterOutSequences)
96+
.run()
97+
}
98+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package org.graphframes.lib
2+
3+
import org.graphframes.GraphFrame
4+
import org.graphframes.GraphFrameTestSparkContext
5+
import org.graphframes.SparkFunSuite
6+
7+
import scala.annotation.nowarn
8+
import scala.collection.mutable
9+
10+
class DetectingCyclesSuite extends SparkFunSuite with GraphFrameTestSparkContext {
11+
test("test detecting cycles") {
12+
val graph = GraphFrame(
13+
spark
14+
.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"), (4L, "d"), (5L, "e")))
15+
.toDF("id", "attr"),
16+
spark
17+
.createDataFrame(Seq((1L, 2L), (2L, 3L), (3L, 1L), (1L, 4L), (2L, 5L)))
18+
.toDF("src", "dst"))
19+
val res = graph.detectingCycles.setUseLocalCheckpoints(true).run()
20+
assert(res.count() == 3)
21+
@nowarn val collected =
22+
res
23+
.sort(GraphFrame.ID)
24+
.select(DetectingCycles.foundSeqCol)
25+
.collect()
26+
.map(r => r.getAs[mutable.WrappedArray[Long]](0))
27+
28+
assert(collected(0) == Seq(1, 2, 3, 1))
29+
assert(collected(1) == Seq(2, 3, 1, 2))
30+
assert(collected(2) == Seq(3, 1, 2, 3))
31+
}
32+
33+
test("test no cycles") {
34+
val graph = GraphFrame(
35+
spark
36+
.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"), (4L, "d"), (5L, "e")))
37+
.toDF("id", "attr"),
38+
spark
39+
.createDataFrame(Seq((1L, 2L), (2L, 3L), (3L, 4L), (4L, 5L)))
40+
.toDF("src", "dst"))
41+
val res = graph.detectingCycles.setUseLocalCheckpoints(true).run()
42+
assert(res.count() == 0)
43+
}
44+
45+
test("test multiple cycles from one source") {
46+
val graph = GraphFrame(
47+
spark
48+
.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"), (4L, "d"), (5L, "e")))
49+
.toDF("id", "attr"),
50+
spark
51+
.createDataFrame(Seq((1L, 2L), (2L, 1L), (1L, 3L), (3L, 1L), (2L, 5L), (5L, 1L)))
52+
.toDF("src", "dst"))
53+
val res = graph.detectingCycles.setUseLocalCheckpoints(true).run()
54+
assert(res.count() == 7)
55+
@nowarn val collected =
56+
res
57+
.sort(GraphFrame.ID, DetectingCycles.foundSeqCol)
58+
.select(DetectingCycles.foundSeqCol)
59+
.collect()
60+
.map(r => r.getAs[mutable.WrappedArray[Long]](0))
61+
assert(collected(0) == Seq(1, 2, 1))
62+
assert(collected(1) == Seq(1, 2, 5, 1))
63+
assert(collected(2) == Seq(1, 3, 1))
64+
assert(collected(3) == Seq(2, 1, 2))
65+
assert(collected(4) == Seq(2, 5, 1, 2))
66+
assert(collected(5) == Seq(3, 1, 3))
67+
assert(collected(6) == Seq(5, 1, 2, 5))
68+
}
69+
}

docs/src/04-user-guide/05-traversals.md

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
## Shortest paths
44

5-
Computes shortest paths from each vertex to the given set of landmark vertices, where landmarks are specified by the vertex ID. Note that this takes an edge direction into account.
5+
Computes shortest paths from each vertex to the given set of landmark vertices, where landmarks are specified by the
6+
vertex ID. Note that this takes an edge direction into account.
67

78
See [Wikipedia](https://en.wikipedia.org/wiki/Shortest_path_problem) for a background.
89

@@ -34,7 +35,8 @@ results.select("id", "distances").show()
3435

3536
## Breadth-first search (BFS)
3637

37-
Breadth-first search (BFS) finds the shortest path(s) from one vertex (or a set of vertices) to another vertex (or a set of vertices). The beginning and end vertices are specified as Spark DataFrame expressions.
38+
Breadth-first search (BFS) finds the shortest path(s) from one vertex (or a set of vertices) to another vertex (or a set
39+
of vertices). The beginning and end vertices are specified as Spark DataFrame expressions.
3840

3941
See [Wikipedia on BFS](https://en.wikipedia.org/wiki/Breadth-first_search) for more background.
4042

@@ -54,16 +56,16 @@ paths.show()
5456

5557
# Specify edge filters or max path lengths
5658

57-
g.bfs("name = 'Esther'", "age < 32",\
58-
edgeFilter="relationship != 'friend'", maxPathLength=3)
59+
g.bfs("name = 'Esther'", "age < 32",
60+
edgeFilter="relationship != 'friend'", maxPathLength=3)
5961
```
6062

6163
### Scala API
6264

6365
For API details, refer to @:scaladoc(org.graphframes.lib.BFS).
6466

6567
```scala
66-
import org.graphframes.{examples,GraphFrame}
68+
import org.graphframes.{examples, GraphFrame}
6769

6870
val g: GraphFrame = examples.Graphs.friends // get example graph
6971

@@ -72,9 +74,11 @@ val paths = g.bfs.fromExpr("name = 'Esther'").toExpr("age < 32").run()
7274
paths.show()
7375

7476
// Specify edge filters or max path lengths.
75-
val paths = { g.bfs.fromExpr("name = 'Esther'").toExpr("age < 32")
76-
.edgeFilter("relationship != 'friend'")
77-
.maxPathLength(3).run() }
77+
val paths = {
78+
g.bfs.fromExpr("name = 'Esther'").toExpr("age < 32")
79+
.edgeFilter("relationship != 'friend'")
80+
.maxPathLength(3).run()
81+
}
7882
paths.show()
7983
```
8084

@@ -84,7 +88,12 @@ Computes the connected component membership of each vertex and returns a graph w
8488

8589
See [Wikipedia](https://en.wikipedia.org/wiki/Connected_component_(graph_theory)) for the background.
8690

87-
**NOTE:** With GraphFrames 0.3.0 and later releases, the default Connected Components algorithm requires setting a Spark checkpoint directory. Users can revert to the old algorithm using `connectedComponents.setAlgorithm("graphx")`. Starting from GraphFrames 0.9.3 release, users can also use `localCheckpoints` that does not require setting a Spark checkpoint directory. To use `localCheckpoints` users can set the config `spark.graphframes.useLocalCheckpoints` to `true` or use the API `connectedComponents.setUseLocalCheckpoints(true)`. While `localCheckpoints` provides better performance they are not as reliable as the persistent checkpointing.
91+
**NOTE:** With GraphFrames 0.3.0 and later releases, the default Connected Components algorithm requires setting a Spark
92+
checkpoint directory. Users can revert to the old algorithm using `connectedComponents.setAlgorithm("graphx")`. Starting
93+
from GraphFrames 0.9.3 release, users can also use `localCheckpoints` that does not require setting a Spark checkpoint
94+
directory. To use `localCheckpoints` users can set the config `spark.graphframes.useLocalCheckpoints` to `true` or use
95+
the API `connectedComponents.setUseLocalCheckpoints(true)`. While `localCheckpoints` provides better performance they
96+
are not as reliable as the persistent checkpointing.
8897

8998
### Python API
9099

@@ -106,7 +115,7 @@ result.select("id", "component").orderBy("component").show()
106115
For API details, refer to the @:scaladoc(org.graphframes.lib.ConnectedComponents).
107116

108117
```scala
109-
import org.graphframes.{examples,GraphFrame}
118+
import org.graphframes.{examples, GraphFrame}
110119

111120
val g: GraphFrame = examples.Graphs.friends // get example graph
112121

@@ -116,7 +125,8 @@ result.select("id", "component").orderBy("component").show()
116125

117126
### Strongly connected components
118127

119-
Compute the strongly connected component (SCC) of each vertex and return a graph with each vertex assigned to the SCC containing that vertex. At the moment, SCC in GraphFrames is a wrapper around GraphX implementation.
128+
Compute the strongly connected component (SCC) of each vertex and return a graph with each vertex assigned to the SCC
129+
containing that vertex. At the moment, SCC in GraphFrames is a wrapper around GraphX implementation.
120130

121131
See [Wikipedia](https://en.wikipedia.org/wiki/Strongly_connected_component) for the background.
122132

@@ -140,7 +150,7 @@ result.select("id", "component").orderBy("component").show()
140150
For API details, refer to the @:scaladoc(org.graphframes.lib.StronglyConnectedComponents).
141151

142152
```scala
143-
import org.graphframes.{examples,GraphFrame}
153+
import org.graphframes.{examples, GraphFrame}
144154

145155
val g: GraphFrame = examples.Graphs.friends // get example graph
146156

@@ -177,3 +187,40 @@ val g: GraphFrame = examples.Graphs.friends // get example graph
177187
val results = g.triangleCount.run()
178188
results.select("id", "count").show()
179189
```
190+
191+
## Cycles Detection
192+
193+
GraphFrames provides an implementation of
194+
the [Rocha–Thatte cycle detection algorithm](https://en.wikipedia.org/wiki/Rocha%E2%80%93Thatte_cycle_detection_algorithm).
195+
196+
### Scala API
197+
198+
```scala
199+
import org.graphframes.GraphFrame
200+
201+
val graph = GraphFrame(
202+
spark
203+
.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"), (4L, "d"), (5L, "e")))
204+
.toDF("id", "attr"),
205+
spark
206+
.createDataFrame(Seq((1L, 2L), (2L, 1L), (1L, 3L), (3L, 1L), (2L, 5L), (5L, 1L)))
207+
.toDF("src", "dst"))
208+
val res = graph.detectingCycles.setUseLocalCheckpoints(true).run()
209+
res.show(false)
210+
211+
// Output:
212+
// +----+--------------+
213+
// | id | found_cycles |
214+
// +----+--------------+
215+
// |1 |[1, 3, 1] |
216+
// |1 |[1, 2, 1] |
217+
// |1 |[1, 2, 5, 1] |
218+
// |2 |[2, 1, 2] |
219+
// |2 |[2, 5, 1, 2] |
220+
// |3 |[3, 1, 3] |
221+
// |5 |[5, 1, 2, 5] |
222+
// +----+--------------+
223+
```
224+
225+
**WARNING:** This algorithm returns all the cycles, and users should handle deduplication of [1, 2, 1] and [2, 1, 2] (
226+
that is the same cycle)

0 commit comments

Comments
 (0)