Skip to content

Commit

Permalink
Implement compare opcode (#46)
Browse files Browse the repository at this point in the history
This is the `<=>` / `partial_cmp` operation!
  • Loading branch information
mkeeter authored Mar 24, 2024
1 parent 6cac755 commit 82c75f5
Show file tree
Hide file tree
Showing 19 changed files with 554 additions and 11 deletions.
16 changes: 14 additions & 2 deletions fidget/src/core/compiler/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,17 @@ impl<const N: usize> RegisterAllocator<N> {
| SsaOp::DivRegImm(..)
| SsaOp::DivImmReg(..)
| SsaOp::MinRegImm(..)
| SsaOp::MaxRegImm(..) => self.op_reg_imm(op),
| SsaOp::MaxRegImm(..)
| SsaOp::CompareRegImm(..)
| SsaOp::CompareImmReg(..) => self.op_reg_imm(op),

SsaOp::AddRegReg(..)
| SsaOp::SubRegReg(..)
| SsaOp::MulRegReg(..)
| SsaOp::DivRegReg(..)
| SsaOp::MinRegReg(..)
| SsaOp::MaxRegReg(..) => self.op_reg_reg(op),
| SsaOp::MaxRegReg(..)
| SsaOp::CompareRegReg(..) => self.op_reg_reg(op),
}
}

Expand Down Expand Up @@ -485,6 +488,9 @@ impl<const N: usize> RegisterAllocator<N> {
SsaOp::MaxRegReg(out, lhs, rhs) => {
(out, lhs, rhs, RegOp::MaxRegReg)
}
SsaOp::CompareRegReg(out, lhs, rhs) => {
(out, lhs, rhs, RegOp::CompareRegReg)
}
_ => panic!("Bad opcode: {op:?}"),
};
let r_x = self.get_out_reg(out);
Expand Down Expand Up @@ -597,6 +603,12 @@ impl<const N: usize> RegisterAllocator<N> {
SsaOp::MaxRegImm(out, arg, imm) => {
(out, arg, imm, RegOp::MaxRegImm)
}
SsaOp::CompareRegImm(out, arg, imm) => {
(out, arg, imm, RegOp::CompareRegImm)
}
SsaOp::CompareImmReg(out, arg, imm) => {
(out, arg, imm, RegOp::CompareImmReg)
}
_ => panic!("Bad opcode: {op:?}"),
};
self.op_reg_fn(out, arg, |out, arg| op(out, arg, imm));
Expand Down
16 changes: 14 additions & 2 deletions fidget/src/core/compiler/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ macro_rules! opcodes {
MinRegImm($t, $t, f32),
#[doc = "Compute the maximum of a register and an immediate"]
MaxRegImm($t, $t, f32),
#[doc = "Compares a register with an immediate"]
CompareRegImm($t, $t, f32),
#[doc = "Compares an immediate with a register"]
CompareImmReg($t, $t, f32),

#[doc = "Add two registers"]
AddRegReg($t, $t, $t),
Expand All @@ -98,6 +102,8 @@ macro_rules! opcodes {
MinRegReg($t, $t, $t),
#[doc = "Take the maximum of two registers"]
MaxRegReg($t, $t, $t),
#[doc = "Compares two registers"]
CompareRegReg($t, $t, $t),

#[doc = "Copy an immediate to a register"]
CopyImm($t, f32),
Expand Down Expand Up @@ -159,7 +165,10 @@ impl SsaOp {
| SsaOp::MinRegImm(out, ..)
| SsaOp::MaxRegImm(out, ..)
| SsaOp::MinRegReg(out, ..)
| SsaOp::MaxRegReg(out, ..) => *out,
| SsaOp::MaxRegReg(out, ..)
| SsaOp::CompareRegReg(out, ..)
| SsaOp::CompareRegImm(out, ..)
| SsaOp::CompareImmReg(out, ..) => *out,
}
}
/// Returns true if the given opcode is associated with a choice
Expand Down Expand Up @@ -191,7 +200,10 @@ impl SsaOp {
| SsaOp::SubRegReg(..)
| SsaOp::DivRegReg(..)
| SsaOp::DivRegImm(..)
| SsaOp::DivImmReg(..) => false,
| SsaOp::DivImmReg(..)
| SsaOp::CompareRegReg(..)
| SsaOp::CompareRegImm(..)
| SsaOp::CompareImmReg(..) => false,
SsaOp::MinRegImm(..)
| SsaOp::MaxRegImm(..)
| SsaOp::MinRegReg(..)
Expand Down
14 changes: 14 additions & 0 deletions fidget/src/core/compiler/ssa_tape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ impl SsaTape {
SsaOp::MaxRegImm,
SsaOp::MaxRegImm,
),
BinaryOpcode::Compare => (
SsaOp::CompareRegReg,
SsaOp::CompareRegImm,
SsaOp::CompareImmReg,
),
};

if matches!(op, BinaryOpcode::Min | BinaryOpcode::Max) {
Expand Down Expand Up @@ -331,6 +336,15 @@ impl SsaTape {
println!("${out} = {op} ${arg} {imm}");
}
}
SsaOp::CompareRegReg(out, lhs, rhs) => {
println!("${out} = COMPARE {lhs} {rhs}")
}
SsaOp::CompareRegImm(out, arg, imm) => {
println!("${out} = COMPARE {arg} {imm}")
}
SsaOp::CompareImmReg(out, arg, imm) => {
println!("${out} = COMPARE {imm} {arg}")
}
SsaOp::CopyImm(out, imm) => {
println!("${out} = COPY {imm}");
}
Expand Down
30 changes: 30 additions & 0 deletions fidget/src/core/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,31 @@ impl Context {
}
}

/// Builds a node that compares two values
///
/// The result is -1 if `a < b`, +1 if `a > b`, 0 if `a == b`, and `NaN` if
/// either side is `NaN`.
/// ```
/// # let mut ctx = fidget::context::Context::new();
/// let x = ctx.x();
/// let op = ctx.compare(x, 1.0).unwrap();
/// let v = ctx.eval_xyz(op, 0.0, 0.0, 0.0).unwrap();
/// assert_eq!(v, -1.0);
/// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
/// assert_eq!(v, 1.0);
/// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
/// assert_eq!(v, 0.0);
/// ```
pub fn compare<A: IntoNode, B: IntoNode>(
&mut self,
a: A,
b: B,
) -> Result<Node, Error> {
let a = a.into_node(self)?;
let b = b.into_node(self)?;
self.op_binary(a, b, BinaryOpcode::Compare)
}

////////////////////////////////////////////////////////////////////////////

/// Remaps the X, Y, Z nodes to the given values
Expand Down Expand Up @@ -643,6 +668,10 @@ impl Context {
BinaryOpcode::Div => a / b,
BinaryOpcode::Min => a.min(b),
BinaryOpcode::Max => a.max(b),
BinaryOpcode::Compare => a
.partial_cmp(&b)
.map(|i| i as i8 as f64)
.unwrap_or(f64::NAN),
}
}

Expand Down Expand Up @@ -769,6 +798,7 @@ impl Context {
BinaryOpcode::Div => out += "div",
BinaryOpcode::Min => out += "min",
BinaryOpcode::Max => out += "max",
BinaryOpcode::Compare => out += "less-than",
},
Op::Unary(op, ..) => match op {
UnaryOpcode::Neg => out += "neg",
Expand Down
1 change: 1 addition & 0 deletions fidget/src/core/context/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub enum BinaryOpcode {
Div,
Min,
Max,
Compare,
}

/// An operation in a math expression.
Expand Down
12 changes: 8 additions & 4 deletions fidget/src/core/eval/test/grad_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ where
}
}

pub fn compare_grad_results(
pub fn compare_grad_results<C: CanonicalBinaryOp>(
i: usize,
j: usize,
lhs: &[f32],
Expand All @@ -503,6 +503,10 @@ where
continue;
}

if C::discontinuous_at(*a, *b) {
continue;
}

const EPSILON: f64 = 1e-8;
let a = *a as f64;
let b = *b as f64;
Expand Down Expand Up @@ -599,7 +603,7 @@ where
.unwrap();

let rhs = if i == j { &args } else { &rgsa };
Self::compare_grad_results(
Self::compare_grad_results::<C>(
i,
j,
&args,
Expand Down Expand Up @@ -642,7 +646,7 @@ where
.unwrap();

let rhs = vec![*rhs; out.len()];
Self::compare_grad_results(
Self::compare_grad_results::<C>(
i,
3,
&args,
Expand Down Expand Up @@ -685,7 +689,7 @@ where
.unwrap();

let lhs = vec![*lhs; out.len()];
Self::compare_grad_results(
Self::compare_grad_results::<C>(
3,
i,
&lhs,
Expand Down
14 changes: 14 additions & 0 deletions fidget/src/core/eval/test/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,19 @@ where
);
}

pub fn test_i_compare() {
let mut ctx = Context::new();
let x = ctx.x();
let y = ctx.y();
let c = ctx.compare(x, y).unwrap();

let shape = S::new(&ctx, c).unwrap();
let tape = shape.ez_interval_tape();
let mut eval = S::new_interval_eval();
let (out, _trace) = eval.eval(&tape, -5.0, -6.0, 0.0, &[]).unwrap();
assert_eq!(out, Interval::from(1f32));
}

pub fn test_i_stress_n(depth: usize) {
let (ctx, node) = build_stress_fn(depth);

Expand Down Expand Up @@ -947,6 +960,7 @@ macro_rules! interval_tests {
$crate::interval_test!(test_i_min_imm, $t);
$crate::interval_test!(test_i_max, $t);
$crate::interval_test!(test_i_max_imm, $t);
$crate::interval_test!(test_i_compare, $t);
$crate::interval_test!(test_i_simplify, $t);
$crate::interval_test!(test_i_var, $t);
$crate::interval_test!(test_i_stress, $t);
Expand Down
25 changes: 24 additions & 1 deletion fidget/src/core/eval/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ pub trait CanonicalBinaryOp {
fn eval_reg_reg_f64(lhs: f64, rhs: f64) -> f64;
fn eval_reg_imm_f64(lhs: f64, rhs: f64) -> f64;
fn eval_imm_reg_f64(lhs: f64, rhs: f64) -> f64;

/// Returns true if there is a bidirectional discontinuity at a position
///
/// This means that we should skip gradient checking, because we can't
/// accurately estimate the gradient on either side.
fn discontinuous_at(_lhs: f32, _rhs: f32) -> bool {
false
}
}

macro_rules! declare_canonical_unary {
Expand All @@ -93,7 +101,7 @@ macro_rules! declare_canonical_unary {
}

macro_rules! declare_canonical_binary {
(Context::$i:ident, |$lhs:ident, $rhs:ident| $t:expr) => {
(Context::$i:ident, |$lhs:ident, $rhs:ident| $t:expr, |$lhs2:ident, $rhs2: ident| $d:expr) => {
pub struct $i;
impl CanonicalBinaryOp for $i {
const NAME: &'static str = stringify!($i);
Expand All @@ -118,8 +126,14 @@ macro_rules! declare_canonical_binary {
fn eval_imm_reg_f64($lhs: f64, $rhs: f64) -> f64 {
$t
}
fn discontinuous_at($lhs2: f32, $rhs2: f32) -> bool {
$d
}
}
};
(Context::$i:ident, |$lhs:ident, $rhs:ident| $t:expr) => {
declare_canonical_binary!(Context::$i, |$lhs, $rhs| $t, |_a, _b| false);
};
}

macro_rules! declare_canonical_binary_full {
Expand Down Expand Up @@ -204,6 +218,14 @@ pub mod canonical {
a.max(b)
}
);
declare_canonical_binary!(
Context::compare,
|a, b| match a.partial_cmp(&b) {
None => f32::NAN.into(),
Some(v) => (v as i8).into(),
},
|a, b| a == b
);
}

#[macro_export]
Expand Down Expand Up @@ -254,5 +276,6 @@ macro_rules! all_binary_tests {
$crate::one_binary_test!($tester, div);
$crate::one_binary_test!($tester, min);
$crate::one_binary_test!($tester, max);
$crate::one_binary_test!($tester, compare);
};
}
7 changes: 5 additions & 2 deletions fidget/src/core/vm/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ impl<const N: usize> VmData<N> {
SsaOp::AddRegReg(index, lhs, rhs)
| SsaOp::MulRegReg(index, lhs, rhs)
| SsaOp::SubRegReg(index, lhs, rhs)
| SsaOp::DivRegReg(index, lhs, rhs) => {
| SsaOp::DivRegReg(index, lhs, rhs)
| SsaOp::CompareRegReg(index, lhs, rhs) => {
*index = new_index;
*lhs = workspace.get_or_insert_active(*lhs);
*rhs = workspace.get_or_insert_active(*rhs);
Expand All @@ -249,7 +250,9 @@ impl<const N: usize> VmData<N> {
| SsaOp::SubRegImm(index, arg, _imm)
| SsaOp::SubImmReg(index, arg, _imm)
| SsaOp::DivRegImm(index, arg, _imm)
| SsaOp::DivImmReg(index, arg, _imm) => {
| SsaOp::DivImmReg(index, arg, _imm)
| SsaOp::CompareRegImm(index, arg, _imm)
| SsaOp::CompareImmReg(index, arg, _imm) => {
*index = new_index;
*arg = workspace.get_or_insert_active(*arg);
}
Expand Down
Loading

0 comments on commit 82c75f5

Please sign in to comment.