Skip to content

Commit

Permalink
[ROCM] Add supports_concurrent_managed_access
Browse files Browse the repository at this point in the history
  • Loading branch information
raikonenfnu authored and monorimet committed Jan 24, 2024
1 parent d1b8a64 commit ddb0610
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 18 deletions.
1 change: 1 addition & 0 deletions experimental/rocm/dynamic_symbol_tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ RC_PFN_DECL(hipMemcpyAsync, void *, const void *, size_t, hipMemcpyKind,
hipStream_t)
RC_PFN_DECL(hipMalloc, void **, size_t)
RC_PFN_DECL(hipMallocManaged, hipDeviceptr_t *, size_t, unsigned int)
RC_PFN_DECL(hipMemPrefetchAsync, const void *, size_t, int, hipStream_t)
RC_PFN_DECL(hipFree, void *)
RC_PFN_DECL(hipHostFree, void *)
RC_PFN_DECL(hipMemAllocHost, void **, size_t, unsigned int)
Expand Down
120 changes: 102 additions & 18 deletions experimental/rocm/rocm_allocator.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@

typedef struct iree_hal_rocm_allocator_t {
iree_hal_resource_t resource;
iree_hal_device_t* base_device;
iree_hal_rocm_context_wrapper_t* context;
hipDevice_t device;
hipStream_t stream;

bool supports_concurrent_managed_access;

IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;)
} iree_hal_rocm_allocator_t;
Expand All @@ -30,17 +33,40 @@ static iree_hal_rocm_allocator_t* iree_hal_rocm_allocator_cast(
}

iree_status_t iree_hal_rocm_allocator_create(
iree_hal_rocm_context_wrapper_t* context,
iree_hal_rocm_context_wrapper_t* context, hipDevice_t device, hipStream_t stream,
iree_hal_allocator_t** out_allocator) {
IREE_ASSERT_ARGUMENT(context);
IREE_TRACE_ZONE_BEGIN(z0);

// To support device-local + host-visible memory we need concurrent managed
// access indicating that the host and devices can concurrently access the
// device memory. If we don't have this feature then we fall back to forcing
// all device-local + host-visible memory into host-local + device-visible
// page-locked memory. The compiler tries to avoid this for high-traffic
// buffers except for readback staging buffers.
int supports_concurrent_managed_access = 0;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, ROCM_RESULT_TO_STATUS(
context->syms,
hipDeviceGetAttribute(
&supports_concurrent_managed_access,
hipDeviceAttributeConcurrentManagedAccess, device),
"hipDeviceGetAttribute"));
IREE_TRACE_ZONE_APPEND_TEXT(
z0, supports_concurrent_managed_access
? "has CONCURRENT_MANAGED_ACCESS"
: "no CONCURRENT_MANAGED_ACCESS (expect slow accesses on "
"device-local + host-visible memory)");
iree_hal_rocm_allocator_t* allocator = NULL;
iree_status_t status = iree_allocator_malloc(
context->host_allocator, sizeof(*allocator), (void**)&allocator);
if (iree_status_is_ok(status)) {
iree_hal_resource_initialize(&iree_hal_rocm_allocator_vtable,
&allocator->resource);
allocator->context = context;
allocator->device = device;
allocator->stream = stream;
allocator->supports_concurrent_managed_access = supports_concurrent_managed_access !=0;
*out_allocator = (iree_hal_allocator_t*)allocator;
}

Expand Down Expand Up @@ -87,52 +113,78 @@ static iree_status_t iree_hal_rocm_allocator_query_memory_heaps(
iree_host_size_t capacity,
iree_hal_allocator_memory_heap_t* IREE_RESTRICT heaps,
iree_host_size_t* IREE_RESTRICT out_count) {
const iree_host_size_t count = 3;
iree_hal_rocm_allocator_t* allocator =
iree_hal_rocm_allocator_cast(base_allocator);

// TODO(benvanik): check CU_DEVICE_ATTRIBUTE_INTEGRATED and return a unified
// set of heaps (likely still a cached and uncached, at minimum).
iree_host_size_t count = 3;
if (allocator->supports_concurrent_managed_access) {
++count; // device-local | host-visible
}
if (out_count) *out_count = count;
if (capacity < count) {
// NOTE: lightweight as this is hit in normal pre-sizing usage.
return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
}

// NOTE: this is all a guess - someone who is familiar with rocm will want
// to refine this further.

// Don't think there's a query for these.
// Max allocation size may be much smaller in certain memory types such as
// page-locked memory and it'd be good to enforce that.
const iree_device_size_t max_allocation_size = ~(iree_device_size_t)0;
const iree_device_size_t min_alignment = 64;

int i = 0;

// Device-local memory (dispatch resources):
heaps[0] = (iree_hal_allocator_memory_heap_t){
heaps[i++] = (iree_hal_allocator_memory_heap_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.allowed_usage =
IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_DISPATCH,
.max_allocation_size = max_allocation_size,
.min_alignment = min_alignment,
};

if (allocator->supports_concurrent_managed_access) {
// Device-local managed memory with host mapping support:
heaps[i++] = (iree_hal_allocator_memory_heap_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
IREE_HAL_MEMORY_TYPE_HOST_VISIBLE |
IREE_HAL_MEMORY_TYPE_HOST_COHERENT,
.allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER |
IREE_HAL_BUFFER_USAGE_DISPATCH |
IREE_HAL_BUFFER_USAGE_MAPPING,
.max_allocation_size = max_allocation_size,
.min_alignment = min_alignment,
};
}

// Write-combined page-locked host-local memory (upload):
heaps[1] = (iree_hal_allocator_memory_heap_t){
.type =
IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_COHERENT,
.allowed_usage =
IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
heaps[i++] = (iree_hal_allocator_memory_heap_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE |
IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
IREE_HAL_MEMORY_TYPE_HOST_COHERENT,
.allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER |
IREE_HAL_BUFFER_USAGE_DISPATCH |
IREE_HAL_BUFFER_USAGE_MAPPING,
.max_allocation_size = max_allocation_size,
.min_alignment = min_alignment,
};

// Cached page-locked host-local memory (download):
heaps[2] = (iree_hal_allocator_memory_heap_t){
.type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
heaps[i++] = (iree_hal_allocator_memory_heap_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE |
IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
IREE_HAL_MEMORY_TYPE_HOST_COHERENT |
IREE_HAL_MEMORY_TYPE_HOST_CACHED,
.allowed_usage =
IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
.allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER |
IREE_HAL_BUFFER_USAGE_DISPATCH |
IREE_HAL_BUFFER_USAGE_MAPPING,
.max_allocation_size = max_allocation_size,
.min_alignment = min_alignment,
};

IREE_ASSERT(i == count);
return iree_ok_status();
}

Expand All @@ -141,22 +193,46 @@ iree_hal_rocm_allocator_query_buffer_compatibility(
iree_hal_allocator_t* IREE_RESTRICT base_allocator,
iree_hal_buffer_params_t* IREE_RESTRICT params,
iree_device_size_t* IREE_RESTRICT allocation_size) {
iree_hal_rocm_allocator_t* allocator =
iree_hal_rocm_allocator_cast(base_allocator);

// All buffers can be allocated on the heap.
iree_hal_buffer_compatibility_t compatibility =
IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE;

if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
// Buffers are importable in ROCM under most cases, though performance may
// vary wildly. We don't fully verify that the buffer parameters are
// self-consistent and just look at whether we can get a device pointer.
if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE;
}

// Buffers can only be used on the queue if they are device visible.
if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
}
if (iree_any_bit_set(params->usage,
IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) {
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH;
}
}

// If concurrent managed access is not supported then make device-local +
// host-visible allocations fall back to host-local + device-visible
// page-locked memory. This will be significantly slower for the device to
// access but the compiler only uses this type for readback staging buffers
// and it's better to function than function fast.
if (!allocator->supports_concurrent_managed_access &&
iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_LOW_PERFORMANCE;
params->type &= ~(IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
IREE_HAL_MEMORY_TYPE_HOST_VISIBLE);
params->type |=
IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
}

// We are now optimal.
params->type &= ~IREE_HAL_MEMORY_TYPE_OPTIMAL;

Expand Down Expand Up @@ -209,6 +285,14 @@ static iree_status_t iree_hal_rocm_allocator_allocate_buffer(
status = ROCM_RESULT_TO_STATUS(
allocator->context->syms,
hipMallocManaged(&device_ptr, allocation_size, hipMemAttachGlobal));
if (iree_status_is_ok(status) &&
allocator->supports_concurrent_managed_access) {
// Prefetch the buffer on the GPU device.
status = ROCM_RESULT_TO_STATUS(
allocator->context->syms,
hipMemPrefetchAsync(device_ptr, allocation_size, allocator->device,
allocator->stream));
}
host_ptr = (void*)device_ptr;
} else {
// Device only.
Expand Down
2 changes: 2 additions & 0 deletions experimental/rocm/rocm_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ extern "C" {
// Create a ROCM allocator.
iree_status_t iree_hal_rocm_allocator_create(
iree_hal_rocm_context_wrapper_t* context,
hipDevice_t device,
hipStream_t stream,
iree_hal_allocator_t** out_allocator);

#ifdef __cplusplus
Expand Down
1 change: 1 addition & 0 deletions experimental/rocm/rocm_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ static iree_status_t iree_hal_rocm_device_create_internal(
}
if (iree_status_is_ok(status)) {
status = iree_hal_rocm_allocator_create(&device->context_wrapper,
device->device, device->stream,
&device->device_allocator);
}
if (iree_status_is_ok(status) &&
Expand Down

0 comments on commit ddb0610

Please sign in to comment.