From 02227a6be74c144e8a1c096f1ab43d7f0094f147 Mon Sep 17 00:00:00 2001 From: Alec Larsen Date: Tue, 28 May 2024 10:34:33 -0700 Subject: [PATCH] Add new Thunk-based runtime --- tailcall/src/lib.rs | 59 ++--------------------- tailcall/src/slot.rs | 36 ++++++++++++++ tailcall/src/thunk.rs | 46 ++++++++++++++++++ tailcall/src/trampoline.rs | 84 +++++++++------------------------ tailcall/tests/thunk_runtime.rs | 27 +++++++++++ 5 files changed, 135 insertions(+), 117 deletions(-) create mode 100644 tailcall/src/slot.rs create mode 100644 tailcall/src/thunk.rs create mode 100644 tailcall/tests/thunk_runtime.rs diff --git a/tailcall/src/lib.rs b/tailcall/src/lib.rs index 525e6dd..45937da 100644 --- a/tailcall/src/lib.rs +++ b/tailcall/src/lib.rs @@ -1,60 +1,7 @@ #![no_std] -#![deny( - missing_docs, - missing_debug_implementations, - missing_copy_implementations, - trivial_casts, - trivial_numeric_casts, - unsafe_code, - unstable_features, - unused_import_braces, - unused_qualifications -)] - -//! Tailcall is a library that adds safe, zero-cost [tail recursion] to stable Rust. -//! Eventually, it will be superseded by the [`become` keyword]. -//! -//! # Usage -//! -//! To guarantee that recursive calls a function will reuse the same stack frame, -//! annotate it with the [`tailcall`] attribute. -//! -//! ``` -//! use tailcall::tailcall; -//! -//! fn factorial(input: u64) -> u64 { -//! #[tailcall] -//! fn factorial_inner(accumulator: u64, input: u64) -> u64 { -//! if input > 0 { -//! factorial_inner(accumulator * input, input - 1) -//! } else { -//! accumulator -//! } -//! } -//! -//! factorial_inner(1, input) -//! } -//! ``` -//! -//! Recursive calls which are not in tail form will result in a compile-time error. -//! -//! ```compile_fail -//! use tailcall::tailcall; -//! -//! #[tailcall] -//! fn factorial(input: u64) -> u64 { -//! if input > 0 { -//! input * factorial(input - 1) -//! } else { -//! 1 -//! } -//! } -//! ``` -//! -//! [tail recursion]: https://en.wikipedia.org/wiki/Tail_call -//! [`become` keyword]: https://internals.rust-lang.org/t/pre-rfc-explicit-proper-tail-calls/3797/16 -//! [`tailcall`]: attr.tailcall.html pub use tailcall_impl::tailcall; -pub mod trampoline; +pub mod slot; +pub mod thunk; +pub mod trampoline; \ No newline at end of file diff --git a/tailcall/src/slot.rs b/tailcall/src/slot.rs new file mode 100644 index 0000000..a32bb39 --- /dev/null +++ b/tailcall/src/slot.rs @@ -0,0 +1,36 @@ +use core::mem::{align_of, size_of, MaybeUninit}; + +#[repr(C, align(128))] +pub struct Slot { + bytes: MaybeUninit<[u8; SIZE]>, +} + +impl Slot { + pub const fn new() -> Self { + Self { + bytes: MaybeUninit::uninit() + } + } + + pub unsafe fn take(in_slot: &mut T) -> (T, &mut Self) { + let in_slot: *mut T = in_slot; + debug_assert!((in_slot as usize) % align_of::() == 0); + + let slot: &mut Self = &mut *in_slot.cast(); + let value = slot.cast().assume_init_read(); + + (value, slot) + } + + pub fn put(&mut self, value: T) -> &mut T { + self.cast().write(value) + } + + fn cast(&mut self) -> &mut MaybeUninit { + debug_assert!(size_of::() <= SIZE); + debug_assert!(align_of::() <= align_of::()); + + // SAFETY: We just checked the size and alignment of T. + unsafe { &mut *self.bytes.as_mut_ptr().cast() } + } +} diff --git a/tailcall/src/thunk.rs b/tailcall/src/thunk.rs new file mode 100644 index 0000000..8545f62 --- /dev/null +++ b/tailcall/src/thunk.rs @@ -0,0 +1,46 @@ +use crate::slot::Slot; + +pub struct Thunk<'slot, T> { + ptr: &'slot mut dyn ThunkFn<'slot, T>, +} + +impl<'slot, T> Thunk<'slot, T> { + #[inline(always)] + pub fn new_in(slot: &'slot mut Slot, fn_once: F) -> Self + where F: FnOnce(&'slot mut Slot) -> T + 'slot + { + let ptr = slot.put(fn_once); + + Self { + ptr, + } + } + + #[inline(always)] + pub fn call(self) -> T { + let ptr: *mut dyn ThunkFn<'slot, T> = self.ptr; + core::mem::forget(self); + + unsafe { (*ptr).call_once_in_slot() } + } +} + +impl Drop for Thunk<'_, T> { + fn drop(&mut self) { + unsafe { core::ptr::drop_in_place(self.ptr) } + } +} + +trait ThunkFn<'slot, T>: FnOnce(&'slot mut Slot) -> T { + unsafe fn call_once_in_slot(&'slot mut self) -> T; +} + +impl<'slot, T, F> ThunkFn<'slot, T> for F + where F: FnOnce(&'slot mut Slot) -> T +{ + unsafe fn call_once_in_slot(&'slot mut self) -> T { + let (fn_once, slot) = Slot::take(self); + + fn_once(slot) + } +} \ No newline at end of file diff --git a/tailcall/src/trampoline.rs b/tailcall/src/trampoline.rs index 3ab4c91..68d89ae 100644 --- a/tailcall/src/trampoline.rs +++ b/tailcall/src/trampoline.rs @@ -1,71 +1,33 @@ -//! This module provides a simple, zero-cost [trampoline]. It is designed to be used by the -//! [`tailcall`] macro, but it can also be used manually. -//! -//! # Usage -//! -//! Express the contents of a recusive function as a step function (`Fn(Input) -> Next`). -//! To guarantee that only a single stack frame will be used at all levels of optimization, annotate it -//! with `#[inline(always)]` attribute. This step function and an initial input can then be passed to -//! [`run`] which will recusively call it until it resolves to an output. -//! -//! ``` -//! // fn gcd(a: u64, b: u64) -> u64 { -//! // if b == 0 { -//! // a -//! // } else { -//! // gcd(b, a % b) -//! // } -//! // } -//! -//! #[inline(always)] -//! fn gcd_step((a, b): (u64, u64)) -> tailcall::trampoline::Next<(u64, u64), u64> { -//! if b == 0 { -//! tailcall::trampoline::Finish(a) -//! } else { -//! tailcall::trampoline::Recurse((b, a % b)) -//! } -//! } -//! -//! fn gcd(a: u64, b: u64) -> u64 { -//! -//! tailcall::trampoline::run(gcd_step, (a, b)) -//! } -//! ``` -//! -//! [trampoline]: https://en.wikipedia.org/wiki/Tail_call#Through_trampolining -//! [`tailcall`]: ../tailcall_impl/attr.tailcall.html -//! [`run`]: fn.run.html -//! +use crate::slot::Slot; +use crate::thunk::Thunk; -/// This is the output of the step function. It indicates to [run] what should happen next. -/// -/// [run]: fn.run.html -#[derive(Debug)] -pub enum Next { - /// This variant indicates that the step function should be run again with the provided input. - Recurse(Input), - - /// This variant indicates that there are no more steps to be taken and the provided output should be returned. - Finish(Output), +pub enum Action<'slot, T> { + Done(T), + Call(Thunk<'slot, Self>), } -pub use Next::*; +#[inline(always)] +pub fn done(_slot: &mut Slot, value: T) -> Action { + Action::Done(value) +} -/// Runs a step function aginast a particular input until it resolves to an output. #[inline(always)] -pub fn run(step: StepFn, mut input: Input) -> Output -where - StepFn: Fn(Input) -> Next, +pub fn call<'slot, T, F>(slot: &'slot mut Slot, fn_once: F) -> Action<'slot, T> + where F: FnOnce(&'slot mut Slot) -> Action<'slot, T> + 'slot { + Action::Call(Thunk::new_in(slot, fn_once)) +} + +#[inline(always)] +pub fn run(build_action: impl FnOnce(&mut Slot) -> Action) -> T { + let slot = &mut Slot::new(); + + let mut action = build_action(slot); + loop { - match step(input) { - Recurse(new_input) => { - input = new_input; - continue; - } - Finish(output) => { - break output; - } + match action { + Action::Done(value) => return value, + Action::Call(thunk) => action = thunk.call(), } } } diff --git a/tailcall/tests/thunk_runtime.rs b/tailcall/tests/thunk_runtime.rs new file mode 100644 index 0000000..3170dc6 --- /dev/null +++ b/tailcall/tests/thunk_runtime.rs @@ -0,0 +1,27 @@ +use tailcall::{slot, trampoline}; + +#[test] +fn factorial_in_new_runtime() { + assert!(factorial(5) == 120); +} + +fn factorial(input: u64) -> u64 { + #[inline(always)] + fn call_factorial_inner<'slot>(slot: &'slot mut slot::Slot, accumulator: u64, input: u64) -> trampoline::Action<'slot, u64> { + trampoline::call(slot, move |slot| { + if input == 0 { + return trampoline::done(slot, accumulator); + } + + return call_factorial_inner(slot, accumulator * input, input - 1); + }) + } + + fn factorial_inner(accumulator: u64, input: u64) -> u64 { + trampoline::run(move |slot| { + call_factorial_inner(slot, accumulator, input) + }) + } + + factorial_inner(1, input) +}