Skip to content

Commit 2fb3fc9

Browse files
committed
Support chained comparisons
1 parent a77b348 commit 2fb3fc9

File tree

7 files changed

+141
-43
lines changed

7 files changed

+141
-43
lines changed

parser/src/ast.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,8 @@ pub enum Expression {
159159
value: Box<Expression>,
160160
},
161161
Compare {
162-
a: Box<Expression>,
163-
op: Comparison,
164-
b: Box<Expression>,
162+
vals: Vec<Expression>,
163+
ops: Vec<Comparison>,
165164
},
166165
Attribute {
167166
value: Box<Expression>,

parser/src/parser.rs

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -460,26 +460,30 @@ mod tests {
460460
},
461461
ifs: vec![
462462
ast::Expression::Compare {
463-
a: Box::new(ast::Expression::Identifier {
464-
name: "a".to_string()
465-
}),
466-
op: ast::Comparison::Less,
467-
b: Box::new(ast::Expression::Number {
468-
value: ast::Number::Integer {
469-
value: BigInt::from(5)
463+
vals: vec![
464+
ast::Expression::Identifier {
465+
name: "a".to_string()
466+
},
467+
ast::Expression::Number {
468+
value: ast::Number::Integer {
469+
value: BigInt::from(5)
470+
}
470471
}
471-
}),
472+
],
473+
ops: vec![ast::Comparison::Less],
472474
},
473475
ast::Expression::Compare {
474-
a: Box::new(ast::Expression::Identifier {
475-
name: "a".to_string()
476-
}),
477-
op: ast::Comparison::Greater,
478-
b: Box::new(ast::Expression::Number {
479-
value: ast::Number::Integer {
480-
value: BigInt::from(10)
476+
vals: vec![
477+
ast::Expression::Identifier {
478+
name: "a".to_string()
479+
},
480+
ast::Expression::Number {
481+
value: ast::Number::Integer {
482+
value: BigInt::from(10)
483+
}
481484
}
482-
}),
485+
],
486+
ops: vec![ast::Comparison::Greater],
483487
},
484488
],
485489
}

parser/src/python.lalrpop

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,15 @@ NotTest: ast::Expression = {
680680
};
681681

682682
Comparison: ast::Expression = {
683-
<e1:Comparison> <op:CompOp> <e2:Expression> => ast::Expression::Compare { a: Box::new(e1), op: op, b: Box::new(e2) },
683+
<e:Expression> <comparisons:(CompOp Expression)+> => {
684+
let mut vals = vec![e];
685+
let mut ops = vec![];
686+
for x in comparisons {
687+
ops.push(x.0);
688+
vals.push(x.1);
689+
}
690+
ast::Expression::Compare { vals, ops }
691+
},
684692
<e:Expression> => e,
685693
};
686694

tests/snippets/ast_snippet.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,10 @@ def foo():
2121
assert foo.body[0].value.func.id == 'print'
2222
assert foo.body[0].lineno == 3
2323
assert foo.body[1].lineno == 4
24+
25+
n = ast.parse("3 < 4 > 5\n")
26+
assert n.body[0].value.left.n == 3
27+
assert 'Lt' in str(n.body[0].value.ops[0])
28+
assert 'Gt' in str(n.body[0].value.ops[1])
29+
assert n.body[0].value.comparators[0].n == 4
30+
assert n.body[0].value.comparators[1].n == 5

tests/snippets/comparisons.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
2+
assert 1 < 2
3+
assert 1 < 2 < 3
4+
assert 5 == 5 == 5
5+
assert (5 == 5) == True
6+
assert 5 == 5 != 4 == 4 > 3 > 2 < 3 <= 3 != 0 == 0
7+
8+
assert not 1 > 2
9+
assert not 5 == 5 == True
10+
assert not 5 == 5 != 5 == 5
11+
assert not 1 < 2 < 3 > 4
12+
assert not 1 < 2 > 3 < 4
13+
assert not 1 > 2 < 3 < 4

vm/src/compile.rs

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,79 @@ impl Compiler {
840840
Ok(())
841841
}
842842

843+
fn compile_chained_comparison(
844+
&mut self,
845+
vals: &[ast::Expression],
846+
ops: &[ast::Comparison],
847+
) -> Result<(), CompileError> {
848+
assert!(ops.len() > 0);
849+
assert!(vals.len() == ops.len() + 1);
850+
851+
let to_operator = |op: &ast::Comparison| match op {
852+
ast::Comparison::Equal => bytecode::ComparisonOperator::Equal,
853+
ast::Comparison::NotEqual => bytecode::ComparisonOperator::NotEqual,
854+
ast::Comparison::Less => bytecode::ComparisonOperator::Less,
855+
ast::Comparison::LessOrEqual => bytecode::ComparisonOperator::LessOrEqual,
856+
ast::Comparison::Greater => bytecode::ComparisonOperator::Greater,
857+
ast::Comparison::GreaterOrEqual => bytecode::ComparisonOperator::GreaterOrEqual,
858+
ast::Comparison::In => bytecode::ComparisonOperator::In,
859+
ast::Comparison::NotIn => bytecode::ComparisonOperator::NotIn,
860+
ast::Comparison::Is => bytecode::ComparisonOperator::Is,
861+
ast::Comparison::IsNot => bytecode::ComparisonOperator::IsNot,
862+
};
863+
864+
// a == b == c == d
865+
// compile into (pseudocode):
866+
// result = a == b
867+
// if result:
868+
// result = b == c
869+
// if result:
870+
// result = c == d
871+
872+
// initialize lhs outside of loop
873+
self.compile_expression(&vals[0])?;
874+
875+
let break_label = self.new_label();
876+
let last_label = self.new_label();
877+
878+
// for all comparisons except the last (as the last one doesn't need a conditional jump)
879+
let ops_slice = &ops[0..ops.len()];
880+
let vals_slice = &vals[1..ops.len()];
881+
for (op, val) in ops_slice.iter().zip(vals_slice.iter()) {
882+
self.compile_expression(val)?;
883+
// store rhs for the next comparison in chain
884+
self.emit(Instruction::Duplicate);
885+
self.emit(Instruction::Rotate { amount: 3 });
886+
887+
self.emit(Instruction::CompareOperation {
888+
op: to_operator(op),
889+
});
890+
891+
// if comparison result is false, we break with this value; if true, try the next one.
892+
// (CPython compresses these three opcodes into JUMP_IF_FALSE_OR_POP)
893+
self.emit(Instruction::Duplicate);
894+
self.emit(Instruction::JumpIfFalse {
895+
target: break_label,
896+
});
897+
self.emit(Instruction::Pop);
898+
}
899+
900+
// handle the last comparison
901+
self.compile_expression(vals.last().unwrap())?;
902+
self.emit(Instruction::CompareOperation {
903+
op: to_operator(ops.last().unwrap()),
904+
});
905+
self.emit(Instruction::Jump { target: last_label });
906+
907+
// early exit left us with stack: `rhs, comparison_result`. We need to clean up rhs.
908+
self.set_label(break_label);
909+
self.emit(Instruction::Rotate { amount: 2 });
910+
self.emit(Instruction::Pop);
911+
912+
self.set_label(last_label);
913+
Ok(())
914+
}
915+
843916
fn compile_store(&mut self, target: &ast::Expression) -> Result<(), CompileError> {
844917
match target {
845918
ast::Expression::Identifier { name } => {
@@ -1022,24 +1095,8 @@ impl Compiler {
10221095
name: name.to_string(),
10231096
});
10241097
}
1025-
ast::Expression::Compare { a, op, b } => {
1026-
self.compile_expression(a)?;
1027-
self.compile_expression(b)?;
1028-
1029-
let i = match op {
1030-
ast::Comparison::Equal => bytecode::ComparisonOperator::Equal,
1031-
ast::Comparison::NotEqual => bytecode::ComparisonOperator::NotEqual,
1032-
ast::Comparison::Less => bytecode::ComparisonOperator::Less,
1033-
ast::Comparison::LessOrEqual => bytecode::ComparisonOperator::LessOrEqual,
1034-
ast::Comparison::Greater => bytecode::ComparisonOperator::Greater,
1035-
ast::Comparison::GreaterOrEqual => bytecode::ComparisonOperator::GreaterOrEqual,
1036-
ast::Comparison::In => bytecode::ComparisonOperator::In,
1037-
ast::Comparison::NotIn => bytecode::ComparisonOperator::NotIn,
1038-
ast::Comparison::Is => bytecode::ComparisonOperator::Is,
1039-
ast::Comparison::IsNot => bytecode::ComparisonOperator::IsNot,
1040-
};
1041-
let i = Instruction::CompareOperation { op: i };
1042-
self.emit(i);
1098+
ast::Expression::Compare { vals, ops } => {
1099+
self.compile_chained_comparison(vals, ops)?;
10431100
}
10441101
ast::Expression::Number { value } => {
10451102
let const_value = match value {

vm/src/stdlib/ast.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,14 @@ fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObj
327327

328328
node
329329
}
330-
ast::Expression::Compare { a, op, b } => {
330+
ast::Expression::Compare { vals, ops } => {
331331
let node = create_node(vm, "Compare");
332332

333-
let py_a = expression_to_ast(vm, a);
333+
let py_a = expression_to_ast(vm, &vals[0]);
334334
vm.ctx.set_attr(&node, "left", py_a);
335335

336336
// Operator:
337-
let str_op = match op {
337+
let to_operator = |op: &ast::Comparison| match op {
338338
ast::Comparison::Equal => "Eq",
339339
ast::Comparison::NotEqual => "NotEq",
340340
ast::Comparison::Less => "Lt",
@@ -346,10 +346,20 @@ fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObj
346346
ast::Comparison::Is => "Is",
347347
ast::Comparison::IsNot => "IsNot",
348348
};
349-
let py_ops = vm.ctx.new_list(vec![vm.ctx.new_str(str_op.to_string())]);
349+
let py_ops = vm.ctx.new_list(
350+
ops.iter()
351+
.map(|x| vm.ctx.new_str(to_operator(x).to_string()))
352+
.collect(),
353+
);
354+
350355
vm.ctx.set_attr(&node, "ops", py_ops);
351356

352-
let py_b = vm.ctx.new_list(vec![expression_to_ast(vm, b)]);
357+
let py_b = vm.ctx.new_list(
358+
vals.iter()
359+
.skip(1)
360+
.map(|x| expression_to_ast(vm, x))
361+
.collect(),
362+
);
353363
vm.ctx.set_attr(&node, "comparators", py_b);
354364
node
355365
}

0 commit comments

Comments
 (0)