Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(model): update model internlm2 #47

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
get_mlp_cls,
)
from internlm.model.multi_head_attention import DistributedAttention
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
from internlm.model.utils import (
gather_forward_split_backward,
split_forward_gather_backward,
try_import_RMSNorm,
)
from internlm.solver.pipeline_utils import partition_uniform
from internlm.utils.checkpoint import activation_checkpoint
from internlm.utils.common import filter_kwargs
Expand All @@ -61,7 +65,7 @@ class MHA(nn.Module):
num_heads (int): The number of attention heads.
num_kv_heads (int): The number of attention heads for key and value.
process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`.
sequence_process_group (torch.distributed.ProcessGroup): The group for `sequence_parallel`.
sequence_process_group (torch.distributed.ProcessGroup): The process group for attention calculation.
bias (bool): Whether the bias is needed for linears. Will be used when initializing QKV matrix and
output projection. False by default.
dropout (float): The dropout rate for cross attention and self attention. 0.0 by default.
Expand Down Expand Up @@ -165,11 +169,9 @@ def __init__(
self.inner_cross_attn_softmax_scale = softmax_scale
self.inner_cross_attn_dropout = dropout

self.attn = flash_attn_varlen_kvpacked_func
if self.tp_mode == "isp":
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=sequence_process_group)
self.inner_cross_attn = DistributedAttention(
self.inner_cross_attn, sequence_process_group=sequence_process_group
)
self.attn = DistributedAttention(self.attn, sequence_process_group=sequence_process_group)

wo_cls = get_linear_cls(self.tp_mode, "row")
self.wo = wo_cls(
Expand Down Expand Up @@ -435,7 +437,7 @@ def _packed_forward(self, x, inference_params=None, **kwargs):
if kv.dtype not in [torch.float16, torch.bfloat16]:
kv = kv.to(torch.bfloat16)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
context = flash_attn_varlen_kvpacked_func(
context = self.attn(
q=q,
kv=kv,
cu_seqlens_q=kwargs["cu_seqlens"],
Expand All @@ -447,7 +449,7 @@ def _packed_forward(self, x, inference_params=None, **kwargs):
causal=self.inner_cross_attn_causal,
).to(self.dtype)
else:
context = flash_attn_varlen_kvpacked_func(
context = self.attn(
q=q,
kv=kv,
cu_seqlens_q=kwargs["cu_seqlens"],
Expand Down Expand Up @@ -962,6 +964,10 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
# if the sequence parallel mode is 'isp', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.tp_mode == "isp":
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None

for _, block in enumerate(self.layers):
Expand All @@ -977,7 +983,11 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
if hasattr(self, "norm"):
hidden_states = self.norm(hidden_states.float())
if hasattr(self, "output"):
hidden_states = self.output(hidden_states)
# Evaluation
if gpc.is_evaluating is True:
hidden_states = self.output(hidden_states, gather_dim=1)
else: # Training
hidden_states = self.output(hidden_states, gather_dim=0)

if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
Expand Down
Loading