diff --git a/src/uct/cuda/cuda_ipc/cuda_ipc_cache.c b/src/uct/cuda/cuda_ipc/cuda_ipc_cache.c index 7e6e5429a2b..7e9d5fda83b 100644 --- a/src/uct/cuda/cuda_ipc/cuda_ipc_cache.c +++ b/src/uct/cuda/cuda_ipc/cuda_ipc_cache.c @@ -124,7 +124,7 @@ static ucs_status_t uct_cuda_ipc_close_memhandle(uct_cuda_ipc_cache_region_t *re (CUdeviceptr)region->mapped_addr, region->key.b_len)); } } else if (region->key.ph.handle_type == UCT_CUDA_IPC_KEY_HANDLE_TYPE_MEMPOOL) { - return UCT_CUDADRV_FUNC_LOG_WARN(cuMemPoolDestroy(region->key.ph.pool)); + return UCT_CUDADRV_FUNC_LOG_WARN(cuMemFree((CUdeviceptr)region->mapped_addr)); } else #endif { @@ -480,14 +480,9 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_ipc_map_memhandle, (key, mapped_addr), ucs_pgt_region_t *pgt_region; uct_cuda_ipc_cache_region_t *region; int ret; + const void *arg1, *arg2; size_t cmp_size; -#if HAVE_CUDA_FABRIC - cmp_size = sizeof(key->ph.handle); -#else - cmp_size = sizeof(key->ph); -#endif - status = uct_cuda_ipc_get_remote_cache(key->pid, &cache); if (status != UCS_OK) { return status; @@ -497,9 +492,17 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_ipc_map_memhandle, (key, mapped_addr), pgt_region = UCS_PROFILE_CALL(ucs_pgtable_lookup, &cache->pgtable, key->d_bptr); if (ucs_likely(pgt_region != NULL)) { - region = ucs_derived_of(pgt_region, uct_cuda_ipc_cache_region_t); - if (memcmp((const void *)&key->ph, (const void *)®ion->key.ph, - cmp_size) == 0) { + region = ucs_derived_of(pgt_region, uct_cuda_ipc_cache_region_t); +#if HAVE_CUDA_FABRIC + cmp_size = sizeof(key->ph.buffer_id); + arg1 = (const void*)&key->ph.buffer_id; + arg2 = (const void*)®ion->key.ph.buffer_id; +#else + cmp_size = sizeof(key->ph); + arg1 = (const void*)&key->ph; + arg2 = (const void*)®ion->key.ph; +#endif + if (memcmp(arg1, arg2, cmp_size) == 0) { /*cache hit */ ucs_trace("%s: cuda_ipc cache hit addr:%p size:%lu region:" UCS_PGT_REGION_FMT, cache->name, (void *)key->d_bptr, diff --git a/src/uct/cuda/cuda_ipc/cuda_ipc_md.c b/src/uct/cuda/cuda_ipc/cuda_ipc_md.c index fd1f588657e..10829edb4a6 100644 --- a/src/uct/cuda/cuda_ipc/cuda_ipc_md.c +++ b/src/uct/cuda/cuda_ipc/cuda_ipc_md.c @@ -117,7 +117,7 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh, uct_cuda_ipc_lkey_t *key; ucs_status_t status; #if HAVE_CUDA_FABRIC -#define UCT_CUDA_IPC_QUERY_NUM_ATTRS 2 +#define UCT_CUDA_IPC_QUERY_NUM_ATTRS 4 CUmemGenericAllocationHandle handle; CUmemoryPool mempool; CUpointer_attribute attr_type[UCT_CUDA_IPC_QUERY_NUM_ATTRS]; @@ -143,6 +143,10 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh, attr_data[0] = &legacy_capable; attr_type[1] = CU_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES; attr_data[1] = &allowed_handle_types; + attr_type[2] = CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE; + attr_data[2] = &mempool; + attr_type[3] = CU_POINTER_ATTRIBUTE_BUFFER_ID; + attr_data[3] = &key->ph.buffer_id; status = UCT_CUDADRV_FUNC_LOG_ERR( cuPointerGetAttributes(ucs_static_array_size(attr_data), attr_type, @@ -185,9 +189,7 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh, goto common_path; } - status = UCT_CUDADRV_FUNC_LOG_ERR(cuPointerGetAttribute(&mempool, - CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)addr)); - if ((status != UCS_OK) || (mempool == 0)) { + if (mempool == 0) { /* cuda_ipc can only handle UCS_MEMORY_TYPE_CUDA, which has to be either * legacy type, or VMM type, or mempool type. Return error if memory * does not belong to any of the three types */ diff --git a/src/uct/cuda/cuda_ipc/cuda_ipc_md.h b/src/uct/cuda/cuda_ipc/cuda_ipc_md.h index eb621bd5ce8..d78611450e5 100644 --- a/src/uct/cuda/cuda_ipc/cuda_ipc_md.h +++ b/src/uct/cuda/cuda_ipc/cuda_ipc_md.h @@ -31,6 +31,7 @@ typedef struct uct_cuda_ipc_md_handle { } handle; CUmemPoolPtrExportData ptr; CUmemoryPool pool; + unsigned long long buffer_id; } uct_cuda_ipc_md_handle_t; #else typedef CUipcMemHandle uct_cuda_ipc_md_handle_t;