To run fine-tuning on multi-GPUs, we will make use of two packages:
-
PEFT methods and in particular using the Hugging Face PEFTlibrary.
-
FSDP which helps us parallelize the training over multiple GPUs. More details.
Given the combination of PEFT and FSDP, we would be able to fine tune a Meta Llama 8B model on multiple GPUs in one node. For big models like 405B we will need to fine-tune in a multi-node setup even if 4bit quantization is enabled.
To run the examples, make sure to install the llama-recipes package and clone the github repository in order to use the provided finetuning.py
script with torchrun (See README.md for details).
Get access to a machine with multiple GPUs ( in this case we tested with 4 A100 and A10s).
This runs with the samsum_dataset
for summarization application by default.
Multiple GPUs one node:
NOTE please make sure to use PyTorch Nightlies for using PEFT+FSDP. Also, note that int8 quantization from bit&bytes currently is not supported in FSDP.
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name /path_of_model_folder/8B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model
The args used in the command above are:
-
--enable_fsdp
boolean flag to enable FSDP in the script -
--use_peft
boolean flag to enable PEFT methods in the script -
--peft_method
to specify the PEFT method, here we uselora
other options arellama_adapter
.
We use torchrun
here to spawn multiple processes for FSDP.
Setting use_fast_kernels
will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in optimum
library from HuggingFace as a one-liner API, please read more here.
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name /path_of_model_folder/8B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --use_fast_kernels
If interested in running full parameter finetuning without making use of PEFT methods, please use the following command. Make sure to change the nproc_per_node
to your available GPUs. This has been tested with BF16
on 8xA100, 40GB GPUs.
torchrun --nnodes 1 --nproc_per_node 8 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name /path_of_model_folder/8B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --fsdp_config.pure_bf16 --use_fast_kernels
This has been tested on 4 H100s GPUs.
FSDP_CPU_RAM_EFFICIENT_LOADING=1 ACCELERATE_USE_FSDP=1 torchrun --nnodes 1 --nproc_per_node 4 finetuning.py --enable_fsdp --quantization 4bit --model_name /path_of_model_folder/70B --mixed_precision False --low_cpu_fsdp --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model
If you are interested in running full parameter fine-tuning on the 70B model, you can enable low_cpu_fsdp
mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with BF16
on 16xA100, 80GB GPUs.
torchrun --nnodes 1 --nproc_per_node 8 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /path_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
Multi GPU multi node:
Here we use a slurm script to schedule a job with slurm over multiple nodes.
sbatch recipes/quickstart/finetuning/multi_node.slurm
# Change the num nodes and GPU per nodes in the script before running.
Currently 4 datasets are supported that can be found in Datasets config file.
-
grammar_dataset
: use this notebook to pull and process theJfleg and C4 200M datasets for grammar checking. -
alpaca_dataset
: to get this open source data please download theaplaca.json
todataset
folder.
wget -P src/llama_recipes/datasets https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json
samsum_dataset
To run with each of the datasets set the dataset
flag in the command as shown below:
# grammer_dataset
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name /path_of_model_folder/8B --use_peft --peft_method lora --dataset grammar_dataset --save_model --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model
# alpaca_dataset
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name /path_of_model_folder/8B --use_peft --peft_method lora --dataset alpaca_dataset --save_model --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model
# samsum_dataset
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name /path_of_model_folder/8B --use_peft --peft_method lora --dataset samsum_dataset --save_model --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model
- Training config file is the main config file that helps to specify the settings for our run and can be found in configs folder
It lets us specify the training settings for everything from model_name
to dataset_name
, batch_size
and so on. Below is the list of supported settings:
model_name: str="PATH/to/Model"
tokenizer_name: str=None
enable_fsdp: bool=False
low_cpu_fsdp: bool=False
run_validation: bool=True
batch_size_training: int=4
batching_strategy: str="packing" #alternative: padding
context_length: int=4096
gradient_accumulation_steps: int=1
gradient_clipping: bool = False
gradient_clipping_threshold: float = 1.0
num_epochs: int=3
max_train_step: int=0
max_eval_step: int=0
num_workers_dataloader: int=1
lr: float=1e-4
weight_decay: float=0.0
gamma: float= 0.85
seed: int=42
use_fp16: bool=False
mixed_precision: bool=True
val_batch_size: int=1
dataset = "samsum_dataset"
peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
use_peft: bool=False
from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
output_dir: str = "PATH/to/save/PEFT/model"
freeze_layers: bool = False
num_freeze_layers: int = 1
quantization: bool = False
one_gpu: bool = False
save_model: bool = True
dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
save_optimizer: bool=False # will be used if using FSDP
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
use_wandb: bool = False # Enable wandb for experient tracking
save_metrics: bool = False # saves training metrics to a json file for later plotting
flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
-
Datasets config file provides the available options for datasets.
-
peft config file provides the supported PEFT methods and respective settings that can be modified.
-
FSDP config file provides FSDP settings such as:
-
mixed_precision
boolean flag to specify using mixed precision, defatults to true. -
use_fp16
boolean flag to specify using FP16 for mixed precision, defatults to False. We recommend not setting this flag, and only setmixed_precision
that will useBF16
, this will help with speed and memory savings while avoiding challenges of scaler accuracies withFP16
. -
sharding_strategy
this specifies the sharding strategy for FSDP, it can be:-
FULL_SHARD
that shards model parameters, gradients and optimizer states, results in the most memory savings. -
SHARD_GRAD_OP
that shards gradinets and optimizer states and keeps the parameters after the firstall_gather
. This reduces communication overhead specially if you are using slower networks more specifically beneficial on multi-node cases. This comes with the trade off of higher memory consumption. -
NO_SHARD
this is equivalent to DDP, does not shard model parameters, gradinets or optimizer states. It keeps the full parameter after the firstall_gather
. -
HYBRID_SHARD
available on PyTorch Nightlies. It does FSDP within a node and DDP between nodes. It's for multi-node cases and helpful for slower networks, given your model will fit into one node.
-
-
-
checkpoint_type
specifies the state dict checkpoint type for saving the model.FULL_STATE_DICT
streams state_dict of each model shard from a rank to CPU and assembels the full state_dict on CPU.SHARDED_STATE_DICT
saves one checkpoint per rank, and enables the re-loading the model in a different world size. -
fsdp_activation_checkpointing
enables activation checkpoining for FSDP, this saves significant amount of memory with the trade off of recomputing itermediate activations during the backward pass. The saved memory can be re-invested in higher batch sizes to increase the throughput. We recommend you use this option. -
fsdp_config.pure_bf16
it moves the model toBFloat16
and ifoptimizer
is set toanyprecision
then optimizer states will be kept inBFloat16
as well. You can use this option if necessary.
To help with benchmarking effort, we are adding the support for counting the FLOPS during the fine-tuning process. You can achieve this by setting --flop_counter
when launching your single/multi GPU fine-tuning. Use --flop_counter_start
to choose which step to count the FLOPS. It is recommended to allow a warm-up stage before using the FLOPS counter.
Similarly, you can set --use_profiler
flag and pass a profiling output path using --profiler_dir
to capture the profile traces of your model using PyTorch profiler. To get accurate profiling result, the pytorch profiler requires a warm-up stage and the current config is wait=1, warmup=2, active=3, thus the profiler will start the profiling after step 3 and will record the next 3 steps. Therefore, in order to use pytorch profiler, the --max-train-step has been greater than 6. The pytorch profiler would be helpful for debugging purposes. However, the --flop_counter
and --use_profiler
can not be used in the same time to ensure the measurement accuracy.