From 53bff1281ee8582fd8cfbefca585a8074ba70bc3 Mon Sep 17 00:00:00 2001 From: Alec Larsen Date: Thu, 13 Jun 2024 19:15:07 -0700 Subject: [PATCH] Handle drop --- tailcall/src/lib.rs | 1 + tailcall/src/slot.rs | 69 ++++++++++++++++++++ tailcall/src/thunk.rs | 141 ++++++++++++++-------------------------- tailcall/tests/thunk.rs | 26 ++++---- 4 files changed, 133 insertions(+), 104 deletions(-) create mode 100644 tailcall/src/slot.rs diff --git a/tailcall/src/lib.rs b/tailcall/src/lib.rs index c1dc579..95dfd44 100644 --- a/tailcall/src/lib.rs +++ b/tailcall/src/lib.rs @@ -10,5 +10,6 @@ pub use tailcall_impl::tailcall; +pub(crate) mod slot; pub mod thunk; pub mod trampoline; diff --git a/tailcall/src/slot.rs b/tailcall/src/slot.rs new file mode 100644 index 0000000..0060b74 --- /dev/null +++ b/tailcall/src/slot.rs @@ -0,0 +1,69 @@ +use core::mem::{align_of, size_of, ManuallyDrop, MaybeUninit}; + +#[repr(C, align(16))] +pub struct Slot { + bytes: MaybeUninit<[u8; SIZE]>, +} + +#[repr(C)] +union SlotView { + value: ManuallyDrop, + slot: ManuallyDrop>, +} + +impl Slot { + pub const fn uninit() -> Self { + Self { + bytes: MaybeUninit::uninit(), + } + } + + pub const fn new(value: T) -> Self { + assert!( + align_of::() <= align_of::(), + "unsupport value alignment", + ); + + assert!( + size_of::() <= size_of::(), + "value size exceeds slot capacity", + ); + + SlotView::of_value(value).into_slot() + } + + // SAFETY: The caller must ensure that `self` contains a valid `T`. + pub const unsafe fn into_value(self) -> T { + unsafe { SlotView::of_slot(self).into_value() } + } +} + +impl Default for Slot { + fn default() -> Self { + Self::uninit() + } +} + +impl SlotView { + const fn of_value(value: T) -> Self { + Self { + value: ManuallyDrop::new(value), + } + } + + const fn of_slot(slot: Slot) -> Self { + Self { + slot: ManuallyDrop::new(slot), + } + } + + const fn into_slot(self) -> Slot { + // SAFETY: `Slot` is valid at all bit patterns. + ManuallyDrop::into_inner(unsafe { self.slot }) + } + + // SAFETY: The caller must ensure that `self` contains a valid `T`. + const unsafe fn into_value(self) -> T { + ManuallyDrop::into_inner(unsafe { self.value }) + } +} diff --git a/tailcall/src/thunk.rs b/tailcall/src/thunk.rs index d15147d..1cb6703 100644 --- a/tailcall/src/thunk.rs +++ b/tailcall/src/thunk.rs @@ -1,122 +1,81 @@ -use core::marker::PhantomData; +use crate::slot::Slot; +use core::{marker::PhantomData, mem::transmute, ptr::drop_in_place}; -mod slot { - use core::mem::{align_of, size_of, ManuallyDrop, MaybeUninit}; +const MAX_THUNK_DATA_SIZE: usize = 48; - #[repr(C, align(16))] - pub struct Slot { - bytes: MaybeUninit<[u8; SIZE]>, - } - - union SlotView { - value: ManuallyDrop, - slot: ManuallyDrop>, - } - - impl Slot { - pub const fn new(value: T) -> Self { - Self::assert_valid_at::(); - - SlotView::of_value(value).into_slot() - } - - // SAFETY: The caller must ensure that `self` contains a valid `T`. - pub const unsafe fn into_inner(self) -> T { - Self::assert_valid_at::(); - - unsafe { SlotView::of_slot(self).into_value() } - } - - const fn assert_valid_at() { - assert!(size_of::() <= size_of::()); - assert!(align_of::() <= align_of::()); - } - } - - impl SlotView { - const fn of_value(value: T) -> Self { - Self { - value: ManuallyDrop::new(value), - } - } +type ThunkSlot = Slot; +type CallFn = fn(ThunkSlot) -> T; +type DropInPlaceFn = unsafe fn(*mut ThunkSlot); - const fn into_slot(self) -> Slot { - // SAFETY: `Slot` is valid at all bit patterns. - ManuallyDrop::into_inner(unsafe { self.slot }) - } - - const fn of_slot(slot: Slot) -> Self { - Self { - slot: ManuallyDrop::new(slot), - } - } - - // SAFETY: The caller must ensure that `self` contains a valid `T`. - const unsafe fn into_value(self) -> T { - ManuallyDrop::into_inner(unsafe { self.value }) - } - } +#[repr(transparent)] +pub struct Thunk<'a, T = ()> { + inner: ThunkInner<'a, T>, } -mod guard { - use core::mem::forget; - - pub struct Gurad(); +struct ThunkInner<'a, T> { + slot: ThunkSlot, + call_fn: CallFn, + drop_in_place_fn: DropInPlaceFn, + _marker: PhantomData T + 'a>, +} - impl Gurad { - pub const fn new() -> Self { - Self() +impl<'a, T> Thunk<'a, T> { + pub const fn new(fn_once: F) -> Self + where + F: FnOnce() -> T + 'a, + { + Self { + inner: ThunkInner::new(fn_once), } + } - pub const fn disarm(self) { - forget(self) - } + #[inline(always)] + pub fn call(self) -> T { + self.into_inner().call() } - impl Drop for Gurad { - fn drop(&mut self) { - unreachable!() - } + const fn into_inner(self) -> ThunkInner<'a, T> { + // SAFETY: `Thunk` is a transparent wrapper around `ThunkInner`. + unsafe { transmute(self) } } } -use guard::Gurad; -use slot::Slot; - -const SLOT_SIZE: usize = 48; - -#[must_use] -pub struct Thunk<'a, T = ()> { - guard: Gurad, - slot: Slot, - call_impl: fn(Slot) -> T, - _marker: PhantomData T + 'a>, +impl<'a, T> Drop for Thunk<'a, T> { + fn drop(&mut self) { + // SAFETY: We own `inner`, and it cannot be used after dropping. + unsafe { self.inner.drop_in_place() } + } } -impl<'a, T> Thunk<'a, T> { +impl<'a, T> ThunkInner<'a, T> { pub const fn new(fn_once: F) -> Self where F: FnOnce() -> T + 'a, { Self { - guard: Gurad::new(), slot: Slot::new(fn_once), - call_impl: |slot| unsafe { slot.into_inner::()() }, + call_fn: |slot| { + // SAFETY: `slot` is initialized above with `F`. + unsafe { slot.into_value::()() } + }, + drop_in_place_fn: |slot_ptr| { + // SAFETY: `slot` is initialized above with `F`. + unsafe { drop_in_place(slot_ptr.cast::()) }; + }, _marker: PhantomData, } } #[inline(always)] pub fn call(self) -> T { - let Self { - call_impl, - slot, - guard, - _marker, - } = self; + let Self { slot, call_fn, .. } = self; - guard.disarm(); + call_fn(slot) + } - call_impl(slot) + // SAFETY: `Self::call` cannot be called after dropping in place. + #[inline(always)] + pub unsafe fn drop_in_place(&mut self) { + unsafe { (self.drop_in_place_fn)(&mut self.slot) } } } diff --git a/tailcall/tests/thunk.rs b/tailcall/tests/thunk.rs index 63cd506..69970ce 100644 --- a/tailcall/tests/thunk.rs +++ b/tailcall/tests/thunk.rs @@ -16,17 +16,17 @@ fn with_captures() { assert_eq!(thunk.call(), x + y); } -// #[test] -// fn with_too_many_captures() { -// let a = 1; -// let b = 2; -// let c = 3; -// let d = 4; -// let e = 5; -// let f = 6; -// let g = 7; - -// let thunk = Thunk::new(move || a + b + c + d + e + f + g); +#[test] +#[should_panic] +fn with_too_many_captures() { + let a = 1usize; + let b = 2usize; + let c = 3usize; + let d = 4usize; + let e = 5usize; + let f = 6usize; + let g = 7usize; + let h = 8usize; -// assert_eq!(thunk.call(), a + b + c + d + e + f + g,); -// } + Thunk::new(move || a + b + c + d + e + f + g + h); +}