From 76ae930b90476c4a94adcd6a1ada60d2948f6e85 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Sun, 12 May 2024 18:47:04 -0400 Subject: [PATCH] Assembly implementations of `floor` / `ceil` / `round` (#104) gotta go fast --- fidget/src/core/eval/test/mod.rs | 2 ++ fidget/src/jit/aarch64/float_slice.rs | 49 ++++++++++++++++++++------- fidget/src/jit/aarch64/interval.rs | 48 +++++++++++++++++++------- fidget/src/jit/aarch64/point.rs | 49 ++++++++++++++++++++------- fidget/src/jit/x86_64/float_slice.rs | 32 ++++++++++------- fidget/src/jit/x86_64/grad_slice.rs | 34 ++++++++++++------- fidget/src/jit/x86_64/interval.rs | 32 ++++++++++------- fidget/src/jit/x86_64/point.rs | 30 +++++++++------- 8 files changed, 186 insertions(+), 90 deletions(-) diff --git a/fidget/src/core/eval/test/mod.rs b/fidget/src/core/eval/test/mod.rs index ea2d9663..ab168da9 100644 --- a/fidget/src/core/eval/test/mod.rs +++ b/fidget/src/core/eval/test/mod.rs @@ -41,6 +41,8 @@ fn test_args_n(n: i64) -> Vec { .collect::>(); args.push(1.0); args.push(5.0); + args.push(0.5); + args.push(1.5); args.push(10.0); args.push(std::f32::consts::PI); args.push(std::f32::consts::FRAC_PI_2); diff --git a/fidget/src/jit/aarch64/float_slice.rs b/fidget/src/jit/aarch64/float_slice.rs index 572bcaf4..1edc816e 100644 --- a/fidget/src/jit/aarch64/float_slice.rs +++ b/fidget/src/jit/aarch64/float_slice.rs @@ -237,24 +237,47 @@ impl Assembler for FloatSliceAssembler { ) } - // TODO optimize these three functions fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) { - extern "C" fn float_floor(f: f32) -> f32 { - f.floor() - } - self.call_fn_unary(out_reg, lhs_reg, float_floor); + dynasm!(self.0.ops + // Build a NAN mask + ; fcmeq v6.s4, V(reg(lhs_reg)).s4, V(reg(lhs_reg)).s4 + ; mvn v6.b16, v6.b16 + + // Round, then convert back to f32 + ; fcvtms V(reg(out_reg)).s4, V(reg(lhs_reg)).s4 + ; scvtf V(reg(out_reg)).s4, V(reg(out_reg)).s4 + + // Apply the NAN mask + ; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v6.b16 + ); } fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) { - extern "C" fn float_ceil(f: f32) -> f32 { - f.ceil() - } - self.call_fn_unary(out_reg, lhs_reg, float_ceil); + dynasm!(self.0.ops + // Build a NAN mask + ; fcmeq v6.s4, V(reg(lhs_reg)).s4, V(reg(lhs_reg)).s4 + ; mvn v6.b16, v6.b16 + + // Round, then convert back to f32 + ; fcvtps V(reg(out_reg)).s4, V(reg(lhs_reg)).s4 + ; scvtf V(reg(out_reg)).s4, V(reg(out_reg)).s4 + + // Apply the NAN mask + ; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v6.b16 + ); } fn build_round(&mut self, out_reg: u8, lhs_reg: u8) { - extern "C" fn float_round(f: f32) -> f32 { - f.round() - } - self.call_fn_unary(out_reg, lhs_reg, float_round); + dynasm!(self.0.ops + // Build a NAN mask + ; fcmeq v6.s4, V(reg(lhs_reg)).s4, V(reg(lhs_reg)).s4 + ; mvn v6.b16, v6.b16 + + // Round, then convert back to f32 + ; fcvtas V(reg(out_reg)).s4, V(reg(lhs_reg)).s4 + ; scvtf V(reg(out_reg)).s4, V(reg(out_reg)).s4 + + // Apply the NAN mask + ; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v6.b16 + ); } fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { diff --git a/fidget/src/jit/aarch64/interval.rs b/fidget/src/jit/aarch64/interval.rs index 650b73a1..a19179c0 100644 --- a/fidget/src/jit/aarch64/interval.rs +++ b/fidget/src/jit/aarch64/interval.rs @@ -293,22 +293,46 @@ impl Assembler for IntervalAssembler { // TODO hand-write these functions fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) { - extern "C" fn interval_floor(v: Interval) -> Interval { - v.floor() - } - self.call_fn_unary(out_reg, lhs_reg, interval_floor); + dynasm!(self.0.ops + // Build a NAN mask + ; fcmeq v6.s2, V(reg(lhs_reg)).s2, V(reg(lhs_reg)).s2 + ; mvn v6.b8, v6.b8 + + // Round, then convert back to f32 + ; fcvtms V(reg(out_reg)).s2, V(reg(lhs_reg)).s2 + ; scvtf V(reg(out_reg)).s2, V(reg(out_reg)).s2 + + // Apply the NAN mask + ; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v6.b8 + ); } fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) { - extern "C" fn interval_ceil(v: Interval) -> Interval { - v.ceil() - } - self.call_fn_unary(out_reg, lhs_reg, interval_ceil); + dynasm!(self.0.ops + // Build a NAN mask + ; fcmeq v6.s2, V(reg(lhs_reg)).s2, V(reg(lhs_reg)).s2 + ; mvn v6.b8, v6.b8 + + // Round, then convert back to f32 + ; fcvtps V(reg(out_reg)).s2, V(reg(lhs_reg)).s2 + ; scvtf V(reg(out_reg)).s2, V(reg(out_reg)).s2 + + // Apply the NAN mask + ; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v6.b8 + ); } fn build_round(&mut self, out_reg: u8, lhs_reg: u8) { - extern "C" fn interval_round(v: Interval) -> Interval { - v.round() - } - self.call_fn_unary(out_reg, lhs_reg, interval_round); + dynasm!(self.0.ops + // Build a NAN mask + ; fcmeq v6.s2, V(reg(lhs_reg)).s2, V(reg(lhs_reg)).s2 + ; mvn v6.b8, v6.b8 + + // Round, then convert back to f32 + ; fcvtas V(reg(out_reg)).s2, V(reg(lhs_reg)).s2 + ; scvtf V(reg(out_reg)).s2, V(reg(out_reg)).s2 + + // Apply the NAN mask + ; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v6.b8 + ); } fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { diff --git a/fidget/src/jit/aarch64/point.rs b/fidget/src/jit/aarch64/point.rs index cb74d34a..dc6942a4 100644 --- a/fidget/src/jit/aarch64/point.rs +++ b/fidget/src/jit/aarch64/point.rs @@ -191,24 +191,47 @@ impl Assembler for PointAssembler { dynasm!(self.0.ops ; fmul S(reg(out_reg)), S(reg(lhs_reg)), S(reg(lhs_reg))) } - // TODO optimize these three functions fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) { - extern "C" fn float_floor(f: f32) -> f32 { - f.floor() - } - self.call_fn_unary(out_reg, lhs_reg, float_floor); + dynasm!(self.0.ops + // Build a NAN mask + ; fcmeq s6, S(reg(lhs_reg)), S(reg(lhs_reg)) + ; mvn v6.b8, v6.b8 + + // Round, then convert back to f32 + ; fcvtms S(reg(out_reg)), S(reg(lhs_reg)) + ; scvtf S(reg(out_reg)), S(reg(out_reg)) + + // Apply the NAN mask + ; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v6.b8 + ); } fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) { - extern "C" fn float_ceil(f: f32) -> f32 { - f.ceil() - } - self.call_fn_unary(out_reg, lhs_reg, float_ceil); + dynasm!(self.0.ops + // Build a NAN mask + ; fcmeq s6, S(reg(lhs_reg)), S(reg(lhs_reg)) + ; mvn v6.b8, v6.b8 + + // Round, then convert back to f32 + ; fcvtps S(reg(out_reg)), S(reg(lhs_reg)) + ; scvtf S(reg(out_reg)), S(reg(out_reg)) + + // Apply the NAN mask + ; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v6.b8 + ); } fn build_round(&mut self, out_reg: u8, lhs_reg: u8) { - extern "C" fn float_round(f: f32) -> f32 { - f.round() - } - self.call_fn_unary(out_reg, lhs_reg, float_round); + dynasm!(self.0.ops + // Build a NAN mask + ; fcmeq s6, S(reg(lhs_reg)), S(reg(lhs_reg)) + ; mvn v6.b8, v6.b8 + + // Round, then convert back to f32 + ; fcvtas S(reg(out_reg)), S(reg(lhs_reg)) + ; scvtf S(reg(out_reg)), S(reg(out_reg)) + + // Apply the NAN mask + ; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v6.b8 + ); } fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { diff --git a/fidget/src/jit/x86_64/float_slice.rs b/fidget/src/jit/x86_64/float_slice.rs index 2fd6bd56..db7c59e8 100644 --- a/fidget/src/jit/x86_64/float_slice.rs +++ b/fidget/src/jit/x86_64/float_slice.rs @@ -194,24 +194,30 @@ impl Assembler for FloatSliceAssembler { ); } - // TODO optimize these three functions fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) { - extern "sysv64" fn float_floor(f: f32) -> f32 { - f.floor() - } - self.call_fn_unary(out_reg, lhs_reg, float_floor); + dynasm!(self.0.ops + ; vroundps Ry(reg(out_reg)), Ry(reg(lhs_reg)), 1 + ); } fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) { - extern "sysv64" fn float_ceil(f: f32) -> f32 { - f.ceil() - } - self.call_fn_unary(out_reg, lhs_reg, float_ceil); + dynasm!(self.0.ops + ; vroundps Ry(reg(out_reg)), Ry(reg(lhs_reg)), 2 + ); } fn build_round(&mut self, out_reg: u8, lhs_reg: u8) { - extern "sysv64" fn float_round(f: f32) -> f32 { - f.round() - } - self.call_fn_unary(out_reg, lhs_reg, float_round); + // Shenanigans figured through Godbolt + dynasm!(self.0.ops + ; mov eax, 0x80000000u32 as i32 + ; vmovd xmm1, eax + ; vbroadcastss ymm1, xmm1 + ; vandps ymm1, ymm1, Ry(reg(lhs_reg)) + ; mov eax, 0x3effffffu32 as i32 + ; vmovd xmm2, eax + ; vbroadcastss ymm2, xmm2 + ; vorps ymm1, ymm1, ymm2 + ; vaddps Ry(reg(out_reg)), ymm1, Ry(reg(lhs_reg)) + ; vroundps Ry(reg(out_reg)), Ry(reg(out_reg)), 3 + ); } fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { diff --git a/fidget/src/jit/x86_64/grad_slice.rs b/fidget/src/jit/x86_64/grad_slice.rs index 4e1534e9..7b904f7f 100644 --- a/fidget/src/jit/x86_64/grad_slice.rs +++ b/fidget/src/jit/x86_64/grad_slice.rs @@ -262,24 +262,32 @@ impl Assembler for GradSliceAssembler { ); } - // TODO hand-write these functions fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) { - extern "sysv64" fn grad_floor(v: Grad) -> Grad { - v.floor() - } - self.call_fn_unary(out_reg, lhs_reg, grad_floor); + dynasm!(self.0.ops + ; vroundss xmm1, Rx(reg(lhs_reg)), Rx(reg(lhs_reg)), 1 + ; vpxor Rx(reg(out_reg)), Rx(reg(out_reg)), Rx(reg(out_reg)) + ; movss Rx(reg(out_reg)), xmm1 + ); } fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) { - extern "sysv64" fn grad_ceil(v: Grad) -> Grad { - v.ceil() - } - self.call_fn_unary(out_reg, lhs_reg, grad_ceil); + dynasm!(self.0.ops + ; vroundss xmm1, Rx(reg(lhs_reg)), Rx(reg(lhs_reg)), 2 + ; vpxor Rx(reg(out_reg)), Rx(reg(out_reg)), Rx(reg(out_reg)) + ; movss Rx(reg(out_reg)), xmm1 + ); } fn build_round(&mut self, out_reg: u8, lhs_reg: u8) { - extern "sysv64" fn grad_round(v: Grad) -> Grad { - v.round() - } - self.call_fn_unary(out_reg, lhs_reg, grad_round); + // Shenanigans figured through Godbolt + dynasm!(self.0.ops + ; mov eax, 0x80000000u32 as i32 + ; vmovd xmm1, eax + ; vandps xmm1, xmm1, Rx(reg(lhs_reg)) + ; mov eax, 0x3effffffu32 as i32 + ; vmovd xmm2, eax + ; vorps xmm1, xmm1, xmm2 + ; vaddss Rx(reg(out_reg)), xmm1, Rx(reg(lhs_reg)) + ; vroundss Rx(reg(out_reg)), Rx(reg(out_reg)), Rx(reg(out_reg)), 3 + ); } fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { diff --git a/fidget/src/jit/x86_64/interval.rs b/fidget/src/jit/x86_64/interval.rs index 65d16366..2a242f23 100644 --- a/fidget/src/jit/x86_64/interval.rs +++ b/fidget/src/jit/x86_64/interval.rs @@ -290,24 +290,30 @@ impl Assembler for IntervalAssembler { ); self.0.ops.commit_local().unwrap(); } - // TODO hand-write these functions fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) { - extern "sysv64" fn interval_floor(v: Interval) -> Interval { - v.floor() - } - self.call_fn_unary(out_reg, lhs_reg, interval_floor); + dynasm!(self.0.ops + ; vroundps Rx(reg(out_reg)), Rx(reg(lhs_reg)), 1 + ); } fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) { - extern "sysv64" fn interval_ceil(v: Interval) -> Interval { - v.ceil() - } - self.call_fn_unary(out_reg, lhs_reg, interval_ceil); + dynasm!(self.0.ops + ; vroundps Rx(reg(out_reg)), Rx(reg(lhs_reg)), 2 + ); } fn build_round(&mut self, out_reg: u8, lhs_reg: u8) { - extern "sysv64" fn interval_round(v: Interval) -> Interval { - v.round() - } - self.call_fn_unary(out_reg, lhs_reg, interval_round); + // Shenanigans figured through Godbolt + dynasm!(self.0.ops + ; mov eax, 0x80000000u32 as i32 + ; vmovd xmm1, eax + ; vbroadcastss xmm1, xmm1 + ; vandps xmm1, xmm1, Rx(reg(lhs_reg)) + ; mov eax, 0x3effffffu32 as i32 + ; vmovd xmm2, eax + ; vbroadcastss xmm2, xmm2 + ; vorps xmm1, xmm1, xmm2 + ; vaddps Rx(reg(out_reg)), xmm1, Rx(reg(lhs_reg)) + ; vroundps Rx(reg(out_reg)), Rx(reg(out_reg)), 3 + ); } fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) { diff --git a/fidget/src/jit/x86_64/point.rs b/fidget/src/jit/x86_64/point.rs index 35057360..21140cf8 100644 --- a/fidget/src/jit/x86_64/point.rs +++ b/fidget/src/jit/x86_64/point.rs @@ -178,24 +178,28 @@ impl Assembler for PointAssembler { ); } - // TODO optimize these three functions fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) { - extern "sysv64" fn float_floor(f: f32) -> f32 { - f.floor() - } - self.call_fn_unary(out_reg, lhs_reg, float_floor); + dynasm!(self.0.ops + ; vroundss Rx(reg(out_reg)), Rx(reg(lhs_reg)), Rx(reg(lhs_reg)), 1 + ); } fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) { - extern "sysv64" fn float_ceil(f: f32) -> f32 { - f.ceil() - } - self.call_fn_unary(out_reg, lhs_reg, float_ceil); + dynasm!(self.0.ops + ; vroundss Rx(reg(out_reg)), Rx(reg(lhs_reg)), Rx(reg(lhs_reg)), 2 + ); } fn build_round(&mut self, out_reg: u8, lhs_reg: u8) { - extern "sysv64" fn float_round(f: f32) -> f32 { - f.round() - } - self.call_fn_unary(out_reg, lhs_reg, float_round); + // Shenanigans figured through Godbolt + dynasm!(self.0.ops + ; mov eax, 0x80000000u32 as i32 + ; vmovd xmm1, eax + ; vandps xmm1, xmm1, Rx(reg(lhs_reg)) + ; mov eax, 0x3effffffu32 as i32 + ; vmovd xmm2, eax + ; vorps xmm1, xmm1, xmm2 + ; vaddss Rx(reg(out_reg)), xmm1, Rx(reg(lhs_reg)) + ; vroundss Rx(reg(out_reg)), Rx(reg(out_reg)), Rx(reg(out_reg)), 3 + ); } fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {