Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up Var implementation #117

Merged
merged 8 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@
- Using the `VmShape` or `JitShape` types should be mostly the same as
before; changes are most noticeable if you're writing things that are
generic across `S: Shape`.
- Major refactoring of how variables are handled
- Removed `VarNode`; the canonical variable type is `Var`, which is its own
unique index.
- Removed named variables, to make `Var` trivially `Copy + Clone`.
- Added `vars()` method to `Function` trait, allowing users to look up the
mapping from variable to evaluation index. A `Var` now represents a
persistent identity from `Tree` to `Context` to `Function` evaluation.
- Move `Var` and `VarMap` into `fidget::vars` module, because they're no
longer specific to a `Context`.
- `Op::Input` now takes a `u32` argument, instead of a `u8`

# 0.2.7
This release brings us to opcode parity with `libfive`'s operators, adding
Expand Down
38 changes: 38 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion fidget/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ nalgebra = "0.31"
num-derive = "0.3"
num-traits = "0.2"
ordered-float = "3"
rand = "0.8.5"
static_assertions = "1"
thiserror = "1"
workspace-hack = { version = "0.1", path = "../workspace-hack" }
serde = { version = "1.0", features = ["derive"] }
serde = { version = "1.0", features = ["derive", "rc"] }

# JIT
dynasmrt = { version = "2.0", optional = true }
Expand Down
6 changes: 2 additions & 4 deletions fidget/src/core/compiler/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl<const N: usize> RegisterAllocator<N> {
#[inline(always)]
pub fn op(&mut self, op: SsaOp) {
match op {
SsaOp::Input(out, i) => self.op_input(out, i.try_into().unwrap()),
SsaOp::Input(out, i) => self.op_input(out, i),
SsaOp::CopyImm(out, imm) => self.op_copy_imm(out, imm),

SsaOp::NegReg(..)
Expand Down Expand Up @@ -673,9 +673,7 @@ impl<const N: usize> RegisterAllocator<N> {

/// Pushes an [`Input`](crate::compiler::RegOp::Input) operation to the tape
#[inline(always)]
fn op_input(&mut self, out: u32, i: u8) {
// TODO: tightly pack variables (which may be sparse) into slots
self.out.var_count = self.out.var_count.max(i as u32 + 1);
fn op_input(&mut self, out: u32, i: u32) {
self.op_out_only(out, |out| RegOp::Input(out, i));
}
}
4 changes: 2 additions & 2 deletions fidget/src/core/compiler/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ macro_rules! opcodes {
) => {
$(#[$($attrss)*])*
pub enum $name {
#[doc = "Read one of the inputs (X, Y, Z)"]
Input($t, $t),
#[doc = "Read an input variable by index"]
Input($t, u32),

#[doc = "Negate the given register"]
NegReg($t, $t),
Expand Down
9 changes: 0 additions & 9 deletions fidget/src/core/compiler/reg_tape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ pub struct RegTape {

/// Total allocated slots
pub(super) slot_count: u32,

/// Number of variables
pub(super) var_count: u32,
}

impl RegTape {
Expand All @@ -35,7 +32,6 @@ impl RegTape {
Self {
tape: vec![],
slot_count: 1,
var_count: 0,
}
}

Expand All @@ -50,11 +46,6 @@ impl RegTape {
pub fn slot_count(&self) -> usize {
self.slot_count as usize
}
/// Returns the number of variables (inputs) used in this tape
#[inline]
pub fn var_count(&self) -> usize {
self.var_count as usize
}
/// Returns the number of elements in the tape
#[inline]
pub fn len(&self) -> usize {
Expand Down
16 changes: 9 additions & 7 deletions fidget/src/core/compiler/ssa_tape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::{
compiler::SsaOp,
context::{BinaryOpcode, Node, Op, UnaryOpcode},
eval::VarMap,
var::VarMap,
Context, Error,
};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -33,7 +33,10 @@ impl SsaTape {
///
/// This should always succeed unless the `root` is from a different
/// `Context`, in which case `Error::BadNode` will be returned.
pub fn new(ctx: &Context, root: Node) -> Result<(Self, VarMap), Error> {
pub fn new(
ctx: &Context,
root: Node,
) -> Result<(Self, VarMap<usize>), Error> {
let mut mapping = HashMap::new();
let mut parent_count: HashMap<Node, usize> = HashMap::new();
let mut slot_count = 0;
Expand All @@ -47,7 +50,7 @@ impl SsaTape {

// Accumulate parent counts and declare all nodes
let mut seen = HashSet::new();
let mut vars = HashMap::new();
let mut vars = VarMap::new();
let mut todo = vec![root];
while let Some(node) = todo.pop() {
if !seen.insert(node) {
Expand All @@ -61,8 +64,7 @@ impl SsaTape {
_ => {
if let Op::Input(v) = op {
let next = vars.len();
let v = ctx.get_var_by_index(*v).unwrap().clone();
vars.entry(v).or_insert(next);
vars.entry(*v).or_insert(next);
}
let i = slot_count;
slot_count += 1;
Expand Down Expand Up @@ -97,8 +99,8 @@ impl SsaTape {
continue;
};
let op = match op {
Op::Input(..) => {
let arg = vars[ctx.var_name(node).unwrap().unwrap()];
Op::Input(v) => {
let arg = vars[v];
SsaOp::Input(i, arg.try_into().unwrap())
}
Op::Const(..) => {
Expand Down
86 changes: 21 additions & 65 deletions fidget/src/core/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
//!
//! - A [`Tree`] is a free-floating math expression, which can be cloned
//! and has overloaded operators for ease of use. It is **not** deduplicated;
//! two calls to [`Tree::x`] will produce two different [`TreeOp`] objects.
//! two calls to [`Tree::constant(1.0)`](Tree::constant) will allocate two
//! different [`TreeOp`] objects.
//! `Tree` objects are typically used when building up expressions; they
//! should be converted to `Node` objects (in a particular `Context`) after
//! they have been constructed.
Expand All @@ -22,7 +23,7 @@ use indexed::{define_index, Index, IndexMap, IndexVec};
pub use op::{BinaryOpcode, Op, UnaryOpcode};
pub use tree::{Tree, TreeOp};

use crate::Error;
use crate::{var::Var, Error};

use std::collections::{BTreeMap, HashMap};
use std::fmt::Write;
Expand All @@ -32,7 +33,6 @@ use std::sync::Arc;
use ordered_float::OrderedFloat;

define_index!(Node, "An index in the `Context::ops` map");
define_index!(VarNode, "An index in the `Context::vars` map");

/// A `Context` holds a set of deduplicated constants, variables, and
/// operations.
Expand All @@ -42,39 +42,6 @@ define_index!(VarNode, "An index in the `Context::vars` map");
#[derive(Debug, Default)]
pub struct Context {
ops: IndexMap<Op, Node>,
vars: IndexMap<Var, VarNode>,
}

/// A `Var` represents a value which can vary during evaluation
///
/// We pre-define common variables (e.g. X, Y, Z) but also allow for fully
/// customized values.
#[allow(missing_docs)]
#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub enum Var {
X,
Y,
Z,
W,
T,
Static(&'static str),
Named(String),
Value(u64),
}

impl std::fmt::Display for Var {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Var::X => write!(f, "X"),
Var::Y => write!(f, "Y"),
Var::Z => write!(f, "Z"),
Var::W => write!(f, "W"),
Var::T => write!(f, "T"),
Var::Static(s) => write!(f, "{s}"),
Var::Named(s) => write!(f, "{s}"),
Var::Value(v) => write!(f, "v_{v}"),
}
}
}

impl Context {
Expand All @@ -85,7 +52,7 @@ impl Context {

/// Clears the context
///
/// All [`Node`] and [`VarNode`] handles from this context are invalidated.
/// All [`Node`] handles from this context are invalidated.
///
/// ```
/// # use fidget::context::Context;
Expand All @@ -96,7 +63,6 @@ impl Context {
/// ```
pub fn clear(&mut self) {
self.ops.clear();
self.vars.clear();
}

/// Returns the number of [`Op`] nodes in the context
Expand Down Expand Up @@ -155,25 +121,17 @@ impl Context {
///
/// If the node is invalid for this tree, returns an error; if the node is
/// not an `Op::Input`, returns `Ok(None)`.
pub fn var_name(&self, n: Node) -> Result<Option<&Var>, Error> {
pub fn get_var(&self, n: Node) -> Result<Option<Var>, Error> {
match self.get_op(n) {
Some(Op::Input(c)) => self.get_var_by_index(*c).map(Some),
Some(Op::Input(v)) => Ok(Some(*v)),
Some(_) => Ok(None),
_ => Err(Error::BadNode),
}
}

/// Looks up the [`Var`] associated with the given [`VarNode`]
pub fn get_var_by_index(&self, n: VarNode) -> Result<&Var, Error> {
match self.vars.get_by_index(n) {
Some(c) => Ok(c),
None => Err(Error::BadVar),
}
}

////////////////////////////////////////////////////////////////////////////
// Primitives
/// Constructs or finds a variable node named "X"
/// Constructs or finds a [`Var::X`] node
/// ```
/// # use fidget::context::Context;
/// let mut ctx = Context::new();
Expand All @@ -182,19 +140,21 @@ impl Context {
/// assert_eq!(v, 1.0);
/// ```
pub fn x(&mut self) -> Node {
let v = self.vars.insert(Var::X);
self.ops.insert(Op::Input(v))
self.var(Var::X)
}

/// Constructs or finds a variable node named "Y"
/// Constructs or finds a [`Var::Y`] node
pub fn y(&mut self) -> Node {
let v = self.vars.insert(Var::Y);
self.ops.insert(Op::Input(v))
self.var(Var::Y)
}

/// Constructs or finds a variable node named "Z"
/// Constructs or finds a [`Var::Z`] node
pub fn z(&mut self) -> Node {
let v = self.vars.insert(Var::Z);
self.var(Var::Z)
}

/// Constructs or finds a variable input node
pub fn var(&mut self, v: Var) -> Node {
self.ops.insert(Op::Input(v))
}

Expand Down Expand Up @@ -822,10 +782,7 @@ impl Context {
}
let mut get = |n: Node| self.eval_inner(n, vars, cache);
let v = match self.get_op(node).ok_or(Error::BadNode)? {
Op::Input(v) => {
let var_name = self.vars.get_by_index(*v).unwrap();
*vars.get(var_name).unwrap()
}
Op::Input(v) => *vars.get(v).unwrap(),
Op::Const(c) => c.0,

Op::Binary(op, a, b) => {
Expand Down Expand Up @@ -995,7 +952,6 @@ impl Context {
match op {
Op::Const(c) => write!(out, "{}", c).unwrap(),
Op::Input(v) => {
let v = self.vars.get_by_index(*v).unwrap();
out += &v.to_string();
}
Op::Binary(op, ..) => match op {
Expand Down Expand Up @@ -1245,9 +1201,9 @@ mod test {
let c8 = ctx.sub(c7, r).unwrap();
let c9 = ctx.max(c8, c6).unwrap();

let (tape, vs) = VmData::<255>::new(&ctx, c9).unwrap();
let tape = VmData::<255>::new(&ctx, c9).unwrap();
assert_eq!(tape.len(), 8);
assert_eq!(vs.len(), 2);
assert_eq!(tape.vars.len(), 2);
}

#[test]
Expand All @@ -1256,8 +1212,8 @@ mod test {
let x = ctx.x();
let x_squared = ctx.mul(x, x).unwrap();

let (tape, vs) = VmData::<255>::new(&ctx, x_squared).unwrap();
let tape = VmData::<255>::new(&ctx, x_squared).unwrap();
assert_eq!(tape.len(), 2);
assert_eq!(vs.len(), 1);
assert_eq!(tape.vars.len(), 1);
}
}
Loading