diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a24920a..058c2b4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,16 @@ - Using the `VmShape` or `JitShape` types should be mostly the same as before; changes are most noticeable if you're writing things that are generic across `S: Shape`. +- Major refactoring of how variables are handled + - Removed `VarNode`; the canonical variable type is `Var`, which is its own + unique index. + - Removed named variables, to make `Var` trivially `Copy + Clone`. + - Added `vars()` method to `Function` trait, allowing users to look up the + mapping from variable to evaluation index. A `Var` now represents a + persistent identity from `Tree` to `Context` to `Function` evaluation. + - Move `Var` and `VarMap` into `fidget::vars` module, because they're no + longer specific to a `Context`. + - `Op::Input` now takes a `u32` argument, instead of a `u8` # 0.2.7 This release brings us to opcode parity with `libfive`'s operators, adding diff --git a/Cargo.lock b/Cargo.lock index d5a99c76..c155e5be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -892,6 +892,7 @@ dependencies = [ "num-derive", "num-traits", "ordered-float", + "rand", "rhai", "serde", "static_assertions", @@ -1777,6 +1778,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro-crate" version = "1.3.1" @@ -1838,6 +1845,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "raw-window-handle" version = "0.5.2" @@ -2908,6 +2945,7 @@ dependencies = [ "clap", "clap_builder", "crossbeam-utils", + "getrandom", "libc", "log", "memchr", diff --git a/fidget/Cargo.toml b/fidget/Cargo.toml index 21327cf8..204ef22a 100644 --- a/fidget/Cargo.toml +++ b/fidget/Cargo.toml @@ -17,10 +17,11 @@ nalgebra = "0.31" num-derive = "0.3" num-traits = "0.2" ordered-float = "3" +rand = "0.8.5" static_assertions = "1" thiserror = "1" workspace-hack = { version = "0.1", path = "../workspace-hack" } -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", features = ["derive", "rc"] } # JIT dynasmrt = { version = "2.0", optional = true } diff --git a/fidget/src/core/compiler/alloc.rs b/fidget/src/core/compiler/alloc.rs index e4ca6f60..d5eb418c 100644 --- a/fidget/src/core/compiler/alloc.rs +++ b/fidget/src/core/compiler/alloc.rs @@ -273,7 +273,7 @@ impl RegisterAllocator { #[inline(always)] pub fn op(&mut self, op: SsaOp) { match op { - SsaOp::Input(out, i) => self.op_input(out, i.try_into().unwrap()), + SsaOp::Input(out, i) => self.op_input(out, i), SsaOp::CopyImm(out, imm) => self.op_copy_imm(out, imm), SsaOp::NegReg(..) @@ -673,9 +673,7 @@ impl RegisterAllocator { /// Pushes an [`Input`](crate::compiler::RegOp::Input) operation to the tape #[inline(always)] - fn op_input(&mut self, out: u32, i: u8) { - // TODO: tightly pack variables (which may be sparse) into slots - self.out.var_count = self.out.var_count.max(i as u32 + 1); + fn op_input(&mut self, out: u32, i: u32) { self.op_out_only(out, |out| RegOp::Input(out, i)); } } diff --git a/fidget/src/core/compiler/op.rs b/fidget/src/core/compiler/op.rs index d003da23..73369312 100644 --- a/fidget/src/core/compiler/op.rs +++ b/fidget/src/core/compiler/op.rs @@ -23,8 +23,8 @@ macro_rules! opcodes { ) => { $(#[$($attrss)*])* pub enum $name { - #[doc = "Read one of the inputs (X, Y, Z)"] - Input($t, $t), + #[doc = "Read an input variable by index"] + Input($t, u32), #[doc = "Negate the given register"] NegReg($t, $t), diff --git a/fidget/src/core/compiler/reg_tape.rs b/fidget/src/core/compiler/reg_tape.rs index 4887019b..77e54245 100644 --- a/fidget/src/core/compiler/reg_tape.rs +++ b/fidget/src/core/compiler/reg_tape.rs @@ -10,9 +10,6 @@ pub struct RegTape { /// Total allocated slots pub(super) slot_count: u32, - - /// Number of variables - pub(super) var_count: u32, } impl RegTape { @@ -35,7 +32,6 @@ impl RegTape { Self { tape: vec![], slot_count: 1, - var_count: 0, } } @@ -50,11 +46,6 @@ impl RegTape { pub fn slot_count(&self) -> usize { self.slot_count as usize } - /// Returns the number of variables (inputs) used in this tape - #[inline] - pub fn var_count(&self) -> usize { - self.var_count as usize - } /// Returns the number of elements in the tape #[inline] pub fn len(&self) -> usize { diff --git a/fidget/src/core/compiler/ssa_tape.rs b/fidget/src/core/compiler/ssa_tape.rs index 80e4fc55..1bea9a29 100644 --- a/fidget/src/core/compiler/ssa_tape.rs +++ b/fidget/src/core/compiler/ssa_tape.rs @@ -2,7 +2,7 @@ use crate::{ compiler::SsaOp, context::{BinaryOpcode, Node, Op, UnaryOpcode}, - eval::VarMap, + var::VarMap, Context, Error, }; use serde::{Deserialize, Serialize}; @@ -33,7 +33,10 @@ impl SsaTape { /// /// This should always succeed unless the `root` is from a different /// `Context`, in which case `Error::BadNode` will be returned. - pub fn new(ctx: &Context, root: Node) -> Result<(Self, VarMap), Error> { + pub fn new( + ctx: &Context, + root: Node, + ) -> Result<(Self, VarMap), Error> { let mut mapping = HashMap::new(); let mut parent_count: HashMap = HashMap::new(); let mut slot_count = 0; @@ -47,7 +50,7 @@ impl SsaTape { // Accumulate parent counts and declare all nodes let mut seen = HashSet::new(); - let mut vars = HashMap::new(); + let mut vars = VarMap::new(); let mut todo = vec![root]; while let Some(node) = todo.pop() { if !seen.insert(node) { @@ -61,8 +64,7 @@ impl SsaTape { _ => { if let Op::Input(v) = op { let next = vars.len(); - let v = ctx.get_var_by_index(*v).unwrap().clone(); - vars.entry(v).or_insert(next); + vars.entry(*v).or_insert(next); } let i = slot_count; slot_count += 1; @@ -97,8 +99,8 @@ impl SsaTape { continue; }; let op = match op { - Op::Input(..) => { - let arg = vars[ctx.var_name(node).unwrap().unwrap()]; + Op::Input(v) => { + let arg = vars[v]; SsaOp::Input(i, arg.try_into().unwrap()) } Op::Const(..) => { diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index b7659c67..0dc8eeeb 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -4,7 +4,8 @@ //! //! - A [`Tree`] is a free-floating math expression, which can be cloned //! and has overloaded operators for ease of use. It is **not** deduplicated; -//! two calls to [`Tree::x`] will produce two different [`TreeOp`] objects. +//! two calls to [`Tree::constant(1.0)`](Tree::constant) will allocate two +//! different [`TreeOp`] objects. //! `Tree` objects are typically used when building up expressions; they //! should be converted to `Node` objects (in a particular `Context`) after //! they have been constructed. @@ -22,7 +23,7 @@ use indexed::{define_index, Index, IndexMap, IndexVec}; pub use op::{BinaryOpcode, Op, UnaryOpcode}; pub use tree::{Tree, TreeOp}; -use crate::Error; +use crate::{var::Var, Error}; use std::collections::{BTreeMap, HashMap}; use std::fmt::Write; @@ -32,7 +33,6 @@ use std::sync::Arc; use ordered_float::OrderedFloat; define_index!(Node, "An index in the `Context::ops` map"); -define_index!(VarNode, "An index in the `Context::vars` map"); /// A `Context` holds a set of deduplicated constants, variables, and /// operations. @@ -42,39 +42,6 @@ define_index!(VarNode, "An index in the `Context::vars` map"); #[derive(Debug, Default)] pub struct Context { ops: IndexMap, - vars: IndexMap, -} - -/// A `Var` represents a value which can vary during evaluation -/// -/// We pre-define common variables (e.g. X, Y, Z) but also allow for fully -/// customized values. -#[allow(missing_docs)] -#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] -pub enum Var { - X, - Y, - Z, - W, - T, - Static(&'static str), - Named(String), - Value(u64), -} - -impl std::fmt::Display for Var { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Var::X => write!(f, "X"), - Var::Y => write!(f, "Y"), - Var::Z => write!(f, "Z"), - Var::W => write!(f, "W"), - Var::T => write!(f, "T"), - Var::Static(s) => write!(f, "{s}"), - Var::Named(s) => write!(f, "{s}"), - Var::Value(v) => write!(f, "v_{v}"), - } - } } impl Context { @@ -85,7 +52,7 @@ impl Context { /// Clears the context /// - /// All [`Node`] and [`VarNode`] handles from this context are invalidated. + /// All [`Node`] handles from this context are invalidated. /// /// ``` /// # use fidget::context::Context; @@ -96,7 +63,6 @@ impl Context { /// ``` pub fn clear(&mut self) { self.ops.clear(); - self.vars.clear(); } /// Returns the number of [`Op`] nodes in the context @@ -155,25 +121,17 @@ impl Context { /// /// If the node is invalid for this tree, returns an error; if the node is /// not an `Op::Input`, returns `Ok(None)`. - pub fn var_name(&self, n: Node) -> Result, Error> { + pub fn get_var(&self, n: Node) -> Result, Error> { match self.get_op(n) { - Some(Op::Input(c)) => self.get_var_by_index(*c).map(Some), + Some(Op::Input(v)) => Ok(Some(*v)), Some(_) => Ok(None), _ => Err(Error::BadNode), } } - /// Looks up the [`Var`] associated with the given [`VarNode`] - pub fn get_var_by_index(&self, n: VarNode) -> Result<&Var, Error> { - match self.vars.get_by_index(n) { - Some(c) => Ok(c), - None => Err(Error::BadVar), - } - } - //////////////////////////////////////////////////////////////////////////// // Primitives - /// Constructs or finds a variable node named "X" + /// Constructs or finds a [`Var::X`] node /// ``` /// # use fidget::context::Context; /// let mut ctx = Context::new(); @@ -182,19 +140,21 @@ impl Context { /// assert_eq!(v, 1.0); /// ``` pub fn x(&mut self) -> Node { - let v = self.vars.insert(Var::X); - self.ops.insert(Op::Input(v)) + self.var(Var::X) } - /// Constructs or finds a variable node named "Y" + /// Constructs or finds a [`Var::Y`] node pub fn y(&mut self) -> Node { - let v = self.vars.insert(Var::Y); - self.ops.insert(Op::Input(v)) + self.var(Var::Y) } - /// Constructs or finds a variable node named "Z" + /// Constructs or finds a [`Var::Z`] node pub fn z(&mut self) -> Node { - let v = self.vars.insert(Var::Z); + self.var(Var::Z) + } + + /// Constructs or finds a variable input node + pub fn var(&mut self, v: Var) -> Node { self.ops.insert(Op::Input(v)) } @@ -822,10 +782,7 @@ impl Context { } let mut get = |n: Node| self.eval_inner(n, vars, cache); let v = match self.get_op(node).ok_or(Error::BadNode)? { - Op::Input(v) => { - let var_name = self.vars.get_by_index(*v).unwrap(); - *vars.get(var_name).unwrap() - } + Op::Input(v) => *vars.get(v).unwrap(), Op::Const(c) => c.0, Op::Binary(op, a, b) => { @@ -995,7 +952,6 @@ impl Context { match op { Op::Const(c) => write!(out, "{}", c).unwrap(), Op::Input(v) => { - let v = self.vars.get_by_index(*v).unwrap(); out += &v.to_string(); } Op::Binary(op, ..) => match op { @@ -1245,9 +1201,9 @@ mod test { let c8 = ctx.sub(c7, r).unwrap(); let c9 = ctx.max(c8, c6).unwrap(); - let (tape, vs) = VmData::<255>::new(&ctx, c9).unwrap(); + let tape = VmData::<255>::new(&ctx, c9).unwrap(); assert_eq!(tape.len(), 8); - assert_eq!(vs.len(), 2); + assert_eq!(tape.vars.len(), 2); } #[test] @@ -1256,8 +1212,8 @@ mod test { let x = ctx.x(); let x_squared = ctx.mul(x, x).unwrap(); - let (tape, vs) = VmData::<255>::new(&ctx, x_squared).unwrap(); + let tape = VmData::<255>::new(&ctx, x_squared).unwrap(); assert_eq!(tape.len(), 2); - assert_eq!(vs.len(), 1); + assert_eq!(tape.vars.len(), 1); } } diff --git a/fidget/src/core/context/op.rs b/fidget/src/core/context/op.rs index e8ce6b74..bd77c63f 100644 --- a/fidget/src/core/context/op.rs +++ b/fidget/src/core/context/op.rs @@ -1,4 +1,7 @@ -use crate::context::{indexed::Index, Node, VarNode}; +use crate::{ + context::{indexed::Index, Node}, + var::Var, +}; use ordered_float::OrderedFloat; /// A one-argument math operation @@ -53,7 +56,7 @@ pub enum BinaryOpcode { #[allow(missing_docs)] #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] pub enum Op { - Input(VarNode), + Input(Var), Const(OrderedFloat), Binary(BinaryOpcode, Node, Node), Unary(UnaryOpcode, Node), diff --git a/fidget/src/core/eval/mod.rs b/fidget/src/core/eval/mod.rs index 680d58f3..88ce2c08 100644 --- a/fidget/src/core/eval/mod.rs +++ b/fidget/src/core/eval/mod.rs @@ -1,7 +1,8 @@ //! Traits and data structures for function evaluation use crate::{ - context::{Context, Node, Tree, Var}, + context::{Context, Node}, types::{Grad, Interval}, + var::VarMap, Error, }; @@ -165,29 +166,15 @@ pub trait Function: Send + Sync + Clone { /// This is underspecified and only used for unit testing; for tape-based /// functions, it's typically the length of the tape, fn size(&self) -> usize; -} -/// Map from variable (from a particular [`Context`]) to index -pub type VarMap = std::collections::HashMap; + /// Returns a map from variable to index + fn vars(&self) -> &VarMap; +} /// A [`Function`] which can be built from a math expression -pub trait MathFunction { +pub trait MathFunction: Function { /// Builds a new function from the given context and node - /// - /// Returns a tuple of the [`MathFunction`] and a [`VarMap`] representing a - /// mapping from variables (in the [`Context`]) to indices used during - /// evaluation. - fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> + fn new(ctx: &Context, node: Node) -> Result where Self: Sized; - - /// Helper function to build a function from a [`Tree`] - fn from_tree(t: &Tree) -> (Self, VarMap) - where - Self: Sized, - { - let mut ctx = Context::new(); - let node = ctx.import(t); - Self::new(&ctx, node).unwrap() - } } diff --git a/fidget/src/core/eval/tracing.rs b/fidget/src/core/eval/tracing.rs index c3c0a873..690f736f 100644 --- a/fidget/src/core/eval/tracing.rs +++ b/fidget/src/core/eval/tracing.rs @@ -51,9 +51,7 @@ pub trait TracingEvaluator: Default { vars: &[Self::Data], var_count: usize, ) -> Result<(), Error> { - // It's fine if the caller has given us extra variables (e.g. due to - // tape simplification), but it must have given us enough. - if vars.len() < var_count { + if vars.len() != var_count { Err(Error::BadVarSlice(vars.len(), var_count)) } else { Ok(()) diff --git a/fidget/src/core/mod.rs b/fidget/src/core/mod.rs index 37ed2934..25ff28c4 100644 --- a/fidget/src/core/mod.rs +++ b/fidget/src/core/mod.rs @@ -60,11 +60,13 @@ pub mod compiler; pub mod eval; pub mod shape; pub mod types; +pub mod var; pub mod vm; #[cfg(test)] mod test { use crate::context::*; + use crate::var::Var; #[test] fn it_works() { diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index 6240171a..d37c4f30 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -30,6 +30,7 @@ use crate::{ context::{Context, Node, Tree}, eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator}, types::{Grad, Interval}, + var::Var, Error, }; use nalgebra::{Matrix4, Point3}; @@ -50,6 +51,10 @@ pub struct Shape { f: F, /// Index of x, y, z axes within the function's variable list (if present) + /// + /// We could instead store an array of [`Var`]s and look them in with + /// [`self.f.vars()[&v]`](Function::vars), but it's more efficient to cache + /// them upon construction (because they never change). axes: [Option; 3], /// Optional transform to apply to the shape @@ -295,15 +300,14 @@ impl Shape { pub fn new_with_axes( ctx: &Context, node: Node, - axes: [Node; 3], + axes: [Var; 3], ) -> Result { - let (f, vs) = F::new(ctx, node)?; - let x = ctx.var_name(axes[0])?.ok_or(Error::NotAVar)?; - let y = ctx.var_name(axes[1])?.ok_or(Error::NotAVar)?; - let z = ctx.var_name(axes[2])?.ok_or(Error::NotAVar)?; + let f = F::new(ctx, node)?; + let vars = f.vars(); + let axes = axes.map(|v| vars.get(&v).cloned()); Ok(Self { f, - axes: [x, y, z].map(|v| vs.get(v).cloned()), + axes, transform: None, }) } @@ -313,8 +317,7 @@ impl Shape { where Self: Sized, { - let axes = ctx.axes(); - Self::new_with_axes(ctx, node, axes) + Self::new_with_axes(ctx, node, [Var::X, Var::Y, Var::Z]) } } diff --git a/fidget/src/core/var/mod.rs b/fidget/src/core/var/mod.rs new file mode 100644 index 00000000..b9a23f50 --- /dev/null +++ b/fidget/src/core/var/mod.rs @@ -0,0 +1,184 @@ +//! Input variables to math expressions +//! +//! A [`Var`] maintains a persistent identity from +//! [`Tree`](crate::context::Tree) to [`Context`](crate::context::Node) (where +//! it is wrapped in a [`Op::Input`](crate::context::Op::Input)) to evaluation +//! (where [`Function::vars`](crate::eval::Function::vars) maps from `Var` to +//! index in the argument list). +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// The [`Var`] type is an input to a math expression +/// +/// We pre-define common variables (e.g. `X`, `Y`, `Z`) but also allow for fully +/// customized values (using [`Var::V`]). +/// +/// Variables are "global", in that every instance of `Var::X` represents the +/// same thing. To generate a "local" variable, [`Var::new`] picks a random +/// 64-bit value, which is very unlikely to collide with anything else. +#[allow(missing_docs)] +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub enum Var { + X, + Y, + Z, + V(u64), +} + +impl Var { + /// Returns a new variable, with a random 64-bit index + /// + /// The odds of collision with any previous variable are infintesimally + /// small; if you are generating billions of random variables, something + /// else in the system is likely to break before collisions become an issue. + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let v: u64 = rand::random(); + Var::V(v) + } +} + +impl std::fmt::Display for Var { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Var::X => write!(f, "X"), + Var::Y => write!(f, "Y"), + Var::Z => write!(f, "Z"), + Var::V(v) if *v < 256 => write!(f, "v_{v}"), + Var::V(v) => write!(f, "V({v:x})"), + } + } +} + +/// Map from [`Var`] to a particular value +/// +/// This is equivalent to a +/// [`HashMap`](std::collections::HashMap) and as such does not include +/// per-function documentation. +/// +/// The advantage over a `HashMap` is that for common variables (`X`, `Y`, `Z`), +/// no allocation is required. +#[derive(Serialize, Deserialize)] +pub struct VarMap { + x: Option, + y: Option, + z: Option, + v: HashMap, +} + +impl Default for VarMap { + fn default() -> Self { + Self { + x: None, + y: None, + z: None, + v: HashMap::default(), + } + } +} + +#[allow(missing_docs)] +impl VarMap { + pub fn new() -> Self { + Self::default() + } + pub fn len(&self) -> usize { + self.x.is_some() as usize + + self.y.is_some() as usize + + self.z.is_some() as usize + + self.v.len() + } + pub fn is_empty(&self) -> bool { + self.x.is_none() + && self.y.is_none() + && self.z.is_none() + && self.v.is_empty() + } + pub fn get(&self, v: &Var) -> Option<&T> { + match v { + Var::X => self.x.as_ref(), + Var::Y => self.y.as_ref(), + Var::Z => self.z.as_ref(), + Var::V(v) => self.v.get(v), + } + } + + pub fn get_mut(&mut self, v: &Var) -> Option<&mut T> { + match v { + Var::X => self.x.as_mut(), + Var::Y => self.y.as_mut(), + Var::Z => self.z.as_mut(), + Var::V(v) => self.v.get_mut(v), + } + } + + pub fn entry(&mut self, v: Var) -> VarMapEntry { + match v { + Var::X => VarMapEntry::Option(&mut self.x), + Var::Y => VarMapEntry::Option(&mut self.y), + Var::Z => VarMapEntry::Option(&mut self.z), + Var::V(v) => VarMapEntry::Hash(self.v.entry(v)), + } + } +} + +/// Entry into a [`VarMap`]; equivalent to [`std::collections::hash_map::Entry`] +/// +/// The implementation has just enough functions to be useful; if you find +/// yourself wanting the rest of the entry API, it could easily be expanded. +#[allow(missing_docs)] +pub enum VarMapEntry<'a, T> { + Option(&'a mut Option), + Hash(std::collections::hash_map::Entry<'a, u64, T>), +} + +#[allow(missing_docs)] +impl<'a, T> VarMapEntry<'a, T> { + pub fn or_insert(self, default: T) -> &'a mut T { + match self { + VarMapEntry::Option(o) => match o { + Some(v) => v, + None => { + *o = Some(default); + o.as_mut().unwrap() + } + }, + VarMapEntry::Hash(e) => e.or_insert(default), + } + } +} + +impl std::ops::Index<&Var> for VarMap { + type Output = T; + fn index(&self, v: &Var) -> &Self::Output { + match v { + Var::X => self.x.as_ref().unwrap(), + Var::Y => self.y.as_ref().unwrap(), + Var::Z => self.z.as_ref().unwrap(), + Var::V(v) => &self.v[v], + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn var_identity() { + let v1 = Var::new(); + let v2 = Var::new(); + assert_ne!(v1, v2); + } + + #[test] + fn var_map() { + let v = Var::new(); + let mut m = VarMap::new(); + assert!(m.get(&v).is_none()); + let p = m.entry(v).or_insert(123); + assert_eq!(*p, 123); + let p = m.entry(v).or_insert(456); + assert_eq!(*p, 123); + } +} diff --git a/fidget/src/core/vm/data.rs b/fidget/src/core/vm/data.rs index d60375c0..9523e997 100644 --- a/fidget/src/core/vm/data.rs +++ b/fidget/src/core/vm/data.rs @@ -2,11 +2,12 @@ use crate::{ compiler::{RegOp, RegTape, RegisterAllocator, SsaOp, SsaTape}, context::{Context, Node}, - eval::VarMap, + var::VarMap, vm::Choice, Error, }; use serde::{Deserialize, Serialize}; +use std::sync::Arc; /// A flattened math expression, ready for evaluation or further compilation. /// @@ -39,19 +40,21 @@ use serde::{Deserialize, Serialize}; /// ``` /// use fidget::{ /// compiler::RegOp, -/// context::{Context, Tree, Var}, +/// context::{Context, Tree}, /// vm::VmData, +/// var::Var, /// }; /// /// let tree = Tree::x() + Tree::y(); /// let mut ctx = Context::new(); /// let sum = ctx.import(&tree); -/// let (data, vars) = VmData::<255>::new(&ctx, sum)?; +/// let data = VmData::<255>::new(&ctx, sum)?; /// assert_eq!(data.len(), 3); // X, Y, and (X + Y) /// /// let mut iter = data.iter_asm(); -/// assert_eq!(iter.next().unwrap(), RegOp::Input(0, vars[&Var::X] as u8)); -/// assert_eq!(iter.next().unwrap(), RegOp::Input(1, vars[&Var::Y] as u8)); +/// let vars = &data.vars; // map from var to index +/// assert_eq!(iter.next().unwrap(), RegOp::Input(0, vars[&Var::X] as u32)); +/// assert_eq!(iter.next().unwrap(), RegOp::Input(1, vars[&Var::Y] as u32)); /// assert_eq!(iter.next().unwrap(), RegOp::AddRegReg(0, 0, 1)); /// # Ok::<(), fidget::Error>(()) /// ``` @@ -63,14 +66,24 @@ use serde::{Deserialize, Serialize}; pub struct VmData { ssa: SsaTape, asm: RegTape, + + /// Mapping from variables to indices during evaluation + /// + /// This member is stored in a shared pointer because it's passed down to + /// children (constructed with [`VmData::simplify`]). + pub vars: Arc>, } impl VmData { /// Builds a new tape for the given node - pub fn new(context: &Context, node: Node) -> Result<(Self, VarMap), Error> { - let (ssa, vs) = SsaTape::new(context, node)?; + pub fn new(context: &Context, node: Node) -> Result { + let (ssa, vars) = SsaTape::new(context, node)?; let asm = RegTape::new::(&ssa); - Ok((Self { ssa, asm }, vs)) + Ok(Self { + ssa, + asm, + vars: vars.into(), + }) } /// Returns the length of the internal VM tape @@ -96,9 +109,12 @@ impl VmData { self.asm.slot_count() } - /// Returns the number of variables (inputs) in the inner VM tape + /// Returns the number of variables that may be used + /// + /// Note that this can sometimes be an overestimate, if the inner tape has + /// been simplified. pub fn var_count(&self) -> usize { - self.asm.var_count() + self.vars.len() } /// Simplifies both inner tapes, using the provided choice array @@ -287,6 +303,7 @@ impl VmData { choice_count, }, asm: asm_tape, + vars: self.vars.clone(), }) } diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index aee2363a..5b4bbf1c 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -4,10 +4,10 @@ use crate::{ context::Node, eval::{ BulkEvaluator, Function, MathFunction, Tape, Trace, TracingEvaluator, - VarMap, }, shape::{RenderHints, Shape}, types::{Grad, Interval}, + var::VarMap, Context, Error, }; use std::sync::Arc; @@ -177,6 +177,10 @@ impl Function for GenericVmFunction { fn size(&self) -> usize { GenericVmFunction::size(self) } + + fn vars(&self) -> &VarMap { + &self.0.vars + } } impl RenderHints for GenericVmFunction { @@ -190,9 +194,9 @@ impl RenderHints for GenericVmFunction { } impl MathFunction for GenericVmFunction { - fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> { - let (d, vs) = VmData::new(ctx, node)?; - Ok((Self(Arc::new(d)), vs)) + fn new(ctx: &Context, node: Node) -> Result { + let d = VmData::new(ctx, node)?; + Ok(Self(d.into())) } } diff --git a/fidget/src/jit/aarch64/float_slice.rs b/fidget/src/jit/aarch64/float_slice.rs index c7bb3121..2e8a6e96 100644 --- a/fidget/src/jit/aarch64/float_slice.rs +++ b/fidget/src/jit/aarch64/float_slice.rs @@ -156,10 +156,10 @@ impl Assembler for FloatSliceAssembler { ) } /// Copies the given input to `out_reg` - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - assert!(src_arg as u32 * 8 < 16384); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + assert!(src_arg < 16384 / 8); dynasm!(self.0.ops - ; ldr x4, [x0, src_arg as u32 * 8] + ; ldr x4, [x0, src_arg * 8] ; add x4, x4, x3 // apply array offset ; ldr Q(reg(out_reg)), [x4] ); diff --git a/fidget/src/jit/aarch64/grad_slice.rs b/fidget/src/jit/aarch64/grad_slice.rs index 2a53db33..73fc54cd 100644 --- a/fidget/src/jit/aarch64/grad_slice.rs +++ b/fidget/src/jit/aarch64/grad_slice.rs @@ -158,10 +158,10 @@ impl Assembler for GradSliceAssembler { ) } /// Copies the given input to `out_reg` - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - assert!(src_arg as u32 * 8 < 16384); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + assert!(src_arg < 16384 / 8); dynasm!(self.0.ops - ; ldr x4, [x0, src_arg as u32 * 8] + ; ldr x4, [x0, src_arg * 8] ; add x4, x4, x3 // apply array offset ; eor V(reg(out_reg)).b16, V(reg(out_reg)).b16, V(reg(out_reg)).b16 ; ldr Q(reg(out_reg)), [x4] diff --git a/fidget/src/jit/aarch64/interval.rs b/fidget/src/jit/aarch64/interval.rs index a19179c0..8834bb5d 100644 --- a/fidget/src/jit/aarch64/interval.rs +++ b/fidget/src/jit/aarch64/interval.rs @@ -120,10 +120,10 @@ impl Assembler for IntervalAssembler { dynasm!(self.0.ops ; str D(reg(src_reg)), [sp, sp_offset]) } /// Copies the given input to `out_reg` - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - assert!(src_arg as u32 * 8 < 16384); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + assert!(src_arg < 16384 / 8); dynasm!(self.0.ops - ; ldr D(reg(out_reg)), [x0, src_arg as u32 * 8] + ; ldr D(reg(out_reg)), [x0, src_arg * 8] ); } fn build_sin(&mut self, out_reg: u8, lhs_reg: u8) { diff --git a/fidget/src/jit/aarch64/point.rs b/fidget/src/jit/aarch64/point.rs index dc6942a4..8d3704a5 100644 --- a/fidget/src/jit/aarch64/point.rs +++ b/fidget/src/jit/aarch64/point.rs @@ -115,10 +115,10 @@ impl Assembler for PointAssembler { dynasm!(self.0.ops ; str S(reg(src_reg)), [sp, sp_offset]) } /// Copies the given input to `out_reg` - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - assert!(src_arg as u32 * 4 < 16384); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + assert!(src_arg < 16384 / 4); dynasm!(self.0.ops - ; ldr S(reg(out_reg)), [x0, src_arg as u32 * 4] + ; ldr S(reg(out_reg)), [x0, src_arg * 4] ); } fn build_copy(&mut self, out_reg: u8, lhs_reg: u8) { diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index f619aa86..85199c56 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -26,12 +26,11 @@ use crate::{ compiler::RegOp, context::{Context, Node}, - eval::{ - BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator, VarMap, - }, + eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator}, jit::mmap::Mmap, shape::RenderHints, types::{Grad, Interval}, + var::VarMap, vm::{Choice, GenericVmFunction, VmData, VmTrace, VmWorkspace}, Error, }; @@ -132,7 +131,7 @@ trait Assembler { fn build_store(&mut self, dst_mem: u32, src_reg: u8); /// Copies the given input to `out_reg` - fn build_input(&mut self, out_reg: u8, src_arg: u8); + fn build_input(&mut self, out_reg: u8, src_arg: u32); /// Copies a register fn build_copy(&mut self, out_reg: u8, lhs_reg: u8); @@ -900,6 +899,10 @@ impl Function for JitFunction { fn size(&self) -> usize { self.0.size() } + + fn vars(&self) -> &VarMap { + self.0.vars() + } } impl RenderHints for JitFunction { @@ -989,7 +992,6 @@ impl JitTracingEval { ) -> (T, Option<&VmTrace>) { let mut simplify = 0; self.choices.resize(tape.choice_count, Choice::Unknown); - assert!(tape.var_count <= 3); self.choices.fill(Choice::Unknown); let out = unsafe { (tape.fn_trace)( @@ -1111,7 +1113,6 @@ unsafe impl Sync for JitBulkFn {} impl + Copy + SimdSize> JitBulkEval { /// Evaluate multiple points fn eval(&mut self, tape: &JitBulkFn, vars: &[&[T]]) -> &[T] { - assert!(tape.var_count <= 3); let n = vars.first().map(|v| v.len()).unwrap_or(0); self.out.resize(n, f32::NAN.into()); self.out.fill(f32::NAN.into()); @@ -1211,9 +1212,8 @@ impl BulkEvaluator for JitGradSliceEval { } impl MathFunction for JitFunction { - fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> { - let (f, vars) = GenericVmFunction::new(ctx, node)?; - Ok((JitFunction(f), vars)) + fn new(ctx: &Context, node: Node) -> Result { + GenericVmFunction::new(ctx, node).map(JitFunction) } } diff --git a/fidget/src/jit/x86_64/float_slice.rs b/fidget/src/jit/x86_64/float_slice.rs index db7c59e8..c1a4e758 100644 --- a/fidget/src/jit/x86_64/float_slice.rs +++ b/fidget/src/jit/x86_64/float_slice.rs @@ -99,8 +99,8 @@ impl Assembler for FloatSliceAssembler { ; vmovups [rsp + sp_offset], Ry(reg(src_reg)) ); } - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - let pos = 8 * (src_arg as i32); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + let pos = 8 * i32::try_from(src_arg).unwrap(); dynasm!(self.0.ops ; mov r8, [rdi + pos] // read the *const float from the array ; add r8, rcx // offset it by array position diff --git a/fidget/src/jit/x86_64/grad_slice.rs b/fidget/src/jit/x86_64/grad_slice.rs index a33b09c8..3f7a19c5 100644 --- a/fidget/src/jit/x86_64/grad_slice.rs +++ b/fidget/src/jit/x86_64/grad_slice.rs @@ -92,8 +92,8 @@ impl Assembler for GradSliceAssembler { ; vmovups [rsp + sp_offset], Rx(reg(src_reg)) ); } - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - let pos = 8 * (src_arg as i32); // offset within the pointer array + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + let pos = 8 * i32::try_from(src_arg).unwrap(); dynasm!(self.0.ops ; mov r8, [rdi + pos] // read the *const float from the array ; add r8, rcx // offset it by array position diff --git a/fidget/src/jit/x86_64/interval.rs b/fidget/src/jit/x86_64/interval.rs index 2a242f23..65b002f0 100644 --- a/fidget/src/jit/x86_64/interval.rs +++ b/fidget/src/jit/x86_64/interval.rs @@ -85,8 +85,8 @@ impl Assembler for IntervalAssembler { ; movq [rsp + sp_offset], Rx(reg(src_reg)) ); } - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - let pos = 8 * (src_arg as i32); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + let pos = 8 * i32::try_from(src_arg).unwrap(); dynasm!(self.0.ops ; vmovq Rx(reg(out_reg)), [rdi + pos] ); diff --git a/fidget/src/jit/x86_64/point.rs b/fidget/src/jit/x86_64/point.rs index 21140cf8..a8bfef7b 100644 --- a/fidget/src/jit/x86_64/point.rs +++ b/fidget/src/jit/x86_64/point.rs @@ -83,8 +83,8 @@ impl Assembler for PointAssembler { ; vmovss [rsp + sp_offset], Rx(reg(src_reg)) ); } - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - let pos = 4 * (src_arg as i32); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + let pos = 4 * i32::try_from(src_arg).unwrap(); dynasm!(self.0.ops // Pull the input from the rdi array ; vmovss Rx(reg(out_reg)), [rdi + pos] diff --git a/wasm-demo/Cargo.lock b/wasm-demo/Cargo.lock index 5a55a4da..61aad802 100644 --- a/wasm-demo/Cargo.lock +++ b/wasm-demo/Cargo.lock @@ -267,6 +267,7 @@ dependencies = [ "num-derive", "num-traits", "ordered-float", + "rand", "rhai", "serde", "static_assertions", @@ -463,6 +464,12 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro2" version = "1.0.81" @@ -481,6 +488,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "rawpointer" version = "0.2.1" diff --git a/workspace-hack/Cargo.toml b/workspace-hack/Cargo.toml index cd563f22..f9f4b5d8 100644 --- a/workspace-hack/Cargo.toml +++ b/workspace-hack/Cargo.toml @@ -19,12 +19,14 @@ bytemuck = { version = "1", default-features = false, features = ["derive", "ext clap = { version = "4", features = ["derive"] } clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "suggestions", "usage"] } crossbeam-utils = { version = "0.8" } +getrandom = { version = "0.2", default-features = false, features = ["std"] } once_cell = { version = "1" } regex = { version = "1", default-features = false, features = ["perf", "std"] } regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal"] } -serde = { version = "1", features = ["alloc", "derive"] } +serde = { version = "1", features = ["alloc", "derive", "rc"] } [build-dependencies] +getrandom = { version = "0.2", default-features = false, features = ["std"] } once_cell = { version = "1" } proc-macro2 = { version = "1" } quote = { version = "1" }