From 4c0ab3639b75c5f001e58625c8e8d3abb66c6b67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=9A=93?= Date: Fri, 10 Jan 2025 00:47:52 -0500 Subject: [PATCH] Implement handling of fibers and exception catching in libretro sandbox runtime --- retro/extra-ruby-bindings.h | 35 ++++++----- retro/sandbox-bindgen.rb | 122 +++++++++++++++++++++--------------- src/core.cpp | 69 ++++++++++++++++++-- src/sandbox/sandbox.h | 7 +-- 4 files changed, 159 insertions(+), 74 deletions(-) diff --git a/retro/extra-ruby-bindings.h b/retro/extra-ruby-bindings.h index e221dfe41..82d35101e 100644 --- a/retro/extra-ruby-bindings.h +++ b/retro/extra-ruby-bindings.h @@ -55,12 +55,20 @@ MKXP_SANDBOX_API void mkxp_sandbox_free(void *ptr) { free(ptr); } -/* Ruby's `rb_`/`ruby_` functions may return early before they're actually finished running. - * You can use `mkxp_sandbox_complete()` to check if the most recent call to a `rb_`/`ruby_` function finished. - * If `mkxp_sandbox_complete()` returns false, the `rb_`/`ruby_` function is not done executing yet and needs to be called again with the same arguments. */ -MKXP_SANDBOX_API bool mkxp_sandbox_complete(void) { - extern void *rb_asyncify_unwind_buf; /* Defined in wasm/setjmp.c in Ruby source code */ - return rb_asyncify_unwind_buf == NULL; +static void (*_mkxp_sandbox_fiber_entry_point)(void *, void *) = NULL; +static void *_mkxp_sandbox_fiber_arg0 = NULL; +static void *_mkxp_sandbox_fiber_arg1 = NULL; + +MKXP_SANDBOX_API void *mkxp_sandbox_fiber_entry_point(void) { + return (void *)_mkxp_sandbox_fiber_entry_point; +} + +MKXP_SANDBOX_API void *mkxp_sandbox_fiber_arg0(void) { + return _mkxp_sandbox_fiber_arg0; +} + +MKXP_SANDBOX_API void *mkxp_sandbox_fiber_arg1(void) { + return _mkxp_sandbox_fiber_arg1; } /* This function drives Ruby's asynchronous runtime. It's based on the `rb_wasm_rt_start()` function from wasm/runtime.c in the Ruby source code. @@ -69,18 +77,17 @@ MKXP_SANDBOX_API bool mkxp_sandbox_complete(void) { * However, if it returns true, then you need to call the `rb_`/`ruby_` function again with the same arguments * and then call `mkxp_sandbox_yield()` again, and repeat until `mkxp_sandbox_yield()` returns false. */ MKXP_SANDBOX_API bool mkxp_sandbox_yield(void) { - static void (*fiber_entry_point)(void *, void *) = NULL; static bool new_fiber_started = false; - static void *arg0; - static void *arg1; void *asyncify_buf; bool unwound = false; + extern void *rb_asyncify_unwind_buf; /* Defined in wasm/setjmp.c in Ruby source code */ + while (1) { if (unwound) { - if (fiber_entry_point != NULL) { - fiber_entry_point(arg0, arg1); + if (_mkxp_sandbox_fiber_entry_point != NULL) { + _mkxp_sandbox_fiber_entry_point(_mkxp_sandbox_fiber_arg0, _mkxp_sandbox_fiber_arg1); } else { return true; } @@ -88,7 +95,7 @@ MKXP_SANDBOX_API bool mkxp_sandbox_yield(void) { unwound = true; } - if (mkxp_sandbox_complete()) { + if (rb_asyncify_unwind_buf == NULL) { break; } @@ -103,7 +110,7 @@ MKXP_SANDBOX_API bool mkxp_sandbox_yield(void) { continue; } - asyncify_buf = rb_wasm_handle_fiber_unwind(&fiber_entry_point, &arg0, &arg1, &new_fiber_started); + asyncify_buf = rb_wasm_handle_fiber_unwind(&_mkxp_sandbox_fiber_entry_point, &_mkxp_sandbox_fiber_arg0, &_mkxp_sandbox_fiber_arg1, &new_fiber_started); if (asyncify_buf != NULL) { asyncify_start_rewind(asyncify_buf); continue; @@ -114,7 +121,7 @@ MKXP_SANDBOX_API bool mkxp_sandbox_yield(void) { break; } - fiber_entry_point = NULL; + _mkxp_sandbox_fiber_entry_point = NULL; new_fiber_started = false; return false; } diff --git a/retro/sandbox-bindgen.rb b/retro/sandbox-bindgen.rb index 302a3824d..1fe6cd8bf 100644 --- a/retro/sandbox-bindgen.rb +++ b/retro/sandbox-bindgen.rb @@ -34,8 +34,6 @@ # The name of the `free()` binding defined in extra-ruby-bindings.h FREE_FUNC = 'mkxp_sandbox_free' -COMPLETE_FUNC = 'mkxp_sandbox_complete' - ################################################################################ IGNORED_FUNCTIONS = Set[ @@ -155,7 +153,9 @@ #include #include #include + #include #include + #include #include #include #include @@ -174,43 +174,78 @@ namespace mkxp_sandbox { struct bindings { private: + + typedef std::tuple key_t; + + struct fiber { + key_t key; + std::vector stack; + size_t stack_ptr; + }; + wasm_ptr_t next_func_ptr; - std::shared_ptr instance; - size_t depth; - std::vector stack; + std::shared_ptr instance; + std::unordered_map> fibers; wasm_ptr_t sbindgen_malloc(wasm_ptr_t); wasm_ptr_t sbindgen_create_func_ptr(); public: - bindings(std::shared_ptr); + + bindings(std::shared_ptr); template struct stack_frame { friend struct bindings; private: - struct bindings &bindings; + + struct bindings &bind; + struct fiber &fiber; T &inner; - static inline T &init(struct bindings &bindings) { - if (bindings.depth == bindings.stack.size()) { - bindings.stack.push_back(T(bindings)); - } else if (bindings.depth > bindings.stack.size()) { + + static inline struct fiber &init_fiber(struct bindings &bind) { + key_t key = { + w2c_ruby_mkxp_sandbox_fiber_entry_point(bind.instance.get()), + w2c_ruby_mkxp_sandbox_fiber_arg0(bind.instance.get()), + w2c_ruby_mkxp_sandbox_fiber_arg1(bind.instance.get()), + }; + if (bind.fibers.count(key) == 0) { + bind.fibers[key] = (struct fiber){.key = key}; + } + return bind.fibers[key]; + } + + static inline T &init_inner(struct bindings &bind, struct fiber &fiber) { + if (fiber.stack_ptr == fiber.stack.size()) { + fiber.stack.push_back(T(bind)); + } else if (fiber.stack_ptr > fiber.stack.size()) { throw SandboxTrapException(); } + try { - return boost::any_cast(bindings.stack[bindings.depth++]); + T &inner = boost::any_cast(fiber.stack[fiber.stack_ptr]); + ++fiber.stack_ptr; + return inner; } catch (boost::bad_any_cast &) { - throw SandboxTrapException(); + fiber.stack.resize(fiber.stack_ptr++); + fiber.stack.push_back(T(bind)); + return boost::any_cast(fiber.stack.back()); } } - stack_frame(struct bindings &b) : bindings(b), inner(init(b)) {} + + stack_frame(struct bindings &b) : bind(b), fiber(init_fiber(b)), inner(init_inner(b, fiber)) {} public: + ~stack_frame() { if (inner.is_complete()) { - bindings.stack.pop_back(); + fiber.stack.pop_back(); + } + --fiber.stack_ptr; + if (fiber.stack.empty()) { + bind.fibers.erase(fiber.key); } - --bindings.depth; } + inline T &operator()() { return inner; } @@ -269,7 +304,7 @@ using namespace mkxp_sandbox; - bindings::bindings(std::shared_ptr m) : next_func_ptr(-1), instance(m), depth(0) {} + bindings::bindings(std::shared_ptr m) : next_func_ptr(-1), instance(m) {} wasm_ptr_t bindings::sbindgen_malloc(wasm_size_t size) { @@ -297,7 +332,7 @@ // Make sure that an integer overflow won't occur if we double the max size of the funcref table wasm_size_t new_max_size; if (__builtin_add_overflow(instance->w2c_T0.max_size, instance->w2c_T0.max_size, &new_max_size)) { - return 0; + return -1; } // Double the max size of the funcref table @@ -312,7 +347,7 @@ .module_instance = instance.get(), }) != old_max_size) { instance->w2c_T0.max_size = old_max_size; - return 0; + return -1; } return next_func_ptr++; @@ -364,12 +399,7 @@ if !handler[:func_ptr_args].nil? || handler[:anyargs] coroutine_initializer += <<~HEREDOC f#{i} = bind.sbindgen_create_func_ptr(); - if (f#{i} == 0) { - HEREDOC - buffers.reverse_each { |buf| coroutine_initializer += " w2c_#{MODULE_NAME}_#{FREE_FUNC}(bind.instance.get(), #{buf});\n" } - coroutine_initializer += <<~HEREDOC - throw SandboxOutOfMemoryException(); - } + if (f#{i} == (wasm_ptr_t)-1) throw SandboxOutOfMemoryException(); HEREDOC if handler[:anyargs] coroutine_initializer += <<~HEREDOC @@ -395,12 +425,7 @@ elsif !handler[:buf_size].nil? coroutine_initializer += <<~HEREDOC f#{i} = bind.sbindgen_malloc(#{handler[:buf_size].gsub('PREV_ARG', "a#{i - 1}").gsub('ARG', "a#{i}")}); - if (f#{i} == 0) { - HEREDOC - buffers.reverse_each { |buf| coroutine_initializer += " w2c_#{MODULE_NAME}_#{FREE_FUNC}(bind.instance.get(), #{buf});\n" } - coroutine_initializer += <<~HEREDOC - throw SandboxOutOfMemoryException(); - } + if (f#{i} == 0) throw SandboxOutOfMemoryException(); HEREDOC coroutine_initializer += handler[:serialize].gsub('PREV_ARG', "a#{i - 1}").gsub('ARG', "a#{i}").gsub('BUF', "f#{i}") coroutine_initializer += "\n" @@ -419,12 +444,7 @@ when 'rb_funcall' coroutine_initializer += <<~HEREDOC f#{args.length - 1} = bind.sbindgen_malloc(a#{args.length - 2} * sizeof(VALUE)); - if (f#{args.length - 1} == 0) { - HEREDOC - buffers.reverse_each { |buf| coroutine_initializer += " w2c_#{MODULE_NAME}_#{FREE_FUNC}(bind.instance.get(), #{buf});\n" } - coroutine_initializer += <<~HEREDOC - throw SandboxOutOfMemoryException(); - } + if (f#{args.length - 1} == 0) throw SandboxOutOfMemoryException(); std::va_list a; va_start(a, a#{args.length - 2}); for (long i = 0; i < a#{args.length - 2}; ++i) { @@ -446,9 +466,6 @@ f#{args.length - 1} = bind.sbindgen_malloc(n * sizeof(VALUE)); if (f#{args.length - 1} == 0) { va_end(a); - HEREDOC - buffers.reverse_each { |buf| coroutine_initializer += " w2c_#{MODULE_NAME}_#{FREE_FUNC}(bind.instance.get(), #{buf});\n" } - coroutine_initializer += <<~HEREDOC throw SandboxOutOfMemoryException(); } for (wasm_size_t i = 0; i < n; ++i) { @@ -488,21 +505,24 @@ coroutine_inner = <<~HEREDOC #{handler[:primitive] == :void ? '' : 'r = '}w2c_#{MODULE_NAME}_#{func_name}(#{(['bind.instance.get()'] + (0...args.length).map { |i| args[i] == '...' || transformed_args.include?(i) ? "f#{i}" : "a#{i}" }).join(', ')}); - if (w2c_#{MODULE_NAME}_#{COMPLETE_FUNC}(bind.instance.get())) break; + if (w2c_#{MODULE_NAME}_asyncify_get_state(bind.instance.get()) != 1) break; yield; HEREDOC - coroutine_finalizer = (0...buffers.length).map { |i| "w2c_#{MODULE_NAME}_#{FREE_FUNC}(bind.instance.get(), #{buffers[buffers.length - 1 - i]});" } + coroutine_destructor = buffers.empty? ? '' : <<~HEREDOC + #{func_name}::~#{func_name}() { + #{(0...buffers.length).map { |i| " try { if (#{buffers[buffers.length - 1 - i]} != 0) w2c_#{MODULE_NAME}_#{FREE_FUNC}(bind.instance.get(), #{buffers[buffers.length - 1 - i]}); } catch (SandboxTrapException) {}" }.join("\n")} + } + HEREDOC coroutine_definition = <<~HEREDOC - #{func_name}::#{func_name}(bindings &bind) : bind(bind) {} #{coroutine_ret} #{func_name}::operator()(#{coroutine_args.join(', ')}) {#{coroutine_vars.empty? ? '' : (coroutine_vars.map { |var| "\n #{var} = 0;" }.join + "\n")} reenter (this) { #{coroutine_initializer.empty? ? '' : (coroutine_initializer.split("\n").map { |line| " #{line}" }.join("\n") + "\n\n")} for (;;) { #{coroutine_inner.split("\n").map { |line| " #{line}" }.join("\n")} - }#{coroutine_finalizer.empty? ? '' : ("\n\n" + coroutine_finalizer.map { |line| " #{line}" }.join("\n"))} + } }#{handler[:primitive] == :void ? '' : "\n\n return r;"} - } + }#{coroutine_destructor.empty? ? '' : ("\n" + coroutine_destructor)} HEREDOC coroutine_declaration = <<~HEREDOC @@ -510,9 +530,9 @@ friend struct bindings; friend struct bindings::stack_frame; #{coroutine_ret} operator()(#{declaration_args.join(', ')}); - private: - #{func_name}(bindings &bind); - bindings &bind; + #{coroutine_destructor.empty? ? '' : "~#{func_name}();\n "}private: + struct bindings &bind; + inline #{func_name}(struct bindings &b) : #{(['bind(b)'] + buffers.map { |buffer| "#{buffer}(0)" }).join(', ')} {} #{fields.empty? ? '' : fields.map { |field| " #{field};\n" }.join}}; HEREDOC @@ -526,9 +546,9 @@ for func_name in func_names file.write(" friend struct #{func_name};\n") end - file.write(" };\n") + file.write(" };") for declaration in declarations - file.write("\n" + declaration.split("\n").map { |line| " #{line}" }.join("\n").rstrip) + file.write("\n\n" + declaration.split("\n").map { |line| " #{line}" }.join("\n").rstrip) end file.write(HEADER_END) end @@ -536,6 +556,6 @@ file.write(PRELUDE) for coroutine in coroutines file.write("\n\n") - file.write(coroutine.rstrip) + file.write(coroutine.rstrip + "\n") end end diff --git a/src/core.cpp b/src/core.cpp index c3a5d511d..92f6c9dcf 100644 --- a/src/core.cpp +++ b/src/core.cpp @@ -37,6 +37,16 @@ yield; \ } while (1) +#define SANDBOX_AWAIT_AND_SET(variable, coroutine, ...) \ + do { \ + { \ + auto frame = sandbox->bindings.bind(); \ + variable = frame()(__VA_ARGS__); \ + if (frame().is_complete()) break; \ + } \ + yield; \ + } while (1) + using namespace mkxp_retro; static void fallback_log(enum retro_log_level level, const char *fmt, ...) { @@ -50,23 +60,72 @@ static uint32_t *frame_buf; static std::unique_ptr sandbox; static const char *game_path = NULL; -static VALUE my_cpp_func(w2c_ruby *ruby, int32_t argc, wasm_ptr_t argv, VALUE self) { - log_printf(RETRO_LOG_INFO, "Hello from Ruby land! my_cpp_func(argc=%d, argv=0x%08x, self=0x%08x)\n", argc, argv, self); +static VALUE my_cpp_func(void *_, VALUE self, VALUE args) { + struct co : boost::asio::coroutine { + inline co(struct mkxp_sandbox::bindings &bind) {} + + void operator()(VALUE args) { + reenter (this) { + SANDBOX_AWAIT(mkxp_sandbox::rb_eval_string, "puts 'Hello from Ruby land!'"); + SANDBOX_AWAIT(mkxp_sandbox::rb_p, args); + } + } + }; + + sandbox->bindings.bind()()(args); + return self; } -static bool init_sandbox() { - struct main : boost::asio::coroutine { +static VALUE func(void *_, VALUE arg) { + struct co : boost::asio::coroutine { + inline co(struct mkxp_sandbox::bindings &bind) {} + void operator()() { reenter (this) { SANDBOX_AWAIT(mkxp_sandbox::rb_eval_string, "puts 'Hello, World!'"); SANDBOX_AWAIT(mkxp_sandbox::rb_eval_string, "require 'zlib'; p Zlib::Deflate::deflate('hello')"); - SANDBOX_AWAIT(mkxp_sandbox::rb_define_global_function, "my_cpp_func", (VALUE (*)(void *, ANYARGS))my_cpp_func, -1); + SANDBOX_AWAIT(mkxp_sandbox::rb_define_global_function, "my_cpp_func", (VALUE (*)(void *, ANYARGS))my_cpp_func, -2); SANDBOX_AWAIT(mkxp_sandbox::rb_eval_string, "my_cpp_func(1, nil, 3, 'this is a string', :symbol, 2)"); SANDBOX_AWAIT(mkxp_sandbox::rb_eval_string, "p Dir.glob '/mkxp-retro-game/*'"); + + SANDBOX_AWAIT(mkxp_sandbox::rb_eval_string, "throw 'Throw an error on purpose to see if we can catch it'"); + + SANDBOX_AWAIT(mkxp_sandbox::rb_eval_string, "puts 'Unreachable code'"); + } + } + }; + + sandbox->bindings.bind()()(); + + return arg; +} + +static VALUE rescue(void *_, VALUE arg, VALUE error) { + struct co : boost::asio::coroutine { + inline co(struct mkxp_sandbox::bindings &bind) {} + + void operator()(VALUE error) { + reenter (this) { + SANDBOX_AWAIT(mkxp_sandbox::rb_eval_string, "puts 'Entered rescue()'"); + SANDBOX_AWAIT(mkxp_sandbox::rb_p, error); + } + } + }; + + sandbox->bindings.bind()()(error); + + return arg; +} + +static bool init_sandbox() { + struct main : boost::asio::coroutine { + void operator()() { + reenter (this) { + SANDBOX_AWAIT(mkxp_sandbox::rb_rescue, func, 0, rescue, 0); } } }; diff --git a/src/sandbox/sandbox.h b/src/sandbox/sandbox.h index bcd512bfb..094002fe7 100644 --- a/src/sandbox/sandbox.h +++ b/src/sandbox/sandbox.h @@ -39,14 +39,13 @@ namespace mkxp_sandbox { sandbox(const char *game_path); ~sandbox(); - // TODO: handle Ruby fibers properly instead of crashing whenever Ruby switches to a different fiber than the main one template inline void run() { T coroutine = T(); - do { + for (;;) { coroutine(); + if (coroutine.is_complete()) break; w2c_ruby_mkxp_sandbox_yield(ruby.get()); - } while (!coroutine.is_complete()); - + } } }; }