Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Root cause: The assertion rewriter's visit_NamedExpr returned the ast…
….NamedExpr node directly, which was then referenced in multiple places in the rewritten AST (the comparison statement, the results tuple for

   _call_reprcompare, and format expressions). Each reference re-evaluated the walrus operator at runtime, causing side effects to fire multiple times.

Fix (in src/_pytest/assertion/rewrite.py):

  1. visit_NamedExpr: Returns the NamedExpr inline (preserving evaluation order for function args and comparators) but references the target variable for display instead of re-evaluating the expression.
  2. visit_Compare (left side): When comp.left is a NamedExpr, hoists it into a temp variable before processing comparators. This ensures comparators that reference the walrus target see the assigned value
  (fixing the assert (obj := "foo") == f(obj) case).
  3. visit_Compare (comparators): For NamedExpr comparators, uses the target variable Name in results (failure message) instead of the NamedExpr node, preventing re-evaluation in the failure path. Also saves the
   left value in a temp when a walrus will overwrite it (for correct failure messages).
  4. visit_BoolOp: Saves the short-circuit condition in a per-operand temp variable for the explanation path, so walrus modifications to the original variable don't corrupt the failure message.
  5. visit_Assert: Clears variables_overwrite at the start of each assert, preventing stale walrus mappings from leaking between statements.

Closes #14445
  • Loading branch information
shuckc committed May 7, 2026
commit f4319185abaae6148a14f40e0d1f57c197fd8d58
86 changes: 51 additions & 35 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,8 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
self.statements: list[ast.stmt] = []
self.variables: list[str] = []
self.variable_counter = itertools.count()
# Clear walrus overwrite tracking — only valid within a single assert.
self.variables_overwrite[self.scope] = {}

if self.enable_assertion_pass_hook:
self.format_variables: list[str] = []
Expand Down Expand Up @@ -964,15 +966,16 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
return self.statements

def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
# This method handles the 'walrus operator' repr of the target
# name if it's a local variable or _should_repr_global_name()
# thinks it's acceptable.
# Return the NamedExpr node itself to preserve in-place evaluation order.
# For the display (used in failure messages), reference the target variable
# rather than re-evaluating the NamedExpr.
locs = ast.Call(self.builtin("locals"), [], [])
target_id = name.target.id
target_name = ast.Name(target_id, ast.Load())
inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", name)
dorepr = self.helper("_should_repr_global_name", target_name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(name), ast.Constant(target_id))
expr = ast.IfExp(test, self.display(target_name), ast.Constant(target_id))
return name, self.explanation_param(expr)

def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]:
Expand All @@ -998,20 +1001,9 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
for i, v in enumerate(boolop.values):
if i:
fail_inner: list[ast.stmt] = []
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821
# expl_cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(expl_cond, fail_inner, [])) # noqa: F821
self.expl_stmts = fail_inner
match v:
# Check if the left operand is an ast.NamedExpr and the value has already been visited
case ast.Compare(
left=ast.NamedExpr(target=ast.Name(id=target_id))
) if target_id in [
e.id for e in boolop.values[:i] if hasattr(e, "id")
]:
pytest_temp = self.variable()
self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment]
# mypy's false positive, we're checking that the 'target' attribute exists.
v.left.target.id = pytest_temp # type:ignore[attr-defined]
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
Expand All @@ -1022,6 +1014,12 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
cond: ast.expr = res
if is_or:
cond = ast.UnaryOp(ast.Not(), cond)
# Save the condition value for the explanation path. A walrus
# in a later operand may modify the variable, but the saved
# value preserves the original truthiness for display purposes.
expl_cond_var = self.variable()
body.append(ast.Assign([ast.Name(expl_cond_var, ast.Store())], cond))
expl_cond: ast.expr = ast.Name(expl_cond_var, ast.Load())
inner: list[ast.stmt] = []
self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner
Expand Down Expand Up @@ -1100,15 +1098,21 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:

def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
self.push_format_context()
# We first check if we have overwritten a variable in the previous assert
match comp.left:
case ast.Name(id=name_id) if name_id in self.variables_overwrite.get(
self.scope, {}
):
comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment]
case ast.NamedExpr(target=ast.Name(id=target_id)):
self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment]
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, ast.NamedExpr):
# Hoist the NamedExpr into a temp variable BEFORE visiting comparators,
# so that comparators referencing the walrus target see the assigned value.
# The assign evaluates (target := value), storing result in @py_assertN.
target_id = comp.left.target.id
left_res = self.assign(comp.left)
target_name = ast.Name(target_id, ast.Load())
locs = ast.Call(self.builtin("locals"), [], [])
inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", target_name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(left_res), ast.Constant(target_id))
left_expl = self.explanation_param(expr)
else:
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, ast.Compare | ast.BoolOp):
left_expl = f"({left_expl})"
res_variables = [self.variable() for i in range(len(comp.ops))]
Expand All @@ -1119,18 +1123,30 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
syms: list[ast.expr] = []
results = [left_res]
for i, op, next_operand in it:
match (next_operand, left_res):
case (
ast.NamedExpr(target=ast.Name(id=target_id)),
ast.Name(id=name_id),
) if target_id == name_id:
next_operand.target.id = self.variable()
self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment]
# If the comparator is a NamedExpr that overwrites the left operand's
# variable, save the left value in a temp BEFORE the comparison so
# the failure message can display the pre-walrus value.
if (
isinstance(next_operand, ast.NamedExpr)
and isinstance(left_res, ast.Name)
and next_operand.target.id == left_res.id
):
saved_left = self.variable()
self.statements.append(
ast.Assign([ast.Name(saved_left, ast.Store())], left_res)
)
results[-1] = ast.Name(saved_left, ast.Load())

next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, ast.Compare | ast.BoolOp):
next_expl = f"({next_expl})"
results.append(next_res)
# For NamedExpr comparators, use the target variable in results
# (for the failure message) instead of the NamedExpr node itself,
# to avoid re-evaluating the walrus operator in the failure path.
if isinstance(next_operand, ast.NamedExpr):
results.append(ast.Name(next_operand.target.id, ast.Load()))
else:
results.append(next_res)
sym = BINOP_MAP[op.__class__]
syms.append(ast.Constant(sym))
expl = f"{left_expl} {sym} {next_expl}"
Expand Down
4 changes: 2 additions & 2 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ def test_walrus_operator_change_boolean_value():
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*assert not (True and False is False)"])
result.stdout.fnmatch_lines(["*assert not (False and False is False)"])

def test_assertion_walrus_operator_boolean_none_fails(
self, pytester: Pytester
Expand All @@ -1702,7 +1702,7 @@ def test_walrus_operator_change_boolean_value():
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*assert not (True and None is None)"])
result.stdout.fnmatch_lines(["*assert not (None and None is None)"])

def test_assertion_walrus_operator_value_changes_cleared_after_each_test(
self, pytester: Pytester
Expand Down