diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 1b1148915..34daf0717 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -171,6 +171,7 @@ object ConnectedComponents extends Logging { import org.graphframes.GraphFrame._ private val COMPONENT = "component" + private val ORIG_ID = "orig_id" private val MIN_NBR = "min_nbr" private val CNT = "cnt" private val CHECKPOINT_NAME_PREFIX = "connected-components" @@ -435,11 +436,27 @@ object ConnectedComponents extends Logging { logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.") logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.") - val output = vv + val indexedLabel = vv .join(ee, vv(ID) === ee(DST), "left_outer") - .select(vv(ATTR), when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(COMPONENT)) - .select(col(s"$ATTR.*"), col(COMPONENT)) - .persist(intermediateStorageLevel) + .select( + vv(ATTR), + when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(COMPONENT), + col(ATTR + "." + ID).as(ID)) + val output = if (graph.hasIntegralIdType) { + indexedLabel + .select(col(s"$ATTR.*"), col(COMPONENT)) + .persist(intermediateStorageLevel) + } else { + indexedLabel + .join( + indexedLabel + .groupBy(col(COMPONENT)) + .agg(min(col(ID)).as(ORIG_ID)) + .select(col(COMPONENT), col(ORIG_ID)), + COMPONENT) + .select(col(s"$ATTR.*"), col(ORIG_ID).as(COMPONENT)) + .persist(intermediateStorageLevel) + } // An action must be performed on the DataFrame for the cache to load output.count() diff --git a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index 32221311b..3b3fcf4ea 100644 --- a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -271,7 +271,7 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon // note: not using agg + collect_list because collect_list is not available in 1.6.2 w/o hive val actualComponents = actual .select("component", "id") - .as[(Long, T)] + .as[(T, T)] .rdd .groupByKey() .values