Skip to content

Commit

Permalink
Return NAN interval if NAN can be produced anywhere (#43)
Browse files Browse the repository at this point in the history
This could make `sqrt` worse, but is more correct.
  • Loading branch information
mkeeter authored Mar 21, 2024
1 parent d76875e commit 99b65c0
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 54 deletions.
58 changes: 40 additions & 18 deletions fidget/src/core/eval/test/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ where
let mut eval = S::new_interval_eval();
assert_eq!(eval.eval_x(&tape, [0.0, 1.0]), [0.0, 1.0].into());
assert_eq!(eval.eval_x(&tape, [0.0, 4.0]), [0.0, 2.0].into());
assert_eq!(eval.eval_x(&tape, [-2.0, 4.0]), [0.0, 2.0].into());

// Even a partial negative returns a NAN interval
let nanan = eval.eval_x(&tape, [-2.0, 4.0]);
assert!(nanan.lower().is_nan());
assert!(nanan.upper().is_nan());

// Full negatives are right out
let nanan = eval.eval_x(&tape, [-2.0, -1.0]);
assert!(nanan.lower().is_nan());
assert!(nanan.upper().is_nan());
Expand Down Expand Up @@ -701,15 +707,23 @@ where
.min(a.upper())
.max(a.lower());
let inside_value = C::eval_f32(inside);
assert!(
inside_value.is_nan()
|| o.lower().is_nan()
|| (inside_value >= o.lower()
&& inside_value <= o.upper()),
"interval failure in '{}': {inside} in {a} => \
{inside_value} not in {o}",
C::NAME,
);

if inside_value.is_nan() || inside_value.is_infinite() {
assert!(
o.has_nan(),
"interval failure in '{}': {inside} in {a} => \
{inside_value} not in {o} (should be [NaN, NaN])",
C::NAME,
);
} else if !o.has_nan() {
assert!(
inside_value >= o.lower()
&& inside_value <= o.upper(),
"interval failure in '{}': {inside} in {a} => \
{inside_value} not in {o}",
C::NAME,
);
}
}
}
tape_data = Some(tape.recycle());
Expand Down Expand Up @@ -737,14 +751,22 @@ where
.min(rhs.upper())
.max(rhs.lower());
let inside_value = g(v_lhs, v_rhs);
assert!(
inside_value.is_nan()
|| out.lower().is_nan()
|| (inside_value >= out.lower()
&& inside_value <= out.upper()),
"interval failure in '{name}': ({v_lhs}, {v_rhs}) in \
({lhs}, {rhs}) => {inside_value} not in {out}"
);

if inside_value.is_nan() || inside_value.is_infinite() {
assert!(
out.has_nan(),
"interval failure in '{name}': ({v_lhs}, {v_rhs}) in \
({lhs}, {rhs}) => {inside_value} not in {out} \
(should be [NaN, NaN])"
);
} else if !out.has_nan() {
assert!(
inside_value >= out.lower()
&& inside_value <= out.upper(),
"interval failure in '{name}': ({v_lhs}, {v_rhs}) in \
({lhs}, {rhs}) => {inside_value} not in {out}"
);
}
}
}
}
Expand Down
25 changes: 14 additions & 11 deletions fidget/src/core/eval/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,15 +374,23 @@ impl Interval {
///
/// Right now, this always returns the maximum range of `[-1, 1]`
pub fn sin(self) -> Self {
// TODO: make this smarter
Interval::new(-1.0, 1.0)
if self.has_nan() {
f32::NAN.into()
} else {
// TODO: make this smarter
Interval::new(-1.0, 1.0)
}
}
/// Computes the cosine of the interval
///
/// Right now, this always returns the maximum range of `[-1, 1]`
pub fn cos(self) -> Self {
// TODO: make this smarter
Interval::new(-1.0, 1.0)
if self.has_nan() {
f32::NAN.into()
} else {
// TODO: make this smarter
Interval::new(-1.0, 1.0)
}
}
/// Computes the tangent of the interval
///
Expand Down Expand Up @@ -441,15 +449,10 @@ impl Interval {
}
/// Calculates the square root of the interval
///
/// If the entire interval is below 0, returns a `NAN` interval; otherwise,
/// returns the valid (positive) interval.
/// If the interval contains values below 0, returns a `NAN` interval.
pub fn sqrt(self) -> Self {
if self.lower < 0.0 {
if self.upper > 0.0 {
Interval::new(0.0, self.upper.sqrt())
} else {
f32::NAN.into()
}
f32::NAN.into()
} else {
Interval::new(self.lower.sqrt(), self.upper.sqrt())
}
Expand Down
17 changes: 3 additions & 14 deletions fidget/src/jit/aarch64/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,27 +265,16 @@ impl Assembler for IntervalAssembler {
}
fn build_sqrt(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops
// Store lhs <= 0.0 in x15
; fcmle v4.s2, V(reg(lhs_reg)).s2, 0.0
// Store lhs < 0.0 in x15
; fcmlt v4.s2, V(reg(lhs_reg)).s2, 0.0
; fmov x15, d4

// Check whether lhs.upper < 0
; tst x15, 0x1_0000_0000
; b.ne 40 // -> upper_lz

; tst x15, 0x1
; b.ne 12 // -> lower_lz

// Happy path
; fsqrt V(reg(out_reg)).s2, V(reg(lhs_reg)).s2
; b 32 // -> end

// <- lower_lz
; mov v4.s[0], V(reg(lhs_reg)).s[1]
; fsqrt s4, s4
; movi D(reg(out_reg)), 0
; mov V(reg(out_reg)).s[1], v4.s[0]
; b 12
; b 12 // -> end

// <- upper_lz
; mov w9, f32::NAN.to_bits().into()
Expand Down
12 changes: 1 addition & 11 deletions fidget/src/jit/x86_64/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,25 +261,15 @@ impl Assembler for IntervalAssembler {
fn build_sqrt(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops
; vpxor xmm0, xmm0, xmm0 // xmm0 = 0.0
; vpshufd xmm1, Rx(reg(lhs_reg)), 1
; vcomiss xmm0, xmm1
; ja >U // upper_lz
; vcomiss xmm0, Rx(reg(lhs_reg))
; ja >L // lower_lz

// Happy path
; vsqrtps Rx(reg(out_reg)), Rx(reg(lhs_reg))
; jmp >E

// lower < 0, upper > 0 => [0, sqrt(upper)]
// lower < 0 => [NaN, NaN]
; L:
; vpxor xmm0, xmm0, xmm0 // clear xmm0
; vsqrtss xmm0, xmm0, xmm1
; vpshufd Rx(reg(out_reg)), xmm0, 0b11110011u8 as i8
; jmp >E

// upper < 0 => [NaN, NaN]
; U:
; vpcmpeqw Rx(reg(out_reg)), Rx(reg(out_reg)), Rx(reg(out_reg))
; vpslld Rx(reg(out_reg)), Rx(reg(out_reg)), 23
; vpsrld Rx(reg(out_reg)), Rx(reg(out_reg)), 1
Expand Down

0 comments on commit 99b65c0

Please sign in to comment.