Skip to content

Commit

Permalink
API: remove extra memory type
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrol aderholdt committed Feb 15, 2024
1 parent 169931a commit ed08c30
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 118 deletions.
4 changes: 1 addition & 3 deletions src/components/tl/ucp/alltoall/alltoall_onesided.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
ucc_rank_t start = (grank + 1) % gsize;
long * pSync = TASK_ARGS(task).global_work_buffer;
ucc_memory_type_t mtype;
ucc_memory_type_t mtype = TASK_ARGS(task).src.info.mem_type;
ucc_rank_t peer;

ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
Expand All @@ -34,8 +34,6 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
/* TODO: change when support for library-based work buffers is complete */
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
dest = dest + grank * nelems;
mtype = (TASK_ARGS(task).src.info.mem_type == UCC_MEMORY_TYPE_EXPORTED) ?
UCC_MEMORY_TYPE_HOST : TASK_ARGS(task).src.info.mem_type;
UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + start * nelems),
(void *)dest, nelems, start, mtype, team, task),
task, out);
Expand Down
3 changes: 1 addition & 2 deletions src/components/tl/ucp/alltoallv/alltoallv_onesided.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask)
ucc_aint_t *d_disp = TASK_ARGS(task).dst.info_v.displacements;
size_t sdt_size = ucc_dt_size(TASK_ARGS(task).src.info_v.datatype);
size_t rdt_size = ucc_dt_size(TASK_ARGS(task).dst.info_v.datatype);
ucc_memory_type_t mtype;
ucc_memory_type_t mtype = TASK_ARGS(task).src.info_v.mem_type;
ucc_rank_t peer;
size_t sd_disp, dd_disp, data_size;

ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
ucc_tl_ucp_coll_dynamic_segments(&TASK_ARGS(task), task);

mtype = (TASK_ARGS(task).src.info_v.mem_type == UCC_MEMORY_TYPE_EXPORTED) ? UCC_MEMORY_TYPE_HOST : TASK_ARGS(task).src.info_v.mem_type;
/* perform a put to each member peer using the peer's index in the
* destination displacement. */
for (peer = (grank + 1) % gsize; task->onesided.put_posted < gsize;
Expand Down
1 change: 1 addition & 0 deletions src/components/tl/ucp/tl_ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ typedef struct ucc_tl_ucp_remote_info {
void * va_base;
size_t len;
void * mem_h;
void * packed_memh;
void * packed_key;
size_t packed_key_len;
} ucc_tl_ucp_remote_info_t;
Expand Down
129 changes: 41 additions & 88 deletions src/components/tl/ucp/tl_ucp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -197,73 +197,50 @@ ucc_status_t ucc_tl_ucp_memmap_append_segment(ucc_tl_ucp_task_t *task,
ucp_mem_h mh;

// map the memory
if (map->mem_type == UCC_MEMORY_TYPE_EXPORTED) {
if (map->resource != NULL) {
mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER;
mmap_params.exported_memh_buffer = map->resource;

ucs_status = ucp_mem_map(tl_ctx->worker.ucp_context, &mmap_params, &mh);
if (ucs_status == UCS_ERR_UNREACHABLE) {
tl_error(tl_ctx->super.super.lib, "exported memh unsupported");
return ucs_status_to_ucc_status(ucs_status);
} else if (ucs_status < UCS_OK) {
tl_error(tl_ctx->super.super.lib, "error on ucp_mem_map");
return ucs_status_to_ucc_status(ucs_status);
}
/* generate rkeys / packed keys */

tl_ctx->dynamic_remote_info[segid].va_base = map->address;
tl_ctx->dynamic_remote_info[segid].len = map->len;
tl_ctx->dynamic_remote_info[segid].mem_h = mh;
//tl_ctx->dynamic_remote_info[segid].packed_memh = map->resource;
ucs_status = ucp_rkey_pack(tl_ctx->worker.ucp_context, mh, &tl_ctx->dynamic_remote_info[segid].packed_key, &tl_ctx->dynamic_remote_info[segid].packed_key_len);
if (UCS_OK != ucs_status) {
tl_error(tl_ctx->super.super.lib, "failed to pack UCP key with error code: %d", ucs_status);
return ucs_status_to_ucc_status(ucs_status);
}
if (map->resource != NULL) {
mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER;
mmap_params.exported_memh_buffer = map->resource;

ucs_status = ucp_mem_map(tl_ctx->worker.ucp_context, &mmap_params, &mh);
if (ucs_status == UCS_ERR_UNREACHABLE) {
tl_error(tl_ctx->super.super.lib, "exported memh unsupported");
return ucs_status_to_ucc_status(ucs_status);
} else if (ucs_status < UCS_OK) {
tl_error(tl_ctx->super.super.lib, "error on ucp_mem_map");
return ucs_status_to_ucc_status(ucs_status);
}
/* generate rkeys / packed keys */

tl_ctx->dynamic_remote_info[segid].va_base = map->address;
tl_ctx->dynamic_remote_info[segid].len = map->len;
tl_ctx->dynamic_remote_info[segid].mem_h = mh;
tl_ctx->dynamic_remote_info[segid].packed_memh = map->resource;
ucs_status = ucp_rkey_pack(tl_ctx->worker.ucp_context, mh, &tl_ctx->dynamic_remote_info[segid].packed_key, &tl_ctx->dynamic_remote_info[segid].packed_key_len);
if (UCS_OK != ucs_status) {
tl_error(tl_ctx->super.super.lib, "failed to pack UCP key with error code: %d", ucs_status);
return ucs_status_to_ucc_status(ucs_status);
}
} else {
/*
* type cpu: we need to map and generate an rkey
* type gpu: is this any different
*/
if (map->mem_type == UCC_MEMORY_TYPE_HOST) {
mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mmap_params.address = map->address;
mmap_params.length = map->len;

ucs_status = ucp_mem_map(tl_ctx->worker.ucp_context, &mmap_params, &mh);
if (ucs_status != UCS_OK) {
tl_error(UCC_TASK_LIB(task), "failure in ucp_mem_map %s", ucs_status_string(ucs_status));
return ucs_status_to_ucc_status(ucs_status);
}
tl_ctx->dynamic_remote_info[segid].va_base = map->address;
tl_ctx->dynamic_remote_info[segid].len = map->len;
tl_ctx->dynamic_remote_info[segid].mem_h = mh;
//tl_ctx->dynamic_remote_info[segid].packed_memh = NULL;
ucs_status = ucp_rkey_pack(tl_ctx->worker.ucp_context, mh, &tl_ctx->dynamic_remote_info[segid].packed_key, &tl_ctx->dynamic_remote_info[segid].packed_key_len);
if (UCS_OK != ucs_status) {
tl_error(tl_ctx->super.super.lib, "failed to pack UCP key with error code: %d", ucs_status);
return ucs_status_to_ucc_status(ucs_status);
}
mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mmap_params.address = map->address;
mmap_params.length = map->len;

ucs_status = ucp_mem_map(tl_ctx->worker.ucp_context, &mmap_params, &mh);
if (ucs_status != UCS_OK) {
tl_error(UCC_TASK_LIB(task), "failure in ucp_mem_map %s", ucs_status_string(ucs_status));
return ucs_status_to_ucc_status(ucs_status);
}
}
return UCC_OK;
}

static inline int find_dynamic_segment(void *va,
size_t len,
ucc_tl_ucp_context_t *ctx)
{
int i = 0;

for (; i < ctx->n_dynrinfo_segs; i++) {
printf("va_base: %p va %p, len: %lu, len: %lu\n", ctx->dynamic_remote_info[i].va_base, va, ctx->dynamic_remote_info[i].len, len);
if (ctx->dynamic_remote_info[i].va_base == va && ctx->dynamic_remote_info[i].len == len) {
return i;
tl_ctx->dynamic_remote_info[segid].va_base = map->address;
tl_ctx->dynamic_remote_info[segid].len = map->len;
tl_ctx->dynamic_remote_info[segid].mem_h = mh;
tl_ctx->dynamic_remote_info[segid].packed_memh = NULL;
ucs_status = ucp_rkey_pack(tl_ctx->worker.ucp_context, mh, &tl_ctx->dynamic_remote_info[segid].packed_key, &tl_ctx->dynamic_remote_info[segid].packed_key_len);
if (UCS_OK != ucs_status) {
tl_error(tl_ctx->super.super.lib, "failed to pack UCP key with error code: %d", ucs_status);
return ucs_status_to_ucc_status(ucs_status);
}
}
return -1;
return UCC_OK;
}

ucc_status_t ucc_tl_ucp_coll_dynamic_segments(ucc_coll_args_t *coll_args,
Expand All @@ -272,9 +249,8 @@ ucc_status_t ucc_tl_ucp_coll_dynamic_segments(ucc_coll_args_t *coll_args,
ucc_status_t status;
int i = 0;

if ((coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS) &&
(coll_args->flags & UCC_COLL_ARGS_FLAG_MEM_MAP)) {
ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(task);// ucc_derived_of(team, ucc_tl_ucp_team_t);
if (coll_args->mem_map.n_segments > 0) {
ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(task);
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(tl_team);
int starting_index = ctx->n_dynrinfo_segs;
size_t seg_pack_size = 0;
Expand All @@ -289,28 +265,6 @@ ucc_status_t ucc_tl_ucp_coll_dynamic_segments(ucc_coll_args_t *coll_args,
void *ex_buffer;
ptrdiff_t old_offset;

/* FIXME: add the following (isn't this the users responsbility?)
* If there are duplicate entries (same base, same len) for all segments: skip it, no allgather
* If there are updates to entries (same base, high len) for a segment in a segment group: update the group
* confirm equal number of segments
* unmap original, map new
* set dyn_buff pointer to that groups offset
* perform allgather
* release unpacked keys, set to rkey[i] = 0
*/
#if 0
int n_found = 0;
// if it's already there, do nothing
for (i = 0; i < coll_args->mem_map.n_segments; i++) {
if (find_dynamic_segment(coll_args->mem_map.segments[i].address, coll_args->mem_map.segments[i].len, ctx) >= 0) {
++n_found;
}
}
printf("n_found: %d n add: %ld\n", n_found, coll_args->mem_map.n_segments);
if (n_found == coll_args->mem_map.n_segments) {
return UCC_OK;
}
#endif
/* increase dynamic remote info size */
ctx->dynamic_remote_info = ucc_realloc(ctx->dynamic_remote_info, sizeof(ucc_tl_ucp_remote_info_t) * (ctx->n_dynrinfo_segs + coll_args->mem_map.n_segments), "dyn remote info");
if (!ctx->dynamic_remote_info) {
Expand Down Expand Up @@ -353,7 +307,6 @@ ucc_status_t ucc_tl_ucp_coll_dynamic_segments(ucc_coll_args_t *coll_args,
old_offset = ctx->dyn_seg.buff_size;
ctx->dyn_seg.buff_size += local_size * core_team->size;
ctx->dyn_seg.dyn_buff = ucc_realloc(ctx->dyn_seg.dyn_buff, ctx->dyn_seg.buff_size, "dyn buff");
// FIXME
ctx->dyn_seg.num_groups += 1;
ctx->dyn_seg.seg_groups = ucc_realloc(ctx->dyn_seg.seg_groups,sizeof(uint64_t) * ctx->n_dynrinfo_segs, "n_dynrinfo_segs");
ctx->dyn_seg.seg_group_start = ucc_realloc(ctx->dyn_seg.seg_group_start,sizeof(uint64_t) * ctx->n_dynrinfo_segs, "n_dynrinfo_segs");
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/ucp/tl_ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t * ctx,
ucc_status = ucs_status_to_ucc_status(status);
goto fail_mem_map;
}
// ctx->remote_info[i].packed_memh = NULL;
ctx->remote_info[i].packed_memh = NULL;
} else {
mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER;
mmap_params.exported_memh_buffer = map.segments[i].resource;
Expand All @@ -525,7 +525,7 @@ ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t * ctx,
ucc_status = ucs_status_to_ucc_status(status);
goto fail_mem_map;
}
// ctx->remote_info[i].packed_memh = map.segments[i].resource;
ctx->remote_info[i].packed_memh = map.segments[i].resource;
}
ctx->remote_info[i].mem_h = (void *)mh;
status = ucp_rkey_pack(ctx->worker.ucp_context, mh,
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/ucp/tl_ucp_sendrecv.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, size_t msglen, u
}
}
*rkey = UCC_TL_UCP_REMOTE_RKEY(ctx, peer, *segment);
*packed_memh = ctx->remote_info[i].mem_h;
*packed_memh = (ctx->remote_info[i].packed_memh) ? ctx->remote_info[i].mem_h : NULL;
return UCC_OK;
}
key_offset += key_sizes[i];
Expand Down Expand Up @@ -305,7 +305,7 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, size_t msglen, u
}
}
*rkey = UCC_TL_UCP_DYN_REMOTE_RKEY(ctx, peer, team_size, *segment);
*packed_memh = ctx->dynamic_remote_info[i].mem_h;
*packed_memh = (ctx->dynamic_remote_info[i].packed_memh) ? ctx->dynamic_remote_info[i].mem_h : NULL;
return UCC_OK;
}
}
Expand Down
34 changes: 13 additions & 21 deletions src/ucc/api/ucc.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,7 @@ typedef enum ucc_memory_type {
UCC_MEMORY_TYPE_ROCM, /*!< AMD ROCM memory */
UCC_MEMORY_TYPE_ROCM_MANAGED, /*!< AMD ROCM managed system memory */
UCC_MEMORY_TYPE_LAST,
UCC_MEMORY_TYPE_UNKNOWN = UCC_MEMORY_TYPE_LAST,
UCC_MEMORY_TYPE_EXPORTED /*!< Exported memory for use by
DPU / SmartNIC. Memory is not valid for
any other use. */
UCC_MEMORY_TYPE_UNKNOWN = UCC_MEMORY_TYPE_LAST
} ucc_memory_type_t;

/**
Expand Down Expand Up @@ -910,7 +907,6 @@ typedef struct ucc_mem_map {
void * address; /*!< the address of a buffer to be attached to
a UCC context */
size_t len; /*!< the length of the buffer */
ucc_memory_type_t mem_type; /*!< the memory type */
void * resource; /*!< resource associated with the address.
examples of resources include memory
keys. */
Expand Down Expand Up @@ -1720,23 +1716,12 @@ typedef enum {
Note, the status is not guaranteed
to be global on all the processes
participating in the collective.*/
UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS = UCC_BIT(7), /*!< If set, both src
UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS = UCC_BIT(7) /*!< If set, both src
and dst buffers
reside in a memory
mapped region.
Useful for one-sided
collectives. */
UCC_COLL_ARGS_FLAG_MEM_MAP = UCC_BIT(8) /*!< If set, map the
memory map parameters.
This requires an
allgather collective
to be performed in addition
to the requested collective.
There is significant
overhead for the first
call. Useful for one-sided
collectives in message
passing programming models */
} ucc_coll_args_flags_t;

/**
Expand Down Expand Up @@ -1898,10 +1883,17 @@ typedef struct ucc_coll_args {
ucc_coll_callback_t cb;
double timeout; /*!< Timeout in seconds */
ucc_mem_map_params_t mem_map; /*!< Memory regions to be used
in during the collective
operation for one-sided
collectives. Not necessary
for two-sided collectives */
for the current and/or
future one-sided collectives.
If set, the designated regions
will be mapped and information
exchanged with the team
associated with the collective
via an allgather operation.
It is recommended to use this
option sparingly due to the
increased overhead. Not necessary
for two-sided collectives. */
struct {
uint64_t start;
int64_t stride;
Expand Down

0 comments on commit ed08c30

Please sign in to comment.