Skip to content

Commit

Permalink
Merge pull request #883 from OptimalScale/yizhenjia-idpo-finalize
Browse files Browse the repository at this point in the history
[Feature] Iterative DPO
  • Loading branch information
wheresmyhair authored Sep 4, 2024
2 parents 0737b2b + 4ae273a commit 8b31bfd
Show file tree
Hide file tree
Showing 26 changed files with 1,503 additions and 132 deletions.
21 changes: 21 additions & 0 deletions configs/accelerate_dsz0_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 16
zero3_init_flag: false
zero_stage: 0
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
gpu_ids:
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 12580
22 changes: 22 additions & 0 deletions configs/accelerate_dsz2_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
gpu_ids:
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 12580
1 change: 1 addition & 0 deletions configs/accelerate_dsz3_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 12580
89 changes: 89 additions & 0 deletions configs/iterative_dpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# general
## model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
reference_model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
reward_model_name_or_path: sfairXC/FsfairX-LLaMA3-RM-v0.1
trust_remote_code: True

## data
dataset_path_list:
- data/iterative-prompt-3it/iter1
- data/iterative-prompt-3it/iter2
- data/iterative-prompt-3it/iter3
conversation_template: llama3
preprocessing_num_workers: 16

## pipeline
output_dir: ./output_models/iterative_dpo
run_name: iterative_dpo
random_seed: 42
use_accelerator: True
enable_distributed_inference: True
distributed_inference_num_instances: 8
initial_iter_idx: 0 # 0 refers to the first dataset in dataset_path_list
do_response_generation: True
do_scoring: True
do_dpo_align: True


# inference phase
## general
apply_chat_template: True
num_output_sequences: 8
use_beam_search: False
temperature: 1.0
top_p: 1.0
max_new_tokens: 2048
enable_decode_inference_result: True

## vllm
use_vllm: True
vllm_gpu_memory_utilization: 0.95
vllm_tensor_parallel_size: 1
vllm_inference_batch_size: 16


# reward model scoring phase
reward_arch_type: text_regression
reward_torch_dtype: bf16
reward_use_flash_attention: True
reward_model_inference_block_size: 2048
overwrite_cache: True
reward_model_inference_batch_size: 10 # the actual batch size for rm forward will be reward_model_inference_batch_size * num_output_sequences


# dpo phase
## model
do_train: True
use_flash_attention: True

## data
sampling_paired_method: max_min
margin_scale: 1.0
length_penalty: 0
max_prompt_length: 1000
mask_prompt: False

## pipeline
### training
accelerate_config_file: configs/accelerate_dsz2_config.yaml
bf16: True
num_train_epochs: 2
max_steps: 1200
learning_rate: 5.0e-7
warmup_steps: 100
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
gradient_accumulation_steps: 16
gradient_checkpointing: True
loss_type: sigmoid
lr_scheduler_type: cosine
optim: paged_adamw_32bit

### logging
logging_steps: 2
save_strategy: steps
save_steps: 500
evaluation_strategy: steps
eval_steps: 500
report_to: wandb
87 changes: 87 additions & 0 deletions examples/iterative_dpo_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import logging
import os
import sys
import copy

from transformers import (
HfArgumentParser
)

from lmflow.datasets import Dataset
from lmflow.pipeline.auto_pipeline import AutoPipeline
from lmflow.args import (
ModelArguments,
DatasetArguments,
AutoArguments,
)
from lmflow.utils.common import remove_dataclass_attr_prefix, create_copied_dataclass


logger = logging.getLogger(__name__)


# NOTE:
# In training processes that needs more than one model such as dpo (reference & target),
# ppo (actor & critic), etc., we use the following function to create separate model arguments
# to distinguish among them.
ReferenceModelArguments = create_copied_dataclass(
original_dataclass=ModelArguments,
field_prefix="reference_",
class_prefix="Reference"
)

RewardModelArguments = create_copied_dataclass(
original_dataclass=ModelArguments,
field_prefix="reward_",
class_prefix="Reward"
)


def main():
pipeline_name = "iterative_dpo_aligner"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((
ModelArguments,
ReferenceModelArguments,
RewardModelArguments,
DatasetArguments,
PipelineArguments
))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, ref_model_args, reward_model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, ref_model_args, reward_model_args, data_args, pipeline_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1]))
else:
model_args, ref_model_args, reward_model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses()

ref_model_args_dict = remove_dataclass_attr_prefix(ref_model_args, "reference_")
ref_model_args = ModelArguments(**ref_model_args_dict)
reward_model_args_dict = remove_dataclass_attr_prefix(reward_model_args, "reward_")
reward_model_args = ModelArguments(**reward_model_args_dict)

dataset_list = []
for dataset in pipeline_args.dataset_path_list:
iter_data_args = copy.deepcopy(data_args)
iter_data_args.dataset_path = dataset
dataset_list.append(Dataset(iter_data_args))

aligner = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
ref_model_args=ref_model_args,
reward_model_args=reward_model_args,
)

aligner.align(dataset_list=dataset_list)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions examples/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def main():
dataset,
release_gpu=False,
enable_decode_inference_result=pipeline_args.enable_decode_inference_result,
enable_distributed_vllm_inference=pipeline_args.enable_distributed_vllm_inference,
)


Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ pydantic
gradio
accelerate>=0.27.2
einops>=0.6.1
vllm>=0.4.1
vllm>=0.4.3
ray>=2.22.0
1 change: 1 addition & 0 deletions scripts/run_iterative_dpo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python examples/iterative_dpo_train.py configs/iterative_dpo.yaml
1 change: 1 addition & 0 deletions scripts/run_vllm_inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,5 @@ python examples/vllm_inference.py \
--enable_decode_inference_result False \
--vllm_gpu_memory_utilization 0.95 \
--vllm_tensor_parallel_size 2 \
--enable_distributed_vllm_inference False \
2>&1 | tee ${log_dir}/vllm_inference.log
64 changes: 63 additions & 1 deletion src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,10 @@ class InferencerArguments:
default=1,
metadata={"help": "batch size for inference"},
)
vllm_inference_batch_size: int = field(
default=1,
metadata={"help": "The batch size for VLLM inference."}
)
temperature: float = field(
default=0.0,
metadata={"help": "Temperature during inference."},
Expand Down Expand Up @@ -1072,6 +1076,18 @@ class InferencerArguments:
default=False,
metadata={"help": "Whether to decode the inference results."},
)
tensor_parallel_size: Optional[int] = field(
default=1,
metadata={"help": "The tp size for distributed (multi-instance) inference."}
)
enable_distributed_inference: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use multi-instance VLLM inference."}
)
distributed_inference_num_instances: Optional[int] = field(
default=1,
metadata={"help": "The number of instances for multi-instance VLLM inference."}
)

# vllm inference args
use_vllm: bool = field(
Expand Down Expand Up @@ -1354,6 +1370,12 @@ class DPOv2AlignerArguments(FinetunerArguments):
"""
The arguments for the DPOv2 training script.
"""
# general args
random_seed: Optional[int] = field(default=42, metadata={"help": "the random seed"})
accelerate_config_file: Optional[str] = field(
default=None,
metadata={"help": "file path for accelerate config file, only used in memory safe dpov2 align."}
)
# pair sampling args
margin_scale: Optional[float] = field(default=1.0, metadata={"help": "the margin scale"})
sampling_paired_method: Optional[str] = field(default="max_random", metadata={"help": "the choose type"})
Expand All @@ -1372,7 +1394,46 @@ class IterativeAlignerArguments(InferencerArguments):
"""
Arguments for iterative aligners.
"""
pass
dataset_path_list: List[str] = field(
default_factory=list,
metadata={"help": "The list of dataset paths for iterative aligners."}
)
initial_iter_idx: int = field(
default=0,
metadata={"help": "The initial iteration index, 0 refers to the first dataset in dataset_path_list."}
)



@dataclass
class IterativeDPOAlignerArguments(IterativeAlignerArguments, DPOv2AlignerArguments):
"""
Arguments for iterative DPO aligners.
"""
output_dir: Optional[str] = field(
default="./runs",
metadata={"help": "Output path for the inferenced results"},
)
reward_model_inference_batch_size: int = field(
default=1,
metadata={"help": "The batch size for reward model inference."}
)
reward_model_inference_block_size: int = field(
default=2048,
metadata={"help": "The block size for reward model inference."}
)
do_response_generation: bool = field(
default=True,
metadata={"help": "Whether to generate responses using the model."}
)
do_scoring: bool = field(
default=True,
metadata={"help": "Whether to score the responses using the reward model."}
)
do_dpo_align: bool = field(
default=True,
metadata={"help": "Whether to perform DPO alignment."}
)


PIPELINE_ARGUMENT_MAPPING = {
Expand All @@ -1385,6 +1446,7 @@ class IterativeAlignerArguments(InferencerArguments):
"dpo_aligner": DPOAlignerArguments,
"rm_tuner": RewardModelTunerArguments,
"dpov2_aligner": DPOv2AlignerArguments,
"iterative_dpo_aligner": IterativeDPOAlignerArguments,
}


Expand Down
Loading

0 comments on commit 8b31bfd

Please sign in to comment.