diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index 550dff5f9..8b84d6d82 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -6,7 +6,7 @@ jobs: fail-fast: false matrix: include: - - spark-version: 3.5.0 + - spark-version: 3.5.4 scala-version: 2.12.18 python-version: 3.9.19 runs-on: ubuntu-22.04 diff --git a/.github/workflows/scala-ci.yml b/.github/workflows/scala-ci.yml index fab50dc1a..032b46592 100644 --- a/.github/workflows/scala-ci.yml +++ b/.github/workflows/scala-ci.yml @@ -6,9 +6,9 @@ jobs: fail-fast: false matrix: include: - - spark-version: 3.5.0 + - spark-version: 3.5.4 scala-version: 2.13.8 - - spark-version: 3.5.0 + - spark-version: 3.5.4 scala-version: 2.12.12 runs-on: ubuntu-22.04 env: diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 33c6b1ab5..dbc1d98f7 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -281,140 +281,149 @@ object ConnectedComponents extends Logging { return runGraphX(graph) } - val runId = UUID.randomUUID().toString.takeRight(8) - val logPrefix = s"[CC $runId]" - logInfo(s"$logPrefix Start connected components with run ID $runId.") - val spark = graph.spark val sc = spark.sparkContext - - val shouldCheckpoint = checkpointInterval > 0 - val checkpointDir: Option[String] = if (shouldCheckpoint) { - val dir = sc.getCheckpointDir.map { d => - new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString - }.getOrElse { - throw new IOException( - "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir().") + // Store original AQE setting + val originalAQE = spark.conf.get("spark.sql.adaptive.enabled") + + try { + spark.conf.set("spark.sql.adaptive.enabled", "false") + + val runId = UUID.randomUUID().toString.takeRight(8) + val logPrefix = s"[CC $runId]" + logInfo(s"$logPrefix Start connected components with run ID $runId.") + + val shouldCheckpoint = checkpointInterval > 0 + val checkpointDir: Option[String] = if (shouldCheckpoint) { + val dir = sc.getCheckpointDir.map { d => + new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString + }.getOrElse { + throw new IOException( + "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir().") + } + logInfo(s"$logPrefix Using $dir for checkpointing with interval $checkpointInterval.") + Some(dir) + } else { + logInfo( + s"$logPrefix Checkpointing is disabled because checkpointInterval=$checkpointInterval.") + None } - logInfo(s"$logPrefix Using $dir for checkpointing with interval $checkpointInterval.") - Some(dir) - } else { - logInfo( - s"$logPrefix Checkpointing is disabled because checkpointInterval=$checkpointInterval.") - None - } - logInfo(s"$logPrefix Preparing the graph for connected component computation ...") - val g = prepare(graph) - val vv = g.vertices - var ee = g.edges.persist(intermediateStorageLevel) // src < dst - val numEdges = ee.count() - logInfo(s"$logPrefix Found $numEdges edges after preparation.") - - var converged = false - var iteration = 1 - - def _calcMinNbrSum(minNbrsDF: DataFrame): BigDecimal = { - // Taking the sum in DecimalType to preserve precision. - // We use 20 digits for long values and Spark SQL will add 10 digits for the sum. - // It should be able to handle 200 billion edges without overflow. - val (minNbrSum, cnt) = minNbrsDF.select(sum(col(MIN_NBR).cast(DecimalType(20, 0))), count("*")).rdd - .map { r => - (r.getAs[BigDecimal](0), r.getLong(1)) - }.first() - if (cnt != 0L && minNbrSum == null) { - throw new ArithmeticException( - s""" - |The total sum of edge src IDs is used to determine convergence during iterations. - |However, the total sum at iteration $iteration exceeded 30 digits (1e30), - |which should happen only if the graph contains more than 200 billion edges. - |If not, please file a bug report at https://github.com/graphframes/graphframes/issues. - """.stripMargin) + logInfo(s"$logPrefix Preparing the graph for connected component computation ...") + val g = prepare(graph) + val vv = g.vertices + var ee = g.edges.persist(intermediateStorageLevel) // src < dst + val numEdges = ee.count() + logInfo(s"$logPrefix Found $numEdges edges after preparation.") + + var converged = false + var iteration = 1 + + def _calcMinNbrSum(minNbrsDF: DataFrame): BigDecimal = { + // Taking the sum in DecimalType to preserve precision. + // We use 20 digits for long values and Spark SQL will add 10 digits for the sum. + // It should be able to handle 200 billion edges without overflow. + val (minNbrSum, cnt) = minNbrsDF.select(sum(col(MIN_NBR).cast(DecimalType(20, 0))), count("*")).rdd + .map { r => + (r.getAs[BigDecimal](0), r.getLong(1)) + }.first() + if (cnt != 0L && minNbrSum == null) { + throw new ArithmeticException( + s""" + |The total sum of edge src IDs is used to determine convergence during iterations. + |However, the total sum at iteration $iteration exceeded 30 digits (1e30), + |which should happen only if the graph contains more than 200 billion edges. + |If not, please file a bug report at https://github.com/graphframes/graphframes/issues. + """.stripMargin) + } + minNbrSum } - minNbrSum - } - // compute min neighbors (including self-min) - var minNbrs1: DataFrame = minNbrs(ee) // src >= min_nbr - .persist(intermediateStorageLevel) - - var prevSum: BigDecimal = _calcMinNbrSum(minNbrs1) - - var lastRoundPersistedDFs = Seq[DataFrame](ee, minNbrs1) - while (!converged) { - var currRoundPersistedDFs = Seq[DataFrame]() - // large-star step - // connect all strictly larger neighbors to the min neighbor (including self) - ee = skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix) - .select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst - .distinct() + // compute min neighbors (including self-min) + var minNbrs1: DataFrame = minNbrs(ee) // src >= min_nbr .persist(intermediateStorageLevel) - currRoundPersistedDFs = currRoundPersistedDFs :+ ee - // small-star step - // compute min neighbors (excluding self-min) - val minNbrs2 = ee.groupBy(col(SRC)).agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr - .persist(intermediateStorageLevel) - currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs2 - - // connect all smaller neighbors to the min neighbor - ee = skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix) - .select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst - .filter(col(SRC) =!= col(DST)) // src < dst - // connect self to the min neighbor - ee = ee.union(minNbrs2.select(col(MIN_NBR).as(SRC), col(SRC).as(DST))) // src < dst - .distinct() - - // checkpointing - if (shouldCheckpoint && (iteration % checkpointInterval == 0)) { - // TODO: remove this after DataFrame.checkpoint is implemented - val out = s"${checkpointDir.get}/$iteration" - ee.write.parquet(out) - // may hit S3 eventually consistent issue - ee = spark.read.parquet(out) - - // remove previous checkpoint - if (iteration > checkpointInterval) { - val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}") - path.getFileSystem(sc.hadoopConfiguration).delete(path, true) + var prevSum: BigDecimal = _calcMinNbrSum(minNbrs1) + + var lastRoundPersistedDFs = Seq[DataFrame](ee, minNbrs1) + while (!converged) { + var currRoundPersistedDFs = Seq[DataFrame]() + // large-star step + // connect all strictly larger neighbors to the min neighbor (including self) + ee = skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix) + .select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst + .distinct() + .persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ ee + + // small-star step + // compute min neighbors (excluding self-min) + val minNbrs2 = ee.groupBy(col(SRC)).agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr + .persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs2 + + // connect all smaller neighbors to the min neighbor + ee = skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix) + .select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst + .filter(col(SRC) =!= col(DST)) // src < dst + // connect self to the min neighbor + ee = ee.union(minNbrs2.select(col(MIN_NBR).as(SRC), col(SRC).as(DST))) // src < dst + .distinct() + + // checkpointing + if (shouldCheckpoint && (iteration % checkpointInterval == 0)) { + // TODO: remove this after DataFrame.checkpoint is implemented + val out = s"${checkpointDir.get}/$iteration" + ee.write.parquet(out) + // may hit S3 eventually consistent issue + ee = spark.read.parquet(out) + + // remove previous checkpoint + if (iteration > checkpointInterval) { + val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}") + path.getFileSystem(sc.hadoopConfiguration).delete(path, true) + } + + System.gc() // hint Spark to clean shuffle directories } - System.gc() // hint Spark to clean shuffle directories - } - - ee.persist(intermediateStorageLevel) - currRoundPersistedDFs = currRoundPersistedDFs :+ ee - - minNbrs1 = minNbrs(ee) // src >= min_nbr - .persist(intermediateStorageLevel) - currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1 - - // test convergence - val currSum = _calcMinNbrSum(minNbrs1) - logInfo(s"$logPrefix Sum of assigned components in iteration $iteration: $currSum.") - if (currSum == prevSum) { - // This also covers the case when cnt = 0 and currSum is null, which means no edges. - converged = true - } else { - prevSum = currSum - } + ee.persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ ee + + minNbrs1 = minNbrs(ee) // src >= min_nbr + .persist(intermediateStorageLevel) + currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1 + + // test convergence + val currSum = _calcMinNbrSum(minNbrs1) + logInfo(s"$logPrefix Sum of assigned components in iteration $iteration: $currSum.") + if (currSum == prevSum) { + // This also covers the case when cnt = 0 and currSum is null, which means no edges. + converged = true + } else { + prevSum = currSum + } - // materialize all persisted DataFrames in current round, - // then we can unpersist last round persisted DataFrames. - for (persisted_df <- currRoundPersistedDFs) { - persisted_df.count() // materialize it. - } - for (persisted_df <- lastRoundPersistedDFs) { - persisted_df.unpersist() + // materialize all persisted DataFrames in current round, + // then we can unpersist last round persisted DataFrames. + for (persisted_df <- currRoundPersistedDFs) { + persisted_df.count() // materialize it. + } + for (persisted_df <- lastRoundPersistedDFs) { + persisted_df.unpersist() + } + lastRoundPersistedDFs = currRoundPersistedDFs + iteration += 1 } - lastRoundPersistedDFs = currRoundPersistedDFs - iteration += 1 - } - logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.") + logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.") - logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.") - vv.join(ee, vv(ID) === ee(DST), "left_outer") - .select(vv(ATTR), when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(COMPONENT)) - .select(col(s"$ATTR.*"), col(COMPONENT)) + logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.") + vv.join(ee, vv(ID) === ee(DST), "left_outer") + .select(vv(ATTR), when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(COMPONENT)) + .select(col(s"$ATTR.*"), col(COMPONENT)) + } finally { + // Restore original AQE setting + spark.conf.set("spark.sql.adaptive.enabled", originalAQE) + } } }