Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vit training with TP/PP #146

Open
wants to merge 1 commit into
base: main_before_rebase
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions examples/pretrain_vit_distributed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/bin/bash

GPUS_PER_NODE=2
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

DATA_PATH=/workspace/dataset_image
CHECKPOINT_PATH=./ckpt

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

python -m torch.distributed.run $DISTRIBUTED_ARGS \
pretrain_vit.py \
--num-classes 200 \
--num-layers 6 \
--hidden-size 128 \
--num-attention-heads 8 \
--kv-channels 64 \
--ffn-hidden-size 3072 \
--encoder-seq-length 197 \
--decoder-seq-length 128 \
--micro-batch-size 128 \
--global-batch-size 1024 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--tensor-model-parallel-size 2 \
--pipeline-model-parallel-size 1 \
--lr-decay-iters 1000000 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--data-impl mmap \
--split 949,50,1 \
--lr 0.0001 \
--min-lr 0.00001 \
--lr-decay-style linear \
--lr-warmup-fraction .01 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--log-interval 100 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--tensorboard-dir vit-nl-hs-nh \
--fp16
109 changes: 109 additions & 0 deletions examples/run_deepspeed_vit_zero2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/bin/bash
set -ex

#BASE_PATH=/vc_data/Megatron-LM/data
#DATA_PATH=${BASE_PATH}/indexed_datasets/megatron
DATA_PATH=/workspace/dataset_image
DS_CONFIG=ds_config.json

TP=1
PP=1
NLAYERS=24
HIDDEN=512

GLOBAL_BATCH=4
MICRO_BATCH=2

ZERO_STAGE=2
OFFLOAD_DEVICE="cpu"
CPU_OPTIM=" --cpu-optimizer"

OUTPUT_DIR=ds_z${ZERO_STAGE}_nl${NLAYERS}_hs${HIDDEN}_gb${GLOBAL_BATCH}_mb${MICRO_BATCH}
#OUTPUT_DIR=baseline_nl${NLAYERS}_hs${HIDDEN}_gb${GLOBAL_BATCH}_mb${MICRO_BATCH}
mkdir -p $OUTPUT_DIR

cat <<EOT > $DS_CONFIG
{
"train_batch_size" : $GLOBAL_BATCH,
"train_micro_batch_size_per_gpu": $MICRO_BATCH,
"steps_per_print": 10,

"zero_optimization": {
"stage": $ZERO_STAGE,
"stage3_max_live_parameters": 3e9,
"stage3_max_reuse_distance": 3e9,
"stage3_param_persistence_threshold": 1e5,
"stage3_prefetch_bucket_size": 5e7,
"contiguous_gradients": true,
"overlap_comm": true,
"reduce_bucket_size": 90000000,
"sub_group_size": 1e8,
"offload_optimizer": {
"device": "$OFFLOAD_DEVICE",
"buffer_count": 4,
"pipeline_read": false,
"pipeline_write": false,
"pin_memory": true
}

},

"fp16": {
"enabled": true,
"initial_scale_power": 12
},
"wall_clock_breakdown": true,
"zero_allow_untested_optimizer": false,
"aio": {
"block_size": 1048576,
"queue_depth": 16,
"single_submit": false,
"overlap_events": true,
"thread_count": 2
}
}
EOT

export NCCL_DEBUG=warn

ds_args=""
ds_args=" --deepspeed ${ds_args}"
ds_args=" --no-pipeline-parallel ${ds_args}"
ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}"
ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}"
ds_args=" --deepspeed-activation-checkpointing ${ds_args}"
export NCCL_SOCKET_IFNAME=eth0

deepspeed pretrain_vit.py \
--tensor-model-parallel-size $TP \
--pipeline-model-parallel-size $PP \
--num-layers $NLAYERS \
--hidden-size $HIDDEN \
--num-attention-heads 16 \
--seq-length 256 \
--loss-scale 12 \
--max-position-embeddings 1024 \
--data-impl mmap \
--micro-batch-size $MICRO_BATCH \
--global-batch-size $GLOBAL_BATCH \
--train-iters 1000 \
--lr 6.0e-5 \
--min-lr 6.0e-6 \
--lr-decay-style cosine \
--log-interval 1 \
--eval-iters 40 \
--eval-interval 1 \
--data-path $DATA_PATH \
--save-interval 1000 \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.006 \
--fp16 \
--checkpoint-activations \
--tensorboard-dir $OUTPUT_DIR \
$CPU_OPTIM $ds_args \
--exit-interval 5000 | tee ${OUTPUT_DIR}/output.log

12 changes: 10 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,14 +584,22 @@ class ParallelTransformerLayerPipe(ParallelTransformerLayer):
If no mask is provided, the module will query `self._args.attn_mask`
for the mask and only return `super().forward(...)`
"""
def __init__(self, isvit=False, **kwargs):
super(ParallelTransformerLayerPipe, self).__init__(**kwargs)
self.isvit = isvit


def forward(self, inputs, **kwargs):
assert torch.is_tensor(inputs) or isinstance(inputs, tuple)
if torch.is_tensor(inputs) or len(inputs) == 1:
# No attention mask forwarded, search for args.attn_mask
if not hasattr(self, '_args'):
self._args = get_args()
hidden_states, attention_mask = inputs, self._args.attn_mask
# HACK: currently MoE model does not support pipeline parallel, so
if not self.isvit:
hidden_states, attention_mask = inputs, self._args.attn_mask
else:
hidden_states, attention_mask = inputs, None
# HACK: currently MoE model does not support pipeline parallel, so
# here we just ignore the moe_loss returned by forward()
return super().forward(hidden_states, attention_mask, **kwargs)[0]
elif len(inputs) == 2:
Expand Down
148 changes: 148 additions & 0 deletions megatron/model/vit_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from megatron import get_args
from megatron.model.utils import init_method_normal
import math
import torch
import torch.nn.functional as F
import einops
from .module import MegatronModule
from deepspeed.accelerator import get_accelerator
from megatron.mpu.utils import ClsUtility
from megatron.mpu.initialize import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
from megatron.mpu.mappings import reduce_from_tensor_model_parallel_region
from megatron.mpu.layers import ColumnParallelLinear

def twod_interpolate_position_embeddings_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):

args = get_args()
num_patches_per_dim = args.img_dim // args.patch_dim
num_patches = num_patches_per_dim ** 2
seq_length = num_patches + 1
hidden_size = args.hidden_size

key = prefix + "weight"
# import pdb
# pdb.set_trace()
assert key in state_dict
if key in state_dict:
input_param = state_dict[key]

assert input_param.shape[1] == hidden_size
if input_param.shape[0] != seq_length:
# update input_param and load it to state_dict[key]

num_tok_input = input_param.shape[0] - 1
num_tok_new = seq_length - 1
input_param_tok, input_param_grid = (
input_param[:1, :],
input_param[1:, :],
)

gs_input = int(math.sqrt(num_tok_input))
gs_new = int(math.sqrt(num_tok_new))

input_param_grid = input_param_grid.transpose(0, 1).contiguous()
input_param_grid = input_param_grid.reshape(
(1, -1, gs_input, gs_input)
)
input_param_grid = input_param_grid.float()
scale_factor = gs_new / gs_input

input_param_grid = F.interpolate(
input_param_grid, scale_factor=scale_factor, mode="bilinear"
)

input_param_grid = input_param_grid.half()
input_param_grid = input_param_grid.reshape((-1, gs_new * gs_new))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()

assert input_param_grid.shape[1] == hidden_size
input_param = torch.cat((input_param_tok, input_param_grid), dim=0)
assert (
input_param.shape[0] == seq_length
and input_param.shape[1] == hidden_size
)

state_dict[key] = input_param

class VitEmbedding(MegatronModule):
def __init__(self):
super(VitEmbedding, self).__init__(share_word_embeddings=False)
args = get_args()
self.hidden_size = args.hidden_size
self.patch_dim = args.patch_dim
self.img_dim = args.img_dim

assert self.img_dim % self.patch_dim == 0
self.num_patches_per_dim = self.img_dim // self.patch_dim
self.num_patches = self.num_patches_per_dim ** 2
self.seq_length = self.num_patches + 1
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
# cls_token
self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
torch.nn.init.zeros_(self.cls_token)

self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()


# Linear encoder
self.linear_encoder = ColumnParallelLinear(
self.flatten_dim, self.hidden_size, gather_output=True
)

# embedding
self.position_embeddings = torch.nn.Embedding(
self.seq_length, self.hidden_size
)
init_method_normal(args.init_method_std)(
self.position_embeddings.weight
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).to(get_accelerator().device_name())

self.position_embeddings._register_load_state_dict_pre_hook(
twod_interpolate_position_embeddings_hook
)

self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)

def forward(self, x):
x = einops.rearrange(
x,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_dim,
p2=self.patch_dim,
)

assert x.dtype == torch.half
x, _ = self.linear_encoder(x)
# Reduce across all the model parallel GPUs.

cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)

x = x + self.position_embeddings(self.position_ids)
x = self.embedding_dropout(x)
return x


class VitEmbeddingPipe(VitEmbedding):

def forward(self, inputs, **kwargs):
if not hasattr(self, '_args'):
self._args = get_args()


embeddings = super().forward(inputs)
return embeddings

@property
def linear_encoder_weight(self):
"""Easy accessory for the DeepSpeed pipeline engine to tie embeddings across stages."""
return self.linear_encoder.weight
Loading