Skip to content

Commit

Permalink
More work on float fuzzing, fix an error
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeeter committed Mar 16, 2024
1 parent 987fc63 commit 708cabe
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 6 deletions.
86 changes: 86 additions & 0 deletions fidget/src/core/eval/test/float_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ macro_rules! float_slice_unary {
};
}

macro_rules! float_slice_binary {
(Context::$i:ident, $t:expr) => {
Self::test_binary(Context::$i, $t, stringify!($i));
};
}

/// Helper struct to put constrains on our `Shape` object
pub struct TestFloatSlice<S>(std::marker::PhantomData<*const S>);

Expand Down Expand Up @@ -315,6 +321,68 @@ where
}
}

pub fn test_binary(
f: impl Fn(&mut Context, Node, Node) -> Result<Node, Error>,
g: impl Fn(f32, f32) -> f32,
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut args =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);

let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
let inputs = [ctx.x(), ctx.y(), ctx.z()];
for rot in 0..args.len() {
let mut rgsa = args.clone();
rgsa.rotate_left(rot);
for (i, &v) in inputs.iter().enumerate() {
for (j, &u) in inputs.iter().enumerate() {
let node = f(&mut ctx, v, u).unwrap();

let shape = S::new(&ctx, node).unwrap();
let mut eval = S::new_float_slice_eval();
let tape = shape.ez_float_slice_tape();

let out = match (i, j) {
(0, 0) => eval.eval(&tape, &args, &zero, &zero, &[]),
(0, 1) => eval.eval(&tape, &args, &rgsa, &zero, &[]),
(0, 2) => eval.eval(&tape, &args, &zero, &rgsa, &[]),
(1, 0) => eval.eval(&tape, &rgsa, &args, &zero, &[]),
(1, 1) => eval.eval(&tape, &zero, &args, &zero, &[]),
(1, 2) => eval.eval(&tape, &zero, &args, &rgsa, &[]),
(2, 0) => eval.eval(&tape, &rgsa, &zero, &args, &[]),
(2, 1) => eval.eval(&tape, &zero, &rgsa, &args, &[]),
(2, 2) => eval.eval(&tape, &zero, &zero, &args, &[]),
_ => unreachable!(),
}
.unwrap();

let b = if i == j { &args } else { &rgsa };
for ((a, b), &o) in args.iter().zip(b).zip(out.iter()) {
let v = g(*a, *b);
let err = (v - o).abs();
assert!(
(o == v)
|| err < 1e-6
|| (v.is_nan() && o.is_nan()),
"mismatch in '{name}' at {a} {b}: \
{v} != {o} ({err})"
)
}
}
}
}
}

pub fn test_f_unary_ops() {
float_slice_unary!(Context::neg, |v| -v);
float_slice_unary!(Context::recip, |v| 1.0 / v);
Expand All @@ -330,6 +398,23 @@ where
float_slice_unary!(Context::square, |v| v * v);
float_slice_unary!(Context::sqrt, |v| v.sqrt());
}

pub fn test_f_binary_ops() {
float_slice_binary!(Context::add, |a, b| a + b);
float_slice_binary!(Context::sub, |a, b| a - b);
float_slice_binary!(Context::mul, |a, b| a * b);
float_slice_binary!(Context::div, |a, b| a / b);
float_slice_binary!(Context::min, |a, b| if a.is_nan() || b.is_nan() {
f32::NAN
} else {
a.min(b)
});
float_slice_binary!(Context::max, |a, b| if a.is_nan() || b.is_nan() {
f32::NAN
} else {
a.max(b)
});
}
}

#[macro_export]
Expand All @@ -351,5 +436,6 @@ macro_rules! float_slice_tests {
$crate::float_slice_test!(test_f_sin, $t);
$crate::float_slice_test!(test_f_stress, $t);
$crate::float_slice_test!(test_f_unary_ops, $t);
$crate::float_slice_test!(test_f_binary_ops, $t);
};
}
40 changes: 34 additions & 6 deletions fidget/src/core/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -798,12 +798,22 @@ impl<const N: usize> BulkEvaluator for VmFloatSliceEval<N> {
}
RegOp::MinRegReg(out, lhs, rhs) => {
for i in 0..size {
v[out][i] = v[lhs][i].min(v[rhs][i]);
v[out][i] = if v[lhs][i].is_nan() || v[rhs][i].is_nan()
{
f32::NAN
} else {
v[lhs][i].min(v[rhs][i])
};
}
}
RegOp::MaxRegReg(out, lhs, rhs) => {
for i in 0..size {
v[out][i] = v[lhs][i].max(v[rhs][i]);
v[out][i] = if v[lhs][i].is_nan() || v[rhs][i].is_nan()
{
f32::NAN
} else {
v[lhs][i].max(v[rhs][i])
};
}
}
RegOp::CopyImm(out, imm) => {
Expand Down Expand Up @@ -981,13 +991,21 @@ impl<const N: usize> BulkEvaluator for VmGradSliceEval<N> {
RegOp::MinRegImm(out, arg, imm) => {
let imm: Grad = imm.into();
for i in 0..size {
v[out][i] = v[arg][i].min(imm);
v[out][i] = if v[arg][i].v.is_nan() || imm.v.is_nan() {
f32::NAN.into()
} else {
v[arg][i].min(imm)
};
}
}
RegOp::MaxRegImm(out, arg, imm) => {
let imm: Grad = imm.into();
for i in 0..size {
v[out][i] = v[arg][i].max(imm);
v[out][i] = if v[arg][i].v.is_nan() || imm.v.is_nan() {
f32::NAN.into()
} else {
v[arg][i].max(imm)
};
}
}
RegOp::AddRegReg(out, lhs, rhs) => {
Expand All @@ -1012,12 +1030,22 @@ impl<const N: usize> BulkEvaluator for VmGradSliceEval<N> {
}
RegOp::MinRegReg(out, lhs, rhs) => {
for i in 0..size {
v[out][i] = v[lhs][i].min(v[rhs][i]);
v[out][i] =
if v[lhs][i].v.is_nan() || v[rhs][i].v.is_nan() {
f32::NAN.into()
} else {
v[lhs][i].min(v[rhs][i])
};
}
}
RegOp::MaxRegReg(out, lhs, rhs) => {
for i in 0..size {
v[out][i] = v[lhs][i].max(v[rhs][i]);
v[out][i] =
if v[lhs][i].v.is_nan() || v[rhs][i].v.is_nan() {
f32::NAN.into()
} else {
v[lhs][i].max(v[rhs][i])
};
}
}
RegOp::CopyImm(out, imm) => {
Expand Down

0 comments on commit 708cabe

Please sign in to comment.