diff --git a/.gitignore b/.gitignore index 4d0a174e7..8554743bd 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,8 @@ project/boot/ project/plugins/project/ .bsp .metals +metals.sbt +.bloop # intellij .idea/ @@ -25,6 +27,9 @@ project/plugins/project/ # VSCode .vscode +# Helix +.helix + # Mac *.DS_Store diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 34daf0717..dd5f59647 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DecimalType import org.apache.spark.storage.StorageLevel +import org.graphframes.WithAlgorithmChoice /** * Connected Components algorithm. @@ -40,11 +41,13 @@ import org.apache.spark.storage.StorageLevel */ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) extends Arguments - with Logging { + with Logging + with WithAlgorithmChoice { import org.graphframes.lib.ConnectedComponents._ private var broadcastThreshold: Int = 1000000 + setAlgorithm(ALGO_GRAPHFRAMES) /** * Sets broadcast threshold in propagating component assignments (default: 1000000). If a node @@ -71,34 +74,6 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) */ def getBroadcastThreshold: Int = broadcastThreshold - private var algorithm: String = ALGO_GRAPHFRAMES - - /** - * Sets the connected components algorithm to use (default: "graphframes"). Supported algorithms - * are: - * - "graphframes": Uses alternating large star and small star iterations proposed in - * [[http://dx.doi.org/10.1145/2670979.2670997 Connected Components in MapReduce and Beyond]] - * with skewed join optimization. - * - "graphx": Converts the graph to a GraphX graph and then uses the connected components - * implementation in GraphX. - * @see - * [[org.graphframes.lib.ConnectedComponents.supportedAlgorithms]] - */ - def setAlgorithm(value: String): this.type = { - require( - supportedAlgorithms.contains(value), - s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $value.") - algorithm = value - this - } - - /** - * Gets the connected component algorithm to use. - * @see - * [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]. - */ - def getAlgorithm: String = algorithm - private var checkpointInterval: Int = 2 /** @@ -159,7 +134,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) def run(): DataFrame = { ConnectedComponents.run( graph, - algorithm = algorithm, + runInGraphX = algorithm == ALGO_GRAPHX, broadcastThreshold = broadcastThreshold, checkpointInterval = checkpointInterval, intermediateStorageLevel = intermediateStorageLevel) @@ -176,15 +151,6 @@ object ConnectedComponents extends Logging { private val CNT = "cnt" private val CHECKPOINT_NAME_PREFIX = "connected-components" - private val ALGO_GRAPHX = "graphx" - private val ALGO_GRAPHFRAMES = "graphframes" - - /** - * Supported algorithms in [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]: - * "graphframes" and "graphx". - */ - private val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES) - /** * Returns the symmetric directed graph of the graph specified by input edges. * @param ee @@ -279,15 +245,11 @@ object ConnectedComponents extends Logging { private def run( graph: GraphFrame, - algorithm: String, + runInGraphX: Boolean, broadcastThreshold: Int, checkpointInterval: Int, intermediateStorageLevel: StorageLevel): DataFrame = { - require( - supportedAlgorithms.contains(algorithm), - s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $algorithm.") - - if (algorithm == ALGO_GRAPHX) { + if (runInGraphX) { return runGraphX(graph) } diff --git a/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/src/main/scala/org/graphframes/lib/ShortestPaths.scala index a376e2b5a..6e8358f3b 100644 --- a/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -24,10 +24,12 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.graphx.{lib => graphxlib} import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.api.java.UDF1 -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.{col, udf, map, lit, when, map_zip_with, reduce, map_values, transform_values, collect_list} import org.apache.spark.sql.types.{IntegerType, MapType} import org.graphframes.GraphFrame +import org.graphframes.Logging +import org.graphframes.WithAlgorithmChoice import org.graphframes.GraphFrame.quote /** @@ -39,7 +41,11 @@ import org.graphframes.GraphFrame.quote * - distances (`MapType[vertex ID type, IntegerType]`): For each vertex v, a map containing the * shortest-path distance to each reachable landmark vertex. */ -class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends Arguments { +class ShortestPaths private[graphframes] (private val graph: GraphFrame) + extends Arguments + with WithAlgorithmChoice { + import org.graphframes.lib.ShortestPaths._ + private var lmarks: Option[Seq[Any]] = None /** @@ -59,13 +65,17 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends } def run(): DataFrame = { - ShortestPaths.run(graph, check(lmarks, "landmarks")) + val lmarksChecked = check(lmarks, "landmarks") + algorithm match { + case ALGO_GRAPHX => runInGraphX(graph, lmarksChecked) + case ALGO_GRAPHFRAMES => runInGraphFrames(graph, lmarksChecked) + } } } -private object ShortestPaths { +private object ShortestPaths extends Logging { - private def run(graph: GraphFrame, landmarks: Seq[Any]): DataFrame = { + private def runInGraphX(graph: GraphFrame, landmarks: Seq[Any]): DataFrame = { val idType = graph.vertices.schema(GraphFrame.ID).dataType val longIdToLandmark = landmarks.map(l => GraphXConversions.integralId(graph, l) -> l).toMap val gx = graphxlib.ShortestPaths @@ -95,6 +105,108 @@ private object ShortestPaths { g.vertices.select(cols.toSeq: _*) } - private val DISTANCE_ID = "distances" + private def runInGraphFrames( + graph: GraphFrame, + landmarks: Seq[Any], + isDirected: Boolean = true): DataFrame = { + logWarn("The GraphFrames based implementation is slow and considered experimental!") + val vertexType = graph.vertices.schema(GraphFrame.ID).dataType + + // For landmark vertices the initial distance to itself is set to 0 + // Example: graph with vertices a, b, c, d; landmarks = (c, d) + // we shoudl init the following: + // (a, Map()), (b, Map()), (c, Map(c -> 0)), (d, Map(d -> 0)) + // + // Inside the following function it is done by applying multiple case-when + // because we know exactly that only one landmark could be equal to the nodeId. + // For example, for vertex c it will be: + // when(id == "a", Map(a -> 0)) + // .when(id == "b", Map(b -> 0)) + // .when(id == "c", Map(c -> 0)) --> this one is the only true + // .when(id == "d", Map(d -> 0)) + def initDistancesMap(vertexId: Column): Column = { + val firstLmarkCol = lit(landmarks.head) + var initCol = when(vertexId === firstLmarkCol, map(firstLmarkCol, lit(0))) + for (lmark <- landmarks.tail) { + initCol = initCol.when(vertexId === lit(lmark), map(lit(lmark), lit(0))) + } + initCol + } + + // Concatenations of two distance maps: + // If one map is null just take another. + // In case both maps are not null: + // - iterate over keys + // - if value in the left map is null or greater than value from the right map take right one + // else take left one + def concatMaps(distancesLeft: Column, distancesRight: Column): Column = + when(distancesLeft.isNull, distancesRight) + .when(distancesRight.isNull, distancesLeft) + .otherwise(map_zip_with( + distancesLeft, + distancesRight, + (_, leftDistance, rightDistance) => { + when(leftDistance.isNull || (leftDistance > rightDistance), rightDistance) + .otherwise(leftDistance) + })) + + // If distance is null, result of d + 1 will be null too + def incrementDistances(distancesMap: Column): Column = + transform_values(distancesMap, (_, distance) => distance + lit(1)) + + // Takes an array of distance maps and reduce them with concatMaps + def aggregateArrayOfDistanceMaps(arrayCol: Column): Column = + reduce(arrayCol, lit(null).cast(MapType(vertexType, IntegerType)), concatMaps) + // Checks that a sent distances map can change the destination distances. + // Evaluation would be "true" in case in the new distances map + // for one of keys present a non-null value but in the old distances map it is null + // or new distance is less than old one. + def isDistanceImprovedWithMessage(newMap: Column, oldMap: Column): Column = reduce( + map_values( + map_zip_with( + newMap, + oldMap, + (_, newDistance, rightDistance) => + (newDistance.isNotNull && rightDistance.isNull) || (newDistance < rightDistance))), + lit(false), + (left, right) => left || right) + + val srcDistanceCol = Pregel.src(DISTANCE_ID) + val dstDistanceCol = Pregel.dst(DISTANCE_ID) + + // Overall: + // 1. Initialize distances + // 2. If new message can improve distances send it + // 3. Collect and aggregate messages + val pregel = graph.pregel + .setMaxIter(Int.MaxValue) // That is how the GraphX implementation works + .withVertexColumn( + DISTANCE_ID, + when(col(GraphFrame.ID).isInCollection(landmarks), initDistancesMap(col(GraphFrame.ID))) + .otherwise(map().cast(MapType(vertexType, IntegerType))), + concatMaps(col(DISTANCE_ID), Pregel.msg)) + .sendMsgToSrc(when( + isDistanceImprovedWithMessage(incrementDistances(dstDistanceCol), srcDistanceCol), + incrementDistances(dstDistanceCol))) + .aggMsgs(aggregateArrayOfDistanceMaps(collect_list(Pregel.msg))) + .setEarlyStopping(true) + + // Experimental feature + if (isDirected) { + pregel.run() + } else { + // For consider edges as undirected, + // it is enough to send messages in both directions + pregel + .sendMsgToDst( + when( + isDistanceImprovedWithMessage(incrementDistances(srcDistanceCol), dstDistanceCol), + incrementDistances(srcDistanceCol))) + .run() + } + + } + + private val DISTANCE_ID = "distances" } diff --git a/src/main/scala/org/graphframes/mixins.scala b/src/main/scala/org/graphframes/mixins.scala new file mode 100644 index 000000000..88a348bb0 --- /dev/null +++ b/src/main/scala/org/graphframes/mixins.scala @@ -0,0 +1,18 @@ +package org.graphframes + +private[graphframes] trait WithAlgorithmChoice { + protected val ALGO_GRAPHX = "graphx" + protected val ALGO_GRAPHFRAMES = "graphframes" + protected var algorithm: String = ALGO_GRAPHX + val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES) + + def setAlgorithm(value: String): this.type = { + require( + supportedAlgorithms.contains(value), + s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $value.") + algorithm = value + this + } + + def getAlgorithm: String = algorithm +} diff --git a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala index 85f211d85..23f0f92ff 100644 --- a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala @@ -71,6 +71,39 @@ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(results === expected) } + test("Simple test with GraphFrames") { + val edgeSeq = Seq((1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)) + .flatMap { case e => + Seq(e, e.swap) + } + .map { case (src, dst) => (src.toLong, dst.toLong) } + val edges = spark.createDataFrame(edgeSeq).toDF("src", "dst") + val graph = GraphFrame.fromEdges(edges) + + // Ground truth + val shortestPaths = Set( + (1, Map(1 -> 0, 4 -> 2)), + (2, Map(1 -> 1, 4 -> 2)), + (3, Map(1 -> 2, 4 -> 1)), + (4, Map(1 -> 2, 4 -> 0)), + (5, Map(1 -> 1, 4 -> 1)), + (6, Map(1 -> 3, 4 -> 1))) + + val landmarks = Seq(1, 4).map(_.toLong) + val v2 = graph.shortestPaths.landmarks(landmarks).setAlgorithm("graphframes").run() + + TestUtils.testSchemaInvariants(graph, v2) + TestUtils.checkColumnType( + v2.schema, + "distances", + DataTypes.createMapType(v2.schema("id").dataType, DataTypes.IntegerType, true)) + val newVs = v2.select("id", "distances").collect().toSeq + val results = newVs.map { case Row(id: Long, spMap: Map[Long, Int] @unchecked) => + (id, spMap) + } + assert(results.toSet === shortestPaths) + } + test("friends graph") { val friends = examples.Graphs.friends val v = friends.shortestPaths.landmarks(Seq("a", "d")).run() @@ -92,6 +125,26 @@ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(results === expected) } + test("friends graph with GraphFrames") { + val friends = examples.Graphs.friends + val v = friends.shortestPaths.landmarks(Seq("a", "d")).setAlgorithm("graphframes").run() + val expected = Set[(String, Map[String, Int])]( + ("a", Map("a" -> 0, "d" -> 2)), + ("b", Map.empty), + ("c", Map.empty), + ("d", Map("a" -> 1, "d" -> 0)), + ("e", Map("a" -> 2, "d" -> 1)), + ("f", Map.empty), + ("g", Map.empty)) + val results = v + .select("id", "distances") + .collect() + .map { case Row(id: String, spMap: Map[String, Int] @unchecked) => + (id, spMap) + } + .toSet + assert(results === expected) + } test("Test vertices with column name") { val verticeSeq = Seq((1L, "one"), (2L, "two"), (3L, "three"), (4L, "four"), (5L, "five"), (6L, "six"))