Skip to content

Commit

Permalink
Split 'eval with variables' into two flavors
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeeter committed Nov 17, 2024
1 parent c0ef85d commit 4491184
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 28 deletions.
11 changes: 6 additions & 5 deletions fidget/src/core/eval/test/float_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,27 +127,28 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {
.is_err());
let mut h: ShapeVars<&[f32]> = ShapeVars::new();
assert!(eval
.eval_v(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0], &h)
.eval_vs(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0], &h)
.is_err());
let index = v.index().unwrap();
h.insert(index, &[4.0, 5.0]);
assert_eq!(
eval.eval_v(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0], &h)
eval.eval_vs(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0], &h)
.unwrap(),
&[7.0, 10.0]
);

h.insert(index, &[4.0, 5.0, 6.0]);
assert!(matches!(
eval.eval_v(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0], &h),
Err(Error::MismatchedSlices)
eval.eval_vs(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0], &h),
Err(Error::MismatchedSlices),
));

// Get a new var index that isn't valid for this tape
let v2 = Var::new();
h.insert(index, &[4.0, 5.0]);
h.insert(v2.index().unwrap(), &[4.0, 5.0]);
assert!(matches!(
eval.eval_v(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0], &h),
eval.eval_vs(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0], &h),
Err(Error::BadVarSlice(..))
));
}
Expand Down
68 changes: 59 additions & 9 deletions fidget/src/core/shape/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,20 +514,18 @@ where
z: &[E::Data],
) -> Result<&[E::Data], Error> {
let h: ShapeVars<&[E::Data]> = Default::default();
self.eval_v(tape, x, y, z, &h)
self.eval_vs(tape, x, y, z, &h)
}

/// Bulk evaluation of many samples, with variables
///
/// Before evaluation, the tape's transform matrix is applied (if present).
pub fn eval_v<V: std::ops::Deref<Target = [G]>, G: Into<E::Data> + Copy>(
/// Helper function to do common setup
fn setup<V>(
&mut self,
tape: &ShapeTape<E::Tape>,
x: &[E::Data],
y: &[E::Data],
z: &[E::Data],
vars: &ShapeVars<V>,
) -> Result<&[E::Data], Error> {
) -> Result<usize, Error> {
assert_eq!(
tape.tape.output_count(),
1,
Expand All @@ -539,9 +537,7 @@ where
return Err(Error::MismatchedSlices);
}
let n = x.len();
if vars.values().any(|vs| vs.len() != n) {
return Err(Error::MismatchedSlices);
}

let vs = tape.vars();
let expected_vars = vs.len()
- vs.get(&Var::X).is_some() as usize
Expand Down Expand Up @@ -586,6 +582,31 @@ where
// TODO fast path if there are no extra vars, reusing slices
};

Ok(n)
}
/// Bulk evaluation of many samples, with slices of variables
///
/// Each variable slice must be the same length as our x, y, z slices
///
/// Before evaluation, the tape's transform matrix is applied (if present).
pub fn eval_vs<
V: std::ops::Deref<Target = [G]>,
G: Into<E::Data> + Copy,
>(
&mut self,
tape: &ShapeTape<E::Tape>,
x: &[E::Data],
y: &[E::Data],
z: &[E::Data],
vars: &ShapeVars<V>,
) -> Result<&[E::Data], Error> {
let n = self.setup(tape, x, y, z, vars)?;

if vars.values().any(|vs| vs.len() != n) {
return Err(Error::MismatchedSlices);
}

let vs = tape.vars();
for (var, value) in vars {
if let Some(i) = vs.get(&Var::V(*var)) {
if i < self.scratch.len() {
Expand All @@ -606,6 +627,35 @@ where
let out = self.eval.eval(&tape.tape, &self.scratch)?;
Ok(out.borrow(0))
}

/// Bulk evaluation of many samples, with fixed variables
///
/// Before evaluation, the tape's transform matrix is applied (if present).
pub fn eval_v<G: Into<E::Data> + Copy>(
&mut self,
tape: &ShapeTape<E::Tape>,
x: &[E::Data],
y: &[E::Data],
z: &[E::Data],
vars: &ShapeVars<G>,
) -> Result<&[E::Data], Error> {
self.setup(tape, x, y, z, vars)?;
let vs = tape.vars();
for (var, value) in vars {
if let Some(i) = vs.get(&Var::V(*var)) {
if i < self.scratch.len() {
self.scratch[i].fill((*value).into());
} else {
return Err(Error::BadVarIndex(i, self.scratch.len()));
}
} else {
// Passing in Bonus Variables is allowed (for now)
}
}

let out = self.eval.eval(&tape.tape, &self.scratch)?;
Ok(out.borrow(0))
}
}

/// Trait for types that can be transformed by a 4x4 homogeneous transform matrix
Expand Down
13 changes: 3 additions & 10 deletions fidget/src/render/render2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ struct Scratch {
x: Vec<f32>,
y: Vec<f32>,
z: Vec<f32>,
vs: ShapeVars<Vec<f32>>,
}

impl Scratch {
Expand All @@ -196,7 +195,6 @@ impl Scratch {
x: vec![0.0; size],
y: vec![0.0; size],
z: vec![0.0; size],
vs: ShapeVars::new(),
}
}
}
Expand Down Expand Up @@ -320,13 +318,14 @@ impl<F: Function, M: RenderMode> Worker<'_, F, M> {
}
}
} else {
self.render_tile_pixels(sub_tape, tile_size, tile);
self.render_tile_pixels(sub_tape, vars, tile_size, tile);
}
}

fn render_tile_pixels(
&mut self,
shape: &mut RenderHandle<F>,
vars: &ShapeVars<f32>,
tile_size: usize,
tile: Tile<2>,
) {
Expand All @@ -346,7 +345,7 @@ impl<F: Function, M: RenderMode> Worker<'_, F, M> {
&self.scratch.x,
&self.scratch.y,
&self.scratch.z,
&self.scratch.vs,
vars,
)
.unwrap();

Expand Down Expand Up @@ -386,11 +385,6 @@ fn worker<F: Function, M: RenderMode>(
workspace: Default::default(),
};

// Copy vars into scratch, expanding from single values to arrays
let last_tile_size = config.tile_sizes.last();
for (k, v) in vars {
w.scratch.vs.insert(*k, vec![*v; last_tile_size.pow(2)]);
}
while let Some(tile) = queue.next() {
w.image = vec![M::Output::default(); config.tile_sizes[0].pow(2)];
w.render_tile_recurse(&mut shape, vars, 0, tile);
Expand Down Expand Up @@ -844,7 +838,6 @@ mod test {
.test(shape, EXPECTED_05);
}

#[macro_export]
macro_rules! render_tests {
($i:ident) => {
mod $i {
Expand Down
91 changes: 87 additions & 4 deletions fidget/src/render/render3d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ impl Scratch {
fn new(tile_size: usize) -> Self {
let size2 = tile_size.pow(2);
let size3 = tile_size.pow(3);

Self {
x: vec![0.0; size3],
y: vec![0.0; size3],
Expand Down Expand Up @@ -138,14 +139,15 @@ impl<F: Function> Worker<'_, F> {
}
}
} else {
self.render_tile_pixels(sub_tape, tile_size, tile);
self.render_tile_pixels(sub_tape, vars, tile_size, tile);
};
// TODO recycle something here?
}

fn render_tile_pixels(
&mut self,
shape: &mut RenderHandle<F>,
vars: &ShapeVars<f32>,
tile_size: usize,
tile: Tile<3>,
) {
Expand Down Expand Up @@ -195,11 +197,12 @@ impl<F: Function> Worker<'_, F> {

let out = self
.eval_float_slice
.eval(
.eval_v(
shape.f_tape(&mut self.tape_storage),
&self.scratch.x[..index],
&self.scratch.y[..index],
&self.scratch.z[..index],
vars,
)
.unwrap();

Expand Down Expand Up @@ -256,11 +259,12 @@ impl<F: Function> Worker<'_, F> {
if grad > 0 {
let out = self
.eval_grad_slice
.eval(
.eval_v(
shape.g_tape(&mut self.tape_storage),
&self.scratch.xg[..grad],
&self.scratch.yg[..grad],
&self.scratch.zg[..grad],
vars,
)
.unwrap();

Expand Down Expand Up @@ -448,7 +452,9 @@ pub fn render<F: Function>(
#[cfg(test)]
mod test {
use super::*;
use crate::{render::VoxelSize, vm::VmShape, Context};
use crate::{
eval::MathFunction, render::VoxelSize, var::Var, vm::VmShape, Context,
};

/// Make sure we don't crash if there's only a single tile
#[test]
Expand All @@ -465,4 +471,81 @@ mod test {
assert_eq!(depth.len(), 128 * 128);
assert_eq!(rgb.len(), 128 * 128);
}

fn sphere_var<F: Function + MathFunction>() {
let mut ctx = Context::new();
let x = ctx.x();
let y = ctx.y();
let z = ctx.z();
let x2 = ctx.square(x).unwrap();
let y2 = ctx.square(y).unwrap();
let z2 = ctx.square(z).unwrap();
let x2y2 = ctx.add(x2, y2).unwrap();
let r2 = ctx.add(x2y2, z2).unwrap();
let r = ctx.sqrt(r2).unwrap();
let v = Var::new();
let c = ctx.var(v);
let root = ctx.sub(r, c).unwrap();
let shape = Shape::<F>::new(&ctx, root).unwrap();

let size = 32;
let cfg = VoxelRenderConfig {
image_size: VoxelSize::from(size),
..Default::default()
};

for r in [0.5, 0.75] {
let mut vars = ShapeVars::new();
vars.insert(v.index().unwrap(), r);
let (depth, _normal) = cfg.run_with_vars::<_>(shape.clone(), &vars);

let epsilon = 0.08;
for (i, p) in depth.iter().enumerate() {
let size = size as i32;
let i = i as i32;
let x = (((i % size) - size / 2) as f32 / size as f32) * 2.0;
let y = (((i / size) - size / 2) as f32 / size as f32) * 2.0;
let z = (*p as i32 - size / 2) as f32 / size as f32 * 2.0;
if *p == 0 {
let v = (x.powi(2) + y.powi(2)).sqrt();
assert!(
v + epsilon > r,
"got z = 0 inside the sphere ({x}, {y}, {z}); \
radius is {v}"
);
} else {
let v = (x.powi(2) + y.powi(2) + z.powi(2)).sqrt();
let err = (r - v).abs();
assert!(
err < epsilon,
"too much error {err} at ({x}, {y}, {z}); \
radius is {v}, expected {r}"
);
}
}
}
}

macro_rules! render_tests {
($i:ident) => {
mod $i {
use super::*;
#[test]
fn vm() {
$i::<$crate::vm::VmFunction>();
}
#[test]
fn vm3() {
$i::<$crate::vm::GenericVmFunction<3>>();
}
#[cfg(feature = "jit")]
#[test]
fn jit() {
$i::<$crate::jit::JitFunction>();
}
}
};
}

render_tests!(sphere_var);
}

0 comments on commit 4491184

Please sign in to comment.