Skip to content

Commit

Permalink
Add new Thunk-based runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
alecdotninja committed May 28, 2024
1 parent cdcaa97 commit 02227a6
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 117 deletions.
59 changes: 3 additions & 56 deletions tailcall/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,60 +1,7 @@
#![no_std]
#![deny(
missing_docs,
missing_debug_implementations,
missing_copy_implementations,
trivial_casts,
trivial_numeric_casts,
unsafe_code,
unstable_features,
unused_import_braces,
unused_qualifications
)]

//! Tailcall is a library that adds safe, zero-cost [tail recursion] to stable Rust.
//! Eventually, it will be superseded by the [`become` keyword].
//!
//! # Usage
//!
//! To guarantee that recursive calls a function will reuse the same stack frame,
//! annotate it with the [`tailcall`] attribute.
//!
//! ```
//! use tailcall::tailcall;
//!
//! fn factorial(input: u64) -> u64 {
//! #[tailcall]
//! fn factorial_inner(accumulator: u64, input: u64) -> u64 {
//! if input > 0 {
//! factorial_inner(accumulator * input, input - 1)
//! } else {
//! accumulator
//! }
//! }
//!
//! factorial_inner(1, input)
//! }
//! ```
//!
//! Recursive calls which are not in tail form will result in a compile-time error.
//!
//! ```compile_fail
//! use tailcall::tailcall;
//!
//! #[tailcall]
//! fn factorial(input: u64) -> u64 {
//! if input > 0 {
//! input * factorial(input - 1)
//! } else {
//! 1
//! }
//! }
//! ```
//!
//! [tail recursion]: https://en.wikipedia.org/wiki/Tail_call
//! [`become` keyword]: https://internals.rust-lang.org/t/pre-rfc-explicit-proper-tail-calls/3797/16
//! [`tailcall`]: attr.tailcall.html

pub use tailcall_impl::tailcall;

pub mod trampoline;
pub mod slot;
pub mod thunk;
pub mod trampoline;
36 changes: 36 additions & 0 deletions tailcall/src/slot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use core::mem::{align_of, size_of, MaybeUninit};

#[repr(C, align(128))]
pub struct Slot<const SIZE: usize = 128> {
bytes: MaybeUninit<[u8; SIZE]>,
}

impl<const SIZE: usize> Slot<SIZE> {
pub const fn new() -> Self {
Self {
bytes: MaybeUninit::uninit()
}
}

pub unsafe fn take<T>(in_slot: &mut T) -> (T, &mut Self) {
let in_slot: *mut T = in_slot;
debug_assert!((in_slot as usize) % align_of::<Self>() == 0);

let slot: &mut Self = &mut *in_slot.cast();
let value = slot.cast().assume_init_read();

(value, slot)
}

pub fn put<T>(&mut self, value: T) -> &mut T {
self.cast().write(value)
}

fn cast<T>(&mut self) -> &mut MaybeUninit<T> {
debug_assert!(size_of::<T>() <= SIZE);
debug_assert!(align_of::<T>() <= align_of::<Self>());

// SAFETY: We just checked the size and alignment of T.
unsafe { &mut *self.bytes.as_mut_ptr().cast() }
}
}
46 changes: 46 additions & 0 deletions tailcall/src/thunk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use crate::slot::Slot;

pub struct Thunk<'slot, T> {
ptr: &'slot mut dyn ThunkFn<'slot, T>,
}

impl<'slot, T> Thunk<'slot, T> {
#[inline(always)]
pub fn new_in<F>(slot: &'slot mut Slot, fn_once: F) -> Self
where F: FnOnce(&'slot mut Slot) -> T + 'slot
{
let ptr = slot.put(fn_once);

Self {
ptr,
}
}

#[inline(always)]
pub fn call(self) -> T {
let ptr: *mut dyn ThunkFn<'slot, T> = self.ptr;
core::mem::forget(self);

unsafe { (*ptr).call_once_in_slot() }
}
}

impl<T> Drop for Thunk<'_, T> {
fn drop(&mut self) {
unsafe { core::ptr::drop_in_place(self.ptr) }
}
}

trait ThunkFn<'slot, T>: FnOnce(&'slot mut Slot) -> T {
unsafe fn call_once_in_slot(&'slot mut self) -> T;
}

impl<'slot, T, F> ThunkFn<'slot, T> for F
where F: FnOnce(&'slot mut Slot) -> T
{
unsafe fn call_once_in_slot(&'slot mut self) -> T {
let (fn_once, slot) = Slot::take(self);

fn_once(slot)
}
}
84 changes: 23 additions & 61 deletions tailcall/src/trampoline.rs
Original file line number Diff line number Diff line change
@@ -1,71 +1,33 @@
//! This module provides a simple, zero-cost [trampoline]. It is designed to be used by the
//! [`tailcall`] macro, but it can also be used manually.
//!
//! # Usage
//!
//! Express the contents of a recusive function as a step function (`Fn(Input) -> Next<Input, Output>`).
//! To guarantee that only a single stack frame will be used at all levels of optimization, annotate it
//! with `#[inline(always)]` attribute. This step function and an initial input can then be passed to
//! [`run`] which will recusively call it until it resolves to an output.
//!
//! ```
//! // fn gcd(a: u64, b: u64) -> u64 {
//! // if b == 0 {
//! // a
//! // } else {
//! // gcd(b, a % b)
//! // }
//! // }
//!
//! #[inline(always)]
//! fn gcd_step((a, b): (u64, u64)) -> tailcall::trampoline::Next<(u64, u64), u64> {
//! if b == 0 {
//! tailcall::trampoline::Finish(a)
//! } else {
//! tailcall::trampoline::Recurse((b, a % b))
//! }
//! }
//!
//! fn gcd(a: u64, b: u64) -> u64 {
//!
//! tailcall::trampoline::run(gcd_step, (a, b))
//! }
//! ```
//!
//! [trampoline]: https://en.wikipedia.org/wiki/Tail_call#Through_trampolining
//! [`tailcall`]: ../tailcall_impl/attr.tailcall.html
//! [`run`]: fn.run.html
//!
use crate::slot::Slot;
use crate::thunk::Thunk;

/// This is the output of the step function. It indicates to [run] what should happen next.
///
/// [run]: fn.run.html
#[derive(Debug)]
pub enum Next<Input, Output> {
/// This variant indicates that the step function should be run again with the provided input.
Recurse(Input),

/// This variant indicates that there are no more steps to be taken and the provided output should be returned.
Finish(Output),
pub enum Action<'slot, T> {
Done(T),
Call(Thunk<'slot, Self>),
}

pub use Next::*;
#[inline(always)]
pub fn done<T>(_slot: &mut Slot, value: T) -> Action<T> {
Action::Done(value)
}

/// Runs a step function aginast a particular input until it resolves to an output.
#[inline(always)]
pub fn run<StepFn, Input, Output>(step: StepFn, mut input: Input) -> Output
where
StepFn: Fn(Input) -> Next<Input, Output>,
pub fn call<'slot, T, F>(slot: &'slot mut Slot, fn_once: F) -> Action<'slot, T>
where F: FnOnce(&'slot mut Slot) -> Action<'slot, T> + 'slot
{
Action::Call(Thunk::new_in(slot, fn_once))
}

#[inline(always)]
pub fn run<T>(build_action: impl FnOnce(&mut Slot) -> Action<T>) -> T {
let slot = &mut Slot::new();

let mut action = build_action(slot);

loop {
match step(input) {
Recurse(new_input) => {
input = new_input;
continue;
}
Finish(output) => {
break output;
}
match action {
Action::Done(value) => return value,
Action::Call(thunk) => action = thunk.call(),
}
}
}
27 changes: 27 additions & 0 deletions tailcall/tests/thunk_runtime.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use tailcall::{slot, trampoline};

#[test]
fn factorial_in_new_runtime() {
assert!(factorial(5) == 120);
}

fn factorial(input: u64) -> u64 {
#[inline(always)]
fn call_factorial_inner<'slot>(slot: &'slot mut slot::Slot, accumulator: u64, input: u64) -> trampoline::Action<'slot, u64> {
trampoline::call(slot, move |slot| {
if input == 0 {
return trampoline::done(slot, accumulator);
}

return call_factorial_inner(slot, accumulator * input, input - 1);
})
}

fn factorial_inner(accumulator: u64, input: u64) -> u64 {
trampoline::run(move |slot| {
call_factorial_inner(slot, accumulator, input)
})
}

factorial_inner(1, input)
}

0 comments on commit 02227a6

Please sign in to comment.