From 87e6568240a6e50b46a1185d0657b83299eb7a09 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sat, 16 Mar 2024 16:14:58 -0400 Subject: [PATCH] Systematic testing of binary ops for float-slice evaluators (#29) --- fidget/src/core/eval/test/float_slice.rs | 209 +++++++++++++++++++++-- fidget/src/core/eval/test/grad_slice.rs | 13 +- fidget/src/core/eval/test/interval.rs | 12 +- fidget/src/core/eval/test/mod.rs | 13 ++ fidget/src/core/eval/test/point.rs | 12 +- fidget/src/core/vm/mod.rs | 52 +++++- fidget/src/jit/x86_64/float_slice.rs | 22 ++- 7 files changed, 280 insertions(+), 53 deletions(-) diff --git a/fidget/src/core/eval/test/float_slice.rs b/fidget/src/core/eval/test/float_slice.rs index 9bf09859..fd7c87b6 100644 --- a/fidget/src/core/eval/test/float_slice.rs +++ b/fidget/src/core/eval/test/float_slice.rs @@ -3,7 +3,7 @@ //! If the `eval-tests` feature is set, then this exposes a standard test suite //! for such evaluators; otherwise, the module has no public exports. -use super::build_stress_fn; +use super::{build_stress_fn, test_args}; use crate::{ context::{Context, Node}, eval::{BulkEvaluator, EzShape, MathShape, Shape, ShapeVars, Vars}, @@ -16,6 +16,14 @@ macro_rules! float_slice_unary { }; } +macro_rules! float_slice_binary { + (Context::$i:ident, $t:expr) => { + Self::test_binary_reg_reg(Context::$i, $t, stringify!($i)); + Self::test_binary_reg_imm(Context::$i, $t, stringify!($i)); + Self::test_binary_imm_reg(Context::$i, $t, stringify!($i)); + }; +} + /// Helper struct to put constrains on our `Shape` object pub struct TestFloatSlice(std::marker::PhantomData<*const S>); @@ -276,17 +284,7 @@ where g: impl Fn(f32) -> f32, name: &'static str, ) { - // Pick a bunch of arguments, some of which are spicy - let mut args = - (-32..32).map(|i| i as f32 / 32f32).collect::>(); - args.push(0.0); - args.push(1.0); - args.push(std::f32::consts::PI); - args.push(std::f32::consts::FRAC_PI_2); - args.push(std::f32::consts::FRAC_1_PI); - args.push(std::f32::consts::SQRT_2); - args.push(f32::NAN); - + let args = test_args(); let zero = vec![0.0; args.len()]; let mut ctx = Context::new(); @@ -315,6 +313,152 @@ where } } + pub fn test_binary_reg_reg( + f: impl Fn(&mut Context, Node, Node) -> Result, + g: impl Fn(f32, f32) -> f32, + name: &'static str, + ) { + let args = test_args(); + let zero = vec![0.0; args.len()]; + + let mut ctx = Context::new(); + let inputs = [ctx.x(), ctx.y(), ctx.z()]; + for rot in 0..args.len() { + let mut rgsa = args.clone(); + rgsa.rotate_left(rot); + for (i, &v) in inputs.iter().enumerate() { + for (j, &u) in inputs.iter().enumerate() { + let node = f(&mut ctx, v, u).unwrap(); + + let shape = S::new(&ctx, node).unwrap(); + let mut eval = S::new_float_slice_eval(); + let tape = shape.ez_float_slice_tape(); + + let out = match (i, j) { + (0, 0) => eval.eval(&tape, &args, &zero, &zero, &[]), + (0, 1) => eval.eval(&tape, &args, &rgsa, &zero, &[]), + (0, 2) => eval.eval(&tape, &args, &zero, &rgsa, &[]), + (1, 0) => eval.eval(&tape, &rgsa, &args, &zero, &[]), + (1, 1) => eval.eval(&tape, &zero, &args, &zero, &[]), + (1, 2) => eval.eval(&tape, &zero, &args, &rgsa, &[]), + (2, 0) => eval.eval(&tape, &rgsa, &zero, &args, &[]), + (2, 1) => eval.eval(&tape, &zero, &rgsa, &args, &[]), + (2, 2) => eval.eval(&tape, &zero, &zero, &args, &[]), + _ => unreachable!(), + } + .unwrap(); + + let b = if i == j { &args } else { &rgsa }; + for ((a, b), &o) in args.iter().zip(b).zip(out.iter()) { + let v = g(*a, *b); + let err = (v - o).abs(); + assert!( + (o == v) + || err < 1e-6 + || (v.is_nan() && o.is_nan()), + "mismatch in '{name}' at {a} {b}: \ + {v} != {o} ({err})" + ) + } + } + } + } + } + + fn test_binary_reg_imm( + f: impl Fn(&mut Context, Node, Node) -> Result, + g: impl Fn(f32, f32) -> f32, + name: &'static str, + ) { + let args = test_args(); + let zero = vec![0.0; args.len()]; + + let mut ctx = Context::new(); + let inputs = [ctx.x(), ctx.y(), ctx.z()]; + + for rot in 0..args.len() { + let mut args = args.clone(); + args.rotate_left(rot); + for (i, &v) in inputs.iter().enumerate() { + for rhs in args.iter() { + let c = ctx.constant(*rhs as f64); + let node = f(&mut ctx, v, c).unwrap(); + + let shape = S::new(&ctx, node).unwrap(); + let mut eval = S::new_float_slice_eval(); + let tape = shape.ez_float_slice_tape(); + + let out = match i { + 0 => eval.eval(&tape, &args, &zero, &zero, &[]), + 1 => eval.eval(&tape, &zero, &args, &zero, &[]), + 2 => eval.eval(&tape, &zero, &zero, &args, &[]), + _ => unreachable!(), + } + .unwrap(); + + for (a, &o) in args.iter().zip(out.iter()) { + let v = g(*a, *rhs); + let err = (v - o).abs(); + assert!( + (o == v) + || err < 1e-6 + || (v.is_nan() && o.is_nan()), + "mismatch in '{name}' at {a}, {rhs} (constant): \ + {v} != {o} ({err})" + ) + } + } + } + } + } + + fn test_binary_imm_reg( + f: impl Fn(&mut Context, Node, Node) -> Result, + g: impl Fn(f32, f32) -> f32, + name: &'static str, + ) { + let args = test_args(); + let zero = vec![0.0; args.len()]; + + let mut ctx = Context::new(); + let inputs = [ctx.x(), ctx.y(), ctx.z()]; + + for rot in 0..args.len() { + let mut args = args.clone(); + args.rotate_left(rot); + for (i, &v) in inputs.iter().enumerate() { + for lhs in args.iter() { + let c = ctx.constant(*lhs as f64); + let node = f(&mut ctx, c, v).unwrap(); + + let shape = S::new(&ctx, node).unwrap(); + let mut eval = S::new_float_slice_eval(); + let tape = shape.ez_float_slice_tape(); + + let out = match i { + 0 => eval.eval(&tape, &args, &zero, &zero, &[]), + 1 => eval.eval(&tape, &zero, &args, &zero, &[]), + 2 => eval.eval(&tape, &zero, &zero, &args, &[]), + _ => unreachable!(), + } + .unwrap(); + + for (a, &o) in args.iter().zip(out.iter()) { + let v = g(*lhs, *a); + let err = (v - o).abs(); + assert!( + (o == v) + || err < 1e-6 + || (v.is_nan() && o.is_nan()), + "mismatch in '{name}' at {lhs} (constant), {a}: \ + {v} != {o} ({err})" + ) + } + } + } + } + } + pub fn test_f_unary_ops() { float_slice_unary!(Context::neg, |v| -v); float_slice_unary!(Context::recip, |v| 1.0 / v); @@ -330,6 +474,46 @@ where float_slice_unary!(Context::square, |v| v * v); float_slice_unary!(Context::sqrt, |v| v.sqrt()); } + + pub fn test_f_binary_ops() { + float_slice_binary!(Context::add, |a, b| a + b); + float_slice_binary!(Context::sub, |a, b| a - b); + + // Multiplication short-circuits to 0, which means that + // 0 (constant) * NaN = 0 + Self::test_binary_reg_reg(Context::mul, |a, b| a * b, "mul"); + Self::test_binary_reg_imm( + Context::mul, + |a, b| if b == 0.0 { b } else { a * b }, + "mul", + ); + Self::test_binary_imm_reg( + Context::mul, + |a, b| if a == 0.0 { a } else { a * b }, + "mul", + ); + + // Multiplication short-circuits to 0, which means that + // 0 (constant) / NaN = 0 + Self::test_binary_reg_reg(Context::div, |a, b| a / b, "div"); + Self::test_binary_reg_imm(Context::div, |a, b| a / b, "div"); + Self::test_binary_imm_reg( + Context::div, + |a, b| if a == 0.0 { a } else { a / b }, + "div", + ); + + float_slice_binary!(Context::min, |a, b| if a.is_nan() || b.is_nan() { + f32::NAN + } else { + a.min(b) + }); + float_slice_binary!(Context::max, |a, b| if a.is_nan() || b.is_nan() { + f32::NAN + } else { + a.max(b) + }); + } } #[macro_export] @@ -351,5 +535,6 @@ macro_rules! float_slice_tests { $crate::float_slice_test!(test_f_sin, $t); $crate::float_slice_test!(test_f_stress, $t); $crate::float_slice_test!(test_f_unary_ops, $t); + $crate::float_slice_test!(test_f_binary_ops, $t); }; } diff --git a/fidget/src/core/eval/test/grad_slice.rs b/fidget/src/core/eval/test/grad_slice.rs index ad40a5f6..1479f36e 100644 --- a/fidget/src/core/eval/test/grad_slice.rs +++ b/fidget/src/core/eval/test/grad_slice.rs @@ -2,7 +2,7 @@ //! //! If the `eval-tests` feature is set, then this exposes a standard test suite //! for interval evaluators; otherwise, the module has no public exports. -use super::build_stress_fn; +use super::{build_stress_fn, test_args}; use crate::{ context::{Context, Node}, eval::{ @@ -442,16 +442,7 @@ where g: impl Fn(f64) -> f64, name: &'static str, ) { - // Pick a bunch of arguments, some of which are spicy - let mut args = - (-32..32).map(|i| i as f32 / 32f32).collect::>(); - args.push(0.0); - args.push(1.0); - args.push(std::f32::consts::PI); - args.push(std::f32::consts::FRAC_PI_2); - args.push(std::f32::consts::FRAC_1_PI); - args.push(std::f32::consts::SQRT_2); - args.push(f32::NAN); + let args = test_args(); let zero = vec![0.0; args.len()]; let mut ctx = Context::new(); diff --git a/fidget/src/core/eval/test/interval.rs b/fidget/src/core/eval/test/interval.rs index 7a02422e..73037be9 100644 --- a/fidget/src/core/eval/test/interval.rs +++ b/fidget/src/core/eval/test/interval.rs @@ -3,7 +3,7 @@ //! If the `eval-tests` feature is set, then this exposes a standard test suite //! for interval evaluators; otherwise, the module has no public exports. -use super::build_stress_fn; +use super::{build_stress_fn, test_args}; use crate::{ context::{Context, Node}, eval::{ @@ -668,15 +668,7 @@ where g: impl Fn(f32) -> f32, name: &'static str, ) { - // Pick a bunch of arguments, some of which are spicy - let mut values = - (-32..32).map(|i| i as f32 / 32f32).collect::>(); - values.push(0.0); - values.push(1.0); - values.push(std::f32::consts::PI); - values.push(std::f32::consts::FRAC_PI_2); - values.push(std::f32::consts::FRAC_1_PI); - values.push(std::f32::consts::SQRT_2); + let values = test_args(); let mut args = vec![]; for &lower in &values { diff --git a/fidget/src/core/eval/test/mod.rs b/fidget/src/core/eval/test/mod.rs index b01363ff..3c69d0c9 100644 --- a/fidget/src/core/eval/test/mod.rs +++ b/fidget/src/core/eval/test/mod.rs @@ -33,3 +33,16 @@ pub(crate) fn build_stress_fn(n: usize) -> (Context, Node) { (ctx, sum) } + +/// Pick a bunch of arguments, some of which are spicy +fn test_args() -> Vec { + let mut args = (-32..32).map(|i| i as f32 / 32f32).collect::>(); + args.push(0.0); + args.push(1.0); + args.push(std::f32::consts::PI); + args.push(std::f32::consts::FRAC_PI_2); + args.push(std::f32::consts::FRAC_1_PI); + args.push(std::f32::consts::SQRT_2); + args.push(f32::NAN); + args +} diff --git a/fidget/src/core/eval/test/point.rs b/fidget/src/core/eval/test/point.rs index cc3af60a..a509f2f4 100644 --- a/fidget/src/core/eval/test/point.rs +++ b/fidget/src/core/eval/test/point.rs @@ -2,7 +2,7 @@ //! //! If the `eval-tests` feature is set, then this exposes a standard test suite //! for point evaluators; otherwise, the module has no public exports. -use super::build_stress_fn; +use super::{build_stress_fn, test_args}; use crate::{ context::{Context, Node}, eval::{EzShape, MathShape, Shape, ShapeVars, TracingEvaluator, Vars}, @@ -397,15 +397,7 @@ where name: &'static str, ) { // Pick a bunch of arguments, some of which are spicy - let mut args = - (-32..32).map(|i| i as f32 / 32f32).collect::>(); - args.push(0.0); - args.push(1.0); - args.push(std::f32::consts::PI); - args.push(std::f32::consts::FRAC_PI_2); - args.push(std::f32::consts::FRAC_1_PI); - args.push(std::f32::consts::SQRT_2); - args.push(f32::NAN); + let args = test_args(); let mut ctx = Context::new(); for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index 810787cd..1456f6f2 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -768,12 +768,20 @@ impl BulkEvaluator for VmFloatSliceEval { } RegOp::MinRegImm(out, arg, imm) => { for i in 0..size { - v[out][i] = v[arg][i].min(imm); + v[out][i] = if v[arg][i].is_nan() || imm.is_nan() { + f32::NAN + } else { + v[arg][i].min(imm) + }; } } RegOp::MaxRegImm(out, arg, imm) => { for i in 0..size { - v[out][i] = v[arg][i].max(imm); + v[out][i] = if v[arg][i].is_nan() || imm.is_nan() { + f32::NAN + } else { + v[arg][i].max(imm) + }; } } RegOp::AddRegReg(out, lhs, rhs) => { @@ -798,12 +806,22 @@ impl BulkEvaluator for VmFloatSliceEval { } RegOp::MinRegReg(out, lhs, rhs) => { for i in 0..size { - v[out][i] = v[lhs][i].min(v[rhs][i]); + v[out][i] = if v[lhs][i].is_nan() || v[rhs][i].is_nan() + { + f32::NAN + } else { + v[lhs][i].min(v[rhs][i]) + }; } } RegOp::MaxRegReg(out, lhs, rhs) => { for i in 0..size { - v[out][i] = v[lhs][i].max(v[rhs][i]); + v[out][i] = if v[lhs][i].is_nan() || v[rhs][i].is_nan() + { + f32::NAN + } else { + v[lhs][i].max(v[rhs][i]) + }; } } RegOp::CopyImm(out, imm) => { @@ -981,13 +999,21 @@ impl BulkEvaluator for VmGradSliceEval { RegOp::MinRegImm(out, arg, imm) => { let imm: Grad = imm.into(); for i in 0..size { - v[out][i] = v[arg][i].min(imm); + v[out][i] = if v[arg][i].v.is_nan() || imm.v.is_nan() { + f32::NAN.into() + } else { + v[arg][i].min(imm) + }; } } RegOp::MaxRegImm(out, arg, imm) => { let imm: Grad = imm.into(); for i in 0..size { - v[out][i] = v[arg][i].max(imm); + v[out][i] = if v[arg][i].v.is_nan() || imm.v.is_nan() { + f32::NAN.into() + } else { + v[arg][i].max(imm) + }; } } RegOp::AddRegReg(out, lhs, rhs) => { @@ -1012,12 +1038,22 @@ impl BulkEvaluator for VmGradSliceEval { } RegOp::MinRegReg(out, lhs, rhs) => { for i in 0..size { - v[out][i] = v[lhs][i].min(v[rhs][i]); + v[out][i] = + if v[lhs][i].v.is_nan() || v[rhs][i].v.is_nan() { + f32::NAN.into() + } else { + v[lhs][i].min(v[rhs][i]) + }; } } RegOp::MaxRegReg(out, lhs, rhs) => { for i in 0..size { - v[out][i] = v[lhs][i].max(v[rhs][i]); + v[out][i] = + if v[lhs][i].v.is_nan() || v[rhs][i].v.is_nan() { + f32::NAN.into() + } else { + v[lhs][i].max(v[rhs][i]) + }; } } RegOp::CopyImm(out, imm) => { diff --git a/fidget/src/jit/x86_64/float_slice.rs b/fidget/src/jit/x86_64/float_slice.rs index 13d063ae..c2bb1c85 100644 --- a/fidget/src/jit/x86_64/float_slice.rs +++ b/fidget/src/jit/x86_64/float_slice.rs @@ -239,15 +239,33 @@ impl Assembler for FloatSliceAssembler { ); } fn build_max(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { - // TODO: does this handle NaN correctly? dynasm!(self.0.ops + // Build a mask of NANs; conveniently, all 1s is a NAN + ; vcmpps ymm1, Ry(reg(lhs_reg)), Ry(reg(lhs_reg)), 3 + ; vcmpps ymm2, Ry(reg(rhs_reg)), Ry(reg(rhs_reg)), 3 + ; vorps ymm1, ymm2, ymm1 + + // Calculate the max, which ignores NANs ; vmaxps Ry(reg(out_reg)), Ry(reg(lhs_reg)), Ry(reg(rhs_reg)) + + // Set the NAN bits + ; vorps Ry(reg(out_reg)), Ry(reg(out_reg)), ymm1 ); } fn build_min(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { - // TODO: does this handle NaN correctly? dynasm!(self.0.ops + // Build a mask of NANs; conveniently, all 1s is a NAN + ; vcmpps ymm1, Ry(reg(lhs_reg)), Ry(reg(lhs_reg)), 3 + ; vcmpps ymm2, Ry(reg(rhs_reg)), Ry(reg(rhs_reg)), 3 + ; vorps ymm1, ymm2, ymm1 + + // Calculate the min, which ignores NANs ; vminps Ry(reg(out_reg)), Ry(reg(lhs_reg)), Ry(reg(rhs_reg)) + + // Set the NAN bits + // (note that we leave other bits unchanged, because it doesn't + // matter here!) + ; vorps Ry(reg(out_reg)), Ry(reg(out_reg)), ymm1 ); } fn load_imm(&mut self, imm: f32) -> u8 {