Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tool]GQA convert support #454

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples_deepspeed/finetune_hf_llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ The pre-trained weights can be found at [Hugging Face - LLAMA-7B](https://huggin

#### 1. Converting Hugging Face Model Weights to Megatron-Deepspeed Model
```bash
bash examples_deepspeed/finetune_hf_llama/finetune_llama.sh convert
bash examples_deepspeed/finetune_hf_llama/finetune_llama.sh convert_hf2mds
```
This command writes the Hugging Face model weights into the Megatron-Deepspeed model and saves it. You can adjust the parallel configuration in the script.
This command writes the Hugging Face model weights into the Megatron-Deepspeed model and saves it. You can adjust the parallel configuration in the script.```convert_mds2hf``` can convert a Megatron-Deepspeed model into the Hugging Face format

#### 2. Fine-tuning Process
```bash
Expand Down
8 changes: 7 additions & 1 deletion examples_deepspeed/finetune_hf_llama/ds_config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
{
"train_batch_size" : 256,
"train_micro_batch_size_per_gpu": 16,
"steps_per_print": 1
"steps_per_print": 100,
"zero_optimization": {
"stage": 0
},
"bf16": {
"enabled": true
}
}
5 changes: 5 additions & 0 deletions examples_deepspeed/finetune_hf_llama/ds_config_empty.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"train_batch_size" : 256,
"train_micro_batch_size_per_gpu": 16,
"steps_per_print": 100
}
10 changes: 9 additions & 1 deletion examples_deepspeed/finetune_hf_llama/finetune_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ cat <<EOT > $DS_CONFIG
}
EOT

if [ "$1" = "convert_hf2mds" ]; then
DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config_empty.json"
elif [ "$1" = "convert_mds2hf" ]; then
DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config_empty.json"
else
DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config.json"
fi

covert_hf2mds_args="deepspeed tools/hf2megads_weight_converter.py \
--hf-ckpt-num-shards 2 \
Expand All @@ -69,6 +76,7 @@ comm_args="--tensor-model-parallel-size $TP \
--num-layers $NUM_LAYERS \
--hidden-size $HIDDEN_SIZE \
--num-attention-heads $NUM_HEADS \
--finetune \
--ffn-hidden-size $FFN_HIDDEN_SIZE \
--attention-dropout 0 \
--hidden-dropout 0 \
Expand Down Expand Up @@ -97,7 +105,7 @@ comm_args="--tensor-model-parallel-size $TP \
--zero-stage 0 \
--tokenizer-type HFTokenizer \
--tokenizer-model $HF_LLAMA_PATH \
--deepspeed_config ./examples_deepspeed/finetune_hf_llama/ds_config.json \
--deepspeed_config $DS_CONFIG_PATH \
--deepspeed \
--distributed-backend nccl \
--num-workers 0 \
Expand Down
66 changes: 39 additions & 27 deletions tools/hf2megads_weight_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,28 +193,43 @@ def _qkv_refactor(self, pname, p, hf_layer):
wk = self.hf_model[hf_wk_name]
wv = self.hf_model[hf_wv_name]

hidden_size = wq.shape[0]
per_partition_size, start_index, end_index = compute_partition_range(
hidden_size, self.tp_rank, self.tp_size)
hidden_size_per_attention_head = divide(hidden_size,
query_hidden_size = wq.shape[0]
kv_hidden_size = wk.shape[0]

per_partition_size, start_qindex, end_index = compute_partition_range(
query_hidden_size, self.tp_rank, self.tp_size)
_,start_kvindex, _= compute_partition_range(
kv_hidden_size, self.tp_rank, self.tp_size)

hidden_size_per_attention_head = divide(query_hidden_size,
self.config.num_attention_heads)
num_attention_heads_per_partition = divide(self.config.num_attention_heads,
self.tp_size)

new_w = torch.zeros((per_partition_size * 3, wq.shape[1]), dtype=wq.dtype)
num_kv_heads_per_partition= divide(self.config.num_key_value_heads,
self.tp_size)
qkv_size=(num_attention_heads_per_partition+2*num_kv_heads_per_partition)*hidden_size_per_attention_head
num_qheads_per_group=divide(self.config.num_attention_heads,self.config.num_key_value_heads)
num_groups =divide(num_attention_heads_per_partition,num_qheads_per_group)
new_w = torch.zeros((qkv_size, wq.shape[1]), dtype=wq.dtype)

for i in range(num_groups):
query_current_index=start_qindex+i*num_qheads_per_group*hidden_size_per_attention_head
query_next_index=query_current_index+num_qheads_per_group*hidden_size_per_attention_head
kv_current_index=start_kvindex+i*hidden_size_per_attention_head
kv_next_kvindex=kv_current_index+hidden_size_per_attention_head

new_w_index=i* (num_qheads_per_group+2)*hidden_size_per_attention_head

for i in range(num_attention_heads_per_partition):
current_index = start_index + i * hidden_size_per_attention_head
next_index = current_index + hidden_size_per_attention_head
new_w_index = i * (3 * hidden_size_per_attention_head)
new_w[new_w_index: new_w_index + (3 * hidden_size_per_attention_head), :] = \
new_w[new_w_index:new_w_index+(num_qheads_per_group+2)*hidden_size_per_attention_head,:]=\
torch.cat([
wq[current_index: next_index, :],
wk[current_index: next_index, :],
wv[current_index: next_index, :]
], dim=0)
wq[query_current_index:query_next_index,:],
wk[kv_current_index:kv_next_kvindex,:],
wv[kv_current_index:kv_next_kvindex,:]
],dim=0)

self.record_mapping_info(
f"mega-ds:{pname,p.data.shape}<--hf{hf_wq_name,hf_wk_name,hf_wv_name,} cat q,k,v [{current_index}:{next_index},:] of q,k,v{wq.shape}"
f"mega-ds:{pname,p.data.shape}<--hf{hf_wq_name,hf_wk_name,hf_wv_name,} cat q,k,v [{query_current_index}:{query_next_index},:] of q,k,v{wq.shape}"
)
return new_w

Expand Down Expand Up @@ -383,17 +398,18 @@ def _qkv_refactor_to_hf(self, pname, ds_w, hf_layer):
hidden_size = oldshape[-1]
hidden_size_per_attention_head = divide(hidden_size,
self.config.num_attention_heads)
num_attention_heads_per_partition = divide(self.config.num_attention_heads,
self.tp_size)
newshape = (self.tp_size, num_attention_heads_per_partition, 3, hidden_size_per_attention_head, hidden_size)
# MHA & GQA
group = divide(self.config.num_attention_heads, self.config.num_key_value_heads)
newshape = (self.config.num_key_value_heads, group + 2, hidden_size_per_attention_head, hidden_size)
ds_w_out = ds_w_all_rank.reshape(*newshape)
self.hf_dict[hf_q_name] = copy.deepcopy(ds_w_out[:, :, 0, :, :].reshape(-1, oldshape[-1]))
self.hf_dict[hf_k_name] = copy.deepcopy(ds_w_out[:, :, 1, :, :].reshape(-1, oldshape[-1]))
self.hf_dict[hf_v_name] = copy.deepcopy(ds_w_out[:, :, 2, :, :].reshape(-1, oldshape[-1]))
query_weight, key_weight, value_weight = torch.split(ds_w_out, [group, 1, 1], dim=1)
self.hf_dict[hf_q_name] = copy.deepcopy(query_weight.reshape(-1, hidden_size))
self.hf_dict[hf_k_name] = copy.deepcopy(key_weight.reshape(-1, hidden_size))
self.hf_dict[hf_v_name] = copy.deepcopy(value_weight.reshape(-1, hidden_size))
del query_weight, key_weight, value_weight


def transform_from_megads_to_hf(self):
use_gqa = True if self.num_attention_heads != self.num_key_value_heads else False

for pname, p in self.ds_model.named_parameters():
if pname in [
Expand All @@ -411,11 +427,7 @@ def transform_from_megads_to_hf(self):
subname = mobj.group(2)
hf_layer = layer_num - self.offset_num
if subname in ["self_attention.query_key_value.weight"]:
if not use_gqa:
self._qkv_refactor_to_hf(pname, p, hf_layer)
else:
#TODO(billishyahao): Not impl yet ...
assert False
self._qkv_refactor_to_hf(pname, p, hf_layer)
elif subname in ["mlp.dense_h_to_4h.weight"]:
self._mlphto4h_dense_refactor_to_hf(pname, p, hf_layer)
elif subname in [
Expand Down