Skip to content

Commit

Permalink
TL/UCP: fix for DPU
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrol aderholdt committed Feb 14, 2024
1 parent 5c25b9c commit 169931a
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 52 deletions.
2 changes: 0 additions & 2 deletions src/components/tl/ucp/alltoall/alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
ucc_status_t status;

ALLTOALL_TASK_CHECK(coll_args->args, tl_team);
#if 1
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER)) {
tl_error(UCC_TL_TEAM_LIB(tl_team),
"global work buffer not provided nor associated with team");
Expand All @@ -95,7 +94,6 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
goto out;
}
}
#endif
task = ucc_tl_ucp_init_task(coll_args, team);
*task_h = &task->super;
task->super.post = ucc_tl_ucp_alltoall_onesided_start;
Expand Down
8 changes: 5 additions & 3 deletions src/components/tl/ucp/alltoall/alltoall_onesided.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,26 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
ucc_rank_t peer;

ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

ucc_tl_ucp_coll_dynamic_segments(&TASK_ARGS(task), task);

/* TODO: change when support for library-based work buffers is complete */
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
dest = dest + grank * nelems;
mtype = (TASK_ARGS(task).src.info.mem_type == UCC_MEMORY_TYPE_EXPORTED) ? UCC_MEMORY_TYPE_HOST : TASK_ARGS(task).src.info.mem_type;

mtype = (TASK_ARGS(task).src.info.mem_type == UCC_MEMORY_TYPE_EXPORTED) ?
UCC_MEMORY_TYPE_HOST : TASK_ARGS(task).src.info.mem_type;
UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + start * nelems),
(void *)dest, nelems, start, mtype, team, task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, start, team), task, out);

for (peer = (start + 1) % gsize; peer != start; peer = (peer + 1) % gsize) {
UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + peer * nelems),
(void *)dest, nelems, peer, mtype, team, task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, team), task,
out);
}

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
out:
return task->super.status;
Expand Down
1 change: 1 addition & 0 deletions src/components/tl/ucp/alltoallv/alltoallv_onesided.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask)
size_t sd_disp, dd_disp, data_size;

ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
ucc_tl_ucp_coll_dynamic_segments(&TASK_ARGS(task), task);

mtype = (TASK_ARGS(task).src.info_v.mem_type == UCC_MEMORY_TYPE_EXPORTED) ? UCC_MEMORY_TYPE_HOST : TASK_ARGS(task).src.info_v.mem_type;
/* perform a put to each member peer using the peer's index in the
Expand Down
2 changes: 0 additions & 2 deletions src/components/tl/ucp/tl_ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#define ONESIDED_SYNC_SIZE 1
#define ONESIDED_REDUCE_SIZE 4


typedef struct ucc_tl_ucp_iface {
ucc_tl_iface_t super;
} ucc_tl_ucp_iface_t;
Expand Down Expand Up @@ -123,7 +122,6 @@ typedef struct ucc_tl_ucp_remote_info {
void * va_base;
size_t len;
void * mem_h;
void * packed_memh;
void * packed_key;
size_t packed_key_len;
} ucc_tl_ucp_remote_info_t;
Expand Down
31 changes: 4 additions & 27 deletions src/components/tl/ucp/tl_ucp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ ucc_status_t ucc_tl_ucp_memmap_append_segment(ucc_tl_ucp_task_t *task,
ucc_mem_map_t *map,
int segid)
{
// ucc_tl_ucp_task_t *ctask = ucc_derived_of(task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(task);// ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(task);
ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(tl_team);
ucs_status_t ucs_status;
ucp_mem_map_params_t mmap_params;
Expand All @@ -216,37 +215,15 @@ ucc_status_t ucc_tl_ucp_memmap_append_segment(ucc_tl_ucp_task_t *task,
tl_ctx->dynamic_remote_info[segid].va_base = map->address;
tl_ctx->dynamic_remote_info[segid].len = map->len;
tl_ctx->dynamic_remote_info[segid].mem_h = mh;
tl_ctx->dynamic_remote_info[segid].packed_memh = map->resource;

//tl_ctx->dynamic_remote_info[segid].packed_memh = map->resource;
ucs_status = ucp_rkey_pack(tl_ctx->worker.ucp_context, mh, &tl_ctx->dynamic_remote_info[segid].packed_key, &tl_ctx->dynamic_remote_info[segid].packed_key_len);
if (UCS_OK != ucs_status) {
tl_error(tl_ctx->super.super.lib, "failed to pack UCP key with error code: %d", ucs_status);
return ucs_status_to_ucc_status(ucs_status);
}
} else {
#if 0
/* FIXME: will this code path ever be hit for DPU? it should be handled by
* urom cl or doca urom cl */
mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mmap_params.address = map->address;
mmap_params.length = map->len;

ucs_status = ucp_mem_map(tl_ctx->worker.ucp_context, &mmap_params, &mh);
if (ucs_status != UCS_OK) {
tl_error(UCC_TASK_LIB(task), "failure in ucp_mem_map %s", ucs_status_string(ucs_status));
return ucs_status_to_ucc_status(ucs_status);
}
pack_params.field_mask = UCP_MEMH_PACK_PARAM_FIELD_FLAGS;
pack_params.flags = UCP_MEMH_PACK_FLAG_EXPORT;
ucs_status = ucp_memh_pack(mh, &pack_params, &packed_memh, &packed_memh_len);
if (ucs_status != UCS_OK) {
tl_error(UCC_TASK_LIB(task), "failure in ucp_memh_pack %s", ucs_status_string(ucs_status));
return ucs_status_to_ucc_status(ucs_status);
}
#endif
}
} else {
/* do something
/*
* type cpu: we need to map and generate an rkey
* type gpu: is this any different
*/
Expand All @@ -263,7 +240,7 @@ ucc_status_t ucc_tl_ucp_memmap_append_segment(ucc_tl_ucp_task_t *task,
tl_ctx->dynamic_remote_info[segid].va_base = map->address;
tl_ctx->dynamic_remote_info[segid].len = map->len;
tl_ctx->dynamic_remote_info[segid].mem_h = mh;
tl_ctx->dynamic_remote_info[segid].packed_memh = NULL;
//tl_ctx->dynamic_remote_info[segid].packed_memh = NULL;
ucs_status = ucp_rkey_pack(tl_ctx->worker.ucp_context, mh, &tl_ctx->dynamic_remote_info[segid].packed_key, &tl_ctx->dynamic_remote_info[segid].packed_key_len);
if (UCS_OK != ucs_status) {
tl_error(tl_ctx->super.super.lib, "failed to pack UCP key with error code: %d", ucs_status);
Expand Down
44 changes: 30 additions & 14 deletions src/components/tl/ucp/tl_ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t,
ucp_params.field_mask =
UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_TAG_SENDER_MASK | UCP_PARAM_FIELD_NAME;
ucp_params.features = UCP_FEATURE_TAG | UCP_FEATURE_AM;
//if (params->params.mask & UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS) {
ucp_params.features |= UCP_FEATURE_RMA | UCP_FEATURE_AMO64;
//}
ucp_params.features |= UCP_FEATURE_RMA | UCP_FEATURE_AMO64 | UCP_FEATURE_EXPORTED_MEMH;
ucp_params.tag_sender_mask = UCC_TL_UCP_TAG_SENDER_MASK;
ucp_params.name = "UCC_UCP_CONTEXT";

Expand Down Expand Up @@ -499,17 +497,35 @@ ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t * ctx,
}

for (i = 0; i < nsegs; i++) {
mmap_params.field_mask =
UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mmap_params.address = map.segments[i].address;
mmap_params.length = map.segments[i].len;

status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh);
if (UCS_OK != status) {
tl_error(ctx->super.super.lib,
"ucp_mem_map failed with error code: %d", status);
ucc_status = ucs_status_to_ucc_status(status);
goto fail_mem_map;
if (map.segments[i].resource == NULL) {
mmap_params.field_mask =
UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mmap_params.address = map.segments[i].address;
mmap_params.length = map.segments[i].len;

status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh);
if (UCS_OK != status) {
tl_error(ctx->super.super.lib,
"ucp_mem_map failed with error code: %d", status);
ucc_status = ucs_status_to_ucc_status(status);
goto fail_mem_map;
}
// ctx->remote_info[i].packed_memh = NULL;
} else {
mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER;
mmap_params.exported_memh_buffer = map.segments[i].resource;

status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh);
if (status == UCS_ERR_UNREACHABLE) {
tl_error(ctx->super.super.lib, "exported memh is unsupported");
ucc_status = ucs_status_to_ucc_status(status);
goto fail_mem_map;
} else if (status < UCS_OK) {
tl_error(ctx->super.super.lib, "ucp_mem_map failed with error code: %d", status);
ucc_status = ucs_status_to_ucc_status(status);
goto fail_mem_map;
}
// ctx->remote_info[i].packed_memh = map.segments[i].resource;
}
ctx->remote_info[i].mem_h = (void *)mh;
status = ucp_rkey_pack(ctx->worker.ucp_context, mh,
Expand Down
6 changes: 2 additions & 4 deletions src/components/tl/ucp/tl_ucp_sendrecv.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, size_t msglen, u
}
}
*rkey = UCC_TL_UCP_REMOTE_RKEY(ctx, peer, *segment);
*packed_memh = ctx->remote_info[i].packed_memh;
*packed_memh = ctx->remote_info[i].mem_h;
return UCC_OK;
}
key_offset += key_sizes[i];
Expand Down Expand Up @@ -305,7 +305,7 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, size_t msglen, u
}
}
*rkey = UCC_TL_UCP_DYN_REMOTE_RKEY(ctx, peer, team_size, *segment);
*packed_memh = ctx->dynamic_remote_info[i].packed_memh;
*packed_memh = ctx->dynamic_remote_info[i].mem_h;
return UCC_OK;
}
}
Expand Down Expand Up @@ -387,14 +387,12 @@ static inline ucc_status_t ucc_tl_ucp_put_nb(void *buffer, void *target,
req_param.cb.send = ucc_tl_ucp_put_completion_cb;
req_param.user_data = (void *)task;
req_param.memory_type = ucc_memtype_to_ucs[mtype];

if (packed_memh) {
req_param.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH;
req_param.memh = packed_memh;
}

ucp_status = ucp_put_nbx(ep, buffer, msglen, rva, rkey, &req_param);

task->onesided.put_posted++;
if (UCS_OK != ucp_status) {
if (UCS_PTR_IS_ERR(ucp_status)) {
Expand Down

0 comments on commit 169931a

Please sign in to comment.