@@ -840,6 +840,79 @@ impl Compiler {
840840 Ok ( ( ) )
841841 }
842842
843+ fn compile_chained_comparison (
844+ & mut self ,
845+ vals : & [ ast:: Expression ] ,
846+ ops : & [ ast:: Comparison ] ,
847+ ) -> Result < ( ) , CompileError > {
848+ assert ! ( ops. len( ) > 0 ) ;
849+ assert ! ( vals. len( ) == ops. len( ) + 1 ) ;
850+
851+ let to_operator = |op : & ast:: Comparison | match op {
852+ ast:: Comparison :: Equal => bytecode:: ComparisonOperator :: Equal ,
853+ ast:: Comparison :: NotEqual => bytecode:: ComparisonOperator :: NotEqual ,
854+ ast:: Comparison :: Less => bytecode:: ComparisonOperator :: Less ,
855+ ast:: Comparison :: LessOrEqual => bytecode:: ComparisonOperator :: LessOrEqual ,
856+ ast:: Comparison :: Greater => bytecode:: ComparisonOperator :: Greater ,
857+ ast:: Comparison :: GreaterOrEqual => bytecode:: ComparisonOperator :: GreaterOrEqual ,
858+ ast:: Comparison :: In => bytecode:: ComparisonOperator :: In ,
859+ ast:: Comparison :: NotIn => bytecode:: ComparisonOperator :: NotIn ,
860+ ast:: Comparison :: Is => bytecode:: ComparisonOperator :: Is ,
861+ ast:: Comparison :: IsNot => bytecode:: ComparisonOperator :: IsNot ,
862+ } ;
863+
864+ // a == b == c == d
865+ // compile into (pseudocode):
866+ // result = a == b
867+ // if result:
868+ // result = b == c
869+ // if result:
870+ // result = c == d
871+
872+ // initialize lhs outside of loop
873+ self . compile_expression ( & vals[ 0 ] ) ?;
874+
875+ let break_label = self . new_label ( ) ;
876+ let last_label = self . new_label ( ) ;
877+
878+ // for all comparisons except the last (as the last one doesn't need a conditional jump)
879+ let ops_slice = & ops[ 0 ..ops. len ( ) ] ;
880+ let vals_slice = & vals[ 1 ..ops. len ( ) ] ;
881+ for ( op, val) in ops_slice. iter ( ) . zip ( vals_slice. iter ( ) ) {
882+ self . compile_expression ( val) ?;
883+ // store rhs for the next comparison in chain
884+ self . emit ( Instruction :: Duplicate ) ;
885+ self . emit ( Instruction :: Rotate { amount : 3 } ) ;
886+
887+ self . emit ( Instruction :: CompareOperation {
888+ op : to_operator ( op) ,
889+ } ) ;
890+
891+ // if comparison result is false, we break with this value; if true, try the next one.
892+ // (CPython compresses these three opcodes into JUMP_IF_FALSE_OR_POP)
893+ self . emit ( Instruction :: Duplicate ) ;
894+ self . emit ( Instruction :: JumpIfFalse {
895+ target : break_label,
896+ } ) ;
897+ self . emit ( Instruction :: Pop ) ;
898+ }
899+
900+ // handle the last comparison
901+ self . compile_expression ( vals. last ( ) . unwrap ( ) ) ?;
902+ self . emit ( Instruction :: CompareOperation {
903+ op : to_operator ( ops. last ( ) . unwrap ( ) ) ,
904+ } ) ;
905+ self . emit ( Instruction :: Jump { target : last_label } ) ;
906+
907+ // early exit left us with stack: `rhs, comparison_result`. We need to clean up rhs.
908+ self . set_label ( break_label) ;
909+ self . emit ( Instruction :: Rotate { amount : 2 } ) ;
910+ self . emit ( Instruction :: Pop ) ;
911+
912+ self . set_label ( last_label) ;
913+ Ok ( ( ) )
914+ }
915+
843916 fn compile_store ( & mut self , target : & ast:: Expression ) -> Result < ( ) , CompileError > {
844917 match target {
845918 ast:: Expression :: Identifier { name } => {
@@ -1022,24 +1095,8 @@ impl Compiler {
10221095 name : name. to_string ( ) ,
10231096 } ) ;
10241097 }
1025- ast:: Expression :: Compare { a, op, b } => {
1026- self . compile_expression ( a) ?;
1027- self . compile_expression ( b) ?;
1028-
1029- let i = match op {
1030- ast:: Comparison :: Equal => bytecode:: ComparisonOperator :: Equal ,
1031- ast:: Comparison :: NotEqual => bytecode:: ComparisonOperator :: NotEqual ,
1032- ast:: Comparison :: Less => bytecode:: ComparisonOperator :: Less ,
1033- ast:: Comparison :: LessOrEqual => bytecode:: ComparisonOperator :: LessOrEqual ,
1034- ast:: Comparison :: Greater => bytecode:: ComparisonOperator :: Greater ,
1035- ast:: Comparison :: GreaterOrEqual => bytecode:: ComparisonOperator :: GreaterOrEqual ,
1036- ast:: Comparison :: In => bytecode:: ComparisonOperator :: In ,
1037- ast:: Comparison :: NotIn => bytecode:: ComparisonOperator :: NotIn ,
1038- ast:: Comparison :: Is => bytecode:: ComparisonOperator :: Is ,
1039- ast:: Comparison :: IsNot => bytecode:: ComparisonOperator :: IsNot ,
1040- } ;
1041- let i = Instruction :: CompareOperation { op : i } ;
1042- self . emit ( i) ;
1098+ ast:: Expression :: Compare { vals, ops } => {
1099+ self . compile_chained_comparison ( vals, ops) ?;
10431100 }
10441101 ast:: Expression :: Number { value } => {
10451102 let const_value = match value {
0 commit comments