From c775a5eece3449fc60186e0067c2aa234f0db917 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sun, 19 May 2024 09:21:16 -0400 Subject: [PATCH 01/12] Move Shape trait into fidget::shape --- demo/src/main.rs | 10 +- fidget/benches/function_call.rs | 2 +- fidget/benches/mesh.rs | 2 +- fidget/benches/render.rs | 2 +- fidget/src/core/context/mod.rs | 4 +- fidget/src/core/eval/mod.rs | 253 +----------------- fidget/src/core/eval/test/float_slice.rs | 2 +- fidget/src/core/eval/test/grad_slice.rs | 2 +- fidget/src/core/eval/test/interval.rs | 3 +- fidget/src/core/eval/test/point.rs | 2 +- fidget/src/core/mod.rs | 2 +- fidget/src/core/{eval => shape}/bulk.rs | 0 fidget/src/core/shape/mod.rs | 259 ++++++++++++++++++- fidget/src/core/{eval => shape}/tracing.rs | 4 +- fidget/src/core/{eval => shape}/transform.rs | 5 +- fidget/src/core/vm/mod.rs | 8 +- fidget/src/jit/mod.rs | 10 +- fidget/src/lib.rs | 12 +- fidget/src/mesh/mod.rs | 2 +- fidget/src/mesh/mt/octree.rs | 2 +- fidget/src/mesh/octree.rs | 6 +- fidget/src/render/config.rs | 6 +- fidget/src/render/mod.rs | 5 +- fidget/src/render/render2d.rs | 5 +- fidget/src/render/render3d.rs | 4 +- fidget/src/rhai/mod.rs | 4 +- viewer/src/main.rs | 6 +- 27 files changed, 319 insertions(+), 303 deletions(-) rename fidget/src/core/{eval => shape}/bulk.rs (100%) rename fidget/src/core/{eval => shape}/tracing.rs (96%) rename fidget/src/core/{eval => shape}/transform.rs (98%) diff --git a/demo/src/main.rs b/demo/src/main.rs index b037fdd7..6a89fc63 100644 --- a/demo/src/main.rs +++ b/demo/src/main.rs @@ -9,7 +9,7 @@ use log::info; use fidget::{ context::Context, - eval::{BulkEvaluator, MathShape}, + shape::{BulkEvaluator, MathShape}, }; /// Simple test program @@ -112,7 +112,7 @@ struct MeshSettings { } //////////////////////////////////////////////////////////////////////////////// -fn run3d( +fn run3d( shape: S, settings: &ImageSettings, isometric: bool, @@ -168,7 +168,7 @@ fn run3d( //////////////////////////////////////////////////////////////////////////////// -fn run2d( +fn run2d( shape: S, settings: &ImageSettings, brute: bool, @@ -236,12 +236,12 @@ fn run2d( //////////////////////////////////////////////////////////////////////////////// -fn run_mesh( +fn run_mesh( shape: S, settings: &MeshSettings, ) -> fidget::mesh::Mesh where - ::TransformedShape: fidget::shape::RenderHints, + ::TransformedShape: fidget::shape::RenderHints, { let mut mesh = fidget::mesh::Mesh::new(); diff --git a/fidget/benches/function_call.rs b/fidget/benches/function_call.rs index 10f428ac..3138bf3f 100644 --- a/fidget/benches/function_call.rs +++ b/fidget/benches/function_call.rs @@ -3,7 +3,7 @@ use criterion::{ }; use fidget::{ context::{Context, Node}, - eval::{BulkEvaluator, EzShape, MathShape, Shape}, + shape::{BulkEvaluator, EzShape, MathShape, Shape}, }; pub fn run_bench( diff --git a/fidget/benches/mesh.rs b/fidget/benches/mesh.rs index 63b488f4..53a2590e 100644 --- a/fidget/benches/mesh.rs +++ b/fidget/benches/mesh.rs @@ -1,7 +1,7 @@ use criterion::{ black_box, criterion_group, criterion_main, BenchmarkId, Criterion, }; -use fidget::eval::MathShape; +use fidget::shape::MathShape; const COLONNADE: &str = include_str!("../../models/colonnade.vm"); diff --git a/fidget/benches/render.rs b/fidget/benches/render.rs index 204d46c1..0cf13a70 100644 --- a/fidget/benches/render.rs +++ b/fidget/benches/render.rs @@ -4,7 +4,7 @@ use criterion::{ const PROSPERO: &str = include_str!("../../models/prospero.vm"); -use fidget::{eval::MathShape, shape::RenderHints}; +use fidget::shape::{MathShape, RenderHints}; pub fn prospero_size_sweep(c: &mut Criterion) { let (ctx, root) = fidget::Context::from_text(PROSPERO.as_bytes()).unwrap(); diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index efe6deeb..8063b09f 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -733,7 +733,7 @@ impl Context { /// Evaluates the given node with the provided values for X, Y, and Z. /// /// This is extremely inefficient; consider converting the node into a - /// [`Shape`](crate::eval::Shape) and using its evaluators instead. + /// [`Shape`](crate::shape::Shape) and using its evaluators instead. /// /// ``` /// # let mut ctx = fidget::context::Context::new(); @@ -762,7 +762,7 @@ impl Context { /// Evaluates the given node with a generic set of variables /// /// This is extremely inefficient; consider converting the node into a - /// [`Shape`](crate::eval::Shape) and using its evaluators instead. + /// [`Shape`](crate::shape::Shape) and using its evaluators instead. pub fn eval( &self, root: Node, diff --git a/fidget/src/core/eval/mod.rs b/fidget/src/core/eval/mod.rs index 1a10b1a1..d3d5b583 100644 --- a/fidget/src/core/eval/mod.rs +++ b/fidget/src/core/eval/mod.rs @@ -1,257 +1,8 @@ -//! Traits and data structures for evaluation -//! -//! There are a bunch of things in here, but the most important trait is -//! [`Shape`], followed by the evaluator traits ([`BulkEvaluator`] and -//! [`TracingEvaluator`]). -//! -//! ```rust -//! use fidget::vm::VmShape; -//! use fidget::context::Context; -//! use fidget::eval::{TracingEvaluator, Shape, MathShape, EzShape}; -//! -//! let mut ctx = Context::new(); -//! let x = ctx.x(); -//! let shape = VmShape::new(&ctx, x)?; -//! -//! // Let's build a single point evaluator: -//! let mut eval = VmShape::new_point_eval(); -//! let tape = shape.ez_point_tape(); -//! let (value, _trace) = eval.eval(&tape, 0.25, 0.0, 0.0)?; -//! assert_eq!(value, 0.25); -//! # Ok::<(), fidget::Error>(()) -//! ``` -use crate::{ - context::Node, - types::{Grad, Interval}, - Context, Error, -}; +//! Traits and data structures for function evaluation #[cfg(any(test, feature = "eval-tests"))] pub mod test; -mod bulk; -mod tracing; -mod transform; - -// Re-export a few things -pub use bulk::BulkEvaluator; -pub use tracing::TracingEvaluator; -pub use transform::TransformedShape; - -/// A shape represents an implicit surface -/// -/// It is mostly agnostic to _how_ that surface is represented; we simply -/// require that the shape can generate evaluators of various kinds. -/// -/// Shapes are shared between threads, so they should be cheap to clone. In -/// most cases, they're a thin wrapper around an `Arc<..>`. -pub trait Shape: Send + Sync + Clone { - /// Associated type traces collected during tracing evaluation - /// - /// This type must implement [`Eq`] so that traces can be compared; calling - /// [`Shape::simplify`] with traces that compare equal should produce an - /// identical result and may be cached. - type Trace: Clone + Eq + Send + Trace; - - /// Associated type for storage used by the shape itself - type Storage: Default + Send; - - /// Associated type for workspace used during shape simplification - type Workspace: Default + Send; - - /// Associated type for storage used by tapes - /// - /// For simplicity, we require that every tape use the same type for storage. - /// This could change in the future! - type TapeStorage: Default + Send; - - /// Associated type for single-point tracing evaluation - type PointEval: TracingEvaluator< - Data = f32, - Trace = Self::Trace, - TapeStorage = Self::TapeStorage, - > + Send - + Sync; - - /// Builds a new point evaluator - fn new_point_eval() -> Self::PointEval { - Self::PointEval::new() - } - - /// Associated type for single interval tracing evaluation - type IntervalEval: TracingEvaluator< - Data = Interval, - Trace = Self::Trace, - TapeStorage = Self::TapeStorage, - > + Send - + Sync; - - /// Builds a new interval evaluator - fn new_interval_eval() -> Self::IntervalEval { - Self::IntervalEval::new() - } - - /// Associated type for evaluating many points in one call - type FloatSliceEval: BulkEvaluator - + Send - + Sync; - - /// Builds a new float slice evaluator - fn new_float_slice_eval() -> Self::FloatSliceEval { - Self::FloatSliceEval::new() - } - - /// Associated type for evaluating many gradients in one call - type GradSliceEval: BulkEvaluator - + Send - + Sync; - - /// Builds a new gradient slice evaluator - fn new_grad_slice_eval() -> Self::GradSliceEval { - Self::GradSliceEval::new() - } - - /// Returns an evaluation tape for a point evaluator - fn point_tape( - &self, - storage: Self::TapeStorage, - ) -> ::Tape; - - /// Returns an evaluation tape for an interval evaluator - fn interval_tape( - &self, - storage: Self::TapeStorage, - ) -> ::Tape; - - /// Returns an evaluation tape for a float slice evaluator - fn float_slice_tape( - &self, - storage: Self::TapeStorage, - ) -> ::Tape; - - /// Returns an evaluation tape for a float slice evaluator - fn grad_slice_tape( - &self, - storage: Self::TapeStorage, - ) -> ::Tape; - - /// Computes a simplified tape using the given trace, and reusing storage - fn simplify( - &self, - trace: &Self::Trace, - storage: Self::Storage, - workspace: &mut Self::Workspace, - ) -> Result - where - Self: Sized; - - /// Attempt to reclaim storage from this shape - /// - /// This may fail, because shapes are `Clone` and are often implemented - /// using an `Arc` around a heavier data structure. - fn recycle(self) -> Option; - - /// Returns a size associated with this shape - /// - /// This is underspecified and only used for unit testing; for tape-based - /// shapes, it's typically the length of the tape, - fn size(&self) -> usize; - - /// Associated type returned when applying a transform - /// - /// This is normally [`TransformedShape`](TransformedShape), but if - /// `Self` is already `TransformedShape`, then the transform is stacked - /// (instead of creating a wrapped object). - type TransformedShape: Shape; - - /// Returns a shape with the given transform applied - fn apply_transform( - self, - mat: nalgebra::Matrix4, - ) -> ::TransformedShape; -} - -/// Extension trait for working with a shape without thinking much about memory -/// -/// All of the [`Shape`] functions that use significant amounts of memory -/// pedantically require you to pass in storage for reuse. This trait allows -/// you to ignore that, at the cost of performance; we require that all storage -/// types implement [`Default`], so these functions do the boilerplate for you. -/// -/// This trait is automatically implemented for every [`Shape`], but must be -/// imported separately as a speed-bump to using it everywhere. -pub trait EzShape: Shape { - /// Returns an evaluation tape for a point evaluator - fn ez_point_tape(&self) -> ::Tape; - - /// Returns an evaluation tape for an interval evaluator - fn ez_interval_tape( - &self, - ) -> ::Tape; - - /// Returns an evaluation tape for a float slice evaluator - fn ez_float_slice_tape( - &self, - ) -> ::Tape; - - /// Returns an evaluation tape for a float slice evaluator - fn ez_grad_slice_tape( - &self, - ) -> ::Tape; - - /// Computes a simplified tape using the given trace - fn ez_simplify(&self, trace: &Self::Trace) -> Result - where - Self: Sized; -} - -impl EzShape for S { - fn ez_point_tape(&self) -> ::Tape { - self.point_tape(Default::default()) - } - - fn ez_interval_tape( - &self, - ) -> ::Tape { - self.interval_tape(Default::default()) - } - - fn ez_float_slice_tape( - &self, - ) -> ::Tape { - self.float_slice_tape(Default::default()) - } - - fn ez_grad_slice_tape( - &self, - ) -> ::Tape { - self.grad_slice_tape(Default::default()) - } - - fn ez_simplify(&self, trace: &Self::Trace) -> Result { - let mut workspace = Default::default(); - self.simplify(trace, Default::default(), &mut workspace) - } -} - -/// A [`Shape`] which can be built from a math expression -pub trait MathShape { - /// Builds a new shape from the given context and node - fn new(ctx: &Context, node: Node) -> Result - where - Self: Sized; - - /// Helper function to build a shape from a [`Tree`](crate::context::Tree) - fn from_tree(t: &crate::context::Tree) -> Self - where - Self: Sized, - { - let mut ctx = Context::new(); - let node = ctx.import(t); - Self::new(&ctx, node).unwrap() - } -} - /// A tape represents something that can be evaluated by an evaluator /// /// The only property enforced on the trait is that we must have some way to @@ -270,7 +21,7 @@ pub trait Tape { /// /// The only property enforced on the trait is that we must have a way of /// reusing trace allocations. Because [`Trace`] implies `Clone` where it's -/// used in [`Shape`], this is trivial, but we can't provide a default +/// used in [`Function`], this is trivial, but we can't provide a default /// implementation because it would fall afoul of `impl` specialization. pub trait Trace { /// Copies the contents of `other` into `self` diff --git a/fidget/src/core/eval/test/float_slice.rs b/fidget/src/core/eval/test/float_slice.rs index 68ca1167..f20be90b 100644 --- a/fidget/src/core/eval/test/float_slice.rs +++ b/fidget/src/core/eval/test/float_slice.rs @@ -6,7 +6,7 @@ use super::{build_stress_fn, test_args, CanonicalBinaryOp, CanonicalUnaryOp}; use crate::{ context::Context, - eval::{BulkEvaluator, EzShape, MathShape, Shape}, + shape::{BulkEvaluator, EzShape, MathShape, Shape}, }; /// Helper struct to put constrains on our `Shape` object diff --git a/fidget/src/core/eval/test/grad_slice.rs b/fidget/src/core/eval/test/grad_slice.rs index 42ef1450..a60bf8c7 100644 --- a/fidget/src/core/eval/test/grad_slice.rs +++ b/fidget/src/core/eval/test/grad_slice.rs @@ -5,7 +5,7 @@ use super::{build_stress_fn, test_args, CanonicalBinaryOp, CanonicalUnaryOp}; use crate::{ context::Context, - eval::{BulkEvaluator, EzShape, MathShape, Shape}, + shape::{BulkEvaluator, EzShape, MathShape, Shape}, types::Grad, }; diff --git a/fidget/src/core/eval/test/interval.rs b/fidget/src/core/eval/test/interval.rs index 52147513..93ed22ff 100644 --- a/fidget/src/core/eval/test/interval.rs +++ b/fidget/src/core/eval/test/interval.rs @@ -9,7 +9,8 @@ use super::{ }; use crate::{ context::Context, - eval::{EzShape, MathShape, Shape, Tape, TracingEvaluator}, + eval::Tape, + shape::{EzShape, MathShape, Shape, TracingEvaluator}, types::Interval, vm::Choice, }; diff --git a/fidget/src/core/eval/test/point.rs b/fidget/src/core/eval/test/point.rs index 737036c1..18718b28 100644 --- a/fidget/src/core/eval/test/point.rs +++ b/fidget/src/core/eval/test/point.rs @@ -5,7 +5,7 @@ use super::{build_stress_fn, test_args, CanonicalBinaryOp, CanonicalUnaryOp}; use crate::{ context::Context, - eval::{EzShape, MathShape, Shape, TracingEvaluator}, + shape::{EzShape, MathShape, Shape, TracingEvaluator}, vm::Choice, }; diff --git a/fidget/src/core/mod.rs b/fidget/src/core/mod.rs index 1f1d8f57..8b0ad6f8 100644 --- a/fidget/src/core/mod.rs +++ b/fidget/src/core/mod.rs @@ -3,7 +3,7 @@ //! ``` //! use fidget::{ //! context::Context, -//! eval::{Shape, MathShape, EzShape, TracingEvaluator}, +//! shape::{Shape, MathShape, EzShape, TracingEvaluator}, //! vm::VmShape //! }; //! let mut ctx = Context::new(); diff --git a/fidget/src/core/eval/bulk.rs b/fidget/src/core/shape/bulk.rs similarity index 100% rename from fidget/src/core/eval/bulk.rs rename to fidget/src/core/shape/bulk.rs diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index eb602598..802fc554 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -1,6 +1,263 @@ -//! Shape-specific data types +//! Traits and data structures for shape evaluation +//! +//! There are a bunch of things in here, but the most important trait is +//! [`Shape`], followed by the evaluator traits ([`BulkEvaluator`] and +//! [`TracingEvaluator`]). +//! +//! ```rust +//! use fidget::vm::VmShape; +//! use fidget::context::Context; +//! use fidget::shape::{TracingEvaluator, Shape, MathShape, EzShape}; +//! +//! let mut ctx = Context::new(); +//! let x = ctx.x(); +//! let shape = VmShape::new(&ctx, x)?; +//! +//! // Let's build a single point evaluator: +//! let mut eval = VmShape::new_point_eval(); +//! let tape = shape.ez_point_tape(); +//! let (value, _trace) = eval.eval(&tape, 0.25, 0.0, 0.0)?; +//! assert_eq!(value, 0.25); +//! # Ok::<(), fidget::Error>(()) +//! ``` +//! +//! Note that the traits here mirror the ones in ones in +//! [`fidget::eval`](crate::eval), but are specialized to operate on `x, y, z` +//! arguments (rather than taking arbitrary numbers of variables). It is +//! recommended to import the traits from either one or the other, to avoid +//! ambiguity. + +use crate::{ + context::Node, + eval::Trace, + types::{Grad, Interval}, + Context, Error, +}; + mod bounds; +mod bulk; +mod tracing; +mod transform; + +// Re-export a few things pub use bounds::Bounds; +pub use bulk::BulkEvaluator; +pub use tracing::TracingEvaluator; +pub use transform::TransformedShape; + +/// A shape represents an implicit surface +/// +/// It is mostly agnostic to _how_ that surface is represented; we simply +/// require that the shape can generate evaluators of various kinds. +/// +/// Shapes are shared between threads, so they should be cheap to clone. In +/// most cases, they're a thin wrapper around an `Arc<..>`. +pub trait Shape: Send + Sync + Clone { + /// Associated type traces collected during tracing evaluation + /// + /// This type must implement [`Eq`] so that traces can be compared; calling + /// [`Shape::simplify`] with traces that compare equal should produce an + /// identical result and may be cached. + type Trace: Clone + Eq + Send + Trace; + + /// Associated type for storage used by the shape itself + type Storage: Default + Send; + + /// Associated type for workspace used during shape simplification + type Workspace: Default + Send; + + /// Associated type for storage used by tapes + /// + /// For simplicity, we require that every tape use the same type for storage. + /// This could change in the future! + type TapeStorage: Default + Send; + + /// Associated type for single-point tracing evaluation + type PointEval: TracingEvaluator< + Data = f32, + Trace = Self::Trace, + TapeStorage = Self::TapeStorage, + > + Send + + Sync; + + /// Builds a new point evaluator + fn new_point_eval() -> Self::PointEval { + Self::PointEval::new() + } + + /// Associated type for single interval tracing evaluation + type IntervalEval: TracingEvaluator< + Data = Interval, + Trace = Self::Trace, + TapeStorage = Self::TapeStorage, + > + Send + + Sync; + + /// Builds a new interval evaluator + fn new_interval_eval() -> Self::IntervalEval { + Self::IntervalEval::new() + } + + /// Associated type for evaluating many points in one call + type FloatSliceEval: BulkEvaluator + + Send + + Sync; + + /// Builds a new float slice evaluator + fn new_float_slice_eval() -> Self::FloatSliceEval { + Self::FloatSliceEval::new() + } + + /// Associated type for evaluating many gradients in one call + type GradSliceEval: BulkEvaluator + + Send + + Sync; + + /// Builds a new gradient slice evaluator + fn new_grad_slice_eval() -> Self::GradSliceEval { + Self::GradSliceEval::new() + } + + /// Returns an evaluation tape for a point evaluator + fn point_tape( + &self, + storage: Self::TapeStorage, + ) -> ::Tape; + + /// Returns an evaluation tape for an interval evaluator + fn interval_tape( + &self, + storage: Self::TapeStorage, + ) -> ::Tape; + + /// Returns an evaluation tape for a float slice evaluator + fn float_slice_tape( + &self, + storage: Self::TapeStorage, + ) -> ::Tape; + + /// Returns an evaluation tape for a float slice evaluator + fn grad_slice_tape( + &self, + storage: Self::TapeStorage, + ) -> ::Tape; + + /// Computes a simplified tape using the given trace, and reusing storage + fn simplify( + &self, + trace: &Self::Trace, + storage: Self::Storage, + workspace: &mut Self::Workspace, + ) -> Result + where + Self: Sized; + + /// Attempt to reclaim storage from this shape + /// + /// This may fail, because shapes are `Clone` and are often implemented + /// using an `Arc` around a heavier data structure. + fn recycle(self) -> Option; + + /// Returns a size associated with this shape + /// + /// This is underspecified and only used for unit testing; for tape-based + /// shapes, it's typically the length of the tape, + fn size(&self) -> usize; + + /// Associated type returned when applying a transform + /// + /// This is normally [`TransformedShape`](TransformedShape), but if + /// `Self` is already `TransformedShape`, then the transform is stacked + /// (instead of creating a wrapped object). + type TransformedShape: Shape; + + /// Returns a shape with the given transform applied + fn apply_transform( + self, + mat: nalgebra::Matrix4, + ) -> ::TransformedShape; +} + +/// Extension trait for working with a shape without thinking much about memory +/// +/// All of the [`Shape`] functions that use significant amounts of memory +/// pedantically require you to pass in storage for reuse. This trait allows +/// you to ignore that, at the cost of performance; we require that all storage +/// types implement [`Default`], so these functions do the boilerplate for you. +/// +/// This trait is automatically implemented for every [`Shape`], but must be +/// imported separately as a speed-bump to using it everywhere. +pub trait EzShape: Shape { + /// Returns an evaluation tape for a point evaluator + fn ez_point_tape(&self) -> ::Tape; + + /// Returns an evaluation tape for an interval evaluator + fn ez_interval_tape( + &self, + ) -> ::Tape; + + /// Returns an evaluation tape for a float slice evaluator + fn ez_float_slice_tape( + &self, + ) -> ::Tape; + + /// Returns an evaluation tape for a float slice evaluator + fn ez_grad_slice_tape( + &self, + ) -> ::Tape; + + /// Computes a simplified tape using the given trace + fn ez_simplify(&self, trace: &Self::Trace) -> Result + where + Self: Sized; +} + +impl EzShape for S { + fn ez_point_tape(&self) -> ::Tape { + self.point_tape(Default::default()) + } + + fn ez_interval_tape( + &self, + ) -> ::Tape { + self.interval_tape(Default::default()) + } + + fn ez_float_slice_tape( + &self, + ) -> ::Tape { + self.float_slice_tape(Default::default()) + } + + fn ez_grad_slice_tape( + &self, + ) -> ::Tape { + self.grad_slice_tape(Default::default()) + } + + fn ez_simplify(&self, trace: &Self::Trace) -> Result { + let mut workspace = Default::default(); + self.simplify(trace, Default::default(), &mut workspace) + } +} + +/// A [`Shape`] which can be built from a math expression +pub trait MathShape { + /// Builds a new shape from the given context and node + fn new(ctx: &Context, node: Node) -> Result + where + Self: Sized; + + /// Helper function to build a shape from a [`Tree`](crate::context::Tree) + fn from_tree(t: &crate::context::Tree) -> Self + where + Self: Sized, + { + let mut ctx = Context::new(); + let node = ctx.import(t); + Self::new(&ctx, node).unwrap() + } +} /// Hints for how to render this particular type pub trait RenderHints { diff --git a/fidget/src/core/eval/tracing.rs b/fidget/src/core/shape/tracing.rs similarity index 96% rename from fidget/src/core/eval/tracing.rs rename to fidget/src/core/shape/tracing.rs index a18116e9..c3f7e041 100644 --- a/fidget/src/core/eval/tracing.rs +++ b/fidget/src/core/shape/tracing.rs @@ -12,8 +12,8 @@ use crate::{eval::Tape, Error}; /// Evaluator for single values which simultaneously captures an execution trace /// -/// The trace can later be used to simplify the [`Shape`](crate::eval::Shape) -/// using [`Shape::simplify`](crate::eval::Shape::simplify). +/// The trace can later be used to simplify the [`Shape`](crate::shape::Shape) +/// using [`Shape::simplify`](crate::shape::Shape::simplify). pub trait TracingEvaluator: Default { /// Data type used during evaluation type Data: From + Copy + Clone; diff --git a/fidget/src/core/eval/transform.rs b/fidget/src/core/shape/transform.rs similarity index 98% rename from fidget/src/core/eval/transform.rs rename to fidget/src/core/shape/transform.rs index d5684b61..c9b13087 100644 --- a/fidget/src/core/eval/transform.rs +++ b/fidget/src/core/shape/transform.rs @@ -1,6 +1,7 @@ use crate::{ - eval::{BulkEvaluator, Grad, Interval, Shape, Tape, TracingEvaluator}, - shape::RenderHints, + eval::Tape, + shape::{BulkEvaluator, RenderHints, Shape, TracingEvaluator}, + types::{Grad, Interval}, Error, }; use nalgebra::{Matrix4, Point3, Vector3}; diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index 78f54fb7..2a0102cc 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -2,11 +2,11 @@ use crate::{ compiler::RegOp, context::Node, - eval::{ - BulkEvaluator, MathShape, Shape, Tape, Trace, TracingEvaluator, - TransformedShape, - }, + eval::{Tape, Trace}, shape::RenderHints, + shape::{ + BulkEvaluator, MathShape, Shape, TracingEvaluator, TransformedShape, + }, types::{Grad, Interval}, Context, Error, }; diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index c183b730..df744438 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -6,7 +6,7 @@ //! ``` //! use fidget::{ //! context::Tree, -//! eval::{TracingEvaluator, Shape, MathShape, EzShape}, +//! shape::{TracingEvaluator, Shape, MathShape, EzShape}, //! jit::JitShape //! }; //! @@ -26,12 +26,12 @@ use crate::{ compiler::RegOp, context::{Context, Node}, - eval::{ - BulkEvaluator, MathShape, Shape, Tape, TracingEvaluator, - TransformedShape, - }, + eval::Tape, jit::mmap::Mmap, shape::RenderHints, + shape::{ + BulkEvaluator, MathShape, Shape, TracingEvaluator, TransformedShape, + }, types::{Grad, Interval}, vm::{Choice, GenericVmShape, VmData, VmTrace, VmWorkspace}, Error, diff --git a/fidget/src/lib.rs b/fidget/src/lib.rs index 1ecba9ca..6cf0374b 100644 --- a/fidget/src/lib.rs +++ b/fidget/src/lib.rs @@ -68,7 +68,7 @@ //! //! Evaluation is deliberately agnostic to the specific details of how we go //! from position to results. This abstraction is represented by the -//! [`Shape` trait](crate::eval::Shape), which defines how to make both +//! [`Shape` trait](crate::shape::Shape), which defines how to make both //! **evaluators** and **tapes**. //! //! An **evaluator** is an object which performs evaluation of some kind (point, @@ -85,7 +85,7 @@ //! - [`fidget::jit::JitShape`](crate::jit::JitShape) performs fast evaluation //! by compiling shapes down to native code. //! -//! The [`eval::Shape`](crate::eval::Shape) trait requires four different kinds +//! The [`Shape`](crate::shape::Shape) trait requires four different kinds //! of evaluation: //! //! - Single-point evaluation @@ -107,7 +107,7 @@ //! ``` //! use fidget::{ //! context::Tree, -//! eval::{Shape, MathShape, EzShape, TracingEvaluator}, +//! shape::{Shape, MathShape, EzShape, TracingEvaluator}, //! vm::VmShape //! }; //! @@ -136,7 +136,7 @@ //! ``` //! use fidget::{ //! context::Tree, -//! eval::{TracingEvaluator, Shape, MathShape, EzShape}, +//! shape::{TracingEvaluator, Shape, MathShape, EzShape}, //! vm::VmShape //! }; //! @@ -159,13 +159,13 @@ //! tape from `min(x, y) → x`. //! //! Interval evaluation is a kind of -//! [tracing evaluation](crate::eval::TracingEvaluator), which returns a tuple +//! [tracing evaluation](crate::shape::TracingEvaluator), which returns a tuple //! of `(value, trace)`. The trace can be used to simplify the original shape: //! //! ``` //! # use fidget::{ //! # context::Tree, -//! # eval::{TracingEvaluator, Shape, MathShape, EzShape}, +//! # shape::{TracingEvaluator, Shape, MathShape, EzShape}, //! # vm::VmShape //! # }; //! # let tree = Tree::x().min(Tree::y()); diff --git a/fidget/src/mesh/mod.rs b/fidget/src/mesh/mod.rs index 39d4125f..7e48dff6 100644 --- a/fidget/src/mesh/mod.rs +++ b/fidget/src/mesh/mod.rs @@ -3,7 +3,7 @@ //! This module implements //! [Manifold Dual Contouring](https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf), //! to generate a triangle mesh from an implicit surface (or anything -//! implementing [`Shape`](crate::eval::Shape)). +//! implementing [`Shape`](crate::shape::Shape)). //! //! The resulting meshes should be //! - Manifold diff --git a/fidget/src/mesh/mt/octree.rs b/fidget/src/mesh/mt/octree.rs index e5e1631f..8b500fe9 100644 --- a/fidget/src/mesh/mt/octree.rs +++ b/fidget/src/mesh/mt/octree.rs @@ -1,7 +1,6 @@ //! Multithreaded octree construction use super::pool::{QueuePool, ThreadContext, ThreadPool}; use crate::{ - eval::Shape, mesh::{ cell::{Cell, CellData, CellIndex}, octree::{BranchResult, CellResult, EvalGroup, OctreeBuilder}, @@ -9,6 +8,7 @@ use crate::{ Octree, Settings, }, shape::RenderHints, + shape::Shape, }; use std::sync::{mpsc::TryRecvError, Arc}; diff --git a/fidget/src/mesh/octree.rs b/fidget/src/mesh/octree.rs index 0aca834e..714b9c21 100644 --- a/fidget/src/mesh/octree.rs +++ b/fidget/src/mesh/octree.rs @@ -11,8 +11,8 @@ use super::{ Mesh, Settings, }; use crate::{ - eval::{BulkEvaluator, Shape, Tape, TracingEvaluator}, - shape::RenderHints, + eval::Tape, + shape::{BulkEvaluator, RenderHints, Shape, TracingEvaluator}, types::Grad, }; use std::{num::NonZeroUsize, sync::Arc, sync::OnceLock}; @@ -1169,9 +1169,9 @@ mod test { use super::*; use crate::{ context::Tree, - eval::{EzShape, MathShape}, mesh::types::{Edge, X, Y, Z}, shape::Bounds, + shape::{EzShape, MathShape}, vm::VmShape, }; use nalgebra::Vector3; diff --git a/fidget/src/render/config.rs b/fidget/src/render/config.rs index 18ddd9fb..2a6f4e6f 100644 --- a/fidget/src/render/config.rs +++ b/fidget/src/render/config.rs @@ -1,4 +1,8 @@ -use crate::{eval::Shape, render::RenderMode, shape::Bounds, Error}; +use crate::{ + render::RenderMode, + shape::{Bounds, Shape}, + Error, +}; use nalgebra::{ allocator::Allocator, Const, DefaultAllocator, DimNameAdd, DimNameSub, DimNameSum, U1, diff --git a/fidget/src/render/mod.rs b/fidget/src/render/mod.rs index 1b6fe38d..f8fbe540 100644 --- a/fidget/src/render/mod.rs +++ b/fidget/src/render/mod.rs @@ -4,7 +4,10 @@ //! [`RenderConfig::run`](RenderConfig::run); you can also use the lower-level //! functions ([`render2d`](render2d()) and [`render3d`](render3d())) for manual //! control over the input tape. -use crate::eval::{BulkEvaluator, Shape, Tape, Trace, TracingEvaluator}; +use crate::{ + eval::{Tape, Trace}, + shape::{BulkEvaluator, Shape, TracingEvaluator}, +}; use std::sync::Arc; mod config; diff --git a/fidget/src/render/render2d.rs b/fidget/src/render/render2d.rs index 18050196..b046b40a 100644 --- a/fidget/src/render/render2d.rs +++ b/fidget/src/render/render2d.rs @@ -1,8 +1,8 @@ //! 2D bitmap rendering / rasterization use super::RenderHandle; use crate::{ - eval::{BulkEvaluator, Shape, TracingEvaluator}, render::config::{AlignedRenderConfig, Queue, RenderConfig, Tile}, + shape::{BulkEvaluator, Shape, TracingEvaluator}, types::Interval, }; use nalgebra::Point2; @@ -467,8 +467,7 @@ fn render_inner( mod test { use super::*; use crate::{ - eval::{MathShape, Shape}, - shape::Bounds, + shape::{Bounds, MathShape, Shape}, vm::{GenericVmShape, VmShape}, Context, }; diff --git a/fidget/src/render/render3d.rs b/fidget/src/render/render3d.rs index 0df2c78b..f7e83227 100644 --- a/fidget/src/render/render3d.rs +++ b/fidget/src/render/render3d.rs @@ -1,8 +1,8 @@ //! 3D bitmap rendering / rasterization use super::RenderHandle; use crate::{ - eval::{BulkEvaluator, Shape, TracingEvaluator}, render::config::{AlignedRenderConfig, Queue, RenderConfig, Tile}, + shape::{BulkEvaluator, Shape, TracingEvaluator}, types::{Grad, Interval}, }; @@ -448,7 +448,7 @@ pub fn render_inner( #[cfg(test)] mod test { use super::*; - use crate::{eval::MathShape, vm::VmShape, Context}; + use crate::{shape::MathShape, vm::VmShape, Context}; /// Make sure we don't crash if there's only a single tile #[test] diff --git a/fidget/src/rhai/mod.rs b/fidget/src/rhai/mod.rs index b1d41035..6ebe9103 100644 --- a/fidget/src/rhai/mod.rs +++ b/fidget/src/rhai/mod.rs @@ -7,7 +7,7 @@ //! //! ``` //! use fidget::{ -//! eval::{Shape, MathShape, EzShape, TracingEvaluator}, +//! shape::{Shape, MathShape, EzShape, TracingEvaluator}, //! vm::VmShape, //! }; //! @@ -24,7 +24,7 @@ //! //! ``` //! use fidget::{ -//! eval::{Shape, MathShape, EzShape, TracingEvaluator}, +//! shape::{Shape, MathShape, EzShape, TracingEvaluator}, //! vm::VmShape, //! rhai::Engine //! }; diff --git a/viewer/src/main.rs b/viewer/src/main.rs index 95f0aaba..46789e3b 100644 --- a/viewer/src/main.rs +++ b/viewer/src/main.rs @@ -82,8 +82,8 @@ fn render_thread( wake: Sender<()>, ) -> Result<()> where - S: fidget::eval::Shape - + fidget::eval::MathShape + S: fidget::shape::Shape + + fidget::shape::MathShape + fidget::shape::RenderHints, { let mut config = None; @@ -150,7 +150,7 @@ where } } -fn render( +fn render( mode: &RenderMode, shape: S, image_size: usize, From fb893b5853746a622b3320757cd0c51ade26c492 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sun, 19 May 2024 09:33:31 -0400 Subject: [PATCH 02/12] Re-add stub traits for eval::{Function, BulkEvaluator, TracingEvaluator} --- fidget/src/core/eval/bulk.rs | 65 ++++++++++++++++ fidget/src/core/eval/mod.rs | 131 ++++++++++++++++++++++++++++++++ fidget/src/core/eval/tracing.rs | 59 ++++++++++++++ 3 files changed, 255 insertions(+) create mode 100644 fidget/src/core/eval/bulk.rs create mode 100644 fidget/src/core/eval/tracing.rs diff --git a/fidget/src/core/eval/bulk.rs b/fidget/src/core/eval/bulk.rs new file mode 100644 index 00000000..67c8e81b --- /dev/null +++ b/fidget/src/core/eval/bulk.rs @@ -0,0 +1,65 @@ +//! Evaluates many points in a single call +//! +//! Doing bulk evaluations helps limit to overhead of instruction dispatch, and +//! can take advantage of SIMD. +//! +//! It is unlikely that you'll want to use these traits or types directly; +//! they're implementation details to minimize code duplication. + +use crate::{eval::Tape, Error}; + +/// Trait for bulk evaluation returning the given type `T` +/// +/// It's uncommon to use this trait outside the library itself; it's an +/// abstraction to reduce code duplication, and is public because it's used as a +/// constraint on other public APIs. +pub trait BulkEvaluator: Default { + /// Data type used during evaluation + type Data: From + Copy + Clone; + + /// Instruction tape used during evaluation + /// + /// This may be a literal instruction tape (in the case of VM evaluation), + /// or a metaphorical instruction tape (e.g. a JIT function). + type Tape: Tape + Send + Sync; + + /// Associated type for tape storage + /// + /// This is a workaround for plumbing purposes + type TapeStorage; + + /// Evaluates many points using the given instruction tape + /// + /// Returns an error if the `x`, `y`, `z`, and `out` slices are of different + /// lengths. + fn eval( + &mut self, + tape: &Self::Tape, + vars: &[&[Self::Data]], + ) -> Result<&[Self::Data], Error>; + + /// Build a new empty evaluator + fn new() -> Self { + Self::default() + } + + /// Helper function to return an error if the inputs are invalid + fn check_arguments( + &self, + vars: &[&[Self::Data]], + var_count: usize, + ) -> Result<(), Error> { + if var_count != vars.len() { + Err(Error::BadVarSlice(vars.len(), var_count)) + } else { + let Some(n) = vars.first().map(|v| v.len()) else { + return Ok(()); + }; + if vars.iter().any(|v| v.len() == n) { + Ok(()) + } else { + Err(Error::MismatchedSlices) + } + } + } +} diff --git a/fidget/src/core/eval/mod.rs b/fidget/src/core/eval/mod.rs index d3d5b583..ff6376c1 100644 --- a/fidget/src/core/eval/mod.rs +++ b/fidget/src/core/eval/mod.rs @@ -1,8 +1,19 @@ //! Traits and data structures for function evaluation +use crate::{ + types::{Grad, Interval}, + Error, +}; #[cfg(any(test, feature = "eval-tests"))] pub mod test; +mod bulk; +mod tracing; + +// Reexport a few types +pub use bulk::BulkEvaluator; +pub use tracing::TracingEvaluator; + /// A tape represents something that can be evaluated by an evaluator /// /// The only property enforced on the trait is that we must have some way to @@ -34,3 +45,123 @@ impl Trace for Vec { self.copy_from_slice(other); } } + +/// A function represents something that can be evaluated +/// +/// It is mostly agnostic to _how_ that something is represented; we simply +/// require that it can generate evaluators of various kinds. +/// +/// Functions are shared between threads, so they should be cheap to clone. In +/// most cases, they're a thin wrapper around an `Arc<..>`. +pub trait Function: Send + Sync + Clone { + /// Associated type traces collected during tracing evaluation + /// + /// This type must implement [`Eq`] so that traces can be compared; calling + /// [`Function::simplify`] with traces that compare equal should produce an + /// identical result and may be cached. + type Trace: Clone + Eq + Send + Trace; + + /// Associated type for storage used by the shape itself + type Storage: Default + Send; + + /// Associated type for workspace used during shape simplification + type Workspace: Default + Send; + + /// Associated type for storage used by tapes + /// + /// For simplicity, we require that every tape use the same type for storage. + /// This could change in the future! + type TapeStorage: Default + Send; + + /// Associated type for single-point tracing evaluation + type PointEval: TracingEvaluator< + Data = f32, + Trace = Self::Trace, + TapeStorage = Self::TapeStorage, + > + Send + + Sync; + + /// Builds a new point evaluator + fn new_point_eval() -> Self::PointEval { + Self::PointEval::new() + } + + /// Associated type for single interval tracing evaluation + type IntervalEval: TracingEvaluator< + Data = Interval, + Trace = Self::Trace, + TapeStorage = Self::TapeStorage, + > + Send + + Sync; + + /// Builds a new interval evaluator + fn new_interval_eval() -> Self::IntervalEval { + Self::IntervalEval::new() + } + + /// Associated type for evaluating many points in one call + type FloatSliceEval: BulkEvaluator + + Send + + Sync; + + /// Builds a new float slice evaluator + fn new_float_slice_eval() -> Self::FloatSliceEval { + Self::FloatSliceEval::new() + } + + /// Associated type for evaluating many gradients in one call + type GradSliceEval: BulkEvaluator + + Send + + Sync; + + /// Builds a new gradient slice evaluator + fn new_grad_slice_eval() -> Self::GradSliceEval { + Self::GradSliceEval::new() + } + + /// Returns an evaluation tape for a point evaluator + fn point_tape( + &self, + storage: Self::TapeStorage, + ) -> ::Tape; + + /// Returns an evaluation tape for an interval evaluator + fn interval_tape( + &self, + storage: Self::TapeStorage, + ) -> ::Tape; + + /// Returns an evaluation tape for a float slice evaluator + fn float_slice_tape( + &self, + storage: Self::TapeStorage, + ) -> ::Tape; + + /// Returns an evaluation tape for a float slice evaluator + fn grad_slice_tape( + &self, + storage: Self::TapeStorage, + ) -> ::Tape; + + /// Computes a simplified tape using the given trace, and reusing storage + fn simplify( + &self, + trace: &Self::Trace, + storage: Self::Storage, + workspace: &mut Self::Workspace, + ) -> Result + where + Self: Sized; + + /// Attempt to reclaim storage from this shape + /// + /// This may fail, because shapes are `Clone` and are often implemented + /// using an `Arc` around a heavier data structure. + fn recycle(self) -> Option; + + /// Returns a size associated with this shape + /// + /// This is underspecified and only used for unit testing; for tape-based + /// shapes, it's typically the length of the tape, + fn size(&self) -> usize; +} diff --git a/fidget/src/core/eval/tracing.rs b/fidget/src/core/eval/tracing.rs new file mode 100644 index 00000000..12b78945 --- /dev/null +++ b/fidget/src/core/eval/tracing.rs @@ -0,0 +1,59 @@ +//! Capturing a trace of function evaluation for further optimization +//! +//! Tracing evaluators are run on a single data type and capture a trace of +//! execution, which is the [`Trace` associated type](TracingEvaluator::Trace). +//! +//! The resulting trace can be used to simplify the original function. +//! +//! It is unlikely that you'll want to use these traits or types directly; +//! they're implementation details to minimize code duplication. + +use crate::{eval::Tape, Error}; + +/// Evaluator for single values which simultaneously captures an execution trace +/// +/// The trace can later be used to simplify the [`Shape`](crate::eval::Shape) +/// using [`Shape::simplify`](crate::eval::Shape::simplify). +pub trait TracingEvaluator: Default { + /// Data type used during evaluation + type Data: From + Copy + Clone; + + /// Instruction tape used during evaluation + /// + /// This may be a literal instruction tape (in the case of VM evaluation), + /// or a metaphorical instruction tape (e.g. a JIT function). + type Tape: Tape + Send + Sync; + + /// Associated type for tape storage + /// + /// This is a workaround for plumbing purposes + type TapeStorage; + + /// Associated type for the trace captured during evaluation + type Trace; + + /// Evaluates the given tape at a particular position + fn eval>( + &mut self, + tape: &Self::Tape, + vars: &[F], + ) -> Result<(Self::Data, Option<&Self::Trace>), Error>; + + /// Build a new empty evaluator + fn new() -> Self { + Self::default() + } + + /// Helper function to return an error if the inputs are invalid + fn check_arguments( + &self, + vars: &[Self::Data], + var_count: usize, + ) -> Result<(), Error> { + if vars.len() != var_count { + Err(Error::BadVarSlice(vars.len(), var_count)) + } else { + Ok(()) + } + } +} From d7c884c5524a2747ed72b8f9a54b7ec0773974c4 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sun, 19 May 2024 10:10:21 -0400 Subject: [PATCH 03/12] Add generic FunctionShape --- fidget/src/core/shape/mod.rs | 164 ++++++++++++++++++++++++++++++++++- 1 file changed, 163 insertions(+), 1 deletion(-) diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index 802fc554..cf4fc74f 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -29,7 +29,7 @@ use crate::{ context::Node, - eval::Trace, + eval::{self, Trace}, types::{Grad, Interval}, Context, Error, }; @@ -277,3 +277,165 @@ pub trait RenderHints { true } } + +//////////////////////////////////////////////////////////////////////////////// + +/// Wrapper to convert a [`Function`](fidget::eval::Function) into a [`Shape`] +/// for evaluation. +#[derive(Clone)] +pub struct FunctionShape { + /// Wrapped function + f: F, + + /// Index of x, y, z axes within the function's variable list + axes: [usize; 3], +} + +impl Shape for FunctionShape { + type Trace = ::Trace; + type Storage = ::Storage; + type Workspace = ::Workspace; + type TapeStorage = ::TapeStorage; + + type PointEval = FunctionShapeTracingEval<::PointEval>; + type IntervalEval = + FunctionShapeTracingEval<::IntervalEval>; + type FloatSliceEval = + FunctionShapeBulkEval<::FloatSliceEval>; + type GradSliceEval = + FunctionShapeBulkEval<::GradSliceEval>; + + fn point_tape( + &self, + storage: Self::TapeStorage, + ) -> ::Tape { + self.f.point_tape(storage) + } + + fn interval_tape( + &self, + storage: Self::TapeStorage, + ) -> ::Tape { + self.f.interval_tape(storage) + } + + fn float_slice_tape( + &self, + storage: Self::TapeStorage, + ) -> ::Tape { + self.f.float_slice_tape(storage) + } + + fn grad_slice_tape( + &self, + storage: Self::TapeStorage, + ) -> ::Tape { + self.f.grad_slice_tape(storage) + } + + fn simplify( + &self, + trace: &Self::Trace, + storage: Self::Storage, + workspace: &mut Self::Workspace, + ) -> Result + where + Self: Sized, + { + let f = self.f.simplify(trace, storage, workspace)?; + Ok(Self { f, axes: self.axes }) + } + + fn recycle(self) -> Option { + self.f.recycle() + } + + fn size(&self) -> usize { + self.f.size() + } + + type TransformedShape = TransformedShape; + + fn apply_transform( + self, + mat: nalgebra::Matrix4, + ) -> ::TransformedShape { + TransformedShape::new(self, mat) + } + + // todo +} + +/// Wrapper struct to convert from [`eval::TracingEvaluator`] to +/// [`shape::TracingEvaluator`](TracingEvaluator) +#[derive(Default)] +pub struct FunctionShapeTracingEval { + eval: E, + + /// Index of x, y, z axes within the function's variable list + axes: [usize; 3], +} + +impl TracingEvaluator + for FunctionShapeTracingEval +{ + type Data = E::Data; + type Tape = E::Tape; + type TapeStorage = E::TapeStorage; + type Trace = E::Trace; + + fn eval>( + &mut self, + tape: &Self::Tape, + x: F, + y: F, + z: F, + ) -> Result<(Self::Data, Option<&Self::Trace>), Error> { + let mut vars = [None, None, None]; + vars[self.axes[0]] = Some(x); + vars[self.axes[1]] = Some(y); + vars[self.axes[2]] = Some(z); + + // TODO make this error? Where do we maintain the `axes` invariants? + let vars = vars.map(Option::unwrap); + self.eval.eval(tape, &vars) + } + // todo +} + +/// Wrapper struct to convert from [`eval::BulkEvaluator`] to +/// [`shape::TracingEvaluator`](BulkEvaluator) +#[derive(Default)] +pub struct FunctionShapeBulkEval { + eval: E, + + /// Index of x, y, z axes within the function's variable list + axes: [usize; 3], +} + +impl BulkEvaluator for FunctionShapeBulkEval { + type Data = E::Data; + type Tape = E::Tape; + type TapeStorage = E::TapeStorage; + + fn new() -> Self { + Self::default() + } + + fn eval( + &mut self, + tape: &Self::Tape, + x: &[Self::Data], + y: &[Self::Data], + z: &[Self::Data], + ) -> Result<&[Self::Data], Error> { + let mut vars = [None, None, None]; + vars[self.axes[0]] = Some(x); + vars[self.axes[1]] = Some(y); + vars[self.axes[2]] = Some(z); + + // TODO make this error? Where do we maintain the `axes` invariants? + let vars = vars.map(Option::unwrap); + self.eval.eval(tape, &vars) + } +} From 7426d7dcd9e265882083c75ca5920e6678312ee8 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sun, 19 May 2024 15:47:45 -0400 Subject: [PATCH 04/12] Got things compiling (without jit), but many tests are failing --- demo/src/main.rs | 6 +- fidget/Cargo.toml | 2 +- fidget/benches/function_call.rs | 4 +- fidget/benches/mesh.rs | 10 +- fidget/benches/render.rs | 10 +- fidget/src/core/context/mod.rs | 5 + fidget/src/core/eval/mod.rs | 19 ++++ fidget/src/core/eval/test/float_slice.rs | 24 ++--- fidget/src/core/eval/test/grad_slice.rs | 46 ++++----- fidget/src/core/eval/test/interval.rs | 74 +++++++------- fidget/src/core/eval/test/point.rs | 44 ++++---- fidget/src/core/eval/tracing.rs | 4 +- fidget/src/core/shape/bulk.rs | 17 ---- fidget/src/core/shape/mod.rs | 88 +++++++++++----- fidget/src/core/shape/tracing.rs | 9 -- fidget/src/core/vm/mod.rs | 122 +++++++---------------- fidget/src/jit/mod.rs | 43 ++++---- fidget/src/mesh/octree.rs | 4 +- fidget/src/render/render2d.rs | 28 +++--- fidget/src/render/render3d.rs | 2 +- 20 files changed, 273 insertions(+), 288 deletions(-) diff --git a/demo/src/main.rs b/demo/src/main.rs index 6a89fc63..58bda40c 100644 --- a/demo/src/main.rs +++ b/demo/src/main.rs @@ -264,7 +264,7 @@ fn main() -> Result<()> { let now = Instant::now(); let args = Args::parse(); let mut file = std::fs::File::open(&args.input)?; - let (ctx, root) = Context::from_text(&mut file)?; + let (mut ctx, root) = Context::from_text(&mut file)?; info!("Loaded file in {:?}", now.elapsed()); match args.cmd { @@ -277,12 +277,12 @@ fn main() -> Result<()> { let buffer = match settings.eval { #[cfg(feature = "jit")] EvalMode::Jit => { - let shape = fidget::jit::JitShape::new(&ctx, root)?; + let shape = fidget::jit::JitShape::new(&mut ctx, root)?; info!("Built shape in {:?}", start.elapsed()); run2d(shape, &settings, brute, sdf) } EvalMode::Vm => { - let shape = fidget::vm::VmShape::new(&ctx, root)?; + let shape = fidget::vm::VmShape::new(&mut ctx, root)?; info!("Built shape in {:?}", start.elapsed()); run2d(shape, &settings, brute, sdf) } diff --git a/fidget/Cargo.toml b/fidget/Cargo.toml index 60e682d2..cf11ac40 100644 --- a/fidget/Cargo.toml +++ b/fidget/Cargo.toml @@ -39,7 +39,7 @@ windows = { version = "0.54.0", features = ["Win32_Foundation", "Win32_System_Me getrandom = { version = "0.2", features = ["js"] } [features] -default = ["jit", "rhai", "render", "mesh"] +default = ["rhai", "render", "mesh"] ## Enables fast evaluation via a JIT compiler. This is exposed in the ## [`fidget::jit`](crate::jit) module, and is supported on diff --git a/fidget/benches/function_call.rs b/fidget/benches/function_call.rs index 3138bf3f..3de2f12d 100644 --- a/fidget/benches/function_call.rs +++ b/fidget/benches/function_call.rs @@ -8,12 +8,12 @@ use fidget::{ pub fn run_bench( c: &mut Criterion, - ctx: Context, + mut ctx: Context, node: Node, test_name: &'static str, name: &'static str, ) { - let shape_vm = &S::new(&ctx, node).unwrap(); + let shape_vm = &S::new(&mut ctx, node).unwrap(); let mut eval = S::new_float_slice_eval(); let tape = shape_vm.ez_float_slice_tape(); diff --git a/fidget/benches/mesh.rs b/fidget/benches/mesh.rs index 53a2590e..8c62dfdd 100644 --- a/fidget/benches/mesh.rs +++ b/fidget/benches/mesh.rs @@ -6,8 +6,9 @@ use fidget::shape::MathShape; const COLONNADE: &str = include_str!("../../models/colonnade.vm"); pub fn colonnade_octree_thread_sweep(c: &mut Criterion) { - let (ctx, root) = fidget::Context::from_text(COLONNADE.as_bytes()).unwrap(); - let shape_vm = &fidget::vm::VmShape::new(&ctx, root).unwrap(); + let (mut ctx, root) = + fidget::Context::from_text(COLONNADE.as_bytes()).unwrap(); + let shape_vm = &fidget::vm::VmShape::new(&mut ctx, root).unwrap(); #[cfg(feature = "jit")] let shape_jit = &fidget::jit::JitShape::new(&ctx, root).unwrap(); @@ -36,8 +37,9 @@ pub fn colonnade_octree_thread_sweep(c: &mut Criterion) { } pub fn colonnade_mesh(c: &mut Criterion) { - let (ctx, root) = fidget::Context::from_text(COLONNADE.as_bytes()).unwrap(); - let shape_vm = &fidget::vm::VmShape::new(&ctx, root).unwrap(); + let (mut ctx, root) = + fidget::Context::from_text(COLONNADE.as_bytes()).unwrap(); + let shape_vm = &fidget::vm::VmShape::new(&mut ctx, root).unwrap(); let cfg = fidget::mesh::Settings { depth: 8, ..Default::default() diff --git a/fidget/benches/render.rs b/fidget/benches/render.rs index 0cf13a70..d8c00088 100644 --- a/fidget/benches/render.rs +++ b/fidget/benches/render.rs @@ -7,8 +7,9 @@ const PROSPERO: &str = include_str!("../../models/prospero.vm"); use fidget::shape::{MathShape, RenderHints}; pub fn prospero_size_sweep(c: &mut Criterion) { - let (ctx, root) = fidget::Context::from_text(PROSPERO.as_bytes()).unwrap(); - let shape_vm = &fidget::vm::VmShape::new(&ctx, root).unwrap(); + let (mut ctx, root) = + fidget::Context::from_text(PROSPERO.as_bytes()).unwrap(); + let shape_vm = &fidget::vm::VmShape::new(&mut ctx, root).unwrap(); #[cfg(feature = "jit")] let shape_jit = &fidget::jit::JitShape::new(&ctx, root).unwrap(); @@ -52,8 +53,9 @@ pub fn prospero_size_sweep(c: &mut Criterion) { } pub fn prospero_thread_sweep(c: &mut Criterion) { - let (ctx, root) = fidget::Context::from_text(PROSPERO.as_bytes()).unwrap(); - let shape_vm = &fidget::vm::VmShape::new(&ctx, root).unwrap(); + let (mut ctx, root) = + fidget::Context::from_text(PROSPERO.as_bytes()).unwrap(); + let shape_vm = &fidget::vm::VmShape::new(&mut ctx, root).unwrap(); #[cfg(feature = "jit")] let shape_jit = &fidget::jit::JitShape::new(&ctx, root).unwrap(); diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index 8063b09f..d835ee95 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -166,6 +166,11 @@ impl Context { self.ops.insert(Op::Input(v)) } + /// Returns a 3-element array of `X`, `Y`, `Z` nodes + pub fn axes(&mut self) -> [Node; 3] { + [self.x(), self.y(), self.z()] + } + /// Returns a node representing the given constant value. /// ``` /// # let mut ctx = fidget::context::Context::new(); diff --git a/fidget/src/core/eval/mod.rs b/fidget/src/core/eval/mod.rs index ff6376c1..c095b5e2 100644 --- a/fidget/src/core/eval/mod.rs +++ b/fidget/src/core/eval/mod.rs @@ -1,5 +1,6 @@ //! Traits and data structures for function evaluation use crate::{ + context::{Context, Node, Tree}, types::{Grad, Interval}, Error, }; @@ -165,3 +166,21 @@ pub trait Function: Send + Sync + Clone { /// shapes, it's typically the length of the tape, fn size(&self) -> usize; } + +/// A [`Function`] which can be built from a math expression +pub trait MathFunction { + /// Builds a new shape from the given context and node + fn new(ctx: &Context, node: Node) -> Result + where + Self: Sized; + + /// Helper function to build a shape from a [`Tree`](crate::context::Tree) + fn from_tree(t: &Tree) -> Self + where + Self: Sized, + { + let mut ctx = Context::new(); + let node = ctx.import(t); + Self::new(&ctx, node).unwrap() + } +} diff --git a/fidget/src/core/eval/test/float_slice.rs b/fidget/src/core/eval/test/float_slice.rs index f20be90b..74890bd5 100644 --- a/fidget/src/core/eval/test/float_slice.rs +++ b/fidget/src/core/eval/test/float_slice.rs @@ -21,8 +21,8 @@ where let x = ctx.x(); let y = ctx.y(); - let shape_x = S::new(&ctx, x).unwrap(); - let shape_y = S::new(&ctx, y).unwrap(); + let shape_x = S::new(&mut ctx, x).unwrap(); + let shape_y = S::new(&mut ctx, y).unwrap(); // This is a fuzz test for icache issues let mut eval = S::new_float_slice_eval(); @@ -59,7 +59,7 @@ where let y = ctx.y(); let mut eval = S::new_float_slice_eval(); - let shape = S::new(&ctx, x).unwrap(); + let shape = S::new(&mut ctx, x).unwrap(); let tape = shape.ez_float_slice_tape(); let out = eval .eval( @@ -92,7 +92,7 @@ where assert_eq!(out, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); let mul = ctx.mul(y, 2.0).unwrap(); - let shape = S::new(&ctx, mul).unwrap(); + let shape = S::new(&mut ctx, mul).unwrap(); let tape = shape.ez_float_slice_tape(); let out = eval .eval( @@ -125,7 +125,7 @@ where let a = ctx.x(); let b = ctx.sin(a).unwrap(); - let shape = S::new(&ctx, b).unwrap(); + let shape = S::new(&mut ctx, b).unwrap(); let mut eval = S::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); @@ -137,7 +137,7 @@ where } pub fn test_f_stress_n(depth: usize) { - let (ctx, node) = build_stress_fn(depth); + let (mut ctx, node) = build_stress_fn(depth); // Pick an input slice that's guaranteed to be > 1 SIMD register let args = (0..32).map(|i| i as f32 / 32f32).collect::>(); @@ -147,7 +147,7 @@ where let z: Vec = args[2..].iter().chain(&args[0..2]).cloned().collect(); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let mut eval = S::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); @@ -172,7 +172,7 @@ where // that S is also a VmShape, but this comparison isn't particularly // expensive, so we'll do it regardless. use crate::vm::VmShape; - let shape = VmShape::new(&ctx, node).unwrap(); + let shape = VmShape::new(&mut ctx, node).unwrap(); let mut eval = VmShape::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); @@ -204,7 +204,7 @@ where for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { let node = C::build(&mut ctx, v); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let mut eval = S::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); @@ -261,7 +261,7 @@ where for (j, &u) in inputs.iter().enumerate() { let node = C::build(&mut ctx, v, u); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let mut eval = S::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); @@ -308,7 +308,7 @@ where let c = ctx.constant(*rhs as f64); let node = C::build(&mut ctx, v, c); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let mut eval = S::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); @@ -349,7 +349,7 @@ where let c = ctx.constant(*lhs as f64); let node = C::build(&mut ctx, c, v); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let mut eval = S::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); diff --git a/fidget/src/core/eval/test/grad_slice.rs b/fidget/src/core/eval/test/grad_slice.rs index a60bf8c7..ed2afe1b 100644 --- a/fidget/src/core/eval/test/grad_slice.rs +++ b/fidget/src/core/eval/test/grad_slice.rs @@ -41,7 +41,7 @@ where pub fn test_g_x() { let mut ctx = Context::new(); let x = ctx.x(); - let shape = S::new(&ctx, x).unwrap(); + let shape = S::new(&mut ctx, x).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -53,7 +53,7 @@ where pub fn test_g_y() { let mut ctx = Context::new(); let y = ctx.y(); - let shape = S::new(&ctx, y).unwrap(); + let shape = S::new(&mut ctx, y).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -65,7 +65,7 @@ where pub fn test_g_z() { let mut ctx = Context::new(); let z = ctx.z(); - let shape = S::new(&ctx, z).unwrap(); + let shape = S::new(&mut ctx, z).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -78,7 +78,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.square(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = S::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -103,7 +103,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.abs(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = S::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -120,7 +120,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sqrt(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = S::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -137,7 +137,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sin(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = S::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); let v = Self::eval_xyz(&tape, &[1.0, 2.0, 3.0], &[0.0; 3], &[0.0; 3]); @@ -148,7 +148,7 @@ where let y = ctx.y(); let y = ctx.mul(y, 2.0).unwrap(); let s = ctx.sin(y).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = S::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); let v = Self::eval_xyz(&tape, &[0.0; 3], &[1.0, 2.0, 3.0], &[0.0; 3]); v[0].compare_eq(Grad::new(2f32.sin(), 0.0, 2.0 * 2f32.cos(), 0.0)); @@ -161,7 +161,7 @@ where let x = ctx.x(); let y = ctx.y(); let s = ctx.mul(x, y).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = S::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -186,7 +186,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.div(x, 2.0).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = S::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -199,7 +199,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.recip(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = S::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -217,7 +217,7 @@ where let x = ctx.x(); let y = ctx.y(); let m = ctx.min(x, y).unwrap(); - let shape = S::new(&ctx, m).unwrap(); + let shape = S::new(&mut ctx, m).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -237,7 +237,7 @@ where let z = ctx.z(); let min = ctx.min(x, y).unwrap(); let max = ctx.max(min, z).unwrap(); - let shape = S::new(&ctx, max).unwrap(); + let shape = S::new(&mut ctx, max).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -259,7 +259,7 @@ where let x = ctx.x(); let y = ctx.y(); let m = ctx.max(x, y).unwrap(); - let shape = S::new(&ctx, m).unwrap(); + let shape = S::new(&mut ctx, m).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -276,7 +276,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let m = ctx.not(x).unwrap(); - let shape = S::new(&ctx, m).unwrap(); + let shape = S::new(&mut ctx, m).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -295,7 +295,7 @@ where let sum = ctx.add(x2, y2).unwrap(); let sqrt = ctx.sqrt(sum).unwrap(); let sub = ctx.sub(sqrt, 0.5).unwrap(); - let shape = S::new(&ctx, sub).unwrap(); + let shape = S::new(&mut ctx, sub).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -317,7 +317,7 @@ where } pub fn test_g_stress_n(depth: usize) { - let (ctx, node) = build_stress_fn(depth); + let (mut ctx, node) = build_stress_fn(depth); // Pick an input slice that's guaranteed to be > 1 SIMD registers let args = (0..32).map(|i| i as f32 / 32f32).collect::>(); @@ -327,7 +327,7 @@ where let z: Vec = args[2..].iter().chain(&args[0..2]).cloned().collect(); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = Self::eval_xyz(&tape, &x, &y, &z); @@ -352,7 +352,7 @@ where // that S is also a VmShape, but this comparison isn't particularly // expensive, so we'll do it regardless. use crate::vm::VmShape; - let shape = VmShape::new(&ctx, node).unwrap(); + let shape = VmShape::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let cmp = TestGradSlice::::eval_xyz(&tape, &x, &y, &z); @@ -375,7 +375,7 @@ where for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { let node = C::build(&mut ctx, v); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = match i { @@ -528,7 +528,7 @@ where for (j, &u) in inputs.iter().enumerate() { let node = C::build(&mut ctx, v, u); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = match (i, j) { @@ -575,7 +575,7 @@ where let c = ctx.constant(*rhs as f64); let node = C::build(&mut ctx, v, c); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = match i { @@ -616,7 +616,7 @@ where let c = ctx.constant(*lhs as f64); let node = C::build(&mut ctx, c, v); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = match i { diff --git a/fidget/src/core/eval/test/interval.rs b/fidget/src/core/eval/test/interval.rs index 93ed22ff..cb178dc6 100644 --- a/fidget/src/core/eval/test/interval.rs +++ b/fidget/src/core/eval/test/interval.rs @@ -28,7 +28,7 @@ where let x = ctx.x(); let y = ctx.y(); - let shape = S::new(&ctx, x).unwrap(); + let shape = S::new(&mut ctx, x).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!( @@ -40,7 +40,7 @@ where [1.0, 5.0].into() ); - let shape = S::new(&ctx, y).unwrap(); + let shape = S::new(&mut ctx, y).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!( @@ -58,7 +58,7 @@ where let x = ctx.x(); let abs_x = ctx.abs(x).unwrap(); - let shape = S::new(&ctx, abs_x).unwrap(); + let shape = S::new(&mut ctx, abs_x).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 1.0].into()); @@ -70,7 +70,7 @@ where let y = ctx.y(); let abs_y = ctx.abs(y).unwrap(); let sum = ctx.add(abs_x, abs_y).unwrap(); - let shape = S::new(&ctx, sum).unwrap(); + let shape = S::new(&mut ctx, sum).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!( @@ -93,7 +93,7 @@ where let v = ctx.add(x, 0.5).unwrap(); let out = ctx.abs(v).unwrap(); - let shape = S::new(&ctx, out).unwrap(); + let shape = S::new(&mut ctx, out).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); @@ -105,7 +105,7 @@ where let x = ctx.x(); let sqrt_x = ctx.sqrt(x).unwrap(); - let shape = S::new(&ctx, sqrt_x).unwrap(); + let shape = S::new(&mut ctx, sqrt_x).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 1.0].into()); @@ -133,7 +133,7 @@ where let x = ctx.x(); let sqrt_x = ctx.square(x).unwrap(); - let shape = S::new(&ctx, sqrt_x).unwrap(); + let shape = S::new(&mut ctx, sqrt_x).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 1.0].into()); @@ -154,7 +154,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sin(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = S::new(&mut ctx, s).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); @@ -164,7 +164,7 @@ where let y = ctx.mul(y, 2.0).unwrap(); let s = ctx.sin(y).unwrap(); let s = ctx.add(x, s).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = S::new(&mut ctx, s).unwrap(); let tape = shape.ez_interval_tape(); assert_eq!(eval.eval_x(&tape, [0.0, 3.0]), [-1.0, 4.0].into()); @@ -175,7 +175,7 @@ where let x = ctx.x(); let neg_x = ctx.neg(x).unwrap(); - let shape = S::new(&ctx, neg_x).unwrap(); + let shape = S::new(&mut ctx, neg_x).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-1.0, 0.0].into()); @@ -197,7 +197,7 @@ where let x = ctx.x(); let not_x = ctx.not(x).unwrap(); - let shape = S::new(&ctx, not_x).unwrap(); + let shape = S::new(&mut ctx, not_x).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [-5.0, 0.0]), [0.0, 1.0].into()); @@ -209,7 +209,7 @@ where let y = ctx.y(); let mul = ctx.mul(x, y).unwrap(); - let shape = S::new(&ctx, mul).unwrap(); + let shape = S::new(&mut ctx, mul).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!( @@ -250,14 +250,14 @@ where let mut ctx = Context::new(); let x = ctx.x(); let mul = ctx.mul(x, 2.0).unwrap(); - let shape = S::new(&ctx, mul).unwrap(); + let shape = S::new(&mut ctx, mul).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 2.0].into()); assert_eq!(eval.eval_x(&tape, [1.0, 2.0]), [2.0, 4.0].into()); let mul = ctx.mul(x, -3.0).unwrap(); - let shape = S::new(&ctx, mul).unwrap(); + let shape = S::new(&mut ctx, mul).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-3.0, 0.0].into()); @@ -270,7 +270,7 @@ where let y = ctx.y(); let sub = ctx.sub(x, y).unwrap(); - let shape = S::new(&ctx, sub).unwrap(); + let shape = S::new(&mut ctx, sub).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!( @@ -299,14 +299,14 @@ where let mut ctx = Context::new(); let x = ctx.x(); let sub = ctx.sub(x, 2.0).unwrap(); - let shape = S::new(&ctx, sub).unwrap(); + let shape = S::new(&mut ctx, sub).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-2.0, -1.0].into()); assert_eq!(eval.eval_x(&tape, [1.0, 2.0]), [-1.0, 0.0].into()); let sub = ctx.sub(-3.0, x).unwrap(); - let shape = S::new(&ctx, sub).unwrap(); + let shape = S::new(&mut ctx, sub).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-4.0, -3.0].into()); @@ -317,7 +317,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let recip = ctx.recip(x).unwrap(); - let shape = S::new(&ctx, recip).unwrap(); + let shape = S::new(&mut ctx, recip).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); @@ -342,7 +342,7 @@ where let x = ctx.x(); let y = ctx.y(); let div = ctx.div(x, y).unwrap(); - let shape = S::new(&ctx, div).unwrap(); + let shape = S::new(&mut ctx, div).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); @@ -389,7 +389,7 @@ where let y = ctx.y(); let min = ctx.min(x, y).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = S::new(&mut ctx, min).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); let (r, data) = @@ -427,7 +427,7 @@ where let x = ctx.x(); let min = ctx.min(x, 1.0).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = S::new(&mut ctx, min).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); let (r, data) = @@ -452,7 +452,7 @@ where let y = ctx.y(); let max = ctx.max(x, y).unwrap(); - let shape = S::new(&ctx, max).unwrap(); + let shape = S::new(&mut ctx, max).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); let (r, data) = @@ -486,7 +486,7 @@ where let z = ctx.z(); let max_xy_z = ctx.max(max, z).unwrap(); - let shape = S::new(&ctx, max_xy_z).unwrap(); + let shape = S::new(&mut ctx, max_xy_z).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); let (r, data) = eval @@ -517,7 +517,7 @@ where let y = ctx.y(); let v = ctx.and(x, y).unwrap(); - let shape = S::new(&ctx, v).unwrap(); + let shape = S::new(&mut ctx, v).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); let (r, trace) = eval @@ -554,7 +554,7 @@ where let y = ctx.y(); let v = ctx.or(x, y).unwrap(); - let shape = S::new(&ctx, v).unwrap(); + let shape = S::new(&mut ctx, v).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); let (r, trace) = eval @@ -587,7 +587,7 @@ where let x = ctx.x(); let min = ctx.min(x, 1.0).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = S::new(&mut ctx, min).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); let (out, data) = @@ -606,7 +606,7 @@ where assert_eq!(data.unwrap().as_ref(), &[Choice::Right]); let max = ctx.max(x, 1.0).unwrap(); - let shape = S::new(&ctx, max).unwrap(); + let shape = S::new(&mut ctx, max).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); let (out, data) = @@ -632,7 +632,7 @@ where let z = ctx.z(); let if_else = ctx.if_nonzero_else(x, y, z).unwrap(); - let shape = S::new(&ctx, if_else).unwrap(); + let shape = S::new(&mut ctx, if_else).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); @@ -681,7 +681,7 @@ where let x = ctx.x(); let max = ctx.max(x, 1.0).unwrap(); - let shape = S::new(&ctx, max).unwrap(); + let shape = S::new(&mut ctx, max).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); let (r, data) = eval @@ -709,7 +709,7 @@ where let y = ctx.y(); let c = ctx.compare(x, y).unwrap(); - let shape = S::new(&ctx, c).unwrap(); + let shape = S::new(&mut ctx, c).unwrap(); let tape = shape.ez_interval_tape(); let mut eval = S::new_interval_eval(); let (out, _trace) = eval.eval(&tape, -5.0, -6.0, 0.0).unwrap(); @@ -717,7 +717,7 @@ where } pub fn test_i_stress_n(depth: usize) { - let (ctx, node) = build_stress_fn(depth); + let (mut ctx, node) = build_stress_fn(depth); // Pick an input slice that's guaranteed to be > 1 SIMD register let args = (0..32).map(|i| i as f32 / 32f32).collect::>(); @@ -729,7 +729,7 @@ where let y: Vec<_> = x[1..].iter().chain(&x[0..1]).cloned().collect(); let z: Vec<_> = x[2..].iter().chain(&x[0..2]).cloned().collect(); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let mut eval = S::new_interval_eval(); let tape = shape.ez_interval_tape(); @@ -742,7 +742,7 @@ where // that S is also a VmShape, but this comparison isn't particularly // expensive, so we'll do it regardless. use crate::vm::VmShape; - let shape = VmShape::new(&ctx, node).unwrap(); + let shape = VmShape::new(&mut ctx, node).unwrap(); let mut eval = VmShape::new_interval_eval(); let tape = shape.ez_interval_tape(); @@ -785,7 +785,7 @@ where for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { let node = C::build(&mut ctx, v); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); for &a in args.iter() { @@ -891,7 +891,7 @@ where continue; } - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); @@ -941,7 +941,7 @@ where let c = ctx.constant(rhs as f64); let node = C::build(&mut ctx, u, c); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); @@ -983,7 +983,7 @@ where let c = ctx.constant(lhs as f64); let node = C::build(&mut ctx, c, u); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); diff --git a/fidget/src/core/eval/test/point.rs b/fidget/src/core/eval/test/point.rs index 18718b28..be179f4b 100644 --- a/fidget/src/core/eval/test/point.rs +++ b/fidget/src/core/eval/test/point.rs @@ -20,7 +20,7 @@ where pub fn test_constant() { let mut ctx = Context::new(); let p = ctx.constant(1.5); - let shape = S::new(&ctx, p).unwrap(); + let shape = S::new(&mut ctx, p).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); assert_eq!(eval.eval(&tape, 0.0, 0.0, 0.0).unwrap().0, 1.5); @@ -31,7 +31,7 @@ where let a = ctx.constant(1.5); let x = ctx.x(); let min = ctx.min(a, x).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = S::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); let (r, trace) = eval.eval(&tape, 2.0, 0.0, 0.0).unwrap(); @@ -54,7 +54,7 @@ where let radius = ctx.add(x_squared, y_squared).unwrap(); let circle = ctx.sub(radius, 1.0).unwrap(); - let shape = S::new(&ctx, circle).unwrap(); + let shape = S::new(&mut ctx, circle).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); assert_eq!(eval.eval(&tape, 0.0, 0.0, 0.0).unwrap().0, -1.0); @@ -70,7 +70,7 @@ where let y = ctx.y(); let min = ctx.min(x, y).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = S::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); let (r, trace) = eval.eval(&tape, 0.0, 0.0, 0.0).unwrap(); @@ -103,7 +103,7 @@ where let y = ctx.y(); let max = ctx.max(x, y).unwrap(); - let shape = S::new(&ctx, max).unwrap(); + let shape = S::new(&mut ctx, max).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); @@ -137,7 +137,7 @@ where let y = ctx.y(); let v = ctx.and(x, y).unwrap(); - let shape = S::new(&ctx, v).unwrap(); + let shape = S::new(&mut ctx, v).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); let (r, trace) = eval.eval(&tape, 0.0, 0.0, 0.0).unwrap(); @@ -174,7 +174,7 @@ where let y = ctx.y(); let v = ctx.or(x, y).unwrap(); - let shape = S::new(&ctx, v).unwrap(); + let shape = S::new(&mut ctx, v).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); let (r, trace) = eval.eval(&tape, 0.0, 0.0, 0.0).unwrap(); @@ -210,7 +210,7 @@ where let x = ctx.x(); let s = ctx.sin(x).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = S::new(&mut ctx, s).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); @@ -230,7 +230,7 @@ where let y = ctx.y(); let s = ctx.add(s, y).unwrap(); - let shape = S::new(&ctx, s).unwrap(); + let shape = S::new(&mut ctx, s).unwrap(); let tape = shape.ez_point_tape(); for (x, y) in [(0.0, 1.0), (1.0, 3.0), (2.0, 8.0)] { @@ -246,7 +246,7 @@ where let y = ctx.y(); let sum = ctx.add(x, 1.0).unwrap(); let min = ctx.min(sum, y).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = S::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 2.0); @@ -263,7 +263,7 @@ where let y = ctx.y(); let min = ctx.min(x, y).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = S::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 1.0); @@ -280,7 +280,7 @@ where assert_eq!(eval.eval(&tape, 3.0, 2.0, 0.0).unwrap().0, 2.0); let min = ctx.min(x, 1.0).unwrap(); - let shape = S::new(&ctx, min).unwrap(); + let shape = S::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); assert_eq!(eval.eval(&tape, 0.5, 0.0, 0.0).unwrap().0, 0.5); @@ -302,13 +302,13 @@ where let x = ctx.x(); let y = ctx.y(); - let shape = S::new(&ctx, x).unwrap(); + let shape = S::new(&mut ctx, x).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 1.0); assert_eq!(eval.eval(&tape, 3.0, 4.0, 0.0).unwrap().0, 3.0); - let shape = S::new(&ctx, y).unwrap(); + let shape = S::new(&mut ctx, y).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 2.0); @@ -317,14 +317,14 @@ where let y2 = ctx.mul(y, 2.5).unwrap(); let sum = ctx.add(x, y2).unwrap(); - let shape = S::new(&ctx, sum).unwrap(); + let shape = S::new(&mut ctx, sum).unwrap(); let tape = shape.ez_point_tape(); let mut eval = S::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 6.0); } pub fn test_p_stress_n(depth: usize) { - let (ctx, node) = build_stress_fn(depth); + let (mut ctx, node) = build_stress_fn(depth); // Pick an input slice that's guaranteed to be > 1 SIMD register let args = (0..32).map(|i| i as f32 / 32f32).collect::>(); @@ -332,7 +332,7 @@ where let y: Vec<_> = x[1..].iter().chain(&x[0..1]).cloned().collect(); let z: Vec<_> = x[2..].iter().chain(&x[0..2]).cloned().collect(); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let mut eval = S::new_point_eval(); let tape = shape.ez_point_tape(); @@ -360,7 +360,7 @@ where // that S is also a VmShape, but this comparison isn't particularly // expensive, so we'll do it regardless. use crate::vm::VmShape; - let shape = VmShape::new(&ctx, node).unwrap(); + let shape = VmShape::new(&mut ctx, node).unwrap(); let mut eval = VmShape::new_point_eval(); let tape = shape.ez_point_tape(); @@ -388,7 +388,7 @@ where for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { let node = C::build(&mut ctx, v); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let mut eval = S::new_point_eval(); let tape = shape.ez_point_tape(); @@ -444,7 +444,7 @@ where for (j, &v) in xyz.iter().enumerate() { let node = C::build(&mut ctx, u, v); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let mut eval = S::new_point_eval(); let tape = shape.ez_point_tape(); @@ -489,7 +489,7 @@ where let c = ctx.constant(rhs as f64); let node = C::build(&mut ctx, u, c); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let mut eval = S::new_point_eval(); let tape = shape.ez_point_tape(); @@ -526,7 +526,7 @@ where let c = ctx.constant(lhs as f64); let node = C::build(&mut ctx, c, u); - let shape = S::new(&ctx, node).unwrap(); + let shape = S::new(&mut ctx, node).unwrap(); let mut eval = S::new_point_eval(); let tape = shape.ez_point_tape(); diff --git a/fidget/src/core/eval/tracing.rs b/fidget/src/core/eval/tracing.rs index 12b78945..6cadf6c0 100644 --- a/fidget/src/core/eval/tracing.rs +++ b/fidget/src/core/eval/tracing.rs @@ -33,10 +33,10 @@ pub trait TracingEvaluator: Default { type Trace; /// Evaluates the given tape at a particular position - fn eval>( + fn eval( &mut self, tape: &Self::Tape, - vars: &[F], + vars: &[Self::Data], ) -> Result<(Self::Data, Option<&Self::Trace>), Error>; /// Build a new empty evaluator diff --git a/fidget/src/core/shape/bulk.rs b/fidget/src/core/shape/bulk.rs index ae6bc083..7390c31d 100644 --- a/fidget/src/core/shape/bulk.rs +++ b/fidget/src/core/shape/bulk.rs @@ -44,21 +44,4 @@ pub trait BulkEvaluator: Default { fn new() -> Self { Self::default() } - - /// Helper function to return an error if the inputs are invalid - fn check_arguments( - &self, - xs: &[T], - ys: &[T], - zs: &[T], - var_count: usize, - ) -> Result<(), Error> { - if xs.len() != ys.len() || ys.len() != zs.len() { - Err(Error::MismatchedSlices) - } else if var_count > 3 { - Err(Error::BadVarSlice(3, var_count)) - } else { - Ok(()) - } - } } diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index cf4fc74f..8ddaa47b 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -28,10 +28,10 @@ //! ambiguity. use crate::{ - context::Node, + context::{Context, Node}, eval::{self, Trace}, types::{Grad, Interval}, - Context, Error, + Error, }; mod bounds; @@ -241,24 +241,6 @@ impl EzShape for S { } } -/// A [`Shape`] which can be built from a math expression -pub trait MathShape { - /// Builds a new shape from the given context and node - fn new(ctx: &Context, node: Node) -> Result - where - Self: Sized; - - /// Helper function to build a shape from a [`Tree`](crate::context::Tree) - fn from_tree(t: &crate::context::Tree) -> Self - where - Self: Sized, - { - let mut ctx = Context::new(); - let node = ctx.import(t); - Self::new(&ctx, node).unwrap() - } -} - /// Hints for how to render this particular type pub trait RenderHints { /// Recommended tile sizes for 3D rendering @@ -278,6 +260,39 @@ pub trait RenderHints { } } +/// A [`Shape`] which can be built from a math expression +pub trait MathShape { + /// Builds a new shape from the given node with default (X, Y, Z) axes + fn new(ctx: &mut Context, node: Node) -> Result + where + Self: Sized, + { + let axes = ctx.axes(); + Self::new_with_axes(ctx, node, axes) + } + + /// Builds a new shape from the given context, node, and axes + fn new_with_axes( + ctx: &Context, + node: Node, + axes: [Node; 3], + ) -> Result + where + Self: Sized; + + /// Helper function to build a shape from a [`Tree`](crate::context::Tree) + /// + /// This function uses the default (X, Y, Z) axes + fn from_tree(t: &crate::context::Tree) -> Self + where + Self: Sized, + { + let mut ctx = Context::new(); + let node = ctx.import(t); + Self::new(&mut ctx, node).unwrap() + } +} + //////////////////////////////////////////////////////////////////////////////// /// Wrapper to convert a [`Function`](fidget::eval::Function) into a [`Shape`] @@ -366,6 +381,31 @@ impl Shape for FunctionShape { // todo } +impl MathShape for FunctionShape { + fn new_with_axes( + ctx: &Context, + node: Node, + axes: [Node; 3], + ) -> Result { + let f = F::new(ctx, node)?; // TODO get a varmap here + Ok(Self { f, axes: [0, 1, 2] }) + } +} + +impl RenderHints for FunctionShape { + fn tile_sizes_3d() -> &'static [usize] { + F::tile_sizes_3d() + } + + fn tile_sizes_2d() -> &'static [usize] { + F::tile_sizes_2d() + } + + fn simplify_tree_during_meshing(d: usize) -> bool { + F::simplify_tree_during_meshing(d) + } +} + /// Wrapper struct to convert from [`eval::TracingEvaluator`] to /// [`shape::TracingEvaluator`](TracingEvaluator) #[derive(Default)] @@ -392,13 +432,13 @@ impl TracingEvaluator z: F, ) -> Result<(Self::Data, Option<&Self::Trace>), Error> { let mut vars = [None, None, None]; - vars[self.axes[0]] = Some(x); - vars[self.axes[1]] = Some(y); - vars[self.axes[2]] = Some(z); + vars[self.axes[0]] = Some(x.into()); + vars[self.axes[1]] = Some(y.into()); + vars[self.axes[2]] = Some(z.into()); // TODO make this error? Where do we maintain the `axes` invariants? let vars = vars.map(Option::unwrap); - self.eval.eval(tape, &vars) + self.eval.eval(tape, vars.as_slice()) } // todo } diff --git a/fidget/src/core/shape/tracing.rs b/fidget/src/core/shape/tracing.rs index c3f7e041..b9689c08 100644 --- a/fidget/src/core/shape/tracing.rs +++ b/fidget/src/core/shape/tracing.rs @@ -46,15 +46,6 @@ pub trait TracingEvaluator: Default { Self::default() } - /// Helper function to return an error if the inputs are invalid - fn check_arguments(&self, var_count: usize) -> Result<(), Error> { - if var_count > 3 { - Err(Error::BadVarSlice(3, var_count)) - } else { - Ok(()) - } - } - #[cfg(test)] fn eval_x>( &mut self, diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index 2a0102cc..7ae5f43b 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -2,15 +2,12 @@ use crate::{ compiler::RegOp, context::Node, + eval::{BulkEvaluator, Function, MathFunction, TracingEvaluator}, eval::{Tape, Trace}, - shape::RenderHints, - shape::{ - BulkEvaluator, MathShape, Shape, TracingEvaluator, TransformedShape, - }, + shape::{FunctionShape, RenderHints}, types::{Grad, Interval}, Context, Error, }; -use nalgebra::Matrix4; use std::sync::Arc; mod choice; @@ -28,10 +25,9 @@ pub use data::{VmData, VmWorkspace}; /// /// All of the associated [`Tape`] types simply clone the internal `Arc`; /// there's no separate planning required to generate a tape. -/// -pub type VmShape = GenericVmShape<{ u8::MAX as usize }>; +pub type VmShape = FunctionShape>; -impl Tape for GenericVmShape { +impl Tape for GenericVmFunction { type Storage = (); fn recycle(self) -> Self::Storage { // nothing to do here @@ -93,15 +89,15 @@ impl AsRef<[Choice]> for VmTrace { /// You are unlikely to use this directly; [`VmShape`] should be used for /// VM-based evaluation. #[derive(Clone)] -pub struct GenericVmShape(Arc>); +pub struct GenericVmFunction(Arc>); -impl From> for GenericVmShape { +impl From> for GenericVmFunction { fn from(d: VmData) -> Self { Self(d.into()) } } -impl GenericVmShape { +impl GenericVmFunction { pub(crate) fn simplify_inner( &self, choices: &[Choice], @@ -137,7 +133,7 @@ impl GenericVmShape { } } -impl Shape for GenericVmShape { +impl Function for GenericVmFunction { type FloatSliceEval = VmFloatSliceEval; type Storage = VmData; type Workspace = VmWorkspace; @@ -170,20 +166,15 @@ impl Shape for GenericVmShape { } fn recycle(self) -> Option { - GenericVmShape::recycle(self) + GenericVmFunction::recycle(self) } fn size(&self) -> usize { - GenericVmShape::size(self) - } - - type TransformedShape = TransformedShape; - fn apply_transform(self, mat: Matrix4) -> Self::TransformedShape { - TransformedShape::new(self, mat) + GenericVmFunction::size(self) } } -impl RenderHints for GenericVmShape { +impl RenderHints for GenericVmFunction { fn tile_sizes_3d() -> &'static [usize] { &[256, 128, 64, 32, 16, 8] } @@ -193,7 +184,7 @@ impl RenderHints for GenericVmShape { } } -impl MathShape for GenericVmShape { +impl MathFunction for GenericVmFunction { fn new(ctx: &Context, node: Node) -> Result { let d = VmData::new(ctx, node)?; Ok(Self(Arc::new(d))) @@ -257,22 +248,17 @@ impl + Clone> TracingVmEval { pub struct VmIntervalEval(TracingVmEval); impl TracingEvaluator for VmIntervalEval { type Data = Interval; - type Tape = GenericVmShape; + type Tape = GenericVmFunction; type Trace = VmTrace; type TapeStorage = (); - fn eval>( + fn eval( &mut self, tape: &Self::Tape, - x: F, - y: F, - z: F, + vars: &[Interval], ) -> Result<(Interval, Option<&VmTrace>), Error> { - let x = x.into(); - let y = y.into(); - let z = z.into(); let tape = tape.0.as_ref(); - self.check_arguments(tape.var_count())?; + self.check_arguments(vars, tape.var_count())?; self.0.resize_slots(tape); let mut simplify = false; @@ -281,12 +267,7 @@ impl TracingEvaluator for VmIntervalEval { for op in tape.iter_asm() { match op { RegOp::Input(out, i) => { - v[out] = match i { - 0 => x, - 1 => y, - 2 => z, - _ => panic!("Invalid input: {}", i), - } + v[out] = vars[i as usize]; } RegOp::NegReg(out, arg) => { v[out] = -v[arg]; @@ -496,22 +477,17 @@ impl TracingEvaluator for VmIntervalEval { pub struct VmPointEval(TracingVmEval); impl TracingEvaluator for VmPointEval { type Data = f32; - type Tape = GenericVmShape; + type Tape = GenericVmFunction; type Trace = VmTrace; type TapeStorage = (); - fn eval>( + fn eval( &mut self, tape: &Self::Tape, - x: F, - y: F, - z: F, + vars: &[f32], ) -> Result<(f32, Option<&VmTrace>), Error> { - let x = x.into(); - let y = y.into(); - let z = z.into(); let tape = tape.0.as_ref(); - self.check_arguments(tape.var_count())?; + self.check_arguments(vars, tape.var_count())?; self.0.resize_slots(tape); let mut choices = self.0.choices.as_mut_slice().iter_mut(); @@ -520,12 +496,7 @@ impl TracingEvaluator for VmPointEval { for op in tape.iter_asm() { match op { RegOp::Input(out, i) => { - v[out] = match i { - 0 => x, - 1 => y, - 2 => z, - _ => panic!("Invalid input: {}", i), - } + v[out] = vars[i as usize]; } RegOp::NegReg(out, arg) => { v[out] = -v[arg]; @@ -593,7 +564,7 @@ impl TracingEvaluator for VmPointEval { v[out] = imm / v[arg]; } RegOp::AtanRegImm(out, arg, imm) => { - v[out] = v[arg].atan2(imm.into()); + v[out] = v[arg].atan2(imm); } RegOp::AtanImmReg(out, arg, imm) => { v[out] = imm.atan2(v[arg]); @@ -821,34 +792,24 @@ impl + Clone> BulkVmEval { pub struct VmFloatSliceEval(BulkVmEval); impl BulkEvaluator for VmFloatSliceEval { type Data = f32; - type Tape = GenericVmShape; + type Tape = GenericVmFunction; type TapeStorage = (); fn eval( &mut self, tape: &Self::Tape, - xs: &[f32], - ys: &[f32], - zs: &[f32], + vars: &[&[f32]], ) -> Result<&[f32], Error> { let tape = tape.0.as_ref(); - self.check_arguments(xs, ys, zs, tape.var_count())?; - self.0.resize_slots(tape, xs.len()); - assert_eq!(xs.len(), ys.len()); - assert_eq!(ys.len(), zs.len()); - - let size = xs.len(); + self.check_arguments(vars, tape.var_count())?; + let size = vars.first().map(|v| v.len()).unwrap_or(0); + self.0.resize_slots(tape, size); let mut v = SlotArray(&mut self.0.slots); for op in tape.iter_asm() { match op { RegOp::Input(out, i) => { - v[out][0..size].copy_from_slice(match i { - 0 => xs, - 1 => ys, - 2 => zs, - _ => panic!("Invalid input: {}", i), - }) + v[out][0..size].copy_from_slice(vars[i as usize]); } RegOp::NegReg(out, arg) => { for i in 0..size { @@ -1139,35 +1100,24 @@ impl BulkEvaluator for VmFloatSliceEval { pub struct VmGradSliceEval(BulkVmEval); impl BulkEvaluator for VmGradSliceEval { type Data = Grad; - type Tape = GenericVmShape; + type Tape = GenericVmFunction; type TapeStorage = (); fn eval( &mut self, tape: &Self::Tape, - xs: &[Grad], - ys: &[Grad], - zs: &[Grad], + vars: &[&[Grad]], ) -> Result<&[Grad], Error> { let tape = tape.0.as_ref(); - self.check_arguments(xs, ys, zs, tape.var_count())?; - self.0.resize_slots(tape, xs.len()); - assert_eq!(xs.len(), ys.len()); - assert_eq!(ys.len(), zs.len()); + self.check_arguments(vars, tape.var_count())?; + let size = vars.first().map(|v| v.len()).unwrap_or(0); + self.0.resize_slots(tape, size); - let size = xs.len(); let mut v = SlotArray(&mut self.0.slots); for op in tape.iter_asm() { match op { - RegOp::Input(out, j) => { - for i in 0..size { - v[out][i] = match j { - 0 => xs[i], - 1 => ys[i], - 2 => zs[i], - _ => panic!("Invalid input: {}", i), - } - } + RegOp::Input(out, i) => { + v[out][0..size].copy_from_slice(vars[i as usize]); } RegOp::NegReg(out, arg) => { for i in 0..size { diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index df744438..1e2668bf 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -26,14 +26,12 @@ use crate::{ compiler::RegOp, context::{Context, Node}, - eval::Tape, + eval::{MathFunction, Tape}, jit::mmap::Mmap, shape::RenderHints, - shape::{ - BulkEvaluator, MathShape, Shape, TracingEvaluator, TransformedShape, - }, + shape::{BulkEvaluator, Shape, TracingEvaluator, TransformedShape}, types::{Grad, Interval}, - vm::{Choice, GenericVmShape, VmData, VmTrace, VmWorkspace}, + vm::{Choice, GenericVmFunction, VmData, VmTrace, VmWorkspace}, Error, }; use dynasmrt::{ @@ -829,7 +827,7 @@ fn build_asm_fn_with_storage( /// Shape for use with a JIT evaluator #[derive(Clone)] -pub struct JitShape(GenericVmShape); +pub struct JitShape(GenericVmFunction); impl JitShape { fn tracing_tape( @@ -989,16 +987,13 @@ unsafe impl Sync for JitTracingFn {} impl JitTracingEval { /// Evaluates a single point, capturing an evaluation trace - fn eval, F: Into>( + fn eval( &mut self, tape: &JitTracingFn, - x: F, - y: F, - z: F, + x: T, + y: T, + z: T, ) -> (T, Option<&VmTrace>) { - let x = x.into(); - let y = y.into(); - let z = z.into(); let mut simplify = 0; self.choices.resize(tape.choice_count, Choice::Unknown); assert!(tape.var_count <= 3); @@ -1031,12 +1026,12 @@ impl TracingEvaluator for JitIntervalEval { type Trace = VmTrace; type TapeStorage = Mmap; - fn eval>( + fn eval( &mut self, tape: &Self::Tape, - x: F, - y: F, - z: F, + x: Self::Data, + y: Self::Data, + z: Self::Data, ) -> Result<(Self::Data, Option<&Self::Trace>), Error> { Ok(self.0.eval(tape, x, y, z)) } @@ -1051,12 +1046,12 @@ impl TracingEvaluator for JitPointEval { type Trace = VmTrace; type TapeStorage = Mmap; - fn eval>( + fn eval( &mut self, tape: &Self::Tape, - x: F, - y: F, - z: F, + x: Self::Data, + y: Self::Data, + z: Self::Data, ) -> Result<(Self::Data, Option<&Self::Trace>), Error> { Ok(self.0.eval(tape, x, y, z)) } @@ -1190,7 +1185,6 @@ impl BulkEvaluator for JitFloatSliceEval { ys: &[f32], zs: &[f32], ) -> Result<&[Self::Data], Error> { - self.check_arguments(xs, ys, zs, tape.var_count)?; Ok(self.0.eval(tape, xs, ys, zs)) } } @@ -1210,14 +1204,13 @@ impl BulkEvaluator for JitGradSliceEval { ys: &[Self::Data], zs: &[Self::Data], ) -> Result<&[Self::Data], Error> { - self.check_arguments(xs, ys, zs, tape.var_count)?; Ok(self.0.eval(tape, xs, ys, zs)) } } -impl MathShape for JitShape { +impl MathFunction for JitShape { fn new(ctx: &Context, node: Node) -> Result { - GenericVmShape::new(ctx, node).map(JitShape) + GenericVmFunction::new(ctx, node).map(JitShape) } } diff --git a/fidget/src/mesh/octree.rs b/fidget/src/mesh/octree.rs index 714b9c21..7dde8cb9 100644 --- a/fidget/src/mesh/octree.rs +++ b/fidget/src/mesh/octree.rs @@ -1585,9 +1585,9 @@ mod test { #[test] fn test_colonnade_manifold() { const COLONNADE: &str = include_str!("../../../models/colonnade.vm"); - let (ctx, root) = + let (mut ctx, root) = crate::Context::from_text(COLONNADE.as_bytes()).unwrap(); - let tape = VmShape::new(&ctx, root).unwrap(); + let tape = VmShape::new(&mut ctx, root).unwrap(); for threads in [1, 8] { let settings = Settings { depth: 5, diff --git a/fidget/src/render/render2d.rs b/fidget/src/render/render2d.rs index b046b40a..6ffb91c4 100644 --- a/fidget/src/render/render2d.rs +++ b/fidget/src/render/render2d.rs @@ -467,8 +467,8 @@ fn render_inner( mod test { use super::*; use crate::{ - shape::{Bounds, MathShape, Shape}, - vm::{GenericVmShape, VmShape}, + shape::{Bounds, FunctionShape, MathShape, Shape}, + vm::{GenericVmFunction, VmShape}, Context, }; @@ -513,8 +513,8 @@ mod test { } fn check_hi() { - let (ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); - let shape = S::new(&ctx, root).unwrap(); + let (mut ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); + let shape = S::new(&mut ctx, root).unwrap(); const EXPECTED: &str = " .................X.............. .................X.............. @@ -552,8 +552,8 @@ mod test { } fn check_hi_transformed() { - let (ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); - let shape = S::new(&ctx, root).unwrap(); + let (mut ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); + let shape = S::new(&mut ctx, root).unwrap(); let mut mat = nalgebra::Matrix4::::identity(); mat.prepend_translation_mut(&nalgebra::Vector3::new(0.5, 0.5, 0.0)); mat.prepend_scaling_mut(0.5); @@ -595,8 +595,8 @@ mod test { } fn check_hi_bounded() { - let (ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); - let shape = S::new(&ctx, root).unwrap(); + let (mut ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); + let shape = S::new(&mut ctx, root).unwrap(); const EXPECTED: &str = " .XXX............................ .XXX............................ @@ -641,8 +641,8 @@ mod test { } fn check_quarter() { - let (ctx, root) = Context::from_text(QUARTER.as_bytes()).unwrap(); - let shape = S::new(&ctx, root).unwrap(); + let (mut ctx, root) = Context::from_text(QUARTER.as_bytes()).unwrap(); + let shape = S::new(&mut ctx, root).unwrap(); const EXPECTED: &str = " ................................ ................................ @@ -686,7 +686,7 @@ mod test { #[test] fn render_hi_vm3() { - check_hi::>(); + check_hi::>>(); } #[cfg(feature = "jit")] @@ -702,7 +702,7 @@ mod test { #[test] fn render_hi_transformed_vm3() { - check_hi_transformed::>(); + check_hi_transformed::>>(); } #[cfg(feature = "jit")] @@ -718,7 +718,7 @@ mod test { #[test] fn render_hi_bounded_vm3() { - check_hi_bounded::>(); + check_hi_bounded::>>(); } #[cfg(feature = "jit")] @@ -734,7 +734,7 @@ mod test { #[test] fn render_quarter_vm3() { - check_quarter::>(); + check_quarter::>>(); } #[cfg(feature = "jit")] diff --git a/fidget/src/render/render3d.rs b/fidget/src/render/render3d.rs index f7e83227..79aa488d 100644 --- a/fidget/src/render/render3d.rs +++ b/fidget/src/render/render3d.rs @@ -455,7 +455,7 @@ mod test { fn test_tile_queues() { let mut ctx = Context::new(); let x = ctx.x(); - let shape = VmShape::new(&ctx, x).unwrap(); + let shape = VmShape::new(&mut ctx, x).unwrap(); let cfg = RenderConfig::<3> { image_size: 128, // very small! From b07560d95462a518cc6cdc19cfc663852725830d Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Tue, 21 May 2024 08:45:19 -0400 Subject: [PATCH 05/12] Add `enum Var` --- fidget/src/core/compiler/ssa_tape.rs | 9 +++-- fidget/src/core/context/mod.rs | 57 +++++++++++++++++++++------- fidget/src/core/eval/mod.rs | 5 ++- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/fidget/src/core/compiler/ssa_tape.rs b/fidget/src/core/compiler/ssa_tape.rs index 6f58c1b4..4fc8293f 100644 --- a/fidget/src/core/compiler/ssa_tape.rs +++ b/fidget/src/core/compiler/ssa_tape.rs @@ -1,7 +1,7 @@ //use crate::vm::{RegisterAllocator, Tape as VmTape}; use crate::{ compiler::SsaOp, - context::{BinaryOpcode, Node, Op, UnaryOpcode}, + context::{BinaryOpcode, Node, Op, UnaryOpcode, Var}, Context, Error, }; use serde::{Deserialize, Serialize}; @@ -92,10 +92,11 @@ impl SsaTape { let op = match op { Op::Input(..) => { let arg = match ctx.var_name(node).unwrap().unwrap() { - "X" => 0, - "Y" => 1, - "Z" => 2, + Var::X => 0, + Var::Y => 1, + Var::Z => 2, i => panic!("Unexpected input index: {i}"), + // TODO make this work for _n_ vars }; SsaOp::Input(i, arg) } diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index d835ee95..04bbcf2b 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -42,7 +42,39 @@ define_index!(VarNode, "An index in the `Context::vars` map"); #[derive(Debug, Default)] pub struct Context { ops: IndexMap, - vars: IndexMap, + vars: IndexMap, +} + +/// A `Var` represents a value which can vary during evaluation +/// +/// We pre-define common variables (e.g. X, Y, Z) but also allow for fully +/// customized values. +#[allow(missing_docs)] +#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub enum Var { + X, + Y, + Z, + W, + T, + Static(&'static str), + Named(String), + Value(u64), +} + +impl std::fmt::Display for Var { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Var::X => write!(f, "X"), + Var::Y => write!(f, "Y"), + Var::Z => write!(f, "Z"), + Var::W => write!(f, "W"), + Var::T => write!(f, "T"), + Var::Static(s) => write!(f, "{s}"), + Var::Named(s) => write!(f, "{s}"), + Var::Value(v) => write!(f, "v_{v}"), + } + } } impl Context { @@ -119,11 +151,11 @@ impl Context { } } - /// Looks up the variable name associated with the given node. + /// Looks up the [`Var`] associated with the given node. /// /// If the node is invalid for this tree, returns an error; if the node is /// not an `Op::Input`, returns `Ok(None)`. - pub fn var_name(&self, n: Node) -> Result, Error> { + pub fn var_name(&self, n: Node) -> Result, Error> { match self.get_op(n) { Some(Op::Input(c)) => self.get_var_by_index(*c).map(Some), Some(_) => Ok(None), @@ -131,8 +163,8 @@ impl Context { } } - /// Looks up the variable name associated with the given `VarNode` - pub fn get_var_by_index(&self, n: VarNode) -> Result<&str, Error> { + /// Looks up the [`Var`] associated with the given [`VarNode`] + pub fn get_var_by_index(&self, n: VarNode) -> Result<&Var, Error> { match self.vars.get_by_index(n) { Some(c) => Ok(c), None => Err(Error::BadVar), @@ -150,19 +182,19 @@ impl Context { /// assert_eq!(v, 1.0); /// ``` pub fn x(&mut self) -> Node { - let v = self.vars.insert(String::from("X")); + let v = self.vars.insert(Var::X); self.ops.insert(Op::Input(v)) } /// Constructs or finds a variable node named "Y" pub fn y(&mut self) -> Node { - let v = self.vars.insert(String::from("Y")); + let v = self.vars.insert(Var::Y); self.ops.insert(Op::Input(v)) } /// Constructs or finds a variable node named "Z" pub fn z(&mut self) -> Node { - let v = self.vars.insert(String::from("Z")); + let v = self.vars.insert(Var::Z); self.ops.insert(Op::Input(v)) } @@ -757,9 +789,8 @@ impl Context { y: f64, z: f64, ) -> Result { - let vars = [("X", x), ("Y", y), ("Z", z)] + let vars = [(Var::X, x), (Var::Y, y), (Var::Z, z)] .into_iter() - .map(|(a, b)| (a.to_string(), b)) .collect(); self.eval(root, &vars) } @@ -771,7 +802,7 @@ impl Context { pub fn eval( &self, root: Node, - vars: &BTreeMap, + vars: &BTreeMap, ) -> Result { let mut cache = vec![None; self.ops.len()].into(); self.eval_inner(root, vars, &mut cache) @@ -780,7 +811,7 @@ impl Context { fn eval_inner( &self, node: Node, - vars: &BTreeMap, + vars: &BTreeMap, cache: &mut IndexVec, Node>, ) -> Result { if node.0 >= cache.len() { @@ -965,7 +996,7 @@ impl Context { Op::Const(c) => write!(out, "{}", c).unwrap(), Op::Input(v) => { let v = self.vars.get_by_index(*v).unwrap(); - out += v; + out += &v.to_string(); } Op::Binary(op, ..) => match op { BinaryOpcode::Add => out += "add", diff --git a/fidget/src/core/eval/mod.rs b/fidget/src/core/eval/mod.rs index c095b5e2..e6e5654e 100644 --- a/fidget/src/core/eval/mod.rs +++ b/fidget/src/core/eval/mod.rs @@ -1,6 +1,6 @@ //! Traits and data structures for function evaluation use crate::{ - context::{Context, Node, Tree}, + context::{Context, Node, Tree, Var}, types::{Grad, Interval}, Error, }; @@ -167,6 +167,9 @@ pub trait Function: Send + Sync + Clone { fn size(&self) -> usize; } +/// Map from variable (from a particular [`Context`]) to index +pub type VarMap = std::collections::HashMap; + /// A [`Function`] which can be built from a math expression pub trait MathFunction { /// Builds a new shape from the given context and node From ebff1873ba0eed3f92564dd8be6a0a8d07aaba78 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Tue, 21 May 2024 09:21:10 -0400 Subject: [PATCH 06/12] Add VarMap and get --lib tests passing --- demo/src/main.rs | 8 +- fidget/Cargo.toml | 2 +- fidget/benches/mesh.rs | 2 +- fidget/benches/render.rs | 4 +- fidget/src/core/compiler/ssa_tape.rs | 29 ++--- fidget/src/core/context/mod.rs | 6 +- fidget/src/core/eval/bulk.rs | 4 +- fidget/src/core/eval/mod.rs | 24 +++-- fidget/src/core/eval/tracing.rs | 4 +- fidget/src/core/mod.rs | 11 +- fidget/src/core/shape/bulk.rs | 5 - fidget/src/core/shape/mod.rs | 135 ++++++++++++++++------- fidget/src/core/shape/tracing.rs | 5 - fidget/src/core/vm/data.rs | 15 +-- fidget/src/core/vm/mod.rs | 13 ++- fidget/src/error.rs | 4 + fidget/src/jit/mod.rs | 155 ++++++++++++++------------- fidget/src/lib.rs | 6 +- fidget/src/mesh/mod.rs | 2 +- fidget/src/render/render2d.rs | 4 +- fidget/src/render/render3d.rs | 6 +- wasm-demo/Cargo.lock | 2 +- wasm-demo/src/lib.rs | 12 ++- 23 files changed, 265 insertions(+), 193 deletions(-) diff --git a/demo/src/main.rs b/demo/src/main.rs index 58bda40c..c5503c8d 100644 --- a/demo/src/main.rs +++ b/demo/src/main.rs @@ -314,12 +314,12 @@ fn main() -> Result<()> { let buffer = match settings.eval { #[cfg(feature = "jit")] EvalMode::Jit => { - let shape = fidget::jit::JitShape::new(&ctx, root)?; + let shape = fidget::jit::JitShape::new(&mut ctx, root)?; info!("Built shape in {:?}", start.elapsed()); run3d(shape, &settings, isometric, color) } EvalMode::Vm => { - let shape = fidget::vm::VmShape::new(&ctx, root)?; + let shape = fidget::vm::VmShape::new(&mut ctx, root)?; info!("Built shape in {:?}", start.elapsed()); run3d(shape, &settings, isometric, color) } @@ -348,12 +348,12 @@ fn main() -> Result<()> { let mesh = match settings.eval { #[cfg(feature = "jit")] EvalMode::Jit => { - let shape = fidget::jit::JitShape::new(&ctx, root)?; + let shape = fidget::jit::JitShape::new(&mut ctx, root)?; info!("Built shape in {:?}", start.elapsed()); run_mesh(shape, &settings) } EvalMode::Vm => { - let shape = fidget::vm::VmShape::new(&ctx, root)?; + let shape = fidget::vm::VmShape::new(&mut ctx, root)?; info!("Built shape in {:?}", start.elapsed()); run_mesh(shape, &settings) } diff --git a/fidget/Cargo.toml b/fidget/Cargo.toml index cf11ac40..60e682d2 100644 --- a/fidget/Cargo.toml +++ b/fidget/Cargo.toml @@ -39,7 +39,7 @@ windows = { version = "0.54.0", features = ["Win32_Foundation", "Win32_System_Me getrandom = { version = "0.2", features = ["js"] } [features] -default = ["rhai", "render", "mesh"] +default = ["jit", "rhai", "render", "mesh"] ## Enables fast evaluation via a JIT compiler. This is exposed in the ## [`fidget::jit`](crate::jit) module, and is supported on diff --git a/fidget/benches/mesh.rs b/fidget/benches/mesh.rs index 8c62dfdd..3d57c38e 100644 --- a/fidget/benches/mesh.rs +++ b/fidget/benches/mesh.rs @@ -10,7 +10,7 @@ pub fn colonnade_octree_thread_sweep(c: &mut Criterion) { fidget::Context::from_text(COLONNADE.as_bytes()).unwrap(); let shape_vm = &fidget::vm::VmShape::new(&mut ctx, root).unwrap(); #[cfg(feature = "jit")] - let shape_jit = &fidget::jit::JitShape::new(&ctx, root).unwrap(); + let shape_jit = &fidget::jit::JitShape::new(&mut ctx, root).unwrap(); let mut group = c.benchmark_group("speed vs threads (colonnade, octree) (depth 6)"); diff --git a/fidget/benches/render.rs b/fidget/benches/render.rs index d8c00088..dfec74f2 100644 --- a/fidget/benches/render.rs +++ b/fidget/benches/render.rs @@ -12,7 +12,7 @@ pub fn prospero_size_sweep(c: &mut Criterion) { let shape_vm = &fidget::vm::VmShape::new(&mut ctx, root).unwrap(); #[cfg(feature = "jit")] - let shape_jit = &fidget::jit::JitShape::new(&ctx, root).unwrap(); + let shape_jit = &fidget::jit::JitShape::new(&mut ctx, root).unwrap(); let mut group = c.benchmark_group("speed vs image size (prospero, 2d) (8 threads)"); @@ -58,7 +58,7 @@ pub fn prospero_thread_sweep(c: &mut Criterion) { let shape_vm = &fidget::vm::VmShape::new(&mut ctx, root).unwrap(); #[cfg(feature = "jit")] - let shape_jit = &fidget::jit::JitShape::new(&ctx, root).unwrap(); + let shape_jit = &fidget::jit::JitShape::new(&mut ctx, root).unwrap(); let mut group = c.benchmark_group("speed vs threads (prospero, 2d) (1024 x 1024)"); diff --git a/fidget/src/core/compiler/ssa_tape.rs b/fidget/src/core/compiler/ssa_tape.rs index 4fc8293f..80e4fc55 100644 --- a/fidget/src/core/compiler/ssa_tape.rs +++ b/fidget/src/core/compiler/ssa_tape.rs @@ -1,7 +1,8 @@ //use crate::vm::{RegisterAllocator, Tape as VmTape}; use crate::{ compiler::SsaOp, - context::{BinaryOpcode, Node, Op, UnaryOpcode, Var}, + context::{BinaryOpcode, Node, Op, UnaryOpcode}, + eval::VarMap, Context, Error, }; use serde::{Deserialize, Serialize}; @@ -32,7 +33,7 @@ impl SsaTape { /// /// This should always succeed unless the `root` is from a different /// `Context`, in which case `Error::BadNode` will be returned. - pub fn new(ctx: &Context, root: Node) -> Result { + pub fn new(ctx: &Context, root: Node) -> Result<(Self, VarMap), Error> { let mut mapping = HashMap::new(); let mut parent_count: HashMap = HashMap::new(); let mut slot_count = 0; @@ -46,6 +47,7 @@ impl SsaTape { // Accumulate parent counts and declare all nodes let mut seen = HashSet::new(); + let mut vars = HashMap::new(); let mut todo = vec![root]; while let Some(node) = todo.pop() { if !seen.insert(node) { @@ -57,6 +59,11 @@ impl SsaTape { mapping.insert(node, Slot::Immediate(c.0 as f32)) } _ => { + if let Op::Input(v) = op { + let next = vars.len(); + let v = ctx.get_var_by_index(*v).unwrap().clone(); + vars.entry(v).or_insert(next); + } let i = slot_count; slot_count += 1; mapping.insert(node, Slot::Reg(i)) @@ -91,14 +98,8 @@ impl SsaTape { }; let op = match op { Op::Input(..) => { - let arg = match ctx.var_name(node).unwrap().unwrap() { - Var::X => 0, - Var::Y => 1, - Var::Z => 2, - i => panic!("Unexpected input index: {i}"), - // TODO make this work for _n_ vars - }; - SsaOp::Input(i, arg) + let arg = vars[ctx.var_name(node).unwrap().unwrap()]; + SsaOp::Input(i, arg.try_into().unwrap()) } Op::Const(..) => { unreachable!("skipped above") @@ -233,7 +234,7 @@ impl SsaTape { tape.push(SsaOp::CopyImm(0, c)); } - Ok(SsaTape { tape, choice_count }) + Ok((SsaTape { tape, choice_count }, vars)) } /// Checks whether the tape is empty @@ -405,8 +406,9 @@ mod test { let c8 = ctx.sub(c7, r).unwrap(); let c9 = ctx.max(c8, c6).unwrap(); - let tape = SsaTape::new(&ctx, c9).unwrap(); + let (tape, vs) = SsaTape::new(&ctx, c9).unwrap(); assert_eq!(tape.len(), 8); + assert_eq!(vs.len(), 2); } #[test] @@ -415,7 +417,8 @@ mod test { let x = ctx.x(); let x_squared = ctx.mul(x, x).unwrap(); - let tape = SsaTape::new(&ctx, x_squared).unwrap(); + let (tape, vs) = SsaTape::new(&ctx, x_squared).unwrap(); assert_eq!(tape.len(), 2); + assert_eq!(vs.len(), 1); } } diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index 04bbcf2b..b10e5424 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -1245,8 +1245,9 @@ mod test { let c8 = ctx.sub(c7, r).unwrap(); let c9 = ctx.max(c8, c6).unwrap(); - let tape = VmData::<255>::new(&ctx, c9).unwrap(); + let (tape, vs) = VmData::<255>::new(&ctx, c9).unwrap(); assert_eq!(tape.len(), 8); + assert_eq!(vs.len(), 2); } #[test] @@ -1255,7 +1256,8 @@ mod test { let x = ctx.x(); let x_squared = ctx.mul(x, x).unwrap(); - let tape = VmData::<255>::new(&ctx, x_squared).unwrap(); + let (tape, vs) = VmData::<255>::new(&ctx, x_squared).unwrap(); assert_eq!(tape.len(), 2); + assert_eq!(vs.len(), 1); } } diff --git a/fidget/src/core/eval/bulk.rs b/fidget/src/core/eval/bulk.rs index 67c8e81b..ede24e96 100644 --- a/fidget/src/core/eval/bulk.rs +++ b/fidget/src/core/eval/bulk.rs @@ -49,7 +49,9 @@ pub trait BulkEvaluator: Default { vars: &[&[Self::Data]], var_count: usize, ) -> Result<(), Error> { - if var_count != vars.len() { + // It's fine if the caller has given us extra variables (e.g. due to + // tape simplification), but it must have given us enough. + if vars.len() < var_count { Err(Error::BadVarSlice(vars.len(), var_count)) } else { let Some(n) = vars.first().map(|v| v.len()) else { diff --git a/fidget/src/core/eval/mod.rs b/fidget/src/core/eval/mod.rs index e6e5654e..680d58f3 100644 --- a/fidget/src/core/eval/mod.rs +++ b/fidget/src/core/eval/mod.rs @@ -62,10 +62,10 @@ pub trait Function: Send + Sync + Clone { /// identical result and may be cached. type Trace: Clone + Eq + Send + Trace; - /// Associated type for storage used by the shape itself + /// Associated type for storage used by the function itself type Storage: Default + Send; - /// Associated type for workspace used during shape simplification + /// Associated type for workspace used during function simplification type Workspace: Default + Send; /// Associated type for storage used by tapes @@ -154,16 +154,16 @@ pub trait Function: Send + Sync + Clone { where Self: Sized; - /// Attempt to reclaim storage from this shape + /// Attempt to reclaim storage from this function /// - /// This may fail, because shapes are `Clone` and are often implemented + /// This may fail, because functions are `Clone` and are often implemented /// using an `Arc` around a heavier data structure. fn recycle(self) -> Option; - /// Returns a size associated with this shape + /// Returns a size associated with this function /// /// This is underspecified and only used for unit testing; for tape-based - /// shapes, it's typically the length of the tape, + /// functions, it's typically the length of the tape, fn size(&self) -> usize; } @@ -172,13 +172,17 @@ pub type VarMap = std::collections::HashMap; /// A [`Function`] which can be built from a math expression pub trait MathFunction { - /// Builds a new shape from the given context and node - fn new(ctx: &Context, node: Node) -> Result + /// Builds a new function from the given context and node + /// + /// Returns a tuple of the [`MathFunction`] and a [`VarMap`] representing a + /// mapping from variables (in the [`Context`]) to indices used during + /// evaluation. + fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> where Self: Sized; - /// Helper function to build a shape from a [`Tree`](crate::context::Tree) - fn from_tree(t: &Tree) -> Self + /// Helper function to build a function from a [`Tree`] + fn from_tree(t: &Tree) -> (Self, VarMap) where Self: Sized, { diff --git a/fidget/src/core/eval/tracing.rs b/fidget/src/core/eval/tracing.rs index 6cadf6c0..68e54080 100644 --- a/fidget/src/core/eval/tracing.rs +++ b/fidget/src/core/eval/tracing.rs @@ -50,7 +50,9 @@ pub trait TracingEvaluator: Default { vars: &[Self::Data], var_count: usize, ) -> Result<(), Error> { - if vars.len() != var_count { + // It's fine if the caller has given us extra variables (e.g. due to + // tape simplification), but it must have given us enough. + if vars.len() < var_count { Err(Error::BadVarSlice(vars.len(), var_count)) } else { Ok(()) diff --git a/fidget/src/core/mod.rs b/fidget/src/core/mod.rs index 8b0ad6f8..42247470 100644 --- a/fidget/src/core/mod.rs +++ b/fidget/src/core/mod.rs @@ -14,7 +14,7 @@ //! let radius = ctx.add(x_squared, y_squared)?; //! let circle = ctx.sub(radius, 1.0)?; //! -//! let shape = VmShape::new(&ctx, circle)?; +//! let shape = VmShape::new(&mut ctx, circle)?; //! let mut eval = VmShape::new_point_eval(); //! let tape = shape.ez_point_tape(); //! @@ -109,13 +109,8 @@ mod test { let v = ctx.add(x, y).unwrap(); assert_eq!( - ctx.eval( - v, - &[("X".to_string(), 1.0), ("Y".to_string(), 2.0)] - .into_iter() - .collect() - ) - .unwrap(), + ctx.eval(v, &[(Var::X, 1.0), (Var::Y, 2.0)].into_iter().collect()) + .unwrap(), 3.0 ); assert_eq!(ctx.eval_xyz(v, 2.0, 3.0, 0.0).unwrap(), 5.0); diff --git a/fidget/src/core/shape/bulk.rs b/fidget/src/core/shape/bulk.rs index 7390c31d..bf80c7a8 100644 --- a/fidget/src/core/shape/bulk.rs +++ b/fidget/src/core/shape/bulk.rs @@ -39,9 +39,4 @@ pub trait BulkEvaluator: Default { y: &[Self::Data], z: &[Self::Data], ) -> Result<&[Self::Data], Error>; - - /// Build a new empty evaluator - fn new() -> Self { - Self::default() - } } diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index 8ddaa47b..e11ebe8f 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -11,7 +11,7 @@ //! //! let mut ctx = Context::new(); //! let x = ctx.x(); -//! let shape = VmShape::new(&ctx, x)?; +//! let shape = VmShape::new(&mut ctx, x)?; //! //! // Let's build a single point evaluator: //! let mut eval = VmShape::new_point_eval(); @@ -82,7 +82,7 @@ pub trait Shape: Send + Sync + Clone { /// Builds a new point evaluator fn new_point_eval() -> Self::PointEval { - Self::PointEval::new() + Self::PointEval::default() } /// Associated type for single interval tracing evaluation @@ -95,7 +95,7 @@ pub trait Shape: Send + Sync + Clone { /// Builds a new interval evaluator fn new_interval_eval() -> Self::IntervalEval { - Self::IntervalEval::new() + Self::IntervalEval::default() } /// Associated type for evaluating many points in one call @@ -105,7 +105,7 @@ pub trait Shape: Send + Sync + Clone { /// Builds a new float slice evaluator fn new_float_slice_eval() -> Self::FloatSliceEval { - Self::FloatSliceEval::new() + Self::FloatSliceEval::default() } /// Associated type for evaluating many gradients in one call @@ -115,7 +115,7 @@ pub trait Shape: Send + Sync + Clone { /// Builds a new gradient slice evaluator fn new_grad_slice_eval() -> Self::GradSliceEval { - Self::GradSliceEval::new() + Self::GradSliceEval::default() } /// Returns an evaluation tape for a point evaluator @@ -302,8 +302,8 @@ pub struct FunctionShape { /// Wrapped function f: F, - /// Index of x, y, z axes within the function's variable list - axes: [usize; 3], + /// Index of x, y, z axes within the function's variable list (if present) + axes: [Option; 3], } impl Shape for FunctionShape { @@ -324,28 +324,40 @@ impl Shape for FunctionShape { &self, storage: Self::TapeStorage, ) -> ::Tape { - self.f.point_tape(storage) + FunctionShapeTape { + tape: self.f.point_tape(storage), + axes: self.axes, + } } fn interval_tape( &self, storage: Self::TapeStorage, ) -> ::Tape { - self.f.interval_tape(storage) + FunctionShapeTape { + tape: self.f.interval_tape(storage), + axes: self.axes, + } } fn float_slice_tape( &self, storage: Self::TapeStorage, ) -> ::Tape { - self.f.float_slice_tape(storage) + FunctionShapeTape { + tape: self.f.float_slice_tape(storage), + axes: self.axes, + } } fn grad_slice_tape( &self, storage: Self::TapeStorage, ) -> ::Tape { - self.f.grad_slice_tape(storage) + FunctionShapeTape { + tape: self.f.grad_slice_tape(storage), + axes: self.axes, + } } fn simplify( @@ -387,8 +399,31 @@ impl MathShape for FunctionShape { node: Node, axes: [Node; 3], ) -> Result { - let f = F::new(ctx, node)?; // TODO get a varmap here - Ok(Self { f, axes: [0, 1, 2] }) + let (f, vs) = F::new(ctx, node)?; + let x = ctx.var_name(axes[0])?.ok_or(Error::NotAVar)?; + let y = ctx.var_name(axes[1])?.ok_or(Error::NotAVar)?; + let z = ctx.var_name(axes[2])?.ok_or(Error::NotAVar)?; + Ok(Self { + f, + axes: [x, y, z].map(|v| vs.get(v).cloned()), + }) + } +} + +impl FunctionShape { + /// Borrows the inner [`Function`](eval::Function) object + pub fn inner(&self) -> &F { + &self.f + } + + /// Borrows the inner axis mapping + pub fn axes(&self) -> &[Option; 3] { + &self.axes + } + + /// Raw constructor + pub fn new_raw(f: F, axes: [Option; 3]) -> Self { + Self { f, axes } } } @@ -406,21 +441,33 @@ impl RenderHints for FunctionShape { } } +/// Wrapper struct to bind a generic tape to particular X, Y, Z axes +pub struct FunctionShapeTape { + tape: T, + + /// Index of the X, Y, Z axes in the variables array + axes: [Option; 3], +} + +impl eval::Tape for FunctionShapeTape { + type Storage = ::Storage; + fn recycle(self) -> Self::Storage { + self.tape.recycle() + } +} + /// Wrapper struct to convert from [`eval::TracingEvaluator`] to /// [`shape::TracingEvaluator`](TracingEvaluator) #[derive(Default)] pub struct FunctionShapeTracingEval { eval: E, - - /// Index of x, y, z axes within the function's variable list - axes: [usize; 3], } impl TracingEvaluator for FunctionShapeTracingEval { type Data = E::Data; - type Tape = E::Tape; + type Tape = FunctionShapeTape; type TapeStorage = E::TapeStorage; type Trace = E::Trace; @@ -432,13 +479,18 @@ impl TracingEvaluator z: F, ) -> Result<(Self::Data, Option<&Self::Trace>), Error> { let mut vars = [None, None, None]; - vars[self.axes[0]] = Some(x.into()); - vars[self.axes[1]] = Some(y.into()); - vars[self.axes[2]] = Some(z.into()); - - // TODO make this error? Where do we maintain the `axes` invariants? - let vars = vars.map(Option::unwrap); - self.eval.eval(tape, vars.as_slice()) + if let Some(a) = tape.axes[0] { + vars[a] = Some(x.into()); + } + if let Some(b) = tape.axes[1] { + vars[b] = Some(y.into()); + } + if let Some(c) = tape.axes[2] { + vars[c] = Some(z.into()); + } + let n = vars.iter().position(|v| Option::is_none(v)).unwrap_or(3); + let vars = vars.map(|v| v.unwrap_or(0f32.into())); + self.eval.eval(&tape.tape, &vars[..n]) } // todo } @@ -448,20 +500,13 @@ impl TracingEvaluator #[derive(Default)] pub struct FunctionShapeBulkEval { eval: E, - - /// Index of x, y, z axes within the function's variable list - axes: [usize; 3], } impl BulkEvaluator for FunctionShapeBulkEval { type Data = E::Data; - type Tape = E::Tape; + type Tape = FunctionShapeTape; type TapeStorage = E::TapeStorage; - fn new() -> Self { - Self::default() - } - fn eval( &mut self, tape: &Self::Tape, @@ -470,12 +515,24 @@ impl BulkEvaluator for FunctionShapeBulkEval { z: &[Self::Data], ) -> Result<&[Self::Data], Error> { let mut vars = [None, None, None]; - vars[self.axes[0]] = Some(x); - vars[self.axes[1]] = Some(y); - vars[self.axes[2]] = Some(z); - - // TODO make this error? Where do we maintain the `axes` invariants? - let vars = vars.map(Option::unwrap); - self.eval.eval(tape, &vars) + if let Some(a) = tape.axes[0] { + vars[a] = Some(x); + } + if let Some(b) = tape.axes[1] { + vars[b] = Some(y); + } + if let Some(c) = tape.axes[2] { + vars[c] = Some(z); + } + let n = vars.iter().position(|v| v.is_none()).unwrap_or(3); + let vars = if vars.iter().all(Option::is_some) { + vars.map(Option::unwrap) + } else if let Some(q) = vars.iter().find(|v| v.is_some()) { + vars.map(|v| if v.is_some() { v.unwrap() } else { q.unwrap() }) + } else { + [[].as_slice(); 3] + }; + + self.eval.eval(&tape.tape, &vars[..n]) } } diff --git a/fidget/src/core/shape/tracing.rs b/fidget/src/core/shape/tracing.rs index b9689c08..173494f7 100644 --- a/fidget/src/core/shape/tracing.rs +++ b/fidget/src/core/shape/tracing.rs @@ -41,11 +41,6 @@ pub trait TracingEvaluator: Default { z: F, ) -> Result<(Self::Data, Option<&Self::Trace>), Error>; - /// Build a new empty evaluator - fn new() -> Self { - Self::default() - } - #[cfg(test)] fn eval_x>( &mut self, diff --git a/fidget/src/core/vm/data.rs b/fidget/src/core/vm/data.rs index 421b8aba..d60375c0 100644 --- a/fidget/src/core/vm/data.rs +++ b/fidget/src/core/vm/data.rs @@ -2,6 +2,7 @@ use crate::{ compiler::{RegOp, RegTape, RegisterAllocator, SsaOp, SsaTape}, context::{Context, Node}, + eval::VarMap, vm::Choice, Error, }; @@ -38,19 +39,19 @@ use serde::{Deserialize, Serialize}; /// ``` /// use fidget::{ /// compiler::RegOp, -/// context::{Context, Tree}, +/// context::{Context, Tree, Var}, /// vm::VmData, /// }; /// /// let tree = Tree::x() + Tree::y(); /// let mut ctx = Context::new(); /// let sum = ctx.import(&tree); -/// let data = VmData::<255>::new(&ctx, sum)?; +/// let (data, vars) = VmData::<255>::new(&ctx, sum)?; /// assert_eq!(data.len(), 3); // X, Y, and (X + Y) /// /// let mut iter = data.iter_asm(); -/// assert_eq!(iter.next().unwrap(), RegOp::Input(0, 0)); -/// assert_eq!(iter.next().unwrap(), RegOp::Input(1, 1)); +/// assert_eq!(iter.next().unwrap(), RegOp::Input(0, vars[&Var::X] as u8)); +/// assert_eq!(iter.next().unwrap(), RegOp::Input(1, vars[&Var::Y] as u8)); /// assert_eq!(iter.next().unwrap(), RegOp::AddRegReg(0, 0, 1)); /// # Ok::<(), fidget::Error>(()) /// ``` @@ -66,10 +67,10 @@ pub struct VmData { impl VmData { /// Builds a new tape for the given node - pub fn new(context: &Context, node: Node) -> Result { - let ssa = SsaTape::new(context, node)?; + pub fn new(context: &Context, node: Node) -> Result<(Self, VarMap), Error> { + let (ssa, vs) = SsaTape::new(context, node)?; let asm = RegTape::new::(&ssa); - Ok(Self { ssa, asm }) + Ok((Self { ssa, asm }, vs)) } /// Returns the length of the internal VM tape diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index 7ae5f43b..a9cc5e6c 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -2,8 +2,10 @@ use crate::{ compiler::RegOp, context::Node, - eval::{BulkEvaluator, Function, MathFunction, TracingEvaluator}, - eval::{Tape, Trace}, + eval::{ + BulkEvaluator, Function, MathFunction, Tape, Trace, TracingEvaluator, + VarMap, + }, shape::{FunctionShape, RenderHints}, types::{Grad, Interval}, Context, Error, @@ -185,9 +187,9 @@ impl RenderHints for GenericVmFunction { } impl MathFunction for GenericVmFunction { - fn new(ctx: &Context, node: Node) -> Result { - let d = VmData::new(ctx, node)?; - Ok(Self(Arc::new(d))) + fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> { + let (d, vs) = VmData::new(ctx, node)?; + Ok((Self(Arc::new(d)), vs)) } } @@ -802,6 +804,7 @@ impl BulkEvaluator for VmFloatSliceEval { ) -> Result<&[f32], Error> { let tape = tape.0.as_ref(); self.check_arguments(vars, tape.var_count())?; + let size = vars.first().map(|v| v.len()).unwrap_or(0); self.0.resize_slots(tape, size); diff --git a/fidget/src/error.rs b/fidget/src/error.rs index 7603000d..953daa96 100644 --- a/fidget/src/error.rs +++ b/fidget/src/error.rs @@ -11,6 +11,10 @@ pub enum Error { #[error("variable is not present in this `Context`")] BadVar, + /// The given node does not have an associated variable + #[error("node does not have an associated variable")] + NotAVar, + /// `Context` is empty #[error("`Context` is empty")] EmptyContext, diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index 1e2668bf..951b9fe7 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -1,13 +1,13 @@ //! Compilation down to native machine code //! -//! Users are unlikely to use anything in this module other than [`JitShape`], -//! which is a [`Shape`] that uses JIT evaluation. +//! Users are unlikely to use anything in this module other than [`JitFunction`], +//! which is a [`Function`] that uses JIT evaluation. //! //! ``` //! use fidget::{ //! context::Tree, //! shape::{TracingEvaluator, Shape, MathShape, EzShape}, -//! jit::JitShape +//! jit::JitShape, //! }; //! //! let tree = Tree::x() + Tree::y(); @@ -26,10 +26,11 @@ use crate::{ compiler::RegOp, context::{Context, Node}, - eval::{MathFunction, Tape}, + eval::{ + BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator, VarMap, + }, jit::mmap::Mmap, shape::RenderHints, - shape::{BulkEvaluator, Shape, TracingEvaluator, TransformedShape}, types::{Grad, Interval}, vm::{Choice, GenericVmFunction, VmData, VmTrace, VmWorkspace}, Error, @@ -38,7 +39,6 @@ use dynasmrt::{ components::PatchLoc, dynasm, AssemblyOffset, DynamicLabel, DynasmApi, DynasmError, DynasmLabelApi, TargetKind, }; -use nalgebra::Matrix4; mod mmap; @@ -825,11 +825,11 @@ fn build_asm_fn_with_storage( // JIT execute mode is restored here when the _guard is dropped } -/// Shape for use with a JIT evaluator +/// Function for use with a JIT evaluator #[derive(Clone)] -pub struct JitShape(GenericVmFunction); +pub struct JitFunction(GenericVmFunction); -impl JitShape { +impl JitFunction { fn tracing_tape( &self, storage: Mmap, @@ -854,7 +854,7 @@ impl JitShape { } } -impl Shape for JitShape { +impl Function for JitFunction { type Trace = VmTrace; type Storage = VmData; type Workspace = VmWorkspace; @@ -890,7 +890,7 @@ impl Shape for JitShape { ) -> Result { self.0 .simplify_inner(trace.as_slice(), storage, workspace) - .map(JitShape) + .map(JitFunction) } fn recycle(self) -> Option { @@ -900,14 +900,9 @@ impl Shape for JitShape { fn size(&self) -> usize { self.0.size() } - - type TransformedShape = TransformedShape; - fn apply_transform(self, mat: Matrix4) -> Self::TransformedShape { - TransformedShape::new(self, mat) - } } -impl RenderHints for JitShape { +impl RenderHints for JitFunction { fn tile_sizes_3d() -> &'static [usize] { &[64, 16, 8] } @@ -952,7 +947,7 @@ macro_rules! jit_fn { /// Evaluator for a JIT-compiled tracing function /// /// Users are unlikely to use this directly, but it's public because it's an -/// associated type on [`JitShape`]. +/// associated type on [`JitFunction`]. #[derive(Default)] struct JitTracingEval { choices: VmTrace, @@ -990,15 +985,12 @@ impl JitTracingEval { fn eval( &mut self, tape: &JitTracingFn, - x: T, - y: T, - z: T, + vars: &[T], ) -> (T, Option<&VmTrace>) { let mut simplify = 0; self.choices.resize(tape.choice_count, Choice::Unknown); assert!(tape.var_count <= 3); self.choices.fill(Choice::Unknown); - let vars = [x, y, z]; let out = unsafe { (tape.fn_trace)( vars.as_ptr(), @@ -1029,11 +1021,9 @@ impl TracingEvaluator for JitIntervalEval { fn eval( &mut self, tape: &Self::Tape, - x: Self::Data, - y: Self::Data, - z: Self::Data, + vars: &[Self::Data], ) -> Result<(Self::Data, Option<&Self::Trace>), Error> { - Ok(self.0.eval(tape, x, y, z)) + Ok(self.0.eval(tape, vars)) } } @@ -1049,11 +1039,9 @@ impl TracingEvaluator for JitPointEval { fn eval( &mut self, tape: &Self::Tape, - x: Self::Data, - y: Self::Data, - z: Self::Data, + vars: &[Self::Data], ) -> Result<(Self::Data, Option<&Self::Trace>), Error> { - Ok(self.0.eval(tape, x, y, z)) + Ok(self.0.eval(tape, vars)) } } @@ -1080,15 +1068,38 @@ impl Tape for JitBulkFn { } } +/// Maximum SIMD width for any type, checked at runtime (alas) +/// +/// We can't use T::SIMD_SIZE directly here due to Rust limitations. Instead we +/// hard-code a maximum SIMD size along with an assertion that should be +/// optimized out; we can't use a constant assertion here due to the same +/// compiler limitations. +const MAX_SIMD_WIDTH: usize = 8; + /// Bulk evaluator for JIT functions struct JitBulkEval { + /// Array of pointers used when calling into the JIT function + ptrs: Vec<*const T>, + + /// Scratch array for evaluation of less-than-SIMD-size slices + scratch: Vec<[T; MAX_SIMD_WIDTH]>, + /// Output array that's written to during evaluation out: Vec, } +// SAFETY: the pointers in `JitBulkEval` are transient and only scoped to a +// single evaluation. +unsafe impl Sync for JitBulkEval {} +unsafe impl Send for JitBulkEval {} + impl Default for JitBulkEval { fn default() -> Self { - Self { out: vec![] } + Self { + out: vec![], + scratch: vec![], + ptrs: vec![], + } } } @@ -1099,15 +1110,9 @@ unsafe impl Sync for JitBulkFn {} impl + Copy + SimdSize> JitBulkEval { /// Evaluate multiple points - fn eval( - &mut self, - tape: &JitBulkFn, - xs: &[T], - ys: &[T], - zs: &[T], - ) -> &[T] { + fn eval(&mut self, tape: &JitBulkFn, vars: &[&[T]]) -> &[T] { assert!(tape.var_count <= 3); - let n = xs.len(); + let n = vars.first().map(|v| v.len()).unwrap_or(0); self.out.resize(n, f32::NAN.into()); self.out.fill(f32::NAN.into()); @@ -1115,51 +1120,51 @@ impl + Copy + SimdSize> JitBulkEval { // in which case the input slices can't be used as workspace (because // they are not valid for the entire range of values read in assembly) if n < T::SIMD_SIZE { - // We can't use T::SIMD_SIZE directly here due to Rust limitations. - // Instead we hard-code a maximum SIMD size along with an assertion - // that should be optimized out; we can't use a constant assertion - // here due to the same compiler limitations. - const MAX_SIMD_WIDTH: usize = 8; - let mut x = [T::from(0.0); MAX_SIMD_WIDTH]; - let mut y = [T::from(0.0); MAX_SIMD_WIDTH]; - let mut z = [T::from(0.0); MAX_SIMD_WIDTH]; assert!(T::SIMD_SIZE <= MAX_SIMD_WIDTH); - x[0..n].copy_from_slice(xs); - y[0..n].copy_from_slice(ys); - z[0..n].copy_from_slice(zs); + // TODO reuse allocation here allocation here? + self.scratch.resize(n, [T::from(0.0); MAX_SIMD_WIDTH]); + for (v, t) in vars.iter().zip(self.scratch.iter_mut()) { + t[0..n].copy_from_slice(v); + } + + self.ptrs.clear(); + self.ptrs.extend(self.scratch.iter().map(|t| t.as_ptr())); - let mut tmp = [f32::NAN.into(); MAX_SIMD_WIDTH]; - let vars = [x.as_ptr(), y.as_ptr(), z.as_ptr()]; + let mut out = [f32::NAN.into(); MAX_SIMD_WIDTH]; unsafe { (tape.fn_bulk)( - vars.as_ptr(), - tmp.as_mut_ptr(), + self.ptrs.as_ptr(), + out.as_mut_ptr(), T::SIMD_SIZE as u64, ); } - self.out.copy_from_slice(&tmp[0..n]); + self.out.copy_from_slice(&out[0..n]); } else { // Our vectorized function only accepts sets of a particular width, // so we'll find the biggest multiple, then do an extra operation to // process any remainders. let m = (n / T::SIMD_SIZE) * T::SIMD_SIZE; // Round down - let vars = [xs.as_ptr(), ys.as_ptr(), zs.as_ptr()]; + self.ptrs.clear(); + self.ptrs.extend(vars.iter().map(|v| v.as_ptr())); unsafe { - (tape.fn_bulk)(vars.as_ptr(), self.out.as_mut_ptr(), m as u64); + (tape.fn_bulk)( + self.ptrs.as_ptr(), + self.out.as_mut_ptr(), + m as u64, + ); } // If we weren't given an even multiple of vector width, then we'll // handle the remaining items by simply evaluating the *last* full // vector in the array again. if n != m { + self.ptrs.clear(); unsafe { - let vars = [ - xs.as_ptr().add(n - T::SIMD_SIZE), - ys.as_ptr().add(n - T::SIMD_SIZE), - zs.as_ptr().add(n - T::SIMD_SIZE), - ]; + self.ptrs.extend( + vars.iter().map(|v| v.as_ptr().add(n - T::SIMD_SIZE)), + ); (tape.fn_bulk)( - vars.as_ptr(), + self.ptrs.as_ptr(), self.out.as_mut_ptr().add(n - T::SIMD_SIZE), T::SIMD_SIZE as u64, ); @@ -1181,11 +1186,10 @@ impl BulkEvaluator for JitFloatSliceEval { fn eval( &mut self, tape: &Self::Tape, - xs: &[f32], - ys: &[f32], - zs: &[f32], + vars: &[&[Self::Data]], ) -> Result<&[Self::Data], Error> { - Ok(self.0.eval(tape, xs, ys, zs)) + self.check_arguments(vars, tape.var_count)?; + Ok(self.0.eval(tape, vars)) } } @@ -1200,20 +1204,23 @@ impl BulkEvaluator for JitGradSliceEval { fn eval( &mut self, tape: &Self::Tape, - xs: &[Self::Data], - ys: &[Self::Data], - zs: &[Self::Data], + vars: &[&[Self::Data]], ) -> Result<&[Self::Data], Error> { - Ok(self.0.eval(tape, xs, ys, zs)) + self.check_arguments(vars, tape.var_count)?; + Ok(self.0.eval(tape, vars)) } } -impl MathFunction for JitShape { - fn new(ctx: &Context, node: Node) -> Result { - GenericVmFunction::new(ctx, node).map(JitShape) +impl MathFunction for JitFunction { + fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> { + let (f, vars) = GenericVmFunction::new(ctx, node)?; + Ok((JitFunction(f), vars)) } } +/// A [`Shape`](crate::shape::Shape) which uses the JIT evaluator +pub type JitShape = crate::shape::FunctionShape; + //////////////////////////////////////////////////////////////////////////////// #[cfg(test)] diff --git a/fidget/src/lib.rs b/fidget/src/lib.rs index 6cf0374b..b3cecf2e 100644 --- a/fidget/src/lib.rs +++ b/fidget/src/lib.rs @@ -170,6 +170,7 @@ //! # }; //! # let tree = Tree::x().min(Tree::y()); //! # let shape = VmShape::from_tree(&tree); +//! assert_eq!(shape.size(), 3); // min, X, Y //! # let mut interval_eval = VmShape::new_interval_eval(); //! # let tape = shape.ez_interval_tape(); //! # let (out, trace) = interval_eval.eval( @@ -179,9 +180,8 @@ //! # [0.0, 0.0], // Z //! # )?; //! // (same code as above) -//! assert_eq!(tape.size(), 3); //! let new_shape = shape.ez_simplify(trace.unwrap())?; -//! assert_eq!(new_shape.ez_interval_tape().size(), 1); // just the 'X' term +//! assert_eq!(new_shape.size(), 1); // just the X term //! # Ok::<(), fidget::Error>(()) //! ``` //! @@ -197,7 +197,7 @@ //! ``` //! use fidget::{ //! context::{Tree, Context}, -//! eval::MathShape, +//! shape::MathShape, //! render::{BitRenderMode, RenderConfig}, //! vm::VmShape, //! }; diff --git a/fidget/src/mesh/mod.rs b/fidget/src/mesh/mod.rs index 7e48dff6..4d558427 100644 --- a/fidget/src/mesh/mod.rs +++ b/fidget/src/mesh/mod.rs @@ -19,7 +19,7 @@ //! //! ``` //! use fidget::{ -//! eval::MathShape, +//! shape::MathShape, //! mesh::{Octree, Settings}, //! vm::VmShape //! }; diff --git a/fidget/src/render/render2d.rs b/fidget/src/render/render2d.rs index 6ffb91c4..62161817 100644 --- a/fidget/src/render/render2d.rs +++ b/fidget/src/render/render2d.rs @@ -358,8 +358,8 @@ fn worker( scratch, image: vec![], config, - eval_float_slice: S::FloatSliceEval::new(), - eval_interval: S::IntervalEval::new(), + eval_float_slice: S::FloatSliceEval::default(), + eval_interval: S::IntervalEval::default(), tape_storage: vec![], shape_storage: vec![], workspace: Default::default(), diff --git a/fidget/src/render/render3d.rs b/fidget/src/render/render3d.rs index 79aa488d..8ed63b5e 100644 --- a/fidget/src/render/render3d.rs +++ b/fidget/src/render/render3d.rs @@ -298,9 +298,9 @@ fn worker( color: vec![], config, - eval_float_slice: S::FloatSliceEval::new(), - eval_interval: S::IntervalEval::new(), - eval_grad_slice: S::GradSliceEval::new(), + eval_float_slice: S::FloatSliceEval::default(), + eval_interval: S::IntervalEval::default(), + eval_grad_slice: S::GradSliceEval::default(), tape_storage: vec![], shape_storage: vec![], diff --git a/wasm-demo/Cargo.lock b/wasm-demo/Cargo.lock index ff34b275..5a55a4da 100644 --- a/wasm-demo/Cargo.lock +++ b/wasm-demo/Cargo.lock @@ -255,7 +255,7 @@ dependencies = [ [[package]] name = "fidget" -version = "0.2.7" +version = "0.2.8" dependencies = [ "arrayvec", "bimap", diff --git a/wasm-demo/src/lib.rs b/wasm-demo/src/lib.rs index 3ada219a..6fd1fde4 100644 --- a/wasm-demo/src/lib.rs +++ b/wasm-demo/src/lib.rs @@ -1,8 +1,8 @@ use fidget::{ context::{Context, Tree}, - eval::MathShape, render::{BitRenderMode, RenderConfig}, shape::Bounds, + shape::MathShape, vm::{VmData, VmShape}, Error, }; @@ -29,16 +29,18 @@ pub fn eval_script(s: &str) -> Result { pub fn serialize_into_tape(t: JsTree) -> Result, String> { let mut ctx = Context::new(); let root = ctx.import(&t.0); - let shape = VmShape::new(&ctx, root).map_err(|e| format!("{e}"))?; - bincode::serialize(shape.data()).map_err(|e| format!("{e}")) + let shape = VmShape::new(&mut ctx, root).map_err(|e| format!("{e}"))?; + let vm_data = shape.inner().data(); + let axes = shape.axes(); + bincode::serialize(&(vm_data, axes)).map_err(|e| format!("{e}")) } /// Deserialize a `bincode`-packed `VmData` into a `VmShape` #[wasm_bindgen] pub fn deserialize_tape(data: Vec) -> Result { - let d: VmData<255> = + let (d, axes): (VmData<255>, [Option; 3]) = bincode::deserialize(&data).map_err(|e| format!("{e}"))?; - Ok(JsVmShape(VmShape::from(d))) + Ok(JsVmShape(VmShape::new_raw(d.into(), axes))) } /// Renders a subregion of an image, for webworker-based multithreading From 02f8a5c385b40166a83a1d14ce54f32823e86a53 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Thu, 23 May 2024 08:03:30 -0400 Subject: [PATCH 07/12] Remove Shape trait --- fidget/benches/function_call.rs | 25 +- fidget/benches/mesh.rs | 1 - fidget/benches/render.rs | 2 - fidget/src/core/eval/test/float_slice.rs | 46 +-- fidget/src/core/eval/test/grad_slice.rs | 61 +-- fidget/src/core/eval/test/interval.rs | 160 ++++---- fidget/src/core/eval/test/point.rs | 121 +++--- fidget/src/core/shape/bulk.rs | 42 -- fidget/src/core/shape/mod.rs | 490 +++++++++-------------- fidget/src/core/shape/tracing.rs | 65 --- fidget/src/core/vm/mod.rs | 14 +- fidget/src/jit/mod.rs | 10 +- fidget/src/mesh/mt/octree.rs | 48 +-- fidget/src/mesh/octree.rs | 155 +++---- fidget/src/render/config.rs | 13 +- fidget/src/render/mod.rs | 49 ++- fidget/src/render/render2d.rs | 97 ++--- fidget/src/render/render3d.rs | 49 +-- 18 files changed, 604 insertions(+), 844 deletions(-) delete mode 100644 fidget/src/core/shape/bulk.rs delete mode 100644 fidget/src/core/shape/tracing.rs diff --git a/fidget/benches/function_call.rs b/fidget/benches/function_call.rs index 3de2f12d..34c0d7b6 100644 --- a/fidget/benches/function_call.rs +++ b/fidget/benches/function_call.rs @@ -3,19 +3,20 @@ use criterion::{ }; use fidget::{ context::{Context, Node}, - shape::{BulkEvaluator, EzShape, MathShape, Shape}, + eval::{Function, MathFunction}, + shape::{EzShape, Shape}, }; -pub fn run_bench( +pub fn run_bench( c: &mut Criterion, mut ctx: Context, node: Node, test_name: &'static str, name: &'static str, ) { - let shape_vm = &S::new(&mut ctx, node).unwrap(); + let shape_vm = &Shape::::new(&mut ctx, node).unwrap(); - let mut eval = S::new_float_slice_eval(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape_vm.ez_float_slice_tape(); let mut group = c.benchmark_group(test_name); @@ -30,7 +31,7 @@ pub fn run_bench( } } -pub fn test_single_fn( +pub fn test_single_fn( c: &mut Criterion, name: &'static str, ) { @@ -38,10 +39,10 @@ pub fn test_single_fn( let x = ctx.x(); let f = ctx.sin(x).unwrap(); - run_bench::(c, ctx, f, "single function", name); + run_bench::(c, ctx, f, "single function", name); } -pub fn test_many_fn( +pub fn test_many_fn( c: &mut Criterion, name: &'static str, ) { @@ -56,19 +57,19 @@ pub fn test_many_fn( let out = ctx.add(f, g).unwrap(); let out = ctx.add(out, h).unwrap(); - run_bench::(c, ctx, out, "many functions", name); + run_bench::(c, ctx, out, "many functions", name); } pub fn test_single_fns(c: &mut Criterion) { - test_single_fn::(c, "vm"); + test_single_fn::(c, "vm"); #[cfg(feature = "jit")] - test_single_fn::(c, "jit"); + test_single_fn::(c, "jit"); } pub fn test_many_fns(c: &mut Criterion) { - test_many_fn::(c, "vm"); + test_many_fn::(c, "vm"); #[cfg(feature = "jit")] - test_many_fn::(c, "jit"); + test_many_fn::(c, "jit"); } criterion_group!(benches, test_single_fns, test_many_fns); diff --git a/fidget/benches/mesh.rs b/fidget/benches/mesh.rs index 3d57c38e..000aa22b 100644 --- a/fidget/benches/mesh.rs +++ b/fidget/benches/mesh.rs @@ -1,7 +1,6 @@ use criterion::{ black_box, criterion_group, criterion_main, BenchmarkId, Criterion, }; -use fidget::shape::MathShape; const COLONNADE: &str = include_str!("../../models/colonnade.vm"); diff --git a/fidget/benches/render.rs b/fidget/benches/render.rs index dfec74f2..c8a13d93 100644 --- a/fidget/benches/render.rs +++ b/fidget/benches/render.rs @@ -4,8 +4,6 @@ use criterion::{ const PROSPERO: &str = include_str!("../../models/prospero.vm"); -use fidget::shape::{MathShape, RenderHints}; - pub fn prospero_size_sweep(c: &mut Criterion) { let (mut ctx, root) = fidget::Context::from_text(PROSPERO.as_bytes()).unwrap(); diff --git a/fidget/src/core/eval/test/float_slice.rs b/fidget/src/core/eval/test/float_slice.rs index 74890bd5..cf9914fc 100644 --- a/fidget/src/core/eval/test/float_slice.rs +++ b/fidget/src/core/eval/test/float_slice.rs @@ -6,26 +6,24 @@ use super::{build_stress_fn, test_args, CanonicalBinaryOp, CanonicalUnaryOp}; use crate::{ context::Context, - shape::{BulkEvaluator, EzShape, MathShape, Shape}, + eval::{Function, MathFunction}, + shape::{EzShape, Shape}, }; /// Helper struct to put constrains on our `Shape` object -pub struct TestFloatSlice(std::marker::PhantomData<*const S>); +pub struct TestFloatSlice(std::marker::PhantomData<*const F>); -impl TestFloatSlice -where - S: Shape + MathShape, -{ +impl TestFloatSlice { pub fn test_give_take() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); - let shape_x = S::new(&mut ctx, x).unwrap(); - let shape_y = S::new(&mut ctx, y).unwrap(); + let shape_x = Shape::::new(&mut ctx, x).unwrap(); + let shape_y = Shape::::new(&mut ctx, y).unwrap(); // This is a fuzz test for icache issues - let mut eval = S::new_float_slice_eval(); + let mut eval = Shape::::new_float_slice_eval(); for _ in 0..10000 { let tape = shape_x.ez_float_slice_tape(); let out = eval @@ -58,8 +56,8 @@ where let x = ctx.x(); let y = ctx.y(); - let mut eval = S::new_float_slice_eval(); - let shape = S::new(&mut ctx, x).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, x).unwrap(); let tape = shape.ez_float_slice_tape(); let out = eval .eval( @@ -92,7 +90,7 @@ where assert_eq!(out, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); let mul = ctx.mul(y, 2.0).unwrap(); - let shape = S::new(&mut ctx, mul).unwrap(); + let shape = Shape::::new(&mut ctx, mul).unwrap(); let tape = shape.ez_float_slice_tape(); let out = eval .eval( @@ -125,8 +123,8 @@ where let a = ctx.x(); let b = ctx.sin(a).unwrap(); - let shape = S::new(&mut ctx, b).unwrap(); - let mut eval = S::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, b).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); let args = [0.0, 1.0, 2.0, std::f32::consts::PI / 2.0]; @@ -147,8 +145,8 @@ where let z: Vec = args[2..].iter().chain(&args[0..2]).cloned().collect(); - let shape = S::new(&mut ctx, node).unwrap(); - let mut eval = S::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); let out = eval.eval(&tape, &x, &y, &z).unwrap(); @@ -204,8 +202,8 @@ where for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { let node = C::build(&mut ctx, v); - let shape = S::new(&mut ctx, node).unwrap(); - let mut eval = S::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); let out = match i { @@ -261,8 +259,8 @@ where for (j, &u) in inputs.iter().enumerate() { let node = C::build(&mut ctx, v, u); - let shape = S::new(&mut ctx, node).unwrap(); - let mut eval = S::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); let out = match (i, j) { @@ -308,8 +306,8 @@ where let c = ctx.constant(*rhs as f64); let node = C::build(&mut ctx, v, c); - let shape = S::new(&mut ctx, node).unwrap(); - let mut eval = S::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); let out = match i { @@ -349,8 +347,8 @@ where let c = ctx.constant(*lhs as f64); let node = C::build(&mut ctx, c, v); - let shape = S::new(&mut ctx, node).unwrap(); - let mut eval = S::new_float_slice_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_float_slice_eval(); let tape = shape.ez_float_slice_tape(); let out = match i { diff --git a/fidget/src/core/eval/test/grad_slice.rs b/fidget/src/core/eval/test/grad_slice.rs index ed2afe1b..7e40b0c3 100644 --- a/fidget/src/core/eval/test/grad_slice.rs +++ b/fidget/src/core/eval/test/grad_slice.rs @@ -5,22 +5,23 @@ use super::{build_stress_fn, test_args, CanonicalBinaryOp, CanonicalUnaryOp}; use crate::{ context::Context, - shape::{BulkEvaluator, EzShape, MathShape, Shape}, + eval::{BulkEvaluator, Function, MathFunction}, + shape::{EzShape, Shape, ShapeTape}, types::Grad, + vm::VmFunction, }; /// Epsilon for gradient estimates const EPSILON: f64 = 1e-8; /// Helper struct to put constrains on our `Shape` object -pub struct TestGradSlice(std::marker::PhantomData<*const S>); +pub struct TestGradSlice(std::marker::PhantomData<*const F>); -impl TestGradSlice -where - S: Shape + MathShape, -{ +impl TestGradSlice { fn eval_xyz( - tape: &<::GradSliceEval as BulkEvaluator>::Tape, + tape: &ShapeTape< + <::GradSliceEval as BulkEvaluator>::Tape, + >, xs: &[f32], ys: &[f32], zs: &[f32], @@ -34,14 +35,14 @@ where let zs: Vec<_> = zs.iter().map(|z| Grad::new(*z, 0.0, 0.0, 1.0)).collect(); - let mut eval = S::new_grad_slice_eval(); + let mut eval = Shape::::new_grad_slice_eval(); eval.eval(tape, &xs, &ys, &zs).unwrap().to_owned() } pub fn test_g_x() { let mut ctx = Context::new(); let x = ctx.x(); - let shape = S::new(&mut ctx, x).unwrap(); + let shape = Shape::::new(&mut ctx, x).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -53,7 +54,7 @@ where pub fn test_g_y() { let mut ctx = Context::new(); let y = ctx.y(); - let shape = S::new(&mut ctx, y).unwrap(); + let shape = Shape::::new(&mut ctx, y).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -65,7 +66,7 @@ where pub fn test_g_z() { let mut ctx = Context::new(); let z = ctx.z(); - let shape = S::new(&mut ctx, z).unwrap(); + let shape = Shape::::new(&mut ctx, z).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -78,7 +79,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.square(x).unwrap(); - let shape = S::new(&mut ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -103,7 +104,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.abs(x).unwrap(); - let shape = S::new(&mut ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -120,7 +121,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sqrt(x).unwrap(); - let shape = S::new(&mut ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -137,7 +138,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sin(x).unwrap(); - let shape = S::new(&mut ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); let v = Self::eval_xyz(&tape, &[1.0, 2.0, 3.0], &[0.0; 3], &[0.0; 3]); @@ -148,7 +149,7 @@ where let y = ctx.y(); let y = ctx.mul(y, 2.0).unwrap(); let s = ctx.sin(y).unwrap(); - let shape = S::new(&mut ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); let v = Self::eval_xyz(&tape, &[0.0; 3], &[1.0, 2.0, 3.0], &[0.0; 3]); v[0].compare_eq(Grad::new(2f32.sin(), 0.0, 2.0 * 2f32.cos(), 0.0)); @@ -161,7 +162,7 @@ where let x = ctx.x(); let y = ctx.y(); let s = ctx.mul(x, y).unwrap(); - let shape = S::new(&mut ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -186,7 +187,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.div(x, 2.0).unwrap(); - let shape = S::new(&mut ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -199,7 +200,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.recip(x).unwrap(); - let shape = S::new(&mut ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -217,7 +218,7 @@ where let x = ctx.x(); let y = ctx.y(); let m = ctx.min(x, y).unwrap(); - let shape = S::new(&mut ctx, m).unwrap(); + let shape = Shape::::new(&mut ctx, m).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -237,7 +238,7 @@ where let z = ctx.z(); let min = ctx.min(x, y).unwrap(); let max = ctx.max(min, z).unwrap(); - let shape = S::new(&mut ctx, max).unwrap(); + let shape = Shape::::new(&mut ctx, max).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -259,7 +260,7 @@ where let x = ctx.x(); let y = ctx.y(); let m = ctx.max(x, y).unwrap(); - let shape = S::new(&mut ctx, m).unwrap(); + let shape = Shape::::new(&mut ctx, m).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -276,7 +277,7 @@ where let mut ctx = Context::new(); let x = ctx.x(); let m = ctx.not(x).unwrap(); - let shape = S::new(&mut ctx, m).unwrap(); + let shape = Shape::::new(&mut ctx, m).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -295,7 +296,7 @@ where let sum = ctx.add(x2, y2).unwrap(); let sqrt = ctx.sqrt(sum).unwrap(); let sub = ctx.sub(sqrt, 0.5).unwrap(); - let shape = S::new(&mut ctx, sub).unwrap(); + let shape = Shape::::new(&mut ctx, sub).unwrap(); let tape = shape.ez_grad_slice_tape(); assert_eq!( @@ -327,7 +328,7 @@ where let z: Vec = args[2..].iter().chain(&args[0..2]).cloned().collect(); - let shape = S::new(&mut ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = Self::eval_xyz(&tape, &x, &y, &z); @@ -355,7 +356,7 @@ where let shape = VmShape::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); - let cmp = TestGradSlice::::eval_xyz(&tape, &x, &y, &z); + let cmp = TestGradSlice::::eval_xyz(&tape, &x, &y, &z); for (a, b) in out.iter().zip(cmp.iter()) { a.compare_eq(*b) } @@ -375,7 +376,7 @@ where for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { let node = C::build(&mut ctx, v); - let shape = S::new(&mut ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = match i { @@ -528,7 +529,7 @@ where for (j, &u) in inputs.iter().enumerate() { let node = C::build(&mut ctx, v, u); - let shape = S::new(&mut ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = match (i, j) { @@ -575,7 +576,7 @@ where let c = ctx.constant(*rhs as f64); let node = C::build(&mut ctx, v, c); - let shape = S::new(&mut ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = match i { @@ -616,7 +617,7 @@ where let c = ctx.constant(*lhs as f64); let node = C::build(&mut ctx, c, v); - let shape = S::new(&mut ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.ez_grad_slice_tape(); let out = match i { diff --git a/fidget/src/core/eval/test/interval.rs b/fidget/src/core/eval/test/interval.rs index cb178dc6..513b828c 100644 --- a/fidget/src/core/eval/test/interval.rs +++ b/fidget/src/core/eval/test/interval.rs @@ -9,28 +9,28 @@ use super::{ }; use crate::{ context::Context, - eval::Tape, - shape::{EzShape, MathShape, Shape, TracingEvaluator}, + eval::{Function, MathFunction}, + shape::{EzShape, Shape}, types::Interval, vm::Choice, }; /// Helper struct to put constrains on our `Shape` object -pub struct TestInterval(std::marker::PhantomData<*const S>); +pub struct TestInterval(std::marker::PhantomData<*const F>); -impl TestInterval +impl TestInterval where - for<'a> S: Shape + MathShape, - ::Trace: AsRef<[Choice]>, + for<'a> F: Function + MathFunction, + ::Trace: AsRef<[Choice]>, { pub fn test_interval() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); - let shape = S::new(&mut ctx, x).unwrap(); + let shape = Shape::::new(&mut ctx, x).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!( eval.eval_xy(&tape, [0.0, 1.0], [2.0, 3.0]), [0.0, 1.0].into() @@ -40,9 +40,9 @@ where [1.0, 5.0].into() ); - let shape = S::new(&mut ctx, y).unwrap(); + let shape = Shape::::new(&mut ctx, y).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!( eval.eval_xy(&tape, [0.0, 1.0], [2.0, 3.0]), [2.0, 3.0].into() @@ -58,9 +58,9 @@ where let x = ctx.x(); let abs_x = ctx.abs(x).unwrap(); - let shape = S::new(&mut ctx, abs_x).unwrap(); + let shape = Shape::::new(&mut ctx, abs_x).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 1.0].into()); assert_eq!(eval.eval_x(&tape, [1.0, 5.0]), [1.0, 5.0].into()); assert_eq!(eval.eval_x(&tape, [-2.0, 5.0]), [0.0, 5.0].into()); @@ -70,9 +70,9 @@ where let y = ctx.y(); let abs_y = ctx.abs(y).unwrap(); let sum = ctx.add(abs_x, abs_y).unwrap(); - let shape = S::new(&mut ctx, sum).unwrap(); + let shape = Shape::::new(&mut ctx, sum).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!( eval.eval_xy(&tape, [0.0, 1.0], [0.0, 1.0]), [0.0, 2.0].into() @@ -93,9 +93,9 @@ where let v = ctx.add(x, 0.5).unwrap(); let out = ctx.abs(v).unwrap(); - let shape = S::new(&mut ctx, out).unwrap(); + let shape = Shape::::new(&mut ctx, out).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [-1.0, 1.0]), [0.0, 1.5].into()); } @@ -105,9 +105,9 @@ where let x = ctx.x(); let sqrt_x = ctx.sqrt(x).unwrap(); - let shape = S::new(&mut ctx, sqrt_x).unwrap(); + let shape = Shape::::new(&mut ctx, sqrt_x).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 1.0].into()); assert_eq!(eval.eval_x(&tape, [0.0, 4.0]), [0.0, 2.0].into()); @@ -133,9 +133,9 @@ where let x = ctx.x(); let sqrt_x = ctx.square(x).unwrap(); - let shape = S::new(&mut ctx, sqrt_x).unwrap(); + let shape = Shape::::new(&mut ctx, sqrt_x).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 1.0].into()); assert_eq!(eval.eval_x(&tape, [0.0, 4.0]), [0.0, 16.0].into()); assert_eq!(eval.eval_x(&tape, [2.0, 4.0]), [4.0, 16.0].into()); @@ -154,17 +154,17 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sin(x).unwrap(); - let shape = S::new(&mut ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-1.0, 1.0].into()); let y = ctx.y(); let y = ctx.mul(y, 2.0).unwrap(); let s = ctx.sin(y).unwrap(); let s = ctx.add(x, s).unwrap(); - let shape = S::new(&mut ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_interval_tape(); assert_eq!(eval.eval_x(&tape, [0.0, 3.0]), [-1.0, 4.0].into()); @@ -175,9 +175,9 @@ where let x = ctx.x(); let neg_x = ctx.neg(x).unwrap(); - let shape = S::new(&mut ctx, neg_x).unwrap(); + let shape = Shape::::new(&mut ctx, neg_x).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-1.0, 0.0].into()); assert_eq!(eval.eval_x(&tape, [0.0, 4.0]), [-4.0, 0.0].into()); assert_eq!(eval.eval_x(&tape, [2.0, 4.0]), [-4.0, -2.0].into()); @@ -197,9 +197,9 @@ where let x = ctx.x(); let not_x = ctx.not(x).unwrap(); - let shape = S::new(&mut ctx, not_x).unwrap(); + let shape = Shape::::new(&mut ctx, not_x).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [-5.0, 0.0]), [0.0, 1.0].into()); } @@ -209,9 +209,9 @@ where let y = ctx.y(); let mul = ctx.mul(x, y).unwrap(); - let shape = S::new(&mut ctx, mul).unwrap(); + let shape = Shape::::new(&mut ctx, mul).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!( eval.eval_xy(&tape, [0.0, 1.0], [0.0, 1.0]), [0.0, 1.0].into() @@ -250,16 +250,16 @@ where let mut ctx = Context::new(); let x = ctx.x(); let mul = ctx.mul(x, 2.0).unwrap(); - let shape = S::new(&mut ctx, mul).unwrap(); + let shape = Shape::::new(&mut ctx, mul).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 2.0].into()); assert_eq!(eval.eval_x(&tape, [1.0, 2.0]), [2.0, 4.0].into()); let mul = ctx.mul(x, -3.0).unwrap(); - let shape = S::new(&mut ctx, mul).unwrap(); + let shape = Shape::::new(&mut ctx, mul).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-3.0, 0.0].into()); assert_eq!(eval.eval_x(&tape, [1.0, 2.0]), [-6.0, -3.0].into()); } @@ -270,9 +270,9 @@ where let y = ctx.y(); let sub = ctx.sub(x, y).unwrap(); - let shape = S::new(&mut ctx, sub).unwrap(); + let shape = Shape::::new(&mut ctx, sub).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!( eval.eval_xy(&tape, [0.0, 1.0], [0.0, 1.0]), [-1.0, 1.0].into() @@ -299,16 +299,16 @@ where let mut ctx = Context::new(); let x = ctx.x(); let sub = ctx.sub(x, 2.0).unwrap(); - let shape = S::new(&mut ctx, sub).unwrap(); + let shape = Shape::::new(&mut ctx, sub).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-2.0, -1.0].into()); assert_eq!(eval.eval_x(&tape, [1.0, 2.0]), [-1.0, 0.0].into()); let sub = ctx.sub(-3.0, x).unwrap(); - let shape = S::new(&mut ctx, sub).unwrap(); + let shape = Shape::::new(&mut ctx, sub).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [-4.0, -3.0].into()); assert_eq!(eval.eval_x(&tape, [1.0, 2.0]), [-5.0, -4.0].into()); } @@ -317,9 +317,9 @@ where let mut ctx = Context::new(); let x = ctx.x(); let recip = ctx.recip(x).unwrap(); - let shape = S::new(&mut ctx, recip).unwrap(); + let shape = Shape::::new(&mut ctx, recip).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let nanan = eval.eval_x(&tape, [0.0, 1.0]); assert!(nanan.lower().is_nan()); @@ -342,9 +342,9 @@ where let x = ctx.x(); let y = ctx.y(); let div = ctx.div(x, y).unwrap(); - let shape = S::new(&mut ctx, div).unwrap(); + let shape = Shape::::new(&mut ctx, div).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let nanan = eval.eval_xy(&tape, [0.0, 1.0], [-1.0, 1.0]); assert!(nanan.lower().is_nan()); @@ -389,9 +389,9 @@ where let y = ctx.y(); let min = ctx.min(x, y).unwrap(); - let shape = S::new(&mut ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, data) = eval.eval(&tape, [0.0, 1.0], [0.5, 1.5], [0.0; 2]).unwrap(); assert_eq!(r, [0.0, 1.0].into()); @@ -427,9 +427,9 @@ where let x = ctx.x(); let min = ctx.min(x, 1.0).unwrap(); - let shape = S::new(&mut ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, data) = eval.eval(&tape, [0.0, 1.0], [0.0; 2], [0.0; 2]).unwrap(); assert_eq!(r, [0.0, 1.0].into()); @@ -452,9 +452,9 @@ where let y = ctx.y(); let max = ctx.max(x, y).unwrap(); - let shape = S::new(&mut ctx, max).unwrap(); + let shape = Shape::::new(&mut ctx, max).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, data) = eval.eval(&tape, [0.0, 1.0], [0.5, 1.5], [0.0; 2]).unwrap(); assert_eq!(r, [0.5, 1.5].into()); @@ -486,9 +486,9 @@ where let z = ctx.z(); let max_xy_z = ctx.max(max, z).unwrap(); - let shape = S::new(&mut ctx, max_xy_z).unwrap(); + let shape = Shape::::new(&mut ctx, max_xy_z).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, data) = eval .eval(&tape, [2.0, 3.0], [0.0, 1.0], [4.0, 5.0]) .unwrap(); @@ -508,18 +508,15 @@ where assert_eq!(data.unwrap().as_ref(), &[Choice::Left, Choice::Left]); } - pub fn test_i_and() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_i_and() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let v = ctx.and(x, y).unwrap(); - let shape = S::new(&mut ctx, v).unwrap(); + let shape = Shape::::new(&mut ctx, v).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, trace) = eval .eval(&tape, [0.0, 0.0], [-1.0, 3.0], [0.0, 0.0]) .unwrap(); @@ -545,18 +542,15 @@ where assert!(trace.is_none()); // can't simplify } - pub fn test_i_or() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_i_or() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let v = ctx.or(x, y).unwrap(); - let shape = S::new(&mut ctx, v).unwrap(); + let shape = Shape::::new(&mut ctx, v).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, trace) = eval .eval(&tape, [0.0, 0.0], [-1.0, 3.0], [0.0, 0.0]) .unwrap(); @@ -587,9 +581,9 @@ where let x = ctx.x(); let min = ctx.min(x, 1.0).unwrap(); - let shape = S::new(&mut ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (out, data) = eval.eval(&tape, [0.0, 2.0], [0.0; 2], [0.0; 2]).unwrap(); assert_eq!(out, [0.0, 1.0].into()); @@ -606,9 +600,9 @@ where assert_eq!(data.unwrap().as_ref(), &[Choice::Right]); let max = ctx.max(x, 1.0).unwrap(); - let shape = S::new(&mut ctx, max).unwrap(); + let shape = Shape::::new(&mut ctx, max).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (out, data) = eval.eval(&tape, [0.0, 2.0], [0.0; 2], [0.0; 2]).unwrap(); assert_eq!(out, [1.0, 2.0].into()); @@ -632,10 +626,10 @@ where let z = ctx.z(); let if_else = ctx.if_nonzero_else(x, y, z).unwrap(); - let shape = S::new(&mut ctx, if_else).unwrap(); + let shape = Shape::::new(&mut ctx, if_else).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (out, data) = eval .eval(&tape, [-1.0, 2.0], [1.0, 2.0], [3.0, 4.0]) .unwrap(); @@ -681,9 +675,9 @@ where let x = ctx.x(); let max = ctx.max(x, 1.0).unwrap(); - let shape = S::new(&mut ctx, max).unwrap(); + let shape = Shape::::new(&mut ctx, max).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (r, data) = eval .eval(&tape, [0.0, 2.0], [0.0, 0.0], [0.0, 0.0]) .unwrap(); @@ -709,9 +703,9 @@ where let y = ctx.y(); let c = ctx.compare(x, y).unwrap(); - let shape = S::new(&mut ctx, c).unwrap(); + let shape = Shape::::new(&mut ctx, c).unwrap(); let tape = shape.ez_interval_tape(); - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); let (out, _trace) = eval.eval(&tape, -5.0, -6.0, 0.0).unwrap(); assert_eq!(out, Interval::from(1f32)); } @@ -729,8 +723,8 @@ where let y: Vec<_> = x[1..].iter().chain(&x[0..1]).cloned().collect(); let z: Vec<_> = x[2..].iter().chain(&x[0..2]).cloned().collect(); - let shape = S::new(&mut ctx, node).unwrap(); - let mut eval = S::new_interval_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_interval_eval(); let tape = shape.ez_interval_tape(); let mut out = vec![]; @@ -781,11 +775,11 @@ where let mut ctx = Context::new(); let mut tape_data = None; - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { let node = C::build(&mut ctx, v); - let shape = S::new(&mut ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); for &a in args.iter() { @@ -877,7 +871,7 @@ where let name = format!("{}(reg, reg)", C::NAME); let zero = Interval::new(0.0, 0.0); let mut tape_data = None; - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); for &lhs in args.iter() { for &rhs in args.iter() { for (i, &u) in xyz.iter().enumerate() { @@ -891,7 +885,7 @@ where continue; } - let shape = S::new(&mut ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); @@ -934,14 +928,14 @@ where let name = format!("{}(reg, imm)", C::NAME); let zero = Interval::new(0.0, 0.0); let mut tape_data = None; - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); for &lhs in args.iter() { for &rhs in values.iter() { for (i, &u) in xyz.iter().enumerate() { let c = ctx.constant(rhs as f64); let node = C::build(&mut ctx, u, c); - let shape = S::new(&mut ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); @@ -976,14 +970,14 @@ where let name = format!("{}(imm, reg)", C::NAME); let zero = Interval::new(0.0, 0.0); let mut tape_data = None; - let mut eval = S::new_interval_eval(); + let mut eval = Shape::::new_interval_eval(); for &lhs in values.iter() { for &rhs in args.iter() { for (i, &u) in xyz.iter().enumerate() { let c = ctx.constant(lhs as f64); let node = C::build(&mut ctx, c, u); - let shape = S::new(&mut ctx, node).unwrap(); + let shape = Shape::::new(&mut ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); diff --git a/fidget/src/core/eval/test/point.rs b/fidget/src/core/eval/test/point.rs index be179f4b..764d8ca6 100644 --- a/fidget/src/core/eval/test/point.rs +++ b/fidget/src/core/eval/test/point.rs @@ -5,24 +5,25 @@ use super::{build_stress_fn, test_args, CanonicalBinaryOp, CanonicalUnaryOp}; use crate::{ context::Context, - shape::{EzShape, MathShape, Shape, TracingEvaluator}, + eval::{Function, MathFunction}, + shape::{EzShape, Shape}, vm::Choice, }; /// Helper struct to put constrains on our `Shape` object -pub struct TestPoint(std::marker::PhantomData<*const S>); -impl TestPoint +pub struct TestPoint(std::marker::PhantomData<*const F>); +impl TestPoint where - S: Shape + MathShape, - ::Trace: AsRef<[Choice]>, - ::Trace: From>, + F: Function + MathFunction, + ::Trace: AsRef<[Choice]>, + ::Trace: From>, { pub fn test_constant() { let mut ctx = Context::new(); let p = ctx.constant(1.5); - let shape = S::new(&mut ctx, p).unwrap(); + let shape = Shape::::new(&mut ctx, p).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 0.0, 0.0, 0.0).unwrap().0, 1.5); } @@ -31,9 +32,9 @@ where let a = ctx.constant(1.5); let x = ctx.x(); let min = ctx.min(a, x).unwrap(); - let shape = S::new(&mut ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); let (r, trace) = eval.eval(&tape, 2.0, 0.0, 0.0).unwrap(); assert_eq!(r, 1.5); @@ -54,25 +55,22 @@ where let radius = ctx.add(x_squared, y_squared).unwrap(); let circle = ctx.sub(radius, 1.0).unwrap(); - let shape = S::new(&mut ctx, circle).unwrap(); + let shape = Shape::::new(&mut ctx, circle).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 0.0, 0.0, 0.0).unwrap().0, -1.0); assert_eq!(eval.eval(&tape, 1.0, 0.0, 0.0).unwrap().0, 0.0); } - pub fn test_p_min() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_p_min() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let min = ctx.min(x, y).unwrap(); - let shape = S::new(&mut ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); let (r, trace) = eval.eval(&tape, 0.0, 0.0, 0.0).unwrap(); assert_eq!(r, 0.0); assert!(trace.is_none()); @@ -94,18 +92,15 @@ where assert!(trace.is_none()); } - pub fn test_p_max() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_p_max() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let max = ctx.max(x, y).unwrap(); - let shape = S::new(&mut ctx, max).unwrap(); + let shape = Shape::::new(&mut ctx, max).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); let (r, trace) = eval.eval(&tape, 0.0, 0.0, 0.0).unwrap(); assert_eq!(r, 0.0); @@ -128,18 +123,15 @@ where assert!(trace.is_none()); } - pub fn test_p_and() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_p_and() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let v = ctx.and(x, y).unwrap(); - let shape = S::new(&mut ctx, v).unwrap(); + let shape = Shape::::new(&mut ctx, v).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); let (r, trace) = eval.eval(&tape, 0.0, 0.0, 0.0).unwrap(); assert_eq!(r, 0.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); @@ -165,18 +157,15 @@ where assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); } - pub fn test_p_or() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_p_or() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let v = ctx.or(x, y).unwrap(); - let shape = S::new(&mut ctx, v).unwrap(); + let shape = Shape::::new(&mut ctx, v).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); let (r, trace) = eval.eval(&tape, 0.0, 0.0, 0.0).unwrap(); assert_eq!(r, 0.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); @@ -202,17 +191,14 @@ where assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); } - pub fn test_p_sin() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_p_sin() { let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sin(x).unwrap(); - let shape = S::new(&mut ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); for x in [0.0, 1.0, 2.0] { let (r, trace) = eval.eval(&tape, x, 0.0, 0.0).unwrap(); @@ -230,7 +216,7 @@ where let y = ctx.y(); let s = ctx.add(s, y).unwrap(); - let shape = S::new(&mut ctx, s).unwrap(); + let shape = Shape::::new(&mut ctx, s).unwrap(); let tape = shape.ez_point_tape(); for (x, y) in [(0.0, 1.0), (1.0, 3.0), (2.0, 8.0)] { @@ -246,26 +232,23 @@ where let y = ctx.y(); let sum = ctx.add(x, 1.0).unwrap(); let min = ctx.min(sum, y).unwrap(); - let shape = S::new(&mut ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 2.0); assert_eq!(eval.eval(&tape, 1.0, 3.0, 0.0).unwrap().0, 2.0); assert_eq!(eval.eval(&tape, 3.0, 3.5, 0.0).unwrap().0, 3.5); } - pub fn test_push() - where - ::Trace: AsRef<[Choice]>, - { + pub fn test_push() { let mut ctx = Context::new(); let x = ctx.x(); let y = ctx.y(); let min = ctx.min(x, y).unwrap(); - let shape = S::new(&mut ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 1.0); assert_eq!(eval.eval(&tape, 3.0, 2.0, 0.0).unwrap().0, 2.0); @@ -280,9 +263,9 @@ where assert_eq!(eval.eval(&tape, 3.0, 2.0, 0.0).unwrap().0, 2.0); let min = ctx.min(x, 1.0).unwrap(); - let shape = S::new(&mut ctx, min).unwrap(); + let shape = Shape::::new(&mut ctx, min).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 0.5, 0.0, 0.0).unwrap().0, 0.5); assert_eq!(eval.eval(&tape, 3.0, 0.0, 0.0).unwrap().0, 1.0); @@ -302,24 +285,24 @@ where let x = ctx.x(); let y = ctx.y(); - let shape = S::new(&mut ctx, x).unwrap(); + let shape = Shape::::new(&mut ctx, x).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 1.0); assert_eq!(eval.eval(&tape, 3.0, 4.0, 0.0).unwrap().0, 3.0); - let shape = S::new(&mut ctx, y).unwrap(); + let shape = Shape::::new(&mut ctx, y).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 2.0); assert_eq!(eval.eval(&tape, 3.0, 4.0, 0.0).unwrap().0, 4.0); let y2 = ctx.mul(y, 2.5).unwrap(); let sum = ctx.add(x, y2).unwrap(); - let shape = S::new(&mut ctx, sum).unwrap(); + let shape = Shape::::new(&mut ctx, sum).unwrap(); let tape = shape.ez_point_tape(); - let mut eval = S::new_point_eval(); + let mut eval = Shape::::new_point_eval(); assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 6.0); } @@ -332,8 +315,8 @@ where let y: Vec<_> = x[1..].iter().chain(&x[0..1]).cloned().collect(); let z: Vec<_> = x[2..].iter().chain(&x[0..2]).cloned().collect(); - let shape = S::new(&mut ctx, node).unwrap(); - let mut eval = S::new_point_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_point_eval(); let tape = shape.ez_point_tape(); let mut out = vec![]; @@ -388,8 +371,8 @@ where for (i, v) in [ctx.x(), ctx.y(), ctx.z()].into_iter().enumerate() { let node = C::build(&mut ctx, v); - let shape = S::new(&mut ctx, node).unwrap(); - let mut eval = S::new_point_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_point_eval(); let tape = shape.ez_point_tape(); for &a in args.iter() { @@ -444,8 +427,8 @@ where for (j, &v) in xyz.iter().enumerate() { let node = C::build(&mut ctx, u, v); - let shape = S::new(&mut ctx, node).unwrap(); - let mut eval = S::new_point_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_point_eval(); let tape = shape.ez_point_tape(); let (out, _trace) = match (i, j) { @@ -489,8 +472,8 @@ where let c = ctx.constant(rhs as f64); let node = C::build(&mut ctx, u, c); - let shape = S::new(&mut ctx, node).unwrap(); - let mut eval = S::new_point_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_point_eval(); let tape = shape.ez_point_tape(); let (out, _trace) = match i { @@ -526,8 +509,8 @@ where let c = ctx.constant(lhs as f64); let node = C::build(&mut ctx, c, u); - let shape = S::new(&mut ctx, node).unwrap(); - let mut eval = S::new_point_eval(); + let shape = Shape::::new(&mut ctx, node).unwrap(); + let mut eval = Shape::::new_point_eval(); let tape = shape.ez_point_tape(); let (out, _trace) = match i { diff --git a/fidget/src/core/shape/bulk.rs b/fidget/src/core/shape/bulk.rs deleted file mode 100644 index bf80c7a8..00000000 --- a/fidget/src/core/shape/bulk.rs +++ /dev/null @@ -1,42 +0,0 @@ -//! Evaluates many points in a single call -//! -//! Doing bulk evaluations helps limit to overhead of instruction dispatch, and -//! can take advantage of SIMD. -//! -//! It is unlikely that you'll want to use these traits or types directly; -//! they're implementation details to minimize code duplication. - -use crate::{eval::Tape, Error}; - -/// Trait for bulk evaluation returning the given type `T` -/// -/// It's uncommon to use this trait outside the library itself; it's an -/// abstraction to reduce code duplication, and is public because it's used as a -/// constraint on other public APIs. -pub trait BulkEvaluator: Default { - /// Data type used during evaluation - type Data: From + Copy + Clone; - - /// Instruction tape used during evaluation - /// - /// This may be a literal instruction tape (in the case of VM evaluation), - /// or a metaphorical instruction tape (e.g. a JIT function). - type Tape: Tape + Send + Sync; - - /// Associated type for tape storage - /// - /// This is a workaround for plumbing purposes - type TapeStorage; - - /// Evaluates many points using the given instruction tape - /// - /// Returns an error if the `x`, `y`, `z`, and `out` slices are of different - /// lengths. - fn eval( - &mut self, - tape: &Self::Tape, - x: &[Self::Data], - y: &[Self::Data], - z: &[Self::Data], - ) -> Result<&[Self::Data], Error>; -} diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index e11ebe8f..a7f7f8b7 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -28,154 +28,169 @@ //! ambiguity. use crate::{ - context::{Context, Node}, - eval::{self, Trace}, - types::{Grad, Interval}, + context::{Context, Node, Tree}, + eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator}, Error, }; mod bounds; -mod bulk; -mod tracing; -mod transform; // Re-export a few things pub use bounds::Bounds; -pub use bulk::BulkEvaluator; -pub use tracing::TracingEvaluator; -pub use transform::TransformedShape; /// A shape represents an implicit surface /// -/// It is mostly agnostic to _how_ that surface is represented; we simply -/// require that the shape can generate evaluators of various kinds. +/// It is mostly agnostic to _how_ that surface is represented, wrapping a +/// [`Function`](Function) and a set of axes. /// /// Shapes are shared between threads, so they should be cheap to clone. In /// most cases, they're a thin wrapper around an `Arc<..>`. -pub trait Shape: Send + Sync + Clone { - /// Associated type traces collected during tracing evaluation - /// - /// This type must implement [`Eq`] so that traces can be compared; calling - /// [`Shape::simplify`] with traces that compare equal should produce an - /// identical result and may be cached. - type Trace: Clone + Eq + Send + Trace; - - /// Associated type for storage used by the shape itself - type Storage: Default + Send; - - /// Associated type for workspace used during shape simplification - type Workspace: Default + Send; +#[derive(Clone)] +pub struct Shape { + /// Wrapped function + f: F, - /// Associated type for storage used by tapes - /// - /// For simplicity, we require that every tape use the same type for storage. - /// This could change in the future! - type TapeStorage: Default + Send; - - /// Associated type for single-point tracing evaluation - type PointEval: TracingEvaluator< - Data = f32, - Trace = Self::Trace, - TapeStorage = Self::TapeStorage, - > + Send - + Sync; + /// Index of x, y, z axes within the function's variable list (if present) + axes: [Option; 3], +} +impl Shape { /// Builds a new point evaluator - fn new_point_eval() -> Self::PointEval { - Self::PointEval::default() + pub fn new_point_eval() -> ShapeTracingEval { + ShapeTracingEval { + eval: F::PointEval::default(), + } } - /// Associated type for single interval tracing evaluation - type IntervalEval: TracingEvaluator< - Data = Interval, - Trace = Self::Trace, - TapeStorage = Self::TapeStorage, - > + Send - + Sync; - /// Builds a new interval evaluator - fn new_interval_eval() -> Self::IntervalEval { - Self::IntervalEval::default() + pub fn new_interval_eval() -> ShapeTracingEval { + ShapeTracingEval { + eval: F::IntervalEval::default(), + } } - /// Associated type for evaluating many points in one call - type FloatSliceEval: BulkEvaluator - + Send - + Sync; - /// Builds a new float slice evaluator - fn new_float_slice_eval() -> Self::FloatSliceEval { - Self::FloatSliceEval::default() + pub fn new_float_slice_eval() -> ShapeBulkEval { + ShapeBulkEval { + eval: F::FloatSliceEval::default(), + } } - /// Associated type for evaluating many gradients in one call - type GradSliceEval: BulkEvaluator - + Send - + Sync; - /// Builds a new gradient slice evaluator - fn new_grad_slice_eval() -> Self::GradSliceEval { - Self::GradSliceEval::default() + pub fn new_grad_slice_eval() -> ShapeBulkEval { + ShapeBulkEval { + eval: F::GradSliceEval::default(), + } } /// Returns an evaluation tape for a point evaluator - fn point_tape( + pub fn point_tape( &self, - storage: Self::TapeStorage, - ) -> ::Tape; + storage: F::TapeStorage, + ) -> ShapeTape<::Tape> { + ShapeTape { + tape: self.f.point_tape(storage), + axes: self.axes, + } + } - /// Returns an evaluation tape for an interval evaluator - fn interval_tape( + /// Returns an evaluation tape for a interval evaluator + pub fn interval_tape( &self, - storage: Self::TapeStorage, - ) -> ::Tape; + storage: F::TapeStorage, + ) -> ShapeTape<::Tape> { + ShapeTape { + tape: self.f.interval_tape(storage), + axes: self.axes, + } + } /// Returns an evaluation tape for a float slice evaluator - fn float_slice_tape( + pub fn float_slice_tape( &self, - storage: Self::TapeStorage, - ) -> ::Tape; + storage: F::TapeStorage, + ) -> ShapeTape<::Tape> { + ShapeTape { + tape: self.f.float_slice_tape(storage), + axes: self.axes, + } + } - /// Returns an evaluation tape for a float slice evaluator - fn grad_slice_tape( + /// Returns an evaluation tape for a gradient slice evaluator + pub fn grad_slice_tape( &self, - storage: Self::TapeStorage, - ) -> ::Tape; + storage: F::TapeStorage, + ) -> ShapeTape<::Tape> { + ShapeTape { + tape: self.f.grad_slice_tape(storage), + axes: self.axes, + } + } /// Computes a simplified tape using the given trace, and reusing storage - fn simplify( + pub fn simplify( &self, - trace: &Self::Trace, - storage: Self::Storage, - workspace: &mut Self::Workspace, + trace: &F::Trace, + storage: F::Storage, + workspace: &mut F::Workspace, ) -> Result where - Self: Sized; + Self: Sized, + { + let f = self.f.simplify(trace, storage, workspace)?; + Ok(Self { f, axes: self.axes }) + } /// Attempt to reclaim storage from this shape /// /// This may fail, because shapes are `Clone` and are often implemented /// using an `Arc` around a heavier data structure. - fn recycle(self) -> Option; + pub fn recycle(self) -> Option { + self.f.recycle() + } /// Returns a size associated with this shape /// /// This is underspecified and only used for unit testing; for tape-based /// shapes, it's typically the length of the tape, - fn size(&self) -> usize; + pub fn size(&self) -> usize { + self.f.size() + } - /// Associated type returned when applying a transform - /// - /// This is normally [`TransformedShape`](TransformedShape), but if - /// `Self` is already `TransformedShape`, then the transform is stacked - /// (instead of creating a wrapped object). - type TransformedShape: Shape; - - /// Returns a shape with the given transform applied - fn apply_transform( - self, - mat: nalgebra::Matrix4, - ) -> ::TransformedShape; + /// Borrows the inner [`Function`](Function) object + pub fn inner(&self) -> &F { + &self.f + } + + /// Borrows the inner axis mapping + pub fn axes(&self) -> &[Option; 3] { + &self.axes + } + + /// Raw constructor + pub fn new_raw(f: F, axes: [Option; 3]) -> Self { + Self { f, axes } + } +} + +impl Shape { + pub fn tile_sizes_3d() -> &'static [usize] { + F::tile_sizes_3d() + } + + pub fn tile_sizes_2d() -> &'static [usize] { + F::tile_sizes_2d() + } + + pub fn simplify_tree_during_meshing(d: usize) -> bool { + F::simplify_tree_during_meshing(d) + } +} + +impl Shape { + pub fn apply_transform(&self, mat: nalgebra::Matrix4) -> Self { + todo!(); + } } /// Extension trait for working with a shape without thinking much about memory @@ -187,55 +202,59 @@ pub trait Shape: Send + Sync + Clone { /// /// This trait is automatically implemented for every [`Shape`], but must be /// imported separately as a speed-bump to using it everywhere. -pub trait EzShape: Shape { +pub trait EzShape { /// Returns an evaluation tape for a point evaluator - fn ez_point_tape(&self) -> ::Tape; + fn ez_point_tape( + &self, + ) -> ShapeTape<::Tape>; /// Returns an evaluation tape for an interval evaluator fn ez_interval_tape( &self, - ) -> ::Tape; + ) -> ShapeTape<::Tape>; /// Returns an evaluation tape for a float slice evaluator fn ez_float_slice_tape( &self, - ) -> ::Tape; + ) -> ShapeTape<::Tape>; /// Returns an evaluation tape for a float slice evaluator fn ez_grad_slice_tape( &self, - ) -> ::Tape; + ) -> ShapeTape<::Tape>; /// Computes a simplified tape using the given trace - fn ez_simplify(&self, trace: &Self::Trace) -> Result + fn ez_simplify(&self, trace: &F::Trace) -> Result where Self: Sized; } -impl EzShape for S { - fn ez_point_tape(&self) -> ::Tape { +impl EzShape for Shape { + fn ez_point_tape( + &self, + ) -> ShapeTape<::Tape> { self.point_tape(Default::default()) } fn ez_interval_tape( &self, - ) -> ::Tape { + ) -> ShapeTape<::Tape> { self.interval_tape(Default::default()) } fn ez_float_slice_tape( &self, - ) -> ::Tape { + ) -> ShapeTape<::Tape> { self.float_slice_tape(Default::default()) } fn ez_grad_slice_tape( &self, - ) -> ::Tape { + ) -> ShapeTape<::Tape> { self.grad_slice_tape(Default::default()) } - fn ez_simplify(&self, trace: &Self::Trace) -> Result { + fn ez_simplify(&self, trace: &F::Trace) -> Result { let mut workspace = Default::default(); self.simplify(trace, Default::default(), &mut workspace) } @@ -260,141 +279,8 @@ pub trait RenderHints { } } -/// A [`Shape`] which can be built from a math expression -pub trait MathShape { - /// Builds a new shape from the given node with default (X, Y, Z) axes - fn new(ctx: &mut Context, node: Node) -> Result - where - Self: Sized, - { - let axes = ctx.axes(); - Self::new_with_axes(ctx, node, axes) - } - - /// Builds a new shape from the given context, node, and axes - fn new_with_axes( - ctx: &Context, - node: Node, - axes: [Node; 3], - ) -> Result - where - Self: Sized; - - /// Helper function to build a shape from a [`Tree`](crate::context::Tree) - /// - /// This function uses the default (X, Y, Z) axes - fn from_tree(t: &crate::context::Tree) -> Self - where - Self: Sized, - { - let mut ctx = Context::new(); - let node = ctx.import(t); - Self::new(&mut ctx, node).unwrap() - } -} - -//////////////////////////////////////////////////////////////////////////////// - -/// Wrapper to convert a [`Function`](fidget::eval::Function) into a [`Shape`] -/// for evaluation. -#[derive(Clone)] -pub struct FunctionShape { - /// Wrapped function - f: F, - - /// Index of x, y, z axes within the function's variable list (if present) - axes: [Option; 3], -} - -impl Shape for FunctionShape { - type Trace = ::Trace; - type Storage = ::Storage; - type Workspace = ::Workspace; - type TapeStorage = ::TapeStorage; - - type PointEval = FunctionShapeTracingEval<::PointEval>; - type IntervalEval = - FunctionShapeTracingEval<::IntervalEval>; - type FloatSliceEval = - FunctionShapeBulkEval<::FloatSliceEval>; - type GradSliceEval = - FunctionShapeBulkEval<::GradSliceEval>; - - fn point_tape( - &self, - storage: Self::TapeStorage, - ) -> ::Tape { - FunctionShapeTape { - tape: self.f.point_tape(storage), - axes: self.axes, - } - } - - fn interval_tape( - &self, - storage: Self::TapeStorage, - ) -> ::Tape { - FunctionShapeTape { - tape: self.f.interval_tape(storage), - axes: self.axes, - } - } - - fn float_slice_tape( - &self, - storage: Self::TapeStorage, - ) -> ::Tape { - FunctionShapeTape { - tape: self.f.float_slice_tape(storage), - axes: self.axes, - } - } - - fn grad_slice_tape( - &self, - storage: Self::TapeStorage, - ) -> ::Tape { - FunctionShapeTape { - tape: self.f.grad_slice_tape(storage), - axes: self.axes, - } - } - - fn simplify( - &self, - trace: &Self::Trace, - storage: Self::Storage, - workspace: &mut Self::Workspace, - ) -> Result - where - Self: Sized, - { - let f = self.f.simplify(trace, storage, workspace)?; - Ok(Self { f, axes: self.axes }) - } - - fn recycle(self) -> Option { - self.f.recycle() - } - - fn size(&self) -> usize { - self.f.size() - } - - type TransformedShape = TransformedShape; - - fn apply_transform( - self, - mat: nalgebra::Matrix4, - ) -> ::TransformedShape { - TransformedShape::new(self, mat) - } - - // todo -} - -impl MathShape for FunctionShape { - fn new_with_axes( +impl Shape { + pub fn new_with_axes( ctx: &Context, node: Node, axes: [Node; 3], @@ -408,76 +294,55 @@ impl MathShape for FunctionShape { axes: [x, y, z].map(|v| vs.get(v).cloned()), }) } -} -impl FunctionShape { - /// Borrows the inner [`Function`](eval::Function) object - pub fn inner(&self) -> &F { - &self.f - } - - /// Borrows the inner axis mapping - pub fn axes(&self) -> &[Option; 3] { - &self.axes - } - - /// Raw constructor - pub fn new_raw(f: F, axes: [Option; 3]) -> Self { - Self { f, axes } + /// Builds a new shape from the given node with default (X, Y, Z) axes + pub fn new(ctx: &mut Context, node: Node) -> Result + where + Self: Sized, + { + let axes = ctx.axes(); + Self::new_with_axes(ctx, node, axes) } } -impl RenderHints for FunctionShape { - fn tile_sizes_3d() -> &'static [usize] { - F::tile_sizes_3d() - } - - fn tile_sizes_2d() -> &'static [usize] { - F::tile_sizes_2d() - } - - fn simplify_tree_during_meshing(d: usize) -> bool { - F::simplify_tree_during_meshing(d) +impl From for Shape { + fn from(t: Tree) -> Self { + let mut ctx = Context::new(); + let node = ctx.import(&t); + Self::new(&mut ctx, node).unwrap() } } /// Wrapper struct to bind a generic tape to particular X, Y, Z axes -pub struct FunctionShapeTape { +pub struct ShapeTape { tape: T, /// Index of the X, Y, Z axes in the variables array axes: [Option; 3], } -impl eval::Tape for FunctionShapeTape { - type Storage = ::Storage; - fn recycle(self) -> Self::Storage { +impl ShapeTape { + /// Recycles the inner tape's storage for reuse + pub fn recycle(self) -> T::Storage { self.tape.recycle() } } -/// Wrapper struct to convert from [`eval::TracingEvaluator`] to +/// Wrapper struct to convert from [`TracingEvaluator`] to /// [`shape::TracingEvaluator`](TracingEvaluator) -#[derive(Default)] -pub struct FunctionShapeTracingEval { +#[derive(Debug, Default)] +pub struct ShapeTracingEval { eval: E, } -impl TracingEvaluator - for FunctionShapeTracingEval -{ - type Data = E::Data; - type Tape = FunctionShapeTape; - type TapeStorage = E::TapeStorage; - type Trace = E::Trace; - - fn eval>( +impl ShapeTracingEval { + pub fn eval>( &mut self, - tape: &Self::Tape, + tape: &ShapeTape, x: F, y: F, z: F, - ) -> Result<(Self::Data, Option<&Self::Trace>), Error> { + ) -> Result<(E::Data, Option<&E::Trace>), Error> { let mut vars = [None, None, None]; if let Some(a) = tape.axes[0] { vars[a] = Some(x.into()); @@ -488,32 +353,49 @@ impl TracingEvaluator if let Some(c) = tape.axes[2] { vars[c] = Some(z.into()); } - let n = vars.iter().position(|v| Option::is_none(v)).unwrap_or(3); + let n = vars.iter().position(Option::is_none).unwrap_or(3); let vars = vars.map(|v| v.unwrap_or(0f32.into())); self.eval.eval(&tape.tape, &vars[..n]) } - // todo + + #[cfg(test)] + pub fn eval_x>( + &mut self, + tape: &ShapeTape, + x: J, + ) -> E::Data { + self.eval(tape, x.into(), E::Data::from(0.0), E::Data::from(0.0)) + .unwrap() + .0 + } + #[cfg(test)] + pub fn eval_xy>( + &mut self, + tape: &ShapeTape, + x: J, + y: J, + ) -> E::Data { + self.eval(tape, x.into(), y.into(), E::Data::from(0.0)) + .unwrap() + .0 + } } -/// Wrapper struct to convert from [`eval::BulkEvaluator`] to +/// Wrapper struct to convert from [`BulkEvaluator`] to /// [`shape::TracingEvaluator`](BulkEvaluator) -#[derive(Default)] -pub struct FunctionShapeBulkEval { +#[derive(Debug, Default)] +pub struct ShapeBulkEval { eval: E, } -impl BulkEvaluator for FunctionShapeBulkEval { - type Data = E::Data; - type Tape = FunctionShapeTape; - type TapeStorage = E::TapeStorage; - - fn eval( +impl ShapeBulkEval { + pub fn eval( &mut self, - tape: &Self::Tape, - x: &[Self::Data], - y: &[Self::Data], - z: &[Self::Data], - ) -> Result<&[Self::Data], Error> { + tape: &ShapeTape, + x: &[E::Data], + y: &[E::Data], + z: &[E::Data], + ) -> Result<&[E::Data], Error> { let mut vars = [None, None, None]; if let Some(a) = tape.axes[0] { vars[a] = Some(x); @@ -528,7 +410,7 @@ impl BulkEvaluator for FunctionShapeBulkEval { let vars = if vars.iter().all(Option::is_some) { vars.map(Option::unwrap) } else if let Some(q) = vars.iter().find(|v| v.is_some()) { - vars.map(|v| if v.is_some() { v.unwrap() } else { q.unwrap() }) + vars.map(|v| v.unwrap_or_else(|| q.unwrap())) } else { [[].as_slice(); 3] }; diff --git a/fidget/src/core/shape/tracing.rs b/fidget/src/core/shape/tracing.rs deleted file mode 100644 index 173494f7..00000000 --- a/fidget/src/core/shape/tracing.rs +++ /dev/null @@ -1,65 +0,0 @@ -//! Capturing a trace of function evaluation for further optimization -//! -//! Tracing evaluators are run on a single data type and capture a trace of -//! execution, which is the [`Trace` associated type](TracingEvaluator::Trace). -//! -//! The resulting trace can be used to simplify the original shape. -//! -//! It is unlikely that you'll want to use these traits or types directly; -//! they're implementation details to minimize code duplication. - -use crate::{eval::Tape, Error}; - -/// Evaluator for single values which simultaneously captures an execution trace -/// -/// The trace can later be used to simplify the [`Shape`](crate::shape::Shape) -/// using [`Shape::simplify`](crate::shape::Shape::simplify). -pub trait TracingEvaluator: Default { - /// Data type used during evaluation - type Data: From + Copy + Clone; - - /// Instruction tape used during evaluation - /// - /// This may be a literal instruction tape (in the case of VM evaluation), - /// or a metaphorical instruction tape (e.g. a JIT function). - type Tape: Tape + Send + Sync; - - /// Associated type for tape storage - /// - /// This is a workaround for plumbing purposes - type TapeStorage; - - /// Associated type for the trace captured during evaluation - type Trace; - - /// Evaluates the given tape at a particular position - fn eval>( - &mut self, - tape: &Self::Tape, - x: F, - y: F, - z: F, - ) -> Result<(Self::Data, Option<&Self::Trace>), Error>; - - #[cfg(test)] - fn eval_x>( - &mut self, - tape: &Self::Tape, - x: J, - ) -> Self::Data { - self.eval(tape, x.into(), Self::Data::from(0.0), Self::Data::from(0.0)) - .unwrap() - .0 - } - #[cfg(test)] - fn eval_xy>( - &mut self, - tape: &Self::Tape, - x: J, - y: J, - ) -> Self::Data { - self.eval(tape, x.into(), y.into(), Self::Data::from(0.0)) - .unwrap() - .0 - } -} diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index a9cc5e6c..e950009e 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -6,7 +6,7 @@ use crate::{ BulkEvaluator, Function, MathFunction, Tape, Trace, TracingEvaluator, VarMap, }, - shape::{FunctionShape, RenderHints}, + shape::{RenderHints, Shape}, types::{Grad, Interval}, Context, Error, }; @@ -20,6 +20,8 @@ pub use data::{VmData, VmWorkspace}; //////////////////////////////////////////////////////////////////////////////// +pub type VmFunction = GenericVmFunction<{ u8::MAX as usize }>; + /// Shape that use a VM backend for evaluation /// /// Internally, the [`VmShape`] stores an [`Arc`](VmData), and @@ -27,7 +29,7 @@ pub use data::{VmData, VmWorkspace}; /// /// All of the associated [`Tape`] types simply clone the internal `Arc`; /// there's no separate planning required to generate a tape. -pub type VmShape = FunctionShape>; +pub type VmShape = Shape; impl Tape for GenericVmFunction { type Storage = (); @@ -1429,8 +1431,8 @@ impl BulkEvaluator for VmGradSliceEval { #[cfg(test)] mod test { use super::*; - crate::grad_slice_tests!(VmShape); - crate::interval_tests!(VmShape); - crate::float_slice_tests!(VmShape); - crate::point_tests!(VmShape); + crate::grad_slice_tests!(VmFunction); + crate::interval_tests!(VmFunction); + crate::float_slice_tests!(VmFunction); + crate::point_tests!(VmFunction); } diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index 951b9fe7..bf63dbe3 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -1219,15 +1219,15 @@ impl MathFunction for JitFunction { } /// A [`Shape`](crate::shape::Shape) which uses the JIT evaluator -pub type JitShape = crate::shape::FunctionShape; +pub type JitShape = crate::shape::Shape; //////////////////////////////////////////////////////////////////////////////// #[cfg(test)] mod test { use super::*; - crate::grad_slice_tests!(JitShape); - crate::interval_tests!(JitShape); - crate::float_slice_tests!(JitShape); - crate::point_tests!(JitShape); + crate::grad_slice_tests!(JitFunction); + crate::interval_tests!(JitFunction); + crate::float_slice_tests!(JitFunction); + crate::point_tests!(JitFunction); } diff --git a/fidget/src/mesh/mt/octree.rs b/fidget/src/mesh/mt/octree.rs index 8b500fe9..c9331ab7 100644 --- a/fidget/src/mesh/mt/octree.rs +++ b/fidget/src/mesh/mt/octree.rs @@ -1,6 +1,7 @@ //! Multithreaded octree construction use super::pool::{QueuePool, ThreadContext, ThreadPool}; use crate::{ + eval::Function, mesh::{ cell::{Cell, CellData, CellIndex}, octree::{BranchResult, CellResult, EvalGroup, OctreeBuilder}, @@ -8,7 +9,6 @@ use crate::{ Octree, Settings, }, shape::RenderHints, - shape::Shape, }; use std::sync::{mpsc::TryRecvError, Arc}; @@ -18,22 +18,22 @@ use std::sync::{mpsc::TryRecvError, Arc}; /// octants, sending results back to the parent (which is numbered implicitly /// based on what queue we stole this from). #[derive(Clone)] -struct Task { - data: Arc>, +struct Task { + data: Arc>, } -impl std::ops::Deref for Task { - type Target = TaskData; +impl std::ops::Deref for Task { + type Target = TaskData; fn deref(&self) -> &Self::Target { &self.data } } -impl Task { +impl Task { /// Builds a new root task /// /// The root task is from worker 0 with the default cell index - fn new(eval: Arc>) -> Self { + fn new(eval: Arc>) -> Self { Self { data: Arc::new(TaskData { eval, @@ -46,7 +46,7 @@ impl Task { fn child( &self, - eval: Arc>, + eval: Arc>, target_cell: CellIndex, assigned_by: usize, ) -> Self { @@ -61,8 +61,8 @@ impl Task { } } -struct TaskData { - eval: Arc>, +struct TaskData { + eval: Arc>, /// Thread in which the parent cell lives assigned_by: usize, @@ -70,12 +70,12 @@ struct TaskData { /// Parent cell, which must be an `Invalid` cell waiting for population target_cell: CellIndex, - parent: Option>>, + parent: Option>>, } -struct Done { +struct Done { /// The task that we have finished evaluating - task: Task, + task: Task, /// The resulting cell /// @@ -89,7 +89,7 @@ struct Done { completed_by: usize, } -pub struct OctreeWorker { +pub struct OctreeWorker { /// Global index of this worker thread /// /// For example, this is the thread's own index in `friend_queue` and @@ -101,24 +101,24 @@ pub struct OctreeWorker { /// This octree may not be complete; worker 0 is guaranteed to contain the /// root, and other works may contain fragmentary branches that point to /// each other in a tree structure. - octree: OctreeBuilder, + octree: OctreeBuilder, /// Incoming completed tasks from other threads - done: std::sync::mpsc::Receiver>, + done: std::sync::mpsc::Receiver>, /// Our queue of tasks - queue: QueuePool>, + queue: QueuePool>, /// When a worker finishes a task, it returns it through these queues /// /// Like `friend_queue`, there's one per thread, including the worker's own /// thread; it would be silly to send stuff back to your own thread via the /// queue (rather than storing it directly). - friend_done: Vec>>, + friend_done: Vec>>, } -impl OctreeWorker { - pub fn scheduler(eval: Arc>, settings: Settings) -> Octree { +impl OctreeWorker { + pub fn scheduler(eval: Arc>, settings: Settings) -> Octree { let task_queues = QueuePool::new(settings.threads()); let done_queues = std::iter::repeat_with(std::sync::mpsc::channel) .take(settings.threads()) @@ -250,13 +250,13 @@ impl OctreeWorker { self.octree.into() } - fn reclaim(&mut self, task: Task) { + fn reclaim(&mut self, task: Task) { if let Ok(t) = Arc::try_unwrap(task.data) { self.reclaim_inner(t) } } - fn reclaim_inner(&mut self, mut t: TaskData) { + fn reclaim_inner(&mut self, mut t: TaskData) { // Try recycling the tapes, if no one else is using them if let Ok(e) = Arc::try_unwrap(t.eval) { self.octree.reclaim(e); @@ -271,7 +271,7 @@ impl OctreeWorker { fn on_done( &mut self, result: BranchResult, - task: &Arc>, + task: &Arc>, completed_by: usize, ctx: &mut ThreadContext, ) { @@ -306,7 +306,7 @@ impl OctreeWorker { &mut self, index: usize, cell: CellData, - parent_task: &Arc>, + parent_task: &Arc>, ctx: &mut ThreadContext, ) { self.octree.record(index, cell); diff --git a/fidget/src/mesh/octree.rs b/fidget/src/mesh/octree.rs index 7dde8cb9..26f5d760 100644 --- a/fidget/src/mesh/octree.rs +++ b/fidget/src/mesh/octree.rs @@ -11,8 +11,8 @@ use super::{ Mesh, Settings, }; use crate::{ - eval::Tape, - shape::{BulkEvaluator, RenderHints, Shape, TracingEvaluator}, + eval::{BulkEvaluator, Function, TracingEvaluator}, + shape::{RenderHints, Shape, ShapeBulkEval, ShapeTape, ShapeTracingEval}, types::Grad, }; use std::{num::NonZeroUsize, sync::Arc, sync::OnceLock}; @@ -20,22 +20,26 @@ use std::{num::NonZeroUsize, sync::Arc, sync::OnceLock}; #[cfg(not(target_arch = "wasm32"))] use super::mt::{DcWorker, OctreeWorker}; +// TODO use fidget::render::RenderHandle here instead? /// Helper struct to contain a set of matched evaluators /// /// Note that this is `Send + Sync` and can be used with shared references! -pub struct EvalGroup { - pub shape: S, +pub struct EvalGroup { + pub shape: Shape, // TODO: passing around an `Arc` ends up with two layers of // indirection (since the tapes also contain `Arc`); could we flatten // them out? (same with the shape, which is usually an `Arc`) - pub interval: OnceLock<::Tape>, - pub float_slice: OnceLock<::Tape>, - pub grad_slice: OnceLock<::Tape>, + pub interval: + OnceLock::Tape>>, + pub float_slice: + OnceLock::Tape>>, + pub grad_slice: + OnceLock::Tape>>, } -impl EvalGroup { - fn new(shape: S) -> Self { +impl EvalGroup { + fn new(shape: Shape) -> Self { Self { shape, interval: OnceLock::new(), @@ -45,16 +49,16 @@ impl EvalGroup { } fn interval_tape( &self, - storage: &mut Vec, - ) -> &::Tape { + storage: &mut Vec, + ) -> &ShapeTape<::Tape> { self.interval.get_or_init(|| { self.shape.interval_tape(storage.pop().unwrap_or_default()) }) } fn float_slice_tape( &self, - storage: &mut Vec, - ) -> &::Tape { + storage: &mut Vec, + ) -> &ShapeTape<::Tape> { self.float_slice.get_or_init(|| { self.shape .float_slice_tape(storage.pop().unwrap_or_default()) @@ -62,8 +66,8 @@ impl EvalGroup { } fn grad_slice_tape( &self, - storage: &mut Vec, - ) -> &::Tape { + storage: &mut Vec, + ) -> &ShapeTape<::Tape> { self.grad_slice.get_or_init(|| { self.shape .grad_slice_tape(storage.pop().unwrap_or_default()) @@ -91,13 +95,10 @@ impl Octree { /// Builds an octree to the given depth /// /// The shape is evaluated on the region specified by `settings.bounds`. - pub fn build( - shape: &S, + pub fn build( + shape: &Shape, settings: Settings, - ) -> Self - where - ::TransformedShape: RenderHints, - { + ) -> Self { // Transform the shape given our bounds let t = settings.bounds.transform(); if t == nalgebra::Transform::identity() { @@ -116,8 +117,8 @@ impl Octree { } } - fn build_inner( - shape: &S, + fn build_inner( + shape: &Shape, settings: Settings, ) -> Self { let eval = Arc::new(EvalGroup::new(shape.clone())); @@ -238,7 +239,7 @@ impl std::ops::IndexMut for Octree { /// Data structure for an under-construction octree #[derive(Debug)] -pub(crate) struct OctreeBuilder { +pub(crate) struct OctreeBuilder { /// Internal octree /// /// Note that in this internal octree, the `index` field of leaf nodes @@ -262,23 +263,23 @@ pub(crate) struct OctreeBuilder { /// Available slots in the `hermite` array hermite_slots: Vec, - eval_float_slice: S::FloatSliceEval, - eval_interval: S::IntervalEval, - eval_grad_slice: S::GradSliceEval, + eval_float_slice: ShapeBulkEval, + eval_interval: ShapeTracingEval, + eval_grad_slice: ShapeBulkEval, - pub tape_storage: Vec, - pub shape_storage: Vec, - workspace: S::Workspace, + pub tape_storage: Vec, + pub shape_storage: Vec, + workspace: F::Workspace, } -impl Default for OctreeBuilder { +impl Default for OctreeBuilder { fn default() -> Self { Self::new() } } -impl From> for Octree { - fn from(o: OctreeBuilder) -> Self { +impl From> for Octree { + fn from(o: OctreeBuilder) -> Self { // Convert from "leaf index into self.leafs" (in the builder) to // "leaf index into self.verts" (in the resulting Octree) let cells = @@ -303,7 +304,7 @@ impl From> for Octree { } } -impl OctreeBuilder { +impl OctreeBuilder { /// Builds a new octree, which allocates data for 8 root cells pub(crate) fn new() -> Self { Self { @@ -314,9 +315,9 @@ impl OctreeBuilder { leafs: vec![], hermite: vec![LeafHermiteData::default()], hermite_slots: vec![], - eval_float_slice: S::new_float_slice_eval(), - eval_grad_slice: S::new_grad_slice_eval(), - eval_interval: S::new_interval_eval(), + eval_float_slice: Shape::::new_float_slice_eval(), + eval_grad_slice: Shape::::new_grad_slice_eval(), + eval_interval: Shape::::new_interval_eval(), tape_storage: vec![], shape_storage: vec![], workspace: Default::default(), @@ -347,10 +348,10 @@ impl OctreeBuilder { /// octree (e.g. on another thread). pub(crate) fn eval_cell( &mut self, - eval: &Arc>, + eval: &Arc>, cell: CellIndex, settings: Settings, - ) -> CellResult { + ) -> CellResult { let (i, r) = self .eval_interval .eval( @@ -365,16 +366,19 @@ impl OctreeBuilder { } else if i.lower() > 0.0 { CellResult::Done(Cell::Empty) } else { - let sub_tape = if S::simplify_tree_during_meshing(cell.depth) { - let s = self.shape_storage.pop().unwrap_or_default(); - r.map(|r| { - Arc::new(EvalGroup::new( - eval.shape.simplify(r, s, &mut self.workspace).unwrap(), - )) - }) - } else { - None - }; + let sub_tape = + if Shape::::simplify_tree_during_meshing(cell.depth) { + let s = self.shape_storage.pop().unwrap_or_default(); + r.map(|r| { + Arc::new(EvalGroup::new( + eval.shape + .simplify(r, s, &mut self.workspace) + .unwrap(), + )) + }) + } else { + None + }; if cell.depth == settings.depth as usize { let eval = sub_tape.unwrap_or_else(|| eval.clone()); let out = CellResult::Done(self.leaf(&eval, cell)); @@ -426,7 +430,7 @@ impl OctreeBuilder { /// Recurse down the octree, building the given cell fn recurse( &mut self, - eval: &Arc>, + eval: &Arc>, cell: CellIndex, settings: Settings, ) { @@ -468,7 +472,7 @@ impl OctreeBuilder { /// Writes the leaf vertex to `self.o.verts`, hermite data to /// `self.hermite`, and the leaf data to `self.leafs`. Does **not** write /// anything to `self.o.cells`; the cell is returned instead. - fn leaf(&mut self, eval: &EvalGroup, cell: CellIndex) -> Cell { + fn leaf(&mut self, eval: &EvalGroup, cell: CellIndex) -> Cell { let mut xs = [0.0; 8]; let mut ys = [0.0; 8]; let mut zs = [0.0; 8]; @@ -895,7 +899,7 @@ impl OctreeBuilder { CELL_TO_VERT_TO_EDGES[mask as usize].len() == 1 } - pub(crate) fn reclaim(&mut self, mut e: EvalGroup) { + pub(crate) fn reclaim(&mut self, mut e: EvalGroup) { if let Some(s) = e.shape.recycle() { self.shape_storage.push(s); } @@ -913,7 +917,7 @@ impl OctreeBuilder { /// `OctreeBuilder` functions which are only used during multithreaded rendering #[cfg(not(target_arch = "wasm32"))] -impl OctreeBuilder { +impl OctreeBuilder { /// Builds a new empty octree /// /// This still allocates data to reserve the lowest slot in `hermite` @@ -927,9 +931,9 @@ impl OctreeBuilder { hermite: vec![LeafHermiteData::default()], hermite_slots: vec![], - eval_float_slice: S::new_float_slice_eval(), - eval_grad_slice: S::new_grad_slice_eval(), - eval_interval: S::new_interval_eval(), + eval_float_slice: Shape::::new_float_slice_eval(), + eval_grad_slice: Shape::::new_grad_slice_eval(), + eval_interval: Shape::::new_interval_eval(), tape_storage: vec![], shape_storage: vec![], @@ -951,9 +955,9 @@ impl OctreeBuilder { } /// Result of a single cell evaluation -pub enum CellResult { +pub enum CellResult { Done(Cell), - Recurse(Arc>), + Recurse(Arc>), } /// Result of a branch evaluation (8-fold division) @@ -1170,9 +1174,8 @@ mod test { use crate::{ context::Tree, mesh::types::{Edge, X, Y, Z}, - shape::Bounds, - shape::{EzShape, MathShape}, - vm::VmShape, + shape::{Bounds, EzShape}, + vm::{VmFunction, VmShape}, }; use nalgebra::Vector3; use std::collections::BTreeMap; @@ -1215,7 +1218,7 @@ mod test { fn test_cube_edge() { const EPSILON: f32 = 1e-3; let f = 2.0; - let shape = VmShape::from_tree(&cube([-f, f], [-f, 0.3], [-f, 0.6])); + let shape = VmShape::from(cube([-f, f], [-f, 0.3], [-f, 0.6])); // This should be a cube with a single edge running through the root // node of the octree, with an edge vertex at [0, 0.3, 0.6] let octree = Octree::build(&shape, DEPTH0_SINGLE_THREAD); @@ -1264,7 +1267,7 @@ mod test { #[test] fn test_mesh_basic() { - let shape = VmShape::from_tree(&sphere([0.0; 3], 0.2)); + let shape = VmShape::from(sphere([0.0; 3], 0.2)); // If we only build a depth-0 octree, then it's a leaf without any // vertices (since all the corners are empty) @@ -1307,7 +1310,7 @@ mod test { #[test] fn test_sphere_verts() { - let shape = VmShape::from_tree(&sphere([0.0; 3], 0.2)); + let shape = VmShape::from(sphere([0.0; 3], 0.2)); let octree = Octree::build(&shape, DEPTH1_SINGLE_THREAD); let sphere_mesh = octree.walk_dual(DEPTH1_SINGLE_THREAD); @@ -1343,7 +1346,7 @@ mod test { #[test] fn test_sphere_manifold() { - let shape = VmShape::from_tree(&sphere([0.0; 3], 0.85)); + let shape = VmShape::from(sphere([0.0; 3], 0.85)); for threads in [1, 8] { let settings = Settings { @@ -1371,8 +1374,7 @@ mod test { #[test] fn test_cube_verts() { - let shape = - VmShape::from_tree(&cube([-0.1, 0.6], [-0.2, 0.75], [-0.3, 0.4])); + let shape = VmShape::from(cube([-0.1, 0.6], [-0.2, 0.75], [-0.3, 0.4])); let octree = Octree::build(&shape, DEPTH1_SINGLE_THREAD); let mesh = octree.walk_dual(DEPTH1_SINGLE_THREAD); @@ -1422,7 +1424,7 @@ mod test { for offset in [0.0, -0.2, 0.2] { let (x, y, z) = Tree::axes(); let f = x * dx + y * dy + z + offset; - let shape = VmShape::from_tree(&f); + let shape = VmShape::from(f); let octree = Octree::build(&shape, DEPTH0_SINGLE_THREAD); assert_eq!(octree.cells.len(), 8); @@ -1457,7 +1459,7 @@ mod test { nalgebra::Vector3::new(1.2, 1.3, 1.4), ] { let corner = nalgebra::Vector3::new(-1.0, -1.0, -1.0); - let shape = VmShape::from_tree(&cone(corner, tip, 0.1)); + let shape = VmShape::from(cone(corner, tip, 0.1)); let mut eval = VmShape::new_point_eval(); let tape = shape.ez_point_tape(); @@ -1498,7 +1500,7 @@ mod test { // Now, we have our shape, which is 0-8 spheres placed at the // corners of the cell spanning [0, 0.25] - let shape = VmShape::from_tree(&shape); + let shape = VmShape::from(shape); let settings = Settings { depth: 2, threads: threads.try_into().unwrap(), @@ -1536,8 +1538,11 @@ mod test { #[test] fn test_collapsible() { - fn builder(shape: Tree, settings: Settings) -> OctreeBuilder { - let shape = VmShape::from_tree(&shape); + fn builder( + shape: Tree, + settings: Settings, + ) -> OctreeBuilder { + let shape = VmShape::from(shape); let eval = Arc::new(EvalGroup::new(shape)); let mut out = OctreeBuilder::new(); out.recurse(&eval, CellIndex::default(), settings); @@ -1566,7 +1571,7 @@ mod test { #[test] fn test_empty_collapse() { // Make a very smol sphere that won't be sampled - let shape = VmShape::from_tree(&sphere([0.1; 3], 0.05)); + let shape = VmShape::from(sphere([0.1; 3], 0.05)); for threads in [1, 4] { let settings = Settings { depth: 1, @@ -1682,7 +1687,7 @@ mod test { #[test] fn test_qef_near_planar() { - let shape = VmShape::from_tree(&sphere([0.0; 3], 0.75)); + let shape = VmShape::from(sphere([0.0; 3], 0.75)); let settings = Settings { depth: 4, @@ -1699,7 +1704,7 @@ mod test { #[test] fn test_octree_bounds() { - let shape = VmShape::from_tree(&sphere([1.0; 3], 0.25)); + let shape = VmShape::from(sphere([1.0; 3], 0.25)); let center = Vector3::new(1.0, 1.0, 1.0); let settings = Settings { diff --git a/fidget/src/render/config.rs b/fidget/src/render/config.rs index 2a6f4e6f..5fd964c1 100644 --- a/fidget/src/render/config.rs +++ b/fidget/src/render/config.rs @@ -1,4 +1,5 @@ use crate::{ + eval::Function, render::RenderMode, shape::{Bounds, Shape}, Error, @@ -214,11 +215,11 @@ impl RenderConfig<2> { /// /// Under the hood, this delegates to /// [`fidget::render::render2d`](crate::render::render2d()) - pub fn run( + pub fn run( &self, - shape: S, + shape: Shape, ) -> Result::Output>, Error> { - Ok(crate::render::render2d::(shape, self)) + Ok(crate::render::render2d::(shape, self)) } } @@ -229,11 +230,11 @@ impl RenderConfig<3> { /// [`fidget::render::render3d`](crate::render::render3d()) /// /// Returns a tuple of heightmap, RGB image. - pub fn run( + pub fn run( &self, - shape: S, + shape: Shape, ) -> Result<(Vec, Vec<[u8; 3]>), Error> { - Ok(crate::render::render3d::(shape, self)) + Ok(crate::render::render3d::(shape, self)) } } diff --git a/fidget/src/render/mod.rs b/fidget/src/render/mod.rs index f8fbe540..b4e67091 100644 --- a/fidget/src/render/mod.rs +++ b/fidget/src/render/mod.rs @@ -5,8 +5,8 @@ //! functions ([`render2d`](render2d()) and [`render3d`](render3d())) for manual //! control over the input tape. use crate::{ - eval::{Tape, Trace}, - shape::{BulkEvaluator, Shape, TracingEvaluator}, + eval::{BulkEvaluator, Function, Trace, TracingEvaluator}, + shape::{Shape, ShapeTape}, }; use std::sync::Arc; @@ -28,17 +28,17 @@ pub use render2d::{ /// The tapes are stored as `Arc<..>`, so it can be cheaply cloned. /// /// The most recent simplification is cached for reuse (if the trace matches). -pub struct RenderHandle { - shape: S, +pub struct RenderHandle { + shape: Shape, - i_tape: Option::Tape>>, - f_tape: Option::Tape>>, - g_tape: Option::Tape>>, + i_tape: Option::Tape>>>, + f_tape: Option::Tape>>>, + g_tape: Option::Tape>>>, - next: Option<(S::Trace, Box)>, + next: Option<(F::Trace, Box)>, } -impl Clone for RenderHandle { +impl Clone for RenderHandle { fn clone(&self) -> Self { Self { shape: self.shape.clone(), @@ -50,14 +50,11 @@ impl Clone for RenderHandle { } } -impl RenderHandle -where - S: Shape, -{ +impl RenderHandle { /// Build a new [`RenderHandle`] for the given shape /// /// None of the tapes are populated here. - pub fn new(shape: S) -> Self { + pub fn new(shape: Shape) -> Self { Self { shape, i_tape: None, @@ -70,8 +67,8 @@ where /// Returns a tape for tracing interval evaluation pub fn i_tape( &mut self, - storage: &mut Vec, - ) -> &::Tape { + storage: &mut Vec, + ) -> &ShapeTape<::Tape> { self.i_tape.get_or_insert_with(|| { Arc::new( self.shape.interval_tape(storage.pop().unwrap_or_default()), @@ -82,8 +79,8 @@ where /// Returns a tape for bulk float evaluation pub fn f_tape( &mut self, - storage: &mut Vec, - ) -> &::Tape { + storage: &mut Vec, + ) -> &ShapeTape<::Tape> { self.f_tape.get_or_insert_with(|| { Arc::new( self.shape @@ -95,8 +92,8 @@ where /// Returns a tape for bulk gradient evaluation pub fn g_tape( &mut self, - storage: &mut Vec, - ) -> &::Tape { + storage: &mut Vec, + ) -> &ShapeTape<::Tape> { self.g_tape.get_or_insert_with(|| { Arc::new( self.shape @@ -111,10 +108,10 @@ where /// the trace matches. pub fn simplify( &mut self, - trace: &S::Trace, - workspace: &mut S::Workspace, - shape_storage: &mut Vec, - tape_storage: &mut Vec, + trace: &F::Trace, + workspace: &mut F::Workspace, + shape_storage: &mut Vec, + tape_storage: &mut Vec, ) -> &mut Self { // Free self.next if it doesn't match our new set of choices let mut trace_storage = if let Some(neighbor) = &self.next { @@ -168,8 +165,8 @@ where /// Recycles the entire handle into the given storage vectors pub fn recycle( mut self, - shape_storage: &mut Vec, - tape_storage: &mut Vec, + shape_storage: &mut Vec, + tape_storage: &mut Vec, ) { // Recycle the child first, in case it borrowed from us if let Some((_trace, shape)) = self.next.take() { diff --git a/fidget/src/render/render2d.rs b/fidget/src/render/render2d.rs index 62161817..7f2d177a 100644 --- a/fidget/src/render/render2d.rs +++ b/fidget/src/render/render2d.rs @@ -1,8 +1,9 @@ //! 2D bitmap rendering / rasterization use super::RenderHandle; use crate::{ + eval::Function, render::config::{AlignedRenderConfig, Queue, RenderConfig, Tile}, - shape::{BulkEvaluator, Shape, TracingEvaluator}, + shape::{Shape, ShapeBulkEval, ShapeTracingEval}, types::Interval, }; use nalgebra::Point2; @@ -200,29 +201,29 @@ impl Scratch { //////////////////////////////////////////////////////////////////////////////// /// Per-thread worker -struct Worker<'a, S: Shape, M: RenderMode> { +struct Worker<'a, F: Function, M: RenderMode> { config: &'a AlignedRenderConfig<2>, scratch: Scratch, - eval_float_slice: S::FloatSliceEval, - eval_interval: S::IntervalEval, + eval_float_slice: ShapeBulkEval, + eval_interval: ShapeTracingEval, /// Spare tape storage for reuse - tape_storage: Vec, + tape_storage: Vec, /// Spare shape storage for reuse - shape_storage: Vec, + shape_storage: Vec, /// Workspace for shape simplification - workspace: S::Workspace, + workspace: F::Workspace, image: Vec, } -impl Worker<'_, S, M> { +impl Worker<'_, F, M> { fn render_tile_recurse( &mut self, - shape: &mut RenderHandle, + shape: &mut RenderHandle, depth: usize, tile: Tile<2>, ) { @@ -310,7 +311,7 @@ impl Worker<'_, S, M> { fn render_tile_pixels( &mut self, - shape: &mut RenderHandle, + shape: &mut RenderHandle, tile_size: usize, tile: Tile<2>, ) { @@ -346,20 +347,20 @@ impl Worker<'_, S, M> { //////////////////////////////////////////////////////////////////////////////// -fn worker( - mut shape: RenderHandle, +fn worker( + mut shape: RenderHandle, queue: &Queue<2>, config: &AlignedRenderConfig<2>, ) -> Vec<(Tile<2>, Vec)> { let mut out = vec![]; let scratch = Scratch::new(config.tile_sizes.last().unwrap_or(&0).pow(2)); - let mut w: Worker = Worker { + let mut w: Worker = Worker { scratch, image: vec![], config, - eval_float_slice: S::FloatSliceEval::default(), - eval_interval: S::IntervalEval::default(), + eval_float_slice: Default::default(), + eval_interval: Default::default(), tape_storage: vec![], shape_storage: vec![], workspace: Default::default(), @@ -384,8 +385,8 @@ fn worker( /// This function is parameterized by both shape type (which determines how we /// perform evaluation) and render mode (which tells us how to color in the /// resulting pixels). -pub fn render( - shape: S, +pub fn render( + shape: Shape, config: &RenderConfig<2>, ) -> Vec { let (config, mat) = config.align(); @@ -402,8 +403,8 @@ pub fn render( render_inner::<_, M>(shape, config) } -fn render_inner( - shape: S, +fn render_inner( + shape: Shape, config: AlignedRenderConfig<2>, ) -> Vec { let mut tiles = vec![]; @@ -423,7 +424,7 @@ fn render_inner( let _ = rh.i_tape(&mut vec![]); // populate i_tape before cloning let out: Vec<_> = if threads == 1 { - worker::(rh, &queue, &config).into_iter().collect() + worker::(rh, &queue, &config).into_iter().collect() } else { #[cfg(target_arch = "wasm32")] unreachable!("multithreaded rendering is not supported on wasm32"); @@ -433,7 +434,7 @@ fn render_inner( let mut handles = vec![]; for _ in 0..threads { let rh = rh.clone(); - handles.push(s.spawn(|| worker::(rh, &queue, &config))); + handles.push(s.spawn(|| worker::(rh, &queue, &config))); } let mut out = vec![]; for h in handles { @@ -467,8 +468,9 @@ fn render_inner( mod test { use super::*; use crate::{ - shape::{Bounds, FunctionShape, MathShape, Shape}, - vm::{GenericVmFunction, VmShape}, + eval::{Function, MathFunction}, + shape::{Bounds, Shape}, + vm::{GenericVmFunction, VmFunction}, Context, }; @@ -479,8 +481,8 @@ mod test { "/../models/quarter.vm" )); - fn render_and_compare_with_bounds( - shape: S, + fn render_and_compare_with_bounds( + shape: Shape, expected: &'static str, bounds: Bounds<2>, ) { @@ -508,13 +510,16 @@ mod test { } } - fn render_and_compare(shape: S, expected: &'static str) { + fn render_and_compare( + shape: Shape, + expected: &'static str, + ) { render_and_compare_with_bounds(shape, expected, Bounds::default()) } - fn check_hi() { + fn check_hi() { let (mut ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); - let shape = S::new(&mut ctx, root).unwrap(); + let shape = Shape::::new(&mut ctx, root).unwrap(); const EXPECTED: &str = " .................X.............. .................X.............. @@ -551,9 +556,9 @@ mod test { render_and_compare(shape, EXPECTED); } - fn check_hi_transformed() { + fn check_hi_transformed() { let (mut ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); - let shape = S::new(&mut ctx, root).unwrap(); + let shape = Shape::::new(&mut ctx, root).unwrap(); let mut mat = nalgebra::Matrix4::::identity(); mat.prepend_translation_mut(&nalgebra::Vector3::new(0.5, 0.5, 0.0)); mat.prepend_scaling_mut(0.5); @@ -594,9 +599,9 @@ mod test { render_and_compare(shape, EXPECTED); } - fn check_hi_bounded() { + fn check_hi_bounded() { let (mut ctx, root) = Context::from_text(HI.as_bytes()).unwrap(); - let shape = S::new(&mut ctx, root).unwrap(); + let shape = Shape::::new(&mut ctx, root).unwrap(); const EXPECTED: &str = " .XXX............................ .XXX............................ @@ -640,9 +645,9 @@ mod test { ); } - fn check_quarter() { + fn check_quarter() { let (mut ctx, root) = Context::from_text(QUARTER.as_bytes()).unwrap(); - let shape = S::new(&mut ctx, root).unwrap(); + let shape = Shape::::new(&mut ctx, root).unwrap(); const EXPECTED: &str = " ................................ ................................ @@ -681,65 +686,65 @@ mod test { #[test] fn render_hi_vm() { - check_hi::(); + check_hi::(); } #[test] fn render_hi_vm3() { - check_hi::>>(); + check_hi::>(); } #[cfg(feature = "jit")] #[test] fn render_hi_jit() { - check_hi::(); + check_hi::(); } #[test] fn render_hi_transformed_vm() { - check_hi_transformed::(); + check_hi_transformed::(); } #[test] fn render_hi_transformed_vm3() { - check_hi_transformed::>>(); + check_hi_transformed::>(); } #[cfg(feature = "jit")] #[test] fn render_hi_transformed_jit() { - check_hi_transformed::(); + check_hi_transformed::(); } #[test] fn render_hi_bounded_vm() { - check_hi_bounded::(); + check_hi_bounded::(); } #[test] fn render_hi_bounded_vm3() { - check_hi_bounded::>>(); + check_hi_bounded::>(); } #[cfg(feature = "jit")] #[test] fn render_hi_bounded_jit() { - check_hi_bounded::(); + check_hi_bounded::(); } #[test] fn render_quarter_vm() { - check_quarter::(); + check_quarter::(); } #[test] fn render_quarter_vm3() { - check_quarter::>>(); + check_quarter::>(); } #[cfg(feature = "jit")] #[test] fn render_quarter_jit() { - check_quarter::(); + check_quarter::(); } } diff --git a/fidget/src/render/render3d.rs b/fidget/src/render/render3d.rs index 8ed63b5e..b4e8751c 100644 --- a/fidget/src/render/render3d.rs +++ b/fidget/src/render/render3d.rs @@ -1,8 +1,9 @@ //! 3D bitmap rendering / rasterization use super::RenderHandle; use crate::{ + eval::Function, render::config::{AlignedRenderConfig, Queue, RenderConfig, Tile}, - shape::{BulkEvaluator, Shape, TracingEvaluator}, + shape::{Shape, ShapeBulkEval, ShapeTracingEval}, types::{Grad, Interval}, }; @@ -44,29 +45,29 @@ impl Scratch { //////////////////////////////////////////////////////////////////////////////// -struct Worker<'a, S: Shape> { +struct Worker<'a, F: Function> { config: &'a AlignedRenderConfig<3>, /// Reusable workspace for evaluation, to minimize allocation scratch: Scratch, - eval_float_slice: S::FloatSliceEval, - eval_grad_slice: S::GradSliceEval, - eval_interval: S::IntervalEval, + eval_float_slice: ShapeBulkEval, + eval_grad_slice: ShapeBulkEval, + eval_interval: ShapeTracingEval, - tape_storage: Vec, - shape_storage: Vec, - workspace: S::Workspace, + tape_storage: Vec, + shape_storage: Vec, + workspace: F::Workspace, /// Output images for this specific tile depth: Vec, color: Vec<[u8; 3]>, } -impl Worker<'_, S> { +impl Worker<'_, F> { fn render_tile_recurse( &mut self, - shape: &mut RenderHandle, + shape: &mut RenderHandle, depth: usize, tile: Tile<3>, ) { @@ -143,7 +144,7 @@ impl Worker<'_, S> { fn render_tile_pixels( &mut self, - shape: &mut RenderHandle, + shape: &mut RenderHandle, tile_size: usize, tile: Tile<3>, ) { @@ -281,8 +282,8 @@ impl Image { //////////////////////////////////////////////////////////////////////////////// -fn worker( - mut shape: RenderHandle, +fn worker( + mut shape: RenderHandle, queues: &[Queue<3>], mut index: usize, config: &AlignedRenderConfig<3>, @@ -292,15 +293,15 @@ fn worker( // Calculate maximum evaluation buffer size let buf_size = *config.tile_sizes.last().unwrap(); let scratch = Scratch::new(buf_size); - let mut w: Worker = Worker { + let mut w: Worker = Worker { scratch, depth: vec![], color: vec![], config, - eval_float_slice: S::FloatSliceEval::default(), - eval_interval: S::IntervalEval::default(), - eval_grad_slice: S::GradSliceEval::default(), + eval_float_slice: Default::default(), + eval_interval: Default::default(), + eval_grad_slice: Default::default(), tape_storage: vec![], shape_storage: vec![], @@ -351,8 +352,8 @@ fn worker( /// /// This function is parameterized by shape type, which determines how we /// perform evaluation. -pub fn render( - shape: S, +pub fn render( + shape: Shape, config: &RenderConfig<3>, ) -> (Vec, Vec<[u8; 3]>) { let (config, mat) = config.align(); @@ -365,8 +366,8 @@ pub fn render( render_inner(shape, config) } -pub fn render_inner( - shape: S, +pub fn render_inner( + shape: Shape, config: AlignedRenderConfig<3>, ) -> (Vec, Vec<[u8; 3]>) { let mut tiles = vec![]; @@ -396,7 +397,7 @@ pub fn render_inner( // Special-case for single-threaded operation, to give simpler backtraces let out: Vec<_> = if threads == 1 { - worker::(rh, tile_queues.as_slice(), 0, &config) + worker::(rh, tile_queues.as_slice(), 0, &config) .into_iter() .collect() } else { @@ -411,7 +412,7 @@ pub fn render_inner( for i in 0..threads { let rh = rh.clone(); handles - .push(s.spawn(move || worker::(rh, queues, i, config))); + .push(s.spawn(move || worker::(rh, queues, i, config))); } let mut out = vec![]; for h in handles { @@ -448,7 +449,7 @@ pub fn render_inner( #[cfg(test)] mod test { use super::*; - use crate::{shape::MathShape, vm::VmShape, Context}; + use crate::{vm::VmShape, Context}; /// Make sure we don't crash if there's only a single tile #[test] From 005a22bf4f7084e7a9916dce74145723f65afb93 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Fri, 24 May 2024 08:19:08 -0400 Subject: [PATCH 08/12] Everything is now working --- demo/src/main.rs | 28 +++---- fidget/src/core/mod.rs | 2 +- fidget/src/core/shape/mod.rs | 147 +++++++++++++++++++++++++++++++---- fidget/src/jit/mod.rs | 4 +- fidget/src/lib.rs | 15 ++-- fidget/src/mesh/mod.rs | 3 +- fidget/src/rhai/mod.rs | 10 +-- viewer/src/main.rs | 30 ++++--- wasm-demo/src/lib.rs | 1 - 9 files changed, 174 insertions(+), 66 deletions(-) diff --git a/demo/src/main.rs b/demo/src/main.rs index c5503c8d..5537019e 100644 --- a/demo/src/main.rs +++ b/demo/src/main.rs @@ -7,10 +7,7 @@ use clap::{Parser, Subcommand, ValueEnum}; use env_logger::Env; use log::info; -use fidget::{ - context::Context, - shape::{BulkEvaluator, MathShape}, -}; +use fidget::context::Context; /// Simple test program #[derive(Parser)] @@ -112,8 +109,8 @@ struct MeshSettings { } //////////////////////////////////////////////////////////////////////////////// -fn run3d( - shape: S, +fn run3d( + shape: fidget::shape::Shape, settings: &ImageSettings, isometric: bool, mode_color: bool, @@ -124,7 +121,7 @@ fn run3d( } let cfg = fidget::render::RenderConfig { image_size: settings.size as usize, - tile_sizes: S::tile_sizes_3d().to_vec(), + tile_sizes: F::tile_sizes_3d().to_vec(), threads: settings.threads, ..Default::default() }; @@ -168,15 +165,15 @@ fn run3d( //////////////////////////////////////////////////////////////////////////////// -fn run2d( - shape: S, +fn run2d( + shape: fidget::shape::Shape, settings: &ImageSettings, brute: bool, sdf: bool, ) -> Vec { if brute { let tape = shape.float_slice_tape(Default::default()); - let mut eval = S::new_float_slice_eval(); + let mut eval = fidget::shape::Shape::::new_float_slice_eval(); let mut out: Vec = vec![]; for _ in 0..settings.n { let mut xs = vec![]; @@ -202,7 +199,7 @@ fn run2d( } else { let cfg = fidget::render::RenderConfig { image_size: settings.size as usize, - tile_sizes: S::tile_sizes_2d().to_vec(), + tile_sizes: F::tile_sizes_2d().to_vec(), threads: settings.threads, ..Default::default() }; @@ -236,13 +233,10 @@ fn run2d( //////////////////////////////////////////////////////////////////////////////// -fn run_mesh( - shape: S, +fn run_mesh( + shape: fidget::shape::Shape, settings: &MeshSettings, -) -> fidget::mesh::Mesh -where - ::TransformedShape: fidget::shape::RenderHints, -{ +) -> fidget::mesh::Mesh { let mut mesh = fidget::mesh::Mesh::new(); for _ in 0..settings.n { diff --git a/fidget/src/core/mod.rs b/fidget/src/core/mod.rs index 42247470..37ed2934 100644 --- a/fidget/src/core/mod.rs +++ b/fidget/src/core/mod.rs @@ -3,7 +3,7 @@ //! ``` //! use fidget::{ //! context::Context, -//! shape::{Shape, MathShape, EzShape, TracingEvaluator}, +//! shape::EzShape, //! vm::VmShape //! }; //! let mut ctx = Context::new(); diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index a7f7f8b7..562cfb7b 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -7,7 +7,7 @@ //! ```rust //! use fidget::vm::VmShape; //! use fidget::context::Context; -//! use fidget::shape::{TracingEvaluator, Shape, MathShape, EzShape}; +//! use fidget::shape::EzShape; //! //! let mut ctx = Context::new(); //! let x = ctx.x(); @@ -30,8 +30,10 @@ use crate::{ context::{Context, Node, Tree}, eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator}, + types::{Grad, Interval}, Error, }; +use nalgebra::{Matrix4, Point3}; mod bounds; @@ -52,6 +54,9 @@ pub struct Shape { /// Index of x, y, z axes within the function's variable list (if present) axes: [Option; 3], + + /// Optional transform to apply to the shape + transform: Option>, } impl Shape { @@ -73,6 +78,9 @@ impl Shape { pub fn new_float_slice_eval() -> ShapeBulkEval { ShapeBulkEval { eval: F::FloatSliceEval::default(), + xs: vec![], + ys: vec![], + zs: vec![], } } @@ -80,6 +88,9 @@ impl Shape { pub fn new_grad_slice_eval() -> ShapeBulkEval { ShapeBulkEval { eval: F::GradSliceEval::default(), + xs: vec![], + ys: vec![], + zs: vec![], } } @@ -91,6 +102,7 @@ impl Shape { ShapeTape { tape: self.f.point_tape(storage), axes: self.axes, + transform: self.transform, } } @@ -102,6 +114,7 @@ impl Shape { ShapeTape { tape: self.f.interval_tape(storage), axes: self.axes, + transform: self.transform, } } @@ -113,6 +126,7 @@ impl Shape { ShapeTape { tape: self.f.float_slice_tape(storage), axes: self.axes, + transform: self.transform, } } @@ -124,6 +138,7 @@ impl Shape { ShapeTape { tape: self.f.grad_slice_tape(storage), axes: self.axes, + transform: self.transform, } } @@ -138,7 +153,11 @@ impl Shape { Self: Sized, { let f = self.f.simplify(trace, storage, workspace)?; - Ok(Self { f, axes: self.axes }) + Ok(Self { + f, + axes: self.axes, + transform: self.transform, + }) } /// Attempt to reclaim storage from this shape @@ -169,7 +188,11 @@ impl Shape { /// Raw constructor pub fn new_raw(f: F, axes: [Option; 3]) -> Self { - Self { f, axes } + Self { + f, + axes, + transform: None, + } } } @@ -188,8 +211,13 @@ impl Shape { } impl Shape { - pub fn apply_transform(&self, mat: nalgebra::Matrix4) -> Self { - todo!(); + pub fn apply_transform(mut self, mat: Matrix4) -> Self { + if let Some(prev) = self.transform.as_mut() { + *prev *= mat; + } else { + self.transform = Some(mat); + } + self } } @@ -292,6 +320,7 @@ impl Shape { Ok(Self { f, axes: [x, y, z].map(|v| vs.get(v).cloned()), + transform: None, }) } @@ -319,6 +348,9 @@ pub struct ShapeTape { /// Index of the X, Y, Z axes in the variables array axes: [Option; 3], + + /// Optional transform + transform: Option>, } impl ShapeTape { @@ -335,7 +367,10 @@ pub struct ShapeTracingEval { eval: E, } -impl ShapeTracingEval { +impl ShapeTracingEval +where + ::Data: Transformable, +{ pub fn eval>( &mut self, tape: &ShapeTape, @@ -343,6 +378,15 @@ impl ShapeTracingEval { y: F, z: F, ) -> Result<(E::Data, Option<&E::Trace>), Error> { + let x = x.into(); + let y = y.into(); + let z = z.into(); + let (x, y, z) = if let Some(mat) = tape.transform { + Transformable::transform(x, y, z, mat) + } else { + (x, y, z) + }; + let mut vars = [None, None, None]; if let Some(a) = tape.axes[0] { vars[a] = Some(x.into()); @@ -381,14 +425,23 @@ impl ShapeTracingEval { } } -/// Wrapper struct to convert from [`BulkEvaluator`] to -/// [`shape::TracingEvaluator`](BulkEvaluator) +/// Bulk evaluator for a shape +/// +/// This wraps a generic [`BulkEvaluator`] and exposes an API that takes +/// `(x, y, z)` arguments instead. In addition, it applies the transform +/// associated with the [`ShapeTape`]. #[derive(Debug, Default)] -pub struct ShapeBulkEval { +pub struct ShapeBulkEval { eval: E, + xs: Vec, + ys: Vec, + zs: Vec, } -impl ShapeBulkEval { +impl ShapeBulkEval +where + E::Data: From + Transformable, +{ pub fn eval( &mut self, tape: &ShapeTape, @@ -396,15 +449,33 @@ impl ShapeBulkEval { y: &[E::Data], z: &[E::Data], ) -> Result<&[E::Data], Error> { + let (xs, ys, zs) = if let Some(mat) = tape.transform { + if x.len() != y.len() || x.len() != z.len() { + return Err(Error::MismatchedSlices); + } + let n = x.len(); + self.xs.resize(n, 0.0.into()); + self.ys.resize(n, 0.0.into()); + self.zs.resize(n, 0.0.into()); + for i in 0..n { + let (x, y, z) = Transformable::transform(x[i], y[i], z[i], mat); + self.xs[i] = x; + self.ys[i] = y; + self.zs[i] = z; + } + (self.xs.as_slice(), self.ys.as_slice(), self.zs.as_slice()) + } else { + (x, y, z) + }; let mut vars = [None, None, None]; if let Some(a) = tape.axes[0] { - vars[a] = Some(x); + vars[a] = Some(xs); } if let Some(b) = tape.axes[1] { - vars[b] = Some(y); + vars[b] = Some(ys); } if let Some(c) = tape.axes[2] { - vars[c] = Some(z); + vars[c] = Some(zs); } let n = vars.iter().position(|v| v.is_none()).unwrap_or(3); let vars = if vars.iter().all(Option::is_some) { @@ -418,3 +489,53 @@ impl ShapeBulkEval { self.eval.eval(&tape.tape, &vars[..n]) } } + +pub trait Transformable { + fn transform( + x: Self, + y: Self, + z: Self, + mat: Matrix4, + ) -> (Self, Self, Self) + where + Self: Sized; +} + +impl Transformable for f32 { + fn transform(x: f32, y: f32, z: f32, mat: Matrix4) -> (f32, f32, f32) { + let out = mat.transform_point(&Point3::new(x, y, z)); + (out.x, out.y, out.z) + } +} + +impl Transformable for Interval { + fn transform( + x: Interval, + y: Interval, + z: Interval, + mat: Matrix4, + ) -> (Interval, Interval, Interval) { + let out = [0, 1, 2, 3].map(|i| { + let row = mat.row(i); + x * row[0] + y * row[1] + z * row[2] + Interval::from(row[3]) + }); + + (out[0] / out[3], out[1] / out[3], out[2] / out[3]) + } +} + +impl Transformable for Grad { + fn transform( + x: Grad, + y: Grad, + z: Grad, + mat: Matrix4, + ) -> (Grad, Grad, Grad) { + let out = [0, 1, 2, 3].map(|i| { + let row = mat.row(i); + x * row[0] + y * row[1] + z * row[2] + Grad::from(row[3]) + }); + + (out[0] / out[3], out[1] / out[3], out[2] / out[3]) + } +} diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index bf63dbe3..9e55576c 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -6,12 +6,12 @@ //! ``` //! use fidget::{ //! context::Tree, -//! shape::{TracingEvaluator, Shape, MathShape, EzShape}, +//! shape::EzShape, //! jit::JitShape, //! }; //! //! let tree = Tree::x() + Tree::y(); -//! let shape = JitShape::from_tree(&tree); +//! let shape = JitShape::from(tree); //! //! // Generate machine code to execute the tape //! let tape = shape.ez_point_tape(); diff --git a/fidget/src/lib.rs b/fidget/src/lib.rs index b3cecf2e..bfc05596 100644 --- a/fidget/src/lib.rs +++ b/fidget/src/lib.rs @@ -107,12 +107,12 @@ //! ``` //! use fidget::{ //! context::Tree, -//! shape::{Shape, MathShape, EzShape, TracingEvaluator}, +//! shape::{Shape, EzShape}, //! vm::VmShape //! }; //! //! let tree = Tree::x() + Tree::y(); -//! let shape = VmShape::from_tree(&tree); +//! let shape = VmShape::from(tree); //! let mut interval_eval = VmShape::new_interval_eval(); //! let tape = shape.ez_interval_tape(); //! let (out, _trace) = interval_eval.eval( @@ -136,12 +136,12 @@ //! ``` //! use fidget::{ //! context::Tree, -//! shape::{TracingEvaluator, Shape, MathShape, EzShape}, +//! shape::EzShape, //! vm::VmShape //! }; //! //! let tree = Tree::x().min(Tree::y()); -//! let shape = VmShape::from_tree(&tree); +//! let shape = VmShape::from(tree); //! let mut interval_eval = VmShape::new_interval_eval(); //! let tape = shape.ez_interval_tape(); //! let (out, trace) = interval_eval.eval( @@ -165,11 +165,11 @@ //! ``` //! # use fidget::{ //! # context::Tree, -//! # shape::{TracingEvaluator, Shape, MathShape, EzShape}, +//! # shape::EzShape, //! # vm::VmShape //! # }; //! # let tree = Tree::x().min(Tree::y()); -//! # let shape = VmShape::from_tree(&tree); +//! # let shape = VmShape::from(tree); //! assert_eq!(shape.size(), 3); // min, X, Y //! # let mut interval_eval = VmShape::new_interval_eval(); //! # let tape = shape.ez_interval_tape(); @@ -197,7 +197,6 @@ //! ``` //! use fidget::{ //! context::{Tree, Context}, -//! shape::MathShape, //! render::{BitRenderMode, RenderConfig}, //! vm::VmShape, //! }; @@ -209,7 +208,7 @@ //! image_size: 32, //! ..RenderConfig::default() //! }; -//! let shape = VmShape::from_tree(&tree); +//! let shape = VmShape::from(tree); //! let out = cfg.run::<_, BitRenderMode>(shape)?; //! let mut iter = out.iter(); //! for y in 0..cfg.image_size { diff --git a/fidget/src/mesh/mod.rs b/fidget/src/mesh/mod.rs index 4d558427..d77fd467 100644 --- a/fidget/src/mesh/mod.rs +++ b/fidget/src/mesh/mod.rs @@ -19,13 +19,12 @@ //! //! ``` //! use fidget::{ -//! shape::MathShape, //! mesh::{Octree, Settings}, //! vm::VmShape //! }; //! //! let tree = fidget::rhai::eval("sphere(0, 0, 0, 0.6)")?; -//! let shape = VmShape::from_tree(&tree); +//! let shape = VmShape::from(tree); //! let settings = Settings { //! depth: 4, //! ..Default::default() diff --git a/fidget/src/rhai/mod.rs b/fidget/src/rhai/mod.rs index 6ebe9103..1a66d64f 100644 --- a/fidget/src/rhai/mod.rs +++ b/fidget/src/rhai/mod.rs @@ -7,12 +7,12 @@ //! //! ``` //! use fidget::{ -//! shape::{Shape, MathShape, EzShape, TracingEvaluator}, +//! shape::EzShape, //! vm::VmShape, //! }; //! //! let tree = fidget::rhai::eval("x + y")?; -//! let shape = VmShape::from_tree(&tree); +//! let shape = VmShape::from(tree); //! let mut eval = VmShape::new_point_eval(); //! let tape = shape.ez_point_tape(); //! assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0)?.0, 3.0); @@ -24,16 +24,16 @@ //! //! ``` //! use fidget::{ -//! shape::{Shape, MathShape, EzShape, TracingEvaluator}, +//! shape::EzShape, //! vm::VmShape, //! rhai::Engine //! }; //! //! let mut engine = Engine::new(); -//! let out = engine.run("draw(x + y - 1)")?; +//! let mut out = engine.run("draw(x + y - 1)")?; //! //! assert_eq!(out.shapes.len(), 1); -//! let shape = VmShape::from_tree(&out.shapes[0].tree); +//! let shape = VmShape::from(out.shapes.pop().unwrap().tree); //! let mut eval = VmShape::new_point_eval(); //! let tape = shape.ez_point_tape(); //! assert_eq!(eval.eval(&tape, 0.5, 2.0, 0.0)?.0, 1.5); diff --git a/viewer/src/main.rs b/viewer/src/main.rs index 46789e3b..5836be7c 100644 --- a/viewer/src/main.rs +++ b/viewer/src/main.rs @@ -75,15 +75,15 @@ struct RenderResult { image_size: usize, } -fn render_thread( +fn render_thread( cfg: Receiver, rx: Receiver>, tx: Sender>, wake: Sender<()>, ) -> Result<()> where - S: fidget::shape::Shape - + fidget::shape::MathShape + F: fidget::eval::Function + + fidget::eval::MathFunction + fidget::shape::RenderHints, { let mut config = None; @@ -128,7 +128,7 @@ where ); let render_start = std::time::Instant::now(); for s in out.shapes.iter() { - let tape = S::from_tree(&s.tree); + let tape = fidget::shape::Shape::::from(s.tree.clone()); render( &render_config.mode, tape, @@ -150,9 +150,9 @@ where } } -fn render( +fn render( mode: &RenderMode, - shape: S, + shape: fidget::shape::Shape, image_size: usize, color: [u8; 3], pixels: &mut [egui::Color32], @@ -161,7 +161,7 @@ fn render( RenderMode::TwoD(camera, mode) => { let config = RenderConfig { image_size, - tile_sizes: S::tile_sizes_2d().to_vec(), + tile_sizes: F::tile_sizes_2d().to_vec(), bounds: fidget::shape::Bounds { center: Vector2::new(camera.offset.x, camera.offset.y), size: camera.scale, @@ -213,7 +213,7 @@ fn render( RenderMode::ThreeD(camera, mode) => { let config = RenderConfig { image_size, - tile_sizes: S::tile_sizes_2d().to_vec(), + tile_sizes: F::tile_sizes_2d().to_vec(), bounds: fidget::shape::Bounds { center: Vector3::new(camera.offset.x, camera.offset.y, 0.0), size: camera.scale, @@ -281,17 +281,13 @@ fn main() -> Result<(), Box> { }); std::thread::spawn(move || { #[cfg(feature = "jit")] - type Shape = fidget::jit::JitShape; + type F = fidget::jit::JitFunction; #[cfg(not(feature = "jit"))] - type Shape = fidget::vm::VmShape; - - let _ = render_thread::( - config_rx, - rhai_result_rx, - render_tx, - wake_tx, - ); + type F = fidget::vm::VmFunction; + + let _ = + render_thread::(config_rx, rhai_result_rx, render_tx, wake_tx); info!("render thread is done"); }); diff --git a/wasm-demo/src/lib.rs b/wasm-demo/src/lib.rs index 6fd1fde4..7e0bfc10 100644 --- a/wasm-demo/src/lib.rs +++ b/wasm-demo/src/lib.rs @@ -2,7 +2,6 @@ use fidget::{ context::{Context, Tree}, render::{BitRenderMode, RenderConfig}, shape::Bounds, - shape::MathShape, vm::{VmData, VmShape}, Error, }; From 71e67d46f66ea51872c69b9b6c7f8673ed495d2e Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Fri, 24 May 2024 10:09:52 -0400 Subject: [PATCH 09/12] Use Function's RenderHints directly --- fidget/benches/render.rs | 9 +++++---- fidget/src/core/shape/mod.rs | 14 -------------- fidget/src/mesh/octree.rs | 23 ++++++++++------------- 3 files changed, 15 insertions(+), 31 deletions(-) diff --git a/fidget/benches/render.rs b/fidget/benches/render.rs index c8a13d93..a6199b6f 100644 --- a/fidget/benches/render.rs +++ b/fidget/benches/render.rs @@ -1,6 +1,7 @@ use criterion::{ black_box, criterion_group, criterion_main, BenchmarkId, Criterion, }; +use fidget::shape::RenderHints; const PROSPERO: &str = include_str!("../../models/prospero.vm"); @@ -17,7 +18,7 @@ pub fn prospero_size_sweep(c: &mut Criterion) { for size in [256, 512, 768, 1024, 1280, 1546, 1792, 2048] { let cfg = &fidget::render::RenderConfig { image_size: size, - tile_sizes: fidget::vm::VmShape::tile_sizes_2d().to_vec(), + tile_sizes: fidget::vm::VmFunction::tile_sizes_2d().to_vec(), ..Default::default() }; group.bench_function(BenchmarkId::new("vm", size), move |b| { @@ -34,7 +35,7 @@ pub fn prospero_size_sweep(c: &mut Criterion) { { let cfg = &fidget::render::RenderConfig { image_size: size, - tile_sizes: fidget::jit::JitShape::tile_sizes_2d().to_vec(), + tile_sizes: fidget::jit::JitFunction::tile_sizes_2d().to_vec(), ..Default::default() }; group.bench_function(BenchmarkId::new("jit", size), move |b| { @@ -63,7 +64,7 @@ pub fn prospero_thread_sweep(c: &mut Criterion) { for threads in [1, 2, 4, 8, 16] { let cfg = &fidget::render::RenderConfig { image_size: 1024, - tile_sizes: fidget::vm::VmShape::tile_sizes_2d().to_vec(), + tile_sizes: fidget::vm::VmFunction::tile_sizes_2d().to_vec(), threads: threads.try_into().unwrap(), ..Default::default() }; @@ -80,7 +81,7 @@ pub fn prospero_thread_sweep(c: &mut Criterion) { { let cfg = &fidget::render::RenderConfig { image_size: 1024, - tile_sizes: fidget::jit::JitShape::tile_sizes_2d().to_vec(), + tile_sizes: fidget::jit::JitFunction::tile_sizes_2d().to_vec(), threads: threads.try_into().unwrap(), ..Default::default() }; diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index 562cfb7b..acbe1bd9 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -196,20 +196,6 @@ impl Shape { } } -impl Shape { - pub fn tile_sizes_3d() -> &'static [usize] { - F::tile_sizes_3d() - } - - pub fn tile_sizes_2d() -> &'static [usize] { - F::tile_sizes_2d() - } - - pub fn simplify_tree_during_meshing(d: usize) -> bool { - F::simplify_tree_during_meshing(d) - } -} - impl Shape { pub fn apply_transform(mut self, mat: Matrix4) -> Self { if let Some(prev) = self.transform.as_mut() { diff --git a/fidget/src/mesh/octree.rs b/fidget/src/mesh/octree.rs index 26f5d760..7b922d8a 100644 --- a/fidget/src/mesh/octree.rs +++ b/fidget/src/mesh/octree.rs @@ -366,19 +366,16 @@ impl OctreeBuilder { } else if i.lower() > 0.0 { CellResult::Done(Cell::Empty) } else { - let sub_tape = - if Shape::::simplify_tree_during_meshing(cell.depth) { - let s = self.shape_storage.pop().unwrap_or_default(); - r.map(|r| { - Arc::new(EvalGroup::new( - eval.shape - .simplify(r, s, &mut self.workspace) - .unwrap(), - )) - }) - } else { - None - }; + let sub_tape = if F::simplify_tree_during_meshing(cell.depth) { + let s = self.shape_storage.pop().unwrap_or_default(); + r.map(|r| { + Arc::new(EvalGroup::new( + eval.shape.simplify(r, s, &mut self.workspace).unwrap(), + )) + }) + } else { + None + }; if cell.depth == settings.depth as usize { let eval = sub_tape.unwrap_or_else(|| eval.clone()); let out = CellResult::Done(self.leaf(&eval, cell)); From fac8d80536a5cd17915ed8988e3d72f90d7a31fc Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sat, 25 May 2024 09:24:00 -0400 Subject: [PATCH 10/12] Add docs --- fidget/src/core/eval/tracing.rs | 5 +-- fidget/src/core/shape/mod.rs | 58 +++++++++++++++++++-------------- fidget/src/core/vm/mod.rs | 9 ++--- fidget/src/lib.rs | 2 +- 4 files changed, 42 insertions(+), 32 deletions(-) diff --git a/fidget/src/core/eval/tracing.rs b/fidget/src/core/eval/tracing.rs index 68e54080..c3c0a873 100644 --- a/fidget/src/core/eval/tracing.rs +++ b/fidget/src/core/eval/tracing.rs @@ -12,8 +12,9 @@ use crate::{eval::Tape, Error}; /// Evaluator for single values which simultaneously captures an execution trace /// -/// The trace can later be used to simplify the [`Shape`](crate::eval::Shape) -/// using [`Shape::simplify`](crate::eval::Shape::simplify). +/// The trace can later be used to simplify the +/// [`Function`](crate::eval::Function) +/// using [`Function::simplify`](crate::eval::Function::simplify). pub trait TracingEvaluator: Default { /// Data type used during evaluation type Data: From + Copy + Clone; diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index acbe1bd9..6240171a 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -1,8 +1,13 @@ -//! Traits and data structures for shape evaluation +//! Data structures for shape evaluation //! -//! There are a bunch of things in here, but the most important trait is -//! [`Shape`], followed by the evaluator traits ([`BulkEvaluator`] and -//! [`TracingEvaluator`]). +//! Types in this module are typically thin (generic) wrappers around objects +//! that implement traits in [`fidget::eval`](crate::eval). The wraper types +//! are specialized to operate on `x, y, z` arguments, rather than taking +//! arbitrary numbers of variables. +//! +//! For example, a [`Shape`] is a wrapper which makes it easier to treat a +//! [`Function`] as an implicit surface (with X, Y, Z axes and an optional +//! transform matrix). //! //! ```rust //! use fidget::vm::VmShape; @@ -20,12 +25,6 @@ //! assert_eq!(value, 0.25); //! # Ok::<(), fidget::Error>(()) //! ``` -//! -//! Note that the traits here mirror the ones in ones in -//! [`fidget::eval`](crate::eval), but are specialized to operate on `x, y, z` -//! arguments (rather than taking arbitrary numbers of variables). It is -//! recommended to import the traits from either one or the other, to avoid -//! ambiguity. use crate::{ context::{Context, Node, Tree}, @@ -36,8 +35,6 @@ use crate::{ use nalgebra::{Matrix4, Point3}; mod bounds; - -// Re-export a few things pub use bounds::Bounds; /// A shape represents an implicit surface @@ -175,7 +172,9 @@ impl Shape { pub fn size(&self) -> usize { self.f.size() } +} +impl Shape { /// Borrows the inner [`Function`](Function) object pub fn inner(&self) -> &F { &self.f @@ -194,9 +193,7 @@ impl Shape { transform: None, } } -} - -impl Shape { + /// Returns a shape with the given transform applied pub fn apply_transform(mut self, mat: Matrix4) -> Self { if let Some(prev) = self.transform.as_mut() { *prev *= mat; @@ -294,6 +291,7 @@ pub trait RenderHints { } impl Shape { + /// Builds a new shape from a math expression with the given axes pub fn new_with_axes( ctx: &Context, node: Node, @@ -320,6 +318,7 @@ impl Shape { } } +/// Converts a [`Tree`] to a [`Shape`] with the default axes impl From for Shape { fn from(t: Tree) -> Self { let mut ctx = Context::new(); @@ -328,7 +327,7 @@ impl From for Shape { } } -/// Wrapper struct to bind a generic tape to particular X, Y, Z axes +/// Wrapper around a function tape, with axes and an optional transform matrix pub struct ShapeTape { tape: T, @@ -346,8 +345,10 @@ impl ShapeTape { } } -/// Wrapper struct to convert from [`TracingEvaluator`] to -/// [`shape::TracingEvaluator`](TracingEvaluator) +/// Wrapper around a [`TracingEvaluator`] +/// +/// Unlike the raw tracing evaluator, a [`ShapeTracingEval`] knows about the +/// tape's X, Y, Z axes and optional transform matrix. #[derive(Debug, Default)] pub struct ShapeTracingEval { eval: E, @@ -357,6 +358,9 @@ impl ShapeTracingEval where ::Data: Transformable, { + /// Tracing evaluation of a single sample + /// + /// Before evaluation, the tape's transform matrix is applied (if present). pub fn eval>( &mut self, tape: &ShapeTape, @@ -375,13 +379,13 @@ where let mut vars = [None, None, None]; if let Some(a) = tape.axes[0] { - vars[a] = Some(x.into()); + vars[a] = Some(x); } if let Some(b) = tape.axes[1] { - vars[b] = Some(y.into()); + vars[b] = Some(y); } if let Some(c) = tape.axes[2] { - vars[c] = Some(z.into()); + vars[c] = Some(z); } let n = vars.iter().position(Option::is_none).unwrap_or(3); let vars = vars.map(|v| v.unwrap_or(0f32.into())); @@ -411,11 +415,10 @@ where } } -/// Bulk evaluator for a shape +/// Wrapper around a [`BulkEvaluator`] /// -/// This wraps a generic [`BulkEvaluator`] and exposes an API that takes -/// `(x, y, z)` arguments instead. In addition, it applies the transform -/// associated with the [`ShapeTape`]. +/// Unlike the raw bulk evaluator, a [`ShapeBulkEval`] knows about the +/// tape's X, Y, Z axes and optional transform matrix. #[derive(Debug, Default)] pub struct ShapeBulkEval { eval: E, @@ -428,6 +431,9 @@ impl ShapeBulkEval where E::Data: From + Transformable, { + /// Bulk evaluation of many samples + /// + /// Before evaluation, the tape's transform matrix is applied (if present). pub fn eval( &mut self, tape: &ShapeTape, @@ -476,7 +482,9 @@ where } } +/// Trait for types that can be transformed by a 4x4 homogenous transform matrix pub trait Transformable { + /// Apply the given transform to an `(x, y, z)` position fn transform( x: Self, y: Self, diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index e950009e..aee2363a 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -20,15 +20,16 @@ pub use data::{VmData, VmWorkspace}; //////////////////////////////////////////////////////////////////////////////// -pub type VmFunction = GenericVmFunction<{ u8::MAX as usize }>; - -/// Shape that use a VM backend for evaluation +/// Function which uses the VM backend for evaluation /// -/// Internally, the [`VmShape`] stores an [`Arc`](VmData), and +/// Internally, the [`VmFunction`] stores an [`Arc`](VmData), and /// iterates over a [`Vec`](RegOp) to perform evaluation. /// /// All of the associated [`Tape`] types simply clone the internal `Arc`; /// there's no separate planning required to generate a tape. +pub type VmFunction = GenericVmFunction<{ u8::MAX as usize }>; + +/// Shape that use a the [`VmFunction`] backend for evaluation pub type VmShape = Shape; impl Tape for GenericVmFunction { diff --git a/fidget/src/lib.rs b/fidget/src/lib.rs index bfc05596..f062725a 100644 --- a/fidget/src/lib.rs +++ b/fidget/src/lib.rs @@ -159,7 +159,7 @@ //! tape from `min(x, y) → x`. //! //! Interval evaluation is a kind of -//! [tracing evaluation](crate::shape::TracingEvaluator), which returns a tuple +//! [tracing evaluation](crate::eval::TracingEvaluator), which returns a tuple //! of `(value, trace)`. The trace can be used to simplify the original shape: //! //! ``` From 903e6d7c505cab8e338e4368a4360d95bd361c22 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sat, 25 May 2024 09:48:49 -0400 Subject: [PATCH 11/12] Update CHANGELOG --- CHANGELOG.md | 13 ++++++++++++- fidget/src/core/context/mod.rs | 4 ++-- fidget/src/lib.rs | 28 ++++++++++++++++++---------- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aada9fa3..4a24920a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,15 @@ -# 0.2.8 (unreleased) +# 0.2.8 +- Major refactoring of core evaluation traits + - The lowest-level "thing that can be evaluated" trait has changed from + `Shape` (taking `(x, y, z)` inputs) to `Function` (taking an arbitrary + number of variables). + - `Shape` is now a wrapper around a `F: Function` instead of a trait. + - Shape evaluators are now wrappers around `E: BulkEvaluator` or `E: + TracingEvaluator`, which convert `(x, y, z)` arguments into + list-of-variables arguments. + - Using the `VmShape` or `JitShape` types should be mostly the same as + before; changes are most noticeable if you're writing things that are + generic across `S: Shape`. # 0.2.7 This release brings us to opcode parity with `libfive`'s operators, adding diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index b10e5424..b7659c67 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -10,10 +10,10 @@ //! they have been constructed. //! - A [`Context`] is an arena for unique (deduplicated) math expressions, //! which are represented as [`Node`] handles. Each `Node` is specific to a -//! particular context. Only `Node` objects can be converted into `Shape` +//! particular context. Only `Node` objects can be converted into `Function` //! objects for evaluation. //! -//! In other words, the typical workflow is `Tree → (Context, Node) → Shape`. +//! In other words, the typical workflow is `Tree → (Context, Node) → Function`. mod indexed; mod op; mod tree; diff --git a/fidget/src/lib.rs b/fidget/src/lib.rs index f062725a..4c19bd96 100644 --- a/fidget/src/lib.rs +++ b/fidget/src/lib.rs @@ -68,7 +68,7 @@ //! //! Evaluation is deliberately agnostic to the specific details of how we go //! from position to results. This abstraction is represented by the -//! [`Shape` trait](crate::shape::Shape), which defines how to make both +//! [`Function` trait](crate::eval::Function), which defines how to make both //! **evaluators** and **tapes**. //! //! An **evaluator** is an object which performs evaluation of some kind (point, @@ -77,22 +77,22 @@ //! //! A **tape** contains instructions for an evaluator. //! -//! At the moment, Fidget implements two kinds of shapes: +//! At the moment, Fidget implements two kinds of functions: //! -//! - [`fidget::vm::VmShape`](crate::vm::VmShape) evaluates a list of opcodes -//! using an interpreter. This is slower, but can run in more situations -//! (e.g. in WebAssembly). -//! - [`fidget::jit::JitShape`](crate::jit::JitShape) performs fast evaluation -//! by compiling shapes down to native code. +//! - [`fidget::vm::VmFunction`](crate::vm::VmFunction) evaluates a list of +//! opcodes using an interpreter. This is slower, but can run in more +//! situations (e.g. in WebAssembly). +//! - [`fidget::jit::JitFunction`](crate::jit::JitFunction) performs fast +//! evaluation by compiling expressions down to native code. //! -//! The [`Shape`](crate::shape::Shape) trait requires four different kinds +//! The [`Function`](crate::eval::Function) trait requires four different kinds //! of evaluation: //! //! - Single-point evaluation //! - Interval evaluation //! - Evaluation on an array of points, returning `f32` values //! - Evaluation on an array of points, returning partial derivatives with -//! respect to `x, y, z` +//! respect to input variables //! //! These evaluation flavors are used in rendering: //! - Interval evaluation can conservatively prove large regions of space to be @@ -103,7 +103,15 @@ //! - At the surface of the model, partial derivatives represent normals and //! can be used for shading. //! -//! Here's a simple example of interval evaluation: +//! # Functions and shapes +//! The [`Function`](crate::eval::Function) trait supports arbitrary numbers of +//! varibles; when using it for implicit surfaces, it's common to wrap it in a +//! [`Shape`](crate::shape::Shape), which binds `(x, y, z)` axes to specific +//! variables. +//! +//! Here's a simple example of interval evaluation, using a `Shape` to wrap a +//! function and evaluate it at a particular `(x, y, z)` position: +//! //! ``` //! use fidget::{ //! context::Tree, From ca25dae6b27dba063218fc37e8885c1ff2a6cd81 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sat, 25 May 2024 10:11:12 -0400 Subject: [PATCH 12/12] Fix comment --- fidget/src/jit/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index 9e55576c..f619aa86 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -1122,7 +1122,6 @@ impl + Copy + SimdSize> JitBulkEval { if n < T::SIMD_SIZE { assert!(T::SIMD_SIZE <= MAX_SIMD_WIDTH); - // TODO reuse allocation here allocation here? self.scratch.resize(n, [T::from(0.0); MAX_SIMD_WIDTH]); for (v, t) in vars.iter().zip(self.scratch.iter_mut()) { t[0..n].copy_from_slice(v);