Skip to content

Commit

Permalink
Add symbolic differentiation (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeeter authored Jul 19, 2024
1 parent a08d3b7 commit c2697ac
Show file tree
Hide file tree
Showing 5 changed files with 456 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
given a context-specific `Node`.
- Fix possible corruption of `x24` during AArch64 float slice JIT evaluation,
due to incorrect stack alignment.
- Added `Context::deriv` and `Tree::deriv` to do symbolic differentiation of
math expressions.

# 0.3.1
The highlight of this release is the `fidget::solver` module, which implements
Expand Down
266 changes: 266 additions & 0 deletions fidget/src/core/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,57 @@ impl Context {
self.op_binary(a, b, BinaryOpcode::Compare)
}

/// Builds a node that is 1 if `lhs < rhs` and 0 otherwise
///
/// ```
/// # let mut ctx = fidget::context::Context::new();
/// let x = ctx.x();
/// let y = ctx.y();
/// let op = ctx.less_than(x, y).unwrap();
/// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap();
/// assert_eq!(v, 1.0);
/// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap();
/// assert_eq!(v, 0.0);
/// let v = ctx.eval_xyz(op, 2.0, 1.0, 0.0).unwrap();
/// assert_eq!(v, 0.0);
/// ```
pub fn less_than<A: IntoNode, B: IntoNode>(
&mut self,
lhs: A,
rhs: B,
) -> Result<Node, Error> {
let lhs = lhs.into_node(self)?;
let rhs = rhs.into_node(self)?;
let cmp = self.op_binary(rhs, lhs, BinaryOpcode::Compare)?;
self.max(cmp, 0.0)
}

/// Builds a node that is 1 if `lhs <= rhs` and 0 otherwise
///
/// ```
/// # let mut ctx = fidget::context::Context::new();
/// let x = ctx.x();
/// let y = ctx.y();
/// let op = ctx.less_than_or_equal(x, y).unwrap();
/// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap();
/// assert_eq!(v, 1.0);
/// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap();
/// assert_eq!(v, 1.0);
/// let v = ctx.eval_xyz(op, 2.0, 1.0, 0.0).unwrap();
/// assert_eq!(v, 0.0);
/// ```
pub fn less_than_or_equal<A: IntoNode, B: IntoNode>(
&mut self,
lhs: A,
rhs: B,
) -> Result<Node, Error> {
let lhs = lhs.into_node(self)?;
let rhs = rhs.into_node(self)?;
let cmp = self.op_binary(rhs, lhs, BinaryOpcode::Compare)?;
let shift = self.add(cmp, 1.0)?;
self.min(shift, 1.0)
}

/// Builds a node that takes the modulo (least non-negative remainder)
pub fn modulo<A: IntoNode, B: IntoNode>(
&mut self,
Expand Down Expand Up @@ -1227,6 +1278,221 @@ impl Context {
assert_eq!(stack.len(), 1);
Ok(stack.pop().unwrap())
}

/// Takes the symbolic derivative of a node with respect to a variable
pub fn deriv(&mut self, n: Node, v: Var) -> Result<Node, Error> {
if self.get_op(n).is_none() {
return Err(Error::BadNode);
}

// Do recursion on the heap to avoid stack overflows for deep trees
enum Action {
/// Pushes `Up(n)` followed by `Down(n)` for each child
Down(Node),
/// Consumes trees from the stack and pushes a new tree
Up(Node, Op),
}
let mut todo = vec![Action::Down(n)];
let mut stack = vec![];
let zero = self.constant(0.0);

// Cache of Node -> Node mapping, for deduplication
let mut seen: HashMap<Node, Node> = HashMap::new();

while let Some(t) = todo.pop() {
match t {
Action::Down(n) => {
// If we've already seen this TreeOp with these axes, then
// we can return the previous Node.
if let Some(p) = seen.get(&n) {
stack.push(*p);
continue;
}
let op = *self.get_op(n).unwrap();
match op {
Op::Const(_c) => {
seen.insert(n, zero);
stack.push(zero);
}
Op::Input(u) => {
let z =
if v == u { self.constant(1.0) } else { zero };
seen.insert(n, z);
stack.push(z);
}
Op::Unary(_op, arg) => {
todo.push(Action::Up(n, op));
todo.push(Action::Down(arg));
}
Op::Binary(_op, lhs, rhs) => {
todo.push(Action::Up(n, op));
todo.push(Action::Down(lhs));
todo.push(Action::Down(rhs));
}
}
}
Action::Up(n, op) => match op {
Op::Const(..) | Op::Input(..) => unreachable!(),
Op::Unary(op, v_arg) => {
let d_arg = stack.pop().unwrap();
let out = match op {
UnaryOpcode::Neg => self.neg(d_arg),
UnaryOpcode::Abs => {
let cond = self.less_than(v_arg, zero).unwrap();
let pos = d_arg;
let neg = self.neg(d_arg).unwrap();
self.if_nonzero_else(cond, neg, pos)
}
UnaryOpcode::Recip => {
let a = self.square(v_arg).unwrap();
let b = self.neg(d_arg).unwrap();
self.div(b, a)
}
UnaryOpcode::Sqrt => {
let v = self.mul(n, 2.0).unwrap();
self.div(d_arg, v)
}
UnaryOpcode::Square => {
let v = self.mul(d_arg, v_arg).unwrap();
self.mul(2.0, v)
}
// Discontinuous constants don't have Dirac deltas
UnaryOpcode::Floor
| UnaryOpcode::Ceil
| UnaryOpcode::Round => Ok(zero),

UnaryOpcode::Sin => {
let c = self.cos(v_arg).unwrap();
self.mul(c, d_arg)
}

UnaryOpcode::Cos => {
let s = self.sin(v_arg).unwrap();
let s = self.neg(s).unwrap();
self.mul(s, d_arg)
}

UnaryOpcode::Tan => {
let c = self.cos(v_arg).unwrap();
let c = self.square(c).unwrap();
self.div(d_arg, c)
}

UnaryOpcode::Asin => {
let v = self.square(v_arg).unwrap();
let v = self.sub(1.0, v).unwrap();
let v = self.sqrt(v).unwrap();
self.div(d_arg, v)
}
UnaryOpcode::Acos => {
let v = self.square(v_arg).unwrap();
let v = self.sub(1.0, v).unwrap();
let v = self.sqrt(v).unwrap();
let v = self.neg(v).unwrap();
self.div(d_arg, v)
}
UnaryOpcode::Atan => {
let v = self.square(v_arg).unwrap();
let v = self.add(1.0, v).unwrap();
self.div(d_arg, v)
}
UnaryOpcode::Exp => self.mul(n, d_arg),
UnaryOpcode::Ln => self.div(d_arg, v_arg),
UnaryOpcode::Not => Ok(zero),
}
.unwrap();
seen.insert(n, out);
stack.push(out);
}
Op::Binary(op, v_lhs, v_rhs) => {
let d_lhs = stack.pop().unwrap();
let d_rhs = stack.pop().unwrap();
let out = match op {
BinaryOpcode::Add => self.add(d_lhs, d_rhs),
BinaryOpcode::Sub => self.sub(d_lhs, d_rhs),
BinaryOpcode::Mul => {
let a = self.mul(d_lhs, v_rhs).unwrap();
let b = self.mul(v_lhs, d_rhs).unwrap();
self.add(a, b)
}
BinaryOpcode::Div => {
let v = self.square(v_rhs).unwrap();
let a = self.mul(v_rhs, d_lhs).unwrap();
let b = self.mul(v_lhs, d_rhs).unwrap();
let c = self.sub(a, b).unwrap();
self.div(c, v)
}
BinaryOpcode::Atan => {
let a = self.square(v_lhs).unwrap();
let b = self.square(v_rhs).unwrap();
let d = self.add(a, b).unwrap();

let a = self.mul(v_rhs, d_lhs).unwrap();
let b = self.mul(v_lhs, d_rhs).unwrap();
let v = self.sub(a, b).unwrap();
self.div(v, d)
}
BinaryOpcode::Min => {
let cond =
self.less_than(v_lhs, v_rhs).unwrap();
self.if_nonzero_else(cond, d_lhs, d_rhs)
}
BinaryOpcode::Max => {
let cond =
self.less_than(v_rhs, v_lhs).unwrap();
self.if_nonzero_else(cond, d_lhs, d_rhs)
}
BinaryOpcode::Compare => Ok(zero),
BinaryOpcode::Mod => {
let e = self.div(v_lhs, v_rhs).unwrap();
let q = self.floor(e).unwrap();

// XXX
// (we don't actually have %, so hack it from
// `modulo`, which is actually `rem_euclid`)
// ???
let m = self.modulo(q, v_rhs).unwrap();
let cond = self.less_than(q, zero).unwrap();
let offset = self
.if_nonzero_else(cond, v_rhs, zero)
.unwrap();
let m = self.sub(m, offset).unwrap();

// Torn from the div_euclid implementation
let outer = self.less_than(m, zero).unwrap();
let inner =
self.less_than(zero, v_rhs).unwrap();
let qa = self.sub(q, 1.0).unwrap();
let qb = self.add(q, 1.0).unwrap();
let inner = self
.if_nonzero_else(inner, qa, qb)
.unwrap();
let e = self
.if_nonzero_else(outer, inner, q)
.unwrap();

let v = self.mul(d_rhs, e).unwrap();
self.sub(d_lhs, v)
}
BinaryOpcode::And => {
let cond = self.compare(v_lhs, zero).unwrap();
self.if_nonzero_else(cond, d_rhs, d_lhs)
}
BinaryOpcode::Or => {
let cond = self.compare(v_lhs, zero).unwrap();
self.if_nonzero_else(cond, d_lhs, d_rhs)
}
}
.unwrap();
seen.insert(n, out);
stack.push(out);
}
},
}
}
assert_eq!(stack.len(), 1);
Ok(stack.pop().unwrap())
}
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
26 changes: 26 additions & 0 deletions fidget/src/core/context/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ impl Tree {
None
}
}

/// Performs symbolic differentiation with respect to the given variable
pub fn deriv(&self, v: Var) -> Tree {
let mut ctx = crate::Context::new();
let node = ctx.import(self);
ctx.deriv(node, v).and_then(|d| ctx.export(d)).unwrap()
}
}

impl TryFrom<Tree> for Var {
Expand Down Expand Up @@ -496,4 +503,23 @@ mod test {
let root = ctx.import(&b);
assert_eq!(ctx.get_const(root).unwrap(), 15.0);
}

#[test]
fn tree_deriv() {
// dx/dx = 1
let x = Tree::x();
let vx = x.var().unwrap();
let d = x.deriv(vx);
let TreeOp::Const(v) = *d else {
panic!("invalid deriv {d:?}")
};
assert_eq!(v, 1.0);

// dx/dv = 0
let d = x.deriv(Var::new());
let TreeOp::Const(v) = *d else {
panic!("invalid deriv {d:?}")
};
assert_eq!(v, 0.0);
}
}
3 changes: 3 additions & 0 deletions fidget/src/core/eval/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ pub mod grad_slice;
pub mod interval;
pub mod point;

// Internal-only tests
mod symbolic_deriv;

use crate::{
context::{Context, IntoNode, Node},
eval::Tape,
Expand Down
Loading

0 comments on commit c2697ac

Please sign in to comment.