Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Inplace ops support in JIT
  • Loading branch information
ShaharNaveh committed Dec 1, 2025
commit c7c5e678e2a79a46f528ff3e12b95549b95941a7
138 changes: 87 additions & 51 deletions crates/jit/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,18 +446,30 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
let b_type = b.to_jit_type();

let val = match (op, a, b) {
(BinaryOperator::Add, JitValue::Int(a), JitValue::Int(b)) => {
(
BinaryOperator::Add | BinaryOperator::InplaceAdd,
JitValue::Int(a),
JitValue::Int(b),
) => {
let (out, carry) = self.builder.ins().sadd_overflow(a, b);
self.builder.ins().trapnz(carry, TrapCode::INTEGER_OVERFLOW);
JitValue::Int(out)
}
(BinaryOperator::Subtract, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Int(self.compile_sub(a, b))
}
(BinaryOperator::FloorDivide, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Int(self.builder.ins().sdiv(a, b))
}
(BinaryOperator::TrueDivide, JitValue::Int(a), JitValue::Int(b)) => {
(
BinaryOperator::Subtract | BinaryOperator::InplaceSubtract,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.compile_sub(a, b)),
(
BinaryOperator::FloorDivide | BinaryOperator::InplaceFloorDivide,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.builder.ins().sdiv(a, b)),
(
BinaryOperator::TrueDivide | BinaryOperator::InplaceTrueDivide,
JitValue::Int(a),
JitValue::Int(b),
) => {
// Check if b == 0, If so trap with a division by zero error
self.builder
.ins()
Expand All @@ -467,15 +479,21 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
let b_float = self.builder.ins().fcvt_from_sint(types::F64, b);
JitValue::Float(self.builder.ins().fdiv(a_float, b_float))
}
(BinaryOperator::Multiply, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Int(self.builder.ins().imul(a, b))
}
(BinaryOperator::Remainder, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Int(self.builder.ins().srem(a, b))
}
(BinaryOperator::Power, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Int(self.compile_ipow(a, b))
}
(
BinaryOperator::Multiply | BinaryOperator::InplaceMultiply,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.builder.ins().imul(a, b)),
(
BinaryOperator::Remainder | BinaryOperator::InplaceRemainder,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.builder.ins().srem(a, b)),
(
BinaryOperator::Power | BinaryOperator::InplacePower,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.compile_ipow(a, b)),
(
BinaryOperator::Lshift | BinaryOperator::Rshift,
JitValue::Int(a),
Expand All @@ -489,39 +507,57 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
TrapCode::user(CustomTrapCode::NegativeShiftCount as u8).unwrap(),
);

let out = if op == BinaryOperator::Lshift {
self.builder.ins().ishl(a, b)
} else {
self.builder.ins().sshr(a, b)
};
let out =
if matches!(op, BinaryOperator::Lshift | BinaryOperator::InplaceLshift)
{
self.builder.ins().ishl(a, b)
} else {
self.builder.ins().sshr(a, b)
};
JitValue::Int(out)
}
(BinaryOperator::And, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Int(self.builder.ins().band(a, b))
}
(BinaryOperator::Or, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Int(self.builder.ins().bor(a, b))
}
(BinaryOperator::Xor, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Int(self.builder.ins().bxor(a, b))
}
(
BinaryOperator::And | BinaryOperator::InplaceAnd,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.builder.ins().band(a, b)),
(
BinaryOperator::Or | BinaryOperator::InplaceOr,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.builder.ins().bor(a, b)),
(
BinaryOperator::Xor | BinaryOperator::InplaceXor,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.builder.ins().bxor(a, b)),

// Floats
(BinaryOperator::Add, JitValue::Float(a), JitValue::Float(b)) => {
JitValue::Float(self.builder.ins().fadd(a, b))
}
(BinaryOperator::Subtract, JitValue::Float(a), JitValue::Float(b)) => {
JitValue::Float(self.builder.ins().fsub(a, b))
}
(BinaryOperator::Multiply, JitValue::Float(a), JitValue::Float(b)) => {
JitValue::Float(self.builder.ins().fmul(a, b))
}
(BinaryOperator::TrueDivide, JitValue::Float(a), JitValue::Float(b)) => {
JitValue::Float(self.builder.ins().fdiv(a, b))
}
(BinaryOperator::Power, JitValue::Float(a), JitValue::Float(b)) => {
JitValue::Float(self.compile_fpow(a, b))
}
(
BinaryOperator::Add | BinaryOperator::InplaceAdd,
JitValue::Float(a),
JitValue::Float(b),
) => JitValue::Float(self.builder.ins().fadd(a, b)),
(
BinaryOperator::Subtract | BinaryOperator::InplaceSubtract,
JitValue::Float(a),
JitValue::Float(b),
) => JitValue::Float(self.builder.ins().fsub(a, b)),
(
BinaryOperator::Multiply | BinaryOperator::InplaceMultiply,
JitValue::Float(a),
JitValue::Float(b),
) => JitValue::Float(self.builder.ins().fmul(a, b)),
(
BinaryOperator::TrueDivide | BinaryOperator::InplaceTrueDivide,
JitValue::Float(a),
JitValue::Float(b),
) => JitValue::Float(self.builder.ins().fdiv(a, b)),
(
BinaryOperator::Power | BinaryOperator::InplacePower,
JitValue::Float(a),
JitValue::Float(b),
) => JitValue::Float(self.compile_fpow(a, b)),

// Floats and Integers
(_, JitValue::Int(a), JitValue::Float(b))
Expand All @@ -537,19 +573,19 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
};

match op {
BinaryOperator::Add => {
BinaryOperator::Add | BinaryOperator::InplaceAdd => {
JitValue::Float(self.builder.ins().fadd(operand_one, operand_two))
}
BinaryOperator::Subtract => {
BinaryOperator::Subtract | BinaryOperator::InplaceSubtract => {
JitValue::Float(self.builder.ins().fsub(operand_one, operand_two))
}
BinaryOperator::Multiply => {
BinaryOperator::Multiply | BinaryOperator::InplaceMultiply => {
JitValue::Float(self.builder.ins().fmul(operand_one, operand_two))
}
BinaryOperator::TrueDivide => {
BinaryOperator::TrueDivide | BinaryOperator::InplaceTrueDivide => {
JitValue::Float(self.builder.ins().fdiv(operand_one, operand_two))
}
BinaryOperator::Power => {
BinaryOperator::Power | BinaryOperator::InplacePower => {
JitValue::Float(self.compile_fpow(operand_one, operand_two))
}
_ => return Err(JitCompileError::NotSupported),
Expand Down
Loading