diff --git a/CHANGELOG.md b/CHANGELOG.md index 189d26fd..4ed6b27b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ given a context-specific `Node`. - Fix possible corruption of `x24` during AArch64 float slice JIT evaluation, due to incorrect stack alignment. +- Added `Context::deriv` and `Tree::deriv` to do symbolic differentiation of + math expressions. # 0.3.1 The highlight of this release is the `fidget::solver` module, which implements diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index 32fdefba..5c1c6884 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -703,6 +703,57 @@ impl Context { self.op_binary(a, b, BinaryOpcode::Compare) } + /// Builds a node that is 1 if `lhs < rhs` and 0 otherwise + /// + /// ``` + /// # let mut ctx = fidget::context::Context::new(); + /// let x = ctx.x(); + /// let y = ctx.y(); + /// let op = ctx.less_than(x, y).unwrap(); + /// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap(); + /// assert_eq!(v, 1.0); + /// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap(); + /// assert_eq!(v, 0.0); + /// let v = ctx.eval_xyz(op, 2.0, 1.0, 0.0).unwrap(); + /// assert_eq!(v, 0.0); + /// ``` + pub fn less_than( + &mut self, + lhs: A, + rhs: B, + ) -> Result { + let lhs = lhs.into_node(self)?; + let rhs = rhs.into_node(self)?; + let cmp = self.op_binary(rhs, lhs, BinaryOpcode::Compare)?; + self.max(cmp, 0.0) + } + + /// Builds a node that is 1 if `lhs <= rhs` and 0 otherwise + /// + /// ``` + /// # let mut ctx = fidget::context::Context::new(); + /// let x = ctx.x(); + /// let y = ctx.y(); + /// let op = ctx.less_than_or_equal(x, y).unwrap(); + /// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap(); + /// assert_eq!(v, 1.0); + /// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap(); + /// assert_eq!(v, 1.0); + /// let v = ctx.eval_xyz(op, 2.0, 1.0, 0.0).unwrap(); + /// assert_eq!(v, 0.0); + /// ``` + pub fn less_than_or_equal( + &mut self, + lhs: A, + rhs: B, + ) -> Result { + let lhs = lhs.into_node(self)?; + let rhs = rhs.into_node(self)?; + let cmp = self.op_binary(rhs, lhs, BinaryOpcode::Compare)?; + let shift = self.add(cmp, 1.0)?; + self.min(shift, 1.0) + } + /// Builds a node that takes the modulo (least non-negative remainder) pub fn modulo( &mut self, @@ -1227,6 +1278,221 @@ impl Context { assert_eq!(stack.len(), 1); Ok(stack.pop().unwrap()) } + + /// Takes the symbolic derivative of a node with respect to a variable + pub fn deriv(&mut self, n: Node, v: Var) -> Result { + if self.get_op(n).is_none() { + return Err(Error::BadNode); + } + + // Do recursion on the heap to avoid stack overflows for deep trees + enum Action { + /// Pushes `Up(n)` followed by `Down(n)` for each child + Down(Node), + /// Consumes trees from the stack and pushes a new tree + Up(Node, Op), + } + let mut todo = vec![Action::Down(n)]; + let mut stack = vec![]; + let zero = self.constant(0.0); + + // Cache of Node -> Node mapping, for deduplication + let mut seen: HashMap = HashMap::new(); + + while let Some(t) = todo.pop() { + match t { + Action::Down(n) => { + // If we've already seen this TreeOp with these axes, then + // we can return the previous Node. + if let Some(p) = seen.get(&n) { + stack.push(*p); + continue; + } + let op = *self.get_op(n).unwrap(); + match op { + Op::Const(_c) => { + seen.insert(n, zero); + stack.push(zero); + } + Op::Input(u) => { + let z = + if v == u { self.constant(1.0) } else { zero }; + seen.insert(n, z); + stack.push(z); + } + Op::Unary(_op, arg) => { + todo.push(Action::Up(n, op)); + todo.push(Action::Down(arg)); + } + Op::Binary(_op, lhs, rhs) => { + todo.push(Action::Up(n, op)); + todo.push(Action::Down(lhs)); + todo.push(Action::Down(rhs)); + } + } + } + Action::Up(n, op) => match op { + Op::Const(..) | Op::Input(..) => unreachable!(), + Op::Unary(op, v_arg) => { + let d_arg = stack.pop().unwrap(); + let out = match op { + UnaryOpcode::Neg => self.neg(d_arg), + UnaryOpcode::Abs => { + let cond = self.less_than(v_arg, zero).unwrap(); + let pos = d_arg; + let neg = self.neg(d_arg).unwrap(); + self.if_nonzero_else(cond, neg, pos) + } + UnaryOpcode::Recip => { + let a = self.square(v_arg).unwrap(); + let b = self.neg(d_arg).unwrap(); + self.div(b, a) + } + UnaryOpcode::Sqrt => { + let v = self.mul(n, 2.0).unwrap(); + self.div(d_arg, v) + } + UnaryOpcode::Square => { + let v = self.mul(d_arg, v_arg).unwrap(); + self.mul(2.0, v) + } + // Discontinuous constants don't have Dirac deltas + UnaryOpcode::Floor + | UnaryOpcode::Ceil + | UnaryOpcode::Round => Ok(zero), + + UnaryOpcode::Sin => { + let c = self.cos(v_arg).unwrap(); + self.mul(c, d_arg) + } + + UnaryOpcode::Cos => { + let s = self.sin(v_arg).unwrap(); + let s = self.neg(s).unwrap(); + self.mul(s, d_arg) + } + + UnaryOpcode::Tan => { + let c = self.cos(v_arg).unwrap(); + let c = self.square(c).unwrap(); + self.div(d_arg, c) + } + + UnaryOpcode::Asin => { + let v = self.square(v_arg).unwrap(); + let v = self.sub(1.0, v).unwrap(); + let v = self.sqrt(v).unwrap(); + self.div(d_arg, v) + } + UnaryOpcode::Acos => { + let v = self.square(v_arg).unwrap(); + let v = self.sub(1.0, v).unwrap(); + let v = self.sqrt(v).unwrap(); + let v = self.neg(v).unwrap(); + self.div(d_arg, v) + } + UnaryOpcode::Atan => { + let v = self.square(v_arg).unwrap(); + let v = self.add(1.0, v).unwrap(); + self.div(d_arg, v) + } + UnaryOpcode::Exp => self.mul(n, d_arg), + UnaryOpcode::Ln => self.div(d_arg, v_arg), + UnaryOpcode::Not => Ok(zero), + } + .unwrap(); + seen.insert(n, out); + stack.push(out); + } + Op::Binary(op, v_lhs, v_rhs) => { + let d_lhs = stack.pop().unwrap(); + let d_rhs = stack.pop().unwrap(); + let out = match op { + BinaryOpcode::Add => self.add(d_lhs, d_rhs), + BinaryOpcode::Sub => self.sub(d_lhs, d_rhs), + BinaryOpcode::Mul => { + let a = self.mul(d_lhs, v_rhs).unwrap(); + let b = self.mul(v_lhs, d_rhs).unwrap(); + self.add(a, b) + } + BinaryOpcode::Div => { + let v = self.square(v_rhs).unwrap(); + let a = self.mul(v_rhs, d_lhs).unwrap(); + let b = self.mul(v_lhs, d_rhs).unwrap(); + let c = self.sub(a, b).unwrap(); + self.div(c, v) + } + BinaryOpcode::Atan => { + let a = self.square(v_lhs).unwrap(); + let b = self.square(v_rhs).unwrap(); + let d = self.add(a, b).unwrap(); + + let a = self.mul(v_rhs, d_lhs).unwrap(); + let b = self.mul(v_lhs, d_rhs).unwrap(); + let v = self.sub(a, b).unwrap(); + self.div(v, d) + } + BinaryOpcode::Min => { + let cond = + self.less_than(v_lhs, v_rhs).unwrap(); + self.if_nonzero_else(cond, d_lhs, d_rhs) + } + BinaryOpcode::Max => { + let cond = + self.less_than(v_rhs, v_lhs).unwrap(); + self.if_nonzero_else(cond, d_lhs, d_rhs) + } + BinaryOpcode::Compare => Ok(zero), + BinaryOpcode::Mod => { + let e = self.div(v_lhs, v_rhs).unwrap(); + let q = self.floor(e).unwrap(); + + // XXX + // (we don't actually have %, so hack it from + // `modulo`, which is actually `rem_euclid`) + // ??? + let m = self.modulo(q, v_rhs).unwrap(); + let cond = self.less_than(q, zero).unwrap(); + let offset = self + .if_nonzero_else(cond, v_rhs, zero) + .unwrap(); + let m = self.sub(m, offset).unwrap(); + + // Torn from the div_euclid implementation + let outer = self.less_than(m, zero).unwrap(); + let inner = + self.less_than(zero, v_rhs).unwrap(); + let qa = self.sub(q, 1.0).unwrap(); + let qb = self.add(q, 1.0).unwrap(); + let inner = self + .if_nonzero_else(inner, qa, qb) + .unwrap(); + let e = self + .if_nonzero_else(outer, inner, q) + .unwrap(); + + let v = self.mul(d_rhs, e).unwrap(); + self.sub(d_lhs, v) + } + BinaryOpcode::And => { + let cond = self.compare(v_lhs, zero).unwrap(); + self.if_nonzero_else(cond, d_rhs, d_lhs) + } + BinaryOpcode::Or => { + let cond = self.compare(v_lhs, zero).unwrap(); + self.if_nonzero_else(cond, d_lhs, d_rhs) + } + } + .unwrap(); + seen.insert(n, out); + stack.push(out); + } + }, + } + } + assert_eq!(stack.len(), 1); + Ok(stack.pop().unwrap()) + } } //////////////////////////////////////////////////////////////////////////////// diff --git a/fidget/src/core/context/tree.rs b/fidget/src/core/context/tree.rs index 8be0fc72..e64a6cc6 100644 --- a/fidget/src/core/context/tree.rs +++ b/fidget/src/core/context/tree.rs @@ -174,6 +174,13 @@ impl Tree { None } } + + /// Performs symbolic differentiation with respect to the given variable + pub fn deriv(&self, v: Var) -> Tree { + let mut ctx = crate::Context::new(); + let node = ctx.import(self); + ctx.deriv(node, v).and_then(|d| ctx.export(d)).unwrap() + } } impl TryFrom for Var { @@ -496,4 +503,23 @@ mod test { let root = ctx.import(&b); assert_eq!(ctx.get_const(root).unwrap(), 15.0); } + + #[test] + fn tree_deriv() { + // dx/dx = 1 + let x = Tree::x(); + let vx = x.var().unwrap(); + let d = x.deriv(vx); + let TreeOp::Const(v) = *d else { + panic!("invalid deriv {d:?}") + }; + assert_eq!(v, 1.0); + + // dx/dv = 0 + let d = x.deriv(Var::new()); + let TreeOp::Const(v) = *d else { + panic!("invalid deriv {d:?}") + }; + assert_eq!(v, 0.0); + } } diff --git a/fidget/src/core/eval/test/mod.rs b/fidget/src/core/eval/test/mod.rs index ee4a8b42..d868d23d 100644 --- a/fidget/src/core/eval/test/mod.rs +++ b/fidget/src/core/eval/test/mod.rs @@ -4,6 +4,9 @@ pub mod grad_slice; pub mod interval; pub mod point; +// Internal-only tests +mod symbolic_deriv; + use crate::{ context::{Context, IntoNode, Node}, eval::Tape, diff --git a/fidget/src/core/eval/test/symbolic_deriv.rs b/fidget/src/core/eval/test/symbolic_deriv.rs new file mode 100644 index 00000000..3a652414 --- /dev/null +++ b/fidget/src/core/eval/test/symbolic_deriv.rs @@ -0,0 +1,159 @@ +use super::{test_args, CanonicalBinaryOp, CanonicalUnaryOp}; +use crate::{ + context::Context, + eval::{BulkEvaluator, Function, MathFunction, Tape}, + types::Grad, + var::Var, + vm::VmFunction, +}; + +/// Helper struct to test symbolic differentiation +pub struct TestSymbolicDerivs; + +impl TestSymbolicDerivs { + pub fn test_unary() { + let args = test_args(); + + let mut ctx = Context::new(); + let v = ctx.var(Var::new()); + let node = C::build(&mut ctx, v); + let shape = VmFunction::new(&ctx, node).unwrap(); + let tape = shape.grad_slice_tape(Default::default()); + let mut eval = VmFunction::new_grad_slice_eval(); + + let node_deriv = ctx.deriv(node, ctx.get_var(v).unwrap()).unwrap(); + let shape_deriv = VmFunction::new(&ctx, node_deriv).unwrap(); + let tape_deriv = shape_deriv.float_slice_tape(Default::default()); + let mut eval_deriv = VmFunction::new_float_slice_eval(); + + let args_g = args + .iter() + .map(|&v| Grad::new(v, 1.0, 0.0, 0.0)) + .collect::>(); + let out = eval.eval(&tape, &[args_g.as_slice()]).unwrap(); + + // Check symbolic differentiation results + let out_deriv = + eval_deriv.eval(&tape_deriv, &[args.as_slice()]).unwrap(); + for (v, (a, b)) in args.iter().zip(out.iter().zip(out_deriv)) { + let a = a.dx; + let err = a - b; + let err_frac = err / a.abs().max(b.abs()); + assert!( + a == *b + || err < 1e-6 + || err_frac < 1e-6 + || (a.is_nan() && b.is_nan()) + || v.is_nan(), + "mismatch in '{}' at {v}: {a} != {b} ({err})", + C::NAME, + ); + } + } + + pub fn test_binary() { + let args = test_args(); + + let mut ctx = Context::new(); + let va = Var::new(); + let vb = Var::new(); + let a = ctx.var(va); + let b = ctx.var(vb); + + let mut eval = VmFunction::new_grad_slice_eval(); + let mut eval_deriv = VmFunction::new_float_slice_eval(); + + let node = C::build(&mut ctx, a, b); + let shape = VmFunction::new(&ctx, node).unwrap(); + let tape = shape.grad_slice_tape(Default::default()); + + let node_a_deriv = ctx.deriv(node, va).unwrap(); + let shape_a_deriv = VmFunction::new(&ctx, node_a_deriv).unwrap(); + let tape_a_deriv = shape_a_deriv.float_slice_tape(Default::default()); + + let node_b_deriv = ctx.deriv(node, vb).unwrap(); + let shape_b_deriv = VmFunction::new(&ctx, node_b_deriv).unwrap(); + let tape_b_deriv = shape_b_deriv.float_slice_tape(Default::default()); + + for rot in 0..args.len() { + let mut rgsa = args.clone(); + rgsa.rotate_left(rot); + + let args_g = args + .iter() + .map(|v| Grad::new(*v, 1.0, 0.0, 0.0)) + .collect::>(); + let rgsa_g = rgsa + .iter() + .map(|v| Grad::new(*v, 0.0, 1.0, 0.0)) + .collect::>(); + + let ia = shape.vars().get(&va).unwrap(); + let ib = shape.vars().get(&vb).unwrap(); + let mut vs = [[].as_slice(), [].as_slice()]; + vs[ia] = args_g.as_slice(); + vs[ib] = rgsa_g.as_slice(); + let out = eval.eval(&tape, &vs).unwrap(); + + // Check symbolic differentiation results + let mut vs = [args.as_slice(), args.as_slice()]; + if let Some(ia) = shape_a_deriv.vars().get(&va) { + vs[ia] = args.as_slice(); + } + if let Some(ib) = shape_a_deriv.vars().get(&vb) { + vs[ib] = rgsa.as_slice(); + } + let out_a_deriv = + eval_deriv.eval(&tape_a_deriv, &vs).unwrap().to_vec(); + + let mut vs = [args.as_slice(), args.as_slice()]; + if let Some(ia) = shape_b_deriv.vars().get(&va) { + vs[ia] = args.as_slice(); + } + if let Some(ib) = shape_b_deriv.vars().get(&vb) { + vs[ib] = rgsa.as_slice(); + } + let out_b_deriv = eval_deriv.eval(&tape_b_deriv, &vs).unwrap(); + + for i in 0..out.len() { + let v = out[i]; + let da = out_a_deriv[i]; + + let a = args[i]; + let b = rgsa[i]; + + let err = v.dx - da; + let err_frac = err / da.abs().max(v.dx.abs()); + assert!( + v.dx == da + || err < 1e-6 + || err_frac < 1e-6 + || (v.dx.is_nan() && da.is_nan()) + || v.v.is_nan(), + "mismatch in 'd {}(a, b) / da' at ({a}, {b}): \ + {} != {da} ({err})", + C::NAME, + v.dx + ); + + let db = out_b_deriv[i]; + let err = v.dy - db; + let err_frac = err / db.abs().max(v.dy.abs()); + assert!( + v.dy == db + || err < 1e-6 + || err_frac < 1e-6 + || (v.dy.is_nan() && db.is_nan()) + || v.v.is_nan(), + "mismatch in 'd {}(a, b) / db' at ({a}, {b}): \ + {} != {db} ({err})", + C::NAME, + v.dx + ); + } + } + } +} + +crate::all_unary_tests!(TestSymbolicDerivs); +crate::all_binary_tests!(TestSymbolicDerivs);