diff --git a/configure.ac b/configure.ac index 49ad4fa7..322b5794 100644 --- a/configure.ac +++ b/configure.ac @@ -80,7 +80,7 @@ AC_ARG_WITH([verbs], AC_CHECK_HEADER( [infiniband/verbs.h], [],[AC_MSG_FAILURE([ibverbs header files not found])]) AC_CHECK_LIB([ibverbs], [ibv_get_device_list], [],[AC_MSG_FAILURE([libibverbs not found]);]) -AC_CHECK_DECLS([IBV_ACCESS_RELAXED_ORDERING, IBV_QPF_GRH_REQUIRED, ibv_reg_dmabuf_mr], [], [], +AC_CHECK_DECLS([IBV_ACCESS_RELAXED_ORDERING, IBV_QPF_GRH_REQUIRED, ibv_reg_dmabuf_mr, ibv_query_ece, ibv_set_ece], [], [], [[#include ]]) # check for ucx diff --git a/include/ibvwrap.h b/include/ibvwrap.h index c79de372..da076b56 100644 --- a/include/ibvwrap.h +++ b/include/ibvwrap.h @@ -59,6 +59,8 @@ static inline ncclResult_t wrap_ibv_poll_cq(struct ibv_cq *cq, int num_entries, ncclResult_t wrap_ibv_create_qp(struct ibv_qp **ret, struct ibv_pd *pd, struct ibv_qp_init_attr *qp_init_attr); ncclResult_t wrap_ibv_modify_qp(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask); ncclResult_t wrap_ibv_destroy_qp(struct ibv_qp *qp); +ncclResult_t wrap_ibv_query_ece(struct ibv_qp *qp, struct ibv_ece *ece, int* supported); +ncclResult_t wrap_ibv_set_ece(struct ibv_qp *qp, struct ibv_ece *ece, int* supported); ncclResult_t wrap_ibv_post_send(struct ibv_qp *qp, struct ibv_send_wr *wr, struct ibv_send_wr **bad_wr); ncclResult_t wrap_ibv_post_recv(struct ibv_qp *qp, struct ibv_recv_wr *wr, struct ibv_recv_wr **bad_wr); ncclResult_t wrap_ibv_event_type_str(char **ret, enum ibv_event_type event); diff --git a/include/nccl.h b/include/nccl.h index 67736cf8..a234af96 100644 --- a/include/nccl.h +++ b/include/nccl.h @@ -14,11 +14,11 @@ #endif #define NCCL_MAJOR 2 -#define NCCL_MINOR 15 -#define NCCL_PATCH 1 +#define NCCL_MINOR 20 +#define NCCL_PATCH 3 #define NCCL_SUFFIX "" -#define NCCL_VERSION_CODE 21510 +#define NCCL_VERSION_CODE 22003 #define NCCL_VERSION(X,Y,Z) (((X) <= 2 && (Y) <= 8) ? (X) * 1000 + (Y) * 100 + (Z) : (X) * 10000 + (Y) * 100 + (Z)) #ifdef __cplusplus @@ -42,15 +42,24 @@ typedef enum { ncclSuccess = 0, ncclInProgress = 7, ncclNumResults = 8 } ncclResult_t; +#define NCCL_CONFIG_UNDEF_INT INT_MIN +#define NCCL_CONFIG_UNDEF_PTR NULL +#define NCCL_SPLIT_NOCOLOR -1 + /* Communicator configuration. Users can assign value to attributes to specify the * behavior of a communicator. */ -typedef struct ncclConfig_v21400 { +typedef struct ncclConfig_v21700 { /* attributes that users should never touch. */ size_t size; unsigned int magic; unsigned int version; /* attributes that users are able to customize. */ int blocking; + int cgaClusterSize; + int minCTAs; + int maxCTAs; + const char *netName; + int splitShare; } ncclConfig_t; /* Config initializer must be assigned to initialize config structure when it is created. @@ -59,9 +68,23 @@ typedef struct ncclConfig_v21400 { sizeof(ncclConfig_t), /* size */ \ 0xcafebeef, /* magic */ \ NCCL_VERSION(NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH), /* version */ \ - 1 /* blocking */ \ + NCCL_CONFIG_UNDEF_INT, /* blocking */ \ + NCCL_CONFIG_UNDEF_INT, /* cgaClusterSize */ \ + NCCL_CONFIG_UNDEF_INT, /* minCTAs */ \ + NCCL_CONFIG_UNDEF_INT, /* maxCTAs */ \ + NCCL_CONFIG_UNDEF_PTR, /* netName */ \ + NCCL_CONFIG_UNDEF_INT /* splitShare */ \ } +/* NCCL malloc and free function for all types of NCCL optimizations + * (e.g. user buffer registration). The actual allocated size might + * be larger than requested due to granularity requirement. */ +ncclResult_t ncclMemAlloc(void** ptr, size_t size); +ncclResult_t pncclMemAlloc(void** ptr, size_t size); + +ncclResult_t ncclMemFree(void *ptr); +ncclResult_t pncclMemFree(void *ptr); + /* Return the NCCL_VERSION_CODE of the NCCL library in the supplied integer. * This integer is coded with the MAJOR, MINOR and PATCH level of the * NCCL library @@ -119,6 +142,10 @@ ncclResult_t pncclCommAbort(ncclComm_t comm); const char* ncclGetErrorString(ncclResult_t result); const char* pncclGetErrorString(ncclResult_t result); +/* Returns a human-readable message of the last error that occurred. */ + const char* ncclGetLastError(ncclComm_t comm); + const char* pncclGetLastError(ncclComm_t comm); + /* Checks whether the comm has encountered any asynchronous errors */ ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError); ncclResult_t pncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError); @@ -135,6 +162,16 @@ ncclResult_t pncclCommCuDevice(const ncclComm_t comm, int* device); ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank); ncclResult_t pncclCommUserRank(const ncclComm_t comm, int* rank); + +/* Register CUDA buffer for zero-copy operation */ +ncclResult_t ncclCommRegister(const ncclComm_t comm, void* buff, size_t size, void** handle); +ncclResult_t pncclCommRegister(const ncclComm_t comm, void* buff, size_t size, void** handle); + +/* Deregister CUDA buffer */ +ncclResult_t ncclCommDeregister(const ncclComm_t comm, void* handle); +ncclResult_t pncclCommDeregister(const ncclComm_t comm, void* handle); + + /* Reduction operation selector */ typedef enum { ncclNumOps_dummy = 5 } ncclRedOp_dummy_t; typedef enum { ncclSum = 0, diff --git a/include/p2p_plugin.h b/include/p2p_plugin.h index 1ac916d6..7d0df5bc 100644 --- a/include/p2p_plugin.h +++ b/include/p2p_plugin.h @@ -44,18 +44,27 @@ struct ncclIbMrCache { int capacity, population; }; +#define NCCL_IB_MAX_DEVS_PER_NIC 2 +#define MAX_MERGED_DEV_NAME (MAXNAMESIZE*NCCL_IB_MAX_DEVS_PER_NIC)+NCCL_IB_MAX_DEVS_PER_NIC +struct ncclIbMergedDev { + int ndevs; + int devs[NCCL_IB_MAX_DEVS_PER_NIC]; // Points to an index in ncclIbDevs + int speed; + char devName[MAX_MERGED_DEV_NAME]; // Up to NCCL_IB_MAX_DEVS_PER_NIC * name size, and a character for each '+' +} __attribute__((aligned(64))); + struct ncclIbRequest { - struct ncclIbVerbs* verbs; + struct ncclIbNetCommBase* base; int type; - int events; struct ncclSocket* sock; - struct ncclIbGidInfo* gidInfo; + int events[NCCL_IB_MAX_DEVS_PER_NIC]; + struct ncclIbNetCommDevBase* devBases[NCCL_IB_MAX_DEVS_PER_NIC]; int nreqs; union { struct { int size; void* data; - uint32_t lkey; + uint32_t lkeys[NCCL_IB_MAX_DEVS_PER_NIC]; int offset; } send; struct { @@ -64,56 +73,57 @@ struct ncclIbRequest { }; }; -struct ncclIbVerbs { - int dev; - struct ibv_pd* pd; // duplicate of ncclIbDevs[dev].pd +// Retain local RoCE address for error logging +struct ncclIbGidInfo { + uint8_t link_layer; + union ibv_gid localGid; +}; + +typedef struct ncclIbNetCommDevBase { + int ibDevN; + struct ibv_pd* pd; struct ibv_cq* cq; uint64_t pad[1]; - struct ncclIbRequest reqs[MAX_REQUESTS]; -}; + struct ncclIbGidInfo gidInfo; +} ncclIbNetCommDevBase; typedef struct ncclIbDev { pthread_mutex_t lock; int device; uint64_t guid; - uint8_t port; + uint8_t portNum; uint8_t link; uint8_t isSharpDev; int speed; struct ibv_context* context; int pdRefs; struct ibv_pd* pd; - struct ncclIbVerbs verbs; char devName[MAXNAMESIZE]; char *pciPath; int realPort; int maxQp; struct ncclIbMrCache mrCache; int ar; // ADAPTIVE_ROUTING -} __attribute__((aligned(64))) nccl_ib_dev_t; + struct ibv_port_attr portAttr; +} __attribute__((aligned(64))) ncclIbDev; -#define MAX_IB_PORT 15 -struct userIbDev { - char devName[MAXNAMESIZE]; - uint16_t port_en; -}; #define MAX_IB_DEVS 32 +struct ncclIbMergedDev ncclIbMergedDevs[MAX_IB_DEVS]; extern struct ncclIbDev ncclIbDevs[MAX_IB_DEVS]; -extern struct ncclIbDev userIbDevs[MAX_IB_DEVS]; /* Detect whether GDR can work on a given NIC with the current CUDA device * Returns : * ncclSuccess : GDR works * ncclSystemError : no module or module loaded but not supported by GPU */ -ncclResult_t nccl_p2p_gdr_support(int dev); +ncclResult_t nccl_p2p_gdr_support(); ncclResult_t nccl_p2p_dmabuf_support(int dev); -ncclResult_t nccl_p2p_ib_pci_path(nccl_ib_dev_t *devs, int num_devs, char* dev_name, char** path, int* real_port); +ncclResult_t nccl_p2p_ib_pci_path(ncclIbDev *devs, int num_devs, char* dev_name, char** path, int* real_port); -ncclResult_t nccl_p2p_ib_get_properties(nccl_ib_dev_t *devs, int dev, ncclNetProperties_t* props); +ncclResult_t nccl_p2p_ib_get_properties(ncclIbDev *devs, int dev, ncclNetProperties_t* props); -ncclResult_t nccl_p2p_ib_init(int *num_devs, nccl_ib_dev_t *ncclIbDevs, char *ncclIbIfName, union ncclSocketAddress *ncclIbIfAddr, pthread_t *ncclIbAsyncThread, ncclDebugLogger_t logFunction); +ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIbIfName, union ncclSocketAddress *ncclIbIfAddr, pthread_t *ncclIbAsyncThread, ncclDebugLogger_t logFunction); /* Convert value returtned by ibv_query_port to actual link width */ int nccl_p2p_ib_width(int width); @@ -125,6 +135,8 @@ int64_t ncclParamSharpMaxComms(); int64_t ncclParamIbMergeVfs(); +int64_t ncclParamIbMergeNics(); + int ncclIbRelaxedOrderingCapable(void); nccl_p2p_plugin_t nccl_p2p_get_plugin_type(); diff --git a/include/param.h b/include/param.h index 376da209..da019606 100644 --- a/include/param.h +++ b/include/param.h @@ -12,6 +12,7 @@ const char* userHomeDir(); void setEnvFile(const char* fileName); void initEnv(); +const char *ncclGetEnv(const char *name); void ncclLoadParam(char const* env, int64_t deftVal, int64_t uninitialized, int64_t* cache); diff --git a/src/ib_plugin.c b/src/ib_plugin.c index f9f95731..5cddebec 100644 --- a/src/ib_plugin.c +++ b/src/ib_plugin.c @@ -65,6 +65,7 @@ int ncclIbRelaxedOrderingCapable(void) { NCCL_PARAM(IbDisable, "IBEXT_DISABLE", 0); NCCL_PARAM(IbMergeVfs, "IB_MERGE_VFS", 1); +NCCL_PARAM(IbMergeNics, "IB_MERGE_NICS", 1); extern ncclDebugLogger_t pluginLogFunction; @@ -113,26 +114,46 @@ ncclResult_t ncclIbGetProperties_v6(int dev, ncclNetProperties_v6_t* props_v6) props_v6->maxRecvs = props.maxRecvs; return ncclSuccess; -} +}; #define NCCL_IB_MAX_QPS 128 -typedef struct ncclIbQpInfo { +struct ncclIbQpInfo { + uint32_t qpn; + + // Fields needed for ece (enhanced connection establishment) + struct ibv_ece ece; + int ece_supported; + int devIndex; +}; + + +// Per-Dev connection metadata +typedef struct ncclIbDevInfo { uint32_t lid; uint8_t ib_port; + enum ibv_mtu mtu; uint8_t link_layer; uint8_t is_global; - uint32_t qpn[NCCL_IB_MAX_QPS]; // For RoCE and IB GRH uint64_t spn; uint64_t iid; - enum ibv_mtu mtu; // FIFO RDMA info uint32_t fifoRkey; + union ibv_gid remoteGid; +} ncclIbDevInfo; + + +// Struct containing everything needed to establish connections +typedef struct ncclIbConnectionMetadata { + struct ncclIbQpInfo qpInfo[NCCL_IB_MAX_QPS]; + struct ncclIbDevInfo devs[NCCL_IB_MAX_DEVS_PER_NIC]; + char devName[MAX_MERGED_DEV_NAME]; uint64_t fifoAddr; -} ncclIbQpInfo; + int ndevs; +} ncclIbConnectionMetadata; enum ncclIbCommState { ncclIbCommStateStart = 0, @@ -159,12 +180,7 @@ struct ncclIbHandle { struct ncclIbCommStage stage; // Used by the other side when connecting }; -// Retain local and remote RoCE addresses for error logging -struct ncclIbGidInfo { - uint8_t link_layer; - union ibv_gid localGid; - union ibv_gid remoteGid; -}; + #define NCCL_NET_IB_REQ_UNUSED 0 #define NCCL_NET_IB_REQ_SEND 1 @@ -181,70 +197,97 @@ struct ncclIbListenComm { struct ncclIbSendFifo { uint64_t addr; int size; - uint32_t rkey; + uint32_t rkeys[NCCL_IB_MAX_DEVS_PER_NIC]; uint32_t nreqs; uint32_t tag; uint64_t idx; + char padding[24]; }; +typedef struct ncclIbQp { + struct ibv_qp* qp; + int devIndex; + int remDevIdx; +} ncclIbQp; + struct ncclIbRemSizesFifo { int elems[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; uint64_t fifoTail; uint64_t addr; - uint32_t rkey; + uint32_t rkeys[NCCL_IB_MAX_DEVS_PER_NIC]; uint32_t flags; - struct ibv_mr* mr; + struct ibv_mr* mrs[NCCL_IB_MAX_DEVS_PER_NIC]; struct ibv_sge sge; }; +// A per-dev struct for netIbSendComm +typedef struct ncclIbSendCommDev { + struct ncclIbNetCommDevBase base; + struct ibv_mr* fifoMr; +} __attribute__((aligned(8))) ncclIbSendCommDev; + + +// Wrapper to track an MR per-device, if needed +struct ncclIbMrHandle { + struct ibv_mr* mrs[NCCL_IB_MAX_DEVS_PER_NIC]; +}; + +typedef struct ncclIbNetCommBase { + int ndevs; + bool isSend; + struct ncclIbRequest reqs[MAX_REQUESTS]; + struct ncclIbQp qps[NCCL_IB_MAX_QPS]; + int nqps; + int qpIndex; + int devIndex; + struct ncclSocket sock; + int ready; + // Track necessary remDevInfo here + int nRemDevs; + struct ncclIbDevInfo remDevs[NCCL_IB_MAX_DEVS_PER_NIC]; + } __attribute__((aligned(32))) ncclIbNetCommBase; + struct ncclIbSendComm { - struct ncclIbVerbs verbs; + struct ncclIbNetCommBase base; struct ncclIbSendFifo fifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; - struct ncclIbRemSizesFifo remSizesFifo; - uint64_t fifoHead; + // Each dev correlates to a mergedIbDev + struct ncclIbSendCommDev devs[NCCL_IB_MAX_DEVS_PER_NIC]; struct ncclIbRequest* fifoReqs[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; - struct ibv_send_wr wrs[NCCL_NET_IB_MAX_RECVS+1]; struct ibv_sge sges[NCCL_NET_IB_MAX_RECVS]; - struct ncclSocket sock; - - int ready; - struct ibv_qp* qps[NCCL_IB_MAX_QPS]; - int nqps; - int qpIndex; - struct ibv_mr* fifoMr; - int ar; - struct ncclIbGidInfo gidInfo; + struct ibv_send_wr wrs[NCCL_NET_IB_MAX_RECVS+1]; + struct ncclIbRemSizesFifo remSizesFifo; + uint64_t fifoHead; + int ar; // Use adaptive routing when all merged devices have it enabled }; struct ncclIbGpuFlush { - int enabled; - int hostMem; struct ibv_mr* hostMr; struct ibv_sge sge; - struct ibv_qp* qp; + struct ncclIbQp qp; }; struct ncclIbRemFifo { struct ncclIbSendFifo elems[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; uint64_t fifoTail; uint64_t addr; - uint32_t rkey; uint32_t flags; - struct ibv_mr* mr; - struct ibv_sge sge; }; +struct ncclIbRecvCommDev { + struct ncclIbNetCommDevBase base; + struct ncclIbGpuFlush gpuFlush; + uint32_t fifoRkey; + struct ibv_mr* fifoMr; + struct ibv_sge fifoSge; + struct ibv_mr* sizesFifoMr; +} __attribute__((aligned(16))); + struct ncclIbRecvComm { - struct ncclIbVerbs verbs; + struct ncclIbNetCommBase base; + struct ncclIbRecvCommDev devs[NCCL_IB_MAX_DEVS_PER_NIC]; struct ncclIbRemFifo remFifo; int sizesFifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; - struct ncclSocket sock; - int ready; - struct ibv_qp* qps[NCCL_IB_MAX_QPS]; - int nqps; - int qpIndex; - struct ncclIbGpuFlush gpuFlush; - struct ibv_mr* sizesFifoMr; - struct ncclIbGidInfo gidInfo; + int gpuFlushHostMem; + int flushEnabled; }; ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { @@ -253,55 +296,65 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { // The SendFifo needs to be 32-byte aligned and each element needs // to be a 32-byte multiple, so that an entry does not get split and // written out of order when IB Relaxed Ordering is enabled + NCCL_STATIC_ASSERT((sizeof(struct ncclIbNetCommBase) % 32) == 0, "ncclIbNetCommBase size must be 32-byte multiple to ensure fifo is at proper offset"); NCCL_STATIC_ASSERT((offsetof(struct ncclIbSendComm, fifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned"); NCCL_STATIC_ASSERT((sizeof(struct ncclIbSendFifo) % 32) == 0, "ncclIbSendFifo element size must be 32-byte multiples"); + NCCL_STATIC_ASSERT((offsetof(struct ncclIbSendComm, sges) % 32) == 0, "sges must be 32-byte aligned"); + NCCL_STATIC_ASSERT((offsetof(struct ncclIbSendComm, wrs) % 32) == 0, "wrs must be 32-byte aligned"); NCCL_STATIC_ASSERT((offsetof(struct ncclIbRecvComm, remFifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned"); + return nccl_p2p_ib_init(&ncclNIbDevs, ncclIbDevs, ncclIbIfName, &ncclIbIfAddr, &ncclIbAsyncThread, logFunction); } NCCL_PARAM(IbQpsPerConn, "IB_QPS_PER_CONNECTION", 1); -ncclResult_t ncclIbInitVerbs(int dev, struct ibv_context* ctx, struct ncclIbVerbs* verbs) { - verbs->dev = dev; +static void ncclIbAddEvent(struct ncclIbRequest* req, int devIndex, struct ncclIbNetCommDevBase* base) { + req->events[devIndex]++; + req->devBases[devIndex] = base; +} - pthread_mutex_lock(&ncclIbDevs[dev].lock); - if (0 == ncclIbDevs[dev].pdRefs++) { +ncclResult_t ncclIbInitCommDevBase(int ibDevN, struct ncclIbNetCommDevBase* base) { + base->ibDevN = ibDevN; + ncclIbDev* ibDev = ncclIbDevs + ibDevN; + pthread_mutex_lock(&ibDev->lock); + if (0 == ibDev->pdRefs++) { ncclResult_t res; - NCCLCHECKGOTO(wrap_ibv_alloc_pd(&ncclIbDevs[dev].pd, ctx), res, failure); + NCCLCHECKGOTO(wrap_ibv_alloc_pd(&ibDev->pd, ibDev->context), res, failure); if (0) { failure: - pthread_mutex_unlock(&ncclIbDevs[dev].lock); + pthread_mutex_unlock(&ibDev->lock); return res; } } - verbs->pd = ncclIbDevs[dev].pd; - pthread_mutex_unlock(&ncclIbDevs[dev].lock); + base->pd = ibDev->pd; + pthread_mutex_unlock(&ibDev->lock); // Recv requests can generate 2 completions (one for the post FIFO, one for the Recv). - NCCLCHECK(wrap_ibv_create_cq(&verbs->cq, ctx, 2*MAX_REQUESTS*ncclParamIbQpsPerConn(), NULL, NULL, 0)); + NCCLCHECK(wrap_ibv_create_cq(&base->cq, ibDev->context, 2*MAX_REQUESTS*ncclParamIbQpsPerConn(), NULL, NULL, 0)); + return ncclSuccess; } -ncclResult_t ncclIbDestroyVerbs(struct ncclIbVerbs* verbs) { +ncclResult_t ncclIbDestroyBase(struct ncclIbNetCommDevBase* base) { ncclResult_t res; - NCCLCHECK(wrap_ibv_destroy_cq(verbs->cq)); + NCCLCHECK(wrap_ibv_destroy_cq(base->cq)); - pthread_mutex_lock(&ncclIbDevs[verbs->dev].lock); - if (0 == --ncclIbDevs[verbs->dev].pdRefs) { - NCCLCHECKGOTO(wrap_ibv_dealloc_pd(ncclIbDevs[verbs->dev].pd), res, returning); + pthread_mutex_lock(&ncclIbDevs[base->ibDevN].lock); + if (0 == --ncclIbDevs[base->ibDevN].pdRefs) { + NCCLCHECKGOTO(wrap_ibv_dealloc_pd(ncclIbDevs[base->ibDevN].pd), res, returning); } res = ncclSuccess; returning: - pthread_mutex_unlock(&ncclIbDevs[verbs->dev].lock); + pthread_mutex_unlock(&ncclIbDevs[base->ibDevN].lock); return res; } -ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbVerbs* verbs, int access_flags, struct ibv_qp** qp) { +ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, int access_flags, struct ncclIbQp* qp) { struct ibv_qp_init_attr qpInitAttr; memset(&qpInitAttr, 0, sizeof(struct ibv_qp_init_attr)); - qpInitAttr.send_cq = verbs->cq; - qpInitAttr.recv_cq = verbs->cq; + qpInitAttr.send_cq = base->cq; + qpInitAttr.recv_cq = base->cq; qpInitAttr.qp_type = IBV_QPT_RC; // We might send 2 messages per send (RDMA and RDMA_WITH_IMM) qpInitAttr.cap.max_send_wr = 2*MAX_REQUESTS; @@ -309,23 +362,23 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbVerbs* verbs, int acce qpInitAttr.cap.max_send_sge = 1; qpInitAttr.cap.max_recv_sge = 1; qpInitAttr.cap.max_inline_data = ncclParamIbUseInline() ? sizeof(struct ncclIbSendFifo) : 0; - NCCLCHECK(wrap_ibv_create_qp(qp, verbs->pd, &qpInitAttr)); + NCCLCHECK(wrap_ibv_create_qp(&qp->qp, base->pd, &qpInitAttr)); struct ibv_qp_attr qpAttr; memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); qpAttr.qp_state = IBV_QPS_INIT; qpAttr.pkey_index = ncclParamIbPkey(); qpAttr.port_num = ib_port; qpAttr.qp_access_flags = access_flags; - NCCLCHECK(wrap_ibv_modify_qp(*qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)); + NCCLCHECK(wrap_ibv_modify_qp(qp->qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)); return ncclSuccess; } -ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t qpn, struct ncclIbQpInfo* info) { +ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t dest_qp_num, struct ncclIbDevInfo* info) { struct ibv_qp_attr qpAttr; memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); qpAttr.qp_state = IBV_QPS_RTR; qpAttr.path_mtu = info->mtu; - qpAttr.dest_qp_num = qpn; + qpAttr.dest_qp_num = dest_qp_num; qpAttr.rq_psn = 0; qpAttr.max_dest_rd_atomic = 1; qpAttr.min_rnr_timer = 12; @@ -380,7 +433,7 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet struct ncclIbHandle* handle = (struct ncclIbHandle*) opaqueHandle; struct ncclIbCommStage* stage = &handle->stage; struct ncclIbSendComm* comm = (struct ncclIbSendComm*)stage->comm; - struct ncclIbQpInfo remQpInfo; + struct ncclIbConnectionMetadata remMeta; int ready; *sendComm = NULL; @@ -394,100 +447,188 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet } NCCLCHECK(ncclIbMalloc((void**)&comm, sizeof(struct ncclIbSendComm))); - NCCLCHECK(ncclSocketInit(&comm->sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1)); + NCCLCHECK(ncclSocketInit(&comm->base.sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1)); stage->comm = comm; stage->state = ncclIbCommStateConnect; - NCCLCHECK(ncclSocketConnect(&comm->sock)); + NCCLCHECK(ncclSocketConnect(&comm->base.sock)); ib_connect_check: /* since ncclSocketConnect is async, we must check if connection is complete */ - NCCLCHECK(ncclSocketReady(&comm->sock, &ready)); + NCCLCHECK(ncclSocketReady(&comm->base.sock, &ready)); if (!ready) return ncclSuccess; // IB Setup - struct ibv_context* ctx; - ctx = ncclIbDevs[dev].context; - NCCLCHECK(ncclIbInitVerbs(dev, ctx, &comm->verbs)); - uint8_t ib_port; - ib_port = ncclIbDevs[dev].port; - comm->nqps = ncclParamIbQpsPerConn(); - for (int q=0; qnqps; q++) { - NCCLCHECK(ncclIbCreateQp(ib_port, &comm->verbs, IBV_ACCESS_REMOTE_WRITE, comm->qps+q)); + struct ncclIbMergedDev* mergedDev; + mergedDev = ncclIbMergedDevs + dev; + comm->base.ndevs = mergedDev->ndevs; + comm->base.nqps = ncclParamIbQpsPerConn() * comm->base.ndevs; // We must have at least 1 qp per-device + comm->base.isSend = true; + + // Init PD, Ctx for each IB device + comm->ar = 1; // Set to 1 for logic + for (int i = 0; i < mergedDev->ndevs; i++) { + int ibDevN = mergedDev->devs[i]; + NCCLCHECK(ncclIbInitCommDevBase(ibDevN, &comm->devs[i].base)); + comm->ar = comm->ar && ncclIbDevs[dev].ar; // ADAPTIVE_ROUTING - if all merged devs have it enabled } - comm->ar = ncclIbDevs[dev].ar; // ADAPTIVE_ROUTING - - // Send my QP Info to receiver through the socket. Hope this won't block. - struct ibv_port_attr portAttr; - NCCLCHECK(wrap_ibv_query_port(ctx, ib_port, &portAttr)); - struct ncclIbQpInfo qpInfo; - qpInfo.ib_port = ib_port; - for (int q=0; qnqps; q++) qpInfo.qpn[q] = comm->qps[q]->qp_num; - qpInfo.mtu = portAttr.active_mtu; - - // Prepare my fifo - NCCLCHECK(wrap_ibv_reg_mr(&comm->fifoMr, comm->verbs.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ)); - qpInfo.fifoRkey = comm->fifoMr->rkey; - qpInfo.fifoAddr = (uint64_t)comm->fifo; - - // RoCE support - qpInfo.lid = portAttr.lid; - qpInfo.link_layer = comm->gidInfo.link_layer = portAttr.link_layer; - qpInfo.is_global = (ncclParamIbIsGlobal() -#if HAVE_DECL_IBV_QPF_GRH_REQUIRED - || (portAttr.flags & IBV_QPF_GRH_REQUIRED) + + struct ncclIbConnectionMetadata meta; + meta.ndevs = comm->base.ndevs; + + // Alternate QPs between devices + int devIndex; + devIndex = 0; + for (int q = 0; q < comm->base.nqps; q++) { + ncclIbSendCommDev* commDev = comm->devs + devIndex; + ncclIbDev* ibDev = ncclIbDevs + commDev->base.ibDevN; + NCCLCHECK(ncclIbCreateQp(ibDev->portNum, &commDev->base, IBV_ACCESS_REMOTE_WRITE, comm->base.qps+q)); + comm->base.qps[q].devIndex = devIndex; + meta.qpInfo[q].qpn = comm->base.qps[q].qp->qp_num; + meta.qpInfo[q].devIndex = comm->base.qps[q].devIndex; + + // Query ece capabilities (enhanced connection establishment) + NCCLCHECK(wrap_ibv_query_ece(comm->base.qps[q].qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported)); + devIndex = (devIndex + 1) % comm->base.ndevs; + } + + for (int i = 0; i < comm->base.ndevs; i++) { + ncclIbSendCommDev* commDev = comm->devs + i; + ncclIbDev* ibDev = ncclIbDevs + commDev->base.ibDevN; + // Send my QP Info to receiver through the socket. Hope this won't block. + // TODO - I thought I queried this in init? + NCCLCHECK(wrap_ibv_query_port(ibDev->context, ibDev->portNum, &ibDev->portAttr)); + + // Write to the metadata struct via this pointer + ncclIbDevInfo* devInfo = meta.devs + i; + devInfo->ib_port = ibDev->portNum; + devInfo->mtu = ibDev->portAttr.active_mtu; + devInfo->lid = ibDev->portAttr.lid; + + // Prepare my fifo + NCCLCHECK(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ)); + devInfo->fifoRkey = commDev->fifoMr->rkey; + + // RoCE support + devInfo->link_layer = commDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; + devInfo->is_global = (ncclParamIbIsGlobal() + #if HAVE_DECL_IBV_QPF_GRH_REQUIRED + || (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED) #endif ); - if (qpInfo.link_layer == IBV_LINK_LAYER_INFINIBAND && !qpInfo.is_global) { // IB - for (int q=0; qnqps; q++) - INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d LID %d", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, qpInfo.lid); - } else { // RoCE - NCCLCHECK(wrap_ibv_query_gid(ctx, ib_port, ncclParamIbGidIndex(), &comm->gidInfo.localGid)); - qpInfo.spn = comm->gidInfo.localGid.global.subnet_prefix; - qpInfo.iid = comm->gidInfo.localGid.global.interface_id; - for (int q=0; qnqps; q++) - INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d GID %ld (%lX/%lX)", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, ncclParamIbGidIndex(), qpInfo.spn, qpInfo.iid); + + if (devInfo->link_layer == IBV_LINK_LAYER_ETHERNET || devInfo->is_global) { + NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, ncclParamIbGidIndex(), &commDev->base.gidInfo.localGid)); + devInfo->spn = commDev->base.gidInfo.localGid.global.subnet_prefix; + devInfo->iid = commDev->base.gidInfo.localGid.global.interface_id; + } + + if (devInfo->link_layer == IBV_LINK_LAYER_INFINIBAND && !devInfo->is_global) { // IB + for (int q = 0; q < comm->base.nqps; q++) { + // Print just the QPs for this dev + if (comm->base.qps[q].devIndex == i) + INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d LID %d fifoRkey=0x%x fifoLkey=0x%x", + comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", + dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, devInfo->lid, devInfo->fifoRkey, commDev->fifoMr->lkey); + } + } else { // RoCE + for (int q = 0; q < comm->base.nqps; q++) { + // Print just the QPs for this dev + if (comm->base.qps[q].devIndex == i) + INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x", + comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev, + commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, ncclParamIbGidIndex(), + devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey); + } + } } + meta.fifoAddr = (uint64_t)comm->fifo; + strncpy(meta.devName, mergedDev->devName, MAX_MERGED_DEV_NAME); + stage->state = ncclIbCommStateSend; stage->offset = 0; - NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(qpInfo))); - memcpy(stage->buffer, &qpInfo, sizeof(qpInfo)); + NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(meta))); + + memcpy(stage->buffer, &meta, sizeof(meta)); ib_send: - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->sock, stage->buffer, sizeof(qpInfo), &stage->offset)); - if (stage->offset != sizeof(qpInfo)) return ncclSuccess; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, stage->buffer, sizeof(meta), &stage->offset)); + if (stage->offset != sizeof(meta)) return ncclSuccess; + stage->state = ncclIbCommStateConnecting; stage->offset = 0; // Clear the staging buffer for re-use - memset(stage->buffer, 0, sizeof(qpInfo)); + memset(stage->buffer, 0, sizeof(meta)); ib_connect: - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, stage->buffer, sizeof(ncclIbQpInfo), &stage->offset)); - if (stage->offset != sizeof(remQpInfo)) return ncclSuccess; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->base.sock, stage->buffer, sizeof(ncclIbConnectionMetadata), &stage->offset)); + if (stage->offset != sizeof(remMeta)) return ncclSuccess; - memcpy(&remQpInfo, stage->buffer, sizeof(ncclIbQpInfo)); + memcpy(&remMeta, stage->buffer, sizeof(ncclIbConnectionMetadata)); - comm->gidInfo.remoteGid.global.subnet_prefix = remQpInfo.spn; - comm->gidInfo.remoteGid.global.interface_id = remQpInfo.iid; - for (int q=0; qnqps; q++) { - struct ibv_qp* qp = comm->qps[q]; - NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo)); - NCCLCHECK(ncclIbRtsQp(qp)); + comm->base.nRemDevs = remMeta.ndevs; + if (comm->base.nRemDevs != comm->base.ndevs) { + mergedDev = ncclIbMergedDevs + dev; + WARN("NET/IB : Local mergedDev=%s has a different number of devices=%d as remoteDev=%s nRemDevs=%d", + mergedDev->devName, comm->base.ndevs, remMeta.devName, comm->base.nRemDevs); + } + + int link_layer; + link_layer = remMeta.devs[0].link_layer; + for (int i = 1; i < remMeta.ndevs; i++) { + if (remMeta.devs[i].link_layer != link_layer) { + WARN("NET/IB : Can't merge net devices with different link_layer. i=%d remMeta.ndevs=%d link_layer=%d rem_link_layer=%d", + i, remMeta.ndevs, link_layer, remMeta.devs[i].link_layer); + return ncclInternalError; + } + } + + // Copy remDevInfo for things like remGidInfo, remFifoAddr, etc. + for (int i = 0; i < remMeta.ndevs; i++) { + comm->base.remDevs[i] = remMeta.devs[i]; + comm->base.remDevs[i].remoteGid.global.interface_id = comm->base.remDevs[i].iid; + comm->base.remDevs[i].remoteGid.global.subnet_prefix = comm->base.remDevs[i].spn; + + // Retain remote sizes fifo info and prepare RDMA ops + comm->remSizesFifo.rkeys[i] = remMeta.devs[i].fifoRkey; + comm->remSizesFifo.addr = remMeta.fifoAddr; + } + + for (int i=0; i < comm->base.ndevs; i++) { + NCCLCHECK(wrap_ibv_reg_mr(comm->remSizesFifo.mrs+i, comm->devs[i].base.pd, &comm->remSizesFifo.elems, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ)); } + comm->base.nRemDevs = remMeta.ndevs; + for (int q = 0; q < comm->base.nqps; q++) { + struct ncclIbQpInfo* remQpInfo = remMeta.qpInfo + q; + struct ncclIbDevInfo* remDevInfo = remMeta.devs + remQpInfo->devIndex; - // Retain remote sizes fifo info and prepare RDMA ops - comm->remSizesFifo.rkey = remQpInfo.fifoRkey; - comm->remSizesFifo.addr = remQpInfo.fifoAddr; - NCCLCHECK(wrap_ibv_reg_mr(&comm->remSizesFifo.mr, comm->verbs.pd, &comm->remSizesFifo.elems, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ)); - comm->remSizesFifo.sge.lkey = comm->remSizesFifo.mr->lkey; + // Assign per-QP remDev + comm->base.qps[q].remDevIdx = remQpInfo->devIndex; - comm->ready = 1; + struct ibv_qp* qp = comm->base.qps[q].qp; + if (remQpInfo->ece_supported && remQpInfo->ece_supported) + NCCLCHECK(wrap_ibv_set_ece(qp, &remQpInfo->ece, &remQpInfo->ece_supported)); + + NCCLCHECK(ncclIbRtrQp(qp, remQpInfo->qpn, remDevInfo)); + NCCLCHECK(ncclIbRtsQp(qp)); + } + + if (link_layer == IBV_LINK_LAYER_ETHERNET ) { // RoCE + for (int q = 0; q < comm->base.nqps; q++) { + struct ncclIbQp* qp = comm->base.qps + q; + int ibDevN = comm->devs[qp->devIndex].base.ibDevN; + struct ncclIbDev* ibDev = ncclIbDevs + ibDevN; + INFO(NCCL_NET,"NET/IB: IbDev %d Port %d qpn %d set_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x}", + ibDevN, ibDev->portNum, remMeta.qpInfo[q].qpn, remMeta.qpInfo[q].ece_supported, remMeta.qpInfo[q].ece.vendor_id, remMeta.qpInfo[q].ece.options, remMeta.qpInfo[q].ece.comp_mask); + } + } + comm->base.ready = 1; stage->state = ncclIbCommStateConnected; stage->offset = 0; ib_send_ready: - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->sock, &comm->ready, sizeof(int), &stage->offset)); + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, &comm->base.ready, sizeof(int), &stage->offset)); if (stage->offset != sizeof(int)) return ncclSuccess; free(stage->buffer); @@ -523,121 +664,167 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl NCCLCHECK(ncclIbMalloc((void**)&rComm, sizeof(struct ncclIbRecvComm))); stage->comm = rComm; stage->state = ncclIbCommStateAccept; - NCCLCHECK(ncclSocketInit(&rComm->sock, NULL, NCCL_SOCKET_MAGIC, ncclSocketTypeUnknown, NULL, 0)); - NCCLCHECK(ncclSocketAccept(&rComm->sock, &lComm->sock)); + NCCLCHECK(ncclSocketInit(&rComm->base.sock, NULL, NCCL_SOCKET_MAGIC, ncclSocketTypeUnknown, NULL, 0)); + NCCLCHECK(ncclSocketAccept(&rComm->base.sock, &lComm->sock)); ib_accept_check: - NCCLCHECK(ncclSocketReady(&rComm->sock, &ready)); + NCCLCHECK(ncclSocketReady(&rComm->base.sock, &ready)); if (!ready) return ncclSuccess; - struct ncclIbQpInfo remQpInfo; + struct ncclIbConnectionMetadata remMeta; stage->state = ncclIbCommStateRecv; stage->offset = 0; - NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(remQpInfo))); + NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(remMeta))); ib_recv: - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->sock, stage->buffer, sizeof(remQpInfo), &stage->offset)); - if (stage->offset != sizeof(remQpInfo)) return ncclSuccess; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, stage->buffer, sizeof(remMeta), &stage->offset)); + if (stage->offset != sizeof(remMeta)) return ncclSuccess; /* copy back the received info */ - memcpy(&remQpInfo, stage->buffer, sizeof(struct ncclIbQpInfo)); - - rComm->gidInfo.remoteGid.global.subnet_prefix = remQpInfo.spn; - rComm->gidInfo.remoteGid.global.interface_id = remQpInfo.iid; + memcpy(&remMeta, stage->buffer, sizeof(struct ncclIbConnectionMetadata)); + + // IB setup + // Pre-declare variables because of goto + struct ncclIbMergedDev* mergedDev; + struct ncclIbDev* ibDev; + int ibDevN; + struct ncclIbRecvCommDev* rCommDev; + struct ncclIbDevInfo* remDevInfo; + struct ncclIbQp* qp; + + mergedDev = ncclIbMergedDevs + lComm->dev; + rComm->base.ndevs = mergedDev->ndevs; + rComm->base.nqps = ncclParamIbQpsPerConn() * rComm->base.ndevs; // We must have at least 1 qp per-device + rComm->base.isSend = false; + + rComm->base.nRemDevs = remMeta.ndevs; + if (rComm->base.nRemDevs != rComm->base.ndevs) { + WARN("NET/IB : Local mergedDev %s has a different number of devices=%d as remote %s %d", + mergedDev->devName, rComm->base.ndevs, remMeta.devName, rComm->base.nRemDevs); + } - // IB setup - struct ibv_context* ctx; - uint8_t ib_port; - ctx = ncclIbDevs[lComm->dev].context; - ib_port = ncclIbDevs[lComm->dev].port; - struct ibv_port_attr portAttr; - NCCLCHECK(wrap_ibv_query_port(ctx, ib_port, &portAttr)); - NCCLCHECK(wrap_ibv_query_gid(ctx, ib_port, ncclParamIbGidIndex(), &rComm->gidInfo.localGid)); - - // QP Creation - NCCLCHECK(ncclIbInitVerbs(lComm->dev, ctx, &rComm->verbs)); - rComm->nqps = ncclParamIbQpsPerConn(); - for (int q=0; qnqps; q++) { - NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_REMOTE_WRITE, rComm->qps+q)); + // Metadata to send back to requestor (sender) + struct ncclIbConnectionMetadata meta; + for (int i = 0; i < rComm->base.ndevs; i++) { + rCommDev = rComm->devs + i; + ibDevN = mergedDev->devs[i]; + NCCLCHECK(ncclIbInitCommDevBase(ibDevN, &rCommDev->base)); + ibDev = ncclIbDevs + ibDevN; + NCCLCHECK(wrap_ibv_query_port(ibDev->context, ibDev->portNum, &ibDev->portAttr)); + NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, ncclParamIbGidIndex(), &rCommDev->base.gidInfo.localGid)); } - // Adjust the MTU - remQpInfo.mtu = (enum ibv_mtu)MIN(remQpInfo.mtu, portAttr.active_mtu); + // Copy remDevInfo for things like remGidInfo, remFifoAddr, etc. + for (int i = 0; i < remMeta.ndevs; i++) { + rComm->base.remDevs[i] = remMeta.devs[i]; + rComm->base.remDevs[i].remoteGid.global.interface_id = rComm->base.remDevs[i].iid; + rComm->base.remDevs[i].remoteGid.global.subnet_prefix = rComm->base.remDevs[i].spn; + } - // Setup QP - for (int q=0; qnqps; q++) { - struct ibv_qp* qp = rComm->qps[q]; - NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo)); - NCCLCHECK(ncclIbRtsQp(qp)); + // Stripe QP creation across merged devs + // Make sure to get correct remote peer dev and QP info + int remDevIndex; + int devIndex; + devIndex = 0; + for (int q = 0; q < rComm->base.nqps; q++) { + remDevIndex = remMeta.qpInfo[q].devIndex; + remDevInfo = remMeta.devs + remDevIndex; + qp = rComm->base.qps+q; + rCommDev = rComm->devs + devIndex; + qp->remDevIdx = remDevIndex; + + // Local ibDevN + ibDevN = rComm->devs[devIndex].base.ibDevN; + ibDev = ncclIbDevs + ibDevN; + NCCLCHECK(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_REMOTE_WRITE, qp)); + qp->devIndex = devIndex; + devIndex = (devIndex + 1) % rComm->base.ndevs; + + // Set the ece (enhanced connection establishment) on this QP before RTR + if (remMeta.qpInfo[q].ece_supported) { + NCCLCHECK(wrap_ibv_set_ece(qp->qp, &remMeta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported)); + + // Query the reduced ece for this QP (matching enhancements between the requestor and the responder) + // Store this in our own qpInfo for returning to the requestor + if (meta.qpInfo[q].ece_supported) + NCCLCHECK(wrap_ibv_query_ece(qp->qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported)); + } + NCCLCHECK(ncclIbRtrQp(qp->qp, remMeta.qpInfo[q].qpn, remDevInfo)); + NCCLCHECK(ncclIbRtsQp(qp->qp)); } - // Retain remote fifo info and prepare my RDMA ops - rComm->remFifo.rkey = remQpInfo.fifoRkey; - rComm->remFifo.addr = remQpInfo.fifoAddr; - NCCLCHECK(wrap_ibv_reg_mr(&rComm->remFifo.mr, rComm->verbs.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ)); - rComm->remFifo.sge.lkey = rComm->remFifo.mr->lkey; - if (ncclParamIbUseInline()) rComm->remFifo.flags = IBV_SEND_INLINE; - - // Allocate Flush dummy buffer for GPU Direct RDMA - rComm->gpuFlush.enabled = ((nccl_p2p_gdr_support(lComm->dev) == ncclSuccess) || nccl_p2p_dmabuf_support(lComm->dev) == ncclSuccess) && - (ncclParamIbGdrFlushDisable() == 0) ? 1 : 0; - if (rComm->gpuFlush.enabled) { - NCCLCHECK(wrap_ibv_reg_mr(&rComm->gpuFlush.hostMr, rComm->verbs.pd, &rComm->gpuFlush.hostMem, sizeof(int), IBV_ACCESS_LOCAL_WRITE)); - rComm->gpuFlush.sge.addr = (uint64_t)&rComm->gpuFlush.hostMem; - rComm->gpuFlush.sge.length = 1; - rComm->gpuFlush.sge.lkey = rComm->gpuFlush.hostMr->lkey; - NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ, &rComm->gpuFlush.qp)); - struct ncclIbQpInfo localQpInfo; - localQpInfo.lid=portAttr.lid; - localQpInfo.link_layer=portAttr.link_layer; - localQpInfo.ib_port=ib_port; - localQpInfo.spn=rComm->gidInfo.localGid.global.subnet_prefix; - localQpInfo.iid=rComm->gidInfo.localGid.global.interface_id; - localQpInfo.is_global=(ncclParamIbIsGlobal() -#if HAVE_DECL_IBV_QPF_GRH_REQUIRED - || (portAttr.flags & IBV_QPF_GRH_REQUIRED) -#endif - ); - localQpInfo.mtu=portAttr.active_mtu; - NCCLCHECK(ncclIbRtrQp(rComm->gpuFlush.qp, rComm->gpuFlush.qp->qp_num, &localQpInfo)); - NCCLCHECK(ncclIbRtsQp(rComm->gpuFlush.qp)); + rComm->flushEnabled = ((nccl_p2p_gdr_support() == ncclSuccess || nccl_p2p_dmabuf_support(lComm->dev) == ncclSuccess) + && (ncclParamIbGdrFlushDisable() == 0)) ? 1 : 0; + + for (int i = 0; i < mergedDev->ndevs; i++) { + rCommDev = rComm->devs + i; + ibDevN = rCommDev->base.ibDevN; + ibDev = ncclIbDevs + ibDevN; + + // Retain remote fifo info and prepare my RDMA ops + rCommDev->fifoRkey = remMeta.devs[i].fifoRkey; + rComm->remFifo.addr = remMeta.fifoAddr; + NCCLCHECK(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ)); + rCommDev->fifoSge.lkey = rCommDev->fifoMr->lkey; + if (ncclParamIbUseInline()) rComm->remFifo.flags = IBV_SEND_INLINE; + + // Allocate Flush dummy buffer for GPU Direct RDMA + if (rComm->flushEnabled) { + NCCLCHECK(wrap_ibv_reg_mr(&rCommDev->gpuFlush.hostMr, rCommDev->base.pd, &rComm->gpuFlushHostMem, sizeof(int), IBV_ACCESS_LOCAL_WRITE)); + rCommDev->gpuFlush.sge.addr = (uint64_t)&rComm->gpuFlushHostMem; + rCommDev->gpuFlush.sge.length = 1; + rCommDev->gpuFlush.sge.lkey = rCommDev->gpuFlush.hostMr->lkey; + NCCLCHECK(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ, &rCommDev->gpuFlush.qp)); + struct ncclIbDevInfo devInfo; + devInfo.lid = ibDev->portAttr.lid; + devInfo.link_layer = ibDev->portAttr.link_layer; + devInfo.ib_port = ibDev->portNum; + devInfo.spn = rCommDev->base.gidInfo.localGid.global.subnet_prefix; + devInfo.iid = rCommDev->base.gidInfo.localGid.global.interface_id; + devInfo.mtu = ibDev->portAttr.active_mtu; + NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo)); + NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp)); + } + + // Fill Handle + meta.devs[i].lid = ibDev->portAttr.lid; + meta.devs[i].link_layer = rCommDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; + meta.devs[i].ib_port = ibDev->portNum; + meta.devs[i].spn = rCommDev->base.gidInfo.localGid.global.subnet_prefix; + meta.devs[i].iid = rCommDev->base.gidInfo.localGid.global.interface_id; + + // Adjust the MTU + remMeta.devs[i].mtu = (enum ibv_mtu)MIN(remMeta.devs[i].mtu, ibDev->portAttr.active_mtu); + meta.devs[i].mtu = remMeta.devs[i].mtu; + + // Prepare sizes fifo + NCCLCHECK(wrap_ibv_reg_mr(&rComm->devs[i].sizesFifoMr, rComm->devs[i].base.pd, rComm->sizesFifo, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ)); + meta.devs[i].fifoRkey = rComm->devs[i].sizesFifoMr->rkey; } + meta.fifoAddr = (uint64_t)rComm->sizesFifo; - // Fill Handle - struct ncclIbQpInfo qpInfo; - qpInfo.lid=portAttr.lid; - qpInfo.link_layer= rComm->gidInfo.link_layer = portAttr.link_layer; - qpInfo.ib_port=ib_port; - for (int q=0; qnqps; q++) qpInfo.qpn[q]=rComm->qps[q]->qp_num; - qpInfo.spn=rComm->gidInfo.localGid.global.subnet_prefix; - qpInfo.iid=rComm->gidInfo.localGid.global.interface_id; - qpInfo.is_global=(ncclParamIbIsGlobal() -#if HAVE_DECL_IBV_QPF_GRH_REQUIRED - || (portAttr.flags & IBV_QPF_GRH_REQUIRED) -#endif - ); - qpInfo.mtu=remQpInfo.mtu; + for (int q = 0; q < rComm->base.nqps; q++) { + meta.qpInfo[q].qpn = rComm->base.qps[q].qp->qp_num; + meta.qpInfo[q].devIndex = rComm->base.qps[q].devIndex; + } - // Prepare sizes fifo - NCCLCHECK(wrap_ibv_reg_mr(&rComm->sizesFifoMr, rComm->verbs.pd, rComm->sizesFifo, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ)); - qpInfo.fifoRkey = rComm->sizesFifoMr->rkey; - qpInfo.fifoAddr = (uint64_t)rComm->sizesFifo; + meta.ndevs = rComm->base.ndevs; + strncpy(meta.devName, mergedDev->devName, MAX_MERGED_DEV_NAME); stage->state = ncclIbCommStateSend; stage->offset = 0; if (stage->buffer) free(stage->buffer); - NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(struct ncclIbQpInfo))); - memcpy(stage->buffer, &qpInfo, sizeof(struct ncclIbQpInfo)); - + NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(struct ncclIbConnectionMetadata))); + memcpy(stage->buffer, &meta, sizeof(struct ncclIbConnectionMetadata)); ib_send: - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &rComm->sock, stage->buffer, sizeof(struct ncclIbQpInfo), &stage->offset)); - if (stage->offset < sizeof(struct ncclIbQpInfo)) return ncclSuccess; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &rComm->base.sock, stage->buffer, sizeof(struct ncclIbConnectionMetadata), &stage->offset)); + if (stage->offset < sizeof(struct ncclIbConnectionMetadata)) return ncclSuccess; stage->offset = 0; stage->state = ncclIbCommStatePendingReady; ib_recv_ready: - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->sock, &rComm->ready, sizeof(int), &stage->offset)); + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, &rComm->base.ready, sizeof(int), &stage->offset)); if (stage->offset != sizeof(int)) return ncclSuccess; free(stage->buffer); @@ -650,19 +837,21 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl stage->buffer = NULL; return ncclSuccess; } + ncclResult_t ncclIbAccept_v6(void* listenComm, void** recvComm) { ncclNetDeviceHandle_v7_t* handle = NULL; return ncclIbAccept(listenComm, recvComm, &handle); } -ncclResult_t ncclIbGetRequest(struct ncclIbVerbs* verbs, struct ncclIbRequest** req) { +ncclResult_t ncclIbGetRequest(struct ncclIbNetCommBase* base, struct ncclIbRequest** req) { for (int i=0; ireqs+i; + struct ncclIbRequest* r = base->reqs+i; if (r->type == NCCL_NET_IB_REQ_UNUSED) { - r->verbs = verbs; - r->events = 1; + r->base = base; r->sock = NULL; - r->gidInfo = NULL; + r->devBases[0] = NULL; + r->devBases[1] = NULL; + r->events[0] = r->events[1] = 0; *req = r; return ncclSuccess; } @@ -671,6 +860,7 @@ ncclResult_t ncclIbGetRequest(struct ncclIbVerbs* verbs, struct ncclIbRequest** *req = NULL; return ncclInternalError; } + ncclResult_t ncclIbFreeRequest(struct ncclIbRequest* r) { r->type = NCCL_NET_IB_REQ_UNUSED; return ncclSuccess; @@ -678,20 +868,15 @@ ncclResult_t ncclIbFreeRequest(struct ncclIbRequest* r) { ncclResult_t ncclIbTest(void* request, int* done, int* size); -/* DMA-BUF support */ -ncclResult_t ncclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) { - NCCL_STATIC_ASSERT(offsetof(struct ncclIbSendComm, verbs) == offsetof(struct ncclIbRecvComm, verbs), "Send and recv comms must have verbs at the same offset") - assert(size > 0); - +ncclResult_t ncclIbRegMrDmaBufInternal(ncclIbNetCommDevBase* base, void* data, size_t size, int type, uint64_t offset, int fd, struct ibv_mr** mhandle) { static __thread uintptr_t pageSize = 0; if (pageSize == 0) pageSize = sysconf(_SC_PAGESIZE); - struct ncclIbVerbs* verbs = (struct ncclIbVerbs*)comm; - struct ncclIbMrCache* cache = &ncclIbDevs[verbs->dev].mrCache; + struct ncclIbMrCache* cache = &ncclIbDevs[base->ibDevN].mrCache; uintptr_t addr = (uintptr_t)data & -pageSize; size_t pages = ((uintptr_t)data + size - addr + pageSize-1)/pageSize; ncclResult_t res; - pthread_mutex_lock(&ncclIbDevs[verbs->dev].lock); + pthread_mutex_lock(&ncclIbDevs[base->ibDevN].lock); for (int slot=0; /*true*/; slot++) { if (slot == cache->population || addr < cache->slots[slot].addr) { // didn't find in cache if (cache->population == cache->capacity) { // must grow cache @@ -704,17 +889,17 @@ ncclResult_t ncclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, ui if (ncclIbRelaxedOrderingEnabled) flags |= IBV_ACCESS_RELAXED_ORDERING; if (fd != -1) { /* DMA-BUF support */ - NCCLCHECKGOTO(wrap_ibv_reg_dmabuf_mr(&mr, verbs->pd, offset, pages*pageSize, addr, fd, flags), res, returning); + NCCLCHECKGOTO(wrap_ibv_reg_dmabuf_mr(&mr, base->pd, offset, pages*pageSize, addr, fd, flags), res, returning); } else { if (ncclIbRelaxedOrderingEnabled) { // Use IBVERBS_1.8 API - needed for IBV_ACCESS_RELAXED_ORDERING support - NCCLCHECKGOTO(wrap_ibv_reg_mr_iova2(&mr, verbs->pd, (void*)addr, pages*pageSize, addr, flags), res, returning); + NCCLCHECKGOTO(wrap_ibv_reg_mr_iova2(&mr, base->pd, (void*)addr, pages*pageSize, addr, flags), res, returning); } else { - NCCLCHECKGOTO(wrap_ibv_reg_mr(&mr, verbs->pd, (void*)addr, pages*pageSize, flags), res, returning); + NCCLCHECKGOTO(wrap_ibv_reg_mr(&mr, base->pd, (void*)addr, pages*pageSize, flags), res, returning); } } - TRACE(NCCL_INIT,"regAddr %llx size %lld rkey %x fd %d", (unsigned long long)addr, (long long)pages*pageSize, mr->rkey, fd); + TRACE(NCCL_INIT|NCCL_NET,"regAddr=0x%lx size=%lld rkey=0x%x lkey=0x%x fd=%d", (unsigned long)addr, (long long)pages*pageSize, mr->rkey, mr->lkey, fd); if (slot != cache->population) memmove(cache->slots+slot+1, cache->slots+slot, (cache->population-slot)*sizeof(struct ncclIbMr)); cache->slots[slot].addr = addr; cache->slots[slot].pages = pages; @@ -724,31 +909,55 @@ ncclResult_t ncclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, ui *mhandle = (void*)mr; res = ncclSuccess; goto returning; - } else if ((addr >= cache->slots[slot].addr) && + } else if ((addr >= cache->slots[slot].addr) && ((addr-cache->slots[slot].addr)/pageSize+pages) <= cache->slots[slot].pages) { cache->slots[slot].refs += 1; - *mhandle = (void*)cache->slots[slot].mr; + *mhandle = cache->slots[slot].mr; res = ncclSuccess; goto returning; } } returning: - pthread_mutex_unlock(&ncclIbDevs[verbs->dev].lock); + pthread_mutex_unlock(&ncclIbDevs[base->ibDevN].lock); return res; } +struct ncclIbNetCommDevBase* ncclIbGetNetCommDevBase(ncclIbNetCommBase* base, int devIndex) { + if (base->isSend) { + struct ncclIbSendComm* sComm = (struct ncclIbSendComm*) base; + return &sComm->devs[devIndex].base; + } else { + struct ncclIbRecvComm* rComm = (struct ncclIbRecvComm*) base; + return &rComm->devs[devIndex].base; + } +} + +/* DMA-BUF support */ +ncclResult_t ncclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) { + assert(size > 0); + struct ncclIbNetCommBase* base = (struct ncclIbNetCommBase*) comm; + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) malloc(sizeof(struct ncclIbMrHandle)); + for (int i = 0; i < base->ndevs; i++) { + // Each ncclIbNetCommDevBase is at different offset in send and recv netComms + struct ncclIbNetCommDevBase* devComm = ncclIbGetNetCommDevBase(base, i); + NCCLCHECK(ncclIbRegMrDmaBufInternal(devComm, data, size, type, offset, fd, mhandleWrapper->mrs + i)); + } + *mhandle = (void*) mhandleWrapper; + return ncclSuccess; +} + ncclResult_t ncclIbRegMr(void* comm, void* data, size_t size, int type, void** mhandle) { return ncclIbRegMrDmaBuf(comm, data, size, type, 0ULL, -1, mhandle); } + ncclResult_t ncclIbRegMr_v7(void* comm, void* data, int size, int type, void** mhandle) { return ncclIbRegMrDmaBuf(comm, data, (size_t)size, type, 0ULL, -1, mhandle); } -ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) { - struct ncclIbVerbs* verbs = (struct ncclIbVerbs*)comm; - struct ncclIbMrCache* cache = &ncclIbDevs[verbs->dev].mrCache; +ncclResult_t ncclIbDeregMrInternal(ncclIbNetCommDevBase* base, struct ibv_mr* mhandle) { + struct ncclIbMrCache* cache = &ncclIbDevs[base->ibDevN].mrCache; ncclResult_t res; - pthread_mutex_lock(&ncclIbDevs[verbs->dev].lock); + pthread_mutex_lock(&ncclIbDevs[base->ibDevN].lock); for (int i=0; i < cache->population; i++) { if (mhandle == cache->slots[i].mr) { if (0 == --cache->slots[i].refs) { @@ -758,7 +967,7 @@ ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) { cache->slots = NULL; cache->capacity = 0; } - NCCLCHECKGOTO(wrap_ibv_dereg_mr((struct ibv_mr*)mhandle), res, returning); + NCCLCHECKGOTO(wrap_ibv_dereg_mr(mhandle), res, returning); } res = ncclSuccess; goto returning; @@ -767,10 +976,22 @@ ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) { WARN("NET/IB: could not find mr %p inside cache of %d entries", mhandle, cache->population); res = ncclInternalError; returning: - pthread_mutex_unlock(&ncclIbDevs[verbs->dev].lock); + pthread_mutex_unlock(&ncclIbDevs[base->ibDevN].lock); return res; } +ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) { + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandle; + struct ncclIbNetCommBase* base = (struct ncclIbNetCommBase*) comm; + for (int i = 0; i < base->ndevs; i++) { + // Each ncclIbNetCommDevBase is at different offset in send and recv netComms + struct ncclIbNetCommDevBase* devComm = ncclIbGetNetCommDevBase(base, i); + NCCLCHECK(ncclIbDeregMrInternal(devComm, mhandleWrapper->mrs[i])); + } + free(mhandleWrapper); + return ncclSuccess; +} + NCCL_PARAM(IbSplitDataOnQps, "IB_SPLIT_DATA_ON_QPS", 0); ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { @@ -780,21 +1001,18 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { if (nreqs > NCCL_NET_IB_MAX_RECVS) return ncclInternalError; uint64_t wr_id = 0ULL; - for (int r=0; rwrs+r; memset(wr, 0, sizeof(struct ibv_send_wr)); struct ibv_sge* sge = comm->sges+r; sge->addr=(uintptr_t)reqs[r]->send.data; - sge->lkey=reqs[r]->send.lkey; wr->opcode = IBV_WR_RDMA_WRITE; wr->send_flags = 0; wr->wr.rdma.remote_addr = slots[r].addr; - wr->wr.rdma.rkey = slots[r].rkey; - wr->next = wr+1; - wr_id += (reqs[r] - comm->verbs.reqs) << (r*8); + wr->next = wr + 1; + wr_id += (reqs[r] - comm->base.reqs) << (r*8); } // Write size as immediate data. In the case of multi-send, only write @@ -819,7 +1037,6 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { if (nreqs > 1) { // Write remote sizes Fifo lastWr->wr.rdma.remote_addr = comm->remSizesFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(int); - lastWr->wr.rdma.rkey = comm->remSizesFifo.rkey; lastWr->num_sge = 1; lastWr->sg_list = &comm->remSizesFifo.sge; } @@ -832,23 +1049,40 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { // Multi-QP: make sure IB writes are multiples of 128B so that LL and LL128 protocols still work const int align = 128; - const int nqps = ncclParamIbSplitDataOnQps() ? comm->nqps : 1; - for (int q=0; qbase.nqps : comm->base.ndevs; + for (int i = 0; i < nqps; i++) { + int qpIndex = comm->base.qpIndex; + ncclIbQp* qp = comm->base.qps + qpIndex; + int devIndex = qp->devIndex; for (int r=0; rdevs[devIndex].base); + + // Select proper rkey (needed even for 0-size send) + comm->wrs[r].wr.rdma.rkey = slots[r].rkeys[qp->remDevIdx]; + int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, nqps), align) * align; int length = MIN(reqs[r]->send.size-reqs[r]->send.offset, chunkSize); if (length <= 0) { comm->wrs[r].sg_list = NULL; comm->wrs[r].num_sge = 0; } else { + // Select proper lkey + comm->sges[r].lkey = reqs[r]->send.lkeys[devIndex]; comm->sges[r].length = length; comm->wrs[r].sg_list = comm->sges+r; comm->wrs[r].num_sge = 1; } } + + if (nreqs > 1) { + // Also make sure lastWr writes remote sizes using the right lkey + comm->remSizesFifo.sge.lkey = comm->remSizesFifo.mrs[devIndex]->lkey; + lastWr->wr.rdma.rkey = comm->remSizesFifo.rkeys[devIndex]; + } + struct ibv_send_wr* bad_wr; - NCCLCHECK(wrap_ibv_post_send(comm->qps[comm->qpIndex], comm->wrs, &bad_wr)); - comm->qpIndex = (comm->qpIndex+1)%comm->nqps; + NCCLCHECK(wrap_ibv_post_send(qp->qp, comm->wrs, &bad_wr)); for (int r=0; rsend.size, nqps), align) * align; @@ -856,6 +1090,9 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { comm->sges[r].addr += chunkSize; comm->wrs[r].wr.rdma.remote_addr += chunkSize; } + + // Select the next qpIndex + comm->base.qpIndex = (comm->base.qpIndex+1) % comm->base.nqps; } return ncclSuccess; @@ -863,16 +1100,16 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; - if (comm->ready == 0) { WARN("NET/IB: ncclIbIsend() called when comm->ready == 0"); return ncclInternalError; } - if (comm->ready == 0) { *request = NULL; return ncclSuccess; } + if (comm->base.ready == 0) { WARN("NET/IB: ncclIbIsend() called when comm->base.ready == 0"); return ncclInternalError; } + if (comm->base.ready == 0) { *request = NULL; return ncclSuccess; } - struct ibv_mr* mr = (struct ibv_mr*)mhandle; + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandle; // Wait for the receiver to have posted the corresponding receive int nreqs = 0; volatile struct ncclIbSendFifo* slots; - int slot = (comm->fifoHead)%MAX_REQUESTS; + int slot = (comm->fifoHead) % MAX_REQUESTS; struct ncclIbRequest** reqs = comm->fifoReqs[slot]; slots = comm->fifo[slot]; uint64_t idx = comm->fifoHead+1; @@ -884,29 +1121,45 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh for (int r=0; r slots[r].size) size = slots[r].size; + if (size > slots[r].size) size = slots[r].size; // Sanity checks - if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkey == 0) { + if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkeys[0] == 0) { char line[SOCKET_NAME_MAXLEN + 1]; union ncclSocketAddress addr; - ncclSocketGetAddr(&comm->sock, &addr); - WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %d addr %lx rkey %x", - r, nreqs, tag, ncclSocketToString(&addr, line, 1), slots[r].size, slots[r].addr, slots[r].rkey); - return ncclInternalError; + ncclSocketGetAddr(&comm->base.sock, &addr); + WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %d addr %lx rkeys[0]=%x", + r, nreqs, tag, ncclSocketToString(&addr, line, 1), slots[r].size, slots[r].addr, slots[r].rkeys[0]); } struct ncclIbRequest* req; - NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); + NCCLCHECK(ncclIbGetRequest(&comm->base, &req)); req->type = NCCL_NET_IB_REQ_SEND; - req->sock = &comm->sock; - req->verbs = &comm->verbs; + req->sock = &comm->base.sock; + req->base = &comm->base; req->nreqs = nreqs; req->send.size = size; req->send.data = data; - req->send.lkey = mr->lkey; req->send.offset = 0; - req->events = ncclParamIbSplitDataOnQps() ? comm->nqps : 1; - req->events = ncclParamIbSplitDataOnQps() ? comm->nqps : 1; - if (comm->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET) req->gidInfo = &comm->gidInfo; + + // Populate events + int nEvents = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.ndevs; + int qpIndex = comm->base.qpIndex; + // Count down + while (nEvents > 0) { + ncclIbQp* qp = comm->base.qps + qpIndex; + int devIndex = qp->devIndex; + ncclIbAddEvent(req, devIndex, &comm->devs[devIndex].base); + // Track the valid lkey for this RDMA_Write + req->send.lkeys[devIndex] = mhandleWrapper->mrs[devIndex]->lkey; + nEvents--; + // Don't update comm->base.qpIndex yet, we need to run through this same set of QPs inside ncclIbMultiSend() + qpIndex = (qpIndex+1)%comm->base.nqps; + } + + // Store all lkeys + for (int i = 0; i < comm->base.ndevs; i++) { + req->send.lkeys[i] = mhandleWrapper->mrs[i]->lkey; + } + *request = reqs[r] = req; // If this is a multi-recv, send only when all requests have matched. @@ -938,10 +1191,19 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int for (int i=0; irecv.sizes[i] = 0; struct ncclIbSendFifo* localElem = comm->remFifo.elems[slot]; + // Select the next devIndex (local) and QP to use for posting this CTS message + // Since QPs are initialized by striping across devIndex, we can simply assign this to the same value + ncclIbQp* ctsQp = comm->base.qps + comm->base.devIndex; + comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.ndevs; + for (int i=0; irkey; + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandles[i]; + + // Send all applicable rkeys + for (int j = 0; j < comm->base.ndevs; j++) + localElem[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey; + localElem[i].nreqs = n; localElem[i].size = sizes[i]; // Sanity/Debugging localElem[i].tag = tags[i]; @@ -949,11 +1211,17 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int } wr.wr.rdma.remote_addr = comm->remFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(struct ncclIbSendFifo); - wr.wr.rdma.rkey = comm->remFifo.rkey; - comm->remFifo.sge.addr = (uint64_t)localElem; - comm->remFifo.sge.length = n*sizeof(struct ncclIbSendFifo); - wr.sg_list = &comm->remFifo.sge; + + // Lookup the correct fifoRkey + wr.wr.rdma.rkey = comm->base.remDevs[ctsQp->remDevIdx].fifoRkey; + + // Set the correct sge properties + comm->devs[ctsQp->devIndex].fifoSge.addr = (uint64_t)localElem; + comm->devs[ctsQp->devIndex].fifoSge.length = n*sizeof(struct ncclIbSendFifo); + wr.sg_list = &comm->devs[ctsQp->devIndex].fifoSge; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_WRITE; wr.send_flags = comm->remFifo.flags; // IBV_SEND_INLINE @@ -978,14 +1246,17 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int // polling it will empty the Send Queue, can be posted) // - The status of all posted Send Request is considered unknown // - if (slot == 0) { + // slot == devIndex - When writing to fifo slot N, and this QP lives on device index N, it should send signalled. + // This works out that each fifo posting QP gets drained + if (slot == ctsQp->devIndex) { + wr.send_flags |= IBV_SEND_SIGNALED; - wr.wr_id = req - comm->verbs.reqs; - req->events++; + wr.wr_id = req - comm->base.reqs; + ncclIbAddEvent(req, ctsQp->devIndex, &comm->devs[ctsQp->devIndex].base); } struct ibv_send_wr* bad_wr; - NCCLCHECK(wrap_ibv_post_send(comm->qps[0], &wr, &bad_wr)); + NCCLCHECK(wrap_ibv_post_send(ctsQp->qp, &wr, &bad_wr)); comm->remFifo.fifoTail++; return ncclSuccess; @@ -993,41 +1264,48 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; - if (comm->ready == 0) { WARN("NET/IB: ncclIbIrecv() called when comm->ready == 0"); return ncclInternalError; } - if (comm->ready == 0) { *request = NULL; return ncclSuccess; } + if (comm->base.ready == 0) { WARN("NET/IB: ncclIbIrecv() called when comm->base.ready == 0"); return ncclInternalError; } + if (comm->base.ready == 0) { *request = NULL; return ncclSuccess; } + if (n > NCCL_NET_IB_MAX_RECVS) return ncclInternalError; struct ncclIbRequest* req; - NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); + NCCLCHECK(ncclIbGetRequest(&comm->base, &req)); req->type = NCCL_NET_IB_REQ_RECV; - req->sock = &comm->sock; + req->sock = &comm->base.sock; req->nreqs = n; - if (comm->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET) req->gidInfo = &comm->gidInfo; + + for (int i = 0; i < comm->base.ndevs; i++) { + req->devBases[i] = &comm->devs[i].base; + } struct ibv_recv_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = req - comm->verbs.reqs; + wr.wr_id = req - comm->base.reqs; wr.sg_list = NULL; wr.num_sge = 0; TIME_START(1); - const int nqps = ncclParamIbSplitDataOnQps() ? comm->nqps : 1; - for (int q=0; qqps[comm->qpIndex]; - struct ibv_recv_wr* bad_wr; - NCCLCHECK(wrap_ibv_post_recv(qp, &wr, &bad_wr)); - comm->qpIndex = (comm->qpIndex+1)%comm->nqps; + // Select either all QPs, or one qp per-device + const int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.ndevs; + + // Post recvs + struct ibv_recv_wr* bad_wr; + for (int i = 0; i < nqps; i++) { + struct ncclIbQp* qp = comm->base.qps + comm->base.qpIndex; + ncclIbAddEvent(req, qp->devIndex, &comm->devs[qp->devIndex].base); + NCCLCHECK(wrap_ibv_post_recv(qp->qp, &wr, &bad_wr)); + comm->base.qpIndex = (comm->base.qpIndex+1)%comm->base.nqps; } TIME_STOP(1); - req->events = nqps; - - *request = req; // Post to FIFO to notify sender TIME_START(2); NCCLCHECK(ncclIbPostFifo(comm, n, data, sizes, tags, mhandles, req)); TIME_STOP(2); + + *request = req; return ncclSuccess; } @@ -1035,30 +1313,34 @@ ncclResult_t ncclIbIflush(void* recvComm, int n, void** data, int* sizes, void** struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; int last = -1; for (int i=0; igpuFlush.enabled == 0 || last == -1) return ncclSuccess; + if (comm->flushEnabled == 0 || last == -1) return ncclSuccess; // Only flush once using the last non-zero receive struct ncclIbRequest* req; - NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); + NCCLCHECK(ncclIbGetRequest(&comm->base, &req)); req->type = NCCL_NET_IB_REQ_FLUSH; - req->sock = &comm->sock; - struct ibv_mr* mr = (struct ibv_mr*)mhandles[last]; - - struct ibv_send_wr wr; - memset(&wr, 0, sizeof(wr)); - wr.wr_id = req - comm->verbs.reqs; - - wr.wr.rdma.remote_addr = (uint64_t)data[last]; - wr.wr.rdma.rkey = mr->rkey; - wr.sg_list = &comm->gpuFlush.sge; - wr.num_sge = 1; - wr.opcode = IBV_WR_RDMA_READ; - wr.send_flags = IBV_SEND_SIGNALED; - - TIME_START(4); - struct ibv_send_wr* bad_wr; - NCCLCHECK(wrap_ibv_post_send(comm->gpuFlush.qp, &wr, &bad_wr)); - TIME_STOP(4); + req->sock = &comm->base.sock; + struct ncclIbMrHandle* mhandle = (struct ncclIbMrHandle*) mhandles[last]; + + // We don't know which devIndex the recv was on, so we flush on all devices + for (int i = 0; i < comm->base.ndevs; i++) { + struct ibv_send_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = req - comm->base.reqs; + + wr.wr.rdma.remote_addr = (uint64_t)data[last]; + wr.wr.rdma.rkey = mhandle->mrs[i]->rkey; + wr.sg_list = &comm->devs[i].gpuFlush.sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_READ; + wr.send_flags = IBV_SEND_SIGNALED; + + TIME_START(4); + struct ibv_send_wr* bad_wr; + NCCLCHECK(wrap_ibv_post_send(comm->devs[i].gpuFlush.qp.qp, &wr, &bad_wr)); + TIME_STOP(4); + ncclIbAddEvent(req, i, &comm->devs[i].base); + } *request = req; return ncclSuccess; @@ -1069,7 +1351,8 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { *done = 0; while (1) { - if (r->events == 0) { + if (r->events[0] == 0 && r->events[1] == 0) { + TRACE(NCCL_NET, "r=%p done", r); *done = 1; if (sizes && r->type == NCCL_NET_IB_REQ_RECV) { for (int i=0; inreqs; i++) sizes[i] = r->recv.sizes[i]; @@ -1078,65 +1361,98 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { sizes[0] = r->send.size; } + if (sizes && r->type == NCCL_NET_IB_REQ_SEND) { + sizes[0] = r->send.size; + } + NCCLCHECK(ncclIbFreeRequest(r)); return ncclSuccess; } + int totalWrDone = 0; int wrDone = 0; struct ibv_wc wcs[4]; - TIME_START(3); - NCCLCHECK(wrap_ibv_poll_cq(r->verbs->cq, 4, wcs, &wrDone)); - if (wrDone == 0) { TIME_CANCEL(3); } else { TIME_STOP(3); } - if (wrDone == 0) return ncclSuccess; - - for (int w=0; wstatus != IBV_WC_SUCCESS) { - char line[SOCKET_NAME_MAXLEN+1]; - union ncclSocketAddress addr; - ncclSocketGetAddr(r->sock, &addr); - char localGidString[INET6_ADDRSTRLEN] = ""; - char remoteGidString[INET6_ADDRSTRLEN] = ""; - const char* localGidStr = NULL, *remoteGidStr = NULL; - if (r->gidInfo) { - localGidStr = inet_ntop(AF_INET6, &r->gidInfo->localGid, localGidString, sizeof(localGidString)); - remoteGidStr = inet_ntop(AF_INET6, &r->gidInfo->remoteGid, remoteGidString, sizeof(remoteGidString)); - } - WARN("NET/IB : Got completion from peer %s with error %d, opcode %d, len %d, vendor err %d (%s)%s%s%s%s", - ncclSocketToString(&addr, line, 1), wc->status, wc->opcode, wc->byte_len, wc->vendor_err, reqTypeStr[r->type], - localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGid ":"", remoteGidString); - return ncclRemoteError; - } + for (int i = 0; i < NCCL_IB_MAX_DEVS_PER_NIC; i++) { + TIME_START(3); + // If we expect any completions from this device's CQ + if (r->events[i]) { + NCCLCHECK(wrap_ibv_poll_cq(r->devBases[i]->cq, 4, wcs, &wrDone)); + totalWrDone += wrDone; + if (wrDone == 0) { TIME_CANCEL(3); } else { TIME_STOP(3); } + if (wrDone == 0) continue; + for (int w=0; wstatus != IBV_WC_SUCCESS) { + union ncclSocketAddress addr; + ncclSocketGetAddr(r->sock, &addr); + char localGidString[INET6_ADDRSTRLEN] = ""; + char remoteGidString[INET6_ADDRSTRLEN] = ""; + const char* localGidStr = NULL, *remoteGidStr = NULL; + if (r->devBases[i]->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET) { + localGidStr = inet_ntop(AF_INET6, &r->devBases[i]->gidInfo.localGid, localGidString, sizeof(localGidString)); + remoteGidStr = inet_ntop(AF_INET6, &r->base->remDevs[i].remoteGid, remoteGidString, sizeof(remoteGidString)); + } + + char line[SOCKET_NAME_MAXLEN+1]; + WARN("NET/IB : Got completion from peer %s with status=%d opcode=%d len=%d vendor err %d (%s)%s%s%s%s", + ncclSocketToString(&addr, line, 1), wc->status, wc->opcode, wc->byte_len, wc->vendor_err, reqTypeStr[r->type], + localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGids":"", remoteGidString); + return ncclRemoteError; + } - struct ncclIbRequest* req = r->verbs->reqs+(wc->wr_id & 0xff); - if (req->type == NCCL_NET_IB_REQ_SEND) { - for (int i=0; inreqs; i++) { - struct ncclIbRequest* sendReq = r->verbs->reqs+((wc->wr_id >> (i*8)) & 0xff); - if ((sendReq->events <= 0)) return ncclInternalError; - sendReq->events--; - } - } else { - if (req && wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) { - if (req->type != NCCL_NET_IB_REQ_RECV) return ncclInternalError; - if (req->nreqs == 1) { - req->recv.sizes[0] += wc->imm_data; + union ncclSocketAddress addr; + ncclSocketGetAddr(r->sock, &addr); + struct ncclIbRequest* req = r->base->reqs+(wc->wr_id & 0xff); + + #ifdef ENABLE_TRACE + char line[SOCKET_NAME_MAXLEN+1]; + TRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%d wr_id=%d r=%p type=%d events={%d,%d}, i=%d", + ncclSocketToString(&addr, line, 1), wc->status, wc->opcode,wc->byte_len, wc->wr_id, req, req->type, req->events[0], req->events[1], i); + #endif + if (req->type == NCCL_NET_IB_REQ_SEND) { + for (int j = 0; j < req->nreqs; j++) { + struct ncclIbRequest* sendReq = r->base->reqs+((wc->wr_id >> (j*8)) & 0xff); + if ((sendReq->events[i] <= 0)) { + WARN("NET/IB: sendReq(%p)->events={%d,%d}, i=%d, j=%d <= 0", sendReq, sendReq->events[0], sendReq->events[1], i, j); + return ncclInternalError; + } + sendReq->events[i]--; + } + } else { + if (req && wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) { + if (req->type != NCCL_NET_IB_REQ_RECV) { + WARN("NET/IB: wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM and req->type=%d", req->type); + return ncclInternalError; + } + if (req->nreqs == 1) { + req->recv.sizes[0] += wc->imm_data; + } + } + req->events[i]--; } } - req->events--; } } + + // If no CQEs found on any device, return and come back later + if (totalWrDone == 0) return ncclSuccess; } } ncclResult_t ncclIbCloseSend(void* sendComm) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; if (comm) { - NCCLCHECK(ncclSocketClose(&comm->sock)); - for (int q=0; qnqps; q++) - if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q])); - if (comm->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->fifoMr)); - if (comm->remSizesFifo.mr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->remSizesFifo.mr)); - NCCLCHECK(ncclIbDestroyVerbs(&comm->verbs)); + NCCLCHECK(ncclSocketClose(&comm->base.sock)); + + for (int q = 0; q < comm->base.nqps; q++) + if (comm->base.qps[q].qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->base.qps[q].qp)); + + for (int i = 0; i < comm->base.ndevs; i++) { + struct ncclIbSendCommDev* commDev = comm->devs + i; + if (commDev->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->fifoMr)); + if (comm->remSizesFifo.mrs[i] != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->remSizesFifo.mrs[i])); + NCCLCHECK(ncclIbDestroyBase(&commDev->base)); + } free(comm); } TIME_PRINT("IB"); @@ -1146,16 +1462,21 @@ ncclResult_t ncclIbCloseSend(void* sendComm) { ncclResult_t ncclIbCloseRecv(void* recvComm) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; if (comm) { - NCCLCHECK(ncclSocketClose(&comm->sock)); - for (int q=0; qnqps; q++) - if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q])); - if (comm->gpuFlush.enabled) { - if (comm->gpuFlush.qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->gpuFlush.qp)); - if (comm->gpuFlush.hostMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->gpuFlush.hostMr)); + NCCLCHECK(ncclSocketClose(&comm->base.sock)); + + for (int q = 0; q < comm->base.nqps; q++) + if (comm->base.qps[q].qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->base.qps[q].qp)); + + for (int i = 0; i < comm->base.ndevs; i++) { + struct ncclIbRecvCommDev* commDev = comm->devs + i; + if (comm->flushEnabled) { + if (commDev->gpuFlush.qp.qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(commDev->gpuFlush.qp.qp)); + if (commDev->gpuFlush.hostMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->gpuFlush.hostMr)); + } + if (commDev->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->fifoMr)); + if (commDev->sizesFifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->sizesFifoMr)); + NCCLCHECK(ncclIbDestroyBase(&commDev->base)); } - if (comm->remFifo.mr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->remFifo.mr)); - if (comm->sizesFifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->sizesFifoMr)); - NCCLCHECK(ncclIbDestroyVerbs(&comm->verbs)); free(comm); } return ncclSuccess; diff --git a/src/ibvwrap.c b/src/ibvwrap.c index 9d210638..53735adf 100644 --- a/src/ibvwrap.c +++ b/src/ibvwrap.c @@ -24,10 +24,24 @@ } \ return ncclSuccess; +#define IBV_INT_CHECK_RET_ERRNO_OPTIONAL(call, success_retval, name, supported) \ + int ret = call; \ + if (ret == ENOTSUP || ret == EOPNOTSUPP) { \ + INFO(NCCL_NET, "Call to " name " failed with error %s errno %d", strerror(ret), ret); \ + *supported = 0; \ + return ncclSuccess; \ + } else if (ret != success_retval) { \ + WARN("Call to " name " failed with error %s errno %d", strerror(ret), ret); \ + *supported = 1; \ + return ncclSystemError; \ + } \ + *supported = 1; \ + return ncclSuccess; + #define IBV_INT_CHECK_RET_ERRNO(call, success_retval, name) \ int ret = call; \ if (ret != success_retval) { \ - WARN("Call to " name " failed with error %s", strerror(ret)); \ + WARN("Call to " name " failed with error %s errno %d", strerror(ret), ret); \ return ncclSystemError; \ } \ return ncclSuccess; @@ -169,6 +183,26 @@ ncclResult_t wrap_ibv_post_recv(struct ibv_qp *qp, struct ibv_recv_wr *wr, struc return ncclSuccess; } +ncclResult_t wrap_ibv_query_ece(struct ibv_qp *qp, struct ibv_ece *ece, int* supported) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/ +#if HAVE_DECL_IBV_QUERY_ECE + IBV_INT_CHECK_RET_ERRNO_OPTIONAL(ibv_query_ece(qp, ece), 0, "ibv_query_ece", supported); +#else + INFO(NCCL_NET, "Call to ibv_query_ece is skipped, doesn't exist"); + *supported = 0; + return ncclSuccess; +#endif +} + +ncclResult_t wrap_ibv_set_ece(struct ibv_qp *qp, struct ibv_ece *ece, int* supported) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/ +#if HAVE_DECL_IBV_SET_ECE + IBV_INT_CHECK_RET_ERRNO_OPTIONAL(ibv_set_ece(qp, ece), 0, "ibv_set_ece", supported); +#else + INFO(NCCL_NET, "Call to ibv_set_ece skipped, doesn't exist"); + *supported = 0; + return ncclSuccess; +#endif +} + ncclResult_t wrap_ibv_event_type_str(char **ret, enum ibv_event_type event) { *ret = (char *) ibv_event_type_str(event); return ncclSuccess; diff --git a/src/p2p_plugin.c b/src/p2p_plugin.c index 8a257b52..b781b9cc 100644 --- a/src/p2p_plugin.c +++ b/src/p2p_plugin.c @@ -32,6 +32,7 @@ extern ncclNet_v5_t ibPlugin_v5; pthread_mutex_t nccl_p2p_lock = PTHREAD_MUTEX_INITIALIZER; ncclDebugLogger_t pluginLogFunction; +static int ncclNMergedIbDevs = -1; #ifdef HAVE_SHARP_PLUGIN extern int ncclNSharpDevs; @@ -144,7 +145,7 @@ ncclResult_t pluginInit_v5(ncclDebugLogger_t logFunction) { return ncclNetPlugin_v5.init(logFunction); } -ncclResult_t nccl_p2p_gdr_support(int dev) +ncclResult_t nccl_p2p_gdr_support() { static int module_loaded = -1; @@ -169,13 +170,20 @@ ncclResult_t nccl_p2p_dmabuf_support(int dev) { ncclResult_t res; struct ibv_pd* pd; struct ibv_context* ctx; - ctx = ncclIbDevs[dev].context; - NCCLCHECKGOTO(wrap_ibv_alloc_pd(&pd, ctx), res, failure); - // Test kernel DMA-BUF support with a dummy call (fd=-1) - (void) wrap_direct_ibv_reg_dmabuf_mr(pd, 0ULL/*offset*/, 0ULL/*len*/, 0ULL/*iova*/, -1/*fd*/, 0/*flags*/); - // ibv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not supported (EBADF otherwise) - dmaBufSupported = (errno != EOPNOTSUPP && errno != EPROTONOSUPPORT) ? 1 : 0; - NCCLCHECKGOTO(wrap_ibv_dealloc_pd(pd), res, failure); + struct ncclIbMergedDev* mergedDev = ncclIbMergedDevs + dev; + + // Test each dev + for (int i = 0; i < mergedDev->ndevs; i++) { + int ibDev = mergedDev->devs[i]; + ctx = ncclIbDevs[ibDev].context; + NCCLCHECKGOTO(wrap_ibv_alloc_pd(&pd, ctx), res, failure); + // Test kernel DMA-BUF support with a dummy call (fd=-1) + (void) wrap_direct_ibv_reg_dmabuf_mr(pd, 0ULL/*offset*/, 0ULL/*len*/, 0ULL/*iova*/, -1/*fd*/, 0/*flags*/); + // ibv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not supported (EBADF otherwise) + dmaBufSupported = (errno != EOPNOTSUPP && errno != EPROTONOSUPPORT) ? 1 : 0; + NCCLCHECKGOTO(wrap_ibv_dealloc_pd(pd), res, failure); + } + } if (dmaBufSupported == 0) return ncclSystemError; return ncclSuccess; @@ -185,13 +193,19 @@ ncclResult_t nccl_p2p_dmabuf_support(int dev) { } -ncclResult_t nccl_p2p_ib_get_properties(nccl_ib_dev_t *devs, int dev, ncclNetProperties_t* props) +ncclResult_t nccl_p2p_ib_get_properties(ncclIbDev *devs, int dev, ncclNetProperties_t* props) { - props->name = devs[dev].devName; - props->pciPath = devs[dev].pciPath; - props->guid = devs[dev].guid; + struct ncclIbMergedDev* mergedDev = ncclIbMergedDevs+dev; + props->name = mergedDev->devName; + props->speed = mergedDev->speed; + + // Take the rest of the properties from an arbitrary sub-device (should be the same) + struct ncclIbDev* ibDev = ncclIbDevs + mergedDev->devs[0]; + props->pciPath = ibDev->pciPath; + props->guid = ibDev->guid; + props->ptrSupport = NCCL_PTR_HOST; - if (nccl_p2p_gdr_support(dev) == ncclSuccess) { + if (nccl_p2p_gdr_support() == ncclSuccess) { props->ptrSupport |= NCCL_PTR_CUDA; // GDR support via nv_peermem INFO(NCCL_NET,"NET/IB : GPU Direct RDMA (nvidia-peermem) enabled for HCA %d '%s", dev, devs[dev].devName); } @@ -201,10 +215,10 @@ ncclResult_t nccl_p2p_ib_get_properties(nccl_ib_dev_t *devs, int dev, ncclNetPro INFO(NCCL_NET,"NET/IB : GPU Direct RDMA (DMABUF) enabled for HCA %d '%s", dev, devs[dev].devName); } - props->speed = devs[dev].speed; props->latency = 0; // Not set - props->port = devs[dev].port + devs[dev].realPort; - props->maxComms = devs[dev].maxQp; + props->port = ibDev->portNum + ibDev->realPort; + props->maxComms = ibDev->maxQp; + if (p2p_plugin == NCCL_P2P_IB || p2p_plugin == NCCL_P2P_UCX) { props->maxRecvs = NCCL_NET_IB_MAX_RECVS; } else { @@ -217,14 +231,14 @@ ncclResult_t nccl_p2p_ib_get_properties(nccl_ib_dev_t *devs, int dev, ncclNetPro } static void* ncclIbAsyncThreadMain(void* args) { - struct ibv_context* context = (struct ibv_context*)args; + struct ncclIbDev* dev = (struct ncclIbDev*)args; while (1) { struct ibv_async_event event; - if (ncclSuccess != wrap_ibv_get_async_event(context, &event)) { break; } + if (ncclSuccess != wrap_ibv_get_async_event(dev->context, &event)) { break; } char *str; if (ncclSuccess != wrap_ibv_event_type_str(&str, event.event_type)) { break; } if (event.event_type != IBV_EVENT_COMM_EST) - WARN("NET/IB : Got async event : %s", str); + WARN("NET/IB : %s:%d Got async event : %s", dev->devName, dev->portNum, str); if (ncclSuccess != wrap_ibv_ack_async_event(&event)) { break; } } return NULL; @@ -240,7 +254,26 @@ int devSharpCompare(const void *a, const void *b) else { return 1; } } -ncclResult_t nccl_p2p_ib_init(int *num_devs, nccl_ib_dev_t *ncclIbDevs, char *ncclIbIfName, union ncclSocketAddress *ncclIbIfAddr, pthread_t *ncclIbAsyncThread, ncclDebugLogger_t logFunction) +// Compare ncclIbDev[dev] to all stored mergedIbDevs +int ncclIbFindMatchingDev(int dev) { + for (int i = 0; i < ncclNMergedIbDevs; i++) { + if (ncclIbMergedDevs[i].ndevs < NCCL_IB_MAX_DEVS_PER_NIC) { + int compareDev = ncclIbMergedDevs[i].devs[0]; + if (strcmp(ncclIbDevs[dev].pciPath, ncclIbDevs[compareDev].pciPath) == 0 && + (ncclIbDevs[dev].guid == ncclIbDevs[compareDev].guid) && + (ncclIbDevs[dev].link == ncclIbDevs[compareDev].link)) { + TRACE(NCCL_NET, "NET/IB: Matched name1=%s pciPath1=%s guid1=0x%lx link1=%u name2=%s pciPath2=%s guid2=0x%lx link2=%u", + ncclIbDevs[dev].devName, ncclIbDevs[dev].pciPath, ncclIbDevs[dev].guid, ncclIbDevs[dev].link, + ncclIbDevs[compareDev].devName, ncclIbDevs[compareDev].pciPath, ncclIbDevs[compareDev].guid, ncclIbDevs[compareDev].link); + return i; + } + } + } + + return ncclNMergedIbDevs; +} + +ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIbIfName, union ncclSocketAddress *ncclIbIfAddr, pthread_t *ncclIbAsyncThread, ncclDebugLogger_t logFunction) { int ncclNIbDevs = *num_devs; @@ -250,6 +283,7 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, nccl_ib_dev_t *ncclIbDevs, char *nc wrap_ibv_fork_init(); if (ncclNIbDevs == -1) { ncclNIbDevs = 0; + ncclNMergedIbDevs = 0; ncclNSharpDevs = 0; if (ncclFindInterfaces(ncclIbIfName, ncclIbIfAddr, MAX_IF_NAME_SIZE, 1) != 1) { WARN("NET/IB : No IP interface found."); @@ -283,10 +317,10 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, nccl_ib_dev_t *ncclIbDevs, char *nc if (ncclSuccess != wrap_ibv_close_device(context)) { return ncclInternalError; } continue; } - for (int port = 1; port <= devAttr.phys_port_cnt; port++) { + for (int port_num = 1; port_num <= devAttr.phys_port_cnt; port_num++) { struct ibv_port_attr portAttr; - if (ncclSuccess != wrap_ibv_query_port(context, port, &portAttr)) { - WARN("NET/IB : Unable to query port %d", port); + if (ncclSuccess != wrap_ibv_query_port(context, port_num, &portAttr)) { + WARN("NET/IB : Unable to query port_num %d", port_num); continue; } if (portAttr.state != IBV_PORT_ACTIVE) continue; @@ -294,15 +328,13 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, nccl_ib_dev_t *ncclIbDevs, char *nc && portAttr.link_layer != IBV_LINK_LAYER_ETHERNET) continue; // check against user specified HCAs/ports - if (! (matchIfList(devices[d]->name, port, userIfs, nUserIfs, searchExact) ^ searchNot)) { + if (! (matchIfList(devices[d]->name, port_num, userIfs, nUserIfs, searchExact) ^ searchNot)) { continue; } - TRACE(NCCL_INIT|NCCL_NET,"NET/IB: [%d] %s:%d/%s ", d, devices[d]->name, port, - portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE"); pthread_mutex_init(&ncclIbDevs[ncclNIbDevs].lock, NULL); ncclIbDevs[ncclNIbDevs].device = d; ncclIbDevs[ncclNIbDevs].guid = devAttr.sys_image_guid; - ncclIbDevs[ncclNIbDevs].port = port; + ncclIbDevs[ncclNIbDevs].portNum = port_num; ncclIbDevs[ncclNIbDevs].link = portAttr.link_layer; ncclIbDevs[ncclNIbDevs].speed = nccl_p2p_ib_speed(portAttr.active_speed) * nccl_p2p_ib_width(portAttr.active_width); ncclIbDevs[ncclNIbDevs].context = context; @@ -327,12 +359,37 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, nccl_ib_dev_t *ncclIbDevs, char *nc ncclIbDevs[ncclNIbDevs].maxQp = ncclParamSharpMaxComms(); ncclNSharpDevs++; } - + TRACE(NCCL_NET,"NET/IB: [%d] %s:%s:%d/%s speed=%d context=%p pciPath=%s ar=%d", d, devices[d]->name, devices[d]->dev_name, ncclIbDevs[ncclNIbDevs].portNum, + portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE", ncclIbDevs[ncclNIbDevs].speed, context, ncclIbDevs[ncclNIbDevs].pciPath, ncclIbDevs[ncclNIbDevs].ar); if (ncclIbAsyncThread != NULL) { - pthread_create(ncclIbAsyncThread, NULL, ncclIbAsyncThreadMain, context); + pthread_create(ncclIbAsyncThread, NULL, ncclIbAsyncThreadMain, ncclIbDevs + ncclNIbDevs); ncclSetThreadName(*ncclIbAsyncThread, "NCCL IbAsync %2d", ncclNIbDevs); pthread_detach(*ncclIbAsyncThread); // will not be pthread_join()'d } + + int mergedDev = ncclNMergedIbDevs; + if (ncclParamIbMergeNics()) { + mergedDev = ncclIbFindMatchingDev(ncclNIbDevs); + } + + // No matching dev found, create new mergedDev entry (it's okay if there's only one dev inside) + if (mergedDev == ncclNMergedIbDevs) { + // Set ndevs to 1, assign first ibDevN to the current IB device + ncclIbMergedDevs[mergedDev].ndevs = 1; + ncclIbMergedDevs[mergedDev].devs[0] = ncclNIbDevs; + ncclNMergedIbDevs++; + strncpy(ncclIbMergedDevs[mergedDev].devName, ncclIbDevs[ncclNIbDevs].devName, MAXNAMESIZE); + // Matching dev found, edit name + } else { + // Set next device in this array to the current IB device + int ndevs = ncclIbMergedDevs[mergedDev].ndevs; + ncclIbMergedDevs[mergedDev].devs[ndevs] = ncclNIbDevs; + ncclIbMergedDevs[mergedDev].ndevs++; + snprintf(ncclIbMergedDevs[mergedDev].devName + strlen(ncclIbMergedDevs[mergedDev].devName), MAXNAMESIZE+1, "+%s", ncclIbDevs[ncclNIbDevs].devName); + } + + // Aggregate speed + ncclIbMergedDevs[mergedDev].speed += ncclIbDevs[ncclNIbDevs].speed; ncclNIbDevs++; nPorts++; } @@ -348,33 +405,48 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, nccl_ib_dev_t *ncclIbDevs, char *nc qsort(ncclIbDevs, ncclNIbDevs, sizeof(struct ncclIbDev), devSharpCompare); } - char line[1024]; + char line[2048]; line[0] = '\0'; // Determine whether RELAXED_ORDERING is enabled and possible ncclIbRelaxedOrderingEnabled = ncclIbRelaxedOrderingCapable(); - for (int d=0; dndevs > 1) { + // Print out merged dev info + snprintf(line+strlen(line), 2047-strlen(line), " [%d]={", d); + for (int i = 0; i < mergedDev->ndevs; i++) { + int ibDev = mergedDev->devs[i]; + snprintf(line+strlen(line), 2047-strlen(line), "[%d] %s:%d/%s%s", ibDev, ncclIbDevs[ibDev].devName, + ncclIbDevs[ibDev].portNum, ncclIbDevs[ibDev].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE", + // Insert comma to delineate + i == (mergedDev->ndevs - 1) ? "" : ", "); + } + snprintf(line+strlen(line), 2047-strlen(line), "}"); + } else { + int ibDev = mergedDev->devs[0]; #ifdef HAVE_SHARP_PLUGIN - snprintf(line+strlen(line), 1023-strlen(line), " [%d]%s:%d/%s%s", d, ncclIbDevs[d].devName, - ncclIbDevs[d].port, ncclIbDevs[d].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE", - ncclIbDevs[d].isSharpDev ? "/SHARP" : ""); + snprintf(line+strlen(line), 2047-strlen(line), " [%d]%s:%d/%s%s", ibDev, ncclIbDevs[ibDev].devName, + ncclIbDevs[ibDev].portNum, ncclIbDevs[ibDev].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE", + ncclIbDevs[ibDev].isSharpDev ? "/SHARP" : ""); #else - snprintf(line+strlen(line), 1023-strlen(line), " [%d]%s:%d/%s", d, ncclIbDevs[d].devName, - ncclIbDevs[d].port, ncclIbDevs[d].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE"); + snprintf(line+strlen(line), 2047-strlen(line), " [%d]%s:%d/%s", ibDev, ncclIbDevs[ibDev].devName, + ncclIbDevs[ibDev].portNum, ncclIbDevs[ibDev].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE"); #endif + } } - line[1023] = '\0'; + line[2047] = '\0'; char addrline[SOCKET_NAME_MAXLEN+1]; INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s %s; OOB %s:%s", line, ncclIbRelaxedOrderingEnabled ? "[RO]" : "", ncclIbIfName, ncclSocketToString(ncclIbIfAddr, addrline, 1)); } - *num_devs = ncclNIbDevs; + *num_devs = ncclNMergedIbDevs; pthread_mutex_unlock(&nccl_p2p_lock); } return ncclSuccess; } -ncclResult_t nccl_p2p_ib_pci_path(nccl_ib_dev_t *devs, int num_devs, char* dev_name, char** path, int* real_port) +ncclResult_t nccl_p2p_ib_pci_path(ncclIbDev *devs, int num_devs, char* dev_name, char** path, int* real_port) { char device_path[PATH_MAX]; snprintf(device_path, PATH_MAX, "/sys/class/infiniband/%s/device", dev_name); diff --git a/src/param.c b/src/param.c index c9653319..06605a0f 100755 --- a/src/param.c +++ b/src/param.c @@ -64,7 +64,7 @@ void ncclLoadParam(char const* env, int64_t deftVal, int64_t uninitialized, int6 static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; pthread_mutex_lock(&mutex); if (__atomic_load_n(cache, __ATOMIC_RELAXED) == uninitialized) { - char* str = getenv(env); + const char* str = ncclGetEnv(env); int64_t value = deftVal; if (str && strlen(str) > 0) { errno = 0; @@ -80,3 +80,9 @@ void ncclLoadParam(char const* env, int64_t deftVal, int64_t uninitialized, int6 } pthread_mutex_unlock(&mutex); } + +const char *ncclGetEnv(const char *name) { + static pthread_once_t once = PTHREAD_ONCE_INIT; + pthread_once(&once, initEnv); + return getenv(name); +} diff --git a/src/socket.c b/src/socket.c index bbb17905..ea711f85 100755 --- a/src/socket.c +++ b/src/socket.c @@ -13,6 +13,7 @@ #include #include #include +#include "param.h" static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int block, int* closed) { int bytes = 0; @@ -35,7 +36,7 @@ static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr } } (*offset) += bytes; - if (sock->abortFlag && *sock->abortFlag != 0) { + if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) { INFO(NCCL_NET, "socketProgressOpt: abort called"); return ncclInternalError; } @@ -86,7 +87,7 @@ static uint16_t socketToPort(union ncclSocketAddress *addr) { /* Allow the user to force the IPv4/IPv6 interface selection */ static int envSocketFamily(void) { int family = -1; // Family selection is not forced, will use first one found - char* env = getenv("NCCL_SOCKET_FAMILY"); + const char* env = ncclGetEnv("NCCL_SOCKET_FAMILY"); if (env == NULL) return family; @@ -327,7 +328,7 @@ int ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int ifNa // Allow user to force the INET socket family selection int sock_family = envSocketFamily(); // User specified interface - char* env = getenv("NCCL_SOCKET_IFNAME"); + const char* env = ncclGetEnv("NCCL_SOCKET_IFNAME"); if (env && strlen(env) > 1) { INFO(NCCL_ENV, "NCCL_SOCKET_IFNAME set by environment to %s", env); // Specified by user : find or fail @@ -339,9 +340,11 @@ int ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int ifNa nIfs = findInterfaces("ib", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); // else see if we can get some hint from COMM ID if (nIfs == 0) { - char* commId = getenv("NCCL_COMM_ID"); + const char* commId = ncclGetEnv("NCCL_COMM_ID"); if (commId && strlen(commId) > 1) { - INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", commId); + INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", commId); + // Try to find interface that is in the same subnet as the IP in comm id + // Try to find interface that is in the same subnet as the IP in comm id union ncclSocketAddress idAddr; ncclSocketGetAddrFromString(&idAddr, commId); @@ -530,6 +533,8 @@ static ncclResult_t socketPollConnect(struct ncclSocket* sock) { sock->state = ncclSocketStateConnecting; } else if (ret != EINPROGRESS) { sock->state = ncclSocketStateError; + char line[SOCKET_NAME_MAXLEN+1]; + WARN("socketPollConnect: Connect to %s returned %d(%s) errno %d(%s)", ncclSocketToString(&sock->addr, line, 1), ret, strerror(ret), errno, strerror(errno)); return ncclSystemError; } return ncclSuccess; @@ -619,12 +624,12 @@ ncclResult_t ncclSocketConnect(struct ncclSocket* sock) { do { NCCLCHECK(socketProgressState(sock)); } while (sock->asyncFlag == 0 && - (sock->abortFlag == NULL || *sock->abortFlag == 0) && + (sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED) == 0) && (sock->state == ncclSocketStateConnecting || sock->state == ncclSocketStateConnectPolling || sock->state == ncclSocketStateConnected)); - if (sock->abortFlag && *sock->abortFlag != 0) return ncclInternalError; + if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) return ncclInternalError; switch (sock->state) { case ncclSocketStateConnecting: @@ -666,11 +671,11 @@ ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listen do { NCCLCHECKGOTO(socketProgressState(sock), ret, exit); } while (sock->asyncFlag == 0 && - (sock->abortFlag == NULL || *sock->abortFlag == 0) && + (sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED) == 0) && (sock->state == ncclSocketStateAccepting || sock->state == ncclSocketStateAccepted)); - if (sock->abortFlag && *sock->abortFlag != 0) return ncclInternalError; + if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) return ncclInternalError; switch (sock->state) { case ncclSocketStateAccepting: diff --git a/src/ucx_plugin.c b/src/ucx_plugin.c index e578696c..68e8dfee 100644 --- a/src/ucx_plugin.c +++ b/src/ucx_plugin.c @@ -278,7 +278,7 @@ static ncclResult_t ucx_init_context(ucp_context_h *ctx, int dev) { if (ucp_ctx[dev] == NULL) { snprintf(ucx_dev_name, PATH_MAX, "%s:%d", ncclIbDevs[dev].devName, - ncclIbDevs[dev].port); + ncclIbDevs[dev].portNum); UCXCHECK(ucp_config_read("NCCL", NULL, &config)); UCXCHECK(ucp_config_modify(config, "NET_DEVICES", ucx_dev_name)); diff --git a/src/ucx_rma_plugin.c b/src/ucx_rma_plugin.c index d1b2debe..42243ef7 100644 --- a/src/ucx_rma_plugin.c +++ b/src/ucx_rma_plugin.c @@ -242,7 +242,7 @@ static ncclResult_t nccl_ucx_rma_init_ucp(int dev, ucp_context_h *ctx) char ucx_dev_name[PATH_MAX]; snprintf(ucx_dev_name, PATH_MAX, "%s:%d", ncclIbDevs[dev].devName, - ncclIbDevs[dev].port); + ncclIbDevs[dev].portNum); UCXCHECK(ucp_config_read("NCCL", NULL, &config)); UCXCHECK(ucp_config_modify(config, "NET_DEVICES", ucx_dev_name));