diff --git a/CHANGELOG.md b/CHANGELOG.md index ba1ab2df..fab0e923 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ # 0.3.1 (unreleased) - Fixed a bug in the x86 JIT which could corrupt registers during gradient (`grad_slice`) evaluation +- Renamed `Context::const_value` to `Context::get_const` and tweaked its return + type to match `Context::get_var`. +- Added `impl From for Tree` to make writing tree expressions easier # 0.3.0 - Major refactoring of core evaluation traits diff --git a/fidget/src/core/compiler/ssa_tape.rs b/fidget/src/core/compiler/ssa_tape.rs index 921fc291..e9b3d2cf 100644 --- a/fidget/src/core/compiler/ssa_tape.rs +++ b/fidget/src/core/compiler/ssa_tape.rs @@ -228,7 +228,7 @@ impl SsaTape { // Special case if the Node is a single constant, which isn't usually // recorded in the tape if tape.is_empty() { - let c = ctx.const_value(root).unwrap().unwrap() as f32; + let c = ctx.get_const(root).unwrap() as f32; tape.push(SsaOp::CopyImm(0, c)); } diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index 4643d0fb..09eda962 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -114,18 +114,18 @@ impl Context { /// /// If the node is invalid for this tree, returns an error; if the node is /// not a constant, returns `Ok(None)`. - pub fn const_value(&self, n: Node) -> Result, Error> { + pub fn get_const(&self, n: Node) -> Result { match self.get_op(n) { - Some(Op::Const(c)) => Ok(Some(c.0)), - Some(_) => Ok(None), + Some(Op::Const(c)) => Ok(c.0), + Some(_) => Err(Error::NotAConst), _ => Err(Error::BadNode), } } /// Looks up the [`Var`] associated with the given node. /// - /// If the node is invalid for this tree, returns an error; if the node is - /// not an `Op::Input`, returns `Ok(None)`. + /// If the node is invalid for this tree or not an `Op::Input`, returns an + /// error. pub fn get_var(&self, n: Node) -> Result { match self.get_op(n) { Some(Op::Input(v)) => Ok(*v), @@ -274,9 +274,9 @@ impl Context { let two = self.constant(2.0); self.mul(a, two) } else { - match (self.const_value(a)?, self.const_value(b)?) { - (Some(zero), _) if zero == 0.0 => Ok(b), - (_, Some(zero)) if zero == 0.0 => Ok(a), + match (self.get_const(a), self.get_const(b)) { + (Ok(zero), _) if zero == 0.0 => Ok(b), + (_, Ok(zero)) if zero == 0.0 => Ok(a), _ => self.op_binary_commutative(a, b, BinaryOpcode::Add), } } @@ -300,11 +300,11 @@ impl Context { if a == b { self.square(a) } else { - match (self.const_value(a)?, self.const_value(b)?) { - (Some(one), _) if one == 1.0 => Ok(b), - (_, Some(one)) if one == 1.0 => Ok(a), - (Some(zero), _) if zero == 0.0 => Ok(a), - (_, Some(zero)) if zero == 0.0 => Ok(b), + match (self.get_const(a), self.get_const(b)) { + (Ok(one), _) if one == 1.0 => Ok(b), + (_, Ok(one)) if one == 1.0 => Ok(a), + (Ok(zero), _) if zero == 0.0 => Ok(a), + (_, Ok(zero)) if zero == 0.0 => Ok(b), _ => self.op_binary_commutative(a, b, BinaryOpcode::Mul), } } @@ -627,9 +627,9 @@ impl Context { let a = a.into_node(self)?; let b = b.into_node(self)?; - match (self.const_value(a)?, self.const_value(b)?) { - (Some(zero), _) if zero == 0.0 => self.neg(b), - (_, Some(zero)) if zero == 0.0 => Ok(a), + match (self.get_const(a), self.get_const(b)) { + (Ok(zero), _) if zero == 0.0 => self.neg(b), + (_, Ok(zero)) if zero == 0.0 => Ok(a), _ => self.op_binary(a, b, BinaryOpcode::Sub), } } @@ -651,9 +651,9 @@ impl Context { let a = a.into_node(self)?; let b = b.into_node(self)?; - match (self.const_value(a)?, self.const_value(b)?) { - (Some(zero), _) if zero == 0.0 => Ok(a), - (_, Some(one)) if one == 1.0 => Ok(a), + match (self.get_const(a), self.get_const(b)) { + (Ok(zero), _) if zero == 0.0 => Ok(a), + (_, Ok(one)) if one == 1.0 => Ok(a), _ => self.op_binary(a, b, BinaryOpcode::Div), } } diff --git a/fidget/src/core/context/tree.rs b/fidget/src/core/context/tree.rs index c3dbb239..c46d8faa 100644 --- a/fidget/src/core/context/tree.rs +++ b/fidget/src/core/context/tree.rs @@ -95,6 +95,12 @@ impl From for Tree { } } +impl From for Tree { + fn from(v: i32) -> Tree { + Tree::constant(v as f64) + } +} + impl From for Tree { fn from(v: Var) -> Tree { Tree(Arc::new(TreeOp::Input(v))) @@ -458,4 +464,14 @@ mod test { {large:?} is not much larger than {small:?}" ); } + + #[test] + fn tree_from_int() { + let a = Tree::from(3); + let b = a * 5; + + let mut ctx = Context::new(); + let root = ctx.import(&b); + assert_eq!(ctx.get_const(root).unwrap(), 15.0); + } } diff --git a/fidget/src/core/mod.rs b/fidget/src/core/mod.rs index 2d66e3c3..a325db7a 100644 --- a/fidget/src/core/mod.rs +++ b/fidget/src/core/mod.rs @@ -67,6 +67,7 @@ pub mod vm; mod test { use crate::context::*; use crate::var::Var; + use crate::Error; #[test] fn it_works() { @@ -78,14 +79,14 @@ mod test { let a = ctx.constant(1.0); let b = ctx.constant(1.0); assert_eq!(a, b); - assert_eq!(ctx.const_value(a).unwrap(), Some(1.0)); - assert_eq!(ctx.const_value(x1).unwrap(), None); + assert_eq!(ctx.get_const(a).unwrap(), 1.0); + assert!(matches!(ctx.get_const(x1), Err(Error::NotAConst))); let c = ctx.add(a, b).unwrap(); - assert_eq!(ctx.const_value(c).unwrap(), Some(2.0)); + assert_eq!(ctx.get_const(c).unwrap(), 2.0); let c = ctx.neg(c).unwrap(); - assert_eq!(ctx.const_value(c).unwrap(), Some(-2.0)); + assert_eq!(ctx.get_const(c).unwrap(), -2.0); } #[test] diff --git a/fidget/src/error.rs b/fidget/src/error.rs index 92c4cc93..c27116b5 100644 --- a/fidget/src/error.rs +++ b/fidget/src/error.rs @@ -21,6 +21,10 @@ pub enum Error { #[error("node does not have an associated variable")] NotAVar, + /// The given node is not a constant + #[error("node is not a constant")] + NotAConst, + /// `Context` is empty #[error("`Context` is empty")] EmptyContext,