From fc3709a8076a3cc424007eb231a1ec9af3d7c0e0 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Tue, 20 Jan 2026 13:11:31 +0100 Subject: [PATCH] 757: Support string IDs in GraphFrame powerIterationClustering The powerIterationClustering algorithm requires integral vertex IDs, but the GraphFrame API supports string IDs. Previously, this method would fail when called on a GraphFrame with string IDs. Now, we internally convert string IDs to long integers, run the clustering algorithm, then map the results back to the original string IDs. Changes: - In `GraphFrame.powerIterationClustering`: - For non-integral ID types (e.g., string), use the precomputed `indexedEdges` (which contain `LONG_SRC`/`LONG_DST` long ID columns) to create a temporary edge DataFrame with long IDs. - Preserve the optional weight column if specified. - Execute PowerIterationClustering on the long-ID edges. - Join the results (which have long IDs) back with `indexedVertices` to map the long cluster IDs back to the original vertex IDs. - For integral ID types, the original behavior is unchanged. - Added a test `powerIterationClustering string ids` in `GraphFrameSuite` to verify correctness with string IDs. --- .../scala/org/graphframes/GraphFrame.scala | 32 ++++++++++++++++-- .../org/graphframes/GraphFrameSuite.scala | 33 +++++++++++++++++++ 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/graphframes/GraphFrame.scala b/core/src/main/scala/org/graphframes/GraphFrame.scala index a1a44ca49..3dc558bed 100644 --- a/core/src/main/scala/org/graphframes/GraphFrame.scala +++ b/core/src/main/scala/org/graphframes/GraphFrame.scala @@ -689,14 +689,40 @@ class GraphFrame private ( * @return */ def powerIterationClustering(k: Int, maxIter: Int, weightCol: Option[String]): DataFrame = { + val integralTypeEdges = if (hasIntegralIdType) { + edges + } else { + val pureIds = + indexedEdges.drop(SRC, DST).withColumnsRenamed(Map(LONG_SRC -> SRC, LONG_DST -> DST)) + if (weightCol.isDefined) { + pureIds.select( + col(SRC), + col(DST), + col("attr").getField(weightCol.get).alias(weightCol.get)) + } else { + pureIds + } + } val powerIterationClustering = new PowerIterationClustering().setK(k).setMaxIter(maxIter).setDstCol(DST).setSrcCol(SRC) - weightCol match { - case Some(col) => powerIterationClustering.setWeightCol(col).assignClusters(edges) + val result = weightCol match { + case Some(col) => + powerIterationClustering.setWeightCol(col).assignClusters(integralTypeEdges) case None => powerIterationClustering .setWeightCol("_weight") - .assignClusters(edges.withColumn("_weight", lit(1.0))) + .assignClusters(integralTypeEdges.withColumn("_weight", lit(1.0))) + } + + if (hasIntegralIdType) { + result + } else { + result + .join( + indexedVertices.select(col(LONG_ID).alias(ID), col(ID).alias("_ID")), + Seq(ID), + "inner") + .select(col("_ID").alias(ID), col("cluster")) } } diff --git a/core/src/test/scala/org/graphframes/GraphFrameSuite.scala b/core/src/test/scala/org/graphframes/GraphFrameSuite.scala index f9e032d4f..196c9d27f 100644 --- a/core/src/test/scala/org/graphframes/GraphFrameSuite.scala +++ b/core/src/test/scala/org/graphframes/GraphFrameSuite.scala @@ -588,6 +588,39 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(Seq(0, 0, 0, 0, 1, 0) == clusters) } + test("power iteration clustering string ids") { + val spark = this.spark + import spark.implicits._ + val edges = spark + .createDataFrame( + Seq( + ("1", "0", 0.5), + ("2", "0", 0.5), + ("2", "1", 0.7), + ("3", "0", 0.5), + ("3", "1", 0.7), + ("3", "2", 0.9), + ("4", "0", 0.5), + ("4", "1", 0.7), + ("4", "2", 0.9), + ("4", "3", 1.1), + ("5", "0", 0.5), + ("5", "1", 0.7), + ("5", "2", 0.9), + ("5", "3", 1.1), + ("5", "4", 1.3))) + .toDF("src", "dst", "weight") + val vertices = Seq("0", "1", "2", "3", "4", "5").toDF("id") + val gf = GraphFrame(vertices, edges) + val clusters = gf + .powerIterationClustering(k = 2, maxIter = 40, weightCol = Some("weight")) + .collect() + .sortBy(_.getAs[String]("id")) + .map(_.getAs[Int]("cluster")) + .toSeq + assert(Seq(1, 1, 1, 1, 1, 0) == clusters) + } + test("convert directed graph to undirected") { val v = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "name") val e = spark.createDataFrame(Seq((1L, 2L), (2L, 3L))).toDF("src", "dst")