Skip to content

Commit

Permalink
Add Floating Point Control for Qwen2 Model Convertor (#274)
Browse files Browse the repository at this point in the history
Co-authored-by: 同润 <[email protected]>
  • Loading branch information
jerryli1981 and 同润 authored Jun 21, 2024
1 parent b9a9e7b commit 4c94f2e
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 19 deletions.
2 changes: 1 addition & 1 deletion examples/deepspeed/ds_train_huggingface_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def main():
model = AutoModelForCausalLM.from_pretrained(
args.load,
from_tf=False,
torch_dtype=torch.float16,
torch_dtype=torch.float16 if args.fp16 else torch.bfloat16,
trust_remote_code=True
)

Expand Down
10 changes: 7 additions & 3 deletions examples/qwen2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ TARGET_CKPT_PATH=$3 # 目标路径
TP=$4 # 模型并行度
PP=$5 # 流水并行度
EP=$6 # 专家并行度
USE_TE=$7 # 是否使用Transformer Engine建模
mg2hf=$8 # 是否执行mcore2hf转换
HG_CKPT_PATH=$9 # HF的CKPT的路径
PR=$7 # 转换精度
USE_TE=$8 # 是否使用Transformer Engine建模
mg2hf=$9 # 是否执行mcore2hf转换
HG_CKPT_PATH=${10} # HF的CKPT的路径
```


Expand Down Expand Up @@ -122,6 +123,7 @@ sh hf2mcore_qwen2_convertor.sh \
1 \
1 \
1 \
fp32 \
true \
false
```
Expand Down Expand Up @@ -233,6 +235,7 @@ A14B \
4 \
1 \
4 \
fp32 \
true \
false
```
Expand Down Expand Up @@ -310,6 +313,7 @@ bash hf2mcore_qwen2_convertor.sh \
1 \
1 \
1 \
fp32 \
true \
true \
/mnt/qwen-ckpts/Qwen2-0.5B
Expand Down
18 changes: 13 additions & 5 deletions megatron_patch/model/qwen2/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,20 @@ def build_layer(layer_spec, layer_number):
# self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])

if self.post_process and self.post_layer_norm:
use_te = self.config.transformer_impl == "transformer_engine"
# Final layer norm before output.
self.final_layernorm = Qwen2RMSNorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
if use_te:
self.final_layernorm = TENorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
else:
self.final_layernorm = Qwen2RMSNorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)

def _get_layer(self, layer_number: int):
return self.layers[layer_number]
Expand Down
2 changes: 2 additions & 0 deletions megatron_patch/model/qwen2/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
@dataclass
class Qwen2TransformerConfig(TransformerConfig):

transformer_impl: str = 'transformer_engine'

moe_ffn_hidden_size: int = None

shared_moe_ffn_hidden_size: int = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ TARGET_CKPT_PATH=$3
TP=$4
PP=$5
EP=$6
USE_TE=$7
MG2HF=$8
HF_CKPT_PATH=$9
PR=$7
USE_TE=$8
MG2HF=$9
HF_CKPT_PATH=${10}

CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )"
MEGATRON_PATH=$( dirname $(dirname $( dirname ${CURRENT_DIR})))
Expand Down Expand Up @@ -148,6 +149,16 @@ elif [ $USE_TE = false ]; then
"
fi

if [ $PR = fp16 ]; then
pr_options=" \
--fp16"

elif [ $PR = bf16 ]; then
pr_options=" \
--bf16"

fi


DISTRIBUTED_ARGS="--nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

Expand All @@ -158,7 +169,6 @@ torchrun ${DISTRIBUTED_ARGS} hf2mcore_qwen2_dense_and_moe_gqa.py \
--target-pipeline-model-parallel-size ${PP} \
--micro-batch-size 1 \
--save-interval 1 \
--bf16 \
--swiglu \
--num-layers ${NUM_HIDDEN_LAYERS} \
--hidden-size ${HIDDEN_SIZE} \
Expand All @@ -185,6 +195,7 @@ torchrun ${DISTRIBUTED_ARGS} hf2mcore_qwen2_dense_and_moe_gqa.py \
${moe_options} \
${te_options} \
${convert_options} \
${pr_options} \
${cpu_options}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

def add_model_args(parser):

Expand Down Expand Up @@ -298,16 +300,18 @@ def load_megatron_model(args):
else:
raise ValueError('not support yet')

model.load_state_dict(state_dict)
model.load_state_dict(state_dict, strict=False)
return model


def convert_checkpoint_from_megatron_to_transformers(mgmodel, hfmodel, args):

if args.fp16:
mgmodel = mgmodel.float16()
mgmodel = mgmodel.half()
hfmodel = hfmodel.half()
elif args.bf16:
mgmodel = mgmodel.bfloat16()
hfmodel = hfmodel.bfloat16()

num_query_groups = args.num_query_groups
hidden_size = args.hidden_size
Expand Down Expand Up @@ -376,9 +380,11 @@ def convert_checkpoint_from_megatron_to_transformers(mgmodel, hfmodel, args):
def convert_checkpoint_from_transformers_to_megatron(hfmodel, mgmodel, args):

if args.fp16:
mgmodel = mgmodel.float16()
mgmodel = mgmodel.half()
hfmodel = hfmodel.half()
elif args.bf16:
mgmodel = mgmodel.bfloat16()
hfmodel = hfmodel.bfloat16()

assert args.num_query_groups >= args.target_tensor_model_parallel_size

Expand Down Expand Up @@ -828,25 +834,28 @@ def print_output_hook(module, args, kwargs, output, layer_idx, mode):
with_kwargs=True)


input_ids = torch.tensor([[1, 2, 3]]).long().cuda()
input_ids = torch.tensor([[151644, 8506, 22564, 27608, 75188, 4344, 121395, 61991, 79554, 36689]]).long().cuda()
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True)
print(hfmodel)
print(mgmodel)
is_oom = False
with torch.inference_mode():
try:
hfmodel.cuda()
hfmodel(input_ids=input_ids)
hflogits = hfmodel(input_ids=input_ids).logits
except torch.cuda.OutOfMemoryError:
print('oom for huggingface model forward')
is_oom = True
hfmodel.cpu()
del hfmodel

with torch.inference_mode():
try:
mgmodel.cuda()
mgmodel(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
mglogits = mgmodel(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
except torch.cuda.OutOfMemoryError:
print('oom for megatron model forward')
is_oom = True
mgmodel.cpu()
del mgmodel

Expand All @@ -860,6 +869,12 @@ def print_output_hook(module, args, kwargs, output, layer_idx, mode):
diff_max = (hfv - mgv).abs().max()
print(f'layer:{idx}, {k}, diff: {same_num}, diff>{epsilon}:[{diff_num}/{hfv.numel()}] diff_max:{diff_max}')

if not is_oom:
same_num = (hflogits != mglogits).sum()
diff_num = ((hflogits - mglogits) > epsilon).sum()
diff_max = (hflogits - mglogits).abs().max()
print(f'logits: {same_num}, diff>{epsilon}:[{diff_num}/{hflogits.numel()}] diff_max:{diff_max}')


def add_extra_args(parser):
parser = get_patch_args(parser)
Expand Down

0 comments on commit 4c94f2e

Please sign in to comment.