Skip to content

Commit

Permalink
Assembly implementations of floor / ceil / round (#104)
Browse files Browse the repository at this point in the history
gotta go fast
  • Loading branch information
mkeeter authored May 12, 2024
1 parent 4cc5f6a commit 76ae930
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 90 deletions.
2 changes: 2 additions & 0 deletions fidget/src/core/eval/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ fn test_args_n(n: i64) -> Vec<f32> {
.collect::<Vec<_>>();
args.push(1.0);
args.push(5.0);
args.push(0.5);
args.push(1.5);
args.push(10.0);
args.push(std::f32::consts::PI);
args.push(std::f32::consts::FRAC_PI_2);
Expand Down
49 changes: 36 additions & 13 deletions fidget/src/jit/aarch64/float_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,24 +237,47 @@ impl Assembler for FloatSliceAssembler {
)
}

// TODO optimize these three functions
fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_floor(f: f32) -> f32 {
f.floor()
}
self.call_fn_unary(out_reg, lhs_reg, float_floor);
dynasm!(self.0.ops
// Build a NAN mask
; fcmeq v6.s4, V(reg(lhs_reg)).s4, V(reg(lhs_reg)).s4
; mvn v6.b16, v6.b16

// Round, then convert back to f32
; fcvtms V(reg(out_reg)).s4, V(reg(lhs_reg)).s4
; scvtf V(reg(out_reg)).s4, V(reg(out_reg)).s4

// Apply the NAN mask
; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v6.b16
);
}
fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_ceil(f: f32) -> f32 {
f.ceil()
}
self.call_fn_unary(out_reg, lhs_reg, float_ceil);
dynasm!(self.0.ops
// Build a NAN mask
; fcmeq v6.s4, V(reg(lhs_reg)).s4, V(reg(lhs_reg)).s4
; mvn v6.b16, v6.b16

// Round, then convert back to f32
; fcvtps V(reg(out_reg)).s4, V(reg(lhs_reg)).s4
; scvtf V(reg(out_reg)).s4, V(reg(out_reg)).s4

// Apply the NAN mask
; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v6.b16
);
}
fn build_round(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_round(f: f32) -> f32 {
f.round()
}
self.call_fn_unary(out_reg, lhs_reg, float_round);
dynasm!(self.0.ops
// Build a NAN mask
; fcmeq v6.s4, V(reg(lhs_reg)).s4, V(reg(lhs_reg)).s4
; mvn v6.b16, v6.b16

// Round, then convert back to f32
; fcvtas V(reg(out_reg)).s4, V(reg(lhs_reg)).s4
; scvtf V(reg(out_reg)).s4, V(reg(out_reg)).s4

// Apply the NAN mask
; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v6.b16
);
}

fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
Expand Down
48 changes: 36 additions & 12 deletions fidget/src/jit/aarch64/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,22 +293,46 @@ impl Assembler for IntervalAssembler {

// TODO hand-write these functions
fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn interval_floor(v: Interval) -> Interval {
v.floor()
}
self.call_fn_unary(out_reg, lhs_reg, interval_floor);
dynasm!(self.0.ops
// Build a NAN mask
; fcmeq v6.s2, V(reg(lhs_reg)).s2, V(reg(lhs_reg)).s2
; mvn v6.b8, v6.b8

// Round, then convert back to f32
; fcvtms V(reg(out_reg)).s2, V(reg(lhs_reg)).s2
; scvtf V(reg(out_reg)).s2, V(reg(out_reg)).s2

// Apply the NAN mask
; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v6.b8
);
}
fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn interval_ceil(v: Interval) -> Interval {
v.ceil()
}
self.call_fn_unary(out_reg, lhs_reg, interval_ceil);
dynasm!(self.0.ops
// Build a NAN mask
; fcmeq v6.s2, V(reg(lhs_reg)).s2, V(reg(lhs_reg)).s2
; mvn v6.b8, v6.b8

// Round, then convert back to f32
; fcvtps V(reg(out_reg)).s2, V(reg(lhs_reg)).s2
; scvtf V(reg(out_reg)).s2, V(reg(out_reg)).s2

// Apply the NAN mask
; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v6.b8
);
}
fn build_round(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn interval_round(v: Interval) -> Interval {
v.round()
}
self.call_fn_unary(out_reg, lhs_reg, interval_round);
dynasm!(self.0.ops
// Build a NAN mask
; fcmeq v6.s2, V(reg(lhs_reg)).s2, V(reg(lhs_reg)).s2
; mvn v6.b8, v6.b8

// Round, then convert back to f32
; fcvtas V(reg(out_reg)).s2, V(reg(lhs_reg)).s2
; scvtf V(reg(out_reg)).s2, V(reg(out_reg)).s2

// Apply the NAN mask
; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v6.b8
);
}

fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
Expand Down
49 changes: 36 additions & 13 deletions fidget/src/jit/aarch64/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,24 +191,47 @@ impl Assembler for PointAssembler {
dynasm!(self.0.ops ; fmul S(reg(out_reg)), S(reg(lhs_reg)), S(reg(lhs_reg)))
}

// TODO optimize these three functions
fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_floor(f: f32) -> f32 {
f.floor()
}
self.call_fn_unary(out_reg, lhs_reg, float_floor);
dynasm!(self.0.ops
// Build a NAN mask
; fcmeq s6, S(reg(lhs_reg)), S(reg(lhs_reg))
; mvn v6.b8, v6.b8

// Round, then convert back to f32
; fcvtms S(reg(out_reg)), S(reg(lhs_reg))
; scvtf S(reg(out_reg)), S(reg(out_reg))

// Apply the NAN mask
; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v6.b8
);
}
fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_ceil(f: f32) -> f32 {
f.ceil()
}
self.call_fn_unary(out_reg, lhs_reg, float_ceil);
dynasm!(self.0.ops
// Build a NAN mask
; fcmeq s6, S(reg(lhs_reg)), S(reg(lhs_reg))
; mvn v6.b8, v6.b8

// Round, then convert back to f32
; fcvtps S(reg(out_reg)), S(reg(lhs_reg))
; scvtf S(reg(out_reg)), S(reg(out_reg))

// Apply the NAN mask
; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v6.b8
);
}
fn build_round(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_round(f: f32) -> f32 {
f.round()
}
self.call_fn_unary(out_reg, lhs_reg, float_round);
dynasm!(self.0.ops
// Build a NAN mask
; fcmeq s6, S(reg(lhs_reg)), S(reg(lhs_reg))
; mvn v6.b8, v6.b8

// Round, then convert back to f32
; fcvtas S(reg(out_reg)), S(reg(lhs_reg))
; scvtf S(reg(out_reg)), S(reg(out_reg))

// Apply the NAN mask
; orr V(reg(out_reg)).B8, V(reg(out_reg)).B8, v6.b8
);
}

fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
Expand Down
32 changes: 19 additions & 13 deletions fidget/src/jit/x86_64/float_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,24 +194,30 @@ impl Assembler for FloatSliceAssembler {
);
}

// TODO optimize these three functions
fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) {
extern "sysv64" fn float_floor(f: f32) -> f32 {
f.floor()
}
self.call_fn_unary(out_reg, lhs_reg, float_floor);
dynasm!(self.0.ops
; vroundps Ry(reg(out_reg)), Ry(reg(lhs_reg)), 1
);
}
fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) {
extern "sysv64" fn float_ceil(f: f32) -> f32 {
f.ceil()
}
self.call_fn_unary(out_reg, lhs_reg, float_ceil);
dynasm!(self.0.ops
; vroundps Ry(reg(out_reg)), Ry(reg(lhs_reg)), 2
);
}
fn build_round(&mut self, out_reg: u8, lhs_reg: u8) {
extern "sysv64" fn float_round(f: f32) -> f32 {
f.round()
}
self.call_fn_unary(out_reg, lhs_reg, float_round);
// Shenanigans figured through Godbolt
dynasm!(self.0.ops
; mov eax, 0x80000000u32 as i32
; vmovd xmm1, eax
; vbroadcastss ymm1, xmm1
; vandps ymm1, ymm1, Ry(reg(lhs_reg))
; mov eax, 0x3effffffu32 as i32
; vmovd xmm2, eax
; vbroadcastss ymm2, xmm2
; vorps ymm1, ymm1, ymm2
; vaddps Ry(reg(out_reg)), ymm1, Ry(reg(lhs_reg))
; vroundps Ry(reg(out_reg)), Ry(reg(out_reg)), 3
);
}

fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
Expand Down
34 changes: 21 additions & 13 deletions fidget/src/jit/x86_64/grad_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,24 +262,32 @@ impl Assembler for GradSliceAssembler {
);
}

// TODO hand-write these functions
fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) {
extern "sysv64" fn grad_floor(v: Grad) -> Grad {
v.floor()
}
self.call_fn_unary(out_reg, lhs_reg, grad_floor);
dynasm!(self.0.ops
; vroundss xmm1, Rx(reg(lhs_reg)), Rx(reg(lhs_reg)), 1
; vpxor Rx(reg(out_reg)), Rx(reg(out_reg)), Rx(reg(out_reg))
; movss Rx(reg(out_reg)), xmm1
);
}
fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) {
extern "sysv64" fn grad_ceil(v: Grad) -> Grad {
v.ceil()
}
self.call_fn_unary(out_reg, lhs_reg, grad_ceil);
dynasm!(self.0.ops
; vroundss xmm1, Rx(reg(lhs_reg)), Rx(reg(lhs_reg)), 2
; vpxor Rx(reg(out_reg)), Rx(reg(out_reg)), Rx(reg(out_reg))
; movss Rx(reg(out_reg)), xmm1
);
}
fn build_round(&mut self, out_reg: u8, lhs_reg: u8) {
extern "sysv64" fn grad_round(v: Grad) -> Grad {
v.round()
}
self.call_fn_unary(out_reg, lhs_reg, grad_round);
// Shenanigans figured through Godbolt
dynasm!(self.0.ops
; mov eax, 0x80000000u32 as i32
; vmovd xmm1, eax
; vandps xmm1, xmm1, Rx(reg(lhs_reg))
; mov eax, 0x3effffffu32 as i32
; vmovd xmm2, eax
; vorps xmm1, xmm1, xmm2
; vaddss Rx(reg(out_reg)), xmm1, Rx(reg(lhs_reg))
; vroundss Rx(reg(out_reg)), Rx(reg(out_reg)), Rx(reg(out_reg)), 3
);
}

fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
Expand Down
32 changes: 19 additions & 13 deletions fidget/src/jit/x86_64/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,24 +290,30 @@ impl Assembler for IntervalAssembler {
);
self.0.ops.commit_local().unwrap();
}
// TODO hand-write these functions
fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) {
extern "sysv64" fn interval_floor(v: Interval) -> Interval {
v.floor()
}
self.call_fn_unary(out_reg, lhs_reg, interval_floor);
dynasm!(self.0.ops
; vroundps Rx(reg(out_reg)), Rx(reg(lhs_reg)), 1
);
}
fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) {
extern "sysv64" fn interval_ceil(v: Interval) -> Interval {
v.ceil()
}
self.call_fn_unary(out_reg, lhs_reg, interval_ceil);
dynasm!(self.0.ops
; vroundps Rx(reg(out_reg)), Rx(reg(lhs_reg)), 2
);
}
fn build_round(&mut self, out_reg: u8, lhs_reg: u8) {
extern "sysv64" fn interval_round(v: Interval) -> Interval {
v.round()
}
self.call_fn_unary(out_reg, lhs_reg, interval_round);
// Shenanigans figured through Godbolt
dynasm!(self.0.ops
; mov eax, 0x80000000u32 as i32
; vmovd xmm1, eax
; vbroadcastss xmm1, xmm1
; vandps xmm1, xmm1, Rx(reg(lhs_reg))
; mov eax, 0x3effffffu32 as i32
; vmovd xmm2, eax
; vbroadcastss xmm2, xmm2
; vorps xmm1, xmm1, xmm2
; vaddps Rx(reg(out_reg)), xmm1, Rx(reg(lhs_reg))
; vroundps Rx(reg(out_reg)), Rx(reg(out_reg)), 3
);
}

fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
Expand Down
30 changes: 17 additions & 13 deletions fidget/src/jit/x86_64/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,24 +178,28 @@ impl Assembler for PointAssembler {
);
}

// TODO optimize these three functions
fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) {
extern "sysv64" fn float_floor(f: f32) -> f32 {
f.floor()
}
self.call_fn_unary(out_reg, lhs_reg, float_floor);
dynasm!(self.0.ops
; vroundss Rx(reg(out_reg)), Rx(reg(lhs_reg)), Rx(reg(lhs_reg)), 1
);
}
fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) {
extern "sysv64" fn float_ceil(f: f32) -> f32 {
f.ceil()
}
self.call_fn_unary(out_reg, lhs_reg, float_ceil);
dynasm!(self.0.ops
; vroundss Rx(reg(out_reg)), Rx(reg(lhs_reg)), Rx(reg(lhs_reg)), 2
);
}
fn build_round(&mut self, out_reg: u8, lhs_reg: u8) {
extern "sysv64" fn float_round(f: f32) -> f32 {
f.round()
}
self.call_fn_unary(out_reg, lhs_reg, float_round);
// Shenanigans figured through Godbolt
dynasm!(self.0.ops
; mov eax, 0x80000000u32 as i32
; vmovd xmm1, eax
; vandps xmm1, xmm1, Rx(reg(lhs_reg))
; mov eax, 0x3effffffu32 as i32
; vmovd xmm2, eax
; vorps xmm1, xmm1, xmm2
; vaddss Rx(reg(out_reg)), xmm1, Rx(reg(lhs_reg))
; vroundss Rx(reg(out_reg)), Rx(reg(out_reg)), Rx(reg(out_reg)), 3
);
}

fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
Expand Down

0 comments on commit 76ae930

Please sign in to comment.