Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/main/scala/org/graphframes/GraphFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ import org.graphframes.pattern._
*/
class GraphFrame private(
@transient private val _vertices: DataFrame,
@transient private val _edges: DataFrame) extends Logging with Serializable {
@transient private val _edges: DataFrame,
@transient private val _storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK) extends Logging with Serializable {

import GraphFrame._

Expand Down Expand Up @@ -480,7 +481,7 @@ class GraphFrame private(
.distinct()
.sortWithinPartitions(ID)
.withColumn(LONG_ID, monotonically_increasing_id())
.persist(StorageLevel.MEMORY_AND_DISK)
.persist(_storageLevel)
vertices.select(col(ID), nestAsCol(vertices, ATTR))
.join(withLongIds, ID)
.select(LONG_ID, ID, ATTR)
Expand Down Expand Up @@ -605,7 +606,7 @@ object GraphFrame extends Serializable with Logging {
* destination vertex IDs. All other columns are treated as edge attributes.
* @return New [[GraphFrame]] instance
*/
def apply(vertices: DataFrame, edges: DataFrame): GraphFrame = {
def apply(vertices: DataFrame, edges: DataFrame, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK): GraphFrame = {
require(vertices.columns.contains(ID),
s"Vertex ID column '$ID' missing from vertex DataFrame, which has columns: "
+ vertices.columns.mkString(","))
Expand All @@ -616,7 +617,7 @@ object GraphFrame extends Serializable with Logging {
s"Destination vertex ID column '$DST' missing from edge DataFrame, which has columns: "
+ edges.columns.mkString(","))

new GraphFrame(vertices, edges)
new GraphFrame(vertices, edges, storageLevel)
}

/**
Expand All @@ -631,12 +632,12 @@ object GraphFrame extends Serializable with Logging {
*
* @group conversions
*/
def fromEdges(e: DataFrame): GraphFrame = {
def fromEdges(e: DataFrame, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK): GraphFrame = {
val srcs = e.select(e("src").as("id"))
val dsts = e.select(e("dst").as("id"))
val v = srcs.unionAll(dsts).distinct
v.persist(StorageLevel.MEMORY_AND_DISK)
apply(v, e)
v.persist(storageLevel)
apply(v, e, storageLevel)
}

/**
Expand Down