diff --git a/fidget/src/core/eval/test/float_slice.rs b/fidget/src/core/eval/test/float_slice.rs index 125ef029..fd7c87b6 100644 --- a/fidget/src/core/eval/test/float_slice.rs +++ b/fidget/src/core/eval/test/float_slice.rs @@ -3,7 +3,7 @@ //! If the `eval-tests` feature is set, then this exposes a standard test suite //! for such evaluators; otherwise, the module has no public exports. -use super::build_stress_fn; +use super::{build_stress_fn, test_args}; use crate::{ context::{Context, Node}, eval::{BulkEvaluator, EzShape, MathShape, Shape, ShapeVars, Vars}, @@ -18,7 +18,9 @@ macro_rules! float_slice_unary { macro_rules! float_slice_binary { (Context::$i:ident, $t:expr) => { - Self::test_binary(Context::$i, $t, stringify!($i)); + Self::test_binary_reg_reg(Context::$i, $t, stringify!($i)); + Self::test_binary_reg_imm(Context::$i, $t, stringify!($i)); + Self::test_binary_imm_reg(Context::$i, $t, stringify!($i)); }; } @@ -282,17 +284,7 @@ where g: impl Fn(f32) -> f32, name: &'static str, ) { - // Pick a bunch of arguments, some of which are spicy - let mut args = - (-32..32).map(|i| i as f32 / 32f32).collect::>(); - args.push(0.0); - args.push(1.0); - args.push(std::f32::consts::PI); - args.push(std::f32::consts::FRAC_PI_2); - args.push(std::f32::consts::FRAC_1_PI); - args.push(std::f32::consts::SQRT_2); - args.push(f32::NAN); - + let args = test_args(); let zero = vec![0.0; args.len()]; let mut ctx = Context::new(); @@ -321,22 +313,12 @@ where } } - pub fn test_binary( + pub fn test_binary_reg_reg( f: impl Fn(&mut Context, Node, Node) -> Result, g: impl Fn(f32, f32) -> f32, name: &'static str, ) { - // Pick a bunch of arguments, some of which are spicy - let mut args = - (-32..32).map(|i| i as f32 / 32f32).collect::>(); - args.push(0.0); - args.push(1.0); - args.push(std::f32::consts::PI); - args.push(std::f32::consts::FRAC_PI_2); - args.push(std::f32::consts::FRAC_1_PI); - args.push(std::f32::consts::SQRT_2); - args.push(f32::NAN); - + let args = test_args(); let zero = vec![0.0; args.len()]; let mut ctx = Context::new(); @@ -383,6 +365,100 @@ where } } + fn test_binary_reg_imm( + f: impl Fn(&mut Context, Node, Node) -> Result, + g: impl Fn(f32, f32) -> f32, + name: &'static str, + ) { + let args = test_args(); + let zero = vec![0.0; args.len()]; + + let mut ctx = Context::new(); + let inputs = [ctx.x(), ctx.y(), ctx.z()]; + + for rot in 0..args.len() { + let mut args = args.clone(); + args.rotate_left(rot); + for (i, &v) in inputs.iter().enumerate() { + for rhs in args.iter() { + let c = ctx.constant(*rhs as f64); + let node = f(&mut ctx, v, c).unwrap(); + + let shape = S::new(&ctx, node).unwrap(); + let mut eval = S::new_float_slice_eval(); + let tape = shape.ez_float_slice_tape(); + + let out = match i { + 0 => eval.eval(&tape, &args, &zero, &zero, &[]), + 1 => eval.eval(&tape, &zero, &args, &zero, &[]), + 2 => eval.eval(&tape, &zero, &zero, &args, &[]), + _ => unreachable!(), + } + .unwrap(); + + for (a, &o) in args.iter().zip(out.iter()) { + let v = g(*a, *rhs); + let err = (v - o).abs(); + assert!( + (o == v) + || err < 1e-6 + || (v.is_nan() && o.is_nan()), + "mismatch in '{name}' at {a}, {rhs} (constant): \ + {v} != {o} ({err})" + ) + } + } + } + } + } + + fn test_binary_imm_reg( + f: impl Fn(&mut Context, Node, Node) -> Result, + g: impl Fn(f32, f32) -> f32, + name: &'static str, + ) { + let args = test_args(); + let zero = vec![0.0; args.len()]; + + let mut ctx = Context::new(); + let inputs = [ctx.x(), ctx.y(), ctx.z()]; + + for rot in 0..args.len() { + let mut args = args.clone(); + args.rotate_left(rot); + for (i, &v) in inputs.iter().enumerate() { + for lhs in args.iter() { + let c = ctx.constant(*lhs as f64); + let node = f(&mut ctx, c, v).unwrap(); + + let shape = S::new(&ctx, node).unwrap(); + let mut eval = S::new_float_slice_eval(); + let tape = shape.ez_float_slice_tape(); + + let out = match i { + 0 => eval.eval(&tape, &args, &zero, &zero, &[]), + 1 => eval.eval(&tape, &zero, &args, &zero, &[]), + 2 => eval.eval(&tape, &zero, &zero, &args, &[]), + _ => unreachable!(), + } + .unwrap(); + + for (a, &o) in args.iter().zip(out.iter()) { + let v = g(*lhs, *a); + let err = (v - o).abs(); + assert!( + (o == v) + || err < 1e-6 + || (v.is_nan() && o.is_nan()), + "mismatch in '{name}' at {lhs} (constant), {a}: \ + {v} != {o} ({err})" + ) + } + } + } + } + } + pub fn test_f_unary_ops() { float_slice_unary!(Context::neg, |v| -v); float_slice_unary!(Context::recip, |v| 1.0 / v); @@ -402,8 +478,31 @@ where pub fn test_f_binary_ops() { float_slice_binary!(Context::add, |a, b| a + b); float_slice_binary!(Context::sub, |a, b| a - b); - float_slice_binary!(Context::mul, |a, b| a * b); - float_slice_binary!(Context::div, |a, b| a / b); + + // Multiplication short-circuits to 0, which means that + // 0 (constant) * NaN = 0 + Self::test_binary_reg_reg(Context::mul, |a, b| a * b, "mul"); + Self::test_binary_reg_imm( + Context::mul, + |a, b| if b == 0.0 { b } else { a * b }, + "mul", + ); + Self::test_binary_imm_reg( + Context::mul, + |a, b| if a == 0.0 { a } else { a * b }, + "mul", + ); + + // Multiplication short-circuits to 0, which means that + // 0 (constant) / NaN = 0 + Self::test_binary_reg_reg(Context::div, |a, b| a / b, "div"); + Self::test_binary_reg_imm(Context::div, |a, b| a / b, "div"); + Self::test_binary_imm_reg( + Context::div, + |a, b| if a == 0.0 { a } else { a / b }, + "div", + ); + float_slice_binary!(Context::min, |a, b| if a.is_nan() || b.is_nan() { f32::NAN } else { diff --git a/fidget/src/core/eval/test/grad_slice.rs b/fidget/src/core/eval/test/grad_slice.rs index ad40a5f6..1479f36e 100644 --- a/fidget/src/core/eval/test/grad_slice.rs +++ b/fidget/src/core/eval/test/grad_slice.rs @@ -2,7 +2,7 @@ //! //! If the `eval-tests` feature is set, then this exposes a standard test suite //! for interval evaluators; otherwise, the module has no public exports. -use super::build_stress_fn; +use super::{build_stress_fn, test_args}; use crate::{ context::{Context, Node}, eval::{ @@ -442,16 +442,7 @@ where g: impl Fn(f64) -> f64, name: &'static str, ) { - // Pick a bunch of arguments, some of which are spicy - let mut args = - (-32..32).map(|i| i as f32 / 32f32).collect::>(); - args.push(0.0); - args.push(1.0); - args.push(std::f32::consts::PI); - args.push(std::f32::consts::FRAC_PI_2); - args.push(std::f32::consts::FRAC_1_PI); - args.push(std::f32::consts::SQRT_2); - args.push(f32::NAN); + let args = test_args(); let zero = vec![0.0; args.len()]; let mut ctx = Context::new(); diff --git a/fidget/src/core/eval/test/interval.rs b/fidget/src/core/eval/test/interval.rs index 7a02422e..73037be9 100644 --- a/fidget/src/core/eval/test/interval.rs +++ b/fidget/src/core/eval/test/interval.rs @@ -3,7 +3,7 @@ //! If the `eval-tests` feature is set, then this exposes a standard test suite //! for interval evaluators; otherwise, the module has no public exports. -use super::build_stress_fn; +use super::{build_stress_fn, test_args}; use crate::{ context::{Context, Node}, eval::{ @@ -668,15 +668,7 @@ where g: impl Fn(f32) -> f32, name: &'static str, ) { - // Pick a bunch of arguments, some of which are spicy - let mut values = - (-32..32).map(|i| i as f32 / 32f32).collect::>(); - values.push(0.0); - values.push(1.0); - values.push(std::f32::consts::PI); - values.push(std::f32::consts::FRAC_PI_2); - values.push(std::f32::consts::FRAC_1_PI); - values.push(std::f32::consts::SQRT_2); + let values = test_args(); let mut args = vec![]; for &lower in &values { diff --git a/fidget/src/core/eval/test/mod.rs b/fidget/src/core/eval/test/mod.rs index b01363ff..3c69d0c9 100644 --- a/fidget/src/core/eval/test/mod.rs +++ b/fidget/src/core/eval/test/mod.rs @@ -33,3 +33,16 @@ pub(crate) fn build_stress_fn(n: usize) -> (Context, Node) { (ctx, sum) } + +/// Pick a bunch of arguments, some of which are spicy +fn test_args() -> Vec { + let mut args = (-32..32).map(|i| i as f32 / 32f32).collect::>(); + args.push(0.0); + args.push(1.0); + args.push(std::f32::consts::PI); + args.push(std::f32::consts::FRAC_PI_2); + args.push(std::f32::consts::FRAC_1_PI); + args.push(std::f32::consts::SQRT_2); + args.push(f32::NAN); + args +} diff --git a/fidget/src/core/eval/test/point.rs b/fidget/src/core/eval/test/point.rs index cc3af60a..a509f2f4 100644 --- a/fidget/src/core/eval/test/point.rs +++ b/fidget/src/core/eval/test/point.rs @@ -2,7 +2,7 @@ //! //! If the `eval-tests` feature is set, then this exposes a standard test suite //! for point evaluators; otherwise, the module has no public exports. -use super::build_stress_fn; +use super::{build_stress_fn, test_args}; use crate::{ context::{Context, Node}, eval::{EzShape, MathShape, Shape, ShapeVars, TracingEvaluator, Vars}, @@ -397,15 +397,7 @@ where name: &'static str, ) { // Pick a bunch of arguments, some of which are spicy - let mut args = - (-32..32).map(|i| i as f32 / 32f32).collect::>(); - args.push(0.0); - args.push(1.0); - args.push(std::f32::consts::PI); - args.push(std::f32::consts::FRAC_PI_2); - args.push(std::f32::consts::FRAC_1_PI); - args.push(std::f32::consts::SQRT_2); - args.push(f32::NAN); + let args = test_args(); let mut ctx = Context::new(); for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index 92f926c9..1456f6f2 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -768,12 +768,20 @@ impl BulkEvaluator for VmFloatSliceEval { } RegOp::MinRegImm(out, arg, imm) => { for i in 0..size { - v[out][i] = v[arg][i].min(imm); + v[out][i] = if v[arg][i].is_nan() || imm.is_nan() { + f32::NAN + } else { + v[arg][i].min(imm) + }; } } RegOp::MaxRegImm(out, arg, imm) => { for i in 0..size { - v[out][i] = v[arg][i].max(imm); + v[out][i] = if v[arg][i].is_nan() || imm.is_nan() { + f32::NAN + } else { + v[arg][i].max(imm) + }; } } RegOp::AddRegReg(out, lhs, rhs) => {