From d38bd2b67858c013bcdc5f140fd3a2732c846258 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Wed, 24 Sep 2025 12:52:09 +0200 Subject: [PATCH] better TriangleCount --- .../examples/TriangleCountExample.java | 90 +++++++++++++++ .../org/graphframes/lib/TriangleCount.scala | 108 +++++++++++------- 2 files changed, 159 insertions(+), 39 deletions(-) create mode 100644 core/src/main/java/org/graphframes/examples/TriangleCountExample.java diff --git a/core/src/main/java/org/graphframes/examples/TriangleCountExample.java b/core/src/main/java/org/graphframes/examples/TriangleCountExample.java new file mode 100644 index 000000000..ff8fc2089 --- /dev/null +++ b/core/src/main/java/org/graphframes/examples/TriangleCountExample.java @@ -0,0 +1,90 @@ +package org.graphframes.examples; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.storage.StorageLevel; +import org.graphframes.GraphFrame; +import org.graphframes.lib.TriangleCount; + +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * The TriangleCount class demonstrates how to use the GraphFrames library in Apache Spark + * to count triangles in a graph dataset. A triangle in a graph is defined as a set of + * three interconnected vertices. + *

+ * This examples uses graphs from the LDBC Graphalytics benchmark datasets. + * The first argument is the name of the benchmark dataset, the second argument is the path where datasets are stored. + */ +public class TriangleCountExample { + public static void main(String[] args) { + String benchmarkName; + if (args.length > 0) { + benchmarkName = args[0]; + } else { + benchmarkName = "kgs"; + } + + Path resourcesPath; + if (args.length > 1) { + resourcesPath = Paths.get(args[1]); + } else { + resourcesPath = Paths.get("/tmp/ldbc_graphalitics_datesets"); + } + + Path caseRoot = resourcesPath.resolve(benchmarkName); + SparkConf sparkConf = new SparkConf() + .setAppName("TriangleCountExample") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); + SparkSession spark = SparkSession.builder().config(sparkConf).getOrCreate(); + SparkContext context = spark.sparkContext(); + context.setLogLevel("ERROR"); + context.setCheckpointDir("/tmp/graphframes-checkpoints"); + + LDBCUtils.downloadLDBCIfNotExists(resourcesPath, benchmarkName); + StructField[] edgeFields = new StructField[]{ + new StructField("src", DataTypes.LongType, true, Metadata.empty()), + new StructField("dst", DataTypes.LongType, true, Metadata.empty()) + }; + Dataset edges = spark.read() + .format("csv") + .option("header", "false") + .option("delimiter", " ") + .schema(new StructType(edgeFields)) + .load(caseRoot.resolve(benchmarkName + ".e").toString()) + .persist(StorageLevel.MEMORY_AND_DISK_SER()); + System.out.println("Edges loaded: " + edges.count()); + + StructField[] vertexFields = new StructField[]{ + new StructField("id", DataTypes.LongType, true, Metadata.empty()), + }; + Dataset vertices = spark.read() + .format("csv") + .option("header", "false") + .option("delimiter", " ") + .schema(new StructType(vertexFields)) + .load(caseRoot.resolve(benchmarkName + ".v").toString()) + .persist(StorageLevel.MEMORY_AND_DISK_SER()); + System.out.println("Vertices loaded: " + vertices.count()); + + var start = System.currentTimeMillis(); + GraphFrame graph = GraphFrame.apply(vertices, edges); + TriangleCount counter = graph.triangleCount(); + Dataset triangles = counter.run(); + + triangles.show(20, false); + long triangleCount = triangles.select(functions.sum("count")).first().getLong(0); + System.out.println("Found triangles: " + triangleCount); + var end = System.currentTimeMillis(); + System.out.println("Total running time in seconds: " + (end - start) / 1000.0); + } +} diff --git a/core/src/main/scala/org/graphframes/lib/TriangleCount.scala b/core/src/main/scala/org/graphframes/lib/TriangleCount.scala index a3bec9402..41adb729f 100644 --- a/core/src/main/scala/org/graphframes/lib/TriangleCount.scala +++ b/core/src/main/scala/org/graphframes/lib/TriangleCount.scala @@ -18,17 +18,11 @@ package org.graphframes.lib import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.array -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.functions.explode -import org.apache.spark.sql.functions.when +import org.apache.spark.sql.functions._ +import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame -import org.graphframes.GraphFrame.DST -import org.graphframes.GraphFrame.ID -import org.graphframes.GraphFrame.LONG_DST -import org.graphframes.GraphFrame.LONG_SRC -import org.graphframes.GraphFrame.SRC -import org.graphframes.GraphFrame.quote +import org.graphframes.Logging +import org.graphframes.WithIntermediateStorageLevel /** * Computes the number of triangles passing through each vertex. @@ -36,48 +30,84 @@ import org.graphframes.GraphFrame.quote * This algorithm ignores edge direction; i.e., all edges are treated as undirected. In a * multigraph, duplicate edges will be counted only once. * - * Note that this provides the same algorithm as GraphX, but GraphX assumes the user provides a - * graph in the correct format. In Spark 2.0+, GraphX can automatically canonicalize the graph to - * put it in this format. + * **WARNING** This implementation is based on intersections of neighbor sets, which requires + * collecting both SRC and DST neighbors per edge! This will blow up memory in case the graph + * contains very high-degree nodes (power-law networks). Consider sampling strategies for that + * case! * * The returned DataFrame contains all the original vertex information and one additional column: * - count (`LongType`): the count of triangles */ -class TriangleCount private[graphframes] (private val graph: GraphFrame) extends Arguments { +class TriangleCount private[graphframes] (private val graph: GraphFrame) + extends Arguments + with Serializable + with WithIntermediateStorageLevel { def run(): DataFrame = { - TriangleCount.run(graph) + TriangleCount.run(graph, intermediateStorageLevel) } } -private object TriangleCount { +private object TriangleCount extends Logging { + import org.graphframes.GraphFrame.* - private def run(graph: GraphFrame): DataFrame = { - // Dedup edges by flipping them to have LONG_SRC < LONG_DST - // TODO (when we drop support for Spark 1.4): Use functions greatest, smallest instead of UDFs - val dedupedE = graph.indexedEdges - .filter(s"$LONG_SRC != $LONG_DST") - .selectExpr( - s"if($LONG_SRC < $LONG_DST, $SRC, $DST) as $SRC", - s"if($LONG_SRC < $LONG_DST, $DST, $SRC) as $DST") - .dropDuplicates(Seq(SRC, DST)) - val g2 = GraphFrame(graph.vertices, dedupedE) + private def prepareGraph(graph: GraphFrame): GraphFrame = { + // Dedup edges by flipping them to have SRC < DST + // Remove self-loops + val dedupedE = graph.edges + .filter(col(SRC) =!= col(DST)) + .select( + when(col(SRC) < col(DST), col(SRC)).otherwise(col(DST)).as(SRC), + when(col(SRC) < col(DST), col(DST)).otherwise(col(SRC)).as(DST)) + .distinct() - // Because SRC < DST, there exists only one type of triangles: - // - Non-cycle with one edge flipped. These are counted 1 time each by motif finding. - val triangles = g2.find("(a)-[]->(b); (b)-[]->(c); (a)-[]->(c)") + // Prepare the graph with no isolated vertices. + GraphFrame(graph.vertices.select(ID), dedupedE).dropIsolatedVertices() + } + + private def run(graph: GraphFrame, intermediateStorageLevel: StorageLevel): DataFrame = { + val g2 = prepareGraph(graph) + + val verticesWithNeighbors = g2.aggregateMessages + .setIntermediateStorageLevel(intermediateStorageLevel) + .sendToSrc(AggregateMessages.dst(ID)) + .sendToDst(AggregateMessages.src(ID)) + .agg(collect_set(AggregateMessages.msg).alias("neighbors")) + + val triangles = verticesWithNeighbors + .select(col(ID), col("neighbors").alias("src_set")) + .join(g2.edges, col(ID) === col(SRC)) + .drop(ID) + .join( + verticesWithNeighbors.select(col(ID), col("neighbors").alias("dst_set")), + col(ID) === col(DST)) + .drop(ID) + // Count of common neighbors of SRC and DST + .withColumn("triplets", array_size(array_intersect(col("src_set"), col("dst_set")))) + .filter(col("triplets") > lit(0)) + .persist(intermediateStorageLevel) + + val srcTriangles = triangles.groupBy(SRC).agg(sum(col("triplets")).alias("src_triplets")) + val dstTriangles = triangles.groupBy(DST).agg(sum(col("triplets")).alias("dst_triplets")) - val triangleCounts = triangles - .select(explode(array(col("a.id"), col("b.id"), col("c.id"))).as(ID)) - .groupBy(ID) - .count() + val result = graph.vertices + .join(srcTriangles, col(ID) === col(SRC), "left_outer") + .join(dstTriangles, col(ID) === col(DST), "left_outer") + // Each triangle counted twice, so divide by 2. + .withColumn( + COUNT_ID, + floor( + when(col("src_triplets").isNull && col("dst_triplets").isNull, lit(0)) + .when(col("src_triplets").isNull, col("dst_triplets")) + .when(col("dst_triplets").isNull, col("src_triplets")) + .otherwise(col("src_triplets") + col("dst_triplets")) / lit(2))) - val v = graph.vertices - val countsCol = when(col("count").isNull, 0L).otherwise(col("count")) - val newV = v - .join(triangleCounts, v(ID) === triangleCounts(ID), "left_outer") - .select((countsCol.as(COUNT_ID) +: v.columns.map(quote).map(v.apply)).toSeq: _*) - newV + result.persist(intermediateStorageLevel) + result.count() + verticesWithNeighbors.unpersist() + triangles.unpersist() + resultIsPersistent() + result } private val COUNT_ID = "count"