Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add systematic testing of binary ops for float-slice evaluators #29

Merged
merged 3 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 197 additions & 12 deletions fidget/src/core/eval/test/float_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! If the `eval-tests` feature is set, then this exposes a standard test suite
//! for such evaluators; otherwise, the module has no public exports.

use super::build_stress_fn;
use super::{build_stress_fn, test_args};
use crate::{
context::{Context, Node},
eval::{BulkEvaluator, EzShape, MathShape, Shape, ShapeVars, Vars},
Expand All @@ -16,6 +16,14 @@ macro_rules! float_slice_unary {
};
}

macro_rules! float_slice_binary {
(Context::$i:ident, $t:expr) => {
Self::test_binary_reg_reg(Context::$i, $t, stringify!($i));
Self::test_binary_reg_imm(Context::$i, $t, stringify!($i));
Self::test_binary_imm_reg(Context::$i, $t, stringify!($i));
};
}

/// Helper struct to put constrains on our `Shape` object
pub struct TestFloatSlice<S>(std::marker::PhantomData<*const S>);

Expand Down Expand Up @@ -276,17 +284,7 @@ where
g: impl Fn(f32) -> f32,
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut args =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);

let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
Expand Down Expand Up @@ -315,6 +313,152 @@ where
}
}

pub fn test_binary_reg_reg(
f: impl Fn(&mut Context, Node, Node) -> Result<Node, Error>,
g: impl Fn(f32, f32) -> f32,
name: &'static str,
) {
let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
let inputs = [ctx.x(), ctx.y(), ctx.z()];
for rot in 0..args.len() {
let mut rgsa = args.clone();
rgsa.rotate_left(rot);
for (i, &v) in inputs.iter().enumerate() {
for (j, &u) in inputs.iter().enumerate() {
let node = f(&mut ctx, v, u).unwrap();

let shape = S::new(&ctx, node).unwrap();
let mut eval = S::new_float_slice_eval();
let tape = shape.ez_float_slice_tape();

let out = match (i, j) {
(0, 0) => eval.eval(&tape, &args, &zero, &zero, &[]),
(0, 1) => eval.eval(&tape, &args, &rgsa, &zero, &[]),
(0, 2) => eval.eval(&tape, &args, &zero, &rgsa, &[]),
(1, 0) => eval.eval(&tape, &rgsa, &args, &zero, &[]),
(1, 1) => eval.eval(&tape, &zero, &args, &zero, &[]),
(1, 2) => eval.eval(&tape, &zero, &args, &rgsa, &[]),
(2, 0) => eval.eval(&tape, &rgsa, &zero, &args, &[]),
(2, 1) => eval.eval(&tape, &zero, &rgsa, &args, &[]),
(2, 2) => eval.eval(&tape, &zero, &zero, &args, &[]),
_ => unreachable!(),
}
.unwrap();

let b = if i == j { &args } else { &rgsa };
for ((a, b), &o) in args.iter().zip(b).zip(out.iter()) {
let v = g(*a, *b);
let err = (v - o).abs();
assert!(
(o == v)
|| err < 1e-6
|| (v.is_nan() && o.is_nan()),
"mismatch in '{name}' at {a} {b}: \
{v} != {o} ({err})"
)
}
}
}
}
}

fn test_binary_reg_imm(
f: impl Fn(&mut Context, Node, Node) -> Result<Node, Error>,
g: impl Fn(f32, f32) -> f32,
name: &'static str,
) {
let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
let inputs = [ctx.x(), ctx.y(), ctx.z()];

for rot in 0..args.len() {
let mut args = args.clone();
args.rotate_left(rot);
for (i, &v) in inputs.iter().enumerate() {
for rhs in args.iter() {
let c = ctx.constant(*rhs as f64);
let node = f(&mut ctx, v, c).unwrap();

let shape = S::new(&ctx, node).unwrap();
let mut eval = S::new_float_slice_eval();
let tape = shape.ez_float_slice_tape();

let out = match i {
0 => eval.eval(&tape, &args, &zero, &zero, &[]),
1 => eval.eval(&tape, &zero, &args, &zero, &[]),
2 => eval.eval(&tape, &zero, &zero, &args, &[]),
_ => unreachable!(),
}
.unwrap();

for (a, &o) in args.iter().zip(out.iter()) {
let v = g(*a, *rhs);
let err = (v - o).abs();
assert!(
(o == v)
|| err < 1e-6
|| (v.is_nan() && o.is_nan()),
"mismatch in '{name}' at {a}, {rhs} (constant): \
{v} != {o} ({err})"
)
}
}
}
}
}

fn test_binary_imm_reg(
f: impl Fn(&mut Context, Node, Node) -> Result<Node, Error>,
g: impl Fn(f32, f32) -> f32,
name: &'static str,
) {
let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
let inputs = [ctx.x(), ctx.y(), ctx.z()];

for rot in 0..args.len() {
let mut args = args.clone();
args.rotate_left(rot);
for (i, &v) in inputs.iter().enumerate() {
for lhs in args.iter() {
let c = ctx.constant(*lhs as f64);
let node = f(&mut ctx, c, v).unwrap();

let shape = S::new(&ctx, node).unwrap();
let mut eval = S::new_float_slice_eval();
let tape = shape.ez_float_slice_tape();

let out = match i {
0 => eval.eval(&tape, &args, &zero, &zero, &[]),
1 => eval.eval(&tape, &zero, &args, &zero, &[]),
2 => eval.eval(&tape, &zero, &zero, &args, &[]),
_ => unreachable!(),
}
.unwrap();

for (a, &o) in args.iter().zip(out.iter()) {
let v = g(*lhs, *a);
let err = (v - o).abs();
assert!(
(o == v)
|| err < 1e-6
|| (v.is_nan() && o.is_nan()),
"mismatch in '{name}' at {lhs} (constant), {a}: \
{v} != {o} ({err})"
)
}
}
}
}
}

pub fn test_f_unary_ops() {
float_slice_unary!(Context::neg, |v| -v);
float_slice_unary!(Context::recip, |v| 1.0 / v);
Expand All @@ -330,6 +474,46 @@ where
float_slice_unary!(Context::square, |v| v * v);
float_slice_unary!(Context::sqrt, |v| v.sqrt());
}

pub fn test_f_binary_ops() {
float_slice_binary!(Context::add, |a, b| a + b);
float_slice_binary!(Context::sub, |a, b| a - b);

// Multiplication short-circuits to 0, which means that
// 0 (constant) * NaN = 0
Self::test_binary_reg_reg(Context::mul, |a, b| a * b, "mul");
Self::test_binary_reg_imm(
Context::mul,
|a, b| if b == 0.0 { b } else { a * b },
"mul",
);
Self::test_binary_imm_reg(
Context::mul,
|a, b| if a == 0.0 { a } else { a * b },
"mul",
);

// Multiplication short-circuits to 0, which means that
// 0 (constant) / NaN = 0
Self::test_binary_reg_reg(Context::div, |a, b| a / b, "div");
Self::test_binary_reg_imm(Context::div, |a, b| a / b, "div");
Self::test_binary_imm_reg(
Context::div,
|a, b| if a == 0.0 { a } else { a / b },
"div",
);

float_slice_binary!(Context::min, |a, b| if a.is_nan() || b.is_nan() {
f32::NAN
} else {
a.min(b)
});
float_slice_binary!(Context::max, |a, b| if a.is_nan() || b.is_nan() {
f32::NAN
} else {
a.max(b)
});
}
}

#[macro_export]
Expand All @@ -351,5 +535,6 @@ macro_rules! float_slice_tests {
$crate::float_slice_test!(test_f_sin, $t);
$crate::float_slice_test!(test_f_stress, $t);
$crate::float_slice_test!(test_f_unary_ops, $t);
$crate::float_slice_test!(test_f_binary_ops, $t);
};
}
13 changes: 2 additions & 11 deletions fidget/src/core/eval/test/grad_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
//! If the `eval-tests` feature is set, then this exposes a standard test suite
//! for interval evaluators; otherwise, the module has no public exports.
use super::build_stress_fn;
use super::{build_stress_fn, test_args};
use crate::{
context::{Context, Node},
eval::{
Expand Down Expand Up @@ -442,16 +442,7 @@ where
g: impl Fn(f64) -> f64,
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut args =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);
let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
Expand Down
12 changes: 2 additions & 10 deletions fidget/src/core/eval/test/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! If the `eval-tests` feature is set, then this exposes a standard test suite
//! for interval evaluators; otherwise, the module has no public exports.

use super::build_stress_fn;
use super::{build_stress_fn, test_args};
use crate::{
context::{Context, Node},
eval::{
Expand Down Expand Up @@ -668,15 +668,7 @@ where
g: impl Fn(f32) -> f32,
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut values =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
values.push(0.0);
values.push(1.0);
values.push(std::f32::consts::PI);
values.push(std::f32::consts::FRAC_PI_2);
values.push(std::f32::consts::FRAC_1_PI);
values.push(std::f32::consts::SQRT_2);
let values = test_args();

let mut args = vec![];
for &lower in &values {
Expand Down
13 changes: 13 additions & 0 deletions fidget/src/core/eval/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,16 @@ pub(crate) fn build_stress_fn(n: usize) -> (Context, Node) {

(ctx, sum)
}

/// Pick a bunch of arguments, some of which are spicy
fn test_args() -> Vec<f32> {
let mut args = (-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);
args
}
12 changes: 2 additions & 10 deletions fidget/src/core/eval/test/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
//! If the `eval-tests` feature is set, then this exposes a standard test suite
//! for point evaluators; otherwise, the module has no public exports.
use super::build_stress_fn;
use super::{build_stress_fn, test_args};
use crate::{
context::{Context, Node},
eval::{EzShape, MathShape, Shape, ShapeVars, TracingEvaluator, Vars},
Expand Down Expand Up @@ -397,15 +397,7 @@ where
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut args =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);
let args = test_args();

let mut ctx = Context::new();
for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() {
Expand Down
Loading