Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement thread creation deletion event callback. #506

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ pub use crate::thread::{Thread, ThreadStatus};
pub use crate::traits::{
FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, LuaNativeFn, LuaNativeFnMut, ObjectLike,
};
#[cfg(feature = "luau")]
pub use crate::types::ThreadEventInfo;
pub use crate::types::{
AppDataRef, AppDataRefMut, Either, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState,
};
Expand Down
9 changes: 5 additions & 4 deletions src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ pub use crate::{
#[doc(no_inline)]
pub use crate::HookTriggers as LuaHookTriggers;

#[cfg(feature = "luau")]
#[doc(no_inline)]
pub use crate::{CoverageInfo as LuaCoverageInfo, Vector as LuaVector};

#[cfg(feature = "async")]
#[doc(no_inline)]
pub use crate::{AsyncThread as LuaAsyncThread, LuaNativeAsyncFn};
#[cfg(feature = "luau")]
#[doc(no_inline)]
pub use crate::{
CoverageInfo as LuaCoverageInfo, ThreadEventInfo as LuaThreadEventInfo, Vector as LuaVector,
};

#[cfg(feature = "serialize")]
#[doc(no_inline)]
Expand Down
70 changes: 70 additions & 0 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ use crate::types::{
AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex,
ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak,
};

#[cfg(any(feature = "luau", doc))]
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
use crate::types::ThreadEventInfo;
use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage};
use crate::util::{
assert_stack, check_stack, protect_lua_closure, push_string, push_table, rawset_field, StackGuard,
Expand Down Expand Up @@ -671,6 +675,72 @@ impl Lua {
}
}

/// Sets a callback that will be called by Luau whenever a thread is created/destroyed.
///
/// Often used for keeping track of threads.
#[cfg(any(feature = "luau", doc))]
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
pub fn set_thread_event_callback<F>(&self, callback: F)
where
F: Fn(&Lua, ThreadEventInfo) -> Result<()> + MaybeSend + 'static,
{
use std::rc::Rc;

unsafe extern "C-unwind" fn userthread_proc(parent: *mut ffi::lua_State, state: *mut ffi::lua_State) {
callback_error_ext(state, ptr::null_mut(), move |extra, _| {
let raw_lua: &RawLua = (*extra).raw_lua();
let _guard = StateGuard::new(raw_lua, state);

let userthread_cb = (*extra).userthread_callback.clone();
let userthread_cb =
mlua_expect!(userthread_cb, "no userthread callback set in userthread_proc");
if parent.is_null() {
raw_lua.push(Value::Nil).unwrap();
} else {
raw_lua.push_ref_thread(parent).unwrap();
}
if parent.is_null() {
let event_info = ThreadEventInfo::Destroyed(state.cast_const().cast());
let main_state = raw_lua.main_state();
if main_state == state {
return Ok(()); // Don't process Destroyed event on main thread.
}
let main_extra = ExtraData::get(main_state);
let main_raw_lua: &RawLua = (*main_extra).raw_lua();
let _guard = StateGuard::new(main_raw_lua, state);
userthread_cb((*main_extra).lua(), event_info)
} else {
raw_lua.push_ref_thread(parent).unwrap();
let event_info = match raw_lua.pop_value() {
Value::Thread(thr) => ThreadEventInfo::Created(thr),
_ => unimplemented!(),
};
userthread_cb((*extra).lua(), event_info)
}
});
}

// Set interrupt callback
let lua = self.lock();
unsafe {
(*lua.extra.get()).userthread_callback = Some(Rc::new(callback));
(*ffi::lua_callbacks(lua.main_state())).userthread = Some(userthread_proc);
}
}

/// Removes any thread event function previously set by `set_thread_event_callback`.
///
/// This function has no effect if a callback was not previously set.
#[cfg(any(feature = "luau", doc))]
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
pub fn remove_thread_event_callback(&self) {
let lua = self.lock();
unsafe {
(*lua.extra.get()).userthread_callback = None;
(*ffi::lua_callbacks(lua.main_state())).userthread = None;
}
}

/// Sets the warning function to be used by Lua to emit warnings.
///
/// Requires `feature = "lua54"`
Expand Down
4 changes: 4 additions & 0 deletions src/state/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ pub(crate) struct ExtraData {
pub(super) warn_callback: Option<crate::types::WarnCallback>,
#[cfg(feature = "luau")]
pub(super) interrupt_callback: Option<crate::types::InterruptCallback>,
#[cfg(feature = "luau")]
pub(super) userthread_callback: Option<crate::types::ThreadEventCallback>,

#[cfg(feature = "luau")]
pub(super) sandboxed: bool,
Expand Down Expand Up @@ -177,6 +179,8 @@ impl ExtraData {
#[cfg(feature = "luau")]
interrupt_callback: None,
#[cfg(feature = "luau")]
userthread_callback: None,
#[cfg(feature = "luau")]
sandboxed: false,
#[cfg(feature = "luau")]
compiler: None,
Expand Down
20 changes: 19 additions & 1 deletion src/state/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ impl Drop for RawLua {
}

let mem_state = MemoryState::get(self.main_state());

#[cfg(feature = "luau")] // Fixes a crash during shutdown
{
(*ffi::lua_callbacks(self.main_state())).userthread = None;
}
ffi::lua_close(self.main_state());

// Deallocate `MemoryState`
Expand Down Expand Up @@ -556,6 +559,21 @@ impl RawLua {
value.push_into_stack(self)
}

pub(crate) unsafe fn push_ref_thread(&self, ref_thread: *mut ffi::lua_State) -> Result<()> {
let state = self.state();
check_stack(state, 1)?;
let _sg = StackGuard::new(ref_thread);
check_stack(ref_thread, 1)?;

if self.unlikely_memory_error() {
ffi::lua_pushthread(ref_thread)
} else {
protect_lua!(ref_thread, 0, 1, |ref_thread| ffi::lua_pushthread(ref_thread))?
};
ffi::lua_xmove(ref_thread, self.state(), 1);
Ok(())
}

/// Pushes a `Value` (by reference) onto the Lua stack.
///
/// Uses 2 stack spaces, does not call `checkstack`.
Expand Down
23 changes: 23 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ use crate::error::Result;
use crate::hook::Debug;
use crate::state::{ExtraData, Lua, RawLua};

#[cfg(any(feature = "luau", doc))]
use crate::thread::Thread;

// Re-export mutex wrappers
pub(crate) use sync::{ArcReentrantMutexGuard, ReentrantMutex, ReentrantMutexGuard, XRc, XWeak};

Expand Down Expand Up @@ -73,6 +76,20 @@ pub enum VmState {
Yield,
}

/// Information about a thread event.
///
/// For creating a thread, it contains the thread that created it.
///
/// This is useful for tracking the origin of all threads.
#[cfg(any(feature = "luau", doc))]
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
pub enum ThreadEventInfo {
/// When a thread is created, it contains the thread that created it.
Created(Thread),
/// When a thread is destroyed, it returns its .to_pointer representation.
Destroyed(*const c_void),
}

#[cfg(all(feature = "send", not(feature = "luau")))]
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;

Expand All @@ -85,6 +102,12 @@ pub(crate) type InterruptCallback = Rc<dyn Fn(&Lua) -> Result<VmState> + Send>;
#[cfg(all(not(feature = "send"), feature = "luau"))]
pub(crate) type InterruptCallback = Rc<dyn Fn(&Lua) -> Result<VmState>>;

#[cfg(all(feature = "send", feature = "luau"))]
pub(crate) type ThreadEventCallback = Rc<dyn Fn(&Lua, ThreadEventInfo) -> Result<()> + Send>;

#[cfg(all(not(feature = "send"), feature = "luau"))]
pub(crate) type ThreadEventCallback = Rc<dyn Fn(&Lua, ThreadEventInfo) -> Result<()>>;

#[cfg(all(feature = "send", feature = "lua54"))]
pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &str, bool) -> Result<()> + Send>;

Expand Down
Loading