diff --git a/CHANGELOG.md b/CHANGELOG.md index aada9fa3..4a24920a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,15 @@ -# 0.2.8 (unreleased) +# 0.2.8 +- Major refactoring of core evaluation traits + - The lowest-level "thing that can be evaluated" trait has changed from + `Shape` (taking `(x, y, z)` inputs) to `Function` (taking an arbitrary + number of variables). + - `Shape` is now a wrapper around a `F: Function` instead of a trait. + - Shape evaluators are now wrappers around `E: BulkEvaluator` or `E: + TracingEvaluator`, which convert `(x, y, z)` arguments into + list-of-variables arguments. + - Using the `VmShape` or `JitShape` types should be mostly the same as + before; changes are most noticeable if you're writing things that are + generic across `S: Shape`. # 0.2.7 This release brings us to opcode parity with `libfive`'s operators, adding diff --git a/demo/src/main.rs b/demo/src/main.rs index b037fdd7..5537019e 100644 --- a/demo/src/main.rs +++ b/demo/src/main.rs @@ -7,10 +7,7 @@ use clap::{Parser, Subcommand, ValueEnum}; use env_logger::Env; use log::info; -use fidget::{ - context::Context, - eval::{BulkEvaluator, MathShape}, -}; +use fidget::context::Context; /// Simple test program #[derive(Parser)] @@ -112,8 +109,8 @@ struct MeshSettings { } //////////////////////////////////////////////////////////////////////////////// -fn run3d( - shape: S, +fn run3d( + shape: fidget::shape::Shape, settings: &ImageSettings, isometric: bool, mode_color: bool, @@ -124,7 +121,7 @@ fn run3d( } let cfg = fidget::render::RenderConfig { image_size: settings.size as usize, - tile_sizes: S::tile_sizes_3d().to_vec(), + tile_sizes: F::tile_sizes_3d().to_vec(), threads: settings.threads, ..Default::default() }; @@ -168,15 +165,15 @@ fn run3d( //////////////////////////////////////////////////////////////////////////////// -fn run2d( - shape: S, +fn run2d( + shape: fidget::shape::Shape, settings: &ImageSettings, brute: bool, sdf: bool, ) -> Vec { if brute { let tape = shape.float_slice_tape(Default::default()); - let mut eval = S::new_float_slice_eval(); + let mut eval = fidget::shape::Shape::::new_float_slice_eval(); let mut out: Vec = vec![]; for _ in 0..settings.n { let mut xs = vec![]; @@ -202,7 +199,7 @@ fn run2d( } else { let cfg = fidget::render::RenderConfig { image_size: settings.size as usize, - tile_sizes: S::tile_sizes_2d().to_vec(), + tile_sizes: F::tile_sizes_2d().to_vec(), threads: settings.threads, ..Default::default() }; @@ -236,13 +233,10 @@ fn run2d( //////////////////////////////////////////////////////////////////////////////// -fn run_mesh( - shape: S, +fn run_mesh( + shape: fidget::shape::Shape, settings: &MeshSettings, -) -> fidget::mesh::Mesh -where - ::TransformedShape: fidget::shape::RenderHints, -{ +) -> fidget::mesh::Mesh { let mut mesh = fidget::mesh::Mesh::new(); for _ in 0..settings.n { @@ -264,7 +258,7 @@ fn main() -> Result<()> { let now = Instant::now(); let args = Args::parse(); let mut file = std::fs::File::open(&args.input)?; - let (ctx, root) = Context::from_text(&mut file)?; + let (mut ctx, root) = Context::from_text(&mut file)?; info!("Loaded file in {:?}", now.elapsed()); match args.cmd { @@ -277,12 +271,12 @@ fn main() -> Result<()> { let buffer = match settings.eval { #[cfg(feature = "jit")] EvalMode::Jit => { - let shape = fidget::jit::JitShape::new(&ctx, root)?; + let shape = fidget::jit::JitShape::new(&mut ctx, root)?; info!("Built shape in {:?}", start.elapsed()); run2d(shape, &settings, brute, sdf) } EvalMode::Vm => { - let shape = fidget::vm::VmShape::new(&ctx, root)?; + let shape = fidget::vm::VmShape::new(&mut ctx, root)?; info!("Built shape in {:?}", start.elapsed()); run2d(shape, &settings, brute, sdf) } @@ -314,12 +308,12 @@ fn main() -> Result<()> { let buffer = match settings.eval { #[cfg(feature = "jit")] EvalMode::Jit => { - let shape = fidget::jit::JitShape::new(&ctx, root)?; + let shape = fidget::jit::JitShape::new(&mut ctx, root)?; info!("Built shape in {:?}", start.elapsed()); run3d(shape, &settings, isometric, color) } EvalMode::Vm => { - let shape = fidget::vm::VmShape::new(&ctx, root)?; + let shape = fidget::vm::VmShape::new(&mut ctx, root)?; info!("Built shape in {:?}", start.elapsed()); run3d(shape, &settings, isometric, color) } @@ -348,12 +342,12 @@ fn main() -> Result<()> { let mesh = match settings.eval { #[cfg(feature = "jit")] EvalMode::Jit => { - let shape = fidget::jit::JitShape::new(&ctx, root)?; + let shape = fidget::jit::JitShape::new(&mut ctx, root)?; info!("Built shape in {:?}", start.elapsed()); run_mesh(shape, &settings) } EvalMode::Vm => { - let shape = fidget::vm::VmShape::new(&ctx, root)?; + let shape = fidget::vm::VmShape::new(&mut ctx, root)?; info!("Built shape in {:?}", start.elapsed()); run_mesh(shape, &settings) } diff --git a/fidget/benches/function_call.rs b/fidget/benches/function_call.rs index 10f428ac..34c0d7b6 100644 --- a/fidget/benches/function_call.rs +++ b/fidget/benches/function_call.rs @@ -3,19 +3,20 @@ use criterion::{ }; use fidget::{ context::{Context, Node}, - eval::{BulkEvaluator, EzShape, MathShape, Shape}, + eval::{Function, MathFunction}, + shape::{EzShape, Shape}, }; -pub fn run_bench( +pub fn run_bench( c: &mut Criterion, - ctx: Context, + mut ctx: Context, node: Node, test_name: &'static str, name: &'static str, ) { - let shape_vm = &S::new(&ctx, node).unwrap(); + let shape_vm = &Shape::::new(&mut ctx, node).unwrap(); - let mut eval = S::new_float_slice_eval(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape_vm.ez_float_slice_tape(); let mut group = c.benchmark_group(test_name); @@ -30,7 +31,7 @@ pub fn run_bench( } } -pub fn test_single_fn( +pub fn test_single_fn( c: &mut Criterion, name: &'static str, ) { @@ -38,10 +39,10 @@ pub fn test_single_fn( let x = ctx.x(); let f = ctx.sin(x).unwrap(); - run_bench::(c, ctx, f, "single function", name); + run_bench::(c, ctx, f, "single function", name); } -pub fn test_many_fn( +pub fn test_many_fn( c: &mut Criterion, name: &'static str, ) { @@ -56,19 +57,19 @@ pub fn test_many_fn( let out = ctx.add(f, g).unwrap(); let out = ctx.add(out, h).unwrap(); - run_bench::(c, ctx, out, "many functions", name); + run_bench::(c, ctx, out, "many functions", name); } pub fn test_single_fns(c: &mut Criterion) { - test_single_fn::(c, "vm"); + test_single_fn::(c, "vm"); #[cfg(feature = "jit")] - test_single_fn::(c, "jit"); + test_single_fn::(c, "jit"); } pub fn test_many_fns(c: &mut Criterion) { - test_many_fn::(c, "vm"); + test_many_fn::(c, "vm"); #[cfg(feature = "jit")] - test_many_fn::(c, "jit"); + test_many_fn::(c, "jit"); } criterion_group!(benches, test_single_fns, test_many_fns); diff --git a/fidget/benches/mesh.rs b/fidget/benches/mesh.rs index 63b488f4..000aa22b 100644 --- a/fidget/benches/mesh.rs +++ b/fidget/benches/mesh.rs @@ -1,15 +1,15 @@ use criterion::{ black_box, criterion_group, criterion_main, BenchmarkId, Criterion, }; -use fidget::eval::MathShape; const COLONNADE: &str = include_str!("../../models/colonnade.vm"); pub fn colonnade_octree_thread_sweep(c: &mut Criterion) { - let (ctx, root) = fidget::Context::from_text(COLONNADE.as_bytes()).unwrap(); - let shape_vm = &fidget::vm::VmShape::new(&ctx, root).unwrap(); + let (mut ctx, root) = + fidget::Context::from_text(COLONNADE.as_bytes()).unwrap(); + let shape_vm = &fidget::vm::VmShape::new(&mut ctx, root).unwrap(); #[cfg(feature = "jit")] - let shape_jit = &fidget::jit::JitShape::new(&ctx, root).unwrap(); + let shape_jit = &fidget::jit::JitShape::new(&mut ctx, root).unwrap(); let mut group = c.benchmark_group("speed vs threads (colonnade, octree) (depth 6)"); @@ -36,8 +36,9 @@ pub fn colonnade_octree_thread_sweep(c: &mut Criterion) { } pub fn colonnade_mesh(c: &mut Criterion) { - let (ctx, root) = fidget::Context::from_text(COLONNADE.as_bytes()).unwrap(); - let shape_vm = &fidget::vm::VmShape::new(&ctx, root).unwrap(); + let (mut ctx, root) = + fidget::Context::from_text(COLONNADE.as_bytes()).unwrap(); + let shape_vm = &fidget::vm::VmShape::new(&mut ctx, root).unwrap(); let cfg = fidget::mesh::Settings { depth: 8, ..Default::default() diff --git a/fidget/benches/render.rs b/fidget/benches/render.rs index 204d46c1..a6199b6f 100644 --- a/fidget/benches/render.rs +++ b/fidget/benches/render.rs @@ -1,24 +1,24 @@ use criterion::{ black_box, criterion_group, criterion_main, BenchmarkId, Criterion, }; +use fidget::shape::RenderHints; const PROSPERO: &str = include_str!("../../models/prospero.vm"); -use fidget::{eval::MathShape, shape::RenderHints}; - pub fn prospero_size_sweep(c: &mut Criterion) { - let (ctx, root) = fidget::Context::from_text(PROSPERO.as_bytes()).unwrap(); - let shape_vm = &fidget::vm::VmShape::new(&ctx, root).unwrap(); + let (mut ctx, root) = + fidget::Context::from_text(PROSPERO.as_bytes()).unwrap(); + let shape_vm = &fidget::vm::VmShape::new(&mut ctx, root).unwrap(); #[cfg(feature = "jit")] - let shape_jit = &fidget::jit::JitShape::new(&ctx, root).unwrap(); + let shape_jit = &fidget::jit::JitShape::new(&mut ctx, root).unwrap(); let mut group = c.benchmark_group("speed vs image size (prospero, 2d) (8 threads)"); for size in [256, 512, 768, 1024, 1280, 1546, 1792, 2048] { let cfg = &fidget::render::RenderConfig { image_size: size, - tile_sizes: fidget::vm::VmShape::tile_sizes_2d().to_vec(), + tile_sizes: fidget::vm::VmFunction::tile_sizes_2d().to_vec(), ..Default::default() }; group.bench_function(BenchmarkId::new("vm", size), move |b| { @@ -35,7 +35,7 @@ pub fn prospero_size_sweep(c: &mut Criterion) { { let cfg = &fidget::render::RenderConfig { image_size: size, - tile_sizes: fidget::jit::JitShape::tile_sizes_2d().to_vec(), + tile_sizes: fidget::jit::JitFunction::tile_sizes_2d().to_vec(), ..Default::default() }; group.bench_function(BenchmarkId::new("jit", size), move |b| { @@ -52,18 +52,19 @@ pub fn prospero_size_sweep(c: &mut Criterion) { } pub fn prospero_thread_sweep(c: &mut Criterion) { - let (ctx, root) = fidget::Context::from_text(PROSPERO.as_bytes()).unwrap(); - let shape_vm = &fidget::vm::VmShape::new(&ctx, root).unwrap(); + let (mut ctx, root) = + fidget::Context::from_text(PROSPERO.as_bytes()).unwrap(); + let shape_vm = &fidget::vm::VmShape::new(&mut ctx, root).unwrap(); #[cfg(feature = "jit")] - let shape_jit = &fidget::jit::JitShape::new(&ctx, root).unwrap(); + let shape_jit = &fidget::jit::JitShape::new(&mut ctx, root).unwrap(); let mut group = c.benchmark_group("speed vs threads (prospero, 2d) (1024 x 1024)"); for threads in [1, 2, 4, 8, 16] { let cfg = &fidget::render::RenderConfig { image_size: 1024, - tile_sizes: fidget::vm::VmShape::tile_sizes_2d().to_vec(), + tile_sizes: fidget::vm::VmFunction::tile_sizes_2d().to_vec(), threads: threads.try_into().unwrap(), ..Default::default() }; @@ -80,7 +81,7 @@ pub fn prospero_thread_sweep(c: &mut Criterion) { { let cfg = &fidget::render::RenderConfig { image_size: 1024, - tile_sizes: fidget::jit::JitShape::tile_sizes_2d().to_vec(), + tile_sizes: fidget::jit::JitFunction::tile_sizes_2d().to_vec(), threads: threads.try_into().unwrap(), ..Default::default() }; diff --git a/fidget/src/core/compiler/ssa_tape.rs b/fidget/src/core/compiler/ssa_tape.rs index 6f58c1b4..80e4fc55 100644 --- a/fidget/src/core/compiler/ssa_tape.rs +++ b/fidget/src/core/compiler/ssa_tape.rs @@ -2,6 +2,7 @@ use crate::{ compiler::SsaOp, context::{BinaryOpcode, Node, Op, UnaryOpcode}, + eval::VarMap, Context, Error, }; use serde::{Deserialize, Serialize}; @@ -32,7 +33,7 @@ impl SsaTape { /// /// This should always succeed unless the `root` is from a different /// `Context`, in which case `Error::BadNode` will be returned. - pub fn new(ctx: &Context, root: Node) -> Result { + pub fn new(ctx: &Context, root: Node) -> Result<(Self, VarMap), Error> { let mut mapping = HashMap::new(); let mut parent_count: HashMap = HashMap::new(); let mut slot_count = 0; @@ -46,6 +47,7 @@ impl SsaTape { // Accumulate parent counts and declare all nodes let mut seen = HashSet::new(); + let mut vars = HashMap::new(); let mut todo = vec![root]; while let Some(node) = todo.pop() { if !seen.insert(node) { @@ -57,6 +59,11 @@ impl SsaTape { mapping.insert(node, Slot::Immediate(c.0 as f32)) } _ => { + if let Op::Input(v) = op { + let next = vars.len(); + let v = ctx.get_var_by_index(*v).unwrap().clone(); + vars.entry(v).or_insert(next); + } let i = slot_count; slot_count += 1; mapping.insert(node, Slot::Reg(i)) @@ -91,13 +98,8 @@ impl SsaTape { }; let op = match op { Op::Input(..) => { - let arg = match ctx.var_name(node).unwrap().unwrap() { - "X" => 0, - "Y" => 1, - "Z" => 2, - i => panic!("Unexpected input index: {i}"), - }; - SsaOp::Input(i, arg) + let arg = vars[ctx.var_name(node).unwrap().unwrap()]; + SsaOp::Input(i, arg.try_into().unwrap()) } Op::Const(..) => { unreachable!("skipped above") @@ -232,7 +234,7 @@ impl SsaTape { tape.push(SsaOp::CopyImm(0, c)); } - Ok(SsaTape { tape, choice_count }) + Ok((SsaTape { tape, choice_count }, vars)) } /// Checks whether the tape is empty @@ -404,8 +406,9 @@ mod test { let c8 = ctx.sub(c7, r).unwrap(); let c9 = ctx.max(c8, c6).unwrap(); - let tape = SsaTape::new(&ctx, c9).unwrap(); + let (tape, vs) = SsaTape::new(&ctx, c9).unwrap(); assert_eq!(tape.len(), 8); + assert_eq!(vs.len(), 2); } #[test] @@ -414,7 +417,8 @@ mod test { let x = ctx.x(); let x_squared = ctx.mul(x, x).unwrap(); - let tape = SsaTape::new(&ctx, x_squared).unwrap(); + let (tape, vs) = SsaTape::new(&ctx, x_squared).unwrap(); assert_eq!(tape.len(), 2); + assert_eq!(vs.len(), 1); } } diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index efe6deeb..b7659c67 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -10,10 +10,10 @@ //! they have been constructed. //! - A [`Context`] is an arena for unique (deduplicated) math expressions, //! which are represented as [`Node`] handles. Each `Node` is specific to a -//! particular context. Only `Node` objects can be converted into `Shape` +//! particular context. Only `Node` objects can be converted into `Function` //! objects for evaluation. //! -//! In other words, the typical workflow is `Tree → (Context, Node) → Shape`. +//! In other words, the typical workflow is `Tree → (Context, Node) → Function`. mod indexed; mod op; mod tree; @@ -42,7 +42,39 @@ define_index!(VarNode, "An index in the `Context::vars` map"); #[derive(Debug, Default)] pub struct Context { ops: IndexMap, - vars: IndexMap, + vars: IndexMap, +} + +/// A `Var` represents a value which can vary during evaluation +/// +/// We pre-define common variables (e.g. X, Y, Z) but also allow for fully +/// customized values. +#[allow(missing_docs)] +#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub enum Var { + X, + Y, + Z, + W, + T, + Static(&'static str), + Named(String), + Value(u64), +} + +impl std::fmt::Display for Var { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Var::X => write!(f, "X"), + Var::Y => write!(f, "Y"), + Var::Z => write!(f, "Z"), + Var::W => write!(f, "W"), + Var::T => write!(f, "T"), + Var::Static(s) => write!(f, "{s}"), + Var::Named(s) => write!(f, "{s}"), + Var::Value(v) => write!(f, "v_{v}"), + } + } } impl Context { @@ -119,11 +151,11 @@ impl Context { } } - /// Looks up the variable name associated with the given node. + /// Looks up the [`Var`] associated with the given node. /// /// If the node is invalid for this tree, returns an error; if the node is /// not an `Op::Input`, returns `Ok(None)`. - pub fn var_name(&self, n: Node) -> Result, Error> { + pub fn var_name(&self, n: Node) -> Result, Error> { match self.get_op(n) { Some(Op::Input(c)) => self.get_var_by_index(*c).map(Some), Some(_) => Ok(None), @@ -131,8 +163,8 @@ impl Context { } } - /// Looks up the variable name associated with the given `VarNode` - pub fn get_var_by_index(&self, n: VarNode) -> Result<&str, Error> { + /// Looks up the [`Var`] associated with the given [`VarNode`] + pub fn get_var_by_index(&self, n: VarNode) -> Result<&Var, Error> { match self.vars.get_by_index(n) { Some(c) => Ok(c), None => Err(Error::BadVar), @@ -150,22 +182,27 @@ impl Context { /// assert_eq!(v, 1.0); /// ``` pub fn x(&mut self) -> Node { - let v = self.vars.insert(String::from("X")); + let v = self.vars.insert(Var::X); self.ops.insert(Op::Input(v)) } /// Constructs or finds a variable node named "Y" pub fn y(&mut self) -> Node { - let v = self.vars.insert(String::from("Y")); + let v = self.vars.insert(Var::Y); self.ops.insert(Op::Input(v)) } /// Constructs or finds a variable node named "Z" pub fn z(&mut self) -> Node { - let v = self.vars.insert(String::from("Z")); + let v = self.vars.insert(Var::Z); self.ops.insert(Op::Input(v)) } + /// Returns a 3-element array of `X`, `Y`, `Z` nodes + pub fn axes(&mut self) -> [Node; 3] { + [self.x(), self.y(), self.z()] + } + /// Returns a node representing the given constant value. /// ``` /// # let mut ctx = fidget::context::Context::new(); @@ -733,7 +770,7 @@ impl Context { /// Evaluates the given node with the provided values for X, Y, and Z. /// /// This is extremely inefficient; consider converting the node into a - /// [`Shape`](crate::eval::Shape) and using its evaluators instead. + /// [`Shape`](crate::shape::Shape) and using its evaluators instead. /// /// ``` /// # let mut ctx = fidget::context::Context::new(); @@ -752,9 +789,8 @@ impl Context { y: f64, z: f64, ) -> Result { - let vars = [("X", x), ("Y", y), ("Z", z)] + let vars = [(Var::X, x), (Var::Y, y), (Var::Z, z)] .into_iter() - .map(|(a, b)| (a.to_string(), b)) .collect(); self.eval(root, &vars) } @@ -762,11 +798,11 @@ impl Context { /// Evaluates the given node with a generic set of variables /// /// This is extremely inefficient; consider converting the node into a - /// [`Shape`](crate::eval::Shape) and using its evaluators instead. + /// [`Shape`](crate::shape::Shape) and using its evaluators instead. pub fn eval( &self, root: Node, - vars: &BTreeMap, + vars: &BTreeMap, ) -> Result { let mut cache = vec![None; self.ops.len()].into(); self.eval_inner(root, vars, &mut cache) @@ -775,7 +811,7 @@ impl Context { fn eval_inner( &self, node: Node, - vars: &BTreeMap, + vars: &BTreeMap, cache: &mut IndexVec, Node>, ) -> Result { if node.0 >= cache.len() { @@ -960,7 +996,7 @@ impl Context { Op::Const(c) => write!(out, "{}", c).unwrap(), Op::Input(v) => { let v = self.vars.get_by_index(*v).unwrap(); - out += v; + out += &v.to_string(); } Op::Binary(op, ..) => match op { BinaryOpcode::Add => out += "add", @@ -1209,8 +1245,9 @@ mod test { let c8 = ctx.sub(c7, r).unwrap(); let c9 = ctx.max(c8, c6).unwrap(); - let tape = VmData::<255>::new(&ctx, c9).unwrap(); + let (tape, vs) = VmData::<255>::new(&ctx, c9).unwrap(); assert_eq!(tape.len(), 8); + assert_eq!(vs.len(), 2); } #[test] @@ -1219,7 +1256,8 @@ mod test { let x = ctx.x(); let x_squared = ctx.mul(x, x).unwrap(); - let tape = VmData::<255>::new(&ctx, x_squared).unwrap(); + let (tape, vs) = VmData::<255>::new(&ctx, x_squared).unwrap(); assert_eq!(tape.len(), 2); + assert_eq!(vs.len(), 1); } } diff --git a/fidget/src/core/eval/bulk.rs b/fidget/src/core/eval/bulk.rs index ae6bc083..ede24e96 100644 --- a/fidget/src/core/eval/bulk.rs +++ b/fidget/src/core/eval/bulk.rs @@ -35,9 +35,7 @@ pub trait BulkEvaluator: Default { fn eval( &mut self, tape: &Self::Tape, - x: &[Self::Data], - y: &[Self::Data], - z: &[Self::Data], + vars: &[&[Self::Data]], ) -> Result<&[Self::Data], Error>; /// Build a new empty evaluator @@ -46,19 +44,24 @@ pub trait BulkEvaluator: Default { } /// Helper function to return an error if the inputs are invalid - fn check_arguments( + fn check_arguments( &self, - xs: &[T], - ys: &[T], - zs: &[T], + vars: &[&[Self::Data]], var_count: usize, ) -> Result<(), Error> { - if xs.len() != ys.len() || ys.len() != zs.len() { - Err(Error::MismatchedSlices) - } else if var_count > 3 { - Err(Error::BadVarSlice(3, var_count)) + // It's fine if the caller has given us extra variables (e.g. due to + // tape simplification), but it must have given us enough. + if vars.len() < var_count { + Err(Error::BadVarSlice(vars.len(), var_count)) } else { - Ok(()) + let Some(n) = vars.first().map(|v| v.len()) else { + return Ok(()); + }; + if vars.iter().any(|v| v.len() == n) { + Ok(()) + } else { + Err(Error::MismatchedSlices) + } } } } diff --git a/fidget/src/core/eval/mod.rs b/fidget/src/core/eval/mod.rs index 1a10b1a1..680d58f3 100644 --- a/fidget/src/core/eval/mod.rs +++ b/fidget/src/core/eval/mod.rs @@ -1,29 +1,8 @@ -//! Traits and data structures for evaluation -//! -//! There are a bunch of things in here, but the most important trait is -//! [`Shape`], followed by the evaluator traits ([`BulkEvaluator`] and -//! [`TracingEvaluator`]). -//! -//! ```rust -//! use fidget::vm::VmShape; -//! use fidget::context::Context; -//! use fidget::eval::{TracingEvaluator, Shape, MathShape, EzShape}; -//! -//! let mut ctx = Context::new(); -//! let x = ctx.x(); -//! let shape = VmShape::new(&ctx, x)?; -//! -//! // Let's build a single point evaluator: -//! let mut eval = VmShape::new_point_eval(); -//! let tape = shape.ez_point_tape(); -//! let (value, _trace) = eval.eval(&tape, 0.25, 0.0, 0.0)?; -//! assert_eq!(value, 0.25); -//! # Ok::<(), fidget::Error>(()) -//! ``` +//! Traits and data structures for function evaluation use crate::{ - context::Node, + context::{Context, Node, Tree, Var}, types::{Grad, Interval}, - Context, Error, + Error, }; #[cfg(any(test, feature = "eval-tests"))] @@ -31,32 +10,62 @@ pub mod test; mod bulk; mod tracing; -mod transform; -// Re-export a few things +// Reexport a few types pub use bulk::BulkEvaluator; pub use tracing::TracingEvaluator; -pub use transform::TransformedShape; -/// A shape represents an implicit surface +/// A tape represents something that can be evaluated by an evaluator +/// +/// The only property enforced on the trait is that we must have some way to +/// recycle its internal storage. This matters most for JIT evaluators, whose +/// tapes are regions of executable memory-mapped RAM (which is expensive to map +/// and unmap). +pub trait Tape { + /// Associated type for this tape's data storage + type Storage: Default; + + /// Retrieves the internal storage from this tape + fn recycle(self) -> Self::Storage; +} + +/// Represents the trace captured by a tracing evaluation +/// +/// The only property enforced on the trait is that we must have a way of +/// reusing trace allocations. Because [`Trace`] implies `Clone` where it's +/// used in [`Function`], this is trivial, but we can't provide a default +/// implementation because it would fall afoul of `impl` specialization. +pub trait Trace { + /// Copies the contents of `other` into `self` + fn copy_from(&mut self, other: &Self); +} + +impl Trace for Vec { + fn copy_from(&mut self, other: &Self) { + self.resize(other.len(), T::default()); + self.copy_from_slice(other); + } +} + +/// A function represents something that can be evaluated /// -/// It is mostly agnostic to _how_ that surface is represented; we simply -/// require that the shape can generate evaluators of various kinds. +/// It is mostly agnostic to _how_ that something is represented; we simply +/// require that it can generate evaluators of various kinds. /// -/// Shapes are shared between threads, so they should be cheap to clone. In +/// Functions are shared between threads, so they should be cheap to clone. In /// most cases, they're a thin wrapper around an `Arc<..>`. -pub trait Shape: Send + Sync + Clone { +pub trait Function: Send + Sync + Clone { /// Associated type traces collected during tracing evaluation /// /// This type must implement [`Eq`] so that traces can be compared; calling - /// [`Shape::simplify`] with traces that compare equal should produce an + /// [`Function::simplify`] with traces that compare equal should produce an /// identical result and may be cached. type Trace: Clone + Eq + Send + Trace; - /// Associated type for storage used by the shape itself + /// Associated type for storage used by the function itself type Storage: Default + Send; - /// Associated type for workspace used during shape simplification + /// Associated type for workspace used during function simplification type Workspace: Default + Send; /// Associated type for storage used by tapes @@ -145,104 +154,35 @@ pub trait Shape: Send + Sync + Clone { where Self: Sized; - /// Attempt to reclaim storage from this shape + /// Attempt to reclaim storage from this function /// - /// This may fail, because shapes are `Clone` and are often implemented + /// This may fail, because functions are `Clone` and are often implemented /// using an `Arc` around a heavier data structure. fn recycle(self) -> Option; - /// Returns a size associated with this shape + /// Returns a size associated with this function /// /// This is underspecified and only used for unit testing; for tape-based - /// shapes, it's typically the length of the tape, + /// functions, it's typically the length of the tape, fn size(&self) -> usize; - - /// Associated type returned when applying a transform - /// - /// This is normally [`TransformedShape`](TransformedShape), but if - /// `Self` is already `TransformedShape`, then the transform is stacked - /// (instead of creating a wrapped object). - type TransformedShape: Shape; - - /// Returns a shape with the given transform applied - fn apply_transform( - self, - mat: nalgebra::Matrix4, - ) -> ::TransformedShape; } -/// Extension trait for working with a shape without thinking much about memory -/// -/// All of the [`Shape`] functions that use significant amounts of memory -/// pedantically require you to pass in storage for reuse. This trait allows -/// you to ignore that, at the cost of performance; we require that all storage -/// types implement [`Default`], so these functions do the boilerplate for you. -/// -/// This trait is automatically implemented for every [`Shape`], but must be -/// imported separately as a speed-bump to using it everywhere. -pub trait EzShape: Shape { - /// Returns an evaluation tape for a point evaluator - fn ez_point_tape(&self) -> ::Tape; +/// Map from variable (from a particular [`Context`]) to index +pub type VarMap = std::collections::HashMap; - /// Returns an evaluation tape for an interval evaluator - fn ez_interval_tape( - &self, - ) -> ::Tape; - - /// Returns an evaluation tape for a float slice evaluator - fn ez_float_slice_tape( - &self, - ) -> ::Tape; - - /// Returns an evaluation tape for a float slice evaluator - fn ez_grad_slice_tape( - &self, - ) -> ::Tape; - - /// Computes a simplified tape using the given trace - fn ez_simplify(&self, trace: &Self::Trace) -> Result - where - Self: Sized; -} - -impl EzShape for S { - fn ez_point_tape(&self) -> ::Tape { - self.point_tape(Default::default()) - } - - fn ez_interval_tape( - &self, - ) -> ::Tape { - self.interval_tape(Default::default()) - } - - fn ez_float_slice_tape( - &self, - ) -> ::Tape { - self.float_slice_tape(Default::default()) - } - - fn ez_grad_slice_tape( - &self, - ) -> ::Tape { - self.grad_slice_tape(Default::default()) - } - - fn ez_simplify(&self, trace: &Self::Trace) -> Result { - let mut workspace = Default::default(); - self.simplify(trace, Default::default(), &mut workspace) - } -} - -/// A [`Shape`] which can be built from a math expression -pub trait MathShape { - /// Builds a new shape from the given context and node - fn new(ctx: &Context, node: Node) -> Result +/// A [`Function`] which can be built from a math expression +pub trait MathFunction { + /// Builds a new function from the given context and node + /// + /// Returns a tuple of the [`MathFunction`] and a [`VarMap`] representing a + /// mapping from variables (in the [`Context`]) to indices used during + /// evaluation. + fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> where Self: Sized; - /// Helper function to build a shape from a [`Tree`](crate::context::Tree) - fn from_tree(t: &crate::context::Tree) -> Self + /// Helper function to build a function from a [`Tree`] + fn from_tree(t: &Tree) -> (Self, VarMap) where Self: Sized, { @@ -251,35 +191,3 @@ pub trait MathShape { Self::new(&ctx, node).unwrap() } } - -/// A tape represents something that can be evaluated by an evaluator -/// -/// The only property enforced on the trait is that we must have some way to -/// recycle its internal storage. This matters most for JIT evaluators, whose -/// tapes are regions of executable memory-mapped RAM (which is expensive to map -/// and unmap). -pub trait Tape { - /// Associated type for this tape's data storage - type Storage: Default; - - /// Retrieves the internal storage from this tape - fn recycle(self) -> Self::Storage; -} - -/// Represents the trace captured by a tracing evaluation -/// -/// The only property enforced on the trait is that we must have a way of -/// reusing trace allocations. Because [`Trace`] implies `Clone` where it's -/// used in [`Shape`], this is trivial, but we can't provide a default -/// implementation because it would fall afoul of `impl` specialization. -pub trait Trace { - /// Copies the contents of `other` into `self` - fn copy_from(&mut self, other: &Self); -} - -impl Trace for Vec { - fn copy_from(&mut self, other: &Self) { - self.resize(other.len(), T::default()); - self.copy_from_slice(other); - } -} diff --git a/fidget/src/core/eval/test/float_slice.rs b/fidget/src/core/eval/test/float_slice.rs index 68ca1167..cf9914fc 100644 --- a/fidget/src/core/eval/test/float_slice.rs +++ b/fidget/src/core/eval/test/float_slice.rs @@ -6,26 +6,24 @@ use super::{build_stress_fn, test_args, CanonicalBinaryOp, CanonicalUnaryOp}; use crate::{ context::Context, - eval::{BulkEvaluator, EzShape, MathShape, Shape}, + eval::{Function, MathFunction}, + shape::{EzShape, Shape}, }; /// Helper struct to put constrains on our `Shape` object -pub struct TestFloatSlice(std::marker::PhantomData<*const S>); +pub struct TestFloatSlice(std::marker::PhantomData<*const F>); -impl TestFloatSlice -where - S: Shape + MathShape, -{ +impl TestFloatSlice { pub fn test_give_take() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); - let shape_x = S::new(&ctx, x).unwrap(); - let shape_y = S::new(&ctx, y).unwrap(); + let shape_x = Shape::::new(&mut ctx, x).unwrap(); + let shape_y = Shape::::new(&mut ctx, y).unwrap(); // This is a fuzz test for icache issues - let mut eval = S::new_float_slice_eval(); + let mut eval = Shape::::new_float_slice_eval(); for _ in 0..10000 { let tape = shape_x.ez_float_slice_tape(); let out = eval @@ -58,8 +56,8 @@ where let x = ctx.x(); let y = ctx.y(); - let mut eval = S::new_float_slice_eval(); - let shape = S::new(&ctx, x).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, x).unwrap(); let tape = shape.ez_float_slice_tape(); let out = eval .eval( @@ -92,7 +90,7 @@ where assert_eq!(out, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); let mul = ctx.mul(y, 2.0).unwrap(); - let shape = S::new(&ctx, mul).unwrap(); + let shape = Shape::::new(&mut ctx, mul).unwrap(); let tape = shape.ez_float_slice_tape(); let out = eval .eval( @@ -125,8 +123,8 @@ where let a = ctx.x(); let b = ctx.sin(a).unwrap(); - let shape = S::new(&ctx, b).unwrap(); - let mut eval = S::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, b).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); let args = [0.0, 1.0, 2.0, std::f32::consts::PI / 2.0]; @@ -137,7 +135,7 @@ where } pub fn test_f_stress_n(depth: usize) { - let (ctx, node) = build_stress_fn(depth); + let (mut ctx, node) = build_stress_fn(depth); // Pick an input slice that's guaranteed to be > 1 SIMD register let args = (0..32).map(|i| i as f32 / 32f32).collect::>(); @@ -147,8 +145,8 @@ where let z: Vec = args[2..].iter().chain(&args[0..2]).cloned().collect(); - let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); let out = eval.eval(&tape, &x, &y, &z).unwrap(); @@ -172,7 +170,7 @@ where // that S is also a VmShape, but this comparison isn't particularly // expensive, so we'll do it regardless. use crate::vm::VmShape; - let shape = VmShape::new(&ctx, node).unwrap(); + let shape = VmShape::new(&mut ctx, node).unwrap(); let mut eval = VmShape::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); @@ -204,8 +202,8 @@ where for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { let node = C::build(&mut ctx, v); - let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); let out = match i { @@ -261,8 +259,8 @@ where for (j, &u) in inputs.iter().enumerate() { let node = C::build(&mut ctx, v, u); - let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); let out = match (i, j) { @@ -308,8 +306,8 @@ where let c = ctx.constant(*rhs as f64); let node = C::build(&mut ctx, v, c); - let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); let out = match i { @@ -349,8 +347,8 @@ where let c = ctx.constant(*lhs as f64); let node = C::build(&mut ctx, c, v); - let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); let out = match i { diff --git a/fidget/src/core/eval/test/grad_slice.rs b/fidget/src/core/eval/test/grad_slice.rs index 42ef1450..7e40b0c3 100644 --- a/fidget/src/core/eval/test/grad_slice.rs +++ b/fidget/src/core/eval/test/grad_slice.rs @@ -5,22 +5,23 @@ use super::{build_stress_fn, test_args, CanonicalBinaryOp, CanonicalUnaryOp}; use crate::{ context::Context, - eval::{BulkEvaluator, EzShape, MathShape, Shape}, + eval::{BulkEvaluator, Function, MathFunction}, + shape::{EzShape, Shape, ShapeTape}, types::Grad, + vm::VmFunction, }; /// Epsilon for gradient estimates const EPSILON: f64 = 1e-8; /// Helper struct to put constrains on our `Shape` object -pub struct TestGradSlice(std::marker::PhantomData<*const S>); +pub struct TestGradSlice(std::marker::PhantomData<*const F>); -impl TestGradSlice -where - S: Shape + MathShape, -{ +impl TestGradSlice { fn eval_xyz( - tape: &<::GradSliceEval as BulkEvaluator>::Tape, + tape: &ShapeTape< + <::GradSliceEval as BulkEvaluator>::Tape, + >, xs: &[f32], ys: &[f32], zs: &[f32], @@ -34,14 +35,14 @@ where let zs: Vec<_> = zs.iter().map(|z| Grad::new(*z, 0.0, 0.0, 1.0)).collect(); - let mut eval = S::new_grad_slice_eval(); + let mut eval = Shape::::new_grad_slice_eval(); eval.eval(tape, &xs, &ys, &zs).unwrap().to_owned() } pub fn test_g_x() { let mut ctx = Context::new(); let x = ctx.x(); - let shape = S::new(&ctx, x).unwrap(); + let shape = Shape::::new(&mut ctx, x).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -53,7 +54,7 @@ where pub fn test_g_y() { let mut ctx = Context::new(); let y = ctx.y(); - let shape = S::new(&ctx, y).unwrap(); + let shape = Shape::::new(&mut ctx, y).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -65,7 +66,7 @@ where pub fn test_g_z() { let mut ctx = Context::new(); let z = ctx.z(); - let shape = S::new(&ctx, z).unwrap(); + let shape = Shape::::new(&mut ctx, z).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -78,7 +79,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.square(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -103,7 +104,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.abs(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -120,7 +121,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sqrt(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -137,7 +138,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sin(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); let v = Self::eval_xyz(&tape, &[1.0, 2.0, 3.0], &[0.0; 3], &[0.0; 3]); @@ -148,7 +149,7 @@ where let y = ctx.y(); let y = ctx.mul(y, 2.0).unwrap(); let s = ctx.sin(y).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); let v = Self::eval_xyz(&tape, &[0.0; 3], &[1.0, 2.0, 3.0], &[0.0; 3]); v[0].compare_eq(Grad::new(2f32.sin(), 0.0, 2.0 * 2f32.cos(), 0.0)); @@ -161,7 +162,7 @@ where let x = ctx.x(); let y = ctx.y(); let s = ctx.mul(x, y).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -186,7 +187,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.div(x, 2.0).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -199,7 +200,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.recip(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -217,7 +218,7 @@ where let x = ctx.x(); let y = ctx.y(); let m = ctx.min(x, y).unwrap(); - let shape = S::new(&ctx, m).unwrap(); + let shape = Shape::::new(&mut ctx, m).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -237,7 +238,7 @@ where let z = ctx.z(); let min = ctx.min(x, y).unwrap(); let max = ctx.max(min, z).unwrap(); - let shape = S::new(&ctx, max).unwrap(); + let shape = Shape::::new(&mut ctx, max).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -259,7 +260,7 @@ where let x = ctx.x(); let y = ctx.y(); let m = ctx.max(x, y).unwrap(); - let shape = S::new(&ctx, m).unwrap(); + let shape = Shape::::new(&mut ctx, m).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -276,7 +277,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let m = ctx.not(x).unwrap(); - let shape = S::new(&ctx, m).unwrap(); + let shape = Shape::::new(&mut ctx, m).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -295,7 +296,7 @@ where let sum = ctx.add(x2, y2).unwrap(); let sqrt = ctx.sqrt(sum).unwrap(); let sub = ctx.sub(sqrt, 0.5).unwrap(); - let shape = S::new(&ctx, sub).unwrap(); + let shape = Shape::::new(&mut ctx, sub).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -317,7 +318,7 @@ where } pub fn test_g_stress_n(depth: usize) { - let (ctx, node) = build_stress_fn(depth); + let (mut ctx, node) = build_stress_fn(depth); // Pick an input slice that's guaranteed to be > 1 SIMD registers let args = (0..32).map(|i| i as f32 / 32f32).collect::>(); @@ -327,7 +328,7 @@ where let z: Vec = args[2..].iter().chain(&args[0..2]).cloned().collect(); - let shape = S::new(&ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = Self::eval_xyz(&tape, &x, &y, &z); @@ -352,10 +353,10 @@ where // that S is also a VmShape, but this comparison isn't particularly // expensive, so we'll do it regardless. use crate::vm::VmShape; - let shape = VmShape::new(&ctx, node).unwrap(); + let shape = VmShape::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); - let cmp = TestGradSlice::::eval_xyz(&tape, &x, &y, &z); + let cmp = TestGradSlice::::eval_xyz(&tape, &x, &y, &z); for (a, b) in out.iter().zip(cmp.iter()) { a.compare_eq(*b) } @@ -375,7 +376,7 @@ where for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { let node = C::build(&mut ctx, v); - let shape = S::new(&ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = match i { @@ -528,7 +529,7 @@ where for (j, &u) in inputs.iter().enumerate() { let node = C::build(&mut ctx, v, u); - let shape = S::new(&ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = match (i, j) { @@ -575,7 +576,7 @@ where let c = ctx.constant(*rhs as f64); let node = C::build(&mut ctx, v, c); - let shape = S::new(&ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = match i { @@ -616,7 +617,7 @@ where let c = ctx.constant(*lhs as f64); let node = C::build(&mut ctx, c, v); - let shape = S::new(&ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = match i { diff --git a/fidget/src/core/eval/test/interval.rs b/fidget/src/core/eval/test/interval.rs index 52147513..513b828c 100644 --- a/fidget/src/core/eval/test/interval.rs +++ b/fidget/src/core/eval/test/interval.rs @@ -9,27 +9,28 @@ use super::{ }; use crate::{ context::Context, - eval::{EzShape, MathShape, Shape, Tape, TracingEvaluator}, + eval::{Function, MathFunction}, + shape::{EzShape, Shape}, types::Interval, vm::Choice, }; /// Helper struct to put constrains on our `Shape` object -pub struct TestInterval(std::marker::PhantomData<*const S>); +pub struct TestInterval(std::marker::PhantomData<*const F>); -impl TestInterval +impl TestInterval where - for<'a> S: Shape + MathShape, - ::Trace: AsRef<[Choice]>, + for<'a> F: Function + MathFunction, + ::Trace: AsRef<[Choice]>, { pub fn test_interval() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); - let shape = S::new(&ctx, x).unwrap(); + let shape = Shape::::new(&mut ctx, x).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!( eval.eval_xy(&tape, [0.0, 1.0], [2.0, 3.0]), [0.0, 1.0].into() @@ -39,9 +40,9 @@ where [1.0, 5.0].into() ); - let shape = S::new(&ctx, y).unwrap(); + let shape = Shape::::new(&mut ctx, y).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!( eval.eval_xy(&tape, [0.0, 1.0], [2.0, 3.0]), [2.0, 3.0].into() @@ -57,9 +58,9 @@ where let x = ctx.x(); let abs_x = ctx.abs(x).unwrap(); - let shape = S::new(&ctx, abs_x).unwrap(); + let shape = Shape::::new(&mut ctx, abs_x).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 1.0].into()); assert_eq!(eval.eval_x(&tape, [1.0, 5.0]), [1.0, 5.0].into()); assert_eq!(eval.eval_x(&tape, [-2.0, 5.0]), [0.0, 5.0].into()); @@ -69,9 +70,9 @@ where let y = ctx.y(); let abs_y = ctx.abs(y).unwrap(); let sum = ctx.add(abs_x, abs_y).unwrap(); - let shape = S::new(&ctx, sum).unwrap(); + let shape = Shape::::new(&mut ctx, sum).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!( eval.eval_xy(&tape, [0.0, 1.0], [0.0, 1.0]), [0.0, 2.0].into() @@ -92,9 +93,9 @@ where let v = ctx.add(x, 0.5).unwrap(); let out = ctx.abs(v).unwrap(); - let shape = S::new(&ctx, out).unwrap(); + let shape = Shape::::new(&mut ctx, out).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [-1.0, 1.0]), [0.0, 1.5].into()); } @@ -104,9 +105,9 @@ where let x = ctx.x(); let sqrt_x = ctx.sqrt(x).unwrap(); - let shape = S::new(&ctx, sqrt_x).unwrap(); + let shape = Shape::::new(&mut ctx, sqrt_x).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 1.0].into()); assert_eq!(eval.eval_x(&tape, [0.0, 4.0]), [0.0, 2.0].into()); @@ -132,9 +133,9 @@ where let x = ctx.x(); let sqrt_x = ctx.square(x).unwrap(); - let shape = S::new(&ctx, sqrt_x).unwrap(); + let shape = Shape::::new(&mut ctx, sqrt_x).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 1.0].into()); assert_eq!(eval.eval_x(&tape, [0.0, 4.0]), [0.0, 16.0].into()); assert_eq!(eval.eval_x(&tape, [2.0, 4.0]), [4.0, 16.0].into()); @@ -153,17 +154,17 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sin(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-1.0, 1.0].into()); let y = ctx.y(); let y = ctx.mul(y, 2.0).unwrap(); let s = ctx.sin(y).unwrap(); let s = ctx.add(x, s).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_interval_tape(); assert_eq!(eval.eval_x(&tape, [0.0, 3.0]), [-1.0, 4.0].into()); @@ -174,9 +175,9 @@ where let x = ctx.x(); let neg_x = ctx.neg(x).unwrap(); - let shape = S::new(&ctx, neg_x).unwrap(); + let shape = Shape::::new(&mut ctx, neg_x).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-1.0, 0.0].into()); assert_eq!(eval.eval_x(&tape, [0.0, 4.0]), [-4.0, 0.0].into()); assert_eq!(eval.eval_x(&tape, [2.0, 4.0]), [-4.0, -2.0].into()); @@ -196,9 +197,9 @@ where let x = ctx.x(); let not_x = ctx.not(x).unwrap(); - let shape = S::new(&ctx, not_x).unwrap(); + let shape = Shape::::new(&mut ctx, not_x).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [-5.0, 0.0]), [0.0, 1.0].into()); } @@ -208,9 +209,9 @@ where let y = ctx.y(); let mul = ctx.mul(x, y).unwrap(); - let shape = S::new(&ctx, mul).unwrap(); + let shape = Shape::::new(&mut ctx, mul).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!( eval.eval_xy(&tape, [0.0, 1.0], [0.0, 1.0]), [0.0, 1.0].into() @@ -249,16 +250,16 @@ where let mut ctx = Context::new(); let x = ctx.x(); let mul = ctx.mul(x, 2.0).unwrap(); - let shape = S::new(&ctx, mul).unwrap(); + let shape = Shape::::new(&mut ctx, mul).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 2.0].into()); assert_eq!(eval.eval_x(&tape, [1.0, 2.0]), [2.0, 4.0].into()); let mul = ctx.mul(x, -3.0).unwrap(); - let shape = S::new(&ctx, mul).unwrap(); + let shape = Shape::::new(&mut ctx, mul).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-3.0, 0.0].into()); assert_eq!(eval.eval_x(&tape, [1.0, 2.0]), [-6.0, -3.0].into()); } @@ -269,9 +270,9 @@ where let y = ctx.y(); let sub = ctx.sub(x, y).unwrap(); - let shape = S::new(&ctx, sub).unwrap(); + let shape = Shape::::new(&mut ctx, sub).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!( eval.eval_xy(&tape, [0.0, 1.0], [0.0, 1.0]), [-1.0, 1.0].into() @@ -298,16 +299,16 @@ where let mut ctx = Context::new(); let x = ctx.x(); let sub = ctx.sub(x, 2.0).unwrap(); - let shape = S::new(&ctx, sub).unwrap(); + let shape = Shape::::new(&mut ctx, sub).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-2.0, -1.0].into()); assert_eq!(eval.eval_x(&tape, [1.0, 2.0]), [-1.0, 0.0].into()); let sub = ctx.sub(-3.0, x).unwrap(); - let shape = S::new(&ctx, sub).unwrap(); + let shape = Shape::::new(&mut ctx, sub).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-4.0, -3.0].into()); assert_eq!(eval.eval_x(&tape, [1.0, 2.0]), [-5.0, -4.0].into()); } @@ -316,9 +317,9 @@ where let mut ctx = Context::new(); let x = ctx.x(); let recip = ctx.recip(x).unwrap(); - let shape = S::new(&ctx, recip).unwrap(); + let shape = Shape::::new(&mut ctx, recip).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let nanan = eval.eval_x(&tape, [0.0, 1.0]); assert!(nanan.lower().is_nan()); @@ -341,9 +342,9 @@ where let x = ctx.x(); let y = ctx.y(); let div = ctx.div(x, y).unwrap(); - let shape = S::new(&ctx, div).unwrap(); + let shape = Shape::::new(&mut ctx, div).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let nanan = eval.eval_xy(&tape, [0.0, 1.0], [-1.0, 1.0]); assert!(nanan.lower().is_nan()); @@ -388,9 +389,9 @@ where let y = ctx.y(); let min = ctx.min(x, y).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, data) = eval.eval(&tape, [0.0, 1.0], [0.5, 1.5], [0.0; 2]).unwrap(); assert_eq!(r, [0.0, 1.0].into()); @@ -426,9 +427,9 @@ where let x = ctx.x(); let min = ctx.min(x, 1.0).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, data) = eval.eval(&tape, [0.0, 1.0], [0.0; 2], [0.0; 2]).unwrap(); assert_eq!(r, [0.0, 1.0].into()); @@ -451,9 +452,9 @@ where let y = ctx.y(); let max = ctx.max(x, y).unwrap(); - let shape = S::new(&ctx, max).unwrap(); + let shape = Shape::::new(&mut ctx, max).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, data) = eval.eval(&tape, [0.0, 1.0], [0.5, 1.5], [0.0; 2]).unwrap(); assert_eq!(r, [0.5, 1.5].into()); @@ -485,9 +486,9 @@ where let z = ctx.z(); let max_xy_z = ctx.max(max, z).unwrap(); - let shape = S::new(&ctx, max_xy_z).unwrap(); + let shape = Shape::::new(&mut ctx, max_xy_z).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, data) = eval .eval(&tape, [2.0, 3.0], [0.0, 1.0], [4.0, 5.0]) .unwrap(); @@ -507,18 +508,15 @@ where assert_eq!(data.unwrap().as_ref(), &[Choice::Left, Choice::Left]); } - pub fn test_i_and() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_i_and() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let v = ctx.and(x, y).unwrap(); - let shape = S::new(&ctx, v).unwrap(); + let shape = Shape::::new(&mut ctx, v).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, trace) = eval .eval(&tape, [0.0, 0.0], [-1.0, 3.0], [0.0, 0.0]) .unwrap(); @@ -544,18 +542,15 @@ where assert!(trace.is_none()); // can't simplify } - pub fn test_i_or() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_i_or() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let v = ctx.or(x, y).unwrap(); - let shape = S::new(&ctx, v).unwrap(); + let shape = Shape::::new(&mut ctx, v).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, trace) = eval .eval(&tape, [0.0, 0.0], [-1.0, 3.0], [0.0, 0.0]) .unwrap(); @@ -586,9 +581,9 @@ where let x = ctx.x(); let min = ctx.min(x, 1.0).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (out, data) = eval.eval(&tape, [0.0, 2.0], [0.0; 2], [0.0; 2]).unwrap(); assert_eq!(out, [0.0, 1.0].into()); @@ -605,9 +600,9 @@ where assert_eq!(data.unwrap().as_ref(), &[Choice::Right]); let max = ctx.max(x, 1.0).unwrap(); - let shape = S::new(&ctx, max).unwrap(); + let shape = Shape::::new(&mut ctx, max).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (out, data) = eval.eval(&tape, [0.0, 2.0], [0.0; 2], [0.0; 2]).unwrap(); assert_eq!(out, [1.0, 2.0].into()); @@ -631,10 +626,10 @@ where let z = ctx.z(); let if_else = ctx.if_nonzero_else(x, y, z).unwrap(); - let shape = S::new(&ctx, if_else).unwrap(); + let shape = Shape::::new(&mut ctx, if_else).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (out, data) = eval .eval(&tape, [-1.0, 2.0], [1.0, 2.0], [3.0, 4.0]) .unwrap(); @@ -680,9 +675,9 @@ where let x = ctx.x(); let max = ctx.max(x, 1.0).unwrap(); - let shape = S::new(&ctx, max).unwrap(); + let shape = Shape::::new(&mut ctx, max).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, data) = eval .eval(&tape, [0.0, 2.0], [0.0, 0.0], [0.0, 0.0]) .unwrap(); @@ -708,15 +703,15 @@ where let y = ctx.y(); let c = ctx.compare(x, y).unwrap(); - let shape = S::new(&ctx, c).unwrap(); + let shape = Shape::::new(&mut ctx, c).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (out, _trace) = eval.eval(&tape, -5.0, -6.0, 0.0).unwrap(); assert_eq!(out, Interval::from(1f32)); } pub fn test_i_stress_n(depth: usize) { - let (ctx, node) = build_stress_fn(depth); + let (mut ctx, node) = build_stress_fn(depth); // Pick an input slice that's guaranteed to be > 1 SIMD register let args = (0..32).map(|i| i as f32 / 32f32).collect::>(); @@ -728,8 +723,8 @@ where let y: Vec<_> = x[1..].iter().chain(&x[0..1]).cloned().collect(); let z: Vec<_> = x[2..].iter().chain(&x[0..2]).cloned().collect(); - let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_interval_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_interval_eval(); let tape = shape.ez_interval_tape(); let mut out = vec![]; @@ -741,7 +736,7 @@ where // that S is also a VmShape, but this comparison isn't particularly // expensive, so we'll do it regardless. use crate::vm::VmShape; - let shape = VmShape::new(&ctx, node).unwrap(); + let shape = VmShape::new(&mut ctx, node).unwrap(); let mut eval = VmShape::new_interval_eval(); let tape = shape.ez_interval_tape(); @@ -780,11 +775,11 @@ where let mut ctx = Context::new(); let mut tape_data = None; - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { let node = C::build(&mut ctx, v); - let shape = S::new(&ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); for &a in args.iter() { @@ -876,7 +871,7 @@ where let name = format!("{}(reg, reg)", C::NAME); let zero = Interval::new(0.0, 0.0); let mut tape_data = None; - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); for &lhs in args.iter() { for &rhs in args.iter() { for (i, &u) in xyz.iter().enumerate() { @@ -890,7 +885,7 @@ where continue; } - let shape = S::new(&ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); @@ -933,14 +928,14 @@ where let name = format!("{}(reg, imm)", C::NAME); let zero = Interval::new(0.0, 0.0); let mut tape_data = None; - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); for &lhs in args.iter() { for &rhs in values.iter() { for (i, &u) in xyz.iter().enumerate() { let c = ctx.constant(rhs as f64); let node = C::build(&mut ctx, u, c); - let shape = S::new(&ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); @@ -975,14 +970,14 @@ where let name = format!("{}(imm, reg)", C::NAME); let zero = Interval::new(0.0, 0.0); let mut tape_data = None; - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); for &lhs in values.iter() { for &rhs in args.iter() { for (i, &u) in xyz.iter().enumerate() { let c = ctx.constant(lhs as f64); let node = C::build(&mut ctx, c, u); - let shape = S::new(&ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); diff --git a/fidget/src/core/eval/test/point.rs b/fidget/src/core/eval/test/point.rs index 737036c1..764d8ca6 100644 --- a/fidget/src/core/eval/test/point.rs +++ b/fidget/src/core/eval/test/point.rs @@ -5,24 +5,25 @@ use super::{build_stress_fn, test_args, CanonicalBinaryOp, CanonicalUnaryOp}; use crate::{ context::Context, - eval::{EzShape, MathShape, Shape, TracingEvaluator}, + eval::{Function, MathFunction}, + shape::{EzShape, Shape}, vm::Choice, }; /// Helper struct to put constrains on our `Shape` object -pub struct TestPoint(std::marker::PhantomData<*const S>); -impl TestPoint +pub struct TestPoint(std::marker::PhantomData<*const F>); +impl TestPoint where - S: Shape + MathShape, - ::Trace: AsRef<[Choice]>, - ::Trace: From>, + F: Function + MathFunction, + ::Trace: AsRef<[Choice]>, + ::Trace: From>, { pub fn test_constant() { let mut ctx = Context::new(); let p = ctx.constant(1.5); - let shape = S::new(&ctx, p).unwrap(); + let shape = Shape::::new(&mut ctx, p).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 0.0, 0.0, 0.0).unwrap().0, 1.5); } @@ -31,9 +32,9 @@ where let a = ctx.constant(1.5); let x = ctx.x(); let min = ctx.min(a, x).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); let (r, trace) = eval.eval(&tape, 2.0, 0.0, 0.0).unwrap(); assert_eq!(r, 1.5); @@ -54,25 +55,22 @@ where let radius = ctx.add(x_squared, y_squared).unwrap(); let circle = ctx.sub(radius, 1.0).unwrap(); - let shape = S::new(&ctx, circle).unwrap(); + let shape = Shape::::new(&mut ctx, circle).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 0.0, 0.0, 0.0).unwrap().0, -1.0); assert_eq!(eval.eval(&tape, 1.0, 0.0, 0.0).unwrap().0, 0.0); } - pub fn test_p_min() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_p_min() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let min = ctx.min(x, y).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); let (r, trace) = eval.eval(&tape, 0.0, 0.0, 0.0).unwrap(); assert_eq!(r, 0.0); assert!(trace.is_none()); @@ -94,18 +92,15 @@ where assert!(trace.is_none()); } - pub fn test_p_max() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_p_max() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let max = ctx.max(x, y).unwrap(); - let shape = S::new(&ctx, max).unwrap(); + let shape = Shape::::new(&mut ctx, max).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); let (r, trace) = eval.eval(&tape, 0.0, 0.0, 0.0).unwrap(); assert_eq!(r, 0.0); @@ -128,18 +123,15 @@ where assert!(trace.is_none()); } - pub fn test_p_and() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_p_and() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let v = ctx.and(x, y).unwrap(); - let shape = S::new(&ctx, v).unwrap(); + let shape = Shape::::new(&mut ctx, v).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); let (r, trace) = eval.eval(&tape, 0.0, 0.0, 0.0).unwrap(); assert_eq!(r, 0.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); @@ -165,18 +157,15 @@ where assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); } - pub fn test_p_or() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_p_or() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let v = ctx.or(x, y).unwrap(); - let shape = S::new(&ctx, v).unwrap(); + let shape = Shape::::new(&mut ctx, v).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); let (r, trace) = eval.eval(&tape, 0.0, 0.0, 0.0).unwrap(); assert_eq!(r, 0.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); @@ -202,17 +191,14 @@ where assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); } - pub fn test_p_sin() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_p_sin() { let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sin(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); for x in [0.0, 1.0, 2.0] { let (r, trace) = eval.eval(&tape, x, 0.0, 0.0).unwrap(); @@ -230,7 +216,7 @@ where let y = ctx.y(); let s = ctx.add(s, y).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_point_tape(); for (x, y) in [(0.0, 1.0), (1.0, 3.0), (2.0, 8.0)] { @@ -246,26 +232,23 @@ where let y = ctx.y(); let sum = ctx.add(x, 1.0).unwrap(); let min = ctx.min(sum, y).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 2.0); assert_eq!(eval.eval(&tape, 1.0, 3.0, 0.0).unwrap().0, 2.0); assert_eq!(eval.eval(&tape, 3.0, 3.5, 0.0).unwrap().0, 3.5); } - pub fn test_push() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_push() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let min = ctx.min(x, y).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 1.0); assert_eq!(eval.eval(&tape, 3.0, 2.0, 0.0).unwrap().0, 2.0); @@ -280,9 +263,9 @@ where assert_eq!(eval.eval(&tape, 3.0, 2.0, 0.0).unwrap().0, 2.0); let min = ctx.min(x, 1.0).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 0.5, 0.0, 0.0).unwrap().0, 0.5); assert_eq!(eval.eval(&tape, 3.0, 0.0, 0.0).unwrap().0, 1.0); @@ -302,29 +285,29 @@ where let x = ctx.x(); let y = ctx.y(); - let shape = S::new(&ctx, x).unwrap(); + let shape = Shape::::new(&mut ctx, x).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 1.0); assert_eq!(eval.eval(&tape, 3.0, 4.0, 0.0).unwrap().0, 3.0); - let shape = S::new(&ctx, y).unwrap(); + let shape = Shape::::new(&mut ctx, y).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 2.0); assert_eq!(eval.eval(&tape, 3.0, 4.0, 0.0).unwrap().0, 4.0); let y2 = ctx.mul(y, 2.5).unwrap(); let sum = ctx.add(x, y2).unwrap(); - let shape = S::new(&ctx, sum).unwrap(); + let shape = Shape::::new(&mut ctx, sum).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 6.0); } pub fn test_p_stress_n(depth: usize) { - let (ctx, node) = build_stress_fn(depth); + let (mut ctx, node) = build_stress_fn(depth); // Pick an input slice that's guaranteed to be > 1 SIMD register let args = (0..32).map(|i| i as f32 / 32f32).collect::>(); @@ -332,8 +315,8 @@ where let y: Vec<_> = x[1..].iter().chain(&x[0..1]).cloned().collect(); let z: Vec<_> = x[2..].iter().chain(&x[0..2]).cloned().collect(); - let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_point_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_point_eval(); let tape = shape.ez_point_tape(); let mut out = vec![]; @@ -360,7 +343,7 @@ where // that S is also a VmShape, but this comparison isn't particularly // expensive, so we'll do it regardless. use crate::vm::VmShape; - let shape = VmShape::new(&ctx, node).unwrap(); + let shape = VmShape::new(&mut ctx, node).unwrap(); let mut eval = VmShape::new_point_eval(); let tape = shape.ez_point_tape(); @@ -388,8 +371,8 @@ where for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { let node = C::build(&mut ctx, v); - let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_point_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_point_eval(); let tape = shape.ez_point_tape(); for &a in args.iter() { @@ -444,8 +427,8 @@ where for (j, &v) in xyz.iter().enumerate() { let node = C::build(&mut ctx, u, v); - let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_point_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_point_eval(); let tape = shape.ez_point_tape(); let (out, _trace) = match (i, j) { @@ -489,8 +472,8 @@ where let c = ctx.constant(rhs as f64); let node = C::build(&mut ctx, u, c); - let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_point_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_point_eval(); let tape = shape.ez_point_tape(); let (out, _trace) = match i { @@ -526,8 +509,8 @@ where let c = ctx.constant(lhs as f64); let node = C::build(&mut ctx, c, u); - let shape = S::new(&ctx, node).unwrap(); - let mut eval = S::new_point_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_point_eval(); let tape = shape.ez_point_tape(); let (out, _trace) = match i { diff --git a/fidget/src/core/eval/tracing.rs b/fidget/src/core/eval/tracing.rs index a18116e9..c3c0a873 100644 --- a/fidget/src/core/eval/tracing.rs +++ b/fidget/src/core/eval/tracing.rs @@ -3,7 +3,7 @@ //! Tracing evaluators are run on a single data type and capture a trace of //! execution, which is the [`Trace` associated type](TracingEvaluator::Trace). //! -//! The resulting trace can be used to simplify the original shape. +//! The resulting trace can be used to simplify the original function. //! //! It is unlikely that you'll want to use these traits or types directly; //! they're implementation details to minimize code duplication. @@ -12,8 +12,9 @@ use crate::{eval::Tape, Error}; /// Evaluator for single values which simultaneously captures an execution trace /// -/// The trace can later be used to simplify the [`Shape`](crate::eval::Shape) -/// using [`Shape::simplify`](crate::eval::Shape::simplify). +/// The trace can later be used to simplify the +/// [`Function`](crate::eval::Function) +/// using [`Function::simplify`](crate::eval::Function::simplify). pub trait TracingEvaluator: Default { /// Data type used during evaluation type Data: From + Copy + Clone; @@ -33,12 +34,10 @@ pub trait TracingEvaluator: Default { type Trace; /// Evaluates the given tape at a particular position - fn eval>( + fn eval( &mut self, tape: &Self::Tape, - x: F, - y: F, - z: F, + vars: &[Self::Data], ) -> Result<(Self::Data, Option<&Self::Trace>), Error>; /// Build a new empty evaluator @@ -47,33 +46,17 @@ pub trait TracingEvaluator: Default { } /// Helper function to return an error if the inputs are invalid - fn check_arguments(&self, var_count: usize) -> Result<(), Error> { - if var_count > 3 { - Err(Error::BadVarSlice(3, var_count)) + fn check_arguments( + &self, + vars: &[Self::Data], + var_count: usize, + ) -> Result<(), Error> { + // It's fine if the caller has given us extra variables (e.g. due to + // tape simplification), but it must have given us enough. + if vars.len() < var_count { + Err(Error::BadVarSlice(vars.len(), var_count)) } else { Ok(()) } } - - #[cfg(test)] - fn eval_x>( - &mut self, - tape: &Self::Tape, - x: J, - ) -> Self::Data { - self.eval(tape, x.into(), Self::Data::from(0.0), Self::Data::from(0.0)) - .unwrap() - .0 - } - #[cfg(test)] - fn eval_xy>( - &mut self, - tape: &Self::Tape, - x: J, - y: J, - ) -> Self::Data { - self.eval(tape, x.into(), y.into(), Self::Data::from(0.0)) - .unwrap() - .0 - } } diff --git a/fidget/src/core/mod.rs b/fidget/src/core/mod.rs index 1f1d8f57..37ed2934 100644 --- a/fidget/src/core/mod.rs +++ b/fidget/src/core/mod.rs @@ -3,7 +3,7 @@ //! ``` //! use fidget::{ //! context::Context, -//! eval::{Shape, MathShape, EzShape, TracingEvaluator}, +//! shape::EzShape, //! vm::VmShape //! }; //! let mut ctx = Context::new(); @@ -14,7 +14,7 @@ //! let radius = ctx.add(x_squared, y_squared)?; //! let circle = ctx.sub(radius, 1.0)?; //! -//! let shape = VmShape::new(&ctx, circle)?; +//! let shape = VmShape::new(&mut ctx, circle)?; //! let mut eval = VmShape::new_point_eval(); //! let tape = shape.ez_point_tape(); //! @@ -109,13 +109,8 @@ mod test { let v = ctx.add(x, y).unwrap(); assert_eq!( - ctx.eval( - v, - &[("X".to_string(), 1.0), ("Y".to_string(), 2.0)] - .into_iter() - .collect() - ) - .unwrap(), + ctx.eval(v, &[(Var::X, 1.0), (Var::Y, 2.0)].into_iter().collect()) + .unwrap(), 3.0 ); assert_eq!(ctx.eval_xyz(v, 2.0, 3.0, 0.0).unwrap(), 5.0); diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index eb602598..6240171a 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -1,7 +1,276 @@ -//! Shape-specific data types +//! Data structures for shape evaluation +//! +//! Types in this module are typically thin (generic) wrappers around objects +//! that implement traits in [`fidget::eval`](crate::eval). The wraper types +//! are specialized to operate on `x, y, z` arguments, rather than taking +//! arbitrary numbers of variables. +//! +//! For example, a [`Shape`] is a wrapper which makes it easier to treat a +//! [`Function`] as an implicit surface (with X, Y, Z axes and an optional +//! transform matrix). +//! +//! ```rust +//! use fidget::vm::VmShape; +//! use fidget::context::Context; +//! use fidget::shape::EzShape; +//! +//! let mut ctx = Context::new(); +//! let x = ctx.x(); +//! let shape = VmShape::new(&mut ctx, x)?; +//! +//! // Let's build a single point evaluator: +//! let mut eval = VmShape::new_point_eval(); +//! let tape = shape.ez_point_tape(); +//! let (value, _trace) = eval.eval(&tape, 0.25, 0.0, 0.0)?; +//! assert_eq!(value, 0.25); +//! # Ok::<(), fidget::Error>(()) +//! ``` + +use crate::{ + context::{Context, Node, Tree}, + eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator}, + types::{Grad, Interval}, + Error, +}; +use nalgebra::{Matrix4, Point3}; + mod bounds; pub use bounds::Bounds; +/// A shape represents an implicit surface +/// +/// It is mostly agnostic to _how_ that surface is represented, wrapping a +/// [`Function`](Function) and a set of axes. +/// +/// Shapes are shared between threads, so they should be cheap to clone. In +/// most cases, they're a thin wrapper around an `Arc<..>`. +#[derive(Clone)] +pub struct Shape { + /// Wrapped function + f: F, + + /// Index of x, y, z axes within the function's variable list (if present) + axes: [Option; 3], + + /// Optional transform to apply to the shape + transform: Option>, +} + +impl Shape { + /// Builds a new point evaluator + pub fn new_point_eval() -> ShapeTracingEval { + ShapeTracingEval { + eval: F::PointEval::default(), + } + } + + /// Builds a new interval evaluator + pub fn new_interval_eval() -> ShapeTracingEval { + ShapeTracingEval { + eval: F::IntervalEval::default(), + } + } + + /// Builds a new float slice evaluator + pub fn new_float_slice_eval() -> ShapeBulkEval { + ShapeBulkEval { + eval: F::FloatSliceEval::default(), + xs: vec![], + ys: vec![], + zs: vec![], + } + } + + /// Builds a new gradient slice evaluator + pub fn new_grad_slice_eval() -> ShapeBulkEval { + ShapeBulkEval { + eval: F::GradSliceEval::default(), + xs: vec![], + ys: vec![], + zs: vec![], + } + } + + /// Returns an evaluation tape for a point evaluator + pub fn point_tape( + &self, + storage: F::TapeStorage, + ) -> ShapeTape<::Tape> { + ShapeTape { + tape: self.f.point_tape(storage), + axes: self.axes, + transform: self.transform, + } + } + + /// Returns an evaluation tape for a interval evaluator + pub fn interval_tape( + &self, + storage: F::TapeStorage, + ) -> ShapeTape<::Tape> { + ShapeTape { + tape: self.f.interval_tape(storage), + axes: self.axes, + transform: self.transform, + } + } + + /// Returns an evaluation tape for a float slice evaluator + pub fn float_slice_tape( + &self, + storage: F::TapeStorage, + ) -> ShapeTape<::Tape> { + ShapeTape { + tape: self.f.float_slice_tape(storage), + axes: self.axes, + transform: self.transform, + } + } + + /// Returns an evaluation tape for a gradient slice evaluator + pub fn grad_slice_tape( + &self, + storage: F::TapeStorage, + ) -> ShapeTape<::Tape> { + ShapeTape { + tape: self.f.grad_slice_tape(storage), + axes: self.axes, + transform: self.transform, + } + } + + /// Computes a simplified tape using the given trace, and reusing storage + pub fn simplify( + &self, + trace: &F::Trace, + storage: F::Storage, + workspace: &mut F::Workspace, + ) -> Result + where + Self: Sized, + { + let f = self.f.simplify(trace, storage, workspace)?; + Ok(Self { + f, + axes: self.axes, + transform: self.transform, + }) + } + + /// Attempt to reclaim storage from this shape + /// + /// This may fail, because shapes are `Clone` and are often implemented + /// using an `Arc` around a heavier data structure. + pub fn recycle(self) -> Option { + self.f.recycle() + } + + /// Returns a size associated with this shape + /// + /// This is underspecified and only used for unit testing; for tape-based + /// shapes, it's typically the length of the tape, + pub fn size(&self) -> usize { + self.f.size() + } +} + +impl Shape { + /// Borrows the inner [`Function`](Function) object + pub fn inner(&self) -> &F { + &self.f + } + + /// Borrows the inner axis mapping + pub fn axes(&self) -> &[Option; 3] { + &self.axes + } + + /// Raw constructor + pub fn new_raw(f: F, axes: [Option; 3]) -> Self { + Self { + f, + axes, + transform: None, + } + } + /// Returns a shape with the given transform applied + pub fn apply_transform(mut self, mat: Matrix4) -> Self { + if let Some(prev) = self.transform.as_mut() { + *prev *= mat; + } else { + self.transform = Some(mat); + } + self + } +} + +/// Extension trait for working with a shape without thinking much about memory +/// +/// All of the [`Shape`] functions that use significant amounts of memory +/// pedantically require you to pass in storage for reuse. This trait allows +/// you to ignore that, at the cost of performance; we require that all storage +/// types implement [`Default`], so these functions do the boilerplate for you. +/// +/// This trait is automatically implemented for every [`Shape`], but must be +/// imported separately as a speed-bump to using it everywhere. +pub trait EzShape { + /// Returns an evaluation tape for a point evaluator + fn ez_point_tape( + &self, + ) -> ShapeTape<::Tape>; + + /// Returns an evaluation tape for an interval evaluator + fn ez_interval_tape( + &self, + ) -> ShapeTape<::Tape>; + + /// Returns an evaluation tape for a float slice evaluator + fn ez_float_slice_tape( + &self, + ) -> ShapeTape<::Tape>; + + /// Returns an evaluation tape for a float slice evaluator + fn ez_grad_slice_tape( + &self, + ) -> ShapeTape<::Tape>; + + /// Computes a simplified tape using the given trace + fn ez_simplify(&self, trace: &F::Trace) -> Result + where + Self: Sized; +} + +impl EzShape for Shape { + fn ez_point_tape( + &self, + ) -> ShapeTape<::Tape> { + self.point_tape(Default::default()) + } + + fn ez_interval_tape( + &self, + ) -> ShapeTape<::Tape> { + self.interval_tape(Default::default()) + } + + fn ez_float_slice_tape( + &self, + ) -> ShapeTape<::Tape> { + self.float_slice_tape(Default::default()) + } + + fn ez_grad_slice_tape( + &self, + ) -> ShapeTape<::Tape> { + self.grad_slice_tape(Default::default()) + } + + fn ez_simplify(&self, trace: &F::Trace) -> Result { + let mut workspace = Default::default(); + self.simplify(trace, Default::default(), &mut workspace) + } +} + /// Hints for how to render this particular type pub trait RenderHints { /// Recommended tile sizes for 3D rendering @@ -20,3 +289,247 @@ pub trait RenderHints { true } } + +impl Shape { + /// Builds a new shape from a math expression with the given axes + pub fn new_with_axes( + ctx: &Context, + node: Node, + axes: [Node; 3], + ) -> Result { + let (f, vs) = F::new(ctx, node)?; + let x = ctx.var_name(axes[0])?.ok_or(Error::NotAVar)?; + let y = ctx.var_name(axes[1])?.ok_or(Error::NotAVar)?; + let z = ctx.var_name(axes[2])?.ok_or(Error::NotAVar)?; + Ok(Self { + f, + axes: [x, y, z].map(|v| vs.get(v).cloned()), + transform: None, + }) + } + + /// Builds a new shape from the given node with default (X, Y, Z) axes + pub fn new(ctx: &mut Context, node: Node) -> Result + where + Self: Sized, + { + let axes = ctx.axes(); + Self::new_with_axes(ctx, node, axes) + } +} + +/// Converts a [`Tree`] to a [`Shape`] with the default axes +impl From for Shape { + fn from(t: Tree) -> Self { + let mut ctx = Context::new(); + let node = ctx.import(&t); + Self::new(&mut ctx, node).unwrap() + } +} + +/// Wrapper around a function tape, with axes and an optional transform matrix +pub struct ShapeTape { + tape: T, + + /// Index of the X, Y, Z axes in the variables array + axes: [Option; 3], + + /// Optional transform + transform: Option>, +} + +impl ShapeTape { + /// Recycles the inner tape's storage for reuse + pub fn recycle(self) -> T::Storage { + self.tape.recycle() + } +} + +/// Wrapper around a [`TracingEvaluator`] +/// +/// Unlike the raw tracing evaluator, a [`ShapeTracingEval`] knows about the +/// tape's X, Y, Z axes and optional transform matrix. +#[derive(Debug, Default)] +pub struct ShapeTracingEval { + eval: E, +} + +impl ShapeTracingEval +where + ::Data: Transformable, +{ + /// Tracing evaluation of a single sample + /// + /// Before evaluation, the tape's transform matrix is applied (if present). + pub fn eval>( + &mut self, + tape: &ShapeTape, + x: F, + y: F, + z: F, + ) -> Result<(E::Data, Option<&E::Trace>), Error> { + let x = x.into(); + let y = y.into(); + let z = z.into(); + let (x, y, z) = if let Some(mat) = tape.transform { + Transformable::transform(x, y, z, mat) + } else { + (x, y, z) + }; + + let mut vars = [None, None, None]; + if let Some(a) = tape.axes[0] { + vars[a] = Some(x); + } + if let Some(b) = tape.axes[1] { + vars[b] = Some(y); + } + if let Some(c) = tape.axes[2] { + vars[c] = Some(z); + } + let n = vars.iter().position(Option::is_none).unwrap_or(3); + let vars = vars.map(|v| v.unwrap_or(0f32.into())); + self.eval.eval(&tape.tape, &vars[..n]) + } + + #[cfg(test)] + pub fn eval_x>( + &mut self, + tape: &ShapeTape, + x: J, + ) -> E::Data { + self.eval(tape, x.into(), E::Data::from(0.0), E::Data::from(0.0)) + .unwrap() + .0 + } + #[cfg(test)] + pub fn eval_xy>( + &mut self, + tape: &ShapeTape, + x: J, + y: J, + ) -> E::Data { + self.eval(tape, x.into(), y.into(), E::Data::from(0.0)) + .unwrap() + .0 + } +} + +/// Wrapper around a [`BulkEvaluator`] +/// +/// Unlike the raw bulk evaluator, a [`ShapeBulkEval`] knows about the +/// tape's X, Y, Z axes and optional transform matrix. +#[derive(Debug, Default)] +pub struct ShapeBulkEval { + eval: E, + xs: Vec, + ys: Vec, + zs: Vec, +} + +impl ShapeBulkEval +where + E::Data: From + Transformable, +{ + /// Bulk evaluation of many samples + /// + /// Before evaluation, the tape's transform matrix is applied (if present). + pub fn eval( + &mut self, + tape: &ShapeTape, + x: &[E::Data], + y: &[E::Data], + z: &[E::Data], + ) -> Result<&[E::Data], Error> { + let (xs, ys, zs) = if let Some(mat) = tape.transform { + if x.len() != y.len() || x.len() != z.len() { + return Err(Error::MismatchedSlices); + } + let n = x.len(); + self.xs.resize(n, 0.0.into()); + self.ys.resize(n, 0.0.into()); + self.zs.resize(n, 0.0.into()); + for i in 0..n { + let (x, y, z) = Transformable::transform(x[i], y[i], z[i], mat); + self.xs[i] = x; + self.ys[i] = y; + self.zs[i] = z; + } + (self.xs.as_slice(), self.ys.as_slice(), self.zs.as_slice()) + } else { + (x, y, z) + }; + let mut vars = [None, None, None]; + if let Some(a) = tape.axes[0] { + vars[a] = Some(xs); + } + if let Some(b) = tape.axes[1] { + vars[b] = Some(ys); + } + if let Some(c) = tape.axes[2] { + vars[c] = Some(zs); + } + let n = vars.iter().position(|v| v.is_none()).unwrap_or(3); + let vars = if vars.iter().all(Option::is_some) { + vars.map(Option::unwrap) + } else if let Some(q) = vars.iter().find(|v| v.is_some()) { + vars.map(|v| v.unwrap_or_else(|| q.unwrap())) + } else { + [[].as_slice(); 3] + }; + + self.eval.eval(&tape.tape, &vars[..n]) + } +} + +/// Trait for types that can be transformed by a 4x4 homogenous transform matrix +pub trait Transformable { + /// Apply the given transform to an `(x, y, z)` position + fn transform( + x: Self, + y: Self, + z: Self, + mat: Matrix4, + ) -> (Self, Self, Self) + where + Self: Sized; +} + +impl Transformable for f32 { + fn transform(x: f32, y: f32, z: f32, mat: Matrix4) -> (f32, f32, f32) { + let out = mat.transform_point(&Point3::new(x, y, z)); + (out.x, out.y, out.z) + } +} + +impl Transformable for Interval { + fn transform( + x: Interval, + y: Interval, + z: Interval, + mat: Matrix4, + ) -> (Interval, Interval, Interval) { + let out = [0, 1, 2, 3].map(|i| { + let row = mat.row(i); + x * row[0] + y * row[1] + z * row[2] + Interval::from(row[3]) + }); + + (out[0] / out[3], out[1] / out[3], out[2] / out[3]) + } +} + +impl Transformable for Grad { + fn transform( + x: Grad, + y: Grad, + z: Grad, + mat: Matrix4, + ) -> (Grad, Grad, Grad) { + let out = [0, 1, 2, 3].map(|i| { + let row = mat.row(i); + x * row[0] + y * row[1] + z * row[2] + Grad::from(row[3]) + }); + + (out[0] / out[3], out[1] / out[3], out[2] / out[3]) + } +} diff --git a/fidget/src/core/eval/transform.rs b/fidget/src/core/shape/transform.rs similarity index 98% rename from fidget/src/core/eval/transform.rs rename to fidget/src/core/shape/transform.rs index d5684b61..c9b13087 100644 --- a/fidget/src/core/eval/transform.rs +++ b/fidget/src/core/shape/transform.rs @@ -1,6 +1,7 @@ use crate::{ - eval::{BulkEvaluator, Grad, Interval, Shape, Tape, TracingEvaluator}, - shape::RenderHints, + eval::Tape, + shape::{BulkEvaluator, RenderHints, Shape, TracingEvaluator}, + types::{Grad, Interval}, Error, }; use nalgebra::{Matrix4, Point3, Vector3}; diff --git a/fidget/src/core/vm/data.rs b/fidget/src/core/vm/data.rs index 421b8aba..d60375c0 100644 --- a/fidget/src/core/vm/data.rs +++ b/fidget/src/core/vm/data.rs @@ -2,6 +2,7 @@ use crate::{ compiler::{RegOp, RegTape, RegisterAllocator, SsaOp, SsaTape}, context::{Context, Node}, + eval::VarMap, vm::Choice, Error, }; @@ -38,19 +39,19 @@ use serde::{Deserialize, Serialize}; /// ``` /// use fidget::{ /// compiler::RegOp, -/// context::{Context, Tree}, +/// context::{Context, Tree, Var}, /// vm::VmData, /// }; /// /// let tree = Tree::x() + Tree::y(); /// let mut ctx = Context::new(); /// let sum = ctx.import(&tree); -/// let data = VmData::<255>::new(&ctx, sum)?; +/// let (data, vars) = VmData::<255>::new(&ctx, sum)?; /// assert_eq!(data.len(), 3); // X, Y, and (X + Y) /// /// let mut iter = data.iter_asm(); -/// assert_eq!(iter.next().unwrap(), RegOp::Input(0, 0)); -/// assert_eq!(iter.next().unwrap(), RegOp::Input(1, 1)); +/// assert_eq!(iter.next().unwrap(), RegOp::Input(0, vars[&Var::X] as u8)); +/// assert_eq!(iter.next().unwrap(), RegOp::Input(1, vars[&Var::Y] as u8)); /// assert_eq!(iter.next().unwrap(), RegOp::AddRegReg(0, 0, 1)); /// # Ok::<(), fidget::Error>(()) /// ``` @@ -66,10 +67,10 @@ pub struct VmData { impl VmData { /// Builds a new tape for the given node - pub fn new(context: &Context, node: Node) -> Result { - let ssa = SsaTape::new(context, node)?; + pub fn new(context: &Context, node: Node) -> Result<(Self, VarMap), Error> { + let (ssa, vs) = SsaTape::new(context, node)?; let asm = RegTape::new::(&ssa); - Ok(Self { ssa, asm }) + Ok((Self { ssa, asm }, vs)) } /// Returns the length of the internal VM tape diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index 78f54fb7..aee2363a 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -3,14 +3,13 @@ use crate::{ compiler::RegOp, context::Node, eval::{ - BulkEvaluator, MathShape, Shape, Tape, Trace, TracingEvaluator, - TransformedShape, + BulkEvaluator, Function, MathFunction, Tape, Trace, TracingEvaluator, + VarMap, }, - shape::RenderHints, + shape::{RenderHints, Shape}, types::{Grad, Interval}, Context, Error, }; -use nalgebra::Matrix4; use std::sync::Arc; mod choice; @@ -21,17 +20,19 @@ pub use data::{VmData, VmWorkspace}; //////////////////////////////////////////////////////////////////////////////// -/// Shape that use a VM backend for evaluation +/// Function which uses the VM backend for evaluation /// -/// Internally, the [`VmShape`] stores an [`Arc`](VmData), and +/// Internally, the [`VmFunction`] stores an [`Arc`](VmData), and /// iterates over a [`Vec`](RegOp) to perform evaluation. /// /// All of the associated [`Tape`] types simply clone the internal `Arc`; /// there's no separate planning required to generate a tape. -/// -pub type VmShape = GenericVmShape<{ u8::MAX as usize }>; +pub type VmFunction = GenericVmFunction<{ u8::MAX as usize }>; + +/// Shape that use a the [`VmFunction`] backend for evaluation +pub type VmShape = Shape; -impl Tape for GenericVmShape { +impl Tape for GenericVmFunction { type Storage = (); fn recycle(self) -> Self::Storage { // nothing to do here @@ -93,15 +94,15 @@ impl AsRef<[Choice]> for VmTrace { /// You are unlikely to use this directly; [`VmShape`] should be used for /// VM-based evaluation. #[derive(Clone)] -pub struct GenericVmShape(Arc>); +pub struct GenericVmFunction(Arc>); -impl From> for GenericVmShape { +impl From> for GenericVmFunction { fn from(d: VmData) -> Self { Self(d.into()) } } -impl GenericVmShape { +impl GenericVmFunction { pub(crate) fn simplify_inner( &self, choices: &[Choice], @@ -137,7 +138,7 @@ impl GenericVmShape { } } -impl Shape for GenericVmShape { +impl Function for GenericVmFunction { type FloatSliceEval = VmFloatSliceEval; type Storage = VmData; type Workspace = VmWorkspace; @@ -170,20 +171,15 @@ impl Shape for GenericVmShape { } fn recycle(self) -> Option { - GenericVmShape::recycle(self) + GenericVmFunction::recycle(self) } fn size(&self) -> usize { - GenericVmShape::size(self) - } - - type TransformedShape = TransformedShape; - fn apply_transform(self, mat: Matrix4) -> Self::TransformedShape { - TransformedShape::new(self, mat) + GenericVmFunction::size(self) } } -impl RenderHints for GenericVmShape { +impl RenderHints for GenericVmFunction { fn tile_sizes_3d() -> &'static [usize] { &[256, 128, 64, 32, 16, 8] } @@ -193,10 +189,10 @@ impl RenderHints for GenericVmShape { } } -impl MathShape for GenericVmShape { - fn new(ctx: &Context, node: Node) -> Result { - let d = VmData::new(ctx, node)?; - Ok(Self(Arc::new(d))) +impl MathFunction for GenericVmFunction { + fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> { + let (d, vs) = VmData::new(ctx, node)?; + Ok((Self(Arc::new(d)), vs)) } } @@ -257,22 +253,17 @@ impl + Clone> TracingVmEval { pub struct VmIntervalEval(TracingVmEval); impl TracingEvaluator for VmIntervalEval { type Data = Interval; - type Tape = GenericVmShape; + type Tape = GenericVmFunction; type Trace = VmTrace; type TapeStorage = (); - fn eval>( + fn eval( &mut self, tape: &Self::Tape, - x: F, - y: F, - z: F, + vars: &[Interval], ) -> Result<(Interval, Option<&VmTrace>), Error> { - let x = x.into(); - let y = y.into(); - let z = z.into(); let tape = tape.0.as_ref(); - self.check_arguments(tape.var_count())?; + self.check_arguments(vars, tape.var_count())?; self.0.resize_slots(tape); let mut simplify = false; @@ -281,12 +272,7 @@ impl TracingEvaluator for VmIntervalEval { for op in tape.iter_asm() { match op { RegOp::Input(out, i) => { - v[out] = match i { - 0 => x, - 1 => y, - 2 => z, - _ => panic!("Invalid input: {}", i), - } + v[out] = vars[i as usize]; } RegOp::NegReg(out, arg) => { v[out] = -v[arg]; @@ -496,22 +482,17 @@ impl TracingEvaluator for VmIntervalEval { pub struct VmPointEval(TracingVmEval); impl TracingEvaluator for VmPointEval { type Data = f32; - type Tape = GenericVmShape; + type Tape = GenericVmFunction; type Trace = VmTrace; type TapeStorage = (); - fn eval>( + fn eval( &mut self, tape: &Self::Tape, - x: F, - y: F, - z: F, + vars: &[f32], ) -> Result<(f32, Option<&VmTrace>), Error> { - let x = x.into(); - let y = y.into(); - let z = z.into(); let tape = tape.0.as_ref(); - self.check_arguments(tape.var_count())?; + self.check_arguments(vars, tape.var_count())?; self.0.resize_slots(tape); let mut choices = self.0.choices.as_mut_slice().iter_mut(); @@ -520,12 +501,7 @@ impl TracingEvaluator for VmPointEval { for op in tape.iter_asm() { match op { RegOp::Input(out, i) => { - v[out] = match i { - 0 => x, - 1 => y, - 2 => z, - _ => panic!("Invalid input: {}", i), - } + v[out] = vars[i as usize]; } RegOp::NegReg(out, arg) => { v[out] = -v[arg]; @@ -593,7 +569,7 @@ impl TracingEvaluator for VmPointEval { v[out] = imm / v[arg]; } RegOp::AtanRegImm(out, arg, imm) => { - v[out] = v[arg].atan2(imm.into()); + v[out] = v[arg].atan2(imm); } RegOp::AtanImmReg(out, arg, imm) => { v[out] = imm.atan2(v[arg]); @@ -821,34 +797,25 @@ impl + Clone> BulkVmEval { pub struct VmFloatSliceEval(BulkVmEval); impl BulkEvaluator for VmFloatSliceEval { type Data = f32; - type Tape = GenericVmShape; + type Tape = GenericVmFunction; type TapeStorage = (); fn eval( &mut self, tape: &Self::Tape, - xs: &[f32], - ys: &[f32], - zs: &[f32], + vars: &[&[f32]], ) -> Result<&[f32], Error> { let tape = tape.0.as_ref(); - self.check_arguments(xs, ys, zs, tape.var_count())?; - self.0.resize_slots(tape, xs.len()); - assert_eq!(xs.len(), ys.len()); - assert_eq!(ys.len(), zs.len()); + self.check_arguments(vars, tape.var_count())?; - let size = xs.len(); + let size = vars.first().map(|v| v.len()).unwrap_or(0); + self.0.resize_slots(tape, size); let mut v = SlotArray(&mut self.0.slots); for op in tape.iter_asm() { match op { RegOp::Input(out, i) => { - v[out][0..size].copy_from_slice(match i { - 0 => xs, - 1 => ys, - 2 => zs, - _ => panic!("Invalid input: {}", i), - }) + v[out][0..size].copy_from_slice(vars[i as usize]); } RegOp::NegReg(out, arg) => { for i in 0..size { @@ -1139,35 +1106,24 @@ impl BulkEvaluator for VmFloatSliceEval { pub struct VmGradSliceEval(BulkVmEval); impl BulkEvaluator for VmGradSliceEval { type Data = Grad; - type Tape = GenericVmShape; + type Tape = GenericVmFunction; type TapeStorage = (); fn eval( &mut self, tape: &Self::Tape, - xs: &[Grad], - ys: &[Grad], - zs: &[Grad], + vars: &[&[Grad]], ) -> Result<&[Grad], Error> { let tape = tape.0.as_ref(); - self.check_arguments(xs, ys, zs, tape.var_count())?; - self.0.resize_slots(tape, xs.len()); - assert_eq!(xs.len(), ys.len()); - assert_eq!(ys.len(), zs.len()); + self.check_arguments(vars, tape.var_count())?; + let size = vars.first().map(|v| v.len()).unwrap_or(0); + self.0.resize_slots(tape, size); - let size = xs.len(); let mut v = SlotArray(&mut self.0.slots); for op in tape.iter_asm() { match op { - RegOp::Input(out, j) => { - for i in 0..size { - v[out][i] = match j { - 0 => xs[i], - 1 => ys[i], - 2 => zs[i], - _ => panic!("Invalid input: {}", i), - } - } + RegOp::Input(out, i) => { + v[out][0..size].copy_from_slice(vars[i as usize]); } RegOp::NegReg(out, arg) => { for i in 0..size { @@ -1476,8 +1432,8 @@ impl BulkEvaluator for VmGradSliceEval { #[cfg(test)] mod test { use super::*; - crate::grad_slice_tests!(VmShape); - crate::interval_tests!(VmShape); - crate::float_slice_tests!(VmShape); - crate::point_tests!(VmShape); + crate::grad_slice_tests!(VmFunction); + crate::interval_tests!(VmFunction); + crate::float_slice_tests!(VmFunction); + crate::point_tests!(VmFunction); } diff --git a/fidget/src/error.rs b/fidget/src/error.rs index 7603000d..953daa96 100644 --- a/fidget/src/error.rs +++ b/fidget/src/error.rs @@ -11,6 +11,10 @@ pub enum Error { #[error("variable is not present in this `Context`")] BadVar, + /// The given node does not have an associated variable + #[error("node does not have an associated variable")] + NotAVar, + /// `Context` is empty #[error("`Context` is empty")] EmptyContext, diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index c183b730..f619aa86 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -1,17 +1,17 @@ //! Compilation down to native machine code //! -//! Users are unlikely to use anything in this module other than [`JitShape`], -//! which is a [`Shape`] that uses JIT evaluation. +//! Users are unlikely to use anything in this module other than [`JitFunction`], +//! which is a [`Function`] that uses JIT evaluation. //! //! ``` //! use fidget::{ //! context::Tree, -//! eval::{TracingEvaluator, Shape, MathShape, EzShape}, -//! jit::JitShape +//! shape::EzShape, +//! jit::JitShape, //! }; //! //! let tree = Tree::x() + Tree::y(); -//! let shape = JitShape::from_tree(&tree); +//! let shape = JitShape::from(tree); //! //! // Generate machine code to execute the tape //! let tape = shape.ez_point_tape(); @@ -27,20 +27,18 @@ use crate::{ compiler::RegOp, context::{Context, Node}, eval::{ - BulkEvaluator, MathShape, Shape, Tape, TracingEvaluator, - TransformedShape, + BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator, VarMap, }, jit::mmap::Mmap, shape::RenderHints, types::{Grad, Interval}, - vm::{Choice, GenericVmShape, VmData, VmTrace, VmWorkspace}, + vm::{Choice, GenericVmFunction, VmData, VmTrace, VmWorkspace}, Error, }; use dynasmrt::{ components::PatchLoc, dynasm, AssemblyOffset, DynamicLabel, DynasmApi, DynasmError, DynasmLabelApi, TargetKind, }; -use nalgebra::Matrix4; mod mmap; @@ -827,11 +825,11 @@ fn build_asm_fn_with_storage( // JIT execute mode is restored here when the _guard is dropped } -/// Shape for use with a JIT evaluator +/// Function for use with a JIT evaluator #[derive(Clone)] -pub struct JitShape(GenericVmShape); +pub struct JitFunction(GenericVmFunction); -impl JitShape { +impl JitFunction { fn tracing_tape( &self, storage: Mmap, @@ -856,7 +854,7 @@ impl JitShape { } } -impl Shape for JitShape { +impl Function for JitFunction { type Trace = VmTrace; type Storage = VmData; type Workspace = VmWorkspace; @@ -892,7 +890,7 @@ impl Shape for JitShape { ) -> Result { self.0 .simplify_inner(trace.as_slice(), storage, workspace) - .map(JitShape) + .map(JitFunction) } fn recycle(self) -> Option { @@ -902,14 +900,9 @@ impl Shape for JitShape { fn size(&self) -> usize { self.0.size() } - - type TransformedShape = TransformedShape; - fn apply_transform(self, mat: Matrix4) -> Self::TransformedShape { - TransformedShape::new(self, mat) - } } -impl RenderHints for JitShape { +impl RenderHints for JitFunction { fn tile_sizes_3d() -> &'static [usize] { &[64, 16, 8] } @@ -954,7 +947,7 @@ macro_rules! jit_fn { /// Evaluator for a JIT-compiled tracing function /// /// Users are unlikely to use this directly, but it's public because it's an -/// associated type on [`JitShape`]. +/// associated type on [`JitFunction`]. #[derive(Default)] struct JitTracingEval { choices: VmTrace, @@ -989,21 +982,15 @@ unsafe impl Sync for JitTracingFn {} impl JitTracingEval { /// Evaluates a single point, capturing an evaluation trace - fn eval, F: Into>( + fn eval( &mut self, tape: &JitTracingFn, - x: F, - y: F, - z: F, + vars: &[T], ) -> (T, Option<&VmTrace>) { - let x = x.into(); - let y = y.into(); - let z = z.into(); let mut simplify = 0; self.choices.resize(tape.choice_count, Choice::Unknown); assert!(tape.var_count <= 3); self.choices.fill(Choice::Unknown); - let vars = [x, y, z]; let out = unsafe { (tape.fn_trace)( vars.as_ptr(), @@ -1031,14 +1018,12 @@ impl TracingEvaluator for JitIntervalEval { type Trace = VmTrace; type TapeStorage = Mmap; - fn eval>( + fn eval( &mut self, tape: &Self::Tape, - x: F, - y: F, - z: F, + vars: &[Self::Data], ) -> Result<(Self::Data, Option<&Self::Trace>), Error> { - Ok(self.0.eval(tape, x, y, z)) + Ok(self.0.eval(tape, vars)) } } @@ -1051,14 +1036,12 @@ impl TracingEvaluator for JitPointEval { type Trace = VmTrace; type TapeStorage = Mmap; - fn eval>( + fn eval( &mut self, tape: &Self::Tape, - x: F, - y: F, - z: F, + vars: &[Self::Data], ) -> Result<(Self::Data, Option<&Self::Trace>), Error> { - Ok(self.0.eval(tape, x, y, z)) + Ok(self.0.eval(tape, vars)) } } @@ -1085,15 +1068,38 @@ impl Tape for JitBulkFn { } } +/// Maximum SIMD width for any type, checked at runtime (alas) +/// +/// We can't use T::SIMD_SIZE directly here due to Rust limitations. Instead we +/// hard-code a maximum SIMD size along with an assertion that should be +/// optimized out; we can't use a constant assertion here due to the same +/// compiler limitations. +const MAX_SIMD_WIDTH: usize = 8; + /// Bulk evaluator for JIT functions struct JitBulkEval { + /// Array of pointers used when calling into the JIT function + ptrs: Vec<*const T>, + + /// Scratch array for evaluation of less-than-SIMD-size slices + scratch: Vec<[T; MAX_SIMD_WIDTH]>, + /// Output array that's written to during evaluation out: Vec, } +// SAFETY: the pointers in `JitBulkEval` are transient and only scoped to a +// single evaluation. +unsafe impl Sync for JitBulkEval {} +unsafe impl Send for JitBulkEval {} + impl Default for JitBulkEval { fn default() -> Self { - Self { out: vec![] } + Self { + out: vec![], + scratch: vec![], + ptrs: vec![], + } } } @@ -1104,15 +1110,9 @@ unsafe impl Sync for JitBulkFn {} impl + Copy + SimdSize> JitBulkEval { /// Evaluate multiple points - fn eval( - &mut self, - tape: &JitBulkFn, - xs: &[T], - ys: &[T], - zs: &[T], - ) -> &[T] { + fn eval(&mut self, tape: &JitBulkFn, vars: &[&[T]]) -> &[T] { assert!(tape.var_count <= 3); - let n = xs.len(); + let n = vars.first().map(|v| v.len()).unwrap_or(0); self.out.resize(n, f32::NAN.into()); self.out.fill(f32::NAN.into()); @@ -1120,51 +1120,50 @@ impl + Copy + SimdSize> JitBulkEval { // in which case the input slices can't be used as workspace (because // they are not valid for the entire range of values read in assembly) if n < T::SIMD_SIZE { - // We can't use T::SIMD_SIZE directly here due to Rust limitations. - // Instead we hard-code a maximum SIMD size along with an assertion - // that should be optimized out; we can't use a constant assertion - // here due to the same compiler limitations. - const MAX_SIMD_WIDTH: usize = 8; - let mut x = [T::from(0.0); MAX_SIMD_WIDTH]; - let mut y = [T::from(0.0); MAX_SIMD_WIDTH]; - let mut z = [T::from(0.0); MAX_SIMD_WIDTH]; assert!(T::SIMD_SIZE <= MAX_SIMD_WIDTH); - x[0..n].copy_from_slice(xs); - y[0..n].copy_from_slice(ys); - z[0..n].copy_from_slice(zs); + self.scratch.resize(n, [T::from(0.0); MAX_SIMD_WIDTH]); + for (v, t) in vars.iter().zip(self.scratch.iter_mut()) { + t[0..n].copy_from_slice(v); + } + + self.ptrs.clear(); + self.ptrs.extend(self.scratch.iter().map(|t| t.as_ptr())); - let mut tmp = [f32::NAN.into(); MAX_SIMD_WIDTH]; - let vars = [x.as_ptr(), y.as_ptr(), z.as_ptr()]; + let mut out = [f32::NAN.into(); MAX_SIMD_WIDTH]; unsafe { (tape.fn_bulk)( - vars.as_ptr(), - tmp.as_mut_ptr(), + self.ptrs.as_ptr(), + out.as_mut_ptr(), T::SIMD_SIZE as u64, ); } - self.out.copy_from_slice(&tmp[0..n]); + self.out.copy_from_slice(&out[0..n]); } else { // Our vectorized function only accepts sets of a particular width, // so we'll find the biggest multiple, then do an extra operation to // process any remainders. let m = (n / T::SIMD_SIZE) * T::SIMD_SIZE; // Round down - let vars = [xs.as_ptr(), ys.as_ptr(), zs.as_ptr()]; + self.ptrs.clear(); + self.ptrs.extend(vars.iter().map(|v| v.as_ptr())); unsafe { - (tape.fn_bulk)(vars.as_ptr(), self.out.as_mut_ptr(), m as u64); + (tape.fn_bulk)( + self.ptrs.as_ptr(), + self.out.as_mut_ptr(), + m as u64, + ); } // If we weren't given an even multiple of vector width, then we'll // handle the remaining items by simply evaluating the *last* full // vector in the array again. if n != m { + self.ptrs.clear(); unsafe { - let vars = [ - xs.as_ptr().add(n - T::SIMD_SIZE), - ys.as_ptr().add(n - T::SIMD_SIZE), - zs.as_ptr().add(n - T::SIMD_SIZE), - ]; + self.ptrs.extend( + vars.iter().map(|v| v.as_ptr().add(n - T::SIMD_SIZE)), + ); (tape.fn_bulk)( - vars.as_ptr(), + self.ptrs.as_ptr(), self.out.as_mut_ptr().add(n - T::SIMD_SIZE), T::SIMD_SIZE as u64, ); @@ -1186,12 +1185,10 @@ impl BulkEvaluator for JitFloatSliceEval { fn eval( &mut self, tape: &Self::Tape, - xs: &[f32], - ys: &[f32], - zs: &[f32], + vars: &[&[Self::Data]], ) -> Result<&[Self::Data], Error> { - self.check_arguments(xs, ys, zs, tape.var_count)?; - Ok(self.0.eval(tape, xs, ys, zs)) + self.check_arguments(vars, tape.var_count)?; + Ok(self.0.eval(tape, vars)) } } @@ -1206,28 +1203,30 @@ impl BulkEvaluator for JitGradSliceEval { fn eval( &mut self, tape: &Self::Tape, - xs: &[Self::Data], - ys: &[Self::Data], - zs: &[Self::Data], + vars: &[&[Self::Data]], ) -> Result<&[Self::Data], Error> { - self.check_arguments(xs, ys, zs, tape.var_count)?; - Ok(self.0.eval(tape, xs, ys, zs)) + self.check_arguments(vars, tape.var_count)?; + Ok(self.0.eval(tape, vars)) } } -impl MathShape for JitShape { - fn new(ctx: &Context, node: Node) -> Result { - GenericVmShape::new(ctx, node).map(JitShape) +impl MathFunction for JitFunction { + fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> { + let (f, vars) = GenericVmFunction::new(ctx, node)?; + Ok((JitFunction(f), vars)) } } +/// A [`Shape`](crate::shape::Shape) which uses the JIT evaluator +pub type JitShape = crate::shape::Shape; + //////////////////////////////////////////////////////////////////////////////// #[cfg(test)] mod test { use super::*; - crate::grad_slice_tests!(JitShape); - crate::interval_tests!(JitShape); - crate::float_slice_tests!(JitShape); - crate::point_tests!(JitShape); + crate::grad_slice_tests!(JitFunction); + crate::interval_tests!(JitFunction); + crate::float_slice_tests!(JitFunction); + crate::point_tests!(JitFunction); } diff --git a/fidget/src/lib.rs b/fidget/src/lib.rs index 1ecba9ca..4c19bd96 100644 --- a/fidget/src/lib.rs +++ b/fidget/src/lib.rs @@ -68,7 +68,7 @@ //! //! Evaluation is deliberately agnostic to the specific details of how we go //! from position to results. This abstraction is represented by the -//! [`Shape` trait](crate::eval::Shape), which defines how to make both +//! [`Function` trait](crate::eval::Function), which defines how to make both //! **evaluators** and **tapes**. //! //! An **evaluator** is an object which performs evaluation of some kind (point, @@ -77,22 +77,22 @@ //! //! A **tape** contains instructions for an evaluator. //! -//! At the moment, Fidget implements two kinds of shapes: +//! At the moment, Fidget implements two kinds of functions: //! -//! - [`fidget::vm::VmShape`](crate::vm::VmShape) evaluates a list of opcodes -//! using an interpreter. This is slower, but can run in more situations -//! (e.g. in WebAssembly). -//! - [`fidget::jit::JitShape`](crate::jit::JitShape) performs fast evaluation -//! by compiling shapes down to native code. +//! - [`fidget::vm::VmFunction`](crate::vm::VmFunction) evaluates a list of +//! opcodes using an interpreter. This is slower, but can run in more +//! situations (e.g. in WebAssembly). +//! - [`fidget::jit::JitFunction`](crate::jit::JitFunction) performs fast +//! evaluation by compiling expressions down to native code. //! -//! The [`eval::Shape`](crate::eval::Shape) trait requires four different kinds +//! The [`Function`](crate::eval::Function) trait requires four different kinds //! of evaluation: //! //! - Single-point evaluation //! - Interval evaluation //! - Evaluation on an array of points, returning `f32` values //! - Evaluation on an array of points, returning partial derivatives with -//! respect to `x, y, z` +//! respect to input variables //! //! These evaluation flavors are used in rendering: //! - Interval evaluation can conservatively prove large regions of space to be @@ -103,16 +103,24 @@ //! - At the surface of the model, partial derivatives represent normals and //! can be used for shading. //! -//! Here's a simple example of interval evaluation: +//! # Functions and shapes +//! The [`Function`](crate::eval::Function) trait supports arbitrary numbers of +//! varibles; when using it for implicit surfaces, it's common to wrap it in a +//! [`Shape`](crate::shape::Shape), which binds `(x, y, z)` axes to specific +//! variables. +//! +//! Here's a simple example of interval evaluation, using a `Shape` to wrap a +//! function and evaluate it at a particular `(x, y, z)` position: +//! //! ``` //! use fidget::{ //! context::Tree, -//! eval::{Shape, MathShape, EzShape, TracingEvaluator}, +//! shape::{Shape, EzShape}, //! vm::VmShape //! }; //! //! let tree = Tree::x() + Tree::y(); -//! let shape = VmShape::from_tree(&tree); +//! let shape = VmShape::from(tree); //! let mut interval_eval = VmShape::new_interval_eval(); //! let tape = shape.ez_interval_tape(); //! let (out, _trace) = interval_eval.eval( @@ -136,12 +144,12 @@ //! ``` //! use fidget::{ //! context::Tree, -//! eval::{TracingEvaluator, Shape, MathShape, EzShape}, +//! shape::EzShape, //! vm::VmShape //! }; //! //! let tree = Tree::x().min(Tree::y()); -//! let shape = VmShape::from_tree(&tree); +//! let shape = VmShape::from(tree); //! let mut interval_eval = VmShape::new_interval_eval(); //! let tape = shape.ez_interval_tape(); //! let (out, trace) = interval_eval.eval( @@ -165,11 +173,12 @@ //! ``` //! # use fidget::{ //! # context::Tree, -//! # eval::{TracingEvaluator, Shape, MathShape, EzShape}, +//! # shape::EzShape, //! # vm::VmShape //! # }; //! # let tree = Tree::x().min(Tree::y()); -//! # let shape = VmShape::from_tree(&tree); +//! # let shape = VmShape::from(tree); +//! assert_eq!(shape.size(), 3); // min, X, Y //! # let mut interval_eval = VmShape::new_interval_eval(); //! # let tape = shape.ez_interval_tape(); //! # let (out, trace) = interval_eval.eval( @@ -179,9 +188,8 @@ //! # [0.0, 0.0], // Z //! # )?; //! // (same code as above) -//! assert_eq!(tape.size(), 3); //! let new_shape = shape.ez_simplify(trace.unwrap())?; -//! assert_eq!(new_shape.ez_interval_tape().size(), 1); // just the 'X' term +//! assert_eq!(new_shape.size(), 1); // just the X term //! # Ok::<(), fidget::Error>(()) //! ``` //! @@ -197,7 +205,6 @@ //! ``` //! use fidget::{ //! context::{Tree, Context}, -//! eval::MathShape, //! render::{BitRenderMode, RenderConfig}, //! vm::VmShape, //! }; @@ -209,7 +216,7 @@ //! image_size: 32, //! ..RenderConfig::default() //! }; -//! let shape = VmShape::from_tree(&tree); +//! let shape = VmShape::from(tree); //! let out = cfg.run::<_, BitRenderMode>(shape)?; //! let mut iter = out.iter(); //! for y in 0..cfg.image_size { diff --git a/fidget/src/mesh/mod.rs b/fidget/src/mesh/mod.rs index 39d4125f..d77fd467 100644 --- a/fidget/src/mesh/mod.rs +++ b/fidget/src/mesh/mod.rs @@ -3,7 +3,7 @@ //! This module implements //! [Manifold Dual Contouring](https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf), //! to generate a triangle mesh from an implicit surface (or anything -//! implementing [`Shape`](crate::eval::Shape)). +//! implementing [`Shape`](crate::shape::Shape)). //! //! The resulting meshes should be //! - Manifold @@ -19,13 +19,12 @@ //! //! ``` //! use fidget::{ -//! eval::MathShape, //! mesh::{Octree, Settings}, //! vm::VmShape //! }; //! //! let tree = fidget::rhai::eval("sphere(0, 0, 0, 0.6)")?; -//! let shape = VmShape::from_tree(&tree); +//! let shape = VmShape::from(tree); //! let settings = Settings { //! depth: 4, //! ..Default::default() diff --git a/fidget/src/mesh/mt/octree.rs b/fidget/src/mesh/mt/octree.rs index e5e1631f..c9331ab7 100644 --- a/fidget/src/mesh/mt/octree.rs +++ b/fidget/src/mesh/mt/octree.rs @@ -1,7 +1,7 @@ //! Multithreaded octree construction use super::pool::{QueuePool, ThreadContext, ThreadPool}; use crate::{ - eval::Shape, + eval::Function, mesh::{ cell::{Cell, CellData, CellIndex}, octree::{BranchResult, CellResult, EvalGroup, OctreeBuilder}, @@ -18,22 +18,22 @@ use std::sync::{mpsc::TryRecvError, Arc}; /// octants, sending results back to the parent (which is numbered implicitly /// based on what queue we stole this from). #[derive(Clone)] -struct Task { - data: Arc>, +struct Task { + data: Arc>, } -impl std::ops::Deref for Task { - type Target = TaskData; +impl std::ops::Deref for Task { + type Target = TaskData; fn deref(&self) -> &Self::Target { &self.data } } -impl Task { +impl Task { /// Builds a new root task /// /// The root task is from worker 0 with the default cell index - fn new(eval: Arc>) -> Self { + fn new(eval: Arc>) -> Self { Self { data: Arc::new(TaskData { eval, @@ -46,7 +46,7 @@ impl Task { fn child( &self, - eval: Arc>, + eval: Arc>, target_cell: CellIndex, assigned_by: usize, ) -> Self { @@ -61,8 +61,8 @@ impl Task { } } -struct TaskData { - eval: Arc>, +struct TaskData { + eval: Arc>, /// Thread in which the parent cell lives assigned_by: usize, @@ -70,12 +70,12 @@ struct TaskData { /// Parent cell, which must be an `Invalid` cell waiting for population target_cell: CellIndex, - parent: Option>>, + parent: Option>>, } -struct Done { +struct Done { /// The task that we have finished evaluating - task: Task, + task: Task, /// The resulting cell /// @@ -89,7 +89,7 @@ struct Done { completed_by: usize, } -pub struct OctreeWorker { +pub struct OctreeWorker { /// Global index of this worker thread /// /// For example, this is the thread's own index in `friend_queue` and @@ -101,24 +101,24 @@ pub struct OctreeWorker { /// This octree may not be complete; worker 0 is guaranteed to contain the /// root, and other works may contain fragmentary branches that point to /// each other in a tree structure. - octree: OctreeBuilder, + octree: OctreeBuilder, /// Incoming completed tasks from other threads - done: std::sync::mpsc::Receiver>, + done: std::sync::mpsc::Receiver>, /// Our queue of tasks - queue: QueuePool>, + queue: QueuePool>, /// When a worker finishes a task, it returns it through these queues /// /// Like `friend_queue`, there's one per thread, including the worker's own /// thread; it would be silly to send stuff back to your own thread via the /// queue (rather than storing it directly). - friend_done: Vec>>, + friend_done: Vec>>, } -impl OctreeWorker { - pub fn scheduler(eval: Arc>, settings: Settings) -> Octree { +impl OctreeWorker { + pub fn scheduler(eval: Arc>, settings: Settings) -> Octree { let task_queues = QueuePool::new(settings.threads()); let done_queues = std::iter::repeat_with(std::sync::mpsc::channel) .take(settings.threads()) @@ -250,13 +250,13 @@ impl OctreeWorker { self.octree.into() } - fn reclaim(&mut self, task: Task) { + fn reclaim(&mut self, task: Task) { if let Ok(t) = Arc::try_unwrap(task.data) { self.reclaim_inner(t) } } - fn reclaim_inner(&mut self, mut t: TaskData) { + fn reclaim_inner(&mut self, mut t: TaskData) { // Try recycling the tapes, if no one else is using them if let Ok(e) = Arc::try_unwrap(t.eval) { self.octree.reclaim(e); @@ -271,7 +271,7 @@ impl OctreeWorker { fn on_done( &mut self, result: BranchResult, - task: &Arc>, + task: &Arc>, completed_by: usize, ctx: &mut ThreadContext, ) { @@ -306,7 +306,7 @@ impl OctreeWorker { &mut self, index: usize, cell: CellData, - parent_task: &Arc>, + parent_task: &Arc>, ctx: &mut ThreadContext, ) { self.octree.record(index, cell); diff --git a/fidget/src/mesh/octree.rs b/fidget/src/mesh/octree.rs index 0aca834e..7b922d8a 100644 --- a/fidget/src/mesh/octree.rs +++ b/fidget/src/mesh/octree.rs @@ -11,8 +11,8 @@ use super::{ Mesh, Settings, }; use crate::{ - eval::{BulkEvaluator, Shape, Tape, TracingEvaluator}, - shape::RenderHints, + eval::{BulkEvaluator, Function, TracingEvaluator}, + shape::{RenderHints, Shape, ShapeBulkEval, ShapeTape, ShapeTracingEval}, types::Grad, }; use std::{num::NonZeroUsize, sync::Arc, sync::OnceLock}; @@ -20,22 +20,26 @@ use std::{num::NonZeroUsize, sync::Arc, sync::OnceLock}; #[cfg(not(target_arch = "wasm32"))] use super::mt::{DcWorker, OctreeWorker}; +// TODO use fidget::render::RenderHandle here instead? /// Helper struct to contain a set of matched evaluators /// /// Note that this is `Send + Sync` and can be used with shared references! -pub struct EvalGroup { - pub shape: S, +pub struct EvalGroup { + pub shape: Shape, // TODO: passing around an `Arc` ends up with two layers of // indirection (since the tapes also contain `Arc`); could we flatten // them out? (same with the shape, which is usually an `Arc`) - pub interval: OnceLock<::Tape>, - pub float_slice: OnceLock<::Tape>, - pub grad_slice: OnceLock<::Tape>, + pub interval: + OnceLock::Tape>>, + pub float_slice: + OnceLock::Tape>>, + pub grad_slice: + OnceLock::Tape>>, } -impl EvalGroup { - fn new(shape: S) -> Self { +impl EvalGroup { + fn new(shape: Shape) -> Self { Self { shape, interval: OnceLock::new(), @@ -45,16 +49,16 @@ impl EvalGroup { } fn interval_tape( &self, - storage: &mut Vec, - ) -> &::Tape { + storage: &mut Vec, + ) -> &ShapeTape<::Tape> { self.interval.get_or_init(|| { self.shape.interval_tape(storage.pop().unwrap_or_default()) }) } fn float_slice_tape( &self, - storage: &mut Vec, - ) -> &::Tape { + storage: &mut Vec, + ) -> &ShapeTape<::Tape> { self.float_slice.get_or_init(|| { self.shape .float_slice_tape(storage.pop().unwrap_or_default()) @@ -62,8 +66,8 @@ impl EvalGroup { } fn grad_slice_tape( &self, - storage: &mut Vec, - ) -> &::Tape { + storage: &mut Vec, + ) -> &ShapeTape<::Tape> { self.grad_slice.get_or_init(|| { self.shape .grad_slice_tape(storage.pop().unwrap_or_default()) @@ -91,13 +95,10 @@ impl Octree { /// Builds an octree to the given depth /// /// The shape is evaluated on the region specified by `settings.bounds`. - pub fn build( - shape: &S, + pub fn build( + shape: &Shape, settings: Settings, - ) -> Self - where - ::TransformedShape: RenderHints, - { + ) -> Self { // Transform the shape given our bounds let t = settings.bounds.transform(); if t == nalgebra::Transform::identity() { @@ -116,8 +117,8 @@ impl Octree { } } - fn build_inner( - shape: &S, + fn build_inner( + shape: &Shape, settings: Settings, ) -> Self { let eval = Arc::new(EvalGroup::new(shape.clone())); @@ -238,7 +239,7 @@ impl std::ops::IndexMut for Octree { /// Data structure for an under-construction octree #[derive(Debug)] -pub(crate) struct OctreeBuilder { +pub(crate) struct OctreeBuilder { /// Internal octree /// /// Note that in this internal octree, the `index` field of leaf nodes @@ -262,23 +263,23 @@ pub(crate) struct OctreeBuilder { /// Available slots in the `hermite` array hermite_slots: Vec, - eval_float_slice: S::FloatSliceEval, - eval_interval: S::IntervalEval, - eval_grad_slice: S::GradSliceEval, + eval_float_slice: ShapeBulkEval, + eval_interval: ShapeTracingEval, + eval_grad_slice: ShapeBulkEval, - pub tape_storage: Vec, - pub shape_storage: Vec, - workspace: S::Workspace, + pub tape_storage: Vec, + pub shape_storage: Vec, + workspace: F::Workspace, } -impl Default for OctreeBuilder { +impl Default for OctreeBuilder { fn default() -> Self { Self::new() } } -impl From> for Octree { - fn from(o: OctreeBuilder) -> Self { +impl From> for Octree { + fn from(o: OctreeBuilder) -> Self { // Convert from "leaf index into self.leafs" (in the builder) to // "leaf index into self.verts" (in the resulting Octree) let cells = @@ -303,7 +304,7 @@ impl From> for Octree { } } -impl OctreeBuilder { +impl OctreeBuilder { /// Builds a new octree, which allocates data for 8 root cells pub(crate) fn new() -> Self { Self { @@ -314,9 +315,9 @@ impl OctreeBuilder { leafs: vec![], hermite: vec![LeafHermiteData::default()], hermite_slots: vec![], - eval_float_slice: S::new_float_slice_eval(), - eval_grad_slice: S::new_grad_slice_eval(), - eval_interval: S::new_interval_eval(), + eval_float_slice: Shape::::new_float_slice_eval(), + eval_grad_slice: Shape::::new_grad_slice_eval(), + eval_interval: Shape::::new_interval_eval(), tape_storage: vec![], shape_storage: vec![], workspace: Default::default(), @@ -347,10 +348,10 @@ impl OctreeBuilder { /// octree (e.g. on another thread). pub(crate) fn eval_cell( &mut self, - eval: &Arc>, + eval: &Arc>, cell: CellIndex, settings: Settings, - ) -> CellResult { + ) -> CellResult { let (i, r) = self .eval_interval .eval( @@ -365,7 +366,7 @@ impl OctreeBuilder { } else if i.lower() > 0.0 { CellResult::Done(Cell::Empty) } else { - let sub_tape = if S::simplify_tree_during_meshing(cell.depth) { + let sub_tape = if F::simplify_tree_during_meshing(cell.depth) { let s = self.shape_storage.pop().unwrap_or_default(); r.map(|r| { Arc::new(EvalGroup::new( @@ -426,7 +427,7 @@ impl OctreeBuilder { /// Recurse down the octree, building the given cell fn recurse( &mut self, - eval: &Arc>, + eval: &Arc>, cell: CellIndex, settings: Settings, ) { @@ -468,7 +469,7 @@ impl OctreeBuilder { /// Writes the leaf vertex to `self.o.verts`, hermite data to /// `self.hermite`, and the leaf data to `self.leafs`. Does **not** write /// anything to `self.o.cells`; the cell is returned instead. - fn leaf(&mut self, eval: &EvalGroup, cell: CellIndex) -> Cell { + fn leaf(&mut self, eval: &EvalGroup, cell: CellIndex) -> Cell { let mut xs = [0.0; 8]; let mut ys = [0.0; 8]; let mut zs = [0.0; 8]; @@ -895,7 +896,7 @@ impl OctreeBuilder { CELL_TO_VERT_TO_EDGES[mask as usize].len() == 1 } - pub(crate) fn reclaim(&mut self, mut e: EvalGroup) { + pub(crate) fn reclaim(&mut self, mut e: EvalGroup) { if let Some(s) = e.shape.recycle() { self.shape_storage.push(s); } @@ -913,7 +914,7 @@ impl OctreeBuilder { /// `OctreeBuilder` functions which are only used during multithreaded rendering #[cfg(not(target_arch = "wasm32"))] -impl OctreeBuilder { +impl OctreeBuilder { /// Builds a new empty octree /// /// This still allocates data to reserve the lowest slot in `hermite` @@ -927,9 +928,9 @@ impl OctreeBuilder { hermite: vec![LeafHermiteData::default()], hermite_slots: vec![], - eval_float_slice: S::new_float_slice_eval(), - eval_grad_slice: S::new_grad_slice_eval(), - eval_interval: S::new_interval_eval(), + eval_float_slice: Shape::::new_float_slice_eval(), + eval_grad_slice: Shape::::new_grad_slice_eval(), + eval_interval: Shape::::new_interval_eval(), tape_storage: vec![], shape_storage: vec![], @@ -951,9 +952,9 @@ impl OctreeBuilder { } /// Result of a single cell evaluation -pub enum CellResult { +pub enum CellResult { Done(Cell), - Recurse(Arc>), + Recurse(Arc>), } /// Result of a branch evaluation (8-fold division) @@ -1169,10 +1170,9 @@ mod test { use super::*; use crate::{ context::Tree, - eval::{EzShape, MathShape}, mesh::types::{Edge, X, Y, Z}, - shape::Bounds, - vm::VmShape, + shape::{Bounds, EzShape}, + vm::{VmFunction, VmShape}, }; use nalgebra::Vector3; use std::collections::BTreeMap; @@ -1215,7 +1215,7 @@ mod test { fn test_cube_edge() { const EPSILON: f32 = 1e-3; let f = 2.0; - let shape = VmShape::from_tree(&cube([-f, f], [-f, 0.3], [-f, 0.6])); + let shape = VmShape::from(cube([-f, f], [-f, 0.3], [-f, 0.6])); // This should be a cube with a single edge running through the root // node of the octree, with an edge vertex at [0, 0.3, 0.6] let octree = Octree::build(&shape, DEPTH0_SINGLE_THREAD); @@ -1264,7 +1264,7 @@ mod test { #[test] fn test_mesh_basic() { - let shape = VmShape::from_tree(&sphere([0.0; 3], 0.2)); + let shape = VmShape::from(sphere([0.0; 3], 0.2)); // If we only build a depth-0 octree, then it's a leaf without any // vertices (since all the corners are empty) @@ -1307,7 +1307,7 @@ mod test { #[test] fn test_sphere_verts() { - let shape = VmShape::from_tree(&sphere([0.0; 3], 0.2)); + let shape = VmShape::from(sphere([0.0; 3], 0.2)); let octree = Octree::build(&shape, DEPTH1_SINGLE_THREAD); let sphere_mesh = octree.walk_dual(DEPTH1_SINGLE_THREAD); @@ -1343,7 +1343,7 @@ mod test { #[test] fn test_sphere_manifold() { - let shape = VmShape::from_tree(&sphere([0.0; 3], 0.85)); + let shape = VmShape::from(sphere([0.0; 3], 0.85)); for threads in [1, 8] { let settings = Settings { @@ -1371,8 +1371,7 @@ mod test { #[test] fn test_cube_verts() { - let shape = - VmShape::from_tree(&cube([-0.1, 0.6], [-0.2, 0.75], [-0.3, 0.4])); + let shape = VmShape::from(cube([-0.1, 0.6], [-0.2, 0.75], [-0.3, 0.4])); let octree = Octree::build(&shape, DEPTH1_SINGLE_THREAD); let mesh = octree.walk_dual(DEPTH1_SINGLE_THREAD); @@ -1422,7 +1421,7 @@ mod test { for offset in [0.0, -0.2, 0.2] { let (x, y, z) = Tree::axes(); let f = x * dx + y * dy + z + offset; - let shape = VmShape::from_tree(&f); + let shape = VmShape::from(f); let octree = Octree::build(&shape, DEPTH0_SINGLE_THREAD); assert_eq!(octree.cells.len(), 8); @@ -1457,7 +1456,7 @@ mod test { nalgebra::Vector3::new(1.2, 1.3, 1.4), ] { let corner = nalgebra::Vector3::new(-1.0, -1.0, -1.0); - let shape = VmShape::from_tree(&cone(corner, tip, 0.1)); + let shape = VmShape::from(cone(corner, tip, 0.1)); let mut eval = VmShape::new_point_eval(); let tape = shape.ez_point_tape(); @@ -1498,7 +1497,7 @@ mod test { // Now, we have our shape, which is 0-8 spheres placed at the // corners of the cell spanning [0, 0.25] - let shape = VmShape::from_tree(&shape); + let shape = VmShape::from(shape); let settings = Settings { depth: 2, threads: threads.try_into().unwrap(), @@ -1536,8 +1535,11 @@ mod test { #[test] fn test_collapsible() { - fn builder(shape: Tree, settings: Settings) -> OctreeBuilder { - let shape = VmShape::from_tree(&shape); + fn builder( + shape: Tree, + settings: Settings, + ) -> OctreeBuilder { + let shape = VmShape::from(shape); let eval = Arc::new(EvalGroup::new(shape)); let mut out = OctreeBuilder::new(); out.recurse(&eval, CellIndex::default(), settings); @@ -1566,7 +1568,7 @@ mod test { #[test] fn test_empty_collapse() { // Make a very smol sphere that won't be sampled - let shape = VmShape::from_tree(&sphere([0.1; 3], 0.05)); + let shape = VmShape::from(sphere([0.1; 3], 0.05)); for threads in [1, 4] { let settings = Settings { depth: 1, @@ -1585,9 +1587,9 @@ mod test { #[test] fn test_colonnade_manifold() { const COLONNADE: &str = include_str!("../../../models/colonnade.vm"); - let (ctx, root) = + let (mut ctx, root) = crate::Context::from_text(COLONNADE.as_bytes()).unwrap(); - let tape = VmShape::new(&ctx, root).unwrap(); + let tape = VmShape::new(&mut ctx, root).unwrap(); for threads in [1, 8] { let settings = Settings { depth: 5, @@ -1682,7 +1684,7 @@ mod test { #[test] fn test_qef_near_planar() { - let shape = VmShape::from_tree(&sphere([0.0; 3], 0.75)); + let shape = VmShape::from(sphere([0.0; 3], 0.75)); let settings = Settings { depth: 4, @@ -1699,7 +1701,7 @@ mod test { #[test] fn test_octree_bounds() { - let shape = VmShape::from_tree(&sphere([1.0; 3], 0.25)); + let shape = VmShape::from(sphere([1.0; 3], 0.25)); let center = Vector3::new(1.0, 1.0, 1.0); let settings = Settings { diff --git a/fidget/src/render/config.rs b/fidget/src/render/config.rs index 18ddd9fb..5fd964c1 100644 --- a/fidget/src/render/config.rs +++ b/fidget/src/render/config.rs @@ -1,4 +1,9 @@ -use crate::{eval::Shape, render::RenderMode, shape::Bounds, Error}; +use crate::{ + eval::Function, + render::RenderMode, + shape::{Bounds, Shape}, + Error, +}; use nalgebra::{ allocator::Allocator, Const, DefaultAllocator, DimNameAdd, DimNameSub, DimNameSum, U1, @@ -210,11 +215,11 @@ impl RenderConfig<2> { /// /// Under the hood, this delegates to /// [`fidget::render::render2d`](crate::render::render2d()) - pub fn run( + pub fn run( &self, - shape: S, + shape: Shape, ) -> Result::Output>, Error> { - Ok(crate::render::render2d::(shape, self)) + Ok(crate::render::render2d::(shape, self)) } } @@ -225,11 +230,11 @@ impl RenderConfig<3> { /// [`fidget::render::render3d`](crate::render::render3d()) /// /// Returns a tuple of heightmap, RGB image. - pub fn run( + pub fn run( &self, - shape: S, + shape: Shape, ) -> Result<(Vec, Vec<[u8; 3]>), Error> { - Ok(crate::render::render3d::(shape, self)) + Ok(crate::render::render3d::(shape, self)) } } diff --git a/fidget/src/render/mod.rs b/fidget/src/render/mod.rs index 1b6fe38d..b4e67091 100644 --- a/fidget/src/render/mod.rs +++ b/fidget/src/render/mod.rs @@ -4,7 +4,10 @@ //! [`RenderConfig::run`](RenderConfig::run); you can also use the lower-level //! functions ([`render2d`](render2d()) and [`render3d`](render3d())) for manual //! control over the input tape. -use crate::eval::{BulkEvaluator, Shape, Tape, Trace, TracingEvaluator}; +use crate::{ + eval::{BulkEvaluator, Function, Trace, TracingEvaluator}, + shape::{Shape, ShapeTape}, +}; use std::sync::Arc; mod config; @@ -25,17 +28,17 @@ pub use render2d::{ /// The tapes are stored as `Arc<..>`, so it can be cheaply cloned. /// /// The most recent simplification is cached for reuse (if the trace matches). -pub struct RenderHandle { - shape: S, +pub struct RenderHandle { + shape: Shape, - i_tape: Option::Tape>>, - f_tape: Option::Tape>>, - g_tape: Option::Tape>>, + i_tape: Option::Tape>>>, + f_tape: Option::Tape>>>, + g_tape: Option::Tape>>>, - next: Option<(S::Trace, Box)>, + next: Option<(F::Trace, Box)>, } -impl Clone for RenderHandle { +impl Clone for RenderHandle { fn clone(&self) -> Self { Self { shape: self.shape.clone(), @@ -47,14 +50,11 @@ impl Clone for RenderHandle { } } -impl RenderHandle -where - S: Shape, -{ +impl RenderHandle { /// Build a new [`RenderHandle`] for the given shape /// /// None of the tapes are populated here. - pub fn new(shape: S) -> Self { + pub fn new(shape: Shape) -> Self { Self { shape, i_tape: None, @@ -67,8 +67,8 @@ where /// Returns a tape for tracing interval evaluation pub fn i_tape( &mut self, - storage: &mut Vec, - ) -> &::Tape { + storage: &mut Vec, + ) -> &ShapeTape<::Tape> { self.i_tape.get_or_insert_with(|| { Arc::new( self.shape.interval_tape(storage.pop().unwrap_or_default()), @@ -79,8 +79,8 @@ where /// Returns a tape for bulk float evaluation pub fn f_tape( &mut self, - storage: &mut Vec, - ) -> &::Tape { + storage: &mut Vec, + ) -> &ShapeTape<::Tape> { self.f_tape.get_or_insert_with(|| { Arc::new( self.shape @@ -92,8 +92,8 @@ where /// Returns a tape for bulk gradient evaluation pub fn g_tape( &mut self, - storage: &mut Vec, - ) -> &::Tape { + storage: &mut Vec, + ) -> &ShapeTape<::Tape> { self.g_tape.get_or_insert_with(|| { Arc::new( self.shape @@ -108,10 +108,10 @@ where /// the trace matches. pub fn simplify( &mut self, - trace: &S::Trace, - workspace: &mut S::Workspace, - shape_storage: &mut Vec, - tape_storage: &mut Vec, + trace: &F::Trace, + workspace: &mut F::Workspace, + shape_storage: &mut Vec, + tape_storage: &mut Vec, ) -> &mut Self { // Free self.next if it doesn't match our new set of choices let mut trace_storage = if let Some(neighbor) = &self.next { @@ -165,8 +165,8 @@ where /// Recycles the entire handle into the given storage vectors pub fn recycle( mut self, - shape_storage: &mut Vec, - tape_storage: &mut Vec, + shape_storage: &mut Vec, + tape_storage: &mut Vec, ) { // Recycle the child first, in case it borrowed from us if let Some((_trace, shape)) = self.next.take() { diff --git a/fidget/src/render/render2d.rs b/fidget/src/render/render2d.rs index 18050196..7f2d177a 100644 --- a/fidget/src/render/render2d.rs +++ b/fidget/src/render/render2d.rs @@ -1,8 +1,9 @@ //! 2D bitmap rendering / rasterization use super::RenderHandle; use crate::{ - eval::{BulkEvaluator, Shape, TracingEvaluator}, + eval::Function, render::config::{AlignedRenderConfig, Queue, RenderConfig, Tile}, + shape::{Shape, ShapeBulkEval, ShapeTracingEval}, types::Interval, }; use nalgebra::Point2; @@ -200,29 +201,29 @@ impl Scratch { //////////////////////////////////////////////////////////////////////////////// /// Per-thread worker -struct Worker<'a, S: Shape, M: RenderMode> { +struct Worker<'a, F: Function, M: RenderMode> { config: &'a AlignedRenderConfig<2>, scratch: Scratch, - eval_float_slice: S::FloatSliceEval, - eval_interval: S::IntervalEval, + eval_float_slice: ShapeBulkEval, + eval_interval: ShapeTracingEval, /// Spare tape storage for reuse - tape_storage: Vec, + tape_storage: Vec, /// Spare shape storage for reuse - shape_storage: Vec, + shape_storage: Vec, /// Workspace for shape simplification - workspace: S::Workspace, + workspace: F::Workspace, image: Vec, } -impl Worker<'_, S, M> { +impl Worker<'_, F, M> { fn render_tile_recurse( &mut self, - shape: &mut RenderHandle, + shape: &mut RenderHandle, depth: usize, tile: Tile<2>, ) { @@ -310,7 +311,7 @@ impl Worker<'_, S, M> { fn render_tile_pixels( &mut self, - shape: &mut RenderHandle, + shape: &mut RenderHandle, tile_size: usize, tile: Tile<2>, ) { @@ -346,20 +347,20 @@ impl Worker<'_, S, M> { //////////////////////////////////////////////////////////////////////////////// -fn worker( - mut shape: RenderHandle, +fn worker( + mut shape: RenderHandle, queue: &Queue<2>, config: &AlignedRenderConfig<2>, ) -> Vec<(Tile<2>, Vec)> { let mut out = vec![]; let scratch = Scratch::new(config.tile_sizes.last().unwrap_or(&0).pow(2)); - let mut w: Worker = Worker { + let mut w: Worker = Worker { scratch, image: vec![], config, - eval_float_slice: S::FloatSliceEval::new(), - eval_interval: S::IntervalEval::new(), + eval_float_slice: Default::default(), + eval_interval: Default::default(), tape_storage: vec![], shape_storage: vec![], workspace: Default::default(), @@ -384,8 +385,8 @@ fn worker( /// This function is parameterized by both shape type (which determines how we /// perform evaluation) and render mode (which tells us how to color in the /// resulting pixels). -pub fn render( - shape: S, +pub fn render( + shape: Shape, config: &RenderConfig<2>, ) -> Vec { let (config, mat) = config.align(); @@ -402,8 +403,8 @@ pub fn render( render_inner::<_, M>(shape, config) } -fn render_inner( - shape: S, +fn render_inner( + shape: Shape, config: AlignedRenderConfig<2>, ) -> Vec { let mut tiles = vec![]; @@ -423,7 +424,7 @@ fn render_inner( let _ = rh.i_tape(&mut vec![]); // populate i_tape before cloning let out: Vec<_> = if threads == 1 { - worker::(rh, &queue, &config).into_iter().collect() + worker::(rh, &queue, &config).into_iter().collect() } else { #[cfg(target_arch = "wasm32")] unreachable!("multithreaded rendering is not supported on wasm32"); @@ -433,7 +434,7 @@ fn render_inner( let mut handles = vec![]; for _ in 0..threads { let rh = rh.clone(); - handles.push(s.spawn(|| worker::(rh, &queue, &config))); + handles.push(s.spawn(|| worker::(rh, &queue, &config))); } let mut out = vec![]; for h in handles { @@ -467,9 +468,9 @@ fn render_inner( mod test { use super::*; use crate::{ - eval::{MathShape, Shape}, - shape::Bounds, - vm::{GenericVmShape, VmShape}, + eval::{Function, MathFunction}, + shape::{Bounds, Shape}, + vm::{GenericVmFunction, VmFunction}, Context, }; @@ -480,8 +481,8 @@ mod test { "/../models/quarter.vm" )); - fn render_and_compare_with_bounds( - shape: S, + fn render_and_compare_with_bounds( + shape: Shape, expected: &'static str, bounds: Bounds<2>, ) { @@ -509,13 +510,16 @@ mod test { } } - fn render_and_compare(shape: S, expected: &'static str) { + fn render_and_compare( + shape: Shape, + expected: &'static str, + ) { render_and_compare_with_bounds(shape, expected, Bounds::default()) } - fn check_hi() { - let (ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); - let shape = S::new(&ctx, root).unwrap(); + fn check_hi() { + let (mut ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); + let shape = Shape::::new(&mut ctx, root).unwrap(); const EXPECTED: &str = " .................X.............. .................X.............. @@ -552,9 +556,9 @@ mod test { render_and_compare(shape, EXPECTED); } - fn check_hi_transformed() { - let (ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); - let shape = S::new(&ctx, root).unwrap(); + fn check_hi_transformed() { + let (mut ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); + let shape = Shape::::new(&mut ctx, root).unwrap(); let mut mat = nalgebra::Matrix4::::identity(); mat.prepend_translation_mut(&nalgebra::Vector3::new(0.5, 0.5, 0.0)); mat.prepend_scaling_mut(0.5); @@ -595,9 +599,9 @@ mod test { render_and_compare(shape, EXPECTED); } - fn check_hi_bounded() { - let (ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); - let shape = S::new(&ctx, root).unwrap(); + fn check_hi_bounded() { + let (mut ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); + let shape = Shape::::new(&mut ctx, root).unwrap(); const EXPECTED: &str = " .XXX............................ .XXX............................ @@ -641,9 +645,9 @@ mod test { ); } - fn check_quarter() { - let (ctx, root) = Context::from_text(QUARTER.as_bytes()).unwrap(); - let shape = S::new(&ctx, root).unwrap(); + fn check_quarter() { + let (mut ctx, root) = Context::from_text(QUARTER.as_bytes()).unwrap(); + let shape = Shape::::new(&mut ctx, root).unwrap(); const EXPECTED: &str = " ................................ ................................ @@ -682,65 +686,65 @@ mod test { #[test] fn render_hi_vm() { - check_hi::(); + check_hi::(); } #[test] fn render_hi_vm3() { - check_hi::>(); + check_hi::>(); } #[cfg(feature = "jit")] #[test] fn render_hi_jit() { - check_hi::(); + check_hi::(); } #[test] fn render_hi_transformed_vm() { - check_hi_transformed::(); + check_hi_transformed::(); } #[test] fn render_hi_transformed_vm3() { - check_hi_transformed::>(); + check_hi_transformed::>(); } #[cfg(feature = "jit")] #[test] fn render_hi_transformed_jit() { - check_hi_transformed::(); + check_hi_transformed::(); } #[test] fn render_hi_bounded_vm() { - check_hi_bounded::(); + check_hi_bounded::(); } #[test] fn render_hi_bounded_vm3() { - check_hi_bounded::>(); + check_hi_bounded::>(); } #[cfg(feature = "jit")] #[test] fn render_hi_bounded_jit() { - check_hi_bounded::(); + check_hi_bounded::(); } #[test] fn render_quarter_vm() { - check_quarter::(); + check_quarter::(); } #[test] fn render_quarter_vm3() { - check_quarter::>(); + check_quarter::>(); } #[cfg(feature = "jit")] #[test] fn render_quarter_jit() { - check_quarter::(); + check_quarter::(); } } diff --git a/fidget/src/render/render3d.rs b/fidget/src/render/render3d.rs index 0df2c78b..b4e8751c 100644 --- a/fidget/src/render/render3d.rs +++ b/fidget/src/render/render3d.rs @@ -1,8 +1,9 @@ //! 3D bitmap rendering / rasterization use super::RenderHandle; use crate::{ - eval::{BulkEvaluator, Shape, TracingEvaluator}, + eval::Function, render::config::{AlignedRenderConfig, Queue, RenderConfig, Tile}, + shape::{Shape, ShapeBulkEval, ShapeTracingEval}, types::{Grad, Interval}, }; @@ -44,29 +45,29 @@ impl Scratch { //////////////////////////////////////////////////////////////////////////////// -struct Worker<'a, S: Shape> { +struct Worker<'a, F: Function> { config: &'a AlignedRenderConfig<3>, /// Reusable workspace for evaluation, to minimize allocation scratch: Scratch, - eval_float_slice: S::FloatSliceEval, - eval_grad_slice: S::GradSliceEval, - eval_interval: S::IntervalEval, + eval_float_slice: ShapeBulkEval, + eval_grad_slice: ShapeBulkEval, + eval_interval: ShapeTracingEval, - tape_storage: Vec, - shape_storage: Vec, - workspace: S::Workspace, + tape_storage: Vec, + shape_storage: Vec, + workspace: F::Workspace, /// Output images for this specific tile depth: Vec, color: Vec<[u8; 3]>, } -impl Worker<'_, S> { +impl Worker<'_, F> { fn render_tile_recurse( &mut self, - shape: &mut RenderHandle, + shape: &mut RenderHandle, depth: usize, tile: Tile<3>, ) { @@ -143,7 +144,7 @@ impl Worker<'_, S> { fn render_tile_pixels( &mut self, - shape: &mut RenderHandle, + shape: &mut RenderHandle, tile_size: usize, tile: Tile<3>, ) { @@ -281,8 +282,8 @@ impl Image { //////////////////////////////////////////////////////////////////////////////// -fn worker( - mut shape: RenderHandle, +fn worker( + mut shape: RenderHandle, queues: &[Queue<3>], mut index: usize, config: &AlignedRenderConfig<3>, @@ -292,15 +293,15 @@ fn worker( // Calculate maximum evaluation buffer size let buf_size = *config.tile_sizes.last().unwrap(); let scratch = Scratch::new(buf_size); - let mut w: Worker = Worker { + let mut w: Worker = Worker { scratch, depth: vec![], color: vec![], config, - eval_float_slice: S::FloatSliceEval::new(), - eval_interval: S::IntervalEval::new(), - eval_grad_slice: S::GradSliceEval::new(), + eval_float_slice: Default::default(), + eval_interval: Default::default(), + eval_grad_slice: Default::default(), tape_storage: vec![], shape_storage: vec![], @@ -351,8 +352,8 @@ fn worker( /// /// This function is parameterized by shape type, which determines how we /// perform evaluation. -pub fn render( - shape: S, +pub fn render( + shape: Shape, config: &RenderConfig<3>, ) -> (Vec, Vec<[u8; 3]>) { let (config, mat) = config.align(); @@ -365,8 +366,8 @@ pub fn render( render_inner(shape, config) } -pub fn render_inner( - shape: S, +pub fn render_inner( + shape: Shape, config: AlignedRenderConfig<3>, ) -> (Vec, Vec<[u8; 3]>) { let mut tiles = vec![]; @@ -396,7 +397,7 @@ pub fn render_inner( // Special-case for single-threaded operation, to give simpler backtraces let out: Vec<_> = if threads == 1 { - worker::(rh, tile_queues.as_slice(), 0, &config) + worker::(rh, tile_queues.as_slice(), 0, &config) .into_iter() .collect() } else { @@ -411,7 +412,7 @@ pub fn render_inner( for i in 0..threads { let rh = rh.clone(); handles - .push(s.spawn(move || worker::(rh, queues, i, config))); + .push(s.spawn(move || worker::(rh, queues, i, config))); } let mut out = vec![]; for h in handles { @@ -448,14 +449,14 @@ pub fn render_inner( #[cfg(test)] mod test { use super::*; - use crate::{eval::MathShape, vm::VmShape, Context}; + use crate::{vm::VmShape, Context}; /// Make sure we don't crash if there's only a single tile #[test] fn test_tile_queues() { let mut ctx = Context::new(); let x = ctx.x(); - let shape = VmShape::new(&ctx, x).unwrap(); + let shape = VmShape::new(&mut ctx, x).unwrap(); let cfg = RenderConfig::<3> { image_size: 128, // very small! diff --git a/fidget/src/rhai/mod.rs b/fidget/src/rhai/mod.rs index b1d41035..1a66d64f 100644 --- a/fidget/src/rhai/mod.rs +++ b/fidget/src/rhai/mod.rs @@ -7,12 +7,12 @@ //! //! ``` //! use fidget::{ -//! eval::{Shape, MathShape, EzShape, TracingEvaluator}, +//! shape::EzShape, //! vm::VmShape, //! }; //! //! let tree = fidget::rhai::eval("x + y")?; -//! let shape = VmShape::from_tree(&tree); +//! let shape = VmShape::from(tree); //! let mut eval = VmShape::new_point_eval(); //! let tape = shape.ez_point_tape(); //! assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0)?.0, 3.0); @@ -24,16 +24,16 @@ //! //! ``` //! use fidget::{ -//! eval::{Shape, MathShape, EzShape, TracingEvaluator}, +//! shape::EzShape, //! vm::VmShape, //! rhai::Engine //! }; //! //! let mut engine = Engine::new(); -//! let out = engine.run("draw(x + y - 1)")?; +//! let mut out = engine.run("draw(x + y - 1)")?; //! //! assert_eq!(out.shapes.len(), 1); -//! let shape = VmShape::from_tree(&out.shapes[0].tree); +//! let shape = VmShape::from(out.shapes.pop().unwrap().tree); //! let mut eval = VmShape::new_point_eval(); //! let tape = shape.ez_point_tape(); //! assert_eq!(eval.eval(&tape, 0.5, 2.0, 0.0)?.0, 1.5); diff --git a/viewer/src/main.rs b/viewer/src/main.rs index 95f0aaba..5836be7c 100644 --- a/viewer/src/main.rs +++ b/viewer/src/main.rs @@ -75,15 +75,15 @@ struct RenderResult { image_size: usize, } -fn render_thread( +fn render_thread( cfg: Receiver, rx: Receiver>, tx: Sender>, wake: Sender<()>, ) -> Result<()> where - S: fidget::eval::Shape - + fidget::eval::MathShape + F: fidget::eval::Function + + fidget::eval::MathFunction + fidget::shape::RenderHints, { let mut config = None; @@ -128,7 +128,7 @@ where ); let render_start = std::time::Instant::now(); for s in out.shapes.iter() { - let tape = S::from_tree(&s.tree); + let tape = fidget::shape::Shape::::from(s.tree.clone()); render( &render_config.mode, tape, @@ -150,9 +150,9 @@ where } } -fn render( +fn render( mode: &RenderMode, - shape: S, + shape: fidget::shape::Shape, image_size: usize, color: [u8; 3], pixels: &mut [egui::Color32], @@ -161,7 +161,7 @@ fn render( RenderMode::TwoD(camera, mode) => { let config = RenderConfig { image_size, - tile_sizes: S::tile_sizes_2d().to_vec(), + tile_sizes: F::tile_sizes_2d().to_vec(), bounds: fidget::shape::Bounds { center: Vector2::new(camera.offset.x, camera.offset.y), size: camera.scale, @@ -213,7 +213,7 @@ fn render( RenderMode::ThreeD(camera, mode) => { let config = RenderConfig { image_size, - tile_sizes: S::tile_sizes_2d().to_vec(), + tile_sizes: F::tile_sizes_2d().to_vec(), bounds: fidget::shape::Bounds { center: Vector3::new(camera.offset.x, camera.offset.y, 0.0), size: camera.scale, @@ -281,17 +281,13 @@ fn main() -> Result<(), Box> { }); std::thread::spawn(move || { #[cfg(feature = "jit")] - type Shape = fidget::jit::JitShape; + type F = fidget::jit::JitFunction; #[cfg(not(feature = "jit"))] - type Shape = fidget::vm::VmShape; - - let _ = render_thread::( - config_rx, - rhai_result_rx, - render_tx, - wake_tx, - ); + type F = fidget::vm::VmFunction; + + let _ = + render_thread::(config_rx, rhai_result_rx, render_tx, wake_tx); info!("render thread is done"); }); diff --git a/wasm-demo/Cargo.lock b/wasm-demo/Cargo.lock index ff34b275..5a55a4da 100644 --- a/wasm-demo/Cargo.lock +++ b/wasm-demo/Cargo.lock @@ -255,7 +255,7 @@ dependencies = [ [[package]] name = "fidget" -version = "0.2.7" +version = "0.2.8" dependencies = [ "arrayvec", "bimap", diff --git a/wasm-demo/src/lib.rs b/wasm-demo/src/lib.rs index 3ada219a..7e0bfc10 100644 --- a/wasm-demo/src/lib.rs +++ b/wasm-demo/src/lib.rs @@ -1,6 +1,5 @@ use fidget::{ context::{Context, Tree}, - eval::MathShape, render::{BitRenderMode, RenderConfig}, shape::Bounds, vm::{VmData, VmShape}, @@ -29,16 +28,18 @@ pub fn eval_script(s: &str) -> Result { pub fn serialize_into_tape(t: JsTree) -> Result, String> { let mut ctx = Context::new(); let root = ctx.import(&t.0); - let shape = VmShape::new(&ctx, root).map_err(|e| format!("{e}"))?; - bincode::serialize(shape.data()).map_err(|e| format!("{e}")) + let shape = VmShape::new(&mut ctx, root).map_err(|e| format!("{e}"))?; + let vm_data = shape.inner().data(); + let axes = shape.axes(); + bincode::serialize(&(vm_data, axes)).map_err(|e| format!("{e}")) } /// Deserialize a `bincode`-packed `VmData` into a `VmShape` #[wasm_bindgen] pub fn deserialize_tape(data: Vec) -> Result { - let d: VmData<255> = + let (d, axes): (VmData<255>, [Option; 3]) = bincode::deserialize(&data).map_err(|e| format!("{e}"))?; - Ok(JsVmShape(VmShape::from(d))) + Ok(JsVmShape(VmShape::new_raw(d.into(), axes))) } /// Renders a subregion of an image, for webworker-based multithreading