Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix max_length error print #2960

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/export/quantize/mllm/awq.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ swift export \
--max_length 2048 \
--quant_method awq \
--quant_bits 4 \
--output_dir Qwen/Qwen2-VL-2B-Instruct-AWQ
--output_dir Qwen2-VL-2B-Instruct-AWQ
2 changes: 1 addition & 1 deletion examples/export/quantize/mllm/gptq.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ swift export \
--max_length 2048 \
--quant_method gptq \
--quant_bits 4 \
--output_dir Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
--output_dir Qwen2-VL-2B-Instruct-GPTQ-Int4
20 changes: 15 additions & 5 deletions swift/llm/dataset/preprocessor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def _fix_streaming_keys(row):
new_k = k[len('__@'):]
row[new_k] = row.pop(k)

def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool) -> Dict[str, Any]:
def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool,
ignore_max_length_error: bool) -> Dict[str, Any]:
from ...template import MaxLengthError
batched_row = dict(batched_row)
assert len(batched_row) > 0
self._fix_streaming_keys(batched_row)
Expand All @@ -162,13 +164,15 @@ def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool) -> Di
self._check_messages(r)
self._check_rejected_response(r)
self._cast_images(r)
except Exception:
except Exception as e:
if strict:
logger.warning('To avoid errors, you can pass `strict=False`.')
raise
if self.traceback_limit is not None and self._traceback_counter < self.traceback_limit:
if isinstance(e, MaxLengthError) and ignore_max_length_error:
pass
elif self.traceback_limit is not None and self._traceback_counter < self.traceback_limit:
import traceback
print(traceback.format_exc())
logger.info(traceback.format_exc())
logger.error('👆👆👆There are errors in the dataset, the data will be deleted')
self._traceback_counter += 1
row = []
Expand Down Expand Up @@ -256,15 +260,21 @@ def __call__(
dataset = self.prepare_dataset(dataset)
dataset = self._cast_pil_image(dataset)
map_kwargs = {}
ignore_max_length_error = False
if isinstance(dataset, HfDataset):
map_kwargs['num_proc'] = num_proc
if num_proc > 1:
ignore_max_length_error = True
with self._patch_arrow_writer():
try:
dataset_mapped = dataset.map(
self.batched_preprocess,
batched=True,
batch_size=batch_size,
fn_kwargs={'strict': strict},
fn_kwargs={
'strict': strict,
'ignore_max_length_error': ignore_max_length_error
},
remove_columns=list(dataset.features.keys()),
**map_kwargs)
except NotImplementedError:
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
raise
if self.traceback_limit is not None and self._traceback_counter < self.traceback_limit:
import traceback
print(traceback.format_exc())
logger.info(traceback.format_exc())
logger.error('👆👆👆There are errors in the template.encode, '
'and another piece of data will be randomly selected.')
self._traceback_counter += 1
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/infer/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def pre_infer_hook(kwargs):
res_or_gen = await self.infer_async(infer_request, request_config, template=self.template, **infer_kwargs)
except Exception as e:
import traceback
print(traceback.format_exc())
logger.info(traceback.format_exc())
return self.create_error_response(HTTPStatus.BAD_REQUEST, str(e))
if request_config.stream:

Expand Down
Loading