Skip to content

Commit

Permalink
Add vars to shape evaluation (#119)
Browse files Browse the repository at this point in the history
This adds an `eval_v` function to the `Shape` API, taking a
`HashMap<VarIndex, T>` for (non-X/Y/Z) variable values.
  • Loading branch information
mkeeter authored May 26, 2024
1 parent 8dc51ee commit 184dc12
Show file tree
Hide file tree
Showing 13 changed files with 393 additions and 179 deletions.
8 changes: 2 additions & 6 deletions fidget/src/core/compiler/ssa_tape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ impl SsaTape {
///
/// This should always succeed unless the `root` is from a different
/// `Context`, in which case `Error::BadNode` will be returned.
pub fn new(
ctx: &Context,
root: Node,
) -> Result<(Self, VarMap<usize>), Error> {
pub fn new(ctx: &Context, root: Node) -> Result<(Self, VarMap), Error> {
let mut mapping = HashMap::new();
let mut parent_count: HashMap<Node, usize> = HashMap::new();
let mut slot_count = 0;
Expand All @@ -63,8 +60,7 @@ impl SsaTape {
}
_ => {
if let Op::Input(v) = op {
let next = vars.len();
vars.entry(*v).or_insert(next);
vars.insert(*v);
}
let i = slot_count;
slot_count += 1;
Expand Down
8 changes: 4 additions & 4 deletions fidget/src/core/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1051,10 +1051,10 @@ impl Context {
TreeOp::Input(s) => {
let axes = axes.last().unwrap();
stack.push(match *s {
"X" => axes.0,
"Y" => axes.1,
"Z" => axes.2,
s => panic!("invalid tree input string {s:?}"),
Var::X => axes.0,
Var::Y => axes.1,
Var::Z => axes.2,
v @ Var::V(..) => self.var(v),
});
}
TreeOp::Unary(_op, arg) => {
Expand Down
17 changes: 12 additions & 5 deletions fidget/src/core/context/tree.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Context-free math trees
use super::op::{BinaryOpcode, UnaryOpcode};
use crate::var::Var;
use std::sync::Arc;

/// Opcode type for trees
Expand All @@ -9,8 +10,8 @@ use std::sync::Arc;
#[derive(Debug)]
#[allow(missing_docs)]
pub enum TreeOp {
/// Input (at the moment, limited to "X", "Y", "Z")
Input(&'static str),
/// Input (an arbitrary [`Var`])
Input(Var),
Const(f64),
Binary(BinaryOpcode, Arc<TreeOp>, Arc<TreeOp>),
Unary(UnaryOpcode, Arc<TreeOp>),
Expand Down Expand Up @@ -94,6 +95,12 @@ impl From<f32> for Tree {
}
}

impl From<Var> for Tree {
fn from(v: Var) -> Tree {
Tree(Arc::new(TreeOp::Input(v)))
}
}

/// Owned handle for a standalone math tree
#[derive(Clone, Debug)]
pub struct Tree(Arc<TreeOp>);
Expand Down Expand Up @@ -152,13 +159,13 @@ impl Tree {
#[allow(missing_docs)]
impl Tree {
pub fn x() -> Self {
Tree(Arc::new(TreeOp::Input("X")))
Tree(Arc::new(TreeOp::Input(Var::X)))
}
pub fn y() -> Self {
Tree(Arc::new(TreeOp::Input("Y")))
Tree(Arc::new(TreeOp::Input(Var::Y)))
}
pub fn z() -> Self {
Tree(Arc::new(TreeOp::Input("Z")))
Tree(Arc::new(TreeOp::Input(Var::Z)))
}
pub fn constant(f: f64) -> Self {
Tree(Arc::new(TreeOp::Const(f)))
Expand Down
8 changes: 4 additions & 4 deletions fidget/src/core/eval/bulk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ pub trait BulkEvaluator: Default {
///
/// Returns an error if the `x`, `y`, `z`, and `out` slices are of different
/// lengths.
fn eval(
fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
&mut self,
tape: &Self::Tape,
vars: &[&[Self::Data]],
vars: &[V],
) -> Result<&[Self::Data], Error>;

/// Build a new empty evaluator
Expand All @@ -44,9 +44,9 @@ pub trait BulkEvaluator: Default {
}

/// Helper function to return an error if the inputs are invalid
fn check_arguments(
fn check_arguments<V: std::ops::Deref<Target = [Self::Data]>>(
&self,
vars: &[&[Self::Data]],
vars: &[V],
var_count: usize,
) -> Result<(), Error> {
// It's fine if the caller has given us extra variables (e.g. due to
Expand Down
5 changes: 4 additions & 1 deletion fidget/src/core/eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ pub trait Tape {

/// Retrieves the internal storage from this tape
fn recycle(self) -> Self::Storage;

/// Returns a mapping from [`Var`](crate::var::Var) to evaluation index
fn vars(&self) -> &VarMap;
}

/// Represents the trace captured by a tracing evaluation
Expand Down Expand Up @@ -168,7 +171,7 @@ pub trait Function: Send + Sync + Clone {
fn size(&self) -> usize;

/// Returns a map from variable to index
fn vars(&self) -> &VarMap<usize>;
fn vars(&self) -> &VarMap;
}

/// A [`Function`] which can be built from a math expression
Expand Down
49 changes: 49 additions & 0 deletions fidget/src/core/eval/test/float_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ use crate::{
context::Context,
eval::{Function, MathFunction},
shape::{EzShape, Shape},
var::{Var, VarIndex},
Error,
};
use std::collections::HashMap;

/// Helper struct to put constrains on our `Shape` object
pub struct TestFloatSlice<F>(std::marker::PhantomData<*const F>);
Expand Down Expand Up @@ -134,6 +137,51 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {
);
}

pub fn test_f_var() {
let v = Var::new();
let mut ctx = Context::new();

let x = ctx.x();
let y = ctx.y();
let v_ = ctx.var(v);

let a = ctx.add(x, y).unwrap();
let a = ctx.add(a, v_).unwrap();

let s = Shape::<F>::new(&mut ctx, a).unwrap();

let mut eval = Shape::<F>::new_float_slice_eval();
let tape = s.ez_float_slice_tape();
assert!(eval
.eval(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0])
.is_err());
let mut h: HashMap<VarIndex, &[f32]> = HashMap::new();
assert!(eval
.eval_v(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0], &h)
.is_err());
let index = v.index().unwrap();
h.insert(index, &[4.0, 5.0]);
assert_eq!(
eval.eval_v(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0], &h)
.unwrap(),
&[7.0, 10.0]
);
h.insert(index, &[4.0, 5.0, 6.0]);
assert!(matches!(
eval.eval_v(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0], &h),
Err(Error::MismatchedSlices)
));

// Get a new var index that isn't valid for this tape
let v2 = Var::new();
h.insert(index, &[4.0, 5.0]);
h.insert(v2.index().unwrap(), &[4.0, 5.0]);
assert!(matches!(
eval.eval_v(&tape, &[1.0, 2.0], &[2.0, 3.0], &[0.0, 0.0], &h),
Err(Error::BadVarSlice(..))
));
}

pub fn test_f_stress_n(depth: usize) {
let (mut ctx, node) = build_stress_fn(depth);

Expand Down Expand Up @@ -395,6 +443,7 @@ macro_rules! float_slice_tests {
$crate::float_slice_test!(test_give_take, $t);
$crate::float_slice_test!(test_vectorized, $t);
$crate::float_slice_test!(test_f_sin, $t);
$crate::float_slice_test!(test_f_var, $t);
$crate::float_slice_test!(test_f_stress, $t);

mod f_unary {
Expand Down
30 changes: 30 additions & 0 deletions fidget/src/core/eval/test/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ use crate::{
context::Context,
eval::{Function, MathFunction},
shape::{EzShape, Shape},
var::Var,
vm::Choice,
};
use std::collections::HashMap;

/// Helper struct to put constrains on our `Shape` object
pub struct TestPoint<F>(std::marker::PhantomData<*const F>);
Expand Down Expand Up @@ -306,6 +308,33 @@ where
assert_eq!(eval.eval(&tape, 1.0, 2.0, 0.0).unwrap().0, 6.0);
}

pub fn test_p_var() {
let v = Var::new();
let mut ctx = Context::new();

let x = ctx.x();
let y = ctx.y();
let v_ = ctx.var(v);

let a = ctx.add(x, y).unwrap();
let a = ctx.add(a, v_).unwrap();

let s = Shape::<F>::new(&mut ctx, a).unwrap();

let mut eval = Shape::<F>::new_point_eval();
let tape = s.ez_point_tape();
assert!(eval.eval(&tape, 1.0, 2.0, 0.0).is_err());

let mut h = HashMap::new();
assert!(eval.eval_v(&tape, 1.0, 2.0, 0.0, &h).is_err());

let index = v.index().unwrap();
h.insert(index, 3.0);
assert_eq!(eval.eval_v(&tape, 1.0, 2.0, 0.0, &h).unwrap().0, 6.0);
h.insert(index, 4.0);
assert_eq!(eval.eval_v(&tape, 1.0, 2.0, 0.0, &h).unwrap().0, 7.0);
}

pub fn test_p_stress_n(depth: usize) {
let (mut ctx, node) = build_stress_fn(depth);

Expand Down Expand Up @@ -564,6 +593,7 @@ macro_rules! point_tests {
$crate::point_test!(basic_interpreter, $t);
$crate::point_test!(test_push, $t);
$crate::point_test!(test_basic, $t);
$crate::point_test!(test_p_var, $t);
$crate::point_test!(test_p_stress, $t);

mod p_unary {
Expand Down
Loading

0 comments on commit 184dc12

Please sign in to comment.