Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package org.graphframes.examples;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.graphframes.GraphFrame;
import org.graphframes.lib.TriangleCount;

import java.nio.file.Path;
import java.nio.file.Paths;

/**
* The TriangleCount class demonstrates how to use the GraphFrames library in Apache Spark
* to count triangles in a graph dataset. A triangle in a graph is defined as a set of
* three interconnected vertices.
* <p>
* This examples uses graphs from the LDBC Graphalytics benchmark datasets.
* The first argument is the name of the benchmark dataset, the second argument is the path where datasets are stored.
*/
public class TriangleCountExample {
public static void main(String[] args) {
String benchmarkName;
if (args.length > 0) {
benchmarkName = args[0];
} else {
benchmarkName = "kgs";
}

Path resourcesPath;
if (args.length > 1) {
resourcesPath = Paths.get(args[1]);
} else {
resourcesPath = Paths.get("/tmp/ldbc_graphalitics_datesets");
}

Path caseRoot = resourcesPath.resolve(benchmarkName);
SparkConf sparkConf = new SparkConf()
.setAppName("TriangleCountExample")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
SparkSession spark = SparkSession.builder().config(sparkConf).getOrCreate();
SparkContext context = spark.sparkContext();
context.setLogLevel("ERROR");
context.setCheckpointDir("/tmp/graphframes-checkpoints");

LDBCUtils.downloadLDBCIfNotExists(resourcesPath, benchmarkName);
StructField[] edgeFields = new StructField[]{
new StructField("src", DataTypes.LongType, true, Metadata.empty()),
new StructField("dst", DataTypes.LongType, true, Metadata.empty())
};
Dataset<Row> edges = spark.read()
.format("csv")
.option("header", "false")
.option("delimiter", " ")
.schema(new StructType(edgeFields))
.load(caseRoot.resolve(benchmarkName + ".e").toString())
.persist(StorageLevel.MEMORY_AND_DISK_SER());
System.out.println("Edges loaded: " + edges.count());

StructField[] vertexFields = new StructField[]{
new StructField("id", DataTypes.LongType, true, Metadata.empty()),
};
Dataset<Row> vertices = spark.read()
.format("csv")
.option("header", "false")
.option("delimiter", " ")
.schema(new StructType(vertexFields))
.load(caseRoot.resolve(benchmarkName + ".v").toString())
.persist(StorageLevel.MEMORY_AND_DISK_SER());
System.out.println("Vertices loaded: " + vertices.count());

var start = System.currentTimeMillis();
GraphFrame graph = GraphFrame.apply(vertices, edges);
TriangleCount counter = graph.triangleCount();
Dataset<Row> triangles = counter.run();

triangles.show(20, false);
long triangleCount = triangles.select(functions.sum("count")).first().getLong(0);
System.out.println("Found triangles: " + triangleCount);
var end = System.currentTimeMillis();
System.out.println("Total running time in seconds: " + (end - start) / 1000.0);
}
}
108 changes: 69 additions & 39 deletions core/src/main/scala/org/graphframes/lib/TriangleCount.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,66 +18,96 @@
package org.graphframes.lib

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.array
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.explode
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel
import org.graphframes.GraphFrame
import org.graphframes.GraphFrame.DST
import org.graphframes.GraphFrame.ID
import org.graphframes.GraphFrame.LONG_DST
import org.graphframes.GraphFrame.LONG_SRC
import org.graphframes.GraphFrame.SRC
import org.graphframes.GraphFrame.quote
import org.graphframes.Logging
import org.graphframes.WithIntermediateStorageLevel

/**
* Computes the number of triangles passing through each vertex.
*
* This algorithm ignores edge direction; i.e., all edges are treated as undirected. In a
* multigraph, duplicate edges will be counted only once.
*
* Note that this provides the same algorithm as GraphX, but GraphX assumes the user provides a
* graph in the correct format. In Spark 2.0+, GraphX can automatically canonicalize the graph to
* put it in this format.
* **WARNING** This implementation is based on intersections of neighbor sets, which requires
* collecting both SRC and DST neighbors per edge! This will blow up memory in case the graph
* contains very high-degree nodes (power-law networks). Consider sampling strategies for that
* case!
*
* The returned DataFrame contains all the original vertex information and one additional column:
* - count (`LongType`): the count of triangles
*/
class TriangleCount private[graphframes] (private val graph: GraphFrame) extends Arguments {
class TriangleCount private[graphframes] (private val graph: GraphFrame)
extends Arguments
with Serializable
with WithIntermediateStorageLevel {

def run(): DataFrame = {
TriangleCount.run(graph)
TriangleCount.run(graph, intermediateStorageLevel)
}
}

private object TriangleCount {
private object TriangleCount extends Logging {
import org.graphframes.GraphFrame.*

private def run(graph: GraphFrame): DataFrame = {
// Dedup edges by flipping them to have LONG_SRC < LONG_DST
// TODO (when we drop support for Spark 1.4): Use functions greatest, smallest instead of UDFs
val dedupedE = graph.indexedEdges
.filter(s"$LONG_SRC != $LONG_DST")
.selectExpr(
s"if($LONG_SRC < $LONG_DST, $SRC, $DST) as $SRC",
s"if($LONG_SRC < $LONG_DST, $DST, $SRC) as $DST")
.dropDuplicates(Seq(SRC, DST))
val g2 = GraphFrame(graph.vertices, dedupedE)
private def prepareGraph(graph: GraphFrame): GraphFrame = {
// Dedup edges by flipping them to have SRC < DST
// Remove self-loops
val dedupedE = graph.edges
.filter(col(SRC) =!= col(DST))
.select(
when(col(SRC) < col(DST), col(SRC)).otherwise(col(DST)).as(SRC),
when(col(SRC) < col(DST), col(DST)).otherwise(col(SRC)).as(DST))
.distinct()

// Because SRC < DST, there exists only one type of triangles:
// - Non-cycle with one edge flipped. These are counted 1 time each by motif finding.
val triangles = g2.find("(a)-[]->(b); (b)-[]->(c); (a)-[]->(c)")
// Prepare the graph with no isolated vertices.
GraphFrame(graph.vertices.select(ID), dedupedE).dropIsolatedVertices()
}

private def run(graph: GraphFrame, intermediateStorageLevel: StorageLevel): DataFrame = {
Comment thread
SemyonSinchenko marked this conversation as resolved.
val g2 = prepareGraph(graph)

val verticesWithNeighbors = g2.aggregateMessages
.setIntermediateStorageLevel(intermediateStorageLevel)
.sendToSrc(AggregateMessages.dst(ID))
.sendToDst(AggregateMessages.src(ID))
.agg(collect_set(AggregateMessages.msg).alias("neighbors"))

val triangles = verticesWithNeighbors
.select(col(ID), col("neighbors").alias("src_set"))
.join(g2.edges, col(ID) === col(SRC))
.drop(ID)
.join(
verticesWithNeighbors.select(col(ID), col("neighbors").alias("dst_set")),
col(ID) === col(DST))
.drop(ID)
// Count of common neighbors of SRC and DST
.withColumn("triplets", array_size(array_intersect(col("src_set"), col("dst_set"))))
.filter(col("triplets") > lit(0))
.persist(intermediateStorageLevel)

val srcTriangles = triangles.groupBy(SRC).agg(sum(col("triplets")).alias("src_triplets"))
val dstTriangles = triangles.groupBy(DST).agg(sum(col("triplets")).alias("dst_triplets"))

val triangleCounts = triangles
.select(explode(array(col("a.id"), col("b.id"), col("c.id"))).as(ID))
.groupBy(ID)
.count()
val result = graph.vertices
.join(srcTriangles, col(ID) === col(SRC), "left_outer")
.join(dstTriangles, col(ID) === col(DST), "left_outer")
// Each triangle counted twice, so divide by 2.
.withColumn(
COUNT_ID,
floor(
when(col("src_triplets").isNull && col("dst_triplets").isNull, lit(0))
.when(col("src_triplets").isNull, col("dst_triplets"))
.when(col("dst_triplets").isNull, col("src_triplets"))
.otherwise(col("src_triplets") + col("dst_triplets")) / lit(2)))

val v = graph.vertices
val countsCol = when(col("count").isNull, 0L).otherwise(col("count"))
val newV = v
.join(triangleCounts, v(ID) === triangleCounts(ID), "left_outer")
.select((countsCol.as(COUNT_ID) +: v.columns.map(quote).map(v.apply)).toSeq: _*)
newV
result.persist(intermediateStorageLevel)
result.count()
verticesWithNeighbors.unpersist()
triangles.unpersist()
resultIsPersistent()
result
}

private val COUNT_ID = "count"
Expand Down