Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions connect/src/main/protobuf/graphframes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ message Pregel {
optional ColumnOrExpression initial_active_expr = 13;
optional ColumnOrExpression update_active_expr = 14;
optional bool skip_messages_from_non_active = 15;
// Required columns for triplet construction (memory optimization)
// Column names separated by comma
optional string required_src_columns = 16;
optional string required_dst_columns = 17;
Comment thread
SemyonSinchenko marked this conversation as resolved.
}

message ShortestPaths {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,19 @@ object GraphFramesConnectUtils {
pregel = pregel.setEarlyStopping(pregelProto.getEarlyStopping)
}

// Handle required columns for triplet optimization (comma-separated)
if (pregelProto.hasRequiredSrcColumns) {
val cols =
pregelProto.getRequiredSrcColumns.split(",").map(_.trim).filter(_.nonEmpty).toSeq
if (cols.nonEmpty) pregel = pregel.requiredSrcColumns(cols.head, cols.tail: _*)
}

if (pregelProto.hasRequiredDstColumns) {
val cols =
pregelProto.getRequiredDstColumns.split(",").map(_.trim).filter(_.nonEmpty).toSeq
if (cols.nonEmpty) pregel = pregel.requiredDstColumns(cols.head, cols.tail: _*)
}

pregel.run()
}
case proto.GraphFramesAPI.MethodCase.SHORTEST_PATHS => {
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/org/graphframes/lib/DetectingCycles.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ object DetectingCycles {
.withVertexColumn(storedSeqCol, initSequences, updateSequences)
.withVertexColumn(foundSeqCol, foundSequences, updateFound)
.aggMsgs(filterOutSequences)
// Memory optimization: only include required columns in triplets
// For cycle detection, we only need the sequences from source vertex
// and just the ID from destination vertex (ID is always included)
.requiredSrcColumns(storedSeqCol)
.run()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ private object LabelPropagation {
.setUpdateActiveVertexExpression(col(LABEL_ID) =!= keyWithMaxValue(Pregel.msg))
.setUseLocalCheckpoints(useLocalCheckpoints)
.setIntermediateStorageLevel(intermediateStorageLevel)
// Memory optimization: only include required columns in triplets
.requiredSrcColumns(LABEL_ID)
.requiredDstColumns(LABEL_ID)

if (isDirected) {
pregel = pregel.sendMsgToDst(Pregel.src(LABEL_ID))
Expand Down
66 changes: 64 additions & 2 deletions core/src/main/scala/org/graphframes/lib/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class Pregel(val graph: GraphFrame)
private val sendMsgs = collection.mutable.ListBuffer.empty[(Column, Column)]
private var aggMsgsCol: Column = null

// Required columns for source and destination vertices in triplets
// When empty, all columns are selected (default behavior)
private val requiredSrcColumnsList = collection.mutable.ListBuffer.empty[String]
private val requiredDstColumnsList = collection.mutable.ListBuffer.empty[String]

/** Sets the max number of iterations (default: 10). */
def setMaxIter(value: Int): this.type = {
maxIter = value
Expand Down Expand Up @@ -291,6 +296,54 @@ class Pregel(val graph: GraphFrame)
this
}

/**
* Specifies which source vertex columns are required when constructing triplets.
*
* By default, all source vertex columns are included in triplets, which can create large
* intermediate datasets for algorithms with significant state (e.g., cycle detection, random
* walks). Use this method to reduce memory usage by specifying only the columns that are
* actually needed by the sendMsgToSrc and sendMsgToDst expressions.
*
* The ID column and the active flag column (if used) are always included automatically.
*
* @param colName
* the first required source vertex column name
* @param colNames
* additional required source vertex column names
* @see
* [[requiredDstColumns]]
*/
def requiredSrcColumns(colName: String, colNames: String*): this.type = {
requiredSrcColumnsList.clear()
requiredSrcColumnsList += colName
requiredSrcColumnsList ++= colNames
this
}

/**
* Specifies which destination vertex columns are required when constructing triplets.
*
* By default, all destination vertex columns are included in triplets, which can create large
* intermediate datasets for algorithms with significant state (e.g., cycle detection, random
* walks). Use this method to reduce memory usage by specifying only the columns that are
* actually needed by the sendMsgToSrc and sendMsgToDst expressions.
*
* The ID column and the active flag column (if used) are always included automatically.
*
* @param colName
* the first required destination vertex column name
* @param colNames
* additional required destination vertex column names
* @see
* [[requiredSrcColumns]]
*/
def requiredDstColumns(colName: String, colNames: String*): this.type = {
requiredDstColumnsList.clear()
requiredDstColumnsList += colName
requiredDstColumnsList ++= colNames
this
}

/**
* Defines how messages are aggregated after grouped by target vertex IDs.
*
Expand Down Expand Up @@ -364,16 +417,25 @@ class Pregel(val graph: GraphFrame)
}
}

// Columns to include in triplet structs (ID + active flag always included if specified)
val srcCols =
if (requiredSrcColumnsList.isEmpty) Seq(col("*"))
else (Seq(ID, Pregel.ACTIVE_FLAG_COL) ++ requiredSrcColumnsList).distinct.map(col)
val dstCols =
if (requiredDstColumnsList.isEmpty) Seq(col("*"))
else (Seq(ID, Pregel.ACTIVE_FLAG_COL) ++ requiredDstColumnsList).distinct.map(col)

breakable {
while (iteration <= maxIter) {
logInfo(s"start Pregel iteration $iteration / $maxIter")
val currRoundPersistent = scala.collection.mutable.Queue[DataFrame]()
currRoundPersistent.enqueue(currentVertices.persist(intermediateStorageLevel))

var tripletsDF = currentVertices
.select(struct(col("*")).as(SRC))
.select(struct(srcCols: _*).as(SRC))
.join(edges, Pregel.src(ID) === col("edge_src"))
.join(
currentVertices.select(struct(col("*")).as(DST)),
currentVertices.select(struct(dstCols: _*).as(DST)),
col("edge_dst") === Pregel.dst(ID))
.drop(col("edge_src"), col("edge_dst"))

Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/org/graphframes/lib/ShortestPaths.scala
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ private object ShortestPaths extends Logging {
.setSkipMessagesFromNonActiveVertices(true)
.setCheckpointInterval(checkpointInterval)
.setUseLocalCheckpoints(useLocalCheckpoints)
// Memory optimization: only include required columns in triplets
.requiredSrcColumns(DISTANCE_ID)
.requiredDstColumns(DISTANCE_ID)

// Experimental feature
if (isDirected) {
Expand Down
123 changes: 123 additions & 0 deletions core/src/test/scala/org/graphframes/lib/PregelSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,127 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext {
.map(r => r.getAs[Long]("id") -> r.getAs[Int]("newColumn"))
.toMap === Map(1L -> 2, 2L -> 1, 3L -> 2, 4L -> 1))
}

test("requiredSrcColumns - only specified columns are used in triplets") {
// Test that requiredSrcColumns correctly limits the columns in triplets
// This is a memory optimization test - we verify the result is correct
// with only required source columns

val edges = Seq((0L, 1L), (1L, 2L), (2L, 4L), (2L, 0L), (3L, 4L), (4L, 0L), (4L, 2L))
.toDF("src", "dst")
.cache()
val vertices = GraphFrame.fromEdges(edges).outDegrees.cache()
val numVertices = vertices.count()
val graph = GraphFrame(vertices, edges)

val alpha = 0.15
// PageRank only needs "rank" and "outDegree" from source vertex
val ranks = graph.pregel
.setMaxIter(5)
.withVertexColumn(
"rank",
lit(1.0 / numVertices),
coalesce(Pregel.msg, lit(0.0)) * (1.0 - alpha) + alpha / numVertices)
.sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree"))
.aggMsgs(sum(Pregel.msg))
.requiredSrcColumns("rank", "outDegree")
.run()

val result = ranks
.sort(col("id"))
.select("rank")
.as[Double]
.collect()
assert(result.sum === 1.0 +- 1e-6)
val expected = Seq(0.245, 0.224, 0.303, 0.03, 0.197)
result.zip(expected).foreach { case (r, e) =>
assert(r === e +- 1e-3)
}
}

test("requiredDstColumns - only specified columns are used in triplets") {
// Test that requiredDstColumns correctly limits the columns in triplets
// Reverse chain propagation where we only need dst("value") from destination

val n = 5
val verDF = (1 to n).toDF("id").repartition(3)
val edgeDF = (1 until n)
.map(x => (x + 1, x))
.toDF("src", "dst")
.repartition(3)

val graph = GraphFrame(verDF, edgeDF)

val resultDF = graph.pregel
.setMaxIter(n - 1)
.withVertexColumn(
"value",
when(col("id") === lit(1), lit(1)).otherwise(lit(0)),
when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value")))
.sendMsgToSrc(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.dst("value")))
.aggMsgs(max(Pregel.msg))
.requiredDstColumns("value") // Only need "value" from destination
.run()

assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1))
}

test("requiredSrcColumns and requiredDstColumns together") {
// Test using both requiredSrcColumns and requiredDstColumns
// Chain propagation where we need "value" from both src and dst

val n = 5
val verDF = (1 to n).toDF("id").repartition(3)
val edgeDF = (1 until n)
.map(x => (x, x + 1))
.toDF("src", "dst")
.repartition(3)

val graph = GraphFrame(verDF, edgeDF)

val resultDF = graph.pregel
.setMaxIter(n - 1)
.withVertexColumn(
"value",
when(col("id") === lit(1), lit(1)).otherwise(lit(0)),
when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value")))
.sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value")))
.aggMsgs(max(Pregel.msg))
.requiredSrcColumns("value") // Only need "value" from source
.requiredDstColumns("value") // Only need "value" from destination
.run()

assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1))
}

test("requiredSrcColumns with empty list uses all columns (default behavior)") {
// Verify that not calling requiredSrcColumns means all columns are used
// This is the same as the original page rank test

val edges = Seq((0L, 1L), (1L, 2L), (2L, 4L), (2L, 0L), (3L, 4L), (4L, 0L), (4L, 2L))
.toDF("src", "dst")
.cache()
val vertices = GraphFrame.fromEdges(edges).outDegrees.cache()
val numVertices = vertices.count()
val graph = GraphFrame(vertices, edges)

val alpha = 0.15
val ranks = graph.pregel
.setMaxIter(5)
.withVertexColumn(
"rank",
lit(1.0 / numVertices),
coalesce(Pregel.msg, lit(0.0)) * (1.0 - alpha) + alpha / numVertices)
.sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree"))
.aggMsgs(sum(Pregel.msg))
// No requiredSrcColumns or requiredDstColumns - should use all columns
.run()

val result = ranks
.sort(col("id"))
.select("rank")
.as[Double]
.collect()
assert(result.sum === 1.0 +- 1e-6)
}
}
30 changes: 30 additions & 0 deletions docs/src/04-user-guide/10-pregel.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,36 @@ val edgeWeight = Pregel.edge("weight")

Under the hood, the passed name of the column will be resolved to get the corresponding element of the triplet structs.

### Memory Optimization for Triplets

By default, all vertex columns are included when constructing triplets. For algorithms with large per-vertex state (e.g., cycle detection storing sequences, random walks), this can create huge intermediate datasets in memory.

To reduce memory usage, you can specify only the columns that are actually needed using `requiredSrcColumns` and `requiredDstColumns`:

```scala
graph.pregel
.withVertexColumn("distances", ...)
.sendMsgToDst(Pregel.src("distances")) // Only needs "distances" from source
.requiredSrcColumns("distances") // Only include "distances" in src struct
.requiredDstColumns("distances") // Only include "distances" in dst struct
.aggMsgs(...)
.run()
```

In Python:

```python
graph.pregel \
.withVertexColumn("distances", ...) \
.sendMsgToDst(Pregel.src("distances")) \
.required_src_columns("distances") \
.required_dst_columns("distances") \
.aggMsgs(...) \
.run()
```

The `id` column and the active flag column (if used) are always included automatically, so you don't need to specify them.

### Sending Messages

GraphFrames Pregel API support arbitrary number of messages per vertex. Inside the Pregel API **graphs are always considered directed**. This means that if a vertex has an outgoing edge to another vertex, then the message will be sent to the destination vertex. To emulate the behavior of the undirected graph, the user can send the same message to both the source and the destination vertex.
Expand Down
Loading