From 82c75f59450a9478f94b4008dcb9ac279e34b9c0 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sun, 24 Mar 2024 10:05:02 -0400 Subject: [PATCH] Implement `compare` opcode (#46) This is the `<=>` / `partial_cmp` operation! --- fidget/src/core/compiler/alloc.rs | 16 +++- fidget/src/core/compiler/op.rs | 16 +++- fidget/src/core/compiler/ssa_tape.rs | 14 ++++ fidget/src/core/context/mod.rs | 30 +++++++ fidget/src/core/context/op.rs | 1 + fidget/src/core/eval/test/grad_slice.rs | 12 ++- fidget/src/core/eval/test/interval.rs | 14 ++++ fidget/src/core/eval/test/mod.rs | 25 +++++- fidget/src/core/vm/data.rs | 7 +- fidget/src/core/vm/mod.rs | 104 ++++++++++++++++++++++++ fidget/src/jit/aarch64/float_slice.rs | 36 ++++++++ fidget/src/jit/aarch64/grad_slice.rs | 35 ++++++++ fidget/src/jit/aarch64/interval.rs | 56 +++++++++++++ fidget/src/jit/aarch64/point.rs | 34 ++++++++ fidget/src/jit/mod.rs | 14 ++++ fidget/src/jit/x86_64/float_slice.rs | 32 ++++++++ fidget/src/jit/x86_64/grad_slice.rs | 32 ++++++++ fidget/src/jit/x86_64/interval.rs | 57 +++++++++++++ fidget/src/jit/x86_64/point.rs | 30 +++++++ 19 files changed, 554 insertions(+), 11 deletions(-) diff --git a/fidget/src/core/compiler/alloc.rs b/fidget/src/core/compiler/alloc.rs index 325156b7..de51d09c 100644 --- a/fidget/src/core/compiler/alloc.rs +++ b/fidget/src/core/compiler/alloc.rs @@ -295,14 +295,17 @@ impl RegisterAllocator { | SsaOp::DivRegImm(..) | SsaOp::DivImmReg(..) | SsaOp::MinRegImm(..) - | SsaOp::MaxRegImm(..) => self.op_reg_imm(op), + | SsaOp::MaxRegImm(..) + | SsaOp::CompareRegImm(..) + | SsaOp::CompareImmReg(..) => self.op_reg_imm(op), SsaOp::AddRegReg(..) | SsaOp::SubRegReg(..) | SsaOp::MulRegReg(..) | SsaOp::DivRegReg(..) | SsaOp::MinRegReg(..) - | SsaOp::MaxRegReg(..) => self.op_reg_reg(op), + | SsaOp::MaxRegReg(..) + | SsaOp::CompareRegReg(..) => self.op_reg_reg(op), } } @@ -485,6 +488,9 @@ impl RegisterAllocator { SsaOp::MaxRegReg(out, lhs, rhs) => { (out, lhs, rhs, RegOp::MaxRegReg) } + SsaOp::CompareRegReg(out, lhs, rhs) => { + (out, lhs, rhs, RegOp::CompareRegReg) + } _ => panic!("Bad opcode: {op:?}"), }; let r_x = self.get_out_reg(out); @@ -597,6 +603,12 @@ impl RegisterAllocator { SsaOp::MaxRegImm(out, arg, imm) => { (out, arg, imm, RegOp::MaxRegImm) } + SsaOp::CompareRegImm(out, arg, imm) => { + (out, arg, imm, RegOp::CompareRegImm) + } + SsaOp::CompareImmReg(out, arg, imm) => { + (out, arg, imm, RegOp::CompareImmReg) + } _ => panic!("Bad opcode: {op:?}"), }; self.op_reg_fn(out, arg, |out, arg| op(out, arg, imm)); diff --git a/fidget/src/core/compiler/op.rs b/fidget/src/core/compiler/op.rs index 146eeda5..8cdfc539 100644 --- a/fidget/src/core/compiler/op.rs +++ b/fidget/src/core/compiler/op.rs @@ -85,6 +85,10 @@ macro_rules! opcodes { MinRegImm($t, $t, f32), #[doc = "Compute the maximum of a register and an immediate"] MaxRegImm($t, $t, f32), + #[doc = "Compares a register with an immediate"] + CompareRegImm($t, $t, f32), + #[doc = "Compares an immediate with a register"] + CompareImmReg($t, $t, f32), #[doc = "Add two registers"] AddRegReg($t, $t, $t), @@ -98,6 +102,8 @@ macro_rules! opcodes { MinRegReg($t, $t, $t), #[doc = "Take the maximum of two registers"] MaxRegReg($t, $t, $t), + #[doc = "Compares two registers"] + CompareRegReg($t, $t, $t), #[doc = "Copy an immediate to a register"] CopyImm($t, f32), @@ -159,7 +165,10 @@ impl SsaOp { | SsaOp::MinRegImm(out, ..) | SsaOp::MaxRegImm(out, ..) | SsaOp::MinRegReg(out, ..) - | SsaOp::MaxRegReg(out, ..) => *out, + | SsaOp::MaxRegReg(out, ..) + | SsaOp::CompareRegReg(out, ..) + | SsaOp::CompareRegImm(out, ..) + | SsaOp::CompareImmReg(out, ..) => *out, } } /// Returns true if the given opcode is associated with a choice @@ -191,7 +200,10 @@ impl SsaOp { | SsaOp::SubRegReg(..) | SsaOp::DivRegReg(..) | SsaOp::DivRegImm(..) - | SsaOp::DivImmReg(..) => false, + | SsaOp::DivImmReg(..) + | SsaOp::CompareRegReg(..) + | SsaOp::CompareRegImm(..) + | SsaOp::CompareImmReg(..) => false, SsaOp::MinRegImm(..) | SsaOp::MaxRegImm(..) | SsaOp::MinRegReg(..) diff --git a/fidget/src/core/compiler/ssa_tape.rs b/fidget/src/core/compiler/ssa_tape.rs index d78cc4a7..b9970aac 100644 --- a/fidget/src/core/compiler/ssa_tape.rs +++ b/fidget/src/core/compiler/ssa_tape.rs @@ -161,6 +161,11 @@ impl SsaTape { SsaOp::MaxRegImm, SsaOp::MaxRegImm, ), + BinaryOpcode::Compare => ( + SsaOp::CompareRegReg, + SsaOp::CompareRegImm, + SsaOp::CompareImmReg, + ), }; if matches!(op, BinaryOpcode::Min | BinaryOpcode::Max) { @@ -331,6 +336,15 @@ impl SsaTape { println!("${out} = {op} ${arg} {imm}"); } } + SsaOp::CompareRegReg(out, lhs, rhs) => { + println!("${out} = COMPARE {lhs} {rhs}") + } + SsaOp::CompareRegImm(out, arg, imm) => { + println!("${out} = COMPARE {arg} {imm}") + } + SsaOp::CompareImmReg(out, arg, imm) => { + println!("${out} = COMPARE {imm} {arg}") + } SsaOp::CopyImm(out, imm) => { println!("${out} = COPY {imm}"); } diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index d388d4ab..0bb03137 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -511,6 +511,31 @@ impl Context { } } + /// Builds a node that compares two values + /// + /// The result is -1 if `a < b`, +1 if `a > b`, 0 if `a == b`, and `NaN` if + /// either side is `NaN`. + /// ``` + /// # let mut ctx = fidget::context::Context::new(); + /// let x = ctx.x(); + /// let op = ctx.compare(x, 1.0).unwrap(); + /// let v = ctx.eval_xyz(op, 0.0, 0.0, 0.0).unwrap(); + /// assert_eq!(v, -1.0); + /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap(); + /// assert_eq!(v, 1.0); + /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap(); + /// assert_eq!(v, 0.0); + /// ``` + pub fn compare( + &mut self, + a: A, + b: B, + ) -> Result { + let a = a.into_node(self)?; + let b = b.into_node(self)?; + self.op_binary(a, b, BinaryOpcode::Compare) + } + //////////////////////////////////////////////////////////////////////////// /// Remaps the X, Y, Z nodes to the given values @@ -643,6 +668,10 @@ impl Context { BinaryOpcode::Div => a / b, BinaryOpcode::Min => a.min(b), BinaryOpcode::Max => a.max(b), + BinaryOpcode::Compare => a + .partial_cmp(&b) + .map(|i| i as i8 as f64) + .unwrap_or(f64::NAN), } } @@ -769,6 +798,7 @@ impl Context { BinaryOpcode::Div => out += "div", BinaryOpcode::Min => out += "min", BinaryOpcode::Max => out += "max", + BinaryOpcode::Compare => out += "less-than", }, Op::Unary(op, ..) => match op { UnaryOpcode::Neg => out += "neg", diff --git a/fidget/src/core/context/op.rs b/fidget/src/core/context/op.rs index fd9e61b1..66551999 100644 --- a/fidget/src/core/context/op.rs +++ b/fidget/src/core/context/op.rs @@ -30,6 +30,7 @@ pub enum BinaryOpcode { Div, Min, Max, + Compare, } /// An operation in a math expression. diff --git a/fidget/src/core/eval/test/grad_slice.rs b/fidget/src/core/eval/test/grad_slice.rs index 2969fcdb..e9e35edd 100644 --- a/fidget/src/core/eval/test/grad_slice.rs +++ b/fidget/src/core/eval/test/grad_slice.rs @@ -478,7 +478,7 @@ where } } - pub fn compare_grad_results( + pub fn compare_grad_results( i: usize, j: usize, lhs: &[f32], @@ -503,6 +503,10 @@ where continue; } + if C::discontinuous_at(*a, *b) { + continue; + } + const EPSILON: f64 = 1e-8; let a = *a as f64; let b = *b as f64; @@ -599,7 +603,7 @@ where .unwrap(); let rhs = if i == j { &args } else { &rgsa }; - Self::compare_grad_results( + Self::compare_grad_results::( i, j, &args, @@ -642,7 +646,7 @@ where .unwrap(); let rhs = vec![*rhs; out.len()]; - Self::compare_grad_results( + Self::compare_grad_results::( i, 3, &args, @@ -685,7 +689,7 @@ where .unwrap(); let lhs = vec![*lhs; out.len()]; - Self::compare_grad_results( + Self::compare_grad_results::( 3, i, &lhs, diff --git a/fidget/src/core/eval/test/interval.rs b/fidget/src/core/eval/test/interval.rs index 17f5f76d..07229e01 100644 --- a/fidget/src/core/eval/test/interval.rs +++ b/fidget/src/core/eval/test/interval.rs @@ -619,6 +619,19 @@ where ); } + pub fn test_i_compare() { + let mut ctx = Context::new(); + let x = ctx.x(); + let y = ctx.y(); + let c = ctx.compare(x, y).unwrap(); + + let shape = S::new(&ctx, c).unwrap(); + let tape = shape.ez_interval_tape(); + let mut eval = S::new_interval_eval(); + let (out, _trace) = eval.eval(&tape, -5.0, -6.0, 0.0, &[]).unwrap(); + assert_eq!(out, Interval::from(1f32)); + } + pub fn test_i_stress_n(depth: usize) { let (ctx, node) = build_stress_fn(depth); @@ -947,6 +960,7 @@ macro_rules! interval_tests { $crate::interval_test!(test_i_min_imm, $t); $crate::interval_test!(test_i_max, $t); $crate::interval_test!(test_i_max_imm, $t); + $crate::interval_test!(test_i_compare, $t); $crate::interval_test!(test_i_simplify, $t); $crate::interval_test!(test_i_var, $t); $crate::interval_test!(test_i_stress, $t); diff --git a/fidget/src/core/eval/test/mod.rs b/fidget/src/core/eval/test/mod.rs index 2c744662..d4b656f5 100644 --- a/fidget/src/core/eval/test/mod.rs +++ b/fidget/src/core/eval/test/mod.rs @@ -72,6 +72,14 @@ pub trait CanonicalBinaryOp { fn eval_reg_reg_f64(lhs: f64, rhs: f64) -> f64; fn eval_reg_imm_f64(lhs: f64, rhs: f64) -> f64; fn eval_imm_reg_f64(lhs: f64, rhs: f64) -> f64; + + /// Returns true if there is a bidirectional discontinuity at a position + /// + /// This means that we should skip gradient checking, because we can't + /// accurately estimate the gradient on either side. + fn discontinuous_at(_lhs: f32, _rhs: f32) -> bool { + false + } } macro_rules! declare_canonical_unary { @@ -93,7 +101,7 @@ macro_rules! declare_canonical_unary { } macro_rules! declare_canonical_binary { - (Context::$i:ident, |$lhs:ident, $rhs:ident| $t:expr) => { + (Context::$i:ident, |$lhs:ident, $rhs:ident| $t:expr, |$lhs2:ident, $rhs2: ident| $d:expr) => { pub struct $i; impl CanonicalBinaryOp for $i { const NAME: &'static str = stringify!($i); @@ -118,8 +126,14 @@ macro_rules! declare_canonical_binary { fn eval_imm_reg_f64($lhs: f64, $rhs: f64) -> f64 { $t } + fn discontinuous_at($lhs2: f32, $rhs2: f32) -> bool { + $d + } } }; + (Context::$i:ident, |$lhs:ident, $rhs:ident| $t:expr) => { + declare_canonical_binary!(Context::$i, |$lhs, $rhs| $t, |_a, _b| false); + }; } macro_rules! declare_canonical_binary_full { @@ -204,6 +218,14 @@ pub mod canonical { a.max(b) } ); + declare_canonical_binary!( + Context::compare, + |a, b| match a.partial_cmp(&b) { + None => f32::NAN.into(), + Some(v) => (v as i8).into(), + }, + |a, b| a == b + ); } #[macro_export] @@ -254,5 +276,6 @@ macro_rules! all_binary_tests { $crate::one_binary_test!($tester, div); $crate::one_binary_test!($tester, min); $crate::one_binary_test!($tester, max); + $crate::one_binary_test!($tester, compare); }; } diff --git a/fidget/src/core/vm/data.rs b/fidget/src/core/vm/data.rs index bcc6786e..f23a099a 100644 --- a/fidget/src/core/vm/data.rs +++ b/fidget/src/core/vm/data.rs @@ -239,7 +239,8 @@ impl VmData { SsaOp::AddRegReg(index, lhs, rhs) | SsaOp::MulRegReg(index, lhs, rhs) | SsaOp::SubRegReg(index, lhs, rhs) - | SsaOp::DivRegReg(index, lhs, rhs) => { + | SsaOp::DivRegReg(index, lhs, rhs) + | SsaOp::CompareRegReg(index, lhs, rhs) => { *index = new_index; *lhs = workspace.get_or_insert_active(*lhs); *rhs = workspace.get_or_insert_active(*rhs); @@ -249,7 +250,9 @@ impl VmData { | SsaOp::SubRegImm(index, arg, _imm) | SsaOp::SubImmReg(index, arg, _imm) | SsaOp::DivRegImm(index, arg, _imm) - | SsaOp::DivImmReg(index, arg, _imm) => { + | SsaOp::DivImmReg(index, arg, _imm) + | SsaOp::CompareRegImm(index, arg, _imm) + | SsaOp::CompareImmReg(index, arg, _imm) => { *index = new_index; *arg = workspace.get_or_insert_active(*arg); } diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index 1456f6f2..ba0f998c 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -359,6 +359,39 @@ impl TracingEvaluator for VmIntervalEval { RegOp::MulRegReg(out, lhs, rhs) => v[out] = v[lhs] * v[rhs], RegOp::DivRegReg(out, lhs, rhs) => v[out] = v[lhs] / v[rhs], RegOp::SubRegReg(out, lhs, rhs) => v[out] = v[lhs] - v[rhs], + RegOp::CompareRegReg(out, lhs, rhs) => { + v[out] = if v[lhs].has_nan() || v[rhs].has_nan() { + f32::NAN.into() + } else if v[lhs].upper() < v[rhs].lower() { + Interval::from(-1.0) + } else if v[lhs].lower() > v[rhs].upper() { + Interval::from(1.0) + } else { + Interval::new(-1.0, 1.0) + }; + } + RegOp::CompareRegImm(out, arg, imm) => { + v[out] = if v[arg].has_nan() || imm.is_nan() { + f32::NAN.into() + } else if v[arg].upper() < imm { + Interval::from(-1.0) + } else if v[arg].lower() > imm { + Interval::from(1.0) + } else { + Interval::new(-1.0, 1.0) + }; + } + RegOp::CompareImmReg(out, arg, imm) => { + v[out] = if v[arg].has_nan() || imm.is_nan() { + f32::NAN.into() + } else if imm < v[arg].lower() { + Interval::from(-1.0) + } else if imm > v[arg].upper() { + Interval::from(1.0) + } else { + Interval::new(-1.0, 1.0) + }; + } RegOp::MinRegReg(out, lhs, rhs) => { let (value, choice) = v[lhs].min_choice(v[rhs]); v[out] = value; @@ -541,6 +574,24 @@ impl TracingEvaluator for VmPointEval { RegOp::DivRegReg(out, lhs, rhs) => { v[out] = v[lhs] / v[rhs]; } + RegOp::CompareRegReg(out, lhs, rhs) => { + v[out] = v[lhs] + .partial_cmp(&v[rhs]) + .map(|c| c as i8 as f32) + .unwrap_or(f32::NAN) + } + RegOp::CompareRegImm(out, arg, imm) => { + v[out] = v[arg] + .partial_cmp(&imm) + .map(|c| c as i8 as f32) + .unwrap_or(f32::NAN) + } + RegOp::CompareImmReg(out, arg, imm) => { + v[out] = imm + .partial_cmp(&v[arg]) + .map(|c| c as i8 as f32) + .unwrap_or(f32::NAN) + } RegOp::SubRegReg(out, lhs, rhs) => { v[out] = v[lhs] - v[rhs]; } @@ -766,6 +817,22 @@ impl BulkEvaluator for VmFloatSliceEval { v[out][i] = v[arg][i] - imm; } } + RegOp::CompareImmReg(out, arg, imm) => { + for i in 0..size { + v[out][i] = imm + .partial_cmp(&v[arg][i]) + .map(|c| c as i8 as f32) + .unwrap_or(f32::NAN) + } + } + RegOp::CompareRegImm(out, arg, imm) => { + for i in 0..size { + v[out][i] = v[arg][i] + .partial_cmp(&imm) + .map(|c| c as i8 as f32) + .unwrap_or(f32::NAN) + } + } RegOp::MinRegImm(out, arg, imm) => { for i in 0..size { v[out][i] = if v[arg][i].is_nan() || imm.is_nan() { @@ -804,6 +871,14 @@ impl BulkEvaluator for VmFloatSliceEval { v[out][i] = v[lhs][i] - v[rhs][i]; } } + RegOp::CompareRegReg(out, lhs, rhs) => { + for i in 0..size { + v[out][i] = v[lhs][i] + .partial_cmp(&v[rhs][i]) + .map(|c| c as i8 as f32) + .unwrap_or(f32::NAN) + } + } RegOp::MinRegReg(out, lhs, rhs) => { for i in 0..size { v[out][i] = if v[lhs][i].is_nan() || v[rhs][i].is_nan() @@ -996,6 +1071,25 @@ impl BulkEvaluator for VmGradSliceEval { v[out][i] = v[arg][i] - imm; } } + RegOp::CompareImmReg(out, arg, imm) => { + for i in 0..size { + let p = imm + .partial_cmp(&v[arg][i].v) + .map(|c| c as i8 as f32) + .unwrap_or(f32::NAN); + v[out][i] = Grad::new(p, 0.0, 0.0, 0.0); + } + } + RegOp::CompareRegImm(out, arg, imm) => { + for i in 0..size { + let p = v[arg][i] + .v + .partial_cmp(&imm) + .map(|c| c as i8 as f32) + .unwrap_or(f32::NAN); + v[out][i] = Grad::new(p, 0.0, 0.0, 0.0); + } + } RegOp::MinRegImm(out, arg, imm) => { let imm: Grad = imm.into(); for i in 0..size { @@ -1036,6 +1130,16 @@ impl BulkEvaluator for VmGradSliceEval { v[out][i] = v[lhs][i] - v[rhs][i]; } } + RegOp::CompareRegReg(out, lhs, rhs) => { + for i in 0..size { + let p = v[lhs][i] + .v + .partial_cmp(&v[rhs][i].v) + .map(|c| c as i8 as f32) + .unwrap_or(f32::NAN); + v[out][i] = Grad::new(p, 0.0, 0.0, 0.0); + } + } RegOp::MinRegReg(out, lhs, rhs) => { for i in 0..size { v[out][i] = diff --git a/fidget/src/jit/aarch64/float_slice.rs b/fidget/src/jit/aarch64/float_slice.rs index 9ec3e0a0..840f5f04 100644 --- a/fidget/src/jit/aarch64/float_slice.rs +++ b/fidget/src/jit/aarch64/float_slice.rs @@ -292,6 +292,42 @@ impl Assembler for FloatSliceAssembler { ) } + fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { + dynasm!(self.0.ops + // Build a mask of valid positions (not NAN) + ; fcmeq v6.S4, V(reg(lhs_reg)).S4, V(reg(lhs_reg)).S4 + ; fcmeq v7.S4, V(reg(rhs_reg)).S4, V(reg(rhs_reg)).S4 + ; and v6.b16, v6.b16, v7.b16 + + // Invert to get a mask of NAN position + ; mvn v6.b16, v6.b16 + + // Note the swap here, from LT -> GT + ; fcmgt v4.S4, V(reg(rhs_reg)).S4, V(reg(lhs_reg)).S4 + ; fcmgt v5.S4, V(reg(lhs_reg)).S4, V(reg(rhs_reg)).S4 + // At this point, out_reg is all 1s where we should put 1.0 + + // Build a map of -1.0 positions + ; fmov s7, -1.0 + ; dup v7.s4, v7.s[0] + ; and V(reg(out_reg)).B16, v4.B16, v7.B16 + + // Build a map of -1.0 positions + ; fmov s7, 1.0 + ; dup v7.s4, v7.s[0] + ; and v5.B16, v5.B16, v7.B16 + ; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v5.B16 + + // Build a NAN mask + ; mov w9, f32::NAN.to_bits().into() + ; dup v7.s4, w9 + ; and v7.b16, v7.b16, v6.b16 + + // Apply NAN mask to NAN positions + ; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v7.b16 + ) + } + /// Loads an immediate into register V4, using W9 as an intermediary fn load_imm(&mut self, imm: f32) -> u8 { let imm_u32 = imm.to_bits(); diff --git a/fidget/src/jit/aarch64/grad_slice.rs b/fidget/src/jit/aarch64/grad_slice.rs index fe67e162..ce43ee36 100644 --- a/fidget/src/jit/aarch64/grad_slice.rs +++ b/fidget/src/jit/aarch64/grad_slice.rs @@ -403,6 +403,41 @@ impl Assembler for GradSliceAssembler { ) } + fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { + dynasm!(self.0.ops + // Check whether either argument is NAN + ; fcmeq s6, S(reg(lhs_reg)), S(reg(lhs_reg)) + ; dup v6.s4, v6.s[0] + ; fcmeq s7, S(reg(rhs_reg)), S(reg(rhs_reg)) + ; dup v7.s4, v7.s[0] + ; and v6.b16, v6.b16, v7.b16 + ; mvn v6.b16, v6.b16 + // At this point, v6 is all 1s if either argument is NAN + + // build masks for all 1s / all 0s + ; fcmgt s4, S(reg(rhs_reg)), S(reg(lhs_reg)) + ; dup v4.s4, v4.s[0] + ; fcmgt s5, S(reg(lhs_reg)), S(reg(rhs_reg)) + ; dup v5.s4, v5.s[0] + + // (lhs < rhs) & [1.0, 0.0, 0.0, 0.0] + ; fmov s7, -1.0 + ; and V(reg(out_reg)).b16, v4.b16, v7.b16 + + ; fmov s7, 1.0 + ; and v5.B16, v5.B16, v7.B16 + ; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v5.B16 + + // Build NAN mask + ; mov w9, f32::NAN.to_bits().into() + ; fmov s7, w9 + ; and v7.b16, v7.b16, v6.b16 + + // Build NAN to output + ; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v7.b16 + ) + } + /// Loads an immediate into register S4, using W9 as an intermediary fn load_imm(&mut self, imm: f32) -> u8 { let imm_u32 = imm.to_bits(); diff --git a/fidget/src/jit/aarch64/interval.rs b/fidget/src/jit/aarch64/interval.rs index 50e10045..9b14f7f3 100644 --- a/fidget/src/jit/aarch64/interval.rs +++ b/fidget/src/jit/aarch64/interval.rs @@ -482,6 +482,62 @@ impl Assembler for IntervalAssembler { ) } + fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { + dynasm!(self.0.ops + // Very similar to build_min, but without writing choices + // (and producing different output) + + // Build a !NAN mask + ; fcmeq v4.s2, V(reg(lhs_reg)).s2, V(reg(lhs_reg)).s2 + ; fcmeq v5.s2, V(reg(rhs_reg)).s2, V(reg(rhs_reg)).s2 + ; and v4.b8, v4.b8, v5.b8 + ; fmov x15, d4 + ; cmp x15, 0 + ; b.ne 16 // -> skip over NAN handling into main logic + + // NAN case + ; mov w15, f32::NAN.to_bits().into() + ; dup V(reg(out_reg)).s2, w15 + ; b 76 // -> end + + // v4 = [lhs.upper, rhs.upper] + // v5 = [rhs.lower, lhs.lower] + // This lets us do two comparisons simultaneously + ; zip2 v4.s2, V(reg(lhs_reg)).s2, V(reg(rhs_reg)).s2 + ; zip1 v5.s2, V(reg(rhs_reg)).s2, V(reg(lhs_reg)).s2 + + // v5 = [rhs.lower > lhs.upper, lhs.lower > rhs.upper] + ; fcmgt v5.s2, v5.s2, v4.s2 + ; fmov x15, d5 + + ; tst x15, 0x1_0000_0000 + ; b.ne 24 // -> rhs + + ; tst x15, 0x1 + ; b.eq 28 // -> both + + // Fallthrough: LHS < RHS => [-1, -1] + ; fmov S(reg(out_reg)), -1.0 + ; dup V(reg(out_reg)).s2, V(reg(out_reg)).s[0] + ; b 32 // -> end + + // <- rhs (for when RHS < LHS) => [1, 1] + ; fmov S(reg(out_reg)), 1.0 + ; dup V(reg(out_reg)).s2, V(reg(out_reg)).s[0] + ; b 20 // -> end + + // <- both [-1, 1] + ; fmov S(reg(out_reg)), 1.0 + ; dup V(reg(out_reg)).s2, V(reg(out_reg)).s[0] + ; fmov s5, -1.0 + ; mov V(reg(out_reg)).s[0], v5.s[0] + + // TODO handle the case where LHS == RHS with no ambiguity + + // <- end + ); + } + /// Loads an immediate into register S4, using W9 as an intermediary fn load_imm(&mut self, imm: f32) -> u8 { let imm_u32 = imm.to_bits(); diff --git a/fidget/src/jit/aarch64/point.rs b/fidget/src/jit/aarch64/point.rs index a098dec9..2dd5dcab 100644 --- a/fidget/src/jit/aarch64/point.rs +++ b/fidget/src/jit/aarch64/point.rs @@ -277,6 +277,40 @@ impl Assembler for PointAssembler { ) } + fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { + // This is using SIMD instructions to avoid branch; dunno if it's faster + // but it means we can use very similar code to float / grad slice + // evaluators. + dynasm!(self.0.ops + // Build a mask of NAN positions in s6 + ; fcmeq s6, S(reg(lhs_reg)), S(reg(lhs_reg)) + ; fcmeq s7, S(reg(rhs_reg)), S(reg(rhs_reg)) + ; and v6.b8, v6.b8, v7.b8 + ; mvn v6.b8, v6.b8 + + // Build our two comparisons + ; fcmgt s4, S(reg(rhs_reg)), S(reg(lhs_reg)) + ; fcmgt s5, S(reg(lhs_reg)), S(reg(rhs_reg)) + + // Apply -1 value (if relevant) + ; fmov s7, -1.0 + ; and V(reg(out_reg)).B8, v4.B8, v7.B8 + + // Apply +1 value (if relevant) + ; fmov s7, 1.0 + ; and v5.B8, v5.B8, v7.B8 + ; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v5.B8 + + // Apply NAN value + ; mov w9, f32::NAN.to_bits().into() + ; fmov s7, w9 + ; and v7.b8, v7.b8, v6.b8 + + // Apply NAN mask to NAN positions + ; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v7.b8 + ); + } + /// Loads an immediate into register S4, using W9 as an intermediary fn load_imm(&mut self, imm: f32) -> u8 { let imm_u32 = imm.to_bits(); diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index 389ea45a..2978628e 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -171,6 +171,9 @@ trait Assembler { /// Natural log fn build_ln(&mut self, out_reg: u8, lhs_reg: u8); + /// Less than + fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8); + /// Square /// /// This has a default implementation, but can be overloaded for efficiency; @@ -734,6 +737,17 @@ fn build_asm_fn_with_storage( let reg = asm.load_imm(imm); asm.build_copy(out, reg); } + RegOp::CompareRegReg(out, lhs, rhs) => { + asm.build_compare(out, lhs, rhs); + } + RegOp::CompareRegImm(out, arg, imm) => { + let reg = asm.load_imm(imm); + asm.build_compare(out, arg, reg); + } + RegOp::CompareImmReg(out, arg, imm) => { + let reg = asm.load_imm(imm); + asm.build_compare(out, reg, arg); + } } } diff --git a/fidget/src/jit/x86_64/float_slice.rs b/fidget/src/jit/x86_64/float_slice.rs index 29578501..8f9aff1e 100644 --- a/fidget/src/jit/x86_64/float_slice.rs +++ b/fidget/src/jit/x86_64/float_slice.rs @@ -268,6 +268,38 @@ impl Assembler for FloatSliceAssembler { ; vorps Ry(reg(out_reg)), Ry(reg(out_reg)), ymm1 ); } + fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { + dynasm!(self.0.ops + // Build a mask of NANs; conveniently, all 1s is a NAN + ; vcmpunordps ymm1, Ry(reg(lhs_reg)), Ry(reg(lhs_reg)) + ; vcmpunordps ymm2, Ry(reg(rhs_reg)), Ry(reg(rhs_reg)) + ; vorps ymm1, ymm2, ymm1 + + // Calculate the less-than mask in ymm2 + ; vcmpltps ymm2, Ry(reg(lhs_reg)), Ry(reg(rhs_reg)) + + // Calculate the greater-than mask in ymm2 + ; vcmpgtps ymm3, Ry(reg(lhs_reg)), Ry(reg(rhs_reg)) + + // Put [-1.0; N] into the output register + ; mov eax, (-1f32).to_bits() as i32 + ; vmovd Rx(reg(out_reg)), eax + ; vbroadcastss Ry(reg(out_reg)), Rx(reg(out_reg)) + + // Apply the less-than mask to the [-1.0 x N] reg + ; vandps Ry(reg(out_reg)), Ry(reg(out_reg)), ymm2 + + // Build and apply [1.0 x N] & greater-than + ; mov eax, 1f32.to_bits() as i32 + ; vmovd xmm2, eax + ; vbroadcastss ymm2, xmm2 + ; vandps ymm2, ymm2, ymm3 + ; vorps Ry(reg(out_reg)), Ry(reg(out_reg)), ymm2 + + // Set the NAN bits + ; vorps Ry(reg(out_reg)), Ry(reg(out_reg)), ymm1 + ); + } fn load_imm(&mut self, imm: f32) -> u8 { dynasm!(self.0.ops ; mov eax, imm.to_bits() as i32 diff --git a/fidget/src/jit/x86_64/grad_slice.rs b/fidget/src/jit/x86_64/grad_slice.rs index 3bcacffd..43a4f997 100644 --- a/fidget/src/jit/x86_64/grad_slice.rs +++ b/fidget/src/jit/x86_64/grad_slice.rs @@ -387,6 +387,38 @@ impl Assembler for GradSliceAssembler { ); self.0.ops.commit_local().unwrap(); } + fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { + dynasm!(self.0.ops + ; vcomiss Rx(reg(lhs_reg)), Rx(reg(rhs_reg)) + ; jp >N + ; ja >R + ; jb >L + + // Fall-through for equal + ; mov eax, 0f32.to_bits() as i32 + ; vmovd Rx(reg(out_reg)), eax + ; jmp >O + + // Less than + ; L: + ; mov eax, (-1f32).to_bits() as i32 + ; vmovd Rx(reg(out_reg)), eax + ; jmp >O + + ; N: + // TODO: this can't be the best way to make a NAN + ; vaddss Rx(reg(out_reg)), Rx(reg(lhs_reg)), Rx(reg(rhs_reg)) + ; jmp >O + + ; R: + ; mov eax, 1f32.to_bits() as i32 + ; vmovd Rx(reg(out_reg)), eax + // fallthrough to out + + ; O: + ); + self.0.ops.commit_local().unwrap(); + } fn load_imm(&mut self, imm: f32) -> u8 { let imm_u32 = imm.to_bits(); dynasm!(self.0.ops diff --git a/fidget/src/jit/x86_64/interval.rs b/fidget/src/jit/x86_64/interval.rs index 2c04f896..7fe23990 100644 --- a/fidget/src/jit/x86_64/interval.rs +++ b/fidget/src/jit/x86_64/interval.rs @@ -498,6 +498,63 @@ impl Assembler for IntervalAssembler { ); self.0.ops.commit_local().unwrap(); } + fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { + // TODO: Godbolt uses unpcklps ? + dynasm!(self.0.ops + // if lhs.has_nan || rhs.has_nan + // out = [NAN, NAN] + // elif lhs.upper < rhs.lower + // out = [-1, 1] + // elif rhs.upper < lhs.lower + // out = [1, 1] + // else + // out = [-1, 1] + + // TODO: use cmpltss to do both comparisons? + // xmm1 = lhs.upper + ; vpshufd xmm1, Rx(reg(lhs_reg)), 0b11111101u8 as i8 + ; vcomiss xmm1, Rx(reg(rhs_reg)) // compare lhs.upper and rhs.lower + ; jp >N + ; jb >L + + // xmm1 = rhs.upper + ; vpshufd xmm1, Rx(reg(rhs_reg)), 0b11111101u8 as i8 + ; vcomiss xmm1, Rx(reg(lhs_reg)) + ; jp >N + ; jb >R + + // Fallthrough: ambiguous case, so load [-1, 1] + ; mov eax, (-1f32).to_bits() as i32 + ; vpinsrd Rx(reg(out_reg)), Rx(reg(out_reg)), eax, 0 + ; mov eax, 1f32.to_bits() as i32 + ; vpinsrd Rx(reg(out_reg)), Rx(reg(out_reg)), eax, 1 + ; jmp >E + + ; N: + // Load NAN into out_reg + ; vpcmpeqw Rx(reg(out_reg)), Rx(reg(out_reg)), Rx(reg(out_reg)) + ; vpslld Rx(reg(out_reg)), Rx(reg(out_reg)), 23 + ; vpsrld Rx(reg(out_reg)), Rx(reg(out_reg)), 1 + ; jmp >E + + // lhs.upper < rhs.lower + ; L: + ; mov eax, (-1f32).to_bits() as i32 + ; vmovd xmm1, eax + ; vbroadcastss Rx(reg(out_reg)), xmm1 + ; jmp >E + + // rhs.upper < lhs.lower + ; R: + ; mov eax, 1f32.to_bits() as i32 + ; vmovd xmm1, eax + ; vbroadcastss Rx(reg(out_reg)), xmm1 + // Fallthrough + + ; E: + ); + self.0.ops.commit_local().unwrap(); + } fn load_imm(&mut self, imm: f32) -> u8 { let imm_u32 = imm.to_bits(); dynasm!(self.0.ops diff --git a/fidget/src/jit/x86_64/point.rs b/fidget/src/jit/x86_64/point.rs index 97c2a545..bb7065eb 100644 --- a/fidget/src/jit/x86_64/point.rs +++ b/fidget/src/jit/x86_64/point.rs @@ -286,6 +286,36 @@ impl Assembler for PointAssembler { ); self.0.ops.commit_local().unwrap() } + fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { + dynasm!(self.0.ops + ; vcomiss Rx(reg(lhs_reg)), Rx(reg(rhs_reg)) + ; jp >N + ; ja >R + ; jb >L + + // Fall-through for equal + ; mov eax, 0f32.to_bits() as i32 + ; vmovd Rx(reg(out_reg)), eax + ; jmp >O + + ; L: + ; mov eax, (-1f32).to_bits() as i32 + ; vmovd Rx(reg(out_reg)), eax + ; jmp >O + + ; N: + // TODO: this can't be the best way to make a NAN + ; vaddss Rx(reg(out_reg)), Rx(reg(lhs_reg)), Rx(reg(rhs_reg)) + ; jmp >O + + ; R: + ; mov eax, 1f32.to_bits() as i32 + ; vmovd Rx(reg(out_reg)), eax + // fallthrough to out + + ; O: + ); + } fn load_imm(&mut self, imm: f32) -> u8 { let imm_u32 = imm.to_bits(); dynasm!(self.0.ops