diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md index 6f554ca221..83b2adddc5 100644 --- a/comps/finetuning/README.md +++ b/comps/finetuning/README.md @@ -114,7 +114,42 @@ curl http://${your_ip}:8015/v1/fine_tuning/jobs \ }' ``` -#### 3.2.2 Reranking Model Training +#### 3.2.2 Instruction Tuning with SQFT's Neural Low-Rank Adapter Search (NLS) + +In addition to traditional fine-tuning, you can use SQFT's NLS to fine-tune your model. +More details about SQFT can be found in [this paper](https://aclanthology.org/2024.findings-emnlp.749.pdf). +Please follow the additional installation requirements [here](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/SQFT/opea#-start-the-nls-microservice-with-python). +Use the following command to launch a finetuning job with the NLS algorithm: + +```bash +# create a fine-tuning job with NLS +# Max LoRA rank: 16 +# LoRA target modules -> Low-rank search space +# ["q_proj", "k_proj", "v_proj"] -> [16,12,8] +# ["up_proj"] -> [16,12,8] +# ["down_proj"] -> [16,12,8] +curl http://${your_ip}:8015/v1/fine_tuning/jobs \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "training_file": "alpaca_data.json", + "model": "meta-llama/Llama-2-7b-chat-hf", + "General": { + "lora_config": { + "r": 16, + "neural_lora_search": true, + "target_module_groups": [["q_proj", "k_proj", "v_proj"], ["up_proj"], ["down_proj"]], + "search_space": ["16,12,8", "16,12,8", "16,12,8"] + } + } + }' +``` + +Detailed explanations for the parameters can be found [here](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/SQFT/opea#create-an-nls-fine-tuning-job). +Additional use-cases and benefits of SQFT are available [here](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/SQFT/opea). +Instructions to extracting the desired sub-adapter and merging it with the base model can be found [here](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/SQFT/opea#leverage-the-fine-tuned-super-adapter). + +#### 3.2.3 Reranking Model Training Use the following command to launch a finetuning job for reranking model finetuning, such as `BAAI/bge-reranker-large`: @@ -133,7 +168,7 @@ curl http://${your_ip}:8015/v1/fine_tuning/jobs \ }' ``` -#### 3.2.3 Embedding Model Training +#### 3.2.4 Embedding Model Training Use the following command to launch a finetuning job for embedding model finetuning, such as `BAAI/bge-base-en-v1.5`: @@ -173,7 +208,7 @@ curl http://${your_ip}:8015/v1/fine_tuning/jobs \ ``` -#### 3.2.4 LLM Pretraining +#### 3.2.5 LLM Pretraining Use the following command to launch a job for LLM pretraining, such as `meta-llama/Llama-2-7b-hf`: @@ -199,7 +234,7 @@ Below is an example for the format of the pretraining dataset: {"text": "A boy with a blue tank top sitting watching three dogs."} ``` -#### 3.2.5 Direct Preference Optimization (DPO) +#### 3.2.6 Direct Preference Optimization (DPO) Use the following command to launch a job for LLM Direct Preference Optimization, such as `meta-llama/Llama-2-7b-hf`: diff --git a/comps/finetuning/finetune_config.py b/comps/finetuning/finetune_config.py index 0b2faf53db..1cf3b08794 100644 --- a/comps/finetuning/finetune_config.py +++ b/comps/finetuning/finetune_config.py @@ -5,9 +5,9 @@ from typing import List, Optional, Union -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, root_validator, validator -from comps.cores.proto.api_protocol import FineTuningJobsRequest +from comps.cores.proto.api_protocol import FineTuningJobIDRequest, FineTuningJobsRequest PRECISION_BF16 = "bf16" PRECISION_FP16 = "fp16" @@ -37,6 +37,24 @@ class LoraConfig(BaseModel): target_modules: Optional[List[str]] = None +class SQFTNLSConfig(LoraConfig): + neural_lora_search: bool = False + target_module_groups: Optional[List[List[str]]] = None + search_space: Optional[List[str]] = None + + @root_validator(pre=True) + def set_target_modules(cls, values): + if values.get("neural_lora_search"): + target_module_groups = values.get("target_module_groups") + search_space = values.get("search_space") + if target_module_groups is None or search_space is None: + raise ValueError("Please specified `target_module_groups` and `search_space` when using NLS strategy.") + if len(search_space) != len(target_module_groups): + raise ValueError("The length of `search_space` must be equal to the length of `target_module_groups`.") + values["target_modules"] = [module for groups in target_module_groups for module in groups] + return values + + class GeneralConfig(BaseModel): base_model: str = None tokenizer_name: Optional[str] = None @@ -47,7 +65,7 @@ class GeneralConfig(BaseModel): resume_from_checkpoint: Optional[str] = None save_strategy: str = "no" config: LoadConfig = LoadConfig() - lora_config: Optional[LoraConfig] = LoraConfig() + lora_config: Optional[Union[LoraConfig, SQFTNLSConfig]] = LoraConfig() enable_gradient_checkpointing: bool = False task: str = "instruction_tuning" @@ -200,3 +218,12 @@ class FineTuningParams(FineTuningJobsRequest): General: GeneralConfig = GeneralConfig() Dataset: DatasetConfig = DatasetConfig() Training: TrainingConfig = TrainingConfig() + + +class ExtractSubAdapterParams(FineTuningJobIDRequest): + adapter_version: str = "heuristic" + custom_config: Optional[List[int]] = None + + +class MergeAdapterParams(FineTuningJobIDRequest): + adapter_version: Optional[str] = None diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py index 64097c720c..1d76eab0ae 100644 --- a/comps/finetuning/finetuning_service.py +++ b/comps/finetuning/finetuning_service.py @@ -4,12 +4,14 @@ from comps import opea_microservices, register_microservice from comps.cores.proto.api_protocol import FineTuningJobIDRequest, UploadFileRequest -from comps.finetuning.finetune_config import FineTuningParams +from comps.finetuning.finetune_config import ExtractSubAdapterParams, FineTuningParams, MergeAdapterParams from comps.finetuning.handlers import ( handle_cancel_finetuning_job, handle_create_finetuning_jobs, + handle_extract_sub_adapter, handle_list_finetuning_checkpoints, handle_list_finetuning_jobs, + handle_merge_adapter, handle_retrieve_finetuning_job, handle_upload_training_files, upload_file, @@ -63,5 +65,17 @@ def list_checkpoints(request: FineTuningJobIDRequest): return checkpoints +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/finetune/extract_sub_adapter", host="0.0.0.0", port=8015 +) +def extract_sub_adapter(request: ExtractSubAdapterParams): + return handle_extract_sub_adapter(request) + + +@register_microservice(name="opea_service@finetuning", endpoint="/v1/finetune/merge_adapter", host="0.0.0.0", port=8015) +def merge_adapter(request: MergeAdapterParams): + return handle_merge_adapter(request) + + if __name__ == "__main__": opea_microservices["opea_service@finetuning"].start() diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py index a47b9f980a..fb4fe9afdc 100644 --- a/comps/finetuning/handlers.py +++ b/comps/finetuning/handlers.py @@ -11,7 +11,7 @@ from typing import Dict from fastapi import BackgroundTasks, File, Form, HTTPException, UploadFile -from pydantic_yaml import to_yaml_file +from pydantic_yaml import parse_yaml_file_as, to_yaml_file from ray.job_submission import JobSubmissionClient from comps import CustomLogger @@ -23,7 +23,12 @@ FineTuningJobList, UploadFileRequest, ) -from comps.finetuning.finetune_config import FinetuneConfig, FineTuningParams +from comps.finetuning.finetune_config import ( + ExtractSubAdapterParams, + FinetuneConfig, + FineTuningParams, + MergeAdapterParams, +) logger = CustomLogger("finetuning_handlers") @@ -134,6 +139,98 @@ def handle_create_finetuning_jobs(request: FineTuningParams, background_tasks: B return job +def handle_extract_sub_adapter(request: ExtractSubAdapterParams): + fine_tuning_job_id = request.fine_tuning_job_id + finetune_config_file = f"{JOBS_PATH}/{fine_tuning_job_id}.yaml" + finetune_config = parse_yaml_file_as(FinetuneConfig, finetune_config_file) + + job = running_finetuning_jobs.get(fine_tuning_job_id) + if job is None: + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") + finetuned_model_path = os.path.join(OUTPUT_DIR, fine_tuning_job_id) + assert finetuned_model_path == finetune_config.General.output_dir + if not os.path.exists(finetuned_model_path): + raise HTTPException( + status_code=404, + detail=f"The fine-tuned model saved by the fine-tuning job '{fine_tuning_job_id}' was not found!", + ) + if job.status != "succeeded": + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' has not completed!") + + if finetune_config.General.lora_config is None: + raise HTTPException( + status_code=404, + detail=f"The fine-tuning job '{fine_tuning_job_id}' does not enable LoRA adapter fine-tuning!", + ) + if not finetune_config.General.lora_config.neural_lora_search: + raise HTTPException( + status_code=404, + detail=f"The fine-tuning job '{fine_tuning_job_id}' did not enable NLS algorithm, " + f"there is no need to extract sub-adapters!", + ) + nncf_config_path = os.path.join(finetune_config.General.output_dir, "nncf_config.json") + if not os.path.exists(nncf_config_path): + raise HTTPException( + status_code=404, detail=f"The NNCF config file does not exist in the fine-tuning job '{fine_tuning_job_id}!" + ) + + from comps.finetuning.utils.extract_sub_adapter import main as extract_sub_adapter_main + + extract_sub_adapter_main( + adapter_model_path=finetuned_model_path, + nncf_config=nncf_config_path, + adapter_version=request.adapter_version, + custom_config=request.custom_config, + ) + + return fine_tuning_job_id + + +def handle_merge_adapter(request: MergeAdapterParams): + fine_tuning_job_id = request.fine_tuning_job_id + finetune_config_file = f"{JOBS_PATH}/{fine_tuning_job_id}.yaml" + finetune_config = parse_yaml_file_as(FinetuneConfig, finetune_config_file) + + job = running_finetuning_jobs.get(fine_tuning_job_id) + if job is None: + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") + finetuned_model_path = os.path.join(OUTPUT_DIR, fine_tuning_job_id) + assert finetuned_model_path == finetune_config.General.output_dir + if not os.path.exists(finetuned_model_path): + raise HTTPException( + status_code=404, + detail=f"The fine-tuned model saved by the fine-tuning job '{fine_tuning_job_id}' was not found!", + ) + if job.status != "succeeded": + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' has not completed!") + + if finetune_config.General.lora_config is None: + raise HTTPException( + status_code=404, + detail=f"The fine-tuning job '{fine_tuning_job_id}' does not enable LoRA adapter fine-tuning!", + ) + + adapter_path = finetuned_model_path + adapter_version = request.adapter_version + if adapter_version is not None: + adapter_path = os.path.join(adapter_path, adapter_version) + if not os.path.exists(adapter_path): + raise HTTPException( + status_code=404, + detail=f"The fine-tuning job '{fine_tuning_job_id}' does not have a '{adapter_version}' adapter!", + ) + + from comps.finetuning.utils.merge_adapter import main as merge_adapter_main + + merge_adapter_main( + base_model_path=finetune_config.General.base_model, + adapter_model_path=adapter_path, + output_path=os.path.join(adapter_path, "merged_model"), + ) + + return fine_tuning_job_id + + def handle_list_finetuning_jobs(): finetuning_jobs_list = FineTuningJobList(data=list(running_finetuning_jobs.values()), has_more=False) diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py index d105269a40..5216cc660b 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune.py @@ -39,6 +39,18 @@ logger = CustomLogger("llm_on_ray/finetune") +try: + from nncf.experimental.torch.nas.bootstrapNAS.training.model_creator_helpers import ( + create_compressed_model_from_algo_names, + ) + from nncf.torch.model_creation import create_nncf_network + + from comps.finetuning.utils.create_sqft_nncf_config import create_sqft_nncf_config + + is_nncf_available = True +except ImportError: + is_nncf_available = False + def adapt_transformers_to_device(config: Dict): device = config["Training"]["device"] @@ -338,6 +350,7 @@ def load_model(config: Dict): model_config = config["General"].get("config", {}) task = config["General"].get("task", "instruction_tuning") ref_model = None + nls_controller = None if task in ["instruction_tuning", "pretraining", "dpo"]: model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, **model_config) if task == "dpo": @@ -346,8 +359,21 @@ def load_model(config: Dict): ) lora_config = config["General"].get("lora_config", None) if lora_config and task == "instruction_tuning": + neural_lora_search = lora_config.pop("neural_lora_search", False) + target_module_groups = lora_config.pop("target_module_groups", None) + search_space = lora_config.pop("search_space", None) peft_config = LoraConfig(**lora_config) model = get_peft_model(model, peft_config) + + # Neural LoRA Search (NLS) + if neural_lora_search: + if not is_nncf_available: + raise NotImplementedError("NNCF is not installed. Please install it for enabling NLS algorithm.") + nncf_config = create_sqft_nncf_config( + config=config, model=model, target_module_groups=target_module_groups, search_space=search_space + ) + model = create_nncf_network(model, nncf_config) + nls_controller, model = create_compressed_model_from_algo_names(model, nncf_config, algo_names=["nls"]) elif task == "rerank": model = CrossEncoder.from_pretrained( config["Dataset"].get("train_group_size", 8), @@ -383,10 +409,10 @@ def load_model(config: Dict): model.to(dtype=model_dtype, device=torch.device(config["Training"]["device"])) - return model, ref_model + return model, ref_model, nls_controller -def get_trainer(config: Dict, model, ref_model, tokenizer, tokenized_dataset, data_collator): +def get_trainer(config: Dict, model, ref_model, tokenizer, tokenized_dataset, data_collator, nls_controller=None): device = config["Training"]["device"] task = config["General"].get("task", "instruction_tuning") if device in ["cpu", "gpu", "cuda"]: @@ -411,18 +437,23 @@ def get_trainer(config: Dict, model, ref_model, tokenizer, tokenized_dataset, da max_length=config["Dataset"].get("max_length", 1024), ) else: - trainer = Trainer( - model=model, - args=training_args, - train_dataset=tokenized_dataset["train"], - eval_dataset=( + trainer_args = { + "model": model, + "args": training_args, + "train_dataset": tokenized_dataset["train"], + "eval_dataset": ( tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None else None ), - tokenizer=tokenizer, - data_collator=data_collator, - ) + "tokenizer": tokenizer, + "data_collator": data_collator, + } + if nls_controller is not None: + trainer_args["compression_ctrl"] = nls_controller + trainer = Trainer(**trainer_args) return training_args, trainer elif device in ["hpu"]: + if nls_controller is not None: + raise NotImplementedError("NLS algorithm is not supported on HPU now.") from optimum.habana import GaudiConfig from optimum.habana.transformers import GaudiTrainer, GaudiTrainingArguments @@ -495,9 +526,11 @@ def train_func(config: Dict[str, Any]): data_collator = prepare_data_collator(config, tokenizer) - model, ref_model = load_model(config) + model, ref_model, nls_controller = load_model(config) - training_args, trainer = get_trainer(config, model, ref_model, tokenizer, tokenized_dataset, data_collator) + training_args, trainer = get_trainer( + config, model, ref_model, tokenizer, tokenized_dataset, data_collator, nls_controller=nls_controller + ) logger.info("train start") trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) diff --git a/comps/finetuning/utils/create_sqft_nncf_config.py b/comps/finetuning/utils/create_sqft_nncf_config.py new file mode 100644 index 0000000000..731791da41 --- /dev/null +++ b/comps/finetuning/utils/create_sqft_nncf_config.py @@ -0,0 +1,130 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import os + +try: + from nncf import NNCFConfig + from nncf.experimental.torch import sqft + + is_nncf_available = True +except ImportError: + is_nncf_available = False + + +NNCF_CONFIG_TEMPLATE = { + "input_info": [ + {"sample_size": [1, 256], "type": "long", "keyword": "input_ids"}, + {"sample_size": [1, 256], "type": "long", "keyword": "attention_mask"}, + ], + "SQFT": { + "training": { + "algorithm": "nls", + "elasticity": { + "available_elasticity_dims": ["width"], + "width": {"overwrite_groups": [], "overwrite_groups_widths": []}, + }, + } + }, +} + + +def add_lr_epochs(nncf_config, learning_rate=3e-4, num_train_epochs=3): + """Add learning rate and epochs to the NNCF configuration. + + Args: + nncf_config (dict): The NNCF configuration dictionary. + learning_rate (float): The initial learning rate to set. + num_epochs (int): The number of epochs to set. + + Returns: + dict: The updated NNCF configuration. + """ + overwrite_groups_widths = nncf_config["SQFT"]["training"]["elasticity"]["width"]["overwrite_groups_widths"] + # Add learning rate and epochs to the configuration + nncf_config["SQFT"]["training"]["schedule"] = { + "list_stage_descriptions": [ + { + "train_dims": ["width"], + "width_indicator": max([len(widths) for widths in overwrite_groups_widths]), + "init_lr": learning_rate, + "epochs": num_train_epochs, + "epochs_lr": num_train_epochs, + } + ] + } + return nncf_config + + +def get_model_paths(model, target_module_name): + """Find all paths to the target layer in the model. + + Args: + model (torch.nn.Module): The model to search. + target_module_name (str): The name of the target layer. + + Returns: + list: A list of paths to the target layer. + """ + + def find_layers(module, target_module_name, path, paths): + for name, sub_module in module.named_children(): + new_path = f"{path}/{sub_module.__class__.__name__}[{name}]" + if target_module_name in name: + # Check if 'lora_A' is in the sub_module's children + for sub_name, _ in sub_module.named_children(): + if "lora_A" in sub_name: + paths.append(f"{new_path}/ModuleDict[lora_A]/NNCFLinear[default]/linear_0") + find_layers(sub_module, target_module_name, new_path, paths) + + base_path = model.__class__.__name__ + paths = [] + find_layers(model, target_module_name, base_path, paths) + return paths + + +def create_sqft_nncf_config(config, model, target_module_groups=None, search_space=None): + """Load and preprocess the NNCF configuration file. + + Returns: + NNCFConfig: The preprocessed NNCF configuration object. + """ + if not is_nncf_available: + raise NotImplementedError("NNCF is not installed. Please install it for enabling NLS algorithm.") + if search_space is None and target_module_groups: + raise ValueError("Neural LoRA search is enabled, `search_space` and `target_module_groups` must be provided.") + # The NNCF Config will be automatically generated based on `target_module_groups` and `search_space`. + num_hidden_layers = model.config.num_hidden_layers + nncf_config_dict = NNCF_CONFIG_TEMPLATE + overwrite_groups = [] + for group in target_module_groups: + group_paths = [] + for module in group: + target_layer_name = module + paths = get_model_paths(model, target_layer_name) + assert paths, f"No paths found for module {module}" + group_paths.append(paths) + # Transpose the list of lists to combine paths by their positions + transposed_paths = list(zip(*group_paths)) + overwrite_groups.extend([list(path_group) for path_group in transposed_paths]) + nncf_config_dict["SQFT"]["training"]["elasticity"]["width"]["overwrite_groups"] = overwrite_groups + + overwrite_groups_widths = [] + for space in search_space: + space = [int(width) for width in space.split(",")] + overwrite_groups_widths.extend([space] * num_hidden_layers) + nncf_config_dict["SQFT"]["training"]["elasticity"]["width"]["overwrite_groups_widths"] = overwrite_groups_widths + assert len(overwrite_groups) == len(overwrite_groups_widths) + nncf_config_dict = add_lr_epochs( + nncf_config_dict, + learning_rate=config["Training"]["learning_rate"], + num_train_epochs=config["Training"]["epochs"], + ) + nncf_config = NNCFConfig.from_dict(nncf_config_dict) + + nncf_config["log_dir"] = config["General"]["output_dir"] + os.makedirs(nncf_config["log_dir"], exist_ok=True) + with open(os.path.join(nncf_config["log_dir"], "nncf_config.json"), "w") as f: + json.dump(nncf_config, f, indent=4) + return nncf_config diff --git a/comps/finetuning/utils/extract_sub_adapter.py b/comps/finetuning/utils/extract_sub_adapter.py new file mode 100644 index 0000000000..2f5eccde32 --- /dev/null +++ b/comps/finetuning/utils/extract_sub_adapter.py @@ -0,0 +1,101 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import re +import shutil + +import torch +from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME + +try: + from nncf import NNCFConfig + + is_nncf_available = True +except ImportError: + is_nncf_available = False + + +PATTERN = re.compile(r"[[](.*?)[]]", re.S) + + +def get_width_for_query_prefix(torch_module_to_width, query_module, length=5): + """Get the width for a given query module prefix. + + Args: + torch_module_to_width (dict): Mapping from torch module to width. + query_module (str): The query module name. + length (int, optional): The length of the prefix to match. Default is 5. + + Returns: + int: The width for the query module prefix. + """ + query_module_list = query_module.split(".") + width = next( + ( + value + for torch_module, value in torch_module_to_width.items() + if torch_module.split(".")[:length] == query_module_list[:length] + ), + None, + ) + return width + + +def main(adapter_model_path, nncf_config, adapter_version, custom_config=None): + if not is_nncf_available: + raise NotImplementedError("NNCF is not installed. Please install it.") + output_dir = os.path.join(adapter_model_path, adapter_version) + os.makedirs(output_dir, exist_ok=True) + nncf_config = NNCFConfig.from_json(nncf_config) + try: + groups = nncf_config["SQFT"]["training"]["elasticity"]["width"]["overwrite_groups"] + groups_widths = nncf_config["SQFT"]["training"]["elasticity"]["width"]["overwrite_groups_widths"] + assert len(groups) == len(groups_widths) + except Exception: + raise ValueError("Cannot get the search space in NNCF config.") + + if adapter_version == "maximal": + subnetwork_config = {idx: space[0] for idx, space in enumerate(groups_widths)} + elif adapter_version == "heuristic": + subnetwork_config = {idx: space[(len(space) - 1) // 2] for idx, space in enumerate(groups_widths)} + elif adapter_version == "minimal": + subnetwork_config = {idx: space[-1] for idx, space in enumerate(groups_widths)} + else: + assert custom_config is not None, "Missing custom subnetwork config." + assert isinstance(custom_config, list), "Custom config must be a list." + subnetwork_config = {i: value for i, value in enumerate(custom_config)} + + # Mapping: nncf node -> width + nncf_node_to_width = {} + for idx, value in subnetwork_config.items(): + space = groups_widths[idx] + assert min(space) <= value <= max(space) + cur_dict = {node: value for node in groups[idx]} + nncf_node_to_width.update(cur_dict) + + # Prune adapter model (LoRA low-rank) + lora_torch_module_to_width = { + ".".join(re.findall(PATTERN, k)): v for k, v in nncf_node_to_width.items() if "lora_A" in k + } + num_module_name_item = list(lora_torch_module_to_width.keys())[0].split(".").index("lora_A") + # Load adapter weights + try: + super_adapter_weights = torch.load(os.path.join(adapter_model_path, WEIGHTS_NAME)) + except: + from safetensors.torch import load_file + + super_adapter_weights = load_file(os.path.join(adapter_model_path, SAFETENSORS_WEIGHTS_NAME)) + sub_adapter_weights = {} + for weight_key, weight_tensor in super_adapter_weights.items(): + width = get_width_for_query_prefix(lora_torch_module_to_width, weight_key, length=num_module_name_item) + if width is not None: + is_loraA = "lora_A" in weight_key + new_weight_tensor = weight_tensor[:width].clone() if is_loraA else weight_tensor[:, :width].clone() + else: + new_weight_tensor = weight_tensor.clone() + sub_adapter_weights[weight_key] = new_weight_tensor + os.makedirs(output_dir, exist_ok=True) + torch.save(sub_adapter_weights, os.path.join(output_dir, WEIGHTS_NAME)) + config_path = os.path.join(adapter_model_path, CONFIG_NAME) + shutil.copy(config_path, output_dir) diff --git a/comps/finetuning/utils/merge_adapter.py b/comps/finetuning/utils/merge_adapter.py new file mode 100644 index 0000000000..44fd01e8ad --- /dev/null +++ b/comps/finetuning/utils/merge_adapter.py @@ -0,0 +1,18 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import intel_extension_for_pytorch +from peft import PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def main(base_model_path, adapter_model_path, output_path): + base_model = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True) + model = PeftModel.from_pretrained(base_model, adapter_model_path) + model.eval() + merged_model = model.merge_and_unload() + merged_model.train(False) + base_model.save_pretrained(output_path, state_dict=merged_model.state_dict()) + + tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True) + tokenizer.save_pretrained(output_path) diff --git a/tests/finetuning/test_finetuning.sh b/tests/finetuning/test_finetuning.sh index 11a544dfda..1dcb6d7b60 100644 --- a/tests/finetuning/test_finetuning.sh +++ b/tests/finetuning/test_finetuning.sh @@ -22,6 +22,19 @@ function build_docker_images() { fi } +function build_sqft_docker_images() { + cd $WORKPATH + echo $(pwd) + curl -o comps/finetuning/Dockerfile.sqft https://raw.githubusercontent.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/main/SQFT/opea/Dockerfile + docker build --no-cache -t opea/finetuning:comps --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy --build-arg HF_TOKEN=$HF_TOKEN -f comps/finetuning/Dockerfile.sqft . + if [ $? -ne 0 ]; then + echo "opea/finetuning (sqft) built fail" + exit 1 + else + echo "opea/finetuning (sqft) built successful" + fi +} + function start_service() { export no_proxy="localhost,127.0.0.1,"${ip_address} docker run -d --name="test-comps-finetuning-server" -p $finetuning_service_port:$finetuning_service_port -p $ray_port:$ray_port --runtime=runc --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e no_proxy=$no_proxy opea/finetuning:comps @@ -119,6 +132,35 @@ function validate_finetune() { done } +function validate_merge_or_extract_adapter() { + local URL="$1" + local SERVICE_NAME="$2" + local DOCKER_NAME="$3" + local EXPECTED_DATA="$4" + local INPUT_DATA="$5" + + HTTP_RESPONSE=$(curl --silent --write-out "HTTPSTATUS:%{http_code}" -X POST -H 'Content-Type: application/json' -d "$INPUT_DATA" "$URL") + HTTP_STATUS=$(echo $HTTP_RESPONSE | tr -d '\n' | sed -e 's/.*HTTPSTATUS://') + RESPONSE_BODY=$(echo $HTTP_RESPONSE | sed -e 's/HTTPSTATUS\:.*//g') + + if [ "$HTTP_STATUS" -ne "200" ]; then + echo "[ $SERVICE_NAME ] HTTP status is not 200. Received status was $HTTP_STATUS" + docker logs $DOCKER_NAME >> ${LOG_PATH}/finetuning-server_merge_or_extract_adapter.log + exit 1 + else + echo "[ $SERVICE_NAME ] HTTP status is 200. Checking content..." + fi + + # Check if the parsed values match the expected values + if [[ "$RESPONSE_BODY" != *"$EXPECTED_DATA"* ]]; then + echo "[ $SERVICE_NAME ] Content does not match the expected result: $RESPONSE_BODY" + docker logs $DOCKER_NAME >> ${LOG_PATH}/finetuning-server_merge_or_extract_adapter.log + exit 1 + else + echo "[ $SERVICE_NAME ] Content is as expected." + fi +} + function validate_microservice() { cd $LOG_PATH export no_proxy="localhost,127.0.0.1,"${ip_address} @@ -225,6 +267,84 @@ EOF } +function validate_sqft_microservice() { + cd $LOG_PATH + export no_proxy="localhost,127.0.0.1,"${ip_address} + + ########################## + # general test # + ########################## + # test /v1/dataprep upload file + echo '[{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."},{"instruction": "Give three tips for staying healthy.", "input": "", "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."}]' > $LOG_PATH/test_data.json + validate_upload \ + "http://${ip_address}:$finetuning_service_port/v1/files" \ + "general - upload" \ + "test-comps-finetuning-server" \ + "fine-tune" \ + "test_data.json" + + # test /v1/fine_tuning/jobs (LoRA) + validate_finetune \ + "http://${ip_address}:$finetuning_service_port/v1/fine_tuning/jobs" \ + "general - finetuning" \ + "test-comps-finetuning-server" \ + '{"id":"ft-job' \ + '{"training_file": "test_data.json","model": "facebook/opt-125m", "General": {"lora_config": {"r": 8, "target_modules": ["q_proj"]}}}' + + # test merging the LoRA adapter into the base model + validate_merge_or_extract_adapter \ + "http://${ip_address}:$finetuning_service_port/v1/finetune/merge_adapter" \ + "adapter merge" \ + "test-comps-finetuning-server" \ + "${FINTUNING_ID}" \ + "{\"fine_tuning_job_id\": \"${FINTUNING_ID}\"}" + + + ########################## + # sqft test # + ########################## + # test /v1/fine_tuning/jobs (SQFT-NLS) + validate_finetune \ + "http://${ip_address}:$finetuning_service_port/v1/fine_tuning/jobs" \ + "sqft - finetuning" \ + "test-comps-finetuning-server" \ + '{"id":"ft-job' \ + '{"training_file": "test_data.json","model": "facebook/opt-125m", "General": {"lora_config": {"r": 8, "neural_lora_search": true, "target_module_groups": [["q_proj"]], "search_space": ["8,6,4"]}}}' + + # test extracting heuristic sub-adapter + validate_merge_or_extract_adapter \ + "http://${ip_address}:$finetuning_service_port/v1/finetune/extract_sub_adapter" \ + "extract heuristic sub-adapter" \ + "test-comps-finetuning-server" \ + "${FINTUNING_ID}" \ + "{\"fine_tuning_job_id\": \"${FINTUNING_ID}\", \"adapter_version\": \"heuristic\"}" + + # test merging the heuristic sub-adapter into the base model + validate_merge_or_extract_adapter \ + "http://${ip_address}:$finetuning_service_port/v1/finetune/merge_adapter" \ + "merge heuristic sub-adapter" \ + "test-comps-finetuning-server" \ + "${FINTUNING_ID}" \ + "{\"fine_tuning_job_id\": \"${FINTUNING_ID}\", \"adapter_version\": \"heuristic\"}" + + # test extracting sub-adapter with custom configuration + validate_merge_or_extract_adapter \ + "http://${ip_address}:$finetuning_service_port/v1/finetune/extract_sub_adapter" \ + "extract custom sub-adapter" \ + "test-comps-finetuning-server" \ + "${FINTUNING_ID}" \ + "{\"fine_tuning_job_id\": \"${FINTUNING_ID}\", \"adapter_version\": \"custom\", \"custom_config\": [8, 6, 4, 4, 8, 6, 8, 8, 8, 8, 4, 8]}" + + # test merging the custom sub-adapter into the base model + validate_merge_or_extract_adapter \ + "http://${ip_address}:$finetuning_service_port/v1/finetune/merge_adapter" \ + "merge custom sub-adapter" \ + "test-comps-finetuning-server" \ + "${FINTUNING_ID}" \ + "{\"fine_tuning_job_id\": \"${FINTUNING_ID}\", \"adapter_version\": \"custom\"}" + +} + function stop_docker() { cid=$(docker ps -aq --filter "name=test-comps-finetuning-server*") if [[ ! -z "$cid" ]]; then docker stop $cid && docker rm $cid && sleep 1s; fi @@ -233,12 +353,16 @@ function stop_docker() { function main() { stop_docker - build_docker_images start_service - validate_microservice + # test sqft + stop_docker + build_sqft_docker_images + start_service + validate_sqft_microservice + stop_docker echo y | docker system prune