Skip to content

Commit

Permalink
fix pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
season0528 committed Aug 26, 2024
1 parent 3e1a30b commit d712899
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 9 deletions.
4 changes: 1 addition & 3 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,7 @@ def load_hf_weights(folder: str, model: nn.Module):
new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)

# embedding
if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (
not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
):
if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)):
new_state_dict["embedding.weight"] = torch.chunk(
state_dict.pop("model.embed_tokens.weight"),
split_size,
Expand Down
4 changes: 1 addition & 3 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,7 @@ def load_hf_weights(folder: str, model: nn.Module):
new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)

# embedding
if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (
not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
):
if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)):
new_state_dict["tok_embeddings.weight"] = torch.chunk(
state_dict.pop("model.tok_embeddings.weight"),
split_size,
Expand Down
4 changes: 1 addition & 3 deletions internlm/model/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,7 @@ def load_hf_weights(folder: str, model: nn.Module):
new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)

# embedding
if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (
not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
):
if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)):
new_state_dict["tok_embeddings.weight"] = torch.chunk(
state_dict.pop("model.embed_tokens.weight"),
split_size,
Expand Down

0 comments on commit d712899

Please sign in to comment.