Skip to content

Commit

Permalink
More work on fuzzing
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeeter committed Mar 16, 2024
1 parent 708cabe commit 1d941f4
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 60 deletions.
153 changes: 126 additions & 27 deletions fidget/src/core/eval/test/float_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! If the `eval-tests` feature is set, then this exposes a standard test suite
//! for such evaluators; otherwise, the module has no public exports.
use super::build_stress_fn;
use super::{build_stress_fn, test_args};
use crate::{
context::{Context, Node},
eval::{BulkEvaluator, EzShape, MathShape, Shape, ShapeVars, Vars},
Expand All @@ -18,7 +18,9 @@ macro_rules! float_slice_unary {

macro_rules! float_slice_binary {
(Context::$i:ident, $t:expr) => {
Self::test_binary(Context::$i, $t, stringify!($i));
Self::test_binary_reg_reg(Context::$i, $t, stringify!($i));
Self::test_binary_reg_imm(Context::$i, $t, stringify!($i));
Self::test_binary_imm_reg(Context::$i, $t, stringify!($i));
};
}

Expand Down Expand Up @@ -282,17 +284,7 @@ where
g: impl Fn(f32) -> f32,
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut args =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);

let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
Expand Down Expand Up @@ -321,22 +313,12 @@ where
}
}

pub fn test_binary(
pub fn test_binary_reg_reg(
f: impl Fn(&mut Context, Node, Node) -> Result<Node, Error>,
g: impl Fn(f32, f32) -> f32,
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut args =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);

let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
Expand Down Expand Up @@ -383,6 +365,100 @@ where
}
}

fn test_binary_reg_imm(
f: impl Fn(&mut Context, Node, Node) -> Result<Node, Error>,
g: impl Fn(f32, f32) -> f32,
name: &'static str,
) {
let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
let inputs = [ctx.x(), ctx.y(), ctx.z()];

for rot in 0..args.len() {
let mut args = args.clone();
args.rotate_left(rot);
for (i, &v) in inputs.iter().enumerate() {
for rhs in args.iter() {
let c = ctx.constant(*rhs as f64);
let node = f(&mut ctx, v, c).unwrap();

let shape = S::new(&ctx, node).unwrap();
let mut eval = S::new_float_slice_eval();
let tape = shape.ez_float_slice_tape();

let out = match i {
0 => eval.eval(&tape, &args, &zero, &zero, &[]),
1 => eval.eval(&tape, &zero, &args, &zero, &[]),
2 => eval.eval(&tape, &zero, &zero, &args, &[]),
_ => unreachable!(),
}
.unwrap();

for (a, &o) in args.iter().zip(out.iter()) {
let v = g(*a, *rhs);
let err = (v - o).abs();
assert!(
(o == v)
|| err < 1e-6
|| (v.is_nan() && o.is_nan()),
"mismatch in '{name}' at {a}, {rhs} (constant): \
{v} != {o} ({err})"
)
}
}
}
}
}

fn test_binary_imm_reg(
f: impl Fn(&mut Context, Node, Node) -> Result<Node, Error>,
g: impl Fn(f32, f32) -> f32,
name: &'static str,
) {
let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
let inputs = [ctx.x(), ctx.y(), ctx.z()];

for rot in 0..args.len() {
let mut args = args.clone();
args.rotate_left(rot);
for (i, &v) in inputs.iter().enumerate() {
for lhs in args.iter() {
let c = ctx.constant(*lhs as f64);
let node = f(&mut ctx, c, v).unwrap();

let shape = S::new(&ctx, node).unwrap();
let mut eval = S::new_float_slice_eval();
let tape = shape.ez_float_slice_tape();

let out = match i {
0 => eval.eval(&tape, &args, &zero, &zero, &[]),
1 => eval.eval(&tape, &zero, &args, &zero, &[]),
2 => eval.eval(&tape, &zero, &zero, &args, &[]),
_ => unreachable!(),
}
.unwrap();

for (a, &o) in args.iter().zip(out.iter()) {
let v = g(*lhs, *a);
let err = (v - o).abs();
assert!(
(o == v)
|| err < 1e-6
|| (v.is_nan() && o.is_nan()),
"mismatch in '{name}' at {lhs} (constant), {a}: \
{v} != {o} ({err})"
)
}
}
}
}
}

pub fn test_f_unary_ops() {
float_slice_unary!(Context::neg, |v| -v);
float_slice_unary!(Context::recip, |v| 1.0 / v);
Expand All @@ -402,8 +478,31 @@ where
pub fn test_f_binary_ops() {
float_slice_binary!(Context::add, |a, b| a + b);
float_slice_binary!(Context::sub, |a, b| a - b);
float_slice_binary!(Context::mul, |a, b| a * b);
float_slice_binary!(Context::div, |a, b| a / b);

// Multiplication short-circuits to 0, which means that
// 0 (constant) * NaN = 0
Self::test_binary_reg_reg(Context::mul, |a, b| a * b, "mul");
Self::test_binary_reg_imm(
Context::mul,
|a, b| if b == 0.0 { b } else { a * b },
"mul",
);
Self::test_binary_imm_reg(
Context::mul,
|a, b| if a == 0.0 { a } else { a * b },
"mul",
);

// Multiplication short-circuits to 0, which means that
// 0 (constant) / NaN = 0
Self::test_binary_reg_reg(Context::div, |a, b| a / b, "div");
Self::test_binary_reg_imm(Context::div, |a, b| a / b, "div");
Self::test_binary_imm_reg(
Context::div,
|a, b| if a == 0.0 { a } else { a / b },
"div",
);

float_slice_binary!(Context::min, |a, b| if a.is_nan() || b.is_nan() {
f32::NAN
} else {
Expand Down
13 changes: 2 additions & 11 deletions fidget/src/core/eval/test/grad_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
//! If the `eval-tests` feature is set, then this exposes a standard test suite
//! for interval evaluators; otherwise, the module has no public exports.
use super::build_stress_fn;
use super::{build_stress_fn, test_args};
use crate::{
context::{Context, Node},
eval::{
Expand Down Expand Up @@ -442,16 +442,7 @@ where
g: impl Fn(f64) -> f64,
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut args =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);
let args = test_args();
let zero = vec![0.0; args.len()];

let mut ctx = Context::new();
Expand Down
12 changes: 2 additions & 10 deletions fidget/src/core/eval/test/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! If the `eval-tests` feature is set, then this exposes a standard test suite
//! for interval evaluators; otherwise, the module has no public exports.
use super::build_stress_fn;
use super::{build_stress_fn, test_args};
use crate::{
context::{Context, Node},
eval::{
Expand Down Expand Up @@ -668,15 +668,7 @@ where
g: impl Fn(f32) -> f32,
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut values =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
values.push(0.0);
values.push(1.0);
values.push(std::f32::consts::PI);
values.push(std::f32::consts::FRAC_PI_2);
values.push(std::f32::consts::FRAC_1_PI);
values.push(std::f32::consts::SQRT_2);
let values = test_args();

let mut args = vec![];
for &lower in &values {
Expand Down
13 changes: 13 additions & 0 deletions fidget/src/core/eval/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,16 @@ pub(crate) fn build_stress_fn(n: usize) -> (Context, Node) {

(ctx, sum)
}

/// Pick a bunch of arguments, some of which are spicy
fn test_args() -> Vec<f32> {
let mut args = (-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);
args
}
12 changes: 2 additions & 10 deletions fidget/src/core/eval/test/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
//! If the `eval-tests` feature is set, then this exposes a standard test suite
//! for point evaluators; otherwise, the module has no public exports.
use super::build_stress_fn;
use super::{build_stress_fn, test_args};
use crate::{
context::{Context, Node},
eval::{EzShape, MathShape, Shape, ShapeVars, TracingEvaluator, Vars},
Expand Down Expand Up @@ -397,15 +397,7 @@ where
name: &'static str,
) {
// Pick a bunch of arguments, some of which are spicy
let mut args =
(-32..32).map(|i| i as f32 / 32f32).collect::<Vec<f32>>();
args.push(0.0);
args.push(1.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
args.push(std::f32::consts::FRAC_1_PI);
args.push(std::f32::consts::SQRT_2);
args.push(f32::NAN);
let args = test_args();

let mut ctx = Context::new();
for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() {
Expand Down
12 changes: 10 additions & 2 deletions fidget/src/core/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -768,12 +768,20 @@ impl<const N: usize> BulkEvaluator for VmFloatSliceEval<N> {
}
RegOp::MinRegImm(out, arg, imm) => {
for i in 0..size {
v[out][i] = v[arg][i].min(imm);
v[out][i] = if v[arg][i].is_nan() || imm.is_nan() {
f32::NAN
} else {
v[arg][i].min(imm)
};
}
}
RegOp::MaxRegImm(out, arg, imm) => {
for i in 0..size {
v[out][i] = v[arg][i].max(imm);
v[out][i] = if v[arg][i].is_nan() || imm.is_nan() {
f32::NAN
} else {
v[arg][i].max(imm)
};
}
}
RegOp::AddRegReg(out, lhs, rhs) => {
Expand Down

0 comments on commit 1d941f4

Please sign in to comment.