diff --git a/.gitignore b/.gitignore
index a07973c1e..9881378cf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -17,6 +17,7 @@ src_managed/
project/boot/
project/plugins/project/
.bsp
+.metals
# intellij
.idea/
diff --git a/.scalafmt.conf b/.scalafmt.conf
new file mode 100644
index 000000000..347451104
--- /dev/null
+++ b/.scalafmt.conf
@@ -0,0 +1,15 @@
+# The context of this file is copied from the Apache Spark Project
+
+align = none
+align.openParenDefnSite = false
+align.openParenCallSite = false
+align.tokens = []
+importSelectors = "singleLine"
+optIn = {
+ configStyleArguments = false
+}
+danglingParentheses.preset = false
+docstrings.style = Asterisk
+maxColumn = 98
+runner.dialect = scala213
+version = 3.8.5
\ No newline at end of file
diff --git a/build.sbt b/build.sbt
index 061901717..4ee4d9bd5 100644
--- a/build.sbt
+++ b/build.sbt
@@ -1,4 +1,4 @@
-import ReleaseTransformations._
+import ReleaseTransformations.*
lazy val sparkVer = sys.props.getOrElse("spark.version", "3.5.4")
lazy val sparkBranch = sparkVer.substring(0, 3)
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 46028c336..feb5a0677 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -8,3 +8,4 @@ ThisBuild / libraryDependencySchemes ++= Seq(
addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.10")
addSbtPlugin("com.github.sbt" % "sbt-release" % "1.4.0")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.1")
+addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.4")
\ No newline at end of file
diff --git a/src/main/scala/org/graphframes/GraphFrame.scala b/src/main/scala/org/graphframes/GraphFrame.scala
index c9a221a6a..d94305364 100644
--- a/src/main/scala/org/graphframes/GraphFrame.scala
+++ b/src/main/scala/org/graphframes/GraphFrame.scala
@@ -40,9 +40,11 @@ import org.graphframes.pattern._
* @groupname degree Graph topology
* @groupname motif Motif finding
*/
-class GraphFrame private(
+class GraphFrame private (
@transient private val _vertices: DataFrame,
- @transient private val _edges: DataFrame) extends Logging with Serializable {
+ @transient private val _edges: DataFrame)
+ extends Logging
+ with Serializable {
import GraphFrame._
@@ -53,7 +55,8 @@ class GraphFrame private(
// We call select on the vertices and edges to ensure that ID, SRC, DST always come first
// in the printed schema.
val vCols = (ID +: vertices.columns.filter(_ != ID).toIndexedSeq).map(col)
- val eCols = (SRC +: DST +: edges.columns.filter(c => c != SRC && c != DST).toIndexedSeq).map(col)
+ val eCols =
+ (SRC +: DST +: edges.columns.filter(c => c != SRC && c != DST).toIndexedSeq).map(col)
val v = vertices.select(vCols.toSeq: _*).toString
val e = edges.select(eCols.toSeq: _*).toString
"GraphFrame(v:" + v + ", e:" + e + ")"
@@ -80,8 +83,9 @@ class GraphFrame private(
/**
* Persist the dataframe representation of vertices and edges of the graph with the given
* storage level.
- * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
- * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc..
+ * @param newLevel
+ * One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, `MEMORY_AND_DISK_SER`,
+ * `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc..
*/
def persist(newLevel: StorageLevel): this.type = {
vertices.persist(newLevel)
@@ -102,7 +106,8 @@ class GraphFrame private(
/**
* Mark the dataframe representation of vertices and edges of the graph as non-persistent, and
* remove all blocks for it from memory and disk.
- * @param blocking Whether to block until all blocks are deleted.
+ * @param blocking
+ * Whether to block until all blocks are deleted.
*/
def unpersist(blocking: Boolean): this.type = {
vertices.unpersist(blocking)
@@ -115,8 +120,8 @@ class GraphFrame private(
/**
* The dataframe representation of the vertices of the graph.
*
- * It contains a column called [[GraphFrame.ID]] with the id of the vertex,
- * and various other user-defined attributes with other attributes.
+ * It contains a column called [[GraphFrame.ID]] with the id of the vertex, and various other
+ * user-defined attributes with other attributes.
*
* The order of the columns is available in [[vertexColumns]].
*
@@ -132,9 +137,9 @@ class GraphFrame private(
/**
* The dataframe representation of the edges of the graph.
*
- * It contains two columns called [[GraphFrame.SRC]] and [[GraphFrame.DST]] that contain
- * the ids of the source vertex and the destination vertex of each edge, respectively.
- * It may also contain various other columns with user-defined attributes for each edge.
+ * It contains two columns called [[GraphFrame.SRC]] and [[GraphFrame.DST]] that contain the ids
+ * of the source vertex and the destination vertex of each edge, respectively. It may also
+ * contain various other columns with user-defined attributes for each edge.
*
* For symmetric graphs, both pairs src -> dst and dst -> src are present with the same
* attributes for each pair.
@@ -154,7 +159,7 @@ class GraphFrame private(
/**
* Returns triplets: (source vertex)-[edge]->(destination vertex) for all edges in the graph.
* The DataFrame returned has 3 columns, with names: [[GraphFrame.SRC]], [[GraphFrame.EDGE]],
- * and [[GraphFrame.DST]]. The 2 vertex columns have schema matching [[GraphFrame.vertices]],
+ * and [[GraphFrame.DST]]. The 2 vertex columns have schema matching [[GraphFrame.vertices]],
* and the edge column has a schema matching [[GraphFrame.edges]].
*
* @group structure
@@ -164,11 +169,11 @@ class GraphFrame private(
// ============================ Conversions ========================================
/**
- * Converts this [[GraphFrame]] instance to a GraphX `Graph`.
- * Vertex and edge attributes are the original rows in [[vertices]] and [[edges]], respectively.
+ * Converts this [[GraphFrame]] instance to a GraphX `Graph`. Vertex and edge attributes are the
+ * original rows in [[vertices]] and [[edges]], respectively.
*
- * Note that vertex (and edge) attributes include vertex IDs (and source, destination IDs)
- * in order to support non-Long vertex IDs. If the vertex IDs are not convertible to Long values,
+ * Note that vertex (and edge) attributes include vertex IDs (and source, destination IDs) in
+ * order to support non-Long vertex IDs. If the vertex IDs are not convertible to Long values,
* then the values are indexed in order to generate corresponding Long vertex IDs (which is an
* expensive operation).
*
@@ -179,16 +184,22 @@ class GraphFrame private(
*/
def toGraphX: Graph[Row, Row] = {
if (hasIntegralIdType) {
- val vv = vertices.select(col(ID).cast(LongType), nestAsCol(vertices, ATTR))
- .rdd.map { case Row(id: Long, attr: Row) => (id, attr) }
- val ee = edges.select(col(SRC).cast(LongType), col(DST).cast(LongType), nestAsCol(edges, ATTR))
- .rdd.map { case Row(srcId: Long, dstId: Long, attr: Row) => Edge(srcId, dstId, attr) }
+ val vv = vertices.select(col(ID).cast(LongType), nestAsCol(vertices, ATTR)).rdd.map {
+ case Row(id: Long, attr: Row) => (id, attr)
+ }
+ val ee = edges
+ .select(col(SRC).cast(LongType), col(DST).cast(LongType), nestAsCol(edges, ATTR))
+ .rdd
+ .map { case Row(srcId: Long, dstId: Long, attr: Row) => Edge(srcId, dstId, attr) }
Graph(vv, ee)
} else {
// Compute Long vertex IDs
- val vv = indexedVertices.select(LONG_ID, ATTR).rdd.map { case Row(long_id: Long, attr: Row) => (long_id, attr) }
- val ee = indexedEdges.select(LONG_SRC, LONG_DST, ATTR).rdd.map { case Row(long_src: Long, long_dst: Long, attr: Row) =>
- Edge(long_src, long_dst, attr)
+ val vv = indexedVertices.select(LONG_ID, ATTR).rdd.map {
+ case Row(long_id: Long, attr: Row) => (long_id, attr)
+ }
+ val ee = indexedEdges.select(LONG_SRC, LONG_DST, ATTR).rdd.map {
+ case Row(long_src: Long, long_dst: Long, attr: Row) =>
+ Edge(long_src, long_dst, attr)
}
Graph(vv, ee)
}
@@ -197,9 +208,9 @@ class GraphFrame private(
/**
* The column names in the [[vertices]] DataFrame, in order.
*
- * Helper method for [[toGraphX]] which specifies the schema of vertex attributes.
- * The vertex attributes of the returned `Graph` are given as a `Row`,
- * and this method defines the column ordering in that `Row`.
+ * Helper method for [[toGraphX]] which specifies the schema of vertex attributes. The vertex
+ * attributes of the returned `Graph` are given as a `Row`, and this method defines the column
+ * ordering in that `Row`.
*
* @group conversions
*/
@@ -215,9 +226,9 @@ class GraphFrame private(
/**
* The vertex names in the [[vertices]] DataFrame, in order.
*
- * Helper method for [[toGraphX]] which specifies the schema of edge attributes.
- * The edge attributes of the returned `edges` are given as a `Row`,
- * and this method defines the column ordering in that `Row`.
+ * Helper method for [[toGraphX]] which specifies the schema of edge attributes. The edge
+ * attributes of the returned `edges` are given as a `Row`, and this method defines the column
+ * ordering in that `Row`.
*
* @group conversions
*/
@@ -234,8 +245,8 @@ class GraphFrame private(
/**
* The out-degree of each vertex in the graph, returned as a DataFrame with two columns:
- * - [[GraphFrame.ID]] the ID of the vertex
- * - "outDegree" (integer) storing the out-degree of the vertex
+ * - [[GraphFrame.ID]] the ID of the vertex
+ * - "outDegree" (integer) storing the out-degree of the vertex
* Note that vertices with 0 out-edges are not returned in the result.
*
* @group degree
@@ -246,9 +257,9 @@ class GraphFrame private(
/**
* The in-degree of each vertex in the graph, returned as a DataFame with two columns:
- * - [[GraphFrame.ID]] the ID of the vertex
- * "- "inDegree" (int) storing the in-degree of the vertex
- * Note that vertices with 0 in-edges are not returned in the result.
+ * - [[GraphFrame.ID]] the ID of the vertex
+ * "- "inDegree" (int) storing the in-degree of the vertex Note that vertices with 0 in-edges
+ * are not returned in the result.
*
* @group degree
*/
@@ -258,14 +269,17 @@ class GraphFrame private(
/**
* The degree of each vertex in the graph, returned as a DataFrame with two columns:
- * - [[GraphFrame.ID]] the ID of the vertex
- * - 'degree' (integer) the degree of the vertex
+ * - [[GraphFrame.ID]] the ID of the vertex
+ * - 'degree' (integer) the degree of the vertex
* Note that vertices with 0 edges are not returned in the result.
*
* @group degree
*/
@transient lazy val degrees: DataFrame = {
- edges.select(explode(array(SRC, DST)).as(ID)).groupBy(ID).agg(count("*").cast("int").as("degree"))
+ edges
+ .select(explode(array(SRC, DST)).as(ID))
+ .groupBy(ID)
+ .agg(count("*").cast("int").as("degree"))
}
// ============================ Motif finding ========================================
@@ -275,59 +289,59 @@ class GraphFrame private(
*
* Motif finding uses a simple Domain-Specific Language (DSL) for expressing structural queries.
* For example, `graph.find("(a)-[e]->(b); (b)-[e2]->(a)")` will search for pairs of vertices
- * `a,b` connected by edges in both directions. It will return a `DataFrame` of all such
- * structures in the graph, with columns for each of the named elements (vertices or edges)
- * in the motif. In this case, the returned columns will be in order of the pattern:
- * "a, e, b, e2."
+ * `a,b` connected by edges in both directions. It will return a `DataFrame` of all such
+ * structures in the graph, with columns for each of the named elements (vertices or edges) in
+ * the motif. In this case, the returned columns will be in order of the pattern: "a, e, b, e2."
*
* DSL for expressing structural patterns:
- * - The basic unit of a pattern is an edge.
- * For example, `"(a)-[e]->(b)"` expresses an edge `e` from vertex `a` to vertex `b`.
- * Note that vertices are denoted by parentheses `(a)`, while edges are denoted by
- * square brackets `[e]`.
- * - A pattern is expressed as a union of edges. Edge patterns can be joined with semicolons.
- * Motif `"(a)-[e]->(b); (b)-[e2]->(c)"` specifies two edges from `a` to `b` to `c`.
- * - Within a pattern, names can be assigned to vertices and edges. For example,
- * `"(a)-[e]->(b)"` has three named elements: vertices `a,b` and edge `e`.
- * These names serve two purposes:
- * - The names can identify common elements among edges. For example,
+ * - The basic unit of a pattern is an edge. For example, `"(a)-[e]->(b)"` expresses an edge
+ * `e` from vertex `a` to vertex `b`. Note that vertices are denoted by parentheses `(a)`,
+ * while edges are denoted by square brackets `[e]`.
+ * - A pattern is expressed as a union of edges. Edge patterns can be joined with semicolons.
+ * Motif `"(a)-[e]->(b); (b)-[e2]->(c)"` specifies two edges from `a` to `b` to `c`.
+ * - Within a pattern, names can be assigned to vertices and edges. For example,
+ * `"(a)-[e]->(b)"` has three named elements: vertices `a,b` and edge `e`. These names serve
+ * two purposes:
+ * - The names can identify common elements among edges. For example,
* `"(a)-[e]->(b); (b)-[e2]->(c)"` specifies that the same vertex `b` is the destination
* of edge `e` and source of edge `e2`.
- * - The names are used as column names in the result `DataFrame`. If a motif contains
- * named vertex `a`, then the result `DataFrame` will contain a column "a" which is a
+ * - The names are used as column names in the result `DataFrame`. If a motif contains named
+ * vertex `a`, then the result `DataFrame` will contain a column "a" which is a
* `StructType` with sub-fields equivalent to the schema (columns) of
- * [[GraphFrame.vertices]]. Similarly, an edge `e` in a motif will produce a column "e"
- * in the result `DataFrame` with sub-fields equivalent to the schema (columns) of
+ * [[GraphFrame.vertices]]. Similarly, an edge `e` in a motif will produce a column "e" in
+ * the result `DataFrame` with sub-fields equivalent to the schema (columns) of
* [[GraphFrame.edges]].
* - Be aware that names do *not* identify *distinct* elements: two elements with different
- * names may refer to the same graph element. For example, in the motif
+ * names may refer to the same graph element. For example, in the motif
* `"(a)-[e]->(b); (b)-[e2]->(c)"`, the names `a` and `c` could refer to the same vertex.
- * To restrict named elements to be distinct vertices or edges, use post-hoc filters
- * such as `resultDataframe.filter("a.id != c.id")`.
- * - It is acceptable to omit names for vertices or edges in motifs when not needed.
- * E.g., `"(a)-[]->(b)"` expresses an edge between vertices `a,b` but does not assign a name
- * to the edge. There will be no column for the anonymous edge in the result `DataFrame`.
- * Similarly, `"(a)-[e]->()"` indicates an out-edge of vertex `a` but does not name
- * the destination vertex. These are called *anonymous* vertices and edges.
- * - An edge can be negated to indicate that the edge should *not* be present in the graph.
- * E.g., `"(a)-[]->(b); !(b)-[]->(a)"` finds edges from `a` to `b` for which there is *no*
- * edge from `b` to `a`.
+ * To restrict named elements to be distinct vertices or edges, use post-hoc filters such
+ * as `resultDataframe.filter("a.id != c.id")`.
+ * - It is acceptable to omit names for vertices or edges in motifs when not needed. E.g.,
+ * `"(a)-[]->(b)"` expresses an edge between vertices `a,b` but does not assign a name to
+ * the edge. There will be no column for the anonymous edge in the result `DataFrame`.
+ * Similarly, `"(a)-[e]->()"` indicates an out-edge of vertex `a` but does not name the
+ * destination vertex. These are called *anonymous* vertices and edges.
+ * - An edge can be negated to indicate that the edge should *not* be present in the graph.
+ * E.g., `"(a)-[]->(b); !(b)-[]->(a)"` finds edges from `a` to `b` for which there is *no*
+ * edge from `b` to `a`.
*
* Restrictions:
- * - Motifs are not allowed to contain edges without any named elements: `"()-[]->()"` and
- * `"!()-[]->()"` are prohibited terms.
- * - Motifs are not allowed to contain named edges within negated terms (since these named
- * edges would never appear within results). E.g., `"!(a)-[ab]->(b)"` is invalid, but
- * `"!(a)-[]->(b)"` is valid.
+ * - Motifs are not allowed to contain edges without any named elements: `"()-[]->()"` and
+ * `"!()-[]->()"` are prohibited terms.
+ * - Motifs are not allowed to contain named edges within negated terms (since these named
+ * edges would never appear within results). E.g., `"!(a)-[ab]->(b)"` is invalid, but
+ * `"!(a)-[]->(b)"` is valid.
*
- * More complex queries, such as queries which operate on vertex or edge attributes,
- * can be expressed by applying filters to the result `DataFrame`.
+ * More complex queries, such as queries which operate on vertex or edge attributes, can be
+ * expressed by applying filters to the result `DataFrame`.
*
- * This can return duplicate rows. E.g., a query `"(u)-[]->()"` will return a result for each
+ * This can return duplicate rows. E.g., a query `"(u)-[]->()"` will return a result for each
* matching edge, even if those edges share the same vertex `u`.
*
- * @param pattern Pattern specifying a motif to search for.
- * @return `DataFrame` containing all instances of the motif.
+ * @param pattern
+ * Pattern specifying a motif to search for.
+ * @return
+ * `DataFrame` containing all instances of the motif.
* @group motif
*/
def find(pattern: String): DataFrame = {
@@ -342,7 +356,7 @@ class GraphFrame private(
val df = findSimple(augmentedPatterns)
val names = Pattern.findNamedElementsInOrder(patterns, includeEdges = true)
- if (names.isEmpty) df else df.select(names.head, names.tail : _*)
+ if (names.isEmpty) df else df.select(names.head, names.tail: _*)
}
// ======================== Other queries ===================================
@@ -357,9 +371,9 @@ class GraphFrame private(
def bfs: BFS = new BFS(this)
/**
- * This is a primitive for implementing graph algorithms.
- * This method aggregates values from the neighboring edges and vertices of each vertex.
- * See [[org.graphframes.lib.AggregateMessages AggregateMessages]] for detailed documentation.
+ * This is a primitive for implementing graph algorithms. This method aggregates values from the
+ * neighboring edges and vertices of each vertex. See
+ * [[org.graphframes.lib.AggregateMessages AggregateMessages]] for detailed documentation.
*/
def aggregateMessages: AggregateMessages = new AggregateMessages(this)
@@ -370,8 +384,9 @@ class GraphFrame private(
*/
def filterVertices(condition: Column): GraphFrame = {
val vv = vertices.filter(condition)
- val ee = edges.join(vv, vv(ID) === edges(SRC), "left_semi")
- .join(vv, vv(ID) === edges(DST), "left_semi")
+ val ee = edges
+ .join(vv, vv(ID) === edges(SRC), "left_semi")
+ .join(vv, vv(ID) === edges(DST), "left_semi")
GraphFrame(vv, ee)
}
@@ -382,7 +397,7 @@ class GraphFrame private(
*/
def filterVertices(conditionExpr: String): GraphFrame = filterVertices(expr(conditionExpr))
- /**
+ /**
* Filter the edges according to Column expression, keep all vertices.
* @group subgraph
*/
@@ -392,7 +407,7 @@ class GraphFrame private(
GraphFrame(vv, ee)
}
- /**
+ /**
* Filter the edges according to String expression.
* @group subgraph
*/
@@ -439,19 +454,20 @@ class GraphFrame private(
def pageRank: PageRank = new PageRank(this)
/**
- * Parallel personalized PageRank algorithm.
- *
- * See [[org.graphframes.lib.ParallelPersonalizedPageRank]] for more details.
- *
- * @group stdlib
- */
+ * Parallel personalized PageRank algorithm.
+ *
+ * See [[org.graphframes.lib.ParallelPersonalizedPageRank]] for more details.
+ *
+ * @group stdlib
+ */
def parallelPersonalizedPageRank: ParallelPersonalizedPageRank =
new ParallelPersonalizedPageRank(this)
/**
* Pregel algorithm.
*
- * @see [[org.graphframes.lib.Pregel]]
+ * @see
+ * [[org.graphframes.lib.Pregel]]
* @group stdlib
*/
def pregel = new Pregel(this)
@@ -496,11 +512,12 @@ class GraphFrame private(
// ========= Motif finding (private) =========
/**
- * Primary method implementing motif finding.
- * This iterative method handles one pattern (via [[findIncremental()]] on each iteration,
- * augmenting the `DataFrame` in prevDF with each new pattern.
+ * Primary method implementing motif finding. This iterative method handles one pattern (via
+ * [[findIncremental()]] on each iteration, augmenting the `DataFrame` in prevDF with each new
+ * pattern.
*
- * @return `DataFrame` containing all instances of the motif specified by the given patterns
+ * @return
+ * `DataFrame` containing all instances of the motif specified by the given patterns
*/
private def findSimple(patterns: Seq[Pattern]): DataFrame = {
val (_, finalDFOpt, _) =
@@ -519,38 +536,42 @@ class GraphFrame private(
/**
* True if the id type can be cast to Long.
*
- * This is important for performance reasons. The underlying graphx
- * implementation only deals with Long types.
+ * This is important for performance reasons. The underlying graphx implementation only deals
+ * with Long types.
*/
private[graphframes] lazy val hasIntegralIdType: Boolean = {
vertices.schema(ID).dataType match {
- case _ @ (ByteType | IntegerType | LongType | ShortType) => true
+ case _ @(ByteType | IntegerType | LongType | ShortType) => true
case _ => false
}
}
/**
- * Vertices with each vertex assigned a unique long ID.
- * If the vertex ID type is integral, this casts the original IDs to long.
+ * Vertices with each vertex assigned a unique long ID. If the vertex ID type is integral, this
+ * casts the original IDs to long.
*
* Columns:
- * - $LONG_ID: the new ID of LongType
- * - $ORIGINAL_ID: the ID provided by the user
- * - $ATTR: all the original vertex attributes
+ * - $LONG_ID: the new ID of LongType
+ * - $ORIGINAL_ID: the ID provided by the user
+ * - $ATTR: all the original vertex attributes
*/
private[graphframes] lazy val indexedVertices: DataFrame = {
if (hasIntegralIdType) {
val indexedVertices = vertices.select(nestAsCol(vertices, ATTR))
indexedVertices.select(
- col(ATTR + "." + ID).cast("long").as(LONG_ID), col(ATTR + "." + ID).as(ID), col(ATTR))
+ col(ATTR + "." + ID).cast("long").as(LONG_ID),
+ col(ATTR + "." + ID).as(ID),
+ col(ATTR))
} else {
- val withLongIds = vertices.select(ID)
+ val withLongIds = vertices
+ .select(ID)
.repartition(col(ID))
.distinct()
.sortWithinPartitions(ID)
.withColumn(LONG_ID, monotonically_increasing_id())
.persist(StorageLevel.MEMORY_AND_DISK)
- vertices.select(col(ID), nestAsCol(vertices, ATTR))
+ vertices
+ .select(col(ID), nestAsCol(vertices, ATTR))
.join(withLongIds, ID)
.select(LONG_ID, ID, ATTR)
}
@@ -558,31 +579,41 @@ class GraphFrame private(
/**
* Columns:
- * - $SRC
- * - $LONG_SRC
- * - $DST
- * - $LONG_DST
- * - $ATTR
+ * - $SRC
+ * - $LONG_SRC
+ * - $DST
+ * - $LONG_DST
+ * - $ATTR
*/
private[graphframes] lazy val indexedEdges: DataFrame = {
val packedEdges = edges.select(col(SRC), col(DST), nestAsCol(edges, ATTR))
if (hasIntegralIdType) {
packedEdges.select(
- col(SRC), col(SRC).cast("long").as(LONG_SRC),
- col(DST), col(DST).cast("long").as(LONG_DST),
+ col(SRC),
+ col(SRC).cast("long").as(LONG_SRC),
+ col(DST),
+ col(DST).cast("long").as(LONG_DST),
col(ATTR))
} else {
val threshold = broadcastThreshold
- val hubs: Set[Any] = degrees.filter(col("degree") >= threshold).select(ID)
- .collect().map(_.get(0)).toSet
+ val hubs: Set[Any] = degrees
+ .filter(col("degree") >= threshold)
+ .select(ID)
+ .collect()
+ .map(_.get(0))
+ .toSet
val indexedSourceEdges = GraphFrame.skewedJoin(
packedEdges,
indexedVertices.select(col(ID).as(SRC), col(LONG_ID).as(LONG_SRC)),
- SRC, hubs, "GraphFrame.indexedEdges:")
+ SRC,
+ hubs,
+ "GraphFrame.indexedEdges:")
val indexedEdges = GraphFrame.skewedJoin(
indexedSourceEdges,
indexedVertices.select(col(ID).as(DST), col(LONG_ID).as(LONG_DST)),
- DST, hubs, "GraphFrame.indexedEdges:")
+ DST,
+ hubs,
+ "GraphFrame.indexedEdges:")
indexedEdges.select(SRC, LONG_SRC, DST, LONG_DST, ATTR)
}
}
@@ -595,26 +626,33 @@ class GraphFrame private(
}
/**
- * A cached conversion of this graph to the GraphX structure, with the data stored for each edge and vertex.
+ * A cached conversion of this graph to the GraphX structure, with the data stored for each edge
+ * and vertex.
*/
@transient private lazy val cachedGraphX: Graph[Row, Row] = { toGraphX }
}
-
object GraphFrame extends Serializable with Logging {
/**
* Implements `a.join(b, joinCol)`, handling skew in the join keys.
- * @param a DataFrame which may have multiple rows with the same key in `joinCol`
- * @param b DataFrame which has exactly 1 row for every key in `a.joinCol`.
- * @param joinCol Name of column on which to do join
- * @param hubs Set of join keys which are high-degree (skewed)
- * @param logPrefix Prefix for logging, e.g., name of algorithm doing the join
- * @return `a.join(b, joinCol)`
- * @tparam T DataType for join key
- */
- private[graphframes] def skewedJoin[T : TypeTag](
+ * @param a
+ * DataFrame which may have multiple rows with the same key in `joinCol`
+ * @param b
+ * DataFrame which has exactly 1 row for every key in `a.joinCol`.
+ * @param joinCol
+ * Name of column on which to do join
+ * @param hubs
+ * Set of join keys which are high-degree (skewed)
+ * @param logPrefix
+ * Prefix for logging, e.g., name of algorithm doing the join
+ * @return
+ * `a.join(b, joinCol)`
+ * @tparam T
+ * DataType for join key
+ */
+ private[graphframes] def skewedJoin[T: TypeTag](
a: DataFrame,
b: DataFrame,
joinCol: String,
@@ -630,43 +668,44 @@ object GraphFrame extends Serializable with Logging {
val isHub = udf { id: T =>
hubs.contains(id)
}
- val hashJoined = a.filter(!isHub(col(joinCol)))
+ val hashJoined = a
+ .filter(!isHub(col(joinCol)))
.join(b.filter(!isHub(col(joinCol))), joinCol)
- val broadcastJoined = a.filter(isHub(col(joinCol)))
+ val broadcastJoined = a
+ .filter(isHub(col(joinCol)))
.join(broadcast(b.filter(isHub(col(joinCol)))), joinCol)
hashJoined.unionAll(broadcastJoined)
}
}
/**
- * Column name for vertex IDs in [[GraphFrame.vertices]]
- * Note that GraphFrame assigns a unique long ID to each vertex,
- * If the vertex ID type is one of byte / int / long / short type,
- * GraphFrame casts the original IDs to long as the unique long ID,
- * otherwise GraphFrame generates the unique long ID by Spark function
- * ``monotonically_increasing_id`` which is less performant.
+ * Column name for vertex IDs in [[GraphFrame.vertices]] Note that GraphFrame assigns a unique
+ * long ID to each vertex, If the vertex ID type is one of byte / int / long / short type,
+ * GraphFrame casts the original IDs to long as the unique long ID, otherwise GraphFrame
+ * generates the unique long ID by Spark function ``monotonically_increasing_id`` which is less
+ * performant.
*/
val ID: String = "id"
/**
* Column name for source vertices of edges.
- * - In [[GraphFrame.edges]], this is a column of vertex IDs.
- * - In [[GraphFrame.triplets]], this is a column of vertices with schema matching
- * [[GraphFrame.vertices]].
+ * - In [[GraphFrame.edges]], this is a column of vertex IDs.
+ * - In [[GraphFrame.triplets]], this is a column of vertices with schema matching
+ * [[GraphFrame.vertices]].
*/
val SRC: String = "src"
/**
* Column name for destination vertices of edges.
- * - In [[GraphFrame.edges]], this is a column of vertex IDs.
- * - In [[GraphFrame.triplets]], this is a column of vertices with schema matching
- * [[GraphFrame.vertices]].
+ * - In [[GraphFrame.edges]], this is a column of vertex IDs.
+ * - In [[GraphFrame.triplets]], this is a column of vertices with schema matching
+ * [[GraphFrame.vertices]].
*/
val DST: String = "dst"
/**
- * Column name for edge in [[GraphFrame.triplets]]. In [[GraphFrame.triplets]],
- * this is a column of edges with schema matching [[GraphFrame.edges]].
+ * Column name for edge in [[GraphFrame.triplets]]. In [[GraphFrame.triplets]], this is a column
+ * of edges with schema matching [[GraphFrame.edges]].
*/
val EDGE: String = "edge"
@@ -675,20 +714,26 @@ object GraphFrame extends Serializable with Logging {
/**
* Create a new [[GraphFrame]] from vertex and edge `DataFrame`s.
*
- * @param vertices Vertex DataFrame. This must include a column "id" containing unique vertex IDs.
- * All other columns are treated as vertex attributes.
- * @param edges Edge DataFrame. This must include columns "src" and "dst" containing source and
- * destination vertex IDs. All other columns are treated as edge attributes.
- * @return New [[GraphFrame]] instance
+ * @param vertices
+ * Vertex DataFrame. This must include a column "id" containing unique vertex IDs. All other
+ * columns are treated as vertex attributes.
+ * @param edges
+ * Edge DataFrame. This must include columns "src" and "dst" containing source and destination
+ * vertex IDs. All other columns are treated as edge attributes.
+ * @return
+ * New [[GraphFrame]] instance
*/
def apply(vertices: DataFrame, edges: DataFrame): GraphFrame = {
- require(vertices.columns.contains(ID),
+ require(
+ vertices.columns.contains(ID),
s"Vertex ID column '$ID' missing from vertex DataFrame, which has columns: "
+ vertices.columns.mkString(","))
- require(edges.columns.contains(SRC),
+ require(
+ edges.columns.contains(SRC),
s"Source vertex ID column '$SRC' missing from edge DataFrame, which has columns: "
+ edges.columns.mkString(","))
- require(edges.columns.contains(DST),
+ require(
+ edges.columns.contains(DST),
s"Destination vertex ID column '$DST' missing from edge DataFrame, which has columns: "
+ edges.columns.mkString(","))
@@ -696,14 +741,16 @@ object GraphFrame extends Serializable with Logging {
}
/**
- * Create a new [[GraphFrame]] from an edge `DataFrame`.
- * The resulting [[GraphFrame]] will have [[GraphFrame.vertices]] with a single "id" column.
+ * Create a new [[GraphFrame]] from an edge `DataFrame`. The resulting [[GraphFrame]] will have
+ * [[GraphFrame.vertices]] with a single "id" column.
*
* Note: The [[GraphFrame.vertices]] DataFrame will be persisted at level
- * `StorageLevel.MEMORY_AND_DISK`.
- * @param e Edge DataFrame. This must include columns "src" and "dst" containing source and
- * destination vertex IDs. All other columns are treated as edge attributes.
- * @return New [[GraphFrame]] instance
+ * `StorageLevel.MEMORY_AND_DISK`.
+ * @param e
+ * Edge DataFrame. This must include columns "src" and "dst" containing source and destination
+ * vertex IDs. All other columns are treated as edge attributes.
+ * @return
+ * New [[GraphFrame]] instance
*
* @group conversions
*/
@@ -718,64 +765,66 @@ object GraphFrame extends Serializable with Logging {
/**
* Converts a GraphX `Graph` instance into a [[GraphFrame]].
*
- * This converts each `org.apache.spark.rdd.RDD` in the `Graph` to a `DataFrame` using
- * schema inference.
+ * This converts each `org.apache.spark.rdd.RDD` in the `Graph` to a `DataFrame` using schema
+ * inference.
*
- * Vertex ID column names will be converted to "id" for the vertex DataFrame,
- * and to "src" and "dst" for the edge DataFrame.
+ * Vertex ID column names will be converted to "id" for the vertex DataFrame, and to "src" and
+ * "dst" for the edge DataFrame.
*
* @group conversions
*/
// TODO: Add version which takes explicit schemas.
- def fromGraphX[VD : TypeTag, ED : TypeTag](graph: Graph[VD, ED]): GraphFrame = {
+ def fromGraphX[VD: TypeTag, ED: TypeTag](graph: Graph[VD, ED]): GraphFrame = {
val spark = SparkSession.builder().getOrCreate()
val vv = spark.createDataFrame(graph.vertices).toDF(ID, ATTR)
val ee = spark.createDataFrame(graph.edges).toDF(SRC, DST, ATTR)
GraphFrame(vv, ee)
}
-
/**
* Given:
- * - a GraphFrame `originalGraph`
- * - a GraphX graph derived from the GraphFrame using [[GraphFrame.toGraphX]]
+ * - a GraphFrame `originalGraph`
+ * - a GraphX graph derived from the GraphFrame using [[GraphFrame.toGraphX]]
* this method merges attributes from the GraphX graph into the original GraphFrame.
*
- * This method is useful for doing computations using the GraphX API and then merging the results
- * with a GraphFrame. For example, given:
- * - GraphFrame `originalGraph`
- * - GraphX Graph[String, Int] `graph` with a String vertex attribute we want to call "category"
- * and an Int edge attribute we want to call "count"
- * We can call `fromGraphX(originalGraph, graph, Seq("category"), Seq("count"))` to produce
- * a new GraphFrame. The new GraphFrame will be an augmented version of `originalGraph`,
- * with new [[GraphFrame.vertices]] column "category" and new [[GraphFrame.edges]] column
- * "count" added.
+ * This method is useful for doing computations using the GraphX API and then merging the
+ * results with a GraphFrame. For example, given:
+ * - GraphFrame `originalGraph`
+ * - GraphX Graph[String, Int] `graph` with a String vertex attribute we want to call
+ * "category" and an Int edge attribute we want to call "count"
+ * We can call `fromGraphX(originalGraph, graph, Seq("category"), Seq("count"))` to produce a
+ * new GraphFrame. The new GraphFrame will be an augmented version of `originalGraph`, with new
+ * [[GraphFrame.vertices]] column "category" and new [[GraphFrame.edges]] column "count" added.
*
* See [[org.graphframes.examples.BeliefPropagation]] for example usage.
*
- * @param originalGraph Original GraphFrame used to compute the GraphX graph.
- * @param graph GraphX graph. Vertex and edge attributes, if any, will be merged into
- * the original graph as new columns. If the attributes are `Product` types
- * such as tuples, then each element of the `Product` will be put in a separate
- * column. If the attributes are other types, then the entire GraphX attribute
- * will become a single new column.
- * @param vertexNames Column name(s) for vertex attributes in the GraphX graph.
- * If there is no vertex attribute, this should be empty.
- * If there is a singleton attribute, this should have a single column name.
- * If the attribute is a `Product` type, this should be a list of names
- * matching the order of the attribute elements.
- * @param edgeNames Column name(s) for edge attributes in the GraphX graph.
- * If there is no edge attribute, this should be empty.
- * If there is a singleton attribute, this should have a single column name.
- * If the attribute is a `Product` type, this should be a list of names
- * matching the order of the attribute elements.
- * @tparam V the type of the vertex data
- * @tparam E the type of the edge data
- * @return original graph augmented with vertex and column attributes from the GraphX graph
+ * @param originalGraph
+ * Original GraphFrame used to compute the GraphX graph.
+ * @param graph
+ * GraphX graph. Vertex and edge attributes, if any, will be merged into the original graph as
+ * new columns. If the attributes are `Product` types such as tuples, then each element of the
+ * `Product` will be put in a separate column. If the attributes are other types, then the
+ * entire GraphX attribute will become a single new column.
+ * @param vertexNames
+ * Column name(s) for vertex attributes in the GraphX graph. If there is no vertex attribute,
+ * this should be empty. If there is a singleton attribute, this should have a single column
+ * name. If the attribute is a `Product` type, this should be a list of names matching the
+ * order of the attribute elements.
+ * @param edgeNames
+ * Column name(s) for edge attributes in the GraphX graph. If there is no edge attribute, this
+ * should be empty. If there is a singleton attribute, this should have a single column name.
+ * If the attribute is a `Product` type, this should be a list of names matching the order of
+ * the attribute elements.
+ * @tparam V
+ * the type of the vertex data
+ * @tparam E
+ * the type of the edge data
+ * @return
+ * original graph augmented with vertex and column attributes from the GraphX graph
*
* @group conversions
*/
- def fromGraphX[V : TypeTag, E : TypeTag](
+ def fromGraphX[V: TypeTag, E: TypeTag](
originalGraph: GraphFrame,
graph: Graph[V, E],
vertexNames: Seq[String] = Nil,
@@ -783,7 +832,6 @@ object GraphFrame extends Serializable with Logging {
GraphXConversions.fromGraphX[V, E](originalGraph, graph, vertexNames, edgeNames)
}
-
// ============== Private constants ==============
/** Default name for attribute columns when converting from GraphX [[Graph]] format */
@@ -798,16 +846,15 @@ object GraphFrame extends Serializable with Logging {
private[graphframes] val LONG_DST: String = "new_dst"
private[graphframes] val GX_ATTR: String = "graphx_attr"
-
-
/** Helper for using [col].* in Spark 1.4. Returns sequence of [col].[field] for all fields */
private[graphframes] def colStar(df: DataFrame, col: String): Seq[String] = {
df.schema(col).dataType match {
case s: StructType =>
s.fieldNames.map(f => col + "." + f).toIndexedSeq
case other =>
- throw new RuntimeException(s"Unknown error in GraphFrame. Expected column $col to be" +
- s" StructType, but found type: $other")
+ throw new RuntimeException(
+ s"Unknown error in GraphFrame. Expected column $col to be" +
+ s" StructType, but found type: $other")
}
}
@@ -825,7 +872,6 @@ object GraphFrame extends Serializable with Logging {
private def eSrcId(name: String): String = prefixWithName(name, SRC)
private def eDstId(name: String): String = prefixWithName(name, DST)
-
private def maybeCrossJoin(aOpt: Option[DataFrame], b: DataFrame): DataFrame = {
aOpt match {
case Some(a) => a.crossJoin(b)
@@ -843,7 +889,6 @@ object GraphFrame extends Serializable with Logging {
}
}
-
/** Indicate whether a named vertex has been seen in any of the given patterns */
private def seen(v: NamedVertex, patterns: Seq[Pattern]) = patterns.exists(p => seen1(v, p))
@@ -861,14 +906,17 @@ object GraphFrame extends Serializable with Logging {
false
}
-
/**
* Augment the given DataFrame based on a pattern.
*
- * @param prevPatterns Patterns which have contributed to the given DataFrame
- * @param prev Given DataFrame
- * @param pattern Pattern to search for
- * @return DataFrame augmented with the current search pattern
+ * @param prevPatterns
+ * Patterns which have contributed to the given DataFrame
+ * @param prev
+ * Given DataFrame
+ * @param pattern
+ * Pattern to search for
+ * @return
+ * DataFrame augmented with the current search pattern
*/
private def findIncremental(
gf: GraphFrame,
@@ -899,93 +947,114 @@ object GraphFrame extends Serializable with Logging {
case NamedEdge(name, AnonymousVertex, dst @ NamedVertex(dstName)) =>
if (seen(dst, prevPatterns)) {
val eRen = nestE(name)
- (Some(maybeJoin(prev, eRen, prev => eRen(eDstId(name)) === prev(vId(dstName)))),
+ (
+ Some(maybeJoin(prev, eRen, prev => eRen(eDstId(name)) === prev(vId(dstName)))),
prevNames :+ name)
} else {
val eRen = nestE(name)
val dstV = nestV(dstName)
- (Some(maybeCrossJoin(prev, eRen)
- .join(dstV, eRen(eDstId(name)) === dstV(vId(dstName)))),
+ (
+ Some(
+ maybeCrossJoin(prev, eRen)
+ .join(dstV, eRen(eDstId(name)) === dstV(vId(dstName)))),
prevNames :+ name :+ dstName)
}
case NamedEdge(name, src @ NamedVertex(srcName), AnonymousVertex) =>
if (seen(src, prevPatterns)) {
val eRen = nestE(name)
- (Some(maybeJoin(prev, eRen, prev => eRen(eSrcId(name)) === prev(vId(srcName)))),
+ (
+ Some(maybeJoin(prev, eRen, prev => eRen(eSrcId(name)) === prev(vId(srcName)))),
prevNames :+ name)
} else {
val eRen = nestE(name)
val srcV = nestV(srcName)
- (Some(maybeCrossJoin(prev, eRen)
- .join(srcV, eRen(eSrcId(name)) === srcV(vId(srcName)))),
- prevNames :+ srcName :+ name)
+ (
+ Some(
+ maybeCrossJoin(prev, eRen)
+ .join(srcV, eRen(eSrcId(name)) === srcV(vId(srcName)))),
+ prevNames :+ srcName :+ name)
}
case NamedEdge(name, src @ NamedVertex(srcName), dst @ NamedVertex(dstName)) =>
(seen(src, prevPatterns), seen(dst, prevPatterns)) match {
case (true, true) =>
val eRen = nestE(name)
- (Some(maybeJoin(prev, eRen, prev =>
- eRen(eSrcId(name)) === prev(vId(srcName)) && eRen(eDstId(name)) === prev(vId(dstName)))),
+ (
+ Some(
+ maybeJoin(
+ prev,
+ eRen,
+ prev =>
+ eRen(eSrcId(name)) === prev(vId(srcName)) && eRen(eDstId(name)) === prev(
+ vId(dstName)))),
prevNames :+ name)
case (true, false) =>
val eRen = nestE(name)
val dstV = nestV(dstName)
- (Some(maybeJoin(prev, eRen, prev => eRen(eSrcId(name)) === prev(vId(srcName)))
- .join(dstV, eRen(eDstId(name)) === dstV(vId(dstName)))),
+ (
+ Some(
+ maybeJoin(prev, eRen, prev => eRen(eSrcId(name)) === prev(vId(srcName)))
+ .join(dstV, eRen(eDstId(name)) === dstV(vId(dstName)))),
prevNames :+ name :+ dstName)
case (false, true) =>
val eRen = nestE(name)
val srcV = nestV(srcName)
- (Some(maybeJoin(prev, eRen, prev => eRen(eDstId(name)) === prev(vId(dstName)))
- .join(srcV, eRen(eSrcId(name)) === srcV(vId(srcName)))),
+ (
+ Some(
+ maybeJoin(prev, eRen, prev => eRen(eDstId(name)) === prev(vId(dstName)))
+ .join(srcV, eRen(eSrcId(name)) === srcV(vId(srcName)))),
prevNames :+ srcName :+ name)
case (false, false) if srcName != dstName =>
val eRen = nestE(name)
val srcV = nestV(srcName)
val dstV = nestV(dstName)
- (Some(maybeCrossJoin(prev, eRen)
- .join(srcV, eRen(eSrcId(name)) === srcV(vId(srcName)))
- .join(dstV, eRen(eDstId(name)) === dstV(vId(dstName)))),
+ (
+ Some(
+ maybeCrossJoin(prev, eRen)
+ .join(srcV, eRen(eSrcId(name)) === srcV(vId(srcName)))
+ .join(dstV, eRen(eDstId(name)) === dstV(vId(dstName)))),
prevNames :+ srcName :+ name :+ dstName)
// TODO: expose the plans from joining these in the opposite order
case (false, false) if srcName == dstName =>
val eRen = nestE(name)
val srcV = nestV(srcName)
- (Some(maybeCrossJoin(prev, eRen)
- .join(srcV,
- eRen(eSrcId(name)) === srcV(vId(srcName)) &&
- eRen(eDstId(name)) === srcV(vId(srcName)))),
+ (
+ Some(
+ maybeCrossJoin(prev, eRen)
+ .join(
+ srcV,
+ eRen(eSrcId(name)) === srcV(vId(srcName)) &&
+ eRen(eDstId(name)) === srcV(vId(srcName)))),
prevNames :+ srcName :+ name)
}
case AnonymousEdge(src, dst) =>
val tmpName = "__tmp" + random.nextLong.toString
- val (df, names) = findIncremental(gf, prevPatterns, prev, prevNames, NamedEdge(tmpName, src, dst))
+ val (df, names) =
+ findIncremental(gf, prevPatterns, prev, prevNames, NamedEdge(tmpName, src, dst))
(df.map(_.drop(tmpName)), names.filter(_ != tmpName))
- case Negation(edge) => prev match {
- case Some(p) =>
- val (df, names) = findIncremental(gf, prevPatterns, Some(p), prevNames, edge)
- (df.map(result => p.except(result)), names)
- case None =>
- throw new InvalidPatternException
- }
+ case Negation(edge) =>
+ prev match {
+ case Some(p) =>
+ val (df, names) = findIncremental(gf, prevPatterns, Some(p), prevNames, edge)
+ (df.map(result => p.except(result)), names)
+ case None =>
+ throw new InvalidPatternException
+ }
}
}
/**
- * Controls broadcast threshold in skewed joins.
- * Use normal joins for vertices with degrees less than the threshold,
- * and broadcast joins otherwise.
- * The default value is 1000000.
- * If we have less than 100 billion edges, this would collect at most
- * 2e11 / 1000000 = 200000 hubs, which could be handled by the driver.
+ * Controls broadcast threshold in skewed joins. Use normal joins for vertices with degrees less
+ * than the threshold, and broadcast joins otherwise. The default value is 1000000. If we have
+ * less than 100 billion edges, this would collect at most 2e11 / 1000000 = 200000 hubs, which
+ * could be handled by the driver.
*/
private[this] var _broadcastThreshold: Int = 1000000
diff --git a/src/main/scala/org/graphframes/examples/BeliefPropagation.scala b/src/main/scala/org/graphframes/examples/BeliefPropagation.scala
index 9074bd13c..df97ba85a 100644
--- a/src/main/scala/org/graphframes/examples/BeliefPropagation.scala
+++ b/src/main/scala/org/graphframes/examples/BeliefPropagation.scala
@@ -26,45 +26,43 @@ import org.graphframes.GraphFrame
import org.graphframes.examples.Graphs.gridIsingModel
import org.graphframes.lib.AggregateMessages
-
/**
* Example code for Belief Propagation (BP)
*
- * This provides a template for building customized BP algorithms for different types of
- * graphical models.
+ * This provides a template for building customized BP algorithms for different types of graphical
+ * models.
*
* This example:
- * - Ising model on a grid
- * - Parallel Belief Propagation using colored fields
+ * - Ising model on a grid
+ * - Parallel Belief Propagation using colored fields
*
- * Ising models are probabilistic graphical models over binary variables x,,i,,.
- * Each binary variable x,,i,, corresponds to one vertex, and it may take values -1 or +1.
- * The probability distribution P(X) (over all x,,i,,) is parameterized by vertex factors a,,i,,
- * and edge factors b,,ij,,:
+ * Ising models are probabilistic graphical models over binary variables x,,i,,. Each binary
+ * variable x,,i,, corresponds to one vertex, and it may take values -1 or +1. The probability
+ * distribution P(X) (over all x,,i,,) is parameterized by vertex factors a,,i,, and edge factors
+ * b,,ij,,:
* {{{
* P(X) = (1/Z) * exp[ \sum_i a_i x_i + \sum_{ij} b_{ij} x_i x_j ]
* }}}
- * where Z is the normalization constant (partition function).
- * See [[https://en.wikipedia.org/wiki/Ising_model Wikipedia]] for more information on Ising models.
+ * where Z is the normalization constant (partition function). See
+ * [[https://en.wikipedia.org/wiki/Ising_model Wikipedia]] for more information on Ising models.
*
* Belief Propagation (BP) provides marginal probabilities of the values of the variables x,,i,,,
- * i.e., P(x,,i,,) for each i. This allows a user to understand likely values of variables.
- * See [[https://en.wikipedia.org/wiki/Belief_propagation Wikipedia]] for more information on BP.
+ * i.e., P(x,,i,,) for each i. This allows a user to understand likely values of variables. See
+ * [[https://en.wikipedia.org/wiki/Belief_propagation Wikipedia]] for more information on BP.
*
* We use a batch synchronous BP algorithm, where batches of vertices are updated synchronously.
* We follow the mean field update algorithm in Slide 13 of the
- * [[http://www.eecs.berkeley.edu/~wainwrig/Talks/A_GraphModel_Tutorial talk slides]] from:
- * Wainwright. "Graphical models, message-passing algorithms, and convex optimization."
+ * [[http://www.eecs.berkeley.edu/~wainwrig/Talks/A_GraphModel_Tutorial talk slides]] from:
+ * Wainwright. "Graphical models, message-passing algorithms, and convex optimization."
*
- * The batches are chosen according to a coloring. For background on graph colorings for inference,
- * see for example:
- * Gonzalez et al. "Parallel Gibbs Sampling: From Colored Fields to Thin Junction Trees."
- * AISTATS, 2011.
+ * The batches are chosen according to a coloring. For background on graph colorings for
+ * inference, see for example: Gonzalez et al. "Parallel Gibbs Sampling: From Colored Fields to
+ * Thin Junction Trees." AISTATS, 2011.
*
* The BP algorithm works by:
- * - Coloring the graph by assigning a color to each vertex such that no neighboring vertices
- * share the same color.
- * - In each step of BP, update all vertices of a single color. Alternate colors.
+ * - Coloring the graph by assigning a color to each vertex such that no neighboring vertices
+ * share the same color.
+ * - In each step of BP, update all vertices of a single color. Alternate colors.
*/
object BeliefPropagation {
@@ -94,14 +92,16 @@ object BeliefPropagation {
}
/**
- * Given a GraphFrame, choose colors for each vertex. No neighboring vertices will share the
- * same color. The number of colors is minimized.
+ * Given a GraphFrame, choose colors for each vertex. No neighboring vertices will share the
+ * same color. The number of colors is minimized.
*
* This is written specifically for grid graphs. For non-grid graphs, it should be generalized,
* such as by using a greedy coloring scheme.
*
- * @param g Grid graph generated by [[org.graphframes.examples.Graphs.gridIsingModel()]]
- * @return Same graph, but with a new vertex column "color" of type Int (0 or 1)
+ * @param g
+ * Grid graph generated by [[org.graphframes.examples.Graphs.gridIsingModel()]]
+ * @return
+ * Same graph, but with a new vertex column "color" of type Int (0 or 1)
*/
private def colorGraph(g: GraphFrame): GraphFrame = {
val colorUDF = udf { (i: Int, j: Int) => (i + j) % 2 }
@@ -112,19 +112,22 @@ object BeliefPropagation {
/**
* Run Belief Propagation.
*
- * This implementation of BP shows how to use GraphX's aggregateMessages method.
- * It is simple to convert to and from GraphX format. This method does the following:
- * - Color GraphFrame vertices for BP scheduling.
- * - Convert GraphFrame to GraphX format.
- * - Run BP using GraphX's aggregateMessages API.
- * - Augment the original GraphFrame with the BP results (vertex beliefs).
+ * This implementation of BP shows how to use GraphX's aggregateMessages method. It is simple to
+ * convert to and from GraphX format. This method does the following:
+ * - Color GraphFrame vertices for BP scheduling.
+ * - Convert GraphFrame to GraphX format.
+ * - Run BP using GraphX's aggregateMessages API.
+ * - Augment the original GraphFrame with the BP results (vertex beliefs).
*
- * @param g Graphical model created by `org.graphframes.examples.Graphs.gridIsingModel()`
- * @param numIter Number of iterations of BP to run. One iteration includes updating each
- * vertex's belief once.
- * @return Same graphical model, but with [[GraphFrame.vertices]] augmented with a new column
- * "belief" containing P(x,,i,, = +1), the marginal probability of vertex i taking
- * value +1 instead of -1.
+ * @param g
+ * Graphical model created by `org.graphframes.examples.Graphs.gridIsingModel()`
+ * @param numIter
+ * Number of iterations of BP to run. One iteration includes updating each vertex's belief
+ * once.
+ * @return
+ * Same graphical model, but with [[GraphFrame.vertices]] augmented with a new column "belief"
+ * containing P(x,,i,, = +1), the marginal probability of vertex i taking value +1 instead of
+ * -1.
*/
def runBPwithGraphX(g: GraphFrame, numIter: Int): GraphFrame = {
// Choose colors for vertices for BP scheduling.
@@ -166,15 +169,14 @@ object BeliefPropagation {
},
_ + _)
// Receive messages, and update beliefs for vertices of the current color.
- gx = gx.outerJoinVertices(msgs) {
- case (vID, vAttr, optMsg) =>
- if (vAttr.color == color) {
- val x = vAttr.a + optMsg.getOrElse(0.0)
- val newBelief = math.exp(-log1pExp(-x))
- VertexAttr(vAttr.a, newBelief, color)
- } else {
- vAttr
- }
+ gx = gx.outerJoinVertices(msgs) { case (vID, vAttr, optMsg) =>
+ if (vAttr.color == color) {
+ val x = vAttr.a + optMsg.getOrElse(0.0)
+ val newBelief = math.exp(-log1pExp(-x))
+ VertexAttr(vAttr.a, newBelief, color)
+ } else {
+ vAttr
+ }
}
}
}
@@ -192,16 +194,19 @@ object BeliefPropagation {
* Run Belief Propagation.
*
* This implementation of BP shows how to use GraphFrame's aggregateMessages method.
- * - Color GraphFrame vertices for BP scheduling.
- * - Run BP using GraphFrame's aggregateMessages API.
- * - Augment the original GraphFrame with the BP results (vertex beliefs).
+ * - Color GraphFrame vertices for BP scheduling.
+ * - Run BP using GraphFrame's aggregateMessages API.
+ * - Augment the original GraphFrame with the BP results (vertex beliefs).
*
- * @param g Graphical model created by `org.graphframes.examples.Graphs.gridIsingModel()`
- * @param numIter Number of iterations of BP to run. One iteration includes updating each
- * vertex's belief once.
- * @return Same graphical model, but with [[GraphFrame.vertices]] augmented with a new column
- * "belief" containing P(x,,i,, = +1), the marginal probability of vertex i taking
- * value +1 instead of -1.
+ * @param g
+ * Graphical model created by `org.graphframes.examples.Graphs.gridIsingModel()`
+ * @param numIter
+ * Number of iterations of BP to run. One iteration includes updating each vertex's belief
+ * once.
+ * @return
+ * Same graphical model, but with [[GraphFrame.vertices]] augmented with a new column "belief"
+ * containing P(x,,i,, = +1), the marginal probability of vertex i taking value +1 instead of
+ * -1.
*/
def runBPwithGraphFrames(g: GraphFrame, numIter: Int): GraphFrame = {
// Choose colors for vertices for BP scheduling.
@@ -231,15 +236,16 @@ object BeliefPropagation {
.agg(sum(AM.msg).as("aggMess"))
val v = gx.vertices
// Receive messages, and update beliefs for vertices of the current color.
- val newBeliefCol = when(v("color") === color && aggregates("aggMess").isNotNull,
+ val newBeliefCol = when(
+ v("color") === color && aggregates("aggMess").isNotNull,
logistic(aggregates("aggMess") + v("a")))
- .otherwise(v("belief")) // keep old beliefs for other colors
+ .otherwise(v("belief")) // keep old beliefs for other colors
val newVertices = v
- .join(aggregates, v("id") === aggregates("id"), "left_outer") // join messages, vertices
- .drop(aggregates("id")) // drop duplicate ID column (from outer join)
- .withColumn("newBelief", newBeliefCol) // compute new beliefs
- .drop("aggMess") // drop messages
- .drop("belief") // drop old beliefs
+ .join(aggregates, v("id") === aggregates("id"), "left_outer") // join messages, vertices
+ .drop(aggregates("id")) // drop duplicate ID column (from outer join)
+ .withColumn("newBelief", newBeliefCol) // compute new beliefs
+ .drop("aggMess") // drop messages
+ .drop("belief") // drop old beliefs
.withColumnRenamed("newBelief", "belief")
// Cache new vertices using workaround for SPARK-13346
val cachedNewVertices = AM.getCachedDataFrame(newVertices)
diff --git a/src/main/scala/org/graphframes/examples/Graphs.scala b/src/main/scala/org/graphframes/examples/Graphs.scala
index 2ea19e525..5066fc2ce 100644
--- a/src/main/scala/org/graphframes/examples/Graphs.scala
+++ b/src/main/scala/org/graphframes/examples/Graphs.scala
@@ -43,13 +43,15 @@ class Graphs private[graphframes] () {
}
/**
- * Returns a chain graph of the given size with Long ID type.
- * The vertex IDs are 0, 1, ..., n-1, and the edges are (0, 1), (1, 2), ...., (n-2, n-1).
+ * Returns a chain graph of the given size with Long ID type. The vertex IDs are 0, 1, ..., n-1,
+ * and the edges are (0, 1), (1, 2), ...., (n-2, n-1).
*/
def chain(n: Long): GraphFrame = {
require(n >= 0, s"Chain graph size must be nonnegative but got $n.")
val vertices = spark.range(n).toDF(ID)
- val edges = spark.range(n - 1L).toDF(ID)
+ val edges = spark
+ .range(n - 1L)
+ .toDF(ID)
.select(col(ID).as(SRC), (col(ID) + 1L).as(DST))
GraphFrame(vertices, edges)
}
@@ -60,33 +62,38 @@ class Graphs private[graphframes] () {
def friends: GraphFrame = {
// For the same reason as above, this cannot be a value.
// Vertex DataFrame
- val v = spark.createDataFrame(List(
- ("a", "Alice", 34),
- ("b", "Bob", 36),
- ("c", "Charlie", 30),
- ("d", "David", 29),
- ("e", "Esther", 32),
- ("f", "Fanny", 36),
- ("g", "Gabby", 60)
- )).toDF("id", "name", "age")
+ val v = spark
+ .createDataFrame(
+ List(
+ ("a", "Alice", 34),
+ ("b", "Bob", 36),
+ ("c", "Charlie", 30),
+ ("d", "David", 29),
+ ("e", "Esther", 32),
+ ("f", "Fanny", 36),
+ ("g", "Gabby", 60)))
+ .toDF("id", "name", "age")
// Edge DataFrame
- val e = spark.createDataFrame(List(
- ("a", "b", "friend"),
- ("b", "c", "follow"),
- ("c", "b", "follow"),
- ("f", "c", "follow"),
- ("e", "f", "follow"),
- ("e", "d", "friend"),
- ("d", "a", "friend"),
- ("a", "e", "friend")
- )).toDF("src", "dst", "relationship")
+ val e = spark
+ .createDataFrame(
+ List(
+ ("a", "b", "friend"),
+ ("b", "c", "follow"),
+ ("c", "b", "follow"),
+ ("f", "c", "follow"),
+ ("e", "f", "follow"),
+ ("e", "d", "friend"),
+ ("d", "a", "friend"),
+ ("a", "e", "friend")))
+ .toDF("src", "dst", "relationship")
// Create a GraphFrame
GraphFrame(v, e)
}
/**
* Two densely connected blobs (vertices 0->n-1 and n->2n-1) connected by a single edge (0->n)
- * @param blobSize the size of each blob.
+ * @param blobSize
+ * the size of each blob.
* @return
*/
def twoBlobs(blobSize: Int): GraphFrame = {
@@ -94,7 +101,8 @@ class Graphs private[graphframes] () {
val edges1 = for (v1 <- 0 until n; v2 <- 0 until n) yield (v1.toLong, v2.toLong, s"$v1-$v2")
val edges2 = for {
v1 <- n until (2 * n)
- v2 <- n until (2 * n) } yield (v1.toLong, v2.toLong, s"$v1-$v2")
+ v2 <- n until (2 * n)
+ } yield (v1.toLong, v2.toLong, s"$v1-$v2")
val edges = edges1 ++ edges2 :+ (0L, n.toLong, s"0-$n")
val vertices = (0 until (2 * n)).map { v => (v.toLong, s"$v", v) }
val e = spark.createDataFrame(edges).toDF("src", "dst", "e_attr1")
@@ -105,7 +113,8 @@ class Graphs private[graphframes] () {
/**
* Returns a star graph with Long ID type, consisting of a central element indexed 0 (the root)
* and the n other leaf vertices 1, 2, ..., n.
- * @param n the number of leaves
+ * @param n
+ * the number of leaves
*/
def star(n: Long): GraphFrame = {
require(n >= 0L)
@@ -127,7 +136,8 @@ class Graphs private[graphframes] () {
(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
}
val edges = spark.createDataFrame(data).toDF("src", "dst", "weight")
- val vs = data.flatMap(r => r._1 :: r._2 :: Nil).collect().distinct.map(x => Tuple1(x)).toIndexedSeq
+ val vs =
+ data.flatMap(r => r._1 :: r._2 :: Nil).collect().distinct.map(x => Tuple1(x)).toIndexedSeq
val vertices = spark.createDataFrame(vs).toDF("id")
GraphFrame(vertices, edges)
}
@@ -155,29 +165,32 @@ class Graphs private[graphframes] () {
/**
* This method generates a grid Ising model with random parameters.
*
- * Ising models are probabilistic graphical models over binary variables x,,i,,.
- * Each binary variable x,,i,, corresponds to one vertex, and it may take values -1 or +1.
- * The probability distribution P(X) (over all x,,i,,) is parameterized by vertex factors a,,i,,
- * and edge factors b,,ij,,:
+ * Ising models are probabilistic graphical models over binary variables x,,i,,. Each binary
+ * variable x,,i,, corresponds to one vertex, and it may take values -1 or +1. The probability
+ * distribution P(X) (over all x,,i,,) is parameterized by vertex factors a,,i,, and edge
+ * factors b,,ij,,:
* {{{
* P(X) = (1/Z) * exp[ \sum_i a_i x_i + \sum_{ij} b_{ij} x_i x_j ]
* }}}
* where Z is the normalization constant (partition function). See
* [[https://en.wikipedia.org/wiki/Ising_model Wikipedia]] for more information on Ising models.
*
- * Each vertex is parameterized by a single scalar a,,i,,.
- * Each edge is parameterized by a single scalar b,,ij,,.
+ * Each vertex is parameterized by a single scalar a,,i,,. Each edge is parameterized by a
+ * single scalar b,,ij,,.
*
- * @param n Length of one side of the grid. The grid will be of size n x n.
- * @param vStd Standard deviation of normal distribution used to generate vertex factors "a".
- * Default of 1.0.
- * @param eStd Standard deviation of normal distribution used to generate edge factors "b".
- * Default of 1.0.
- * @return GraphFrame. Vertices have columns "id" and "a".
- * Edges have columns "src", "dst", and "b". Edges are directed, but they should be
- * treated as undirected in any algorithms run on this model.
- * Vertex IDs are of the form "i,j". E.g., vertex "1,3" is in the second row and fourth
- * column of the grid.
+ * @param n
+ * Length of one side of the grid. The grid will be of size n x n.
+ * @param vStd
+ * Standard deviation of normal distribution used to generate vertex factors "a". Default of
+ * 1.0.
+ * @param eStd
+ * Standard deviation of normal distribution used to generate edge factors "b". Default of
+ * 1.0.
+ * @return
+ * GraphFrame. Vertices have columns "id" and "a". Edges have columns "src", "dst", and "b".
+ * Edges are directed, but they should be treated as undirected in any algorithms run on this
+ * model. Vertex IDs are of the form "i,j". E.g., vertex "1,3" is in the second row and fourth
+ * column of the grid.
*/
def gridIsingModel(spark: SparkSession, n: Int, vStd: Double, eStd: Double): GraphFrame = {
require(n >= 1, s"Grid graph must have size >= 1, but was given invalid value n = $n")
@@ -195,20 +208,23 @@ class Graphs private[graphframes] () {
val vIDcol = toIDudf(col("i"), col("j"))
// Add random parameters generated from a normal distribution
val seed = 12345
- val vertices = coordinates.withColumn("id", vIDcol) // vertex IDs "i,j"
- .withColumn("a", randn(seed) * vStd) // Ising parameter for vertex
+ val vertices = coordinates
+ .withColumn("id", vIDcol) // vertex IDs "i,j"
+ .withColumn("a", randn(seed) * vStd) // Ising parameter for vertex
// Create the edge DataFrame
// Create SQL expression for converting coordinates (i,j+1) and (i+1,j) to string IDs
val rightIDcol = toIDudf(col("i"), col("j") + 1)
val downIDcol = toIDudf(col("i") + 1, col("j"))
- val horizontalEdges = coordinates.filter(col("j") =!= n - 1)
+ val horizontalEdges = coordinates
+ .filter(col("j") =!= n - 1)
.select(vIDcol.as("src"), rightIDcol.as("dst"))
- val verticalEdges = coordinates.filter(col("i") =!= n - 1)
+ val verticalEdges = coordinates
+ .filter(col("i") =!= n - 1)
.select(vIDcol.as("src"), downIDcol.as("dst"))
val allEdges = horizontalEdges.union(verticalEdges)
// Add random parameters from a normal distribution
- val edges = allEdges.withColumn("b", randn(seed + 1) * eStd) // Ising parameter for edge
+ val edges = allEdges.withColumn("b", randn(seed + 1) * eStd) // Ising parameter for edge
// Create the GraphFrame
val g = GraphFrame(vertices, edges)
diff --git a/src/main/scala/org/graphframes/exceptions.scala b/src/main/scala/org/graphframes/exceptions.scala
index e98ff84ee..fe2be9ba3 100644
--- a/src/main/scala/org/graphframes/exceptions.scala
+++ b/src/main/scala/org/graphframes/exceptions.scala
@@ -2,7 +2,6 @@ package org.graphframes
// All the public exceptions thrown by GraphFrame methods
-
/**
* Exception thrown when a pattern String for motif finding cannot be parsed.
*/
@@ -17,4 +16,4 @@ class NoSuchVertexException(message: String) extends Exception(message)
* Exception thrown when a parsed pattern for motif finding cannot be translated into a DataFrame
* query.
*/
-class InvalidPatternException() extends Exception()
\ No newline at end of file
+class InvalidPatternException() extends Exception()
diff --git a/src/main/scala/org/graphframes/lib/AggregateMessages.scala b/src/main/scala/org/graphframes/lib/AggregateMessages.scala
index 4c9efec1d..c3f721b21 100644
--- a/src/main/scala/org/graphframes/lib/AggregateMessages.scala
+++ b/src/main/scala/org/graphframes/lib/AggregateMessages.scala
@@ -23,41 +23,41 @@ import org.apache.spark.sql.{Column, DataFrame}
import org.graphframes.{GraphFrame, Logging}
/**
- * This is a primitive for implementing graph algorithms.
- * This method aggregates messages from the neighboring edges and vertices of each vertex.
+ * This is a primitive for implementing graph algorithms. This method aggregates messages from the
+ * neighboring edges and vertices of each vertex.
*
- * For each triplet (source vertex, edge, destination vertex) in [[GraphFrame.triplets]],
- * this can send a message to the source and/or destination vertices.
- * - `AggregateMessages.sendToSrc()` sends a message to the source vertex of each
- * triplet
- * - `AggregateMessages.sendToDst()` sends a message to the destination vertex of each
- * triplet
- * - `AggregateMessages.agg` specifies an aggregation function for aggregating the
- * messages sent to each vertex. It also runs the aggregation, computing a DataFrame
- * with one row for each vertex which receives > 0 messages. The DataFrame has 2 columns:
+ * For each triplet (source vertex, edge, destination vertex) in [[GraphFrame.triplets]], this can
+ * send a message to the source and/or destination vertices.
+ * - `AggregateMessages.sendToSrc()` sends a message to the source vertex of each triplet
+ * - `AggregateMessages.sendToDst()` sends a message to the destination vertex of each triplet
+ * - `AggregateMessages.agg` specifies an aggregation function for aggregating the messages sent
+ * to each vertex. It also runs the aggregation, computing a DataFrame with one row for each
+ * vertex which receives > 0 messages. The DataFrame has 2 columns:
* - vertex column ID (named [[GraphFrame.ID]])
- * - aggregate from messages sent to vertex (with the name given to the `Column` specified
- * in `AggregateMessages.agg()`)
+ * - aggregate from messages sent to vertex (with the name given to the `Column` specified in
+ * `AggregateMessages.agg()`)
*
* When specifying the messages and aggregation function, the user may reference columns using:
- * - [[AggregateMessages.src]]: column for source vertex of edge
- * - [[AggregateMessages.edge]]: column for edge
- * - [[AggregateMessages.dst]]: column for destination vertex of edge
- * - [[AggregateMessages.msg]]: message sent to vertex (for aggregation function)
+ * - [[AggregateMessages.src]]: column for source vertex of edge
+ * - [[AggregateMessages.edge]]: column for edge
+ * - [[AggregateMessages.dst]]: column for destination vertex of edge
+ * - [[AggregateMessages.msg]]: message sent to vertex (for aggregation function)
*
* Note: If you use this operation to write an iterative algorithm, you may want to use
* [[AggregateMessages$.getCachedDataFrame getCachedDataFrame()]] as a workaround for caching
* issues.
*
- * @example We can use this function to compute the in-degree of each vertex
- * {{{
+ * @example
+ * We can use this function to compute the in-degree of each vertex
+ * {{{
* val g: GraphFrame = Graph.textFile("twittergraph")
* val inDeg: DataFrame =
* g.aggregateMessages().sendToDst(lit(1)).agg(sum(AggregateMessagesBuilder.msg))
- * }}}
+ * }}}
*/
class AggregateMessages private[graphframes] (private val g: GraphFrame)
- extends Arguments with Serializable {
+ extends Arguments
+ with Serializable {
import org.graphframes.GraphFrame.{DST, ID, SRC}
@@ -84,15 +84,14 @@ class AggregateMessages private[graphframes] (private val g: GraphFrame)
def sendToDst(value: String): this.type = sendToDst(expr(value))
/**
- * Run the aggregation, returning the resulting DataFrame of aggregated messages.
- * This is a lazy operation, so the DataFrame will not be materialized until an action is
- * executed on it.
+ * Run the aggregation, returning the resulting DataFrame of aggregated messages. This is a lazy
+ * operation, so the DataFrame will not be materialized until an action is executed on it.
*
* This returns a DataFrame with schema:
- * - column "id": vertex ID
- * - aggCol: aggregate result
- * If you need to join this with the original [[GraphFrame.vertices]], you can run an inner
- * join of the form:
+ * - column "id": vertex ID
+ * - aggCol: aggregate result
+ * If you need to join this with the original [[GraphFrame.vertices]], you can run an inner join
+ * of the form:
* {{{
* val g: GraphFrame = ...
* val aggResult = g.AggregateMessagesBuilder.sendToSrc(msg).agg(aggFunc)
@@ -100,22 +99,24 @@ class AggregateMessages private[graphframes] (private val g: GraphFrame)
* }}}
*/
def agg(aggCol: Column): DataFrame = {
- require(msgToSrc.nonEmpty || msgToDst.nonEmpty, s"To run GraphFrame.aggregateMessages," +
- s" messages must be sent to src, dst, or both. Set using sendToSrc(), sendToDst().")
+ require(
+ msgToSrc.nonEmpty || msgToDst.nonEmpty,
+ s"To run GraphFrame.aggregateMessages," +
+ s" messages must be sent to src, dst, or both. Set using sendToSrc(), sendToDst().")
val triplets = g.triplets
val sentMsgsToSrc = msgToSrc.map { msg =>
- val msgsToSrc = triplets.select(
- msg.as(AggregateMessages.MSG_COL_NAME),
- triplets(SRC)(ID).as(ID))
+ val msgsToSrc =
+ triplets.select(msg.as(AggregateMessages.MSG_COL_NAME), triplets(SRC)(ID).as(ID))
// Inner join: only send messages to vertices with edges
- msgsToSrc.join(g.vertices, ID)
+ msgsToSrc
+ .join(g.vertices, ID)
.select(msgsToSrc(AggregateMessages.MSG_COL_NAME), col(ID))
}
val sentMsgsToDst = msgToDst.map { msg =>
- val msgsToDst = triplets.select(
- msg.as(AggregateMessages.MSG_COL_NAME),
- triplets(DST)(ID).as(ID))
- msgsToDst.join(g.vertices, ID)
+ val msgsToDst =
+ triplets.select(msg.as(AggregateMessages.MSG_COL_NAME), triplets(DST)(ID).as(ID))
+ msgsToDst
+ .join(g.vertices, ID)
.select(msgsToDst(AggregateMessages.MSG_COL_NAME), col(ID))
}
val unionedMsgs = (sentMsgsToSrc, sentMsgsToDst) match {
@@ -158,12 +159,12 @@ object AggregateMessages extends Logging with Serializable {
/**
* Create a new cached copy of a DataFrame. For iterative DataFrame-based algorithms.
*
- * WARNING: This is NOT the same as `DataFrame.cache()`.
- * The original DataFrame will NOT be cached.
+ * WARNING: This is NOT the same as `DataFrame.cache()`. The original DataFrame will NOT be
+ * cached.
*
* This is a workaround for SPARK-13346, which makes it difficult to use DataFrames in iterative
- * algorithms. This workaround converts the DataFrame to an RDD, caches the RDD, and creates
- * a new DataFrame. This is important for avoiding the creation of extremely complex DataFrame
+ * algorithms. This workaround converts the DataFrame to an RDD, caches the RDD, and creates a
+ * new DataFrame. This is important for avoiding the creation of extremely complex DataFrame
* query plans when using DataFrames in iterative algorithms.
*/
def getCachedDataFrame(df: DataFrame): DataFrame = {
diff --git a/src/main/scala/org/graphframes/lib/BFS.scala b/src/main/scala/org/graphframes/lib/BFS.scala
index 80ebe313e..802d8f63c 100644
--- a/src/main/scala/org/graphframes/lib/BFS.scala
+++ b/src/main/scala/org/graphframes/lib/BFS.scala
@@ -27,64 +27,59 @@ import org.graphframes.GraphFrame.nestAsCol
/**
* Breadth-first search (BFS)
*
- * This method returns a DataFrame of valid shortest paths from vertices matching `fromExpr`
- * to vertices matching `toExpr`. If multiple paths are valid and have the same length,
- * the DataFrame will return one Row for each path. If no paths are valid, the DataFrame will
- * be empty.
- * Note: "Shortest" means globally shortest path. I.e., if the shortest path between two vertices
+ * This method returns a DataFrame of valid shortest paths from vertices matching `fromExpr` to
+ * vertices matching `toExpr`. If multiple paths are valid and have the same length, the DataFrame
+ * will return one Row for each path. If no paths are valid, the DataFrame will be empty. Note:
+ * "Shortest" means globally shortest path. I.e., if the shortest path between two vertices
* matching `fromExpr` and `toExpr` is length 5 (edges) but no path is shorter than 5, then all
* paths returned by BFS will have length 5.
*
* The returned DataFrame will have the following columns:
- * - `from` start vertex of path
- * - `e[i]` edge i in the path, indexed from 0
- * - `v[i]` intermediate vertex i in the path, indexed from 1
- * - `to` end vertex of path
+ * - `from` start vertex of path
+ * - `e[i]` edge i in the path, indexed from 0
+ * - `v[i]` intermediate vertex i in the path, indexed from 1
+ * - `to` end vertex of path
* Each of these columns is a StructType whose fields are the same as the columns of
* [[GraphFrame.vertices]] or [[GraphFrame.edges]].
*
- * For example, suppose we have a graph g. Say the vertices DataFrame of g has columns "id" and
+ * For example, suppose we have a graph g. Say the vertices DataFrame of g has columns "id" and
* "job", and the edges DataFrame of g has columns "src", "dst", and "relation".
* {{{
* // Search from vertex "Joe" to find the closet vertices with attribute job = CEO.
* g.bfs(col("id") === "Joe", col("job") === "CEO").run()
* }}}
* If we found a path of 3 edges, each row would have columns:
- * {{{from | e0 | v1 | e1 | v2 | e2 | to}}}
- * In the above row, each vertex column (from, v1, v2, to) would have fields "id" and "job"
- * (just like g.vertices).
- * Each edge column (e0, e1, e2) would have fields "src", "dst", and "relation".
+ * {{{from | e0 | v1 | e1 | v2 | e2 | to}}} In the above row, each vertex column (from, v1, v2,
+ * to) would have fields "id" and "job" (just like g.vertices). Each edge column (e0, e1, e2)
+ * would have fields "src", "dst", and "relation".
*
* If there are ties, then each of the equal paths will be returned as a separate Row.
*
- * If one or more vertices match both the from and to conditions, then there is a 0-hop path.
- * The returned DataFrame will have the "from" and "to" columns (as above); however,
- * the "from" and "to" columns will be exactly the same. There will be one row for each vertex
- * in [[GraphFrame.vertices]] matching both `fromExpr` and `toExpr`.
+ * If one or more vertices match both the from and to conditions, then there is a 0-hop path. The
+ * returned DataFrame will have the "from" and "to" columns (as above); however, the "from" and
+ * "to" columns will be exactly the same. There will be one row for each vertex in
+ * [[GraphFrame.vertices]] matching both `fromExpr` and `toExpr`.
*
* Parameters:
*
- * - `fromExpr` Spark SQL expression specifying valid starting vertices for the BFS.
- * This condition will be matched against each vertex's id or attributes.
- * To start from a specific vertex, this could be "id = [start vertex id]".
- * To start from multiple valid vertices, this can operate on vertex attributes.
- *
- * - `toExpr` Spark SQL expression specifying valid target vertices for the BFS.
- * This condition will be matched against each vertex's id or attributes.
- *
- * - `maxPathLength` Limit on the length of paths. If no valid paths of length
- * <= maxPathLength are found, then the BFS is terminated.
- * (default = 10)
- * - `edgeFilter` Spark SQL expression specifying edges which may be used in the search.
- * This allows the user to disallow crossing certain edges. Such filters
- * can be applied post-hoc after BFS, run specifying the filter here is more
- * efficient.
+ * - `fromExpr` Spark SQL expression specifying valid starting vertices for the BFS. This
+ * condition will be matched against each vertex's id or attributes. To start from a specific
+ * vertex, this could be "id = [start vertex id]". To start from multiple valid vertices, this
+ * can operate on vertex attributes.
+ * - `toExpr` Spark SQL expression specifying valid target vertices for the BFS. This condition
+ * will be matched against each vertex's id or attributes.
+ * - `maxPathLength` Limit on the length of paths. If no valid paths of length <= maxPathLength
+ * are found, then the BFS is terminated. (default = 10)
+ * - `edgeFilter` Spark SQL expression specifying edges which may be used in the search. This
+ * allows the user to disallow crossing certain edges. Such filters can be applied post-hoc
+ * after BFS, run specifying the filter here is more efficient.
*
* Returns:
- * - DataFrame of valid shortest paths found in the BFS
+ * - DataFrame of valid shortest paths found in the BFS
*/
class BFS private[graphframes] (private val graph: GraphFrame)
- extends Arguments with Serializable {
+ extends Arguments
+ with Serializable {
private var maxPathLength: Int = 10
private var edgeFilter: Option[Column] = None
@@ -125,7 +120,6 @@ class BFS private[graphframes] (private val graph: GraphFrame)
}
}
-
private object BFS extends Logging with Serializable {
private def run(
@@ -147,7 +141,8 @@ private object BFS extends Logging with Serializable {
if (fromEqualsToDF.take(1).nonEmpty) {
// from == to, so return matching vertices
return fromEqualsToDF.select(
- nestAsCol(fromEqualsToDF, "from"), nestAsCol(fromEqualsToDF, "to"))
+ nestAsCol(fromEqualsToDF, "from"),
+ nestAsCol(fromEqualsToDF, "to"))
}
// We handled edge cases above, so now we do BFS.
@@ -179,15 +174,20 @@ private object BFS extends Logging with Serializable {
if (iter == 0) {
// Note: We could avoid this special case by initializing paths with just 1 "from" column,
// but that would create a longer lineage for the result DataFrame.
- paths = a2b.filter(fromAExpr)
- .filter(col("a.id") =!= col("b.id")) // remove self-loops
- .withColumnRenamed("a", "from").withColumnRenamed("e", nextEdge)
+ paths = a2b
+ .filter(fromAExpr)
+ .filter(col("a.id") =!= col("b.id")) // remove self-loops
+ .withColumnRenamed("a", "from")
+ .withColumnRenamed("e", nextEdge)
.withColumnRenamed("b", nextVertex)
} else {
val prevVertex = s"v$iter"
- val nextLinks = a2b.withColumnRenamed("a", prevVertex).withColumnRenamed("e", nextEdge)
+ val nextLinks = a2b
+ .withColumnRenamed("a", prevVertex)
+ .withColumnRenamed("e", nextEdge)
.withColumnRenamed("b", nextVertex)
- paths = paths.join(nextLinks, paths(prevVertex + ".id") === nextLinks(prevVertex + ".id"))
+ paths = paths
+ .join(nextLinks, paths(prevVertex + ".id") === nextLinks(prevVertex + ".id"))
.drop(paths(prevVertex))
// Make sure we are not backtracking within each path.
// TODO: Avoid crossing paths; i.e., touch each vertex at most once.
@@ -222,24 +222,24 @@ private object BFS extends Logging with Serializable {
} else {
logInfo(s"GraphFrame.bfs failed to find a path of length <= $maxPathLength.")
// Return empty DataFrame
- g.spark.createDataFrame(
- g.spark.sparkContext.parallelize(Seq.empty[Row]),
- g.vertices.schema)
+ g.spark.createDataFrame(g.spark.sparkContext.parallelize(Seq.empty[Row]), g.vertices.schema)
}
}
-
/**
- * Apply the given SQL expression (such as `id = 3`) to the field in a column,
- * rather than to the column itself.
+ * Apply the given SQL expression (such as `id = 3`) to the field in a column, rather than to
+ * the column itself.
*
- * @param expr SQL expression, such as `id = 3`
- * @param colName Column name, such as `myVertex`
- * @return SQL expression applied to the column fields, such as `myVertex.id = 3`
+ * @param expr
+ * SQL expression, such as `id = 3`
+ * @param colName
+ * Column name, such as `myVertex`
+ * @return
+ * SQL expression applied to the column fields, such as `myVertex.id = 3`
*/
private def applyExprToCol(expr: Column, colName: String) = {
- new Column(expr.expr.transform {
- case UnresolvedAttribute(nameParts) => UnresolvedAttribute(colName +: nameParts)
+ new Column(expr.expr.transform { case UnresolvedAttribute(nameParts) =>
+ UnresolvedAttribute(colName +: nameParts)
})
}
}
diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala
index dbc1d98f7..68c13b85b 100644
--- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala
+++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala
@@ -36,21 +36,22 @@ import org.graphframes.{GraphFrame, Logging}
* information with each vertex assigned a component ID.
*
* The resulting DataFrame contains all the vertex information and one additional column:
- * - component (`LongType`): unique ID for this component
+ * - component (`LongType`): unique ID for this component
*/
-class ConnectedComponents private[graphframes] (
- private val graph: GraphFrame) extends Arguments with Logging {
+class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
+ extends Arguments
+ with Logging {
import org.graphframes.lib.ConnectedComponents._
private var broadcastThreshold: Int = 1000000
/**
- * Sets broadcast threshold in propagating component assignments (default: 1000000).
- * If a node degree is greater than this threshold at some iteration, its component assignment
- * will be collected and then broadcasted back to propagate the assignment to its neighbors.
- * Otherwise, the assignment propagation is done by a normal Spark join.
- * This parameter is only used when the algorithm is set to "graphframes".
+ * Sets broadcast threshold in propagating component assignments (default: 1000000). If a node
+ * degree is greater than this threshold at some iteration, its component assignment will be
+ * collected and then broadcasted back to propagate the assignment to its neighbors. Otherwise,
+ * the assignment propagation is done by a normal Spark join. This parameter is only used when
+ * the algorithm is set to "graphframes".
*/
def setBroadcastThreshold(value: Int): this.type = {
require(value >= 0, s"Broadcast threshold must be non-negative but got $value.")
@@ -65,24 +66,27 @@ class ConnectedComponents private[graphframes] (
/**
* Gets broadcast threshold in propagating component assignment.
- * @see [[org.graphframes.lib.ConnectedComponents.setBroadcastThreshold]]
+ * @see
+ * [[org.graphframes.lib.ConnectedComponents.setBroadcastThreshold]]
*/
def getBroadcastThreshold: Int = broadcastThreshold
private var algorithm: String = ALGO_GRAPHFRAMES
/**
- * Sets the connected components algorithm to use (default: "graphframes").
- * Supported algorithms are:
+ * Sets the connected components algorithm to use (default: "graphframes"). Supported algorithms
+ * are:
* - "graphframes": Uses alternating large star and small star iterations proposed in
* [[http://dx.doi.org/10.1145/2670979.2670997 Connected Components in MapReduce and Beyond]]
* with skewed join optimization.
* - "graphx": Converts the graph to a GraphX graph and then uses the connected components
* implementation in GraphX.
- * @see [[org.graphframes.lib.ConnectedComponents.supportedAlgorithms]]
+ * @see
+ * [[org.graphframes.lib.ConnectedComponents.supportedAlgorithms]]
*/
def setAlgorithm(value: String): this.type = {
- require(supportedAlgorithms.contains(value),
+ require(
+ supportedAlgorithms.contains(value),
s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $value.")
algorithm = value
this
@@ -90,26 +94,26 @@ class ConnectedComponents private[graphframes] (
/**
* Gets the connected component algorithm to use.
- * @see [[org.graphframes.lib.ConnectedComponents.setAlgorithm]].
+ * @see
+ * [[org.graphframes.lib.ConnectedComponents.setAlgorithm]].
*/
def getAlgorithm: String = algorithm
private var checkpointInterval: Int = 2
/**
- * Sets checkpoint interval in terms of number of iterations (default: 2).
- * Checkpointing regularly helps recover from failures, clean shuffle files, shorten the
- * lineage of the computation graph, and reduce the complexity of plan optimization.
- * As of Spark 2.0, the complexity of plan optimization would grow exponentially without
- * checkpointing.
- * Hence disabling or setting longer-than-default checkpoint intervals are not recommended.
- * Checkpoint data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with
- * prefix "connected-components".
- * If the checkpoint directory is not set, this throws a `java.io.IOException`.
- * Set a nonpositive value to disable checkpointing.
- * This parameter is only used when the algorithm is set to "graphframes".
- * Its default value might change in the future.
- * @see `org.apache.spark.SparkContext.setCheckpointDir` in Spark API doc
+ * Sets checkpoint interval in terms of number of iterations (default: 2). Checkpointing
+ * regularly helps recover from failures, clean shuffle files, shorten the lineage of the
+ * computation graph, and reduce the complexity of plan optimization. As of Spark 2.0, the
+ * complexity of plan optimization would grow exponentially without checkpointing. Hence
+ * disabling or setting longer-than-default checkpoint intervals are not recommended. Checkpoint
+ * data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with prefix
+ * "connected-components". If the checkpoint directory is not set, this throws a
+ * `java.io.IOException`. Set a nonpositive value to disable checkpointing. This parameter is
+ * only used when the algorithm is set to "graphframes". Its default value might change in the
+ * future.
+ * @see
+ * `org.apache.spark.SparkContext.setCheckpointDir` in Spark API doc
*/
def setCheckpointInterval(value: Int): this.type = {
if (value <= 0 || value > 2) {
@@ -128,14 +132,16 @@ class ConnectedComponents private[graphframes] (
/**
* Gets checkpoint interval.
- * @see [[org.graphframes.lib.ConnectedComponents.setCheckpointInterval]]
+ * @see
+ * [[org.graphframes.lib.ConnectedComponents.setCheckpointInterval]]
*/
def getCheckpointInterval: Int = checkpointInterval
private var intermediateStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
/**
- * Sets storage level for intermediate datasets that require multiple passes (default: ``MEMORY_AND_DISK``).
+ * Sets storage level for intermediate datasets that require multiple passes (default:
+ * ``MEMORY_AND_DISK``).
*/
def setIntermediateStorageLevel(value: StorageLevel): this.type = {
intermediateStorageLevel = value
@@ -151,7 +157,8 @@ class ConnectedComponents private[graphframes] (
* Runs the algorithm.
*/
def run(): DataFrame = {
- ConnectedComponents.run(graph,
+ ConnectedComponents.run(
+ graph,
algorithm = algorithm,
broadcastThreshold = broadcastThreshold,
checkpointInterval = checkpointInterval,
@@ -173,21 +180,20 @@ object ConnectedComponents extends Logging {
private val ALGO_GRAPHFRAMES = "graphframes"
/**
- * Supported algorithms in [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]: "graphframes"
- * and "graphx".
+ * Supported algorithms in [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]:
+ * "graphframes" and "graphx".
*/
val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES)
/**
* Returns the symmetric directed graph of the graph specified by input edges.
- * @param ee non-bidirectional edges
+ * @param ee
+ * non-bidirectional edges
*/
private def symmetrize(ee: DataFrame): DataFrame = {
val EDGE = "_edge"
- ee.select(explode(array(
- struct(col(SRC), col(DST)),
- struct(col(DST).as(SRC), col(SRC).as(DST)))
- ).as(EDGE))
+ ee.select(explode(
+ array(struct(col(SRC), col(DST)), struct(col(DST).as(SRC), col(SRC).as(DST)))).as(EDGE))
.select(col(s"$EDGE.$SRC").as(SRC), col(s"$EDGE.$DST").as(DST))
}
@@ -208,11 +214,12 @@ object ConnectedComponents extends Logging {
// TODO: This assignment job might fail if the graph is skewed.
val vertices = graph.indexedVertices
.select(col(LONG_ID).as(ID), col(ATTR))
- // TODO: confirm the contract for a graph and decide whether we need distinct here
- // .distinct()
+ // TODO: confirm the contract for a graph and decide whether we need distinct here
+ // .distinct()
val edges = graph.indexedEdges
.select(col(LONG_SRC).as(SRC), col(LONG_DST).as(DST))
- val orderedEdges = edges.filter(col(SRC) =!= col(DST))
+ val orderedEdges = edges
+ .filter(col(SRC) =!= col(DST))
.select(minValue(col(SRC), col(DST)).as(SRC), maxValue(col(SRC), col(DST)).as(DST))
.distinct()
GraphFrame(vertices, orderedEdges)
@@ -226,7 +233,8 @@ object ConnectedComponents extends Logging {
*/
private def minNbrs(ee: DataFrame): DataFrame = {
symmetrize(ee)
- .groupBy(SRC).agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT))
+ .groupBy(SRC)
+ .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT))
.withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR)))
}
@@ -239,8 +247,8 @@ object ConnectedComponents extends Logging {
}
/**
- * Performs a possibly skewed join between edges and current component assignments.
- * The skew join is done by broadcast join for frequent keys and normal join for the rest.
+ * Performs a possibly skewed join between edges and current component assignments. The skew
+ * join is done by broadcast join for frequent keys and normal join for the rest.
*/
private def skewedJoin(
edges: DataFrame,
@@ -248,7 +256,8 @@ object ConnectedComponents extends Logging {
broadcastThreshold: Int,
logPrefix: String): DataFrame = {
import edges.sparkSession.implicits._
- val hubs = minNbrs.filter(col(CNT) > broadcastThreshold)
+ val hubs = minNbrs
+ .filter(col(CNT) > broadcastThreshold)
.select(SRC)
.as[Long]
.collect()
@@ -264,7 +273,8 @@ object ConnectedComponents extends Logging {
}
private def runGraphX(graph: GraphFrame): DataFrame = {
- val components = org.apache.spark.graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX)
+ val components =
+ org.apache.spark.graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX)
GraphXConversions.fromGraphX(graph, components, vertexNames = Seq(COMPONENT)).vertices
}
@@ -274,7 +284,8 @@ object ConnectedComponents extends Logging {
broadcastThreshold: Int,
checkpointInterval: Int,
intermediateStorageLevel: StorageLevel): DataFrame = {
- require(supportedAlgorithms.contains(algorithm),
+ require(
+ supportedAlgorithms.contains(algorithm),
s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $algorithm.")
if (algorithm == ALGO_GRAPHX) {
@@ -295,12 +306,14 @@ object ConnectedComponents extends Logging {
val shouldCheckpoint = checkpointInterval > 0
val checkpointDir: Option[String] = if (shouldCheckpoint) {
- val dir = sc.getCheckpointDir.map { d =>
- new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString
- }.getOrElse {
- throw new IOException(
- "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir().")
- }
+ val dir = sc.getCheckpointDir
+ .map { d =>
+ new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString
+ }
+ .getOrElse {
+ throw new IOException(
+ "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir().")
+ }
logInfo(s"$logPrefix Using $dir for checkpointing with interval $checkpointInterval.")
Some(dir)
} else {
@@ -323,13 +336,15 @@ object ConnectedComponents extends Logging {
// Taking the sum in DecimalType to preserve precision.
// We use 20 digits for long values and Spark SQL will add 10 digits for the sum.
// It should be able to handle 200 billion edges without overflow.
- val (minNbrSum, cnt) = minNbrsDF.select(sum(col(MIN_NBR).cast(DecimalType(20, 0))), count("*")).rdd
+ val (minNbrSum, cnt) = minNbrsDF
+ .select(sum(col(MIN_NBR).cast(DecimalType(20, 0))), count("*"))
+ .rdd
.map { r =>
(r.getAs[BigDecimal](0), r.getLong(1))
- }.first()
+ }
+ .first()
if (cnt != 0L && minNbrSum == null) {
- throw new ArithmeticException(
- s"""
+ throw new ArithmeticException(s"""
|The total sum of edge src IDs is used to determine convergence during iterations.
|However, the total sum at iteration $iteration exceeded 30 digits (1e30),
|which should happen only if the graph contains more than 200 billion edges.
@@ -357,7 +372,9 @@ object ConnectedComponents extends Logging {
// small-star step
// compute min neighbors (excluding self-min)
- val minNbrs2 = ee.groupBy(col(SRC)).agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr
+ val minNbrs2 = ee
+ .groupBy(col(SRC))
+ .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr
.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs2
@@ -366,7 +383,8 @@ object ConnectedComponents extends Logging {
.select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst
.filter(col(SRC) =!= col(DST)) // src < dst
// connect self to the min neighbor
- ee = ee.union(minNbrs2.select(col(MIN_NBR).as(SRC), col(SRC).as(DST))) // src < dst
+ ee = ee
+ .union(minNbrs2.select(col(MIN_NBR).as(SRC), col(SRC).as(DST))) // src < dst
.distinct()
// checkpointing
@@ -406,7 +424,7 @@ object ConnectedComponents extends Logging {
// materialize all persisted DataFrames in current round,
// then we can unpersist last round persisted DataFrames.
for (persisted_df <- currRoundPersistedDFs) {
- persisted_df.count() // materialize it.
+ persisted_df.count() // materialize it.
}
for (persisted_df <- lastRoundPersistedDFs) {
persisted_df.unpersist()
diff --git a/src/main/scala/org/graphframes/lib/GraphXConversions.scala b/src/main/scala/org/graphframes/lib/GraphXConversions.scala
index 13f9beb65..0143f0b7b 100644
--- a/src/main/scala/org/graphframes/lib/GraphXConversions.scala
+++ b/src/main/scala/org/graphframes/lib/GraphXConversions.scala
@@ -27,28 +27,28 @@ import org.apache.spark.sql.types.{StructField, StructType}
import org.graphframes.{NoSuchVertexException, GraphFrame}
/**
- * Convenience functions to map GraphX graphs to GraphFrames,
- * checking for the types expected by GraphX.
+ * Convenience functions to map GraphX graphs to GraphFrames, checking for the types expected by
+ * GraphX.
*/
private[graphframes] object GraphXConversions {
import GraphFrame._
/** Indicates if T is a Unit type */
- private def isUnitType[T : TypeTag] : Boolean = {
+ private def isUnitType[T: TypeTag]: Boolean = {
val t = typeOf[T]
typeOf[Unit] =:= t
}
/** Indicates if T is a Product type */
- private def isProductType[T : TypeTag] : Boolean = {
+ private def isProductType[T: TypeTag]: Boolean = {
val t = typeOf[T]
// See http://stackoverflow.com/questions/21209006/how-to-check-if-reflected-type-represents-a-tuple
t.typeSymbol.fullName.startsWith("scala.Tuple")
}
/** See [[GraphFrame.fromGraphX()]] for documentation */
- def fromGraphX[V : TypeTag, E : TypeTag](
+ def fromGraphX[V: TypeTag, E: TypeTag](
originalGraph: GraphFrame,
graph: Graph[V, E],
vertexNames: Seq[String] = Nil,
@@ -64,7 +64,7 @@ private[graphframes] object GraphXConversions {
renameStructFields(vertexDF0, GX_ATTR, vertexNames)
} else {
// Assume it is just one field, and pack it in a tuple to have a structure.
- val vertexData = graph.vertices.map { case (vid, data) => (vid, Tuple1(data)) }
+ val vertexData = graph.vertices.map { case (vid, data) => (vid, Tuple1(data)) }
val vertexDF0 = spark.createDataFrame(vertexData).toDF(LONG_ID, GX_ATTR)
renameStructFields(vertexDF0, GX_ATTR, vertexNames)
}
@@ -85,12 +85,14 @@ private[graphframes] object GraphXConversions {
}
/**
- * Given the name of a column (assumed to contain a struct),
- * renames all the fields of this struct.
+ * Given the name of a column (assumed to contain a struct), renames all the fields of this
+ * struct.
*
- * @param structName Struct name whose fields will be renamed. This method assumes this field
- * exists and will not check for errors.
- * @param fieldNames List of new field names corresponding to all fields in the struct col.
+ * @param structName
+ * Struct name whose fields will be renamed. This method assumes this field exists and will
+ * not check for errors.
+ * @param fieldNames
+ * List of new field names corresponding to all fields in the struct col.
*/
private[lib] def renameStructFields(
df: DataFrame,
@@ -127,12 +129,13 @@ private[graphframes] object GraphXConversions {
}
/**
- * Joins all the data from the original columns against the new data. Assumes the columns
- * are not going to conflict.
+ * Joins all the data from the original columns against the new data. Assumes the columns are
+ * not going to conflict.
*
- * @param gxVertexData DataFrame with column [[LONG_ID]] and optionally column [[GX_ATTR]]
- * @param gxEdgeData DataFrame with columns [[LONG_DST]], [[LONG_SRC]] and optionally column
- * [[GX_ATTR]]
+ * @param gxVertexData
+ * DataFrame with column [[LONG_ID]] and optionally column [[GX_ATTR]]
+ * @param gxEdgeData
+ * DataFrame with columns [[LONG_DST]], [[LONG_SRC]] and optionally column [[GX_ATTR]]
*/
private def fromGraphX(
originalGraph: GraphFrame,
@@ -146,16 +149,19 @@ private[graphframes] object GraphXConversions {
val indexedEdges = originalGraph.indexedEdges
// Handle 2 cases: GraphX edge has attr, or not.
val hasGxAttr = gxEdgeData.schema.exists(_.name == GX_ATTR)
- val gxCol = if (hasGxAttr) { Seq(col(GX_ATTR)) } else { Seq() }
+ val gxCol = if (hasGxAttr) { Seq(col(GX_ATTR)) }
+ else { Seq() }
val sel1 = Seq(col(LONG_SRC), col(LONG_DST)) ++ gxCol
val gxe = gxEdgeData.select(sel1: _*)
val sel3 = Seq(col(ATTR)) ++ gxCol
// TODO: CHECK IN UNIT TESTS: Drop the src and dst columns from the index, they are already
// in the attributes and will be unpacked with the rest of the user columns.
// TODO(tjh) 2-step join?
- gxe.join(
- indexedEdges.select(indexedEdges(LONG_SRC), indexedEdges(LONG_DST), indexedEdges(ATTR)),
- (gxe(LONG_SRC) === indexedEdges(LONG_SRC)) && (gxe(LONG_DST) === indexedEdges(LONG_DST)))
+ gxe
+ .join(
+ indexedEdges.select(indexedEdges(LONG_SRC), indexedEdges(LONG_DST), indexedEdges(ATTR)),
+ (gxe(LONG_SRC) === indexedEdges(LONG_SRC)) && (gxe(LONG_DST) === indexedEdges(
+ LONG_DST)))
.select(sel3: _*)
}
val edgeDF = unpackStructFields(drop(packedEdges, LONG_SRC, LONG_DST))
@@ -164,8 +170,8 @@ private[graphframes] object GraphXConversions {
}
/**
- * Given a graph and an object, gets the the corresponding integral id in the
- * internal representation.
+ * Given a graph and an object, gets the the corresponding integral id in the internal
+ * representation.
*/
private[graphframes] def integralId(graph: GraphFrame, vertexId: Any): Long = {
// Check if we can directly convert it
@@ -179,10 +185,12 @@ private[graphframes] object GraphXConversions {
// If the vertex is a non-integral type such as a String, we need to use the translation table.
val longIdRow: Array[Row] = graph.indexedVertices
.filter(col(GraphFrame.ID) === vertexId)
- .select(GraphFrame.LONG_ID).take(1)
+ .select(GraphFrame.LONG_ID)
+ .take(1)
if (longIdRow.isEmpty) {
- throw new NoSuchVertexException(s"GraphFrame algorithm given vertex ID which does not exist" +
- s" in Graph. Vertex ID $vertexId not contained in $graph")
+ throw new NoSuchVertexException(
+ s"GraphFrame algorithm given vertex ID which does not exist" +
+ s" in Graph. Vertex ID $vertexId not contained in $graph")
}
// TODO(tjh): could do more informative message
longIdRow.head.getLong(0)
diff --git a/src/main/scala/org/graphframes/lib/LabelPropagation.scala b/src/main/scala/org/graphframes/lib/LabelPropagation.scala
index a740dd5fd..877d7345b 100644
--- a/src/main/scala/org/graphframes/lib/LabelPropagation.scala
+++ b/src/main/scala/org/graphframes/lib/LabelPropagation.scala
@@ -30,19 +30,19 @@ import org.graphframes.GraphFrame
* affiliation of incoming messages.
*
* LPA is a standard community detection algorithm for graphs. It is very inexpensive
- * computationally, although (1) convergence is not guaranteed and (2) one can end up with
- * trivial solutions (all nodes are identified into a single community).
+ * computationally, although (1) convergence is not guaranteed and (2) one can end up with trivial
+ * solutions (all nodes are identified into a single community).
*
* The resulting DataFrame contains all the original vertex information and one additional column:
- * - label (`LongType`): label of community affiliation
+ * - label (`LongType`): label of community affiliation
*/
class LabelPropagation private[graphframes] (private val graph: GraphFrame) extends Arguments {
private var maxIter: Option[Int] = None
/**
- * The max number of iterations of LPA to be performed. Because this is a static
- * implementation, the algorithm will run for exactly this many iterations.
+ * The max number of iterations of LPA to be performed. Because this is a static implementation,
+ * the algorithm will run for exactly this many iterations.
*/
def maxIter(value: Int): this.type = {
maxIter = Some(value)
@@ -50,13 +50,10 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame) exte
}
def run(): DataFrame = {
- LabelPropagation.run(
- graph,
- check(maxIter, "maxIter"))
+ LabelPropagation.run(graph, check(maxIter, "maxIter"))
}
}
-
private object LabelPropagation {
private def run(graph: GraphFrame, maxIter: Int): DataFrame = {
val gx = graphxlib.LabelPropagation.run(graph.cachedTopologyGraphX, maxIter)
diff --git a/src/main/scala/org/graphframes/lib/PageRank.scala b/src/main/scala/org/graphframes/lib/PageRank.scala
index 2c9a9cf2e..b5bcfa53a 100644
--- a/src/main/scala/org/graphframes/lib/PageRank.scala
+++ b/src/main/scala/org/graphframes/lib/PageRank.scala
@@ -24,9 +24,9 @@ import org.graphframes.GraphFrame
/**
* PageRank algorithm implementation. There are two implementations of PageRank.
*
- * The first one uses the `org.apache.spark.graphx.graph` interface with `aggregateMessages` and runs
- * PageRank for a fixed number of iterations. This can be executed by setting `maxIter`. Conceptually,
- * the algorithm does the following:
+ * The first one uses the `org.apache.spark.graphx.graph` interface with `aggregateMessages` and
+ * runs PageRank for a fixed number of iterations. This can be executed by setting `maxIter`.
+ * Conceptually, the algorithm does the following:
* {{{
* var PR = Array.fill(n)( 1.0 )
* val oldPR = Array.fill(n)( 1.0 )
@@ -38,8 +38,9 @@ import org.graphframes.GraphFrame
* }
* }}}
*
- * The second implementation uses the `org.apache.spark.graphx.Pregel` interface and runs PageRank until
- * convergence and this can be run by setting `tol`. Conceptually, the algorithm does the following:
+ * The second implementation uses the `org.apache.spark.graphx.Pregel` interface and runs PageRank
+ * until convergence and this can be run by setting `tol`. Conceptually, the algorithm does the
+ * following:
* {{{
* var PR = Array.fill(n)( 1.0 )
* val oldPR = Array.fill(n)( 0.0 )
@@ -51,26 +52,24 @@ import org.graphframes.GraphFrame
* }
* }}}
*
- * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of
- * neighbors which link to `i` and `outDeg[j]` is the out degree of vertex `j`.
+ * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of neighbors
+ * which link to `i` and `outDeg[j]` is the out degree of vertex `j`.
*
- * Note that this is not the "normalized" PageRank and as a consequence pages that have no
- * inlinks will have a PageRank of alpha. In particular, the pageranks may have some values
- * greater than 1.
+ * Note that this is not the "normalized" PageRank and as a consequence pages that have no inlinks
+ * will have a PageRank of alpha. In particular, the pageranks may have some values greater than 1.
*
* The resulting vertices DataFrame contains one additional column:
- * - pagerank (`DoubleType`): the pagerank of this vertex
+ * - pagerank (`DoubleType`): the pagerank of this vertex
*
* The resulting edges DataFrame contains one additional column:
- * - weight (`DoubleType`): the normalized weight of this edge after running PageRank
+ * - weight (`DoubleType`): the normalized weight of this edge after running PageRank
*/
-class PageRank private[graphframes] (
- private val graph: GraphFrame) extends Arguments {
+class PageRank private[graphframes] (private val graph: GraphFrame) extends Arguments {
private var tol: Option[Double] = None
private var resetProb: Option[Double] = Some(0.15)
private var maxIter: Option[Int] = None
- private var srcId : Option[Any] = None
+ private var srcId: Option[Any] = None
/** Source vertex for a Personalized Page Rank (optional) */
def sourceId(value: Any): this.type = {
@@ -97,8 +96,7 @@ class PageRank private[graphframes] (
def run(): GraphFrame = {
tol match {
case Some(t) =>
- assert(maxIter.isEmpty,
- "You cannot specify maxIter() and tol() at the same time.")
+ assert(maxIter.isEmpty, "You cannot specify maxIter() and tol() at the same time.")
PageRank.runUntilConvergence(graph, t, resetProb.get, srcId)
case None =>
PageRank.run(graph, check(maxIter, "maxIter"), resetProb.get, srcId)
@@ -106,20 +104,23 @@ class PageRank private[graphframes] (
}
}
-
// TODO: srcID's type should be checked. The most futureproof check would be Encoder because it is
// compatible with Datasets after that.
private object PageRank {
+
/**
- * Run PageRank for a fixed number of iterations returning a graph
- * with vertex attributes containing the PageRank and edge
- * attributes the normalized edge weight.
+ * Run PageRank for a fixed number of iterations returning a graph with vertex attributes
+ * containing the PageRank and edge attributes the normalized edge weight.
*
- * @param graph the graph on which to compute PageRank
- * @param maxIter the number of iterations of PageRank to run
- * @param resetProb the random reset probability (alpha)
- * @return the graph containing with each vertex containing the PageRank and each edge
- * containing the normalized weight.
+ * @param graph
+ * the graph on which to compute PageRank
+ * @param maxIter
+ * the number of iterations of PageRank to run
+ * @param resetProb
+ * the random reset probability (alpha)
+ * @return
+ * the graph containing with each vertex containing the PageRank and each edge containing the
+ * normalized weight.
*/
def run(
graph: GraphFrame,
@@ -127,8 +128,8 @@ private object PageRank {
resetProb: Double = 0.15,
srcId: Option[Any] = None): GraphFrame = {
val longSrcId = srcId.map(GraphXConversions.integralId(graph, _))
- val gx = graphxlib.PageRank.runWithOptions(
- graph.cachedTopologyGraphX, maxIter, resetProb, longSrcId)
+ val gx =
+ graphxlib.PageRank.runWithOptions(graph.cachedTopologyGraphX, maxIter, resetProb, longSrcId)
GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANK), edgeNames = Seq(WEIGHT))
}
@@ -136,12 +137,17 @@ private object PageRank {
* Run a dynamic version of PageRank returning a graph with vertex attributes containing the
* PageRank and edge attributes containing the normalized edge weight.
*
- * @param graph the graph on which to compute PageRank
- * @param tol the tolerance allowed at convergence (smaller => more accurate).
- * @param resetProb the random reset probability (alpha)
- * @param srcId the source vertex for a Personalized Page Rank (optional)
- * @return the graph containing with each vertex containing the PageRank and each edge
- * containing the normalized weight.
+ * @param graph
+ * the graph on which to compute PageRank
+ * @param tol
+ * the tolerance allowed at convergence (smaller => more accurate).
+ * @param resetProb
+ * the random reset probability (alpha)
+ * @param srcId
+ * the source vertex for a Personalized Page Rank (optional)
+ * @return
+ * the graph containing with each vertex containing the PageRank and each edge containing the
+ * normalized weight.
*/
def runUntilConvergence(
graph: GraphFrame,
@@ -150,7 +156,10 @@ private object PageRank {
srcId: Option[Any] = None): GraphFrame = {
val longSrcId = srcId.map(GraphXConversions.integralId(graph, _))
val gx = graphxlib.PageRank.runUntilConvergenceWithOptions(
- graph.cachedTopologyGraphX, tol, resetProb, longSrcId)
+ graph.cachedTopologyGraphX,
+ tol,
+ resetProb,
+ longSrcId)
GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANK), edgeNames = Seq(WEIGHT))
}
diff --git a/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala b/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala
index f8f55d732..4f4ecfa41 100644
--- a/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala
+++ b/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala
@@ -23,11 +23,10 @@ import org.graphframes.{GraphFrame, Logging}
/**
* Parallel Personalized PageRank algorithm implementation.
*
- * This implementation uses the standalone [[GraphFrame]] interface and
- * runs personalized PageRank in parallel for a fixed number of iterations.
- * This can be run by setting `maxIter`.
- * The source vertex Ids are set in `sourceIds`.
- * A simple local implementation of this algorithm is as follows.
+ * This implementation uses the standalone [[GraphFrame]] interface and runs personalized PageRank
+ * in parallel for a fixed number of iterations. This can be run by setting `maxIter`. The source
+ * vertex Ids are set in `sourceIds`. A simple local implementation of this algorithm is as
+ * follows.
* {{{
* var oldPR = Array.fill(n)( 1.0 )
* val PR = (0 until n).map(i => if sourceIds.contains(i) alpha else 0.0)
@@ -40,21 +39,20 @@ import org.graphframes.{GraphFrame, Logging}
* }
* }}}
*
- * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of
- * neighbors which link to `i` and `outDeg[j]` is the out degree of vertex `j`.
+ * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of neighbors
+ * which link to `i` and `outDeg[j]` is the out degree of vertex `j`.
*
- * Note that this is not the "normalized" PageRank and as a consequence pages that have no
- * inlinks will have a PageRank of alpha. In particular, the pageranks may have some values
- * greater than 1.
+ * Note that this is not the "normalized" PageRank and as a consequence pages that have no inlinks
+ * will have a PageRank of alpha. In particular, the pageranks may have some values greater than 1.
*
* The resulting vertices DataFrame contains one additional column:
- * - pageranks (`VectorType`): the pageranks of this vertex from all input source vertices
+ * - pageranks (`VectorType`): the pageranks of this vertex from all input source vertices
*
* The resulting edges DataFrame contains one additional column:
- * - weight (`DoubleType`): the normalized weight of this edge after running PageRank
+ * - weight (`DoubleType`): the normalized weight of this edge after running PageRank
*/
-class ParallelPersonalizedPageRank private[graphframes] (
- private val graph: GraphFrame) extends Arguments {
+class ParallelPersonalizedPageRank private[graphframes] (private val graph: GraphFrame)
+ extends Arguments {
private var resetProb: Option[Double] = Some(0.15)
private var maxIter: Option[Int] = None
@@ -86,6 +84,7 @@ class ParallelPersonalizedPageRank private[graphframes] (
}
private object ParallelPersonalizedPageRank {
+
/** Default name for the pageranks column. */
private val PAGERANKS = "pageranks"
@@ -93,18 +92,21 @@ private object ParallelPersonalizedPageRank {
private val WEIGHT = "weight"
/**
- * Run Personalized PageRank for a fixed number of iterations, for a
- * set of starting nodes in parallel. Returns a graph with vertex attributes
- * containing the pageranks relative to all starting nodes (as a vector) and
- * edge attributes the normalized edge weight
+ * Run Personalized PageRank for a fixed number of iterations, for a set of starting nodes in
+ * parallel. Returns a graph with vertex attributes containing the pageranks relative to all
+ * starting nodes (as a vector) and edge attributes the normalized edge weight
*
- * @param graph The graph on which to compute personalized pagerank
- * @param maxIter The number of iterations to run
- * @param resetProb The random reset probability
- * @param sourceIds The list of sources to compute personalized pagerank from
- * @return the graph with vertex attributes
- * containing the pageranks relative to all starting nodes as a vector and
- * edge attributes the normalized edge weight
+ * @param graph
+ * The graph on which to compute personalized pagerank
+ * @param maxIter
+ * The number of iterations to run
+ * @param resetProb
+ * The random reset probability
+ * @param sourceIds
+ * The list of sources to compute personalized pagerank from
+ * @return
+ * the graph with vertex attributes containing the pageranks relative to all starting nodes as
+ * a vector and edge attributes the normalized edge weight
*/
def run(
graph: GraphFrame,
@@ -113,7 +115,10 @@ private object ParallelPersonalizedPageRank {
sourceIds: Array[Any]): GraphFrame = {
val longSrcIds = sourceIds.map(GraphXConversions.integralId(graph, _))
val gx = graphxlib.PageRank.runParallelPersonalizedPageRank(
- graph.cachedTopologyGraphX, maxIter, resetProb, longSrcIds)
+ graph.cachedTopologyGraphX,
+ maxIter,
+ resetProb,
+ longSrcIds)
GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANKS), edgeNames = Seq(WEIGHT))
}
}
diff --git a/src/main/scala/org/graphframes/lib/Pregel.scala b/src/main/scala/org/graphframes/lib/Pregel.scala
index 50b207092..beba0eb4e 100644
--- a/src/main/scala/org/graphframes/lib/Pregel.scala
+++ b/src/main/scala/org/graphframes/lib/Pregel.scala
@@ -25,25 +25,26 @@ import org.apache.spark.sql.functions.{array, col, explode, struct}
/**
* Implements a Pregel-like bulk-synchronous message-passing API based on DataFrame operations.
*
- * See Malewicz et al., Pregel: a system for large-scale graph
- * processing for a detailed description of the Pregel algorithm.
+ * See Malewicz et al., Pregel: a system for
+ * large-scale graph processing for a detailed description of the Pregel algorithm.
*
- * You can construct a Pregel instance using either this constructor or [[org.graphframes.GraphFrame#pregel]],
- * then use builder pattern to describe the operations, and then call [[run]] to start a run.
- * It returns a DataFrame of vertices from the last iteration.
+ * You can construct a Pregel instance using either this constructor or
+ * [[org.graphframes.GraphFrame#pregel]], then use builder pattern to describe the operations, and
+ * then call [[run]] to start a run. It returns a DataFrame of vertices from the last iteration.
*
- * When a run starts, it expands the vertices DataFrame using column expressions defined by [[withVertexColumn]].
- * Those additional vertex properties can be changed during Pregel iterations.
- * In each Pregel iteration, there are three phases:
- * - Given each edge triplet, generate messages and specify target vertices to send,
- * described by [[sendMsgToDst]] and [[sendMsgToSrc]].
+ * When a run starts, it expands the vertices DataFrame using column expressions defined by
+ * [[withVertexColumn]]. Those additional vertex properties can be changed during Pregel
+ * iterations. In each Pregel iteration, there are three phases:
+ * - Given each edge triplet, generate messages and specify target vertices to send, described
+ * by [[sendMsgToDst]] and [[sendMsgToSrc]].
* - Aggregate messages by target vertex IDs, described by [[aggMsgs]].
- * - Update additional vertex properties based on aggregated messages and states from previous iteration,
- * described by [[withVertexColumn]].
+ * - Update additional vertex properties based on aggregated messages and states from previous
+ * iteration, described by [[withVertexColumn]].
*
* Please find what columns you can reference at each phase in the method API docs.
*
- * You can control the number of iterations by [[setMaxIter]] and check API docs for advanced controls.
+ * You can control the number of iterations by [[setMaxIter]] and check API docs for advanced
+ * controls.
*
* Example code for Page Rank:
*
@@ -61,11 +62,13 @@ import org.apache.spark.sql.functions.{array, col, explode, struct}
* .run()
* }}}
*
- * @param graph The graph that Pregel will run on.
- * @see [[org.graphframes.GraphFrame#pregel]]
- * @see
- * Malewicz et al., Pregel: a system for large-scale graph processing.
- *
+ * @param graph
+ * The graph that Pregel will run on.
+ * @see
+ * [[org.graphframes.GraphFrame#pregel]]
+ * @see
+ * Malewicz et al., Pregel: a system for
+ * large-scale graph processing.
*/
class Pregel(val graph: GraphFrame) {
@@ -99,27 +102,37 @@ class Pregel(val graph: GraphFrame) {
}
/**
- * Defines an additional vertex column at the start of run and how to update it in each iteration.
+ * Defines an additional vertex column at the start of run and how to update it in each
+ * iteration.
*
* You can call it multiple times to add more than one additional vertex columns.
*
- * @param colName the name of the additional vertex column.
- * It cannot be an existing vertex column in the graph.
- * @param initialExpr the expression to initialize the additional vertex column.
- * You can reference all original vertex columns in this expression.
- * @param updateAfterAggMsgsExpr the expression to update the additional vertex column after messages aggregation.
- * You can reference all original vertex columns, additional vertex columns, and the
- * aggregated message column using [[Pregel$#msg]].
- * If the vertex received no messages, the message column would be null.
+ * @param colName
+ * the name of the additional vertex column. It cannot be an existing vertex column in the
+ * graph.
+ * @param initialExpr
+ * the expression to initialize the additional vertex column. You can reference all original
+ * vertex columns in this expression.
+ * @param updateAfterAggMsgsExpr
+ * the expression to update the additional vertex column after messages aggregation. You can
+ * reference all original vertex columns, additional vertex columns, and the aggregated
+ * message column using [[Pregel$#msg]]. If the vertex received no messages, the message
+ * column would be null.
*/
- def withVertexColumn(colName: String, initialExpr: Column, updateAfterAggMsgsExpr: Column): this.type = {
+ def withVertexColumn(
+ colName: String,
+ initialExpr: Column,
+ updateAfterAggMsgsExpr: Column): this.type = {
// TODO: check if this column exists.
- require(colName != null && colName != ID && colName != Pregel.MSG_COL_NAME,
+ require(
+ colName != null && colName != ID && colName != Pregel.MSG_COL_NAME,
"additional column name cannot be null and cannot be the same name with ID column or " +
- "msg column.")
+ "msg column.")
require(initialExpr != null, "additional column should provide a nonnull initial expression.")
- require(updateAfterAggMsgsExpr != null, "additional column should provide a nonnull " +
- "updateAfterAggMsgs expression.")
+ require(
+ updateAfterAggMsgsExpr != null,
+ "additional column should provide a nonnull " +
+ "updateAfterAggMsgs expression.")
withVertexColumnList += Tuple3(colName, initialExpr, updateAfterAggMsgsExpr)
this
}
@@ -129,12 +142,14 @@ class Pregel(val graph: GraphFrame) {
*
* You can call it multiple times to send more than one messages.
*
- * @param msgExpr the expression of the message to send to the source vertex given a (src, edge, dst) triplet.
- * Source/destination vertex properties and edge properties are nested under columns `src`, `dst`,
- * and `edge`, respectively.
- * You can reference them using [[Pregel$#src]], [[Pregel$#dst]], and [[Pregel$#edge]].
- * Null messages are not included in message aggregation.
- * @see [[sendMsgToDst]]
+ * @param msgExpr
+ * the expression of the message to send to the source vertex given a (src, edge, dst)
+ * triplet. Source/destination vertex properties and edge properties are nested under columns
+ * `src`, `dst`, and `edge`, respectively. You can reference them using [[Pregel$#src]],
+ * [[Pregel$#dst]], and [[Pregel$#edge]]. Null messages are not included in message
+ * aggregation.
+ * @see
+ * [[sendMsgToDst]]
*/
def sendMsgToSrc(msgExpr: Column): this.type = {
sendMsgs += Tuple2(Pregel.src(ID), msgExpr)
@@ -146,12 +161,14 @@ class Pregel(val graph: GraphFrame) {
*
* You can call it multiple times to send more than one messages.
*
- * @param msgExpr the message expression to send to the destination vertex given a (`src`, `edge`, `dst`) triplet.
- * Source/destination vertex properties and edge properties are nested under columns `src`, `dst`,
- * and `edge`, respectively.
- * You can reference them using [[Pregel$#src]], [[Pregel$#dst]], and [[Pregel$#edge]].
- * Null messages are not included in message aggregation.
- * @see [[sendMsgToSrc]]
+ * @param msgExpr
+ * the message expression to send to the destination vertex given a (`src`, `edge`, `dst`)
+ * triplet. Source/destination vertex properties and edge properties are nested under columns
+ * `src`, `dst`, and `edge`, respectively. You can reference them using [[Pregel$#src]],
+ * [[Pregel$#dst]], and [[Pregel$#edge]]. Null messages are not included in message
+ * aggregation.
+ * @see
+ * [[sendMsgToSrc]]
*/
def sendMsgToDst(msgExpr: Column): this.type = {
sendMsgs += Tuple2(Pregel.dst(ID), msgExpr)
@@ -161,9 +178,10 @@ class Pregel(val graph: GraphFrame) {
/**
* Defines how messages are aggregated after grouped by target vertex IDs.
*
- * @param aggExpr the message aggregation expression, such as `sum(Pregel.msg)`.
- * You can reference the message column by [[Pregel$#msg]] and the vertex ID by [[GraphFrame$#ID]],
- * while the latter is usually not used.
+ * @param aggExpr
+ * the message aggregation expression, such as `sum(Pregel.msg)`. You can reference the
+ * message column by [[Pregel$#msg]] and the vertex ID by [[GraphFrame$#ID]], while the latter
+ * is usually not used.
*/
def aggMsgs(aggExpr: Column): this.type = {
aggMsgsCol = aggExpr
@@ -173,14 +191,22 @@ class Pregel(val graph: GraphFrame) {
/**
* Runs the defined Pregel algorithm.
*
- * @return the result vertex DataFrame from the final iteration including both original and additional columns.
+ * @return
+ * the result vertex DataFrame from the final iteration including both original and additional
+ * columns.
*/
def run(): DataFrame = {
- require(sendMsgs.length > 0, "We need to set at least one message expression for pregel running.")
+ require(
+ sendMsgs.length > 0,
+ "We need to set at least one message expression for pregel running.")
require(aggMsgsCol != null, "We need to set aggMsgs for pregel running.")
require(maxIter >= 1, "The max iteration number should be >= 1.")
- require(checkpointInterval >= 0, "The checkpoint interval should be >= 0, 0 indicates no checkpoint.")
- require(withVertexColumnList.size > 0, "There should be at least one additional vertex columns for updating.")
+ require(
+ checkpointInterval >= 0,
+ "The checkpoint interval should be >= 0, 0 indicates no checkpoint.")
+ require(
+ withVertexColumnList.size > 0,
+ "There should be at least one additional vertex columns for updating.")
val sendMsgsColList = sendMsgs.toList.map { case (id, msg) =>
struct(id.as(ID), msg.as("msg"))
@@ -203,9 +229,12 @@ class Pregel(val graph: GraphFrame) {
val shouldCheckpoint = checkpointInterval > 0
while (iteration <= maxIter) {
- val tripletsDF = currentVertices.select(struct(col("*")).as(SRC))
+ val tripletsDF = currentVertices
+ .select(struct(col("*")).as(SRC))
.join(edges.select(struct(col("*")).as(EDGE)), Pregel.src(ID) === Pregel.edge(SRC))
- .join(currentVertices.select(struct(col("*")).as(DST)), Pregel.edge(DST) === Pregel.dst(ID))
+ .join(
+ currentVertices.select(struct(col("*")).as(DST)),
+ Pregel.edge(DST) === Pregel.dst(ID))
var msgDF: DataFrame = tripletsDF
.select(explode(array(sendMsgsColList: _*)).as("msg"))
@@ -258,31 +287,38 @@ object Pregel extends Serializable {
/**
* References the message column in aggregating messages and updating additional vertex columns.
*
- * @see [[Pregel.aggMsgs]] and [[Pregel.withVertexColumn]]
+ * @see
+ * [[Pregel.aggMsgs]] and [[Pregel.withVertexColumn]]
*/
val msg: Column = col(MSG_COL_NAME)
/**
* References a source vertex column in generating messages to send.
*
- * @param colName the vertex column name.
- * @see [[Pregel.sendMsgToSrc]] and [[Pregel.sendMsgToDst]]
+ * @param colName
+ * the vertex column name.
+ * @see
+ * [[Pregel.sendMsgToSrc]] and [[Pregel.sendMsgToDst]]
*/
def src(colName: String): Column = col(GraphFrame.SRC + "." + colName)
/**
* References a destination vertex column in generating messages to send.
*
- * @param colName the vertex column name.
- * @see [[Pregel.sendMsgToSrc]] and [[Pregel.sendMsgToDst]]
+ * @param colName
+ * the vertex column name.
+ * @see
+ * [[Pregel.sendMsgToSrc]] and [[Pregel.sendMsgToDst]]
*/
def dst(colName: String): Column = col(GraphFrame.DST + "." + colName)
/**
* References an edge column in generating messages to send.
*
- * @param colName the edge column name.
- * @see [[Pregel.sendMsgToSrc]] and [[Pregel.sendMsgToDst]]
+ * @param colName
+ * the edge column name.
+ * @see
+ * [[Pregel.sendMsgToSrc]] and [[Pregel.sendMsgToDst]]
*/
def edge(colName: String): Column = col(GraphFrame.EDGE + "." + colName)
}
diff --git a/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala b/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala
index 43ceedd92..35b18a742 100644
--- a/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala
+++ b/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala
@@ -23,9 +23,8 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.graphframes.{GraphFrame, Logging}
/**
- * Implement SVD++ based on "Factorization Meets the Neighborhood:
- * a Multifaceted Collaborative Filtering Model",
- * available at [[https://dl.acm.org/citation.cfm?id=1401944]].
+ * Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative
+ * Filtering Model", available at [[https://dl.acm.org/citation.cfm?id=1401944]].
*
* Note: The status of this algorithm is EXPERIMENTAL. Its API and implementation may be changed
* in the future.
@@ -35,9 +34,8 @@ import org.graphframes.{GraphFrame, Logging}
*
* Configuration parameters: see the description of each parameter in the article.
*
- * Returns a DataFrame with vertex attributes containing the trained model. See the object
+ * Returns a DataFrame with vertex attributes containing the trained model. See the object
* (static) members for the names of the output columns.
- *
*/
class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) extends Arguments {
private var _rank: Int = 10
@@ -120,7 +118,9 @@ object SVDPlusPlus {
case Row(src: Long, dst: Long, w: Double) => Edge(src, dst, w)
}
val (gx, res) = graphxlib.SVDPlusPlus.run(edges, conf)
- val gf = GraphXConversions.fromGraphX(graph, gx,
+ val gf = GraphXConversions.fromGraphX(
+ graph,
+ gx,
vertexNames = Seq(COLUMN1, COLUMN2, COLUMN3, COLUMN4))
(gf.vertices, res)
}
@@ -133,32 +133,32 @@ object SVDPlusPlus {
val COLUMN_WEIGHT = "weight"
/**
- * Name for output vertexDataFrame column containing first parameter of learned model,
- * of type `Array[Double]`.
+ * Name for output vertexDataFrame column containing first parameter of learned model, of type
+ * `Array[Double]`.
*
* Note: This column name may change in the future!
*/
val COLUMN1 = "column1"
/**
- * Name for output vertexDataFrame column containing second parameter of learned model,
- * of type `Array[Double]`.
+ * Name for output vertexDataFrame column containing second parameter of learned model, of type
+ * `Array[Double]`.
*
* Note: This column name may change in the future!
*/
val COLUMN2 = "column2"
/**
- * Name for output vertexDataFrame column containing third parameter of learned model,
- * of type `Double`.
+ * Name for output vertexDataFrame column containing third parameter of learned model, of type
+ * `Double`.
*
* Note: This column name may change in the future!
*/
val COLUMN3 = "column3"
/**
- * Name for output vertexDataFrame column containing fourth parameter of learned model,
- * of type `Double`.
+ * Name for output vertexDataFrame column containing fourth parameter of learned model, of type
+ * `Double`.
*
* Note: This column name may change in the future!
*/
diff --git a/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/src/main/scala/org/graphframes/lib/ShortestPaths.scala
index 6b1f15c74..0d0b52407 100644
--- a/src/main/scala/org/graphframes/lib/ShortestPaths.scala
+++ b/src/main/scala/org/graphframes/lib/ShortestPaths.scala
@@ -30,13 +30,13 @@ import org.apache.spark.sql.types.{IntegerType, MapType}
import org.graphframes.GraphFrame
/**
- * Computes shortest paths from every vertex to the given set of landmark vertices.
- * Note that this takes edge direction into account.
+ * Computes shortest paths from every vertex to the given set of landmark vertices. Note that this
+ * takes edge direction into account.
*
* The returned DataFrame contains all the original vertex information as well as one additional
* column:
- * - distances (`MapType[vertex ID type, IntegerType]`): For each vertex v, a map containing
- * the shortest-path distance to each reachable landmark vertex.
+ * - distances (`MapType[vertex ID type, IntegerType]`): For each vertex v, a map containing the
+ * shortest-path distance to each reachable landmark vertex.
*/
class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends Arguments {
private var lmarks: Option[Seq[Any]] = None
@@ -67,9 +67,9 @@ private object ShortestPaths {
private def run(graph: GraphFrame, landmarks: Seq[Any]): DataFrame = {
val idType = graph.vertices.schema(GraphFrame.ID).dataType
val longIdToLandmark = landmarks.map(l => GraphXConversions.integralId(graph, l) -> l).toMap
- val gx = graphxlib.ShortestPaths.run(
- graph.cachedTopologyGraphX,
- longIdToLandmark.keys.toSeq.sorted).mapVertices { case (_, m) => m.toSeq }
+ val gx = graphxlib.ShortestPaths
+ .run(graph.cachedTopologyGraphX, longIdToLandmark.keys.toSeq.sorted)
+ .mapVertices { case (_, m) => m.toSeq }
val g = GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(DISTANCE_ID))
val distanceCol: Column = if (graph.hasIntegralIdType) {
// It seems there are no easy way to convert a sequence of pairs into a map
@@ -83,7 +83,7 @@ private object ShortestPaths {
val func = new UDF1[Seq[Row], Map[Any, Int]] {
override def call(t1: Seq[Row]): Map[Any, Int] = {
t1.map { case Row(k: Long, v: Int) =>
- longIdToLandmark(k) -> v
+ longIdToLandmark(k) -> v
}.toMap
}
}
diff --git a/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala b/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala
index 2914a287a..e8d9ecde8 100644
--- a/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala
+++ b/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala
@@ -27,10 +27,10 @@ import org.graphframes.GraphFrame
* vertex assigned to the SCC containing that vertex.
*
* The resulting DataFrame contains all the original vertex information and one additional column:
- * - component (`LongType`): unique ID for this component
+ * - component (`LongType`): unique ID for this component
*/
class StronglyConnectedComponents private[graphframes] (private val graph: GraphFrame)
- extends Arguments {
+ extends Arguments {
private var maxIter: Option[Int] = None
@@ -44,7 +44,6 @@ class StronglyConnectedComponents private[graphframes] (private val graph: Graph
}
}
-
/** Strongly connected components algorithm implementation. */
private object StronglyConnectedComponents {
private def run(graph: GraphFrame, numIter: Int): DataFrame = {
diff --git a/src/main/scala/org/graphframes/lib/TriangleCount.scala b/src/main/scala/org/graphframes/lib/TriangleCount.scala
index fb8c51462..5eede3c67 100644
--- a/src/main/scala/org/graphframes/lib/TriangleCount.scala
+++ b/src/main/scala/org/graphframes/lib/TriangleCount.scala
@@ -26,15 +26,15 @@ import org.graphframes.GraphFrame.{DST, ID, LONG_DST, LONG_SRC, SRC}
/**
* Computes the number of triangles passing through each vertex.
*
- * This algorithm ignores edge direction; i.e., all edges are treated as undirected.
- * In a multigraph, duplicate edges will be counted only once.
+ * This algorithm ignores edge direction; i.e., all edges are treated as undirected. In a
+ * multigraph, duplicate edges will be counted only once.
*
- * Note that this provides the same algorithm as GraphX, but GraphX assumes the user provides
- * a graph in the correct format. In Spark 2.0+, GraphX can automatically canonicalize
- * the graph to put it in this format.
+ * Note that this provides the same algorithm as GraphX, but GraphX assumes the user provides a
+ * graph in the correct format. In Spark 2.0+, GraphX can automatically canonicalize the graph to
+ * put it in this format.
*
* The returned DataFrame contains all the original vertex information and one additional column:
- * - count (`LongType`): the count of triangles
+ * - count (`LongType`): the count of triangles
*/
class TriangleCount private[graphframes] (private val graph: GraphFrame) extends Arguments {
@@ -67,7 +67,8 @@ private object TriangleCount {
val v = graph.vertices
val countsCol = when(col("count").isNull, 0L).otherwise(col("count"))
- val newV = v.join(triangleCounts, v(ID) === triangleCounts(ID), "left_outer")
+ val newV = v
+ .join(triangleCounts, v(ID) === triangleCounts(ID), "left_outer")
.select((countsCol.as(COUNT_ID) +: v.columns.map(v.apply)).toSeq: _*)
newV
}
diff --git a/src/main/scala/org/graphframes/pattern/patterns.scala b/src/main/scala/org/graphframes/pattern/patterns.scala
index bd566d578..fcc9f22c5 100644
--- a/src/main/scala/org/graphframes/pattern/patterns.scala
+++ b/src/main/scala/org/graphframes/pattern/patterns.scala
@@ -34,13 +34,13 @@ private[graphframes] object PatternParser extends RegexParsers {
case src ~ "-" ~ "[" ~ name ~ "]" ~ "->" ~ dst => NamedEdge(name, src, dst)
}
val anonymousEdge: Parser[Edge] =
- vertex ~ "-" ~ "[" ~ "]" ~ "->" ~ vertex ^^ {
- case src ~ "-" ~ "[" ~ "]" ~ "->" ~ dst => AnonymousEdge(src, dst)
+ vertex ~ "-" ~ "[" ~ "]" ~ "->" ~ vertex ^^ { case src ~ "-" ~ "[" ~ "]" ~ "->" ~ dst =>
+ AnonymousEdge(src, dst)
}
private val edge: Parser[Edge] = namedEdge | anonymousEdge
private val negatedEdge: Parser[Pattern] =
- "!" ~ edge ^^ {
- case _ ~ e => Negation(e)
+ "!" ~ edge ^^ { case _ ~ e =>
+ Negation(e)
}
private val pattern: Parser[Pattern] = edge | vertex | negatedEdge
val patterns: Parser[List[Pattern]] = repsep(pattern, ";")
@@ -62,10 +62,11 @@ private[graphframes] object Pattern {
/**
* Checks all Patterns for validity:
- * - Disallow named edges within negated terms
- * - Disallow term "()-[]->()" and its negation
- * - Disallow name to be shared by a vertex and an edge
- * @throws InvalidParseException if an negated terms contain named edges
+ * - Disallow named edges within negated terms
+ * - Disallow term "()-[]->()" and its negation
+ * - Disallow name to be shared by a vertex and an edge
+ * @throws InvalidParseException
+ * if an negated terms contain named edges
*/
private def assertValidPatterns(patterns: Seq[Pattern]): Unit = {
@@ -75,21 +76,24 @@ private[graphframes] object Pattern {
def addVertex(v: Vertex): Unit = v match {
case NamedVertex(name) =>
if (edgeNames.contains(name)) {
- throw new InvalidParseException(s"Motif reused name '$name' for both a vertex and " +
- s"an edge, which is not allowed.")
+ throw new InvalidParseException(
+ s"Motif reused name '$name' for both a vertex and " +
+ s"an edge, which is not allowed.")
}
vertexNames += name
- case AnonymousVertex => // pass
+ case AnonymousVertex => // pass
}
def addEdge(e: Edge): Unit = e match {
case NamedEdge(name, src, dst) =>
if (vertexNames.contains(name)) {
- throw new InvalidParseException(s"Motif reused name '$name' for both a vertex and " +
- s"an edge, which is not allowed.")
+ throw new InvalidParseException(
+ s"Motif reused name '$name' for both a vertex and " +
+ s"an edge, which is not allowed.")
}
if (edgeNames.contains(name)) {
- throw new InvalidParseException(s"Motif reused name '$name' for multiple edges, " +
- s"which is not allowed.")
+ throw new InvalidParseException(
+ s"Motif reused name '$name' for multiple edges, " +
+ s"which is not allowed.")
}
edgeNames += name
addVertex(src)
@@ -103,8 +107,9 @@ private[graphframes] object Pattern {
case Negation(edge) =>
edge match {
case NamedEdge(name, src, dst) =>
- throw new InvalidParseException(s"Motif finding does not support negated named " +
- s"edges, but the given pattern contained: !($src)-[$name]->($dst)")
+ throw new InvalidParseException(
+ s"Motif finding does not support negated named " +
+ s"edges, but the given pattern contained: !($src)-[$name]->($dst)")
case AnonymousEdge(AnonymousVertex, AnonymousVertex) =>
throw new InvalidParseException(s"Motif finding does not support completely " +
s"anonymous negated edges !()-[]->(). Users can check for 0 edges in the graph " +
@@ -113,17 +118,19 @@ private[graphframes] object Pattern {
addEdge(e)
}
case AnonymousEdge(AnonymousVertex, AnonymousVertex) =>
- throw new InvalidParseException(s"Motif finding does not support completely " +
- s"anonymous edges ()-[]->(). Users can check for the existence of edges in the " +
- s"graph using the edges DataFrame.")
+ throw new InvalidParseException(
+ s"Motif finding does not support completely " +
+ s"anonymous edges ()-[]->(). Users can check for the existence of edges in the " +
+ s"graph using the edges DataFrame.")
case e @ AnonymousEdge(_, _) =>
addEdge(e)
case e @ NamedEdge(_, _, _) =>
addEdge(e)
case AnonymousVertex =>
- throw new InvalidParseException("Motif finding does not allow a lone anonymous vertex " +
- "\"()\" in a motif. Users can check for the existence of vertices in the graph " +
- "using the vertices DataFrame.")
+ throw new InvalidParseException(
+ "Motif finding does not allow a lone anonymous vertex " +
+ "\"()\" in a motif. Users can check for the existence of vertices in the graph " +
+ "using the vertices DataFrame.")
case v @ NamedVertex(_) =>
addVertex(v)
}
@@ -132,27 +139,31 @@ private[graphframes] object Pattern {
/**
* Return the set of named vertices which only appear in negated terms, in sorted order.
*/
- private[graphframes]
- def findNamedVerticesOnlyInNegatedTerms(patterns: Seq[Pattern]): Seq[String] = {
+ private[graphframes] def findNamedVerticesOnlyInNegatedTerms(
+ patterns: Seq[Pattern]): Seq[String] = {
val vPos = findNamedElementsInOrder(
- patterns.filter(p => !p.isInstanceOf[Negation]), includeEdges = false).toSet
+ patterns.filter(p => !p.isInstanceOf[Negation]),
+ includeEdges = false).toSet
val vNeg = findNamedElementsInOrder(
- patterns.filter(p => p.isInstanceOf[Negation]), includeEdges = false).toSet
+ patterns.filter(p => p.isInstanceOf[Negation]),
+ includeEdges = false).toSet
vNeg.diff(vPos).toSeq.sorted
}
/**
- * Return the set of named vertices (and optionally edges) appearing in the given patterns,
- * in the order they first appear in the sequence of patterns.
- * @param includeEdges If true, include named edges in the returned sequence.
+ * Return the set of named vertices (and optionally edges) appearing in the given patterns, in
+ * the order they first appear in the sequence of patterns.
+ * @param includeEdges
+ * If true, include named edges in the returned sequence.
*/
- private[graphframes]
- def findNamedElementsInOrder(patterns: Seq[Pattern], includeEdges: Boolean): Seq[String] = {
+ private[graphframes] def findNamedElementsInOrder(
+ patterns: Seq[Pattern],
+ includeEdges: Boolean): Seq[String] = {
val elementSet = mutable.LinkedHashSet.empty[String]
def findNamedElementsHelper(pattern: Pattern): Unit = pattern match {
case Negation(child) =>
findNamedElementsHelper(child)
- case AnonymousVertex => // pass
+ case AnonymousVertex => // pass
case NamedVertex(name) =>
if (!elementSet.contains(name)) {
elementSet += name
@@ -187,4 +198,3 @@ private[graphframes] sealed trait Edge extends Pattern
private[graphframes] case class AnonymousEdge(src: Vertex, dst: Vertex) extends Edge
private[graphframes] case class NamedEdge(name: String, src: Vertex, dst: Vertex) extends Edge
-
diff --git a/src/test/scala/org/graphframes/GraphFrameSuite.scala b/src/test/scala/org/graphframes/GraphFrameSuite.scala
index 3ef5059ec..d8d761898 100644
--- a/src/test/scala/org/graphframes/GraphFrameSuite.scala
+++ b/src/test/scala/org/graphframes/GraphFrameSuite.scala
@@ -32,7 +32,6 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.graphframes.examples.Graphs
-
class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {
import GraphFrame._
@@ -47,10 +46,11 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {
super.beforeAll()
tempDir = Files.createTempDir()
vertices = spark.createDataFrame(localVertices.toSeq).toDF("id", "name")
- edges = spark.createDataFrame(localEdges.toSeq.map {
- case ((src, dst), action) =>
+ edges = spark
+ .createDataFrame(localEdges.toSeq.map { case ((src, dst), action) =>
(src, dst, action)
- }).toDF("src", "dst", "action")
+ })
+ .toDF("src", "dst", "action")
}
override def afterAll(): Unit = {
@@ -86,9 +86,14 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val idsFromVertices = g.vertices.select("id").rdd.map(_.getLong(0)).collect()
val idsFromVerticesSet = idsFromVertices.toSet
assert(idsFromVertices.length === idsFromVerticesSet.size)
- val idsFromEdgesSet = g.edges.select("src", "dst").rdd.flatMap { case Row(src: Long, dst: Long) =>
- Seq(src, dst)
- }.collect().toSet
+ val idsFromEdgesSet = g.edges
+ .select("src", "dst")
+ .rdd
+ .flatMap { case Row(src: Long, dst: Long) =>
+ Seq(src, dst)
+ }
+ .collect()
+ .toSet
assert(idsFromVerticesSet === idsFromEdgesSet)
}
@@ -127,8 +132,10 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("convert to GraphX: Int IDs") {
val vv = vertices.select(col("id").cast(IntegerType).as("id"), col("name"))
- val ee = edges.select(col("src").cast(IntegerType).as("src"),
- col("dst").cast(IntegerType).as("dst"), col("action"))
+ val ee = edges.select(
+ col("src").cast(IntegerType).as("src"),
+ col("dst").cast(IntegerType).as("dst"),
+ col("action"))
val gf = GraphFrame(vv, ee)
val g = gf.toGraphX
// Int IDs should be directly cast to Long, so ID values should match.
@@ -140,31 +147,35 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {
assert(id0 === id1)
assert(localVertices(id0) === name)
}
- g.edges.collect().foreach {
- case Edge(src0: Long, dst0: Long, attr: Row) =>
- val src1 = attr.getInt(eCols("src"))
- val dst1 = attr.getInt(eCols("dst"))
- val action = attr.getString(eCols("action"))
- assert(src0 === src1)
- assert(dst0 === dst1)
- assert(localEdges((src0, dst0)) === action)
+ g.edges.collect().foreach { case Edge(src0: Long, dst0: Long, attr: Row) =>
+ val src1 = attr.getInt(eCols("src"))
+ val dst1 = attr.getInt(eCols("dst"))
+ val action = attr.getString(eCols("action"))
+ assert(src0 === src1)
+ assert(dst0 === dst1)
+ assert(localEdges((src0, dst0)) === action)
}
}
test("convert to GraphX: String IDs") {
try {
val vv = vertices.select(col("id").cast(StringType).as("id"), col("name"))
- val ee = edges.select(col("src").cast(StringType).as("src"),
- col("dst").cast(StringType).as("dst"), col("action"))
+ val ee = edges.select(
+ col("src").cast(StringType).as("src"),
+ col("dst").cast(StringType).as("dst"),
+ col("action"))
val gf = GraphFrame(vv, ee)
val g = gf.toGraphX
// String IDs will be re-indexed, so ID values may not match.
val vCols = gf.vertexColumnMap
val eCols = gf.edgeColumnMap
// First, get index.
- val new2oldID: Map[Long, String] = g.vertices.map { case (id: Long, attr: Row) =>
- (id, attr.getString(vCols("id")))
- }.collect().toMap
+ val new2oldID: Map[Long, String] = g.vertices
+ .map { case (id: Long, attr: Row) =>
+ (id, attr.getString(vCols("id")))
+ }
+ .collect()
+ .toMap
// Same as in test with Int IDs, but with re-indexing
g.vertices.collect().foreach { case (id0: Long, attr: Row) =>
val id1 = attr.getString(vCols("id"))
@@ -172,14 +183,13 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {
assert(new2oldID(id0) === id1)
assert(localVertices(new2oldID(id0).toInt) === name)
}
- g.edges.collect().foreach {
- case Edge(src0: Long, dst0: Long, attr: Row) =>
- val src1 = attr.getString(eCols("src"))
- val dst1 = attr.getString(eCols("dst"))
- val action = attr.getString(eCols("action"))
- assert(new2oldID(src0) === src1)
- assert(new2oldID(dst0) === dst1)
- assert(localEdges((new2oldID(src0).toInt, new2oldID(dst0).toInt)) === action)
+ g.edges.collect().foreach { case Edge(src0: Long, dst0: Long, attr: Row) =>
+ val src1 = attr.getString(eCols("src"))
+ val dst1 = attr.getString(eCols("dst"))
+ val action = attr.getString(eCols("action"))
+ assert(new2oldID(src0) === src1)
+ assert(new2oldID(dst0) === dst1)
+ assert(localEdges((new2oldID(src0).toInt, new2oldID(dst0).toInt)) === action)
}
} catch {
case e: Exception =>
@@ -211,21 +221,30 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val g = GraphFrame(vertices, edges)
assert(g.outDegrees.columns === Seq("id", "outDegree"))
- val outDegrees = g.outDegrees.collect().map { case Row(id: Long, outDeg: Int) =>
- (id, outDeg)
- }.toMap
+ val outDegrees = g.outDegrees
+ .collect()
+ .map { case Row(id: Long, outDeg: Int) =>
+ (id, outDeg)
+ }
+ .toMap
assert(outDegrees === Map(1L -> 1, 2L -> 2))
assert(g.inDegrees.columns === Seq("id", "inDegree"))
- val inDegrees = g.inDegrees.collect().map { case Row(id: Long, inDeg: Int) =>
- (id, inDeg)
- }.toMap
+ val inDegrees = g.inDegrees
+ .collect()
+ .map { case Row(id: Long, inDeg: Int) =>
+ (id, inDeg)
+ }
+ .toMap
assert(inDegrees === Map(1L -> 1, 2L -> 1, 3L -> 1))
assert(g.degrees.columns === Seq("id", "degree"))
- val degrees = g.degrees.collect().map { case Row(id: Long, deg: Int) =>
- (id, deg)
- }.toMap
+ val degrees = g.degrees
+ .collect()
+ .map { case Row(id: Long, deg: Int) =>
+ (id, deg)
+ }
+ .toMap
assert(degrees === Map(1L -> 2, 2L -> 3, 3L -> 1))
}
@@ -257,7 +276,8 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val chain = Graphs.chain(n + 1)
val vertices = star.vertices.select(col(ID).cast("string").as(ID))
val edges =
- star.edges.select(col(SRC).cast("string").as(SRC), col(DST).cast("string").as(DST))
+ star.edges
+ .select(col(SRC).cast("string").as(SRC), col(DST).cast("string").as(DST))
.unionAll(
chain.edges.select(col(SRC).cast("string").as(SRC), col(DST).cast("string").as(DST)))
@@ -265,7 +285,8 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val localEdges = edges.select(SRC, DST).as[(String, String)].collect().toSet
val defaultThreshold = GraphFrame.broadcastThreshold
- assert(defaultThreshold === 1000000,
+ assert(
+ defaultThreshold === 1000000,
s"Default broadcast threshold should be 1000000 but got $defaultThreshold.")
for (threshold <- Seq(0, 4, 10)) {
@@ -274,18 +295,19 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val g = GraphFrame(vertices, edges)
g.persist(StorageLevel.MEMORY_AND_DISK)
- val indexedVertices = g.indexedVertices.select(ID, LONG_ID).as[(String, Long)].collect().toMap
+ val indexedVertices =
+ g.indexedVertices.select(ID, LONG_ID).as[(String, Long)].collect().toMap
assert(indexedVertices.keySet === localVertices)
assert(indexedVertices.values.toSeq.distinct.size === localVertices.size)
val origEdges = g.indexedEdges.select(SRC, DST).as[(String, String)].collect().toSet
assert(origEdges === localEdges)
g.indexedEdges
- .select(SRC, LONG_SRC, DST, LONG_DST).as[(String, Long, String, Long)]
+ .select(SRC, LONG_SRC, DST, LONG_DST)
+ .as[(String, Long, String, Long)]
.collect()
- .foreach {
- case (src, longSrc, dst, longDst) =>
- assert(indexedVertices(src) === longSrc)
- assert(indexedVertices(dst) === longDst)
+ .foreach { case (src, longSrc, dst, longDst) =>
+ assert(indexedVertices(src) === longSrc)
+ assert(indexedVertices(dst) === longDst)
}
}
diff --git a/src/test/scala/org/graphframes/GraphFrameTestSparkContext.scala b/src/test/scala/org/graphframes/GraphFrameTestSparkContext.scala
index 97bdf6e4a..99c8b56c6 100644
--- a/src/test/scala/org/graphframes/GraphFrameTestSparkContext.scala
+++ b/src/test/scala/org/graphframes/GraphFrameTestSparkContext.scala
@@ -35,9 +35,9 @@ trait GraphFrameTestSparkContext extends BeforeAndAfterAll { self: Suite =>
/**
* A helper object for importing SQL implicits.
*
- * Note that the alternative of importing `spark.implicits._` is not possible here.
- * This is because we create the `SQLContext` immediately before the first test is run,
- * but the implicits import is needed in the constructor.
+ * Note that the alternative of importing `spark.implicits._` is not possible here. This is
+ * because we create the `SQLContext` immediately before the first test is run, but the
+ * implicits import is needed in the constructor.
*/
protected object testImplicits extends SQLImplicits {
protected override def _sqlContext: SQLContext = self.sqlContext
@@ -56,7 +56,8 @@ trait GraphFrameTestSparkContext extends BeforeAndAfterAll { self: Suite =>
override def beforeAll() {
super.beforeAll()
- spark = SparkSession.builder()
+ spark = SparkSession
+ .builder()
.master("local[2]")
.appName("GraphFramesUnitTest")
.config("spark.sql.shuffle.partitions", 4)
diff --git a/src/test/scala/org/graphframes/PatternMatchSuite.scala b/src/test/scala/org/graphframes/PatternMatchSuite.scala
index 6feac0294..fc06d2eb6 100644
--- a/src/test/scala/org/graphframes/PatternMatchSuite.scala
+++ b/src/test/scala/org/graphframes/PatternMatchSuite.scala
@@ -23,12 +23,12 @@ import org.apache.spark.sql.functions.{col, lit, when}
/**
* Cases to go through:
- * - Any negated terms?
- * - Any anonymous vertices
+ * - Any negated terms?
+ * - Any anonymous vertices
* - in non-negated terms?
* - in negated terms?
- * - # named vertices grounding a negated term to non-negated terms: 2, 1, 0
- * - Named edges?
+ * - # named vertices grounding a negated term to non-negated terms: 2, 1, 0
+ * - Named edges?
*/
class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
@@ -40,18 +40,20 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
override def beforeAll(): Unit = {
super.beforeAll()
- v = spark.createDataFrame(List(
- (0L, "a", "f"),
- (1L, "b", "m"),
- (2L, "c", "m"),
- (3L, "d", "f"))).toDF("id", "attr", "gender")
- e = spark.createDataFrame(List(
- (0L, 1L, "friend"),
- (1L, 0L, "follow"),
- (1L, 2L, "friend"),
- (2L, 3L, "follow"),
- (2L, 0L, "unknown"))).toDF("src", "dst", "relationship")
- noEdges = v.select(col("id").alias("src"))
+ v = spark
+ .createDataFrame(List((0L, "a", "f"), (1L, "b", "m"), (2L, "c", "m"), (3L, "d", "f")))
+ .toDF("id", "attr", "gender")
+ e = spark
+ .createDataFrame(
+ List(
+ (0L, 1L, "friend"),
+ (1L, 0L, "follow"),
+ (1L, 2L, "friend"),
+ (2L, 3L, "follow"),
+ (2L, 0L, "unknown")))
+ .toDF("src", "dst", "relationship")
+ noEdges = v
+ .select(col("id").alias("src"))
.crossJoin(v.select(col("id").alias("dst")))
.except(e.select("src", "dst"))
g = GraphFrame(v, e)
@@ -66,11 +68,12 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
private def compareResultToExpected[A](result: Set[A], expected: Set[A]): Unit = {
if (result !== expected) {
- throw new AssertionError("result !== expected.\n" +
- s"Result contained additional values: ${result.diff(expected)}\n" +
- s"Expected contained additional values: ${expected.diff(result)}\n" +
- s"Result: $result\n" +
- s"Expected: $expected")
+ throw new AssertionError(
+ "result !== expected.\n" +
+ s"Result contained additional values: ${result.diff(expected)}\n" +
+ s"Expected contained additional values: ${expected.diff(result)}\n" +
+ s"Result: $result\n" +
+ s"Expected: $expected")
}
}
@@ -90,17 +93,10 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val s = "relationship = 'friend'"
// column expression
val c = col("relationship") === "friend"
- // expected subgraph vertices
- val expected_v = Set(
- Row(0L, "a", "f"),
- Row(1L, "b", "m"),
- Row(2L, "c", "m")
- )
+ // expected subgraph vertices
+ val expected_v = Set(Row(0L, "a", "f"), Row(1L, "b", "m"), Row(2L, "c", "m"))
// expected subgraph edges
- val expected_e = Set(
- Row(0L, 1L, "friend"),
- Row(1L, 2L, "friend")
- )
+ val expected_e = Set(Row(0L, 1L, "friend"), Row(1L, 2L, "friend"))
val res_s = g.filterEdges(s)
assert(res_s.vertices.collect().toSet === v.collect().toSet)
@@ -120,17 +116,10 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val s = "id > 0"
// column expression
val c = col("id") > 0
- // expected subgraph vertices
- val expected_v = Set(
- Row(1L, "b", "m"),
- Row(2L, "c", "m"),
- Row(3L, "d", "f")
- )
+ // expected subgraph vertices
+ val expected_v = Set(Row(1L, "b", "m"), Row(2L, "c", "m"), Row(3L, "d", "f"))
// expected subgraph edges
- val expected_e = Set(
- Row(1L, 2L, "friend"),
- Row(2L, 3L, "follow")
- )
+ val expected_e = Set(Row(1L, 2L, "friend"), Row(2L, 3L, "follow"))
val res_s = g.filterVertices(s)
assert(res_s.vertices.collect().toSet === expected_v)
@@ -142,14 +131,11 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
}
test("triangles") {
- val triangles = g.find("(a)-[]->(b); (b)-[]->(c); (c)-[]->(a)")
+ val triangles = g
+ .find("(a)-[]->(b); (b)-[]->(c); (c)-[]->(a)")
.select("a.id", "b.id", "c.id")
- assert(triangles.collect().toSet === Set(
- Row(0L, 1L, 2L),
- Row(2L, 0L, 1L),
- Row(1L, 2L, 0L)
- ))
+ assert(triangles.collect().toSet === Set(Row(0L, 1L, 2L), Row(2L, 0L, 1L), Row(1L, 2L, 0L)))
}
/* ====================================== Vertex queries ===================================== */
@@ -169,29 +155,32 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
assert(triplets.columns === Array("u", "v"))
val res = triplets.select("u.id", "u.attr", "v.id", "v.attr").collect().toSet
- compareResultToExpected(res, Set(
- Row(0L, "a", 1L, "b"),
- Row(1L, "b", 0L, "a"),
- Row(1L, "b", 2L, "c"),
- Row(2L, "c", 3L, "d"),
- Row(2L, "c", 0L, "a")
- ))
+ compareResultToExpected(
+ res,
+ Set(
+ Row(0L, "a", 1L, "b"),
+ Row(1L, "b", 0L, "a"),
+ Row(1L, "b", 2L, "c"),
+ Row(2L, "c", 3L, "d"),
+ Row(2L, "c", 0L, "a")))
}
test("triplet with named edge") {
val triplets = g.find("(u)-[uv]->(v)")
assert(triplets.columns === Array("u", "uv", "v"))
- val res = triplets.select("u.id", "u.attr", "uv.src", "uv.dst", "uv.relationship",
- "v.id", "v.attr")
- .collect().toSet
- compareResultToExpected(res, Set(
- Row(0L, "a", 0L, 1L, "friend", 1L, "b"),
- Row(1L, "b", 1L, 0L, "follow", 0L, "a"),
- Row(1L, "b", 1L, 2L, "friend", 2L, "c"),
- Row(2L, "c", 2L, 3L, "follow", 3L, "d"),
- Row(2L, "c", 2L, 0L, "unknown", 0L, "a")
- ))
+ val res = triplets
+ .select("u.id", "u.attr", "uv.src", "uv.dst", "uv.relationship", "v.id", "v.attr")
+ .collect()
+ .toSet
+ compareResultToExpected(
+ res,
+ Set(
+ Row(0L, "a", 0L, 1L, "friend", 1L, "b"),
+ Row(1L, "b", 1L, 0L, "follow", 0L, "a"),
+ Row(1L, "b", 1L, 2L, "friend", 2L, "c"),
+ Row(2L, "c", 2L, 3L, "follow", 3L, "d"),
+ Row(2L, "c", 2L, 0L, "unknown", 0L, "a")))
}
test("triplet with anonymous vertex") {
@@ -199,13 +188,13 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
assert(triplets.columns === Array("u"))
// Do not use compareResultToExpected since it uses sets, and we expect duplicates.
- assert(triplets.select("u.id", "u.attr").collect().sortBy(_.getLong(0)) === Array(
- Row(0L, "a"),
- Row(1L, "b"),
- Row(1L, "b"),
- Row(2L, "c"),
- Row(2L, "c")
- ).sortBy(_.getLong(0)))
+ assert(
+ triplets.select("u.id", "u.attr").collect().sortBy(_.getLong(0)) === Array(
+ Row(0L, "a"),
+ Row(1L, "b"),
+ Row(1L, "b"),
+ Row(2L, "c"),
+ Row(2L, "c")).sortBy(_.getLong(0)))
}
test("triplet with 2 anonymous vertices") {
@@ -218,9 +207,9 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
}
test("self-loop") {
- val myE = spark.createDataFrame(List(
- (1L, 1L, "self"),
- (3L, 3L, "self"))).toDF("src", "dst", "relationship")
+ val myE = spark
+ .createDataFrame(List((1L, 1L, "self"), (3L, 3L, "self")))
+ .toDF("src", "dst", "relationship")
.union(e)
val myG = GraphFrame(v, myE)
@@ -231,32 +220,29 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val selfLoops2 = myG.find("(a)-[]->(b); (a)-[]->(a)")
assert(selfLoops2.columns === Array("a", "b"))
- val res2 = selfLoops2.select("a.id", "b.id")
+ val res2 = selfLoops2
+ .select("a.id", "b.id")
.where("a.id != b.id")
- .collect().toSet
- compareResultToExpected(res2, Set(
- Row(1L, 0L),
- Row(1L, 2L)
- ))
+ .collect()
+ .toSet
+ compareResultToExpected(res2, Set(Row(1L, 0L), Row(1L, 2L)))
}
test("duplicate edges") {
- val myE = spark.createDataFrame(List(
- (1L, 0L, "dup"),
- (1L, 2L, "dup"))).toDF("src", "dst", "relationship")
+ val myE = spark
+ .createDataFrame(List((1L, 0L, "dup"), (1L, 2L, "dup")))
+ .toDF("src", "dst", "relationship")
.union(e)
val myG = GraphFrame(v, myE)
- val edges = myG.find("(a)-[]->(b)")
+ val edges = myG
+ .find("(a)-[]->(b)")
.where("a.id = 1")
- val res = edges.select("a.id", "b.id")
- .collect().sortBy(_.getLong(1))
- val expected = Array(
- Row(1L, 0L),
- Row(1L, 0L),
- Row(1L, 2L),
- Row(1L, 2L)
- )
+ val res = edges
+ .select("a.id", "b.id")
+ .collect()
+ .sortBy(_.getLong(1))
+ val expected = Array(Row(1L, 0L), Row(1L, 0L), Row(1L, 2L), Row(1L, 2L))
assert(res === expected)
}
@@ -266,26 +252,28 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val triangles = g.find("(a)-[]->(b); (b)-[]->(c); (c)-[]->(a)")
assert(triangles.columns === Array("a", "b", "c"))
- val res = triangles.select("a.id", "b.id", "c.id")
- .collect().toSet
- compareResultToExpected(res, Set(
- Row(0L, 1L, 2L),
- Row(2L, 0L, 1L),
- Row(1L, 2L, 0L)
- ))
+ val res = triangles
+ .select("a.id", "b.id", "c.id")
+ .collect()
+ .toSet
+ compareResultToExpected(res, Set(Row(0L, 1L, 2L), Row(2L, 0L, 1L), Row(1L, 2L, 0L)))
}
test("disconnected edges create an outer join") {
val edgePairs = g.find("(a)-[]->(b); (c)-[]->(d)")
assert(edgePairs.columns === Array("a", "b", "c", "d"))
- val res = edgePairs.select("a.id", "b.id", "c.id", "d.id")
- .collect().toSet
+ val res = edgePairs
+ .select("a.id", "b.id", "c.id", "d.id")
+ .collect()
+ .toSet
val ab = e.select(col("src").alias("a"), col("dst").alias("b"))
val cd = e.select(col("src").alias("c"), col("dst").alias("d"))
- val expected = ab.crossJoin(cd)
- .collect().toSet
+ val expected = ab
+ .crossJoin(cd)
+ .collect()
+ .toSet
compareResultToExpected(res, expected)
val numEdges = e.count()
assert(expected.size === numEdges * numEdges)
@@ -297,38 +285,32 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val edges = g.find("(a)-[]->(b); !(b)-[]->(a)")
assert(edges.columns === Array("a", "b"))
- val res = edges.select("a.id", "b.id")
- .collect().toSet
- compareResultToExpected(res, Set(
- Row(1L, 2L),
- Row(2L, 0L),
- Row(2L, 3L)
- ))
+ val res = edges
+ .select("a.id", "b.id")
+ .collect()
+ .toSet
+ compareResultToExpected(res, Set(Row(1L, 2L), Row(2L, 0L), Row(2L, 3L)))
}
test("a->b->c but not c->a") {
val edges = g.find("(a)-[]->(b); (b)-[]->(c); !(c)-[]->(a)")
assert(edges.columns === Array("a", "b", "c"))
- val res = edges.select("a.id", "b.id", "c.id")
- .collect().toSet
- assert(res === Set(
- Row(0L, 1L, 0L),
- Row(1L, 0L, 1L),
- Row(1L, 2L, 3L)
- ))
+ val res = edges
+ .select("a.id", "b.id", "c.id")
+ .collect()
+ .toSet
+ assert(res === Set(Row(0L, 1L, 0L), Row(1L, 0L, 1L), Row(1L, 2L, 3L)))
}
test("three connected vertices not in a triangle") {
- val fof = g.find("(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w); !(w)-[]->(u)")
+ val fof = g
+ .find("(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w); !(w)-[]->(u)")
.select("u.id", "v.id", "w.id")
- .collect().toSet
+ .collect()
+ .toSet
- compareResultToExpected(fof, Set(
- Row(1L, 0L, 1L),
- Row(0L, 1L, 0L),
- Row(1L, 2L, 3L)
- ))
+ compareResultToExpected(fof, Set(Row(1L, 0L, 1L), Row(0L, 1L, 0L), Row(1L, 2L, 3L)))
}
/* ========== 1 named vertex grounding a negated term to non-negated terms =============== */
@@ -337,35 +319,38 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val edges = g.find("(a)-[]->(b); !(b)-[]->(c)")
assert(edges.columns === Array("a", "b", "c"))
- val res = edges.select("a.id", "b.id", "c.id")
- .collect().toSet
- compareResultToExpected(res, Set(
- Row(0L, 1L, 1L),
- Row(0L, 1L, 3L),
- Row(1L, 0L, 0L),
- Row(1L, 0L, 2L),
- Row(1L, 0L, 3L),
- Row(1L, 2L, 1L),
- Row(1L, 2L, 2L),
- Row(2L, 3L, 0L),
- Row(2L, 3L, 1L),
- Row(2L, 3L, 2L),
- Row(2L, 3L, 3L),
- Row(2L, 0L, 0L),
- Row(2L, 0L, 2L),
- Row(2L, 0L, 3L)
- ))
+ val res = edges
+ .select("a.id", "b.id", "c.id")
+ .collect()
+ .toSet
+ compareResultToExpected(
+ res,
+ Set(
+ Row(0L, 1L, 1L),
+ Row(0L, 1L, 3L),
+ Row(1L, 0L, 0L),
+ Row(1L, 0L, 2L),
+ Row(1L, 0L, 3L),
+ Row(1L, 2L, 1L),
+ Row(1L, 2L, 2L),
+ Row(2L, 3L, 0L),
+ Row(2L, 3L, 1L),
+ Row(2L, 3L, 2L),
+ Row(2L, 3L, 3L),
+ Row(2L, 0L, 0L),
+ Row(2L, 0L, 2L),
+ Row(2L, 0L, 3L)))
}
test("a->b where b has no out edges") {
val edges = g.find("(a)-[]->(b); !(b)-[]->()")
assert(edges.columns === Array("a", "b"))
- val res = edges.select("a.id", "b.id")
- .collect().toSet
- compareResultToExpected(res, Set(
- Row(2L, 3L)
- ))
+ val res = edges
+ .select("a.id", "b.id")
+ .collect()
+ .toSet
+ compareResultToExpected(res, Set(Row(2L, 3L)))
}
/* ========== 0 named vertices grounding a negated term to non-negated terms =============== */
@@ -374,29 +359,30 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val edgePairs = g.find("(a)-[]->(b); !(c)-[]->(d)")
assert(edgePairs.columns === Array("a", "b", "c", "d"))
- val res = edgePairs.select("a.id", "b.id", "c.id", "d.id")
- .collect().toSet
- val expected = e.select(col("src").alias("a"), col("dst").alias("b"))
+ val res = edgePairs
+ .select("a.id", "b.id", "c.id", "d.id")
+ .collect()
+ .toSet
+ val expected = e
+ .select(col("src").alias("a"), col("dst").alias("b"))
.crossJoin(noEdges.select(col("src").alias("c"), col("dst").alias("d")))
.select("a", "b", "c", "d")
- .collect().toSet
+ .collect()
+ .toSet
compareResultToExpected(res, expected)
- assert(expected.size === noEdges.count() * e.count()) // make sure there are no duplicates
+ assert(expected.size === noEdges.count() * e.count()) // make sure there are no duplicates
}
test("a->b, c where c has no out edges") {
val triplets = g.find("(a)-[]->(b); !(c)-[]->()")
assert(triplets.columns === Array("a", "b", "c"))
- val res = triplets.select("a.id", "b.id", "c.id")
- .collect().toSet
- val expected = Set(
- Row(0L, 1L, 3L),
- Row(1L, 0L, 3L),
- Row(1L, 2L, 3L),
- Row(2L, 3L, 3L),
- Row(2L, 0L, 3L)
- )
+ val res = triplets
+ .select("a.id", "b.id", "c.id")
+ .collect()
+ .toSet
+ val expected =
+ Set(Row(0L, 1L, 3L), Row(1L, 0L, 3L), Row(1L, 2L, 3L), Row(2L, 3L, 3L), Row(2L, 0L, 3L))
compareResultToExpected(res, expected)
}
@@ -408,8 +394,10 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val seq = g.find("(a)-[]->(b); !(b)-[]->(c); !(c)-[]->(a)")
assert(seq.columns === Array("a", "b", "c"))
- val res = seq.select("a.id", "b.id", "c.id")
- .collect().toSet
+ val res = seq
+ .select("a.id", "b.id", "c.id")
+ .collect()
+ .toSet
val expected = Set(
Row(0L, 1L, 3L),
Row(1L, 0L, 2L),
@@ -421,8 +409,7 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
Row(2L, 3L, 3L),
Row(2L, 0L, 0L),
Row(2L, 0L, 2L),
- Row(2L, 0L, 3L)
- )
+ Row(2L, 0L, 3L))
compareResultToExpected(res, expected)
}
@@ -430,9 +417,11 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val edgePairs = g.find("(a)-[]->(b); !(a)-[]->(c); !(c)-[]->(d)")
assert(edgePairs.columns === Array("a", "b", "c", "d"))
- val res = edgePairs.select("a.id", "b.id", "c.id", "d.id")
- .where("a.id = 0 AND a.id != b.id") // check subset for brevity
- .collect().toSet
+ val res = edgePairs
+ .select("a.id", "b.id", "c.id", "d.id")
+ .where("a.id = 0 AND a.id != b.id") // check subset for brevity
+ .collect()
+ .toSet
val expected = Set(
Row(0L, 1L, 0L, 0L),
Row(0L, 1L, 0L, 2L),
@@ -442,26 +431,31 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
Row(0L, 1L, 3L, 0L),
Row(0L, 1L, 3L, 1L),
Row(0L, 1L, 3L, 2L),
- Row(0L, 1L, 3L, 3L)
- )
+ Row(0L, 1L, 3L, 3L))
compareResultToExpected(res, expected)
}
/* ============================== 0 non-negated terms ============================== */
test("query without non-negated terms, with one named vertex") {
- val res = g.find("!(v)-[]->()")
+ val res = g
+ .find("!(v)-[]->()")
.select("v.id")
- .collect().toSet
+ .collect()
+ .toSet
compareResultToExpected(res, Set(Row(3L)))
}
test("query without non-negated terms, with two named vertices") {
- val res = g.find("!(u)-[]->(v)")
+ val res = g
+ .find("!(u)-[]->(v)")
.select("u.id", "v.id")
- .collect().toSet
- val expected = noEdges.select(col("src").alias("u"), col("dst").alias("v"))
- .collect().toSet
+ .collect()
+ .toSet
+ val expected = noEdges
+ .select(col("src").alias("u"), col("dst").alias("v"))
+ .collect()
+ .toSet
compareResultToExpected(res, expected)
}
@@ -469,7 +463,8 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("named edges") {
// edges whose destination leads nowhere
- val edges = g.find("()-[e]->(v); !(v)-[]->()")
+ val edges = g
+ .find("()-[e]->(v); !(v)-[]->()")
.select("e.src", "e.dst")
val res = edges.collect().toSet
compareResultToExpected(res, Set(Row(2L, 3L)))
@@ -490,7 +485,8 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
}
test("find column order") {
- val fof = g.find("(u)-[e]->(v); (v)-[]->(w); !(u)-[]->(w); !(w)-[]->(u)")
+ val fof = g
+ .find("(u)-[e]->(v); (v)-[]->(w); !(u)-[]->(w); !(w)-[]->(u)")
.where("u.id != v.id AND v.id != w.id AND u.id != w.id")
assert(fof.columns === Array("u", "e", "v", "w"))
compareResultToExpected(
@@ -539,20 +535,18 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
/* ============================= More complex use case examples ============================== */
test("triangles via post-hoc filter") {
- val triangles = g.find("(a)-[]->(b); (b)-[]->(c); (d)-[]->(e)")
+ val triangles = g
+ .find("(a)-[]->(b); (b)-[]->(c); (d)-[]->(e)")
.where("c.id = d.id AND e.id = a.id")
.select("a.id", "b.id", "c.id")
val res = triangles.collect().toSet
- compareResultToExpected(res, Set(
- Row(0L, 1L, 2L),
- Row(2L, 0L, 1L),
- Row(1L, 2L, 0L)
- ))
+ compareResultToExpected(res, Set(Row(0L, 1L, 2L), Row(2L, 0L, 1L), Row(1L, 2L, 0L)))
}
test("stateful predicates via UDFs") {
- val chain4 = g.find("(a)-[ab]->(b); (b)-[bc]->(c); (c)-[cd]->(d)")
+ val chain4 = g
+ .find("(a)-[ab]->(b); (b)-[bc]->(c); (c)-[cd]->(d)")
.where("a.id != b.id AND b.id != c.id AND c.id != a.id")
// Using DataFrame operations, but not really operating in a stateful manner
@@ -562,7 +556,9 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
.reduce(_ + _) >= 2)
assert(chainWith2Friends.count() === 4)
- chainWith2Friends.select("ab.relationship", "bc.relationship", "cd.relationship").collect()
+ chainWith2Friends
+ .select("ab.relationship", "bc.relationship", "cd.relationship")
+ .collect()
.foreach { case Row(ab: String, bc: String, cd: String) =>
val numFriends = Seq(ab, bc, cd).map(r => if (r == "friend") 1 else 0).sum
assert(numFriends >= 2)
@@ -572,8 +568,8 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
def sumFriends(cnt: Column, relationship: Column): Column = {
when(relationship === "friend", cnt + 1).otherwise(cnt)
}
- val condition = Seq("ab", "bc", "cd").
- foldLeft(lit(0))((cnt, e) => sumFriends(cnt, col(e)("relationship")))
+ val condition =
+ Seq("ab", "bc", "cd").foldLeft(lit(0))((cnt, e) => sumFriends(cnt, col(e)("relationship")))
val chainWith2Friends2 = chain4.where(condition >= 2)
compareResultToExpected(chainWith2Friends.collect().toSet, chainWith2Friends2.collect().toSet)
@@ -604,5 +600,5 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
}
assert(joins.isEmpty, s"joins was non-empty: ${joins.map(_.toString()).mkString("; ")}")
}
- */
+ */
}
diff --git a/src/test/scala/org/graphframes/SparkFunSuite.scala b/src/test/scala/org/graphframes/SparkFunSuite.scala
index aaf1fb536..e53c1111a 100644
--- a/src/test/scala/org/graphframes/SparkFunSuite.scala
+++ b/src/test/scala/org/graphframes/SparkFunSuite.scala
@@ -27,9 +27,8 @@ private[graphframes] abstract class SparkFunSuite extends FunSuite with Logging
/**
* Log the suite name and the test name before and after each test.
*
- * Subclasses should never override this method. If they wish to run
- * custom code before and after each test, they should mix in the
- * {{org.scalatest.BeforeAndAfter}} trait instead.
+ * Subclasses should never override this method. If they wish to run custom code before and
+ * after each test, they should mix in the {{org.scalatest.BeforeAndAfter}} trait instead.
*/
final protected override def withFixture(test: NoArgTest): Outcome = {
val testName = test.text
diff --git a/src/test/scala/org/graphframes/TestUtils.scala b/src/test/scala/org/graphframes/TestUtils.scala
index c05a4f49d..629984da7 100644
--- a/src/test/scala/org/graphframes/TestUtils.scala
+++ b/src/test/scala/org/graphframes/TestUtils.scala
@@ -15,16 +15,19 @@ object TestUtils {
case Some(m) =>
(m.group(1).toInt, m.group(2).toInt)
case None =>
- throw new IllegalArgumentException(s"Spark tried to parse '$sparkVersion' as a Spark" +
- s" version string, but it could not find the major and minor version numbers.")
+ throw new IllegalArgumentException(
+ s"Spark tried to parse '$sparkVersion' as a Spark" +
+ s" version string, but it could not find the major and minor version numbers.")
}
}
/**
* Check whether the given schema contains a column of the required data type.
*
- * @param colName column name
- * @param dataType required column data type
+ * @param colName
+ * column name
+ * @param dataType
+ * required column data type
*/
def checkColumnType(
schema: StructType,
@@ -33,7 +36,8 @@ object TestUtils {
msg: String = ""): Unit = {
val actualDataType = schema(colName).dataType
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
- require(actualDataType.equals(dataType),
+ require(
+ actualDataType.equals(dataType),
s"Column $colName must be of type $dataType but was actually $actualDataType.$message")
}
@@ -47,10 +51,9 @@ object TestUtils {
}
/**
- * Test validity of both GraphFrames.
- * Also ensure that the GraphFrames match:
- * - vertex column schema match
- * - `before` columns are a subset of the `after` columns, and schema match
+ * Test validity of both GraphFrames. Also ensure that the GraphFrames match:
+ * - vertex column schema match
+ * - `before` columns are a subset of the `after` columns, and schema match
*/
def testSchemaInvariants(before: GraphFrame, after: GraphFrame): Unit = {
testSchemaInvariant(before)
@@ -76,23 +79,24 @@ object TestUtils {
if (!afterVNames.contains(f.name)) {
throw new Exception(s"vertex error: ${f.name} should be in ${afterVNames.mkString(", ")}")
}
- assert(before.vertices.schema(f.name) == after.vertices.schema(f.name),
+ assert(
+ before.vertices.schema(f.name) == after.vertices.schema(f.name),
s"${before.vertices.schema} != ${after.vertices.schema}")
}
for (f <- before.edges.schema.iterator) {
val a = before.edges.schema(f.name)
val b = after.edges.schema(f.name)
- assert(a.dataType == b.dataType,
+ assert(
+ a.dataType == b.dataType,
s"${before.edges.schema} not a subset of ${after.edges.schema}")
}
}
/**
- * Test validity of both GraphFrames.
- * Also ensure that the GraphFrames match:
- * - vertex column schema match
- * - `before` columns are a subset of the `after` columns, and schema match
+ * Test validity of both GraphFrames. Also ensure that the GraphFrames match:
+ * - vertex column schema match
+ * - `before` columns are a subset of the `after` columns, and schema match
*/
def testSchemaInvariants(before: GraphFrame, afterVertices: DataFrame): Unit = {
testSchemaInvariants(before, GraphFrame(afterVertices, before.edges))
diff --git a/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala b/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala
index 5f7e3382f..e2a7af557 100644
--- a/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala
+++ b/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala
@@ -23,11 +23,10 @@ import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite}
import org.graphframes.examples.BeliefPropagation._
import org.graphframes.examples.Graphs.gridIsingModel
-
class BeliefPropagationSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("BP using GraphX and GraphFrame aggregateMessages") {
- val n = 3 // graph is n x n
+ val n = 3 // graph is n x n
val numIter = 5 // iterations of BP
// Create graphical model g.
@@ -41,7 +40,8 @@ class BeliefPropagationSuite extends SparkFunSuite with GraphFrameTestSparkConte
// Check beliefs.
def checkResults(v: DataFrame): Unit = {
v.select("belief").collect().foreach { case Row(belief: Double) =>
- assert(belief >= 0.0 && belief <= 1.0,
+ assert(
+ belief >= 0.0 && belief <= 1.0,
s"Expected belief to be probability in [0,1], but found $belief")
}
}
@@ -51,10 +51,12 @@ class BeliefPropagationSuite extends SparkFunSuite with GraphFrameTestSparkConte
// Compare beliefs.
val gxBeliefs = gxResults.vertices.select("id", "belief")
val gfBeliefs = gfResults.vertices.select("id", "belief")
- gxBeliefs.join(gfBeliefs, "id")
+ gxBeliefs
+ .join(gfBeliefs, "id")
.select(gxBeliefs("belief").as("gxBelief"), gfBeliefs("belief").as("gfBelief"))
- .collect().foreach { case Row(gxBelief: Double, gfBelief: Double) =>
+ .collect()
+ .foreach { case Row(gxBelief: Double, gfBelief: Double) =>
assert(math.abs(gxBelief - gfBelief) <= 1e-6)
- }
+ }
}
}
diff --git a/src/test/scala/org/graphframes/examples/GraphsSuite.scala b/src/test/scala/org/graphframes/examples/GraphsSuite.scala
index 9165cfaa9..5d3076424 100644
--- a/src/test/scala/org/graphframes/examples/GraphsSuite.scala
+++ b/src/test/scala/org/graphframes/examples/GraphsSuite.scala
@@ -19,7 +19,6 @@ package org.graphframes.examples
import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite}
-
class GraphsSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("empty graph") {
diff --git a/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala b/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala
index 6f199df9e..3ea405ec4 100644
--- a/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala
+++ b/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.functions._
import org.graphframes.examples.Graphs
import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite}
-
class AggregateMessagesSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("aggregateMessages") {
@@ -41,18 +40,28 @@ class AggregateMessagesSuite extends SparkFunSuite with GraphFrameTestSparkConte
.agg(sum(AM.msg).as("summedAges"))
// Convert agg to a Map.
import org.apache.spark.sql._
- val aggMap: Map[String, Long] = agg.select("id", "summedAges").collect().map {
- case Row(id: String, s: Long) => id -> s
- }.toMap
+ val aggMap: Map[String, Long] = agg
+ .select("id", "summedAges")
+ .collect()
+ .map { case Row(id: String, s: Long) =>
+ id -> s
+ }
+ .toMap
// Compute the truth via brute force for comparison.
val trueAgg: Map[String, Int] = {
- val user2age = g.vertices.select("id", "age").collect().map {
- case Row(id: String, age: Int) => id -> age
- }.toMap
+ val user2age = g.vertices
+ .select("id", "age")
+ .collect()
+ .map { case Row(id: String, age: Int) =>
+ id -> age
+ }
+ .toMap
val a = mutable.HashMap.empty[String, Int]
g.edges.select("src", "dst", "relationship").collect().foreach {
case Row(src: String, dst: String, relationship: String) =>
- a.put(src, a.getOrElse(src, 0) + user2age(dst) + (if (relationship == "friend") 1 else 0))
+ a.put(
+ src,
+ a.getOrElse(src, 0) + user2age(dst) + (if (relationship == "friend") 1 else 0))
a.put(dst, a.getOrElse(dst, 0) + user2age(src))
}
a.toMap
@@ -69,9 +78,13 @@ class AggregateMessagesSuite extends SparkFunSuite with GraphFrameTestSparkConte
.sendToDst(msgToDst2)
.agg("sum(MSG) AS `summedAges`")
// Convert agg2 to a Map.
- val agg2Map: Map[String, Long] = agg2.select("id", "summedAges").collect().map {
- case Row(id: String, s: Long) => id -> s
- }.toMap
+ val agg2Map: Map[String, Long] = agg2
+ .select("id", "summedAges")
+ .collect()
+ .map { case Row(id: String, s: Long) =>
+ id -> s
+ }
+ .toMap
// Compare to the true values.
agg2Map.keys.foreach { case user =>
assert(agg2Map(user) === trueAgg(user), s"Failure on user $user")
diff --git a/src/test/scala/org/graphframes/lib/BFSSuite.scala b/src/test/scala/org/graphframes/lib/BFSSuite.scala
index d5982f9ca..081c3de99 100644
--- a/src/test/scala/org/graphframes/lib/BFSSuite.scala
+++ b/src/test/scala/org/graphframes/lib/BFSSuite.scala
@@ -22,7 +22,6 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.graphframes.{GraphFrameTestSparkContext, GraphFrame, SparkFunSuite}
-
class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext {
// First graph uses String IDs
@@ -44,36 +43,27 @@ class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext {
| |
d <- e --> f Also, self-edge for f
*/
- v = spark.createDataFrame(List(
- ("a", "f"),
- ("b", "f"),
- ("c", "m"),
- ("d", "f"),
- ("e", "m"),
- ("f", "m")
- )).toDF("id", "gender")
- e = spark.createDataFrame(List(
- ("a", "b", "friend"),
- ("b", "c", "follow"),
- ("c", "b", "follow"),
- ("f", "c", "follow"),
- ("e", "f", "follow"),
- ("e", "d", "friend"),
- ("d", "a", "friend"),
- ("f", "f", "self")
- )).toDF("src", "dst", "relationship")
+ v = spark
+ .createDataFrame(
+ List(("a", "f"), ("b", "f"), ("c", "m"), ("d", "f"), ("e", "m"), ("f", "m")))
+ .toDF("id", "gender")
+ e = spark
+ .createDataFrame(
+ List(
+ ("a", "b", "friend"),
+ ("b", "c", "follow"),
+ ("c", "b", "follow"),
+ ("f", "c", "follow"),
+ ("e", "f", "follow"),
+ ("e", "d", "friend"),
+ ("d", "a", "friend"),
+ ("f", "f", "self")))
+ .toDF("src", "dst", "relationship")
g = GraphFrame(v, e)
- v2 = spark.createDataFrame(List(
- (0L, "f"),
- (1L, "m"),
- (2L, "m"),
- (3L, "f"))).toDF("id", "gender")
- e2 = spark.createDataFrame(List(
- (0L, 1L),
- (1L, 2L),
- (2L, 3L),
- (2L, 0L))).toDF("src", "dst")
+ v2 =
+ spark.createDataFrame(List((0L, "f"), (1L, "m"), (2L, "m"), (3L, "f"))).toDF("id", "gender")
+ e2 = spark.createDataFrame(List((0L, 1L), (1L, 2L), (2L, 3L), (2L, 0L))).toDF("src", "dst")
g2 = GraphFrame(v2, e2)
}
@@ -121,26 +111,32 @@ class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("maxPathLength: length 1") {
val paths = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "f").maxPathLength(1).run()
assert(paths.count() === 1)
- val paths0 = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "f").maxPathLength(0).run()
+ val paths0 =
+ g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "f").maxPathLength(0).run()
assert(paths0.count() === 0)
}
test("maxPathLength: length > 1") {
val paths = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "b").maxPathLength(3).run()
assert(paths.count() === 2)
- val paths0 = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "b").maxPathLength(2).run()
+ val paths0 =
+ g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "b").maxPathLength(2).run()
assert(paths0.count() === 0)
}
test("edge filter") {
- val paths1 = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "b")
+ val paths1 = g.bfs
+ .fromExpr(col("id") === "e")
+ .toExpr(col("id") === "b")
.edgeFilter(col("src") =!= "d")
.run()
assert(paths1.count() === 1)
paths1.select("e0.dst").collect().foreach { case Row(id: String) =>
assert(id === "f")
}
- val paths2 = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "b")
+ val paths2 = g.bfs
+ .fromExpr(col("id") === "e")
+ .toExpr(col("id") === "b")
.edgeFilter(col("relationship") === "friend")
.run()
assert(paths2.count() === 1)
@@ -150,7 +146,9 @@ class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext {
}
test("string expressions") {
- val paths1 = g.bfs.fromExpr("id = 'e'").toExpr("id = 'b'")
+ val paths1 = g.bfs
+ .fromExpr("id = 'e'")
+ .toExpr("id = 'b'")
.edgeFilter("src != 'd'")
.run()
assert(paths1.count() === 1)
diff --git a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala
index 96acd0bde..f55ae4edd 100644
--- a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala
+++ b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala
@@ -49,18 +49,20 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
}
test("single vertex") {
- val v = spark.createDataFrame(List(
- (0L, "a", "b"))).toDF("id", "vattr", "gender")
+ val v = spark.createDataFrame(List((0L, "a", "b"))).toDF("id", "vattr", "gender")
// Create an empty dataframe with the proper columns.
- val e = spark.createDataFrame(List((0L, 0L, 1L))).toDF("src", "dst", "test")
+ val e = spark
+ .createDataFrame(List((0L, 0L, 1L)))
+ .toDF("src", "dst", "test")
.filter("src > 10")
val g = GraphFrame(v, e)
val comps = ConnectedComponents.run(g)
TestUtils.testSchemaInvariants(g, comps)
TestUtils.checkColumnType(comps.schema, "component", DataTypes.LongType)
assert(comps.count() === 1)
- assert(comps.select("id", "component", "vattr", "gender").collect()
- === Seq(Row(0L, 0L, "a", "b")))
+ assert(
+ comps.select("id", "component", "vattr", "gender").collect()
+ === Seq(Row(0L, 0L, "a", "b")))
}
test("disconnected vertices") {
@@ -74,11 +76,8 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
}
test("two connected vertices") {
- val v = spark.createDataFrame(List(
- (0L, "a0", "b0"),
- (1L, "a1", "b1"))).toDF("id", "A", "B")
- val e = spark.createDataFrame(List(
- (0L, 1L, "a01", "b01"))).toDF("src", "dst", "A", "B")
+ val v = spark.createDataFrame(List((0L, "a0", "b0"), (1L, "a1", "b1"))).toDF("id", "A", "B")
+ val e = spark.createDataFrame(List((0L, 1L, "a01", "b01"))).toDF("src", "dst", "A", "B")
val g = GraphFrame(v, e)
val comps = g.connectedComponents.run()
TestUtils.testSchemaInvariants(g, comps)
@@ -113,10 +112,9 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
test("two components") {
val vertices = spark.range(6L).toDF(ID)
- val edges = spark.createDataFrame(Seq(
- (0L, 1L), (1L, 2L), (2L, 0L),
- (3L, 4L), (4L, 5L), (5L, 3L)
- )).toDF(SRC, DST)
+ val edges = spark
+ .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L)))
+ .toDF(SRC, DST)
val g = GraphFrame(vertices, edges)
val components = g.connectedComponents.run()
val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L))
@@ -125,10 +123,15 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
test("one component, differing edge directions") {
val vertices = spark.range(5L).toDF(ID)
- val edges = spark.createDataFrame(Seq(
- // 0 -> 4 -> 3 <- 2 -> 1
- (0L, 4L), (4L, 3L), (2L, 3L), (2L, 1L)
- )).toDF(SRC, DST)
+ val edges = spark
+ .createDataFrame(
+ Seq(
+ // 0 -> 4 -> 3 <- 2 -> 1
+ (0L, 4L),
+ (4L, 3L),
+ (2L, 3L),
+ (2L, 1L)))
+ .toDF(SRC, DST)
val g = GraphFrame(vertices, edges)
val components = g.connectedComponents.run()
val expected = Set((0L to 4L).toSet)
@@ -137,10 +140,9 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
test("two components and two dangling vertices") {
val vertices = spark.range(8L).toDF(ID)
- val edges = spark.createDataFrame(Seq(
- (0L, 1L), (1L, 2L), (2L, 0L),
- (3L, 4L), (4L, 5L), (5L, 3L)
- )).toDF(SRC, DST)
+ val edges = spark
+ .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L)))
+ .toDF(SRC, DST)
val g = GraphFrame(vertices, edges)
val components = g.connectedComponents.run()
val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L), Set(6L), Set(7L))
@@ -151,7 +153,7 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
val friends = Graphs.friends
val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g"))
for ((algorithm, broadcastThreshold) <-
- Seq(("graphx", 1000000), ("graphframes", 100000), ("graphframes", 1))) {
+ Seq(("graphx", 1000000), ("graphframes", 100000), ("graphframes", 1))) {
val components = friends.connectedComponents
.setAlgorithm(algorithm)
.setBroadcastThreshold(broadcastThreshold)
@@ -176,15 +178,17 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g"))
val cc = new ConnectedComponents(friends)
- assert(cc.getCheckpointInterval === 2,
+ assert(
+ cc.getCheckpointInterval === 2,
s"Default checkpoint interval should be 2, but got ${cc.getCheckpointInterval}.")
val checkpointDir = sc.getCheckpointDir
assert(checkpointDir.nonEmpty)
sc.setCheckpointDir(null)
- withClue("Should throw an IOException if sc.getCheckpointDir is empty " +
- "and checkpointInterval is positive.") {
+ withClue(
+ "Should throw an IOException if sc.getCheckpointDir is empty " +
+ "and checkpointInterval is positive.") {
intercept[IOException] {
cc.run()
}
@@ -198,19 +202,22 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
val components0 = cc.setCheckpointInterval(0).run()
assertComponents(components0, expected)
- assert(!isFromCheckpoint(components0),
+ assert(
+ !isFromCheckpoint(components0),
"The result shouldn't depend on checkpoint data if checkpointing is disabled.")
sc.setCheckpointDir(checkpointDir.get)
val components1 = cc.setCheckpointInterval(1).run()
assertComponents(components1, expected)
- assert(isFromCheckpoint(components1),
+ assert(
+ isFromCheckpoint(components1),
"The result should depend on checkpoint data if checkpoint interval is 1.")
val components10 = cc.setCheckpointInterval(10).run()
assertComponents(components10, expected)
- assert(!isFromCheckpoint(components10),
+ assert(
+ !isFromCheckpoint(components10),
"The result shouldn't depend on checkpoint data if converged before first checkpoint.")
}
@@ -226,7 +233,10 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
val cc = friends.connectedComponents
assert(cc.getIntermediateStorageLevel === StorageLevel.MEMORY_AND_DISK)
- for (storageLevel <- Seq(StorageLevel.DISK_ONLY, StorageLevel.MEMORY_ONLY, StorageLevel.NONE)) {
+ for (storageLevel <- Seq(
+ StorageLevel.DISK_ONLY,
+ StorageLevel.MEMORY_ONLY,
+ StorageLevel.NONE)) {
// TODO: it is not trivial to confirm the actual storage level used
val components = cc
.setIntermediateStorageLevel(storageLevel)
@@ -243,22 +253,27 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
}
}
- private def assertComponents[T: ClassTag:TypeTag](
+ private def assertComponents[T: ClassTag: TypeTag](
actual: DataFrame,
expected: Set[Set[T]]): Unit = {
import actual.sparkSession.implicits._
// note: not using agg + collect_list because collect_list is not available in 1.6.2 w/o hive
- val actualComponents = actual.select("component", "id").as[(Long, T)].rdd
+ val actualComponents = actual
+ .select("component", "id")
+ .as[(Long, T)]
+ .rdd
.groupByKey()
.values
.map(_.toSeq)
.collect()
.map { ids =>
val idSet = ids.toSet
- assert(idSet.size === ids.size,
+ assert(
+ idSet.size === ids.size,
s"Found duplicated component assignment in [${ids.mkString(",")}].")
idSet
- }.toSet
+ }
+ .toSet
assert(actualComponents === expected)
}
}
diff --git a/src/test/scala/org/graphframes/lib/PageRankSuite.scala b/src/test/scala/org/graphframes/lib/PageRankSuite.scala
index eca810ec7..40343f3f5 100644
--- a/src/test/scala/org/graphframes/lib/PageRankSuite.scala
+++ b/src/test/scala/org/graphframes/lib/PageRankSuite.scala
@@ -33,7 +33,8 @@ class PageRankSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val errorTol = 1.0e-5
val pr = g.pageRank
.resetProbability(resetProb)
- .tol(errorTol).run()
+ .tol(errorTol)
+ .run()
TestUtils.testSchemaInvariants(g, pr)
TestUtils.checkColumnType(pr.vertices.schema, "pagerank", DataTypes.DoubleType)
TestUtils.checkColumnType(pr.edges.schema, "weight", DataTypes.DoubleType)
@@ -43,7 +44,8 @@ class PageRankSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val results = Graphs.friends.pageRank.resetProbability(0.15).maxIter(10).sourceId("a").run()
val gRank = results.vertices.filter(col("id") === "g").select("pagerank").first().getDouble(0)
- assert(gRank === 0.0,
+ assert(
+ gRank === 0.0,
s"User g (Gabby) doesn't connect with a. So its pagerank should be 0 but we got $gRank.")
}
}
diff --git a/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala b/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala
index 1fc2d0daa..0fa75ce5b 100644
--- a/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala
+++ b/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala
@@ -70,8 +70,9 @@ class ParallelPersonalizedPageRankSuite extends SparkFunSuite with GraphFrameTes
// In Spark <2.4, sourceIds must be smaller than Int.MaxValue,
// which might not be the case for LONG_ID in graph.indexedVertices.
- if (Version.valueOf(org.apache.spark.SPARK_VERSION)
- .greaterThanOrEqualTo(Version.valueOf("2.4.0"))) {
+ if (Version
+ .valueOf(org.apache.spark.SPARK_VERSION)
+ .greaterThanOrEqualTo(Version.valueOf("2.4.0"))) {
test("friends graph with parallel personalized PageRank") {
val g = Graphs.friends
val resetProb = 0.15
@@ -89,14 +90,17 @@ class ParallelPersonalizedPageRankSuite extends SparkFunSuite with GraphFrameTes
.filter { row: Row =>
vertexIds.size != row.getAs[SparseVector](0).size
}
- assert(prInvalid.size === 0,
+ assert(
+ prInvalid.size === 0,
s"found ${prInvalid.size} entries with invalid number of returned personalized pagerank vector")
val gRank = pr.vertices
.filter(col("id") === "g")
.select("pageranks")
- .first().getAs[SparseVector](0)
- assert(gRank.numNonzeros === 0,
+ .first()
+ .getAs[SparseVector](0)
+ assert(
+ gRank.numNonzeros === 0,
s"User g (Gabby) doesn't connect with a. So its pagerank should be 0 but we got ${gRank.numNonzeros}.")
}
}
diff --git a/src/test/scala/org/graphframes/lib/PregelSuite.scala b/src/test/scala/org/graphframes/lib/PregelSuite.scala
index e6af54440..dc6cc35c8 100644
--- a/src/test/scala/org/graphframes/lib/PregelSuite.scala
+++ b/src/test/scala/org/graphframes/lib/PregelSuite.scala
@@ -35,8 +35,7 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext {
(2L, 0L),
(3L, 4L), // 3 has no in-links
(4L, 0L),
- (4L, 2L)
- ).toDF("src", "dst").cache()
+ (4L, 2L)).toDF("src", "dst").cache()
val vertices = GraphFrame.fromEdges(edges).outDegrees.cache()
val numVertices = vertices.count()
val graph = GraphFrame(vertices, edges)
@@ -45,14 +44,19 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext {
// NOTE: This version doesn't handle nodes with no out-links.
val ranks = graph.pregel
.setMaxIter(5)
- .withVertexColumn("rank", lit(1.0 / numVertices),
+ .withVertexColumn(
+ "rank",
+ lit(1.0 / numVertices),
coalesce(Pregel.msg, lit(0.0)) * (1.0 - alpha) + alpha / numVertices)
.sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree"))
.aggMsgs(sum(Pregel.msg))
.run()
- val result = ranks.sort(col("id"))
- .select("rank").as[Double].collect()
+ val result = ranks
+ .sort(col("id"))
+ .select("rank")
+ .as[Double]
+ .collect()
assert(result.sum === 1.0 +- 1e-6)
val expected = Seq(0.245, 0.224, 0.303, 0.03, 0.197)
result.zip(expected).foreach { case (r, e) =>
@@ -63,20 +67,20 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("chain propagation") {
val n = 5
val verDF = (1 to n).toDF("id").repartition(3)
- val edgeDF = (1 until n).map(x => (x, x + 1))
- .toDF("src", "dst").repartition(3)
+ val edgeDF = (1 until n)
+ .map(x => (x, x + 1))
+ .toDF("src", "dst")
+ .repartition(3)
val graph = GraphFrame(verDF, edgeDF)
val resultDF = graph.pregel
.setMaxIter(n - 1)
- .withVertexColumn("value",
+ .withVertexColumn(
+ "value",
when(col("id") === lit(1), lit(1)).otherwise(lit(0)),
- when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))
- )
- .sendMsgToDst(
- when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value"))
- )
+ when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value")))
+ .sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value")))
.aggMsgs(max(Pregel.msg))
.run()
@@ -86,20 +90,20 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("reverse chain propagation") {
val n = 5
val verDF = (1 to n).toDF("id").repartition(3)
- val edgeDF = (1 until n).map(x => (x + 1, x))
- .toDF("src", "dst").repartition(3)
+ val edgeDF = (1 until n)
+ .map(x => (x + 1, x))
+ .toDF("src", "dst")
+ .repartition(3)
val graph = GraphFrame(verDF, edgeDF)
val resultDF = graph.pregel
.setMaxIter(n - 1)
- .withVertexColumn("value",
+ .withVertexColumn(
+ "value",
when(col("id") === lit(1), lit(1)).otherwise(lit(0)),
- when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))
- )
- .sendMsgToSrc(
- when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.dst("value"))
- )
+ when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value")))
+ .sendMsgToSrc(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.dst("value")))
.aggMsgs(max(Pregel.msg))
.run()
diff --git a/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala b/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala
index 890ed5e14..b7a2df5e7 100644
--- a/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala
+++ b/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala
@@ -23,7 +23,6 @@ import org.apache.spark.sql.types.DataTypes
import org.graphframes.{GraphFrame, GraphFrameTestSparkContext, SparkFunSuite, TestUtils}
import org.graphframes.examples.Graphs
-
class SVDPlusPlusSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("Test SVD++ with mean square error on training set") {
@@ -33,16 +32,21 @@ class SVDPlusPlusSuite extends SparkFunSuite with GraphFrameTestSparkContext {
val v2 = g.svdPlusPlus.maxIter(2).run()
TestUtils.testSchemaInvariants(g, v2)
Seq(SVDPlusPlus.COLUMN1, SVDPlusPlus.COLUMN2).foreach { case c =>
- TestUtils.checkColumnType(v2.schema, c,
+ TestUtils.checkColumnType(
+ v2.schema,
+ c,
DataTypes.createArrayType(DataTypes.DoubleType, false))
}
Seq(SVDPlusPlus.COLUMN3, SVDPlusPlus.COLUMN4).foreach { case c =>
TestUtils.checkColumnType(v2.schema, c, DataTypes.DoubleType)
}
- val err = v2.select(GraphFrame.ID, SVDPlusPlus.COLUMN4).rdd.map {
- case Row(vid: Long, vd: Double) =>
+ val err = v2
+ .select(GraphFrame.ID, SVDPlusPlus.COLUMN4)
+ .rdd
+ .map { case Row(vid: Long, vd: Double) =>
if (vid % 2 == 1) vd else 0.0
- }.reduce(_ + _) / g.edges.count()
+ }
+ .reduce(_ + _) / g.edges.count()
assert(err <= svdppErr)
}
}
diff --git a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala
index fc0262d1b..69899ed9a 100644
--- a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala
+++ b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala
@@ -25,26 +25,34 @@ import org.graphframes._
class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("Simple test") {
- val edgeSeq = Seq((1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)).flatMap {
- case e => Seq(e, e.swap)
- } .map { case (src, dst) => (src.toLong, dst.toLong) }
+ val edgeSeq = Seq((1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6))
+ .flatMap { case e =>
+ Seq(e, e.swap)
+ }
+ .map { case (src, dst) => (src.toLong, dst.toLong) }
val edges = spark.createDataFrame(edgeSeq).toDF("src", "dst")
val graph = GraphFrame.fromEdges(edges)
// Ground truth
val shortestPaths = Set(
- (1, Map(1 -> 0, 4 -> 2)), (2, Map(1 -> 1, 4 -> 2)), (3, Map(1 -> 2, 4 -> 1)),
- (4, Map(1 -> 2, 4 -> 0)), (5, Map(1 -> 1, 4 -> 1)), (6, Map(1 -> 3, 4 -> 1)))
+ (1, Map(1 -> 0, 4 -> 2)),
+ (2, Map(1 -> 1, 4 -> 2)),
+ (3, Map(1 -> 2, 4 -> 1)),
+ (4, Map(1 -> 2, 4 -> 0)),
+ (5, Map(1 -> 1, 4 -> 1)),
+ (6, Map(1 -> 3, 4 -> 1)))
val landmarks = Seq(1, 4).map(_.toLong)
val v2 = graph.shortestPaths.landmarks(landmarks).run()
TestUtils.testSchemaInvariants(graph, v2)
- TestUtils.checkColumnType(v2.schema, "distances",
+ TestUtils.checkColumnType(
+ v2.schema,
+ "distances",
DataTypes.createMapType(v2.schema("id").dataType, DataTypes.IntegerType, false))
val newVs = v2.select("id", "distances").collect().toSeq
- val results = newVs.map {
- case Row(id: Long, spMap: Map[Long, Int] @unchecked) => (id, spMap)
+ val results = newVs.map { case Row(id: Long, spMap: Map[Long, Int] @unchecked) =>
+ (id, spMap)
}
assert(results.toSet === shortestPaths)
}
@@ -52,13 +60,21 @@ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("friends graph") {
val friends = examples.Graphs.friends
val v = friends.shortestPaths.landmarks(Seq("a", "d")).run()
- val expected = Set[(String, Map[String, Int])](("a", Map("a" -> 0, "d" -> 2)), ("b", Map.empty),
- ("c", Map.empty), ("d", Map("a" -> 1, "d" -> 0)), ("e", Map("a" -> 2, "d" -> 1)),
- ("f", Map.empty), ("g", Map.empty))
- val results = v.select("id", "distances").collect().map {
- case Row(id: String, spMap: Map[String, Int] @unchecked) =>
+ val expected = Set[(String, Map[String, Int])](
+ ("a", Map("a" -> 0, "d" -> 2)),
+ ("b", Map.empty),
+ ("c", Map.empty),
+ ("d", Map("a" -> 1, "d" -> 0)),
+ ("e", Map("a" -> 2, "d" -> 1)),
+ ("f", Map.empty),
+ ("g", Map.empty))
+ val results = v
+ .select("id", "distances")
+ .collect()
+ .map { case Row(id: String, spMap: Map[String, Int] @unchecked) =>
(id, spMap)
- }.toSet
+ }
+ .toSet
assert(results === expected)
}
diff --git a/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala
index 65ac324a0..c549cc5ca 100644
--- a/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala
+++ b/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala
@@ -24,19 +24,15 @@ import org.graphframes.{GraphFrameTestSparkContext, GraphFrame, SparkFunSuite, T
class StronglyConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("Island Strongly Connected Components") {
- val vertices = spark.createDataFrame(Seq(
- (1L, "a"),
- (2L, "b"),
- (3L, "c"),
- (4L, "d"),
- (5L, "e"))).toDF("id", "value")
+ val vertices = spark
+ .createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"), (4L, "d"), (5L, "e")))
+ .toDF("id", "value")
val edges = spark.createDataFrame(Seq.empty[(Long, Long)]).toDF("src", "dst")
val graph = GraphFrame(vertices, edges)
val c = graph.stronglyConnectedComponents.maxIter(5).run()
TestUtils.testSchemaInvariants(graph, c)
TestUtils.checkColumnType(c.schema, "component", DataTypes.LongType)
- for (Row(id: Long, component: Long, _)
- <- c.select("id", "component", "value").collect()) {
+ for (Row(id: Long, component: Long, _) <- c.select("id", "component", "value").collect()) {
assert(id === component)
}
}
diff --git a/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala b/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala
index 2ac2127a5..9ad04eeca 100644
--- a/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala
+++ b/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala
@@ -26,19 +26,24 @@ class TriangleCountSuite extends SparkFunSuite with GraphFrameTestSparkContext {
test("Count a single triangle") {
val edges = spark.createDataFrame(Array(0L -> 1L, 1L -> 2L, 2L -> 0L)).toDF("src", "dst")
- val vertices = spark.createDataFrame(Seq((0L, "a"), (1L, "b"), (2L, "c")))
+ val vertices = spark
+ .createDataFrame(Seq((0L, "a"), (1L, "b"), (2L, "c")))
.toDF("id", "a")
val g = GraphFrame(vertices, edges)
val v2 = g.triangleCount.run()
TestUtils.testSchemaInvariants(g, v2)
TestUtils.checkColumnType(v2.schema, "count", DataTypes.LongType)
v2.select("id", "count", "a")
- .collect().foreach { case Row(vid: Long, count: Long, _) => assert(count === 1) }
+ .collect()
+ .foreach { case Row(vid: Long, count: Long, _) => assert(count === 1) }
}
test("Count two triangles") {
- val edges = spark.createDataFrame(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
- Array(0L -> -1L, -1L -> -2L, -2L -> 0L)).toDF("src", "dst")
+ val edges = spark
+ .createDataFrame(
+ Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
+ Array(0L -> -1L, -1L -> -2L, -2L -> 0L))
+ .toDF("src", "dst")
val g = GraphFrame.fromEdges(edges)
val v2 = g.triangleCount.run()
v2.select("id", "count").collect().foreach { case Row(id: Long, count: Long) =>
@@ -67,8 +72,11 @@ class TriangleCountSuite extends SparkFunSuite with GraphFrameTestSparkContext {
}
test("Count a single triangle with duplicate edges") {
- val edges = spark.createDataFrame(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
- Array(0L -> 1L, 1L -> 2L, 2L -> 0L)).toDF("src", "dst")
+ val edges = spark
+ .createDataFrame(
+ Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
+ Array(0L -> 1L, 1L -> 2L, 2L -> 0L))
+ .toDF("src", "dst")
val g = GraphFrame.fromEdges(edges)
val v2 = g.triangleCount.run()
v2.select("id", "count").collect().foreach { case Row(id: Long, count: Long) =>
diff --git a/src/test/scala/org/graphframes/pattern/PatternSuite.scala b/src/test/scala/org/graphframes/pattern/PatternSuite.scala
index 7e1e23928..5a7cdd05a 100644
--- a/src/test/scala/org/graphframes/pattern/PatternSuite.scala
+++ b/src/test/scala/org/graphframes/pattern/PatternSuite.scala
@@ -24,27 +24,32 @@ class PatternSuite extends SparkFunSuite {
test("good parses") {
assert(Pattern.parse("(abc)") === Seq(NamedVertex("abc")))
- assert(Pattern.parse("(u)-[e]->(v)") ===
- Seq(NamedEdge("e", NamedVertex("u"), NamedVertex("v"))))
-
- assert(Pattern.parse("()-[]->(v)") ===
- Seq(AnonymousEdge(AnonymousVertex, NamedVertex("v"))))
-
- assert(Pattern.parse("()-[e]->()") ===
- Seq(NamedEdge("e", AnonymousVertex, AnonymousVertex)))
-
- assert(Pattern.parse("(u)-[e]->(u)") ===
- Seq(NamedEdge("e", NamedVertex("u"), NamedVertex("u"))))
-
- assert(Pattern.parse("(u); ()-[]->(v)") ===
- Seq(NamedVertex("u"), AnonymousEdge(AnonymousVertex, NamedVertex("v"))))
-
- assert(Pattern.parse("(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w)") ===
- Seq(
- AnonymousEdge(NamedVertex("u"), NamedVertex("v")),
- AnonymousEdge(NamedVertex("v"), NamedVertex("w")),
- Negation(
- AnonymousEdge(NamedVertex("u"), NamedVertex("w")))))
+ assert(
+ Pattern.parse("(u)-[e]->(v)") ===
+ Seq(NamedEdge("e", NamedVertex("u"), NamedVertex("v"))))
+
+ assert(
+ Pattern.parse("()-[]->(v)") ===
+ Seq(AnonymousEdge(AnonymousVertex, NamedVertex("v"))))
+
+ assert(
+ Pattern.parse("()-[e]->()") ===
+ Seq(NamedEdge("e", AnonymousVertex, AnonymousVertex)))
+
+ assert(
+ Pattern.parse("(u)-[e]->(u)") ===
+ Seq(NamedEdge("e", NamedVertex("u"), NamedVertex("u"))))
+
+ assert(
+ Pattern.parse("(u); ()-[]->(v)") ===
+ Seq(NamedVertex("u"), AnonymousEdge(AnonymousVertex, NamedVertex("v"))))
+
+ assert(
+ Pattern.parse("(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w)") ===
+ Seq(
+ AnonymousEdge(NamedVertex("u"), NamedVertex("v")),
+ AnonymousEdge(NamedVertex("v"), NamedVertex("w")),
+ Negation(AnonymousEdge(NamedVertex("u"), NamedVertex("w")))))
}
test("bad parses") {
@@ -133,13 +138,9 @@ class PatternSuite extends SparkFunSuite {
"(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w)",
Seq.empty[String])
- testFindNamedVerticesOnlyInNegatedTerms(
- "(u)-[]->(v); (v)-[]->(w)",
- Seq.empty[String])
+ testFindNamedVerticesOnlyInNegatedTerms("(u)-[]->(v); (v)-[]->(w)", Seq.empty[String])
- testFindNamedVerticesOnlyInNegatedTerms(
- "!(u)-[]->(v)",
- Seq("u", "v"))
+ testFindNamedVerticesOnlyInNegatedTerms("!(u)-[]->(v)", Seq("u", "v"))
testFindNamedVerticesOnlyInNegatedTerms(
"(u)-[]->(v); (v)-[]->(w); !(a)-[]->(b); !(v)-[]->(c)",
@@ -153,13 +154,9 @@ class PatternSuite extends SparkFunSuite {
}
}
- testFindNamedElementsInOrder(
- "(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w)",
- Seq("u", "v", "w"))
+ testFindNamedElementsInOrder("(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w)", Seq("u", "v", "w"))
- testFindNamedElementsInOrder(
- "(u)-[]->(v); ()-[vw]->()",
- Seq("u", "v", "vw"))
+ testFindNamedElementsInOrder("(u)-[]->(v); ()-[vw]->()", Seq("u", "v", "vw"))
testFindNamedElementsInOrder(
"(u)-[uv]->(v); (v)-[vw]->(w); !(u)-[]->(w); (x)",