Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[usability] Add hymba lora target (#924) #925

Merged
merged 1 commit into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ class ModelArguments:
metadata={
"help": "Merging ratio between the fine-tuned model and the original. This is controlled by a parameter called alpha in the paper."},
)
lora_target_modules: List[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name",}
lora_target_modules: str = field(
default=None, metadata={"help": "Model modules to apply LoRA to. Use comma to separate multiple modules."}
)
lora_dropout: float = field(
default=0.1,
Expand Down Expand Up @@ -364,6 +364,9 @@ def __post_init__(self):
if not is_flash_attn_available():
self.use_flash_attention = False
logger.warning("Flash attention is not available in the current environment. Disabling flash attention.")

if self.lora_target_modules is not None:
self.lora_target_modules: List[str] = split_args(self.lora_target_modules)


@dataclass
Expand Down Expand Up @@ -1464,3 +1467,7 @@ class AutoArguments:

def get_pipeline_args_class(pipeline_name: str):
return PIPELINE_ARGUMENT_MAPPING[pipeline_name]


def split_args(args):
return [elem.strip() for elem in args.split(",")] if isinstance(args, str) else args
11 changes: 8 additions & 3 deletions src/lmflow/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,17 @@
DEFAULT_IM_END_TOKEN = "<im_end>"

# Lora
# NOTE: Be careful, when passing lora_target_modules through arg parser, the
# value should be like'--lora_target_modules q_proj, v_proj \', while specifying
# here, it should be in list format.
# NOTE: This work as a mapping for those models that `peft` library doesn't support yet, and will be
# overwritten by peft.utils.constants.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
# if the model is supported (see hf_model_mixin.py).
# NOTE: When passing lora_target_modules through arg parser, the
# value should be a string. Using commas to separate the module names, e.g.
# "--lora_target_modules 'q_proj, v_proj'".
# However, when specifying here, they should be lists.
LMFLOW_LORA_TARGET_MODULES_MAPPING = {
'qwen2': ["q_proj", "v_proj"],
'internlm2': ["wqkv"],
'hymba': ["x_proj.0", "in_proj", "out_proj", "dt_proj.0"]
}

# vllm inference
Expand Down
Loading