Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Internvl2TemplateWithAngles template #1784

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ python llm_infer.py \
--repetition_penalty 1. \
--do_sample true \
--merge_lora false
swdf
2 changes: 1 addition & 1 deletion swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def new_func(self, *args, **kwargs):

standard_keys = {
'query', 'query_role', 'response', 'rejected_response', 'system', 'history', 'history_roles', 'images', 'objects',
'videos', 'audios', 'tools', 'label'
'videos', 'audios', 'tools', 'label', 'angles'
}


Expand Down
47 changes: 47 additions & 0 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class TemplateType:
internlm_xcomposer2_5 = 'internlm-xcomposer2_5'
internvl = 'internvl'
internvl2 = 'internvl2'
internvl2_angle = 'internvl2-angle'
internvl_phi3 = 'internvl-phi3'
internvl2_phi3 = 'internvl2-phi3'
florence = 'florence'
Expand Down Expand Up @@ -2524,6 +2525,48 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
inputs.pop('loss_scale', None)
return inputs, {}

class Internvl2TemplateWithAngles(Internvl2Template):
def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, _ = super(InternvlTemplate, self)._encode(example)
if len(inputs) == 0:
return inputs, {}
input_ids = inputs['input_ids']
idx_list = _findall(input_ids, -100)
labels = inputs.get('labels')
images = example.get('images')
angles = example.get('angles')
if images:
has_video = bool(example.get('videos'))
if angles:
assert len(images) == len(angles), "len(angles) must be equal to len(images)!"
images = [
image.rotate(360-int(angle), expand=True) if int(angle) != 0 else image
for image, angle in zip(images, angles)
]
pixel_values = [transform_image(image, max_num=1 if has_video else 12) for image in images]
num_patches = [pv.shape[0] for pv in pixel_values]
pixel_values = torch.cat(pixel_values).to(self.model.dtype)
else:
pixel_values = None
num_patches = []
assert len(num_patches) == len(
idx_list), f'len(num_patches): {len(num_patches)}, len(idx_list): {len(idx_list)}'
added_tokens_len = 0
for idx, num_patch in zip(idx_list, num_patches):
img_tokens: List[int] = self.tokenizer.encode(
'<IMG_CONTEXT>', add_special_tokens=False) * self.num_image_token * num_patch
input_ids = input_ids[:idx + added_tokens_len] + img_tokens + input_ids[idx + added_tokens_len + 1:]
if labels is not None:
labels = labels[:idx + added_tokens_len] + [-100] * len(img_tokens) + labels[idx + added_tokens_len
+ 1:]
added_tokens_len += len(img_tokens) - 1
inputs['input_ids'] = input_ids
inputs['labels'] = labels
inputs['_data'] = {'input_ids': torch.tensor(input_ids), 'pixel_values': pixel_values}
inputs.pop('loss_scale', None)
return inputs, {}



class InternvlPhi3TemplateMixin:

Expand All @@ -2543,6 +2586,7 @@ class Internvl2Phi3Template(InternvlPhi3TemplateMixin, Internvl2Template):
pass



register_template(
TemplateType.internvl, InternvlTemplate(), use_model=True, lazy_tokenize=True, infer_media_type='dialogue')

Expand All @@ -2551,6 +2595,9 @@ class Internvl2Phi3Template(InternvlPhi3TemplateMixin, Internvl2Template):

register_template(TemplateType.internvl2, Internvl2Template(), use_model=True, lazy_tokenize=True)

# 支持图片角度输入
register_template(TemplateType.internvl2_angle, Internvl2TemplateWithAngles(), use_model=True, lazy_tokenize=True)

register_template(TemplateType.internvl2_phi3, Internvl2Phi3Template(), use_model=True, lazy_tokenize=True)


Expand Down