From ea0c5ca12096768df1626c1dec7f98f1abda522b Mon Sep 17 00:00:00 2001 From: YizhenJia Date: Tue, 24 Dec 2024 10:30:12 +0800 Subject: [PATCH 1/3] [usability] add hymba lora target --- src/lmflow/args.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/lmflow/args.py b/src/lmflow/args.py index dbe6fbb3..085d090b 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -238,8 +238,8 @@ class ModelArguments: metadata={ "help": "Merging ratio between the fine-tuned model and the original. This is controlled by a parameter called alpha in the paper."}, ) - lora_target_modules: List[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name",} + lora_target_modules: str = field( + default=None, metadata={"help": "Model modules to apply LoRA to. Use comma to separate multiple modules."} ) lora_dropout: float = field( default=0.1, @@ -364,6 +364,9 @@ def __post_init__(self): if not is_flash_attn_available(): self.use_flash_attention = False logger.warning("Flash attention is not available in the current environment. Disabling flash attention.") + + if self.lora_target_modules is not None: + self.lora_target_modules: List[str] = split_args(self.lora_target_modules) @dataclass @@ -1464,3 +1467,7 @@ class AutoArguments: def get_pipeline_args_class(pipeline_name: str): return PIPELINE_ARGUMENT_MAPPING[pipeline_name] + + +def split_args(args): + return [elem.strip() for elem in args.split(",")] if isinstance(args, str) else args \ No newline at end of file From 748a3632916686538789d3c76647c4d79f75ed60 Mon Sep 17 00:00:00 2001 From: YizhenJia Date: Tue, 24 Dec 2024 10:34:41 +0800 Subject: [PATCH 2/3] [usability] add hymba lora target --- src/lmflow/utils/constants.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/lmflow/utils/constants.py b/src/lmflow/utils/constants.py index f0278985..df4c969e 100644 --- a/src/lmflow/utils/constants.py +++ b/src/lmflow/utils/constants.py @@ -386,12 +386,17 @@ DEFAULT_IM_END_TOKEN = "" # Lora -# NOTE: Be careful, when passing lora_target_modules through arg parser, the -# value should be like'--lora_target_modules q_proj, v_proj \', while specifying -# here, it should be in list format. +# NOTE: This work as a mapping for those models that `peft` library doesn't support yet, and will be +# overridden by peft.utils.constants.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING +# if the model is supported (see hf_model_mixin.py). +# NOTE: When passing lora_target_modules through arg parser, the +# value should be a string. Using commas to separate the module names, e.g. +# "--lora_target_modules 'q_proj, v_proj'". +# However, when specifying here, they should be lists. LMFLOW_LORA_TARGET_MODULES_MAPPING = { 'qwen2': ["q_proj", "v_proj"], 'internlm2': ["wqkv"], + 'hymba': ["x_proj.0", "in_proj", "out_proj", "dt_proj.0"] } # vllm inference From 400af44bdb24010722f7669bfbcb02ce69c05abb Mon Sep 17 00:00:00 2001 From: YizhenJia Date: Tue, 24 Dec 2024 10:36:02 +0800 Subject: [PATCH 3/3] typo fix --- src/lmflow/utils/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lmflow/utils/constants.py b/src/lmflow/utils/constants.py index df4c969e..04ad5696 100644 --- a/src/lmflow/utils/constants.py +++ b/src/lmflow/utils/constants.py @@ -387,7 +387,7 @@ # Lora # NOTE: This work as a mapping for those models that `peft` library doesn't support yet, and will be -# overridden by peft.utils.constants.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING +# overwritten by peft.utils.constants.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING # if the model is supported (see hf_model_mixin.py). # NOTE: When passing lora_target_modules through arg parser, the # value should be a string. Using commas to separate the module names, e.g.