Skip to content

Commit

Permalink
Merge pull request #9 from okoge-kaz/feature/phi-3
Browse files Browse the repository at this point in the history
Fix minor bugs
  • Loading branch information
okoge-kaz authored Jul 17, 2024
2 parents 606cdfb + 13ad141 commit f7da9a0
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 10 deletions.
8 changes: 4 additions & 4 deletions megatron_lm/megatron/core/datasets/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
text, _ = self._query_document_sample_shuffle_indices(idx)

text = torch.from_numpy(text)
text = torch.from_numpy(text).long()

tokens = text.long()
labels = tokens.clone()
tokens = text.contiguous()
labels = text.clone().contiguous()

# HF Transformers automatically shift input_ids and labels, so don't shift manually.
# ref: Mistral https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L1171-L1174
# ref: Mistral https://github.com/huggingface/transformers/blob/48d35b21789ad80a90ea242e46cb1d53e4db4f1c/src/transformers/models/mistral/modeling_mistral.py#L1211-L1212
# ref: https://discuss.huggingface.co/t/where-does-the-transformers-do-the-target-text-shifting-in-causal-lm/32408/4
# Also, if attention mask is all 1(= True), you don't have to pass attention mask.
# HF Transformers' attention mask is 1 D. so like this. [1, 1, 1, ...., 0, 0]
Expand Down
2 changes: 1 addition & 1 deletion megatron_lm/megatron/core/datasets/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ py::array build_sample_idx(const py::array_t<int32_t> &sizes_,
while (sample_index <= num_samples)
{
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + 1;
int32_t remaining_seq_length = seq_length;
while (remaining_seq_length != 0)
{
// Get the document length.
Expand Down
2 changes: 1 addition & 1 deletion scripts/gcp/yi-1.5-9b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#SBATCH --job-name=yi-1.5-9b
#SBATCH --partition=a3
#SBATCH --exclusive
#SBATCH --nodes 2
#SBATCH --nodes 1
#SBATCH --gpus-per-node=8
#SBATCH --ntasks-per-node=8
#SBATCH --output=outputs/yi-1.5-9b/%x-%j.out
Expand Down
11 changes: 7 additions & 4 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

from llama_recipes.arguments import parse_args
from llama_recipes.get_fsdp import get_sharding_strategy
from llama_recipes.utils.precision import preserve_fp32_buffers
from megatron_lm.megatron.global_vars import set_global_variables


Expand Down Expand Up @@ -105,10 +106,12 @@ def main() -> None:
print_model_size(model, args.base_model, rank) # type: ignore

# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
if args.bf16:
model.to(torch.bfloat16) # type: ignore
elif args.fp16:
model.to(torch.float16) # type: ignore
# RoPE inv_freq etc. are stored in fp32, so we need to preserve them
with preserve_fp32_buffers(model): # type: ignore
if args.bf16:
model.to(torch.bfloat16) # type: ignore
elif args.fp16:
model.to(torch.float16) # type: ignore

if args.use_freeze_layers:
print_rank_0("NOTE: freeze transformer layers")
Expand Down
23 changes: 23 additions & 0 deletions src/llama_recipes/utils/precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from contextlib import contextmanager


@contextmanager
def preserve_fp32_buffers(model: torch.nn.Module):
fp32_buffers = dict()
for name, param in model.named_buffers():
if param.dtype == torch.float32:
fp32_buffers[name] = param.clone()

# model.to(torch.float16) or model.to(torch.bfloat16)
yield

for name, param in model.named_buffers():
if name in fp32_buffers:
if "." in name:
module_name, buffer_name = name.rsplit(".", 1)
target_module = model.get_submodule(module_name)
else:
buffer_name = name
target_module = model
setattr(target_module, buffer_name, fp32_buffers[name])

0 comments on commit f7da9a0

Please sign in to comment.