Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ lazy val commonSetting = Seq(
ScalacOptions.warnUnusedPrivates,
ScalacOptions.warnNumericWiden,
ScalacOptions.privateWarnNumericWiden,
ScalacOptions.warnUnusedNoWarn,
ScalacOptions.privateWarnUnusedNoWarn,
))

lazy val graphx = (project in file("graphx"))
Expand Down
19 changes: 19 additions & 0 deletions core/src/main/scala/org/graphframes/GraphFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,25 @@ class GraphFrame private (
}
}

/**
* Find all cycles in the graph. An implementation of the Rocha–Thatte cycle detection
* algorithm.
*
* Rocha, Rodrigo Caetano, and Bhalchandra D. Thatte. "Distributed cycle detection in
* large-scale sparse graphs." Proceedings of Simpósio Brasileiro de Pesquisa Operacional
* (SBPO’15) (2015): 1-11.
*
* Returns a DataFrame with ID and cycles, ID are not unique if there are multiple cycles
* starting from this ID. For the case of cycle 1 -> 2 -> 3 -> 1 all the vertices will have the
* same cycle! E.g.: 1 -> [1, 2, 3, 1] 2 -> [2, 3, 1, 2] 3 -> [3, 1, 2, 3]
*
* Deduplication of cycles should be done by the user!
*
* @return
* an instance of DetectingCycles initialized with the current context
*/
def detectingCycles: DetectingCycles = new DetectingCycles(this)

// ========= Motif finding (private) =========

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ import org.graphframes.WithIntermediateStorageLevel
* - [[AggregateMessages.msg]]: message sent to vertex (for aggregation function)
*
* Note: If you use this operation to write an iterative algorithm, you may want to use
* [[AggregateMessages$.getCachedDataFrame getCachedDataFrame()]] as a workaround for caching
* issues.
* `org.apache.spark.sql.Dataset.checkpoint` or `org.apache.spark.sql.Dataset.localCheckpoint`` as
* a workaround for caching issues.
*
* @example
* We can use this function to compute the in-degree of each vertex
Expand Down
98 changes: 98 additions & 0 deletions core/src/main/scala/org/graphframes/lib/DetectingCycles.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package org.graphframes.lib

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.storage.StorageLevel
import org.graphframes.GraphFrame
import org.graphframes.Logging
import org.graphframes.WithCheckpointInterval
import org.graphframes.WithIntermediateStorageLevel
import org.graphframes.WithLocalCheckpoints

class DetectingCycles private[graphframes] (private val graph: GraphFrame)
extends Arguments
with Serializable
with Logging
with WithIntermediateStorageLevel
with WithLocalCheckpoints
with WithCheckpointInterval {
import DetectingCycles._
def run(): DataFrame = {
val rawRes = DetectingCycles.run(
graph,
useLocalCheckpoints,
checkpointInterval,
intermediateStorageLevel)
val explodedRes = rawRes
.select(
col(GraphFrame.ID),
filter(col(foundSeqCol), x => size(x) > lit(0)).alias(foundSeqCol))
.filter(size(col(foundSeqCol)) > lit(0))
.select(
col(GraphFrame.ID),
// from vid -> [[cycle1, cycle2, ...]]
// to vid -> [cycle1], vid -> [cycle2], ...
explode(col(foundSeqCol)).alias(foundSeqCol))
.persist(intermediateStorageLevel)
explodedRes.count()
resultIsPersistent()
rawRes.unpersist()
explodedRes
}
}

object DetectingCycles {
private val storedSeqCol: String = "sequences"
val foundSeqCol: String = "found_cycles"

def run(
graph: GraphFrame,
useLocalCheckpoints: Boolean,
checkpointInterval: Int,
intermediateStorageLevel: StorageLevel): DataFrame = {
val preparedGraph = GraphFrame(
graph.vertices.select(GraphFrame.ID),
graph.edges.select(GraphFrame.SRC, GraphFrame.DST))

val vertexDT = preparedGraph.vertices.schema(GraphFrame.ID).dataType

// Each vertex stores sequences from the previous iteration, initial is just Array(Array(ID))
val initSequences = array(array(col(GraphFrame.ID)))
// Each vertex stores all the found cycles
val foundSequences = array().cast(ArrayType(ArrayType(vertexDT)))
// Message is simply stored sequences
val sentMessages = when(size(Pregel.src(storedSeqCol)) =!= lit(0), Pregel.src(storedSeqCol))
.otherwise(lit(null).cast(ArrayType(ArrayType(vertexDT))))
// If the sequence contains the current vertex ID somewhere in the middle, it is
// a previously detected cycle and a sequence should be discarded.
val filterOutSequences = flatten(collect_list(Pregel.msg))
when(Pregel.msg.isNull, array(array()).cast(ArrayType(ArrayType(vertexDT))))
.otherwise(filter(Pregel.msg, x => !(array_position(x, col(GraphFrame.ID)) > lit(1))))
// update found sequences by appending all from messages that start from the current vertex ID
val updateFound = when(Pregel.msg.isNull, col(foundSeqCol)).otherwise(
array_union(
col(foundSeqCol),
transform(
filter(Pregel.msg, x => try_element_at(x, lit(1)) === col(GraphFrame.ID)),
x => array_append(x, col(GraphFrame.ID)))))
// update stored sequences by filtering out already added sequences
val updateSequences = transform(
filter(Pregel.msg, x => !array_contains(x, col(GraphFrame.ID))),
x => array_append(x, col(GraphFrame.ID)))

preparedGraph.pregel
.setCheckpointInterval(checkpointInterval)
.setUseLocalCheckpoints(useLocalCheckpoints)
.setIntermediateStorageLevel(intermediateStorageLevel)
.setEarlyStopping(false)
.setSkipMessagesFromNonActiveVertices(true)
.setInitialActiveVertexExpression(lit(true))
.sendMsgToDst(sentMessages)
.setUpdateActiveVertexExpression(Pregel.msg.isNotNull && (size(updateSequences) > lit(0)))
.withVertexColumn(storedSeqCol, initSequences, updateSequences)
.withVertexColumn(foundSeqCol, foundSequences, updateFound)
.aggMsgs(filterOutSequences)
.run()
}
}
69 changes: 69 additions & 0 deletions core/src/test/scala/org/graphframes/lib/DetectingCyclesSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package org.graphframes.lib

import org.graphframes.GraphFrame
import org.graphframes.GraphFrameTestSparkContext
import org.graphframes.SparkFunSuite

import scala.annotation.nowarn
import scala.collection.mutable

class DetectingCyclesSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("test detecting cycles") {
val graph = GraphFrame(
spark
.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"), (4L, "d"), (5L, "e")))
.toDF("id", "attr"),
spark
.createDataFrame(Seq((1L, 2L), (2L, 3L), (3L, 1L), (1L, 4L), (2L, 5L)))
.toDF("src", "dst"))
val res = graph.detectingCycles.setUseLocalCheckpoints(true).run()
assert(res.count() == 3)
@nowarn val collected =
res
.sort(GraphFrame.ID)
.select(DetectingCycles.foundSeqCol)
.collect()
.map(r => r.getAs[mutable.WrappedArray[Long]](0))

assert(collected(0) == Seq(1, 2, 3, 1))
assert(collected(1) == Seq(2, 3, 1, 2))
assert(collected(2) == Seq(3, 1, 2, 3))
}

test("test no cycles") {
val graph = GraphFrame(
spark
.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"), (4L, "d"), (5L, "e")))
.toDF("id", "attr"),
spark
.createDataFrame(Seq((1L, 2L), (2L, 3L), (3L, 4L), (4L, 5L)))
.toDF("src", "dst"))
val res = graph.detectingCycles.setUseLocalCheckpoints(true).run()
assert(res.count() == 0)
}

test("test multiple cycles from one source") {
val graph = GraphFrame(
spark
.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"), (4L, "d"), (5L, "e")))
.toDF("id", "attr"),
spark
.createDataFrame(Seq((1L, 2L), (2L, 1L), (1L, 3L), (3L, 1L), (2L, 5L), (5L, 1L)))
.toDF("src", "dst"))
val res = graph.detectingCycles.setUseLocalCheckpoints(true).run()
assert(res.count() == 7)
@nowarn val collected =
res
.sort(GraphFrame.ID, DetectingCycles.foundSeqCol)
.select(DetectingCycles.foundSeqCol)
.collect()
.map(r => r.getAs[mutable.WrappedArray[Long]](0))
assert(collected(0) == Seq(1, 2, 1))
assert(collected(1) == Seq(1, 2, 5, 1))
assert(collected(2) == Seq(1, 3, 1))
assert(collected(3) == Seq(2, 1, 2))
assert(collected(4) == Seq(2, 5, 1, 2))
assert(collected(5) == Seq(3, 1, 3))
assert(collected(6) == Seq(5, 1, 2, 5))
}
}
71 changes: 59 additions & 12 deletions docs/src/04-user-guide/05-traversals.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## Shortest paths

Computes shortest paths from each vertex to the given set of landmark vertices, where landmarks are specified by the vertex ID. Note that this takes an edge direction into account.
Computes shortest paths from each vertex to the given set of landmark vertices, where landmarks are specified by the
vertex ID. Note that this takes an edge direction into account.

See [Wikipedia](https://en.wikipedia.org/wiki/Shortest_path_problem) for a background.

Expand Down Expand Up @@ -34,7 +35,8 @@ results.select("id", "distances").show()

## Breadth-first search (BFS)

Breadth-first search (BFS) finds the shortest path(s) from one vertex (or a set of vertices) to another vertex (or a set of vertices). The beginning and end vertices are specified as Spark DataFrame expressions.
Breadth-first search (BFS) finds the shortest path(s) from one vertex (or a set of vertices) to another vertex (or a set
of vertices). The beginning and end vertices are specified as Spark DataFrame expressions.

See [Wikipedia on BFS](https://en.wikipedia.org/wiki/Breadth-first_search) for more background.

Expand All @@ -54,16 +56,16 @@ paths.show()

# Specify edge filters or max path lengths

g.bfs("name = 'Esther'", "age < 32",\
edgeFilter="relationship != 'friend'", maxPathLength=3)
g.bfs("name = 'Esther'", "age < 32",
edgeFilter="relationship != 'friend'", maxPathLength=3)
```

### Scala API

For API details, refer to @:scaladoc(org.graphframes.lib.BFS).

```scala
import org.graphframes.{examples,GraphFrame}
import org.graphframes.{examples, GraphFrame}

val g: GraphFrame = examples.Graphs.friends // get example graph

Expand All @@ -72,9 +74,11 @@ val paths = g.bfs.fromExpr("name = 'Esther'").toExpr("age < 32").run()
paths.show()

// Specify edge filters or max path lengths.
val paths = { g.bfs.fromExpr("name = 'Esther'").toExpr("age < 32")
.edgeFilter("relationship != 'friend'")
.maxPathLength(3).run() }
val paths = {
g.bfs.fromExpr("name = 'Esther'").toExpr("age < 32")
.edgeFilter("relationship != 'friend'")
.maxPathLength(3).run()
}
paths.show()
```

Expand All @@ -84,7 +88,12 @@ Computes the connected component membership of each vertex and returns a graph w

See [Wikipedia](https://en.wikipedia.org/wiki/Connected_component_(graph_theory)) for the background.

**NOTE:** With GraphFrames 0.3.0 and later releases, the default Connected Components algorithm requires setting a Spark checkpoint directory. Users can revert to the old algorithm using `connectedComponents.setAlgorithm("graphx")`. Starting from GraphFrames 0.9.3 release, users can also use `localCheckpoints` that does not require setting a Spark checkpoint directory. To use `localCheckpoints` users can set the config `spark.graphframes.useLocalCheckpoints` to `true` or use the API `connectedComponents.setUseLocalCheckpoints(true)`. While `localCheckpoints` provides better performance they are not as reliable as the persistent checkpointing.
**NOTE:** With GraphFrames 0.3.0 and later releases, the default Connected Components algorithm requires setting a Spark
checkpoint directory. Users can revert to the old algorithm using `connectedComponents.setAlgorithm("graphx")`. Starting
from GraphFrames 0.9.3 release, users can also use `localCheckpoints` that does not require setting a Spark checkpoint
directory. To use `localCheckpoints` users can set the config `spark.graphframes.useLocalCheckpoints` to `true` or use
the API `connectedComponents.setUseLocalCheckpoints(true)`. While `localCheckpoints` provides better performance they
are not as reliable as the persistent checkpointing.

### Python API

Expand All @@ -106,7 +115,7 @@ result.select("id", "component").orderBy("component").show()
For API details, refer to the @:scaladoc(org.graphframes.lib.ConnectedComponents).

```scala
import org.graphframes.{examples,GraphFrame}
import org.graphframes.{examples, GraphFrame}

val g: GraphFrame = examples.Graphs.friends // get example graph

Expand All @@ -116,7 +125,8 @@ result.select("id", "component").orderBy("component").show()

### Strongly connected components

Compute the strongly connected component (SCC) of each vertex and return a graph with each vertex assigned to the SCC containing that vertex. At the moment, SCC in GraphFrames is a wrapper around GraphX implementation.
Compute the strongly connected component (SCC) of each vertex and return a graph with each vertex assigned to the SCC
containing that vertex. At the moment, SCC in GraphFrames is a wrapper around GraphX implementation.

See [Wikipedia](https://en.wikipedia.org/wiki/Strongly_connected_component) for the background.

Expand All @@ -140,7 +150,7 @@ result.select("id", "component").orderBy("component").show()
For API details, refer to the @:scaladoc(org.graphframes.lib.StronglyConnectedComponents).

```scala
import org.graphframes.{examples,GraphFrame}
import org.graphframes.{examples, GraphFrame}

val g: GraphFrame = examples.Graphs.friends // get example graph

Expand Down Expand Up @@ -177,3 +187,40 @@ val g: GraphFrame = examples.Graphs.friends // get example graph
val results = g.triangleCount.run()
results.select("id", "count").show()
```

## Cycles Detection

GraphFrames provides an implementation of
the [Rocha–Thatte cycle detection algorithm](https://en.wikipedia.org/wiki/Rocha%E2%80%93Thatte_cycle_detection_algorithm).

### Scala API
Comment thread
SemyonSinchenko marked this conversation as resolved.

```scala
import org.graphframes.GraphFrame

val graph = GraphFrame(
spark
.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"), (4L, "d"), (5L, "e")))
.toDF("id", "attr"),
spark
.createDataFrame(Seq((1L, 2L), (2L, 1L), (1L, 3L), (3L, 1L), (2L, 5L), (5L, 1L)))
.toDF("src", "dst"))
val res = graph.detectingCycles.setUseLocalCheckpoints(true).run()
res.show(false)

// Output:
// +----+--------------+
// | id | found_cycles |
// +----+--------------+
// |1 |[1, 3, 1] |
// |1 |[1, 2, 1] |
// |1 |[1, 2, 5, 1] |
// |2 |[2, 1, 2] |
// |2 |[2, 5, 1, 2] |
// |3 |[3, 1, 3] |
// |5 |[5, 1, 2, 5] |
// +----+--------------+
```

**WARNING:** This algorithm returns all the cycles, and users should handle deduplication of [1, 2, 1] and [2, 1, 2] (
that is the same cycle)