From 29d47419c4f93996b8892f12a639202f7f08f581 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Thu, 27 Feb 2025 11:52:55 +0100 Subject: [PATCH 1/6] PowerIterationClustering wrapper --- python/graphframes/graphframe.py | 164 ++++++++++----- python/graphframes/tests.py | 189 +++++++++++------- .../scala/org/graphframes/GraphFrame.scala | 75 ++++--- .../org/graphframes/GraphFrameSuite.scala | 45 ++++- 4 files changed, 317 insertions(+), 156 deletions(-) diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index 5381ec8b5..cb683ae05 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -18,7 +18,7 @@ import sys from typing import Any, Union, Optional -if sys.version > '3': +if sys.version > "3": basestring = str from graphframes.lib import Pregel @@ -27,7 +27,7 @@ from pyspark.storagelevel import StorageLevel -def _from_java_gf(jgf: Any, spark: SparkSession) -> 'GraphFrame': +def _from_java_gf(jgf: Any, spark: SparkSession) -> "GraphFrame": """ (internal) creates a python GraphFrame wrapper from a java GraphFrame. @@ -37,10 +37,15 @@ def _from_java_gf(jgf: Any, spark: SparkSession) -> 'GraphFrame': pe = DataFrame(jgf.edges(), spark) return GraphFrame(pv, pe) + def _java_api(jsc: SparkContext) -> Any: javaClassName = "org.graphframes.GraphFramePythonAPI" - return jsc._jvm.Thread.currentThread().getContextClassLoader().loadClass(javaClassName) \ - .newInstance() + return ( + jsc._jvm.Thread.currentThread() + .getContextClassLoader() + .loadClass(javaClassName) + .newInstance() + ) class GraphFrame: @@ -76,16 +81,22 @@ def __init__(self, v: DataFrame, e: DataFrame) -> None: # Check that provided DataFrames contain required columns if self.ID not in v.columns: raise ValueError( - "Vertex ID column {} missing from vertex DataFrame, which has columns: {}" - .format(self.ID, ",".join(v.columns))) + "Vertex ID column {} missing from vertex DataFrame, which has columns: {}".format( + self.ID, ",".join(v.columns) + ) + ) if self.SRC not in e.columns: raise ValueError( - "Source vertex ID column {} missing from edge DataFrame, which has columns: {}" - .format(self.SRC, ",".join(e.columns))) + "Source vertex ID column {} missing from edge DataFrame, which has columns: {}".format( + self.SRC, ",".join(e.columns) + ) + ) if self.DST not in e.columns: raise ValueError( - "Destination vertex ID column {} missing from edge DataFrame, which has columns: {}" - .format(self.DST, ",".join(e.columns))) + "Destination vertex ID column {} missing from edge DataFrame, which has columns: {}".format( + self.DST, ",".join(e.columns) + ) + ) self._jvm_graph = self._jvm_gf_api.createGraph(v._jdf, e._jdf) @@ -109,8 +120,8 @@ def edges(self) -> DataFrame: def __repr__(self): return self._jvm_graph.toString() - def cache(self) -> 'GraphFrame': - """ Persist the dataframe representation of vertices and edges of the graph with the default + def cache(self) -> "GraphFrame": + """Persist the dataframe representation of vertices and edges of the graph with the default storage level. """ self._jvm_graph.cache() @@ -124,7 +135,7 @@ def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "Gra self._jvm_graph.persist(javaStorageLevel) return self - def unpersist(self, blocking: bool = False) -> 'GraphFrame': + def unpersist(self, blocking: bool = False) -> "GraphFrame": """Mark the dataframe representation of vertices and edges of the graph as non-persistent, and remove all blocks for it from memory and disk. """ @@ -209,12 +220,12 @@ def find(self, pattern: str) -> DataFrame: jdf = self._jvm_graph.find(pattern) return DataFrame(jdf, self._spark) - def filterVertices(self, condition: Union[str, Column]) -> 'GraphFrame': + def filterVertices(self, condition: Union[str, Column]) -> "GraphFrame": """ Filters the vertices based on expression, remove edges containing any dropped vertices. - + :param condition: String or Column describing the condition expression for filtering. - :return: GraphFrame with filtered vertices and edges. + :return: GraphFrame with filtered vertices and edges. """ if isinstance(condition, basestring): @@ -225,12 +236,12 @@ def filterVertices(self, condition: Union[str, Column]) -> 'GraphFrame': raise TypeError("condition should be string or Column") return _from_java_gf(jdf, self._spark) - def filterEdges(self, condition: Union[str, Column]) -> 'GraphFrame': + def filterEdges(self, condition: Union[str, Column]) -> "GraphFrame": """ Filters the edges based on expression, keep all vertices. - + :param condition: String or Column describing the condition expression for filtering. - :return: GraphFrame with filtered edges. + :return: GraphFrame with filtered edges. """ if isinstance(condition, basestring): jdf = self._jvm_graph.filterEdges(condition) @@ -240,18 +251,18 @@ def filterEdges(self, condition: Union[str, Column]) -> 'GraphFrame': raise TypeError("condition should be string or Column") return _from_java_gf(jdf, self._spark) - def dropIsolatedVertices(self) -> 'GraphFrame': + def dropIsolatedVertices(self) -> "GraphFrame": """ Drops isolated vertices, vertices are not contained in any edges. - :return: GraphFrame with filtered vertices. + :return: GraphFrame with filtered vertices. """ jdf = self._jvm_graph.dropIsolatedVertices() return _from_java_gf(jdf, self._spark) - def bfs(self, fromExpr: str, toExpr: str, - edgeFilter: Optional[str] = None, - maxPathLength: int = 10) -> DataFrame: + def bfs( + self, fromExpr: str, toExpr: str, edgeFilter: Optional[str] = None, maxPathLength: int = 10 + ) -> DataFrame: """ Breadth-first search (BFS). @@ -259,18 +270,20 @@ def bfs(self, fromExpr: str, toExpr: str, :return: DataFrame with one Row for each shortest path between matching vertices. """ - builder = self._jvm_graph.bfs()\ - .fromExpr(fromExpr)\ - .toExpr(toExpr)\ - .maxPathLength(maxPathLength) + builder = ( + self._jvm_graph.bfs().fromExpr(fromExpr).toExpr(toExpr).maxPathLength(maxPathLength) + ) if edgeFilter is not None: builder.edgeFilter(edgeFilter) jdf = builder.run() return DataFrame(jdf, self._spark) - def aggregateMessages(self, aggCol: Union[Column, str], - sendToSrc: Union[Column, str, None] = None, - sendToDst: Union[Column, str, None] = None) -> DataFrame: + def aggregateMessages( + self, + aggCol: Union[Column, str], + sendToSrc: Union[Column, str, None] = None, + sendToDst: Union[Column, str, None] = None, + ) -> DataFrame: """ Aggregates messages from the neighbours. @@ -314,9 +327,12 @@ def aggregateMessages(self, aggCol: Union[Column, str], # Standard algorithms - def connectedComponents(self, algorithm: str = 'graphframes', - checkpointInterval: int = 2, - broadcastThreshold: int = 1000000) -> DataFrame: + def connectedComponents( + self, + algorithm: str = "graphframes", + checkpointInterval: int = 2, + broadcastThreshold: int = 1000000, + ) -> DataFrame: """ Computes the connected components of the graph. @@ -330,11 +346,13 @@ def connectedComponents(self, algorithm: str = 'graphframes', :return: DataFrame with new vertices column "component" """ - jdf = self._jvm_graph.connectedComponents() \ - .setAlgorithm(algorithm) \ - .setCheckpointInterval(checkpointInterval) \ - .setBroadcastThreshold(broadcastThreshold) \ + jdf = ( + self._jvm_graph.connectedComponents() + .setAlgorithm(algorithm) + .setCheckpointInterval(checkpointInterval) + .setBroadcastThreshold(broadcastThreshold) .run() + ) return DataFrame(jdf, self._spark) def labelPropagation(self, maxIter: int) -> DataFrame: @@ -349,10 +367,13 @@ def labelPropagation(self, maxIter: int) -> DataFrame: jdf = self._jvm_graph.labelPropagation().maxIter(maxIter).run() return DataFrame(jdf, self._spark) - def pageRank(self, resetProbability: float = 0.15, - sourceId: Optional[Any] = None, - maxIter: Optional[int] = None, - tol: Optional[float] = None) -> 'GraphFrame': + def pageRank( + self, + resetProbability: float = 0.15, + sourceId: Optional[Any] = None, + maxIter: Optional[int] = None, + tol: Optional[float] = None, + ) -> "GraphFrame": """ Runs the PageRank algorithm on the graph. Note: Exactly one of fixed_num_iter or tolerance must be set. @@ -379,9 +400,12 @@ def pageRank(self, resetProbability: float = 0.15, jgf = builder.run() return _from_java_gf(jgf, self._spark) - def parallelPersonalizedPageRank(self, resetProbability: float = 0.15, - sourceIds: Optional[list[Any]] = None, - maxIter: Optional[int] = None) -> 'GraphFrame': + def parallelPersonalizedPageRank( + self, + resetProbability: float = 0.15, + sourceIds: Optional[list[Any]] = None, + maxIter: Optional[int] = None, + ) -> "GraphFrame": """ Run the personalized PageRank algorithm on the graph, from the provided list of sources in parallel for a fixed number of iterations. @@ -393,7 +417,9 @@ def parallelPersonalizedPageRank(self, resetProbability: float = 0.15, :param maxIter: the fixed number of iterations this algorithm runs :return: GraphFrame with new vertices column "pageranks" and new edges column "weight" """ - assert sourceIds is not None and len(sourceIds) > 0, "Source vertices Ids sourceIds must be provided" + assert ( + sourceIds is not None and len(sourceIds) > 0 + ), "Source vertices Ids sourceIds must be provided" assert maxIter is not None, "Max number of iterations maxIter must be provided" sourceIds = self._sc._jvm.PythonUtils.toArray(sourceIds) builder = self._jvm_graph.parallelPersonalizedPageRank() @@ -427,10 +453,17 @@ def stronglyConnectedComponents(self, maxIter: int) -> DataFrame: jdf = self._jvm_graph.stronglyConnectedComponents().maxIter(maxIter).run() return DataFrame(jdf, self._spark) - def svdPlusPlus(self, rank: int = 10, maxIter: int = 2, - minValue: float = 0.0, maxValue: float = 5.0, - gamma1: float = 0.007, gamma2: float = 0.007, - gamma6: float = 0.005, gamma7: float = 0.015) -> tuple[DataFrame, float]: + def svdPlusPlus( + self, + rank: int = 10, + maxIter: int = 2, + minValue: float = 0.0, + maxValue: float = 5.0, + gamma1: float = 0.007, + gamma2: float = 0.007, + gamma6: float = 0.005, + gamma7: float = 0.015, + ) -> tuple[DataFrame, float]: """ Runs the SVD++ algorithm. @@ -458,16 +491,39 @@ def triangleCount(self) -> DataFrame: jdf = self._jvm_graph.triangleCount().run() return DataFrame(jdf, self._spark) + def powerIterationClustering( + self, k: int, maxIter: int, weightCol: Optional[str] = None + ) -> DataFrame: + """ + Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and Cohen. + From the abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration + on a normalized pair-wise similarity matrix of the data. + + :param k: the numbers of clusters to create + :param maxIter: param for maximum number of iterations (>= 0) + :param weightCol: optional name of weight column, 1.0 is used if not provided + + :return: DataFrame with new column "cluster" + """ + if weightCol: + weightCol = self._spark._jvm.scala.Option.apply(weightCol) + else: + weightCol = self._spark._jvm.scala.Option.empty() + jdf = self._jvm_graph.powerIterationClustering(k, maxIter, weightCol) + return DataFrame(jdf, self._spark) + def _test(): import doctest import graphframe + globs = graphframe.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - globs['spark'] = SparkSession(globs['sc']).builder.getOrCreate() + globs["sc"] = SparkContext("local[4]", "PythonTest", batchSize=2) + globs["spark"] = SparkSession(globs["sc"]).builder.getOrCreate() (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) - globs['sc'].stop() + globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE + ) + globs["sc"].stop() if failure_count: exit(-1) diff --git a/python/graphframes/tests.py b/python/graphframes/tests.py index 9a7ad1371..c4d70d32d 100644 --- a/python/graphframes/tests.py +++ b/python/graphframes/tests.py @@ -24,7 +24,7 @@ try: import unittest2 as unittest except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.stderr.write("Please install unittest2 to test with Python 2.6 or earlier") sys.exit(1) else: import unittest @@ -36,39 +36,42 @@ from .lib import AggregateMessages as AM from .examples import Graphs, BeliefPropagation + class GraphFrameTestUtils(object): @classmethod def parse_spark_version(cls, version_str): - """ take an input version string - return version items in a dictionary + """take an input version string + return version items in a dictionary """ - _sc_ver_patt = r'(\d+)\.(\d+)(\.(\d+)(-(.+))?)?' + _sc_ver_patt = r"(\d+)\.(\d+)(\.(\d+)(-(.+))?)?" m = re.match(_sc_ver_patt, version_str) if not m: - raise TypeError("version {} shoud be in ..".format(version_str)) + raise TypeError( + "version {} shoud be in ..".format(version_str) + ) version_info = {} try: - version_info['major'] = int(m.group(1)) + version_info["major"] = int(m.group(1)) except: raise TypeError("invalid minor version") try: - version_info['minor'] = int(m.group(2)) + version_info["minor"] = int(m.group(2)) except: raise TypeError("invalid major version") try: - version_info['maintenance'] = int(m.group(4)) + version_info["maintenance"] = int(m.group(4)) except: - version_info['maintenance'] = 0 + version_info["maintenance"] = 0 try: - version_info['special'] = m.group(6) + version_info["special"] = m.group(6) except: pass return version_info @classmethod def createSparkContext(cls): - cls.sc = sc = SparkContext('local[4]', "GraphFramesTests") + cls.sc = sc = SparkContext("local[4]", "GraphFramesTests") cls.checkpointDir = tempfile.mkdtemp() cls.sc.setCheckpointDir(cls.checkpointDir) cls.spark_version = cls.parse_spark_version(sc.version) @@ -81,10 +84,10 @@ def stopSparkContext(cls): @classmethod def spark_at_least_of_version(cls, version_str): - assert hasattr(cls, 'spark_version') + assert hasattr(cls, "spark_version") required_version = cls.parse_spark_version(version_str) spark_version = cls.spark_version - for _name in ['major', 'minor', 'maintenance']: + for _name in ["major", "minor", "maintenance"]: sc_ver = spark_version[_name] req_ver = required_version[_name] if sc_ver != req_ver: @@ -92,9 +95,11 @@ def spark_at_least_of_version(cls, version_str): # All major.minor.maintenance equal return True + def setUpModule(): GraphFrameTestUtils.createSparkContext() + def tearDownModule(): GraphFrameTestUtils.stopSparkContext() @@ -104,7 +109,11 @@ class GraphFrameTestCase(unittest.TestCase): @classmethod def setUpClass(cls): # Small tests run much faster with spark.sql.shuffle.partitions = 4 - cls.spark = SparkSession(GraphFrameTestUtils.sc).builder.config('spark.sql.shuffle.partitions', 4).getOrCreate() + cls.spark = ( + SparkSession(GraphFrameTestUtils.sc) + .builder.config("spark.sql.shuffle.partitions", 4) + .getOrCreate() + ) @classmethod def tearDownClass(cls): @@ -136,14 +145,20 @@ def test_construction(self): assert sorted(vertexIDs) == [1, 2, 3] edgeActions = map(lambda x: x[0], g.edges.select("action").collect()) assert sorted(edgeActions) == ["follow", "hate", "love"] - tripletsFirst = list(map(lambda x: (x[0][1], x[1][1], x[2][2]), - g.triplets.sort("src.id").select("src", "dst", "edge").take(1))) + tripletsFirst = list( + map( + lambda x: (x[0][1], x[1][1], x[2][2]), + g.triplets.sort("src.id").select("src", "dst", "edge").take(1), + ) + ) assert tripletsFirst == [("A", "B", "love")], tripletsFirst # Try with invalid vertices and edges DataFrames v_invalid = self.spark.createDataFrame( - [(1, "A"), (2, "B"), (3, "C")], ["invalid_colname_1", "invalid_colname_2"]) + [(1, "A"), (2, "B"), (3, "C")], ["invalid_colname_1", "invalid_colname_2"] + ) e_invalid = self.spark.createDataFrame( - [(1, 2), (2, 3), (3, 1)], ["invalid_colname_3", "invalid_colname_4"]) + [(1, 2), (2, 3), (3, 1)], ["invalid_colname_3", "invalid_colname_4"] + ) with self.assertRaises(ValueError): GraphFrame(v_invalid, e_invalid) @@ -217,6 +232,37 @@ def test_bfs(self): paths3 = g.bfs("name='A'", "name='C'", maxPathLength=1) self.assertEqual(paths3.count(), 0) + def test_power_iteration_clustering(self): + vertices = [ + (1, 0, 0.5), + (2, 0, 0.5), + (2, 1, 0.7), + (3, 0, 0.5), + (3, 1, 0.7), + (3, 2, 0.9), + (4, 0, 0.5), + (4, 1, 0.7), + (4, 2, 0.9), + (4, 3, 1.1), + (5, 0, 0.5), + (5, 1, 0.7), + (5, 2, 0.9), + (5, 3, 1.1), + (5, 4, 1.3), + ] + edges = [(0,), (1,), (2,), (3,), (4,), (5,)] + g = GraphFrame( + v=self.spark.createDataFrame(vertices).toDF("src", "dst", "weight"), + e=self.spark.createDataFrame(edges).toDF("id"), + ) + clusters = [ + r["cluster"] + for r in g.powerIterationClustering(k=2, maxIter=40, weightCol="weight") + .sort("id") + .collect() + ] + self.assertEqual(clusters, [0, 0, 0, 0, 1]) + class PregelTest(GraphFrameTestCase): def setUp(self): @@ -224,13 +270,11 @@ def setUp(self): def test_page_rank(self): from pyspark.sql.functions import coalesce, col, lit, sum, when - edges = self.spark.createDataFrame([[0, 1], - [1, 2], - [2, 4], - [2, 0], - [3, 4], # 3 has no in-links - [4, 0], - [4, 2]], ["src", "dst"]) + + edges = self.spark.createDataFrame( + [[0, 1], [1, 2], [2, 4], [2, 0], [3, 4], [4, 0], [4, 2]], # 3 has no in-links + ["src", "dst"], + ) edges.cache() vertices = self.spark.createDataFrame([[0], [1], [2], [3], [4]], ["id"]) numVertices = vertices.count() @@ -238,19 +282,22 @@ def test_page_rank(self): vertices.cache() graph = GraphFrame(vertices, edges) alpha = 0.15 - ranks = graph.pregel \ - .setMaxIter(5) \ - .withVertexColumn("rank", lit(1.0 / numVertices), - coalesce(Pregel.msg(), - lit(0.0)) * lit(1.0 - alpha) + lit(alpha / numVertices)) \ - .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree")) \ - .aggMsgs(sum(Pregel.msg())) \ + ranks = ( + graph.pregel.setMaxIter(5) + .withVertexColumn( + "rank", + lit(1.0 / numVertices), + coalesce(Pregel.msg(), lit(0.0)) * lit(1.0 - alpha) + lit(alpha / numVertices), + ) + .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree")) + .aggMsgs(sum(Pregel.msg())) .run() + ) resultRows = ranks.sort(ranks.id).collect() result = map(lambda x: x.rank, resultRows) expected = [0.245, 0.224, 0.303, 0.03, 0.197] for a, b in zip(result, expected): - self.assertAlmostEqual(a, b, delta = 1e-3) + self.assertAlmostEqual(a, b, delta=1e-3) class GraphFrameLibTest(GraphFrameTestCase): @@ -258,11 +305,11 @@ def setUp(self): super(GraphFrameLibTest, self).setUp() self.japi = _java_api(self.spark._sc) - def _hasCols(self, graph, vcols = [], ecols = []): + def _hasCols(self, graph, vcols=[], ecols=[]): map(lambda c: self.assertIn(c, graph.vertices.columns), vcols) map(lambda c: self.assertIn(c, graph.edges.columns), ecols) - def _df_hasCols(self, vertices, vcols = []): + def _df_hasCols(self, vertices, vcols=[]): map(lambda c: self.assertIn(c, vertices.columns), vcols) def _graph(self, name, *args): @@ -281,30 +328,27 @@ def test_aggregate_messages(self): g = self._graph("friends") # For each user, sum the ages of the adjacent users, # plus 1 for the src's sum if the edge is "friend". - sendToSrc = ( - AM.dst['age'] + - sqlfunctions.when( - AM.edge['relationship'] == 'friend', - sqlfunctions.lit(1) - ).otherwise(0)) - sendToDst = AM.src['age'] + sendToSrc = AM.dst["age"] + sqlfunctions.when( + AM.edge["relationship"] == "friend", sqlfunctions.lit(1) + ).otherwise(0) + sendToDst = AM.src["age"] agg = g.aggregateMessages( - sqlfunctions.sum(AM.msg).alias('summedAges'), - sendToSrc=sendToSrc, - sendToDst=sendToDst) + sqlfunctions.sum(AM.msg).alias("summedAges"), sendToSrc=sendToSrc, sendToDst=sendToDst + ) # Run the aggregation again providing SQL expressions as String instead. agg2 = g.aggregateMessages( "sum(MSG) AS `summedAges`", sendToSrc="(dst['age'] + CASE WHEN (edge['relationship'] = 'friend') THEN 1 ELSE 0 END)", - sendToDst="src['age']") + sendToDst="src['age']", + ) # Convert agg and agg2 to a mapping from id to the aggregated message. - aggMap = {id_: s for id_, s in agg.select('id', 'summedAges').collect()} - agg2Map = {id_: s for id_, s in agg2.select('id', 'summedAges').collect()} + aggMap = {id_: s for id_, s in agg.select("id", "summedAges").collect()} + agg2Map = {id_: s for id_, s in agg2.select("id", "summedAges").collect()} # Compute the truth via brute force. - user2age = {id_: age for id_, age in g.vertices.select('id', 'age').collect()} + user2age = {id_: age for id_, age in g.vertices.select("id", "age").collect()} trueAgg = {} for src, dst, rel in g.edges.select("src", "dst", "relationship").collect(): - trueAgg[src] = trueAgg.get(src, 0) + user2age[dst] + (1 if rel == 'friend' else 0) + trueAgg[src] = trueAgg.get(src, 0) + user2age[dst] + (1 if rel == "friend" else 0) trueAgg[dst] = trueAgg.get(dst, 0) + user2age[src] # Compare if the agg mappings match the brute force mapping self.assertEqual(aggMap, trueAgg) @@ -312,22 +356,19 @@ def test_aggregate_messages(self): # Check that TypeError is raises with messages of wrong type with self.assertRaises(TypeError): g.aggregateMessages( - "sum(MSG) AS `summedAges`", - sendToSrc=object(), - sendToDst="src['age']") + "sum(MSG) AS `summedAges`", sendToSrc=object(), sendToDst="src['age']" + ) with self.assertRaises(TypeError): g.aggregateMessages( - "sum(MSG) AS `summedAges`", - sendToSrc=dst['age'], - sendToDst=object()) + "sum(MSG) AS `summedAges`", sendToSrc=dst["age"], sendToDst=object() + ) def test_connected_components(self): - v = self.spark.createDataFrame([ - (0, "a", "b")], ["id", "vattr", "gender"]) + v = self.spark.createDataFrame([(0, "a", "b")], ["id", "vattr", "gender"]) e = self.spark.createDataFrame([(0, 0, 1)], ["src", "dst", "test"]).filter("src > 10") g = GraphFrame(v, e) comps = g.connectedComponents() - self._df_hasCols(comps, vcols=['id', 'component', 'vattr', 'gender']) + self._df_hasCols(comps, vcols=["id", "component", "vattr", "gender"]) self.assertEqual(comps.count(), 1) def test_connected_components2(self): @@ -335,7 +376,7 @@ def test_connected_components2(self): e = self.spark.createDataFrame([(0, 1, "a01", "b01")], ["src", "dst", "A", "B"]) g = GraphFrame(v, e) comps = g.connectedComponents() - self._df_hasCols(comps, vcols=['id', 'component', 'A', 'B']) + self._df_hasCols(comps, vcols=["id", "component", "A", "B"]) self.assertEqual(comps.count(), 2) def test_connected_components_friends(self): @@ -367,7 +408,7 @@ def test_page_rank(self): resetProb = 0.15 errorTol = 1.0e-5 pr = g.pageRank(resetProb, tol=errorTol) - self._hasCols(pr, vcols=['id', 'pagerank'], ecols=['src', 'dst', 'weight']) + self._hasCols(pr, vcols=["id", "pagerank"], ecols=["src", "dst", "weight"]) def test_parallel_personalized_page_rank(self): n = 100 @@ -376,7 +417,7 @@ def test_parallel_personalized_page_rank(self): maxIter = 15 sourceIds = [1, 2, 3, 4] pr = g.parallelPersonalizedPageRank(resetProb, sourceIds=sourceIds, maxIter=maxIter) - self._hasCols(pr, vcols=['id', 'pageranks'], ecols=['src', 'dst', 'weight']) + self._hasCols(pr, vcols=["id", "pageranks"], ecols=["src", "dst", "weight"]) def test_shortest_paths(self): edges = [(1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)] @@ -391,7 +432,7 @@ def test_shortest_paths(self): def test_svd_plus_plus(self): g = self._graph("ALSSyntheticData") (v2, cost) = g.svdPlusPlus() - self._df_hasCols(v2, vcols=['id', 'column1', 'column2', 'column3', 'column4']) + self._df_hasCols(v2, vcols=["id", "column1", "column2", "column3", "column4"]) def test_strongly_connected_components(self): # Simple island test @@ -408,25 +449,26 @@ def test_triangle_counts(self): g = GraphFrame(vertices, edges) c = g.triangleCount() for row in c.select("id", "count").collect(): - self.assertEqual(row.asDict()['count'], 1) - + self.assertEqual(row.asDict()["count"], 1) + def test_mutithreaded_sparksession_usage(self): # Test that we can use the GraphFrame API from multiple threads localVertices = [(1, "A"), (2, "B"), (3, "C")] localEdges = [(1, 2, "love"), (2, 1, "hate"), (2, 3, "follow")] v = self.spark.createDataFrame(localVertices, ["id", "name"]) e = self.spark.createDataFrame(localEdges, ["src", "dst", "action"]) - - + exc = None + def run_graphframe() -> None: try: GraphFrame(v, e) except Exception as _e: nonlocal exc exc = _e - + import threading + thread = threading.Thread(target=run_graphframe) thread.start() thread.join() @@ -445,11 +487,12 @@ def test_belief_propagation(self): numIter = 5 results = BeliefPropagation.runBPwithGraphFrames(g, numIter) # check beliefs are valid - for row in results.vertices.select('belief').collect(): - belief = row['belief'] + for row in results.vertices.select("belief").collect(): + belief = row["belief"] self.assertTrue( 0 <= belief <= 1, - msg="Expected belief to be probability in [0,1], but found {}".format(belief)) + msg="Expected belief to be probability in [0,1], but found {}".format(belief), + ) def test_graph_friends(self): # construct graph @@ -462,7 +505,7 @@ def test_graph_grid_ising_model(self): n = 3 g = Graphs(self.spark).gridIsingModel(n) # check that all the vertices exist - ids = [v['id'] for v in g.vertices.collect()] + ids = [v["id"] for v in g.vertices.collect()] for i in range(n): for j in range(n): - self.assertIn('{},{}'.format(i, j), ids) + self.assertIn("{},{}".format(i, j), ids) diff --git a/src/main/scala/org/graphframes/GraphFrame.scala b/src/main/scala/org/graphframes/GraphFrame.scala index 01b829065..fac754300 100644 --- a/src/main/scala/org/graphframes/GraphFrame.scala +++ b/src/main/scala/org/graphframes/GraphFrame.scala @@ -21,15 +21,16 @@ import java.util.Random import scala.reflect.runtime.universe.TypeTag +import org.graphframes.lib._ +import org.graphframes.pattern._ + import org.apache.spark.graphx.{Edge, Graph} +import org.apache.spark.ml.clustering.PowerIterationClustering import org.apache.spark.sql._ -import org.apache.spark.sql.functions.{array, broadcast, col, count, explode, struct, udf, monotonically_increasing_id, expr} +import org.apache.spark.sql.functions.{array, broadcast, col, count, explode, expr, lit, max, monotonically_increasing_id, struct, udf} import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -import org.graphframes.lib._ -import org.graphframes.pattern._ - /** * A representation of a graph using `DataFrame`s. * @@ -246,8 +247,8 @@ class GraphFrame private ( /** * The out-degree of each vertex in the graph, returned as a DataFrame with two columns: * - [[GraphFrame.ID]] the ID of the vertex - * - "outDegree" (integer) storing the out-degree of the vertex - * Note that vertices with 0 out-edges are not returned in the result. + * - "outDegree" (integer) storing the out-degree of the vertex Note that vertices with 0 + * out-edges are not returned in the result. * * @group degree */ @@ -257,9 +258,8 @@ class GraphFrame private ( /** * The in-degree of each vertex in the graph, returned as a DataFame with two columns: - * - [[GraphFrame.ID]] the ID of the vertex - * "- "inDegree" (int) storing the in-degree of the vertex Note that vertices with 0 in-edges - * are not returned in the result. + * - [[GraphFrame.ID]] the ID of the vertex "- "inDegree" (int) storing the in-degree of the + * vertex Note that vertices with 0 in-edges are not returned in the result. * * @group degree */ @@ -270,8 +270,8 @@ class GraphFrame private ( /** * The degree of each vertex in the graph, returned as a DataFrame with two columns: * - [[GraphFrame.ID]] the ID of the vertex - * - 'degree' (integer) the degree of the vertex - * Note that vertices with 0 edges are not returned in the result. + * - 'degree' (integer) the degree of the vertex Note that vertices with 0 edges are not + * returned in the result. * * @group degree */ @@ -302,9 +302,9 @@ class GraphFrame private ( * - Within a pattern, names can be assigned to vertices and edges. For example, * `"(a)-[e]->(b)"` has three named elements: vertices `a,b` and edge `e`. These names serve * two purposes: - * - The names can identify common elements among edges. For example, - * `"(a)-[e]->(b); (b)-[e2]->(c)"` specifies that the same vertex `b` is the destination - * of edge `e` and source of edge `e2`. + * - The names can identify common elements among edges. For example, `"(a)-[e]->(b); + * (b)-[e2]->(c)"` specifies that the same vertex `b` is the destination of edge `e` and + * source of edge `e2`. * - The names are used as column names in the result `DataFrame`. If a motif contains named * vertex `a`, then the result `DataFrame` will contain a column "a" which is a * `StructType` with sub-fields equivalent to the schema (columns) of @@ -312,10 +312,10 @@ class GraphFrame private ( * the result `DataFrame` with sub-fields equivalent to the schema (columns) of * [[GraphFrame.edges]]. * - Be aware that names do *not* identify *distinct* elements: two elements with different - * names may refer to the same graph element. For example, in the motif - * `"(a)-[e]->(b); (b)-[e2]->(c)"`, the names `a` and `c` could refer to the same vertex. - * To restrict named elements to be distinct vertices or edges, use post-hoc filters such - * as `resultDataframe.filter("a.id != c.id")`. + * names may refer to the same graph element. For example, in the motif `"(a)-[e]->(b); + * (b)-[e2]->(c)"`, the names `a` and `c` could refer to the same vertex. To restrict + * named elements to be distinct vertices or edges, use post-hoc filters such as + * `resultDataframe.filter("a.id != c.id")`. * - It is acceptable to omit names for vertices or edges in motifs when not needed. E.g., * `"(a)-[]->(b)"` expresses an edge between vertices `a,b` but does not assign a name to * the edge. There will be no column for the anonymous edge in the result `DataFrame`. @@ -509,6 +509,32 @@ class GraphFrame private ( */ def triangleCount: TriangleCount = new TriangleCount(this) + /** + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and + * Cohen. From the abstract: PIC finds a very low-dimensional embedding of a dataset using + * truncated power iteration on a normalized pair-wise similarity matrix of the data. + * + * PowerIterationClustering algorithm. + * @param k + * The number of clusters to create (k). + * @param maxIter + * Param for maximum number of iterations (>= 0). + * @param weightCol + * Param for weight column name. + * @return + */ + def powerIterationClustering(k: Int, maxIter: Int, weightCol: Option[String]): DataFrame = { + val powerIterationClustering = + new PowerIterationClustering().setK(k).setMaxIter(maxIter).setDstCol(DST).setSrcCol(SRC) + weightCol match { + case Some(col) => powerIterationClustering.setWeightCol(col).assignClusters(edges) + case None => + powerIterationClustering + .setWeightCol("_weight") + .assignClusters(edges.withColumn("_weight", lit(1.0))) + } + } + // ========= Motif finding (private) ========= /** @@ -784,17 +810,18 @@ object GraphFrame extends Serializable with Logging { /** * Given: * - a GraphFrame `originalGraph` - * - a GraphX graph derived from the GraphFrame using [[GraphFrame.toGraphX]] - * this method merges attributes from the GraphX graph into the original GraphFrame. + * - a GraphX graph derived from the GraphFrame using [[GraphFrame.toGraphX]] this method + * merges attributes from the GraphX graph into the original GraphFrame. * * This method is useful for doing computations using the GraphX API and then merging the * results with a GraphFrame. For example, given: * - GraphFrame `originalGraph` * - GraphX Graph[String, Int] `graph` with a String vertex attribute we want to call - * "category" and an Int edge attribute we want to call "count" - * We can call `fromGraphX(originalGraph, graph, Seq("category"), Seq("count"))` to produce a - * new GraphFrame. The new GraphFrame will be an augmented version of `originalGraph`, with new - * [[GraphFrame.vertices]] column "category" and new [[GraphFrame.edges]] column "count" added. + * "category" and an Int edge attribute we want to call "count" We can call + * `fromGraphX(originalGraph, graph, Seq("category"), Seq("count"))` to produce a new + * GraphFrame. The new GraphFrame will be an augmented version of `originalGraph`, with new + * [[GraphFrame.vertices]] column "category" and new [[GraphFrame.edges]] column "count" + * added. * * See [[org.graphframes.examples.BeliefPropagation]] for example usage. * diff --git a/src/test/scala/org/graphframes/GraphFrameSuite.scala b/src/test/scala/org/graphframes/GraphFrameSuite.scala index d8d761898..aaa12d4a4 100644 --- a/src/test/scala/org/graphframes/GraphFrameSuite.scala +++ b/src/test/scala/org/graphframes/GraphFrameSuite.scala @@ -19,18 +19,18 @@ package org.graphframes import java.io.File -import com.google.common.io.Files +import org.graphframes.examples.Graphs + import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path - import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{IntegerType, StringType} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.storage.StorageLevel -import org.graphframes.examples.Graphs +import com.google.common.io.Files class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { @@ -313,4 +313,39 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { GraphFrame.setBroadcastThreshold(defaultThreshold) } + + test("power iteration clustering wrapper") { + val spark = this.spark + import spark.implicits._ + val edges = spark + .createDataFrame( + Seq( + (1, 0, 0.5), + (2, 0, 0.5), + (2, 1, 0.7), + (3, 0, 0.5), + (3, 1, 0.7), + (3, 2, 0.9), + (4, 0, 0.5), + (4, 1, 0.7), + (4, 2, 0.9), + (4, 3, 1.1), + (5, 0, 0.5), + (5, 1, 0.7), + (5, 2, 0.9), + (5, 3, 1.1), + (5, 4, 1.3))) + .toDF("src", "dst", "weight") + val vertices = Seq(0, 1, 2, 3, 4, 5).toDF("id") + val gf = GraphFrame(vertices, edges) + val clusters = gf + .powerIterationClustering(k = 2, maxIter = 40, weightCol = Some("weight")) + .sort("id") + .collect() + assert( + clusters + .zip(Seq(0, 0, 0, 0, 1)) + .map { case (r: Row, expected: Int) => r.getInt(1) == expected } + .foldLeft(true) { case (r: Boolean, row: Boolean) => r && row }) + } } From d0ebbaa7c25dca27f43051101154656dd35a1e0c Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Fri, 28 Feb 2025 09:01:55 +0100 Subject: [PATCH 2/6] fix tests --- src/test/scala/org/graphframes/GraphFrameSuite.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/test/scala/org/graphframes/GraphFrameSuite.scala b/src/test/scala/org/graphframes/GraphFrameSuite.scala index aaa12d4a4..d163f9956 100644 --- a/src/test/scala/org/graphframes/GraphFrameSuite.scala +++ b/src/test/scala/org/graphframes/GraphFrameSuite.scala @@ -342,10 +342,6 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { .powerIterationClustering(k = 2, maxIter = 40, weightCol = Some("weight")) .sort("id") .collect() - assert( - clusters - .zip(Seq(0, 0, 0, 0, 1)) - .map { case (r: Row, expected: Int) => r.getInt(1) == expected } - .foldLeft(true) { case (r: Boolean, row: Boolean) => r && row }) + assert(Seq(0, 0, 0, 0, 1, 0) == clusters.map(_.getAs[Int]("cluster")).toSeq) } } From 91c5492dd298bf3a62625a4863c509853d9aecf9 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Fri, 28 Feb 2025 09:17:41 +0100 Subject: [PATCH 3/6] Fix collect-order --- python/graphframes/tests.py | 4 +++- src/test/scala/org/graphframes/GraphFrameSuite.scala | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/graphframes/tests.py b/python/graphframes/tests.py index 0f93b8bc9..c037f7818 100644 --- a/python/graphframes/tests.py +++ b/python/graphframes/tests.py @@ -255,13 +255,15 @@ def test_power_iteration_clustering(self): v=self.spark.createDataFrame(vertices).toDF("src", "dst", "weight"), e=self.spark.createDataFrame(edges).toDF("id"), ) + clusters = [ r["cluster"] for r in g.powerIterationClustering(k=2, maxIter=40, weightCol="weight") .sort("id") .collect() ] - self.assertEqual(clusters, [0, 0, 0, 0, 1]) + + self.assertEqual(clusters, [0, 0, 0, 0, 1, 0]) class PregelTest(GraphFrameTestCase): diff --git a/src/test/scala/org/graphframes/GraphFrameSuite.scala b/src/test/scala/org/graphframes/GraphFrameSuite.scala index d163f9956..ad3dbc9fa 100644 --- a/src/test/scala/org/graphframes/GraphFrameSuite.scala +++ b/src/test/scala/org/graphframes/GraphFrameSuite.scala @@ -340,8 +340,10 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { val gf = GraphFrame(vertices, edges) val clusters = gf .powerIterationClustering(k = 2, maxIter = 40, weightCol = Some("weight")) - .sort("id") .collect() - assert(Seq(0, 0, 0, 0, 1, 0) == clusters.map(_.getAs[Int]("cluster")).toSeq) + .sortBy(_.getAs[Int]("id")) + .map(_.getAs[Int]("cluster")) + .toSeq + assert(Seq(0, 0, 0, 0, 1, 0) == clusters) } } From 04b4fb99f38da682758dc1ae71d5bf7160dfc556 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Fri, 28 Feb 2025 09:28:51 +0100 Subject: [PATCH 4/6] fix typo --- src/test/scala/org/graphframes/GraphFrameSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/scala/org/graphframes/GraphFrameSuite.scala b/src/test/scala/org/graphframes/GraphFrameSuite.scala index ad3dbc9fa..508ed926e 100644 --- a/src/test/scala/org/graphframes/GraphFrameSuite.scala +++ b/src/test/scala/org/graphframes/GraphFrameSuite.scala @@ -341,7 +341,7 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { val clusters = gf .powerIterationClustering(k = 2, maxIter = 40, weightCol = Some("weight")) .collect() - .sortBy(_.getAs[Int]("id")) + .sortBy(_.getAs[Long]("id")) .map(_.getAs[Int]("cluster")) .toSeq assert(Seq(0, 0, 0, 0, 1, 0) == clusters) From 0cd4d4bfd36087820511e44f960e124386502b32 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Thu, 6 Mar 2025 17:00:06 +0100 Subject: [PATCH 5/6] fix merge artifacts && apply pre-commit --- python/graphframes/graphframe.py | 4 ++-- python/graphframes/tests.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index 4ac5d22ae..27706f48c 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -27,7 +27,7 @@ from graphframes.lib import Pregel -def _from_java_gf(jgf: Any, spark: SparkSession) -> 'GraphFrame': + def _from_java_gf(jgf: Any, spark: SparkSession) -> "GraphFrame": """ (internal) creates a python GraphFrame wrapper from a java GraphFrame. @@ -509,7 +509,7 @@ def powerIterationClustering( :param weightCol: optional name of weight column, 1.0 is used if not provided :return: DataFrame with new column "cluster" - """ + """ # noqa: E501 if weightCol: weightCol = self._spark._jvm.scala.Option.apply(weightCol) else: diff --git a/python/graphframes/tests.py b/python/graphframes/tests.py index c037f7818..72373c17d 100644 --- a/python/graphframes/tests.py +++ b/python/graphframes/tests.py @@ -38,7 +38,6 @@ from .lib import AggregateMessages as AM - class GraphFrameTestUtils(object): @classmethod def parse_spark_version(cls, version_str): From b1c3ecb4d6caecba96e66358d8119f027793bf5b Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Thu, 6 Mar 2025 17:16:11 +0100 Subject: [PATCH 6/6] merge-conflicts artifacts --- python/graphframes/tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/graphframes/tests.py b/python/graphframes/tests.py index 72373c17d..b0463c942 100644 --- a/python/graphframes/tests.py +++ b/python/graphframes/tests.py @@ -251,8 +251,8 @@ def test_power_iteration_clustering(self): ] edges = [(0,), (1,), (2,), (3,), (4,), (5,)] g = GraphFrame( - v=self.spark.createDataFrame(vertices).toDF("src", "dst", "weight"), - e=self.spark.createDataFrame(edges).toDF("id"), + v=self.spark.createDataFrame(edges).toDF("id"), + e=self.spark.createDataFrame(vertices).toDF("src", "dst", "weight"), ) clusters = [