From 74d8c7d8bf96fec298da649285d2bbdef119055d Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Mon, 2 Dec 2024 17:28:14 +0800 Subject: [PATCH] fix check CUDA_DEVICE_MAX_CONNECTIONS --- internlm/utils/common.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 56ebcfbe..8283c36e 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -250,8 +250,22 @@ def enable_pytorch_expandable_segments(): def check_cuda_env(): - if os.getenv("CUDA_DEVICE_MAX_CONNECTIONS") is None: - logger.warning("Env var CUDA_DEVICE_MAX_CONNECTIONS has not be set, please note this!") + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + max_connections = os.getenv("CUDA_DEVICE_MAX_CONNECTIONS") + assert ( + max_connections is not None + ), "Env var CUDA_DEVICE_MAX_CONNECTIONS has not been set, please set it to 1!" + assert ( + max_connections == '1' + ), "Env var CUDA_DEVICE_MAX_CONNECTIONS is set to {}, it should be set to 1!".format(max_connections) + + avoid_record_streams = os.getenv("TORCH_NCCL_AVOID_RECORD_STREAMS") + assert ( + avoid_record_streams is not None + ), "Env var TORCH_NCCL_AVOID_RECORD_STREAMS has not been set, please set it to 1!" + assert ( + avoid_record_streams == '1' + ), "Env var TORCH_NCCL_AVOID_RECORD_STREAMS is set to {}, it should be set to 1!".format(avoid_record_streams) class DummyProfile: