diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index 1850bba4b..68cd397c8 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -259,7 +259,10 @@ def aggregateMessages(self, aggCol, sendToSrc=None, sendToDst=None): # Standard algorithms def connectedComponents(self, algorithm = "graphframes", checkpointInterval = 2, - broadcastThreshold = 1000000): + broadcastThreshold = 1000000, + intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK, + intermediateGraphxVertexStorageLevel = StorageLevel.MEMORY_ONLY, + intermediateGraphxEdgeStorageLevel = StorageLevel.MEMORY_ONLY): """ Computes the connected components of the graph. @@ -270,6 +273,14 @@ def connectedComponents(self, algorithm = "graphframes", checkpointInterval = 2, :param checkpointInterval: checkpoint interval in terms of number of iterations (default: 2) :param broadcastThreshold: broadcast threshold in propagating component assignments (default: 1000000) + :param intermediateStorageLevel: storage level for intermediate datasets that require + multiple passes (default: ``MEMORY_AND_DISK``) + :param intermediateGraphxVertexStorageLevel: storage level for intermediate `Graph` vertices that + require multiple passes (default: `MEMORY_ONLY`). This parameter is only used when the + algorithm is set to "graphx" + :param intermediateGraphxEdgeStorageLevel: storage level for intermediate `Graph` edges that + require multiple passes (default: `MEMORY_ONLY`). This parameter is only used when the + algorithm is set to "graphx" :return: DataFrame with new vertices column "component" """ @@ -277,23 +288,38 @@ def connectedComponents(self, algorithm = "graphframes", checkpointInterval = 2, .setAlgorithm(algorithm) \ .setCheckpointInterval(checkpointInterval) \ .setBroadcastThreshold(broadcastThreshold) \ + .setIntermediateStorageLevel(self._sc._getJavaStorageLevel(intermediateStorageLevel)) \ + .setIntermediateGraphxVertexStorageLevel(self._sc._getJavaStorageLevel(intermediateGraphxVertexStorageLevel)) \ + .setIntermediateGraphxEdgeStorageLevel(self._sc._getJavaStorageLevel(intermediateGraphxEdgeStorageLevel)) \ .run() return DataFrame(jdf, self._sqlContext) - def labelPropagation(self, maxIter): + def labelPropagation(self, maxIter, intermediateVertexStorageLevel = StorageLevel.MEMORY_ONLY, + intermediateEdgeStorageLevel = StorageLevel.MEMORY_ONLY): """ Runs static label propagation for detecting communities in networks. See Scala documentation for more details. :param maxIter: the number of iterations to be performed + :param intermediateVertexStorageLevel: storage level for intermediate `Graph` vertices that + require multiple passes (default: `MEMORY_ONLY`). + :param intermediateEdgeStorageLevel: storage level for intermediate `Graph` edges that + require multiple passes (default: `MEMORY_ONLY`). :return: DataFrame with new vertices column "label" """ - jdf = self._jvm_graph.labelPropagation().maxIter(maxIter).run() + + javaVertexStorageLevel = self._sc._getJavaStorageLevel(intermediateVertexStorageLevel) + javaEdgeStorageLevel = self._sc._getJavaStorageLevel(intermediateEdgeStorageLevel) + jdf = self._jvm_graph.labelPropagation().maxIter(maxIter) \ + .setIntermediateVertexStorageLevel(javaVertexStorageLevel) \ + .setIntermediateEdgeStorageLevel(javaEdgeStorageLevel) \ + .run() return DataFrame(jdf, self._sqlContext) def pageRank(self, resetProbability = 0.15, sourceId = None, maxIter = None, - tol = None): + tol = None, intermediateVertexStorageLevel = StorageLevel.MEMORY_ONLY, + intermediateEdgeStorageLevel = StorageLevel.MEMORY_ONLY): """ Runs the PageRank algorithm on the graph. Note: Exactly one of fixed_num_iter or tolerance must be set. @@ -306,8 +332,14 @@ def pageRank(self, resetProbability = 0.15, sourceId = None, maxIter = None, of iterations. This may not be set if the `tol` parameter is set. :param tol: If set, the algorithm is run until the given tolerance. This may not be set if the `numIter` parameter is set. + :param intermediateVertexStorageLevel: storage level for intermediate `Graph` vertices that + require multiple passes (default: `MEMORY_ONLY`). + :param intermediateEdgeStorageLevel: storage level for intermediate `Graph` edges that + require multiple passes (default: `MEMORY_ONLY`). :return: GraphFrame with new vertices column "pagerank" and new edges column "weight" """ + javaVertexStorageLevel = self._sc._getJavaStorageLevel(intermediateVertexStorageLevel) + javaEdgeStorageLevel = self._sc._getJavaStorageLevel(intermediateEdgeStorageLevel) builder = self._jvm_graph.pageRank().resetProbability(resetProbability) if sourceId is not None: builder = builder.sourceId(sourceId) @@ -317,11 +349,15 @@ def pageRank(self, resetProbability = 0.15, sourceId = None, maxIter = None, else: assert tol is not None, "Exactly one of maxIter or tol should be set." builder = builder.tol(tol) - jgf = builder.run() + jgf = builder.setIntermediateVertexStorageLevel(javaVertexStorageLevel) \ + .setIntermediateEdgeStorageLevel(javaEdgeStorageLevel) \ + .run() return _from_java_gf(jgf, self._sqlContext) def parallelPersonalizedPageRank(self, resetProbability = 0.15, sourceIds = None, - maxIter = None): + maxIter = None, + intermediateVertexStorageLevel = StorageLevel.MEMORY_ONLY, + intermediateEdgeStorageLevel = StorageLevel.MEMORY_ONLY): """ Run the personalized PageRank algorithm on the graph, from the provided list of sources in parallel for a fixed number of iterations. @@ -331,6 +367,10 @@ def parallelPersonalizedPageRank(self, resetProbability = 0.15, sourceIds = None :param resetProbability: Probability of resetting to a random vertex :param sourceIds: the source vertices for a personalized PageRank :param maxIter: the fixed number of iterations this algorithm runs + :param intermediateVertexStorageLevel: storage level for intermediate `Graph` vertices that + require multiple passes (default: `MEMORY_ONLY`). + :param intermediateEdgeStorageLevel: storage level for intermediate `Graph` edges that + require multiple passes (default: `MEMORY_ONLY`). :return: GraphFrame with new vertices column "pageranks" and new edges column "weight" """ assert sourceIds is not None and len(sourceIds) > 0, "Source vertices Ids sourceIds must be provided" @@ -340,31 +380,55 @@ def parallelPersonalizedPageRank(self, resetProbability = 0.15, sourceIds = None builder = builder.resetProbability(resetProbability) builder = builder.sourceIds(sourceIds) builder = builder.maxIter(maxIter) + javaVertexStorageLevel = self._sc._getJavaStorageLevel(intermediateVertexStorageLevel) + builder = builder.setIntermediateVertexStorageLevel(javaVertexStorageLevel) + javaEdgeStorageLevel = self._sc._getJavaStorageLevel(intermediateEdgeStorageLevel) + builder = builder.setIntermediateEdgeStorageLevel(javaEdgeStorageLevel) jgf = builder.run() return _from_java_gf(jgf, self._sqlContext) - def shortestPaths(self, landmarks): + def shortestPaths(self, landmarks, intermediateVertexStorageLevel = StorageLevel.MEMORY_ONLY, + intermediateEdgeStorageLevel = StorageLevel.MEMORY_ONLY): """ Runs the shortest path algorithm from a set of landmark vertices in the graph. See Scala documentation for more details. :param landmarks: a set of one or more landmarks + :param intermediateVertexStorageLevel: storage level for intermediate `Graph` vertices that + require multiple passes (default: `MEMORY_ONLY`). + :param intermediateEdgeStorageLevel: storage level for intermediate `Graph` edges that + require multiple passes (default: `MEMORY_ONLY`). :return: DataFrame with new vertices column "distances" """ - jdf = self._jvm_graph.shortestPaths().landmarks(landmarks).run() + javaVertexStorageLevel = self._sc._getJavaStorageLevel(intermediateVertexStorageLevel) + javaEdgeStorageLevel = self._sc._getJavaStorageLevel(intermediateEdgeStorageLevel) + jdf = self._jvm_graph.shortestPaths().landmarks(landmarks) \ + .setIntermediateVertexStorageLevel(javaVertexStorageLevel) \ + .setIntermediateEdgeStorageLevel(javaEdgeStorageLevel) \ + .run() return DataFrame(jdf, self._sqlContext) - def stronglyConnectedComponents(self, maxIter): + def stronglyConnectedComponents(self, maxIter, intermediateVertexStorageLevel = StorageLevel.MEMORY_ONLY, + intermediateEdgeStorageLevel = StorageLevel.MEMORY_ONLY): """ Runs the strongly connected components algorithm on this graph. See Scala documentation for more details. :param maxIter: the number of iterations to run + :param intermediateVertexStorageLevel: storage level for intermediate `Graph` vertices that + require multiple passes (default: `MEMORY_ONLY`). + :param intermediateEdgeStorageLevel: storage level for intermediate `Graph` edges that + require multiple passes (default: `MEMORY_ONLY`). :return: DataFrame with new vertex column "component" """ - jdf = self._jvm_graph.stronglyConnectedComponents().maxIter(maxIter).run() + javaVertexStorageLevel = self._sc._getJavaStorageLevel(intermediateVertexStorageLevel) + javaEdgeStorageLevel = self._sc._getJavaStorageLevel(intermediateEdgeStorageLevel) + jdf = self._jvm_graph.stronglyConnectedComponents().maxIter(maxIter) \ + .setIntermediateVertexStorageLevel(javaVertexStorageLevel) \ + .setIntermediateEdgeStorageLevel(javaEdgeStorageLevel) \ + .run() return DataFrame(jdf, self._sqlContext) def svdPlusPlus(self, rank = 10, maxIter = 2, minValue = 0.0, maxValue = 5.0, diff --git a/python/graphframes/tests.py b/python/graphframes/tests.py index 1c070eb23..1c49e046e 100644 --- a/python/graphframes/tests.py +++ b/python/graphframes/tests.py @@ -35,6 +35,7 @@ from .graphframe import GraphFrame, _java_api, _from_java_gf from .lib import AggregateMessages as AM from .examples import Graphs, BeliefPropagation +from pyspark.storagelevel import StorageLevel class GraphFrameTestUtils(object): @@ -273,6 +274,23 @@ def test_connected_components_friends(self): for c in comps_tests: self.assertEqual(c.groupBy("component").count().count(), 2) + def test_connected_components_intermediate_storage_level(self): + #graphx implementation + g = self._graph("friends") + expected = g.connectedComponents(algorithm="graphx").collect() + levels = [StorageLevel.DISK_ONLY, StorageLevel.MEMORY_AND_DISK] + for vLabel in levels: + for eLabel in levels: + out = g.connectedComponents(algorithm="graphx", intermediateGraphxVertexStorageLevel=vLabel, intermediateGraphxEdgeStorageLevel=eLabel).collect() + self.assertEqual(out, expected) + + #graphframe implementation + expected = g.connectedComponents().collect() + levels = [StorageLevel.DISK_ONLY, StorageLevel.MEMORY_AND_DISK] + for label in levels: + out = g.connectedComponents(intermediateStorageLevel=label).collect() + self.assertEqual(out, expected) + def test_label_progagation(self): n = 5 g = self._graph("twoBlobs", n) @@ -284,6 +302,16 @@ def test_label_progagation(self): all2 = set([x.label for x in labels2]) self.assertEqual(all2, set([n])) + def test_label_progagation_intermediate_storage_level(self): + g = self._graph("friends") + expected = g.labelPropagation(maxIter=1).collect() + levels = [StorageLevel.DISK_ONLY, StorageLevel.MEMORY_AND_DISK] + for vLabel in levels: + for eLabel in levels: + out = g.labelPropagation(maxIter=1, intermediateVertexStorageLevel=vLabel, intermediateEdgeStorageLevel=eLabel).collect() + self.assertEqual(out, expected) + + def test_page_rank(self): n = 100 g = self._graph("star", n) @@ -292,6 +320,18 @@ def test_page_rank(self): pr = g.pageRank(resetProb, tol=errorTol) self._hasCols(pr, vcols=['id', 'pagerank'], ecols=['src', 'dst', 'weight']) + def test_page_rank_intermediate_storage_level(self): + g = self._graph("friends") + pr = g.pageRank(maxIter=1) + expected_vertex = pr.vertices.collect() + expected_edge = pr.edges.collect() + levels = [StorageLevel.DISK_ONLY, StorageLevel.MEMORY_AND_DISK] + for vLabel in levels: + for eLabel in levels: + pr_out = g.pageRank(maxIter=1, intermediateVertexStorageLevel=vLabel, intermediateEdgeStorageLevel=eLabel) + self.assertEqual(pr_out.vertices.collect(), expected_vertex) + self.assertEqual(pr_out.edges.collect(), expected_edge) + def test_parallel_personalized_page_rank(self): if not GraphFrameTestUtils.spark_at_least_of_version("2.1"): self.skipTest("Parallel Personalized PageRank is only available in Apache Spark 2.1+") @@ -303,6 +343,24 @@ def test_parallel_personalized_page_rank(self): pr = g.parallelPersonalizedPageRank(resetProb, sourceIds=sourceIds, maxIter=maxIter) self._hasCols(pr, vcols=['id', 'pageranks'], ecols=['src', 'dst', 'weight']) + def test_parallel_personalized_page_rank_intermediate_storage_level(self): + if GraphFrameTestUtils.spark_at_least_of_version("2.2"): + self.skipTest("in Spark 2.2, sourceIds must be smaller than Int.MaxValue \ + which might not be the case for LONG_ID in graph.indexedVertices") + if not GraphFrameTestUtils.spark_at_least_of_version("2.1"): + self.skipTest("Parallel Personalized PageRank is only available in Apache Spark 2.1+") + g = self._graph("friends") + sourceIds = ["a", "b"] + ppr = g.parallelPersonalizedPageRank(maxIter=1, sourceIds=sourceIds) + expected_vertex = ppr.vertices.collect() + expected_edge = ppr.edges.collect() + levels = [StorageLevel.DISK_ONLY, StorageLevel.MEMORY_AND_DISK] + for vLabel in levels: + for eLabel in levels: + ppr_out = g.parallelPersonalizedPageRank(maxIter=1, sourceIds=sourceIds, intermediateVertexStorageLevel=vLabel, intermediateEdgeStorageLevel=eLabel) + self.assertEqual(ppr_out.vertices.collect(), expected_vertex) + self.assertEqual(ppr_out.edges.collect(), expected_edge) + def test_shortest_paths(self): edges = [(1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)] all_edges = [z for (a, b) in edges for z in [(a, b), (b, a)]] @@ -313,6 +371,16 @@ def test_shortest_paths(self): v2 = g.shortestPaths(landmarks) self._df_hasCols(v2, vcols=["id", "distances"]) + def test_shortest_paths_intermediate_storage_level(self): + g = self._graph("friends") + landmarks = ["a", "d"] + expected = g.shortestPaths(landmarks=landmarks).collect() + levels = [StorageLevel.DISK_ONLY, StorageLevel.MEMORY_AND_DISK] + for vLabel in levels: + for eLabel in levels: + out = g.shortestPaths(landmarks=landmarks, intermediateVertexStorageLevel=vLabel, intermediateEdgeStorageLevel=eLabel).collect() + self.assertEqual(out, expected) + def test_svd_plus_plus(self): g = self._graph("ALSSyntheticData") (v2, cost) = g.svdPlusPlus() @@ -327,6 +395,15 @@ def test_strongly_connected_components(self): for row in c.collect(): self.assertEqual(row.id, row.component) + def test_strongly_connected_components_intermediate_storage_level(self): + g = self._graph("friends") + expected = g.stronglyConnectedComponents(maxIter=1).collect() + levels = [StorageLevel.DISK_ONLY, StorageLevel.MEMORY_AND_DISK] + for vLabel in levels: + for eLabel in levels: + out = g.stronglyConnectedComponents(maxIter=1, intermediateVertexStorageLevel=vLabel, intermediateEdgeStorageLevel=eLabel).collect() + self.assertEqual(out, expected) + def test_triangle_counts(self): edges = self.sqlContext.createDataFrame([(0, 1), (1, 2), (2, 0)], ["src", "dst"]) vertices = self.sqlContext.createDataFrame([(0,), (1,), (2,)], ["id"]) diff --git a/src/main/scala/org/graphframes/GraphFrame.scala b/src/main/scala/org/graphframes/GraphFrame.scala index d6c6c73d7..174ebaf5f 100644 --- a/src/main/scala/org/graphframes/GraphFrame.scala +++ b/src/main/scala/org/graphframes/GraphFrame.scala @@ -525,6 +525,30 @@ class GraphFrame private( cachedGraphX.mapVertices((_, _) => ()).mapEdges(e => ()) } + /** + * creates a new GraphX `Graph` from the cachedTopologyGraphX `Graph` with the vertex and edge storage levels passed + * in this method + * @param vertexStorageLevel vertex storage level set in the returned GraphX `Graph` instance + * @param edgeStorageLevel edge storage level set in the returned GraphX `Graph` instance + * @return a copy of the cachedTopologyGraphX `Graph` instance with the provided storage levels or the + * cachedTopologyGraphX instance if the default storage level values are provided (``MEMORY_ONLY`` in both cases) + * + */ + private[graphframes] def cachedTopologyGraphXWithStorageLevel( + vertexStorageLevel: StorageLevel, + edgeStorageLevel: StorageLevel) = + if(vertexStorageLevel == StorageLevel.MEMORY_ONLY && edgeStorageLevel == StorageLevel.MEMORY_ONLY) { + cachedTopologyGraphX + } else { + Graph( + vertices = cachedTopologyGraphX.vertices, + edges = cachedTopologyGraphX.edges, + defaultVertexAttr = null.asInstanceOf[Unit], + vertexStorageLevel = vertexStorageLevel, + edgeStorageLevel = edgeStorageLevel + ) + } + /** * A cached conversion of this graph to the GraphX structure, with the data stored for each edge and vertex. */ diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index d5d250457..6536cb004 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -147,6 +147,37 @@ class ConnectedComponents private[graphframes] ( */ def getIntermediateStorageLevel: StorageLevel = intermediateStorageLevel + private var graphxVertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + private var graphxEdgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + + /** + * Sets storage level for intermediate `Graph` vertices that require multiple passes (default: ``MEMORY_ONLY``). + * This parameter is only used when the algorithm is set to "graphx". + */ + def setIntermediateGraphxVertexStorageLevel(value: StorageLevel): this.type = { + graphxVertexStorageLevel = value + this + } + + /** + * Gets storage level for intermediate `Graph` vertices that require multiple passes. + */ + def getIntermediateGraphxVertexStorageLevel: StorageLevel = graphxVertexStorageLevel + + /** + * Sets storage level for intermediate `Graph` edges that require multiple passes (default: ``MEMORY_ONLY``). + * This parameter is only used when the algorithm is set to "graphx". + */ + def setIntermediateGraphxEdgeStorageLevel(value: StorageLevel): this.type = { + graphxEdgeStorageLevel = value + this + } + + /** + * Gets storage level for intermediate `Graph` edges that require multiple passes. + */ + def getIntermediateGraphxEdgeStorageLevel: StorageLevel = graphxEdgeStorageLevel + /** * Runs the algorithm. */ @@ -155,7 +186,9 @@ class ConnectedComponents private[graphframes] ( algorithm = algorithm, broadcastThreshold = broadcastThreshold, checkpointInterval = checkpointInterval, - intermediateStorageLevel = intermediateStorageLevel) + intermediateStorageLevel = intermediateStorageLevel, + graphxVertexStorageLevel = graphxVertexStorageLevel, + graphxEdgeStorageLevel = graphxEdgeStorageLevel) } } @@ -263,8 +296,12 @@ object ConnectedComponents extends Logging { new ConnectedComponents(graph).run() } - private def runGraphX(graph: GraphFrame): DataFrame = { - val components = org.apache.spark.graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX) + private def runGraphX( + graph: GraphFrame, + intermediateVertexStorageLevel: StorageLevel, + intermediateEdgeStorageLevel: StorageLevel): DataFrame = { + val components = org.apache.spark.graphx.lib.ConnectedComponents.run( + graph.cachedTopologyGraphXWithStorageLevel(intermediateVertexStorageLevel, intermediateEdgeStorageLevel)) GraphXConversions.fromGraphX(graph, components, vertexNames = Seq(COMPONENT)).vertices } @@ -273,12 +310,14 @@ object ConnectedComponents extends Logging { algorithm: String, broadcastThreshold: Int, checkpointInterval: Int, - intermediateStorageLevel: StorageLevel): DataFrame = { + intermediateStorageLevel: StorageLevel, + graphxVertexStorageLevel: StorageLevel, + graphxEdgeStorageLevel: StorageLevel): DataFrame = { require(supportedAlgorithms.contains(algorithm), s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $algorithm.") if (algorithm == ALGO_GRAPHX) { - return runGraphX(graph) + return runGraphX(graph, graphxVertexStorageLevel, graphxEdgeStorageLevel) } val runId = UUID.randomUUID().toString.takeRight(8) diff --git a/src/main/scala/org/graphframes/lib/LabelPropagation.scala b/src/main/scala/org/graphframes/lib/LabelPropagation.scala index a740dd5fd..dc2ed4cfe 100644 --- a/src/main/scala/org/graphframes/lib/LabelPropagation.scala +++ b/src/main/scala/org/graphframes/lib/LabelPropagation.scala @@ -19,7 +19,7 @@ package org.graphframes.lib import org.apache.spark.graphx.{lib => graphxlib} import org.apache.spark.sql.DataFrame - +import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame /** @@ -40,6 +40,36 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame) exte private var maxIter: Option[Int] = None + private var intermediateVertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + + private var intermediateEdgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + + /** + * Sets storage level for intermediate `Graph` vertices that require multiple passes (default: ``MEMORY_ONLY``). + */ + def setIntermediateVertexStorageLevel(value: StorageLevel): this.type = { + intermediateVertexStorageLevel = value + this + } + + /** + * Gets storage level for intermediate `Graph` vertices that require multiple passes. + */ + def getIntermediateVertexStorageLevel: StorageLevel = intermediateVertexStorageLevel + + /** + * Sets storage level for intermediate `Graph` edges that require multiple passes (default: ``MEMORY_ONLY``). + */ + def setIntermediateEdgeStorageLevel(value: StorageLevel): this.type = { + intermediateEdgeStorageLevel = value + this + } + + /** + * Gets storage level for intermediate `Graph` edges that require multiple passes. + */ + def getIntermediateEdgeStorageLevel: StorageLevel = intermediateEdgeStorageLevel + /** * The max number of iterations of LPA to be performed. Because this is a static * implementation, the algorithm will run for exactly this many iterations. @@ -52,14 +82,21 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame) exte def run(): DataFrame = { LabelPropagation.run( graph, - check(maxIter, "maxIter")) + check(maxIter, "maxIter"), + intermediateVertexStorageLevel, + intermediateEdgeStorageLevel) } } private object LabelPropagation { - private def run(graph: GraphFrame, maxIter: Int): DataFrame = { - val gx = graphxlib.LabelPropagation.run(graph.cachedTopologyGraphX, maxIter) + private def run( + graph: GraphFrame, + maxIter: Int, + intermediateVertexStorageLevel: StorageLevel, + intermediateEdgeStorageLevel: StorageLevel): DataFrame = { + val gx = graphxlib.LabelPropagation.run( + graph.cachedTopologyGraphXWithStorageLevel(intermediateVertexStorageLevel, intermediateEdgeStorageLevel), maxIter) GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(LABEL_ID)).vertices } diff --git a/src/main/scala/org/graphframes/lib/PageRank.scala b/src/main/scala/org/graphframes/lib/PageRank.scala index 2c9a9cf2e..b602e2715 100644 --- a/src/main/scala/org/graphframes/lib/PageRank.scala +++ b/src/main/scala/org/graphframes/lib/PageRank.scala @@ -18,7 +18,7 @@ package org.graphframes.lib import org.apache.spark.graphx.{lib => graphxlib} - +import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame /** @@ -71,6 +71,34 @@ class PageRank private[graphframes] ( private var resetProb: Option[Double] = Some(0.15) private var maxIter: Option[Int] = None private var srcId : Option[Any] = None + private var intermediateVertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + private var intermediateEdgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + + /** + * Sets storage level for intermediate `Graph` vertices that require multiple passes (default: ``MEMORY_ONLY``). + */ + def setIntermediateVertexStorageLevel(value: StorageLevel): this.type = { + intermediateVertexStorageLevel = value + this + } + + /** + * Gets storage level for intermediate `Graph` vertices that require multiple passes. + */ + def getIntermediateVertexStorageLevel: StorageLevel = intermediateVertexStorageLevel + + /** + * Sets storage level for intermediate `Graph` edges that require multiple passes (default: ``MEMORY_ONLY``). + */ + def setIntermediateEdgeStorageLevel(value: StorageLevel): this.type = { + intermediateEdgeStorageLevel = value + this + } + + /** + * Gets storage level for intermediate `Graph` edges that require multiple passes. + */ + def getIntermediateEdgeStorageLevel: StorageLevel = intermediateEdgeStorageLevel /** Source vertex for a Personalized Page Rank (optional) */ def sourceId(value: Any): this.type = { @@ -99,9 +127,9 @@ class PageRank private[graphframes] ( case Some(t) => assert(maxIter.isEmpty, "You cannot specify maxIter() and tol() at the same time.") - PageRank.runUntilConvergence(graph, t, resetProb.get, srcId) + PageRank.runUntilConvergence(graph, t, resetProb.get, srcId, intermediateVertexStorageLevel, intermediateEdgeStorageLevel) case None => - PageRank.run(graph, check(maxIter, "maxIter"), resetProb.get, srcId) + PageRank.run(graph, check(maxIter, "maxIter"), resetProb.get, srcId, intermediateVertexStorageLevel, intermediateEdgeStorageLevel) } } } @@ -125,10 +153,12 @@ private object PageRank { graph: GraphFrame, maxIter: Int, resetProb: Double = 0.15, - srcId: Option[Any] = None): GraphFrame = { + srcId: Option[Any] = None, + intermediateVertexStorageLevel: StorageLevel, + intermediateEdgeStorageLevel: StorageLevel): GraphFrame = { val longSrcId = srcId.map(GraphXConversions.integralId(graph, _)) val gx = graphxlib.PageRank.runWithOptions( - graph.cachedTopologyGraphX, maxIter, resetProb, longSrcId) + graph.cachedTopologyGraphXWithStorageLevel(intermediateVertexStorageLevel, intermediateEdgeStorageLevel), maxIter, resetProb, longSrcId) GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANK), edgeNames = Seq(WEIGHT)) } @@ -147,10 +177,12 @@ private object PageRank { graph: GraphFrame, tol: Double, resetProb: Double = 0.15, - srcId: Option[Any] = None): GraphFrame = { + srcId: Option[Any] = None, + intermediateVertexStorageLevel: StorageLevel, + intermediateEdgeStorageLevel: StorageLevel): GraphFrame = { val longSrcId = srcId.map(GraphXConversions.integralId(graph, _)) val gx = graphxlib.PageRank.runUntilConvergenceWithOptions( - graph.cachedTopologyGraphX, tol, resetProb, longSrcId) + graph.cachedTopologyGraphXWithStorageLevel(intermediateVertexStorageLevel, intermediateEdgeStorageLevel), tol, resetProb, longSrcId) GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANK), edgeNames = Seq(WEIGHT)) } diff --git a/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala b/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala index 33444a461..6d3a896f6 100644 --- a/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala +++ b/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala @@ -18,6 +18,7 @@ package org.graphframes.lib import org.apache.spark.graphx.{lib => graphxlib} +import org.apache.spark.storage.StorageLevel import org.graphframes.{GraphFrame, Logging} /** @@ -59,6 +60,34 @@ class ParallelPersonalizedPageRank private[graphframes] ( private var resetProb: Option[Double] = Some(0.15) private var maxIter: Option[Int] = None private var srcIds: Array[Any] = Array() + private var intermediateVertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + private var intermediateEdgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + + /** + * Sets storage level for intermediate `Graph` vertices that require multiple passes (default: ``MEMORY_ONLY``). + */ + def setIntermediateVertexStorageLevel(value: StorageLevel): this.type = { + intermediateVertexStorageLevel = value + this + } + + /** + * Gets storage level for intermediate `Graph` vertices that require multiple passes. + */ + def getIntermediateVertexStorageLevel: StorageLevel = intermediateVertexStorageLevel + + /** + * Sets storage level for intermediate `Graph` edges that require multiple passes (default: ``MEMORY_ONLY``). + */ + def setIntermediateEdgeStorageLevel(value: StorageLevel): this.type = { + intermediateEdgeStorageLevel = value + this + } + + /** + * Gets storage level for intermediate `Graph` edges that require multiple passes. + */ + def getIntermediateEdgeStorageLevel: StorageLevel = intermediateEdgeStorageLevel /** Source vertices for a Personalized Page Rank */ def sourceIds(values: Array[Any]): this.type = { @@ -81,7 +110,7 @@ class ParallelPersonalizedPageRank private[graphframes] ( def run(): GraphFrame = { require(maxIter != None, s"Max number of iterations maxIter() must be provided") require(srcIds.nonEmpty, s"Source vertices Ids sourceIds() must be provided") - ParallelPersonalizedPageRank.run(graph, maxIter.get, resetProb.get, srcIds) + ParallelPersonalizedPageRank.run(graph, maxIter.get, resetProb.get, srcIds, intermediateVertexStorageLevel, intermediateEdgeStorageLevel) } } @@ -110,10 +139,15 @@ private object ParallelPersonalizedPageRank { graph: GraphFrame, maxIter: Int, resetProb: Double, - sourceIds: Array[Any]): GraphFrame = { + sourceIds: Array[Any], + intermediateVertexStorageLevel: StorageLevel, + intermediateEdgeStorageLevel: StorageLevel): GraphFrame = { val longSrcIds = sourceIds.map(GraphXConversions.integralId(graph, _)) val gx = graphxlib.GraphXHelpers.runParallelPersonalizedPageRank( - graph.cachedTopologyGraphX, maxIter, resetProb, longSrcIds) + graph.cachedTopologyGraphXWithStorageLevel(intermediateVertexStorageLevel, intermediateEdgeStorageLevel), + maxIter, + resetProb, + longSrcIds) GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANKS), edgeNames = Seq(WEIGHT)) } } diff --git a/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/src/main/scala/org/graphframes/lib/ShortestPaths.scala index c5be8b81b..d05d0ff7d 100644 --- a/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -20,13 +20,12 @@ package org.graphframes.lib import java.util import scala.collection.JavaConverters._ - import org.apache.spark.graphx.{lib => graphxlib} import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.functions.col import org.apache.spark.sql.SQLHelpers._ import org.apache.spark.sql.types.{IntegerType, MapType} - +import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame /** @@ -40,6 +39,34 @@ import org.graphframes.GraphFrame */ class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends Arguments { private var lmarks: Option[Seq[Any]] = None + private var intermediateVertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + private var intermediateEdgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + + /** + * Sets storage level for intermediate `Graph` vertices that require multiple passes (default: ``MEMORY_ONLY``). + */ + def setIntermediateVertexStorageLevel(value: StorageLevel): this.type = { + intermediateVertexStorageLevel = value + this + } + + /** + * Gets storage level for intermediate `Graph` vertices that require multiple passes. + */ + def getIntermediateVertexStorageLevel: StorageLevel = intermediateVertexStorageLevel + + /** + * Sets storage level for intermediate `Graph` edges that require multiple passes (default: ``MEMORY_ONLY``). + */ + def setIntermediateEdgeStorageLevel(value: StorageLevel): this.type = { + intermediateEdgeStorageLevel = value + this + } + + /** + * Gets storage level for intermediate `Graph` edges that require multiple passes. + */ + def getIntermediateEdgeStorageLevel: StorageLevel = intermediateEdgeStorageLevel /** * The list of landmark vertex ids. Shortest paths will be computed to each landmark. @@ -58,17 +85,21 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends } def run(): DataFrame = { - ShortestPaths.run(graph, check(lmarks, "landmarks")) + ShortestPaths.run(graph, check(lmarks, "landmarks"), intermediateVertexStorageLevel, intermediateEdgeStorageLevel) } } private object ShortestPaths { - private def run(graph: GraphFrame, landmarks: Seq[Any]): DataFrame = { + private def run( + graph: GraphFrame, + landmarks: Seq[Any], + intermediateVertexStorageLevel: StorageLevel, + intermediateEdgeStorageLevel: StorageLevel): DataFrame = { val idType = graph.vertices.schema(GraphFrame.ID).dataType val longIdToLandmark = landmarks.map(l => GraphXConversions.integralId(graph, l) -> l).toMap val gx = graphxlib.ShortestPaths.run( - graph.cachedTopologyGraphX, + graph.cachedTopologyGraphXWithStorageLevel(intermediateVertexStorageLevel, intermediateEdgeStorageLevel), longIdToLandmark.keys.toSeq.sorted).mapVertices { case (_, m) => m.toSeq } val g = GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(DISTANCE_ID)) val distanceCol: Column = if (graph.hasIntegralIdType) { diff --git a/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala b/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala index 2914a287a..fdc865a84 100644 --- a/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala @@ -19,7 +19,7 @@ package org.graphframes.lib import org.apache.spark.graphx.{lib => graphxlib} import org.apache.spark.sql.DataFrame - +import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame /** @@ -33,6 +33,34 @@ class StronglyConnectedComponents private[graphframes] (private val graph: Graph extends Arguments { private var maxIter: Option[Int] = None + private var intermediateVertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + private var intermediateEdgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + + /** + * Sets storage level for intermediate `Graph` vertices that require multiple passes (default: ``MEMORY_ONLY``). + */ + def setIntermediateVertexStorageLevel(value: StorageLevel): this.type = { + intermediateVertexStorageLevel = value + this + } + + /** + * Gets storage level for intermediate `Graph` vertices that require multiple passes. + */ + def getIntermediateVertexStorageLevel: StorageLevel = intermediateVertexStorageLevel + + /** + * Sets storage level for intermediate `Graph` edges that require multiple passes (default: ``MEMORY_ONLY``). + */ + def setIntermediateEdgeStorageLevel(value: StorageLevel): this.type = { + intermediateEdgeStorageLevel = value + this + } + + /** + * Gets storage level for intermediate `Graph` edges that require multiple passes. + */ + def getIntermediateEdgeStorageLevel: StorageLevel = intermediateEdgeStorageLevel def maxIter(value: Int): this.type = { maxIter = Some(value) @@ -40,15 +68,21 @@ class StronglyConnectedComponents private[graphframes] (private val graph: Graph } def run(): DataFrame = { - StronglyConnectedComponents.run(graph, check(maxIter, "maxIter")) + StronglyConnectedComponents.run(graph, check(maxIter, "maxIter"), intermediateVertexStorageLevel, intermediateEdgeStorageLevel) } } /** Strongly connected components algorithm implementation. */ private object StronglyConnectedComponents { - private def run(graph: GraphFrame, numIter: Int): DataFrame = { - val gx = graphxlib.StronglyConnectedComponents.run(graph.cachedTopologyGraphX, numIter) + private def run( + graph: GraphFrame, + numIter: Int, + intermediateVertexStorageLevel: StorageLevel, + intermediateEdgeStorageLevel: StorageLevel): DataFrame = { + val gx = graphxlib.StronglyConnectedComponents.run( + graph.cachedTopologyGraphXWithStorageLevel(intermediateVertexStorageLevel, intermediateEdgeStorageLevel), + numIter) GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(COMPONENT_ID)).vertices } diff --git a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index 65534a463..b80967054 100644 --- a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -230,6 +230,26 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon } } + test("intermediate storage level for graphx") { + val friends = Graphs.friends + val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) + + val cc = friends.connectedComponents.setAlgorithm("graphx") + assert(cc.getIntermediateGraphxVertexStorageLevel === StorageLevel.MEMORY_ONLY) + assert(cc.getIntermediateGraphxEdgeStorageLevel === StorageLevel.MEMORY_ONLY) + + val levels = Seq(StorageLevel.DISK_ONLY, StorageLevel.MEMORY_AND_DISK) + for (vLevel <- levels; eLevel <- levels) { + val components = cc + .setIntermediateGraphxVertexStorageLevel(vLevel) + .setIntermediateGraphxEdgeStorageLevel(eLevel) + .run() + assertComponents(components, expected) + assert(cc.getIntermediateGraphxVertexStorageLevel === vLevel) + assert(cc.getIntermediateGraphxEdgeStorageLevel === eLevel) + } + } + private def assertComponents[T: ClassTag:TypeTag]( actual: DataFrame, expected: Set[Set[T]]): Unit = { diff --git a/src/test/scala/org/graphframes/lib/LabelPropagationSuite.scala b/src/test/scala/org/graphframes/lib/LabelPropagationSuite.scala index df9ba7445..818537aae 100644 --- a/src/test/scala/org/graphframes/lib/LabelPropagationSuite.scala +++ b/src/test/scala/org/graphframes/lib/LabelPropagationSuite.scala @@ -18,7 +18,7 @@ package org.graphframes.lib import org.apache.spark.sql.types.DataTypes - +import org.apache.spark.storage.StorageLevel import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite, TestUtils} import org.graphframes.examples.Graphs @@ -39,4 +39,25 @@ class LabelPropagationSuite extends SparkFunSuite with GraphFrameTestSparkContex assert(clique2.size === 1) assert(clique1 !== clique2) } + + test("intermediate storage levels") { + val lb = Graphs.friends.labelPropagation.maxIter(2) + assert(lb.getIntermediateVertexStorageLevel === StorageLevel.MEMORY_ONLY) + assert(lb.getIntermediateEdgeStorageLevel === StorageLevel.MEMORY_ONLY) + val expected = lb.run().collect().map(r => r.getAs[String]("id") -> r.getAs[Long]("label")).toMap + + val levels = Seq(StorageLevel.DISK_ONLY, StorageLevel.NONE, StorageLevel.MEMORY_AND_DISK) + for(vLevel <- levels; eLevel <- levels) { + val output = lb + .setIntermediateVertexStorageLevel(vLevel) + .setIntermediateEdgeStorageLevel(eLevel) + .run() + .collect() + .map(r => r.getAs[String]("id") -> r.getAs[Long]("label")) + .toMap + assert(output === expected) + assert(lb.getIntermediateVertexStorageLevel === vLevel) + assert(lb.getIntermediateEdgeStorageLevel === eLevel) + } + } } diff --git a/src/test/scala/org/graphframes/lib/PageRankSuite.scala b/src/test/scala/org/graphframes/lib/PageRankSuite.scala index eca810ec7..89f65d623 100644 --- a/src/test/scala/org/graphframes/lib/PageRankSuite.scala +++ b/src/test/scala/org/graphframes/lib/PageRankSuite.scala @@ -19,7 +19,7 @@ package org.graphframes.lib import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.DataTypes - +import org.apache.spark.storage.StorageLevel import org.graphframes.examples.Graphs import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite, TestUtils} @@ -46,4 +46,32 @@ class PageRankSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(gRank === 0.0, s"User g (Gabby) doesn't connect with a. So its pagerank should be 0 but we got $gRank.") } + + test("intermediate storage levels") { + val prIter = Graphs.friends.pageRank.maxIter(1) + val prConv = Graphs.friends.pageRank.tol(0.9) + assert(prIter.getIntermediateVertexStorageLevel === StorageLevel.MEMORY_ONLY) + assert(prIter.getIntermediateEdgeStorageLevel === StorageLevel.MEMORY_ONLY) + assert(prConv.getIntermediateVertexStorageLevel === StorageLevel.MEMORY_ONLY) + assert(prConv.getIntermediateEdgeStorageLevel === StorageLevel.MEMORY_ONLY) + + val graphIter = prIter.run() + val graphConv = prConv.run() + val expected = Seq( + (graphIter.vertices.collect(), graphIter.edges.collect()), + (graphConv.vertices.collect(), graphConv.edges.collect())) + + val levels = Seq(StorageLevel.DISK_ONLY, StorageLevel.MEMORY_AND_DISK) + for(vLevel <- levels; eLevel <- levels; prExpected <- Seq(prIter, prConv).zip(expected)) { + val graph = prExpected._1 + .setIntermediateVertexStorageLevel(vLevel) + .setIntermediateEdgeStorageLevel(eLevel) + .run() + + assert(graph.vertices.collect() === prExpected._2._1) + assert(graph.edges.collect() === prExpected._2._2) + assert(prExpected._1.getIntermediateVertexStorageLevel === vLevel) + assert(prExpected._1.getIntermediateEdgeStorageLevel === eLevel) + } + } } diff --git a/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala b/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala index 5115612f7..ff2b3b2b6 100644 --- a/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala +++ b/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala @@ -17,11 +17,11 @@ package org.graphframes.lib -import org.apache.spark.ml.linalg.{SparseVector, SQLDataTypes} +import org.apache.spark.ml.linalg.{SQLDataTypes, SparseVector} import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.Row - +import org.apache.spark.storage.StorageLevel import org.graphframes.examples.Graphs import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite, TestUtils} @@ -106,4 +106,32 @@ class ParallelPersonalizedPageRankSuite extends SparkFunSuite with GraphFrameTes } } + test("intermediate storage levels") { + val ppr = Graphs.friends.parallelPersonalizedPageRank.maxIter(1).sourceIds(Array("a", "b")) + assert(ppr.getIntermediateVertexStorageLevel === StorageLevel.MEMORY_ONLY) + assert(ppr.getIntermediateEdgeStorageLevel === StorageLevel.MEMORY_ONLY) + + if (isLaterVersion("2.2")) { + intercept[java.lang.IllegalArgumentException] { ppr.run() } + } else if (isLaterVersion("2.1")) { + val graph = ppr.run() + val expected = (graph.vertices.collect(), graph.edges.collect()) + + val levels = Seq(StorageLevel.DISK_ONLY, StorageLevel.MEMORY_AND_DISK) + for (vLevel <- levels; eLevel <- levels) { + val graph = ppr + .setIntermediateVertexStorageLevel(vLevel) + .setIntermediateEdgeStorageLevel(eLevel) + .run() + + assert(graph.vertices.collect() === expected._1) + assert(graph.edges.collect() === expected._2) + assert(ppr.getIntermediateVertexStorageLevel === vLevel) + assert(ppr.getIntermediateEdgeStorageLevel === eLevel) + } + } else { + intercept[NotImplementedError] { ppr.run() } + } + } + } diff --git a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala index debdfcf92..5d437fab9 100644 --- a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala @@ -19,8 +19,9 @@ package org.graphframes.lib import org.apache.spark.sql.Row import org.apache.spark.sql.types.DataTypes - +import org.apache.spark.storage.StorageLevel import org.graphframes._ +import org.graphframes.examples.Graphs class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext { @@ -62,4 +63,24 @@ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(results === expected) } + test("intermediate storage levels") { + val sp = Graphs.friends.shortestPaths.landmarks(Seq("a", "d")) + assert(sp.getIntermediateVertexStorageLevel === StorageLevel.MEMORY_ONLY) + assert(sp.getIntermediateEdgeStorageLevel === StorageLevel.MEMORY_ONLY) + + val expected = sp.run().collect() + + val levels = Seq(StorageLevel.DISK_ONLY, StorageLevel.MEMORY_AND_DISK) + for (vLevel <- levels; eLevel <- levels) { + val output = sp + .setIntermediateVertexStorageLevel(vLevel) + .setIntermediateEdgeStorageLevel(eLevel) + .run() + + assert(output.collect() === expected) + assert(sp.getIntermediateVertexStorageLevel === vLevel) + assert(sp.getIntermediateEdgeStorageLevel === eLevel) + } + } + } diff --git a/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala index 36e318a64..8932f35e1 100644 --- a/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala +++ b/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala @@ -19,8 +19,9 @@ package org.graphframes.lib import org.apache.spark.sql.Row import org.apache.spark.sql.types.DataTypes - -import org.graphframes.{GraphFrameTestSparkContext, GraphFrame, SparkFunSuite, TestUtils} +import org.apache.spark.storage.StorageLevel +import org.graphframes.examples.Graphs +import org.graphframes.{GraphFrame, GraphFrameTestSparkContext, SparkFunSuite, TestUtils} class StronglyConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("Island Strongly Connected Components") { @@ -40,4 +41,24 @@ class StronglyConnectedComponentsSuite extends SparkFunSuite with GraphFrameTest assert(id === component) } } + + test("intermediate storage levels") { + val scc = Graphs.friends.stronglyConnectedComponents.maxIter(1) + assert(scc.getIntermediateVertexStorageLevel === StorageLevel.MEMORY_ONLY) + assert(scc.getIntermediateEdgeStorageLevel === StorageLevel.MEMORY_ONLY) + + val expected = scc.run().collect() + + val levels = Seq(StorageLevel.DISK_ONLY, StorageLevel.MEMORY_AND_DISK) + for (vLevel <- levels; eLevel <- levels) { + val output = scc + .setIntermediateVertexStorageLevel(vLevel) + .setIntermediateEdgeStorageLevel(eLevel) + .run() + + assert(output.collect() === expected) + assert(scc.getIntermediateVertexStorageLevel === vLevel) + assert(scc.getIntermediateEdgeStorageLevel === eLevel) + } + } }