Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 0.3.0 #124

Merged
merged 7 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# 0.2.8
# 0.3.0
- Major refactoring of core evaluation traits
- The lowest-level "thing that can be evaluated" trait has changed from
`Shape` (taking `(x, y, z)` inputs) to `Function` (taking an arbitrary
Expand Down
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion demo/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ fn main() -> Result<()> {
let now = Instant::now();
let args = Args::parse();
let mut file = std::fs::File::open(&args.input)?;
let (mut ctx, root) = Context::from_text(&mut file)?;
let (ctx, root) = Context::from_text(&mut file)?;
info!("Loaded file in {:?}", now.elapsed());

match args.cmd {
Expand Down
4 changes: 2 additions & 2 deletions fidget/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fidget"
version = "0.2.8"
version = "0.3.0"
edition = "2021"
license = "MPL-2.0"
repository = "https://github.com/mkeeter/fidget"
Expand Down Expand Up @@ -60,7 +60,7 @@ render = []
## Enable 3D meshing, in the [`fidget::mesh`](crate::mesh) module
mesh = ["dep:crossbeam-deque"]

## Enable `eval-tests` if you're writing your own Shape / evaluators and want to
## Enable `eval-tests` if you're writing your own evaluators and want to
## unit-test them. When enabled, the crate exports a set of macros to test each
## evaluator type, e.g. `float_slice_tests!(...)`.
eval-tests = []
Expand Down
41 changes: 31 additions & 10 deletions fidget/src/core/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//! - A [`Tree`] is a free-floating math expression, which can be cloned
//! and has overloaded operators for ease of use. It is **not** deduplicated;
//! two calls to [`Tree::constant(1.0)`](Tree::constant) will allocate two
//! different [`TreeOp`] objects.
//! different objects.
//! `Tree` objects are typically used when building up expressions; they
//! should be converted to `Node` objects (in a particular `Context`) after
//! they have been constructed.
Expand Down Expand Up @@ -38,7 +38,12 @@ define_index!(Node, "An index in the `Context::ops` map");
/// operations.
///
/// It should be used like an arena allocator: it grows over time, then frees
/// all of its contents when dropped.
/// all of its contents when dropped. There is no reference counting within the
/// context.
///
/// Items in the context are accessed with [`Node`] keys, which are simple
/// handles into an internal map. Inside the context, operations are
/// represented with the [`Op`] type.
#[derive(Debug, Default)]
pub struct Context {
ops: IndexMap<Op, Node>,
Expand Down Expand Up @@ -121,10 +126,10 @@ impl Context {
///
/// If the node is invalid for this tree, returns an error; if the node is
/// not an `Op::Input`, returns `Ok(None)`.
pub fn get_var(&self, n: Node) -> Result<Option<Var>, Error> {
pub fn get_var(&self, n: Node) -> Result<Var, Error> {
match self.get_op(n) {
Some(Op::Input(v)) => Ok(Some(*v)),
Some(_) => Ok(None),
Some(Op::Input(v)) => Ok(*v),
Some(..) => Err(Error::NotAVar),
_ => Err(Error::BadNode),
}
}
Expand Down Expand Up @@ -154,6 +159,22 @@ impl Context {
}

/// Constructs or finds a variable input node
///
/// To make an anonymous variable, call this function with [`Var::new()`]:
///
/// ```
/// # use fidget::{context::Context, var::Var};
/// # use std::collections::HashMap;
/// let mut ctx = Context::new();
/// let v1 = ctx.var(Var::new());
/// let v2 = ctx.var(Var::new());
/// assert_ne!(v1, v2);
///
/// let mut vars = HashMap::new();
/// vars.insert(ctx.get_var(v1).unwrap(), 3.0);
/// assert_eq!(ctx.eval(v1, &vars).unwrap(), 3.0);
/// assert!(ctx.eval(v2, &vars).is_err()); // v2 isn't in the map
/// ```
pub fn var(&mut self, v: Var) -> Node {
self.ops.insert(Op::Input(v))
}
Expand Down Expand Up @@ -181,7 +202,7 @@ impl Context {
let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
let n = self.ops.insert(Op::Unary(op, a));
let out = if matches!(op_a, Op::Const(_)) {
let v = self.eval(n, &BTreeMap::new())?;
let v = self.eval(n, &Default::default())?;
self.pop().unwrap(); // removes `n`
self.constant(v)
} else {
Expand Down Expand Up @@ -214,7 +235,7 @@ impl Context {
// constant-folded (indeed, we pop the node right afterwards)
let n = self.ops.insert(f(a, b));
let out = if matches!((op_a, op_b), (Op::Const(_), Op::Const(_))) {
let v = self.eval(n, &BTreeMap::new())?;
let v = self.eval(n, &Default::default())?;
self.pop().unwrap(); // removes `n`
self.constant(v)
} else {
Expand Down Expand Up @@ -762,7 +783,7 @@ impl Context {
pub fn eval(
&self,
root: Node,
vars: &BTreeMap<Var, f64>,
vars: &HashMap<Var, f64>,
) -> Result<f64, Error> {
let mut cache = vec![None; self.ops.len()].into();
self.eval_inner(root, vars, &mut cache)
Expand All @@ -771,7 +792,7 @@ impl Context {
fn eval_inner(
&self,
node: Node,
vars: &BTreeMap<Var, f64>,
vars: &HashMap<Var, f64>,
cache: &mut IndexVec<Option<f64>, Node>,
) -> Result<f64, Error> {
if node.0 >= cache.len() {
Expand All @@ -782,7 +803,7 @@ impl Context {
}
let mut get = |n: Node| self.eval_inner(n, vars, cache);
let v = match self.get_op(node).ok_or(Error::BadNode)? {
Op::Input(v) => *vars.get(v).unwrap(),
Op::Input(v) => *vars.get(v).ok_or(Error::MissingVar(*v))?,
Op::Const(c) => c.0,

Op::Binary(op, a, b) => {
Expand Down
34 changes: 6 additions & 28 deletions fidget/src/core/eval/bulk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
use crate::{eval::Tape, Error};

/// Trait for bulk evaluation returning the given type `T`
///
/// It's uncommon to use this trait outside the library itself; it's an
/// abstraction to reduce code duplication, and is public because it's used as a
/// constraint on other public APIs.
pub trait BulkEvaluator: Default {
/// Data type used during evaluation
type Data: From<f32> + Copy + Clone;
Expand All @@ -30,8 +26,12 @@ pub trait BulkEvaluator: Default {

/// Evaluates many points using the given instruction tape
///
/// Returns an error if the `x`, `y`, `z`, and `out` slices are of different
/// lengths.
/// `vars` should be a slice-of-slices (or a slice-of-`Vec`s) representing
/// input arguments for each of the tape's variables; use [`Tape::vars`] to
/// map from [`Var`](crate::var::Var) to position in the list.
///
/// Returns an error if any of the `var` slices are of different lengths, or
/// if all variables aren't present.
fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
&mut self,
tape: &Self::Tape,
Expand All @@ -42,26 +42,4 @@ pub trait BulkEvaluator: Default {
fn new() -> Self {
Self::default()
}

/// Helper function to return an error if the inputs are invalid
fn check_arguments<V: std::ops::Deref<Target = [Self::Data]>>(
&self,
vars: &[V],
var_count: usize,
) -> Result<(), Error> {
// It's fine if the caller has given us extra variables (e.g. due to
// tape simplification), but it must have given us enough.
if vars.len() < var_count {
Err(Error::BadVarSlice(vars.len(), var_count))
} else {
let Some(n) = vars.first().map(|v| v.len()) else {
return Ok(());
};
if vars.iter().any(|v| v.len() == n) {
Ok(())
} else {
Err(Error::MismatchedSlices)
}
}
}
}
6 changes: 2 additions & 4 deletions fidget/src/core/eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ pub use tracing::TracingEvaluator;

/// A tape represents something that can be evaluated by an evaluator
///
/// It includes some kind of storage and the ability to look up variable
/// mapping. The variable mapping should be identical to calling
/// [`Function::vars`] on the `Function` which produced this tape, but it's
/// convenient to be able to look up vars locally.
/// It includes some kind of storage (which could be empty) and the ability to
/// look up variable mapping.
pub trait Tape {
/// Associated type for this tape's data storage
type Storage: Default;
Expand Down
1 change: 0 additions & 1 deletion fidget/src/core/eval/test/grad_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,6 @@ impl<F: Function + MathFunction> TestGradSlice<F> {
// that S is also a VmShape, but this comparison isn't particularly
// expensive, so we'll do it regardless.
let shape = VmFunction::new(&ctx, node).unwrap();
#[allow(clippy::unit_arg)]
let tape = shape.grad_slice_tape(Default::default());

let cmp = TestGradSlice::<VmFunction>::eval_xyz(&tape, &x, &y, &z);
Expand Down
21 changes: 6 additions & 15 deletions fidget/src/core/eval/tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ pub trait TracingEvaluator: Default {
type Trace;

/// Evaluates the given tape at a particular position
///
/// `vars` should be a slice of values representing input arguments for each
/// of the tape's variables; use [`Tape::vars`] to map from
/// [`Var`](crate::var::Var) to position in the list.
///
/// Returns an error if the `var` slice is not of sufficient length.
fn eval(
&mut self,
tape: &Self::Tape,
Expand All @@ -44,19 +50,4 @@ pub trait TracingEvaluator: Default {
fn new() -> Self {
Self::default()
}

/// Helper function to return an error if the inputs are invalid
fn check_arguments(
&self,
vars: &[Self::Data],
var_count: usize,
) -> Result<(), Error> {
if vars.len() < var_count {
// It's okay to be passed extra vars, because expressions may have
// been simplified.
Err(Error::BadVarSlice(vars.len(), var_count))
} else {
Ok(())
}
}
}
10 changes: 7 additions & 3 deletions fidget/src/core/shape/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Data structures for shape evaluation
//!
//! Types in this module are typically thin (generic) wrappers around objects
//! that implement traits in [`fidget::eval`](crate::eval). The wraper types
//! that implement traits in [`fidget::eval`](crate::eval). The wrapper types
//! are specialized to operate on `x, y, z` arguments, rather than taking
//! arbitrary numbers of variables.
//!
Expand Down Expand Up @@ -467,7 +467,11 @@ impl<E: BulkEvaluator> ShapeBulkEval<E>
where
E::Data: From<f32> + Transformable,
{
/// Bulk evaluation of many samples
/// Bulk evaluation of many samples, without any variables
///
/// If the shape includes variables other than `X`, `Y`, `Z`,
/// [`eval_v`](Self::eval_v) should be used instead (and this function will
/// return an error).
///
/// Before evaluation, the tape's transform matrix is applied (if present).
pub fn eval(
Expand All @@ -481,7 +485,7 @@ where
self.eval_v(tape, x, y, z, &h)
}

/// Bulk evaluation of many samples
/// Bulk evaluation of many samples, with variables
///
/// Before evaluation, the tape's transform matrix is applied (if present).
pub fn eval_v<V: std::ops::Deref<Target = [E::Data]>>(
Expand Down
51 changes: 45 additions & 6 deletions fidget/src/core/var/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
//! A [`Var`] maintains a persistent identity from
//! [`Tree`](crate::context::Tree) to [`Context`](crate::context::Node) (where
//! it is wrapped in a [`Op::Input`](crate::context::Op::Input)) to evaluation
//! (where [`Function::vars`](crate::eval::Function::vars) maps from `Var` to
//! index in the argument list).
//! (where [`Tape::vars`](crate::eval::Tape::vars) maps from `Var` to index in
//! the argument list).
use crate::Error;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

Expand All @@ -16,7 +17,6 @@ use std::collections::HashMap;
/// Variables are "global", in that every instance of `Var::X` represents the
/// same thing. To generate a "local" variable, [`Var::new`] picks a random
/// 64-bit value, which is very unlikely to collide with anything else.
#[allow(missing_docs)]
#[derive(
Copy,
Clone,
Expand All @@ -30,13 +30,17 @@ use std::collections::HashMap;
Deserialize,
)]
pub enum Var {
/// Variable representing the X axis for 2D / 3D shapes
X,
/// Variable representing the Y axis for 2D / 3D shapes
Y,
/// Variable representing the Z axis for 3D shapes
Z,
/// Generic variable
V(VarIndex),
}

/// Type for a variable index (implemented as a `u64`)
/// Type for a variable index (implemented as a `u64`), used in [`Var::V`]
#[derive(
Copy,
Clone,
Expand Down Expand Up @@ -91,8 +95,10 @@ impl std::fmt::Display for Var {
/// Variable indexes are automatically assigned the first time
/// [`VarMap::insert`] is called on that variable.
///
/// Indexes are guaranteed to be tightly packed, i.e. contains values from
/// `0..vars.len()`.
/// Indexes are guaranteed to be tightly packed, i.e. a map `vars` will contains
/// values from `0..vars.len()`.
///
/// For efficiency, this type does not allocate heap memory for `Var::X/Y/Z`.
#[derive(Default, Serialize, Deserialize)]
pub struct VarMap {
x: Option<usize>,
Expand Down Expand Up @@ -138,6 +144,39 @@ impl VarMap {
Var::V(v) => self.v.entry(v).or_insert(next),
};
}

pub(crate) fn check_tracing_arguments<T>(
&self,
vars: &[T],
) -> Result<(), Error> {
if vars.len() < self.len() {
// It's okay to be passed extra vars, because expressions may have
// been simplified.
Err(Error::BadVarSlice(vars.len(), self.len()))
} else {
Ok(())
}
}

pub(crate) fn check_bulk_arguments<T, V: std::ops::Deref<Target = [T]>>(
&self,
vars: &[V],
) -> Result<(), Error> {
// It's fine if the caller has given us extra variables (e.g. due to
// tape simplification), but it must have given us enough.
if vars.len() < self.len() {
Err(Error::BadVarSlice(vars.len(), self.len()))
} else {
let Some(n) = vars.first().map(|v| v.len()) else {
return Ok(());
};
if vars.iter().any(|v| v.len() == n) {
Ok(())
} else {
Err(Error::MismatchedSlices)
}
}
}
}

impl std::ops::Index<&Var> for VarMap {
Expand Down
8 changes: 0 additions & 8 deletions fidget/src/core/vm/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,6 @@ impl<const N: usize> VmData<N> {
self.asm.slot_count()
}

/// Returns the number of variables that may be used
///
/// Note that this can sometimes be an overestimate, if the inner tape has
/// been simplified.
pub fn var_count(&self) -> usize {
self.vars.len()
}

/// Simplifies both inner tapes, using the provided choice array
///
/// To minimize allocations, this function takes a [`VmWorkspace`] and
Expand Down
Loading