Skip to content

Commit

Permalink
fix(linear.py): linear module uneven split is forbidden (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 authored Nov 27, 2024
1 parent aee457c commit 6e7163c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
6 changes: 6 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,17 @@ def args_sanity_check():
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name)
if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name
assert (
gpc.config.VOCAB_SIZE % gpc.config.parallel.tensor.size == 0
), "VOCAB_SIZE must be integer multiple of tensor parallel size"
if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name:
assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp"
assert (
torch.__version__ >= "2.1.0"
), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}"
assert (
gpc.config.VOCAB_SIZE % gpc.config.parallel.weight.size == 0
), "VOCAB_SIZE must be integer multiple of wp size"

assert gpc.config.parallel["tensor"].get("mode", None) in [
TensorParallelMode.mtp.name,
Expand Down
2 changes: 2 additions & 0 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,10 @@ def __init__(
split_features = out_features if split_mode == "column" else in_features
multiple = split_features // multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
# uneven split is forbidden
div = multiple // world_size
mod = multiple % world_size
assert mod == 0, "linear module uneven split is forbidden"
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple = div + int(rank < mod)

Expand Down
1 change: 1 addition & 0 deletions tests/test_training/test_forward_output_no_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TOTAL_STEPS = 1
config = Config(
dict(
VOCAB_SIZE=92544,
parallel=dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
Expand Down

0 comments on commit 6e7163c

Please sign in to comment.