Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Jan 27, 2025
1 parent 80b9a74 commit 82a6274
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 16 deletions.
4 changes: 2 additions & 2 deletions docs/source/Customization/自定义数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ RLHF的数据格式可以参考纯文本大模型的格式。
该格式比通用格式多了objects字段,该字段包含的字段有:
- ref:用于替换`<ref-object>`
- bbox:用于替换`<bbox>`
- bbox_type: 可选项为'real','norm1'。默认为real,即bbox为真实bbox值。若是'norm1',则bbox已经归一化为0~1
- image_id: 该参数只有当bbox_type为real时生效。代表bbox对应的图片是第几张,用于缩放bbox。索引从0开始,默认全为第0张
- bbox_type: 可选项为'real','norm1'。默认为'real',即bbox为真实bbox值。若是'norm1',则bbox已经归一化为0~1
- image_id: 该参数只有当bbox_type为'real'时生效。代表bbox对应的图片是第几张,用于缩放bbox。索引从0开始,默认全为第0张

### 文生图格式

Expand Down
29 changes: 16 additions & 13 deletions swift/llm/dataset/dataset/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,12 +777,10 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
bbox[i] = round(float(bbox[i]))
res = {}

objects = [{
'caption': caption,
'bbox': bbox,
'bbox_type': 'real',
'image': 0,
}]
objects = {
'ref': [caption],
'bbox': [bbox],
}
res['query'], res['response'] = self.construct_grounding_prompt()
res['images'] = [image_path]
res['objects'] = objects
Expand Down Expand Up @@ -996,10 +994,14 @@ def replace_intervals_with_tags(response, start_ends):
return ''.join(result)

def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
images = row['url']
images = row['images']
caption = row['caption']
ref_exps = row['ref_exps']
objects = []
objects = {
'ref': [],
'bbox': [],
'bbox_type': 'norm1'
}
start_end_pairs = []
for ref_exp in ref_exps:
start = ref_exp[0]
Expand All @@ -1008,10 +1010,11 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
start_end_pairs.append(ref_exp[0:2])

object_part = caption[int(start):int(end)]
objects.append({'caption': object_part, 'bbox': ref_exp[2:6], 'bbox_type': 'real', 'image': 0})
objects['ref'].append(object_part)
objects['bbox'].append(ref_exp[2:6])

start_end_pairs.sort(key=lambda x: (x[0], x[1]))
if self.has_overlap(start_end_pairs) or not objects:
if self.has_overlap(start_end_pairs) or not ref_exps:
return

if self.task_type in ('grounding', 'caption'):
Expand All @@ -1038,15 +1041,15 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
hf_dataset_id='zzliang/GRIT',
subsets=[
SubsetDataset(
subset='caption',
name='caption',
preprocess_func=GritPreprocessor('caption', columns_mapping={'url': 'images'}),
),
SubsetDataset(
subset='grounding',
name='grounding',
preprocess_func=GritPreprocessor('grounding', columns_mapping={'url': 'images'}),
),
SubsetDataset(
subset='vqa',
name='vqa',
preprocess_func=GritPreprocessor('vqa', columns_mapping={'url': 'images'}),
)
],
Expand Down
6 changes: 5 additions & 1 deletion swift/llm/template/grounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ def normalize_bbox(images: List[Image.Image],
return
bbox_list = objects['bbox']
ref_list = objects['ref']
bbox_type = objects.get('bbox_type') or 'real'
image_id_list = objects.get('image_id') or []
image_id_list += [0] * (len(ref_list) - len(image_id_list))
for bbox, ref, image_id in zip(bbox_list, ref_list, image_id_list):
image = images[image_id]
if norm_bbox == 'norm1000':
width, height = image.width, image.height
if bbox_type == 'norm1':
width, height = 1, 1
else:
width, height = image.width, image.height
for i, (x, y) in enumerate(zip(bbox[::2], bbox[1::2])):
bbox[2 * i] = int(x / width * 1000)
bbox[2 * i + 1] = int(y / height * 1000)

0 comments on commit 82a6274

Please sign in to comment.