diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e4f0e76..189d26fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ # 0.3.2 (unreleased) - Added `impl IntoNode for Var`, to make handling `Var` values in a context easier. +- Added `impl From for Tree` for convenience +- Added `Context::export(&self, n: Node) -> Tree` to make a freestanding `Tree` + given a context-specific `Node`. - Fix possible corruption of `x24` during AArch64 float slice JIT evaluation, due to incorrect stack alignment. diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index 09eda962..32fdefba 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -1149,6 +1149,84 @@ impl Context { assert_eq!(stack.len(), 1); stack.pop().unwrap() } + + /// Converts from a context-specific node into a standalone [`Tree`] + pub fn export(&self, n: Node) -> Result { + if self.get_op(n).is_none() { + return Err(Error::BadNode); + } + + // Do recursion on the heap to avoid stack overflows for deep trees + enum Action { + /// Pushes `Up(n)` followed by `Down(n)` for each child + Down(Node), + /// Consumes trees from the stack and pushes a new tree + Up(Node, Op), + } + let mut todo = vec![Action::Down(n)]; + let mut stack = vec![]; + + // Cache of Node -> Tree mapping, for Tree deduplication + let mut seen: HashMap = HashMap::new(); + + while let Some(t) = todo.pop() { + match t { + Action::Down(n) => { + // If we've already seen this TreeOp with these axes, then + // we can return the previous Node. + if let Some(p) = seen.get(&n) { + stack.push(p.clone()); + continue; + } + let op = self.get_op(n).unwrap(); + match op { + Op::Const(c) => { + let t = Tree::from(c.0); + seen.insert(n, t.clone()); + stack.push(t); + } + Op::Input(v) => { + let t = Tree::from(*v); + seen.insert(n, t.clone()); + stack.push(t); + } + Op::Unary(_op, arg) => { + todo.push(Action::Up(n, *op)); + todo.push(Action::Down(*arg)); + } + Op::Binary(_op, lhs, rhs) => { + todo.push(Action::Up(n, *op)); + todo.push(Action::Down(*lhs)); + todo.push(Action::Down(*rhs)); + } + } + } + Action::Up(n, op) => match op { + Op::Const(..) | Op::Input(..) => unreachable!(), + Op::Unary(op, ..) => { + let arg = stack.pop().unwrap(); + let out = + Tree::from(TreeOp::Unary(op, arg.arc().clone())); + seen.insert(n, out.clone()); + stack.push(out); + } + Op::Binary(op, ..) => { + let lhs = stack.pop().unwrap(); + let rhs = stack.pop().unwrap(); + let out = Tree::from(TreeOp::Binary( + op, + lhs.arc().clone(), + rhs.arc().clone(), + )); + seen.insert(n, out.clone()); + stack.push(out); + } + }, + } + } + assert_eq!(stack.len(), 1); + Ok(stack.pop().unwrap()) + } } //////////////////////////////////////////////////////////////////////////////// @@ -1237,4 +1315,30 @@ mod test { assert_eq!(tape.len(), 2); assert_eq!(tape.vars.len(), 1); } + + #[test] + fn test_export() { + let mut ctx = Context::new(); + let x = ctx.x(); + let s = ctx.sin(x).unwrap(); + let c = ctx.cos(x).unwrap(); + let sum = ctx.add(s, c).unwrap(); + let t = ctx.export(sum).unwrap(); + if let TreeOp::Binary(BinaryOpcode::Add, lhs, rhs) = &*t { + match (&**lhs, &**rhs) { + ( + TreeOp::Unary(UnaryOpcode::Sin, x1), + TreeOp::Unary(UnaryOpcode::Cos, x2), + ) => { + assert_eq!(Arc::as_ptr(x1), Arc::as_ptr(x2)); + let TreeOp::Input(Var::X) = &**x1 else { + panic!("invalid X: {x1:?}"); + }; + } + _ => panic!("invalid lhs / rhs: {lhs:?} {rhs:?}"), + } + } else { + panic!("unexpected opcode {t:?}"); + } + } } diff --git a/fidget/src/core/context/tree.rs b/fidget/src/core/context/tree.rs index 3dc3d6f3..8be0fc72 100644 --- a/fidget/src/core/context/tree.rs +++ b/fidget/src/core/context/tree.rs @@ -107,6 +107,12 @@ impl From for Tree { } } +impl From for Tree { + fn from(t: TreeOp) -> Tree { + Tree(Arc::new(t)) + } +} + /// Owned handle for a standalone math tree #[derive(Clone, Debug)] pub struct Tree(Arc);