Skip to content

Commit

Permalink
Tweak constant handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeeter committed May 31, 2024
1 parent bffe0c2 commit bf4465a
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 24 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# 0.3.1 (unreleased)
- Fixed a bug in the x86 JIT which could corrupt registers during gradient
(`grad_slice`) evaluation
- Renamed `Context::const_value` to `Context::get_const` and tweaked its return
type to match `Context::get_var`.
- Added `impl From<i32> for Tree` to make writing tree expressions easier

# 0.3.0
- Major refactoring of core evaluation traits
Expand Down
2 changes: 1 addition & 1 deletion fidget/src/core/compiler/ssa_tape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ impl SsaTape {
// Special case if the Node is a single constant, which isn't usually
// recorded in the tape
if tape.is_empty() {
let c = ctx.const_value(root).unwrap().unwrap() as f32;
let c = ctx.get_const(root).unwrap() as f32;
tape.push(SsaOp::CopyImm(0, c));
}

Expand Down
38 changes: 19 additions & 19 deletions fidget/src/core/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,18 @@ impl Context {
///
/// If the node is invalid for this tree, returns an error; if the node is
/// not a constant, returns `Ok(None)`.
pub fn const_value(&self, n: Node) -> Result<Option<f64>, Error> {
pub fn get_const(&self, n: Node) -> Result<f64, Error> {
match self.get_op(n) {
Some(Op::Const(c)) => Ok(Some(c.0)),
Some(_) => Ok(None),
Some(Op::Const(c)) => Ok(c.0),
Some(_) => Err(Error::NotAConst),
_ => Err(Error::BadNode),
}
}

/// Looks up the [`Var`] associated with the given node.
///
/// If the node is invalid for this tree, returns an error; if the node is
/// not an `Op::Input`, returns `Ok(None)`.
/// If the node is invalid for this tree or not an `Op::Input`, returns an
/// error.
pub fn get_var(&self, n: Node) -> Result<Var, Error> {
match self.get_op(n) {
Some(Op::Input(v)) => Ok(*v),
Expand Down Expand Up @@ -274,9 +274,9 @@ impl Context {
let two = self.constant(2.0);
self.mul(a, two)
} else {
match (self.const_value(a)?, self.const_value(b)?) {
(Some(zero), _) if zero == 0.0 => Ok(b),
(_, Some(zero)) if zero == 0.0 => Ok(a),
match (self.get_const(a), self.get_const(b)) {
(Ok(zero), _) if zero == 0.0 => Ok(b),
(_, Ok(zero)) if zero == 0.0 => Ok(a),
_ => self.op_binary_commutative(a, b, BinaryOpcode::Add),
}
}
Expand All @@ -300,11 +300,11 @@ impl Context {
if a == b {
self.square(a)
} else {
match (self.const_value(a)?, self.const_value(b)?) {
(Some(one), _) if one == 1.0 => Ok(b),
(_, Some(one)) if one == 1.0 => Ok(a),
(Some(zero), _) if zero == 0.0 => Ok(a),
(_, Some(zero)) if zero == 0.0 => Ok(b),
match (self.get_const(a), self.get_const(b)) {
(Ok(one), _) if one == 1.0 => Ok(b),
(_, Ok(one)) if one == 1.0 => Ok(a),
(Ok(zero), _) if zero == 0.0 => Ok(a),
(_, Ok(zero)) if zero == 0.0 => Ok(b),
_ => self.op_binary_commutative(a, b, BinaryOpcode::Mul),
}
}
Expand Down Expand Up @@ -627,9 +627,9 @@ impl Context {
let a = a.into_node(self)?;
let b = b.into_node(self)?;

match (self.const_value(a)?, self.const_value(b)?) {
(Some(zero), _) if zero == 0.0 => self.neg(b),
(_, Some(zero)) if zero == 0.0 => Ok(a),
match (self.get_const(a), self.get_const(b)) {
(Ok(zero), _) if zero == 0.0 => self.neg(b),
(_, Ok(zero)) if zero == 0.0 => Ok(a),
_ => self.op_binary(a, b, BinaryOpcode::Sub),
}
}
Expand All @@ -651,9 +651,9 @@ impl Context {
let a = a.into_node(self)?;
let b = b.into_node(self)?;

match (self.const_value(a)?, self.const_value(b)?) {
(Some(zero), _) if zero == 0.0 => Ok(a),
(_, Some(one)) if one == 1.0 => Ok(a),
match (self.get_const(a), self.get_const(b)) {
(Ok(zero), _) if zero == 0.0 => Ok(a),
(_, Ok(one)) if one == 1.0 => Ok(a),
_ => self.op_binary(a, b, BinaryOpcode::Div),
}
}
Expand Down
16 changes: 16 additions & 0 deletions fidget/src/core/context/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ impl From<f32> for Tree {
}
}

impl From<i32> for Tree {
fn from(v: i32) -> Tree {
Tree::constant(v as f64)
}
}

impl From<Var> for Tree {
fn from(v: Var) -> Tree {
Tree(Arc::new(TreeOp::Input(v)))
Expand Down Expand Up @@ -458,4 +464,14 @@ mod test {
{large:?} is not much larger than {small:?}"
);
}

#[test]
fn tree_from_int() {
let a = Tree::from(3);
let b = a * 5;

let mut ctx = Context::new();
let root = ctx.import(&b);
assert_eq!(ctx.get_const(root).unwrap(), 15.0);
}
}
9 changes: 5 additions & 4 deletions fidget/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pub mod vm;
mod test {
use crate::context::*;
use crate::var::Var;
use crate::Error;

#[test]
fn it_works() {
Expand All @@ -78,14 +79,14 @@ mod test {
let a = ctx.constant(1.0);
let b = ctx.constant(1.0);
assert_eq!(a, b);
assert_eq!(ctx.const_value(a).unwrap(), Some(1.0));
assert_eq!(ctx.const_value(x1).unwrap(), None);
assert_eq!(ctx.get_const(a).unwrap(), 1.0);
assert!(matches!(ctx.get_const(x1), Err(Error::NotAConst)));

let c = ctx.add(a, b).unwrap();
assert_eq!(ctx.const_value(c).unwrap(), Some(2.0));
assert_eq!(ctx.get_const(c).unwrap(), 2.0);

let c = ctx.neg(c).unwrap();
assert_eq!(ctx.const_value(c).unwrap(), Some(-2.0));
assert_eq!(ctx.get_const(c).unwrap(), -2.0);
}

#[test]
Expand Down
4 changes: 4 additions & 0 deletions fidget/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ pub enum Error {
#[error("node does not have an associated variable")]
NotAVar,

/// The given node is not a constant
#[error("node is not a constant")]
NotAConst,

/// `Context` is empty
#[error("`Context` is empty")]
EmptyContext,
Expand Down

0 comments on commit bf4465a

Please sign in to comment.