Skip to content

Commit

Permalink
Adding indirect command buffer emulation for Metal. (iree-org#17849)
Browse files Browse the repository at this point in the history
Until we switch to using MTLIndirectCommandBuffer any reusable command
buffers or ones with indirect bindings will need to be recorded into
deferred command buffers and replayed upon submission.
  • Loading branch information
benvanik authored Jul 10, 2024
1 parent 9ac1015 commit e000353
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 15 deletions.
1 change: 1 addition & 0 deletions runtime/src/iree/hal/drivers/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ iree_cc_library(
iree::base::internal::flatcc::parsing
iree::hal
iree::hal::drivers::metal::builtin
iree::hal::utils::deferred_command_buffer
iree::hal::utils::file_transfer
iree::hal::utils::memory_file
iree::hal::utils::resource_set
Expand Down
100 changes: 85 additions & 15 deletions runtime/src/iree/hal/drivers/metal/metal_device.m
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "iree/hal/drivers/metal/pipeline_layout.h"
#include "iree/hal/drivers/metal/shared_event.h"
#include "iree/hal/drivers/metal/staging_buffer.h"
#include "iree/hal/utils/deferred_command_buffer.h"
#include "iree/hal/utils/file_transfer.h"
#include "iree/hal/utils/memory_file.h"
#include "iree/hal/utils/resource_set.h"
Expand Down Expand Up @@ -247,12 +248,17 @@ static iree_status_t iree_hal_metal_device_create_command_buffer(
iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) {
iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);

if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"multi-shot command buffer not yet supported");
} else if (binding_capacity > 0) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"indirect command buffers not yet supported");
// Native Metal command buffers are not reusable so we emulate by recording into our own reusable
// instance. This will be replayed against a Metal command buffer upon submission.
//
// TODO(indirect-cmd): natively support indirect command buffers in Metal via
// MTLIndirectCommandBuffer. We could switch to exclusively using that for all modes to keep the
// number of code paths down. MTLIndirectCommandBuffer is both reusable and has what we require
// for argument buffer updates to pass in binding tables.
if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT) || binding_capacity > 0) {
return iree_hal_deferred_command_buffer_create(
device->device_allocator, mode, command_categories, binding_capacity, &device->block_pool,
device->host_allocator, out_command_buffer);
}

return iree_hal_metal_direct_command_buffer_create(
Expand Down Expand Up @@ -390,6 +396,38 @@ static iree_status_t iree_hal_metal_device_queue_write(
return loop_status;
}

static iree_status_t iree_hal_metal_replay_command_buffer(
iree_hal_metal_device_t* device, iree_hal_command_buffer_t* deferred_command_buffer,
iree_hal_buffer_binding_table_t binding_table,
iree_hal_command_buffer_t** out_direct_command_buffer) {
IREE_TRACE_ZONE_BEGIN(z0);

// Create the transient command buffer. Note that it is one-shot and has no indirect bindings as
// we will be replaying it once with all the bindings resolved.
iree_hal_command_buffer_t* direct_command_buffer = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_metal_direct_command_buffer_create(
(iree_hal_device_t*)device, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
iree_hal_command_buffer_allowed_categories(deferred_command_buffer),
/*binding_capacity=*/0, device->command_buffer_resource_reference_mode, device->queue,
&device->block_pool, &device->staging_buffer, device->builtin_executable,
device->host_allocator, &direct_command_buffer));

// Attempt to replay all commands against the transient command buffer. Note that this will fail
// if any binding does not meet the requirements - having succeeded when recording initially is
// not a guarantee that this will succeed.
iree_status_t status = iree_hal_deferred_command_buffer_apply(
deferred_command_buffer, direct_command_buffer, binding_table);

if (iree_status_is_ok(status)) {
*out_direct_command_buffer = direct_command_buffer;
} else {
iree_hal_command_buffer_release(direct_command_buffer);
}
IREE_TRACE_ZONE_END(z0);
return status;
}

static iree_status_t iree_hal_metal_device_queue_execute(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
Expand All @@ -403,20 +441,46 @@ static iree_status_t iree_hal_metal_device_queue_execute(
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_allocate(&device->block_pool, &resource_set));

iree_status_t status =
iree_hal_resource_set_insert(resource_set, command_buffer_count, command_buffers);

// Put the full semaphore list into a resource set, which retains them--we will need to access
// them until the command buffer completes.
if (iree_status_is_ok(status)) {
status = iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count,
wait_semaphore_list.semaphores);
}
iree_status_t status = iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count,
wait_semaphore_list.semaphores);
if (iree_status_is_ok(status)) {
status = iree_hal_resource_set_insert(resource_set, signal_semaphore_list.count,
signal_semaphore_list.semaphores);
}

// Translate any deferred command buffers into real Metal command buffers.
// We do this prior to beginning execution so that if we fail we don't leave the system in an
// inconsistent state.
iree_hal_command_buffer_t** direct_command_buffers = (iree_hal_command_buffer_t**)iree_alloca(
command_buffer_count * sizeof(iree_hal_command_buffer_t*));
if (iree_status_is_ok(status)) {
for (iree_host_size_t i = 0; i < command_buffer_count; ++i) {
iree_hal_command_buffer_t* command_buffer = command_buffers[i];
iree_hal_command_buffer_t* direct_command_buffer = NULL;
if (iree_hal_deferred_command_buffer_isa(command_buffer)) {
// Create a temporary command buffer and replay the deferred command buffer with the
// binding table provided. Note that any resources used will be retained by the command
// buffer so we only need to retain the command buffer itself instead of the binding
// tables provided.
iree_hal_buffer_binding_table_t binding_table =
binding_tables ? binding_tables[i] : iree_hal_buffer_binding_table_empty();
@autoreleasepool {
status = iree_hal_metal_replay_command_buffer(device, command_buffer, binding_table,
&direct_command_buffer);
}
} else {
// Retain the command buffer until the submission has completed.
direct_command_buffer = command_buffer;
}
if (!iree_status_is_ok(status)) break;
status = iree_hal_resource_set_insert(resource_set, 1, &direct_command_buffer);
if (!iree_status_is_ok(status)) break;
direct_command_buffers[i] = direct_command_buffer;
}
}

if (iree_status_is_ok(status)) {
@autoreleasepool {
// First create a new command buffer and encode wait commands for all wait semaphores.
Expand All @@ -436,8 +500,14 @@ static iree_status_t iree_hal_metal_device_queue_execute(
// up with semaphore signaling.
id<MTLCommandBuffer> signal_command_buffer = nil;
for (iree_host_size_t i = 0; i < command_buffer_count; ++i) {
iree_hal_command_buffer_t* command_buffer = command_buffers[i];
id<MTLCommandBuffer> handle = iree_hal_metal_direct_command_buffer_handle(command_buffer);
// NOTE: translation happens above such that we always know these are direct command
// buffers.
//
// TODO(indirect-cmd): support indirect command buffers and switch here, or only use
// indirect command buffers and assume that instead.
iree_hal_command_buffer_t* direct_command_buffer = direct_command_buffers[i];
id<MTLCommandBuffer> handle =
iree_hal_metal_direct_command_buffer_handle(direct_command_buffer);
if (i + 1 != command_buffer_count) [handle commit];
signal_command_buffer = handle;
}
Expand Down

0 comments on commit e000353

Please sign in to comment.