@@ -29,6 +29,7 @@ class TraverserTest extends Specification {
2929 enter : { TraverserContext context ->
3030 preOrderNodes << context. thisNode(). number
3131 println " enter:$preOrderNodes "
32+ context. setResult(context. thisNode())
3233 TraversalControl . CONTINUE
3334 },
3435 leave : { TraverserContext context ->
@@ -44,6 +45,7 @@ class TraverserTest extends Specification {
4445 then :
4546 ! result. encounteredCycle
4647 result. fullTraversal
48+ result. result. number == 5
4749 preOrderNodes == [0 , 1 , 3 , 2 , 4 , 5 ]
4850 postOrderNodes == [3 , 1 , 4 , 5 , 2 , 0 ]
4951 }
@@ -56,6 +58,7 @@ class TraverserTest extends Specification {
5658 def visitor = [
5759 enter : { TraverserContext context ->
5860 enterData << context. thisNode(). number
61+ context. setResult(context. thisNode())
5962 println " enter:$enterData "
6063 TraversalControl . CONTINUE
6164 },
@@ -71,6 +74,7 @@ class TraverserTest extends Specification {
7174 then :
7275 ! result. encounteredCycle
7376 result. fullTraversal
77+ result. result. number == 5
7478 enterData == [0 , 1 , 2 , 3 , 4 , 5 ]
7579 leaveData == [0 , 1 , 2 , 3 , 4 , 5 ]
7680 }
@@ -256,5 +260,78 @@ class TraverserTest extends Specification {
256260 0 * visitor. backRef(_)
257261 }
258262
263+
264+ def " test context variables" () {
265+ given :
266+ def visitor = [
267+ enter : { TraverserContext context ->
268+ assert context. getParentContext(). getVar(Object . class) == " var1"
269+ assert context. getParentContext(). getVar(String . class) == " var2"
270+ context. setVar(Object . class, " var1" )
271+ context. setVar(String . class, " var2" )
272+
273+ TraversalControl . CONTINUE
274+ },
275+ leave : { TraverserContext context ->
276+ TraversalControl . CONTINUE
277+ }
278+ ] as TraverserVisitor
279+ when :
280+ def result = Traverser . breadthFirst({ n -> n. children },)
281+ .rootVars([(Object . class): " var1" , (String . class): " var2" ])
282+ .traverse(root, visitor)
283+
284+
285+ then :
286+ true
287+ }
288+
289+ def " test parent result chain" () {
290+ given :
291+ def visitor = [
292+ enter : { TraverserContext context ->
293+ List visited = context. getParentResult()
294+ visited = visited == null ? new ArrayList<> () : visited
295+ visited. add(context. thisNode(). number)
296+ context. setVar(List . class, visited)
297+ context. setResult(visited)
298+ TraversalControl . CONTINUE
299+ },
300+ leave : { TraverserContext context ->
301+ TraversalControl . CONTINUE
302+ }
303+ ] as TraverserVisitor
304+ when :
305+ def result = Traverser . breadthFirst({ n -> n. children },)
306+ .traverse(root, visitor)
307+
308+
309+ then :
310+ result. result == [0 , 1 , 2 , 3 , 4 , 5 ]
311+ }
312+
313+ def " test initial data" () {
314+ def visitor = [
315+ enter : { TraverserContext context ->
316+ assert context. getInitialData() == " foo"
317+ TraversalControl . CONTINUE
318+ },
319+ leave : { TraverserContext context ->
320+ assert context. getInitialData() == " foo"
321+ TraversalControl . QUIT
322+ },
323+
324+ ] as TraverserVisitor
325+
326+ when :
327+ Traverser . depthFirst({ n -> n. children }, " foo" ). traverse(root, visitor)
328+
329+ then :
330+ true
331+
332+ }
333+
259334}
260335
336+
337+
0 commit comments