From 8dc2122ff955235406bd1536af5837b4e5697686 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sat, 25 May 2024 10:17:20 -0400 Subject: [PATCH 1/8] Remove VarNode, simplify Var --- Cargo.lock | 37 +++++++ fidget/Cargo.toml | 1 + fidget/src/core/compiler/ssa_tape.rs | 17 +-- fidget/src/core/context/mod.rs | 66 ++---------- fidget/src/core/context/op.rs | 4 +- fidget/src/core/context/var.rs | 148 +++++++++++++++++++++++++++ fidget/src/core/eval/mod.rs | 9 +- fidget/src/core/shape/mod.rs | 12 +-- fidget/src/core/vm/data.rs | 8 +- fidget/src/core/vm/mod.rs | 5 +- fidget/src/jit/mod.rs | 8 +- 11 files changed, 223 insertions(+), 92 deletions(-) create mode 100644 fidget/src/core/context/var.rs diff --git a/Cargo.lock b/Cargo.lock index d5a99c76..c22e3e29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -892,6 +892,7 @@ dependencies = [ "num-derive", "num-traits", "ordered-float", + "rand", "rhai", "serde", "static_assertions", @@ -1777,6 +1778,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro-crate" version = "1.3.1" @@ -1838,6 +1845,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "raw-window-handle" version = "0.5.2" diff --git a/fidget/Cargo.toml b/fidget/Cargo.toml index 21327cf8..b6b04ff3 100644 --- a/fidget/Cargo.toml +++ b/fidget/Cargo.toml @@ -17,6 +17,7 @@ nalgebra = "0.31" num-derive = "0.3" num-traits = "0.2" ordered-float = "3" +rand = "0.8.5" static_assertions = "1" thiserror = "1" workspace-hack = { version = "0.1", path = "../workspace-hack" } diff --git a/fidget/src/core/compiler/ssa_tape.rs b/fidget/src/core/compiler/ssa_tape.rs index 80e4fc55..e4d0c5e5 100644 --- a/fidget/src/core/compiler/ssa_tape.rs +++ b/fidget/src/core/compiler/ssa_tape.rs @@ -1,8 +1,7 @@ //use crate::vm::{RegisterAllocator, Tape as VmTape}; use crate::{ compiler::SsaOp, - context::{BinaryOpcode, Node, Op, UnaryOpcode}, - eval::VarMap, + context::{BinaryOpcode, Node, Op, UnaryOpcode, VarMap}, Context, Error, }; use serde::{Deserialize, Serialize}; @@ -33,7 +32,10 @@ impl SsaTape { /// /// This should always succeed unless the `root` is from a different /// `Context`, in which case `Error::BadNode` will be returned. - pub fn new(ctx: &Context, root: Node) -> Result<(Self, VarMap), Error> { + pub fn new( + ctx: &Context, + root: Node, + ) -> Result<(Self, VarMap), Error> { let mut mapping = HashMap::new(); let mut parent_count: HashMap = HashMap::new(); let mut slot_count = 0; @@ -47,7 +49,7 @@ impl SsaTape { // Accumulate parent counts and declare all nodes let mut seen = HashSet::new(); - let mut vars = HashMap::new(); + let mut vars = VarMap::new(); let mut todo = vec![root]; while let Some(node) = todo.pop() { if !seen.insert(node) { @@ -61,8 +63,7 @@ impl SsaTape { _ => { if let Op::Input(v) = op { let next = vars.len(); - let v = ctx.get_var_by_index(*v).unwrap().clone(); - vars.entry(v).or_insert(next); + vars.entry(*v).or_insert(next); } let i = slot_count; slot_count += 1; @@ -97,8 +98,8 @@ impl SsaTape { continue; }; let op = match op { - Op::Input(..) => { - let arg = vars[ctx.var_name(node).unwrap().unwrap()]; + Op::Input(v) => { + let arg = vars[v]; SsaOp::Input(i, arg.try_into().unwrap()) } Op::Const(..) => { diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index b7659c67..016c690b 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -17,10 +17,12 @@ mod indexed; mod op; mod tree; +mod var; use indexed::{define_index, Index, IndexMap, IndexVec}; pub use op::{BinaryOpcode, Op, UnaryOpcode}; pub use tree::{Tree, TreeOp}; +pub use var::{Var, VarMap}; use crate::Error; @@ -32,7 +34,6 @@ use std::sync::Arc; use ordered_float::OrderedFloat; define_index!(Node, "An index in the `Context::ops` map"); -define_index!(VarNode, "An index in the `Context::vars` map"); /// A `Context` holds a set of deduplicated constants, variables, and /// operations. @@ -42,39 +43,6 @@ define_index!(VarNode, "An index in the `Context::vars` map"); #[derive(Debug, Default)] pub struct Context { ops: IndexMap, - vars: IndexMap, -} - -/// A `Var` represents a value which can vary during evaluation -/// -/// We pre-define common variables (e.g. X, Y, Z) but also allow for fully -/// customized values. -#[allow(missing_docs)] -#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] -pub enum Var { - X, - Y, - Z, - W, - T, - Static(&'static str), - Named(String), - Value(u64), -} - -impl std::fmt::Display for Var { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Var::X => write!(f, "X"), - Var::Y => write!(f, "Y"), - Var::Z => write!(f, "Z"), - Var::W => write!(f, "W"), - Var::T => write!(f, "T"), - Var::Static(s) => write!(f, "{s}"), - Var::Named(s) => write!(f, "{s}"), - Var::Value(v) => write!(f, "v_{v}"), - } - } } impl Context { @@ -85,7 +53,7 @@ impl Context { /// Clears the context /// - /// All [`Node`] and [`VarNode`] handles from this context are invalidated. + /// All [`Node`] handles from this context are invalidated. /// /// ``` /// # use fidget::context::Context; @@ -96,7 +64,6 @@ impl Context { /// ``` pub fn clear(&mut self) { self.ops.clear(); - self.vars.clear(); } /// Returns the number of [`Op`] nodes in the context @@ -155,22 +122,14 @@ impl Context { /// /// If the node is invalid for this tree, returns an error; if the node is /// not an `Op::Input`, returns `Ok(None)`. - pub fn var_name(&self, n: Node) -> Result, Error> { + pub fn get_var(&self, n: Node) -> Result, Error> { match self.get_op(n) { - Some(Op::Input(c)) => self.get_var_by_index(*c).map(Some), + Some(Op::Input(v)) => Ok(Some(*v)), Some(_) => Ok(None), _ => Err(Error::BadNode), } } - /// Looks up the [`Var`] associated with the given [`VarNode`] - pub fn get_var_by_index(&self, n: VarNode) -> Result<&Var, Error> { - match self.vars.get_by_index(n) { - Some(c) => Ok(c), - None => Err(Error::BadVar), - } - } - //////////////////////////////////////////////////////////////////////////// // Primitives /// Constructs or finds a variable node named "X" @@ -182,20 +141,17 @@ impl Context { /// assert_eq!(v, 1.0); /// ``` pub fn x(&mut self) -> Node { - let v = self.vars.insert(Var::X); - self.ops.insert(Op::Input(v)) + self.ops.insert(Op::Input(Var::X)) } /// Constructs or finds a variable node named "Y" pub fn y(&mut self) -> Node { - let v = self.vars.insert(Var::Y); - self.ops.insert(Op::Input(v)) + self.ops.insert(Op::Input(Var::Y)) } /// Constructs or finds a variable node named "Z" pub fn z(&mut self) -> Node { - let v = self.vars.insert(Var::Z); - self.ops.insert(Op::Input(v)) + self.ops.insert(Op::Input(Var::Z)) } /// Returns a 3-element array of `X`, `Y`, `Z` nodes @@ -822,10 +778,7 @@ impl Context { } let mut get = |n: Node| self.eval_inner(n, vars, cache); let v = match self.get_op(node).ok_or(Error::BadNode)? { - Op::Input(v) => { - let var_name = self.vars.get_by_index(*v).unwrap(); - *vars.get(var_name).unwrap() - } + Op::Input(v) => *vars.get(v).unwrap(), Op::Const(c) => c.0, Op::Binary(op, a, b) => { @@ -995,7 +948,6 @@ impl Context { match op { Op::Const(c) => write!(out, "{}", c).unwrap(), Op::Input(v) => { - let v = self.vars.get_by_index(*v).unwrap(); out += &v.to_string(); } Op::Binary(op, ..) => match op { diff --git a/fidget/src/core/context/op.rs b/fidget/src/core/context/op.rs index e8ce6b74..218b2ece 100644 --- a/fidget/src/core/context/op.rs +++ b/fidget/src/core/context/op.rs @@ -1,4 +1,4 @@ -use crate::context::{indexed::Index, Node, VarNode}; +use crate::context::{indexed::Index, Node, Var}; use ordered_float::OrderedFloat; /// A one-argument math operation @@ -53,7 +53,7 @@ pub enum BinaryOpcode { #[allow(missing_docs)] #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] pub enum Op { - Input(VarNode), + Input(Var), Const(OrderedFloat), Binary(BinaryOpcode, Node, Node), Unary(UnaryOpcode, Node), diff --git a/fidget/src/core/context/var.rs b/fidget/src/core/context/var.rs new file mode 100644 index 00000000..c7b52748 --- /dev/null +++ b/fidget/src/core/context/var.rs @@ -0,0 +1,148 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// A `Var` represents a value which can vary during evaluation +/// +/// We pre-define common variables (e.g. `X`, `Y`, `Z`) but also allow for fully +/// customized values. +/// +/// Variables are "global", in that every instance of `Var::X` represents the +/// same thing. To generate a "local" variable, [`Var::new`] picks a random +/// 64-bit value, which is very unlikely to collide with anything else. +#[allow(missing_docs)] +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub enum Var { + X, + Y, + Z, + V(u64), +} + +impl Var { + /// Returns a new variable, with a random 64-bit index + /// + /// The odds of collision with any previous variable are infintesimally + /// small; if you are generating billions of random variables, something + /// else in the system is likely to break before collisions become an issue. + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let v: u64 = rand::random(); + Var::V(v) + } +} + +impl std::fmt::Display for Var { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Var::X => write!(f, "X"), + Var::Y => write!(f, "Y"), + Var::Z => write!(f, "Z"), + Var::V(v) if *v < 256 => write!(f, "v_{v}"), + Var::V(v) => write!(f, "V({v:x})"), + } + } +} + +/// Map from [`Var`] to a particular value +/// +/// This is equivalent to a +/// [`HashMap`](std::collections::HashMap) and as such does not include +/// per-function documentation. +/// +/// The advantage over a `HashMap` is that for common variables (`X`, `Y`, `Z`), +/// no allocation is required. +#[derive(Serialize, Deserialize)] +pub struct VarMap { + x: Option, + y: Option, + z: Option, + v: HashMap, +} + +impl Default for VarMap { + fn default() -> Self { + Self { + x: None, + y: None, + z: None, + v: HashMap::default(), + } + } +} + +#[allow(missing_docs)] +impl VarMap { + pub fn new() -> Self { + Self::default() + } + pub fn len(&self) -> usize { + self.x.is_some() as usize + + self.y.is_some() as usize + + self.z.is_some() as usize + + self.v.len() + } + pub fn is_empty(&self) -> bool { + self.x.is_none() + && self.y.is_none() + && self.z.is_none() + && self.v.is_empty() + } + pub fn get(&self, v: &Var) -> Option<&T> { + match v { + Var::X => self.x.as_ref(), + Var::Y => self.y.as_ref(), + Var::Z => self.z.as_ref(), + Var::V(v) => self.v.get(v), + } + } + + pub fn get_mut(&mut self, v: &Var) -> Option<&mut T> { + match v { + Var::X => self.x.as_mut(), + Var::Y => self.y.as_mut(), + Var::Z => self.z.as_mut(), + Var::V(v) => self.v.get_mut(v), + } + } + + pub fn entry(&mut self, v: Var) -> VarMapEntry { + match v { + Var::X => VarMapEntry::Option(&mut self.x), + Var::Y => VarMapEntry::Option(&mut self.y), + Var::Z => VarMapEntry::Option(&mut self.z), + Var::V(v) => VarMapEntry::Hash(self.v.entry(v)), + } + } +} + +pub enum VarMapEntry<'a, T> { + Option(&'a mut Option), + Hash(std::collections::hash_map::Entry<'a, u64, T>), +} + +impl<'a, T> VarMapEntry<'a, T> { + pub fn or_insert(self, default: T) -> &'a mut T { + match self { + VarMapEntry::Option(o) => match o { + Some(v) => v, + None => { + *o = Some(default); + o.as_mut().unwrap() + } + }, + VarMapEntry::Hash(e) => e.or_insert(default), + } + } +} + +impl std::ops::Index<&Var> for VarMap { + type Output = T; + fn index(&self, v: &Var) -> &Self::Output { + match v { + Var::X => self.x.as_ref().unwrap(), + Var::Y => self.y.as_ref().unwrap(), + Var::Z => self.z.as_ref().unwrap(), + Var::V(v) => &self.v[v], + } + } +} diff --git a/fidget/src/core/eval/mod.rs b/fidget/src/core/eval/mod.rs index 680d58f3..9f249697 100644 --- a/fidget/src/core/eval/mod.rs +++ b/fidget/src/core/eval/mod.rs @@ -1,6 +1,6 @@ //! Traits and data structures for function evaluation use crate::{ - context::{Context, Node, Tree, Var}, + context::{Context, Node, Tree, VarMap}, types::{Grad, Interval}, Error, }; @@ -167,9 +167,6 @@ pub trait Function: Send + Sync + Clone { fn size(&self) -> usize; } -/// Map from variable (from a particular [`Context`]) to index -pub type VarMap = std::collections::HashMap; - /// A [`Function`] which can be built from a math expression pub trait MathFunction { /// Builds a new function from the given context and node @@ -177,12 +174,12 @@ pub trait MathFunction { /// Returns a tuple of the [`MathFunction`] and a [`VarMap`] representing a /// mapping from variables (in the [`Context`]) to indices used during /// evaluation. - fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> + fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> where Self: Sized; /// Helper function to build a function from a [`Tree`] - fn from_tree(t: &Tree) -> (Self, VarMap) + fn from_tree(t: &Tree) -> (Self, VarMap) where Self: Sized, { diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index 6240171a..1b4b7f26 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -27,7 +27,7 @@ //! ``` use crate::{ - context::{Context, Node, Tree}, + context::{Context, Node, Tree, Var}, eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator}, types::{Grad, Interval}, Error, @@ -295,15 +295,12 @@ impl Shape { pub fn new_with_axes( ctx: &Context, node: Node, - axes: [Node; 3], + axes: [Var; 3], ) -> Result { let (f, vs) = F::new(ctx, node)?; - let x = ctx.var_name(axes[0])?.ok_or(Error::NotAVar)?; - let y = ctx.var_name(axes[1])?.ok_or(Error::NotAVar)?; - let z = ctx.var_name(axes[2])?.ok_or(Error::NotAVar)?; Ok(Self { f, - axes: [x, y, z].map(|v| vs.get(v).cloned()), + axes: axes.map(|v| vs.get(&v).cloned()), transform: None, }) } @@ -313,8 +310,7 @@ impl Shape { where Self: Sized, { - let axes = ctx.axes(); - Self::new_with_axes(ctx, node, axes) + Self::new_with_axes(ctx, node, [Var::X, Var::Y, Var::Z]) } } diff --git a/fidget/src/core/vm/data.rs b/fidget/src/core/vm/data.rs index d60375c0..5996c1d3 100644 --- a/fidget/src/core/vm/data.rs +++ b/fidget/src/core/vm/data.rs @@ -1,8 +1,7 @@ //! General-purpose tapes for use during evaluation or further compilation use crate::{ compiler::{RegOp, RegTape, RegisterAllocator, SsaOp, SsaTape}, - context::{Context, Node}, - eval::VarMap, + context::{Context, Node, VarMap}, vm::Choice, Error, }; @@ -67,7 +66,10 @@ pub struct VmData { impl VmData { /// Builds a new tape for the given node - pub fn new(context: &Context, node: Node) -> Result<(Self, VarMap), Error> { + pub fn new( + context: &Context, + node: Node, + ) -> Result<(Self, VarMap), Error> { let (ssa, vs) = SsaTape::new(context, node)?; let asm = RegTape::new::(&ssa); Ok((Self { ssa, asm }, vs)) diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index aee2363a..61cbe2dd 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -1,10 +1,9 @@ //! Simple virtual machine for shape evaluation use crate::{ compiler::RegOp, - context::Node, + context::{Node, VarMap}, eval::{ BulkEvaluator, Function, MathFunction, Tape, Trace, TracingEvaluator, - VarMap, }, shape::{RenderHints, Shape}, types::{Grad, Interval}, @@ -190,7 +189,7 @@ impl RenderHints for GenericVmFunction { } impl MathFunction for GenericVmFunction { - fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> { + fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> { let (d, vs) = VmData::new(ctx, node)?; Ok((Self(Arc::new(d)), vs)) } diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index f619aa86..edbb6f54 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -25,10 +25,8 @@ use crate::{ compiler::RegOp, - context::{Context, Node}, - eval::{ - BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator, VarMap, - }, + context::{Context, Node, VarMap}, + eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator}, jit::mmap::Mmap, shape::RenderHints, types::{Grad, Interval}, @@ -1211,7 +1209,7 @@ impl BulkEvaluator for JitGradSliceEval { } impl MathFunction for JitFunction { - fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> { + fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> { let (f, vars) = GenericVmFunction::new(ctx, node)?; Ok((JitFunction(f), vars)) } From 1dbb7f5872add0381e69ea6cfd600a31b22766f2 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sat, 25 May 2024 10:56:35 -0400 Subject: [PATCH 2/8] Add vars() to Function trait --- fidget/Cargo.toml | 2 +- fidget/src/core/context/mod.rs | 28 +++++++++++++++++----------- fidget/src/core/eval/mod.rs | 13 ++++++------- fidget/src/core/shape/mod.rs | 10 ++++++++-- fidget/src/core/vm/data.rs | 27 ++++++++++++++++++--------- fidget/src/core/vm/mod.rs | 10 +++++++--- fidget/src/jit/mod.rs | 9 ++++++--- 7 files changed, 63 insertions(+), 36 deletions(-) diff --git a/fidget/Cargo.toml b/fidget/Cargo.toml index b6b04ff3..204ef22a 100644 --- a/fidget/Cargo.toml +++ b/fidget/Cargo.toml @@ -21,7 +21,7 @@ rand = "0.8.5" static_assertions = "1" thiserror = "1" workspace-hack = { version = "0.1", path = "../workspace-hack" } -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", features = ["derive", "rc"] } # JIT dynasmrt = { version = "2.0", optional = true } diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index 016c690b..cb161dff 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -4,7 +4,8 @@ //! //! - A [`Tree`] is a free-floating math expression, which can be cloned //! and has overloaded operators for ease of use. It is **not** deduplicated; -//! two calls to [`Tree::x`] will produce two different [`TreeOp`] objects. +//! two calls to [`Tree::constant(1.0)`](Tree::constant) will allocate two +//! different [`TreeOp`] objects. //! `Tree` objects are typically used when building up expressions; they //! should be converted to `Node` objects (in a particular `Context`) after //! they have been constructed. @@ -132,7 +133,7 @@ impl Context { //////////////////////////////////////////////////////////////////////////// // Primitives - /// Constructs or finds a variable node named "X" + /// Constructs or finds a [`Var::X`] node /// ``` /// # use fidget::context::Context; /// let mut ctx = Context::new(); @@ -141,17 +142,22 @@ impl Context { /// assert_eq!(v, 1.0); /// ``` pub fn x(&mut self) -> Node { - self.ops.insert(Op::Input(Var::X)) + self.var(Var::X) } - /// Constructs or finds a variable node named "Y" + /// Constructs or finds a [`Var::Y`] node pub fn y(&mut self) -> Node { - self.ops.insert(Op::Input(Var::Y)) + self.var(Var::Y) } - /// Constructs or finds a variable node named "Z" + /// Constructs or finds a [`Var::Z`] node pub fn z(&mut self) -> Node { - self.ops.insert(Op::Input(Var::Z)) + self.var(Var::Z) + } + + /// Constructs or finds a variable input node + pub fn var(&mut self, v: Var) -> Node { + self.ops.insert(Op::Input(v)) } /// Returns a 3-element array of `X`, `Y`, `Z` nodes @@ -1197,9 +1203,9 @@ mod test { let c8 = ctx.sub(c7, r).unwrap(); let c9 = ctx.max(c8, c6).unwrap(); - let (tape, vs) = VmData::<255>::new(&ctx, c9).unwrap(); + let tape = VmData::<255>::new(&ctx, c9).unwrap(); assert_eq!(tape.len(), 8); - assert_eq!(vs.len(), 2); + assert_eq!(tape.vars.len(), 2); } #[test] @@ -1208,8 +1214,8 @@ mod test { let x = ctx.x(); let x_squared = ctx.mul(x, x).unwrap(); - let (tape, vs) = VmData::<255>::new(&ctx, x_squared).unwrap(); + let tape = VmData::<255>::new(&ctx, x_squared).unwrap(); assert_eq!(tape.len(), 2); - assert_eq!(vs.len(), 1); + assert_eq!(tape.vars.len(), 1); } } diff --git a/fidget/src/core/eval/mod.rs b/fidget/src/core/eval/mod.rs index 9f249697..2be88d53 100644 --- a/fidget/src/core/eval/mod.rs +++ b/fidget/src/core/eval/mod.rs @@ -165,21 +165,20 @@ pub trait Function: Send + Sync + Clone { /// This is underspecified and only used for unit testing; for tape-based /// functions, it's typically the length of the tape, fn size(&self) -> usize; + + /// Returns a map from variable to index + fn vars(&self) -> &VarMap; } /// A [`Function`] which can be built from a math expression -pub trait MathFunction { +pub trait MathFunction: Function { /// Builds a new function from the given context and node - /// - /// Returns a tuple of the [`MathFunction`] and a [`VarMap`] representing a - /// mapping from variables (in the [`Context`]) to indices used during - /// evaluation. - fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> + fn new(ctx: &Context, node: Node) -> Result where Self: Sized; /// Helper function to build a function from a [`Tree`] - fn from_tree(t: &Tree) -> (Self, VarMap) + fn from_tree(t: &Tree) -> Self where Self: Sized, { diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index 1b4b7f26..a92625c9 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -50,6 +50,10 @@ pub struct Shape { f: F, /// Index of x, y, z axes within the function's variable list (if present) + /// + /// We could instead store an array of [`Var`]s and look them in with + /// [`self.f.vars()[&v]`](Function::vars), but it's more efficient to cache + /// them upon construction (because they never change). axes: [Option; 3], /// Optional transform to apply to the shape @@ -297,10 +301,12 @@ impl Shape { node: Node, axes: [Var; 3], ) -> Result { - let (f, vs) = F::new(ctx, node)?; + let f = F::new(ctx, node)?; + let vars = f.vars(); + let axes = axes.map(|v| vars.get(&v).cloned()); Ok(Self { f, - axes: axes.map(|v| vs.get(&v).cloned()), + axes, transform: None, }) } diff --git a/fidget/src/core/vm/data.rs b/fidget/src/core/vm/data.rs index 5996c1d3..6743ab90 100644 --- a/fidget/src/core/vm/data.rs +++ b/fidget/src/core/vm/data.rs @@ -6,6 +6,7 @@ use crate::{ Error, }; use serde::{Deserialize, Serialize}; +use std::sync::Arc; /// A flattened math expression, ready for evaluation or further compilation. /// @@ -45,12 +46,12 @@ use serde::{Deserialize, Serialize}; /// let tree = Tree::x() + Tree::y(); /// let mut ctx = Context::new(); /// let sum = ctx.import(&tree); -/// let (data, vars) = VmData::<255>::new(&ctx, sum)?; +/// let data = VmData::<255>::new(&ctx, sum)?; /// assert_eq!(data.len(), 3); // X, Y, and (X + Y) /// /// let mut iter = data.iter_asm(); -/// assert_eq!(iter.next().unwrap(), RegOp::Input(0, vars[&Var::X] as u8)); -/// assert_eq!(iter.next().unwrap(), RegOp::Input(1, vars[&Var::Y] as u8)); +/// assert_eq!(iter.next().unwrap(), RegOp::Input(0, data.vars[&Var::X] as u8)); +/// assert_eq!(iter.next().unwrap(), RegOp::Input(1, data.vars[&Var::Y] as u8)); /// assert_eq!(iter.next().unwrap(), RegOp::AddRegReg(0, 0, 1)); /// # Ok::<(), fidget::Error>(()) /// ``` @@ -62,17 +63,24 @@ use serde::{Deserialize, Serialize}; pub struct VmData { ssa: SsaTape, asm: RegTape, + + /// Mapping from variables to indices during evaluation + /// + /// This member is stored in a shared pointer because it's passed down to + /// children (constructed with [`VmData::simplify`]). + pub vars: Arc>, } impl VmData { /// Builds a new tape for the given node - pub fn new( - context: &Context, - node: Node, - ) -> Result<(Self, VarMap), Error> { - let (ssa, vs) = SsaTape::new(context, node)?; + pub fn new(context: &Context, node: Node) -> Result { + let (ssa, vars) = SsaTape::new(context, node)?; let asm = RegTape::new::(&ssa); - Ok((Self { ssa, asm }, vs)) + Ok(Self { + ssa, + asm, + vars: vars.into(), + }) } /// Returns the length of the internal VM tape @@ -289,6 +297,7 @@ impl VmData { choice_count, }, asm: asm_tape, + vars: self.vars.clone(), }) } diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index 61cbe2dd..d6c87b27 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -176,6 +176,10 @@ impl Function for GenericVmFunction { fn size(&self) -> usize { GenericVmFunction::size(self) } + + fn vars(&self) -> &VarMap { + &self.0.vars + } } impl RenderHints for GenericVmFunction { @@ -189,9 +193,9 @@ impl RenderHints for GenericVmFunction { } impl MathFunction for GenericVmFunction { - fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> { - let (d, vs) = VmData::new(ctx, node)?; - Ok((Self(Arc::new(d)), vs)) + fn new(ctx: &Context, node: Node) -> Result { + let d = VmData::new(ctx, node)?; + Ok(Self(d.into())) } } diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index edbb6f54..4024b757 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -898,6 +898,10 @@ impl Function for JitFunction { fn size(&self) -> usize { self.0.size() } + + fn vars(&self) -> &VarMap { + self.0.vars() + } } impl RenderHints for JitFunction { @@ -1209,9 +1213,8 @@ impl BulkEvaluator for JitGradSliceEval { } impl MathFunction for JitFunction { - fn new(ctx: &Context, node: Node) -> Result<(Self, VarMap), Error> { - let (f, vars) = GenericVmFunction::new(ctx, node)?; - Ok((JitFunction(f), vars)) + fn new(ctx: &Context, node: Node) -> Result { + GenericVmFunction::new(ctx, node).map(JitFunction) } } From 83aabdcd679c95ce46a06495c4f94776b9bd39e7 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sat, 25 May 2024 11:12:10 -0400 Subject: [PATCH 3/8] Remove unused function --- fidget/src/core/eval/mod.rs | 12 +----------- wasm-demo/Cargo.lock | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/fidget/src/core/eval/mod.rs b/fidget/src/core/eval/mod.rs index 2be88d53..85801dfd 100644 --- a/fidget/src/core/eval/mod.rs +++ b/fidget/src/core/eval/mod.rs @@ -1,6 +1,6 @@ //! Traits and data structures for function evaluation use crate::{ - context::{Context, Node, Tree, VarMap}, + context::{Context, Node, VarMap}, types::{Grad, Interval}, Error, }; @@ -176,14 +176,4 @@ pub trait MathFunction: Function { fn new(ctx: &Context, node: Node) -> Result where Self: Sized; - - /// Helper function to build a function from a [`Tree`] - fn from_tree(t: &Tree) -> Self - where - Self: Sized, - { - let mut ctx = Context::new(); - let node = ctx.import(t); - Self::new(&ctx, node).unwrap() - } } diff --git a/wasm-demo/Cargo.lock b/wasm-demo/Cargo.lock index 5a55a4da..61aad802 100644 --- a/wasm-demo/Cargo.lock +++ b/wasm-demo/Cargo.lock @@ -267,6 +267,7 @@ dependencies = [ "num-derive", "num-traits", "ordered-float", + "rand", "rhai", "serde", "static_assertions", @@ -463,6 +464,12 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro2" version = "1.0.81" @@ -481,6 +488,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "rawpointer" version = "0.2.1" From 3a4a6eb38966c0f073e0a111913a6b7a875279fe Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sat, 25 May 2024 11:17:15 -0400 Subject: [PATCH 4/8] Update CHANGELOG --- CHANGELOG.md | 6 ++++++ fidget/src/core/var/mod.rs | 0 2 files changed, 6 insertions(+) create mode 100644 fidget/src/core/var/mod.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a24920a..65aea4c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ - Using the `VmShape` or `JitShape` types should be mostly the same as before; changes are most noticeable if you're writing things that are generic across `S: Shape`. +- Major refactoring of how variables are handled + - Removed `VarNode`; the canonical variable type is `Var`, which is its own + unique index. + - Removed named variables, to make `Var` trivially `Copy + Clone`. + - Added `vars()` method to `Function` trait, allowing users to look up the + mapping from variable to evaluation index. # 0.2.7 This release brings us to opcode parity with `libfive`'s operators, adding diff --git a/fidget/src/core/var/mod.rs b/fidget/src/core/var/mod.rs new file mode 100644 index 00000000..e69de29b From 8cb3c4ea3f11eb5542319a0dce35a382b5747581 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sat, 25 May 2024 11:21:05 -0400 Subject: [PATCH 5/8] Move vars module; write more docs --- CHANGELOG.md | 5 +- fidget/src/core/compiler/ssa_tape.rs | 3 +- fidget/src/core/context/mod.rs | 4 +- fidget/src/core/context/op.rs | 5 +- fidget/src/core/context/var.rs | 148 --------------------- fidget/src/core/eval/mod.rs | 3 +- fidget/src/core/mod.rs | 2 + fidget/src/core/shape/mod.rs | 3 +- fidget/src/core/var/mod.rs | 184 +++++++++++++++++++++++++++ fidget/src/core/vm/data.rs | 3 +- fidget/src/core/vm/mod.rs | 3 +- fidget/src/jit/mod.rs | 3 +- 12 files changed, 207 insertions(+), 159 deletions(-) delete mode 100644 fidget/src/core/context/var.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 65aea4c5..9c77d2b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,10 @@ unique index. - Removed named variables, to make `Var` trivially `Copy + Clone`. - Added `vars()` method to `Function` trait, allowing users to look up the - mapping from variable to evaluation index. + mapping from variable to evaluation index. A `Var` now represents a + persistent identity from `Tree` to `Context` to `Function` evaluation. + - Move `Var` and `VarMap` into `fidget::vars` module, because they're no + longer specific to a `Context`. # 0.2.7 This release brings us to opcode parity with `libfive`'s operators, adding diff --git a/fidget/src/core/compiler/ssa_tape.rs b/fidget/src/core/compiler/ssa_tape.rs index e4d0c5e5..1bea9a29 100644 --- a/fidget/src/core/compiler/ssa_tape.rs +++ b/fidget/src/core/compiler/ssa_tape.rs @@ -1,7 +1,8 @@ //use crate::vm::{RegisterAllocator, Tape as VmTape}; use crate::{ compiler::SsaOp, - context::{BinaryOpcode, Node, Op, UnaryOpcode, VarMap}, + context::{BinaryOpcode, Node, Op, UnaryOpcode}, + var::VarMap, Context, Error, }; use serde::{Deserialize, Serialize}; diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index cb161dff..0dc8eeeb 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -18,14 +18,12 @@ mod indexed; mod op; mod tree; -mod var; use indexed::{define_index, Index, IndexMap, IndexVec}; pub use op::{BinaryOpcode, Op, UnaryOpcode}; pub use tree::{Tree, TreeOp}; -pub use var::{Var, VarMap}; -use crate::Error; +use crate::{var::Var, Error}; use std::collections::{BTreeMap, HashMap}; use std::fmt::Write; diff --git a/fidget/src/core/context/op.rs b/fidget/src/core/context/op.rs index 218b2ece..bd77c63f 100644 --- a/fidget/src/core/context/op.rs +++ b/fidget/src/core/context/op.rs @@ -1,4 +1,7 @@ -use crate::context::{indexed::Index, Node, Var}; +use crate::{ + context::{indexed::Index, Node}, + var::Var, +}; use ordered_float::OrderedFloat; /// A one-argument math operation diff --git a/fidget/src/core/context/var.rs b/fidget/src/core/context/var.rs deleted file mode 100644 index c7b52748..00000000 --- a/fidget/src/core/context/var.rs +++ /dev/null @@ -1,148 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -/// A `Var` represents a value which can vary during evaluation -/// -/// We pre-define common variables (e.g. `X`, `Y`, `Z`) but also allow for fully -/// customized values. -/// -/// Variables are "global", in that every instance of `Var::X` represents the -/// same thing. To generate a "local" variable, [`Var::new`] picks a random -/// 64-bit value, which is very unlikely to collide with anything else. -#[allow(missing_docs)] -#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] -pub enum Var { - X, - Y, - Z, - V(u64), -} - -impl Var { - /// Returns a new variable, with a random 64-bit index - /// - /// The odds of collision with any previous variable are infintesimally - /// small; if you are generating billions of random variables, something - /// else in the system is likely to break before collisions become an issue. - #[allow(clippy::new_without_default)] - pub fn new() -> Self { - let v: u64 = rand::random(); - Var::V(v) - } -} - -impl std::fmt::Display for Var { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Var::X => write!(f, "X"), - Var::Y => write!(f, "Y"), - Var::Z => write!(f, "Z"), - Var::V(v) if *v < 256 => write!(f, "v_{v}"), - Var::V(v) => write!(f, "V({v:x})"), - } - } -} - -/// Map from [`Var`] to a particular value -/// -/// This is equivalent to a -/// [`HashMap`](std::collections::HashMap) and as such does not include -/// per-function documentation. -/// -/// The advantage over a `HashMap` is that for common variables (`X`, `Y`, `Z`), -/// no allocation is required. -#[derive(Serialize, Deserialize)] -pub struct VarMap { - x: Option, - y: Option, - z: Option, - v: HashMap, -} - -impl Default for VarMap { - fn default() -> Self { - Self { - x: None, - y: None, - z: None, - v: HashMap::default(), - } - } -} - -#[allow(missing_docs)] -impl VarMap { - pub fn new() -> Self { - Self::default() - } - pub fn len(&self) -> usize { - self.x.is_some() as usize - + self.y.is_some() as usize - + self.z.is_some() as usize - + self.v.len() - } - pub fn is_empty(&self) -> bool { - self.x.is_none() - && self.y.is_none() - && self.z.is_none() - && self.v.is_empty() - } - pub fn get(&self, v: &Var) -> Option<&T> { - match v { - Var::X => self.x.as_ref(), - Var::Y => self.y.as_ref(), - Var::Z => self.z.as_ref(), - Var::V(v) => self.v.get(v), - } - } - - pub fn get_mut(&mut self, v: &Var) -> Option<&mut T> { - match v { - Var::X => self.x.as_mut(), - Var::Y => self.y.as_mut(), - Var::Z => self.z.as_mut(), - Var::V(v) => self.v.get_mut(v), - } - } - - pub fn entry(&mut self, v: Var) -> VarMapEntry { - match v { - Var::X => VarMapEntry::Option(&mut self.x), - Var::Y => VarMapEntry::Option(&mut self.y), - Var::Z => VarMapEntry::Option(&mut self.z), - Var::V(v) => VarMapEntry::Hash(self.v.entry(v)), - } - } -} - -pub enum VarMapEntry<'a, T> { - Option(&'a mut Option), - Hash(std::collections::hash_map::Entry<'a, u64, T>), -} - -impl<'a, T> VarMapEntry<'a, T> { - pub fn or_insert(self, default: T) -> &'a mut T { - match self { - VarMapEntry::Option(o) => match o { - Some(v) => v, - None => { - *o = Some(default); - o.as_mut().unwrap() - } - }, - VarMapEntry::Hash(e) => e.or_insert(default), - } - } -} - -impl std::ops::Index<&Var> for VarMap { - type Output = T; - fn index(&self, v: &Var) -> &Self::Output { - match v { - Var::X => self.x.as_ref().unwrap(), - Var::Y => self.y.as_ref().unwrap(), - Var::Z => self.z.as_ref().unwrap(), - Var::V(v) => &self.v[v], - } - } -} diff --git a/fidget/src/core/eval/mod.rs b/fidget/src/core/eval/mod.rs index 85801dfd..88ce2c08 100644 --- a/fidget/src/core/eval/mod.rs +++ b/fidget/src/core/eval/mod.rs @@ -1,7 +1,8 @@ //! Traits and data structures for function evaluation use crate::{ - context::{Context, Node, VarMap}, + context::{Context, Node}, types::{Grad, Interval}, + var::VarMap, Error, }; diff --git a/fidget/src/core/mod.rs b/fidget/src/core/mod.rs index 37ed2934..25ff28c4 100644 --- a/fidget/src/core/mod.rs +++ b/fidget/src/core/mod.rs @@ -60,11 +60,13 @@ pub mod compiler; pub mod eval; pub mod shape; pub mod types; +pub mod var; pub mod vm; #[cfg(test)] mod test { use crate::context::*; + use crate::var::Var; #[test] fn it_works() { diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index a92625c9..d37c4f30 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -27,9 +27,10 @@ //! ``` use crate::{ - context::{Context, Node, Tree, Var}, + context::{Context, Node, Tree}, eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator}, types::{Grad, Interval}, + var::Var, Error, }; use nalgebra::{Matrix4, Point3}; diff --git a/fidget/src/core/var/mod.rs b/fidget/src/core/var/mod.rs index e69de29b..b9a23f50 100644 --- a/fidget/src/core/var/mod.rs +++ b/fidget/src/core/var/mod.rs @@ -0,0 +1,184 @@ +//! Input variables to math expressions +//! +//! A [`Var`] maintains a persistent identity from +//! [`Tree`](crate::context::Tree) to [`Context`](crate::context::Node) (where +//! it is wrapped in a [`Op::Input`](crate::context::Op::Input)) to evaluation +//! (where [`Function::vars`](crate::eval::Function::vars) maps from `Var` to +//! index in the argument list). +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// The [`Var`] type is an input to a math expression +/// +/// We pre-define common variables (e.g. `X`, `Y`, `Z`) but also allow for fully +/// customized values (using [`Var::V`]). +/// +/// Variables are "global", in that every instance of `Var::X` represents the +/// same thing. To generate a "local" variable, [`Var::new`] picks a random +/// 64-bit value, which is very unlikely to collide with anything else. +#[allow(missing_docs)] +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub enum Var { + X, + Y, + Z, + V(u64), +} + +impl Var { + /// Returns a new variable, with a random 64-bit index + /// + /// The odds of collision with any previous variable are infintesimally + /// small; if you are generating billions of random variables, something + /// else in the system is likely to break before collisions become an issue. + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let v: u64 = rand::random(); + Var::V(v) + } +} + +impl std::fmt::Display for Var { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Var::X => write!(f, "X"), + Var::Y => write!(f, "Y"), + Var::Z => write!(f, "Z"), + Var::V(v) if *v < 256 => write!(f, "v_{v}"), + Var::V(v) => write!(f, "V({v:x})"), + } + } +} + +/// Map from [`Var`] to a particular value +/// +/// This is equivalent to a +/// [`HashMap`](std::collections::HashMap) and as such does not include +/// per-function documentation. +/// +/// The advantage over a `HashMap` is that for common variables (`X`, `Y`, `Z`), +/// no allocation is required. +#[derive(Serialize, Deserialize)] +pub struct VarMap { + x: Option, + y: Option, + z: Option, + v: HashMap, +} + +impl Default for VarMap { + fn default() -> Self { + Self { + x: None, + y: None, + z: None, + v: HashMap::default(), + } + } +} + +#[allow(missing_docs)] +impl VarMap { + pub fn new() -> Self { + Self::default() + } + pub fn len(&self) -> usize { + self.x.is_some() as usize + + self.y.is_some() as usize + + self.z.is_some() as usize + + self.v.len() + } + pub fn is_empty(&self) -> bool { + self.x.is_none() + && self.y.is_none() + && self.z.is_none() + && self.v.is_empty() + } + pub fn get(&self, v: &Var) -> Option<&T> { + match v { + Var::X => self.x.as_ref(), + Var::Y => self.y.as_ref(), + Var::Z => self.z.as_ref(), + Var::V(v) => self.v.get(v), + } + } + + pub fn get_mut(&mut self, v: &Var) -> Option<&mut T> { + match v { + Var::X => self.x.as_mut(), + Var::Y => self.y.as_mut(), + Var::Z => self.z.as_mut(), + Var::V(v) => self.v.get_mut(v), + } + } + + pub fn entry(&mut self, v: Var) -> VarMapEntry { + match v { + Var::X => VarMapEntry::Option(&mut self.x), + Var::Y => VarMapEntry::Option(&mut self.y), + Var::Z => VarMapEntry::Option(&mut self.z), + Var::V(v) => VarMapEntry::Hash(self.v.entry(v)), + } + } +} + +/// Entry into a [`VarMap`]; equivalent to [`std::collections::hash_map::Entry`] +/// +/// The implementation has just enough functions to be useful; if you find +/// yourself wanting the rest of the entry API, it could easily be expanded. +#[allow(missing_docs)] +pub enum VarMapEntry<'a, T> { + Option(&'a mut Option), + Hash(std::collections::hash_map::Entry<'a, u64, T>), +} + +#[allow(missing_docs)] +impl<'a, T> VarMapEntry<'a, T> { + pub fn or_insert(self, default: T) -> &'a mut T { + match self { + VarMapEntry::Option(o) => match o { + Some(v) => v, + None => { + *o = Some(default); + o.as_mut().unwrap() + } + }, + VarMapEntry::Hash(e) => e.or_insert(default), + } + } +} + +impl std::ops::Index<&Var> for VarMap { + type Output = T; + fn index(&self, v: &Var) -> &Self::Output { + match v { + Var::X => self.x.as_ref().unwrap(), + Var::Y => self.y.as_ref().unwrap(), + Var::Z => self.z.as_ref().unwrap(), + Var::V(v) => &self.v[v], + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn var_identity() { + let v1 = Var::new(); + let v2 = Var::new(); + assert_ne!(v1, v2); + } + + #[test] + fn var_map() { + let v = Var::new(); + let mut m = VarMap::new(); + assert!(m.get(&v).is_none()); + let p = m.entry(v).or_insert(123); + assert_eq!(*p, 123); + let p = m.entry(v).or_insert(456); + assert_eq!(*p, 123); + } +} diff --git a/fidget/src/core/vm/data.rs b/fidget/src/core/vm/data.rs index 6743ab90..eb93d816 100644 --- a/fidget/src/core/vm/data.rs +++ b/fidget/src/core/vm/data.rs @@ -1,7 +1,8 @@ //! General-purpose tapes for use during evaluation or further compilation use crate::{ compiler::{RegOp, RegTape, RegisterAllocator, SsaOp, SsaTape}, - context::{Context, Node, VarMap}, + context::{Context, Node}, + var::VarMap, vm::Choice, Error, }; diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index d6c87b27..5b4bbf1c 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -1,12 +1,13 @@ //! Simple virtual machine for shape evaluation use crate::{ compiler::RegOp, - context::{Node, VarMap}, + context::Node, eval::{ BulkEvaluator, Function, MathFunction, Tape, Trace, TracingEvaluator, }, shape::{RenderHints, Shape}, types::{Grad, Interval}, + var::VarMap, Context, Error, }; use std::sync::Arc; diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index 4024b757..0c7f356e 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -25,11 +25,12 @@ use crate::{ compiler::RegOp, - context::{Context, Node, VarMap}, + context::{Context, Node}, eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator}, jit::mmap::Mmap, shape::RenderHints, types::{Grad, Interval}, + var::VarMap, vm::{Choice, GenericVmFunction, VmData, VmTrace, VmWorkspace}, Error, }; From a91452e32d00b56e32a353c08af2441b07a7c568 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sat, 25 May 2024 11:31:59 -0400 Subject: [PATCH 6/8] Make Op::Input take a u32 --- CHANGELOG.md | 1 + fidget/src/core/compiler/alloc.rs | 2 +- fidget/src/core/compiler/op.rs | 4 ++-- fidget/src/core/vm/data.rs | 8 +++++--- fidget/src/jit/aarch64/float_slice.rs | 6 +++--- fidget/src/jit/aarch64/grad_slice.rs | 6 +++--- fidget/src/jit/aarch64/interval.rs | 6 +++--- fidget/src/jit/aarch64/point.rs | 6 +++--- fidget/src/jit/mod.rs | 2 +- fidget/src/jit/x86_64/float_slice.rs | 4 ++-- fidget/src/jit/x86_64/grad_slice.rs | 4 ++-- fidget/src/jit/x86_64/interval.rs | 4 ++-- fidget/src/jit/x86_64/point.rs | 4 ++-- 13 files changed, 30 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c77d2b0..058c2b4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ persistent identity from `Tree` to `Context` to `Function` evaluation. - Move `Var` and `VarMap` into `fidget::vars` module, because they're no longer specific to a `Context`. + - `Op::Input` now takes a `u32` argument, instead of a `u8` # 0.2.7 This release brings us to opcode parity with `libfive`'s operators, adding diff --git a/fidget/src/core/compiler/alloc.rs b/fidget/src/core/compiler/alloc.rs index e4ca6f60..e235a3ec 100644 --- a/fidget/src/core/compiler/alloc.rs +++ b/fidget/src/core/compiler/alloc.rs @@ -673,7 +673,7 @@ impl RegisterAllocator { /// Pushes an [`Input`](crate::compiler::RegOp::Input) operation to the tape #[inline(always)] - fn op_input(&mut self, out: u32, i: u8) { + fn op_input(&mut self, out: u32, i: u32) { // TODO: tightly pack variables (which may be sparse) into slots self.out.var_count = self.out.var_count.max(i as u32 + 1); self.op_out_only(out, |out| RegOp::Input(out, i)); diff --git a/fidget/src/core/compiler/op.rs b/fidget/src/core/compiler/op.rs index d003da23..73369312 100644 --- a/fidget/src/core/compiler/op.rs +++ b/fidget/src/core/compiler/op.rs @@ -23,8 +23,8 @@ macro_rules! opcodes { ) => { $(#[$($attrss)*])* pub enum $name { - #[doc = "Read one of the inputs (X, Y, Z)"] - Input($t, $t), + #[doc = "Read an input variable by index"] + Input($t, u32), #[doc = "Negate the given register"] NegReg($t, $t), diff --git a/fidget/src/core/vm/data.rs b/fidget/src/core/vm/data.rs index eb93d816..59e00697 100644 --- a/fidget/src/core/vm/data.rs +++ b/fidget/src/core/vm/data.rs @@ -40,8 +40,9 @@ use std::sync::Arc; /// ``` /// use fidget::{ /// compiler::RegOp, -/// context::{Context, Tree, Var}, +/// context::{Context, Tree}, /// vm::VmData, +/// var::Var, /// }; /// /// let tree = Tree::x() + Tree::y(); @@ -51,8 +52,9 @@ use std::sync::Arc; /// assert_eq!(data.len(), 3); // X, Y, and (X + Y) /// /// let mut iter = data.iter_asm(); -/// assert_eq!(iter.next().unwrap(), RegOp::Input(0, data.vars[&Var::X] as u8)); -/// assert_eq!(iter.next().unwrap(), RegOp::Input(1, data.vars[&Var::Y] as u8)); +/// let vars = &data.vars; // map from var to index +/// assert_eq!(iter.next().unwrap(), RegOp::Input(0, vars[&Var::X] as u32)); +/// assert_eq!(iter.next().unwrap(), RegOp::Input(1, vars[&Var::Y] as u32)); /// assert_eq!(iter.next().unwrap(), RegOp::AddRegReg(0, 0, 1)); /// # Ok::<(), fidget::Error>(()) /// ``` diff --git a/fidget/src/jit/aarch64/float_slice.rs b/fidget/src/jit/aarch64/float_slice.rs index c7bb3121..2e8a6e96 100644 --- a/fidget/src/jit/aarch64/float_slice.rs +++ b/fidget/src/jit/aarch64/float_slice.rs @@ -156,10 +156,10 @@ impl Assembler for FloatSliceAssembler { ) } /// Copies the given input to `out_reg` - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - assert!(src_arg as u32 * 8 < 16384); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + assert!(src_arg < 16384 / 8); dynasm!(self.0.ops - ; ldr x4, [x0, src_arg as u32 * 8] + ; ldr x4, [x0, src_arg * 8] ; add x4, x4, x3 // apply array offset ; ldr Q(reg(out_reg)), [x4] ); diff --git a/fidget/src/jit/aarch64/grad_slice.rs b/fidget/src/jit/aarch64/grad_slice.rs index 2a53db33..73fc54cd 100644 --- a/fidget/src/jit/aarch64/grad_slice.rs +++ b/fidget/src/jit/aarch64/grad_slice.rs @@ -158,10 +158,10 @@ impl Assembler for GradSliceAssembler { ) } /// Copies the given input to `out_reg` - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - assert!(src_arg as u32 * 8 < 16384); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + assert!(src_arg < 16384 / 8); dynasm!(self.0.ops - ; ldr x4, [x0, src_arg as u32 * 8] + ; ldr x4, [x0, src_arg * 8] ; add x4, x4, x3 // apply array offset ; eor V(reg(out_reg)).b16, V(reg(out_reg)).b16, V(reg(out_reg)).b16 ; ldr Q(reg(out_reg)), [x4] diff --git a/fidget/src/jit/aarch64/interval.rs b/fidget/src/jit/aarch64/interval.rs index a19179c0..8834bb5d 100644 --- a/fidget/src/jit/aarch64/interval.rs +++ b/fidget/src/jit/aarch64/interval.rs @@ -120,10 +120,10 @@ impl Assembler for IntervalAssembler { dynasm!(self.0.ops ; str D(reg(src_reg)), [sp, sp_offset]) } /// Copies the given input to `out_reg` - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - assert!(src_arg as u32 * 8 < 16384); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + assert!(src_arg < 16384 / 8); dynasm!(self.0.ops - ; ldr D(reg(out_reg)), [x0, src_arg as u32 * 8] + ; ldr D(reg(out_reg)), [x0, src_arg * 8] ); } fn build_sin(&mut self, out_reg: u8, lhs_reg: u8) { diff --git a/fidget/src/jit/aarch64/point.rs b/fidget/src/jit/aarch64/point.rs index dc6942a4..8d3704a5 100644 --- a/fidget/src/jit/aarch64/point.rs +++ b/fidget/src/jit/aarch64/point.rs @@ -115,10 +115,10 @@ impl Assembler for PointAssembler { dynasm!(self.0.ops ; str S(reg(src_reg)), [sp, sp_offset]) } /// Copies the given input to `out_reg` - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - assert!(src_arg as u32 * 4 < 16384); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + assert!(src_arg < 16384 / 4); dynasm!(self.0.ops - ; ldr S(reg(out_reg)), [x0, src_arg as u32 * 4] + ; ldr S(reg(out_reg)), [x0, src_arg * 4] ); } fn build_copy(&mut self, out_reg: u8, lhs_reg: u8) { diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index 0c7f356e..ca3f1fba 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -131,7 +131,7 @@ trait Assembler { fn build_store(&mut self, dst_mem: u32, src_reg: u8); /// Copies the given input to `out_reg` - fn build_input(&mut self, out_reg: u8, src_arg: u8); + fn build_input(&mut self, out_reg: u8, src_arg: u32); /// Copies a register fn build_copy(&mut self, out_reg: u8, lhs_reg: u8); diff --git a/fidget/src/jit/x86_64/float_slice.rs b/fidget/src/jit/x86_64/float_slice.rs index db7c59e8..c1a4e758 100644 --- a/fidget/src/jit/x86_64/float_slice.rs +++ b/fidget/src/jit/x86_64/float_slice.rs @@ -99,8 +99,8 @@ impl Assembler for FloatSliceAssembler { ; vmovups [rsp + sp_offset], Ry(reg(src_reg)) ); } - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - let pos = 8 * (src_arg as i32); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + let pos = 8 * i32::try_from(src_arg).unwrap(); dynasm!(self.0.ops ; mov r8, [rdi + pos] // read the *const float from the array ; add r8, rcx // offset it by array position diff --git a/fidget/src/jit/x86_64/grad_slice.rs b/fidget/src/jit/x86_64/grad_slice.rs index a33b09c8..3f7a19c5 100644 --- a/fidget/src/jit/x86_64/grad_slice.rs +++ b/fidget/src/jit/x86_64/grad_slice.rs @@ -92,8 +92,8 @@ impl Assembler for GradSliceAssembler { ; vmovups [rsp + sp_offset], Rx(reg(src_reg)) ); } - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - let pos = 8 * (src_arg as i32); // offset within the pointer array + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + let pos = 8 * i32::try_from(src_arg).unwrap(); dynasm!(self.0.ops ; mov r8, [rdi + pos] // read the *const float from the array ; add r8, rcx // offset it by array position diff --git a/fidget/src/jit/x86_64/interval.rs b/fidget/src/jit/x86_64/interval.rs index 2a242f23..65b002f0 100644 --- a/fidget/src/jit/x86_64/interval.rs +++ b/fidget/src/jit/x86_64/interval.rs @@ -85,8 +85,8 @@ impl Assembler for IntervalAssembler { ; movq [rsp + sp_offset], Rx(reg(src_reg)) ); } - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - let pos = 8 * (src_arg as i32); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + let pos = 8 * i32::try_from(src_arg).unwrap(); dynasm!(self.0.ops ; vmovq Rx(reg(out_reg)), [rdi + pos] ); diff --git a/fidget/src/jit/x86_64/point.rs b/fidget/src/jit/x86_64/point.rs index 21140cf8..a8bfef7b 100644 --- a/fidget/src/jit/x86_64/point.rs +++ b/fidget/src/jit/x86_64/point.rs @@ -83,8 +83,8 @@ impl Assembler for PointAssembler { ; vmovss [rsp + sp_offset], Rx(reg(src_reg)) ); } - fn build_input(&mut self, out_reg: u8, src_arg: u8) { - let pos = 4 * (src_arg as i32); + fn build_input(&mut self, out_reg: u8, src_arg: u32) { + let pos = 4 * i32::try_from(src_arg).unwrap(); dynasm!(self.0.ops // Pull the input from the rdi array ; vmovss Rx(reg(out_reg)), [rdi + pos] From fda50ba764b1821bc820a87ec28a1c461df03dc4 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sat, 25 May 2024 11:49:39 -0400 Subject: [PATCH 7/8] Clean up var_count implementation --- fidget/src/core/compiler/alloc.rs | 4 +--- fidget/src/core/compiler/reg_tape.rs | 9 --------- fidget/src/core/eval/tracing.rs | 4 +--- fidget/src/core/vm/data.rs | 7 +++++-- fidget/src/jit/mod.rs | 2 -- 5 files changed, 7 insertions(+), 19 deletions(-) diff --git a/fidget/src/core/compiler/alloc.rs b/fidget/src/core/compiler/alloc.rs index e235a3ec..d5eb418c 100644 --- a/fidget/src/core/compiler/alloc.rs +++ b/fidget/src/core/compiler/alloc.rs @@ -273,7 +273,7 @@ impl RegisterAllocator { #[inline(always)] pub fn op(&mut self, op: SsaOp) { match op { - SsaOp::Input(out, i) => self.op_input(out, i.try_into().unwrap()), + SsaOp::Input(out, i) => self.op_input(out, i), SsaOp::CopyImm(out, imm) => self.op_copy_imm(out, imm), SsaOp::NegReg(..) @@ -674,8 +674,6 @@ impl RegisterAllocator { /// Pushes an [`Input`](crate::compiler::RegOp::Input) operation to the tape #[inline(always)] fn op_input(&mut self, out: u32, i: u32) { - // TODO: tightly pack variables (which may be sparse) into slots - self.out.var_count = self.out.var_count.max(i as u32 + 1); self.op_out_only(out, |out| RegOp::Input(out, i)); } } diff --git a/fidget/src/core/compiler/reg_tape.rs b/fidget/src/core/compiler/reg_tape.rs index 4887019b..77e54245 100644 --- a/fidget/src/core/compiler/reg_tape.rs +++ b/fidget/src/core/compiler/reg_tape.rs @@ -10,9 +10,6 @@ pub struct RegTape { /// Total allocated slots pub(super) slot_count: u32, - - /// Number of variables - pub(super) var_count: u32, } impl RegTape { @@ -35,7 +32,6 @@ impl RegTape { Self { tape: vec![], slot_count: 1, - var_count: 0, } } @@ -50,11 +46,6 @@ impl RegTape { pub fn slot_count(&self) -> usize { self.slot_count as usize } - /// Returns the number of variables (inputs) used in this tape - #[inline] - pub fn var_count(&self) -> usize { - self.var_count as usize - } /// Returns the number of elements in the tape #[inline] pub fn len(&self) -> usize { diff --git a/fidget/src/core/eval/tracing.rs b/fidget/src/core/eval/tracing.rs index c3c0a873..690f736f 100644 --- a/fidget/src/core/eval/tracing.rs +++ b/fidget/src/core/eval/tracing.rs @@ -51,9 +51,7 @@ pub trait TracingEvaluator: Default { vars: &[Self::Data], var_count: usize, ) -> Result<(), Error> { - // It's fine if the caller has given us extra variables (e.g. due to - // tape simplification), but it must have given us enough. - if vars.len() < var_count { + if vars.len() != var_count { Err(Error::BadVarSlice(vars.len(), var_count)) } else { Ok(()) diff --git a/fidget/src/core/vm/data.rs b/fidget/src/core/vm/data.rs index 59e00697..9523e997 100644 --- a/fidget/src/core/vm/data.rs +++ b/fidget/src/core/vm/data.rs @@ -109,9 +109,12 @@ impl VmData { self.asm.slot_count() } - /// Returns the number of variables (inputs) in the inner VM tape + /// Returns the number of variables that may be used + /// + /// Note that this can sometimes be an overestimate, if the inner tape has + /// been simplified. pub fn var_count(&self) -> usize { - self.asm.var_count() + self.vars.len() } /// Simplifies both inner tapes, using the provided choice array diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index ca3f1fba..85199c56 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -992,7 +992,6 @@ impl JitTracingEval { ) -> (T, Option<&VmTrace>) { let mut simplify = 0; self.choices.resize(tape.choice_count, Choice::Unknown); - assert!(tape.var_count <= 3); self.choices.fill(Choice::Unknown); let out = unsafe { (tape.fn_trace)( @@ -1114,7 +1113,6 @@ unsafe impl Sync for JitBulkFn {} impl + Copy + SimdSize> JitBulkEval { /// Evaluate multiple points fn eval(&mut self, tape: &JitBulkFn, vars: &[&[T]]) -> &[T] { - assert!(tape.var_count <= 3); let n = vars.first().map(|v| v.len()).unwrap_or(0); self.out.resize(n, f32::NAN.into()); self.out.fill(f32::NAN.into()); From 5a3e76368390f319cc3b7cf795241e6830e18f7a Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sat, 25 May 2024 11:54:44 -0400 Subject: [PATCH 8/8] hakari update --- Cargo.lock | 1 + workspace-hack/Cargo.toml | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index c22e3e29..c155e5be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2945,6 +2945,7 @@ dependencies = [ "clap", "clap_builder", "crossbeam-utils", + "getrandom", "libc", "log", "memchr", diff --git a/workspace-hack/Cargo.toml b/workspace-hack/Cargo.toml index cd563f22..f9f4b5d8 100644 --- a/workspace-hack/Cargo.toml +++ b/workspace-hack/Cargo.toml @@ -19,12 +19,14 @@ bytemuck = { version = "1", default-features = false, features = ["derive", "ext clap = { version = "4", features = ["derive"] } clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "suggestions", "usage"] } crossbeam-utils = { version = "0.8" } +getrandom = { version = "0.2", default-features = false, features = ["std"] } once_cell = { version = "1" } regex = { version = "1", default-features = false, features = ["perf", "std"] } regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal"] } -serde = { version = "1", features = ["alloc", "derive"] } +serde = { version = "1", features = ["alloc", "derive", "rc"] } [build-dependencies] +getrandom = { version = "0.2", default-features = false, features = ["std"] } once_cell = { version = "1" } proc-macro2 = { version = "1" } quote = { version = "1" }