Skip to content

Commit

Permalink
fix max_length error print (#2960)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Jan 22, 2025
1 parent db01dea commit 23559d4
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 9 deletions.
2 changes: 1 addition & 1 deletion examples/export/quantize/mllm/awq.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ swift export \
--max_length 2048 \
--quant_method awq \
--quant_bits 4 \
--output_dir Qwen/Qwen2-VL-2B-Instruct-AWQ
--output_dir Qwen2-VL-2B-Instruct-AWQ
2 changes: 1 addition & 1 deletion examples/export/quantize/mllm/gptq.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ swift export \
--max_length 2048 \
--quant_method gptq \
--quant_bits 4 \
--output_dir Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
--output_dir Qwen2-VL-2B-Instruct-GPTQ-Int4
20 changes: 15 additions & 5 deletions swift/llm/dataset/preprocessor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def _fix_streaming_keys(row):
new_k = k[len('__@'):]
row[new_k] = row.pop(k)

def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool) -> Dict[str, Any]:
def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool,
ignore_max_length_error: bool) -> Dict[str, Any]:
from ...template import MaxLengthError
batched_row = dict(batched_row)
assert len(batched_row) > 0
self._fix_streaming_keys(batched_row)
Expand All @@ -162,13 +164,15 @@ def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool) -> Di
self._check_messages(r)
self._check_rejected_response(r)
self._cast_images(r)
except Exception:
except Exception as e:
if strict:
logger.warning('To avoid errors, you can pass `strict=False`.')
raise
if self.traceback_limit is not None and self._traceback_counter < self.traceback_limit:
if isinstance(e, MaxLengthError) and ignore_max_length_error:
pass
elif self.traceback_limit is not None and self._traceback_counter < self.traceback_limit:
import traceback
print(traceback.format_exc())
logger.info(traceback.format_exc())
logger.error('👆👆👆There are errors in the dataset, the data will be deleted')
self._traceback_counter += 1
row = []
Expand Down Expand Up @@ -256,15 +260,21 @@ def __call__(
dataset = self.prepare_dataset(dataset)
dataset = self._cast_pil_image(dataset)
map_kwargs = {}
ignore_max_length_error = False
if isinstance(dataset, HfDataset):
map_kwargs['num_proc'] = num_proc
if num_proc > 1:
ignore_max_length_error = True
with self._patch_arrow_writer():
try:
dataset_mapped = dataset.map(
self.batched_preprocess,
batched=True,
batch_size=batch_size,
fn_kwargs={'strict': strict},
fn_kwargs={
'strict': strict,
'ignore_max_length_error': ignore_max_length_error
},
remove_columns=list(dataset.features.keys()),
**map_kwargs)
except NotImplementedError:
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
raise
if self.traceback_limit is not None and self._traceback_counter < self.traceback_limit:
import traceback
print(traceback.format_exc())
logger.info(traceback.format_exc())
logger.error('👆👆👆There are errors in the template.encode, '
'and another piece of data will be randomly selected.')
self._traceback_counter += 1
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/infer/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def pre_infer_hook(kwargs):
res_or_gen = await self.infer_async(infer_request, request_config, template=self.template, **infer_kwargs)
except Exception as e:
import traceback
print(traceback.format_exc())
logger.info(traceback.format_exc())
return self.create_error_response(HTTPStatus.BAD_REQUEST, str(e))
if request_config.stream:

Expand Down

0 comments on commit 23559d4

Please sign in to comment.