Skip to content

Commit

Permalink
Fix stack overflows when handling deep Tree objects (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeeter authored Apr 21, 2024
1 parent 6c1ab2f commit b21918d
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 51 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# 0.2.6 (unreleased)
- Added `VmShape` serialization (using `serde`), specifically
- `#[derive(Serialize, Deserialize)}` on `VmData`
- `impl From<VmData<255>> for VmShape { .. }`
- `#[derive(Serialize, Deserialize)}` on `VmData`
- `impl From<VmData<255>> for VmShape { .. }`
- Fixed stack overflows when handling very deep `Tree` objects
- Added a non-recursive `Drop` implementation
- Rewrote `Context::import` to use the heap instead of stack

# 0.2.5
The highlight of this release is native Windows support (including JIT
Expand Down
106 changes: 71 additions & 35 deletions fidget/src/core/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -937,43 +937,79 @@ impl Context {

/// Imports the given tree, deduplicating and returning the root
pub fn import(&mut self, tree: &Tree) -> Node {
// TODO make this non-recursive to avoid blowing up the stack
let x = self.x();
let y = self.y();
let z = self.z();
self.import_inner(tree, x, y, z)
}

fn import_inner(&mut self, tree: &Tree, x: Node, y: Node, z: Node) -> Node {
match &**tree {
TreeOp::Input(s) => match *s {
"X" => x,
"Y" => y,
"Z" => z,
s => panic!("invalid tree input string {s:?}"),
},
TreeOp::Const(c) => self.constant(*c),
TreeOp::Unary(op, t) => {
let t = self.import_inner(t, x, y, z);
self.op_unary(t, *op).unwrap()
}
TreeOp::Binary(op, a, b) => {
let a = self.import_inner(a, x, y, z);
let b = self.import_inner(b, x, y, z);
self.op_binary(a, b, *op).unwrap()
}
TreeOp::RemapAxes {
target,
x: tx,
y: ty,
z: tz,
} => {
let x_ = self.import_inner(tx, x, y, z);
let y_ = self.import_inner(ty, x, y, z);
let z_ = self.import_inner(tz, x, y, z);
self.import_inner(target, x_, y_, z_)
// A naive remapping implementation would use recursion. A naive
// remapping implementation would blow up the stack given any
// significant tree size.
//
// Instead, we maintain our own pseudo-stack here in a pair of Vecs (one
// stack for actions, and a second stack for return values).
enum Action<'a> {
/// Pushes `Up(op)` followed by `Down(c)` for each child
Down(&'a TreeOp),
/// Consumes imported trees from the stack and pushes a new tree
Up(&'a TreeOp),
/// Pops the latest axis frame
Pop,
}
let mut axes = vec![(self.x(), self.y(), self.z())];
let mut todo = vec![Action::Down(tree)];
let mut stack = vec![];

while let Some(t) = todo.pop() {
match t {
Action::Down(t) => {
todo.push(Action::Up(t));
match t {
TreeOp::Const(..) | TreeOp::Input(..) => (),
TreeOp::Unary(_op, arg) => todo.push(Action::Down(arg)),
TreeOp::Binary(_op, lhs, rhs) => {
todo.push(Action::Down(lhs));
todo.push(Action::Down(rhs));
}
TreeOp::RemapAxes { target: _, x, y, z } => {
// Action::Up(t) does the remapping and target eval
todo.push(Action::Down(x));
todo.push(Action::Down(y));
todo.push(Action::Down(z));
}
}
}
Action::Up(t) => match t {
TreeOp::Const(c) => stack.push(self.constant(*c)),
TreeOp::Input(s) => {
let axes = axes.last().unwrap();
stack.push(match *s {
"X" => axes.0,
"Y" => axes.1,
"Z" => axes.2,
s => panic!("invalid tree input string {s:?}"),
});
}
TreeOp::Unary(op, ..) => {
let arg = stack.pop().unwrap();
stack.push(self.op_unary(arg, *op).unwrap());
}
TreeOp::Binary(op, ..) => {
let lhs = stack.pop().unwrap();
let rhs = stack.pop().unwrap();
stack.push(self.op_binary(lhs, rhs, *op).unwrap());
}
TreeOp::RemapAxes { target, .. } => {
let x = stack.pop().unwrap();
let y = stack.pop().unwrap();
let z = stack.pop().unwrap();
axes.push((x, y, z));
todo.push(Action::Pop);
todo.push(Action::Down(target));
}
},
Action::Pop => {
axes.pop().unwrap();
}
}
}
assert_eq!(stack.len(), 1);
stack.pop().unwrap()
}
}

Expand Down
177 changes: 163 additions & 14 deletions fidget/src/core/context/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,76 @@ pub enum TreeOp {
/// Input (at the moment, limited to "X", "Y", "Z")
Input(&'static str),
Const(f64),
Binary(BinaryOpcode, Tree, Tree),
Unary(UnaryOpcode, Tree),
Binary(BinaryOpcode, Arc<TreeOp>, Arc<TreeOp>),
Unary(UnaryOpcode, Arc<TreeOp>),
/// Lazy remapping of trees
///
/// When imported into a `Context`, all `x/y/z` clauses within `target` will
/// be replaced with the provided `x/y/z` trees.
RemapAxes {
target: Tree,
x: Tree,
y: Tree,
z: Tree,
target: Arc<TreeOp>,
x: Arc<TreeOp>,
y: Arc<TreeOp>,
z: Arc<TreeOp>,
},
}

impl Drop for TreeOp {
fn drop(&mut self) {
// Early exit for TreeOps which have limited recursion
if self.fast_drop() {
return;
}

let mut todo = vec![std::mem::replace(self, TreeOp::Const(0.0))];
let empty = Arc::new(TreeOp::Const(0.0));
while let Some(mut t) = todo.pop() {
for t in t.iter_children() {
let arg = std::mem::replace(t, empty.clone());
todo.extend(Arc::into_inner(arg));
}
drop(t);
}
}
}

impl TreeOp {
/// Checks whether the given tree is eligible for fast dropping
///
/// Fast dropping uses the normal `Drop` implementation, which recurses on
/// the stack and can overflow for deep trees. A recursive tree is only
/// eligible for fast dropping if all of its children are `TreeOp::Const`.
fn fast_drop(&self) -> bool {
match self {
TreeOp::Const(..) | TreeOp::Input(..) => true,
TreeOp::Unary(_op, arg) => matches!(**arg, TreeOp::Const(..)),
TreeOp::Binary(_op, lhs, rhs) => {
matches!(**lhs, TreeOp::Const(..))
&& matches!(**rhs, TreeOp::Const(..))
}
TreeOp::RemapAxes { target, x, y, z } => {
matches!(**target, TreeOp::Const(..))
&& matches!(**x, TreeOp::Const(..))
&& matches!(**y, TreeOp::Const(..))
&& matches!(**z, TreeOp::Const(..))
}
}
}

fn iter_children(&mut self) -> impl Iterator<Item = &mut Arc<TreeOp>> {
match self {
TreeOp::Const(..) | TreeOp::Input(..) => [None, None, None, None],
TreeOp::Unary(_op, arg) => [Some(arg), None, None, None],
TreeOp::Binary(_op, lhs, rhs) => [Some(lhs), Some(rhs), None, None],
TreeOp::RemapAxes { target, x, y, z } => {
[Some(target), Some(x), Some(y), Some(z)]
}
}
.into_iter()
.flatten()
}
}

impl From<f64> for Tree {
fn from(v: f64) -> Tree {
Tree::constant(v)
Expand Down Expand Up @@ -79,10 +135,10 @@ impl Tree {
/// into a `Context`.
pub fn remap_xyz(&self, x: Tree, y: Tree, z: Tree) -> Tree {
Self(Arc::new(TreeOp::RemapAxes {
target: self.clone(),
x,
y,
z,
target: self.0.clone(),
x: x.0,
y: y.0,
z: z.0,
}))
}
}
Expand All @@ -103,10 +159,10 @@ impl Tree {
Tree(Arc::new(TreeOp::Const(f)))
}
fn op_unary(a: Tree, op: UnaryOpcode) -> Self {
Tree(Arc::new(TreeOp::Unary(op, a)))
Tree(Arc::new(TreeOp::Unary(op, a.0)))
}
fn op_binary(a: Tree, b: Tree, op: BinaryOpcode) -> Self {
Tree(Arc::new(TreeOp::Binary(op, a, b)))
Tree(Arc::new(TreeOp::Binary(op, a.0, b.0)))
}
pub fn square(&self) -> Self {
Self::op_unary(self.clone(), UnaryOpcode::Square)
Expand Down Expand Up @@ -179,7 +235,8 @@ macro_rules! impl_binary {
impl<A: Into<Tree>> std::ops::$op_assign<A> for Tree {
fn $assign_fn(&mut self, other: A) {
use std::ops::$op;
self.0 = self.clone().$base_fn(other.into()).0
let mut next = self.clone().$base_fn(other.into());
std::mem::swap(self, &mut next);
}
}
impl std::ops::$op<Tree> for f32 {
Expand Down Expand Up @@ -221,16 +278,108 @@ mod test {

#[test]
fn test_remap_xyz() {
// Remapping X
let s = Tree::x() + 1.0;

let v = s.remap_xyz(Tree::y(), Tree::y(), Tree::z());
let v = s.remap_xyz(Tree::y(), Tree::z(), Tree::x());
let mut ctx = Context::new();
let v_ = ctx.import(&v);
assert_eq!(ctx.eval_xyz(v_, 0.0, 1.0, 0.0).unwrap(), 2.0);

let v = s.remap_xyz(Tree::z(), Tree::x(), Tree::y());
let mut ctx = Context::new();
let v_ = ctx.import(&v);
assert_eq!(ctx.eval_xyz(v_, 0.0, 0.0, 1.0).unwrap(), 2.0);

let v = s.remap_xyz(Tree::x(), Tree::y(), Tree::z());
let mut ctx = Context::new();
let v_ = ctx.import(&v);
assert_eq!(ctx.eval_xyz(v_, 1.0, 0.0, 0.0).unwrap(), 2.0);

// Remapping Y
let s = Tree::y() + 1.0;

let v = s.remap_xyz(Tree::y(), Tree::z(), Tree::x());
let mut ctx = Context::new();
let v_ = ctx.import(&v);
assert_eq!(ctx.eval_xyz(v_, 0.0, 0.0, 1.0).unwrap(), 2.0);

let v = s.remap_xyz(Tree::z(), Tree::x(), Tree::y());
let mut ctx = Context::new();
let v_ = ctx.import(&v);
assert_eq!(ctx.eval_xyz(v_, 1.0, 0.0, 0.0).unwrap(), 2.0);

let v = s.remap_xyz(Tree::x(), Tree::y(), Tree::z());
let mut ctx = Context::new();
let v_ = ctx.import(&v);
assert_eq!(ctx.eval_xyz(v_, 0.0, 1.0, 0.0).unwrap(), 2.0);

// Remapping Z
let s = Tree::z() + 1.0;

let v = s.remap_xyz(Tree::y(), Tree::z(), Tree::x());
let mut ctx = Context::new();
let v_ = ctx.import(&v);
assert_eq!(ctx.eval_xyz(v_, 1.0, 0.0, 0.0).unwrap(), 2.0);

let v = s.remap_xyz(Tree::z(), Tree::x(), Tree::y());
let mut ctx = Context::new();
let v_ = ctx.import(&v);
assert_eq!(ctx.eval_xyz(v_, 0.0, 1.0, 0.0).unwrap(), 2.0);

let v = s.remap_xyz(Tree::x(), Tree::y(), Tree::z());
let mut ctx = Context::new();
let v_ = ctx.import(&v);
assert_eq!(ctx.eval_xyz(v_, 0.0, 0.0, 1.0).unwrap(), 2.0);

// Test remapping to a constant
let s = Tree::x() + 1.0;
let one = Tree::constant(3.0);
let v = s.remap_xyz(one, Tree::y(), Tree::z());
let v_ = ctx.import(&v);
assert_eq!(ctx.eval_xyz(v_, 0.0, 1.0, 0.0).unwrap(), 4.0);
}

#[test]
fn deep_recursion_drop() {
let mut x = Tree::x();
for _ in 0..1_000_000 {
x += 1.0;
}
drop(x);
// we should not panic here!
}

#[test]
fn deep_recursion_import() {
let mut x = Tree::x();
for _ in 0..1_000_000 {
x += 1.0;
}
let mut ctx = Context::new();
ctx.import(&x);
// we should not panic here!
}

#[test]
fn tree_remap_multi() {
let mut ctx = Context::new();

let out = Tree::x() + Tree::y() + Tree::z();
let out =
out.remap_xyz(Tree::x() * 2.0, Tree::y() * 3.0, Tree::z() * 5.0);

let v_ = ctx.import(&out);
assert_eq!(ctx.eval_xyz(v_, 1.0, 1.0, 1.0).unwrap(), 10.0);
assert_eq!(ctx.eval_xyz(v_, 2.0, 1.0, 1.0).unwrap(), 12.0);
assert_eq!(ctx.eval_xyz(v_, 2.0, 2.0, 1.0).unwrap(), 15.0);
assert_eq!(ctx.eval_xyz(v_, 2.0, 2.0, 2.0).unwrap(), 20.0);

let out = out.remap_xyz(Tree::y(), Tree::z(), Tree::x());
let v_ = ctx.import(&out);
assert_eq!(ctx.eval_xyz(v_, 1.0, 1.0, 1.0).unwrap(), 10.0);
assert_eq!(ctx.eval_xyz(v_, 2.0, 1.0, 1.0).unwrap(), 15.0);
assert_eq!(ctx.eval_xyz(v_, 2.0, 2.0, 1.0).unwrap(), 17.0);
assert_eq!(ctx.eval_xyz(v_, 2.0, 2.0, 2.0).unwrap(), 20.0);
}
}

0 comments on commit b21918d

Please sign in to comment.