From 2f36436d378b03c0f3fd32bbb1f8389206a3906a Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Thu, 21 Nov 2024 13:02:48 +0800 Subject: [PATCH 1/2] fix(linear.py): linear module uneven split is forbidden --- internlm/initialize/launch.py | 6 ++++++ internlm/model/modules/linear.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index fc63b8a2..1ac8ef31 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -395,11 +395,17 @@ def args_sanity_check(): gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name) if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name + assert ( + gpc.config.VOCAB_SIZE % gpc.config.parallel.tensor.size == 0 + ), "VOCAB_SIZE must be integer multiple of tensor parallel size" if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name: assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp" assert ( torch.__version__ >= "2.1.0" ), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}" + assert ( + gpc.config.VOCAB_SIZE % gpc.config.parallel.weight.size == 0 + ), "VOCAB_SIZE must be integer multiple of wp size" assert gpc.config.parallel["tensor"].get("mode", None) in [ TensorParallelMode.mtp.name, diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 29070b42..856e6ba0 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -602,8 +602,10 @@ def __init__( split_features = out_features if split_mode == "column" else in_features multiple = split_features // multiple_of # We want to split @multiple across world_size, but it could be an uneven split + # uneven split is forbidden div = multiple // world_size mod = multiple % world_size + assert mod == 0, "linear module uneven split is forbidden" # The first @mod ranks get @div + 1 copies, the rest get @div copies local_multiple = div + int(rank < mod) From 22a2c04e8034b7e161d7c7a42846a0938be7a75e Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Thu, 21 Nov 2024 14:15:26 +0800 Subject: [PATCH 2/2] fix(test_forward_output_no_fa.py): fix ci test err --- tests/test_training/test_forward_output_no_fa.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index e846594e..d089934b 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -31,6 +31,7 @@ TOTAL_STEPS = 1 config = Config( dict( + VOCAB_SIZE=92544, parallel=dict( zero1=dict(size=-1), tensor=dict(size=1, mode="mtp"),