diff --git a/core/src/main/scala/org/graphframes/GraphFrame.scala b/core/src/main/scala/org/graphframes/GraphFrame.scala index 852ffc82e..a56f57304 100644 --- a/core/src/main/scala/org/graphframes/GraphFrame.scala +++ b/core/src/main/scala/org/graphframes/GraphFrame.scala @@ -470,6 +470,7 @@ class GraphFrame private ( */ def find(pattern: String): DataFrame = { val VarLengthPattern = """\((\w+)\)-\[(\w*)\*(\d*)\.\.(\d*)\]-(>?)\((\w+)\)""".r + val FixedLengthUndirectedPattern = """\((\w+)\)-\[(\w*)\*(\d*)\]-\((\w+)\)""".r pattern match { case VarLengthPattern(src, name, min, max, direction, dst) => @@ -478,37 +479,53 @@ class GraphFrame private ( s"Unbounded length patten ${pattern} is not supported! " + "Please a pattern of defined length.") } - val strToSeq: Seq[(Int, String)] = (min.toInt to max.toInt).reverse.map { hop => - (hop, s"($src)-[$name*$hop]->($dst)") - } - val strToSeqReverse: Seq[(Int, String)] = if (direction.isEmpty) { - (min.toInt to max.toInt).reverse.map(hop => (hop, s"($src)<-[$name*$hop]-($dst)")) - } else { - Seq.empty[(Int, String)] - } - - val out: Seq[DataFrame] = strToSeq.map { case (hop, patternStr) => - findAugmentedPatterns(patternStr) - .withColumn("_hop", lit(hop)) - .withColumn("_pattern", lit(patternStr)) - .withColumn("_direction", lit("out")) - } + findVarLengthPattern(src, name, min.toInt, max.toInt, direction, dst) - val in: Seq[DataFrame] = strToSeqReverse.map { case (hop, patternStr) => - findAugmentedPatterns(patternStr) - .withColumn("_hop", lit(hop)) - .withColumn("_pattern", lit(patternStr)) - .withColumn("_direction", lit("in")) + case FixedLengthUndirectedPattern(src, name, hop, dst) => + if (hop.isEmpty) { + throw new InvalidParseException("Missing hop!") } - - val ret = (out ++ in).reduce((a, b) => a.unionByName(b, allowMissingColumns = true)) - ret.orderBy("_hop", "_direction") + findVarLengthPattern(src, name, hop.toInt, hop.toInt, "", dst) case _ => findAugmentedPatterns(pattern) } } + def findVarLengthPattern( + src: String, + name: String, + min: Int, + max: Int, + direction: String, + dst: String): DataFrame = { + val strToSeq: Seq[(Int, String)] = (min to max).reverse.map { hop => + (hop, s"($src)-[$name*$hop]->($dst)") + } + val strToSeqReverse: Seq[(Int, String)] = if (direction.isEmpty) { + (min to max).reverse.map(hop => (hop, s"($src)<-[$name*$hop]-($dst)")) + } else { + Seq.empty[(Int, String)] + } + + val out: Seq[DataFrame] = strToSeq.map { case (hop, patternStr) => + findAugmentedPatterns(patternStr) + .withColumn("_hop", lit(hop)) + .withColumn("_pattern", lit(patternStr)) + .withColumn("_direction", lit("out")) + } + + val in: Seq[DataFrame] = strToSeqReverse.map { case (hop, patternStr) => + findAugmentedPatterns(patternStr) + .withColumn("_hop", lit(hop)) + .withColumn("_pattern", lit(patternStr)) + .withColumn("_direction", lit("in")) + } + + val ret = (out ++ in).reduce((a, b) => a.unionByName(b, allowMissingColumns = true)) + ret.orderBy("_hop", "_direction") + } + def findAugmentedPatterns(pattern: String): DataFrame = { val patterns = Pattern.parse(pattern) diff --git a/core/src/test/scala/org/graphframes/PatternMatchSuite.scala b/core/src/test/scala/org/graphframes/PatternMatchSuite.scala index 5289f04f5..f8b57870b 100644 --- a/core/src/test/scala/org/graphframes/PatternMatchSuite.scala +++ b/core/src/test/scala/org/graphframes/PatternMatchSuite.scala @@ -790,6 +790,14 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(res.except(expected).isEmpty && expected.except(res).isEmpty) } + test("undirected fixed-length pattern") { + val res = g.find("(u)-[e*3]-(v)") + val expected = g.find("(u)-[e*3..3]-(v)") + + assert(res.schema === expected.schema) + assert(res.except(expected).isEmpty && expected.except(res).isEmpty) + } + test("stateful predicates via UDFs") { val chain4 = g .find("(a)-[ab]->(b); (b)-[bc]->(c); (c)-[cd]->(d)")