diff --git a/.gitignore b/.gitignore index a07973c1e..9881378cf 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ src_managed/ project/boot/ project/plugins/project/ .bsp +.metals # intellij .idea/ diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 000000000..347451104 --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1,15 @@ +# The context of this file is copied from the Apache Spark Project + +align = none +align.openParenDefnSite = false +align.openParenCallSite = false +align.tokens = [] +importSelectors = "singleLine" +optIn = { + configStyleArguments = false +} +danglingParentheses.preset = false +docstrings.style = Asterisk +maxColumn = 98 +runner.dialect = scala213 +version = 3.8.5 \ No newline at end of file diff --git a/build.sbt b/build.sbt index 061901717..4ee4d9bd5 100644 --- a/build.sbt +++ b/build.sbt @@ -1,4 +1,4 @@ -import ReleaseTransformations._ +import ReleaseTransformations.* lazy val sparkVer = sys.props.getOrElse("spark.version", "3.5.4") lazy val sparkBranch = sparkVer.substring(0, 3) diff --git a/project/plugins.sbt b/project/plugins.sbt index 46028c336..feb5a0677 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -8,3 +8,4 @@ ThisBuild / libraryDependencySchemes ++= Seq( addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.10") addSbtPlugin("com.github.sbt" % "sbt-release" % "1.4.0") addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.1") +addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.4") \ No newline at end of file diff --git a/src/main/scala/org/graphframes/GraphFrame.scala b/src/main/scala/org/graphframes/GraphFrame.scala index c9a221a6a..d94305364 100644 --- a/src/main/scala/org/graphframes/GraphFrame.scala +++ b/src/main/scala/org/graphframes/GraphFrame.scala @@ -40,9 +40,11 @@ import org.graphframes.pattern._ * @groupname degree Graph topology * @groupname motif Motif finding */ -class GraphFrame private( +class GraphFrame private ( @transient private val _vertices: DataFrame, - @transient private val _edges: DataFrame) extends Logging with Serializable { + @transient private val _edges: DataFrame) + extends Logging + with Serializable { import GraphFrame._ @@ -53,7 +55,8 @@ class GraphFrame private( // We call select on the vertices and edges to ensure that ID, SRC, DST always come first // in the printed schema. val vCols = (ID +: vertices.columns.filter(_ != ID).toIndexedSeq).map(col) - val eCols = (SRC +: DST +: edges.columns.filter(c => c != SRC && c != DST).toIndexedSeq).map(col) + val eCols = + (SRC +: DST +: edges.columns.filter(c => c != SRC && c != DST).toIndexedSeq).map(col) val v = vertices.select(vCols.toSeq: _*).toString val e = edges.select(eCols.toSeq: _*).toString "GraphFrame(v:" + v + ", e:" + e + ")" @@ -80,8 +83,9 @@ class GraphFrame private( /** * Persist the dataframe representation of vertices and edges of the graph with the given * storage level. - * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, - * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc.. + * @param newLevel + * One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, `MEMORY_AND_DISK_SER`, + * `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc.. */ def persist(newLevel: StorageLevel): this.type = { vertices.persist(newLevel) @@ -102,7 +106,8 @@ class GraphFrame private( /** * Mark the dataframe representation of vertices and edges of the graph as non-persistent, and * remove all blocks for it from memory and disk. - * @param blocking Whether to block until all blocks are deleted. + * @param blocking + * Whether to block until all blocks are deleted. */ def unpersist(blocking: Boolean): this.type = { vertices.unpersist(blocking) @@ -115,8 +120,8 @@ class GraphFrame private( /** * The dataframe representation of the vertices of the graph. * - * It contains a column called [[GraphFrame.ID]] with the id of the vertex, - * and various other user-defined attributes with other attributes. + * It contains a column called [[GraphFrame.ID]] with the id of the vertex, and various other + * user-defined attributes with other attributes. * * The order of the columns is available in [[vertexColumns]]. * @@ -132,9 +137,9 @@ class GraphFrame private( /** * The dataframe representation of the edges of the graph. * - * It contains two columns called [[GraphFrame.SRC]] and [[GraphFrame.DST]] that contain - * the ids of the source vertex and the destination vertex of each edge, respectively. - * It may also contain various other columns with user-defined attributes for each edge. + * It contains two columns called [[GraphFrame.SRC]] and [[GraphFrame.DST]] that contain the ids + * of the source vertex and the destination vertex of each edge, respectively. It may also + * contain various other columns with user-defined attributes for each edge. * * For symmetric graphs, both pairs src -> dst and dst -> src are present with the same * attributes for each pair. @@ -154,7 +159,7 @@ class GraphFrame private( /** * Returns triplets: (source vertex)-[edge]->(destination vertex) for all edges in the graph. * The DataFrame returned has 3 columns, with names: [[GraphFrame.SRC]], [[GraphFrame.EDGE]], - * and [[GraphFrame.DST]]. The 2 vertex columns have schema matching [[GraphFrame.vertices]], + * and [[GraphFrame.DST]]. The 2 vertex columns have schema matching [[GraphFrame.vertices]], * and the edge column has a schema matching [[GraphFrame.edges]]. * * @group structure @@ -164,11 +169,11 @@ class GraphFrame private( // ============================ Conversions ======================================== /** - * Converts this [[GraphFrame]] instance to a GraphX `Graph`. - * Vertex and edge attributes are the original rows in [[vertices]] and [[edges]], respectively. + * Converts this [[GraphFrame]] instance to a GraphX `Graph`. Vertex and edge attributes are the + * original rows in [[vertices]] and [[edges]], respectively. * - * Note that vertex (and edge) attributes include vertex IDs (and source, destination IDs) - * in order to support non-Long vertex IDs. If the vertex IDs are not convertible to Long values, + * Note that vertex (and edge) attributes include vertex IDs (and source, destination IDs) in + * order to support non-Long vertex IDs. If the vertex IDs are not convertible to Long values, * then the values are indexed in order to generate corresponding Long vertex IDs (which is an * expensive operation). * @@ -179,16 +184,22 @@ class GraphFrame private( */ def toGraphX: Graph[Row, Row] = { if (hasIntegralIdType) { - val vv = vertices.select(col(ID).cast(LongType), nestAsCol(vertices, ATTR)) - .rdd.map { case Row(id: Long, attr: Row) => (id, attr) } - val ee = edges.select(col(SRC).cast(LongType), col(DST).cast(LongType), nestAsCol(edges, ATTR)) - .rdd.map { case Row(srcId: Long, dstId: Long, attr: Row) => Edge(srcId, dstId, attr) } + val vv = vertices.select(col(ID).cast(LongType), nestAsCol(vertices, ATTR)).rdd.map { + case Row(id: Long, attr: Row) => (id, attr) + } + val ee = edges + .select(col(SRC).cast(LongType), col(DST).cast(LongType), nestAsCol(edges, ATTR)) + .rdd + .map { case Row(srcId: Long, dstId: Long, attr: Row) => Edge(srcId, dstId, attr) } Graph(vv, ee) } else { // Compute Long vertex IDs - val vv = indexedVertices.select(LONG_ID, ATTR).rdd.map { case Row(long_id: Long, attr: Row) => (long_id, attr) } - val ee = indexedEdges.select(LONG_SRC, LONG_DST, ATTR).rdd.map { case Row(long_src: Long, long_dst: Long, attr: Row) => - Edge(long_src, long_dst, attr) + val vv = indexedVertices.select(LONG_ID, ATTR).rdd.map { + case Row(long_id: Long, attr: Row) => (long_id, attr) + } + val ee = indexedEdges.select(LONG_SRC, LONG_DST, ATTR).rdd.map { + case Row(long_src: Long, long_dst: Long, attr: Row) => + Edge(long_src, long_dst, attr) } Graph(vv, ee) } @@ -197,9 +208,9 @@ class GraphFrame private( /** * The column names in the [[vertices]] DataFrame, in order. * - * Helper method for [[toGraphX]] which specifies the schema of vertex attributes. - * The vertex attributes of the returned `Graph` are given as a `Row`, - * and this method defines the column ordering in that `Row`. + * Helper method for [[toGraphX]] which specifies the schema of vertex attributes. The vertex + * attributes of the returned `Graph` are given as a `Row`, and this method defines the column + * ordering in that `Row`. * * @group conversions */ @@ -215,9 +226,9 @@ class GraphFrame private( /** * The vertex names in the [[vertices]] DataFrame, in order. * - * Helper method for [[toGraphX]] which specifies the schema of edge attributes. - * The edge attributes of the returned `edges` are given as a `Row`, - * and this method defines the column ordering in that `Row`. + * Helper method for [[toGraphX]] which specifies the schema of edge attributes. The edge + * attributes of the returned `edges` are given as a `Row`, and this method defines the column + * ordering in that `Row`. * * @group conversions */ @@ -234,8 +245,8 @@ class GraphFrame private( /** * The out-degree of each vertex in the graph, returned as a DataFrame with two columns: - * - [[GraphFrame.ID]] the ID of the vertex - * - "outDegree" (integer) storing the out-degree of the vertex + * - [[GraphFrame.ID]] the ID of the vertex + * - "outDegree" (integer) storing the out-degree of the vertex * Note that vertices with 0 out-edges are not returned in the result. * * @group degree @@ -246,9 +257,9 @@ class GraphFrame private( /** * The in-degree of each vertex in the graph, returned as a DataFame with two columns: - * - [[GraphFrame.ID]] the ID of the vertex - * "- "inDegree" (int) storing the in-degree of the vertex - * Note that vertices with 0 in-edges are not returned in the result. + * - [[GraphFrame.ID]] the ID of the vertex + * "- "inDegree" (int) storing the in-degree of the vertex Note that vertices with 0 in-edges + * are not returned in the result. * * @group degree */ @@ -258,14 +269,17 @@ class GraphFrame private( /** * The degree of each vertex in the graph, returned as a DataFrame with two columns: - * - [[GraphFrame.ID]] the ID of the vertex - * - 'degree' (integer) the degree of the vertex + * - [[GraphFrame.ID]] the ID of the vertex + * - 'degree' (integer) the degree of the vertex * Note that vertices with 0 edges are not returned in the result. * * @group degree */ @transient lazy val degrees: DataFrame = { - edges.select(explode(array(SRC, DST)).as(ID)).groupBy(ID).agg(count("*").cast("int").as("degree")) + edges + .select(explode(array(SRC, DST)).as(ID)) + .groupBy(ID) + .agg(count("*").cast("int").as("degree")) } // ============================ Motif finding ======================================== @@ -275,59 +289,59 @@ class GraphFrame private( * * Motif finding uses a simple Domain-Specific Language (DSL) for expressing structural queries. * For example, `graph.find("(a)-[e]->(b); (b)-[e2]->(a)")` will search for pairs of vertices - * `a,b` connected by edges in both directions. It will return a `DataFrame` of all such - * structures in the graph, with columns for each of the named elements (vertices or edges) - * in the motif. In this case, the returned columns will be in order of the pattern: - * "a, e, b, e2." + * `a,b` connected by edges in both directions. It will return a `DataFrame` of all such + * structures in the graph, with columns for each of the named elements (vertices or edges) in + * the motif. In this case, the returned columns will be in order of the pattern: "a, e, b, e2." * * DSL for expressing structural patterns: - * - The basic unit of a pattern is an edge. - * For example, `"(a)-[e]->(b)"` expresses an edge `e` from vertex `a` to vertex `b`. - * Note that vertices are denoted by parentheses `(a)`, while edges are denoted by - * square brackets `[e]`. - * - A pattern is expressed as a union of edges. Edge patterns can be joined with semicolons. - * Motif `"(a)-[e]->(b); (b)-[e2]->(c)"` specifies two edges from `a` to `b` to `c`. - * - Within a pattern, names can be assigned to vertices and edges. For example, - * `"(a)-[e]->(b)"` has three named elements: vertices `a,b` and edge `e`. - * These names serve two purposes: - * - The names can identify common elements among edges. For example, + * - The basic unit of a pattern is an edge. For example, `"(a)-[e]->(b)"` expresses an edge + * `e` from vertex `a` to vertex `b`. Note that vertices are denoted by parentheses `(a)`, + * while edges are denoted by square brackets `[e]`. + * - A pattern is expressed as a union of edges. Edge patterns can be joined with semicolons. + * Motif `"(a)-[e]->(b); (b)-[e2]->(c)"` specifies two edges from `a` to `b` to `c`. + * - Within a pattern, names can be assigned to vertices and edges. For example, + * `"(a)-[e]->(b)"` has three named elements: vertices `a,b` and edge `e`. These names serve + * two purposes: + * - The names can identify common elements among edges. For example, * `"(a)-[e]->(b); (b)-[e2]->(c)"` specifies that the same vertex `b` is the destination * of edge `e` and source of edge `e2`. - * - The names are used as column names in the result `DataFrame`. If a motif contains - * named vertex `a`, then the result `DataFrame` will contain a column "a" which is a + * - The names are used as column names in the result `DataFrame`. If a motif contains named + * vertex `a`, then the result `DataFrame` will contain a column "a" which is a * `StructType` with sub-fields equivalent to the schema (columns) of - * [[GraphFrame.vertices]]. Similarly, an edge `e` in a motif will produce a column "e" - * in the result `DataFrame` with sub-fields equivalent to the schema (columns) of + * [[GraphFrame.vertices]]. Similarly, an edge `e` in a motif will produce a column "e" in + * the result `DataFrame` with sub-fields equivalent to the schema (columns) of * [[GraphFrame.edges]]. * - Be aware that names do *not* identify *distinct* elements: two elements with different - * names may refer to the same graph element. For example, in the motif + * names may refer to the same graph element. For example, in the motif * `"(a)-[e]->(b); (b)-[e2]->(c)"`, the names `a` and `c` could refer to the same vertex. - * To restrict named elements to be distinct vertices or edges, use post-hoc filters - * such as `resultDataframe.filter("a.id != c.id")`. - * - It is acceptable to omit names for vertices or edges in motifs when not needed. - * E.g., `"(a)-[]->(b)"` expresses an edge between vertices `a,b` but does not assign a name - * to the edge. There will be no column for the anonymous edge in the result `DataFrame`. - * Similarly, `"(a)-[e]->()"` indicates an out-edge of vertex `a` but does not name - * the destination vertex. These are called *anonymous* vertices and edges. - * - An edge can be negated to indicate that the edge should *not* be present in the graph. - * E.g., `"(a)-[]->(b); !(b)-[]->(a)"` finds edges from `a` to `b` for which there is *no* - * edge from `b` to `a`. + * To restrict named elements to be distinct vertices or edges, use post-hoc filters such + * as `resultDataframe.filter("a.id != c.id")`. + * - It is acceptable to omit names for vertices or edges in motifs when not needed. E.g., + * `"(a)-[]->(b)"` expresses an edge between vertices `a,b` but does not assign a name to + * the edge. There will be no column for the anonymous edge in the result `DataFrame`. + * Similarly, `"(a)-[e]->()"` indicates an out-edge of vertex `a` but does not name the + * destination vertex. These are called *anonymous* vertices and edges. + * - An edge can be negated to indicate that the edge should *not* be present in the graph. + * E.g., `"(a)-[]->(b); !(b)-[]->(a)"` finds edges from `a` to `b` for which there is *no* + * edge from `b` to `a`. * * Restrictions: - * - Motifs are not allowed to contain edges without any named elements: `"()-[]->()"` and - * `"!()-[]->()"` are prohibited terms. - * - Motifs are not allowed to contain named edges within negated terms (since these named - * edges would never appear within results). E.g., `"!(a)-[ab]->(b)"` is invalid, but - * `"!(a)-[]->(b)"` is valid. + * - Motifs are not allowed to contain edges without any named elements: `"()-[]->()"` and + * `"!()-[]->()"` are prohibited terms. + * - Motifs are not allowed to contain named edges within negated terms (since these named + * edges would never appear within results). E.g., `"!(a)-[ab]->(b)"` is invalid, but + * `"!(a)-[]->(b)"` is valid. * - * More complex queries, such as queries which operate on vertex or edge attributes, - * can be expressed by applying filters to the result `DataFrame`. + * More complex queries, such as queries which operate on vertex or edge attributes, can be + * expressed by applying filters to the result `DataFrame`. * - * This can return duplicate rows. E.g., a query `"(u)-[]->()"` will return a result for each + * This can return duplicate rows. E.g., a query `"(u)-[]->()"` will return a result for each * matching edge, even if those edges share the same vertex `u`. * - * @param pattern Pattern specifying a motif to search for. - * @return `DataFrame` containing all instances of the motif. + * @param pattern + * Pattern specifying a motif to search for. + * @return + * `DataFrame` containing all instances of the motif. * @group motif */ def find(pattern: String): DataFrame = { @@ -342,7 +356,7 @@ class GraphFrame private( val df = findSimple(augmentedPatterns) val names = Pattern.findNamedElementsInOrder(patterns, includeEdges = true) - if (names.isEmpty) df else df.select(names.head, names.tail : _*) + if (names.isEmpty) df else df.select(names.head, names.tail: _*) } // ======================== Other queries =================================== @@ -357,9 +371,9 @@ class GraphFrame private( def bfs: BFS = new BFS(this) /** - * This is a primitive for implementing graph algorithms. - * This method aggregates values from the neighboring edges and vertices of each vertex. - * See [[org.graphframes.lib.AggregateMessages AggregateMessages]] for detailed documentation. + * This is a primitive for implementing graph algorithms. This method aggregates values from the + * neighboring edges and vertices of each vertex. See + * [[org.graphframes.lib.AggregateMessages AggregateMessages]] for detailed documentation. */ def aggregateMessages: AggregateMessages = new AggregateMessages(this) @@ -370,8 +384,9 @@ class GraphFrame private( */ def filterVertices(condition: Column): GraphFrame = { val vv = vertices.filter(condition) - val ee = edges.join(vv, vv(ID) === edges(SRC), "left_semi") - .join(vv, vv(ID) === edges(DST), "left_semi") + val ee = edges + .join(vv, vv(ID) === edges(SRC), "left_semi") + .join(vv, vv(ID) === edges(DST), "left_semi") GraphFrame(vv, ee) } @@ -382,7 +397,7 @@ class GraphFrame private( */ def filterVertices(conditionExpr: String): GraphFrame = filterVertices(expr(conditionExpr)) - /** + /** * Filter the edges according to Column expression, keep all vertices. * @group subgraph */ @@ -392,7 +407,7 @@ class GraphFrame private( GraphFrame(vv, ee) } - /** + /** * Filter the edges according to String expression. * @group subgraph */ @@ -439,19 +454,20 @@ class GraphFrame private( def pageRank: PageRank = new PageRank(this) /** - * Parallel personalized PageRank algorithm. - * - * See [[org.graphframes.lib.ParallelPersonalizedPageRank]] for more details. - * - * @group stdlib - */ + * Parallel personalized PageRank algorithm. + * + * See [[org.graphframes.lib.ParallelPersonalizedPageRank]] for more details. + * + * @group stdlib + */ def parallelPersonalizedPageRank: ParallelPersonalizedPageRank = new ParallelPersonalizedPageRank(this) /** * Pregel algorithm. * - * @see [[org.graphframes.lib.Pregel]] + * @see + * [[org.graphframes.lib.Pregel]] * @group stdlib */ def pregel = new Pregel(this) @@ -496,11 +512,12 @@ class GraphFrame private( // ========= Motif finding (private) ========= /** - * Primary method implementing motif finding. - * This iterative method handles one pattern (via [[findIncremental()]] on each iteration, - * augmenting the `DataFrame` in prevDF with each new pattern. + * Primary method implementing motif finding. This iterative method handles one pattern (via + * [[findIncremental()]] on each iteration, augmenting the `DataFrame` in prevDF with each new + * pattern. * - * @return `DataFrame` containing all instances of the motif specified by the given patterns + * @return + * `DataFrame` containing all instances of the motif specified by the given patterns */ private def findSimple(patterns: Seq[Pattern]): DataFrame = { val (_, finalDFOpt, _) = @@ -519,38 +536,42 @@ class GraphFrame private( /** * True if the id type can be cast to Long. * - * This is important for performance reasons. The underlying graphx - * implementation only deals with Long types. + * This is important for performance reasons. The underlying graphx implementation only deals + * with Long types. */ private[graphframes] lazy val hasIntegralIdType: Boolean = { vertices.schema(ID).dataType match { - case _ @ (ByteType | IntegerType | LongType | ShortType) => true + case _ @(ByteType | IntegerType | LongType | ShortType) => true case _ => false } } /** - * Vertices with each vertex assigned a unique long ID. - * If the vertex ID type is integral, this casts the original IDs to long. + * Vertices with each vertex assigned a unique long ID. If the vertex ID type is integral, this + * casts the original IDs to long. * * Columns: - * - $LONG_ID: the new ID of LongType - * - $ORIGINAL_ID: the ID provided by the user - * - $ATTR: all the original vertex attributes + * - $LONG_ID: the new ID of LongType + * - $ORIGINAL_ID: the ID provided by the user + * - $ATTR: all the original vertex attributes */ private[graphframes] lazy val indexedVertices: DataFrame = { if (hasIntegralIdType) { val indexedVertices = vertices.select(nestAsCol(vertices, ATTR)) indexedVertices.select( - col(ATTR + "." + ID).cast("long").as(LONG_ID), col(ATTR + "." + ID).as(ID), col(ATTR)) + col(ATTR + "." + ID).cast("long").as(LONG_ID), + col(ATTR + "." + ID).as(ID), + col(ATTR)) } else { - val withLongIds = vertices.select(ID) + val withLongIds = vertices + .select(ID) .repartition(col(ID)) .distinct() .sortWithinPartitions(ID) .withColumn(LONG_ID, monotonically_increasing_id()) .persist(StorageLevel.MEMORY_AND_DISK) - vertices.select(col(ID), nestAsCol(vertices, ATTR)) + vertices + .select(col(ID), nestAsCol(vertices, ATTR)) .join(withLongIds, ID) .select(LONG_ID, ID, ATTR) } @@ -558,31 +579,41 @@ class GraphFrame private( /** * Columns: - * - $SRC - * - $LONG_SRC - * - $DST - * - $LONG_DST - * - $ATTR + * - $SRC + * - $LONG_SRC + * - $DST + * - $LONG_DST + * - $ATTR */ private[graphframes] lazy val indexedEdges: DataFrame = { val packedEdges = edges.select(col(SRC), col(DST), nestAsCol(edges, ATTR)) if (hasIntegralIdType) { packedEdges.select( - col(SRC), col(SRC).cast("long").as(LONG_SRC), - col(DST), col(DST).cast("long").as(LONG_DST), + col(SRC), + col(SRC).cast("long").as(LONG_SRC), + col(DST), + col(DST).cast("long").as(LONG_DST), col(ATTR)) } else { val threshold = broadcastThreshold - val hubs: Set[Any] = degrees.filter(col("degree") >= threshold).select(ID) - .collect().map(_.get(0)).toSet + val hubs: Set[Any] = degrees + .filter(col("degree") >= threshold) + .select(ID) + .collect() + .map(_.get(0)) + .toSet val indexedSourceEdges = GraphFrame.skewedJoin( packedEdges, indexedVertices.select(col(ID).as(SRC), col(LONG_ID).as(LONG_SRC)), - SRC, hubs, "GraphFrame.indexedEdges:") + SRC, + hubs, + "GraphFrame.indexedEdges:") val indexedEdges = GraphFrame.skewedJoin( indexedSourceEdges, indexedVertices.select(col(ID).as(DST), col(LONG_ID).as(LONG_DST)), - DST, hubs, "GraphFrame.indexedEdges:") + DST, + hubs, + "GraphFrame.indexedEdges:") indexedEdges.select(SRC, LONG_SRC, DST, LONG_DST, ATTR) } } @@ -595,26 +626,33 @@ class GraphFrame private( } /** - * A cached conversion of this graph to the GraphX structure, with the data stored for each edge and vertex. + * A cached conversion of this graph to the GraphX structure, with the data stored for each edge + * and vertex. */ @transient private lazy val cachedGraphX: Graph[Row, Row] = { toGraphX } } - object GraphFrame extends Serializable with Logging { /** * Implements `a.join(b, joinCol)`, handling skew in the join keys. - * @param a DataFrame which may have multiple rows with the same key in `joinCol` - * @param b DataFrame which has exactly 1 row for every key in `a.joinCol`. - * @param joinCol Name of column on which to do join - * @param hubs Set of join keys which are high-degree (skewed) - * @param logPrefix Prefix for logging, e.g., name of algorithm doing the join - * @return `a.join(b, joinCol)` - * @tparam T DataType for join key - */ - private[graphframes] def skewedJoin[T : TypeTag]( + * @param a + * DataFrame which may have multiple rows with the same key in `joinCol` + * @param b + * DataFrame which has exactly 1 row for every key in `a.joinCol`. + * @param joinCol + * Name of column on which to do join + * @param hubs + * Set of join keys which are high-degree (skewed) + * @param logPrefix + * Prefix for logging, e.g., name of algorithm doing the join + * @return + * `a.join(b, joinCol)` + * @tparam T + * DataType for join key + */ + private[graphframes] def skewedJoin[T: TypeTag]( a: DataFrame, b: DataFrame, joinCol: String, @@ -630,43 +668,44 @@ object GraphFrame extends Serializable with Logging { val isHub = udf { id: T => hubs.contains(id) } - val hashJoined = a.filter(!isHub(col(joinCol))) + val hashJoined = a + .filter(!isHub(col(joinCol))) .join(b.filter(!isHub(col(joinCol))), joinCol) - val broadcastJoined = a.filter(isHub(col(joinCol))) + val broadcastJoined = a + .filter(isHub(col(joinCol))) .join(broadcast(b.filter(isHub(col(joinCol)))), joinCol) hashJoined.unionAll(broadcastJoined) } } /** - * Column name for vertex IDs in [[GraphFrame.vertices]] - * Note that GraphFrame assigns a unique long ID to each vertex, - * If the vertex ID type is one of byte / int / long / short type, - * GraphFrame casts the original IDs to long as the unique long ID, - * otherwise GraphFrame generates the unique long ID by Spark function - * ``monotonically_increasing_id`` which is less performant. + * Column name for vertex IDs in [[GraphFrame.vertices]] Note that GraphFrame assigns a unique + * long ID to each vertex, If the vertex ID type is one of byte / int / long / short type, + * GraphFrame casts the original IDs to long as the unique long ID, otherwise GraphFrame + * generates the unique long ID by Spark function ``monotonically_increasing_id`` which is less + * performant. */ val ID: String = "id" /** * Column name for source vertices of edges. - * - In [[GraphFrame.edges]], this is a column of vertex IDs. - * - In [[GraphFrame.triplets]], this is a column of vertices with schema matching - * [[GraphFrame.vertices]]. + * - In [[GraphFrame.edges]], this is a column of vertex IDs. + * - In [[GraphFrame.triplets]], this is a column of vertices with schema matching + * [[GraphFrame.vertices]]. */ val SRC: String = "src" /** * Column name for destination vertices of edges. - * - In [[GraphFrame.edges]], this is a column of vertex IDs. - * - In [[GraphFrame.triplets]], this is a column of vertices with schema matching - * [[GraphFrame.vertices]]. + * - In [[GraphFrame.edges]], this is a column of vertex IDs. + * - In [[GraphFrame.triplets]], this is a column of vertices with schema matching + * [[GraphFrame.vertices]]. */ val DST: String = "dst" /** - * Column name for edge in [[GraphFrame.triplets]]. In [[GraphFrame.triplets]], - * this is a column of edges with schema matching [[GraphFrame.edges]]. + * Column name for edge in [[GraphFrame.triplets]]. In [[GraphFrame.triplets]], this is a column + * of edges with schema matching [[GraphFrame.edges]]. */ val EDGE: String = "edge" @@ -675,20 +714,26 @@ object GraphFrame extends Serializable with Logging { /** * Create a new [[GraphFrame]] from vertex and edge `DataFrame`s. * - * @param vertices Vertex DataFrame. This must include a column "id" containing unique vertex IDs. - * All other columns are treated as vertex attributes. - * @param edges Edge DataFrame. This must include columns "src" and "dst" containing source and - * destination vertex IDs. All other columns are treated as edge attributes. - * @return New [[GraphFrame]] instance + * @param vertices + * Vertex DataFrame. This must include a column "id" containing unique vertex IDs. All other + * columns are treated as vertex attributes. + * @param edges + * Edge DataFrame. This must include columns "src" and "dst" containing source and destination + * vertex IDs. All other columns are treated as edge attributes. + * @return + * New [[GraphFrame]] instance */ def apply(vertices: DataFrame, edges: DataFrame): GraphFrame = { - require(vertices.columns.contains(ID), + require( + vertices.columns.contains(ID), s"Vertex ID column '$ID' missing from vertex DataFrame, which has columns: " + vertices.columns.mkString(",")) - require(edges.columns.contains(SRC), + require( + edges.columns.contains(SRC), s"Source vertex ID column '$SRC' missing from edge DataFrame, which has columns: " + edges.columns.mkString(",")) - require(edges.columns.contains(DST), + require( + edges.columns.contains(DST), s"Destination vertex ID column '$DST' missing from edge DataFrame, which has columns: " + edges.columns.mkString(",")) @@ -696,14 +741,16 @@ object GraphFrame extends Serializable with Logging { } /** - * Create a new [[GraphFrame]] from an edge `DataFrame`. - * The resulting [[GraphFrame]] will have [[GraphFrame.vertices]] with a single "id" column. + * Create a new [[GraphFrame]] from an edge `DataFrame`. The resulting [[GraphFrame]] will have + * [[GraphFrame.vertices]] with a single "id" column. * * Note: The [[GraphFrame.vertices]] DataFrame will be persisted at level - * `StorageLevel.MEMORY_AND_DISK`. - * @param e Edge DataFrame. This must include columns "src" and "dst" containing source and - * destination vertex IDs. All other columns are treated as edge attributes. - * @return New [[GraphFrame]] instance + * `StorageLevel.MEMORY_AND_DISK`. + * @param e + * Edge DataFrame. This must include columns "src" and "dst" containing source and destination + * vertex IDs. All other columns are treated as edge attributes. + * @return + * New [[GraphFrame]] instance * * @group conversions */ @@ -718,64 +765,66 @@ object GraphFrame extends Serializable with Logging { /** * Converts a GraphX `Graph` instance into a [[GraphFrame]]. * - * This converts each `org.apache.spark.rdd.RDD` in the `Graph` to a `DataFrame` using - * schema inference. + * This converts each `org.apache.spark.rdd.RDD` in the `Graph` to a `DataFrame` using schema + * inference. * - * Vertex ID column names will be converted to "id" for the vertex DataFrame, - * and to "src" and "dst" for the edge DataFrame. + * Vertex ID column names will be converted to "id" for the vertex DataFrame, and to "src" and + * "dst" for the edge DataFrame. * * @group conversions */ // TODO: Add version which takes explicit schemas. - def fromGraphX[VD : TypeTag, ED : TypeTag](graph: Graph[VD, ED]): GraphFrame = { + def fromGraphX[VD: TypeTag, ED: TypeTag](graph: Graph[VD, ED]): GraphFrame = { val spark = SparkSession.builder().getOrCreate() val vv = spark.createDataFrame(graph.vertices).toDF(ID, ATTR) val ee = spark.createDataFrame(graph.edges).toDF(SRC, DST, ATTR) GraphFrame(vv, ee) } - /** * Given: - * - a GraphFrame `originalGraph` - * - a GraphX graph derived from the GraphFrame using [[GraphFrame.toGraphX]] + * - a GraphFrame `originalGraph` + * - a GraphX graph derived from the GraphFrame using [[GraphFrame.toGraphX]] * this method merges attributes from the GraphX graph into the original GraphFrame. * - * This method is useful for doing computations using the GraphX API and then merging the results - * with a GraphFrame. For example, given: - * - GraphFrame `originalGraph` - * - GraphX Graph[String, Int] `graph` with a String vertex attribute we want to call "category" - * and an Int edge attribute we want to call "count" - * We can call `fromGraphX(originalGraph, graph, Seq("category"), Seq("count"))` to produce - * a new GraphFrame. The new GraphFrame will be an augmented version of `originalGraph`, - * with new [[GraphFrame.vertices]] column "category" and new [[GraphFrame.edges]] column - * "count" added. + * This method is useful for doing computations using the GraphX API and then merging the + * results with a GraphFrame. For example, given: + * - GraphFrame `originalGraph` + * - GraphX Graph[String, Int] `graph` with a String vertex attribute we want to call + * "category" and an Int edge attribute we want to call "count" + * We can call `fromGraphX(originalGraph, graph, Seq("category"), Seq("count"))` to produce a + * new GraphFrame. The new GraphFrame will be an augmented version of `originalGraph`, with new + * [[GraphFrame.vertices]] column "category" and new [[GraphFrame.edges]] column "count" added. * * See [[org.graphframes.examples.BeliefPropagation]] for example usage. * - * @param originalGraph Original GraphFrame used to compute the GraphX graph. - * @param graph GraphX graph. Vertex and edge attributes, if any, will be merged into - * the original graph as new columns. If the attributes are `Product` types - * such as tuples, then each element of the `Product` will be put in a separate - * column. If the attributes are other types, then the entire GraphX attribute - * will become a single new column. - * @param vertexNames Column name(s) for vertex attributes in the GraphX graph. - * If there is no vertex attribute, this should be empty. - * If there is a singleton attribute, this should have a single column name. - * If the attribute is a `Product` type, this should be a list of names - * matching the order of the attribute elements. - * @param edgeNames Column name(s) for edge attributes in the GraphX graph. - * If there is no edge attribute, this should be empty. - * If there is a singleton attribute, this should have a single column name. - * If the attribute is a `Product` type, this should be a list of names - * matching the order of the attribute elements. - * @tparam V the type of the vertex data - * @tparam E the type of the edge data - * @return original graph augmented with vertex and column attributes from the GraphX graph + * @param originalGraph + * Original GraphFrame used to compute the GraphX graph. + * @param graph + * GraphX graph. Vertex and edge attributes, if any, will be merged into the original graph as + * new columns. If the attributes are `Product` types such as tuples, then each element of the + * `Product` will be put in a separate column. If the attributes are other types, then the + * entire GraphX attribute will become a single new column. + * @param vertexNames + * Column name(s) for vertex attributes in the GraphX graph. If there is no vertex attribute, + * this should be empty. If there is a singleton attribute, this should have a single column + * name. If the attribute is a `Product` type, this should be a list of names matching the + * order of the attribute elements. + * @param edgeNames + * Column name(s) for edge attributes in the GraphX graph. If there is no edge attribute, this + * should be empty. If there is a singleton attribute, this should have a single column name. + * If the attribute is a `Product` type, this should be a list of names matching the order of + * the attribute elements. + * @tparam V + * the type of the vertex data + * @tparam E + * the type of the edge data + * @return + * original graph augmented with vertex and column attributes from the GraphX graph * * @group conversions */ - def fromGraphX[V : TypeTag, E : TypeTag]( + def fromGraphX[V: TypeTag, E: TypeTag]( originalGraph: GraphFrame, graph: Graph[V, E], vertexNames: Seq[String] = Nil, @@ -783,7 +832,6 @@ object GraphFrame extends Serializable with Logging { GraphXConversions.fromGraphX[V, E](originalGraph, graph, vertexNames, edgeNames) } - // ============== Private constants ============== /** Default name for attribute columns when converting from GraphX [[Graph]] format */ @@ -798,16 +846,15 @@ object GraphFrame extends Serializable with Logging { private[graphframes] val LONG_DST: String = "new_dst" private[graphframes] val GX_ATTR: String = "graphx_attr" - - /** Helper for using [col].* in Spark 1.4. Returns sequence of [col].[field] for all fields */ private[graphframes] def colStar(df: DataFrame, col: String): Seq[String] = { df.schema(col).dataType match { case s: StructType => s.fieldNames.map(f => col + "." + f).toIndexedSeq case other => - throw new RuntimeException(s"Unknown error in GraphFrame. Expected column $col to be" + - s" StructType, but found type: $other") + throw new RuntimeException( + s"Unknown error in GraphFrame. Expected column $col to be" + + s" StructType, but found type: $other") } } @@ -825,7 +872,6 @@ object GraphFrame extends Serializable with Logging { private def eSrcId(name: String): String = prefixWithName(name, SRC) private def eDstId(name: String): String = prefixWithName(name, DST) - private def maybeCrossJoin(aOpt: Option[DataFrame], b: DataFrame): DataFrame = { aOpt match { case Some(a) => a.crossJoin(b) @@ -843,7 +889,6 @@ object GraphFrame extends Serializable with Logging { } } - /** Indicate whether a named vertex has been seen in any of the given patterns */ private def seen(v: NamedVertex, patterns: Seq[Pattern]) = patterns.exists(p => seen1(v, p)) @@ -861,14 +906,17 @@ object GraphFrame extends Serializable with Logging { false } - /** * Augment the given DataFrame based on a pattern. * - * @param prevPatterns Patterns which have contributed to the given DataFrame - * @param prev Given DataFrame - * @param pattern Pattern to search for - * @return DataFrame augmented with the current search pattern + * @param prevPatterns + * Patterns which have contributed to the given DataFrame + * @param prev + * Given DataFrame + * @param pattern + * Pattern to search for + * @return + * DataFrame augmented with the current search pattern */ private def findIncremental( gf: GraphFrame, @@ -899,93 +947,114 @@ object GraphFrame extends Serializable with Logging { case NamedEdge(name, AnonymousVertex, dst @ NamedVertex(dstName)) => if (seen(dst, prevPatterns)) { val eRen = nestE(name) - (Some(maybeJoin(prev, eRen, prev => eRen(eDstId(name)) === prev(vId(dstName)))), + ( + Some(maybeJoin(prev, eRen, prev => eRen(eDstId(name)) === prev(vId(dstName)))), prevNames :+ name) } else { val eRen = nestE(name) val dstV = nestV(dstName) - (Some(maybeCrossJoin(prev, eRen) - .join(dstV, eRen(eDstId(name)) === dstV(vId(dstName)))), + ( + Some( + maybeCrossJoin(prev, eRen) + .join(dstV, eRen(eDstId(name)) === dstV(vId(dstName)))), prevNames :+ name :+ dstName) } case NamedEdge(name, src @ NamedVertex(srcName), AnonymousVertex) => if (seen(src, prevPatterns)) { val eRen = nestE(name) - (Some(maybeJoin(prev, eRen, prev => eRen(eSrcId(name)) === prev(vId(srcName)))), + ( + Some(maybeJoin(prev, eRen, prev => eRen(eSrcId(name)) === prev(vId(srcName)))), prevNames :+ name) } else { val eRen = nestE(name) val srcV = nestV(srcName) - (Some(maybeCrossJoin(prev, eRen) - .join(srcV, eRen(eSrcId(name)) === srcV(vId(srcName)))), - prevNames :+ srcName :+ name) + ( + Some( + maybeCrossJoin(prev, eRen) + .join(srcV, eRen(eSrcId(name)) === srcV(vId(srcName)))), + prevNames :+ srcName :+ name) } case NamedEdge(name, src @ NamedVertex(srcName), dst @ NamedVertex(dstName)) => (seen(src, prevPatterns), seen(dst, prevPatterns)) match { case (true, true) => val eRen = nestE(name) - (Some(maybeJoin(prev, eRen, prev => - eRen(eSrcId(name)) === prev(vId(srcName)) && eRen(eDstId(name)) === prev(vId(dstName)))), + ( + Some( + maybeJoin( + prev, + eRen, + prev => + eRen(eSrcId(name)) === prev(vId(srcName)) && eRen(eDstId(name)) === prev( + vId(dstName)))), prevNames :+ name) case (true, false) => val eRen = nestE(name) val dstV = nestV(dstName) - (Some(maybeJoin(prev, eRen, prev => eRen(eSrcId(name)) === prev(vId(srcName))) - .join(dstV, eRen(eDstId(name)) === dstV(vId(dstName)))), + ( + Some( + maybeJoin(prev, eRen, prev => eRen(eSrcId(name)) === prev(vId(srcName))) + .join(dstV, eRen(eDstId(name)) === dstV(vId(dstName)))), prevNames :+ name :+ dstName) case (false, true) => val eRen = nestE(name) val srcV = nestV(srcName) - (Some(maybeJoin(prev, eRen, prev => eRen(eDstId(name)) === prev(vId(dstName))) - .join(srcV, eRen(eSrcId(name)) === srcV(vId(srcName)))), + ( + Some( + maybeJoin(prev, eRen, prev => eRen(eDstId(name)) === prev(vId(dstName))) + .join(srcV, eRen(eSrcId(name)) === srcV(vId(srcName)))), prevNames :+ srcName :+ name) case (false, false) if srcName != dstName => val eRen = nestE(name) val srcV = nestV(srcName) val dstV = nestV(dstName) - (Some(maybeCrossJoin(prev, eRen) - .join(srcV, eRen(eSrcId(name)) === srcV(vId(srcName))) - .join(dstV, eRen(eDstId(name)) === dstV(vId(dstName)))), + ( + Some( + maybeCrossJoin(prev, eRen) + .join(srcV, eRen(eSrcId(name)) === srcV(vId(srcName))) + .join(dstV, eRen(eDstId(name)) === dstV(vId(dstName)))), prevNames :+ srcName :+ name :+ dstName) // TODO: expose the plans from joining these in the opposite order case (false, false) if srcName == dstName => val eRen = nestE(name) val srcV = nestV(srcName) - (Some(maybeCrossJoin(prev, eRen) - .join(srcV, - eRen(eSrcId(name)) === srcV(vId(srcName)) && - eRen(eDstId(name)) === srcV(vId(srcName)))), + ( + Some( + maybeCrossJoin(prev, eRen) + .join( + srcV, + eRen(eSrcId(name)) === srcV(vId(srcName)) && + eRen(eDstId(name)) === srcV(vId(srcName)))), prevNames :+ srcName :+ name) } case AnonymousEdge(src, dst) => val tmpName = "__tmp" + random.nextLong.toString - val (df, names) = findIncremental(gf, prevPatterns, prev, prevNames, NamedEdge(tmpName, src, dst)) + val (df, names) = + findIncremental(gf, prevPatterns, prev, prevNames, NamedEdge(tmpName, src, dst)) (df.map(_.drop(tmpName)), names.filter(_ != tmpName)) - case Negation(edge) => prev match { - case Some(p) => - val (df, names) = findIncremental(gf, prevPatterns, Some(p), prevNames, edge) - (df.map(result => p.except(result)), names) - case None => - throw new InvalidPatternException - } + case Negation(edge) => + prev match { + case Some(p) => + val (df, names) = findIncremental(gf, prevPatterns, Some(p), prevNames, edge) + (df.map(result => p.except(result)), names) + case None => + throw new InvalidPatternException + } } } /** - * Controls broadcast threshold in skewed joins. - * Use normal joins for vertices with degrees less than the threshold, - * and broadcast joins otherwise. - * The default value is 1000000. - * If we have less than 100 billion edges, this would collect at most - * 2e11 / 1000000 = 200000 hubs, which could be handled by the driver. + * Controls broadcast threshold in skewed joins. Use normal joins for vertices with degrees less + * than the threshold, and broadcast joins otherwise. The default value is 1000000. If we have + * less than 100 billion edges, this would collect at most 2e11 / 1000000 = 200000 hubs, which + * could be handled by the driver. */ private[this] var _broadcastThreshold: Int = 1000000 diff --git a/src/main/scala/org/graphframes/examples/BeliefPropagation.scala b/src/main/scala/org/graphframes/examples/BeliefPropagation.scala index 9074bd13c..df97ba85a 100644 --- a/src/main/scala/org/graphframes/examples/BeliefPropagation.scala +++ b/src/main/scala/org/graphframes/examples/BeliefPropagation.scala @@ -26,45 +26,43 @@ import org.graphframes.GraphFrame import org.graphframes.examples.Graphs.gridIsingModel import org.graphframes.lib.AggregateMessages - /** * Example code for Belief Propagation (BP) * - * This provides a template for building customized BP algorithms for different types of - * graphical models. + * This provides a template for building customized BP algorithms for different types of graphical + * models. * * This example: - * - Ising model on a grid - * - Parallel Belief Propagation using colored fields + * - Ising model on a grid + * - Parallel Belief Propagation using colored fields * - * Ising models are probabilistic graphical models over binary variables x,,i,,. - * Each binary variable x,,i,, corresponds to one vertex, and it may take values -1 or +1. - * The probability distribution P(X) (over all x,,i,,) is parameterized by vertex factors a,,i,, - * and edge factors b,,ij,,: + * Ising models are probabilistic graphical models over binary variables x,,i,,. Each binary + * variable x,,i,, corresponds to one vertex, and it may take values -1 or +1. The probability + * distribution P(X) (over all x,,i,,) is parameterized by vertex factors a,,i,, and edge factors + * b,,ij,,: * {{{ * P(X) = (1/Z) * exp[ \sum_i a_i x_i + \sum_{ij} b_{ij} x_i x_j ] * }}} - * where Z is the normalization constant (partition function). - * See [[https://en.wikipedia.org/wiki/Ising_model Wikipedia]] for more information on Ising models. + * where Z is the normalization constant (partition function). See + * [[https://en.wikipedia.org/wiki/Ising_model Wikipedia]] for more information on Ising models. * * Belief Propagation (BP) provides marginal probabilities of the values of the variables x,,i,,, - * i.e., P(x,,i,,) for each i. This allows a user to understand likely values of variables. - * See [[https://en.wikipedia.org/wiki/Belief_propagation Wikipedia]] for more information on BP. + * i.e., P(x,,i,,) for each i. This allows a user to understand likely values of variables. See + * [[https://en.wikipedia.org/wiki/Belief_propagation Wikipedia]] for more information on BP. * * We use a batch synchronous BP algorithm, where batches of vertices are updated synchronously. * We follow the mean field update algorithm in Slide 13 of the - * [[http://www.eecs.berkeley.edu/~wainwrig/Talks/A_GraphModel_Tutorial talk slides]] from: - * Wainwright. "Graphical models, message-passing algorithms, and convex optimization." + * [[http://www.eecs.berkeley.edu/~wainwrig/Talks/A_GraphModel_Tutorial talk slides]] from: + * Wainwright. "Graphical models, message-passing algorithms, and convex optimization." * - * The batches are chosen according to a coloring. For background on graph colorings for inference, - * see for example: - * Gonzalez et al. "Parallel Gibbs Sampling: From Colored Fields to Thin Junction Trees." - * AISTATS, 2011. + * The batches are chosen according to a coloring. For background on graph colorings for + * inference, see for example: Gonzalez et al. "Parallel Gibbs Sampling: From Colored Fields to + * Thin Junction Trees." AISTATS, 2011. * * The BP algorithm works by: - * - Coloring the graph by assigning a color to each vertex such that no neighboring vertices - * share the same color. - * - In each step of BP, update all vertices of a single color. Alternate colors. + * - Coloring the graph by assigning a color to each vertex such that no neighboring vertices + * share the same color. + * - In each step of BP, update all vertices of a single color. Alternate colors. */ object BeliefPropagation { @@ -94,14 +92,16 @@ object BeliefPropagation { } /** - * Given a GraphFrame, choose colors for each vertex. No neighboring vertices will share the - * same color. The number of colors is minimized. + * Given a GraphFrame, choose colors for each vertex. No neighboring vertices will share the + * same color. The number of colors is minimized. * * This is written specifically for grid graphs. For non-grid graphs, it should be generalized, * such as by using a greedy coloring scheme. * - * @param g Grid graph generated by [[org.graphframes.examples.Graphs.gridIsingModel()]] - * @return Same graph, but with a new vertex column "color" of type Int (0 or 1) + * @param g + * Grid graph generated by [[org.graphframes.examples.Graphs.gridIsingModel()]] + * @return + * Same graph, but with a new vertex column "color" of type Int (0 or 1) */ private def colorGraph(g: GraphFrame): GraphFrame = { val colorUDF = udf { (i: Int, j: Int) => (i + j) % 2 } @@ -112,19 +112,22 @@ object BeliefPropagation { /** * Run Belief Propagation. * - * This implementation of BP shows how to use GraphX's aggregateMessages method. - * It is simple to convert to and from GraphX format. This method does the following: - * - Color GraphFrame vertices for BP scheduling. - * - Convert GraphFrame to GraphX format. - * - Run BP using GraphX's aggregateMessages API. - * - Augment the original GraphFrame with the BP results (vertex beliefs). + * This implementation of BP shows how to use GraphX's aggregateMessages method. It is simple to + * convert to and from GraphX format. This method does the following: + * - Color GraphFrame vertices for BP scheduling. + * - Convert GraphFrame to GraphX format. + * - Run BP using GraphX's aggregateMessages API. + * - Augment the original GraphFrame with the BP results (vertex beliefs). * - * @param g Graphical model created by `org.graphframes.examples.Graphs.gridIsingModel()` - * @param numIter Number of iterations of BP to run. One iteration includes updating each - * vertex's belief once. - * @return Same graphical model, but with [[GraphFrame.vertices]] augmented with a new column - * "belief" containing P(x,,i,, = +1), the marginal probability of vertex i taking - * value +1 instead of -1. + * @param g + * Graphical model created by `org.graphframes.examples.Graphs.gridIsingModel()` + * @param numIter + * Number of iterations of BP to run. One iteration includes updating each vertex's belief + * once. + * @return + * Same graphical model, but with [[GraphFrame.vertices]] augmented with a new column "belief" + * containing P(x,,i,, = +1), the marginal probability of vertex i taking value +1 instead of + * -1. */ def runBPwithGraphX(g: GraphFrame, numIter: Int): GraphFrame = { // Choose colors for vertices for BP scheduling. @@ -166,15 +169,14 @@ object BeliefPropagation { }, _ + _) // Receive messages, and update beliefs for vertices of the current color. - gx = gx.outerJoinVertices(msgs) { - case (vID, vAttr, optMsg) => - if (vAttr.color == color) { - val x = vAttr.a + optMsg.getOrElse(0.0) - val newBelief = math.exp(-log1pExp(-x)) - VertexAttr(vAttr.a, newBelief, color) - } else { - vAttr - } + gx = gx.outerJoinVertices(msgs) { case (vID, vAttr, optMsg) => + if (vAttr.color == color) { + val x = vAttr.a + optMsg.getOrElse(0.0) + val newBelief = math.exp(-log1pExp(-x)) + VertexAttr(vAttr.a, newBelief, color) + } else { + vAttr + } } } } @@ -192,16 +194,19 @@ object BeliefPropagation { * Run Belief Propagation. * * This implementation of BP shows how to use GraphFrame's aggregateMessages method. - * - Color GraphFrame vertices for BP scheduling. - * - Run BP using GraphFrame's aggregateMessages API. - * - Augment the original GraphFrame with the BP results (vertex beliefs). + * - Color GraphFrame vertices for BP scheduling. + * - Run BP using GraphFrame's aggregateMessages API. + * - Augment the original GraphFrame with the BP results (vertex beliefs). * - * @param g Graphical model created by `org.graphframes.examples.Graphs.gridIsingModel()` - * @param numIter Number of iterations of BP to run. One iteration includes updating each - * vertex's belief once. - * @return Same graphical model, but with [[GraphFrame.vertices]] augmented with a new column - * "belief" containing P(x,,i,, = +1), the marginal probability of vertex i taking - * value +1 instead of -1. + * @param g + * Graphical model created by `org.graphframes.examples.Graphs.gridIsingModel()` + * @param numIter + * Number of iterations of BP to run. One iteration includes updating each vertex's belief + * once. + * @return + * Same graphical model, but with [[GraphFrame.vertices]] augmented with a new column "belief" + * containing P(x,,i,, = +1), the marginal probability of vertex i taking value +1 instead of + * -1. */ def runBPwithGraphFrames(g: GraphFrame, numIter: Int): GraphFrame = { // Choose colors for vertices for BP scheduling. @@ -231,15 +236,16 @@ object BeliefPropagation { .agg(sum(AM.msg).as("aggMess")) val v = gx.vertices // Receive messages, and update beliefs for vertices of the current color. - val newBeliefCol = when(v("color") === color && aggregates("aggMess").isNotNull, + val newBeliefCol = when( + v("color") === color && aggregates("aggMess").isNotNull, logistic(aggregates("aggMess") + v("a"))) - .otherwise(v("belief")) // keep old beliefs for other colors + .otherwise(v("belief")) // keep old beliefs for other colors val newVertices = v - .join(aggregates, v("id") === aggregates("id"), "left_outer") // join messages, vertices - .drop(aggregates("id")) // drop duplicate ID column (from outer join) - .withColumn("newBelief", newBeliefCol) // compute new beliefs - .drop("aggMess") // drop messages - .drop("belief") // drop old beliefs + .join(aggregates, v("id") === aggregates("id"), "left_outer") // join messages, vertices + .drop(aggregates("id")) // drop duplicate ID column (from outer join) + .withColumn("newBelief", newBeliefCol) // compute new beliefs + .drop("aggMess") // drop messages + .drop("belief") // drop old beliefs .withColumnRenamed("newBelief", "belief") // Cache new vertices using workaround for SPARK-13346 val cachedNewVertices = AM.getCachedDataFrame(newVertices) diff --git a/src/main/scala/org/graphframes/examples/Graphs.scala b/src/main/scala/org/graphframes/examples/Graphs.scala index 2ea19e525..5066fc2ce 100644 --- a/src/main/scala/org/graphframes/examples/Graphs.scala +++ b/src/main/scala/org/graphframes/examples/Graphs.scala @@ -43,13 +43,15 @@ class Graphs private[graphframes] () { } /** - * Returns a chain graph of the given size with Long ID type. - * The vertex IDs are 0, 1, ..., n-1, and the edges are (0, 1), (1, 2), ...., (n-2, n-1). + * Returns a chain graph of the given size with Long ID type. The vertex IDs are 0, 1, ..., n-1, + * and the edges are (0, 1), (1, 2), ...., (n-2, n-1). */ def chain(n: Long): GraphFrame = { require(n >= 0, s"Chain graph size must be nonnegative but got $n.") val vertices = spark.range(n).toDF(ID) - val edges = spark.range(n - 1L).toDF(ID) + val edges = spark + .range(n - 1L) + .toDF(ID) .select(col(ID).as(SRC), (col(ID) + 1L).as(DST)) GraphFrame(vertices, edges) } @@ -60,33 +62,38 @@ class Graphs private[graphframes] () { def friends: GraphFrame = { // For the same reason as above, this cannot be a value. // Vertex DataFrame - val v = spark.createDataFrame(List( - ("a", "Alice", 34), - ("b", "Bob", 36), - ("c", "Charlie", 30), - ("d", "David", 29), - ("e", "Esther", 32), - ("f", "Fanny", 36), - ("g", "Gabby", 60) - )).toDF("id", "name", "age") + val v = spark + .createDataFrame( + List( + ("a", "Alice", 34), + ("b", "Bob", 36), + ("c", "Charlie", 30), + ("d", "David", 29), + ("e", "Esther", 32), + ("f", "Fanny", 36), + ("g", "Gabby", 60))) + .toDF("id", "name", "age") // Edge DataFrame - val e = spark.createDataFrame(List( - ("a", "b", "friend"), - ("b", "c", "follow"), - ("c", "b", "follow"), - ("f", "c", "follow"), - ("e", "f", "follow"), - ("e", "d", "friend"), - ("d", "a", "friend"), - ("a", "e", "friend") - )).toDF("src", "dst", "relationship") + val e = spark + .createDataFrame( + List( + ("a", "b", "friend"), + ("b", "c", "follow"), + ("c", "b", "follow"), + ("f", "c", "follow"), + ("e", "f", "follow"), + ("e", "d", "friend"), + ("d", "a", "friend"), + ("a", "e", "friend"))) + .toDF("src", "dst", "relationship") // Create a GraphFrame GraphFrame(v, e) } /** * Two densely connected blobs (vertices 0->n-1 and n->2n-1) connected by a single edge (0->n) - * @param blobSize the size of each blob. + * @param blobSize + * the size of each blob. * @return */ def twoBlobs(blobSize: Int): GraphFrame = { @@ -94,7 +101,8 @@ class Graphs private[graphframes] () { val edges1 = for (v1 <- 0 until n; v2 <- 0 until n) yield (v1.toLong, v2.toLong, s"$v1-$v2") val edges2 = for { v1 <- n until (2 * n) - v2 <- n until (2 * n) } yield (v1.toLong, v2.toLong, s"$v1-$v2") + v2 <- n until (2 * n) + } yield (v1.toLong, v2.toLong, s"$v1-$v2") val edges = edges1 ++ edges2 :+ (0L, n.toLong, s"0-$n") val vertices = (0 until (2 * n)).map { v => (v.toLong, s"$v", v) } val e = spark.createDataFrame(edges).toDF("src", "dst", "e_attr1") @@ -105,7 +113,8 @@ class Graphs private[graphframes] () { /** * Returns a star graph with Long ID type, consisting of a central element indexed 0 (the root) * and the n other leaf vertices 1, 2, ..., n. - * @param n the number of leaves + * @param n + * the number of leaves */ def star(n: Long): GraphFrame = { require(n >= 0L) @@ -127,7 +136,8 @@ class Graphs private[graphframes] () { (fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble) } val edges = spark.createDataFrame(data).toDF("src", "dst", "weight") - val vs = data.flatMap(r => r._1 :: r._2 :: Nil).collect().distinct.map(x => Tuple1(x)).toIndexedSeq + val vs = + data.flatMap(r => r._1 :: r._2 :: Nil).collect().distinct.map(x => Tuple1(x)).toIndexedSeq val vertices = spark.createDataFrame(vs).toDF("id") GraphFrame(vertices, edges) } @@ -155,29 +165,32 @@ class Graphs private[graphframes] () { /** * This method generates a grid Ising model with random parameters. * - * Ising models are probabilistic graphical models over binary variables x,,i,,. - * Each binary variable x,,i,, corresponds to one vertex, and it may take values -1 or +1. - * The probability distribution P(X) (over all x,,i,,) is parameterized by vertex factors a,,i,, - * and edge factors b,,ij,,: + * Ising models are probabilistic graphical models over binary variables x,,i,,. Each binary + * variable x,,i,, corresponds to one vertex, and it may take values -1 or +1. The probability + * distribution P(X) (over all x,,i,,) is parameterized by vertex factors a,,i,, and edge + * factors b,,ij,,: * {{{ * P(X) = (1/Z) * exp[ \sum_i a_i x_i + \sum_{ij} b_{ij} x_i x_j ] * }}} * where Z is the normalization constant (partition function). See * [[https://en.wikipedia.org/wiki/Ising_model Wikipedia]] for more information on Ising models. * - * Each vertex is parameterized by a single scalar a,,i,,. - * Each edge is parameterized by a single scalar b,,ij,,. + * Each vertex is parameterized by a single scalar a,,i,,. Each edge is parameterized by a + * single scalar b,,ij,,. * - * @param n Length of one side of the grid. The grid will be of size n x n. - * @param vStd Standard deviation of normal distribution used to generate vertex factors "a". - * Default of 1.0. - * @param eStd Standard deviation of normal distribution used to generate edge factors "b". - * Default of 1.0. - * @return GraphFrame. Vertices have columns "id" and "a". - * Edges have columns "src", "dst", and "b". Edges are directed, but they should be - * treated as undirected in any algorithms run on this model. - * Vertex IDs are of the form "i,j". E.g., vertex "1,3" is in the second row and fourth - * column of the grid. + * @param n + * Length of one side of the grid. The grid will be of size n x n. + * @param vStd + * Standard deviation of normal distribution used to generate vertex factors "a". Default of + * 1.0. + * @param eStd + * Standard deviation of normal distribution used to generate edge factors "b". Default of + * 1.0. + * @return + * GraphFrame. Vertices have columns "id" and "a". Edges have columns "src", "dst", and "b". + * Edges are directed, but they should be treated as undirected in any algorithms run on this + * model. Vertex IDs are of the form "i,j". E.g., vertex "1,3" is in the second row and fourth + * column of the grid. */ def gridIsingModel(spark: SparkSession, n: Int, vStd: Double, eStd: Double): GraphFrame = { require(n >= 1, s"Grid graph must have size >= 1, but was given invalid value n = $n") @@ -195,20 +208,23 @@ class Graphs private[graphframes] () { val vIDcol = toIDudf(col("i"), col("j")) // Add random parameters generated from a normal distribution val seed = 12345 - val vertices = coordinates.withColumn("id", vIDcol) // vertex IDs "i,j" - .withColumn("a", randn(seed) * vStd) // Ising parameter for vertex + val vertices = coordinates + .withColumn("id", vIDcol) // vertex IDs "i,j" + .withColumn("a", randn(seed) * vStd) // Ising parameter for vertex // Create the edge DataFrame // Create SQL expression for converting coordinates (i,j+1) and (i+1,j) to string IDs val rightIDcol = toIDudf(col("i"), col("j") + 1) val downIDcol = toIDudf(col("i") + 1, col("j")) - val horizontalEdges = coordinates.filter(col("j") =!= n - 1) + val horizontalEdges = coordinates + .filter(col("j") =!= n - 1) .select(vIDcol.as("src"), rightIDcol.as("dst")) - val verticalEdges = coordinates.filter(col("i") =!= n - 1) + val verticalEdges = coordinates + .filter(col("i") =!= n - 1) .select(vIDcol.as("src"), downIDcol.as("dst")) val allEdges = horizontalEdges.union(verticalEdges) // Add random parameters from a normal distribution - val edges = allEdges.withColumn("b", randn(seed + 1) * eStd) // Ising parameter for edge + val edges = allEdges.withColumn("b", randn(seed + 1) * eStd) // Ising parameter for edge // Create the GraphFrame val g = GraphFrame(vertices, edges) diff --git a/src/main/scala/org/graphframes/exceptions.scala b/src/main/scala/org/graphframes/exceptions.scala index e98ff84ee..fe2be9ba3 100644 --- a/src/main/scala/org/graphframes/exceptions.scala +++ b/src/main/scala/org/graphframes/exceptions.scala @@ -2,7 +2,6 @@ package org.graphframes // All the public exceptions thrown by GraphFrame methods - /** * Exception thrown when a pattern String for motif finding cannot be parsed. */ @@ -17,4 +16,4 @@ class NoSuchVertexException(message: String) extends Exception(message) * Exception thrown when a parsed pattern for motif finding cannot be translated into a DataFrame * query. */ -class InvalidPatternException() extends Exception() \ No newline at end of file +class InvalidPatternException() extends Exception() diff --git a/src/main/scala/org/graphframes/lib/AggregateMessages.scala b/src/main/scala/org/graphframes/lib/AggregateMessages.scala index 4c9efec1d..c3f721b21 100644 --- a/src/main/scala/org/graphframes/lib/AggregateMessages.scala +++ b/src/main/scala/org/graphframes/lib/AggregateMessages.scala @@ -23,41 +23,41 @@ import org.apache.spark.sql.{Column, DataFrame} import org.graphframes.{GraphFrame, Logging} /** - * This is a primitive for implementing graph algorithms. - * This method aggregates messages from the neighboring edges and vertices of each vertex. + * This is a primitive for implementing graph algorithms. This method aggregates messages from the + * neighboring edges and vertices of each vertex. * - * For each triplet (source vertex, edge, destination vertex) in [[GraphFrame.triplets]], - * this can send a message to the source and/or destination vertices. - * - `AggregateMessages.sendToSrc()` sends a message to the source vertex of each - * triplet - * - `AggregateMessages.sendToDst()` sends a message to the destination vertex of each - * triplet - * - `AggregateMessages.agg` specifies an aggregation function for aggregating the - * messages sent to each vertex. It also runs the aggregation, computing a DataFrame - * with one row for each vertex which receives > 0 messages. The DataFrame has 2 columns: + * For each triplet (source vertex, edge, destination vertex) in [[GraphFrame.triplets]], this can + * send a message to the source and/or destination vertices. + * - `AggregateMessages.sendToSrc()` sends a message to the source vertex of each triplet + * - `AggregateMessages.sendToDst()` sends a message to the destination vertex of each triplet + * - `AggregateMessages.agg` specifies an aggregation function for aggregating the messages sent + * to each vertex. It also runs the aggregation, computing a DataFrame with one row for each + * vertex which receives > 0 messages. The DataFrame has 2 columns: * - vertex column ID (named [[GraphFrame.ID]]) - * - aggregate from messages sent to vertex (with the name given to the `Column` specified - * in `AggregateMessages.agg()`) + * - aggregate from messages sent to vertex (with the name given to the `Column` specified in + * `AggregateMessages.agg()`) * * When specifying the messages and aggregation function, the user may reference columns using: - * - [[AggregateMessages.src]]: column for source vertex of edge - * - [[AggregateMessages.edge]]: column for edge - * - [[AggregateMessages.dst]]: column for destination vertex of edge - * - [[AggregateMessages.msg]]: message sent to vertex (for aggregation function) + * - [[AggregateMessages.src]]: column for source vertex of edge + * - [[AggregateMessages.edge]]: column for edge + * - [[AggregateMessages.dst]]: column for destination vertex of edge + * - [[AggregateMessages.msg]]: message sent to vertex (for aggregation function) * * Note: If you use this operation to write an iterative algorithm, you may want to use * [[AggregateMessages$.getCachedDataFrame getCachedDataFrame()]] as a workaround for caching * issues. * - * @example We can use this function to compute the in-degree of each vertex - * {{{ + * @example + * We can use this function to compute the in-degree of each vertex + * {{{ * val g: GraphFrame = Graph.textFile("twittergraph") * val inDeg: DataFrame = * g.aggregateMessages().sendToDst(lit(1)).agg(sum(AggregateMessagesBuilder.msg)) - * }}} + * }}} */ class AggregateMessages private[graphframes] (private val g: GraphFrame) - extends Arguments with Serializable { + extends Arguments + with Serializable { import org.graphframes.GraphFrame.{DST, ID, SRC} @@ -84,15 +84,14 @@ class AggregateMessages private[graphframes] (private val g: GraphFrame) def sendToDst(value: String): this.type = sendToDst(expr(value)) /** - * Run the aggregation, returning the resulting DataFrame of aggregated messages. - * This is a lazy operation, so the DataFrame will not be materialized until an action is - * executed on it. + * Run the aggregation, returning the resulting DataFrame of aggregated messages. This is a lazy + * operation, so the DataFrame will not be materialized until an action is executed on it. * * This returns a DataFrame with schema: - * - column "id": vertex ID - * - aggCol: aggregate result - * If you need to join this with the original [[GraphFrame.vertices]], you can run an inner - * join of the form: + * - column "id": vertex ID + * - aggCol: aggregate result + * If you need to join this with the original [[GraphFrame.vertices]], you can run an inner join + * of the form: * {{{ * val g: GraphFrame = ... * val aggResult = g.AggregateMessagesBuilder.sendToSrc(msg).agg(aggFunc) @@ -100,22 +99,24 @@ class AggregateMessages private[graphframes] (private val g: GraphFrame) * }}} */ def agg(aggCol: Column): DataFrame = { - require(msgToSrc.nonEmpty || msgToDst.nonEmpty, s"To run GraphFrame.aggregateMessages," + - s" messages must be sent to src, dst, or both. Set using sendToSrc(), sendToDst().") + require( + msgToSrc.nonEmpty || msgToDst.nonEmpty, + s"To run GraphFrame.aggregateMessages," + + s" messages must be sent to src, dst, or both. Set using sendToSrc(), sendToDst().") val triplets = g.triplets val sentMsgsToSrc = msgToSrc.map { msg => - val msgsToSrc = triplets.select( - msg.as(AggregateMessages.MSG_COL_NAME), - triplets(SRC)(ID).as(ID)) + val msgsToSrc = + triplets.select(msg.as(AggregateMessages.MSG_COL_NAME), triplets(SRC)(ID).as(ID)) // Inner join: only send messages to vertices with edges - msgsToSrc.join(g.vertices, ID) + msgsToSrc + .join(g.vertices, ID) .select(msgsToSrc(AggregateMessages.MSG_COL_NAME), col(ID)) } val sentMsgsToDst = msgToDst.map { msg => - val msgsToDst = triplets.select( - msg.as(AggregateMessages.MSG_COL_NAME), - triplets(DST)(ID).as(ID)) - msgsToDst.join(g.vertices, ID) + val msgsToDst = + triplets.select(msg.as(AggregateMessages.MSG_COL_NAME), triplets(DST)(ID).as(ID)) + msgsToDst + .join(g.vertices, ID) .select(msgsToDst(AggregateMessages.MSG_COL_NAME), col(ID)) } val unionedMsgs = (sentMsgsToSrc, sentMsgsToDst) match { @@ -158,12 +159,12 @@ object AggregateMessages extends Logging with Serializable { /** * Create a new cached copy of a DataFrame. For iterative DataFrame-based algorithms. * - * WARNING: This is NOT the same as `DataFrame.cache()`. - * The original DataFrame will NOT be cached. + * WARNING: This is NOT the same as `DataFrame.cache()`. The original DataFrame will NOT be + * cached. * * This is a workaround for SPARK-13346, which makes it difficult to use DataFrames in iterative - * algorithms. This workaround converts the DataFrame to an RDD, caches the RDD, and creates - * a new DataFrame. This is important for avoiding the creation of extremely complex DataFrame + * algorithms. This workaround converts the DataFrame to an RDD, caches the RDD, and creates a + * new DataFrame. This is important for avoiding the creation of extremely complex DataFrame * query plans when using DataFrames in iterative algorithms. */ def getCachedDataFrame(df: DataFrame): DataFrame = { diff --git a/src/main/scala/org/graphframes/lib/BFS.scala b/src/main/scala/org/graphframes/lib/BFS.scala index 80ebe313e..802d8f63c 100644 --- a/src/main/scala/org/graphframes/lib/BFS.scala +++ b/src/main/scala/org/graphframes/lib/BFS.scala @@ -27,64 +27,59 @@ import org.graphframes.GraphFrame.nestAsCol /** * Breadth-first search (BFS) * - * This method returns a DataFrame of valid shortest paths from vertices matching `fromExpr` - * to vertices matching `toExpr`. If multiple paths are valid and have the same length, - * the DataFrame will return one Row for each path. If no paths are valid, the DataFrame will - * be empty. - * Note: "Shortest" means globally shortest path. I.e., if the shortest path between two vertices + * This method returns a DataFrame of valid shortest paths from vertices matching `fromExpr` to + * vertices matching `toExpr`. If multiple paths are valid and have the same length, the DataFrame + * will return one Row for each path. If no paths are valid, the DataFrame will be empty. Note: + * "Shortest" means globally shortest path. I.e., if the shortest path between two vertices * matching `fromExpr` and `toExpr` is length 5 (edges) but no path is shorter than 5, then all * paths returned by BFS will have length 5. * * The returned DataFrame will have the following columns: - * - `from` start vertex of path - * - `e[i]` edge i in the path, indexed from 0 - * - `v[i]` intermediate vertex i in the path, indexed from 1 - * - `to` end vertex of path + * - `from` start vertex of path + * - `e[i]` edge i in the path, indexed from 0 + * - `v[i]` intermediate vertex i in the path, indexed from 1 + * - `to` end vertex of path * Each of these columns is a StructType whose fields are the same as the columns of * [[GraphFrame.vertices]] or [[GraphFrame.edges]]. * - * For example, suppose we have a graph g. Say the vertices DataFrame of g has columns "id" and + * For example, suppose we have a graph g. Say the vertices DataFrame of g has columns "id" and * "job", and the edges DataFrame of g has columns "src", "dst", and "relation". * {{{ * // Search from vertex "Joe" to find the closet vertices with attribute job = CEO. * g.bfs(col("id") === "Joe", col("job") === "CEO").run() * }}} * If we found a path of 3 edges, each row would have columns: - * {{{from | e0 | v1 | e1 | v2 | e2 | to}}} - * In the above row, each vertex column (from, v1, v2, to) would have fields "id" and "job" - * (just like g.vertices). - * Each edge column (e0, e1, e2) would have fields "src", "dst", and "relation". + * {{{from | e0 | v1 | e1 | v2 | e2 | to}}} In the above row, each vertex column (from, v1, v2, + * to) would have fields "id" and "job" (just like g.vertices). Each edge column (e0, e1, e2) + * would have fields "src", "dst", and "relation". * * If there are ties, then each of the equal paths will be returned as a separate Row. * - * If one or more vertices match both the from and to conditions, then there is a 0-hop path. - * The returned DataFrame will have the "from" and "to" columns (as above); however, - * the "from" and "to" columns will be exactly the same. There will be one row for each vertex - * in [[GraphFrame.vertices]] matching both `fromExpr` and `toExpr`. + * If one or more vertices match both the from and to conditions, then there is a 0-hop path. The + * returned DataFrame will have the "from" and "to" columns (as above); however, the "from" and + * "to" columns will be exactly the same. There will be one row for each vertex in + * [[GraphFrame.vertices]] matching both `fromExpr` and `toExpr`. * * Parameters: * - * - `fromExpr` Spark SQL expression specifying valid starting vertices for the BFS. - * This condition will be matched against each vertex's id or attributes. - * To start from a specific vertex, this could be "id = [start vertex id]". - * To start from multiple valid vertices, this can operate on vertex attributes. - * - * - `toExpr` Spark SQL expression specifying valid target vertices for the BFS. - * This condition will be matched against each vertex's id or attributes. - * - * - `maxPathLength` Limit on the length of paths. If no valid paths of length - * <= maxPathLength are found, then the BFS is terminated. - * (default = 10) - * - `edgeFilter` Spark SQL expression specifying edges which may be used in the search. - * This allows the user to disallow crossing certain edges. Such filters - * can be applied post-hoc after BFS, run specifying the filter here is more - * efficient. + * - `fromExpr` Spark SQL expression specifying valid starting vertices for the BFS. This + * condition will be matched against each vertex's id or attributes. To start from a specific + * vertex, this could be "id = [start vertex id]". To start from multiple valid vertices, this + * can operate on vertex attributes. + * - `toExpr` Spark SQL expression specifying valid target vertices for the BFS. This condition + * will be matched against each vertex's id or attributes. + * - `maxPathLength` Limit on the length of paths. If no valid paths of length <= maxPathLength + * are found, then the BFS is terminated. (default = 10) + * - `edgeFilter` Spark SQL expression specifying edges which may be used in the search. This + * allows the user to disallow crossing certain edges. Such filters can be applied post-hoc + * after BFS, run specifying the filter here is more efficient. * * Returns: - * - DataFrame of valid shortest paths found in the BFS + * - DataFrame of valid shortest paths found in the BFS */ class BFS private[graphframes] (private val graph: GraphFrame) - extends Arguments with Serializable { + extends Arguments + with Serializable { private var maxPathLength: Int = 10 private var edgeFilter: Option[Column] = None @@ -125,7 +120,6 @@ class BFS private[graphframes] (private val graph: GraphFrame) } } - private object BFS extends Logging with Serializable { private def run( @@ -147,7 +141,8 @@ private object BFS extends Logging with Serializable { if (fromEqualsToDF.take(1).nonEmpty) { // from == to, so return matching vertices return fromEqualsToDF.select( - nestAsCol(fromEqualsToDF, "from"), nestAsCol(fromEqualsToDF, "to")) + nestAsCol(fromEqualsToDF, "from"), + nestAsCol(fromEqualsToDF, "to")) } // We handled edge cases above, so now we do BFS. @@ -179,15 +174,20 @@ private object BFS extends Logging with Serializable { if (iter == 0) { // Note: We could avoid this special case by initializing paths with just 1 "from" column, // but that would create a longer lineage for the result DataFrame. - paths = a2b.filter(fromAExpr) - .filter(col("a.id") =!= col("b.id")) // remove self-loops - .withColumnRenamed("a", "from").withColumnRenamed("e", nextEdge) + paths = a2b + .filter(fromAExpr) + .filter(col("a.id") =!= col("b.id")) // remove self-loops + .withColumnRenamed("a", "from") + .withColumnRenamed("e", nextEdge) .withColumnRenamed("b", nextVertex) } else { val prevVertex = s"v$iter" - val nextLinks = a2b.withColumnRenamed("a", prevVertex).withColumnRenamed("e", nextEdge) + val nextLinks = a2b + .withColumnRenamed("a", prevVertex) + .withColumnRenamed("e", nextEdge) .withColumnRenamed("b", nextVertex) - paths = paths.join(nextLinks, paths(prevVertex + ".id") === nextLinks(prevVertex + ".id")) + paths = paths + .join(nextLinks, paths(prevVertex + ".id") === nextLinks(prevVertex + ".id")) .drop(paths(prevVertex)) // Make sure we are not backtracking within each path. // TODO: Avoid crossing paths; i.e., touch each vertex at most once. @@ -222,24 +222,24 @@ private object BFS extends Logging with Serializable { } else { logInfo(s"GraphFrame.bfs failed to find a path of length <= $maxPathLength.") // Return empty DataFrame - g.spark.createDataFrame( - g.spark.sparkContext.parallelize(Seq.empty[Row]), - g.vertices.schema) + g.spark.createDataFrame(g.spark.sparkContext.parallelize(Seq.empty[Row]), g.vertices.schema) } } - /** - * Apply the given SQL expression (such as `id = 3`) to the field in a column, - * rather than to the column itself. + * Apply the given SQL expression (such as `id = 3`) to the field in a column, rather than to + * the column itself. * - * @param expr SQL expression, such as `id = 3` - * @param colName Column name, such as `myVertex` - * @return SQL expression applied to the column fields, such as `myVertex.id = 3` + * @param expr + * SQL expression, such as `id = 3` + * @param colName + * Column name, such as `myVertex` + * @return + * SQL expression applied to the column fields, such as `myVertex.id = 3` */ private def applyExprToCol(expr: Column, colName: String) = { - new Column(expr.expr.transform { - case UnresolvedAttribute(nameParts) => UnresolvedAttribute(colName +: nameParts) + new Column(expr.expr.transform { case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(colName +: nameParts) }) } } diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index dbc1d98f7..68c13b85b 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -36,21 +36,22 @@ import org.graphframes.{GraphFrame, Logging} * information with each vertex assigned a component ID. * * The resulting DataFrame contains all the vertex information and one additional column: - * - component (`LongType`): unique ID for this component + * - component (`LongType`): unique ID for this component */ -class ConnectedComponents private[graphframes] ( - private val graph: GraphFrame) extends Arguments with Logging { +class ConnectedComponents private[graphframes] (private val graph: GraphFrame) + extends Arguments + with Logging { import org.graphframes.lib.ConnectedComponents._ private var broadcastThreshold: Int = 1000000 /** - * Sets broadcast threshold in propagating component assignments (default: 1000000). - * If a node degree is greater than this threshold at some iteration, its component assignment - * will be collected and then broadcasted back to propagate the assignment to its neighbors. - * Otherwise, the assignment propagation is done by a normal Spark join. - * This parameter is only used when the algorithm is set to "graphframes". + * Sets broadcast threshold in propagating component assignments (default: 1000000). If a node + * degree is greater than this threshold at some iteration, its component assignment will be + * collected and then broadcasted back to propagate the assignment to its neighbors. Otherwise, + * the assignment propagation is done by a normal Spark join. This parameter is only used when + * the algorithm is set to "graphframes". */ def setBroadcastThreshold(value: Int): this.type = { require(value >= 0, s"Broadcast threshold must be non-negative but got $value.") @@ -65,24 +66,27 @@ class ConnectedComponents private[graphframes] ( /** * Gets broadcast threshold in propagating component assignment. - * @see [[org.graphframes.lib.ConnectedComponents.setBroadcastThreshold]] + * @see + * [[org.graphframes.lib.ConnectedComponents.setBroadcastThreshold]] */ def getBroadcastThreshold: Int = broadcastThreshold private var algorithm: String = ALGO_GRAPHFRAMES /** - * Sets the connected components algorithm to use (default: "graphframes"). - * Supported algorithms are: + * Sets the connected components algorithm to use (default: "graphframes"). Supported algorithms + * are: * - "graphframes": Uses alternating large star and small star iterations proposed in * [[http://dx.doi.org/10.1145/2670979.2670997 Connected Components in MapReduce and Beyond]] * with skewed join optimization. * - "graphx": Converts the graph to a GraphX graph and then uses the connected components * implementation in GraphX. - * @see [[org.graphframes.lib.ConnectedComponents.supportedAlgorithms]] + * @see + * [[org.graphframes.lib.ConnectedComponents.supportedAlgorithms]] */ def setAlgorithm(value: String): this.type = { - require(supportedAlgorithms.contains(value), + require( + supportedAlgorithms.contains(value), s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $value.") algorithm = value this @@ -90,26 +94,26 @@ class ConnectedComponents private[graphframes] ( /** * Gets the connected component algorithm to use. - * @see [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]. + * @see + * [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]. */ def getAlgorithm: String = algorithm private var checkpointInterval: Int = 2 /** - * Sets checkpoint interval in terms of number of iterations (default: 2). - * Checkpointing regularly helps recover from failures, clean shuffle files, shorten the - * lineage of the computation graph, and reduce the complexity of plan optimization. - * As of Spark 2.0, the complexity of plan optimization would grow exponentially without - * checkpointing. - * Hence disabling or setting longer-than-default checkpoint intervals are not recommended. - * Checkpoint data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with - * prefix "connected-components". - * If the checkpoint directory is not set, this throws a `java.io.IOException`. - * Set a nonpositive value to disable checkpointing. - * This parameter is only used when the algorithm is set to "graphframes". - * Its default value might change in the future. - * @see `org.apache.spark.SparkContext.setCheckpointDir` in Spark API doc + * Sets checkpoint interval in terms of number of iterations (default: 2). Checkpointing + * regularly helps recover from failures, clean shuffle files, shorten the lineage of the + * computation graph, and reduce the complexity of plan optimization. As of Spark 2.0, the + * complexity of plan optimization would grow exponentially without checkpointing. Hence + * disabling or setting longer-than-default checkpoint intervals are not recommended. Checkpoint + * data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with prefix + * "connected-components". If the checkpoint directory is not set, this throws a + * `java.io.IOException`. Set a nonpositive value to disable checkpointing. This parameter is + * only used when the algorithm is set to "graphframes". Its default value might change in the + * future. + * @see + * `org.apache.spark.SparkContext.setCheckpointDir` in Spark API doc */ def setCheckpointInterval(value: Int): this.type = { if (value <= 0 || value > 2) { @@ -128,14 +132,16 @@ class ConnectedComponents private[graphframes] ( /** * Gets checkpoint interval. - * @see [[org.graphframes.lib.ConnectedComponents.setCheckpointInterval]] + * @see + * [[org.graphframes.lib.ConnectedComponents.setCheckpointInterval]] */ def getCheckpointInterval: Int = checkpointInterval private var intermediateStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK /** - * Sets storage level for intermediate datasets that require multiple passes (default: ``MEMORY_AND_DISK``). + * Sets storage level for intermediate datasets that require multiple passes (default: + * ``MEMORY_AND_DISK``). */ def setIntermediateStorageLevel(value: StorageLevel): this.type = { intermediateStorageLevel = value @@ -151,7 +157,8 @@ class ConnectedComponents private[graphframes] ( * Runs the algorithm. */ def run(): DataFrame = { - ConnectedComponents.run(graph, + ConnectedComponents.run( + graph, algorithm = algorithm, broadcastThreshold = broadcastThreshold, checkpointInterval = checkpointInterval, @@ -173,21 +180,20 @@ object ConnectedComponents extends Logging { private val ALGO_GRAPHFRAMES = "graphframes" /** - * Supported algorithms in [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]: "graphframes" - * and "graphx". + * Supported algorithms in [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]: + * "graphframes" and "graphx". */ val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES) /** * Returns the symmetric directed graph of the graph specified by input edges. - * @param ee non-bidirectional edges + * @param ee + * non-bidirectional edges */ private def symmetrize(ee: DataFrame): DataFrame = { val EDGE = "_edge" - ee.select(explode(array( - struct(col(SRC), col(DST)), - struct(col(DST).as(SRC), col(SRC).as(DST))) - ).as(EDGE)) + ee.select(explode( + array(struct(col(SRC), col(DST)), struct(col(DST).as(SRC), col(SRC).as(DST)))).as(EDGE)) .select(col(s"$EDGE.$SRC").as(SRC), col(s"$EDGE.$DST").as(DST)) } @@ -208,11 +214,12 @@ object ConnectedComponents extends Logging { // TODO: This assignment job might fail if the graph is skewed. val vertices = graph.indexedVertices .select(col(LONG_ID).as(ID), col(ATTR)) - // TODO: confirm the contract for a graph and decide whether we need distinct here - // .distinct() + // TODO: confirm the contract for a graph and decide whether we need distinct here + // .distinct() val edges = graph.indexedEdges .select(col(LONG_SRC).as(SRC), col(LONG_DST).as(DST)) - val orderedEdges = edges.filter(col(SRC) =!= col(DST)) + val orderedEdges = edges + .filter(col(SRC) =!= col(DST)) .select(minValue(col(SRC), col(DST)).as(SRC), maxValue(col(SRC), col(DST)).as(DST)) .distinct() GraphFrame(vertices, orderedEdges) @@ -226,7 +233,8 @@ object ConnectedComponents extends Logging { */ private def minNbrs(ee: DataFrame): DataFrame = { symmetrize(ee) - .groupBy(SRC).agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) + .groupBy(SRC) + .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) .withColumn(MIN_NBR, minValue(col(SRC), col(MIN_NBR))) } @@ -239,8 +247,8 @@ object ConnectedComponents extends Logging { } /** - * Performs a possibly skewed join between edges and current component assignments. - * The skew join is done by broadcast join for frequent keys and normal join for the rest. + * Performs a possibly skewed join between edges and current component assignments. The skew + * join is done by broadcast join for frequent keys and normal join for the rest. */ private def skewedJoin( edges: DataFrame, @@ -248,7 +256,8 @@ object ConnectedComponents extends Logging { broadcastThreshold: Int, logPrefix: String): DataFrame = { import edges.sparkSession.implicits._ - val hubs = minNbrs.filter(col(CNT) > broadcastThreshold) + val hubs = minNbrs + .filter(col(CNT) > broadcastThreshold) .select(SRC) .as[Long] .collect() @@ -264,7 +273,8 @@ object ConnectedComponents extends Logging { } private def runGraphX(graph: GraphFrame): DataFrame = { - val components = org.apache.spark.graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX) + val components = + org.apache.spark.graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX) GraphXConversions.fromGraphX(graph, components, vertexNames = Seq(COMPONENT)).vertices } @@ -274,7 +284,8 @@ object ConnectedComponents extends Logging { broadcastThreshold: Int, checkpointInterval: Int, intermediateStorageLevel: StorageLevel): DataFrame = { - require(supportedAlgorithms.contains(algorithm), + require( + supportedAlgorithms.contains(algorithm), s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $algorithm.") if (algorithm == ALGO_GRAPHX) { @@ -295,12 +306,14 @@ object ConnectedComponents extends Logging { val shouldCheckpoint = checkpointInterval > 0 val checkpointDir: Option[String] = if (shouldCheckpoint) { - val dir = sc.getCheckpointDir.map { d => - new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString - }.getOrElse { - throw new IOException( - "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir().") - } + val dir = sc.getCheckpointDir + .map { d => + new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString + } + .getOrElse { + throw new IOException( + "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir().") + } logInfo(s"$logPrefix Using $dir for checkpointing with interval $checkpointInterval.") Some(dir) } else { @@ -323,13 +336,15 @@ object ConnectedComponents extends Logging { // Taking the sum in DecimalType to preserve precision. // We use 20 digits for long values and Spark SQL will add 10 digits for the sum. // It should be able to handle 200 billion edges without overflow. - val (minNbrSum, cnt) = minNbrsDF.select(sum(col(MIN_NBR).cast(DecimalType(20, 0))), count("*")).rdd + val (minNbrSum, cnt) = minNbrsDF + .select(sum(col(MIN_NBR).cast(DecimalType(20, 0))), count("*")) + .rdd .map { r => (r.getAs[BigDecimal](0), r.getLong(1)) - }.first() + } + .first() if (cnt != 0L && minNbrSum == null) { - throw new ArithmeticException( - s""" + throw new ArithmeticException(s""" |The total sum of edge src IDs is used to determine convergence during iterations. |However, the total sum at iteration $iteration exceeded 30 digits (1e30), |which should happen only if the graph contains more than 200 billion edges. @@ -357,7 +372,9 @@ object ConnectedComponents extends Logging { // small-star step // compute min neighbors (excluding self-min) - val minNbrs2 = ee.groupBy(col(SRC)).agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr + val minNbrs2 = ee + .groupBy(col(SRC)) + .agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr .persist(intermediateStorageLevel) currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs2 @@ -366,7 +383,8 @@ object ConnectedComponents extends Logging { .select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst .filter(col(SRC) =!= col(DST)) // src < dst // connect self to the min neighbor - ee = ee.union(minNbrs2.select(col(MIN_NBR).as(SRC), col(SRC).as(DST))) // src < dst + ee = ee + .union(minNbrs2.select(col(MIN_NBR).as(SRC), col(SRC).as(DST))) // src < dst .distinct() // checkpointing @@ -406,7 +424,7 @@ object ConnectedComponents extends Logging { // materialize all persisted DataFrames in current round, // then we can unpersist last round persisted DataFrames. for (persisted_df <- currRoundPersistedDFs) { - persisted_df.count() // materialize it. + persisted_df.count() // materialize it. } for (persisted_df <- lastRoundPersistedDFs) { persisted_df.unpersist() diff --git a/src/main/scala/org/graphframes/lib/GraphXConversions.scala b/src/main/scala/org/graphframes/lib/GraphXConversions.scala index 13f9beb65..0143f0b7b 100644 --- a/src/main/scala/org/graphframes/lib/GraphXConversions.scala +++ b/src/main/scala/org/graphframes/lib/GraphXConversions.scala @@ -27,28 +27,28 @@ import org.apache.spark.sql.types.{StructField, StructType} import org.graphframes.{NoSuchVertexException, GraphFrame} /** - * Convenience functions to map GraphX graphs to GraphFrames, - * checking for the types expected by GraphX. + * Convenience functions to map GraphX graphs to GraphFrames, checking for the types expected by + * GraphX. */ private[graphframes] object GraphXConversions { import GraphFrame._ /** Indicates if T is a Unit type */ - private def isUnitType[T : TypeTag] : Boolean = { + private def isUnitType[T: TypeTag]: Boolean = { val t = typeOf[T] typeOf[Unit] =:= t } /** Indicates if T is a Product type */ - private def isProductType[T : TypeTag] : Boolean = { + private def isProductType[T: TypeTag]: Boolean = { val t = typeOf[T] // See http://stackoverflow.com/questions/21209006/how-to-check-if-reflected-type-represents-a-tuple t.typeSymbol.fullName.startsWith("scala.Tuple") } /** See [[GraphFrame.fromGraphX()]] for documentation */ - def fromGraphX[V : TypeTag, E : TypeTag]( + def fromGraphX[V: TypeTag, E: TypeTag]( originalGraph: GraphFrame, graph: Graph[V, E], vertexNames: Seq[String] = Nil, @@ -64,7 +64,7 @@ private[graphframes] object GraphXConversions { renameStructFields(vertexDF0, GX_ATTR, vertexNames) } else { // Assume it is just one field, and pack it in a tuple to have a structure. - val vertexData = graph.vertices.map { case (vid, data) => (vid, Tuple1(data)) } + val vertexData = graph.vertices.map { case (vid, data) => (vid, Tuple1(data)) } val vertexDF0 = spark.createDataFrame(vertexData).toDF(LONG_ID, GX_ATTR) renameStructFields(vertexDF0, GX_ATTR, vertexNames) } @@ -85,12 +85,14 @@ private[graphframes] object GraphXConversions { } /** - * Given the name of a column (assumed to contain a struct), - * renames all the fields of this struct. + * Given the name of a column (assumed to contain a struct), renames all the fields of this + * struct. * - * @param structName Struct name whose fields will be renamed. This method assumes this field - * exists and will not check for errors. - * @param fieldNames List of new field names corresponding to all fields in the struct col. + * @param structName + * Struct name whose fields will be renamed. This method assumes this field exists and will + * not check for errors. + * @param fieldNames + * List of new field names corresponding to all fields in the struct col. */ private[lib] def renameStructFields( df: DataFrame, @@ -127,12 +129,13 @@ private[graphframes] object GraphXConversions { } /** - * Joins all the data from the original columns against the new data. Assumes the columns - * are not going to conflict. + * Joins all the data from the original columns against the new data. Assumes the columns are + * not going to conflict. * - * @param gxVertexData DataFrame with column [[LONG_ID]] and optionally column [[GX_ATTR]] - * @param gxEdgeData DataFrame with columns [[LONG_DST]], [[LONG_SRC]] and optionally column - * [[GX_ATTR]] + * @param gxVertexData + * DataFrame with column [[LONG_ID]] and optionally column [[GX_ATTR]] + * @param gxEdgeData + * DataFrame with columns [[LONG_DST]], [[LONG_SRC]] and optionally column [[GX_ATTR]] */ private def fromGraphX( originalGraph: GraphFrame, @@ -146,16 +149,19 @@ private[graphframes] object GraphXConversions { val indexedEdges = originalGraph.indexedEdges // Handle 2 cases: GraphX edge has attr, or not. val hasGxAttr = gxEdgeData.schema.exists(_.name == GX_ATTR) - val gxCol = if (hasGxAttr) { Seq(col(GX_ATTR)) } else { Seq() } + val gxCol = if (hasGxAttr) { Seq(col(GX_ATTR)) } + else { Seq() } val sel1 = Seq(col(LONG_SRC), col(LONG_DST)) ++ gxCol val gxe = gxEdgeData.select(sel1: _*) val sel3 = Seq(col(ATTR)) ++ gxCol // TODO: CHECK IN UNIT TESTS: Drop the src and dst columns from the index, they are already // in the attributes and will be unpacked with the rest of the user columns. // TODO(tjh) 2-step join? - gxe.join( - indexedEdges.select(indexedEdges(LONG_SRC), indexedEdges(LONG_DST), indexedEdges(ATTR)), - (gxe(LONG_SRC) === indexedEdges(LONG_SRC)) && (gxe(LONG_DST) === indexedEdges(LONG_DST))) + gxe + .join( + indexedEdges.select(indexedEdges(LONG_SRC), indexedEdges(LONG_DST), indexedEdges(ATTR)), + (gxe(LONG_SRC) === indexedEdges(LONG_SRC)) && (gxe(LONG_DST) === indexedEdges( + LONG_DST))) .select(sel3: _*) } val edgeDF = unpackStructFields(drop(packedEdges, LONG_SRC, LONG_DST)) @@ -164,8 +170,8 @@ private[graphframes] object GraphXConversions { } /** - * Given a graph and an object, gets the the corresponding integral id in the - * internal representation. + * Given a graph and an object, gets the the corresponding integral id in the internal + * representation. */ private[graphframes] def integralId(graph: GraphFrame, vertexId: Any): Long = { // Check if we can directly convert it @@ -179,10 +185,12 @@ private[graphframes] object GraphXConversions { // If the vertex is a non-integral type such as a String, we need to use the translation table. val longIdRow: Array[Row] = graph.indexedVertices .filter(col(GraphFrame.ID) === vertexId) - .select(GraphFrame.LONG_ID).take(1) + .select(GraphFrame.LONG_ID) + .take(1) if (longIdRow.isEmpty) { - throw new NoSuchVertexException(s"GraphFrame algorithm given vertex ID which does not exist" + - s" in Graph. Vertex ID $vertexId not contained in $graph") + throw new NoSuchVertexException( + s"GraphFrame algorithm given vertex ID which does not exist" + + s" in Graph. Vertex ID $vertexId not contained in $graph") } // TODO(tjh): could do more informative message longIdRow.head.getLong(0) diff --git a/src/main/scala/org/graphframes/lib/LabelPropagation.scala b/src/main/scala/org/graphframes/lib/LabelPropagation.scala index a740dd5fd..877d7345b 100644 --- a/src/main/scala/org/graphframes/lib/LabelPropagation.scala +++ b/src/main/scala/org/graphframes/lib/LabelPropagation.scala @@ -30,19 +30,19 @@ import org.graphframes.GraphFrame * affiliation of incoming messages. * * LPA is a standard community detection algorithm for graphs. It is very inexpensive - * computationally, although (1) convergence is not guaranteed and (2) one can end up with - * trivial solutions (all nodes are identified into a single community). + * computationally, although (1) convergence is not guaranteed and (2) one can end up with trivial + * solutions (all nodes are identified into a single community). * * The resulting DataFrame contains all the original vertex information and one additional column: - * - label (`LongType`): label of community affiliation + * - label (`LongType`): label of community affiliation */ class LabelPropagation private[graphframes] (private val graph: GraphFrame) extends Arguments { private var maxIter: Option[Int] = None /** - * The max number of iterations of LPA to be performed. Because this is a static - * implementation, the algorithm will run for exactly this many iterations. + * The max number of iterations of LPA to be performed. Because this is a static implementation, + * the algorithm will run for exactly this many iterations. */ def maxIter(value: Int): this.type = { maxIter = Some(value) @@ -50,13 +50,10 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame) exte } def run(): DataFrame = { - LabelPropagation.run( - graph, - check(maxIter, "maxIter")) + LabelPropagation.run(graph, check(maxIter, "maxIter")) } } - private object LabelPropagation { private def run(graph: GraphFrame, maxIter: Int): DataFrame = { val gx = graphxlib.LabelPropagation.run(graph.cachedTopologyGraphX, maxIter) diff --git a/src/main/scala/org/graphframes/lib/PageRank.scala b/src/main/scala/org/graphframes/lib/PageRank.scala index 2c9a9cf2e..b5bcfa53a 100644 --- a/src/main/scala/org/graphframes/lib/PageRank.scala +++ b/src/main/scala/org/graphframes/lib/PageRank.scala @@ -24,9 +24,9 @@ import org.graphframes.GraphFrame /** * PageRank algorithm implementation. There are two implementations of PageRank. * - * The first one uses the `org.apache.spark.graphx.graph` interface with `aggregateMessages` and runs - * PageRank for a fixed number of iterations. This can be executed by setting `maxIter`. Conceptually, - * the algorithm does the following: + * The first one uses the `org.apache.spark.graphx.graph` interface with `aggregateMessages` and + * runs PageRank for a fixed number of iterations. This can be executed by setting `maxIter`. + * Conceptually, the algorithm does the following: * {{{ * var PR = Array.fill(n)( 1.0 ) * val oldPR = Array.fill(n)( 1.0 ) @@ -38,8 +38,9 @@ import org.graphframes.GraphFrame * } * }}} * - * The second implementation uses the `org.apache.spark.graphx.Pregel` interface and runs PageRank until - * convergence and this can be run by setting `tol`. Conceptually, the algorithm does the following: + * The second implementation uses the `org.apache.spark.graphx.Pregel` interface and runs PageRank + * until convergence and this can be run by setting `tol`. Conceptually, the algorithm does the + * following: * {{{ * var PR = Array.fill(n)( 1.0 ) * val oldPR = Array.fill(n)( 0.0 ) @@ -51,26 +52,24 @@ import org.graphframes.GraphFrame * } * }}} * - * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of - * neighbors which link to `i` and `outDeg[j]` is the out degree of vertex `j`. + * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of neighbors + * which link to `i` and `outDeg[j]` is the out degree of vertex `j`. * - * Note that this is not the "normalized" PageRank and as a consequence pages that have no - * inlinks will have a PageRank of alpha. In particular, the pageranks may have some values - * greater than 1. + * Note that this is not the "normalized" PageRank and as a consequence pages that have no inlinks + * will have a PageRank of alpha. In particular, the pageranks may have some values greater than 1. * * The resulting vertices DataFrame contains one additional column: - * - pagerank (`DoubleType`): the pagerank of this vertex + * - pagerank (`DoubleType`): the pagerank of this vertex * * The resulting edges DataFrame contains one additional column: - * - weight (`DoubleType`): the normalized weight of this edge after running PageRank + * - weight (`DoubleType`): the normalized weight of this edge after running PageRank */ -class PageRank private[graphframes] ( - private val graph: GraphFrame) extends Arguments { +class PageRank private[graphframes] (private val graph: GraphFrame) extends Arguments { private var tol: Option[Double] = None private var resetProb: Option[Double] = Some(0.15) private var maxIter: Option[Int] = None - private var srcId : Option[Any] = None + private var srcId: Option[Any] = None /** Source vertex for a Personalized Page Rank (optional) */ def sourceId(value: Any): this.type = { @@ -97,8 +96,7 @@ class PageRank private[graphframes] ( def run(): GraphFrame = { tol match { case Some(t) => - assert(maxIter.isEmpty, - "You cannot specify maxIter() and tol() at the same time.") + assert(maxIter.isEmpty, "You cannot specify maxIter() and tol() at the same time.") PageRank.runUntilConvergence(graph, t, resetProb.get, srcId) case None => PageRank.run(graph, check(maxIter, "maxIter"), resetProb.get, srcId) @@ -106,20 +104,23 @@ class PageRank private[graphframes] ( } } - // TODO: srcID's type should be checked. The most futureproof check would be Encoder because it is // compatible with Datasets after that. private object PageRank { + /** - * Run PageRank for a fixed number of iterations returning a graph - * with vertex attributes containing the PageRank and edge - * attributes the normalized edge weight. + * Run PageRank for a fixed number of iterations returning a graph with vertex attributes + * containing the PageRank and edge attributes the normalized edge weight. * - * @param graph the graph on which to compute PageRank - * @param maxIter the number of iterations of PageRank to run - * @param resetProb the random reset probability (alpha) - * @return the graph containing with each vertex containing the PageRank and each edge - * containing the normalized weight. + * @param graph + * the graph on which to compute PageRank + * @param maxIter + * the number of iterations of PageRank to run + * @param resetProb + * the random reset probability (alpha) + * @return + * the graph containing with each vertex containing the PageRank and each edge containing the + * normalized weight. */ def run( graph: GraphFrame, @@ -127,8 +128,8 @@ private object PageRank { resetProb: Double = 0.15, srcId: Option[Any] = None): GraphFrame = { val longSrcId = srcId.map(GraphXConversions.integralId(graph, _)) - val gx = graphxlib.PageRank.runWithOptions( - graph.cachedTopologyGraphX, maxIter, resetProb, longSrcId) + val gx = + graphxlib.PageRank.runWithOptions(graph.cachedTopologyGraphX, maxIter, resetProb, longSrcId) GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANK), edgeNames = Seq(WEIGHT)) } @@ -136,12 +137,17 @@ private object PageRank { * Run a dynamic version of PageRank returning a graph with vertex attributes containing the * PageRank and edge attributes containing the normalized edge weight. * - * @param graph the graph on which to compute PageRank - * @param tol the tolerance allowed at convergence (smaller => more accurate). - * @param resetProb the random reset probability (alpha) - * @param srcId the source vertex for a Personalized Page Rank (optional) - * @return the graph containing with each vertex containing the PageRank and each edge - * containing the normalized weight. + * @param graph + * the graph on which to compute PageRank + * @param tol + * the tolerance allowed at convergence (smaller => more accurate). + * @param resetProb + * the random reset probability (alpha) + * @param srcId + * the source vertex for a Personalized Page Rank (optional) + * @return + * the graph containing with each vertex containing the PageRank and each edge containing the + * normalized weight. */ def runUntilConvergence( graph: GraphFrame, @@ -150,7 +156,10 @@ private object PageRank { srcId: Option[Any] = None): GraphFrame = { val longSrcId = srcId.map(GraphXConversions.integralId(graph, _)) val gx = graphxlib.PageRank.runUntilConvergenceWithOptions( - graph.cachedTopologyGraphX, tol, resetProb, longSrcId) + graph.cachedTopologyGraphX, + tol, + resetProb, + longSrcId) GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANK), edgeNames = Seq(WEIGHT)) } diff --git a/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala b/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala index f8f55d732..4f4ecfa41 100644 --- a/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala +++ b/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala @@ -23,11 +23,10 @@ import org.graphframes.{GraphFrame, Logging} /** * Parallel Personalized PageRank algorithm implementation. * - * This implementation uses the standalone [[GraphFrame]] interface and - * runs personalized PageRank in parallel for a fixed number of iterations. - * This can be run by setting `maxIter`. - * The source vertex Ids are set in `sourceIds`. - * A simple local implementation of this algorithm is as follows. + * This implementation uses the standalone [[GraphFrame]] interface and runs personalized PageRank + * in parallel for a fixed number of iterations. This can be run by setting `maxIter`. The source + * vertex Ids are set in `sourceIds`. A simple local implementation of this algorithm is as + * follows. * {{{ * var oldPR = Array.fill(n)( 1.0 ) * val PR = (0 until n).map(i => if sourceIds.contains(i) alpha else 0.0) @@ -40,21 +39,20 @@ import org.graphframes.{GraphFrame, Logging} * } * }}} * - * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of - * neighbors which link to `i` and `outDeg[j]` is the out degree of vertex `j`. + * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of neighbors + * which link to `i` and `outDeg[j]` is the out degree of vertex `j`. * - * Note that this is not the "normalized" PageRank and as a consequence pages that have no - * inlinks will have a PageRank of alpha. In particular, the pageranks may have some values - * greater than 1. + * Note that this is not the "normalized" PageRank and as a consequence pages that have no inlinks + * will have a PageRank of alpha. In particular, the pageranks may have some values greater than 1. * * The resulting vertices DataFrame contains one additional column: - * - pageranks (`VectorType`): the pageranks of this vertex from all input source vertices + * - pageranks (`VectorType`): the pageranks of this vertex from all input source vertices * * The resulting edges DataFrame contains one additional column: - * - weight (`DoubleType`): the normalized weight of this edge after running PageRank + * - weight (`DoubleType`): the normalized weight of this edge after running PageRank */ -class ParallelPersonalizedPageRank private[graphframes] ( - private val graph: GraphFrame) extends Arguments { +class ParallelPersonalizedPageRank private[graphframes] (private val graph: GraphFrame) + extends Arguments { private var resetProb: Option[Double] = Some(0.15) private var maxIter: Option[Int] = None @@ -86,6 +84,7 @@ class ParallelPersonalizedPageRank private[graphframes] ( } private object ParallelPersonalizedPageRank { + /** Default name for the pageranks column. */ private val PAGERANKS = "pageranks" @@ -93,18 +92,21 @@ private object ParallelPersonalizedPageRank { private val WEIGHT = "weight" /** - * Run Personalized PageRank for a fixed number of iterations, for a - * set of starting nodes in parallel. Returns a graph with vertex attributes - * containing the pageranks relative to all starting nodes (as a vector) and - * edge attributes the normalized edge weight + * Run Personalized PageRank for a fixed number of iterations, for a set of starting nodes in + * parallel. Returns a graph with vertex attributes containing the pageranks relative to all + * starting nodes (as a vector) and edge attributes the normalized edge weight * - * @param graph The graph on which to compute personalized pagerank - * @param maxIter The number of iterations to run - * @param resetProb The random reset probability - * @param sourceIds The list of sources to compute personalized pagerank from - * @return the graph with vertex attributes - * containing the pageranks relative to all starting nodes as a vector and - * edge attributes the normalized edge weight + * @param graph + * The graph on which to compute personalized pagerank + * @param maxIter + * The number of iterations to run + * @param resetProb + * The random reset probability + * @param sourceIds + * The list of sources to compute personalized pagerank from + * @return + * the graph with vertex attributes containing the pageranks relative to all starting nodes as + * a vector and edge attributes the normalized edge weight */ def run( graph: GraphFrame, @@ -113,7 +115,10 @@ private object ParallelPersonalizedPageRank { sourceIds: Array[Any]): GraphFrame = { val longSrcIds = sourceIds.map(GraphXConversions.integralId(graph, _)) val gx = graphxlib.PageRank.runParallelPersonalizedPageRank( - graph.cachedTopologyGraphX, maxIter, resetProb, longSrcIds) + graph.cachedTopologyGraphX, + maxIter, + resetProb, + longSrcIds) GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANKS), edgeNames = Seq(WEIGHT)) } } diff --git a/src/main/scala/org/graphframes/lib/Pregel.scala b/src/main/scala/org/graphframes/lib/Pregel.scala index 50b207092..beba0eb4e 100644 --- a/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/src/main/scala/org/graphframes/lib/Pregel.scala @@ -25,25 +25,26 @@ import org.apache.spark.sql.functions.{array, col, explode, struct} /** * Implements a Pregel-like bulk-synchronous message-passing API based on DataFrame operations. * - * See Malewicz et al., Pregel: a system for large-scale graph - * processing for a detailed description of the Pregel algorithm. + * See Malewicz et al., Pregel: a system for + * large-scale graph processing for a detailed description of the Pregel algorithm. * - * You can construct a Pregel instance using either this constructor or [[org.graphframes.GraphFrame#pregel]], - * then use builder pattern to describe the operations, and then call [[run]] to start a run. - * It returns a DataFrame of vertices from the last iteration. + * You can construct a Pregel instance using either this constructor or + * [[org.graphframes.GraphFrame#pregel]], then use builder pattern to describe the operations, and + * then call [[run]] to start a run. It returns a DataFrame of vertices from the last iteration. * - * When a run starts, it expands the vertices DataFrame using column expressions defined by [[withVertexColumn]]. - * Those additional vertex properties can be changed during Pregel iterations. - * In each Pregel iteration, there are three phases: - * - Given each edge triplet, generate messages and specify target vertices to send, - * described by [[sendMsgToDst]] and [[sendMsgToSrc]]. + * When a run starts, it expands the vertices DataFrame using column expressions defined by + * [[withVertexColumn]]. Those additional vertex properties can be changed during Pregel + * iterations. In each Pregel iteration, there are three phases: + * - Given each edge triplet, generate messages and specify target vertices to send, described + * by [[sendMsgToDst]] and [[sendMsgToSrc]]. * - Aggregate messages by target vertex IDs, described by [[aggMsgs]]. - * - Update additional vertex properties based on aggregated messages and states from previous iteration, - * described by [[withVertexColumn]]. + * - Update additional vertex properties based on aggregated messages and states from previous + * iteration, described by [[withVertexColumn]]. * * Please find what columns you can reference at each phase in the method API docs. * - * You can control the number of iterations by [[setMaxIter]] and check API docs for advanced controls. + * You can control the number of iterations by [[setMaxIter]] and check API docs for advanced + * controls. * * Example code for Page Rank: * @@ -61,11 +62,13 @@ import org.apache.spark.sql.functions.{array, col, explode, struct} * .run() * }}} * - * @param graph The graph that Pregel will run on. - * @see [[org.graphframes.GraphFrame#pregel]] - * @see - * Malewicz et al., Pregel: a system for large-scale graph processing. - * + * @param graph + * The graph that Pregel will run on. + * @see + * [[org.graphframes.GraphFrame#pregel]] + * @see + * Malewicz et al., Pregel: a system for + * large-scale graph processing. */ class Pregel(val graph: GraphFrame) { @@ -99,27 +102,37 @@ class Pregel(val graph: GraphFrame) { } /** - * Defines an additional vertex column at the start of run and how to update it in each iteration. + * Defines an additional vertex column at the start of run and how to update it in each + * iteration. * * You can call it multiple times to add more than one additional vertex columns. * - * @param colName the name of the additional vertex column. - * It cannot be an existing vertex column in the graph. - * @param initialExpr the expression to initialize the additional vertex column. - * You can reference all original vertex columns in this expression. - * @param updateAfterAggMsgsExpr the expression to update the additional vertex column after messages aggregation. - * You can reference all original vertex columns, additional vertex columns, and the - * aggregated message column using [[Pregel$#msg]]. - * If the vertex received no messages, the message column would be null. + * @param colName + * the name of the additional vertex column. It cannot be an existing vertex column in the + * graph. + * @param initialExpr + * the expression to initialize the additional vertex column. You can reference all original + * vertex columns in this expression. + * @param updateAfterAggMsgsExpr + * the expression to update the additional vertex column after messages aggregation. You can + * reference all original vertex columns, additional vertex columns, and the aggregated + * message column using [[Pregel$#msg]]. If the vertex received no messages, the message + * column would be null. */ - def withVertexColumn(colName: String, initialExpr: Column, updateAfterAggMsgsExpr: Column): this.type = { + def withVertexColumn( + colName: String, + initialExpr: Column, + updateAfterAggMsgsExpr: Column): this.type = { // TODO: check if this column exists. - require(colName != null && colName != ID && colName != Pregel.MSG_COL_NAME, + require( + colName != null && colName != ID && colName != Pregel.MSG_COL_NAME, "additional column name cannot be null and cannot be the same name with ID column or " + - "msg column.") + "msg column.") require(initialExpr != null, "additional column should provide a nonnull initial expression.") - require(updateAfterAggMsgsExpr != null, "additional column should provide a nonnull " + - "updateAfterAggMsgs expression.") + require( + updateAfterAggMsgsExpr != null, + "additional column should provide a nonnull " + + "updateAfterAggMsgs expression.") withVertexColumnList += Tuple3(colName, initialExpr, updateAfterAggMsgsExpr) this } @@ -129,12 +142,14 @@ class Pregel(val graph: GraphFrame) { * * You can call it multiple times to send more than one messages. * - * @param msgExpr the expression of the message to send to the source vertex given a (src, edge, dst) triplet. - * Source/destination vertex properties and edge properties are nested under columns `src`, `dst`, - * and `edge`, respectively. - * You can reference them using [[Pregel$#src]], [[Pregel$#dst]], and [[Pregel$#edge]]. - * Null messages are not included in message aggregation. - * @see [[sendMsgToDst]] + * @param msgExpr + * the expression of the message to send to the source vertex given a (src, edge, dst) + * triplet. Source/destination vertex properties and edge properties are nested under columns + * `src`, `dst`, and `edge`, respectively. You can reference them using [[Pregel$#src]], + * [[Pregel$#dst]], and [[Pregel$#edge]]. Null messages are not included in message + * aggregation. + * @see + * [[sendMsgToDst]] */ def sendMsgToSrc(msgExpr: Column): this.type = { sendMsgs += Tuple2(Pregel.src(ID), msgExpr) @@ -146,12 +161,14 @@ class Pregel(val graph: GraphFrame) { * * You can call it multiple times to send more than one messages. * - * @param msgExpr the message expression to send to the destination vertex given a (`src`, `edge`, `dst`) triplet. - * Source/destination vertex properties and edge properties are nested under columns `src`, `dst`, - * and `edge`, respectively. - * You can reference them using [[Pregel$#src]], [[Pregel$#dst]], and [[Pregel$#edge]]. - * Null messages are not included in message aggregation. - * @see [[sendMsgToSrc]] + * @param msgExpr + * the message expression to send to the destination vertex given a (`src`, `edge`, `dst`) + * triplet. Source/destination vertex properties and edge properties are nested under columns + * `src`, `dst`, and `edge`, respectively. You can reference them using [[Pregel$#src]], + * [[Pregel$#dst]], and [[Pregel$#edge]]. Null messages are not included in message + * aggregation. + * @see + * [[sendMsgToSrc]] */ def sendMsgToDst(msgExpr: Column): this.type = { sendMsgs += Tuple2(Pregel.dst(ID), msgExpr) @@ -161,9 +178,10 @@ class Pregel(val graph: GraphFrame) { /** * Defines how messages are aggregated after grouped by target vertex IDs. * - * @param aggExpr the message aggregation expression, such as `sum(Pregel.msg)`. - * You can reference the message column by [[Pregel$#msg]] and the vertex ID by [[GraphFrame$#ID]], - * while the latter is usually not used. + * @param aggExpr + * the message aggregation expression, such as `sum(Pregel.msg)`. You can reference the + * message column by [[Pregel$#msg]] and the vertex ID by [[GraphFrame$#ID]], while the latter + * is usually not used. */ def aggMsgs(aggExpr: Column): this.type = { aggMsgsCol = aggExpr @@ -173,14 +191,22 @@ class Pregel(val graph: GraphFrame) { /** * Runs the defined Pregel algorithm. * - * @return the result vertex DataFrame from the final iteration including both original and additional columns. + * @return + * the result vertex DataFrame from the final iteration including both original and additional + * columns. */ def run(): DataFrame = { - require(sendMsgs.length > 0, "We need to set at least one message expression for pregel running.") + require( + sendMsgs.length > 0, + "We need to set at least one message expression for pregel running.") require(aggMsgsCol != null, "We need to set aggMsgs for pregel running.") require(maxIter >= 1, "The max iteration number should be >= 1.") - require(checkpointInterval >= 0, "The checkpoint interval should be >= 0, 0 indicates no checkpoint.") - require(withVertexColumnList.size > 0, "There should be at least one additional vertex columns for updating.") + require( + checkpointInterval >= 0, + "The checkpoint interval should be >= 0, 0 indicates no checkpoint.") + require( + withVertexColumnList.size > 0, + "There should be at least one additional vertex columns for updating.") val sendMsgsColList = sendMsgs.toList.map { case (id, msg) => struct(id.as(ID), msg.as("msg")) @@ -203,9 +229,12 @@ class Pregel(val graph: GraphFrame) { val shouldCheckpoint = checkpointInterval > 0 while (iteration <= maxIter) { - val tripletsDF = currentVertices.select(struct(col("*")).as(SRC)) + val tripletsDF = currentVertices + .select(struct(col("*")).as(SRC)) .join(edges.select(struct(col("*")).as(EDGE)), Pregel.src(ID) === Pregel.edge(SRC)) - .join(currentVertices.select(struct(col("*")).as(DST)), Pregel.edge(DST) === Pregel.dst(ID)) + .join( + currentVertices.select(struct(col("*")).as(DST)), + Pregel.edge(DST) === Pregel.dst(ID)) var msgDF: DataFrame = tripletsDF .select(explode(array(sendMsgsColList: _*)).as("msg")) @@ -258,31 +287,38 @@ object Pregel extends Serializable { /** * References the message column in aggregating messages and updating additional vertex columns. * - * @see [[Pregel.aggMsgs]] and [[Pregel.withVertexColumn]] + * @see + * [[Pregel.aggMsgs]] and [[Pregel.withVertexColumn]] */ val msg: Column = col(MSG_COL_NAME) /** * References a source vertex column in generating messages to send. * - * @param colName the vertex column name. - * @see [[Pregel.sendMsgToSrc]] and [[Pregel.sendMsgToDst]] + * @param colName + * the vertex column name. + * @see + * [[Pregel.sendMsgToSrc]] and [[Pregel.sendMsgToDst]] */ def src(colName: String): Column = col(GraphFrame.SRC + "." + colName) /** * References a destination vertex column in generating messages to send. * - * @param colName the vertex column name. - * @see [[Pregel.sendMsgToSrc]] and [[Pregel.sendMsgToDst]] + * @param colName + * the vertex column name. + * @see + * [[Pregel.sendMsgToSrc]] and [[Pregel.sendMsgToDst]] */ def dst(colName: String): Column = col(GraphFrame.DST + "." + colName) /** * References an edge column in generating messages to send. * - * @param colName the edge column name. - * @see [[Pregel.sendMsgToSrc]] and [[Pregel.sendMsgToDst]] + * @param colName + * the edge column name. + * @see + * [[Pregel.sendMsgToSrc]] and [[Pregel.sendMsgToDst]] */ def edge(colName: String): Column = col(GraphFrame.EDGE + "." + colName) } diff --git a/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala b/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala index 43ceedd92..35b18a742 100644 --- a/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala +++ b/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.{DataFrame, Row} import org.graphframes.{GraphFrame, Logging} /** - * Implement SVD++ based on "Factorization Meets the Neighborhood: - * a Multifaceted Collaborative Filtering Model", - * available at [[https://dl.acm.org/citation.cfm?id=1401944]]. + * Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative + * Filtering Model", available at [[https://dl.acm.org/citation.cfm?id=1401944]]. * * Note: The status of this algorithm is EXPERIMENTAL. Its API and implementation may be changed * in the future. @@ -35,9 +34,8 @@ import org.graphframes.{GraphFrame, Logging} * * Configuration parameters: see the description of each parameter in the article. * - * Returns a DataFrame with vertex attributes containing the trained model. See the object + * Returns a DataFrame with vertex attributes containing the trained model. See the object * (static) members for the names of the output columns. - * */ class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) extends Arguments { private var _rank: Int = 10 @@ -120,7 +118,9 @@ object SVDPlusPlus { case Row(src: Long, dst: Long, w: Double) => Edge(src, dst, w) } val (gx, res) = graphxlib.SVDPlusPlus.run(edges, conf) - val gf = GraphXConversions.fromGraphX(graph, gx, + val gf = GraphXConversions.fromGraphX( + graph, + gx, vertexNames = Seq(COLUMN1, COLUMN2, COLUMN3, COLUMN4)) (gf.vertices, res) } @@ -133,32 +133,32 @@ object SVDPlusPlus { val COLUMN_WEIGHT = "weight" /** - * Name for output vertexDataFrame column containing first parameter of learned model, - * of type `Array[Double]`. + * Name for output vertexDataFrame column containing first parameter of learned model, of type + * `Array[Double]`. * * Note: This column name may change in the future! */ val COLUMN1 = "column1" /** - * Name for output vertexDataFrame column containing second parameter of learned model, - * of type `Array[Double]`. + * Name for output vertexDataFrame column containing second parameter of learned model, of type + * `Array[Double]`. * * Note: This column name may change in the future! */ val COLUMN2 = "column2" /** - * Name for output vertexDataFrame column containing third parameter of learned model, - * of type `Double`. + * Name for output vertexDataFrame column containing third parameter of learned model, of type + * `Double`. * * Note: This column name may change in the future! */ val COLUMN3 = "column3" /** - * Name for output vertexDataFrame column containing fourth parameter of learned model, - * of type `Double`. + * Name for output vertexDataFrame column containing fourth parameter of learned model, of type + * `Double`. * * Note: This column name may change in the future! */ diff --git a/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/src/main/scala/org/graphframes/lib/ShortestPaths.scala index 6b1f15c74..0d0b52407 100644 --- a/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -30,13 +30,13 @@ import org.apache.spark.sql.types.{IntegerType, MapType} import org.graphframes.GraphFrame /** - * Computes shortest paths from every vertex to the given set of landmark vertices. - * Note that this takes edge direction into account. + * Computes shortest paths from every vertex to the given set of landmark vertices. Note that this + * takes edge direction into account. * * The returned DataFrame contains all the original vertex information as well as one additional * column: - * - distances (`MapType[vertex ID type, IntegerType]`): For each vertex v, a map containing - * the shortest-path distance to each reachable landmark vertex. + * - distances (`MapType[vertex ID type, IntegerType]`): For each vertex v, a map containing the + * shortest-path distance to each reachable landmark vertex. */ class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends Arguments { private var lmarks: Option[Seq[Any]] = None @@ -67,9 +67,9 @@ private object ShortestPaths { private def run(graph: GraphFrame, landmarks: Seq[Any]): DataFrame = { val idType = graph.vertices.schema(GraphFrame.ID).dataType val longIdToLandmark = landmarks.map(l => GraphXConversions.integralId(graph, l) -> l).toMap - val gx = graphxlib.ShortestPaths.run( - graph.cachedTopologyGraphX, - longIdToLandmark.keys.toSeq.sorted).mapVertices { case (_, m) => m.toSeq } + val gx = graphxlib.ShortestPaths + .run(graph.cachedTopologyGraphX, longIdToLandmark.keys.toSeq.sorted) + .mapVertices { case (_, m) => m.toSeq } val g = GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(DISTANCE_ID)) val distanceCol: Column = if (graph.hasIntegralIdType) { // It seems there are no easy way to convert a sequence of pairs into a map @@ -83,7 +83,7 @@ private object ShortestPaths { val func = new UDF1[Seq[Row], Map[Any, Int]] { override def call(t1: Seq[Row]): Map[Any, Int] = { t1.map { case Row(k: Long, v: Int) => - longIdToLandmark(k) -> v + longIdToLandmark(k) -> v }.toMap } } diff --git a/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala b/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala index 2914a287a..e8d9ecde8 100644 --- a/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala @@ -27,10 +27,10 @@ import org.graphframes.GraphFrame * vertex assigned to the SCC containing that vertex. * * The resulting DataFrame contains all the original vertex information and one additional column: - * - component (`LongType`): unique ID for this component + * - component (`LongType`): unique ID for this component */ class StronglyConnectedComponents private[graphframes] (private val graph: GraphFrame) - extends Arguments { + extends Arguments { private var maxIter: Option[Int] = None @@ -44,7 +44,6 @@ class StronglyConnectedComponents private[graphframes] (private val graph: Graph } } - /** Strongly connected components algorithm implementation. */ private object StronglyConnectedComponents { private def run(graph: GraphFrame, numIter: Int): DataFrame = { diff --git a/src/main/scala/org/graphframes/lib/TriangleCount.scala b/src/main/scala/org/graphframes/lib/TriangleCount.scala index fb8c51462..5eede3c67 100644 --- a/src/main/scala/org/graphframes/lib/TriangleCount.scala +++ b/src/main/scala/org/graphframes/lib/TriangleCount.scala @@ -26,15 +26,15 @@ import org.graphframes.GraphFrame.{DST, ID, LONG_DST, LONG_SRC, SRC} /** * Computes the number of triangles passing through each vertex. * - * This algorithm ignores edge direction; i.e., all edges are treated as undirected. - * In a multigraph, duplicate edges will be counted only once. + * This algorithm ignores edge direction; i.e., all edges are treated as undirected. In a + * multigraph, duplicate edges will be counted only once. * - * Note that this provides the same algorithm as GraphX, but GraphX assumes the user provides - * a graph in the correct format. In Spark 2.0+, GraphX can automatically canonicalize - * the graph to put it in this format. + * Note that this provides the same algorithm as GraphX, but GraphX assumes the user provides a + * graph in the correct format. In Spark 2.0+, GraphX can automatically canonicalize the graph to + * put it in this format. * * The returned DataFrame contains all the original vertex information and one additional column: - * - count (`LongType`): the count of triangles + * - count (`LongType`): the count of triangles */ class TriangleCount private[graphframes] (private val graph: GraphFrame) extends Arguments { @@ -67,7 +67,8 @@ private object TriangleCount { val v = graph.vertices val countsCol = when(col("count").isNull, 0L).otherwise(col("count")) - val newV = v.join(triangleCounts, v(ID) === triangleCounts(ID), "left_outer") + val newV = v + .join(triangleCounts, v(ID) === triangleCounts(ID), "left_outer") .select((countsCol.as(COUNT_ID) +: v.columns.map(v.apply)).toSeq: _*) newV } diff --git a/src/main/scala/org/graphframes/pattern/patterns.scala b/src/main/scala/org/graphframes/pattern/patterns.scala index bd566d578..fcc9f22c5 100644 --- a/src/main/scala/org/graphframes/pattern/patterns.scala +++ b/src/main/scala/org/graphframes/pattern/patterns.scala @@ -34,13 +34,13 @@ private[graphframes] object PatternParser extends RegexParsers { case src ~ "-" ~ "[" ~ name ~ "]" ~ "->" ~ dst => NamedEdge(name, src, dst) } val anonymousEdge: Parser[Edge] = - vertex ~ "-" ~ "[" ~ "]" ~ "->" ~ vertex ^^ { - case src ~ "-" ~ "[" ~ "]" ~ "->" ~ dst => AnonymousEdge(src, dst) + vertex ~ "-" ~ "[" ~ "]" ~ "->" ~ vertex ^^ { case src ~ "-" ~ "[" ~ "]" ~ "->" ~ dst => + AnonymousEdge(src, dst) } private val edge: Parser[Edge] = namedEdge | anonymousEdge private val negatedEdge: Parser[Pattern] = - "!" ~ edge ^^ { - case _ ~ e => Negation(e) + "!" ~ edge ^^ { case _ ~ e => + Negation(e) } private val pattern: Parser[Pattern] = edge | vertex | negatedEdge val patterns: Parser[List[Pattern]] = repsep(pattern, ";") @@ -62,10 +62,11 @@ private[graphframes] object Pattern { /** * Checks all Patterns for validity: - * - Disallow named edges within negated terms - * - Disallow term "()-[]->()" and its negation - * - Disallow name to be shared by a vertex and an edge - * @throws InvalidParseException if an negated terms contain named edges + * - Disallow named edges within negated terms + * - Disallow term "()-[]->()" and its negation + * - Disallow name to be shared by a vertex and an edge + * @throws InvalidParseException + * if an negated terms contain named edges */ private def assertValidPatterns(patterns: Seq[Pattern]): Unit = { @@ -75,21 +76,24 @@ private[graphframes] object Pattern { def addVertex(v: Vertex): Unit = v match { case NamedVertex(name) => if (edgeNames.contains(name)) { - throw new InvalidParseException(s"Motif reused name '$name' for both a vertex and " + - s"an edge, which is not allowed.") + throw new InvalidParseException( + s"Motif reused name '$name' for both a vertex and " + + s"an edge, which is not allowed.") } vertexNames += name - case AnonymousVertex => // pass + case AnonymousVertex => // pass } def addEdge(e: Edge): Unit = e match { case NamedEdge(name, src, dst) => if (vertexNames.contains(name)) { - throw new InvalidParseException(s"Motif reused name '$name' for both a vertex and " + - s"an edge, which is not allowed.") + throw new InvalidParseException( + s"Motif reused name '$name' for both a vertex and " + + s"an edge, which is not allowed.") } if (edgeNames.contains(name)) { - throw new InvalidParseException(s"Motif reused name '$name' for multiple edges, " + - s"which is not allowed.") + throw new InvalidParseException( + s"Motif reused name '$name' for multiple edges, " + + s"which is not allowed.") } edgeNames += name addVertex(src) @@ -103,8 +107,9 @@ private[graphframes] object Pattern { case Negation(edge) => edge match { case NamedEdge(name, src, dst) => - throw new InvalidParseException(s"Motif finding does not support negated named " + - s"edges, but the given pattern contained: !($src)-[$name]->($dst)") + throw new InvalidParseException( + s"Motif finding does not support negated named " + + s"edges, but the given pattern contained: !($src)-[$name]->($dst)") case AnonymousEdge(AnonymousVertex, AnonymousVertex) => throw new InvalidParseException(s"Motif finding does not support completely " + s"anonymous negated edges !()-[]->(). Users can check for 0 edges in the graph " + @@ -113,17 +118,19 @@ private[graphframes] object Pattern { addEdge(e) } case AnonymousEdge(AnonymousVertex, AnonymousVertex) => - throw new InvalidParseException(s"Motif finding does not support completely " + - s"anonymous edges ()-[]->(). Users can check for the existence of edges in the " + - s"graph using the edges DataFrame.") + throw new InvalidParseException( + s"Motif finding does not support completely " + + s"anonymous edges ()-[]->(). Users can check for the existence of edges in the " + + s"graph using the edges DataFrame.") case e @ AnonymousEdge(_, _) => addEdge(e) case e @ NamedEdge(_, _, _) => addEdge(e) case AnonymousVertex => - throw new InvalidParseException("Motif finding does not allow a lone anonymous vertex " + - "\"()\" in a motif. Users can check for the existence of vertices in the graph " + - "using the vertices DataFrame.") + throw new InvalidParseException( + "Motif finding does not allow a lone anonymous vertex " + + "\"()\" in a motif. Users can check for the existence of vertices in the graph " + + "using the vertices DataFrame.") case v @ NamedVertex(_) => addVertex(v) } @@ -132,27 +139,31 @@ private[graphframes] object Pattern { /** * Return the set of named vertices which only appear in negated terms, in sorted order. */ - private[graphframes] - def findNamedVerticesOnlyInNegatedTerms(patterns: Seq[Pattern]): Seq[String] = { + private[graphframes] def findNamedVerticesOnlyInNegatedTerms( + patterns: Seq[Pattern]): Seq[String] = { val vPos = findNamedElementsInOrder( - patterns.filter(p => !p.isInstanceOf[Negation]), includeEdges = false).toSet + patterns.filter(p => !p.isInstanceOf[Negation]), + includeEdges = false).toSet val vNeg = findNamedElementsInOrder( - patterns.filter(p => p.isInstanceOf[Negation]), includeEdges = false).toSet + patterns.filter(p => p.isInstanceOf[Negation]), + includeEdges = false).toSet vNeg.diff(vPos).toSeq.sorted } /** - * Return the set of named vertices (and optionally edges) appearing in the given patterns, - * in the order they first appear in the sequence of patterns. - * @param includeEdges If true, include named edges in the returned sequence. + * Return the set of named vertices (and optionally edges) appearing in the given patterns, in + * the order they first appear in the sequence of patterns. + * @param includeEdges + * If true, include named edges in the returned sequence. */ - private[graphframes] - def findNamedElementsInOrder(patterns: Seq[Pattern], includeEdges: Boolean): Seq[String] = { + private[graphframes] def findNamedElementsInOrder( + patterns: Seq[Pattern], + includeEdges: Boolean): Seq[String] = { val elementSet = mutable.LinkedHashSet.empty[String] def findNamedElementsHelper(pattern: Pattern): Unit = pattern match { case Negation(child) => findNamedElementsHelper(child) - case AnonymousVertex => // pass + case AnonymousVertex => // pass case NamedVertex(name) => if (!elementSet.contains(name)) { elementSet += name @@ -187,4 +198,3 @@ private[graphframes] sealed trait Edge extends Pattern private[graphframes] case class AnonymousEdge(src: Vertex, dst: Vertex) extends Edge private[graphframes] case class NamedEdge(name: String, src: Vertex, dst: Vertex) extends Edge - diff --git a/src/test/scala/org/graphframes/GraphFrameSuite.scala b/src/test/scala/org/graphframes/GraphFrameSuite.scala index 3ef5059ec..d8d761898 100644 --- a/src/test/scala/org/graphframes/GraphFrameSuite.scala +++ b/src/test/scala/org/graphframes/GraphFrameSuite.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.{DataFrame, Row} import org.graphframes.examples.Graphs - class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { import GraphFrame._ @@ -47,10 +46,11 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { super.beforeAll() tempDir = Files.createTempDir() vertices = spark.createDataFrame(localVertices.toSeq).toDF("id", "name") - edges = spark.createDataFrame(localEdges.toSeq.map { - case ((src, dst), action) => + edges = spark + .createDataFrame(localEdges.toSeq.map { case ((src, dst), action) => (src, dst, action) - }).toDF("src", "dst", "action") + }) + .toDF("src", "dst", "action") } override def afterAll(): Unit = { @@ -86,9 +86,14 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { val idsFromVertices = g.vertices.select("id").rdd.map(_.getLong(0)).collect() val idsFromVerticesSet = idsFromVertices.toSet assert(idsFromVertices.length === idsFromVerticesSet.size) - val idsFromEdgesSet = g.edges.select("src", "dst").rdd.flatMap { case Row(src: Long, dst: Long) => - Seq(src, dst) - }.collect().toSet + val idsFromEdgesSet = g.edges + .select("src", "dst") + .rdd + .flatMap { case Row(src: Long, dst: Long) => + Seq(src, dst) + } + .collect() + .toSet assert(idsFromVerticesSet === idsFromEdgesSet) } @@ -127,8 +132,10 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("convert to GraphX: Int IDs") { val vv = vertices.select(col("id").cast(IntegerType).as("id"), col("name")) - val ee = edges.select(col("src").cast(IntegerType).as("src"), - col("dst").cast(IntegerType).as("dst"), col("action")) + val ee = edges.select( + col("src").cast(IntegerType).as("src"), + col("dst").cast(IntegerType).as("dst"), + col("action")) val gf = GraphFrame(vv, ee) val g = gf.toGraphX // Int IDs should be directly cast to Long, so ID values should match. @@ -140,31 +147,35 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(id0 === id1) assert(localVertices(id0) === name) } - g.edges.collect().foreach { - case Edge(src0: Long, dst0: Long, attr: Row) => - val src1 = attr.getInt(eCols("src")) - val dst1 = attr.getInt(eCols("dst")) - val action = attr.getString(eCols("action")) - assert(src0 === src1) - assert(dst0 === dst1) - assert(localEdges((src0, dst0)) === action) + g.edges.collect().foreach { case Edge(src0: Long, dst0: Long, attr: Row) => + val src1 = attr.getInt(eCols("src")) + val dst1 = attr.getInt(eCols("dst")) + val action = attr.getString(eCols("action")) + assert(src0 === src1) + assert(dst0 === dst1) + assert(localEdges((src0, dst0)) === action) } } test("convert to GraphX: String IDs") { try { val vv = vertices.select(col("id").cast(StringType).as("id"), col("name")) - val ee = edges.select(col("src").cast(StringType).as("src"), - col("dst").cast(StringType).as("dst"), col("action")) + val ee = edges.select( + col("src").cast(StringType).as("src"), + col("dst").cast(StringType).as("dst"), + col("action")) val gf = GraphFrame(vv, ee) val g = gf.toGraphX // String IDs will be re-indexed, so ID values may not match. val vCols = gf.vertexColumnMap val eCols = gf.edgeColumnMap // First, get index. - val new2oldID: Map[Long, String] = g.vertices.map { case (id: Long, attr: Row) => - (id, attr.getString(vCols("id"))) - }.collect().toMap + val new2oldID: Map[Long, String] = g.vertices + .map { case (id: Long, attr: Row) => + (id, attr.getString(vCols("id"))) + } + .collect() + .toMap // Same as in test with Int IDs, but with re-indexing g.vertices.collect().foreach { case (id0: Long, attr: Row) => val id1 = attr.getString(vCols("id")) @@ -172,14 +183,13 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(new2oldID(id0) === id1) assert(localVertices(new2oldID(id0).toInt) === name) } - g.edges.collect().foreach { - case Edge(src0: Long, dst0: Long, attr: Row) => - val src1 = attr.getString(eCols("src")) - val dst1 = attr.getString(eCols("dst")) - val action = attr.getString(eCols("action")) - assert(new2oldID(src0) === src1) - assert(new2oldID(dst0) === dst1) - assert(localEdges((new2oldID(src0).toInt, new2oldID(dst0).toInt)) === action) + g.edges.collect().foreach { case Edge(src0: Long, dst0: Long, attr: Row) => + val src1 = attr.getString(eCols("src")) + val dst1 = attr.getString(eCols("dst")) + val action = attr.getString(eCols("action")) + assert(new2oldID(src0) === src1) + assert(new2oldID(dst0) === dst1) + assert(localEdges((new2oldID(src0).toInt, new2oldID(dst0).toInt)) === action) } } catch { case e: Exception => @@ -211,21 +221,30 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { val g = GraphFrame(vertices, edges) assert(g.outDegrees.columns === Seq("id", "outDegree")) - val outDegrees = g.outDegrees.collect().map { case Row(id: Long, outDeg: Int) => - (id, outDeg) - }.toMap + val outDegrees = g.outDegrees + .collect() + .map { case Row(id: Long, outDeg: Int) => + (id, outDeg) + } + .toMap assert(outDegrees === Map(1L -> 1, 2L -> 2)) assert(g.inDegrees.columns === Seq("id", "inDegree")) - val inDegrees = g.inDegrees.collect().map { case Row(id: Long, inDeg: Int) => - (id, inDeg) - }.toMap + val inDegrees = g.inDegrees + .collect() + .map { case Row(id: Long, inDeg: Int) => + (id, inDeg) + } + .toMap assert(inDegrees === Map(1L -> 1, 2L -> 1, 3L -> 1)) assert(g.degrees.columns === Seq("id", "degree")) - val degrees = g.degrees.collect().map { case Row(id: Long, deg: Int) => - (id, deg) - }.toMap + val degrees = g.degrees + .collect() + .map { case Row(id: Long, deg: Int) => + (id, deg) + } + .toMap assert(degrees === Map(1L -> 2, 2L -> 3, 3L -> 1)) } @@ -257,7 +276,8 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { val chain = Graphs.chain(n + 1) val vertices = star.vertices.select(col(ID).cast("string").as(ID)) val edges = - star.edges.select(col(SRC).cast("string").as(SRC), col(DST).cast("string").as(DST)) + star.edges + .select(col(SRC).cast("string").as(SRC), col(DST).cast("string").as(DST)) .unionAll( chain.edges.select(col(SRC).cast("string").as(SRC), col(DST).cast("string").as(DST))) @@ -265,7 +285,8 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { val localEdges = edges.select(SRC, DST).as[(String, String)].collect().toSet val defaultThreshold = GraphFrame.broadcastThreshold - assert(defaultThreshold === 1000000, + assert( + defaultThreshold === 1000000, s"Default broadcast threshold should be 1000000 but got $defaultThreshold.") for (threshold <- Seq(0, 4, 10)) { @@ -274,18 +295,19 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { val g = GraphFrame(vertices, edges) g.persist(StorageLevel.MEMORY_AND_DISK) - val indexedVertices = g.indexedVertices.select(ID, LONG_ID).as[(String, Long)].collect().toMap + val indexedVertices = + g.indexedVertices.select(ID, LONG_ID).as[(String, Long)].collect().toMap assert(indexedVertices.keySet === localVertices) assert(indexedVertices.values.toSeq.distinct.size === localVertices.size) val origEdges = g.indexedEdges.select(SRC, DST).as[(String, String)].collect().toSet assert(origEdges === localEdges) g.indexedEdges - .select(SRC, LONG_SRC, DST, LONG_DST).as[(String, Long, String, Long)] + .select(SRC, LONG_SRC, DST, LONG_DST) + .as[(String, Long, String, Long)] .collect() - .foreach { - case (src, longSrc, dst, longDst) => - assert(indexedVertices(src) === longSrc) - assert(indexedVertices(dst) === longDst) + .foreach { case (src, longSrc, dst, longDst) => + assert(indexedVertices(src) === longSrc) + assert(indexedVertices(dst) === longDst) } } diff --git a/src/test/scala/org/graphframes/GraphFrameTestSparkContext.scala b/src/test/scala/org/graphframes/GraphFrameTestSparkContext.scala index 97bdf6e4a..99c8b56c6 100644 --- a/src/test/scala/org/graphframes/GraphFrameTestSparkContext.scala +++ b/src/test/scala/org/graphframes/GraphFrameTestSparkContext.scala @@ -35,9 +35,9 @@ trait GraphFrameTestSparkContext extends BeforeAndAfterAll { self: Suite => /** * A helper object for importing SQL implicits. * - * Note that the alternative of importing `spark.implicits._` is not possible here. - * This is because we create the `SQLContext` immediately before the first test is run, - * but the implicits import is needed in the constructor. + * Note that the alternative of importing `spark.implicits._` is not possible here. This is + * because we create the `SQLContext` immediately before the first test is run, but the + * implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { protected override def _sqlContext: SQLContext = self.sqlContext @@ -56,7 +56,8 @@ trait GraphFrameTestSparkContext extends BeforeAndAfterAll { self: Suite => override def beforeAll() { super.beforeAll() - spark = SparkSession.builder() + spark = SparkSession + .builder() .master("local[2]") .appName("GraphFramesUnitTest") .config("spark.sql.shuffle.partitions", 4) diff --git a/src/test/scala/org/graphframes/PatternMatchSuite.scala b/src/test/scala/org/graphframes/PatternMatchSuite.scala index 6feac0294..fc06d2eb6 100644 --- a/src/test/scala/org/graphframes/PatternMatchSuite.scala +++ b/src/test/scala/org/graphframes/PatternMatchSuite.scala @@ -23,12 +23,12 @@ import org.apache.spark.sql.functions.{col, lit, when} /** * Cases to go through: - * - Any negated terms? - * - Any anonymous vertices + * - Any negated terms? + * - Any anonymous vertices * - in non-negated terms? * - in negated terms? - * - # named vertices grounding a negated term to non-negated terms: 2, 1, 0 - * - Named edges? + * - # named vertices grounding a negated term to non-negated terms: 2, 1, 0 + * - Named edges? */ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { @@ -40,18 +40,20 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { override def beforeAll(): Unit = { super.beforeAll() - v = spark.createDataFrame(List( - (0L, "a", "f"), - (1L, "b", "m"), - (2L, "c", "m"), - (3L, "d", "f"))).toDF("id", "attr", "gender") - e = spark.createDataFrame(List( - (0L, 1L, "friend"), - (1L, 0L, "follow"), - (1L, 2L, "friend"), - (2L, 3L, "follow"), - (2L, 0L, "unknown"))).toDF("src", "dst", "relationship") - noEdges = v.select(col("id").alias("src")) + v = spark + .createDataFrame(List((0L, "a", "f"), (1L, "b", "m"), (2L, "c", "m"), (3L, "d", "f"))) + .toDF("id", "attr", "gender") + e = spark + .createDataFrame( + List( + (0L, 1L, "friend"), + (1L, 0L, "follow"), + (1L, 2L, "friend"), + (2L, 3L, "follow"), + (2L, 0L, "unknown"))) + .toDF("src", "dst", "relationship") + noEdges = v + .select(col("id").alias("src")) .crossJoin(v.select(col("id").alias("dst"))) .except(e.select("src", "dst")) g = GraphFrame(v, e) @@ -66,11 +68,12 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { private def compareResultToExpected[A](result: Set[A], expected: Set[A]): Unit = { if (result !== expected) { - throw new AssertionError("result !== expected.\n" + - s"Result contained additional values: ${result.diff(expected)}\n" + - s"Expected contained additional values: ${expected.diff(result)}\n" + - s"Result: $result\n" + - s"Expected: $expected") + throw new AssertionError( + "result !== expected.\n" + + s"Result contained additional values: ${result.diff(expected)}\n" + + s"Expected contained additional values: ${expected.diff(result)}\n" + + s"Result: $result\n" + + s"Expected: $expected") } } @@ -90,17 +93,10 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { val s = "relationship = 'friend'" // column expression val c = col("relationship") === "friend" - // expected subgraph vertices - val expected_v = Set( - Row(0L, "a", "f"), - Row(1L, "b", "m"), - Row(2L, "c", "m") - ) + // expected subgraph vertices + val expected_v = Set(Row(0L, "a", "f"), Row(1L, "b", "m"), Row(2L, "c", "m")) // expected subgraph edges - val expected_e = Set( - Row(0L, 1L, "friend"), - Row(1L, 2L, "friend") - ) + val expected_e = Set(Row(0L, 1L, "friend"), Row(1L, 2L, "friend")) val res_s = g.filterEdges(s) assert(res_s.vertices.collect().toSet === v.collect().toSet) @@ -120,17 +116,10 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { val s = "id > 0" // column expression val c = col("id") > 0 - // expected subgraph vertices - val expected_v = Set( - Row(1L, "b", "m"), - Row(2L, "c", "m"), - Row(3L, "d", "f") - ) + // expected subgraph vertices + val expected_v = Set(Row(1L, "b", "m"), Row(2L, "c", "m"), Row(3L, "d", "f")) // expected subgraph edges - val expected_e = Set( - Row(1L, 2L, "friend"), - Row(2L, 3L, "follow") - ) + val expected_e = Set(Row(1L, 2L, "friend"), Row(2L, 3L, "follow")) val res_s = g.filterVertices(s) assert(res_s.vertices.collect().toSet === expected_v) @@ -142,14 +131,11 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { } test("triangles") { - val triangles = g.find("(a)-[]->(b); (b)-[]->(c); (c)-[]->(a)") + val triangles = g + .find("(a)-[]->(b); (b)-[]->(c); (c)-[]->(a)") .select("a.id", "b.id", "c.id") - assert(triangles.collect().toSet === Set( - Row(0L, 1L, 2L), - Row(2L, 0L, 1L), - Row(1L, 2L, 0L) - )) + assert(triangles.collect().toSet === Set(Row(0L, 1L, 2L), Row(2L, 0L, 1L), Row(1L, 2L, 0L))) } /* ====================================== Vertex queries ===================================== */ @@ -169,29 +155,32 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(triplets.columns === Array("u", "v")) val res = triplets.select("u.id", "u.attr", "v.id", "v.attr").collect().toSet - compareResultToExpected(res, Set( - Row(0L, "a", 1L, "b"), - Row(1L, "b", 0L, "a"), - Row(1L, "b", 2L, "c"), - Row(2L, "c", 3L, "d"), - Row(2L, "c", 0L, "a") - )) + compareResultToExpected( + res, + Set( + Row(0L, "a", 1L, "b"), + Row(1L, "b", 0L, "a"), + Row(1L, "b", 2L, "c"), + Row(2L, "c", 3L, "d"), + Row(2L, "c", 0L, "a"))) } test("triplet with named edge") { val triplets = g.find("(u)-[uv]->(v)") assert(triplets.columns === Array("u", "uv", "v")) - val res = triplets.select("u.id", "u.attr", "uv.src", "uv.dst", "uv.relationship", - "v.id", "v.attr") - .collect().toSet - compareResultToExpected(res, Set( - Row(0L, "a", 0L, 1L, "friend", 1L, "b"), - Row(1L, "b", 1L, 0L, "follow", 0L, "a"), - Row(1L, "b", 1L, 2L, "friend", 2L, "c"), - Row(2L, "c", 2L, 3L, "follow", 3L, "d"), - Row(2L, "c", 2L, 0L, "unknown", 0L, "a") - )) + val res = triplets + .select("u.id", "u.attr", "uv.src", "uv.dst", "uv.relationship", "v.id", "v.attr") + .collect() + .toSet + compareResultToExpected( + res, + Set( + Row(0L, "a", 0L, 1L, "friend", 1L, "b"), + Row(1L, "b", 1L, 0L, "follow", 0L, "a"), + Row(1L, "b", 1L, 2L, "friend", 2L, "c"), + Row(2L, "c", 2L, 3L, "follow", 3L, "d"), + Row(2L, "c", 2L, 0L, "unknown", 0L, "a"))) } test("triplet with anonymous vertex") { @@ -199,13 +188,13 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(triplets.columns === Array("u")) // Do not use compareResultToExpected since it uses sets, and we expect duplicates. - assert(triplets.select("u.id", "u.attr").collect().sortBy(_.getLong(0)) === Array( - Row(0L, "a"), - Row(1L, "b"), - Row(1L, "b"), - Row(2L, "c"), - Row(2L, "c") - ).sortBy(_.getLong(0))) + assert( + triplets.select("u.id", "u.attr").collect().sortBy(_.getLong(0)) === Array( + Row(0L, "a"), + Row(1L, "b"), + Row(1L, "b"), + Row(2L, "c"), + Row(2L, "c")).sortBy(_.getLong(0))) } test("triplet with 2 anonymous vertices") { @@ -218,9 +207,9 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { } test("self-loop") { - val myE = spark.createDataFrame(List( - (1L, 1L, "self"), - (3L, 3L, "self"))).toDF("src", "dst", "relationship") + val myE = spark + .createDataFrame(List((1L, 1L, "self"), (3L, 3L, "self"))) + .toDF("src", "dst", "relationship") .union(e) val myG = GraphFrame(v, myE) @@ -231,32 +220,29 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { val selfLoops2 = myG.find("(a)-[]->(b); (a)-[]->(a)") assert(selfLoops2.columns === Array("a", "b")) - val res2 = selfLoops2.select("a.id", "b.id") + val res2 = selfLoops2 + .select("a.id", "b.id") .where("a.id != b.id") - .collect().toSet - compareResultToExpected(res2, Set( - Row(1L, 0L), - Row(1L, 2L) - )) + .collect() + .toSet + compareResultToExpected(res2, Set(Row(1L, 0L), Row(1L, 2L))) } test("duplicate edges") { - val myE = spark.createDataFrame(List( - (1L, 0L, "dup"), - (1L, 2L, "dup"))).toDF("src", "dst", "relationship") + val myE = spark + .createDataFrame(List((1L, 0L, "dup"), (1L, 2L, "dup"))) + .toDF("src", "dst", "relationship") .union(e) val myG = GraphFrame(v, myE) - val edges = myG.find("(a)-[]->(b)") + val edges = myG + .find("(a)-[]->(b)") .where("a.id = 1") - val res = edges.select("a.id", "b.id") - .collect().sortBy(_.getLong(1)) - val expected = Array( - Row(1L, 0L), - Row(1L, 0L), - Row(1L, 2L), - Row(1L, 2L) - ) + val res = edges + .select("a.id", "b.id") + .collect() + .sortBy(_.getLong(1)) + val expected = Array(Row(1L, 0L), Row(1L, 0L), Row(1L, 2L), Row(1L, 2L)) assert(res === expected) } @@ -266,26 +252,28 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { val triangles = g.find("(a)-[]->(b); (b)-[]->(c); (c)-[]->(a)") assert(triangles.columns === Array("a", "b", "c")) - val res = triangles.select("a.id", "b.id", "c.id") - .collect().toSet - compareResultToExpected(res, Set( - Row(0L, 1L, 2L), - Row(2L, 0L, 1L), - Row(1L, 2L, 0L) - )) + val res = triangles + .select("a.id", "b.id", "c.id") + .collect() + .toSet + compareResultToExpected(res, Set(Row(0L, 1L, 2L), Row(2L, 0L, 1L), Row(1L, 2L, 0L))) } test("disconnected edges create an outer join") { val edgePairs = g.find("(a)-[]->(b); (c)-[]->(d)") assert(edgePairs.columns === Array("a", "b", "c", "d")) - val res = edgePairs.select("a.id", "b.id", "c.id", "d.id") - .collect().toSet + val res = edgePairs + .select("a.id", "b.id", "c.id", "d.id") + .collect() + .toSet val ab = e.select(col("src").alias("a"), col("dst").alias("b")) val cd = e.select(col("src").alias("c"), col("dst").alias("d")) - val expected = ab.crossJoin(cd) - .collect().toSet + val expected = ab + .crossJoin(cd) + .collect() + .toSet compareResultToExpected(res, expected) val numEdges = e.count() assert(expected.size === numEdges * numEdges) @@ -297,38 +285,32 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { val edges = g.find("(a)-[]->(b); !(b)-[]->(a)") assert(edges.columns === Array("a", "b")) - val res = edges.select("a.id", "b.id") - .collect().toSet - compareResultToExpected(res, Set( - Row(1L, 2L), - Row(2L, 0L), - Row(2L, 3L) - )) + val res = edges + .select("a.id", "b.id") + .collect() + .toSet + compareResultToExpected(res, Set(Row(1L, 2L), Row(2L, 0L), Row(2L, 3L))) } test("a->b->c but not c->a") { val edges = g.find("(a)-[]->(b); (b)-[]->(c); !(c)-[]->(a)") assert(edges.columns === Array("a", "b", "c")) - val res = edges.select("a.id", "b.id", "c.id") - .collect().toSet - assert(res === Set( - Row(0L, 1L, 0L), - Row(1L, 0L, 1L), - Row(1L, 2L, 3L) - )) + val res = edges + .select("a.id", "b.id", "c.id") + .collect() + .toSet + assert(res === Set(Row(0L, 1L, 0L), Row(1L, 0L, 1L), Row(1L, 2L, 3L))) } test("three connected vertices not in a triangle") { - val fof = g.find("(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w); !(w)-[]->(u)") + val fof = g + .find("(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w); !(w)-[]->(u)") .select("u.id", "v.id", "w.id") - .collect().toSet + .collect() + .toSet - compareResultToExpected(fof, Set( - Row(1L, 0L, 1L), - Row(0L, 1L, 0L), - Row(1L, 2L, 3L) - )) + compareResultToExpected(fof, Set(Row(1L, 0L, 1L), Row(0L, 1L, 0L), Row(1L, 2L, 3L))) } /* ========== 1 named vertex grounding a negated term to non-negated terms =============== */ @@ -337,35 +319,38 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { val edges = g.find("(a)-[]->(b); !(b)-[]->(c)") assert(edges.columns === Array("a", "b", "c")) - val res = edges.select("a.id", "b.id", "c.id") - .collect().toSet - compareResultToExpected(res, Set( - Row(0L, 1L, 1L), - Row(0L, 1L, 3L), - Row(1L, 0L, 0L), - Row(1L, 0L, 2L), - Row(1L, 0L, 3L), - Row(1L, 2L, 1L), - Row(1L, 2L, 2L), - Row(2L, 3L, 0L), - Row(2L, 3L, 1L), - Row(2L, 3L, 2L), - Row(2L, 3L, 3L), - Row(2L, 0L, 0L), - Row(2L, 0L, 2L), - Row(2L, 0L, 3L) - )) + val res = edges + .select("a.id", "b.id", "c.id") + .collect() + .toSet + compareResultToExpected( + res, + Set( + Row(0L, 1L, 1L), + Row(0L, 1L, 3L), + Row(1L, 0L, 0L), + Row(1L, 0L, 2L), + Row(1L, 0L, 3L), + Row(1L, 2L, 1L), + Row(1L, 2L, 2L), + Row(2L, 3L, 0L), + Row(2L, 3L, 1L), + Row(2L, 3L, 2L), + Row(2L, 3L, 3L), + Row(2L, 0L, 0L), + Row(2L, 0L, 2L), + Row(2L, 0L, 3L))) } test("a->b where b has no out edges") { val edges = g.find("(a)-[]->(b); !(b)-[]->()") assert(edges.columns === Array("a", "b")) - val res = edges.select("a.id", "b.id") - .collect().toSet - compareResultToExpected(res, Set( - Row(2L, 3L) - )) + val res = edges + .select("a.id", "b.id") + .collect() + .toSet + compareResultToExpected(res, Set(Row(2L, 3L))) } /* ========== 0 named vertices grounding a negated term to non-negated terms =============== */ @@ -374,29 +359,30 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { val edgePairs = g.find("(a)-[]->(b); !(c)-[]->(d)") assert(edgePairs.columns === Array("a", "b", "c", "d")) - val res = edgePairs.select("a.id", "b.id", "c.id", "d.id") - .collect().toSet - val expected = e.select(col("src").alias("a"), col("dst").alias("b")) + val res = edgePairs + .select("a.id", "b.id", "c.id", "d.id") + .collect() + .toSet + val expected = e + .select(col("src").alias("a"), col("dst").alias("b")) .crossJoin(noEdges.select(col("src").alias("c"), col("dst").alias("d"))) .select("a", "b", "c", "d") - .collect().toSet + .collect() + .toSet compareResultToExpected(res, expected) - assert(expected.size === noEdges.count() * e.count()) // make sure there are no duplicates + assert(expected.size === noEdges.count() * e.count()) // make sure there are no duplicates } test("a->b, c where c has no out edges") { val triplets = g.find("(a)-[]->(b); !(c)-[]->()") assert(triplets.columns === Array("a", "b", "c")) - val res = triplets.select("a.id", "b.id", "c.id") - .collect().toSet - val expected = Set( - Row(0L, 1L, 3L), - Row(1L, 0L, 3L), - Row(1L, 2L, 3L), - Row(2L, 3L, 3L), - Row(2L, 0L, 3L) - ) + val res = triplets + .select("a.id", "b.id", "c.id") + .collect() + .toSet + val expected = + Set(Row(0L, 1L, 3L), Row(1L, 0L, 3L), Row(1L, 2L, 3L), Row(2L, 3L, 3L), Row(2L, 0L, 3L)) compareResultToExpected(res, expected) } @@ -408,8 +394,10 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { val seq = g.find("(a)-[]->(b); !(b)-[]->(c); !(c)-[]->(a)") assert(seq.columns === Array("a", "b", "c")) - val res = seq.select("a.id", "b.id", "c.id") - .collect().toSet + val res = seq + .select("a.id", "b.id", "c.id") + .collect() + .toSet val expected = Set( Row(0L, 1L, 3L), Row(1L, 0L, 2L), @@ -421,8 +409,7 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { Row(2L, 3L, 3L), Row(2L, 0L, 0L), Row(2L, 0L, 2L), - Row(2L, 0L, 3L) - ) + Row(2L, 0L, 3L)) compareResultToExpected(res, expected) } @@ -430,9 +417,11 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { val edgePairs = g.find("(a)-[]->(b); !(a)-[]->(c); !(c)-[]->(d)") assert(edgePairs.columns === Array("a", "b", "c", "d")) - val res = edgePairs.select("a.id", "b.id", "c.id", "d.id") - .where("a.id = 0 AND a.id != b.id") // check subset for brevity - .collect().toSet + val res = edgePairs + .select("a.id", "b.id", "c.id", "d.id") + .where("a.id = 0 AND a.id != b.id") // check subset for brevity + .collect() + .toSet val expected = Set( Row(0L, 1L, 0L, 0L), Row(0L, 1L, 0L, 2L), @@ -442,26 +431,31 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { Row(0L, 1L, 3L, 0L), Row(0L, 1L, 3L, 1L), Row(0L, 1L, 3L, 2L), - Row(0L, 1L, 3L, 3L) - ) + Row(0L, 1L, 3L, 3L)) compareResultToExpected(res, expected) } /* ============================== 0 non-negated terms ============================== */ test("query without non-negated terms, with one named vertex") { - val res = g.find("!(v)-[]->()") + val res = g + .find("!(v)-[]->()") .select("v.id") - .collect().toSet + .collect() + .toSet compareResultToExpected(res, Set(Row(3L))) } test("query without non-negated terms, with two named vertices") { - val res = g.find("!(u)-[]->(v)") + val res = g + .find("!(u)-[]->(v)") .select("u.id", "v.id") - .collect().toSet - val expected = noEdges.select(col("src").alias("u"), col("dst").alias("v")) - .collect().toSet + .collect() + .toSet + val expected = noEdges + .select(col("src").alias("u"), col("dst").alias("v")) + .collect() + .toSet compareResultToExpected(res, expected) } @@ -469,7 +463,8 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("named edges") { // edges whose destination leads nowhere - val edges = g.find("()-[e]->(v); !(v)-[]->()") + val edges = g + .find("()-[e]->(v); !(v)-[]->()") .select("e.src", "e.dst") val res = edges.collect().toSet compareResultToExpected(res, Set(Row(2L, 3L))) @@ -490,7 +485,8 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { } test("find column order") { - val fof = g.find("(u)-[e]->(v); (v)-[]->(w); !(u)-[]->(w); !(w)-[]->(u)") + val fof = g + .find("(u)-[e]->(v); (v)-[]->(w); !(u)-[]->(w); !(w)-[]->(u)") .where("u.id != v.id AND v.id != w.id AND u.id != w.id") assert(fof.columns === Array("u", "e", "v", "w")) compareResultToExpected( @@ -539,20 +535,18 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { /* ============================= More complex use case examples ============================== */ test("triangles via post-hoc filter") { - val triangles = g.find("(a)-[]->(b); (b)-[]->(c); (d)-[]->(e)") + val triangles = g + .find("(a)-[]->(b); (b)-[]->(c); (d)-[]->(e)") .where("c.id = d.id AND e.id = a.id") .select("a.id", "b.id", "c.id") val res = triangles.collect().toSet - compareResultToExpected(res, Set( - Row(0L, 1L, 2L), - Row(2L, 0L, 1L), - Row(1L, 2L, 0L) - )) + compareResultToExpected(res, Set(Row(0L, 1L, 2L), Row(2L, 0L, 1L), Row(1L, 2L, 0L))) } test("stateful predicates via UDFs") { - val chain4 = g.find("(a)-[ab]->(b); (b)-[bc]->(c); (c)-[cd]->(d)") + val chain4 = g + .find("(a)-[ab]->(b); (b)-[bc]->(c); (c)-[cd]->(d)") .where("a.id != b.id AND b.id != c.id AND c.id != a.id") // Using DataFrame operations, but not really operating in a stateful manner @@ -562,7 +556,9 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { .reduce(_ + _) >= 2) assert(chainWith2Friends.count() === 4) - chainWith2Friends.select("ab.relationship", "bc.relationship", "cd.relationship").collect() + chainWith2Friends + .select("ab.relationship", "bc.relationship", "cd.relationship") + .collect() .foreach { case Row(ab: String, bc: String, cd: String) => val numFriends = Seq(ab, bc, cd).map(r => if (r == "friend") 1 else 0).sum assert(numFriends >= 2) @@ -572,8 +568,8 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { def sumFriends(cnt: Column, relationship: Column): Column = { when(relationship === "friend", cnt + 1).otherwise(cnt) } - val condition = Seq("ab", "bc", "cd"). - foldLeft(lit(0))((cnt, e) => sumFriends(cnt, col(e)("relationship"))) + val condition = + Seq("ab", "bc", "cd").foldLeft(lit(0))((cnt, e) => sumFriends(cnt, col(e)("relationship"))) val chainWith2Friends2 = chain4.where(condition >= 2) compareResultToExpected(chainWith2Friends.collect().toSet, chainWith2Friends2.collect().toSet) @@ -604,5 +600,5 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { } assert(joins.isEmpty, s"joins was non-empty: ${joins.map(_.toString()).mkString("; ")}") } - */ + */ } diff --git a/src/test/scala/org/graphframes/SparkFunSuite.scala b/src/test/scala/org/graphframes/SparkFunSuite.scala index aaf1fb536..e53c1111a 100644 --- a/src/test/scala/org/graphframes/SparkFunSuite.scala +++ b/src/test/scala/org/graphframes/SparkFunSuite.scala @@ -27,9 +27,8 @@ private[graphframes] abstract class SparkFunSuite extends FunSuite with Logging /** * Log the suite name and the test name before and after each test. * - * Subclasses should never override this method. If they wish to run - * custom code before and after each test, they should mix in the - * {{org.scalatest.BeforeAndAfter}} trait instead. + * Subclasses should never override this method. If they wish to run custom code before and + * after each test, they should mix in the {{org.scalatest.BeforeAndAfter}} trait instead. */ final protected override def withFixture(test: NoArgTest): Outcome = { val testName = test.text diff --git a/src/test/scala/org/graphframes/TestUtils.scala b/src/test/scala/org/graphframes/TestUtils.scala index c05a4f49d..629984da7 100644 --- a/src/test/scala/org/graphframes/TestUtils.scala +++ b/src/test/scala/org/graphframes/TestUtils.scala @@ -15,16 +15,19 @@ object TestUtils { case Some(m) => (m.group(1).toInt, m.group(2).toInt) case None => - throw new IllegalArgumentException(s"Spark tried to parse '$sparkVersion' as a Spark" + - s" version string, but it could not find the major and minor version numbers.") + throw new IllegalArgumentException( + s"Spark tried to parse '$sparkVersion' as a Spark" + + s" version string, but it could not find the major and minor version numbers.") } } /** * Check whether the given schema contains a column of the required data type. * - * @param colName column name - * @param dataType required column data type + * @param colName + * column name + * @param dataType + * required column data type */ def checkColumnType( schema: StructType, @@ -33,7 +36,8 @@ object TestUtils { msg: String = ""): Unit = { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" - require(actualDataType.equals(dataType), + require( + actualDataType.equals(dataType), s"Column $colName must be of type $dataType but was actually $actualDataType.$message") } @@ -47,10 +51,9 @@ object TestUtils { } /** - * Test validity of both GraphFrames. - * Also ensure that the GraphFrames match: - * - vertex column schema match - * - `before` columns are a subset of the `after` columns, and schema match + * Test validity of both GraphFrames. Also ensure that the GraphFrames match: + * - vertex column schema match + * - `before` columns are a subset of the `after` columns, and schema match */ def testSchemaInvariants(before: GraphFrame, after: GraphFrame): Unit = { testSchemaInvariant(before) @@ -76,23 +79,24 @@ object TestUtils { if (!afterVNames.contains(f.name)) { throw new Exception(s"vertex error: ${f.name} should be in ${afterVNames.mkString(", ")}") } - assert(before.vertices.schema(f.name) == after.vertices.schema(f.name), + assert( + before.vertices.schema(f.name) == after.vertices.schema(f.name), s"${before.vertices.schema} != ${after.vertices.schema}") } for (f <- before.edges.schema.iterator) { val a = before.edges.schema(f.name) val b = after.edges.schema(f.name) - assert(a.dataType == b.dataType, + assert( + a.dataType == b.dataType, s"${before.edges.schema} not a subset of ${after.edges.schema}") } } /** - * Test validity of both GraphFrames. - * Also ensure that the GraphFrames match: - * - vertex column schema match - * - `before` columns are a subset of the `after` columns, and schema match + * Test validity of both GraphFrames. Also ensure that the GraphFrames match: + * - vertex column schema match + * - `before` columns are a subset of the `after` columns, and schema match */ def testSchemaInvariants(before: GraphFrame, afterVertices: DataFrame): Unit = { testSchemaInvariants(before, GraphFrame(afterVertices, before.edges)) diff --git a/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala b/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala index 5f7e3382f..e2a7af557 100644 --- a/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala +++ b/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala @@ -23,11 +23,10 @@ import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite} import org.graphframes.examples.BeliefPropagation._ import org.graphframes.examples.Graphs.gridIsingModel - class BeliefPropagationSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("BP using GraphX and GraphFrame aggregateMessages") { - val n = 3 // graph is n x n + val n = 3 // graph is n x n val numIter = 5 // iterations of BP // Create graphical model g. @@ -41,7 +40,8 @@ class BeliefPropagationSuite extends SparkFunSuite with GraphFrameTestSparkConte // Check beliefs. def checkResults(v: DataFrame): Unit = { v.select("belief").collect().foreach { case Row(belief: Double) => - assert(belief >= 0.0 && belief <= 1.0, + assert( + belief >= 0.0 && belief <= 1.0, s"Expected belief to be probability in [0,1], but found $belief") } } @@ -51,10 +51,12 @@ class BeliefPropagationSuite extends SparkFunSuite with GraphFrameTestSparkConte // Compare beliefs. val gxBeliefs = gxResults.vertices.select("id", "belief") val gfBeliefs = gfResults.vertices.select("id", "belief") - gxBeliefs.join(gfBeliefs, "id") + gxBeliefs + .join(gfBeliefs, "id") .select(gxBeliefs("belief").as("gxBelief"), gfBeliefs("belief").as("gfBelief")) - .collect().foreach { case Row(gxBelief: Double, gfBelief: Double) => + .collect() + .foreach { case Row(gxBelief: Double, gfBelief: Double) => assert(math.abs(gxBelief - gfBelief) <= 1e-6) - } + } } } diff --git a/src/test/scala/org/graphframes/examples/GraphsSuite.scala b/src/test/scala/org/graphframes/examples/GraphsSuite.scala index 9165cfaa9..5d3076424 100644 --- a/src/test/scala/org/graphframes/examples/GraphsSuite.scala +++ b/src/test/scala/org/graphframes/examples/GraphsSuite.scala @@ -19,7 +19,6 @@ package org.graphframes.examples import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite} - class GraphsSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("empty graph") { diff --git a/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala b/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala index 6f199df9e..3ea405ec4 100644 --- a/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala +++ b/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.functions._ import org.graphframes.examples.Graphs import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite} - class AggregateMessagesSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("aggregateMessages") { @@ -41,18 +40,28 @@ class AggregateMessagesSuite extends SparkFunSuite with GraphFrameTestSparkConte .agg(sum(AM.msg).as("summedAges")) // Convert agg to a Map. import org.apache.spark.sql._ - val aggMap: Map[String, Long] = agg.select("id", "summedAges").collect().map { - case Row(id: String, s: Long) => id -> s - }.toMap + val aggMap: Map[String, Long] = agg + .select("id", "summedAges") + .collect() + .map { case Row(id: String, s: Long) => + id -> s + } + .toMap // Compute the truth via brute force for comparison. val trueAgg: Map[String, Int] = { - val user2age = g.vertices.select("id", "age").collect().map { - case Row(id: String, age: Int) => id -> age - }.toMap + val user2age = g.vertices + .select("id", "age") + .collect() + .map { case Row(id: String, age: Int) => + id -> age + } + .toMap val a = mutable.HashMap.empty[String, Int] g.edges.select("src", "dst", "relationship").collect().foreach { case Row(src: String, dst: String, relationship: String) => - a.put(src, a.getOrElse(src, 0) + user2age(dst) + (if (relationship == "friend") 1 else 0)) + a.put( + src, + a.getOrElse(src, 0) + user2age(dst) + (if (relationship == "friend") 1 else 0)) a.put(dst, a.getOrElse(dst, 0) + user2age(src)) } a.toMap @@ -69,9 +78,13 @@ class AggregateMessagesSuite extends SparkFunSuite with GraphFrameTestSparkConte .sendToDst(msgToDst2) .agg("sum(MSG) AS `summedAges`") // Convert agg2 to a Map. - val agg2Map: Map[String, Long] = agg2.select("id", "summedAges").collect().map { - case Row(id: String, s: Long) => id -> s - }.toMap + val agg2Map: Map[String, Long] = agg2 + .select("id", "summedAges") + .collect() + .map { case Row(id: String, s: Long) => + id -> s + } + .toMap // Compare to the true values. agg2Map.keys.foreach { case user => assert(agg2Map(user) === trueAgg(user), s"Failure on user $user") diff --git a/src/test/scala/org/graphframes/lib/BFSSuite.scala b/src/test/scala/org/graphframes/lib/BFSSuite.scala index d5982f9ca..081c3de99 100644 --- a/src/test/scala/org/graphframes/lib/BFSSuite.scala +++ b/src/test/scala/org/graphframes/lib/BFSSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.{DataFrame, Row} import org.graphframes.{GraphFrameTestSparkContext, GraphFrame, SparkFunSuite} - class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext { // First graph uses String IDs @@ -44,36 +43,27 @@ class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext { | | d <- e --> f Also, self-edge for f */ - v = spark.createDataFrame(List( - ("a", "f"), - ("b", "f"), - ("c", "m"), - ("d", "f"), - ("e", "m"), - ("f", "m") - )).toDF("id", "gender") - e = spark.createDataFrame(List( - ("a", "b", "friend"), - ("b", "c", "follow"), - ("c", "b", "follow"), - ("f", "c", "follow"), - ("e", "f", "follow"), - ("e", "d", "friend"), - ("d", "a", "friend"), - ("f", "f", "self") - )).toDF("src", "dst", "relationship") + v = spark + .createDataFrame( + List(("a", "f"), ("b", "f"), ("c", "m"), ("d", "f"), ("e", "m"), ("f", "m"))) + .toDF("id", "gender") + e = spark + .createDataFrame( + List( + ("a", "b", "friend"), + ("b", "c", "follow"), + ("c", "b", "follow"), + ("f", "c", "follow"), + ("e", "f", "follow"), + ("e", "d", "friend"), + ("d", "a", "friend"), + ("f", "f", "self"))) + .toDF("src", "dst", "relationship") g = GraphFrame(v, e) - v2 = spark.createDataFrame(List( - (0L, "f"), - (1L, "m"), - (2L, "m"), - (3L, "f"))).toDF("id", "gender") - e2 = spark.createDataFrame(List( - (0L, 1L), - (1L, 2L), - (2L, 3L), - (2L, 0L))).toDF("src", "dst") + v2 = + spark.createDataFrame(List((0L, "f"), (1L, "m"), (2L, "m"), (3L, "f"))).toDF("id", "gender") + e2 = spark.createDataFrame(List((0L, 1L), (1L, 2L), (2L, 3L), (2L, 0L))).toDF("src", "dst") g2 = GraphFrame(v2, e2) } @@ -121,26 +111,32 @@ class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("maxPathLength: length 1") { val paths = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "f").maxPathLength(1).run() assert(paths.count() === 1) - val paths0 = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "f").maxPathLength(0).run() + val paths0 = + g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "f").maxPathLength(0).run() assert(paths0.count() === 0) } test("maxPathLength: length > 1") { val paths = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "b").maxPathLength(3).run() assert(paths.count() === 2) - val paths0 = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "b").maxPathLength(2).run() + val paths0 = + g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "b").maxPathLength(2).run() assert(paths0.count() === 0) } test("edge filter") { - val paths1 = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "b") + val paths1 = g.bfs + .fromExpr(col("id") === "e") + .toExpr(col("id") === "b") .edgeFilter(col("src") =!= "d") .run() assert(paths1.count() === 1) paths1.select("e0.dst").collect().foreach { case Row(id: String) => assert(id === "f") } - val paths2 = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "b") + val paths2 = g.bfs + .fromExpr(col("id") === "e") + .toExpr(col("id") === "b") .edgeFilter(col("relationship") === "friend") .run() assert(paths2.count() === 1) @@ -150,7 +146,9 @@ class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext { } test("string expressions") { - val paths1 = g.bfs.fromExpr("id = 'e'").toExpr("id = 'b'") + val paths1 = g.bfs + .fromExpr("id = 'e'") + .toExpr("id = 'b'") .edgeFilter("src != 'd'") .run() assert(paths1.count() === 1) diff --git a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index 96acd0bde..f55ae4edd 100644 --- a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -49,18 +49,20 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon } test("single vertex") { - val v = spark.createDataFrame(List( - (0L, "a", "b"))).toDF("id", "vattr", "gender") + val v = spark.createDataFrame(List((0L, "a", "b"))).toDF("id", "vattr", "gender") // Create an empty dataframe with the proper columns. - val e = spark.createDataFrame(List((0L, 0L, 1L))).toDF("src", "dst", "test") + val e = spark + .createDataFrame(List((0L, 0L, 1L))) + .toDF("src", "dst", "test") .filter("src > 10") val g = GraphFrame(v, e) val comps = ConnectedComponents.run(g) TestUtils.testSchemaInvariants(g, comps) TestUtils.checkColumnType(comps.schema, "component", DataTypes.LongType) assert(comps.count() === 1) - assert(comps.select("id", "component", "vattr", "gender").collect() - === Seq(Row(0L, 0L, "a", "b"))) + assert( + comps.select("id", "component", "vattr", "gender").collect() + === Seq(Row(0L, 0L, "a", "b"))) } test("disconnected vertices") { @@ -74,11 +76,8 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon } test("two connected vertices") { - val v = spark.createDataFrame(List( - (0L, "a0", "b0"), - (1L, "a1", "b1"))).toDF("id", "A", "B") - val e = spark.createDataFrame(List( - (0L, 1L, "a01", "b01"))).toDF("src", "dst", "A", "B") + val v = spark.createDataFrame(List((0L, "a0", "b0"), (1L, "a1", "b1"))).toDF("id", "A", "B") + val e = spark.createDataFrame(List((0L, 1L, "a01", "b01"))).toDF("src", "dst", "A", "B") val g = GraphFrame(v, e) val comps = g.connectedComponents.run() TestUtils.testSchemaInvariants(g, comps) @@ -113,10 +112,9 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon test("two components") { val vertices = spark.range(6L).toDF(ID) - val edges = spark.createDataFrame(Seq( - (0L, 1L), (1L, 2L), (2L, 0L), - (3L, 4L), (4L, 5L), (5L, 3L) - )).toDF(SRC, DST) + val edges = spark + .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L))) + .toDF(SRC, DST) val g = GraphFrame(vertices, edges) val components = g.connectedComponents.run() val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L)) @@ -125,10 +123,15 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon test("one component, differing edge directions") { val vertices = spark.range(5L).toDF(ID) - val edges = spark.createDataFrame(Seq( - // 0 -> 4 -> 3 <- 2 -> 1 - (0L, 4L), (4L, 3L), (2L, 3L), (2L, 1L) - )).toDF(SRC, DST) + val edges = spark + .createDataFrame( + Seq( + // 0 -> 4 -> 3 <- 2 -> 1 + (0L, 4L), + (4L, 3L), + (2L, 3L), + (2L, 1L))) + .toDF(SRC, DST) val g = GraphFrame(vertices, edges) val components = g.connectedComponents.run() val expected = Set((0L to 4L).toSet) @@ -137,10 +140,9 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon test("two components and two dangling vertices") { val vertices = spark.range(8L).toDF(ID) - val edges = spark.createDataFrame(Seq( - (0L, 1L), (1L, 2L), (2L, 0L), - (3L, 4L), (4L, 5L), (5L, 3L) - )).toDF(SRC, DST) + val edges = spark + .createDataFrame(Seq((0L, 1L), (1L, 2L), (2L, 0L), (3L, 4L), (4L, 5L), (5L, 3L))) + .toDF(SRC, DST) val g = GraphFrame(vertices, edges) val components = g.connectedComponents.run() val expected = Set(Set(0L, 1L, 2L), Set(3L, 4L, 5L), Set(6L), Set(7L)) @@ -151,7 +153,7 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon val friends = Graphs.friends val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) for ((algorithm, broadcastThreshold) <- - Seq(("graphx", 1000000), ("graphframes", 100000), ("graphframes", 1))) { + Seq(("graphx", 1000000), ("graphframes", 100000), ("graphframes", 1))) { val components = friends.connectedComponents .setAlgorithm(algorithm) .setBroadcastThreshold(broadcastThreshold) @@ -176,15 +178,17 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) val cc = new ConnectedComponents(friends) - assert(cc.getCheckpointInterval === 2, + assert( + cc.getCheckpointInterval === 2, s"Default checkpoint interval should be 2, but got ${cc.getCheckpointInterval}.") val checkpointDir = sc.getCheckpointDir assert(checkpointDir.nonEmpty) sc.setCheckpointDir(null) - withClue("Should throw an IOException if sc.getCheckpointDir is empty " + - "and checkpointInterval is positive.") { + withClue( + "Should throw an IOException if sc.getCheckpointDir is empty " + + "and checkpointInterval is positive.") { intercept[IOException] { cc.run() } @@ -198,19 +202,22 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon val components0 = cc.setCheckpointInterval(0).run() assertComponents(components0, expected) - assert(!isFromCheckpoint(components0), + assert( + !isFromCheckpoint(components0), "The result shouldn't depend on checkpoint data if checkpointing is disabled.") sc.setCheckpointDir(checkpointDir.get) val components1 = cc.setCheckpointInterval(1).run() assertComponents(components1, expected) - assert(isFromCheckpoint(components1), + assert( + isFromCheckpoint(components1), "The result should depend on checkpoint data if checkpoint interval is 1.") val components10 = cc.setCheckpointInterval(10).run() assertComponents(components10, expected) - assert(!isFromCheckpoint(components10), + assert( + !isFromCheckpoint(components10), "The result shouldn't depend on checkpoint data if converged before first checkpoint.") } @@ -226,7 +233,10 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon val cc = friends.connectedComponents assert(cc.getIntermediateStorageLevel === StorageLevel.MEMORY_AND_DISK) - for (storageLevel <- Seq(StorageLevel.DISK_ONLY, StorageLevel.MEMORY_ONLY, StorageLevel.NONE)) { + for (storageLevel <- Seq( + StorageLevel.DISK_ONLY, + StorageLevel.MEMORY_ONLY, + StorageLevel.NONE)) { // TODO: it is not trivial to confirm the actual storage level used val components = cc .setIntermediateStorageLevel(storageLevel) @@ -243,22 +253,27 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon } } - private def assertComponents[T: ClassTag:TypeTag]( + private def assertComponents[T: ClassTag: TypeTag]( actual: DataFrame, expected: Set[Set[T]]): Unit = { import actual.sparkSession.implicits._ // note: not using agg + collect_list because collect_list is not available in 1.6.2 w/o hive - val actualComponents = actual.select("component", "id").as[(Long, T)].rdd + val actualComponents = actual + .select("component", "id") + .as[(Long, T)] + .rdd .groupByKey() .values .map(_.toSeq) .collect() .map { ids => val idSet = ids.toSet - assert(idSet.size === ids.size, + assert( + idSet.size === ids.size, s"Found duplicated component assignment in [${ids.mkString(",")}].") idSet - }.toSet + } + .toSet assert(actualComponents === expected) } } diff --git a/src/test/scala/org/graphframes/lib/PageRankSuite.scala b/src/test/scala/org/graphframes/lib/PageRankSuite.scala index eca810ec7..40343f3f5 100644 --- a/src/test/scala/org/graphframes/lib/PageRankSuite.scala +++ b/src/test/scala/org/graphframes/lib/PageRankSuite.scala @@ -33,7 +33,8 @@ class PageRankSuite extends SparkFunSuite with GraphFrameTestSparkContext { val errorTol = 1.0e-5 val pr = g.pageRank .resetProbability(resetProb) - .tol(errorTol).run() + .tol(errorTol) + .run() TestUtils.testSchemaInvariants(g, pr) TestUtils.checkColumnType(pr.vertices.schema, "pagerank", DataTypes.DoubleType) TestUtils.checkColumnType(pr.edges.schema, "weight", DataTypes.DoubleType) @@ -43,7 +44,8 @@ class PageRankSuite extends SparkFunSuite with GraphFrameTestSparkContext { val results = Graphs.friends.pageRank.resetProbability(0.15).maxIter(10).sourceId("a").run() val gRank = results.vertices.filter(col("id") === "g").select("pagerank").first().getDouble(0) - assert(gRank === 0.0, + assert( + gRank === 0.0, s"User g (Gabby) doesn't connect with a. So its pagerank should be 0 but we got $gRank.") } } diff --git a/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala b/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala index 1fc2d0daa..0fa75ce5b 100644 --- a/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala +++ b/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala @@ -70,8 +70,9 @@ class ParallelPersonalizedPageRankSuite extends SparkFunSuite with GraphFrameTes // In Spark <2.4, sourceIds must be smaller than Int.MaxValue, // which might not be the case for LONG_ID in graph.indexedVertices. - if (Version.valueOf(org.apache.spark.SPARK_VERSION) - .greaterThanOrEqualTo(Version.valueOf("2.4.0"))) { + if (Version + .valueOf(org.apache.spark.SPARK_VERSION) + .greaterThanOrEqualTo(Version.valueOf("2.4.0"))) { test("friends graph with parallel personalized PageRank") { val g = Graphs.friends val resetProb = 0.15 @@ -89,14 +90,17 @@ class ParallelPersonalizedPageRankSuite extends SparkFunSuite with GraphFrameTes .filter { row: Row => vertexIds.size != row.getAs[SparseVector](0).size } - assert(prInvalid.size === 0, + assert( + prInvalid.size === 0, s"found ${prInvalid.size} entries with invalid number of returned personalized pagerank vector") val gRank = pr.vertices .filter(col("id") === "g") .select("pageranks") - .first().getAs[SparseVector](0) - assert(gRank.numNonzeros === 0, + .first() + .getAs[SparseVector](0) + assert( + gRank.numNonzeros === 0, s"User g (Gabby) doesn't connect with a. So its pagerank should be 0 but we got ${gRank.numNonzeros}.") } } diff --git a/src/test/scala/org/graphframes/lib/PregelSuite.scala b/src/test/scala/org/graphframes/lib/PregelSuite.scala index e6af54440..dc6cc35c8 100644 --- a/src/test/scala/org/graphframes/lib/PregelSuite.scala +++ b/src/test/scala/org/graphframes/lib/PregelSuite.scala @@ -35,8 +35,7 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { (2L, 0L), (3L, 4L), // 3 has no in-links (4L, 0L), - (4L, 2L) - ).toDF("src", "dst").cache() + (4L, 2L)).toDF("src", "dst").cache() val vertices = GraphFrame.fromEdges(edges).outDegrees.cache() val numVertices = vertices.count() val graph = GraphFrame(vertices, edges) @@ -45,14 +44,19 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { // NOTE: This version doesn't handle nodes with no out-links. val ranks = graph.pregel .setMaxIter(5) - .withVertexColumn("rank", lit(1.0 / numVertices), + .withVertexColumn( + "rank", + lit(1.0 / numVertices), coalesce(Pregel.msg, lit(0.0)) * (1.0 - alpha) + alpha / numVertices) .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree")) .aggMsgs(sum(Pregel.msg)) .run() - val result = ranks.sort(col("id")) - .select("rank").as[Double].collect() + val result = ranks + .sort(col("id")) + .select("rank") + .as[Double] + .collect() assert(result.sum === 1.0 +- 1e-6) val expected = Seq(0.245, 0.224, 0.303, 0.03, 0.197) result.zip(expected).foreach { case (r, e) => @@ -63,20 +67,20 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("chain propagation") { val n = 5 val verDF = (1 to n).toDF("id").repartition(3) - val edgeDF = (1 until n).map(x => (x, x + 1)) - .toDF("src", "dst").repartition(3) + val edgeDF = (1 until n) + .map(x => (x, x + 1)) + .toDF("src", "dst") + .repartition(3) val graph = GraphFrame(verDF, edgeDF) val resultDF = graph.pregel .setMaxIter(n - 1) - .withVertexColumn("value", + .withVertexColumn( + "value", when(col("id") === lit(1), lit(1)).otherwise(lit(0)), - when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value")) - ) - .sendMsgToDst( - when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value")) - ) + when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))) + .sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value"))) .aggMsgs(max(Pregel.msg)) .run() @@ -86,20 +90,20 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("reverse chain propagation") { val n = 5 val verDF = (1 to n).toDF("id").repartition(3) - val edgeDF = (1 until n).map(x => (x + 1, x)) - .toDF("src", "dst").repartition(3) + val edgeDF = (1 until n) + .map(x => (x + 1, x)) + .toDF("src", "dst") + .repartition(3) val graph = GraphFrame(verDF, edgeDF) val resultDF = graph.pregel .setMaxIter(n - 1) - .withVertexColumn("value", + .withVertexColumn( + "value", when(col("id") === lit(1), lit(1)).otherwise(lit(0)), - when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value")) - ) - .sendMsgToSrc( - when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.dst("value")) - ) + when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))) + .sendMsgToSrc(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.dst("value"))) .aggMsgs(max(Pregel.msg)) .run() diff --git a/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala b/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala index 890ed5e14..b7a2df5e7 100644 --- a/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala +++ b/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.types.DataTypes import org.graphframes.{GraphFrame, GraphFrameTestSparkContext, SparkFunSuite, TestUtils} import org.graphframes.examples.Graphs - class SVDPlusPlusSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("Test SVD++ with mean square error on training set") { @@ -33,16 +32,21 @@ class SVDPlusPlusSuite extends SparkFunSuite with GraphFrameTestSparkContext { val v2 = g.svdPlusPlus.maxIter(2).run() TestUtils.testSchemaInvariants(g, v2) Seq(SVDPlusPlus.COLUMN1, SVDPlusPlus.COLUMN2).foreach { case c => - TestUtils.checkColumnType(v2.schema, c, + TestUtils.checkColumnType( + v2.schema, + c, DataTypes.createArrayType(DataTypes.DoubleType, false)) } Seq(SVDPlusPlus.COLUMN3, SVDPlusPlus.COLUMN4).foreach { case c => TestUtils.checkColumnType(v2.schema, c, DataTypes.DoubleType) } - val err = v2.select(GraphFrame.ID, SVDPlusPlus.COLUMN4).rdd.map { - case Row(vid: Long, vd: Double) => + val err = v2 + .select(GraphFrame.ID, SVDPlusPlus.COLUMN4) + .rdd + .map { case Row(vid: Long, vd: Double) => if (vid % 2 == 1) vd else 0.0 - }.reduce(_ + _) / g.edges.count() + } + .reduce(_ + _) / g.edges.count() assert(err <= svdppErr) } } diff --git a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala index fc0262d1b..69899ed9a 100644 --- a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala @@ -25,26 +25,34 @@ import org.graphframes._ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("Simple test") { - val edgeSeq = Seq((1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)).flatMap { - case e => Seq(e, e.swap) - } .map { case (src, dst) => (src.toLong, dst.toLong) } + val edgeSeq = Seq((1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)) + .flatMap { case e => + Seq(e, e.swap) + } + .map { case (src, dst) => (src.toLong, dst.toLong) } val edges = spark.createDataFrame(edgeSeq).toDF("src", "dst") val graph = GraphFrame.fromEdges(edges) // Ground truth val shortestPaths = Set( - (1, Map(1 -> 0, 4 -> 2)), (2, Map(1 -> 1, 4 -> 2)), (3, Map(1 -> 2, 4 -> 1)), - (4, Map(1 -> 2, 4 -> 0)), (5, Map(1 -> 1, 4 -> 1)), (6, Map(1 -> 3, 4 -> 1))) + (1, Map(1 -> 0, 4 -> 2)), + (2, Map(1 -> 1, 4 -> 2)), + (3, Map(1 -> 2, 4 -> 1)), + (4, Map(1 -> 2, 4 -> 0)), + (5, Map(1 -> 1, 4 -> 1)), + (6, Map(1 -> 3, 4 -> 1))) val landmarks = Seq(1, 4).map(_.toLong) val v2 = graph.shortestPaths.landmarks(landmarks).run() TestUtils.testSchemaInvariants(graph, v2) - TestUtils.checkColumnType(v2.schema, "distances", + TestUtils.checkColumnType( + v2.schema, + "distances", DataTypes.createMapType(v2.schema("id").dataType, DataTypes.IntegerType, false)) val newVs = v2.select("id", "distances").collect().toSeq - val results = newVs.map { - case Row(id: Long, spMap: Map[Long, Int] @unchecked) => (id, spMap) + val results = newVs.map { case Row(id: Long, spMap: Map[Long, Int] @unchecked) => + (id, spMap) } assert(results.toSet === shortestPaths) } @@ -52,13 +60,21 @@ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("friends graph") { val friends = examples.Graphs.friends val v = friends.shortestPaths.landmarks(Seq("a", "d")).run() - val expected = Set[(String, Map[String, Int])](("a", Map("a" -> 0, "d" -> 2)), ("b", Map.empty), - ("c", Map.empty), ("d", Map("a" -> 1, "d" -> 0)), ("e", Map("a" -> 2, "d" -> 1)), - ("f", Map.empty), ("g", Map.empty)) - val results = v.select("id", "distances").collect().map { - case Row(id: String, spMap: Map[String, Int] @unchecked) => + val expected = Set[(String, Map[String, Int])]( + ("a", Map("a" -> 0, "d" -> 2)), + ("b", Map.empty), + ("c", Map.empty), + ("d", Map("a" -> 1, "d" -> 0)), + ("e", Map("a" -> 2, "d" -> 1)), + ("f", Map.empty), + ("g", Map.empty)) + val results = v + .select("id", "distances") + .collect() + .map { case Row(id: String, spMap: Map[String, Int] @unchecked) => (id, spMap) - }.toSet + } + .toSet assert(results === expected) } diff --git a/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala index 65ac324a0..c549cc5ca 100644 --- a/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala +++ b/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala @@ -24,19 +24,15 @@ import org.graphframes.{GraphFrameTestSparkContext, GraphFrame, SparkFunSuite, T class StronglyConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("Island Strongly Connected Components") { - val vertices = spark.createDataFrame(Seq( - (1L, "a"), - (2L, "b"), - (3L, "c"), - (4L, "d"), - (5L, "e"))).toDF("id", "value") + val vertices = spark + .createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"), (4L, "d"), (5L, "e"))) + .toDF("id", "value") val edges = spark.createDataFrame(Seq.empty[(Long, Long)]).toDF("src", "dst") val graph = GraphFrame(vertices, edges) val c = graph.stronglyConnectedComponents.maxIter(5).run() TestUtils.testSchemaInvariants(graph, c) TestUtils.checkColumnType(c.schema, "component", DataTypes.LongType) - for (Row(id: Long, component: Long, _) - <- c.select("id", "component", "value").collect()) { + for (Row(id: Long, component: Long, _) <- c.select("id", "component", "value").collect()) { assert(id === component) } } diff --git a/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala b/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala index 2ac2127a5..9ad04eeca 100644 --- a/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala +++ b/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala @@ -26,19 +26,24 @@ class TriangleCountSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("Count a single triangle") { val edges = spark.createDataFrame(Array(0L -> 1L, 1L -> 2L, 2L -> 0L)).toDF("src", "dst") - val vertices = spark.createDataFrame(Seq((0L, "a"), (1L, "b"), (2L, "c"))) + val vertices = spark + .createDataFrame(Seq((0L, "a"), (1L, "b"), (2L, "c"))) .toDF("id", "a") val g = GraphFrame(vertices, edges) val v2 = g.triangleCount.run() TestUtils.testSchemaInvariants(g, v2) TestUtils.checkColumnType(v2.schema, "count", DataTypes.LongType) v2.select("id", "count", "a") - .collect().foreach { case Row(vid: Long, count: Long, _) => assert(count === 1) } + .collect() + .foreach { case Row(vid: Long, count: Long, _) => assert(count === 1) } } test("Count two triangles") { - val edges = spark.createDataFrame(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ - Array(0L -> -1L, -1L -> -2L, -2L -> 0L)).toDF("src", "dst") + val edges = spark + .createDataFrame( + Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ + Array(0L -> -1L, -1L -> -2L, -2L -> 0L)) + .toDF("src", "dst") val g = GraphFrame.fromEdges(edges) val v2 = g.triangleCount.run() v2.select("id", "count").collect().foreach { case Row(id: Long, count: Long) => @@ -67,8 +72,11 @@ class TriangleCountSuite extends SparkFunSuite with GraphFrameTestSparkContext { } test("Count a single triangle with duplicate edges") { - val edges = spark.createDataFrame(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ - Array(0L -> 1L, 1L -> 2L, 2L -> 0L)).toDF("src", "dst") + val edges = spark + .createDataFrame( + Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ + Array(0L -> 1L, 1L -> 2L, 2L -> 0L)) + .toDF("src", "dst") val g = GraphFrame.fromEdges(edges) val v2 = g.triangleCount.run() v2.select("id", "count").collect().foreach { case Row(id: Long, count: Long) => diff --git a/src/test/scala/org/graphframes/pattern/PatternSuite.scala b/src/test/scala/org/graphframes/pattern/PatternSuite.scala index 7e1e23928..5a7cdd05a 100644 --- a/src/test/scala/org/graphframes/pattern/PatternSuite.scala +++ b/src/test/scala/org/graphframes/pattern/PatternSuite.scala @@ -24,27 +24,32 @@ class PatternSuite extends SparkFunSuite { test("good parses") { assert(Pattern.parse("(abc)") === Seq(NamedVertex("abc"))) - assert(Pattern.parse("(u)-[e]->(v)") === - Seq(NamedEdge("e", NamedVertex("u"), NamedVertex("v")))) - - assert(Pattern.parse("()-[]->(v)") === - Seq(AnonymousEdge(AnonymousVertex, NamedVertex("v")))) - - assert(Pattern.parse("()-[e]->()") === - Seq(NamedEdge("e", AnonymousVertex, AnonymousVertex))) - - assert(Pattern.parse("(u)-[e]->(u)") === - Seq(NamedEdge("e", NamedVertex("u"), NamedVertex("u")))) - - assert(Pattern.parse("(u); ()-[]->(v)") === - Seq(NamedVertex("u"), AnonymousEdge(AnonymousVertex, NamedVertex("v")))) - - assert(Pattern.parse("(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w)") === - Seq( - AnonymousEdge(NamedVertex("u"), NamedVertex("v")), - AnonymousEdge(NamedVertex("v"), NamedVertex("w")), - Negation( - AnonymousEdge(NamedVertex("u"), NamedVertex("w"))))) + assert( + Pattern.parse("(u)-[e]->(v)") === + Seq(NamedEdge("e", NamedVertex("u"), NamedVertex("v")))) + + assert( + Pattern.parse("()-[]->(v)") === + Seq(AnonymousEdge(AnonymousVertex, NamedVertex("v")))) + + assert( + Pattern.parse("()-[e]->()") === + Seq(NamedEdge("e", AnonymousVertex, AnonymousVertex))) + + assert( + Pattern.parse("(u)-[e]->(u)") === + Seq(NamedEdge("e", NamedVertex("u"), NamedVertex("u")))) + + assert( + Pattern.parse("(u); ()-[]->(v)") === + Seq(NamedVertex("u"), AnonymousEdge(AnonymousVertex, NamedVertex("v")))) + + assert( + Pattern.parse("(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w)") === + Seq( + AnonymousEdge(NamedVertex("u"), NamedVertex("v")), + AnonymousEdge(NamedVertex("v"), NamedVertex("w")), + Negation(AnonymousEdge(NamedVertex("u"), NamedVertex("w"))))) } test("bad parses") { @@ -133,13 +138,9 @@ class PatternSuite extends SparkFunSuite { "(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w)", Seq.empty[String]) - testFindNamedVerticesOnlyInNegatedTerms( - "(u)-[]->(v); (v)-[]->(w)", - Seq.empty[String]) + testFindNamedVerticesOnlyInNegatedTerms("(u)-[]->(v); (v)-[]->(w)", Seq.empty[String]) - testFindNamedVerticesOnlyInNegatedTerms( - "!(u)-[]->(v)", - Seq("u", "v")) + testFindNamedVerticesOnlyInNegatedTerms("!(u)-[]->(v)", Seq("u", "v")) testFindNamedVerticesOnlyInNegatedTerms( "(u)-[]->(v); (v)-[]->(w); !(a)-[]->(b); !(v)-[]->(c)", @@ -153,13 +154,9 @@ class PatternSuite extends SparkFunSuite { } } - testFindNamedElementsInOrder( - "(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w)", - Seq("u", "v", "w")) + testFindNamedElementsInOrder("(u)-[]->(v); (v)-[]->(w); !(u)-[]->(w)", Seq("u", "v", "w")) - testFindNamedElementsInOrder( - "(u)-[]->(v); ()-[vw]->()", - Seq("u", "v", "vw")) + testFindNamedElementsInOrder("(u)-[]->(v); ()-[vw]->()", Seq("u", "v", "vw")) testFindNamedElementsInOrder( "(u)-[uv]->(v); (v)-[vw]->(w); !(u)-[]->(w); (x)",