Skip to content

Commit

Permalink
Add Qwen2-MoE Pipeline Parallel Model Convertor (#272)
Browse files Browse the repository at this point in the history
Co-authored-by: 同润 <[email protected]>
  • Loading branch information
jerryli1981 and 同润 authored Jun 20, 2024
1 parent 96a11ba commit e31375f
Showing 1 changed file with 136 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def load_megatron_model(args):
else:
raise ValueError
state_dict[k] = target_v

elif (
args.tensor_model_parallel_size == 1
and args.pipeline_model_parallel_size == 1
Expand Down Expand Up @@ -235,6 +236,65 @@ def load_megatron_model(args):
exit()
state_dict[k] = target_v

elif (
args.tensor_model_parallel_size > 1
and args.pipeline_model_parallel_size > 1
and args.expert_model_parallel_size > 1
and args.num_experts % args.expert_model_parallel_size == 0
):
num_layers = args.num_layers // args.pipeline_model_parallel_size
layers_to_copy = {}
for tp_rank in range(args.tensor_model_parallel_size):
for ep_rank in range(args.expert_model_parallel_size):
for pp_rank in range(args.pipeline_model_parallel_size):
layer_offset = pp_rank * num_layers
for layer in range(num_layers):
pp_layer_id = layer + layer_offset
layers_to_copy[f"decoder.layers.{layer}"] = pp_layer_id

if args.expert_model_parallel_size > 1:
checkpoint_name = get_checkpoint_name(model_path, iteration, release, True, tp_rank, pp_rank, True,
ep_rank)
elif args.expert_model_parallel_size == 1:
checkpoint_name = get_checkpoint_name(model_path, iteration, release, True, tp_rank, pp_rank,
False)
print(f'load {checkpoint_name}')
split_state = torch.load(checkpoint_name, map_location="cpu")['model']
for k, v in split_state.items():
if 'local_experts' in k and 'norm' not in k:
local_expert_rank = name_to_expert_rank(k)
expert_rank = local_expert_rank + num_local_experts * ep_rank
k = k.replace(f'local_experts.{local_expert_rank}', f'local_experts.{expert_rank}')
mid_state[k].append(v)
elif ep_rank == 0:
mid_state[k].append(v)
try:
pattern = re.compile(r'\d+')
res = pattern.findall(k)
k = re.sub(r"decoder.layers.\d+", "decoder.layers." + str(layers_to_copy["decoder.layers." + res[0]]), k)
mid_state[k].append(v)
except:
mid_state[k].append(v)
for k, v in mid_state.items():
if not isinstance(v[0], torch.Tensor) or 'norm' in k or 'router' in k or 'gate' in k:
target_v = v[0]
elif 'embedding' in k or 'output_layer' in k:
target_v = torch.cat(v, dim=0)
elif 'linear_proj' in k or 'linear_fc2' in k:
target_v = torch.cat(v, dim=1)
elif 'linear_qkv.weight' in k:
viewed = [x.view(group_per_split, -1, head_dim, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = [x.view(group_per_split, -1) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1)
elif 'linear_fc1' in k:
viewed = [x.view(2, -1, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
else:
raise ValueError
state_dict[k] = target_v

else:
raise ValueError('not support yet')

Expand Down Expand Up @@ -590,6 +650,82 @@ def save_mgmodel(mgmodel, args):
model_split[k] = target_v
save_state_dict(args, model_split, checkpoint_name)

elif (
args.pipeline_model_parallel_size > 1
and args.num_experts % args.expert_model_parallel_size == 0
):

for tp_rank in range(args.tensor_model_parallel_size):
for ep_rank in range(args.expert_model_parallel_size):
for pp_rank in range(args.pipeline_model_parallel_size):
model_split = {}
layer_offset = pp_rank * num_layers
layers_to_copy = {}
for layer in range(num_layers):
pp_layer_id = layer + layer_offset
layers_to_copy[f"decoder.layers.{pp_layer_id}"] = layer
if args.expert_model_parallel_size > 1:
checkpoint_name = get_checkpoint_name(args.save, 0, True, True, tp_rank, pp_rank, True, ep_rank)
elif args.expert_model_parallel_size == 1:
checkpoint_name = get_checkpoint_name(args.save, 0, True, True, tp_rank, pp_rank, False)
print(f'tensor_parallel & pipeline_parallel & expert_parallel, save model to {checkpoint_name}')
for k, v in full_model.items():
if check_layer(layers_to_copy, k):
pattern = re.compile(r'\d+')
res = pattern.findall(k)
k = re.sub(r"decoder.layers.\d+", "decoder.layers." + str(layers_to_copy["decoder.layers." + res[0]]), k)
elif not ("word_embeddings" in k or "output_layer" in k or "final_layernorm" in k):
continue
if not isinstance(v, torch.Tensor):
target_v = v
elif 'linear_qkv.weight' in k:
viewed = v.view(args.num_query_groups, -1, head_dim, args.hidden_size)
viewed = viewed[group_per_split * tp_rank: group_per_split * (tp_rank + 1)]
target_v = viewed.view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = v.view(args.num_query_groups, -1, head_dim)
viewed = viewed[group_per_split * tp_rank: group_per_split * (tp_rank + 1)]
target_v = viewed.view(-1)
elif 'linear_proj' in k:
seg = v.shape[1] // args.tensor_model_parallel_size
target_v = v[:, seg * tp_rank: seg * (tp_rank + 1)]
elif 'embedding' in k or 'output_layer' in k:
seg = v.shape[0] // args.tensor_model_parallel_size
target_v = v[seg * tp_rank: seg * (tp_rank + 1)]
elif 'local_experts' in k:
expert_rank = int(re.findall(pattern, k)[0])
if expert_rank // num_local_experts != ep_rank:
continue
expert_local_rank = expert_rank % num_local_experts
if 'linear_fc1' in k and 'norm' not in k:
viewed = v.view(-1, args.moe_ffn_hidden_size, args.hidden_size)
seg = args.moe_ffn_hidden_size // args.tensor_model_parallel_size
target_v = viewed[:, seg * tp_rank: seg * (tp_rank + 1), :].reshape(-1,
args.hidden_size)
elif 'linear_fc2' in k:
seg = v.shape[1] // args.tensor_model_parallel_size
target_v = v[:, seg * tp_rank: seg * (tp_rank + 1)]
k = k.replace(f'local_experts.{expert_rank}', f'local_experts.{expert_local_rank}')
elif 'shared_expert' in k and 'gate' not in k:
if 'linear_fc1' in k:
viewed = v.view(-1, args.shared_moe_ffn_hidden_size, args.hidden_size)
seg = args.shared_moe_ffn_hidden_size // args.tensor_model_parallel_size
target_v = viewed[:, seg * tp_rank: seg * (tp_rank + 1), :].reshape(-1,
args.hidden_size)
elif 'linear_fc2' in k:
seg = v.shape[1] // args.tensor_model_parallel_size
target_v = v[:, seg * tp_rank: seg * (tp_rank + 1)]
else:
target_v = v
if "word_embeddings" in k:
if pp_rank == 0:
model_split[k] = target_v
elif "output_layer" in k or "final_layernorm" in k:
if pp_rank == args.pipeline_model_parallel_size - 1:
model_split[k] = target_v
else:
model_split[k] = target_v
save_state_dict(args, model_split, checkpoint_name)

else:
raise ValueError('not support pp convert')
Expand Down

0 comments on commit e31375f

Please sign in to comment.