diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index cdd3e5786..27706f48c 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -496,6 +496,27 @@ def triangleCount(self) -> DataFrame: jdf = self._jvm_graph.triangleCount().run() return DataFrame(jdf, self._spark) + def powerIterationClustering( + self, k: int, maxIter: int, weightCol: Optional[str] = None + ) -> DataFrame: + """ + Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and Cohen. + From the abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration + on a normalized pair-wise similarity matrix of the data. + + :param k: the numbers of clusters to create + :param maxIter: param for maximum number of iterations (>= 0) + :param weightCol: optional name of weight column, 1.0 is used if not provided + + :return: DataFrame with new column "cluster" + """ # noqa: E501 + if weightCol: + weightCol = self._spark._jvm.scala.Option.apply(weightCol) + else: + weightCol = self._spark._jvm.scala.Option.empty() + jdf = self._jvm_graph.powerIterationClustering(k, maxIter, weightCol) + return DataFrame(jdf, self._spark) + def _test(): import doctest diff --git a/python/graphframes/tests.py b/python/graphframes/tests.py index 80b1bd075..b0463c942 100644 --- a/python/graphframes/tests.py +++ b/python/graphframes/tests.py @@ -231,6 +231,39 @@ def test_bfs(self): paths3 = g.bfs("name='A'", "name='C'", maxPathLength=1) self.assertEqual(paths3.count(), 0) + def test_power_iteration_clustering(self): + vertices = [ + (1, 0, 0.5), + (2, 0, 0.5), + (2, 1, 0.7), + (3, 0, 0.5), + (3, 1, 0.7), + (3, 2, 0.9), + (4, 0, 0.5), + (4, 1, 0.7), + (4, 2, 0.9), + (4, 3, 1.1), + (5, 0, 0.5), + (5, 1, 0.7), + (5, 2, 0.9), + (5, 3, 1.1), + (5, 4, 1.3), + ] + edges = [(0,), (1,), (2,), (3,), (4,), (5,)] + g = GraphFrame( + v=self.spark.createDataFrame(edges).toDF("id"), + e=self.spark.createDataFrame(vertices).toDF("src", "dst", "weight"), + ) + + clusters = [ + r["cluster"] + for r in g.powerIterationClustering(k=2, maxIter=40, weightCol="weight") + .sort("id") + .collect() + ] + + self.assertEqual(clusters, [0, 0, 0, 0, 1, 0]) + class PregelTest(GraphFrameTestCase): def setUp(self): diff --git a/src/main/scala/org/graphframes/GraphFrame.scala b/src/main/scala/org/graphframes/GraphFrame.scala index 01b829065..fac754300 100644 --- a/src/main/scala/org/graphframes/GraphFrame.scala +++ b/src/main/scala/org/graphframes/GraphFrame.scala @@ -21,15 +21,16 @@ import java.util.Random import scala.reflect.runtime.universe.TypeTag +import org.graphframes.lib._ +import org.graphframes.pattern._ + import org.apache.spark.graphx.{Edge, Graph} +import org.apache.spark.ml.clustering.PowerIterationClustering import org.apache.spark.sql._ -import org.apache.spark.sql.functions.{array, broadcast, col, count, explode, struct, udf, monotonically_increasing_id, expr} +import org.apache.spark.sql.functions.{array, broadcast, col, count, explode, expr, lit, max, monotonically_increasing_id, struct, udf} import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -import org.graphframes.lib._ -import org.graphframes.pattern._ - /** * A representation of a graph using `DataFrame`s. * @@ -246,8 +247,8 @@ class GraphFrame private ( /** * The out-degree of each vertex in the graph, returned as a DataFrame with two columns: * - [[GraphFrame.ID]] the ID of the vertex - * - "outDegree" (integer) storing the out-degree of the vertex - * Note that vertices with 0 out-edges are not returned in the result. + * - "outDegree" (integer) storing the out-degree of the vertex Note that vertices with 0 + * out-edges are not returned in the result. * * @group degree */ @@ -257,9 +258,8 @@ class GraphFrame private ( /** * The in-degree of each vertex in the graph, returned as a DataFame with two columns: - * - [[GraphFrame.ID]] the ID of the vertex - * "- "inDegree" (int) storing the in-degree of the vertex Note that vertices with 0 in-edges - * are not returned in the result. + * - [[GraphFrame.ID]] the ID of the vertex "- "inDegree" (int) storing the in-degree of the + * vertex Note that vertices with 0 in-edges are not returned in the result. * * @group degree */ @@ -270,8 +270,8 @@ class GraphFrame private ( /** * The degree of each vertex in the graph, returned as a DataFrame with two columns: * - [[GraphFrame.ID]] the ID of the vertex - * - 'degree' (integer) the degree of the vertex - * Note that vertices with 0 edges are not returned in the result. + * - 'degree' (integer) the degree of the vertex Note that vertices with 0 edges are not + * returned in the result. * * @group degree */ @@ -302,9 +302,9 @@ class GraphFrame private ( * - Within a pattern, names can be assigned to vertices and edges. For example, * `"(a)-[e]->(b)"` has three named elements: vertices `a,b` and edge `e`. These names serve * two purposes: - * - The names can identify common elements among edges. For example, - * `"(a)-[e]->(b); (b)-[e2]->(c)"` specifies that the same vertex `b` is the destination - * of edge `e` and source of edge `e2`. + * - The names can identify common elements among edges. For example, `"(a)-[e]->(b); + * (b)-[e2]->(c)"` specifies that the same vertex `b` is the destination of edge `e` and + * source of edge `e2`. * - The names are used as column names in the result `DataFrame`. If a motif contains named * vertex `a`, then the result `DataFrame` will contain a column "a" which is a * `StructType` with sub-fields equivalent to the schema (columns) of @@ -312,10 +312,10 @@ class GraphFrame private ( * the result `DataFrame` with sub-fields equivalent to the schema (columns) of * [[GraphFrame.edges]]. * - Be aware that names do *not* identify *distinct* elements: two elements with different - * names may refer to the same graph element. For example, in the motif - * `"(a)-[e]->(b); (b)-[e2]->(c)"`, the names `a` and `c` could refer to the same vertex. - * To restrict named elements to be distinct vertices or edges, use post-hoc filters such - * as `resultDataframe.filter("a.id != c.id")`. + * names may refer to the same graph element. For example, in the motif `"(a)-[e]->(b); + * (b)-[e2]->(c)"`, the names `a` and `c` could refer to the same vertex. To restrict + * named elements to be distinct vertices or edges, use post-hoc filters such as + * `resultDataframe.filter("a.id != c.id")`. * - It is acceptable to omit names for vertices or edges in motifs when not needed. E.g., * `"(a)-[]->(b)"` expresses an edge between vertices `a,b` but does not assign a name to * the edge. There will be no column for the anonymous edge in the result `DataFrame`. @@ -509,6 +509,32 @@ class GraphFrame private ( */ def triangleCount: TriangleCount = new TriangleCount(this) + /** + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and + * Cohen. From the abstract: PIC finds a very low-dimensional embedding of a dataset using + * truncated power iteration on a normalized pair-wise similarity matrix of the data. + * + * PowerIterationClustering algorithm. + * @param k + * The number of clusters to create (k). + * @param maxIter + * Param for maximum number of iterations (>= 0). + * @param weightCol + * Param for weight column name. + * @return + */ + def powerIterationClustering(k: Int, maxIter: Int, weightCol: Option[String]): DataFrame = { + val powerIterationClustering = + new PowerIterationClustering().setK(k).setMaxIter(maxIter).setDstCol(DST).setSrcCol(SRC) + weightCol match { + case Some(col) => powerIterationClustering.setWeightCol(col).assignClusters(edges) + case None => + powerIterationClustering + .setWeightCol("_weight") + .assignClusters(edges.withColumn("_weight", lit(1.0))) + } + } + // ========= Motif finding (private) ========= /** @@ -784,17 +810,18 @@ object GraphFrame extends Serializable with Logging { /** * Given: * - a GraphFrame `originalGraph` - * - a GraphX graph derived from the GraphFrame using [[GraphFrame.toGraphX]] - * this method merges attributes from the GraphX graph into the original GraphFrame. + * - a GraphX graph derived from the GraphFrame using [[GraphFrame.toGraphX]] this method + * merges attributes from the GraphX graph into the original GraphFrame. * * This method is useful for doing computations using the GraphX API and then merging the * results with a GraphFrame. For example, given: * - GraphFrame `originalGraph` * - GraphX Graph[String, Int] `graph` with a String vertex attribute we want to call - * "category" and an Int edge attribute we want to call "count" - * We can call `fromGraphX(originalGraph, graph, Seq("category"), Seq("count"))` to produce a - * new GraphFrame. The new GraphFrame will be an augmented version of `originalGraph`, with new - * [[GraphFrame.vertices]] column "category" and new [[GraphFrame.edges]] column "count" added. + * "category" and an Int edge attribute we want to call "count" We can call + * `fromGraphX(originalGraph, graph, Seq("category"), Seq("count"))` to produce a new + * GraphFrame. The new GraphFrame will be an augmented version of `originalGraph`, with new + * [[GraphFrame.vertices]] column "category" and new [[GraphFrame.edges]] column "count" + * added. * * See [[org.graphframes.examples.BeliefPropagation]] for example usage. * diff --git a/src/test/scala/org/graphframes/GraphFrameSuite.scala b/src/test/scala/org/graphframes/GraphFrameSuite.scala index d8d761898..508ed926e 100644 --- a/src/test/scala/org/graphframes/GraphFrameSuite.scala +++ b/src/test/scala/org/graphframes/GraphFrameSuite.scala @@ -19,18 +19,18 @@ package org.graphframes import java.io.File -import com.google.common.io.Files +import org.graphframes.examples.Graphs + import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path - import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{IntegerType, StringType} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.storage.StorageLevel -import org.graphframes.examples.Graphs +import com.google.common.io.Files class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { @@ -313,4 +313,37 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { GraphFrame.setBroadcastThreshold(defaultThreshold) } + + test("power iteration clustering wrapper") { + val spark = this.spark + import spark.implicits._ + val edges = spark + .createDataFrame( + Seq( + (1, 0, 0.5), + (2, 0, 0.5), + (2, 1, 0.7), + (3, 0, 0.5), + (3, 1, 0.7), + (3, 2, 0.9), + (4, 0, 0.5), + (4, 1, 0.7), + (4, 2, 0.9), + (4, 3, 1.1), + (5, 0, 0.5), + (5, 1, 0.7), + (5, 2, 0.9), + (5, 3, 1.1), + (5, 4, 1.3))) + .toDF("src", "dst", "weight") + val vertices = Seq(0, 1, 2, 3, 4, 5).toDF("id") + val gf = GraphFrame(vertices, edges) + val clusters = gf + .powerIterationClustering(k = 2, maxIter = 40, weightCol = Some("weight")) + .collect() + .sortBy(_.getAs[Long]("id")) + .map(_.getAs[Int]("cluster")) + .toSeq + assert(Seq(0, 0, 0, 0, 1, 0) == clusters) + } }