Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(linear.py): linear module uneven split is forbidden #374

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,17 @@ def args_sanity_check():
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name)
if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name
assert (
gpc.config.VOCAB_SIZE % gpc.config.parallel.tensor.size == 0
), "VOCAB_SIZE must be integer multiple of tensor parallel size"
if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name:
assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp"
assert (
torch.__version__ >= "2.1.0"
), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}"
assert (
gpc.config.VOCAB_SIZE % gpc.config.parallel.weight.size == 0
), "VOCAB_SIZE must be integer multiple of wp size"

assert gpc.config.parallel["tensor"].get("mode", None) in [
TensorParallelMode.mtp.name,
Expand Down
2 changes: 2 additions & 0 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,10 @@ def __init__(
split_features = out_features if split_mode == "column" else in_features
multiple = split_features // multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
# uneven split is forbidden
div = multiple // world_size
mod = multiple % world_size
assert mod == 0, "linear module uneven split is forbidden"
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple = div + int(rank < mod)

Expand Down
1 change: 1 addition & 0 deletions tests/test_training/test_forward_output_no_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TOTAL_STEPS = 1
config = Config(
dict(
VOCAB_SIZE=92544,
parallel=dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
Expand Down
Loading