Skip to content

Commit

Permalink
fix bugs (#2954)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Jan 22, 2025
1 parent c34b128 commit 7f0d9d6
Show file tree
Hide file tree
Showing 16 changed files with 56 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
- 🔥stream: 流式输出,默认为`False`
- stop_words: 额外的停止词,默认为`[]`
- logprobs: 是否输出logprobs,默认为False
- top_logprobs: 默认为None

### 量化参数
以下为拉起模型时量化的参数,具体含义可以查看[量化](https://huggingface.co/docs/transformers/main/en/main_classes/quantization)文档。这里不包含`swift export`中涉及的`gptq``awq`量化参数
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Refer to the [generation_config](https://huggingface.co/docs/transformers/main_c
- 🔥stream: Stream output, default is `False`.
- stop_words: Additional stop words, default is `[]`.
- logprobs: Whether to output logprobs, default is False.
- top_logprobs: Default is `None`.

### Quantization Arguments

Expand Down
3 changes: 2 additions & 1 deletion examples/custom/sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ swift sft \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--max_length 2048 \
--output_dir output
--output_dir output \
--dataset_num_proc 4
3 changes: 2 additions & 1 deletion examples/train/rlhf/cpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ swift rlhf \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--deepspeed zero2
--deepspeed zero2 \
--dataset_num_proc 4
3 changes: 2 additions & 1 deletion examples/train/rlhf/dpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ swift rlhf \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--deepspeed zero2
--deepspeed zero2 \
--dataset_num_proc 4
3 changes: 2 additions & 1 deletion examples/train/rlhf/kto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ swift rlhf \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--deepspeed zero2
--deepspeed zero2 \
--dataset_num_proc 4
3 changes: 2 additions & 1 deletion examples/train/rlhf/orpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ swift rlhf \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--deepspeed zero2
--deepspeed zero2 \
--dataset_num_proc 4
3 changes: 2 additions & 1 deletion examples/train/rlhf/ppo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ swift rlhf \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--deepspeed zero2 \
--response_length 512
--response_length 512 \
--dataset_num_proc 4
3 changes: 2 additions & 1 deletion examples/train/rlhf/rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ swift rlhf \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--deepspeed zero2
--deepspeed zero2 \
--dataset_num_proc 4
3 changes: 2 additions & 1 deletion examples/train/rlhf/simpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ swift rlhf \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--deepspeed zero2
--deepspeed zero2 \
--dataset_num_proc 4
4 changes: 3 additions & 1 deletion swift/llm/argument/base_args/generation_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class GenerationArguments:
stream: bool = False
stop_words: List[str] = field(default_factory=list)
logprobs: bool = False
top_logprobs: Optional[int] = None

def get_request_config(self):
if getattr(self, 'task_type') != 'causal_lm':
Expand All @@ -50,4 +51,5 @@ def get_request_config(self):
stop=self.stop_words,
stream=self.stream,
repetition_penalty=self.repetition_penalty,
logprobs=self.logprobs)
logprobs=self.logprobs,
top_logprobs=self.top_logprobs)
9 changes: 6 additions & 3 deletions swift/llm/infer/infer_engine/lmdeploy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def _prepare_generation_config(self, request_config: RequestConfig) -> LmdeployG
if request_config.top_logprobs is not None:
kwargs['logprobs'] = max(1, request_config.top_logprobs)

return LmdeployGenerationConfig(**kwargs)
res = LmdeployGenerationConfig(**kwargs)
res.top_logprobs = request_config.top_logprobs
return res

async def _infer_stream_async(
self, template: Template, inputs: Dict[str, Any],
Expand All @@ -204,7 +206,8 @@ async def _infer_stream_async(
if not delta_text and not is_finished:
continue

logprobs = self._get_logprobs(output.logprobs, output.token_ids[token_idx:], generation_config.logprobs)
logprobs = self._get_logprobs(output.logprobs, output.token_ids[token_idx:],
generation_config.top_logprobs)
token_idx = len(output.token_ids)

usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)
Expand Down Expand Up @@ -233,7 +236,7 @@ async def _infer_full_async(self, template: Template, inputs: Dict[str, Any],
pass

response = template.decode(output.token_ids)
logprobs = self._get_logprobs(output.logprobs, output.token_ids, generation_config.logprobs)
logprobs = self._get_logprobs(output.logprobs, output.token_ids, generation_config.top_logprobs)

usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)
toolcall = self._get_toolcall(response, template.tools_prompt)
Expand Down
12 changes: 8 additions & 4 deletions swift/llm/infer/infer_engine/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,10 @@ def _get_logprobs(self,
logprobs_list: Optional[List[Dict[int, float]]],
token_ids: List[int],
top_logprobs: Optional[int] = None) -> Optional[Dict[str, Any]]:
if logprobs_list is None:
if logprobs_list is None or len(token_ids) == 0:
return None
if len(token_ids) > 0:
logprobs_list = logprobs_list[-len(token_ids):]
for logprobs in logprobs_list:
for token_id, logprob in logprobs.items():
logprobs[token_id] = logprob.logprob
Expand All @@ -244,7 +246,9 @@ def _prepare_generation_config(self, request_config: RequestConfig) -> SamplingP
for key in ['n', 'best_of', 'frequency_penalty', 'presence_penalty', 'seed']:
kwargs[key] = getattr(request_config, key)

return SamplingParams(**kwargs)
res = SamplingParams(**kwargs)
res.top_logprobs = request_config.top_logprobs
return res

async def _infer_stream_async(self, template: Template, inputs: Dict[str, Any], generation_config: SamplingParams,
**kwargs) -> AsyncIterator[ChatCompletionStreamResponse]:
Expand All @@ -271,7 +275,7 @@ async def _infer_stream_async(self, template: Template, inputs: Dict[str, Any],
choices = []
for output in result.outputs:
logprobs = self._get_logprobs(output.logprobs, output.token_ids[token_idxs[output.index]:],
generation_config.logprobs)
generation_config.top_logprobs)
token_idxs[output.index] = len(output.token_ids)
toolcall = None
if output.is_finished:
Expand Down Expand Up @@ -301,7 +305,7 @@ async def _infer_full_async(self,
for output in result.outputs:
output.token_ids = list(output.token_ids)
response = template.decode(output.token_ids)
logprobs = self._get_logprobs(output.logprobs, output.token_ids, generation_config.logprobs)
logprobs = self._get_logprobs(output.logprobs, output.token_ids, generation_config.top_logprobs)
toolcall = self._get_toolcall(response, template.tools_prompt)
choice = ChatCompletionResponseChoice(
index=output.index,
Expand Down
17 changes: 17 additions & 0 deletions swift/llm/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,20 @@ def __new_init__(self, *args, **kwargs):
yield
finally:
PreTrainedModel.from_pretrained = from_pretrained


@contextmanager
def patch_automodel_for_awq():
from_pretrained = PreTrainedModel.from_pretrained

@classmethod
def _new_from_pretrained(cls, *args, **kwargs):
kwargs.pop('use_cache', None)
return from_pretrained.__func__(cls, *args, **kwargs)

PreTrainedModel.from_pretrained = _new_from_pretrained

try:
yield
finally:
PreTrainedModel.from_pretrained = from_pretrained
6 changes: 4 additions & 2 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from swift.utils import get_dist_setting, get_logger, is_mp, is_unsloth_available, patch_getattr, use_torchacc
from .constant import ModelType
from .patcher import patch_automodel_for_sequence_classification
from .patcher import patch_automodel_for_awq, patch_automodel_for_sequence_classification
from .utils import AttnImpl, HfConfigFactory, ModelInfo, safe_snapshot_download

GetModelTokenizerFunction = Callable[..., Tuple[Optional[PreTrainedModel], PreTrainedTokenizerBase]]
Expand Down Expand Up @@ -200,13 +200,15 @@ def get_model_tokenizer_from_local(model_dir: str,
except ValueError:
model = None

automodel_class = automodel_class or AutoModelForCausalLM
if model is None:
if model_info.task_type == 'seq_cls':
context = partial(patch_automodel_for_sequence_classification, model_meta=kwargs['model_meta'])
elif 'AutoAWQFor' in automodel_class.__name__:
context = patch_automodel_for_awq
else:
context = nullcontext
with context():
automodel_class = automodel_class or AutoModelForCausalLM
model = automodel_class.from_pretrained(
model_dir, config=model_config, torch_dtype=torch_dtype, trust_remote_code=True, **model_kwargs)

Expand Down
File renamed without changes.

0 comments on commit 7f0d9d6

Please sign in to comment.