From e000353aed068d1801def3675e52764e843e60b9 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 10 Jul 2024 09:06:27 -0700 Subject: [PATCH] Adding indirect command buffer emulation for Metal. (#17849) Until we switch to using MTLIndirectCommandBuffer any reusable command buffers or ones with indirect bindings will need to be recorded into deferred command buffers and replayed upon submission. --- .../src/iree/hal/drivers/metal/CMakeLists.txt | 1 + .../src/iree/hal/drivers/metal/metal_device.m | 100 +++++++++++++++--- 2 files changed, 86 insertions(+), 15 deletions(-) diff --git a/runtime/src/iree/hal/drivers/metal/CMakeLists.txt b/runtime/src/iree/hal/drivers/metal/CMakeLists.txt index 0d60ae22e623..c3186a109824 100644 --- a/runtime/src/iree/hal/drivers/metal/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/metal/CMakeLists.txt @@ -41,6 +41,7 @@ iree_cc_library( iree::base::internal::flatcc::parsing iree::hal iree::hal::drivers::metal::builtin + iree::hal::utils::deferred_command_buffer iree::hal::utils::file_transfer iree::hal::utils::memory_file iree::hal::utils::resource_set diff --git a/runtime/src/iree/hal/drivers/metal/metal_device.m b/runtime/src/iree/hal/drivers/metal/metal_device.m index 620ff571f3d8..21386f853f64 100644 --- a/runtime/src/iree/hal/drivers/metal/metal_device.m +++ b/runtime/src/iree/hal/drivers/metal/metal_device.m @@ -17,6 +17,7 @@ #include "iree/hal/drivers/metal/pipeline_layout.h" #include "iree/hal/drivers/metal/shared_event.h" #include "iree/hal/drivers/metal/staging_buffer.h" +#include "iree/hal/utils/deferred_command_buffer.h" #include "iree/hal/utils/file_transfer.h" #include "iree/hal/utils/memory_file.h" #include "iree/hal/utils/resource_set.h" @@ -247,12 +248,17 @@ static iree_status_t iree_hal_metal_device_create_command_buffer( iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) { iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device); - if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)) { - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "multi-shot command buffer not yet supported"); - } else if (binding_capacity > 0) { - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "indirect command buffers not yet supported"); + // Native Metal command buffers are not reusable so we emulate by recording into our own reusable + // instance. This will be replayed against a Metal command buffer upon submission. + // + // TODO(indirect-cmd): natively support indirect command buffers in Metal via + // MTLIndirectCommandBuffer. We could switch to exclusively using that for all modes to keep the + // number of code paths down. MTLIndirectCommandBuffer is both reusable and has what we require + // for argument buffer updates to pass in binding tables. + if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT) || binding_capacity > 0) { + return iree_hal_deferred_command_buffer_create( + device->device_allocator, mode, command_categories, binding_capacity, &device->block_pool, + device->host_allocator, out_command_buffer); } return iree_hal_metal_direct_command_buffer_create( @@ -390,6 +396,38 @@ static iree_status_t iree_hal_metal_device_queue_write( return loop_status; } +static iree_status_t iree_hal_metal_replay_command_buffer( + iree_hal_metal_device_t* device, iree_hal_command_buffer_t* deferred_command_buffer, + iree_hal_buffer_binding_table_t binding_table, + iree_hal_command_buffer_t** out_direct_command_buffer) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Create the transient command buffer. Note that it is one-shot and has no indirect bindings as + // we will be replaying it once with all the bindings resolved. + iree_hal_command_buffer_t* direct_command_buffer = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_metal_direct_command_buffer_create( + (iree_hal_device_t*)device, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + iree_hal_command_buffer_allowed_categories(deferred_command_buffer), + /*binding_capacity=*/0, device->command_buffer_resource_reference_mode, device->queue, + &device->block_pool, &device->staging_buffer, device->builtin_executable, + device->host_allocator, &direct_command_buffer)); + + // Attempt to replay all commands against the transient command buffer. Note that this will fail + // if any binding does not meet the requirements - having succeeded when recording initially is + // not a guarantee that this will succeed. + iree_status_t status = iree_hal_deferred_command_buffer_apply( + deferred_command_buffer, direct_command_buffer, binding_table); + + if (iree_status_is_ok(status)) { + *out_direct_command_buffer = direct_command_buffer; + } else { + iree_hal_command_buffer_release(direct_command_buffer); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + static iree_status_t iree_hal_metal_device_queue_execute( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, @@ -403,20 +441,46 @@ static iree_status_t iree_hal_metal_device_queue_execute( IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_resource_set_allocate(&device->block_pool, &resource_set)); - iree_status_t status = - iree_hal_resource_set_insert(resource_set, command_buffer_count, command_buffers); - // Put the full semaphore list into a resource set, which retains them--we will need to access // them until the command buffer completes. - if (iree_status_is_ok(status)) { - status = iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count, - wait_semaphore_list.semaphores); - } + iree_status_t status = iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count, + wait_semaphore_list.semaphores); if (iree_status_is_ok(status)) { status = iree_hal_resource_set_insert(resource_set, signal_semaphore_list.count, signal_semaphore_list.semaphores); } + // Translate any deferred command buffers into real Metal command buffers. + // We do this prior to beginning execution so that if we fail we don't leave the system in an + // inconsistent state. + iree_hal_command_buffer_t** direct_command_buffers = (iree_hal_command_buffer_t**)iree_alloca( + command_buffer_count * sizeof(iree_hal_command_buffer_t*)); + if (iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < command_buffer_count; ++i) { + iree_hal_command_buffer_t* command_buffer = command_buffers[i]; + iree_hal_command_buffer_t* direct_command_buffer = NULL; + if (iree_hal_deferred_command_buffer_isa(command_buffer)) { + // Create a temporary command buffer and replay the deferred command buffer with the + // binding table provided. Note that any resources used will be retained by the command + // buffer so we only need to retain the command buffer itself instead of the binding + // tables provided. + iree_hal_buffer_binding_table_t binding_table = + binding_tables ? binding_tables[i] : iree_hal_buffer_binding_table_empty(); + @autoreleasepool { + status = iree_hal_metal_replay_command_buffer(device, command_buffer, binding_table, + &direct_command_buffer); + } + } else { + // Retain the command buffer until the submission has completed. + direct_command_buffer = command_buffer; + } + if (!iree_status_is_ok(status)) break; + status = iree_hal_resource_set_insert(resource_set, 1, &direct_command_buffer); + if (!iree_status_is_ok(status)) break; + direct_command_buffers[i] = direct_command_buffer; + } + } + if (iree_status_is_ok(status)) { @autoreleasepool { // First create a new command buffer and encode wait commands for all wait semaphores. @@ -436,8 +500,14 @@ static iree_status_t iree_hal_metal_device_queue_execute( // up with semaphore signaling. id signal_command_buffer = nil; for (iree_host_size_t i = 0; i < command_buffer_count; ++i) { - iree_hal_command_buffer_t* command_buffer = command_buffers[i]; - id handle = iree_hal_metal_direct_command_buffer_handle(command_buffer); + // NOTE: translation happens above such that we always know these are direct command + // buffers. + // + // TODO(indirect-cmd): support indirect command buffers and switch here, or only use + // indirect command buffers and assume that instead. + iree_hal_command_buffer_t* direct_command_buffer = direct_command_buffers[i]; + id handle = + iree_hal_metal_direct_command_buffer_handle(direct_command_buffer); if (i + 1 != command_buffer_count) [handle commit]; signal_command_buffer = handle; }