From 531bd9961bfea081cfc6899fb73ed130529a465d Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 12 Apr 2025 11:31:48 +0200 Subject: [PATCH 1/5] Add -Xlint and scalafix && fix all --- .pre-commit-config.yaml | 8 ++ .scalafix.conf | 8 ++ build.sbt | 24 +++++- .../sql/graphframes/GraphFramesConnect.scala | 6 +- .../graphframes/GraphFramesConnectUtils.scala | 21 ++++-- project/plugins.sbt | 3 + .../scala/org/graphframes/GraphFrame.scala | 28 ++++--- .../org/graphframes/GraphFramePythonAPI.scala | 5 +- .../scala/org/graphframes/LDBCUtils.scala | 6 +- src/main/scala/org/graphframes/Logging.scala | 3 +- .../examples/BeliefPropagation.scala | 16 ++-- .../org/graphframes/examples/Graphs.scala | 13 ++-- .../graphframes/lib/AggregateMessages.scala | 14 ++-- src/main/scala/org/graphframes/lib/BFS.scala | 13 ++-- .../graphframes/lib/ConnectedComponents.scala | 17 ++--- .../graphframes/lib/GraphXConversions.scala | 14 ++-- .../graphframes/lib/LabelPropagation.scala | 1 - .../scala/org/graphframes/lib/PageRank.scala | 1 - .../lib/ParallelPersonalizedPageRank.scala | 6 +- .../scala/org/graphframes/lib/Pregel.scala | 19 ++--- .../org/graphframes/lib/SVDPlusPlus.scala | 9 ++- .../org/graphframes/lib/ShortestPaths.scala | 28 ++++--- .../lib/StronglyConnectedComponents.scala | 1 - .../org/graphframes/lib/TriangleCount.scala | 13 +++- .../org/graphframes/pattern/patterns.scala | 25 ++++--- .../org/graphframes/GraphFrameSuite.scala | 25 ++++--- .../GraphFrameTestSparkContext.scala | 17 +++-- .../org/graphframes/PatternMatchSuite.scala | 7 +- .../scala/org/graphframes/SparkFunSuite.scala | 3 +- .../scala/org/graphframes/TestUtils.scala | 6 +- .../examples/BeliefPropagationSuite.scala | 7 +- .../graphframes/examples/GraphsSuite.scala | 3 +- .../org/graphframes/ldbc/TestLDBCCases.scala | 17 +++-- .../lib/AggregateMessagesSuite.scala | 8 +- .../scala/org/graphframes/lib/BFSSuite.scala | 8 +- .../lib/ConnectedComponentsSuite.scala | 18 ++--- .../lib/LabelPropagationSuite.scala | 5 +- .../org/graphframes/lib/PageRankSuite.scala | 5 +- .../ParallelPersonalizedPageRankSuite.scala | 73 +++++++++---------- .../org/graphframes/lib/PregelSuite.scala | 4 +- .../graphframes/lib/SVDPlusPlusSuite.scala | 6 +- .../graphframes/lib/ShortestPathsSuite.scala | 4 +- .../StronglyConnectedComponentsSuite.scala | 6 +- .../graphframes/lib/TriangleCountSuite.scala | 6 +- .../graphframes/pattern/PatternSuite.scala | 3 +- 45 files changed, 314 insertions(+), 219 deletions(-) create mode 100644 .scalafix.conf diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 726a85df1..9b89cba9a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,3 +25,11 @@ repos: language: system types: [scala] pass_filenames: false + + - id: scalafix + name: scalafix + entry: build/sbt scalafixAll + language: system + types: [scala] + pass_filenames: false + diff --git a/.scalafix.conf b/.scalafix.conf new file mode 100644 index 000000000..82bd389eb --- /dev/null +++ b/.scalafix.conf @@ -0,0 +1,8 @@ +rules = [ + RemoveUnused + DisableSyntax + ProcedureSyntax + RedundantSyntax + OrganizeImports + ExplicitResultTypes +] diff --git a/build.sbt b/build.sbt index b41af000b..468f283d1 100644 --- a/build.sbt +++ b/build.sbt @@ -24,6 +24,10 @@ ThisBuild / scalaVersion := scalaVer ThisBuild / organization := "org.graphframes" ThisBuild / crossScalaVersions := Seq("2.12.18", "2.13.8") +// Scalafix configuration +ThisBuild / semanticdbEnabled := true +ThisBuild / semanticdbVersion := scalafixSemanticdb.revision + lazy val commonSetting = Seq( libraryDependencies ++= Seq( "org.apache.spark" %% "spark-graphx" % sparkVer % "provided" cross CrossVersion.for3Use2_13, @@ -55,7 +59,22 @@ lazy val commonSetting = Seq( "--add-opens=java.base/java.nio=ALL-UNNAMED", "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED", "--add-opens=java.base/java.util=ALL-UNNAMED"), - credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials")) + credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials"), + + // Scalafix + scalacOptions ++= Seq( + "-Xlint", // to enforce code quality checks + if (scalaVersion.value.startsWith("2.12")) { + // fail on warning + "-Xfatal-warnings" + } else { + "-Werror" // the same but in 2.13 + }, + // scalastyle related things + if (scalaVersion.value.startsWith("2.12")) + "-Ywarn-unused-import" + else + "-Wunused:imports")) lazy val root = (project in file(".")) .settings( @@ -108,5 +127,4 @@ lazy val connect = (project in file("graphframes-connect")) case x => val oldStrategy = (assembly / assemblyMergeStrategy).value oldStrategy(x) - } - ) + }) diff --git a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnect.scala b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnect.scala index a8088a84b..49a58f918 100644 --- a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnect.scala +++ b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnect.scala @@ -1,12 +1,10 @@ package org.apache.spark.sql.graphframes -import org.graphframes.connect.proto.GraphFramesAPI - +import com.google.protobuf import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.plugin.RelationPlugin - -import com.google.protobuf +import org.graphframes.connect.proto.GraphFramesAPI class GraphFramesConnect extends RelationPlugin { override def transform( diff --git a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala index 0e20c29ed..4d9e8d2a0 100644 --- a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala +++ b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala @@ -2,16 +2,23 @@ // Same about Column helper object. package org.apache.spark.sql.graphframes -import scala.jdk.CollectionConverters._ -import org.graphframes.{GraphFrame, GraphFramesUnreachableException} -import org.graphframes.connect.proto.{ColumnOrExpression, GraphFramesAPI, StringOrLongID} +import com.google.protobuf.ByteString +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.lit +import org.graphframes.GraphFrame +import org.graphframes.GraphFramesUnreachableException +import org.graphframes.connect.proto.ColumnOrExpression import org.graphframes.connect.proto.ColumnOrExpression.ColOrExprCase +import org.graphframes.connect.proto.GraphFramesAPI import org.graphframes.connect.proto.GraphFramesAPI.MethodCase +import org.graphframes.connect.proto.StringOrLongID import org.graphframes.connect.proto.StringOrLongID.IdCase -import org.apache.spark.sql.{Column, DataFrame, Dataset} -import org.apache.spark.sql.connect.planner.SparkConnectPlanner -import org.apache.spark.sql.functions.{col, expr, lit} -import com.google.protobuf.ByteString + +import scala.jdk.CollectionConverters._ object GraphFramesConnectUtils { private[graphframes] def parseColumnOrExpression( diff --git a/project/plugins.sbt b/project/plugins.sbt index 4a02df846..86469d44e 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -13,3 +13,6 @@ addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.4") // Protobuf things needed for the Spark Connect addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.7") libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.10.10" + +// Scalafix +addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.14.2") diff --git a/src/main/scala/org/graphframes/GraphFrame.scala b/src/main/scala/org/graphframes/GraphFrame.scala index 74b905c0c..135734456 100644 --- a/src/main/scala/org/graphframes/GraphFrame.scala +++ b/src/main/scala/org/graphframes/GraphFrame.scala @@ -17,19 +17,27 @@ package org.graphframes -import java.util.Random - -import scala.reflect.runtime.universe.TypeTag - -import org.graphframes.lib._ -import org.graphframes.pattern._ - -import org.apache.spark.graphx.{Edge, Graph} +import org.apache.spark.graphx.Edge +import org.apache.spark.graphx.Graph import org.apache.spark.ml.clustering.PowerIterationClustering import org.apache.spark.sql._ -import org.apache.spark.sql.functions.{array, broadcast, col, count, explode, expr, lit, max, monotonically_increasing_id, struct, udf} +import org.apache.spark.sql.functions.array +import org.apache.spark.sql.functions.broadcast +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.count +import org.apache.spark.sql.functions.explode +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.functions.struct +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.graphframes.lib._ +import org.graphframes.pattern._ + +import java.util.Random +import scala.reflect.runtime.universe.TypeTag /** * A representation of a graph using `DataFrame`s. @@ -686,8 +694,6 @@ object GraphFrame extends Serializable with Logging { joinCol: String, hubs: Set[T], logPrefix: String): DataFrame = { - val spark = a.sparkSession - import spark.implicits._ if (hubs.isEmpty) { // No skew. Do regular join. a.join(b, joinCol) diff --git a/src/main/scala/org/graphframes/GraphFramePythonAPI.scala b/src/main/scala/org/graphframes/GraphFramePythonAPI.scala index c9ada1788..bfe30de53 100644 --- a/src/main/scala/org/graphframes/GraphFramePythonAPI.scala +++ b/src/main/scala/org/graphframes/GraphFramePythonAPI.scala @@ -1,13 +1,12 @@ package org.graphframes import org.apache.spark.sql.DataFrame - -import org.graphframes.lib.AggregateMessages import org.graphframes.examples.Graphs +import org.graphframes.lib.AggregateMessages private[graphframes] class GraphFramePythonAPI { - def createGraph(v: DataFrame, e: DataFrame) = GraphFrame(v, e) + def createGraph(v: DataFrame, e: DataFrame): GraphFrame = GraphFrame(v, e) val ID: String = GraphFrame.ID val SRC: String = GraphFrame.SRC diff --git a/src/main/scala/org/graphframes/LDBCUtils.scala b/src/main/scala/org/graphframes/LDBCUtils.scala index 9a83304ce..b628cf810 100644 --- a/src/main/scala/org/graphframes/LDBCUtils.scala +++ b/src/main/scala/org/graphframes/LDBCUtils.scala @@ -2,12 +2,8 @@ package org.graphframes.examples import java.net.URL import java.nio.file._ -import java.util.Properties - import scala.sys.process._ -import org.graphframes.GraphFrame - object LDBCUtils { private val LDBC_URL_PREFIX = "https://datasets.ldbcouncil.org/graphalytics/" private val bufferSize = 8192 // 8Kb @@ -37,7 +33,7 @@ object LDBCUtils { private def checkZSTD(): Unit = { try { - s"zstd --version".! + "zstd --version".! } catch { case e: Exception => throw new RuntimeException( diff --git a/src/main/scala/org/graphframes/Logging.scala b/src/main/scala/org/graphframes/Logging.scala index 93ad0a96a..594178ba3 100644 --- a/src/main/scala/org/graphframes/Logging.scala +++ b/src/main/scala/org/graphframes/Logging.scala @@ -17,7 +17,8 @@ package org.graphframes -import org.slf4j.{Logger, LoggerFactory} +import org.slf4j.Logger +import org.slf4j.LoggerFactory // This needs to be accessible to org.apache.spark.graphx.lib.backport private[org] trait Logging { diff --git a/src/main/scala/org/graphframes/examples/BeliefPropagation.scala b/src/main/scala/org/graphframes/examples/BeliefPropagation.scala index df97ba85a..ab9a25814 100644 --- a/src/main/scala/org/graphframes/examples/BeliefPropagation.scala +++ b/src/main/scala/org/graphframes/examples/BeliefPropagation.scala @@ -17,11 +17,17 @@ package org.graphframes.examples -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.graphx.{Graph, VertexRDD, Edge => GXEdge} -import org.apache.spark.sql.{Column, Row, SparkSession} -import org.apache.spark.sql.functions.{col, lit, sum, udf, when} - +import org.apache.spark.graphx.Graph +import org.apache.spark.graphx.VertexRDD +import org.apache.spark.graphx.{Edge => GXEdge} +import org.apache.spark.sql.Column +import org.apache.spark.sql.Row +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.sum +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.functions.when import org.graphframes.GraphFrame import org.graphframes.examples.Graphs.gridIsingModel import org.graphframes.lib.AggregateMessages diff --git a/src/main/scala/org/graphframes/examples/Graphs.scala b/src/main/scala/org/graphframes/examples/Graphs.scala index 5066fc2ce..847f209d2 100644 --- a/src/main/scala/org/graphframes/examples/Graphs.scala +++ b/src/main/scala/org/graphframes/examples/Graphs.scala @@ -17,15 +17,16 @@ package org.graphframes.examples -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.functions.{col, lit, randn, udf} - +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.randn +import org.apache.spark.sql.functions.udf import org.graphframes.GraphFrame import org.graphframes.GraphFrame._ +import scala.reflect.runtime.universe.TypeTag + class Graphs private[graphframes] () { // Note: this cannot be values: we are creating and destroying spark contexts during the tests, // and turning these into vals means we would hold onto a potentially destroyed spark context. @@ -103,7 +104,7 @@ class Graphs private[graphframes] () { v1 <- n until (2 * n) v2 <- n until (2 * n) } yield (v1.toLong, v2.toLong, s"$v1-$v2") - val edges = edges1 ++ edges2 :+ (0L, n.toLong, s"0-$n") + val edges = edges1 ++ edges2 ++ Seq((0L, n.toLong, s"0-$n")) val vertices = (0 until (2 * n)).map { v => (v.toLong, s"$v", v) } val e = spark.createDataFrame(edges).toDF("src", "dst", "e_attr1") val v = spark.createDataFrame(vertices).toDF("id", "v_attr1", "v_attr2") diff --git a/src/main/scala/org/graphframes/lib/AggregateMessages.scala b/src/main/scala/org/graphframes/lib/AggregateMessages.scala index c3f721b21..34f42f28d 100644 --- a/src/main/scala/org/graphframes/lib/AggregateMessages.scala +++ b/src/main/scala/org/graphframes/lib/AggregateMessages.scala @@ -17,10 +17,12 @@ package org.graphframes.lib -import org.apache.spark.sql.functions.{col, expr} -import org.apache.spark.sql.{Column, DataFrame} - -import org.graphframes.{GraphFrame, Logging} +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.expr +import org.graphframes.GraphFrame +import org.graphframes.Logging /** * This is a primitive for implementing graph algorithms. This method aggregates messages from the @@ -101,8 +103,8 @@ class AggregateMessages private[graphframes] (private val g: GraphFrame) def agg(aggCol: Column): DataFrame = { require( msgToSrc.nonEmpty || msgToDst.nonEmpty, - s"To run GraphFrame.aggregateMessages," + - s" messages must be sent to src, dst, or both. Set using sendToSrc(), sendToDst().") + "To run GraphFrame.aggregateMessages," + + " messages must be sent to src, dst, or both. Set using sendToSrc(), sendToDst().") val triplets = g.triplets val sentMsgsToSrc = msgToSrc.map { msg => val msgsToSrc = diff --git a/src/main/scala/org/graphframes/lib/BFS.scala b/src/main/scala/org/graphframes/lib/BFS.scala index 802d8f63c..a8f11270c 100644 --- a/src/main/scala/org/graphframes/lib/BFS.scala +++ b/src/main/scala/org/graphframes/lib/BFS.scala @@ -17,12 +17,15 @@ package org.graphframes.lib +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.functions.{col, expr} -import org.apache.spark.sql.{Column, DataFrame, Row} - -import org.graphframes.{GraphFrame, Logging} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.expr +import org.graphframes.GraphFrame import org.graphframes.GraphFrame.nestAsCol +import org.graphframes.Logging /** * Breadth-first search (BFS) @@ -193,7 +196,7 @@ private object BFS extends Logging with Serializable { // TODO: Avoid crossing paths; i.e., touch each vertex at most once. val previousVertexChecks = Range(1, iter + 1) .map(i => paths(s"v$i.id") =!= paths(nextVertex + ".id")) - .foldLeft(paths(s"from.id") =!= paths(nextVertex + ".id"))((c1, c2) => c1 && c2) + .foldLeft(paths("from.id") =!= paths(nextVertex + ".id"))((c1, c2) => c1 && c2) paths = paths.filter(previousVertexChecks) } // Check if done by applying toExpr to column nextVertex diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index dd5f59647..7b1cb6f1d 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -17,19 +17,20 @@ package org.graphframes.lib -import java.io.IOException -import java.math.BigDecimal -import java.util.UUID - import org.apache.hadoop.fs.Path - -import org.graphframes.{GraphFrame, Logging} -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DecimalType import org.apache.spark.storage.StorageLevel +import org.graphframes.GraphFrame +import org.graphframes.Logging import org.graphframes.WithAlgorithmChoice +import java.io.IOException +import java.math.BigDecimal +import java.util.UUID + /** * Connected Components algorithm. * @@ -44,8 +45,6 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) with Logging with WithAlgorithmChoice { - import org.graphframes.lib.ConnectedComponents._ - private var broadcastThreshold: Int = 1000000 setAlgorithm(ALGO_GRAPHFRAMES) diff --git a/src/main/scala/org/graphframes/lib/GraphXConversions.scala b/src/main/scala/org/graphframes/lib/GraphXConversions.scala index dce8006f0..f8da1f091 100644 --- a/src/main/scala/org/graphframes/lib/GraphXConversions.scala +++ b/src/main/scala/org/graphframes/lib/GraphXConversions.scala @@ -17,14 +17,16 @@ package org.graphframes.lib -import scala.reflect.runtime.universe._ - import org.apache.spark.graphx.Graph -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.graphframes.GraphFrame +import org.graphframes.NoSuchVertexException -import org.graphframes.{NoSuchVertexException, GraphFrame} +import scala.reflect.runtime.universe._ /** * Convenience functions to map GraphX graphs to GraphFrames, checking for the types expected by @@ -189,7 +191,7 @@ private[graphframes] object GraphXConversions { .take(1) if (longIdRow.isEmpty) { throw new NoSuchVertexException( - s"GraphFrame algorithm given vertex ID which does not exist" + + "GraphFrame algorithm given vertex ID which does not exist" + s" in Graph. Vertex ID $vertexId not contained in $graph") } // TODO(tjh): could do more informative message diff --git a/src/main/scala/org/graphframes/lib/LabelPropagation.scala b/src/main/scala/org/graphframes/lib/LabelPropagation.scala index 877d7345b..00e20de7d 100644 --- a/src/main/scala/org/graphframes/lib/LabelPropagation.scala +++ b/src/main/scala/org/graphframes/lib/LabelPropagation.scala @@ -19,7 +19,6 @@ package org.graphframes.lib import org.apache.spark.graphx.{lib => graphxlib} import org.apache.spark.sql.DataFrame - import org.graphframes.GraphFrame /** diff --git a/src/main/scala/org/graphframes/lib/PageRank.scala b/src/main/scala/org/graphframes/lib/PageRank.scala index b5bcfa53a..44e78b3e5 100644 --- a/src/main/scala/org/graphframes/lib/PageRank.scala +++ b/src/main/scala/org/graphframes/lib/PageRank.scala @@ -18,7 +18,6 @@ package org.graphframes.lib import org.apache.spark.graphx.{lib => graphxlib} - import org.graphframes.GraphFrame /** diff --git a/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala b/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala index 4f4ecfa41..22057b2fc 100644 --- a/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala +++ b/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala @@ -18,7 +18,7 @@ package org.graphframes.lib import org.apache.spark.graphx.{lib => graphxlib} -import org.graphframes.{GraphFrame, Logging} +import org.graphframes.GraphFrame /** * Parallel Personalized PageRank algorithm implementation. @@ -77,8 +77,8 @@ class ParallelPersonalizedPageRank private[graphframes] (private val graph: Grap } def run(): GraphFrame = { - require(maxIter != None, s"Max number of iterations maxIter() must be provided") - require(srcIds.nonEmpty, s"Source vertices Ids sourceIds() must be provided") + require(maxIter != None, "Max number of iterations maxIter() must be provided") + require(srcIds.nonEmpty, "Source vertices Ids sourceIds() must be provided") ParallelPersonalizedPageRank.run(graph, maxIter.get, resetProb.get, srcIds) } } diff --git a/src/main/scala/org/graphframes/lib/Pregel.scala b/src/main/scala/org/graphframes/lib/Pregel.scala index 8f8abd238..ed416fe0e 100644 --- a/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/src/main/scala/org/graphframes/lib/Pregel.scala @@ -17,15 +17,18 @@ package org.graphframes.lib -import java.io.IOException - +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.array +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.explode +import org.apache.spark.sql.functions.struct import org.graphframes.GraphFrame import org.graphframes.GraphFrame._ -import org.apache.spark.sql.{Column, DataFrame} -import org.apache.spark.sql.functions.{array, col, explode, struct} - -import scala.util.control.Breaks.{breakable, break} +import java.io.IOException +import scala.util.control.Breaks.break +import scala.util.control.Breaks.breakable /** * Implements a Pregel-like bulk-synchronous message-passing API based on DataFrame operations. @@ -83,11 +86,9 @@ class Pregel(val graph: GraphFrame) { private var checkpointInterval = 2 private var earlyStopping = false - private var sendMsgs = collection.mutable.ListBuffer.empty[(Column, Column)] + private val sendMsgs = collection.mutable.ListBuffer.empty[(Column, Column)] private var aggMsgsCol: Column = null - private val CHECKPOINT_NAME_PREFIX = "pregel" - /** Sets the max number of iterations (default: 10). */ def setMaxIter(value: Int): this.type = { maxIter = value diff --git a/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala b/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala index 35b18a742..e23b4661b 100644 --- a/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala +++ b/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala @@ -17,10 +17,11 @@ package org.graphframes.lib -import org.apache.spark.graphx.{Edge, lib => graphxlib} -import org.apache.spark.sql.{DataFrame, Row} - -import org.graphframes.{GraphFrame, Logging} +import org.apache.spark.graphx.Edge +import org.apache.spark.graphx.{lib => graphxlib} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.graphframes.GraphFrame /** * Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative diff --git a/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/src/main/scala/org/graphframes/lib/ShortestPaths.scala index 6e8358f3b..ab11da8dd 100644 --- a/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -17,20 +17,30 @@ package org.graphframes.lib -import java.util - -import scala.jdk.CollectionConverters._ - import org.apache.spark.graphx.{lib => graphxlib} -import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row import org.apache.spark.sql.api.java.UDF1 -import org.apache.spark.sql.functions.{col, udf, map, lit, when, map_zip_with, reduce, map_values, transform_values, collect_list} -import org.apache.spark.sql.types.{IntegerType, MapType} - +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.collect_list +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.map +import org.apache.spark.sql.functions.map_values +import org.apache.spark.sql.functions.map_zip_with +import org.apache.spark.sql.functions.reduce +import org.apache.spark.sql.functions.transform_values +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.functions.when +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.MapType import org.graphframes.GraphFrame +import org.graphframes.GraphFrame.quote import org.graphframes.Logging import org.graphframes.WithAlgorithmChoice -import org.graphframes.GraphFrame.quote + +import java.util +import scala.jdk.CollectionConverters._ /** * Computes shortest paths from every vertex to the given set of landmark vertices. Note that this diff --git a/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala b/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala index e8d9ecde8..41b2a2d8e 100644 --- a/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala @@ -19,7 +19,6 @@ package org.graphframes.lib import org.apache.spark.graphx.{lib => graphxlib} import org.apache.spark.sql.DataFrame - import org.graphframes.GraphFrame /** diff --git a/src/main/scala/org/graphframes/lib/TriangleCount.scala b/src/main/scala/org/graphframes/lib/TriangleCount.scala index 740a1b7cc..a3bec9402 100644 --- a/src/main/scala/org/graphframes/lib/TriangleCount.scala +++ b/src/main/scala/org/graphframes/lib/TriangleCount.scala @@ -18,10 +18,17 @@ package org.graphframes.lib import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.{array, col, explode, when} - +import org.apache.spark.sql.functions.array +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.explode +import org.apache.spark.sql.functions.when import org.graphframes.GraphFrame -import org.graphframes.GraphFrame.{DST, ID, LONG_DST, LONG_SRC, SRC, quote} +import org.graphframes.GraphFrame.DST +import org.graphframes.GraphFrame.ID +import org.graphframes.GraphFrame.LONG_DST +import org.graphframes.GraphFrame.LONG_SRC +import org.graphframes.GraphFrame.SRC +import org.graphframes.GraphFrame.quote /** * Computes the number of triangles passing through each vertex. diff --git a/src/main/scala/org/graphframes/pattern/patterns.scala b/src/main/scala/org/graphframes/pattern/patterns.scala index 4f9d0f8cd..308577a71 100644 --- a/src/main/scala/org/graphframes/pattern/patterns.scala +++ b/src/main/scala/org/graphframes/pattern/patterns.scala @@ -17,9 +17,11 @@ package org.graphframes.pattern +import org.graphframes.GraphFramesUnreachableException +import org.graphframes.InvalidParseException + import scala.collection.mutable import scala.util.parsing.combinator._ -import org.graphframes.{GraphFramesUnreachableException, InvalidParseException} /** * Parser for graph patterns for motif finding. Copied from GraphFrames with minor modification. @@ -79,7 +81,7 @@ private[graphframes] object Pattern { if (edgeNames.contains(name)) { throw new InvalidParseException( s"Motif reused name '$name' for both a vertex and " + - s"an edge, which is not allowed.") + "an edge, which is not allowed.") } vertexNames += name case AnonymousVertex => // pass @@ -89,12 +91,12 @@ private[graphframes] object Pattern { if (vertexNames.contains(name)) { throw new InvalidParseException( s"Motif reused name '$name' for both a vertex and " + - s"an edge, which is not allowed.") + "an edge, which is not allowed.") } if (edgeNames.contains(name)) { throw new InvalidParseException( s"Motif reused name '$name' for multiple edges, " + - s"which is not allowed.") + "which is not allowed.") } edgeNames += name addVertex(src) @@ -109,20 +111,21 @@ private[graphframes] object Pattern { edge match { case NamedEdge(name, src, dst) => throw new InvalidParseException( - s"Motif finding does not support negated named " + + "Motif finding does not support negated named " + s"edges, but the given pattern contained: !($src)-[$name]->($dst)") case AnonymousEdge(AnonymousVertex, AnonymousVertex) => - throw new InvalidParseException(s"Motif finding does not support completely " + - s"anonymous negated edges !()-[]->(). Users can check for 0 edges in the graph " + - s"using the edges DataFrame.") + throw new InvalidParseException( + "Motif finding does not support completely " + + "anonymous negated edges !()-[]->(). Users can check for 0 edges in the graph " + + "using the edges DataFrame.") case e @ AnonymousEdge(_, _) => addEdge(e) } case AnonymousEdge(AnonymousVertex, AnonymousVertex) => throw new InvalidParseException( - s"Motif finding does not support completely " + - s"anonymous edges ()-[]->(). Users can check for the existence of edges in the " + - s"graph using the edges DataFrame.") + "Motif finding does not support completely " + + "anonymous edges ()-[]->(). Users can check for the existence of edges in the " + + "graph using the edges DataFrame.") case e @ AnonymousEdge(_, _) => addEdge(e) case e @ NamedEdge(_, _, _) => diff --git a/src/test/scala/org/graphframes/GraphFrameSuite.scala b/src/test/scala/org/graphframes/GraphFrameSuite.scala index 59f710ae5..bf2fcfdec 100644 --- a/src/test/scala/org/graphframes/GraphFrameSuite.scala +++ b/src/test/scala/org/graphframes/GraphFrameSuite.scala @@ -17,28 +17,33 @@ package org.graphframes -import java.io.File - -import org.graphframes.examples.Graphs - +import com.google.common.io.Files import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path -import org.apache.spark.graphx.{Edge, Graph} +import org.apache.spark.graphx.Edge +import org.apache.spark.graphx.Graph import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel +import org.graphframes.examples.Graphs -import com.google.common.io.Files +import java.io.File class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { import GraphFrame._ var vertices: DataFrame = _ - val localVertices = Map(1L -> "A", 2L -> "B", 3L -> "C") - val localEdges = Map((1L, 2L) -> "love", (2L, 1L) -> "hate", (2L, 3L) -> "follow") + val localVertices: Map[Long, String] = Map(1L -> "A", 2L -> "B", 3L -> "C") + val localEdges: Map[(Long, Long), String] = + Map((1L, 2L) -> "love", (2L, 1L) -> "hate", (2L, 3L) -> "follow") var edges: DataFrame = _ var tempDir: File = _ diff --git a/src/test/scala/org/graphframes/GraphFrameTestSparkContext.scala b/src/test/scala/org/graphframes/GraphFrameTestSparkContext.scala index 99c8b56c6..4b1434806 100644 --- a/src/test/scala/org/graphframes/GraphFrameTestSparkContext.scala +++ b/src/test/scala/org/graphframes/GraphFrameTestSparkContext.scala @@ -17,14 +17,17 @@ package org.graphframes +import org.apache.commons.io.FileUtils +import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SQLImplicits +import org.apache.spark.sql.SparkSession +import org.scalatest.BeforeAndAfterAll +import org.scalatest.Suite + import java.io.File import java.nio.file.Files -import org.apache.commons.io.FileUtils -import org.scalatest.{BeforeAndAfterAll, Suite} -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits} - trait GraphFrameTestSparkContext extends BeforeAndAfterAll { self: Suite => @transient var spark: SparkSession = _ @transient var sc: SparkContext = _ @@ -53,7 +56,7 @@ trait GraphFrameTestSparkContext extends BeforeAndAfterAll { self: Suite => } } - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() spark = SparkSession @@ -73,7 +76,7 @@ trait GraphFrameTestSparkContext extends BeforeAndAfterAll { self: Suite => sparkMinorVersion = verMinor } - override def afterAll() { + override def afterAll(): Unit = { val checkpointDir = sc.getCheckpointDir if (spark != null) { spark.stop() diff --git a/src/test/scala/org/graphframes/PatternMatchSuite.scala b/src/test/scala/org/graphframes/PatternMatchSuite.scala index fc06d2eb6..79ef51045 100644 --- a/src/test/scala/org/graphframes/PatternMatchSuite.scala +++ b/src/test/scala/org/graphframes/PatternMatchSuite.scala @@ -17,9 +17,12 @@ package org.graphframes -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.Column -import org.apache.spark.sql.functions.{col, lit, when} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.when /** * Cases to go through: diff --git a/src/test/scala/org/graphframes/SparkFunSuite.scala b/src/test/scala/org/graphframes/SparkFunSuite.scala index e53c1111a..4bf1db939 100644 --- a/src/test/scala/org/graphframes/SparkFunSuite.scala +++ b/src/test/scala/org/graphframes/SparkFunSuite.scala @@ -17,7 +17,8 @@ package org.graphframes -import org.scalatest.{FunSuite, Outcome} +import org.scalatest.FunSuite +import org.scalatest.Outcome /** * Base abstract class for all unit tests in Spark for handling common functionality. diff --git a/src/test/scala/org/graphframes/TestUtils.scala b/src/test/scala/org/graphframes/TestUtils.scala index 629984da7..8e4f9e932 100644 --- a/src/test/scala/org/graphframes/TestUtils.scala +++ b/src/test/scala/org/graphframes/TestUtils.scala @@ -1,8 +1,8 @@ package org.graphframes import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.{DataType, StructType} - +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.StructType import org.graphframes.GraphFrame._ object TestUtils { @@ -17,7 +17,7 @@ object TestUtils { case None => throw new IllegalArgumentException( s"Spark tried to parse '$sparkVersion' as a Spark" + - s" version string, but it could not find the major and minor version numbers.") + " version string, but it could not find the major and minor version numbers.") } } diff --git a/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala b/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala index e2a7af557..11c2d3bfb 100644 --- a/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala +++ b/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala @@ -17,9 +17,10 @@ package org.graphframes.examples -import org.apache.spark.sql.{DataFrame, Row} - -import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.SparkFunSuite import org.graphframes.examples.BeliefPropagation._ import org.graphframes.examples.Graphs.gridIsingModel diff --git a/src/test/scala/org/graphframes/examples/GraphsSuite.scala b/src/test/scala/org/graphframes/examples/GraphsSuite.scala index 5d3076424..8e071065e 100644 --- a/src/test/scala/org/graphframes/examples/GraphsSuite.scala +++ b/src/test/scala/org/graphframes/examples/GraphsSuite.scala @@ -17,7 +17,8 @@ package org.graphframes.examples -import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite} +import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.SparkFunSuite class GraphsSuite extends SparkFunSuite with GraphFrameTestSparkContext { diff --git a/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala b/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala index b51305faf..969599edf 100644 --- a/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala +++ b/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala @@ -1,14 +1,18 @@ package org.graphframes.ldbc -import org.graphframes.SparkFunSuite -import org.graphframes.GraphFrameTestSparkContext -import java.nio.file._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.abs +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.sum +import org.apache.spark.sql.types.LongType import org.graphframes.GraphFrame +import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.SparkFunSuite import org.graphframes.examples.LDBCUtils + +import java.nio.file._ import java.util.Properties -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.{col, lit, abs, sum} -import org.apache.spark.sql.types.LongType class TestLDBCCases extends SparkFunSuite with GraphFrameTestSparkContext { private val resourcesPath = Paths.get(getClass().getResource("/").toURI()) @@ -46,7 +50,6 @@ class TestLDBCCases extends SparkFunSuite with GraphFrameTestSparkContext { LDBCUtils.downloadLDBCIfNotExists(resourcesPath, LDBCUtils.TEST_BFS_UNDIRECTED) val caseRoot = resourcesPath.resolve(LDBCUtils.TEST_BFS_UNDIRECTED) - val edgesPath = caseRoot.resolve(s"${LDBCUtils.TEST_BFS_UNDIRECTED}.e") val expectedPath = caseRoot.resolve(s"${LDBCUtils.TEST_BFS_UNDIRECTED}-BFS") val expectedDistances = spark.read diff --git a/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala b/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala index 3ea405ec4..945bb263b 100644 --- a/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala +++ b/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala @@ -17,12 +17,12 @@ package org.graphframes.lib -import scala.collection.mutable - import org.apache.spark.sql.functions._ - +import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.SparkFunSuite import org.graphframes.examples.Graphs -import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite} + +import scala.collection.mutable class AggregateMessagesSuite extends SparkFunSuite with GraphFrameTestSparkContext { diff --git a/src/test/scala/org/graphframes/lib/BFSSuite.scala b/src/test/scala/org/graphframes/lib/BFSSuite.scala index 081c3de99..addaf2fc5 100644 --- a/src/test/scala/org/graphframes/lib/BFSSuite.scala +++ b/src/test/scala/org/graphframes/lib/BFSSuite.scala @@ -17,10 +17,12 @@ package org.graphframes.lib +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -import org.apache.spark.sql.{DataFrame, Row} - -import org.graphframes.{GraphFrameTestSparkContext, GraphFrame, SparkFunSuite} +import org.graphframes.GraphFrame +import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.SparkFunSuite class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext { diff --git a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index 3b3fcf4ea..9614d1e20 100644 --- a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -17,20 +17,20 @@ package org.graphframes.lib -import java.io.IOException - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types.DataTypes import org.apache.spark.storage.StorageLevel - -import org.graphframes._ import org.graphframes.GraphFrame._ +import org.graphframes._ import org.graphframes.examples.Graphs +import java.io.IOException +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("default params") { diff --git a/src/test/scala/org/graphframes/lib/LabelPropagationSuite.scala b/src/test/scala/org/graphframes/lib/LabelPropagationSuite.scala index df9ba7445..ed17bd754 100644 --- a/src/test/scala/org/graphframes/lib/LabelPropagationSuite.scala +++ b/src/test/scala/org/graphframes/lib/LabelPropagationSuite.scala @@ -18,8 +18,9 @@ package org.graphframes.lib import org.apache.spark.sql.types.DataTypes - -import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite, TestUtils} +import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.SparkFunSuite +import org.graphframes.TestUtils import org.graphframes.examples.Graphs class LabelPropagationSuite extends SparkFunSuite with GraphFrameTestSparkContext { diff --git a/src/test/scala/org/graphframes/lib/PageRankSuite.scala b/src/test/scala/org/graphframes/lib/PageRankSuite.scala index 40343f3f5..f94c0bca7 100644 --- a/src/test/scala/org/graphframes/lib/PageRankSuite.scala +++ b/src/test/scala/org/graphframes/lib/PageRankSuite.scala @@ -19,9 +19,10 @@ package org.graphframes.lib import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.DataTypes - +import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.SparkFunSuite +import org.graphframes.TestUtils import org.graphframes.examples.Graphs -import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite, TestUtils} class PageRankSuite extends SparkFunSuite with GraphFrameTestSparkContext { diff --git a/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala b/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala index 0fa75ce5b..29f9b02b7 100644 --- a/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala +++ b/src/test/scala/org/graphframes/lib/ParallelPersonalizedPageRankSuite.scala @@ -16,16 +16,15 @@ */ package org.graphframes.lib - -import com.github.zafarkhaja.semver.Version - -import org.apache.spark.ml.linalg.{SQLDataTypes, SparseVector} +import org.apache.spark.ml.linalg.SQLDataTypes +import org.apache.spark.ml.linalg.SparseVector import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.DataTypes - +import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.SparkFunSuite +import org.graphframes.TestUtils import org.graphframes.examples.Graphs -import org.graphframes.{GraphFrameTestSparkContext, SparkFunSuite, TestUtils} class ParallelPersonalizedPageRankSuite extends SparkFunSuite with GraphFrameTestSparkContext { @@ -68,40 +67,34 @@ class ParallelPersonalizedPageRankSuite extends SparkFunSuite with GraphFrameTes TestUtils.checkColumnType(pr.edges.schema, "weight", DataTypes.DoubleType) } - // In Spark <2.4, sourceIds must be smaller than Int.MaxValue, - // which might not be the case for LONG_ID in graph.indexedVertices. - if (Version - .valueOf(org.apache.spark.SPARK_VERSION) - .greaterThanOrEqualTo(Version.valueOf("2.4.0"))) { - test("friends graph with parallel personalized PageRank") { - val g = Graphs.friends - val resetProb = 0.15 - val maxIter = 10 - val vertexIds: Array[Any] = Array("a") - lazy val prc = g.parallelPersonalizedPageRank - .maxIter(maxIter) - .sourceIds(vertexIds) - .resetProbability(resetProb) - - val pr = prc.run() - val prInvalid = pr.vertices - .select("pageranks") - .collect() - .filter { row: Row => - vertexIds.size != row.getAs[SparseVector](0).size - } - assert( - prInvalid.size === 0, - s"found ${prInvalid.size} entries with invalid number of returned personalized pagerank vector") + test("friends graph with parallel personalized PageRank") { + val g = Graphs.friends + val resetProb = 0.15 + val maxIter = 10 + val vertexIds: Array[Any] = Array("a") + lazy val prc = g.parallelPersonalizedPageRank + .maxIter(maxIter) + .sourceIds(vertexIds) + .resetProbability(resetProb) - val gRank = pr.vertices - .filter(col("id") === "g") - .select("pageranks") - .first() - .getAs[SparseVector](0) - assert( - gRank.numNonzeros === 0, - s"User g (Gabby) doesn't connect with a. So its pagerank should be 0 but we got ${gRank.numNonzeros}.") - } + val pr = prc.run() + val prInvalid = pr.vertices + .select("pageranks") + .collect() + .filter { row: Row => + vertexIds.size != row.getAs[SparseVector](0).size + } + assert( + prInvalid.size === 0, + s"found ${prInvalid.size} entries with invalid number of returned personalized pagerank vector") + + val gRank = pr.vertices + .filter(col("id") === "g") + .select("pageranks") + .first() + .getAs[SparseVector](0) + assert( + gRank.numNonzeros === 0, + s"User g (Gabby) doesn't connect with a. So its pagerank should be 0 but we got ${gRank.numNonzeros}.") } } diff --git a/src/test/scala/org/graphframes/lib/PregelSuite.scala b/src/test/scala/org/graphframes/lib/PregelSuite.scala index 669172a56..342d8964d 100644 --- a/src/test/scala/org/graphframes/lib/PregelSuite.scala +++ b/src/test/scala/org/graphframes/lib/PregelSuite.scala @@ -17,11 +17,9 @@ package org.graphframes.lib -import org.scalactic.Tolerance._ - import org.apache.spark.sql.functions._ - import org.graphframes._ +import org.scalactic.Tolerance._ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { diff --git a/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala b/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala index b7a2df5e7..83d86187e 100644 --- a/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala +++ b/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala @@ -19,8 +19,10 @@ package org.graphframes.lib import org.apache.spark.sql.Row import org.apache.spark.sql.types.DataTypes - -import org.graphframes.{GraphFrame, GraphFrameTestSparkContext, SparkFunSuite, TestUtils} +import org.graphframes.GraphFrame +import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.SparkFunSuite +import org.graphframes.TestUtils import org.graphframes.examples.Graphs class SVDPlusPlusSuite extends SparkFunSuite with GraphFrameTestSparkContext { diff --git a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala index 23f0f92ff..31fc0598b 100644 --- a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala @@ -17,10 +17,10 @@ package org.graphframes.lib -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.DataTypes - import org.graphframes.GraphFrame.quote import org.graphframes._ diff --git a/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala index c549cc5ca..ac37e3b75 100644 --- a/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala +++ b/src/test/scala/org/graphframes/lib/StronglyConnectedComponentsSuite.scala @@ -19,8 +19,10 @@ package org.graphframes.lib import org.apache.spark.sql.Row import org.apache.spark.sql.types.DataTypes - -import org.graphframes.{GraphFrameTestSparkContext, GraphFrame, SparkFunSuite, TestUtils} +import org.graphframes.GraphFrame +import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.SparkFunSuite +import org.graphframes.TestUtils class StronglyConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("Island Strongly Connected Components") { diff --git a/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala b/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala index 27501423f..0ca88ef7e 100644 --- a/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala +++ b/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala @@ -19,9 +19,11 @@ package org.graphframes.lib import org.apache.spark.sql.Row import org.apache.spark.sql.types.DataTypes - +import org.graphframes.GraphFrame import org.graphframes.GraphFrame.quote -import org.graphframes.{GraphFrameTestSparkContext, GraphFrame, SparkFunSuite, TestUtils} +import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.SparkFunSuite +import org.graphframes.TestUtils class TriangleCountSuite extends SparkFunSuite with GraphFrameTestSparkContext { diff --git a/src/test/scala/org/graphframes/pattern/PatternSuite.scala b/src/test/scala/org/graphframes/pattern/PatternSuite.scala index 5a7cdd05a..193d4e5dc 100644 --- a/src/test/scala/org/graphframes/pattern/PatternSuite.scala +++ b/src/test/scala/org/graphframes/pattern/PatternSuite.scala @@ -17,7 +17,8 @@ package org.graphframes.pattern -import org.graphframes.{InvalidParseException, SparkFunSuite} +import org.graphframes.InvalidParseException +import org.graphframes.SparkFunSuite class PatternSuite extends SparkFunSuite { From c035eda3c5deee7d3ba79ec96e3c10f4773ed8b5 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 12 Apr 2025 11:40:58 +0200 Subject: [PATCH 2/5] Fix semanticdb version && enforce scalastyle in CI --- .github/workflows/scala-ci.yml | 4 +++- build.sbt | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/scala-ci.yml b/.github/workflows/scala-ci.yml index d9234e9f7..26e944a79 100644 --- a/.github/workflows/scala-ci.yml +++ b/.github/workflows/scala-ci.yml @@ -36,7 +36,9 @@ jobs: ~/.ivy2/cache key: sbt-ivy-cache-spark-${{ matrix.spark-version}}-scala-${{ matrix.scala-version }}-java-${{ matrix.java-version }} - name: Check scalafmt - run: build/sbt root/scalafmtCheckAll + run: build/sbt scalafmtCheckAll + - name: Check scalastyle + run: build/sbt scalafixAll --check - name: Build and Test run: build/sbt -v ++${{ matrix.scala-version }} -Dspark.version=${{ matrix.spark-version }} coverage test coverageReport - uses: codecov/codecov-action@v3 diff --git a/build.sbt b/build.sbt index 468f283d1..8576db037 100644 --- a/build.sbt +++ b/build.sbt @@ -26,7 +26,7 @@ ThisBuild / crossScalaVersions := Seq("2.12.18", "2.13.8") // Scalafix configuration ThisBuild / semanticdbEnabled := true -ThisBuild / semanticdbVersion := scalafixSemanticdb.revision +ThisBuild / semanticdbVersion := "4.8.10" // The maximal version that supports both 2.13.8 and 2.12.18 lazy val commonSetting = Seq( libraryDependencies ++= Seq( From 5f0fc68f596c292619601f90a5625dd5634321e1 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 12 Apr 2025 11:43:38 +0200 Subject: [PATCH 3/5] Fix --- .github/workflows/scala-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/scala-ci.yml b/.github/workflows/scala-ci.yml index 26e944a79..f161c2c8f 100644 --- a/.github/workflows/scala-ci.yml +++ b/.github/workflows/scala-ci.yml @@ -38,7 +38,7 @@ jobs: - name: Check scalafmt run: build/sbt scalafmtCheckAll - name: Check scalastyle - run: build/sbt scalafixAll --check + run: build/sbt scalafixAll - name: Build and Test run: build/sbt -v ++${{ matrix.scala-version }} -Dspark.version=${{ matrix.spark-version }} coverage test coverageReport - uses: codecov/codecov-action@v3 From 262a4fc8b60024a5efe718c53829ce9691368589 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 12 Apr 2025 13:59:02 +0200 Subject: [PATCH 4/5] Fix incomplete pattern matching --- .../scala/org/graphframes/GraphFrame.scala | 8 +- .../scala/org/graphframes/lib/Pregel.scala | 2 +- .../org/graphframes/lib/SVDPlusPlus.scala | 2 + .../org/graphframes/lib/ShortestPaths.scala | 15 +- .../org/graphframes/GraphFrameSuite.scala | 128 ++++++++++++------ .../org/graphframes/PatternMatchSuite.scala | 8 +- .../examples/BeliefPropagationSuite.scala | 17 ++- .../lib/AggregateMessagesSuite.scala | 20 ++- .../scala/org/graphframes/lib/BFSSuite.scala | 25 ++-- .../graphframes/lib/SVDPlusPlusSuite.scala | 7 +- .../graphframes/lib/ShortestPathsSuite.scala | 18 ++- .../graphframes/lib/TriangleCountSuite.scala | 74 ++++++---- 12 files changed, 214 insertions(+), 110 deletions(-) diff --git a/src/main/scala/org/graphframes/GraphFrame.scala b/src/main/scala/org/graphframes/GraphFrame.scala index 135734456..9b7907488 100644 --- a/src/main/scala/org/graphframes/GraphFrame.scala +++ b/src/main/scala/org/graphframes/GraphFrame.scala @@ -197,20 +197,26 @@ class GraphFrame private ( if (hasIntegralIdType) { val vv = vertices.select(col(ID).cast(LongType), nestAsCol(vertices, ATTR)).rdd.map { case Row(id: Long, attr: Row) => (id, attr) + case _ => throw new GraphFramesUnreachableException() } val ee = edges .select(col(SRC).cast(LongType), col(DST).cast(LongType), nestAsCol(edges, ATTR)) .rdd - .map { case Row(srcId: Long, dstId: Long, attr: Row) => Edge(srcId, dstId, attr) } + .map { + case Row(srcId: Long, dstId: Long, attr: Row) => Edge(srcId, dstId, attr) + case _ => throw new GraphFramesUnreachableException() + } Graph(vv, ee) } else { // Compute Long vertex IDs val vv = indexedVertices.select(LONG_ID, ATTR).rdd.map { case Row(long_id: Long, attr: Row) => (long_id, attr) + case _ => throw new GraphFramesUnreachableException() } val ee = indexedEdges.select(LONG_SRC, LONG_DST, ATTR).rdd.map { case Row(long_src: Long, long_dst: Long, attr: Row) => Edge(long_src, long_dst, attr) + case _ => throw new GraphFramesUnreachableException() } Graph(vv, ee) } diff --git a/src/main/scala/org/graphframes/lib/Pregel.scala b/src/main/scala/org/graphframes/lib/Pregel.scala index ed416fe0e..ab34494d2 100644 --- a/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/src/main/scala/org/graphframes/lib/Pregel.scala @@ -288,7 +288,7 @@ class Pregel(val graph: GraphFrame) { if (vertexUpdateColDF != null) { vertexUpdateColDF.unpersist() } - break + break() } val newAggMsgDF = msgDF diff --git a/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala b/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala index e23b4661b..21492cd6d 100644 --- a/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala +++ b/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala @@ -22,6 +22,7 @@ import org.apache.spark.graphx.{lib => graphxlib} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.graphframes.GraphFrame +import org.graphframes.GraphFramesUnreachableException /** * Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative @@ -117,6 +118,7 @@ object SVDPlusPlus { private def run(graph: GraphFrame, conf: graphxlib.SVDPlusPlus.Conf): (DataFrame, Double) = { val edges = graph.edges.select(GraphFrame.SRC, GraphFrame.DST, COLUMN_WEIGHT).rdd.map { case Row(src: Long, dst: Long, w: Double) => Edge(src, dst, w) + case _ => throw new GraphFramesUnreachableException() } val (gx, res) = graphxlib.SVDPlusPlus.run(edges, conf) val gf = GraphXConversions.fromGraphX( diff --git a/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/src/main/scala/org/graphframes/lib/ShortestPaths.scala index ab11da8dd..302be36eb 100644 --- a/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.MapType import org.graphframes.GraphFrame import org.graphframes.GraphFrame.quote +import org.graphframes.GraphFramesUnreachableException import org.graphframes.Logging import org.graphframes.WithAlgorithmChoice @@ -79,6 +80,7 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame) algorithm match { case ALGO_GRAPHX => runInGraphX(graph, lmarksChecked) case ALGO_GRAPHFRAMES => runInGraphFrames(graph, lmarksChecked) + case _ => throw new GraphFramesUnreachableException() } } } @@ -95,16 +97,21 @@ private object ShortestPaths extends Logging { val distanceCol: Column = if (graph.hasIntegralIdType) { // It seems there are no easy way to convert a sequence of pairs into a map val mapToLandmark = udf { distances: Seq[Row] => - distances.map { case Row(k: Long, v: Int) => - k -> v + distances.map { + case Row(k: Long, v: Int) => + k -> v + case _: Row => throw new GraphFramesUnreachableException() }.toMap } mapToLandmark(g.vertices(DISTANCE_ID)) } else { val func = new UDF1[Seq[Row], Map[Any, Int]] { override def call(t1: Seq[Row]): Map[Any, Int] = { - t1.map { case Row(k: Long, v: Int) => - longIdToLandmark(k) -> v + t1.map { + case Row(k: Long, v: Int) => + longIdToLandmark(k) -> v + + case _: Row => throw new GraphFramesUnreachableException() }.toMap } } diff --git a/src/test/scala/org/graphframes/GraphFrameSuite.scala b/src/test/scala/org/graphframes/GraphFrameSuite.scala index bf2fcfdec..057f6c492 100644 --- a/src/test/scala/org/graphframes/GraphFrameSuite.scala +++ b/src/test/scala/org/graphframes/GraphFrameSuite.scala @@ -65,11 +65,15 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("construction from DataFrames") { val g = GraphFrame(vertices, edges) - g.vertices.collect().foreach { case Row(id: Long, name: String) => - assert(localVertices(id) === name) + g.vertices.collect().foreach { + case Row(id: Long, name: String) => + assert(localVertices(id) === name) + case _: Row => throw new GraphFramesUnreachableException() } - g.edges.collect().foreach { case Row(src: Long, dst: Long, action: String) => - assert(localEdges((src, dst)) === action) + g.edges.collect().foreach { + case Row(src: Long, dst: Long, action: String) => + assert(localEdges((src, dst)) === action) + case _: Row => throw new GraphFramesUnreachableException() } intercept[IllegalArgumentException] { val badVertices = vertices.select(col("id").as("uid"), col("name")) @@ -89,11 +93,15 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { val g = GraphFrame( vertices.withColumnRenamed("name", "a.name"), edges.withColumnRenamed("action", "the.action")) - g.vertices.collect().foreach { case Row(id: Long, name: String) => - assert(localVertices(id) === name) + g.vertices.collect().foreach { + case Row(id: Long, name: String) => + assert(localVertices(id) === name) + case _: Row => throw new GraphFramesUnreachableException() } - g.edges.collect().foreach { case Row(src: Long, dst: Long, action: String) => - assert(localEdges((src, dst)) === action) + g.edges.collect().foreach { + case Row(src: Long, dst: Long, action: String) => + assert(localEdges((src, dst)) === action) + case _: Row => throw new GraphFramesUnreachableException() } g.pageRank.maxIter(10).run() } @@ -102,11 +110,15 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { val g = GraphFrame( vertices.withColumnRenamed("name", "a `name`"), edges.withColumnRenamed("action", "the `action`")) - g.vertices.collect().foreach { case Row(id: Long, name: String) => - assert(localVertices(id) === name) + g.vertices.collect().foreach { + case Row(id: Long, name: String) => + assert(localVertices(id) === name) + case _: Row => throw new GraphFramesUnreachableException() } - g.edges.collect().foreach { case Row(src: Long, dst: Long, action: String) => - assert(localEdges((src, dst)) === action) + g.edges.collect().foreach { + case Row(src: Long, dst: Long, action: String) => + assert(localEdges((src, dst)) === action) + case _: Row => throw new GraphFramesUnreachableException() } g.pageRank.maxIter(10).run() } @@ -120,8 +132,10 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { val idsFromEdgesSet = g.edges .select("src", "dst") .rdd - .flatMap { case Row(src: Long, dst: Long) => - Seq(src, dst) + .flatMap { + case Row(src: Long, dst: Long) => + Seq(src, dst) + case _: Row => throw new GraphFramesUnreachableException() } .collect() .toSet @@ -129,35 +143,45 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { } test("construction from GraphX") { - val vv: RDD[(Long, String)] = vertices.rdd.map { case Row(id: Long, name: String) => - (id, name) + val vv: RDD[(Long, String)] = vertices.rdd.map { + case Row(id: Long, name: String) => + (id, name) + case _: Row => throw new GraphFramesUnreachableException() } - val ee: RDD[Edge[String]] = edges.rdd.map { case Row(src: Long, dst: Long, action: String) => - Edge(src, dst, action) + val ee: RDD[Edge[String]] = edges.rdd.map { + case Row(src: Long, dst: Long, action: String) => + Edge(src, dst, action) + case _: Row => throw new GraphFramesUnreachableException() } val g = Graph(vv, ee) val gf = GraphFrame.fromGraphX(g) - gf.vertices.select("id", "attr").collect().foreach { case Row(id: Long, name: String) => - assert(localVertices(id) === name) + gf.vertices.select("id", "attr").collect().foreach { + case Row(id: Long, name: String) => + assert(localVertices(id) === name) + case _: Row => throw new GraphFramesUnreachableException() } gf.edges.select("src", "dst", "attr").collect().foreach { case Row(src: Long, dst: Long, action: String) => assert(localEdges((src, dst)) === action) + case _: Row => throw new GraphFramesUnreachableException() } } test("convert to GraphX: Long IDs") { val gf = GraphFrame(vertices, edges) val g = gf.toGraphX - g.vertices.collect().foreach { case (id0, Row(id1: Long, name: String)) => - assert(id0 === id1) - assert(localVertices(id0) === name) + g.vertices.collect().foreach { + case (id0, Row(id1: Long, name: String)) => + assert(id0 === id1) + assert(localVertices(id0) === name) + case _ => throw new GraphFramesUnreachableException() } g.edges.collect().foreach { case Edge(src0, dst0, Row(src1: Long, dst1: Long, action: String)) => assert(src0 === src1) assert(dst0 === dst1) assert(localEdges((src0, dst0)) === action) + case _ => throw new GraphFramesUnreachableException() } } @@ -172,19 +196,23 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { // Int IDs should be directly cast to Long, so ID values should match. val vCols = gf.vertexColumnMap val eCols = gf.edgeColumnMap - g.vertices.collect().foreach { case (id0: Long, attr: Row) => - val id1 = attr.getInt(vCols("id")) - val name = attr.getString(vCols("name")) - assert(id0 === id1) - assert(localVertices(id0) === name) + g.vertices.collect().foreach { + case (id0: Long, attr: Row) => + val id1 = attr.getInt(vCols("id")) + val name = attr.getString(vCols("name")) + assert(id0 === id1) + assert(localVertices(id0) === name) + case _ => throw new GraphFramesUnreachableException() } - g.edges.collect().foreach { case Edge(src0: Long, dst0: Long, attr: Row) => - val src1 = attr.getInt(eCols("src")) - val dst1 = attr.getInt(eCols("dst")) - val action = attr.getString(eCols("action")) - assert(src0 === src1) - assert(dst0 === dst1) - assert(localEdges((src0, dst0)) === action) + g.edges.collect().foreach { + case Edge(src0: Long, dst0: Long, attr: Row) => + val src1 = attr.getInt(eCols("src")) + val dst1 = attr.getInt(eCols("dst")) + val action = attr.getString(eCols("action")) + assert(src0 === src1) + assert(dst0 === dst1) + assert(localEdges((src0, dst0)) === action) + case _ => throw new GraphFramesUnreachableException() } } @@ -240,11 +268,15 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { val e1 = spark.read.parquet(ePath) val g1 = GraphFrame(v1, e1) - g1.vertices.collect().foreach { case Row(id: Long, name: String) => - assert(localVertices(id) === name) + g1.vertices.collect().foreach { + case Row(id: Long, name: String) => + assert(localVertices(id) === name) + case _ => throw new GraphFramesUnreachableException() } - g1.edges.collect().foreach { case Row(src: Long, dst: Long, action: String) => - assert(localEdges((src, dst)) === action) + g1.edges.collect().foreach { + case Row(src: Long, dst: Long, action: String) => + assert(localEdges((src, dst)) === action) + case _ => throw new GraphFramesUnreachableException() } } @@ -254,8 +286,10 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(g.outDegrees.columns === Seq("id", "outDegree")) val outDegrees = g.outDegrees .collect() - .map { case Row(id: Long, outDeg: Int) => - (id, outDeg) + .map { + case Row(id: Long, outDeg: Int) => + (id, outDeg) + case _ => throw new GraphFramesUnreachableException() } .toMap assert(outDegrees === Map(1L -> 1, 2L -> 2)) @@ -263,8 +297,10 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(g.inDegrees.columns === Seq("id", "inDegree")) val inDegrees = g.inDegrees .collect() - .map { case Row(id: Long, inDeg: Int) => - (id, inDeg) + .map { + case Row(id: Long, inDeg: Int) => + (id, inDeg) + case _ => throw new GraphFramesUnreachableException() } .toMap assert(inDegrees === Map(1L -> 1, 2L -> 1, 3L -> 1)) @@ -272,8 +308,10 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(g.degrees.columns === Seq("id", "degree")) val degrees = g.degrees .collect() - .map { case Row(id: Long, deg: Int) => - (id, deg) + .map { + case Row(id: Long, deg: Int) => + (id, deg) + case _ => throw new GraphFramesUnreachableException() } .toMap assert(degrees === Map(1L -> 2, 2L -> 3, 3L -> 1)) diff --git a/src/test/scala/org/graphframes/PatternMatchSuite.scala b/src/test/scala/org/graphframes/PatternMatchSuite.scala index 79ef51045..503280168 100644 --- a/src/test/scala/org/graphframes/PatternMatchSuite.scala +++ b/src/test/scala/org/graphframes/PatternMatchSuite.scala @@ -562,9 +562,11 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { chainWith2Friends .select("ab.relationship", "bc.relationship", "cd.relationship") .collect() - .foreach { case Row(ab: String, bc: String, cd: String) => - val numFriends = Seq(ab, bc, cd).map(r => if (r == "friend") 1 else 0).sum - assert(numFriends >= 2) + .foreach { + case Row(ab: String, bc: String, cd: String) => + val numFriends = Seq(ab, bc, cd).map(r => if (r == "friend") 1 else 0).sum + assert(numFriends >= 2) + case _ => throw new GraphFramesUnreachableException() } // Operating in a stateful manner, where cnt is the state. diff --git a/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala b/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala index 11c2d3bfb..ff1999977 100644 --- a/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala +++ b/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala @@ -20,6 +20,7 @@ package org.graphframes.examples import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.GraphFramesUnreachableException import org.graphframes.SparkFunSuite import org.graphframes.examples.BeliefPropagation._ import org.graphframes.examples.Graphs.gridIsingModel @@ -40,10 +41,12 @@ class BeliefPropagationSuite extends SparkFunSuite with GraphFrameTestSparkConte // Check beliefs. def checkResults(v: DataFrame): Unit = { - v.select("belief").collect().foreach { case Row(belief: Double) => - assert( - belief >= 0.0 && belief <= 1.0, - s"Expected belief to be probability in [0,1], but found $belief") + v.select("belief").collect().foreach { + case Row(belief: Double) => + assert( + belief >= 0.0 && belief <= 1.0, + s"Expected belief to be probability in [0,1], but found $belief") + case _ => throw new GraphFramesUnreachableException() } } checkResults(gxResults.vertices) @@ -56,8 +59,10 @@ class BeliefPropagationSuite extends SparkFunSuite with GraphFrameTestSparkConte .join(gfBeliefs, "id") .select(gxBeliefs("belief").as("gxBelief"), gfBeliefs("belief").as("gfBelief")) .collect() - .foreach { case Row(gxBelief: Double, gfBelief: Double) => - assert(math.abs(gxBelief - gfBelief) <= 1e-6) + .foreach { + case Row(gxBelief: Double, gfBelief: Double) => + assert(math.abs(gxBelief - gfBelief) <= 1e-6) + case _ => throw new GraphFramesUnreachableException() } } } diff --git a/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala b/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala index 945bb263b..6e8c70d23 100644 --- a/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala +++ b/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala @@ -19,6 +19,7 @@ package org.graphframes.lib import org.apache.spark.sql.functions._ import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.GraphFramesUnreachableException import org.graphframes.SparkFunSuite import org.graphframes.examples.Graphs @@ -43,8 +44,10 @@ class AggregateMessagesSuite extends SparkFunSuite with GraphFrameTestSparkConte val aggMap: Map[String, Long] = agg .select("id", "summedAges") .collect() - .map { case Row(id: String, s: Long) => - id -> s + .map { + case Row(id: String, s: Long) => + id -> s + case _: Row => throw new GraphFramesUnreachableException() } .toMap // Compute the truth via brute force for comparison. @@ -52,8 +55,10 @@ class AggregateMessagesSuite extends SparkFunSuite with GraphFrameTestSparkConte val user2age = g.vertices .select("id", "age") .collect() - .map { case Row(id: String, age: Int) => - id -> age + .map { + case Row(id: String, age: Int) => + id -> age + case _: Row => throw new GraphFramesUnreachableException() } .toMap val a = mutable.HashMap.empty[String, Int] @@ -63,6 +68,7 @@ class AggregateMessagesSuite extends SparkFunSuite with GraphFrameTestSparkConte src, a.getOrElse(src, 0) + user2age(dst) + (if (relationship == "friend") 1 else 0)) a.put(dst, a.getOrElse(dst, 0) + user2age(src)) + case _ => throw new GraphFramesUnreachableException() } a.toMap } @@ -81,8 +87,10 @@ class AggregateMessagesSuite extends SparkFunSuite with GraphFrameTestSparkConte val agg2Map: Map[String, Long] = agg2 .select("id", "summedAges") .collect() - .map { case Row(id: String, s: Long) => - id -> s + .map { + case Row(id: String, s: Long) => + id -> s + case _: Row => throw new GraphFramesUnreachableException() } .toMap // Compare to the true values. diff --git a/src/test/scala/org/graphframes/lib/BFSSuite.scala b/src/test/scala/org/graphframes/lib/BFSSuite.scala index addaf2fc5..8d992f5cc 100644 --- a/src/test/scala/org/graphframes/lib/BFSSuite.scala +++ b/src/test/scala/org/graphframes/lib/BFSSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col import org.graphframes.GraphFrame import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.GraphFramesUnreachableException import org.graphframes.SparkFunSuite class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext { @@ -105,8 +106,10 @@ class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext { val paths = g.bfs.fromExpr(col("id") === "e").toExpr(col("id") === "b").run() assert(paths.count() === 2) assert(paths.columns === Seq("from", "e0", "v1", "e1", "v2", "e2", "to")) - paths.select("to.id").collect().foreach { case Row(id: String) => - assert(id === "b") + paths.select("to.id").collect().foreach { + case Row(id: String) => + assert(id === "b") + case _ => throw new GraphFramesUnreachableException() } } @@ -133,8 +136,10 @@ class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext { .edgeFilter(col("src") =!= "d") .run() assert(paths1.count() === 1) - paths1.select("e0.dst").collect().foreach { case Row(id: String) => - assert(id === "f") + paths1.select("e0.dst").collect().foreach { + case Row(id: String) => + assert(id === "f") + case _: Row => throw new GraphFramesUnreachableException() } val paths2 = g.bfs .fromExpr(col("id") === "e") @@ -142,8 +147,10 @@ class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext { .edgeFilter(col("relationship") === "friend") .run() assert(paths2.count() === 1) - paths2.select("e0.dst").collect().foreach { case Row(id: String) => - assert(id === "d") + paths2.select("e0.dst").collect().foreach { + case Row(id: String) => + assert(id === "d") + case _: Row => throw new GraphFramesUnreachableException() } } @@ -154,8 +161,10 @@ class BFSSuite extends SparkFunSuite with GraphFrameTestSparkContext { .edgeFilter("src != 'd'") .run() assert(paths1.count() === 1) - paths1.select("e0.dst").collect().foreach { case Row(id: String) => - assert(id === "f") + paths1.select("e0.dst").collect().foreach { + case Row(id: String) => + assert(id === "f") + case _: Row => throw new GraphFramesUnreachableException() } } diff --git a/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala b/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala index 83d86187e..95754e6e7 100644 --- a/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala +++ b/src/test/scala/org/graphframes/lib/SVDPlusPlusSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.types.DataTypes import org.graphframes.GraphFrame import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.GraphFramesUnreachableException import org.graphframes.SparkFunSuite import org.graphframes.TestUtils import org.graphframes.examples.Graphs @@ -45,8 +46,10 @@ class SVDPlusPlusSuite extends SparkFunSuite with GraphFrameTestSparkContext { val err = v2 .select(GraphFrame.ID, SVDPlusPlus.COLUMN4) .rdd - .map { case Row(vid: Long, vd: Double) => - if (vid % 2 == 1) vd else 0.0 + .map { + case Row(vid: Long, vd: Double) => + if (vid % 2 == 1) vd else 0.0 + case _ => throw new GraphFramesUnreachableException() } .reduce(_ + _) / g.edges.count() assert(err <= svdppErr) diff --git a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala index 31fc0598b..99f030e0c 100644 --- a/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala @@ -98,8 +98,10 @@ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext { "distances", DataTypes.createMapType(v2.schema("id").dataType, DataTypes.IntegerType, true)) val newVs = v2.select("id", "distances").collect().toSeq - val results = newVs.map { case Row(id: Long, spMap: Map[Long, Int] @unchecked) => - (id, spMap) + val results = newVs.map { + case Row(id: Long, spMap: Map[Long, Int] @unchecked) => + (id, spMap) + case _ => throw new GraphFramesUnreachableException() } assert(results.toSet === shortestPaths) } @@ -118,8 +120,10 @@ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext { val results = v .select("id", "distances") .collect() - .map { case Row(id: String, spMap: Map[String, Int] @unchecked) => - (id, spMap) + .map { + case Row(id: String, spMap: Map[String, Int] @unchecked) => + (id, spMap) + case _ => throw new GraphFramesUnreachableException() } .toSet assert(results === expected) @@ -139,8 +143,10 @@ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext { val results = v .select("id", "distances") .collect() - .map { case Row(id: String, spMap: Map[String, Int] @unchecked) => - (id, spMap) + .map { + case Row(id: String, spMap: Map[String, Int] @unchecked) => + (id, spMap) + case _ => throw new GraphFramesUnreachableException() } .toSet assert(results === expected) diff --git a/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala b/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala index 0ca88ef7e..d3aa3674c 100644 --- a/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala +++ b/src/test/scala/org/graphframes/lib/TriangleCountSuite.scala @@ -22,13 +22,14 @@ import org.apache.spark.sql.types.DataTypes import org.graphframes.GraphFrame import org.graphframes.GraphFrame.quote import org.graphframes.GraphFrameTestSparkContext +import org.graphframes.GraphFramesUnreachableException import org.graphframes.SparkFunSuite import org.graphframes.TestUtils class TriangleCountSuite extends SparkFunSuite with GraphFrameTestSparkContext { test("Count a single triangle") { - val edges = spark.createDataFrame(Array(0L -> 1L, 1L -> 2L, 2L -> 0L)).toDF("src", "dst") + val edges = spark.createDataFrame(Seq(0L -> 1L, 1L -> 2L, 2L -> 0L)).toDF("src", "dst") val vertices = spark .createDataFrame(Seq((0L, "a"), (1L, "b"), (2L, "c"))) .toDF("id", "a") @@ -38,57 +39,66 @@ class TriangleCountSuite extends SparkFunSuite with GraphFrameTestSparkContext { TestUtils.checkColumnType(v2.schema, "count", DataTypes.LongType) v2.select("id", "count", "a") .collect() - .foreach { case Row(vid: Long, count: Long, _) => assert(count === 1) } + .foreach { + case Row(vid: Long, count: Long, _) => assert(count === 1) + case _: Row => throw new GraphFramesUnreachableException() + } } test("Count two triangles") { val edges = spark .createDataFrame( - Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ - Array(0L -> -1L, -1L -> -2L, -2L -> 0L)) + Seq(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ + Seq(0L -> -1L, -1L -> -2L, -2L -> 0L)) .toDF("src", "dst") val g = GraphFrame.fromEdges(edges) val v2 = g.triangleCount.run() - v2.select("id", "count").collect().foreach { case Row(id: Long, count: Long) => - if (id == 0) { - assert(count === 2) - } else { - assert(count === 1) - } + v2.select("id", "count").collect().foreach { + case Row(id: Long, count: Long) => + if (id == 0) { + assert(count === 2) + } else { + assert(count === 1) + } + case _: Row => throw new GraphFramesUnreachableException() } } test("Count one triangles with bi-directed edges") { // Note: This is different from GraphX, which double-counts triangles with bidirected edges. - val triangles = Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ Array(0L -> -1L, -1L -> -2L, -2L -> 0L) + val triangles = Seq(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ Seq(0L -> -1L, -1L -> -2L, -2L -> 0L) val revTriangles = triangles.map { case (a, b) => (b, a) } val edges = spark.createDataFrame(triangles ++ revTriangles).toDF("src", "dst") val g = GraphFrame.fromEdges(edges) val v2 = g.triangleCount.run() - v2.select("id", "count").collect().foreach { case Row(id: Long, count: Long) => - if (id == 0) { - assert(count === 2) - } else { - assert(count === 1) - } + v2.select("id", "count").collect().foreach { + case Row(id: Long, count: Long) => + if (id == 0) { + assert(count === 2) + } else { + assert(count === 1) + } + case _: Row => throw new GraphFramesUnreachableException() } } test("Count a single triangle with duplicate edges") { val edges = spark .createDataFrame( - Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ - Array(0L -> 1L, 1L -> 2L, 2L -> 0L)) + Seq(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ + Seq(0L -> 1L, 1L -> 2L, 2L -> 0L)) .toDF("src", "dst") val g = GraphFrame.fromEdges(edges) val v2 = g.triangleCount.run() - v2.select("id", "count").collect().foreach { case Row(id: Long, count: Long) => - assert(count === 1) + v2.select("id", "count").collect().foreach { + case Row(id: Long, count: Long) => + assert(count === 1) + case _: Row => throw new GraphFramesUnreachableException() } } test("Count with dot column name") { - val edges = sqlContext.createDataFrame(Array(0L -> 1L, 1L -> 2L, 2L -> 0L)).toDF("src", "dst") + val edges = sqlContext.createDataFrame(Seq(0L -> 1L, 1L -> 2L, 2L -> 0L)).toDF("src", "dst") val vertices = sqlContext .createDataFrame(Seq((0L, "a"), (1L, "b"), (2L, "c"))) .toDF("id", "a.column") @@ -98,11 +108,14 @@ class TriangleCountSuite extends SparkFunSuite with GraphFrameTestSparkContext { TestUtils.checkColumnType(v2.schema, "count", DataTypes.LongType) v2.select("id", "count", quote("a.column")) .collect() - .foreach { case Row(vid: Long, count: Long, _) => assert(count === 1) } + .foreach { + case Row(vid: Long, count: Long, _) => assert(count === 1) + case _: Row => throw new GraphFramesUnreachableException() + } } test("Count with backquote in column name") { - val edges = sqlContext.createDataFrame(Array(0L -> 1L, 1L -> 2L, 2L -> 0L)).toDF("src", "dst") + val edges = sqlContext.createDataFrame(Seq(0L -> 1L, 1L -> 2L, 2L -> 0L)).toDF("src", "dst") val vertices = sqlContext .createDataFrame(Seq((0L, "a"), (1L, "b"), (2L, "c"))) .toDF("id", "a `column`") @@ -112,15 +125,20 @@ class TriangleCountSuite extends SparkFunSuite with GraphFrameTestSparkContext { TestUtils.checkColumnType(v2.schema, "count", DataTypes.LongType) v2.select("id", "count", quote("a `column`")) .collect() - .foreach { case Row(vid: Long, count: Long, _) => assert(count === 1) } + .foreach { + case Row(vid: Long, count: Long, _) => assert(count === 1) + case _: Row => throw new GraphFramesUnreachableException() + } } test("no triangle") { - val edges = spark.createDataFrame(Array(0L -> 1L, 1L -> 2L)).toDF("src", "dst") + val edges = spark.createDataFrame(Seq(0L -> 1L, 1L -> 2L)).toDF("src", "dst") val g = GraphFrame.fromEdges(edges) val v2 = g.triangleCount.run() - v2.select("count").collect().foreach { case Row(count: Long) => - assert(count === 0) + v2.select("count").collect().foreach { + case Row(count: Long) => + assert(count === 0) + case _: Row => throw new GraphFramesUnreachableException() } } } From 597acf4091bc4eda0dd5bea68cff1db892c83a79 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 12 Apr 2025 20:22:34 +0200 Subject: [PATCH 5/5] Move LDBCUtils to align the package name --- src/main/scala/org/graphframes/{ => examples}/LDBCUtils.scala | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/main/scala/org/graphframes/{ => examples}/LDBCUtils.scala (100%) diff --git a/src/main/scala/org/graphframes/LDBCUtils.scala b/src/main/scala/org/graphframes/examples/LDBCUtils.scala similarity index 100% rename from src/main/scala/org/graphframes/LDBCUtils.scala rename to src/main/scala/org/graphframes/examples/LDBCUtils.scala