Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQFT Finetuning #947

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions comps/finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,42 @@ curl http://${your_ip}:8015/v1/fine_tuning/jobs \
}'
```

#### 3.2.2 Reranking Model Training
#### 3.2.2 Instruction Tuning with SQFT's Neural Low-Rank Adapter Search (NLS)

In addition to traditional fine-tuning, you can use SQFT's NLS to fine-tune your model.
More details about SQFT can be found in [this paper](https://aclanthology.org/2024.findings-emnlp.749.pdf).
Please follow the additional installation requirements [here](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/SQFT/opea#-start-the-nls-microservice-with-python).
Use the following command to launch a finetuning job with the NLS algorithm:

```bash
# create a fine-tuning job with NLS
# Max LoRA rank: 16
# LoRA target modules -> Low-rank search space
# ["q_proj", "k_proj", "v_proj"] -> [16,12,8]
# ["up_proj"] -> [16,12,8]
# ["down_proj"] -> [16,12,8]
curl http://${your_ip}:8015/v1/fine_tuning/jobs \
-X POST \
-H "Content-Type: application/json" \
-d '{
"training_file": "alpaca_data.json",
"model": "meta-llama/Llama-2-7b-chat-hf",
"General": {
"lora_config": {
"r": 16,
"neural_lora_search": true,
"target_module_groups": [["q_proj", "k_proj", "v_proj"], ["up_proj"], ["down_proj"]],
"search_space": ["16,12,8", "16,12,8", "16,12,8"]
}
}
}'
```

Detailed explanations for the parameters can be found [here](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/SQFT/opea#create-an-nls-fine-tuning-job).
Additional use-cases and benefits of SQFT are available [here](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/SQFT/opea).
Instructions to extracting the desired sub-adapter and merging it with the base model can be found [here](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/SQFT/opea#leverage-the-fine-tuned-super-adapter).

#### 3.2.3 Reranking Model Training

Use the following command to launch a finetuning job for reranking model finetuning, such as `BAAI/bge-reranker-large`:

Expand All @@ -133,7 +168,7 @@ curl http://${your_ip}:8015/v1/fine_tuning/jobs \
}'
```

#### 3.2.3 Embedding Model Training
#### 3.2.4 Embedding Model Training

Use the following command to launch a finetuning job for embedding model finetuning, such as `BAAI/bge-base-en-v1.5`:

Expand Down Expand Up @@ -173,7 +208,7 @@ curl http://${your_ip}:8015/v1/fine_tuning/jobs \

```

#### 3.2.4 LLM Pretraining
#### 3.2.5 LLM Pretraining

Use the following command to launch a job for LLM pretraining, such as `meta-llama/Llama-2-7b-hf`:

Expand All @@ -199,7 +234,7 @@ Below is an example for the format of the pretraining dataset:
{"text": "A boy with a blue tank top sitting watching three dogs."}
```

#### 3.2.5 Direct Preference Optimization (DPO)
#### 3.2.6 Direct Preference Optimization (DPO)

Use the following command to launch a job for LLM Direct Preference Optimization, such as `meta-llama/Llama-2-7b-hf`:

Expand Down
33 changes: 30 additions & 3 deletions comps/finetuning/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from typing import List, Optional, Union

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, root_validator, validator

from comps.cores.proto.api_protocol import FineTuningJobsRequest
from comps.cores.proto.api_protocol import FineTuningJobIDRequest, FineTuningJobsRequest

PRECISION_BF16 = "bf16"
PRECISION_FP16 = "fp16"
Expand Down Expand Up @@ -37,6 +37,24 @@ class LoraConfig(BaseModel):
target_modules: Optional[List[str]] = None


class SQFTNLSConfig(LoraConfig):
neural_lora_search: bool = False
target_module_groups: Optional[List[List[str]]] = None
search_space: Optional[List[str]] = None

@root_validator(pre=True)
def set_target_modules(cls, values):
if values.get("neural_lora_search"):
target_module_groups = values.get("target_module_groups")
search_space = values.get("search_space")
if target_module_groups is None or search_space is None:
raise ValueError("Please specified `target_module_groups` and `search_space` when using NLS strategy.")
if len(search_space) != len(target_module_groups):
raise ValueError("The length of `search_space` must be equal to the length of `target_module_groups`.")
values["target_modules"] = [module for groups in target_module_groups for module in groups]
return values


class GeneralConfig(BaseModel):
base_model: str = None
tokenizer_name: Optional[str] = None
Expand All @@ -47,7 +65,7 @@ class GeneralConfig(BaseModel):
resume_from_checkpoint: Optional[str] = None
save_strategy: str = "no"
config: LoadConfig = LoadConfig()
lora_config: Optional[LoraConfig] = LoraConfig()
lora_config: Optional[Union[LoraConfig, SQFTNLSConfig]] = LoraConfig()
enable_gradient_checkpointing: bool = False
task: str = "instruction_tuning"

Expand Down Expand Up @@ -200,3 +218,12 @@ class FineTuningParams(FineTuningJobsRequest):
General: GeneralConfig = GeneralConfig()
Dataset: DatasetConfig = DatasetConfig()
Training: TrainingConfig = TrainingConfig()


class ExtractSubAdapterParams(FineTuningJobIDRequest):
adapter_version: str = "heuristic"
custom_config: Optional[List[int]] = None


class MergeAdapterParams(FineTuningJobIDRequest):
adapter_version: Optional[str] = None
16 changes: 15 additions & 1 deletion comps/finetuning/finetuning_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

from comps import opea_microservices, register_microservice
from comps.cores.proto.api_protocol import FineTuningJobIDRequest, UploadFileRequest
from comps.finetuning.finetune_config import FineTuningParams
from comps.finetuning.finetune_config import ExtractSubAdapterParams, FineTuningParams, MergeAdapterParams
from comps.finetuning.handlers import (
handle_cancel_finetuning_job,
handle_create_finetuning_jobs,
handle_extract_sub_adapter,
handle_list_finetuning_checkpoints,
handle_list_finetuning_jobs,
handle_merge_adapter,
handle_retrieve_finetuning_job,
handle_upload_training_files,
upload_file,
Expand Down Expand Up @@ -63,5 +65,17 @@ def list_checkpoints(request: FineTuningJobIDRequest):
return checkpoints


@register_microservice(
name="opea_service@finetuning", endpoint="/v1/finetune/extract_sub_adapter", host="0.0.0.0", port=8015
)
def extract_sub_adapter(request: ExtractSubAdapterParams):
return handle_extract_sub_adapter(request)


@register_microservice(name="opea_service@finetuning", endpoint="/v1/finetune/merge_adapter", host="0.0.0.0", port=8015)
def merge_adapter(request: MergeAdapterParams):
return handle_merge_adapter(request)


if __name__ == "__main__":
opea_microservices["opea_service@finetuning"].start()
101 changes: 99 additions & 2 deletions comps/finetuning/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Dict

from fastapi import BackgroundTasks, File, Form, HTTPException, UploadFile
from pydantic_yaml import to_yaml_file
from pydantic_yaml import parse_yaml_file_as, to_yaml_file
from ray.job_submission import JobSubmissionClient

from comps import CustomLogger
Expand All @@ -23,7 +23,12 @@
FineTuningJobList,
UploadFileRequest,
)
from comps.finetuning.finetune_config import FinetuneConfig, FineTuningParams
from comps.finetuning.finetune_config import (
ExtractSubAdapterParams,
FinetuneConfig,
FineTuningParams,
MergeAdapterParams,
)

logger = CustomLogger("finetuning_handlers")

Expand Down Expand Up @@ -134,6 +139,98 @@ def handle_create_finetuning_jobs(request: FineTuningParams, background_tasks: B
return job


def handle_extract_sub_adapter(request: ExtractSubAdapterParams):
fine_tuning_job_id = request.fine_tuning_job_id
finetune_config_file = f"{JOBS_PATH}/{fine_tuning_job_id}.yaml"
finetune_config = parse_yaml_file_as(FinetuneConfig, finetune_config_file)

job = running_finetuning_jobs.get(fine_tuning_job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!")
finetuned_model_path = os.path.join(OUTPUT_DIR, fine_tuning_job_id)
assert finetuned_model_path == finetune_config.General.output_dir
if not os.path.exists(finetuned_model_path):
raise HTTPException(
status_code=404,
detail=f"The fine-tuned model saved by the fine-tuning job '{fine_tuning_job_id}' was not found!",
)
if job.status != "succeeded":
raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' has not completed!")

if finetune_config.General.lora_config is None:
raise HTTPException(
status_code=404,
detail=f"The fine-tuning job '{fine_tuning_job_id}' does not enable LoRA adapter fine-tuning!",
)
if not finetune_config.General.lora_config.neural_lora_search:
raise HTTPException(
status_code=404,
detail=f"The fine-tuning job '{fine_tuning_job_id}' did not enable NLS algorithm, "
f"there is no need to extract sub-adapters!",
)
nncf_config_path = os.path.join(finetune_config.General.output_dir, "nncf_config.json")
if not os.path.exists(nncf_config_path):
raise HTTPException(
status_code=404, detail=f"The NNCF config file does not exist in the fine-tuning job '{fine_tuning_job_id}!"
)

from comps.finetuning.utils.extract_sub_adapter import main as extract_sub_adapter_main

extract_sub_adapter_main(
adapter_model_path=finetuned_model_path,
nncf_config=nncf_config_path,
adapter_version=request.adapter_version,
custom_config=request.custom_config,
)

return fine_tuning_job_id


def handle_merge_adapter(request: MergeAdapterParams):
fine_tuning_job_id = request.fine_tuning_job_id
finetune_config_file = f"{JOBS_PATH}/{fine_tuning_job_id}.yaml"
finetune_config = parse_yaml_file_as(FinetuneConfig, finetune_config_file)

job = running_finetuning_jobs.get(fine_tuning_job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!")
finetuned_model_path = os.path.join(OUTPUT_DIR, fine_tuning_job_id)
assert finetuned_model_path == finetune_config.General.output_dir
if not os.path.exists(finetuned_model_path):
raise HTTPException(
status_code=404,
detail=f"The fine-tuned model saved by the fine-tuning job '{fine_tuning_job_id}' was not found!",
)
if job.status != "succeeded":
raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' has not completed!")

if finetune_config.General.lora_config is None:
raise HTTPException(
status_code=404,
detail=f"The fine-tuning job '{fine_tuning_job_id}' does not enable LoRA adapter fine-tuning!",
)

adapter_path = finetuned_model_path
adapter_version = request.adapter_version
if adapter_version is not None:
adapter_path = os.path.join(adapter_path, adapter_version)
if not os.path.exists(adapter_path):
raise HTTPException(
status_code=404,
detail=f"The fine-tuning job '{fine_tuning_job_id}' does not have a '{adapter_version}' adapter!",
)

from comps.finetuning.utils.merge_adapter import main as merge_adapter_main

merge_adapter_main(
base_model_path=finetune_config.General.base_model,
adapter_model_path=adapter_path,
output_path=os.path.join(adapter_path, "merged_model"),
)

return fine_tuning_job_id


def handle_list_finetuning_jobs():
finetuning_jobs_list = FineTuningJobList(data=list(running_finetuning_jobs.values()), has_more=False)

Expand Down
57 changes: 45 additions & 12 deletions comps/finetuning/llm_on_ray/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@

logger = CustomLogger("llm_on_ray/finetune")

try:
from nncf.experimental.torch.nas.bootstrapNAS.training.model_creator_helpers import (
create_compressed_model_from_algo_names,
)
from nncf.torch.model_creation import create_nncf_network

from comps.finetuning.utils.create_sqft_nncf_config import create_sqft_nncf_config

is_nncf_available = True
except ImportError:
is_nncf_available = False


def adapt_transformers_to_device(config: Dict):
device = config["Training"]["device"]
Expand Down Expand Up @@ -338,6 +350,7 @@ def load_model(config: Dict):
model_config = config["General"].get("config", {})
task = config["General"].get("task", "instruction_tuning")
ref_model = None
nls_controller = None
if task in ["instruction_tuning", "pretraining", "dpo"]:
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, **model_config)
if task == "dpo":
Expand All @@ -346,8 +359,21 @@ def load_model(config: Dict):
)
lora_config = config["General"].get("lora_config", None)
if lora_config and task == "instruction_tuning":
neural_lora_search = lora_config.pop("neural_lora_search", False)
target_module_groups = lora_config.pop("target_module_groups", None)
search_space = lora_config.pop("search_space", None)
peft_config = LoraConfig(**lora_config)
model = get_peft_model(model, peft_config)

# Neural LoRA Search (NLS)
if neural_lora_search:
if not is_nncf_available:
raise NotImplementedError("NNCF is not installed. Please install it for enabling NLS algorithm.")
nncf_config = create_sqft_nncf_config(
config=config, model=model, target_module_groups=target_module_groups, search_space=search_space
)
model = create_nncf_network(model, nncf_config)
nls_controller, model = create_compressed_model_from_algo_names(model, nncf_config, algo_names=["nls"])
elif task == "rerank":
model = CrossEncoder.from_pretrained(
config["Dataset"].get("train_group_size", 8),
Expand Down Expand Up @@ -383,10 +409,10 @@ def load_model(config: Dict):

model.to(dtype=model_dtype, device=torch.device(config["Training"]["device"]))

return model, ref_model
return model, ref_model, nls_controller


def get_trainer(config: Dict, model, ref_model, tokenizer, tokenized_dataset, data_collator):
def get_trainer(config: Dict, model, ref_model, tokenizer, tokenized_dataset, data_collator, nls_controller=None):
device = config["Training"]["device"]
task = config["General"].get("task", "instruction_tuning")
if device in ["cpu", "gpu", "cuda"]:
Expand All @@ -411,18 +437,23 @@ def get_trainer(config: Dict, model, ref_model, tokenizer, tokenized_dataset, da
max_length=config["Dataset"].get("max_length", 1024),
)
else:
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=(
trainer_args = {
"model": model,
"args": training_args,
"train_dataset": tokenized_dataset["train"],
"eval_dataset": (
tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None else None
),
tokenizer=tokenizer,
data_collator=data_collator,
)
"tokenizer": tokenizer,
"data_collator": data_collator,
}
if nls_controller is not None:
trainer_args["compression_ctrl"] = nls_controller
trainer = Trainer(**trainer_args)
return training_args, trainer
elif device in ["hpu"]:
if nls_controller is not None:
raise NotImplementedError("NLS algorithm is not supported on HPU now.")
from optimum.habana import GaudiConfig
from optimum.habana.transformers import GaudiTrainer, GaudiTrainingArguments

Expand Down Expand Up @@ -495,9 +526,11 @@ def train_func(config: Dict[str, Any]):

data_collator = prepare_data_collator(config, tokenizer)

model, ref_model = load_model(config)
model, ref_model, nls_controller = load_model(config)

training_args, trainer = get_trainer(config, model, ref_model, tokenizer, tokenized_dataset, data_collator)
training_args, trainer = get_trainer(
config, model, ref_model, tokenizer, tokenized_dataset, data_collator, nls_controller=nls_controller
)

logger.info("train start")
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
Expand Down
Loading
Loading