Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 40 additions & 23 deletions core/src/main/scala/org/graphframes/GraphFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ class GraphFrame private (
*/
def find(pattern: String): DataFrame = {
val VarLengthPattern = """\((\w+)\)-\[(\w*)\*(\d*)\.\.(\d*)\]-(>?)\((\w+)\)""".r
val FixedLengthUndirectedPattern = """\((\w+)\)-\[(\w*)\*(\d*)\]-\((\w+)\)""".r
Comment thread
SemyonSinchenko marked this conversation as resolved.

pattern match {
case VarLengthPattern(src, name, min, max, direction, dst) =>
Expand All @@ -478,37 +479,53 @@ class GraphFrame private (
s"Unbounded length patten ${pattern} is not supported! " +
"Please a pattern of defined length.")
}
val strToSeq: Seq[(Int, String)] = (min.toInt to max.toInt).reverse.map { hop =>
(hop, s"($src)-[$name*$hop]->($dst)")
}
val strToSeqReverse: Seq[(Int, String)] = if (direction.isEmpty) {
(min.toInt to max.toInt).reverse.map(hop => (hop, s"($src)<-[$name*$hop]-($dst)"))
} else {
Seq.empty[(Int, String)]
}

val out: Seq[DataFrame] = strToSeq.map { case (hop, patternStr) =>
findAugmentedPatterns(patternStr)
.withColumn("_hop", lit(hop))
.withColumn("_pattern", lit(patternStr))
.withColumn("_direction", lit("out"))
}
findVarLengthPattern(src, name, min.toInt, max.toInt, direction, dst)

val in: Seq[DataFrame] = strToSeqReverse.map { case (hop, patternStr) =>
findAugmentedPatterns(patternStr)
.withColumn("_hop", lit(hop))
.withColumn("_pattern", lit(patternStr))
.withColumn("_direction", lit("in"))
case FixedLengthUndirectedPattern(src, name, hop, dst) =>
if (hop.isEmpty) {
throw new InvalidParseException("Missing hop!")
}

val ret = (out ++ in).reduce((a, b) => a.unionByName(b, allowMissingColumns = true))
ret.orderBy("_hop", "_direction")
findVarLengthPattern(src, name, hop.toInt, hop.toInt, "", dst)

case _ =>
findAugmentedPatterns(pattern)
}
}

def findVarLengthPattern(
src: String,
name: String,
min: Int,
max: Int,
direction: String,
dst: String): DataFrame = {
val strToSeq: Seq[(Int, String)] = (min to max).reverse.map { hop =>
(hop, s"($src)-[$name*$hop]->($dst)")
}
val strToSeqReverse: Seq[(Int, String)] = if (direction.isEmpty) {
(min to max).reverse.map(hop => (hop, s"($src)<-[$name*$hop]-($dst)"))
} else {
Seq.empty[(Int, String)]
}

val out: Seq[DataFrame] = strToSeq.map { case (hop, patternStr) =>
findAugmentedPatterns(patternStr)
.withColumn("_hop", lit(hop))
.withColumn("_pattern", lit(patternStr))
.withColumn("_direction", lit("out"))
}

val in: Seq[DataFrame] = strToSeqReverse.map { case (hop, patternStr) =>
findAugmentedPatterns(patternStr)
.withColumn("_hop", lit(hop))
.withColumn("_pattern", lit(patternStr))
.withColumn("_direction", lit("in"))
}

val ret = (out ++ in).reduce((a, b) => a.unionByName(b, allowMissingColumns = true))
ret.orderBy("_hop", "_direction")
}

def findAugmentedPatterns(pattern: String): DataFrame = {
val patterns = Pattern.parse(pattern)

Expand Down
8 changes: 8 additions & 0 deletions core/src/test/scala/org/graphframes/PatternMatchSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,14 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
assert(res.except(expected).isEmpty && expected.except(res).isEmpty)
}

test("undirected fixed-length pattern") {
val res = g.find("(u)-[e*3]-(v)")
val expected = g.find("(u)-[e*3..3]-(v)")

assert(res.schema === expected.schema)
assert(res.except(expected).isEmpty && expected.except(res).isEmpty)
}

test("stateful predicates via UDFs") {
val chain4 = g
.find("(a)-[ab]->(b); (b)-[bc]->(c); (c)-[cd]->(d)")
Expand Down