diff --git a/assets/Hymba_loss.png b/assets/Hymba_loss.png new file mode 100644 index 000000000..173ea2546 Binary files /dev/null and b/assets/Hymba_loss.png differ diff --git a/experimental/Hymba/README.md b/experimental/Hymba/README.md new file mode 100644 index 000000000..3d11018c8 --- /dev/null +++ b/experimental/Hymba/README.md @@ -0,0 +1,43 @@ +# Training Hymba with LMFlow + +## Hymba +[GITHUB](https://github.com/NVlabs/hymba/tree/main) +Hymba is a family of small language models (SLMs) featuring a hybrid-head parallel architecture that integrates transformer attention mechanisms with SSMs to achieve the best of both worlds: enhanced efficiency and improved performance. In Hymba, attention heads provide high-resolution recall, while SSM heads enable efficient context summarization. + +## Preparing the environment + +- Using Docker + +```bash +docker pull ghcr.io/tilmto/hymba:v1 +docker run --gpus all -v /home/$USER:/home/$USER -it ghcr.io/tilmto/hymba:v1 bash +``` + +- Install LMFlow in the docker container + +```bash +git clone https://github.com/OptimalScale/LMFlow.git +cd LMFlow +conda create -n lmflow python=3.9 -y +conda activate lmflow +conda install mpi4py +pip install -e . +``` + +- Tips + +For training the Hymba model, please add below arguments to the `run_finetune.sh` script: + +```bash +--trust_remote_code True +--bf16 +``` + +Demo script: [run_finetune_hymba.sh](./run_finetune_hymba.sh) + +Recommend on the A100, H100, A40 GPUs. + + +## Training Loss +The training loss curve for `nvidia/Hymba-1.5B-Instruct`, fine-tuned on the `MedMCQA/train` dataset with a learning rate of $5e-5$ over 100 steps using SFT, LoRA, LISA, and DORA, is shown below: +![Training Loss](../../assets/Hymba_loss.png) \ No newline at end of file diff --git a/experimental/Hymba/run_finetune_hymba.sh b/experimental/Hymba/run_finetune_hymba.sh new file mode 100644 index 000000000..37b6b05b5 --- /dev/null +++ b/experimental/Hymba/run_finetune_hymba.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# Please run this script under ${project_id} in project directory of +# https://github.com/shizhediao/llm-ft +# COMMIT: d5fecf30ba8011067b10cf51fede53a5ab6574e4 + +# Parses arguments +model_name_or_path=nvidia/Hymba-1.5B-Instruct +dataset_path=MedMCQA/train +output_dir=output_models/finetune +deepspeed_args="--master_port=11000" +conversation_template=llama2 + +# Safety related arguments +trust_remote_code=0 + +while [[ $# -ge 1 ]]; do + key="$1" + case ${key} in + -m|--model_name_or_path) + model_name_or_path="$2" + shift + ;; + -d|--dataset_path) + dataset_path="$2" + shift + ;; + -o|--output_model_path) + output_dir="$2" + shift + ;; + --conversation_template) + conversation_template="$2" + shift + ;; + --deepspeed_args) + deepspeed_args="$2" + shift + ;; + --trust_remote_code) + trust_remote_code="$2" + shift + ;; + *) + echo "error: unknown option \"${key}\"" 1>&2 + exit 1 + esac + shift +done + +# Finetune +exp_id=finetune +project_dir=$(cd "$(dirname $0)"/..; pwd) +log_dir=${project_dir}/log/${exp_id} +mkdir -p ${output_dir} ${log_dir} + +deepspeed ${deepspeed_args} \ + examples/finetune.py \ + --model_name_or_path ${model_name_or_path} \ + --trust_remote_code ${trust_remote_code} \ + --dataset_path ${dataset_path} \ + --output_dir ${output_dir} --overwrite_output_dir \ + --conversation_template ${conversation_template} \ + --num_train_epochs 0.01 \ + --learning_rate 5e-5 \ + --disable_group_texts 1 \ + --block_size 256 \ + --trust_remote_code True \ + --per_device_train_batch_size 1 \ + --deepspeed configs/ds_config_zero2_no_offload.json \ + --bf16 \ + --run_name hymba_finetune \ + --validation_split_percentage 0 \ + --logging_steps 1 \ + --do_train \ + --gradient_checkpointing 1 \ + --use_flash_attention 1 \ + --ddp_timeout 72000 \ + --save_steps 5000 \ + --dataloader_num_workers 1 \ + > >(tee ${log_dir}/train.log) \ + 2> >(tee ${log_dir}/train.err >&2) + + diff --git a/src/lmflow/args.py b/src/lmflow/args.py index 65572eb43..dbe6fbb32 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -203,6 +203,10 @@ class ModelArguments: "choices": ["auto", "bfloat16", "float16", "float32"], }, ) + use_dora: bool = field( + default=False, + metadata={"help": "Whether to dora, https://github.com/NVlabs/DoRA."}, + ) use_lora: bool = field( default=False, metadata={"help": "Whether to lora."}, diff --git a/src/lmflow/models/hf_model_mixin.py b/src/lmflow/models/hf_model_mixin.py index 0d2c26671..45f414292 100644 --- a/src/lmflow/models/hf_model_mixin.py +++ b/src/lmflow/models/hf_model_mixin.py @@ -277,7 +277,29 @@ def __prepare_peft_config( lora_dropout=model_args.lora_dropout, target_modules=lora_target_modules, ) + if model_args.use_dora: + if model_args.lora_target_modules: + lora_target_modules = model_args.lora_target_modules + else: + model_config = self.hf_model_config + if hasattr(model_config, "to_dict"): + model_config = model_config.to_dict() + if "model_type" not in model_config or not model_config["model_type"]: + logger.warning("It seems that your base model is a custom model, since " + "model_type is not found in model_config when preparing peft config. " + "Setting model_type to 'custom' as a fallback.") + model_config["model_type"] = "custom" + lora_target_modules = LORA_TARGET_MODULES_MAPPING.get(model_config["model_type"], None) + peft_config = LoraConfig( + use_dora=True, + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=model_args.lora_r, + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout, + target_modules=lora_target_modules, + ) return peft_config @@ -307,6 +329,13 @@ def __prepare_model_for_training( ): assert self.do_train, "To prepare the model for training, please set do_train=True." # TODO: change to accelerate + + if 'hymba' in model_args.model_name_or_path: + import torch._dynamo + torch._dynamo.config.suppress_errors = True + torch._dynamo.config.disable = True + + logger.info("Preparing model for training") if model_args.model_name_or_path: model = hf_auto_model.from_pretrained( @@ -314,6 +343,7 @@ def __prepare_model_for_training( torch_dtype=self.torch_dtype, config=self.hf_model_config, quantization_config=self.quant_config, + trust_remote_code=model_args.trust_remote_code, ) if model_args.use_qlora: @@ -333,7 +363,7 @@ def __prepare_model_for_training( name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool ] - if model_args.use_lora: + if model_args.use_lora or model_args.use_dora: model.enable_input_require_grads() if model_args.lora_model_path is not None: # Load model from LoRA weights diff --git a/src/lmflow/pipeline/finetuner.py b/src/lmflow/pipeline/finetuner.py index de5eb6279..34eb948c0 100644 --- a/src/lmflow/pipeline/finetuner.py +++ b/src/lmflow/pipeline/finetuner.py @@ -524,6 +524,7 @@ def __init__(self, n_layers, interval_steps, model): 'MixtralForCausalLM': 'model.model.layers', 'GemmaForCausalLM': 'model.model.layers', 'GPT2LMHeadModel': 'model.transformer.h', + 'HymbaForCausalLM': 'model.model.layers', } model_class_name = self.model.__class__.__name__ if model_class_name in class_to_layers_map: