From 5060edbbe1bb40f99a77aba7fe1dc3c6af45dbf8 Mon Sep 17 00:00:00 2001 From: Goun Na Date: Mon, 29 Sep 2025 02:04:16 +0900 Subject: [PATCH] undirected edge --- .../scala/org/graphframes/GraphFrame.scala | 21 +++++++-- .../org/graphframes/PatternMatchSuite.scala | 45 +++++++++++++++++++ 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/graphframes/GraphFrame.scala b/core/src/main/scala/org/graphframes/GraphFrame.scala index e753d422c..8d5b5b112 100644 --- a/core/src/main/scala/org/graphframes/GraphFrame.scala +++ b/core/src/main/scala/org/graphframes/GraphFrame.scala @@ -373,9 +373,11 @@ class GraphFrame private ( * @group motif */ def find(pattern: String): DataFrame = { - val VarLengthPattern = """\((\w+)\)-\[(\w*)\*(\d*)\.\.(\d*)\]->\((\w+)\)""".r + val VarLengthPattern = """\((\w+)\)-\[(\w*)\*(\d*)\.\.(\d*)\]-(>?)\((\w+)\)""".r + val UndirectedPattern = """\((\w+)\)-\[(\w*)\]-\((\w+)\)""".r + pattern match { - case VarLengthPattern(src, name, min, max, dst) => + case VarLengthPattern(src, name, min, max, direction, dst) => if (min.isEmpty || max.isEmpty) { throw new InvalidParseException( s"Unbounded length patten ${pattern} is not supported! " + @@ -384,9 +386,22 @@ class GraphFrame private ( val strToSeq: Seq[String] = (min.toInt to max.toInt).reverse.map { hop => s"($src)-[$name*$hop]->($dst)" } - strToSeq + val strToSeqReverse: Seq[String] = if (direction.isEmpty) { + (min.toInt to max.toInt).reverse.map(hop => s"($dst)-[$name*$hop]->($src)") + } else { + Seq.empty[String] + } + + (strToSeq ++ strToSeqReverse) .map(findAugmentedPatterns) .reduce((a, b) => a.unionByName(b, allowMissingColumns = true)) + + case UndirectedPattern(src, name, dst) => + val out: DataFrame = findAugmentedPatterns(s"($src)-[$name]->($dst)") + val in: DataFrame = findAugmentedPatterns(s"($dst)-[$name]->($src)") + + out.unionByName(in) + case _ => findAugmentedPatterns(pattern) } diff --git a/core/src/test/scala/org/graphframes/PatternMatchSuite.scala b/core/src/test/scala/org/graphframes/PatternMatchSuite.scala index bdb413fb7..097c76fd3 100644 --- a/core/src/test/scala/org/graphframes/PatternMatchSuite.scala +++ b/core/src/test/scala/org/graphframes/PatternMatchSuite.scala @@ -658,6 +658,51 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(varEdge.except(unionEdge).isEmpty && unionEdge.except(varEdge).isEmpty) } + test("undirected edge") { + val res = g + .find("(u)-[]-(v)") + .where("u.id == 0") + .select("u.id", "v.id") + .collect() + .toSet + + val expected = Set(Row(0L, 1L), Row(0L, 2L)) + + compareResultToExpected(res, expected) + } + + test("undirected with edge name") { + val res = g + .find("(u)-[e]-(v)") + .where("u.id == 0") + .select("e.src", "e.dst", "e.relationship") + .collect() + .toSet + + val expected = Set(Row(0L, 1L, "friend"), Row(1L, 0L, "follow"), Row(2L, 0L, "unknown")) + + compareResultToExpected(res, expected) + } + + test("undirected var-length pattern") { + val res = g + .find("(u)-[e*1..3]-(v)") + .where("u.id == 2") + + val df1 = g + .find("(u)-[e*1..3]->(v)") + .where("u.id == 2") + + val df2 = g + .find("(v)-[e*1..3]->(u)") + .where("u.id == 2") + + val expected = df1.unionByName(df2, allowMissingColumns = true) + + assert(res.schema === expected.schema) + assert(res.except(expected).isEmpty && expected.except(res).isEmpty) + } + test("stateful predicates via UDFs") { val chain4 = g .find("(a)-[ab]->(b); (b)-[bc]->(c); (c)-[cd]->(d)")