Skip to content

Commit

Permalink
Systematic testing of binary ops for float-slice evaluators (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeeter authored Mar 16, 2024
1 parent 987fc63 commit 87e6568
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 53 deletions.
209 changes: 197 additions & 12 deletions fidget/src/core/eval/test/float_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! If the `eval-tests` feature is set, then this exposes a standard test suite
//! for such evaluators; otherwise, the module has no public exports.
use super::build_stress_fn;
use super::{build_stress_fn, test_args};
use crate::{
context::{Context, Node},
eval::{BulkEvaluator, EzShape, MathShape, Shape, ShapeVars, Vars},
Expand All @@ -16,6 +16,14 @@ macro_rules! float_slice_unary {
};
}

macro_rules! float_slice_binary {
(Context::$i:ident, $t:expr) => {
Self::test_binary_reg_reg(Context::$i, $t, stringify!($i));
Self::test_binary_reg_imm(Context::$i, $t, stringify!($i));
Self::test_binary_imm_reg(Context::$i, $t, stringify!($i));
};
}

/// Helper struct to put constrains on our `Shape` object
pub struct TestFloatSlice<S>(std::marker::PhantomData<*const S>);

Expand Down Expand Up @@ -276,17 +284,7 @@ where
g: impl Fn(f32) -> f32,
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut args =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);

let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
Expand Down Expand Up @@ -315,6 +313,152 @@ where
}
}

pub fn test_binary_reg_reg(
f: impl Fn(&mut Context, Node, Node) -> Result<Node, Error>,
g: impl Fn(f32, f32) -> f32,
name: &'static str,
) {
let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
let inputs = [ctx.x(), ctx.y(), ctx.z()];
for rot in 0..args.len() {
let mut rgsa = args.clone();
rgsa.rotate_left(rot);
for (i, &v) in inputs.iter().enumerate() {
for (j, &u) in inputs.iter().enumerate() {
let node = f(&mut ctx, v, u).unwrap();

let shape = S::new(&ctx, node).unwrap();
let mut eval = S::new_float_slice_eval();
let tape = shape.ez_float_slice_tape();

let out = match (i, j) {
(0, 0) => eval.eval(&tape, &args, &zero, &zero, &[]),
(0, 1) => eval.eval(&tape, &args, &rgsa, &zero, &[]),
(0, 2) => eval.eval(&tape, &args, &zero, &rgsa, &[]),
(1, 0) => eval.eval(&tape, &rgsa, &args, &zero, &[]),
(1, 1) => eval.eval(&tape, &zero, &args, &zero, &[]),
(1, 2) => eval.eval(&tape, &zero, &args, &rgsa, &[]),
(2, 0) => eval.eval(&tape, &rgsa, &zero, &args, &[]),
(2, 1) => eval.eval(&tape, &zero, &rgsa, &args, &[]),
(2, 2) => eval.eval(&tape, &zero, &zero, &args, &[]),
_ => unreachable!(),
}
.unwrap();

let b = if i == j { &args } else { &rgsa };
for ((a, b), &o) in args.iter().zip(b).zip(out.iter()) {
let v = g(*a, *b);
let err = (v - o).abs();
assert!(
(o == v)
|| err < 1e-6
|| (v.is_nan() && o.is_nan()),
"mismatch in '{name}' at {a} {b}: \
{v} != {o} ({err})"
)
}
}
}
}
}

fn test_binary_reg_imm(
f: impl Fn(&mut Context, Node, Node) -> Result<Node, Error>,
g: impl Fn(f32, f32) -> f32,
name: &'static str,
) {
let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
let inputs = [ctx.x(), ctx.y(), ctx.z()];

for rot in 0..args.len() {
let mut args = args.clone();
args.rotate_left(rot);
for (i, &v) in inputs.iter().enumerate() {
for rhs in args.iter() {
let c = ctx.constant(*rhs as f64);
let node = f(&mut ctx, v, c).unwrap();

let shape = S::new(&ctx, node).unwrap();
let mut eval = S::new_float_slice_eval();
let tape = shape.ez_float_slice_tape();

let out = match i {
0 => eval.eval(&tape, &args, &zero, &zero, &[]),
1 => eval.eval(&tape, &zero, &args, &zero, &[]),
2 => eval.eval(&tape, &zero, &zero, &args, &[]),
_ => unreachable!(),
}
.unwrap();

for (a, &o) in args.iter().zip(out.iter()) {
let v = g(*a, *rhs);
let err = (v - o).abs();
assert!(
(o == v)
|| err < 1e-6
|| (v.is_nan() && o.is_nan()),
"mismatch in '{name}' at {a}, {rhs} (constant): \
{v} != {o} ({err})"
)
}
}
}
}
}

fn test_binary_imm_reg(
f: impl Fn(&mut Context, Node, Node) -> Result<Node, Error>,
g: impl Fn(f32, f32) -> f32,
name: &'static str,
) {
let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
let inputs = [ctx.x(), ctx.y(), ctx.z()];

for rot in 0..args.len() {
let mut args = args.clone();
args.rotate_left(rot);
for (i, &v) in inputs.iter().enumerate() {
for lhs in args.iter() {
let c = ctx.constant(*lhs as f64);
let node = f(&mut ctx, c, v).unwrap();

let shape = S::new(&ctx, node).unwrap();
let mut eval = S::new_float_slice_eval();
let tape = shape.ez_float_slice_tape();

let out = match i {
0 => eval.eval(&tape, &args, &zero, &zero, &[]),
1 => eval.eval(&tape, &zero, &args, &zero, &[]),
2 => eval.eval(&tape, &zero, &zero, &args, &[]),
_ => unreachable!(),
}
.unwrap();

for (a, &o) in args.iter().zip(out.iter()) {
let v = g(*lhs, *a);
let err = (v - o).abs();
assert!(
(o == v)
|| err < 1e-6
|| (v.is_nan() && o.is_nan()),
"mismatch in '{name}' at {lhs} (constant), {a}: \
{v} != {o} ({err})"
)
}
}
}
}
}

pub fn test_f_unary_ops() {
float_slice_unary!(Context::neg, |v| -v);
float_slice_unary!(Context::recip, |v| 1.0 / v);
Expand All @@ -330,6 +474,46 @@ where
float_slice_unary!(Context::square, |v| v * v);
float_slice_unary!(Context::sqrt, |v| v.sqrt());
}

pub fn test_f_binary_ops() {
float_slice_binary!(Context::add, |a, b| a + b);
float_slice_binary!(Context::sub, |a, b| a - b);

// Multiplication short-circuits to 0, which means that
// 0 (constant) * NaN = 0
Self::test_binary_reg_reg(Context::mul, |a, b| a * b, "mul");
Self::test_binary_reg_imm(
Context::mul,
|a, b| if b == 0.0 { b } else { a * b },
"mul",
);
Self::test_binary_imm_reg(
Context::mul,
|a, b| if a == 0.0 { a } else { a * b },
"mul",
);

// Multiplication short-circuits to 0, which means that
// 0 (constant) / NaN = 0
Self::test_binary_reg_reg(Context::div, |a, b| a / b, "div");
Self::test_binary_reg_imm(Context::div, |a, b| a / b, "div");
Self::test_binary_imm_reg(
Context::div,
|a, b| if a == 0.0 { a } else { a / b },
"div",
);

float_slice_binary!(Context::min, |a, b| if a.is_nan() || b.is_nan() {
f32::NAN
} else {
a.min(b)
});
float_slice_binary!(Context::max, |a, b| if a.is_nan() || b.is_nan() {
f32::NAN
} else {
a.max(b)
});
}
}

#[macro_export]
Expand All @@ -351,5 +535,6 @@ macro_rules! float_slice_tests {
$crate::float_slice_test!(test_f_sin, $t);
$crate::float_slice_test!(test_f_stress, $t);
$crate::float_slice_test!(test_f_unary_ops, $t);
$crate::float_slice_test!(test_f_binary_ops, $t);
};
}
13 changes: 2 additions & 11 deletions fidget/src/core/eval/test/grad_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
//! If the `eval-tests` feature is set, then this exposes a standard test suite
//! for interval evaluators; otherwise, the module has no public exports.
use super::build_stress_fn;
use super::{build_stress_fn, test_args};
use crate::{
context::{Context, Node},
eval::{
Expand Down Expand Up @@ -442,16 +442,7 @@ where
g: impl Fn(f64) -> f64,
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut args =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);
let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
Expand Down
12 changes: 2 additions & 10 deletions fidget/src/core/eval/test/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! If the `eval-tests` feature is set, then this exposes a standard test suite
//! for interval evaluators; otherwise, the module has no public exports.
use super::build_stress_fn;
use super::{build_stress_fn, test_args};
use crate::{
context::{Context, Node},
eval::{
Expand Down Expand Up @@ -668,15 +668,7 @@ where
g: impl Fn(f32) -> f32,
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut values =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
values.push(0.0);
values.push(1.0);
values.push(std::f32::consts::PI);
values.push(std::f32::consts::FRAC_PI_2);
values.push(std::f32::consts::FRAC_1_PI);
values.push(std::f32::consts::SQRT_2);
let values = test_args();

let mut args = vec![];
for &lower in &values {
Expand Down
13 changes: 13 additions & 0 deletions fidget/src/core/eval/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,16 @@ pub(crate) fn build_stress_fn(n: usize) -> (Context, Node) {

(ctx, sum)
}

/// Pick a bunch of arguments, some of which are spicy
fn test_args() -> Vec<f32> {
let mut args = (-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);
args
}
12 changes: 2 additions & 10 deletions fidget/src/core/eval/test/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
//! If the `eval-tests` feature is set, then this exposes a standard test suite
//! for point evaluators; otherwise, the module has no public exports.
use super::build_stress_fn;
use super::{build_stress_fn, test_args};
use crate::{
context::{Context, Node},
eval::{EzShape, MathShape, Shape, ShapeVars, TracingEvaluator, Vars},
Expand Down Expand Up @@ -397,15 +397,7 @@ where
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut args =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);
let args = test_args();

let mut ctx = Context::new();
for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() {
Expand Down
Loading

0 comments on commit 87e6568

Please sign in to comment.