Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,86 @@ import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.GetStructField
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import scala.annotation.nowarn
import scala.collection.mutable

object SparkShims {

/**
* Extracts all column references from a Column expression, returning a map from top-level
* prefix to the set of nested field names accessed under that prefix.
*
* For nested column references like "src.id" or "edge.weight", this returns Map("src" ->
* Set("id"), "edge" -> Set("weight")). For top-level references like "src" (the whole struct),
* it returns Map("src" -> Set()).
*
* This handles both unresolved expressions (UnresolvedAttribute, UnresolvedExtractValue) and
* resolved expressions (AttributeReference, GetStructField).
*
* Note: Deeply nested struct access (e.g., "dst.location.city") is not fully parsed. In such
* cases, the prefix is recorded with an empty field set, which causes callers to conservatively
* assume the entire struct is needed. This is the safe/correct fallback behavior.
*
* @param spark
* the SparkSession (unused in Spark 3, included for API compatibility with Spark 4)
* @param expr
* the Column expression to analyze
* @return
* a Map from column prefix to the set of nested field names accessed
*/
@nowarn
def extractColumnReferences(spark: SparkSession, expr: Column): Map[String, Set[String]] = {
val refs = mutable.Map.empty[String, mutable.Set[String]]

def addRef(prefix: String, field: Option[String]): Unit = {
val fields = refs.getOrElseUpdate(prefix, mutable.Set.empty[String])
field.foreach(fields += _)
}

expr.expr.foreach {
// Unresolved: col("src.id") -> UnresolvedAttribute(Seq("src", "id"))
case UnresolvedAttribute(nameParts) if nameParts.nonEmpty =>
addRef(nameParts.head, nameParts.lift(1))

// Unresolved: col("src")("id") -> UnresolvedExtractValue
case UnresolvedExtractValue(child, extraction) =>
child match {
case UnresolvedAttribute(nameParts) if nameParts.nonEmpty =>
extraction match {
case Literal(fieldName: String, _) => addRef(nameParts.head, Some(fieldName))
case Literal(fieldName, _) if fieldName != null =>
// Handle UTF8String (Spark's internal string representation)
addRef(nameParts.head, Some(fieldName.toString))
case _ => addRef(nameParts.head, None) // Unknown field access
}
case _ => // Nested extraction we can't easily parse - conservative fallback
}

// Resolved: AttributeReference for top-level columns
case attr: AttributeReference =>
addRef(attr.name, None)

// Resolved: GetStructField for nested field access like struct.field
// Note: Only handles single-level nesting; deeper nesting falls through to default case
case GetStructField(child, _, Some(fieldName)) =>
child match {
case attr: AttributeReference => addRef(attr.name, Some(fieldName))
case _ => // Deeply nested struct access - conservative fallback (join will be used)
}

case _ => // ignore other expression types
}

refs.map { case (k, v) => k -> v.toSet }.toMap
}

/**
* Apply the given SQL expression (such as `id = 3`) to the field in a column, rather than to
* the column itself.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,90 @@ import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.GetStructField
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.ClassicConversions.*
import org.apache.spark.sql.classic.DataFrame as ClassicDataFrame
import org.apache.spark.sql.classic.Dataset
import org.apache.spark.sql.classic.ExpressionUtils
import org.apache.spark.sql.classic.SparkSession as ClassicSparkSession

import scala.collection.mutable

object SparkShims {

/**
* Extracts all column references from a Column expression, returning a map from top-level
* prefix to the set of nested field names accessed under that prefix.
*
* For nested column references like "src.id" or "edge.weight", this returns Map("src" ->
* Set("id"), "edge" -> Set("weight")). For top-level references like "src" (the whole struct),
* it returns Map("src" -> Set()).
*
* This handles both unresolved expressions (UnresolvedAttribute, UnresolvedExtractValue) and
* resolved expressions (AttributeReference, GetStructField).
*
* Note: Deeply nested struct access (e.g., "dst.location.city") is not fully parsed. In such
* cases, the prefix is recorded with an empty field set, which causes callers to conservatively
* assume the entire struct is needed. This is the safe/correct fallback behavior.
*
* @param spark
* the SparkSession (needed for expression conversion in Spark 4)
* @param expr
* the Column expression to analyze
* @return
* a Map from column prefix to the set of nested field names accessed
*/
def extractColumnReferences(spark: SparkSession, expr: Column): Map[String, Set[String]] = {
val refs = mutable.Map.empty[String, mutable.Set[String]]

def addRef(prefix: String, field: Option[String]): Unit = {
val fields = refs.getOrElseUpdate(prefix, mutable.Set.empty[String])
field.foreach(fields += _)
}

val converted = spark.asInstanceOf[ClassicSparkSession].converter(expr.node)
converted.foreach {
// Unresolved: col("src.id") -> UnresolvedAttribute(Seq("src", "id"))
case UnresolvedAttribute(nameParts) if nameParts.nonEmpty =>
addRef(nameParts.head, nameParts.lift(1))

// Unresolved: col("src")("id") -> UnresolvedExtractValue
case UnresolvedExtractValue(child, extraction) =>
child match {
case UnresolvedAttribute(nameParts) if nameParts.nonEmpty =>
extraction match {
case Literal(fieldName: String, _) => addRef(nameParts.head, Some(fieldName))
case Literal(fieldName, _) if fieldName != null =>
// Handle UTF8String (Spark's internal string representation)
addRef(nameParts.head, Some(fieldName.toString))
case _ => addRef(nameParts.head, None) // Unknown field access
}
case _ => // Nested extraction we can't easily parse - conservative fallback
}

// Resolved: AttributeReference for top-level columns
case attr: AttributeReference =>
addRef(attr.name, None)

// Resolved: GetStructField for nested field access like struct.field
// Note: Only handles single-level nesting; deeper nesting falls through to default case
case GetStructField(child, _, Some(fieldName)) =>
child match {
case attr: AttributeReference => addRef(attr.name, Some(fieldName))
case _ => // Deeply nested struct access - conservative fallback (join will be used)
}

case _ => // ignore other expression types
}

refs.map { case (k, v) => k -> v.toSet }.toMap
}

/**
* Apply the given SQL expression (such as `id = 3`) to the field in a column, rather than to
* the column itself.
Expand Down
56 changes: 49 additions & 7 deletions core/src/main/scala/org/graphframes/lib/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.explode
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.struct
import org.apache.spark.sql.graphframes.SparkShims
import org.graphframes.GraphFrame
import org.graphframes.GraphFrame.*
import org.graphframes.Logging
Expand Down Expand Up @@ -397,9 +398,30 @@ class Pregel(val graph: GraphFrame)
((initialAttributes :+ initialActiveVertexExpression.alias(
Pregel.ACTIVE_FLAG_COL)) ++ initVertexCols): _*)

// Automatic optimization: detect if destination vertex state is needed by analyzing
// the MESSAGE expressions only (not the target ID expressions, since dst.id is always
// available from the edge). If no message expression references dst.* columns,
// we can skip the second join entirely.
// Additionally, if the only dst field referenced is "id", we can still skip since
// dst.id is available from the edge's dst column.
val messageExpressions = sendMsgs.toList.map { case (_, msgExpr) => msgExpr }
val allDstRefs = messageExpressions.flatMap { expr =>
SparkShims.extractColumnReferences(graph.spark, expr).get(DST)
}
val dstPrefixReferenced = allDstRefs.nonEmpty
val dstFieldsReferenced = allDstRefs.flatten.toSet

// We need the dst join if dst is referenced AND fields other than just "id" are accessed
val needsDstState =
dstPrefixReferenced && (dstFieldsReferenced.isEmpty || dstFieldsReferenced != Set(ID))
if (!needsDstState) {
logDebug(
"Optimization: skipping second join (dst state not required by message expressions)")
}

val edges = graph.edges
.select(col(SRC).alias("edge_src"), col(DST).alias("edge_dst"), struct(col("*")).as(EDGE))
.repartition(col("edge_src"), col("edge_dst"))
.repartition(col("edge_src"))
.persist(intermediateStorageLevel)

var iteration = 1
Expand Down Expand Up @@ -431,15 +453,35 @@ class Pregel(val graph: GraphFrame)
val currRoundPersistent = scala.collection.mutable.Queue[DataFrame]()
currRoundPersistent.enqueue(currentVertices.persist(intermediateStorageLevel))

var tripletsDF = currentVertices
// Prune non-active vertices early if skipMessagesFromNonActiveVertices
// is enabled and we don't need the dst state.
val srcVertices =
if (!needsDstState && skipMessagesFromNonActiveVertices)
currentVertices.filter(col(Pregel.ACTIVE_FLAG_COL))
else currentVertices

// Build triplets: start with src vertex state joined with edges
val srcWithEdges = srcVertices
.select(struct(srcCols: _*).as(SRC))
.join(edges, Pregel.src(ID) === col("edge_src"))
.join(
currentVertices.select(struct(dstCols: _*).as(DST)),
col("edge_dst") === Pregel.dst(ID))
.drop(col("edge_src"), col("edge_dst"))

if (skipMessagesFromNonActiveVertices) {
// Only perform the second join (adding dst vertex state) if needed
var tripletsDF = if (needsDstState) {
srcWithEdges
.join(
currentVertices.select(struct(dstCols: _*).as(DST)),
col("edge_dst") === Pregel.dst(ID))
.drop(col("edge_src"), col("edge_dst"))
} else {
// Skip second join - dst state not needed by any message expression.
// Create a minimal dst struct with just the id from edge_dst for sendMsgToDst to work.
srcWithEdges
.withColumn(DST, struct(col("edge_dst").as(ID)))
.drop(col("edge_src"), col("edge_dst"))
}

// Only prune here if we didn't prune above.
if (needsDstState && skipMessagesFromNonActiveVertices) {
tripletsDF = tripletsDF.filter(
Pregel.src(Pregel.ACTIVE_FLAG_COL) || Pregel.dst(Pregel.ACTIVE_FLAG_COL))
}
Expand Down
Loading
Loading