diff --git a/src/main/scala/org/graphframes/GraphFrame.scala b/src/main/scala/org/graphframes/GraphFrame.scala index d6c6c73d7..64e22d095 100644 --- a/src/main/scala/org/graphframes/GraphFrame.scala +++ b/src/main/scala/org/graphframes/GraphFrame.scala @@ -41,7 +41,8 @@ import org.graphframes.pattern._ */ class GraphFrame private( @transient private val _vertices: DataFrame, - @transient private val _edges: DataFrame) extends Logging with Serializable { + @transient private val _edges: DataFrame, + @transient private val _storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK) extends Logging with Serializable { import GraphFrame._ @@ -480,7 +481,7 @@ class GraphFrame private( .distinct() .sortWithinPartitions(ID) .withColumn(LONG_ID, monotonically_increasing_id()) - .persist(StorageLevel.MEMORY_AND_DISK) + .persist(_storageLevel) vertices.select(col(ID), nestAsCol(vertices, ATTR)) .join(withLongIds, ID) .select(LONG_ID, ID, ATTR) @@ -605,7 +606,7 @@ object GraphFrame extends Serializable with Logging { * destination vertex IDs. All other columns are treated as edge attributes. * @return New [[GraphFrame]] instance */ - def apply(vertices: DataFrame, edges: DataFrame): GraphFrame = { + def apply(vertices: DataFrame, edges: DataFrame, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK): GraphFrame = { require(vertices.columns.contains(ID), s"Vertex ID column '$ID' missing from vertex DataFrame, which has columns: " + vertices.columns.mkString(",")) @@ -616,7 +617,7 @@ object GraphFrame extends Serializable with Logging { s"Destination vertex ID column '$DST' missing from edge DataFrame, which has columns: " + edges.columns.mkString(",")) - new GraphFrame(vertices, edges) + new GraphFrame(vertices, edges, storageLevel) } /** @@ -631,12 +632,12 @@ object GraphFrame extends Serializable with Logging { * * @group conversions */ - def fromEdges(e: DataFrame): GraphFrame = { + def fromEdges(e: DataFrame, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK): GraphFrame = { val srcs = e.select(e("src").as("id")) val dsts = e.select(e("dst").as("id")) val v = srcs.unionAll(dsts).distinct - v.persist(StorageLevel.MEMORY_AND_DISK) - apply(v, e) + v.persist(storageLevel) + apply(v, e, storageLevel) } /**