From 2d3c72d42334503df66d6673b23de7abd0adf914 Mon Sep 17 00:00:00 2001 From: jameswillis Date: Fri, 6 Feb 2026 13:19:41 -0800 Subject: [PATCH 01/13] feat(pregel): automatically skip second join when dst columns not needed Implements automatic optimization for Pregel triplet generation that skips the second join (adding destination vertex state) when no message expressions reference dst.* columns. The optimization works by: 1. Analyzing all message expressions before the iteration loop 2. Extracting column prefixes (src, dst, edge) from the expression AST 3. Skipping the dst vertex join if no dst.* columns are referenced AND skipMessagesFromNonActiveVertices is disabled This provides significant performance improvement for algorithms like PageRank, directed LabelPropagation, and DetectingCycles that only need source vertex or edge columns in their message expressions. Closes #790 --- .../spark/sql/graphframes/SparkShims.scala | 26 +++ .../spark/sql/graphframes/SparkShims.scala | 27 +++ .../scala/org/graphframes/lib/Pregel.scala | 35 +++- .../org/graphframes/lib/PregelSuite.scala | 159 ++++++++++++++++++ 4 files changed, 242 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala index 77593803b..1089177b0 100644 --- a/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala @@ -26,9 +26,35 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import scala.annotation.nowarn +import scala.collection.mutable object SparkShims { + /** + * Extracts all column name prefixes (e.g., "src", "dst", "edge") from a Column expression. + * + * This is used to detect which triplet columns are referenced in message expressions, + * enabling automatic optimization to skip the second join when destination columns + * are not needed. + * + * @param spark + * the SparkSession (unused in Spark 3, included for API compatibility with Spark 4) + * @param expr + * the Column expression to analyze + * @return + * a Set of column name prefixes found in the expression + */ + @nowarn + def extractColumnPrefixes(spark: SparkSession, expr: Column): Set[String] = { + val prefixes = mutable.Set.empty[String] + expr.expr.foreach { + case UnresolvedAttribute(nameParts) if nameParts.nonEmpty => + prefixes += nameParts.head + case _ => // ignore other expression types + } + prefixes.toSet + } + /** * Apply the given SQL expression (such as `id = 3`) to the field in a column, rather than to * the column itself. diff --git a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala index fea8e553a..1799186a9 100644 --- a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala @@ -29,8 +29,35 @@ import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.classic.ExpressionUtils import org.apache.spark.sql.classic.SparkSession as ClassicSparkSession +import scala.collection.mutable + object SparkShims { + /** + * Extracts all column name prefixes (e.g., "src", "dst", "edge") from a Column expression. + * + * This is used to detect which triplet columns are referenced in message expressions, + * enabling automatic optimization to skip the second join when destination columns + * are not needed. + * + * @param spark + * the SparkSession (needed for expression conversion in Spark 4) + * @param expr + * the Column expression to analyze + * @return + * a Set of column name prefixes found in the expression + */ + def extractColumnPrefixes(spark: SparkSession, expr: Column): Set[String] = { + val prefixes = mutable.Set.empty[String] + val converted = spark.asInstanceOf[ClassicSparkSession].converter(expr.node) + converted.foreach { + case UnresolvedAttribute(nameParts) if nameParts.nonEmpty => + prefixes += nameParts.head + case _ => // ignore other expression types + } + prefixes.toSet + } + /** * Apply the given SQL expression (such as `id = 3`) to the field in a column, rather than to * the column itself. diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index 9232b1f82..df24f6fd6 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.explode import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.struct +import org.apache.spark.sql.graphframes.SparkShims import org.graphframes.GraphFrame import org.graphframes.GraphFrame.* import org.graphframes.Logging @@ -425,19 +426,43 @@ class Pregel(val graph: GraphFrame) if (requiredDstColumnsList.isEmpty) Seq(col("*")) else (Seq(ID, Pregel.ACTIVE_FLAG_COL) ++ requiredDstColumnsList).distinct.map(col) + // Automatic optimization: detect if destination vertex state is needed by analyzing + // all message expressions. If no expression references dst.* columns (other than dst.id + // which is implicitly available), we can skip the second join entirely. + // However, if skipMessagesFromNonActiveVertices is enabled, we need dst._pregel_is_active. + val allMessageExpressions = sendMsgs.toList.flatMap { case (idExpr, msgExpr) => + Seq(idExpr, msgExpr) + } + val allReferencedPrefixes = allMessageExpressions.flatMap { expr => + SparkShims.extractColumnPrefixes(graph.spark, expr) + }.toSet + val needsDstState = allReferencedPrefixes.contains(DST) || skipMessagesFromNonActiveVertices + if (!needsDstState) { + logInfo("Optimization: skipping second join (dst state not required by message expressions)") + } + breakable { while (iteration <= maxIter) { logInfo(s"start Pregel iteration $iteration / $maxIter") val currRoundPersistent = scala.collection.mutable.Queue[DataFrame]() currRoundPersistent.enqueue(currentVertices.persist(intermediateStorageLevel)) - var tripletsDF = currentVertices + // Build triplets: start with src vertex state joined with edges + val srcWithEdges = currentVertices .select(struct(srcCols: _*).as(SRC)) .join(edges, Pregel.src(ID) === col("edge_src")) - .join( - currentVertices.select(struct(dstCols: _*).as(DST)), - col("edge_dst") === Pregel.dst(ID)) - .drop(col("edge_src"), col("edge_dst")) + + // Only perform the second join (adding dst vertex state) if needed + var tripletsDF = if (needsDstState) { + srcWithEdges + .join( + currentVertices.select(struct(dstCols: _*).as(DST)), + col("edge_dst") === Pregel.dst(ID)) + .drop(col("edge_src"), col("edge_dst")) + } else { + // Skip second join - dst state not needed by any message expression + srcWithEdges.drop(col("edge_src"), col("edge_dst")) + } if (skipMessagesFromNonActiveVertices) { tripletsDF = tripletsDF.filter( diff --git a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala index a1e7ca00a..39be126ab 100644 --- a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala @@ -293,4 +293,163 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { .collect() assert(result.sum === 1.0 +- 1e-6) } + + test("automatic dst join skipping - PageRank only uses src columns") { + // PageRank only references Pregel.src("rank") and Pregel.src("outDegree"), + // so the second join (for dst vertex state) should be automatically skipped. + // This test verifies the optimization produces correct results. + + val edges = Seq( + (0L, 1L), + (1L, 2L), + (2L, 4L), + (2L, 0L), + (3L, 4L), + (4L, 0L), + (4L, 2L)).toDF("src", "dst").cache() + val vertices = GraphFrame.fromEdges(edges).outDegrees.cache() + val numVertices = vertices.count() + val graph = GraphFrame(vertices, edges) + + val alpha = 0.15 + // PageRank only uses Pregel.src(...) - dst state should be automatically skipped + val ranks = graph.pregel + .setMaxIter(5) + .withVertexColumn( + "rank", + lit(1.0 / numVertices), + coalesce(Pregel.msg, lit(0.0)) * (1.0 - alpha) + alpha / numVertices) + .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree")) + .aggMsgs(sum(Pregel.msg)) + .run() + + val result = ranks + .sort(col("id")) + .select("rank") + .as[Double] + .collect() + assert(result.sum === 1.0 +- 1e-6) + val expected = Seq(0.245, 0.224, 0.303, 0.03, 0.197) + result.zip(expected).foreach { case (r, e) => + assert(r === e +- 1e-3) + } + } + + test("automatic dst join NOT skipped when dst columns are referenced") { + // This test uses Pregel.dst("value") in the message expression, + // so the second join must NOT be skipped. + + val n = 5 + val verDF = (1 to n).toDF("id").repartition(3) + val edgeDF = (1 until n) + .map(x => (x, x + 1)) + .toDF("src", "dst") + .repartition(3) + + val graph = GraphFrame(verDF, edgeDF) + + val resultDF = graph.pregel + .setMaxIter(n - 1) + .withVertexColumn( + "value", + when(col("id") === lit(1), lit(1)).otherwise(lit(0)), + when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))) + // This references BOTH src and dst - dst join should NOT be skipped + .sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value"))) + .aggMsgs(max(Pregel.msg)) + .run() + + assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1)) + } + + test("automatic dst join NOT skipped when skipMessagesFromNonActiveVertices is enabled") { + // When skipMessagesFromNonActiveVertices is true, we need dst._pregel_is_active, + // so the second join must NOT be skipped even if message expressions don't use dst. + + val n = 5 + val verDF = (1 to n).toDF("id").repartition(3) + val edgeDF = (1 until n) + .map(x => (x, x + 1)) + .toDF("src", "dst") + .repartition(3) + + val graph = GraphFrame(verDF, edgeDF) + + // This only uses Pregel.src("value"), but skipMessagesFromNonActiveVertices + // requires dst._pregel_is_active, so dst join should NOT be skipped + val resultDF = graph.pregel + .setMaxIter(n - 1) + .setSkipMessagesFromNonActiveVertices(true) + .setUpdateActiveVertexExpression(Pregel.msg.isNotNull) + .withVertexColumn( + "value", + when(col("id") === lit(1), lit(1)).otherwise(lit(0)), + when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))) + .sendMsgToDst(Pregel.src("value")) + .aggMsgs(max(Pregel.msg)) + .run() + + assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1)) + } + + test("automatic dst join skipping - sendMsgToSrc with only edge columns") { + // When sending messages to src using only edge columns, dst join should be skipped + + val edges = Seq( + (1L, 0L, 10L), + (2L, 1L, 20L), + (3L, 2L, 30L), + (4L, 3L, 40L)).toDF("src", "dst", "weight").cache() + val vertices = (0L to 4L).toDF("id").cache() + + val graph = GraphFrame(vertices, edges) + + // Only uses Pregel.edge("weight") - dst join should be skipped + val resultDF = graph.pregel + .setMaxIter(1) + .withVertexColumn( + "received", + lit(0L), + coalesce(Pregel.msg, col("received"))) + .sendMsgToSrc(Pregel.edge("weight")) + .aggMsgs(sum(Pregel.msg)) + .run() + + // Each src vertex receives the weight from its outgoing edge + val received = resultDF.sort("id").select("received").as[Long].collect() + assert(received(0) === 0L) // vertex 0: no outgoing edges + assert(received(1) === 10L) // vertex 1: edge 1->0 with weight 10 + assert(received(2) === 20L) // vertex 2: edge 2->1 with weight 20 + assert(received(3) === 30L) // vertex 3: edge 3->2 with weight 30 + assert(received(4) === 40L) // vertex 4: edge 4->3 with weight 40 + } + + test("automatic dst join skipping - edge columns only") { + // When message expressions only reference edge columns, dst join should be skipped + + val edges = Seq( + (0L, 1L, 1.0), + (1L, 2L, 2.0), + (2L, 3L, 3.0)).toDF("src", "dst", "weight").cache() + val vertices = Seq(0L, 1L, 2L, 3L).toDF("id").cache() + val graph = GraphFrame(vertices, edges) + + // Only uses Pregel.edge("weight") - dst join should be skipped + val result = graph.pregel + .setMaxIter(1) // Single iteration to simplify testing + .withVertexColumn( + "total", + lit(0.0), + coalesce(Pregel.msg, col("total"))) + .sendMsgToDst(Pregel.edge("weight")) + .aggMsgs(sum(Pregel.msg)) + .run() + + // Verify results: vertex 1 gets weight 1.0, vertex 2 gets 2.0, vertex 3 gets 3.0 + val totals = result.sort("id").select("total").as[Double].collect() + assert(totals(0) === 0.0 +- 1e-6) // vertex 0: no incoming edges + assert(totals(1) === 1.0 +- 1e-6) // vertex 1: edge 0->1 with weight 1.0 + assert(totals(2) === 2.0 +- 1e-6) // vertex 2: edge 1->2 with weight 2.0 + assert(totals(3) === 3.0 +- 1e-6) // vertex 3: edge 2->3 with weight 3.0 + } } From 65b7264152badbecbc4fece4d4faa3b290032db8 Mon Sep 17 00:00:00 2001 From: jameswillis Date: Fri, 6 Feb 2026 13:27:43 -0800 Subject: [PATCH 02/13] fix: only analyze message expressions for dst detection, provide dst.id when skipping join The previous implementation incorrectly checked both the target ID expression and message expression for dst.* references. Since sendMsgToDst uses Pregel.dst(ID) as the target, it would always detect 'dst' as referenced even when the message itself only used src columns. This fix: 1. Only analyzes the message expressions (not target ID) for dst.* references 2. When skipping the join, creates a minimal dst struct with just the id from edge_dst so that sendMsgToDst can still route messages correctly Added test: 'sendMsgToDst with only src columns in message' to verify the optimization works correctly when dst.id is implicitly used for routing. --- .../scala/org/graphframes/lib/Pregel.scala | 20 ++++++------ .../org/graphframes/lib/PregelSuite.scala | 31 +++++++++++++++++++ 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index df24f6fd6..f3c74466a 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -427,16 +427,15 @@ class Pregel(val graph: GraphFrame) else (Seq(ID, Pregel.ACTIVE_FLAG_COL) ++ requiredDstColumnsList).distinct.map(col) // Automatic optimization: detect if destination vertex state is needed by analyzing - // all message expressions. If no expression references dst.* columns (other than dst.id - // which is implicitly available), we can skip the second join entirely. + // the MESSAGE expressions only (not the target ID expressions, since dst.id is always + // available from the edge). If no message expression references dst.* columns, + // we can skip the second join entirely. // However, if skipMessagesFromNonActiveVertices is enabled, we need dst._pregel_is_active. - val allMessageExpressions = sendMsgs.toList.flatMap { case (idExpr, msgExpr) => - Seq(idExpr, msgExpr) - } - val allReferencedPrefixes = allMessageExpressions.flatMap { expr => + val messageExpressions = sendMsgs.toList.map { case (_, msgExpr) => msgExpr } + val referencedPrefixesInMessages = messageExpressions.flatMap { expr => SparkShims.extractColumnPrefixes(graph.spark, expr) }.toSet - val needsDstState = allReferencedPrefixes.contains(DST) || skipMessagesFromNonActiveVertices + val needsDstState = referencedPrefixesInMessages.contains(DST) || skipMessagesFromNonActiveVertices if (!needsDstState) { logInfo("Optimization: skipping second join (dst state not required by message expressions)") } @@ -460,8 +459,11 @@ class Pregel(val graph: GraphFrame) col("edge_dst") === Pregel.dst(ID)) .drop(col("edge_src"), col("edge_dst")) } else { - // Skip second join - dst state not needed by any message expression - srcWithEdges.drop(col("edge_src"), col("edge_dst")) + // Skip second join - dst state not needed by any message expression. + // Create a minimal dst struct with just the id from edge_dst for sendMsgToDst to work. + srcWithEdges + .withColumn(DST, struct(col("edge_dst").as(ID))) + .drop(col("edge_src"), col("edge_dst")) } if (skipMessagesFromNonActiveVertices) { diff --git a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala index 39be126ab..601048e39 100644 --- a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala @@ -452,4 +452,35 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(totals(2) === 2.0 +- 1e-6) // vertex 2: edge 1->2 with weight 2.0 assert(totals(3) === 3.0 +- 1e-6) // vertex 3: edge 2->3 with weight 3.0 } + + test("automatic dst join skipping - sendMsgToDst with only src columns in message") { + // When sendMsgToDst is used but the message expression only references src columns, + // the second join should be skipped. The dst.id needed for message routing is + // obtained from the edge's dst column, not from a vertex join. + + val edges = Seq( + (0L, 1L), + (1L, 2L), + (2L, 3L)).toDF("src", "dst").cache() + val vertices = (0L to 3L).toDF("id").cache() + val graph = GraphFrame(vertices, edges) + + // sendMsgToDst but message only uses Pregel.src("id") - dst join should be skipped + val result = graph.pregel + .setMaxIter(1) + .withVertexColumn( + "received", + lit(0L), + coalesce(Pregel.msg, col("received"))) + .sendMsgToDst(Pregel.src("id")) // Message only uses src.id + .aggMsgs(sum(Pregel.msg)) + .run() + + // Each dst vertex receives the src.id from incoming edges + val received = result.sort("id").select("received").as[Long].collect() + assert(received(0) === 0L) // vertex 0: no incoming edges + assert(received(1) === 0L) // vertex 1: edge 0->1, receives src.id = 0 + assert(received(2) === 1L) // vertex 2: edge 1->2, receives src.id = 1 + assert(received(3) === 2L) // vertex 3: edge 2->3, receives src.id = 2 + } } From ffbfc0326f4069636a2806277dc0b80ed8dd4395 Mon Sep 17 00:00:00 2001 From: jameswillis Date: Fri, 6 Feb 2026 13:38:04 -0800 Subject: [PATCH 03/13] style: run scalafmt and update SparkShims docs to be implementation-agnostic --- .../spark/sql/graphframes/SparkShims.scala | 8 ++-- .../spark/sql/graphframes/SparkShims.scala | 8 ++-- .../scala/org/graphframes/lib/Pregel.scala | 6 ++- .../org/graphframes/lib/PregelSuite.scala | 47 +++++-------------- 4 files changed, 25 insertions(+), 44 deletions(-) diff --git a/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala index 1089177b0..2b561dee7 100644 --- a/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala @@ -31,11 +31,11 @@ import scala.collection.mutable object SparkShims { /** - * Extracts all column name prefixes (e.g., "src", "dst", "edge") from a Column expression. + * Extracts all top-level column name prefixes from a Column expression. * - * This is used to detect which triplet columns are referenced in message expressions, - * enabling automatic optimization to skip the second join when destination columns - * are not needed. + * For nested column references like "src.id" or "edge.weight", this extracts the first + * component ("src", "edge"). This is useful for analyzing which struct columns are referenced + * in an expression. * * @param spark * the SparkSession (unused in Spark 3, included for API compatibility with Spark 4) diff --git a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala index 1799186a9..a7f95a88e 100644 --- a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala @@ -34,11 +34,11 @@ import scala.collection.mutable object SparkShims { /** - * Extracts all column name prefixes (e.g., "src", "dst", "edge") from a Column expression. + * Extracts all top-level column name prefixes from a Column expression. * - * This is used to detect which triplet columns are referenced in message expressions, - * enabling automatic optimization to skip the second join when destination columns - * are not needed. + * For nested column references like "src.id" or "edge.weight", this extracts + * the first component ("src", "edge"). This is useful for analyzing which + * struct columns are referenced in an expression. * * @param spark * the SparkSession (needed for expression conversion in Spark 4) diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index f3c74466a..4e78f2a62 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -435,9 +435,11 @@ class Pregel(val graph: GraphFrame) val referencedPrefixesInMessages = messageExpressions.flatMap { expr => SparkShims.extractColumnPrefixes(graph.spark, expr) }.toSet - val needsDstState = referencedPrefixesInMessages.contains(DST) || skipMessagesFromNonActiveVertices + val needsDstState = + referencedPrefixesInMessages.contains(DST) || skipMessagesFromNonActiveVertices if (!needsDstState) { - logInfo("Optimization: skipping second join (dst state not required by message expressions)") + logInfo( + "Optimization: skipping second join (dst state not required by message expressions)") } breakable { diff --git a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala index 601048e39..1b2a62d3c 100644 --- a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala @@ -299,14 +299,9 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { // so the second join (for dst vertex state) should be automatically skipped. // This test verifies the optimization produces correct results. - val edges = Seq( - (0L, 1L), - (1L, 2L), - (2L, 4L), - (2L, 0L), - (3L, 4L), - (4L, 0L), - (4L, 2L)).toDF("src", "dst").cache() + val edges = Seq((0L, 1L), (1L, 2L), (2L, 4L), (2L, 0L), (3L, 4L), (4L, 0L), (4L, 2L)) + .toDF("src", "dst") + .cache() val vertices = GraphFrame.fromEdges(edges).outDegrees.cache() val numVertices = vertices.count() val graph = GraphFrame(vertices, edges) @@ -395,11 +390,9 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("automatic dst join skipping - sendMsgToSrc with only edge columns") { // When sending messages to src using only edge columns, dst join should be skipped - val edges = Seq( - (1L, 0L, 10L), - (2L, 1L, 20L), - (3L, 2L, 30L), - (4L, 3L, 40L)).toDF("src", "dst", "weight").cache() + val edges = Seq((1L, 0L, 10L), (2L, 1L, 20L), (3L, 2L, 30L), (4L, 3L, 40L)) + .toDF("src", "dst", "weight") + .cache() val vertices = (0L to 4L).toDF("id").cache() val graph = GraphFrame(vertices, edges) @@ -407,17 +400,14 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { // Only uses Pregel.edge("weight") - dst join should be skipped val resultDF = graph.pregel .setMaxIter(1) - .withVertexColumn( - "received", - lit(0L), - coalesce(Pregel.msg, col("received"))) + .withVertexColumn("received", lit(0L), coalesce(Pregel.msg, col("received"))) .sendMsgToSrc(Pregel.edge("weight")) .aggMsgs(sum(Pregel.msg)) .run() // Each src vertex receives the weight from its outgoing edge val received = resultDF.sort("id").select("received").as[Long].collect() - assert(received(0) === 0L) // vertex 0: no outgoing edges + assert(received(0) === 0L) // vertex 0: no outgoing edges assert(received(1) === 10L) // vertex 1: edge 1->0 with weight 10 assert(received(2) === 20L) // vertex 2: edge 2->1 with weight 20 assert(received(3) === 30L) // vertex 3: edge 3->2 with weight 30 @@ -427,20 +417,15 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("automatic dst join skipping - edge columns only") { // When message expressions only reference edge columns, dst join should be skipped - val edges = Seq( - (0L, 1L, 1.0), - (1L, 2L, 2.0), - (2L, 3L, 3.0)).toDF("src", "dst", "weight").cache() + val edges = + Seq((0L, 1L, 1.0), (1L, 2L, 2.0), (2L, 3L, 3.0)).toDF("src", "dst", "weight").cache() val vertices = Seq(0L, 1L, 2L, 3L).toDF("id").cache() val graph = GraphFrame(vertices, edges) // Only uses Pregel.edge("weight") - dst join should be skipped val result = graph.pregel .setMaxIter(1) // Single iteration to simplify testing - .withVertexColumn( - "total", - lit(0.0), - coalesce(Pregel.msg, col("total"))) + .withVertexColumn("total", lit(0.0), coalesce(Pregel.msg, col("total"))) .sendMsgToDst(Pregel.edge("weight")) .aggMsgs(sum(Pregel.msg)) .run() @@ -458,20 +443,14 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { // the second join should be skipped. The dst.id needed for message routing is // obtained from the edge's dst column, not from a vertex join. - val edges = Seq( - (0L, 1L), - (1L, 2L), - (2L, 3L)).toDF("src", "dst").cache() + val edges = Seq((0L, 1L), (1L, 2L), (2L, 3L)).toDF("src", "dst").cache() val vertices = (0L to 3L).toDF("id").cache() val graph = GraphFrame(vertices, edges) // sendMsgToDst but message only uses Pregel.src("id") - dst join should be skipped val result = graph.pregel .setMaxIter(1) - .withVertexColumn( - "received", - lit(0L), - coalesce(Pregel.msg, col("received"))) + .withVertexColumn("received", lit(0L), coalesce(Pregel.msg, col("received"))) .sendMsgToDst(Pregel.src("id")) // Message only uses src.id .aggMsgs(sum(Pregel.msg)) .run() From ea67cfc1233d80b0809bb3f22b5bf5429666720e Mon Sep 17 00:00:00 2001 From: jameswillis Date: Fri, 6 Feb 2026 14:13:07 -0800 Subject: [PATCH 04/13] feat: skip dst join when only dst.id is referenced - Add extractColumnReferences to SparkShims returning Map[String, Set[String]] to track which specific fields are accessed under each prefix - Handle resolved expressions (AttributeReference, GetStructField) in addition to unresolved ones for more robust column detection - Update Pregel optimization to skip dst join when only dst.id is referenced since dst.id is available from the edge's dst column - Change optimization log message from logInfo to logDebug - Add test for dst.id-only reference case --- .../spark/sql/graphframes/SparkShims.scala | 74 ++++++++++++++++--- .../spark/sql/graphframes/SparkShims.scala | 74 ++++++++++++++++--- .../scala/org/graphframes/lib/Pregel.scala | 19 +++-- .../org/graphframes/lib/PregelSuite.scala | 24 ++++++ 4 files changed, 168 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala index 2b561dee7..f353336dd 100644 --- a/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala @@ -22,7 +22,11 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Dataset import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.GetStructField +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import scala.annotation.nowarn @@ -31,30 +35,82 @@ import scala.collection.mutable object SparkShims { /** - * Extracts all top-level column name prefixes from a Column expression. + * Extracts all column references from a Column expression, returning a map from top-level + * prefix to the set of nested field names accessed under that prefix. * - * For nested column references like "src.id" or "edge.weight", this extracts the first - * component ("src", "edge"). This is useful for analyzing which struct columns are referenced - * in an expression. + * For nested column references like "src.id" or "edge.weight", this returns Map("src" -> + * Set("id"), "edge" -> Set("weight")). For top-level references like "src" (the whole struct), + * it returns Map("src" -> Set()). + * + * This handles both unresolved expressions (UnresolvedAttribute, UnresolvedExtractValue) and + * resolved expressions (AttributeReference, GetStructField). * * @param spark * the SparkSession (unused in Spark 3, included for API compatibility with Spark 4) * @param expr * the Column expression to analyze * @return - * a Set of column name prefixes found in the expression + * a Map from column prefix to the set of nested field names accessed */ @nowarn - def extractColumnPrefixes(spark: SparkSession, expr: Column): Set[String] = { - val prefixes = mutable.Set.empty[String] + def extractColumnReferences(spark: SparkSession, expr: Column): Map[String, Set[String]] = { + val refs = mutable.Map.empty[String, mutable.Set[String]] + + def addRef(prefix: String, field: Option[String]): Unit = { + val fields = refs.getOrElseUpdate(prefix, mutable.Set.empty[String]) + field.foreach(fields += _) + } + expr.expr.foreach { + // Unresolved: col("src.id") or Pregel.src("id") -> UnresolvedAttribute(Seq("src", "id")) case UnresolvedAttribute(nameParts) if nameParts.nonEmpty => - prefixes += nameParts.head + addRef(nameParts.head, nameParts.lift(1)) + + // Unresolved: col("src")("id") -> UnresolvedExtractValue + case UnresolvedExtractValue(child, extraction) => + child match { + case UnresolvedAttribute(nameParts) if nameParts.nonEmpty => + extraction match { + case Literal(fieldName: String, _) => addRef(nameParts.head, Some(fieldName)) + case _ => addRef(nameParts.head, None) // Unknown field access + } + case _ => // Nested extraction we can't easily parse + } + + // Resolved: AttributeReference for top-level columns + case attr: AttributeReference => + addRef(attr.name, None) + + // Resolved: GetStructField for nested field access like struct.field + case GetStructField(child, _, Some(fieldName)) => + child match { + case attr: AttributeReference => addRef(attr.name, Some(fieldName)) + case _ => // Nested struct access we can't easily parse + } + case _ => // ignore other expression types } - prefixes.toSet + + refs.map { case (k, v) => k -> v.toSet }.toMap } + /** + * Extracts all top-level column name prefixes from a Column expression. + * + * For nested column references like "src.id" or "edge.weight", this extracts the first + * component ("src", "edge"). This is useful for analyzing which struct columns are referenced + * in an expression. + * + * @param spark + * the SparkSession (unused in Spark 3, included for API compatibility with Spark 4) + * @param expr + * the Column expression to analyze + * @return + * a Set of column name prefixes found in the expression + */ + def extractColumnPrefixes(spark: SparkSession, expr: Column): Set[String] = + extractColumnReferences(spark, expr).keySet + /** * Apply the given SQL expression (such as `id = 3`) to the field in a column, rather than to * the column itself. diff --git a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala index a7f95a88e..ce9b11525 100644 --- a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala @@ -21,7 +21,11 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.GetStructField +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.classic.ClassicConversions.* import org.apache.spark.sql.classic.DataFrame as ClassicDataFrame @@ -34,30 +38,82 @@ import scala.collection.mutable object SparkShims { /** - * Extracts all top-level column name prefixes from a Column expression. + * Extracts all column references from a Column expression, returning a map from top-level + * prefix to the set of nested field names accessed under that prefix. + * + * For nested column references like "src.id" or "edge.weight", this returns Map("src" -> + * Set("id"), "edge" -> Set("weight")). For top-level references like "src" (the whole struct), + * it returns Map("src" -> Set()). * - * For nested column references like "src.id" or "edge.weight", this extracts - * the first component ("src", "edge"). This is useful for analyzing which - * struct columns are referenced in an expression. + * This handles both unresolved expressions (UnresolvedAttribute, UnresolvedExtractValue) and + * resolved expressions (AttributeReference, GetStructField). * * @param spark * the SparkSession (needed for expression conversion in Spark 4) * @param expr * the Column expression to analyze * @return - * a Set of column name prefixes found in the expression + * a Map from column prefix to the set of nested field names accessed */ - def extractColumnPrefixes(spark: SparkSession, expr: Column): Set[String] = { - val prefixes = mutable.Set.empty[String] + def extractColumnReferences(spark: SparkSession, expr: Column): Map[String, Set[String]] = { + val refs = mutable.Map.empty[String, mutable.Set[String]] + + def addRef(prefix: String, field: Option[String]): Unit = { + val fields = refs.getOrElseUpdate(prefix, mutable.Set.empty[String]) + field.foreach(fields += _) + } + val converted = spark.asInstanceOf[ClassicSparkSession].converter(expr.node) converted.foreach { + // Unresolved: col("src.id") or Pregel.src("id") -> UnresolvedAttribute(Seq("src", "id")) case UnresolvedAttribute(nameParts) if nameParts.nonEmpty => - prefixes += nameParts.head + addRef(nameParts.head, nameParts.lift(1)) + + // Unresolved: col("src")("id") -> UnresolvedExtractValue + case UnresolvedExtractValue(child, extraction) => + child match { + case UnresolvedAttribute(nameParts) if nameParts.nonEmpty => + extraction match { + case Literal(fieldName: String, _) => addRef(nameParts.head, Some(fieldName)) + case _ => addRef(nameParts.head, None) // Unknown field access + } + case _ => // Nested extraction we can't easily parse + } + + // Resolved: AttributeReference for top-level columns + case attr: AttributeReference => + addRef(attr.name, None) + + // Resolved: GetStructField for nested field access like struct.field + case GetStructField(child, _, Some(fieldName)) => + child match { + case attr: AttributeReference => addRef(attr.name, Some(fieldName)) + case _ => // Nested struct access we can't easily parse + } + case _ => // ignore other expression types } - prefixes.toSet + + refs.map { case (k, v) => k -> v.toSet }.toMap } + /** + * Extracts all top-level column name prefixes from a Column expression. + * + * For nested column references like "src.id" or "edge.weight", this extracts the first + * component ("src", "edge"). This is useful for analyzing which struct columns are referenced + * in an expression. + * + * @param spark + * the SparkSession (needed for expression conversion in Spark 4) + * @param expr + * the Column expression to analyze + * @return + * a Set of column name prefixes found in the expression + */ + def extractColumnPrefixes(spark: SparkSession, expr: Column): Set[String] = + extractColumnReferences(spark, expr).keySet + /** * Apply the given SQL expression (such as `id = 3`) to the field in a column, rather than to * the column itself. diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index 4e78f2a62..cd4eae7a8 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -431,14 +431,23 @@ class Pregel(val graph: GraphFrame) // available from the edge). If no message expression references dst.* columns, // we can skip the second join entirely. // However, if skipMessagesFromNonActiveVertices is enabled, we need dst._pregel_is_active. + // Additionally, if the only dst field referenced is "id", we can still skip since + // dst.id is available from the edge's dst column. val messageExpressions = sendMsgs.toList.map { case (_, msgExpr) => msgExpr } - val referencedPrefixesInMessages = messageExpressions.flatMap { expr => - SparkShims.extractColumnPrefixes(graph.spark, expr) + val dstFieldsReferenced = messageExpressions.flatMap { expr => + SparkShims.extractColumnReferences(graph.spark, expr).getOrElse(DST, Set.empty) }.toSet - val needsDstState = - referencedPrefixesInMessages.contains(DST) || skipMessagesFromNonActiveVertices + val dstPrefixReferenced = messageExpressions.exists { expr => + SparkShims.extractColumnReferences(graph.spark, expr).contains(DST) + } + // We need the dst join if: + // 1. skipMessagesFromNonActiveVertices is enabled (needs dst._pregel_is_active), OR + // 2. dst is referenced AND fields other than just "id" are accessed + // (empty set means whole struct access like col("dst"), which also needs the join) + val needsDstState = skipMessagesFromNonActiveVertices || + (dstPrefixReferenced && (dstFieldsReferenced.isEmpty || dstFieldsReferenced != Set(ID))) if (!needsDstState) { - logInfo( + logDebug( "Optimization: skipping second join (dst state not required by message expressions)") } diff --git a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala index 1b2a62d3c..6263d3891 100644 --- a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala @@ -462,4 +462,28 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(received(2) === 1L) // vertex 2: edge 1->2, receives src.id = 1 assert(received(3) === 2L) // vertex 3: edge 2->3, receives src.id = 2 } + + test("automatic dst join skipping - message references only dst.id") { + // When message expressions only reference dst.id (not other dst fields), + // the join should still be skipped since dst.id is available from the edge. + + val edges = Seq((0L, 1L), (1L, 2L), (2L, 3L)).toDF("src", "dst").cache() + val vertices = (0L to 3L).toDF("id").cache() + val graph = GraphFrame(vertices, edges) + + // Message uses Pregel.dst("id") - but since only id is used, dst join should be skipped + val result = graph.pregel + .setMaxIter(1) + .withVertexColumn("received", lit(0L), coalesce(Pregel.msg, col("received"))) + .sendMsgToDst(Pregel.src("id") + Pregel.dst("id")) // Uses dst.id only + .aggMsgs(sum(Pregel.msg)) + .run() + + // Each dst vertex receives src.id + dst.id from incoming edges + val received = result.sort("id").select("received").as[Long].collect() + assert(received(0) === 0L) // vertex 0: no incoming edges + assert(received(1) === 1L) // vertex 1: edge 0->1, receives 0 + 1 = 1 + assert(received(2) === 3L) // vertex 2: edge 1->2, receives 1 + 2 = 3 + assert(received(3) === 5L) // vertex 3: edge 2->3, receives 2 + 3 = 5 + } } From 9b90e54cf21e3c23eebb43eb46ee29a0642e76c6 Mon Sep 17 00:00:00 2001 From: jameswillis Date: Fri, 6 Feb 2026 14:17:07 -0800 Subject: [PATCH 05/13] refactor: address reviewer feedback on dst join optimization - Remove unused extractColumnPrefixes method from SparkShims (both Spark 3/4) - Refactor Pregel.scala to parse expressions once instead of twice - Add documentation for deeply nested struct access fallback behavior --- .../spark/sql/graphframes/SparkShims.scala | 26 +++++-------------- .../spark/sql/graphframes/SparkShims.scala | 26 +++++-------------- .../scala/org/graphframes/lib/Pregel.scala | 9 +++---- 3 files changed, 18 insertions(+), 43 deletions(-) diff --git a/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala index f353336dd..101dd86e1 100644 --- a/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala @@ -45,6 +45,10 @@ object SparkShims { * This handles both unresolved expressions (UnresolvedAttribute, UnresolvedExtractValue) and * resolved expressions (AttributeReference, GetStructField). * + * Note: Deeply nested struct access (e.g., "dst.location.city") is not fully parsed. In such + * cases, the prefix is recorded with an empty field set, which causes callers to conservatively + * assume the entire struct is needed. This is the safe/correct fallback behavior. + * * @param spark * the SparkSession (unused in Spark 3, included for API compatibility with Spark 4) * @param expr @@ -74,7 +78,7 @@ object SparkShims { case Literal(fieldName: String, _) => addRef(nameParts.head, Some(fieldName)) case _ => addRef(nameParts.head, None) // Unknown field access } - case _ => // Nested extraction we can't easily parse + case _ => // Nested extraction we can't easily parse - conservative fallback } // Resolved: AttributeReference for top-level columns @@ -82,10 +86,11 @@ object SparkShims { addRef(attr.name, None) // Resolved: GetStructField for nested field access like struct.field + // Note: Only handles single-level nesting; deeper nesting falls through to default case case GetStructField(child, _, Some(fieldName)) => child match { case attr: AttributeReference => addRef(attr.name, Some(fieldName)) - case _ => // Nested struct access we can't easily parse + case _ => // Deeply nested struct access - conservative fallback (join will be used) } case _ => // ignore other expression types @@ -94,23 +99,6 @@ object SparkShims { refs.map { case (k, v) => k -> v.toSet }.toMap } - /** - * Extracts all top-level column name prefixes from a Column expression. - * - * For nested column references like "src.id" or "edge.weight", this extracts the first - * component ("src", "edge"). This is useful for analyzing which struct columns are referenced - * in an expression. - * - * @param spark - * the SparkSession (unused in Spark 3, included for API compatibility with Spark 4) - * @param expr - * the Column expression to analyze - * @return - * a Set of column name prefixes found in the expression - */ - def extractColumnPrefixes(spark: SparkSession, expr: Column): Set[String] = - extractColumnReferences(spark, expr).keySet - /** * Apply the given SQL expression (such as `id = 3`) to the field in a column, rather than to * the column itself. diff --git a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala index ce9b11525..d6e8f3c17 100644 --- a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala @@ -48,6 +48,10 @@ object SparkShims { * This handles both unresolved expressions (UnresolvedAttribute, UnresolvedExtractValue) and * resolved expressions (AttributeReference, GetStructField). * + * Note: Deeply nested struct access (e.g., "dst.location.city") is not fully parsed. In such + * cases, the prefix is recorded with an empty field set, which causes callers to conservatively + * assume the entire struct is needed. This is the safe/correct fallback behavior. + * * @param spark * the SparkSession (needed for expression conversion in Spark 4) * @param expr @@ -77,7 +81,7 @@ object SparkShims { case Literal(fieldName: String, _) => addRef(nameParts.head, Some(fieldName)) case _ => addRef(nameParts.head, None) // Unknown field access } - case _ => // Nested extraction we can't easily parse + case _ => // Nested extraction we can't easily parse - conservative fallback } // Resolved: AttributeReference for top-level columns @@ -85,10 +89,11 @@ object SparkShims { addRef(attr.name, None) // Resolved: GetStructField for nested field access like struct.field + // Note: Only handles single-level nesting; deeper nesting falls through to default case case GetStructField(child, _, Some(fieldName)) => child match { case attr: AttributeReference => addRef(attr.name, Some(fieldName)) - case _ => // Nested struct access we can't easily parse + case _ => // Deeply nested struct access - conservative fallback (join will be used) } case _ => // ignore other expression types @@ -97,23 +102,6 @@ object SparkShims { refs.map { case (k, v) => k -> v.toSet }.toMap } - /** - * Extracts all top-level column name prefixes from a Column expression. - * - * For nested column references like "src.id" or "edge.weight", this extracts the first - * component ("src", "edge"). This is useful for analyzing which struct columns are referenced - * in an expression. - * - * @param spark - * the SparkSession (needed for expression conversion in Spark 4) - * @param expr - * the Column expression to analyze - * @return - * a Set of column name prefixes found in the expression - */ - def extractColumnPrefixes(spark: SparkSession, expr: Column): Set[String] = - extractColumnReferences(spark, expr).keySet - /** * Apply the given SQL expression (such as `id = 3`) to the field in a column, rather than to * the column itself. diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index cd4eae7a8..9177c55a0 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -434,12 +434,11 @@ class Pregel(val graph: GraphFrame) // Additionally, if the only dst field referenced is "id", we can still skip since // dst.id is available from the edge's dst column. val messageExpressions = sendMsgs.toList.map { case (_, msgExpr) => msgExpr } - val dstFieldsReferenced = messageExpressions.flatMap { expr => - SparkShims.extractColumnReferences(graph.spark, expr).getOrElse(DST, Set.empty) - }.toSet - val dstPrefixReferenced = messageExpressions.exists { expr => - SparkShims.extractColumnReferences(graph.spark, expr).contains(DST) + val allDstRefs = messageExpressions.flatMap { expr => + SparkShims.extractColumnReferences(graph.spark, expr).get(DST) } + val dstPrefixReferenced = allDstRefs.nonEmpty + val dstFieldsReferenced = allDstRefs.flatten.toSet // We need the dst join if: // 1. skipMessagesFromNonActiveVertices is enabled (needs dst._pregel_is_active), OR // 2. dst is referenced AND fields other than just "id" are accessed From 5bd03f508d72c90fe28c84949261526d3b1e04d1 Mon Sep 17 00:00:00 2001 From: jameswillis Date: Sat, 7 Feb 2026 11:55:40 -0800 Subject: [PATCH 06/13] test: add comprehensive tests for extractColumnReferences and dst join optimization - Add SparkShimsSuite with 22 unit tests for column reference extraction - Add 4 integration tests to PregelSuite for complex dst usage patterns - Fix UTF8String handling in UnresolvedExtractValue pattern matching --- .../spark/sql/graphframes/SparkShims.scala | 3 + .../spark/sql/graphframes/SparkShims.scala | 3 + .../org/graphframes/SparkShimsSuite.scala | 204 ++++++++++++++++++ .../org/graphframes/lib/PregelSuite.scala | 96 +++++++++ 4 files changed, 306 insertions(+) create mode 100644 core/src/test/scala/org/graphframes/SparkShimsSuite.scala diff --git a/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala index 101dd86e1..0338ba175 100644 --- a/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala @@ -76,6 +76,9 @@ object SparkShims { case UnresolvedAttribute(nameParts) if nameParts.nonEmpty => extraction match { case Literal(fieldName: String, _) => addRef(nameParts.head, Some(fieldName)) + case Literal(fieldName, _) if fieldName != null => + // Handle UTF8String (Spark's internal string representation) + addRef(nameParts.head, Some(fieldName.toString)) case _ => addRef(nameParts.head, None) // Unknown field access } case _ => // Nested extraction we can't easily parse - conservative fallback diff --git a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala index d6e8f3c17..68c00fbbd 100644 --- a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala @@ -79,6 +79,9 @@ object SparkShims { case UnresolvedAttribute(nameParts) if nameParts.nonEmpty => extraction match { case Literal(fieldName: String, _) => addRef(nameParts.head, Some(fieldName)) + case Literal(fieldName, _) if fieldName != null => + // Handle UTF8String (Spark's internal string representation) + addRef(nameParts.head, Some(fieldName.toString)) case _ => addRef(nameParts.head, None) // Unknown field access } case _ => // Nested extraction we can't easily parse - conservative fallback diff --git a/core/src/test/scala/org/graphframes/SparkShimsSuite.scala b/core/src/test/scala/org/graphframes/SparkShimsSuite.scala new file mode 100644 index 000000000..c314f2c93 --- /dev/null +++ b/core/src/test/scala/org/graphframes/SparkShimsSuite.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.graphframes + +import org.apache.spark.sql.functions.* +import org.apache.spark.sql.graphframes.SparkShims +import org.graphframes.lib.Pregel + +/** + * Unit tests for SparkShims.extractColumnReferences. + * + * These tests verify that column references are correctly extracted from various expression + * patterns, which is critical for the Pregel dst join optimization. The optimization skips a join + * when dst columns are not referenced, so we must correctly identify all column references. + */ +class SparkShimsSuite extends SparkFunSuite with GraphFrameTestSparkContext { + + // ============================================================================ + // Basic Column References + // ============================================================================ + + test("extractColumnReferences - simple dot notation") { + val refs = SparkShims.extractColumnReferences(spark, col("src.id")) + assert(refs === Map("src" -> Set("id"))) + } + + test("extractColumnReferences - Pregel.src helper") { + val refs = SparkShims.extractColumnReferences(spark, Pregel.src("rank")) + assert(refs === Map("src" -> Set("rank"))) + } + + test("extractColumnReferences - Pregel.dst helper") { + val refs = SparkShims.extractColumnReferences(spark, Pregel.dst("value")) + assert(refs === Map("dst" -> Set("value"))) + } + + test("extractColumnReferences - Pregel.edge helper") { + val refs = SparkShims.extractColumnReferences(spark, Pregel.edge("weight")) + assert(refs === Map("edge" -> Set("weight"))) + } + + test("extractColumnReferences - whole struct reference") { + val refs = SparkShims.extractColumnReferences(spark, col("dst")) + assert(refs === Map("dst" -> Set())) + } + + // ============================================================================ + // Bracket Notation + // ============================================================================ + + test("extractColumnReferences - bracket notation") { + val refs = SparkShims.extractColumnReferences(spark, col("src")("id")) + assert(refs === Map("src" -> Set("id"))) + } + + // ============================================================================ + // Complex Expressions with Multiple References + // ============================================================================ + + test("extractColumnReferences - arithmetic with multiple refs from same prefix") { + val refs = + SparkShims.extractColumnReferences(spark, Pregel.src("rank") / Pregel.src("outDegree")) + assert(refs === Map("src" -> Set("rank", "outDegree"))) + } + + test("extractColumnReferences - expression with src, dst, and edge") { + val refs = SparkShims.extractColumnReferences( + spark, + Pregel.src("value") + Pregel.dst("value") + Pregel.edge("weight")) + assert(refs === Map("src" -> Set("value"), "dst" -> Set("value"), "edge" -> Set("weight"))) + } + + test("extractColumnReferences - when/case expression") { + val refs = SparkShims.extractColumnReferences( + spark, + when(Pregel.dst("value") > Pregel.src("value"), Pregel.edge("weight"))) + assert(refs.contains("dst"), "Should detect dst reference") + assert(refs.contains("src"), "Should detect src reference") + assert(refs.contains("edge"), "Should detect edge reference") + } + + test("extractColumnReferences - coalesce with multiple refs") { + val refs = + SparkShims.extractColumnReferences(spark, coalesce(col("dst.value"), col("src.default"))) + assert(refs.contains("dst")) + assert(refs.contains("src")) + } + + // ============================================================================ + // Column Used as Map/Array Key - Critical Cases! + // These verify that foreach traversal catches column refs used as arguments + // ============================================================================ + + test("extractColumnReferences - column used as map key via element_at") { + // element_at(col("edge.weights"), col("dst.name")) + // Should detect BOTH "edge" and "dst" references + val refs = + SparkShims.extractColumnReferences(spark, element_at(col("edge.weights"), col("dst.name"))) + assert(refs.contains("edge"), "Should detect edge reference (the map)") + assert(refs.contains("dst"), "Should detect dst reference (used as map key)") + } + + test("extractColumnReferences - column used as array index via element_at") { + // element_at(col("edge.values"), col("dst.index")) + // Should detect BOTH "edge" and "dst" references + val refs = + SparkShims.extractColumnReferences(spark, element_at(col("edge.values"), col("dst.index"))) + assert(refs.contains("edge"), "Should detect edge reference (the array)") + assert(refs.contains("dst"), "Should detect dst reference (used as array index)") + } + + test("extractColumnReferences - column in nested function call") { + // concat(col("src.prefix"), col("dst.suffix")) + val refs = + SparkShims.extractColumnReferences(spark, concat(col("src.prefix"), col("dst.suffix"))) + assert(refs.contains("src")) + assert(refs.contains("dst")) + } + + test("extractColumnReferences - column in aggregate-like expression") { + // greatest(col("src.value"), col("dst.value"), col("edge.weight")) + val refs = SparkShims.extractColumnReferences( + spark, + greatest(col("src.value"), col("dst.value"), col("edge.weight"))) + assert(refs.contains("src")) + assert(refs.contains("dst")) + assert(refs.contains("edge")) + } + + // ============================================================================ + // Deeply Nested Struct Access + // ============================================================================ + + test("extractColumnReferences - deeply nested struct via bracket notation") { + // col("dst")("location")("city") - three levels deep + // Should still detect "dst" as a prefix + val refs = SparkShims.extractColumnReferences(spark, col("dst")("location")("city")) + assert(refs.contains("dst"), "Should detect dst prefix even for deeply nested access") + } + + test("extractColumnReferences - two-level nesting via dot notation") { + // col("dst.location.city") - parsed as UnresolvedAttribute(Seq("dst", "location", "city")) + val refs = SparkShims.extractColumnReferences(spark, col("dst.location.city")) + assert(refs.contains("dst")) + // We should get "location" as the first-level field (we only parse one level deep) + assert(refs("dst").contains("location")) + } + + // ============================================================================ + // Edge Cases - No Column References + // ============================================================================ + + test("extractColumnReferences - literal only expression") { + val refs = SparkShims.extractColumnReferences(spark, lit(42)) + assert(refs.isEmpty) + } + + test("extractColumnReferences - null literal") { + val refs = SparkShims.extractColumnReferences(spark, lit(null)) + assert(refs.isEmpty) + } + + test("extractColumnReferences - complex math with no columns") { + val refs = SparkShims.extractColumnReferences(spark, lit(1) + lit(2) * lit(3)) + assert(refs.isEmpty) + } + + test("extractColumnReferences - string literal") { + val refs = SparkShims.extractColumnReferences(spark, lit("hello")) + assert(refs.isEmpty) + } + + // ============================================================================ + // Mixed Expressions - Columns and Literals + // ============================================================================ + + test("extractColumnReferences - column plus literal") { + val refs = SparkShims.extractColumnReferences(spark, col("src.value") + lit(10)) + assert(refs === Map("src" -> Set("value"))) + } + + test("extractColumnReferences - conditional with literal fallback") { + val refs = SparkShims.extractColumnReferences( + spark, + when(col("dst.flag"), col("src.value")).otherwise(lit(0))) + assert(refs.contains("dst")) + assert(refs.contains("src")) + } +} diff --git a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala index 6263d3891..84a4e828f 100644 --- a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala @@ -486,4 +486,100 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(received(2) === 3L) // vertex 2: edge 1->2, receives 1 + 2 = 3 assert(received(3) === 5L) // vertex 3: edge 2->3, receives 2 + 3 = 5 } + + // ============================================================================ + // Integration tests for complex expression patterns + // These verify that dst join is correctly performed when dst columns are used + // in non-trivial ways (map keys, array indices, conditionals, nested structs) + // ============================================================================ + + test("dst join required when dst column used in conditional") { + // when(Pregel.dst("flag"), Pregel.src("value")) - dst.flag requires the join + val vertices = Seq((0L, true, 10L), (1L, false, 20L), (2L, true, 30L)) + .toDF("id", "flag", "value") + val edges = Seq((0L, 1L), (1L, 2L)).toDF("src", "dst") + val graph = GraphFrame(vertices, edges) + + val result = graph.pregel + .setMaxIter(1) + .withVertexColumn("received", lit(0L), coalesce(Pregel.msg, col("received"))) + .sendMsgToDst(when(Pregel.dst("flag"), Pregel.src("value")).otherwise(lit(null))) + .aggMsgs(sum(Pregel.msg)) + .run() + + // Verify correct behavior: message only sent when dst.flag is true + val received = result.sort("id").select("received").as[Long].collect() + assert(received(0) === 0L) // vertex 0: no incoming + assert(received(1) === 0L) // vertex 1: dst.flag=false, so null message (filtered) + assert(received(2) === 20L) // vertex 2: dst.flag=true, receives src.value=20 + } + + test("dst join required when dst column used as map key") { + // Create edges with a map column, use dst vertex attribute as key + val vertices = Seq((0L, "a"), (1L, "b"), (2L, "a")).toDF("id", "key") + val edges = Seq((0L, 1L, Map("a" -> 10L, "b" -> 20L)), (1L, 2L, Map("a" -> 30L, "b" -> 40L))) + .toDF("src", "dst", "weights") + val graph = GraphFrame(vertices, edges) + + val result = graph.pregel + .setMaxIter(1) + .withVertexColumn("received", lit(0L), coalesce(Pregel.msg, col("received"))) + // Use dst.key to look up value in edge.weights map + .sendMsgToDst(element_at(Pregel.edge("weights"), Pregel.dst("key"))) + .aggMsgs(sum(Pregel.msg)) + .run() + + // Verify: dst.key is used to index into map + val received = result.sort("id").select("received").as[Long].collect() + assert(received(0) === 0L) // vertex 0: no incoming + assert(received(1) === 20L) // vertex 1: key="b", edge weights has b->20 + assert(received(2) === 30L) // vertex 2: key="a", edge weights has a->30 + } + + test("dst join required when dst column used as array index") { + // Create edges with array column, use dst vertex attribute as index + val vertices = Seq((0L, 1), (1L, 2), (2L, 1)).toDF("id", "idx") + val edges = Seq((0L, 1L, Array(100L, 200L)), (1L, 2L, Array(300L, 400L))) + .toDF("src", "dst", "values") + val graph = GraphFrame(vertices, edges) + + val result = graph.pregel + .setMaxIter(1) + .withVertexColumn("received", lit(0L), coalesce(Pregel.msg, col("received"))) + // Use dst.idx to index into edge.values array (element_at is 1-based) + .sendMsgToDst(element_at(Pregel.edge("values"), Pregel.dst("idx"))) + .aggMsgs(sum(Pregel.msg)) + .run() + + // Verify: dst.idx is used to index into array + val received = result.sort("id").select("received").as[Long].collect() + assert(received(0) === 0L) // vertex 0: no incoming + assert(received(1) === 200L) // vertex 1: idx=2, array element 2 = 200 + assert(received(2) === 300L) // vertex 2: idx=1, array element 1 = 300 + } + + test("dst join required for nested struct field access") { + // Create vertices with nested struct + val vertices = spark + .createDataFrame(Seq((0L, 1.0, 2.0), (1L, 3.0, 4.0), (2L, 5.0, 6.0))) + .toDF("id", "x", "y") + .selectExpr("id", "named_struct('x', x, 'y', y) as location") + + val edges = Seq((0L, 1L), (1L, 2L)).toDF("src", "dst") + val graph = GraphFrame(vertices, edges) + + val result = graph.pregel + .setMaxIter(1) + .withVertexColumn("received", lit(0.0), coalesce(Pregel.msg, col("received"))) + // Access nested field dst.location.x and src.location.y + .sendMsgToDst(Pregel.dst("location")("x") + Pregel.src("location")("y")) + .aggMsgs(sum(Pregel.msg)) + .run() + + // Verify: nested struct fields are accessed correctly + val received = result.sort("id").select("received").as[Double].collect() + assert(received(0) === 0.0 +- 1e-6) // vertex 0: no incoming + assert(received(1) === 5.0 +- 1e-6) // vertex 1: dst.location.x=3.0 + src.location.y=2.0 + assert(received(2) === 9.0 +- 1e-6) // vertex 2: dst.location.x=5.0 + src.location.y=4.0 + } } From 7e34f89adf09ba827286b27a29fc876c24c2df76 Mon Sep 17 00:00:00 2001 From: James Date: Sat, 28 Feb 2026 12:19:58 -0800 Subject: [PATCH 07/13] refactor: remove Pregel references from SparkShims comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address peer feedback to keep SparkShims implementation-agnostic by removing specific algorithm references from comments while maintaining functional clarity. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../org/apache/spark/sql/graphframes/SparkShims.scala | 2 +- .../org/apache/spark/sql/graphframes/SparkShims.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala index 0338ba175..282e41d5d 100644 --- a/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-3/org/apache/spark/sql/graphframes/SparkShims.scala @@ -66,7 +66,7 @@ object SparkShims { } expr.expr.foreach { - // Unresolved: col("src.id") or Pregel.src("id") -> UnresolvedAttribute(Seq("src", "id")) + // Unresolved: col("src.id") -> UnresolvedAttribute(Seq("src", "id")) case UnresolvedAttribute(nameParts) if nameParts.nonEmpty => addRef(nameParts.head, nameParts.lift(1)) diff --git a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala index 68c00fbbd..0316fa535 100644 --- a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala @@ -69,7 +69,7 @@ object SparkShims { val converted = spark.asInstanceOf[ClassicSparkSession].converter(expr.node) converted.foreach { - // Unresolved: col("src.id") or Pregel.src("id") -> UnresolvedAttribute(Seq("src", "id")) + // Unresolved: col("src.id") -> UnresolvedAttribute(Seq("src", "id")) case UnresolvedAttribute(nameParts) if nameParts.nonEmpty => addRef(nameParts.head, nameParts.lift(1)) From f7ccf17114bca1cb79066a559d9460b19085b3db Mon Sep 17 00:00:00 2001 From: James Date: Sat, 28 Feb 2026 12:24:14 -0800 Subject: [PATCH 08/13] perf: optimize caching and partitioning when skipping dst join MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement peer feedback suggestions: 1. Move cache/checkpoint logic before expensive operations - persist srcWithEdges when skipping dst join to avoid recomputation 2. Change partitioning to src-only when dst join is skipped since dst partitioning is unnecessary 3. Move dst state detection earlier to enable these optimizations These changes provide additional performance improvements for algorithms like PageRank that only need source vertex data. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../scala/org/graphframes/lib/Pregel.scala | 58 ++++++++++--------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index 9177c55a0..b872a2fb1 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -398,9 +398,33 @@ class Pregel(val graph: GraphFrame) ((initialAttributes :+ initialActiveVertexExpression.alias( Pregel.ACTIVE_FLAG_COL)) ++ initVertexCols): _*) + // Automatic optimization: detect if destination vertex state is needed by analyzing + // the MESSAGE expressions only (not the target ID expressions, since dst.id is always + // available from the edge). If no message expression references dst.* columns, + // we can skip the second join entirely. + // However, if skipMessagesFromNonActiveVertices is enabled, we need dst._pregel_is_active. + // Additionally, if the only dst field referenced is "id", we can still skip since + // dst.id is available from the edge's dst column. + val messageExpressions = sendMsgs.toList.map { case (_, msgExpr) => msgExpr } + val allDstRefs = messageExpressions.flatMap { expr => + SparkShims.extractColumnReferences(graph.spark, expr).get(DST) + } + val dstPrefixReferenced = allDstRefs.nonEmpty + val dstFieldsReferenced = allDstRefs.flatten.toSet + // We need the dst join if: + // 1. skipMessagesFromNonActiveVertices is enabled (needs dst._pregel_is_active), OR + // 2. dst is referenced AND fields other than just "id" are accessed + // (empty set means whole struct access like col("dst"), which also needs the join) + val needsDstState = skipMessagesFromNonActiveVertices || + (dstPrefixReferenced && (dstFieldsReferenced.isEmpty || dstFieldsReferenced != Set(ID))) + if (!needsDstState) { + logDebug( + "Optimization: skipping second join (dst state not required by message expressions)") + } + val edges = graph.edges .select(col(SRC).alias("edge_src"), col(DST).alias("edge_dst"), struct(col("*")).as(EDGE)) - .repartition(col("edge_src"), col("edge_dst")) + .repartition((if (needsDstState) Seq(col("edge_src"), col("edge_dst")) else Seq(col("edge_src"))): _*) .persist(intermediateStorageLevel) var iteration = 1 @@ -426,30 +450,6 @@ class Pregel(val graph: GraphFrame) if (requiredDstColumnsList.isEmpty) Seq(col("*")) else (Seq(ID, Pregel.ACTIVE_FLAG_COL) ++ requiredDstColumnsList).distinct.map(col) - // Automatic optimization: detect if destination vertex state is needed by analyzing - // the MESSAGE expressions only (not the target ID expressions, since dst.id is always - // available from the edge). If no message expression references dst.* columns, - // we can skip the second join entirely. - // However, if skipMessagesFromNonActiveVertices is enabled, we need dst._pregel_is_active. - // Additionally, if the only dst field referenced is "id", we can still skip since - // dst.id is available from the edge's dst column. - val messageExpressions = sendMsgs.toList.map { case (_, msgExpr) => msgExpr } - val allDstRefs = messageExpressions.flatMap { expr => - SparkShims.extractColumnReferences(graph.spark, expr).get(DST) - } - val dstPrefixReferenced = allDstRefs.nonEmpty - val dstFieldsReferenced = allDstRefs.flatten.toSet - // We need the dst join if: - // 1. skipMessagesFromNonActiveVertices is enabled (needs dst._pregel_is_active), OR - // 2. dst is referenced AND fields other than just "id" are accessed - // (empty set means whole struct access like col("dst"), which also needs the join) - val needsDstState = skipMessagesFromNonActiveVertices || - (dstPrefixReferenced && (dstFieldsReferenced.isEmpty || dstFieldsReferenced != Set(ID))) - if (!needsDstState) { - logDebug( - "Optimization: skipping second join (dst state not required by message expressions)") - } - breakable { while (iteration <= maxIter) { logInfo(s"start Pregel iteration $iteration / $maxIter") @@ -457,10 +457,16 @@ class Pregel(val graph: GraphFrame) currRoundPersistent.enqueue(currentVertices.persist(intermediateStorageLevel)) // Build triplets: start with src vertex state joined with edges - val srcWithEdges = currentVertices + var srcWithEdges = currentVertices .select(struct(srcCols: _*).as(SRC)) .join(edges, Pregel.src(ID) === col("edge_src")) + // Optimization: persist srcWithEdges when skipping dst join to avoid recomputation + if (!needsDstState) { + srcWithEdges = srcWithEdges.persist(intermediateStorageLevel) + currRoundPersistent.enqueue(srcWithEdges) + } + // Only perform the second join (adding dst vertex state) if needed var tripletsDF = if (needsDstState) { srcWithEdges From d3e5b5a9c26e42c9a10e252befe64247199614e9 Mon Sep 17 00:00:00 2001 From: James Date: Sat, 28 Feb 2026 12:37:59 -0800 Subject: [PATCH 09/13] fix: correct repartition syntax for conditional partitioning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix Scala syntax error in repartition call that was causing CI build failures. Use proper sequence expansion syntax for multiple column repartitioning. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- core/src/main/scala/org/graphframes/lib/Pregel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index b872a2fb1..235e1ac36 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -424,7 +424,7 @@ class Pregel(val graph: GraphFrame) val edges = graph.edges .select(col(SRC).alias("edge_src"), col(DST).alias("edge_dst"), struct(col("*")).as(EDGE)) - .repartition((if (needsDstState) Seq(col("edge_src"), col("edge_dst")) else Seq(col("edge_src"))): _*) + .repartition(if (needsDstState) Seq(col("edge_src"), col("edge_dst")): _* else Seq(col("edge_src")): _*) .persist(intermediateStorageLevel) var iteration = 1 From d92982cb5eefe9aa0380b26aae8ca90074554882 Mon Sep 17 00:00:00 2001 From: James Date: Sat, 28 Feb 2026 13:04:00 -0800 Subject: [PATCH 10/13] fix: correct repartition syntax for conditional partitioning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix Scala syntax error in repartition call that was causing CI build failures. Use proper sequence expansion syntax for multiple column repartitioning. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- core/src/main/scala/org/graphframes/lib/Pregel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index 235e1ac36..b872a2fb1 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -424,7 +424,7 @@ class Pregel(val graph: GraphFrame) val edges = graph.edges .select(col(SRC).alias("edge_src"), col(DST).alias("edge_dst"), struct(col("*")).as(EDGE)) - .repartition(if (needsDstState) Seq(col("edge_src"), col("edge_dst")): _* else Seq(col("edge_src")): _*) + .repartition((if (needsDstState) Seq(col("edge_src"), col("edge_dst")) else Seq(col("edge_src"))): _*) .persist(intermediateStorageLevel) var iteration = 1 From 1fc9b3cbc0e6fabcebfebe813a84f8db978a1060 Mon Sep 17 00:00:00 2001 From: James Date: Sat, 28 Feb 2026 13:27:22 -0800 Subject: [PATCH 11/13] style: apply scalafmt formatting to Pregel.scala MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix formatting issues that were causing CI scalafmt checks to fail. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- core/src/main/scala/org/graphframes/lib/Pregel.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index b872a2fb1..1dffaac72 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -424,7 +424,8 @@ class Pregel(val graph: GraphFrame) val edges = graph.edges .select(col(SRC).alias("edge_src"), col(DST).alias("edge_dst"), struct(col("*")).as(EDGE)) - .repartition((if (needsDstState) Seq(col("edge_src"), col("edge_dst")) else Seq(col("edge_src"))): _*) + .repartition( + (if (needsDstState) Seq(col("edge_src"), col("edge_dst")) else Seq(col("edge_src"))): _*) .persist(intermediateStorageLevel) var iteration = 1 From b24f943387779d610ecae4c9d0a2e38d0cecef94 Mon Sep 17 00:00:00 2001 From: jameswillis Date: Wed, 11 Mar 2026 13:03:36 -0700 Subject: [PATCH 12/13] address Sem's PR comments --- core/src/main/scala/org/graphframes/lib/Pregel.scala | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index 1dffaac72..dacac661f 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -424,8 +424,7 @@ class Pregel(val graph: GraphFrame) val edges = graph.edges .select(col(SRC).alias("edge_src"), col(DST).alias("edge_dst"), struct(col("*")).as(EDGE)) - .repartition( - (if (needsDstState) Seq(col("edge_src"), col("edge_dst")) else Seq(col("edge_src"))): _*) + .repartition(col("edge_src")) .persist(intermediateStorageLevel) var iteration = 1 @@ -458,16 +457,10 @@ class Pregel(val graph: GraphFrame) currRoundPersistent.enqueue(currentVertices.persist(intermediateStorageLevel)) // Build triplets: start with src vertex state joined with edges - var srcWithEdges = currentVertices + val srcWithEdges = currentVertices .select(struct(srcCols: _*).as(SRC)) .join(edges, Pregel.src(ID) === col("edge_src")) - // Optimization: persist srcWithEdges when skipping dst join to avoid recomputation - if (!needsDstState) { - srcWithEdges = srcWithEdges.persist(intermediateStorageLevel) - currRoundPersistent.enqueue(srcWithEdges) - } - // Only perform the second join (adding dst vertex state) if needed var tripletsDF = if (needsDstState) { srcWithEdges From 1be371ca6bd30b9323edd955d2906b6628944074 Mon Sep 17 00:00:00 2001 From: jameswillis Date: Thu, 12 Mar 2026 11:35:02 -0700 Subject: [PATCH 13/13] push triplet filtering when no dst join --- .../scala/org/graphframes/lib/Pregel.scala | 23 +++++++++++-------- .../org/graphframes/lib/PregelSuite.scala | 11 +++++---- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index dacac661f..4030f3d17 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -402,7 +402,6 @@ class Pregel(val graph: GraphFrame) // the MESSAGE expressions only (not the target ID expressions, since dst.id is always // available from the edge). If no message expression references dst.* columns, // we can skip the second join entirely. - // However, if skipMessagesFromNonActiveVertices is enabled, we need dst._pregel_is_active. // Additionally, if the only dst field referenced is "id", we can still skip since // dst.id is available from the edge's dst column. val messageExpressions = sendMsgs.toList.map { case (_, msgExpr) => msgExpr } @@ -411,12 +410,10 @@ class Pregel(val graph: GraphFrame) } val dstPrefixReferenced = allDstRefs.nonEmpty val dstFieldsReferenced = allDstRefs.flatten.toSet - // We need the dst join if: - // 1. skipMessagesFromNonActiveVertices is enabled (needs dst._pregel_is_active), OR - // 2. dst is referenced AND fields other than just "id" are accessed - // (empty set means whole struct access like col("dst"), which also needs the join) - val needsDstState = skipMessagesFromNonActiveVertices || - (dstPrefixReferenced && (dstFieldsReferenced.isEmpty || dstFieldsReferenced != Set(ID))) + + // We need the dst join if dst is referenced AND fields other than just "id" are accessed + val needsDstState = + dstPrefixReferenced && (dstFieldsReferenced.isEmpty || dstFieldsReferenced != Set(ID)) if (!needsDstState) { logDebug( "Optimization: skipping second join (dst state not required by message expressions)") @@ -456,8 +453,15 @@ class Pregel(val graph: GraphFrame) val currRoundPersistent = scala.collection.mutable.Queue[DataFrame]() currRoundPersistent.enqueue(currentVertices.persist(intermediateStorageLevel)) + // Prune non-active vertices early if skipMessagesFromNonActiveVertices + // is enabled and we don't need the dst state. + val srcVertices = + if (!needsDstState && skipMessagesFromNonActiveVertices) + currentVertices.filter(col(Pregel.ACTIVE_FLAG_COL)) + else currentVertices + // Build triplets: start with src vertex state joined with edges - val srcWithEdges = currentVertices + val srcWithEdges = srcVertices .select(struct(srcCols: _*).as(SRC)) .join(edges, Pregel.src(ID) === col("edge_src")) @@ -476,7 +480,8 @@ class Pregel(val graph: GraphFrame) .drop(col("edge_src"), col("edge_dst")) } - if (skipMessagesFromNonActiveVertices) { + // Only prune here if we didn't prune above. + if (needsDstState && skipMessagesFromNonActiveVertices) { tripletsDF = tripletsDF.filter( Pregel.src(Pregel.ACTIVE_FLAG_COL) || Pregel.dst(Pregel.ACTIVE_FLAG_COL)) } diff --git a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala index 84a4e828f..8261d39d2 100644 --- a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala @@ -357,9 +357,10 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1)) } - test("automatic dst join NOT skipped when skipMessagesFromNonActiveVertices is enabled") { - // When skipMessagesFromNonActiveVertices is true, we need dst._pregel_is_active, - // so the second join must NOT be skipped even if message expressions don't use dst. + test("automatic dst join skipping with skipMessagesFromNonActiveVertices enabled") { + // When skipMessagesFromNonActiveVertices is true but message expressions don't + // reference dst columns, the dst join is still skipped. Active-vertex filtering + // is pushed before the src-edge join to reduce data volume. val n = 5 val verDF = (1 to n).toDF("id").repartition(3) @@ -370,8 +371,8 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { val graph = GraphFrame(verDF, edgeDF) - // This only uses Pregel.src("value"), but skipMessagesFromNonActiveVertices - // requires dst._pregel_is_active, so dst join should NOT be skipped + // This only uses Pregel.src("value") - dst join should be skipped, + // and active-vertex filtering is applied before the src-edge join. val resultDF = graph.pregel .setMaxIter(n - 1) .setSkipMessagesFromNonActiveVertices(true)