From ef544cae032861b25c2e5d22eabdbb10ea9d43a1 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Wed, 5 Jun 2024 09:12:59 -0400 Subject: [PATCH] Use `IntoNode` to simplify a few tests (#133) --- fidget/src/core/eval/test/float_slice.rs | 15 ++++----------- fidget/src/core/eval/test/grad_slice.rs | 6 ++---- fidget/src/core/eval/test/interval.rs | 7 ++----- fidget/src/core/eval/test/mod.rs | 24 ++++++++++++++++++++---- fidget/src/core/eval/test/point.rs | 21 ++++++--------------- 5 files changed, 34 insertions(+), 39 deletions(-) diff --git a/fidget/src/core/eval/test/float_slice.rs b/fidget/src/core/eval/test/float_slice.rs index c3e16f4a..04c0f42a 100644 --- a/fidget/src/core/eval/test/float_slice.rs +++ b/fidget/src/core/eval/test/float_slice.rs @@ -115,10 +115,9 @@ impl TestFloatSlice { let x = ctx.x(); let y = ctx.y(); - let v_ = ctx.var(v); let a = ctx.add(x, y).unwrap(); - let a = ctx.add(a, v_).unwrap(); + let a = ctx.add(a, v).unwrap(); let s = Shape::::new(&ctx, a).unwrap(); @@ -265,14 +264,12 @@ impl TestFloatSlice { let mut ctx = Context::new(); let va = Var::new(); let vb = Var::new(); - let a = ctx.var(va); - let b = ctx.var(vb); let name = format!("{}(reg, reg)", C::NAME); for rot in 0..args.len() { let mut rgsa = args.clone(); rgsa.rotate_left(rot); - let node = C::build(&mut ctx, a, b); + let node = C::build(&mut ctx, va, vb); let shape = F::new(&ctx, node).unwrap(); let mut eval = F::new_float_slice_eval(); @@ -304,15 +301,13 @@ impl TestFloatSlice { let mut ctx = Context::new(); let va = Var::new(); - let a = ctx.var(va); let name = format!("{}(reg, imm)", C::NAME); for rot in 0..args.len() { let mut args = args.clone(); args.rotate_left(rot); for rhs in args.iter() { - let c = ctx.constant(*rhs as f64); - let node = C::build(&mut ctx, a, c); + let node = C::build(&mut ctx, va, *rhs); let shape = F::new(&ctx, node).unwrap(); let mut eval = F::new_float_slice_eval(); @@ -337,15 +332,13 @@ impl TestFloatSlice { let mut ctx = Context::new(); let va = Var::new(); - let a = ctx.var(va); let name = format!("{}(imm, reg)", C::NAME); for rot in 0..args.len() { let mut args = args.clone(); args.rotate_left(rot); for lhs in args.iter() { - let c = ctx.constant(*lhs as f64); - let node = C::build(&mut ctx, c, a); + let node = C::build(&mut ctx, *lhs, va); let shape = F::new(&ctx, node).unwrap(); let mut eval = F::new_float_slice_eval(); diff --git a/fidget/src/core/eval/test/grad_slice.rs b/fidget/src/core/eval/test/grad_slice.rs index 9b35a3d9..816d3e2d 100644 --- a/fidget/src/core/eval/test/grad_slice.rs +++ b/fidget/src/core/eval/test/grad_slice.rs @@ -590,8 +590,7 @@ impl TestGradSlice { args.rotate_left(rot); for (i, &v) in inputs.iter().enumerate() { for rhs in args.iter() { - let c = ctx.constant(*rhs as f64); - let node = C::build(&mut ctx, v, c); + let node = C::build(&mut ctx, v, *rhs); let shape = F::new(&ctx, node).unwrap(); let tape = shape.grad_slice_tape(Default::default()); @@ -631,8 +630,7 @@ impl TestGradSlice { args.rotate_left(rot); for (i, &v) in inputs.iter().enumerate() { for lhs in args.iter() { - let c = ctx.constant(*lhs as f64); - let node = C::build(&mut ctx, c, v); + let node = C::build(&mut ctx, *lhs, v); let shape = F::new(&ctx, node).unwrap(); let tape = shape.grad_slice_tape(Default::default()); diff --git a/fidget/src/core/eval/test/interval.rs b/fidget/src/core/eval/test/interval.rs index aa185a66..6cb6db29 100644 --- a/fidget/src/core/eval/test/interval.rs +++ b/fidget/src/core/eval/test/interval.rs @@ -1025,8 +1025,7 @@ where let mut eval = F::new_interval_eval(); for &lhs in args.iter() { for &rhs in values.iter() { - let c = ctx.constant(rhs as f64); - let node = C::build(&mut ctx, a, c); + let node = C::build(&mut ctx, a, rhs); let shape = F::new(&ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); @@ -1051,15 +1050,13 @@ where let mut ctx = Context::new(); let va = Var::new(); - let a = ctx.var(va); let name = format!("{}(imm, reg)", C::NAME); let mut tape_data = None; let mut eval = F::new_interval_eval(); for &lhs in values.iter() { for &rhs in args.iter() { - let c = ctx.constant(lhs as f64); - let node = C::build(&mut ctx, c, a); + let node = C::build(&mut ctx, lhs, va); let shape = F::new(&ctx, node).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); diff --git a/fidget/src/core/eval/test/mod.rs b/fidget/src/core/eval/test/mod.rs index a3b7b0e0..ee4a8b42 100644 --- a/fidget/src/core/eval/test/mod.rs +++ b/fidget/src/core/eval/test/mod.rs @@ -5,7 +5,7 @@ pub mod interval; pub mod point; use crate::{ - context::{Context, Node}, + context::{Context, IntoNode, Node}, eval::Tape, var::Var, }; @@ -111,7 +111,11 @@ pub trait CanonicalUnaryOp { /// Trait for canonical evaluation testing of binary operations pub trait CanonicalBinaryOp { const NAME: &'static str; - fn build(ctx: &mut Context, lhs: Node, rhs: Node) -> Node; + fn build( + ctx: &mut Context, + lhs: A, + rhs: B, + ) -> Node; fn eval_reg_reg_f32(lhs: f32, rhs: f32) -> f32; fn eval_reg_imm_f32(lhs: f32, rhs: f32) -> f32; fn eval_imm_reg_f32(lhs: f32, rhs: f32) -> f32; @@ -157,7 +161,13 @@ macro_rules! declare_canonical_binary { pub struct $i; impl CanonicalBinaryOp for $i { const NAME: &'static str = stringify!($i); - fn build(ctx: &mut Context, lhs: Node, rhs: Node) -> Node { + fn build( + ctx: &mut Context, + lhs: A, + rhs: B, + ) -> Node { + let lhs = lhs.into_node(ctx).unwrap(); + let rhs = rhs.into_node(ctx).unwrap(); Context::$i(ctx, lhs, rhs).unwrap() } fn eval_reg_reg_f32($lhs: f32, $rhs: f32) -> f32 { @@ -197,7 +207,13 @@ macro_rules! declare_canonical_binary_full { pub struct $i; impl CanonicalBinaryOp for $i { const NAME: &'static str = stringify!($i); - fn build(ctx: &mut Context, lhs: Node, rhs: Node) -> Node { + fn build( + ctx: &mut Context, + lhs: A, + rhs: B, + ) -> Node { + let lhs = lhs.into_node(ctx).unwrap(); + let rhs = rhs.into_node(ctx).unwrap(); Context::$i(ctx, lhs, rhs).unwrap() } fn eval_reg_reg_f32($lhs_reg_reg: f32, $rhs_reg_reg: f32) -> f32 { diff --git a/fidget/src/core/eval/test/point.rs b/fidget/src/core/eval/test/point.rs index f0ad2ea1..e6d266c3 100644 --- a/fidget/src/core/eval/test/point.rs +++ b/fidget/src/core/eval/test/point.rs @@ -34,9 +34,7 @@ where pub fn test_constant_push() { let mut ctx = Context::new(); - let a = ctx.constant(1.5); - let x = ctx.x(); - let min = ctx.min(a, x).unwrap(); + let min = ctx.min(1.5, Var::X).unwrap(); let shape = F::new(&ctx, min).unwrap(); let tape = shape.point_tape(Default::default()); let mut eval = F::new_point_eval(); @@ -361,10 +359,9 @@ where let x = ctx.x(); let y = ctx.y(); - let v_ = ctx.var(v); let a = ctx.add(x, y).unwrap(); - let a = ctx.add(a, v_).unwrap(); + let a = ctx.add(a, v).unwrap(); let s = Shape::::new(&ctx, a).unwrap(); @@ -491,13 +488,11 @@ where let mut ctx = Context::new(); let va = Var::new(); let vb = Var::new(); - let a = ctx.var(va); - let b = ctx.var(vb); let name = format!("{}(reg, reg)", C::NAME); for &lhs in args.iter() { for &rhs in args.iter() { - let node = C::build(&mut ctx, a, b); + let node = C::build(&mut ctx, va, vb); let shape = F::new(&ctx, node).unwrap(); let mut eval = F::new_point_eval(); @@ -525,7 +520,7 @@ where } for &lhs in args.iter() { - let node = C::build(&mut ctx, a, a); + let node = C::build(&mut ctx, va, va); let shape = F::new(&ctx, node).unwrap(); let mut eval = F::new_point_eval(); @@ -553,13 +548,11 @@ where let mut ctx = Context::new(); let va = Var::new(); - let a = ctx.var(va); let name = format!("{}(reg, imm)", C::NAME); for &lhs in args.iter() { for &rhs in args.iter() { - let c = ctx.constant(rhs as f64); - let node = C::build(&mut ctx, a, c); + let node = C::build(&mut ctx, va, rhs); let shape = F::new(&ctx, node).unwrap(); let mut eval = F::new_point_eval(); @@ -583,13 +576,11 @@ where let mut ctx = Context::new(); let va = Var::new(); - let a = ctx.var(va); let name = format!("{}(imm, reg)", C::NAME); for &lhs in args.iter() { for &rhs in args.iter() { - let c = ctx.constant(lhs as f64); - let node = C::build(&mut ctx, c, a); + let node = C::build(&mut ctx, lhs, va); let shape = F::new(&ctx, node).unwrap(); let mut eval = F::new_point_eval();