diff --git a/CHANGELOG.md b/CHANGELOG.md index 52e87772..46df27ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ - Change signature of `fidget::render::render2d` to pass the mode only as a generic parameter, instead of an argument - Add new operations: `floor`, `ceil`, `round`, `atan2` +- Changed `BulkEvaluator::eval` signature to take x, y, z arguments as `&[T]` + instead of `&[f32]`. This is more flexible for gradient evaluation, because + it allows the caller to specify up to three gradients, without pinning them to + specific argument. # 0.2.6 This is a relatively small release; there are a few features to improve the diff --git a/fidget/src/core/eval/bulk.rs b/fidget/src/core/eval/bulk.rs index e0e519bd..8fd9b3b1 100644 --- a/fidget/src/core/eval/bulk.rs +++ b/fidget/src/core/eval/bulk.rs @@ -40,9 +40,9 @@ pub trait BulkEvaluator: Default { fn eval( &mut self, tape: &Self::Tape, - x: &[f32], - y: &[f32], - z: &[f32], + x: &[Self::Data], + y: &[Self::Data], + z: &[Self::Data], ) -> Result<&[Self::Data], Error>; /// Build a new empty evaluator diff --git a/fidget/src/core/eval/test/grad_slice.rs b/fidget/src/core/eval/test/grad_slice.rs index 36f78e4e..42ef1450 100644 --- a/fidget/src/core/eval/test/grad_slice.rs +++ b/fidget/src/core/eval/test/grad_slice.rs @@ -19,15 +19,33 @@ impl TestGradSlice where S: Shape + MathShape, { + fn eval_xyz( + tape: &<::GradSliceEval as BulkEvaluator>::Tape, + xs: &[f32], + ys: &[f32], + zs: &[f32], + ) -> Vec { + assert_eq!(xs.len(), ys.len()); + assert_eq!(ys.len(), zs.len()); + let xs: Vec<_> = + xs.iter().map(|x| Grad::new(*x, 1.0, 0.0, 0.0)).collect(); + let ys: Vec<_> = + ys.iter().map(|y| Grad::new(*y, 0.0, 1.0, 0.0)).collect(); + let zs: Vec<_> = + zs.iter().map(|z| Grad::new(*z, 0.0, 0.0, 1.0)).collect(); + + let mut eval = S::new_grad_slice_eval(); + eval.eval(tape, &xs, &ys, &zs).unwrap().to_owned() + } + pub fn test_g_x() { let mut ctx = Context::new(); let x = ctx.x(); let shape = S::new(&ctx, x).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[2.0], &[3.0], &[4.0]).unwrap()[0], + Self::eval_xyz(&tape, &[2.0], &[3.0], &[4.0])[0], Grad::new(2.0, 1.0, 0.0, 0.0) ); } @@ -37,10 +55,9 @@ where let y = ctx.y(); let shape = S::new(&ctx, y).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[2.0], &[3.0], &[4.0]).unwrap()[0], + Self::eval_xyz(&tape, &[2.0], &[3.0], &[4.0])[0], Grad::new(3.0, 0.0, 1.0, 0.0) ); } @@ -50,10 +67,9 @@ where let z = ctx.z(); let shape = S::new(&ctx, z).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[2.0], &[3.0], &[4.0]).unwrap()[0], + Self::eval_xyz(&tape, &[2.0], &[3.0], &[4.0])[0], Grad::new(4.0, 0.0, 0.0, 1.0) ); } @@ -64,22 +80,21 @@ where let s = ctx.square(x).unwrap(); let shape = S::new(&ctx, s).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[0.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[0.0], &[0.0], &[0.0])[0], Grad::new(0.0, 0.0, 0.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[1.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[1.0], &[0.0], &[0.0])[0], Grad::new(1.0, 2.0, 0.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[2.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[2.0], &[0.0], &[0.0])[0], Grad::new(4.0, 4.0, 0.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[3.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[3.0], &[0.0], &[0.0])[0], Grad::new(9.0, 6.0, 0.0, 0.0) ); } @@ -90,14 +105,13 @@ where let s = ctx.abs(x).unwrap(); let shape = S::new(&ctx, s).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[2.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[2.0], &[0.0], &[0.0])[0], Grad::new(2.0, 1.0, 0.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[-2.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[-2.0], &[0.0], &[0.0])[0], Grad::new(2.0, -1.0, 0.0, 0.0) ); } @@ -108,14 +122,13 @@ where let s = ctx.sqrt(x).unwrap(); let shape = S::new(&ctx, s).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[1.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[1.0], &[0.0], &[0.0])[0], Grad::new(1.0, 0.5, 0.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[4.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[4.0], &[0.0], &[0.0])[0], Grad::new(2.0, 0.25, 0.0, 0.0) ); } @@ -127,10 +140,7 @@ where let shape = S::new(&ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); - let mut eval = S::new_grad_slice_eval(); - let v = eval - .eval(&tape, &[1.0, 2.0, 3.0], &[0.0; 3], &[0.0; 3]) - .unwrap(); + let v = Self::eval_xyz(&tape, &[1.0, 2.0, 3.0], &[0.0; 3], &[0.0; 3]); v[0].compare_eq(Grad::new(1f32.sin(), 1f32.cos(), 0.0, 0.0)); v[1].compare_eq(Grad::new(2f32.sin(), 2f32.cos(), 0.0, 0.0)); v[2].compare_eq(Grad::new(3f32.sin(), 3f32.cos(), 0.0, 0.0)); @@ -140,9 +150,7 @@ where let s = ctx.sin(y).unwrap(); let shape = S::new(&ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); - let v = eval - .eval(&tape, &[0.0; 3], &[1.0, 2.0, 3.0], &[0.0; 3]) - .unwrap(); + let v = Self::eval_xyz(&tape, &[0.0; 3], &[1.0, 2.0, 3.0], &[0.0; 3]); v[0].compare_eq(Grad::new(2f32.sin(), 0.0, 2.0 * 2f32.cos(), 0.0)); v[1].compare_eq(Grad::new(4f32.sin(), 0.0, 2.0 * 4f32.cos(), 0.0)); v[2].compare_eq(Grad::new(6f32.sin(), 0.0, 2.0 * 6f32.cos(), 0.0)); @@ -155,22 +163,21 @@ where let s = ctx.mul(x, y).unwrap(); let shape = S::new(&ctx, s).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[1.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[1.0], &[0.0], &[0.0])[0], Grad::new(0.0, 0.0, 1.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[0.0], &[1.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[0.0], &[1.0], &[0.0])[0], Grad::new(0.0, 1.0, 0.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[4.0], &[1.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[4.0], &[1.0], &[0.0])[0], Grad::new(4.0, 1.0, 4.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[4.0], &[2.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[4.0], &[2.0], &[0.0])[0], Grad::new(8.0, 2.0, 4.0, 0.0) ); } @@ -181,10 +188,9 @@ where let s = ctx.div(x, 2.0).unwrap(); let shape = S::new(&ctx, s).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[1.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[1.0], &[0.0], &[0.0])[0], Grad::new(0.5, 0.5, 0.0, 0.0) ); } @@ -195,14 +201,13 @@ where let s = ctx.recip(x).unwrap(); let shape = S::new(&ctx, s).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[1.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[1.0], &[0.0], &[0.0])[0], Grad::new(1.0, -1.0, 0.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[2.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[2.0], &[0.0], &[0.0])[0], Grad::new(0.5, -0.25, 0.0, 0.0) ); } @@ -214,14 +219,13 @@ where let m = ctx.min(x, y).unwrap(); let shape = S::new(&ctx, m).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[2.0], &[3.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[2.0], &[3.0], &[0.0])[0], Grad::new(2.0, 1.0, 0.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[4.0], &[3.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[4.0], &[3.0], &[0.0])[0], Grad::new(3.0, 0.0, 1.0, 0.0) ); } @@ -235,18 +239,17 @@ where let max = ctx.max(min, z).unwrap(); let shape = S::new(&ctx, max).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[2.0], &[3.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[2.0], &[3.0], &[0.0])[0], Grad::new(2.0, 1.0, 0.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[4.0], &[3.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[4.0], &[3.0], &[0.0])[0], Grad::new(3.0, 0.0, 1.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[4.0], &[3.0], &[5.0]).unwrap()[0], + Self::eval_xyz(&tape, &[4.0], &[3.0], &[5.0])[0], Grad::new(5.0, 0.0, 0.0, 1.0) ); } @@ -258,14 +261,13 @@ where let m = ctx.max(x, y).unwrap(); let shape = S::new(&ctx, m).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[2.0], &[3.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[2.0], &[3.0], &[0.0])[0], Grad::new(3.0, 0.0, 1.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[4.0], &[3.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[4.0], &[3.0], &[0.0])[0], Grad::new(4.0, 1.0, 0.0, 0.0) ); } @@ -276,10 +278,9 @@ where let m = ctx.not(x).unwrap(); let shape = S::new(&ctx, m).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[0.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[0.0], &[0.0], &[0.0])[0], Grad::new(1.0, 0.0, 0.0, 0.0) ); } @@ -296,22 +297,21 @@ where let sub = ctx.sub(sqrt, 0.5).unwrap(); let shape = S::new(&ctx, sub).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); assert_eq!( - eval.eval(&tape, &[1.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[1.0], &[0.0], &[0.0])[0], Grad::new(0.5, 1.0, 0.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[0.0], &[1.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[0.0], &[1.0], &[0.0])[0], Grad::new(0.5, 0.0, 1.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[2.0], &[0.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[2.0], &[0.0], &[0.0])[0], Grad::new(1.5, 1.0, 0.0, 0.0) ); assert_eq!( - eval.eval(&tape, &[0.0], &[2.0], &[0.0]).unwrap()[0], + Self::eval_xyz(&tape, &[0.0], &[2.0], &[0.0])[0], Grad::new(1.5, 0.0, 1.0, 0.0) ); } @@ -328,10 +328,9 @@ where args[2..].iter().chain(&args[0..2]).cloned().collect(); let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); - let out = eval.eval(&tape, &x, &y, &z).unwrap(); + let out = Self::eval_xyz(&tape, &x, &y, &z); // Compare values (the `.v` term) with the context's evaluator for (i, v) in out.iter().cloned().enumerate() { @@ -354,10 +353,9 @@ where // expensive, so we'll do it regardless. use crate::vm::VmShape; let shape = VmShape::new(&ctx, node).unwrap(); - let mut eval = VmShape::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); - let cmp = eval.eval(&tape, &x, &y, &z).unwrap(); + let cmp = TestGradSlice::::eval_xyz(&tape, &x, &y, &z); for (a, b) in out.iter().zip(cmp.iter()) { a.compare_eq(*b) } @@ -378,16 +376,14 @@ where let node = C::build(&mut ctx, v); let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); let out = match i { - 0 => eval.eval(&tape, &args, &zero, &zero), - 1 => eval.eval(&tape, &zero, &args, &zero), - 2 => eval.eval(&tape, &zero, &zero, &args), + 0 => Self::eval_xyz(&tape, &args, &zero, &zero), + 1 => Self::eval_xyz(&tape, &zero, &args, &zero), + 2 => Self::eval_xyz(&tape, &zero, &zero, &args), _ => unreachable!(), - } - .unwrap(); + }; for (a, &o) in args.iter().zip(out.iter()) { let v = C::eval_f64(*a as f64); let err = (v as f32 - o.v).abs(); @@ -533,22 +529,20 @@ where let node = C::build(&mut ctx, v, u); let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); let out = match (i, j) { - (0, 0) => eval.eval(&tape, &args, &zero, &zero), - (0, 1) => eval.eval(&tape, &args, &rgsa, &zero), - (0, 2) => eval.eval(&tape, &args, &zero, &rgsa), - (1, 0) => eval.eval(&tape, &rgsa, &args, &zero), - (1, 1) => eval.eval(&tape, &zero, &args, &zero), - (1, 2) => eval.eval(&tape, &zero, &args, &rgsa), - (2, 0) => eval.eval(&tape, &rgsa, &zero, &args), - (2, 1) => eval.eval(&tape, &zero, &rgsa, &args), - (2, 2) => eval.eval(&tape, &zero, &zero, &args), + (0, 0) => Self::eval_xyz(&tape, &args, &zero, &zero), + (0, 1) => Self::eval_xyz(&tape, &args, &rgsa, &zero), + (0, 2) => Self::eval_xyz(&tape, &args, &zero, &rgsa), + (1, 0) => Self::eval_xyz(&tape, &rgsa, &args, &zero), + (1, 1) => Self::eval_xyz(&tape, &zero, &args, &zero), + (1, 2) => Self::eval_xyz(&tape, &zero, &args, &rgsa), + (2, 0) => Self::eval_xyz(&tape, &rgsa, &zero, &args), + (2, 1) => Self::eval_xyz(&tape, &zero, &rgsa, &args), + (2, 2) => Self::eval_xyz(&tape, &zero, &zero, &args), _ => unreachable!(), - } - .unwrap(); + }; let rhs = if i == j { &args } else { &rgsa }; Self::compare_grad_results::( @@ -556,7 +550,7 @@ where j, &args, rhs, - out, + &out, C::eval_reg_reg_f64, &name, ); @@ -582,16 +576,14 @@ where let node = C::build(&mut ctx, v, c); let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); let out = match i { - 0 => eval.eval(&tape, &args, &zero, &zero), - 1 => eval.eval(&tape, &zero, &args, &zero), - 2 => eval.eval(&tape, &zero, &zero, &args), + 0 => Self::eval_xyz(&tape, &args, &zero, &zero), + 1 => Self::eval_xyz(&tape, &zero, &args, &zero), + 2 => Self::eval_xyz(&tape, &zero, &zero, &args), _ => unreachable!(), - } - .unwrap(); + }; let rhs = vec![*rhs; out.len()]; Self::compare_grad_results::( @@ -599,7 +591,7 @@ where 3, &args, &rhs, - out, + &out, C::eval_reg_imm_f64, &name, ); @@ -625,16 +617,14 @@ where let node = C::build(&mut ctx, c, v); let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_grad_slice_eval(); let tape = shape.ez_grad_slice_tape(); let out = match i { - 0 => eval.eval(&tape, &args, &zero, &zero), - 1 => eval.eval(&tape, &zero, &args, &zero), - 2 => eval.eval(&tape, &zero, &zero, &args), + 0 => Self::eval_xyz(&tape, &args, &zero, &zero), + 1 => Self::eval_xyz(&tape, &zero, &args, &zero), + 2 => Self::eval_xyz(&tape, &zero, &zero, &args), _ => unreachable!(), - } - .unwrap(); + }; let lhs = vec![*lhs; out.len()]; Self::compare_grad_results::( @@ -642,7 +632,7 @@ where i, &lhs, &args, - out, + &out, C::eval_imm_reg_f64, &name, ); diff --git a/fidget/src/core/eval/transform.rs b/fidget/src/core/eval/transform.rs index 5103ec75..882cf71d 100644 --- a/fidget/src/core/eval/transform.rs +++ b/fidget/src/core/eval/transform.rs @@ -1,5 +1,5 @@ use crate::{ - eval::{BulkEvaluator, Interval, Shape, Tape, TracingEvaluator}, + eval::{BulkEvaluator, Grad, Interval, Shape, Tape, TracingEvaluator}, Error, }; use nalgebra::{Matrix4, Point3, Vector3}; @@ -91,6 +91,22 @@ impl Transformable for Interval { } } +impl Transformable for Grad { + fn transform( + x: Grad, + y: Grad, + z: Grad, + mat: Matrix4, + ) -> (Grad, Grad, Grad) { + let out = [0, 1, 2, 3].map(|i| { + let row = mat.row(i); + x * row[0] + y * row[1] + z * row[2] + Grad::from(row[3]) + }); + + (out[0] / out[3], out[1] / out[3], out[2] / out[3]) + } +} + impl TracingEvaluator for TransformedTracingEval where ::Data: Transformable, @@ -115,37 +131,51 @@ where } /// A generic [`BulkEvaluator`] which applies a transform matrix -#[derive(Default)] -pub struct TransformedBulkEval { +pub struct TransformedBulkEval { eval: E, - xs: Vec, - ys: Vec, - zs: Vec, + xs: Vec, + ys: Vec, + zs: Vec, +} + +impl Default for TransformedBulkEval { + fn default() -> Self { + Self { + eval: E::default(), + xs: vec![], + ys: vec![], + zs: vec![], + } + } } -impl BulkEvaluator for TransformedBulkEval { - type Data = ::Data; - type Tape = TransformedTape<::Tape>; - type TapeStorage = ::TapeStorage; +impl BulkEvaluator for TransformedBulkEval +where + ::Data: Transformable, +{ + type Data = ::Data; + type Tape = TransformedTape<::Tape>; + type TapeStorage = ::TapeStorage; fn eval( &mut self, tape: &Self::Tape, - x: &[f32], - y: &[f32], - z: &[f32], + x: &[E::Data], + y: &[E::Data], + z: &[E::Data], ) -> Result<&[Self::Data], Error> { if x.len() != y.len() || x.len() != z.len() { return Err(Error::MismatchedSlices); } let n = x.len(); - self.xs.resize(n, 0.0); - self.ys.resize(n, 0.0); - self.zs.resize(n, 0.0); + self.xs.resize(n, E::Data::from(0.0)); + self.ys.resize(n, E::Data::from(0.0)); + self.zs.resize(n, E::Data::from(0.0)); for i in 0..x.len() { - let p = tape.mat.transform_point(&Point3::new(x[i], y[i], z[i])); - self.xs[i] = p.x; - self.ys[i] = p.y; - self.zs[i] = p.z; + let (x, y, z) = + Transformable::transform(x[i], y[i], z[i], tape.mat); + self.xs[i] = x; + self.ys[i] = y; + self.zs[i] = z; } self.eval.eval(&tape.tape, &self.xs, &self.ys, &self.zs) } diff --git a/fidget/src/core/types/grad.rs b/fidget/src/core/types/grad.rs index 3858a981..d7ca5f6e 100644 --- a/fidget/src/core/types/grad.rs +++ b/fidget/src/core/types/grad.rs @@ -295,6 +295,18 @@ impl std::ops::Mul for Grad { } } +impl std::ops::Mul for Grad { + type Output = Self; + fn mul(self, rhs: f32) -> Self { + Self { + v: self.v * rhs, + dx: self.dx * rhs, + dy: self.dy * rhs, + dz: self.dz * rhs, + } + } +} + impl std::ops::Div for Grad { type Output = Self; fn div(self, rhs: Self) -> Self { diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index 8e060e5d..c37a7a83 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -1142,9 +1142,9 @@ impl BulkEvaluator for VmGradSliceEval { fn eval( &mut self, tape: &Self::Tape, - xs: &[f32], - ys: &[f32], - zs: &[f32], + xs: &[Grad], + ys: &[Grad], + zs: &[Grad], ) -> Result<&[Grad], Error> { let tape = tape.0.as_ref(); self.check_arguments(xs, ys, zs, tape.var_count())?; @@ -1159,9 +1159,9 @@ impl BulkEvaluator for VmGradSliceEval { RegOp::Input(out, j) => { for i in 0..size { v[out][i] = match j { - 0 => Grad::new(xs[i], 1.0, 0.0, 0.0), - 1 => Grad::new(ys[i], 0.0, 1.0, 0.0), - 2 => Grad::new(zs[i], 0.0, 0.0, 1.0), + 0 => xs[i], + 1 => ys[i], + 2 => zs[i], _ => panic!("Invalid input: {}", i), } } @@ -1265,7 +1265,7 @@ impl BulkEvaluator for VmGradSliceEval { } RegOp::MulRegImm(out, arg, imm) => { for i in 0..size { - v[out][i] = v[arg][i] * imm.into(); + v[out][i] = v[arg][i] * imm; } } RegOp::DivRegImm(out, arg, imm) => { diff --git a/fidget/src/jit/aarch64/float_slice.rs b/fidget/src/jit/aarch64/float_slice.rs index 1edc816e..c7bb3121 100644 --- a/fidget/src/jit/aarch64/float_slice.rs +++ b/fidget/src/jit/aarch64/float_slice.rs @@ -8,11 +8,11 @@ pub const SIMD_WIDTH: usize = 4; /// Assembler for SIMD point-wise evaluation on `aarch64` /// -/// | Argument | Register | Type | -/// | ---------|----------|--------------------------| -/// | `vars` | `x0` | `*mut *const [f32; 4]` | -/// | out | `x1` | `*mut [f32; 4]` | -/// | size | `x2` | `u64` | +/// | Argument | Register | Type | +/// | ---------|----------|----------------------------| +/// | `vars` | `x0` | `*const *const [f32; 4]` | +/// | `out` | `x1` | `*mut [f32; 4]` | +/// | `count` | `x2` | `u64` | /// /// The arrays must be an even multiple of 4 floats, since we're using NEON and /// 128-bit wide operations for everything. diff --git a/fidget/src/jit/aarch64/grad_slice.rs b/fidget/src/jit/aarch64/grad_slice.rs index e439f36c..2a53db33 100644 --- a/fidget/src/jit/aarch64/grad_slice.rs +++ b/fidget/src/jit/aarch64/grad_slice.rs @@ -12,14 +12,15 @@ use dynasmrt::{dynasm, DynasmApi, DynasmLabelApi}; /// /// Registers as pased in as follows: /// -/// | Variable | Register | Type | -/// |------------|----------|--------------------| -/// | X | `x0` | `*const f32` | -/// | `out` | `x1` | `*const [f32; 4]` | -/// | `count` | `x2` | `u64` | +/// | Variable | Register | Type | +/// |------------|----------|---------------------------| +/// | `vars` | `x0` | `*const *const [f32; 4]` | +/// | `out` | `x1` | `*mut [f32; 4]` | +/// | `count` | `x2` | `u64` | /// -/// During evaluation, X, Y, and Z are stored in `V0-3.S4`. Each SIMD register -/// is in the order `[value, dx, dy, dz]`, e.g. the value for X is in `V0.S0`. +/// During evaluation, variables are loaded into SIMD registers in the order +/// `[value, dx, dy, dz]`, e.g. if we load `vars[0]` into `V0`, its value would +/// be in `V0.S0` (and the three partial derivatives would be in `V0.S{1,2,3}`). /// /// In addition to the registers above (`x0-5`), the following extra registers /// are used during evaluation: @@ -41,7 +42,7 @@ use dynasmrt::{dynasm, DynasmApi, DynasmLabelApi}; /// |----------|--------------|---------------------------------------------| /// | 0x220 | ... | Register spills live up here | /// |----------|--------------|---------------------------------------------| -/// | 0x218 | `x23` | Backup for callee-saved register | +/// | 0x218 | `x23` | Backup for callee-saved registers | /// | 0x210 | `x22` | | /// | 0x208 | `x21` | | /// | 0x200 | `x20` | | @@ -163,16 +164,8 @@ impl Assembler for GradSliceAssembler { ; ldr x4, [x0, src_arg as u32 * 8] ; add x4, x4, x3 // apply array offset ; eor V(reg(out_reg)).b16, V(reg(out_reg)).b16, V(reg(out_reg)).b16 - ; ldr S(reg(out_reg)), [x4] - ; fmov s6, 1.0 + ; ldr Q(reg(out_reg)), [x4] ); - // Load the gradient, which is a 1.0 - match src_arg % 3 { - 0 => dynasm!(self.0.ops ; mov V(reg(out_reg)).S[1], v6.S[0]), - 1 => dynasm!(self.0.ops ; mov V(reg(out_reg)).S[2], v6.S[0]), - 2 => dynasm!(self.0.ops ; mov V(reg(out_reg)).S[3], v6.S[0]), - _ => unreachable!(), - } } fn build_sin(&mut self, out_reg: u8, lhs_reg: u8) { extern "C" fn grad_sin(v: Grad) -> Grad { @@ -499,7 +492,7 @@ impl Assembler for GradSliceAssembler { ; sub x2, x2, 1 // We handle 1 item at a time // Adjust the array offset pointer - ; add x3, x3, 4 // 1 item = 4 bytes + ; add x3, x3, 16 // 1 item = 16 bytes // Prepare our return value, writing to the pointer in x1 ; str Q(reg(out_reg)), [x1], 16 diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index 6ce20f81..4424d1e4 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -1068,9 +1068,9 @@ pub struct JitBulkFn { var_count: usize, fn_bulk: jit_fn!( unsafe fn( - *const *const f32, // vars - *mut T, // out - u64, // size + *const *const T, // vars + *mut T, // out + u64, // size ) -> T ), } @@ -1104,9 +1104,9 @@ impl + Copy + SimdSize> JitBulkEval { fn eval( &mut self, tape: &JitBulkFn, - xs: &[f32], - ys: &[f32], - zs: &[f32], + xs: &[T], + ys: &[T], + zs: &[T], ) -> &[T] { assert!(tape.var_count <= 3); let n = xs.len(); @@ -1122,9 +1122,9 @@ impl + Copy + SimdSize> JitBulkEval { // that should be optimized out; we can't use a constant assertion // here due to the same compiler limitations. const MAX_SIMD_WIDTH: usize = 8; - let mut x = [0.0; MAX_SIMD_WIDTH]; - let mut y = [0.0; MAX_SIMD_WIDTH]; - let mut z = [0.0; MAX_SIMD_WIDTH]; + let mut x = [T::from(0.0); MAX_SIMD_WIDTH]; + let mut y = [T::from(0.0); MAX_SIMD_WIDTH]; + let mut z = [T::from(0.0); MAX_SIMD_WIDTH]; assert!(T::SIMD_SIZE <= MAX_SIMD_WIDTH); x[0..n].copy_from_slice(xs); @@ -1203,9 +1203,9 @@ impl BulkEvaluator for JitGradSliceEval { fn eval( &mut self, tape: &Self::Tape, - xs: &[f32], - ys: &[f32], - zs: &[f32], + xs: &[Self::Data], + ys: &[Self::Data], + zs: &[Self::Data], ) -> Result<&[Self::Data], Error> { self.check_arguments(xs, ys, zs, tape.var_count)?; Ok(self.0.eval(tape, xs, ys, zs)) diff --git a/fidget/src/jit/x86_64/grad_slice.rs b/fidget/src/jit/x86_64/grad_slice.rs index 7b904f7f..a33b09c8 100644 --- a/fidget/src/jit/x86_64/grad_slice.rs +++ b/fidget/src/jit/x86_64/grad_slice.rs @@ -12,11 +12,11 @@ use dynasmrt::{dynasm, DynasmApi, DynasmLabelApi}; /// /// Registers as pased in as follows: /// -/// | Variable | Register | Type | -/// |------------|----------|--------------------| -/// | vars | `rdi` | `*const f32` | -/// | `out` | `rsi` | `*const [f32; 4]` | -/// | `count` | `rdx` | `u64` | +/// | Variable | Register | Type | +/// |------------|----------|--------------------------| +/// | `vars` | `rdi` | `*const *const [f32; 4]` | +/// | `out` | `rsi` | `*mut [f32; 4]` | +/// | `count` | `rdx` | `u64` | /// /// During evaluation, `rcx` is used to track offset within `vars`. /// @@ -87,27 +87,17 @@ impl Assembler for GradSliceAssembler { + STACK_SIZE_LOWER as u32) .try_into() .unwrap(); + // XXX could we use vmovaps here instead? dynasm!(self.0.ops ; vmovups [rsp + sp_offset], Rx(reg(src_reg)) ); } fn build_input(&mut self, out_reg: u8, src_arg: u8) { - // upper 2 bits are insert position (COUNT_D), lower 4 are ZMASK - let imm = match src_arg % 3 { - 0 => 0b01_1100, - 1 => 0b10_1010, - 2 => 0b11_0110, - _ => unreachable!(), - }; let pos = 8 * (src_arg as i32); // offset within the pointer array dynasm!(self.0.ops ; mov r8, [rdi + pos] // read the *const float from the array ; add r8, rcx // offset it by array position - ; vmovss Rx(reg(out_reg)), [r8] - - ; mov eax, 1.0f32.to_bits() as i32 - ; movd xmm1, eax - ; vinsertps Rx(reg(out_reg)), Rx(reg(out_reg)), xmm1, imm + ; vmovaps Rx(reg(out_reg)), [r8] ); } fn build_sin(&mut self, out_reg: u8, lhs_reg: u8) { @@ -482,7 +472,7 @@ impl Assembler for GradSliceAssembler { ; vmovups [rsi], Rx(reg(out_reg)) ; add rsi, 16 // 4x float ; sub rdx, 1 // we process one element at a time - ; add rcx, 4 // input is array is single floats + ; add rcx, 16 // input is array is Grad (f32 x 4) ; jmp ->L // Finalization code, which happens after all evaluation is complete diff --git a/fidget/src/mesh/octree.rs b/fidget/src/mesh/octree.rs index 01cb9312..4fcdcd35 100644 --- a/fidget/src/mesh/octree.rs +++ b/fidget/src/mesh/octree.rs @@ -10,7 +10,10 @@ use super::{ types::{Axis, Corner, Edge}, Mesh, Settings, }; -use crate::eval::{BulkEvaluator, Shape, Tape, TracingEvaluator}; +use crate::{ + eval::{BulkEvaluator, Shape, Tape, TracingEvaluator}, + types::Grad, +}; use std::{num::NonZeroUsize, sync::Arc, sync::OnceLock}; #[cfg(not(target_arch = "wasm32"))] @@ -538,11 +541,11 @@ impl OctreeBuilder { const EDGE_SEARCH_SIZE: usize = 16; const EDGE_SEARCH_DEPTH: usize = 4; let xs = - &mut [0f32; 12 * EDGE_SEARCH_SIZE][..edge_count * EDGE_SEARCH_SIZE]; + &mut [0.0; 12 * EDGE_SEARCH_SIZE][..edge_count * EDGE_SEARCH_SIZE]; let ys = - &mut [0f32; 12 * EDGE_SEARCH_SIZE][..edge_count * EDGE_SEARCH_SIZE]; + &mut [0.0; 12 * EDGE_SEARCH_SIZE][..edge_count * EDGE_SEARCH_SIZE]; let zs = - &mut [0f32; 12 * EDGE_SEARCH_SIZE][..edge_count * EDGE_SEARCH_SIZE]; + &mut [0.0; 12 * EDGE_SEARCH_SIZE][..edge_count * EDGE_SEARCH_SIZE]; // This part looks hairy, but it's just doing an N-ary search along each // edge to find the intersection point. @@ -618,11 +621,18 @@ impl OctreeBuilder { }) .collect(); + let xs = &mut [Grad::from(0.0); 12 * EDGE_SEARCH_SIZE] + [..intersections.len()]; + let ys = &mut [Grad::from(0.0); 12 * EDGE_SEARCH_SIZE] + [..intersections.len()]; + let zs = &mut [Grad::from(0.0); 12 * EDGE_SEARCH_SIZE] + [..intersections.len()]; + for (i, xyz) in intersections.iter().enumerate() { let pos = cell.pos(*xyz); - xs[i] = pos.x; - ys[i] = pos.y; - zs[i] = pos.z; + xs[i] = Grad::new(pos.x, 1.0, 0.0, 0.0); + ys[i] = Grad::new(pos.y, 0.0, 1.0, 0.0); + zs[i] = Grad::new(pos.z, 0.0, 0.0, 1.0); } // TODO: special case for cells with multiple gradients ("features") @@ -638,7 +648,7 @@ impl OctreeBuilder { for vs in CELL_TO_VERT_TO_EDGES[mask as usize].iter() { let mut qef = QuadraticErrorSolver::new(); for e in vs.iter() { - let pos = nalgebra::Vector3::new(xs[i], ys[i], zs[i]); + let pos = nalgebra::Vector3::new(xs[i].v, ys[i].v, zs[i].v); let grad: nalgebra::Vector4 = grads[i].into(); qef.add_intersection(pos, grad); diff --git a/fidget/src/render/render3d.rs b/fidget/src/render/render3d.rs index 9ac12a60..0df2c78b 100644 --- a/fidget/src/render/render3d.rs +++ b/fidget/src/render/render3d.rs @@ -3,7 +3,7 @@ use super::RenderHandle; use crate::{ eval::{BulkEvaluator, Shape, TracingEvaluator}, render::config::{AlignedRenderConfig, Queue, RenderConfig, Tile}, - types::Interval, + types::{Grad, Interval}, }; use nalgebra::Point3; @@ -16,6 +16,10 @@ struct Scratch { y: Vec, z: Vec, + xg: Vec, + yg: Vec, + zg: Vec, + /// Depth of each column columns: Vec, } @@ -29,6 +33,10 @@ impl Scratch { y: vec![0.0; size3], z: vec![0.0; size3], + xg: vec![Grad::from(0.0); size2], + yg: vec![Grad::from(0.0); size2], + zg: vec![Grad::from(0.0); size2], + columns: vec![0; size2], } } @@ -222,9 +230,12 @@ impl Worker<'_, S> { // We step one voxel above the surface to reduce // glitchiness on edges and corners, where rendering // inside the surface could pick the wrong normal. - self.scratch.x[grad] = (tile.corner[0] + i) as f32; - self.scratch.y[grad] = (tile.corner[1] + j) as f32; - self.scratch.z[grad] = (tile.corner[2] + k) as f32; + self.scratch.xg[grad] = + Grad::new((tile.corner[0] + i) as f32, 1.0, 0.0, 0.0); + self.scratch.yg[grad] = + Grad::new((tile.corner[1] + j) as f32, 0.0, 1.0, 0.0); + self.scratch.zg[grad] = + Grad::new((tile.corner[2] + k) as f32, 0.0, 0.0, 1.0); // This can only be called once per iteration, so we'll // never overwrite parts of columns that are still used @@ -238,9 +249,9 @@ impl Worker<'_, S> { .eval_grad_slice .eval( shape.g_tape(&mut self.tape_storage), - &self.scratch.x[..grad], - &self.scratch.y[..grad], - &self.scratch.z[..grad], + &self.scratch.xg[..grad], + &self.scratch.yg[..grad], + &self.scratch.zg[..grad], ) .unwrap();