Skip to content

Commit

Permalink
Clean up Var implementation (#117)
Browse files Browse the repository at this point in the history
This is another step towards gracefully handling n-ary functions; see
the CHANGELOG for details.
  • Loading branch information
mkeeter authored May 25, 2024
1 parent 59f3e31 commit 8d5c0db
Show file tree
Hide file tree
Showing 27 changed files with 398 additions and 165 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@
- Using the `VmShape` or `JitShape` types should be mostly the same as
before; changes are most noticeable if you're writing things that are
generic across `S: Shape`.
- Major refactoring of how variables are handled
- Removed `VarNode`; the canonical variable type is `Var`, which is its own
unique index.
- Removed named variables, to make `Var` trivially `Copy + Clone`.
- Added `vars()` method to `Function` trait, allowing users to look up the
mapping from variable to evaluation index. A `Var` now represents a
persistent identity from `Tree` to `Context` to `Function` evaluation.
- Move `Var` and `VarMap` into `fidget::vars` module, because they're no
longer specific to a `Context`.
- `Op::Input` now takes a `u32` argument, instead of a `u8`

# 0.2.7
This release brings us to opcode parity with `libfive`'s operators, adding
Expand Down
38 changes: 38 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion fidget/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ nalgebra = "0.31"
num-derive = "0.3"
num-traits = "0.2"
ordered-float = "3"
rand = "0.8.5"
static_assertions = "1"
thiserror = "1"
workspace-hack = { version = "0.1", path = "../workspace-hack" }
serde = { version = "1.0", features = ["derive"] }
serde = { version = "1.0", features = ["derive", "rc"] }

# JIT
dynasmrt = { version = "2.0", optional = true }
Expand Down
6 changes: 2 additions & 4 deletions fidget/src/core/compiler/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl<const N: usize> RegisterAllocator<N> {
#[inline(always)]
pub fn op(&mut self, op: SsaOp) {
match op {
SsaOp::Input(out, i) => self.op_input(out, i.try_into().unwrap()),
SsaOp::Input(out, i) => self.op_input(out, i),
SsaOp::CopyImm(out, imm) => self.op_copy_imm(out, imm),

SsaOp::NegReg(..)
Expand Down Expand Up @@ -673,9 +673,7 @@ impl<const N: usize> RegisterAllocator<N> {

/// Pushes an [`Input`](crate::compiler::RegOp::Input) operation to the tape
#[inline(always)]
fn op_input(&mut self, out: u32, i: u8) {
// TODO: tightly pack variables (which may be sparse) into slots
self.out.var_count = self.out.var_count.max(i as u32 + 1);
fn op_input(&mut self, out: u32, i: u32) {
self.op_out_only(out, |out| RegOp::Input(out, i));
}
}
4 changes: 2 additions & 2 deletions fidget/src/core/compiler/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ macro_rules! opcodes {
) => {
$(#[$($attrss)*])*
pub enum $name {
#[doc = "Read one of the inputs (X, Y, Z)"]
Input($t, $t),
#[doc = "Read an input variable by index"]
Input($t, u32),

#[doc = "Negate the given register"]
NegReg($t, $t),
Expand Down
9 changes: 0 additions & 9 deletions fidget/src/core/compiler/reg_tape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ pub struct RegTape {

/// Total allocated slots
pub(super) slot_count: u32,

/// Number of variables
pub(super) var_count: u32,
}

impl RegTape {
Expand All @@ -35,7 +32,6 @@ impl RegTape {
Self {
tape: vec![],
slot_count: 1,
var_count: 0,
}
}

Expand All @@ -50,11 +46,6 @@ impl RegTape {
pub fn slot_count(&self) -> usize {
self.slot_count as usize
}
/// Returns the number of variables (inputs) used in this tape
#[inline]
pub fn var_count(&self) -> usize {
self.var_count as usize
}
/// Returns the number of elements in the tape
#[inline]
pub fn len(&self) -> usize {
Expand Down
16 changes: 9 additions & 7 deletions fidget/src/core/compiler/ssa_tape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::{
compiler::SsaOp,
context::{BinaryOpcode, Node, Op, UnaryOpcode},
eval::VarMap,
var::VarMap,
Context, Error,
};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -33,7 +33,10 @@ impl SsaTape {
///
/// This should always succeed unless the `root` is from a different
/// `Context`, in which case `Error::BadNode` will be returned.
pub fn new(ctx: &Context, root: Node) -> Result<(Self, VarMap), Error> {
pub fn new(
ctx: &Context,
root: Node,
) -> Result<(Self, VarMap<usize>), Error> {
let mut mapping = HashMap::new();
let mut parent_count: HashMap<Node, usize> = HashMap::new();
let mut slot_count = 0;
Expand All @@ -47,7 +50,7 @@ impl SsaTape {

// Accumulate parent counts and declare all nodes
let mut seen = HashSet::new();
let mut vars = HashMap::new();
let mut vars = VarMap::new();
let mut todo = vec![root];
while let Some(node) = todo.pop() {
if !seen.insert(node) {
Expand All @@ -61,8 +64,7 @@ impl SsaTape {
_ => {
if let Op::Input(v) = op {
let next = vars.len();
let v = ctx.get_var_by_index(*v).unwrap().clone();
vars.entry(v).or_insert(next);
vars.entry(*v).or_insert(next);
}
let i = slot_count;
slot_count += 1;
Expand Down Expand Up @@ -97,8 +99,8 @@ impl SsaTape {
continue;
};
let op = match op {
Op::Input(..) => {
let arg = vars[ctx.var_name(node).unwrap().unwrap()];
Op::Input(v) => {
let arg = vars[v];
SsaOp::Input(i, arg.try_into().unwrap())
}
Op::Const(..) => {
Expand Down
86 changes: 21 additions & 65 deletions fidget/src/core/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
//!
//! - A [`Tree`] is a free-floating math expression, which can be cloned
//! and has overloaded operators for ease of use. It is **not** deduplicated;
//! two calls to [`Tree::x`] will produce two different [`TreeOp`] objects.
//! two calls to [`Tree::constant(1.0)`](Tree::constant) will allocate two
//! different [`TreeOp`] objects.
//! `Tree` objects are typically used when building up expressions; they
//! should be converted to `Node` objects (in a particular `Context`) after
//! they have been constructed.
Expand All @@ -22,7 +23,7 @@ use indexed::{define_index, Index, IndexMap, IndexVec};
pub use op::{BinaryOpcode, Op, UnaryOpcode};
pub use tree::{Tree, TreeOp};

use crate::Error;
use crate::{var::Var, Error};

use std::collections::{BTreeMap, HashMap};
use std::fmt::Write;
Expand All @@ -32,7 +33,6 @@ use std::sync::Arc;
use ordered_float::OrderedFloat;

define_index!(Node, "An index in the `Context::ops` map");
define_index!(VarNode, "An index in the `Context::vars` map");

/// A `Context` holds a set of deduplicated constants, variables, and
/// operations.
Expand All @@ -42,39 +42,6 @@ define_index!(VarNode, "An index in the `Context::vars` map");
#[derive(Debug, Default)]
pub struct Context {
ops: IndexMap<Op, Node>,
vars: IndexMap<Var, VarNode>,
}

/// A `Var` represents a value which can vary during evaluation
///
/// We pre-define common variables (e.g. X, Y, Z) but also allow for fully
/// customized values.
#[allow(missing_docs)]
#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub enum Var {
X,
Y,
Z,
W,
T,
Static(&'static str),
Named(String),
Value(u64),
}

impl std::fmt::Display for Var {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Var::X => write!(f, "X"),
Var::Y => write!(f, "Y"),
Var::Z => write!(f, "Z"),
Var::W => write!(f, "W"),
Var::T => write!(f, "T"),
Var::Static(s) => write!(f, "{s}"),
Var::Named(s) => write!(f, "{s}"),
Var::Value(v) => write!(f, "v_{v}"),
}
}
}

impl Context {
Expand All @@ -85,7 +52,7 @@ impl Context {

/// Clears the context
///
/// All [`Node`] and [`VarNode`] handles from this context are invalidated.
/// All [`Node`] handles from this context are invalidated.
///
/// ```
/// # use fidget::context::Context;
Expand All @@ -96,7 +63,6 @@ impl Context {
/// ```
pub fn clear(&mut self) {
self.ops.clear();
self.vars.clear();
}

/// Returns the number of [`Op`] nodes in the context
Expand Down Expand Up @@ -155,25 +121,17 @@ impl Context {
///
/// If the node is invalid for this tree, returns an error; if the node is
/// not an `Op::Input`, returns `Ok(None)`.
pub fn var_name(&self, n: Node) -> Result<Option<&Var>, Error> {
pub fn get_var(&self, n: Node) -> Result<Option<Var>, Error> {
match self.get_op(n) {
Some(Op::Input(c)) => self.get_var_by_index(*c).map(Some),
Some(Op::Input(v)) => Ok(Some(*v)),
Some(_) => Ok(None),
_ => Err(Error::BadNode),
}
}

/// Looks up the [`Var`] associated with the given [`VarNode`]
pub fn get_var_by_index(&self, n: VarNode) -> Result<&Var, Error> {
match self.vars.get_by_index(n) {
Some(c) => Ok(c),
None => Err(Error::BadVar),
}
}

////////////////////////////////////////////////////////////////////////////
// Primitives
/// Constructs or finds a variable node named "X"
/// Constructs or finds a [`Var::X`] node
/// ```
/// # use fidget::context::Context;
/// let mut ctx = Context::new();
Expand All @@ -182,19 +140,21 @@ impl Context {
/// assert_eq!(v, 1.0);
/// ```
pub fn x(&mut self) -> Node {
let v = self.vars.insert(Var::X);
self.ops.insert(Op::Input(v))
self.var(Var::X)
}

/// Constructs or finds a variable node named "Y"
/// Constructs or finds a [`Var::Y`] node
pub fn y(&mut self) -> Node {
let v = self.vars.insert(Var::Y);
self.ops.insert(Op::Input(v))
self.var(Var::Y)
}

/// Constructs or finds a variable node named "Z"
/// Constructs or finds a [`Var::Z`] node
pub fn z(&mut self) -> Node {
let v = self.vars.insert(Var::Z);
self.var(Var::Z)
}

/// Constructs or finds a variable input node
pub fn var(&mut self, v: Var) -> Node {
self.ops.insert(Op::Input(v))
}

Expand Down Expand Up @@ -822,10 +782,7 @@ impl Context {
}
let mut get = |n: Node| self.eval_inner(n, vars, cache);
let v = match self.get_op(node).ok_or(Error::BadNode)? {
Op::Input(v) => {
let var_name = self.vars.get_by_index(*v).unwrap();
*vars.get(var_name).unwrap()
}
Op::Input(v) => *vars.get(v).unwrap(),
Op::Const(c) => c.0,

Op::Binary(op, a, b) => {
Expand Down Expand Up @@ -995,7 +952,6 @@ impl Context {
match op {
Op::Const(c) => write!(out, "{}", c).unwrap(),
Op::Input(v) => {
let v = self.vars.get_by_index(*v).unwrap();
out += &v.to_string();
}
Op::Binary(op, ..) => match op {
Expand Down Expand Up @@ -1245,9 +1201,9 @@ mod test {
let c8 = ctx.sub(c7, r).unwrap();
let c9 = ctx.max(c8, c6).unwrap();

let (tape, vs) = VmData::<255>::new(&ctx, c9).unwrap();
let tape = VmData::<255>::new(&ctx, c9).unwrap();
assert_eq!(tape.len(), 8);
assert_eq!(vs.len(), 2);
assert_eq!(tape.vars.len(), 2);
}

#[test]
Expand All @@ -1256,8 +1212,8 @@ mod test {
let x = ctx.x();
let x_squared = ctx.mul(x, x).unwrap();

let (tape, vs) = VmData::<255>::new(&ctx, x_squared).unwrap();
let tape = VmData::<255>::new(&ctx, x_squared).unwrap();
assert_eq!(tape.len(), 2);
assert_eq!(vs.len(), 1);
assert_eq!(tape.vars.len(), 1);
}
}
Loading

0 comments on commit 8d5c0db

Please sign in to comment.