Skip to content

Commit

Permalink
Use IntoNode to simplify a few tests (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeeter authored Jun 5, 2024
1 parent 97dda1a commit ef544ca
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 39 deletions.
15 changes: 4 additions & 11 deletions fidget/src/core/eval/test/float_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,9 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {

let x = ctx.x();
let y = ctx.y();
let v_ = ctx.var(v);

let a = ctx.add(x, y).unwrap();
let a = ctx.add(a, v_).unwrap();
let a = ctx.add(a, v).unwrap();

let s = Shape::<F>::new(&ctx, a).unwrap();

Expand Down Expand Up @@ -265,14 +264,12 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {
let mut ctx = Context::new();
let va = Var::new();
let vb = Var::new();
let a = ctx.var(va);
let b = ctx.var(vb);

let name = format!("{}(reg, reg)", C::NAME);
for rot in 0..args.len() {
let mut rgsa = args.clone();
rgsa.rotate_left(rot);
let node = C::build(&mut ctx, a, b);
let node = C::build(&mut ctx, va, vb);

let shape = F::new(&ctx, node).unwrap();
let mut eval = F::new_float_slice_eval();
Expand Down Expand Up @@ -304,15 +301,13 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {

let mut ctx = Context::new();
let va = Var::new();
let a = ctx.var(va);

let name = format!("{}(reg, imm)", C::NAME);
for rot in 0..args.len() {
let mut args = args.clone();
args.rotate_left(rot);
for rhs in args.iter() {
let c = ctx.constant(*rhs as f64);
let node = C::build(&mut ctx, a, c);
let node = C::build(&mut ctx, va, *rhs);

let shape = F::new(&ctx, node).unwrap();
let mut eval = F::new_float_slice_eval();
Expand All @@ -337,15 +332,13 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {

let mut ctx = Context::new();
let va = Var::new();
let a = ctx.var(va);

let name = format!("{}(imm, reg)", C::NAME);
for rot in 0..args.len() {
let mut args = args.clone();
args.rotate_left(rot);
for lhs in args.iter() {
let c = ctx.constant(*lhs as f64);
let node = C::build(&mut ctx, c, a);
let node = C::build(&mut ctx, *lhs, va);

let shape = F::new(&ctx, node).unwrap();
let mut eval = F::new_float_slice_eval();
Expand Down
6 changes: 2 additions & 4 deletions fidget/src/core/eval/test/grad_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,7 @@ impl<F: Function + MathFunction> TestGradSlice<F> {
args.rotate_left(rot);
for (i, &v) in inputs.iter().enumerate() {
for rhs in args.iter() {
let c = ctx.constant(*rhs as f64);
let node = C::build(&mut ctx, v, c);
let node = C::build(&mut ctx, v, *rhs);

let shape = F::new(&ctx, node).unwrap();
let tape = shape.grad_slice_tape(Default::default());
Expand Down Expand Up @@ -631,8 +630,7 @@ impl<F: Function + MathFunction> TestGradSlice<F> {
args.rotate_left(rot);
for (i, &v) in inputs.iter().enumerate() {
for lhs in args.iter() {
let c = ctx.constant(*lhs as f64);
let node = C::build(&mut ctx, c, v);
let node = C::build(&mut ctx, *lhs, v);

let shape = F::new(&ctx, node).unwrap();
let tape = shape.grad_slice_tape(Default::default());
Expand Down
7 changes: 2 additions & 5 deletions fidget/src/core/eval/test/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1025,8 +1025,7 @@ where
let mut eval = F::new_interval_eval();
for &lhs in args.iter() {
for &rhs in values.iter() {
let c = ctx.constant(rhs as f64);
let node = C::build(&mut ctx, a, c);
let node = C::build(&mut ctx, a, rhs);

let shape = F::new(&ctx, node).unwrap();
let tape = shape.interval_tape(tape_data.unwrap_or_default());
Expand All @@ -1051,15 +1050,13 @@ where

let mut ctx = Context::new();
let va = Var::new();
let a = ctx.var(va);

let name = format!("{}(imm, reg)", C::NAME);
let mut tape_data = None;
let mut eval = F::new_interval_eval();
for &lhs in values.iter() {
for &rhs in args.iter() {
let c = ctx.constant(lhs as f64);
let node = C::build(&mut ctx, c, a);
let node = C::build(&mut ctx, lhs, va);

let shape = F::new(&ctx, node).unwrap();
let tape = shape.interval_tape(tape_data.unwrap_or_default());
Expand Down
24 changes: 20 additions & 4 deletions fidget/src/core/eval/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub mod interval;
pub mod point;

use crate::{
context::{Context, Node},
context::{Context, IntoNode, Node},
eval::Tape,
var::Var,
};
Expand Down Expand Up @@ -111,7 +111,11 @@ pub trait CanonicalUnaryOp {
/// Trait for canonical evaluation testing of binary operations
pub trait CanonicalBinaryOp {
const NAME: &'static str;
fn build(ctx: &mut Context, lhs: Node, rhs: Node) -> Node;
fn build<A: IntoNode, B: IntoNode>(
ctx: &mut Context,
lhs: A,
rhs: B,
) -> Node;
fn eval_reg_reg_f32(lhs: f32, rhs: f32) -> f32;
fn eval_reg_imm_f32(lhs: f32, rhs: f32) -> f32;
fn eval_imm_reg_f32(lhs: f32, rhs: f32) -> f32;
Expand Down Expand Up @@ -157,7 +161,13 @@ macro_rules! declare_canonical_binary {
pub struct $i;
impl CanonicalBinaryOp for $i {
const NAME: &'static str = stringify!($i);
fn build(ctx: &mut Context, lhs: Node, rhs: Node) -> Node {
fn build<A: IntoNode, B: IntoNode>(
ctx: &mut Context,
lhs: A,
rhs: B,
) -> Node {
let lhs = lhs.into_node(ctx).unwrap();
let rhs = rhs.into_node(ctx).unwrap();
Context::$i(ctx, lhs, rhs).unwrap()
}
fn eval_reg_reg_f32($lhs: f32, $rhs: f32) -> f32 {
Expand Down Expand Up @@ -197,7 +207,13 @@ macro_rules! declare_canonical_binary_full {
pub struct $i;
impl CanonicalBinaryOp for $i {
const NAME: &'static str = stringify!($i);
fn build(ctx: &mut Context, lhs: Node, rhs: Node) -> Node {
fn build<A: IntoNode, B: IntoNode>(
ctx: &mut Context,
lhs: A,
rhs: B,
) -> Node {
let lhs = lhs.into_node(ctx).unwrap();
let rhs = rhs.into_node(ctx).unwrap();
Context::$i(ctx, lhs, rhs).unwrap()
}
fn eval_reg_reg_f32($lhs_reg_reg: f32, $rhs_reg_reg: f32) -> f32 {
Expand Down
21 changes: 6 additions & 15 deletions fidget/src/core/eval/test/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ where

pub fn test_constant_push() {
let mut ctx = Context::new();
let a = ctx.constant(1.5);
let x = ctx.x();
let min = ctx.min(a, x).unwrap();
let min = ctx.min(1.5, Var::X).unwrap();
let shape = F::new(&ctx, min).unwrap();
let tape = shape.point_tape(Default::default());
let mut eval = F::new_point_eval();
Expand Down Expand Up @@ -361,10 +359,9 @@ where

let x = ctx.x();
let y = ctx.y();
let v_ = ctx.var(v);

let a = ctx.add(x, y).unwrap();
let a = ctx.add(a, v_).unwrap();
let a = ctx.add(a, v).unwrap();

let s = Shape::<F>::new(&ctx, a).unwrap();

Expand Down Expand Up @@ -491,13 +488,11 @@ where
let mut ctx = Context::new();
let va = Var::new();
let vb = Var::new();
let a = ctx.var(va);
let b = ctx.var(vb);

let name = format!("{}(reg, reg)", C::NAME);
for &lhs in args.iter() {
for &rhs in args.iter() {
let node = C::build(&mut ctx, a, b);
let node = C::build(&mut ctx, va, vb);

let shape = F::new(&ctx, node).unwrap();
let mut eval = F::new_point_eval();
Expand Down Expand Up @@ -525,7 +520,7 @@ where
}

for &lhs in args.iter() {
let node = C::build(&mut ctx, a, a);
let node = C::build(&mut ctx, va, va);

let shape = F::new(&ctx, node).unwrap();
let mut eval = F::new_point_eval();
Expand Down Expand Up @@ -553,13 +548,11 @@ where

let mut ctx = Context::new();
let va = Var::new();
let a = ctx.var(va);

let name = format!("{}(reg, imm)", C::NAME);
for &lhs in args.iter() {
for &rhs in args.iter() {
let c = ctx.constant(rhs as f64);
let node = C::build(&mut ctx, a, c);
let node = C::build(&mut ctx, va, rhs);

let shape = F::new(&ctx, node).unwrap();
let mut eval = F::new_point_eval();
Expand All @@ -583,13 +576,11 @@ where

let mut ctx = Context::new();
let va = Var::new();
let a = ctx.var(va);

let name = format!("{}(imm, reg)", C::NAME);
for &lhs in args.iter() {
for &rhs in args.iter() {
let c = ctx.constant(lhs as f64);
let node = C::build(&mut ctx, c, a);
let node = C::build(&mut ctx, lhs, va);

let shape = F::new(&ctx, node).unwrap();
let mut eval = F::new_point_eval();
Expand Down

0 comments on commit ef544ca

Please sign in to comment.