Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support Internvideo2 #1769

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ class ModelType:
internvl2_26b = 'internvl2-26b'
internvl2_40b = 'internvl2-40b'
internvl2_llama3_76b = 'internvl2-llama3-76b'
internvideo2_chat_8b = 'internvideo2-chat-8b'
# deepseek
deepseek_7b = 'deepseek-7b'
deepseek_7b_chat = 'deepseek-7b-chat'
Expand Down Expand Up @@ -540,6 +541,7 @@ class LoRATM(NamedTuple):
llava = f'{get_regex_for_mm_default_lora("llava")}'
internlm_xcomposer = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']
internvl = f'{get_regex_for_mm_default_lora("internvl")}'
internvideo = f'{get_regex_for_mm_default_lora("internvideo")}'
deepseek_vl = f'{get_regex_for_mm_default_lora("deepseek_vl")}'
minicpm_v = f'{get_regex_for_mm_default_lora("minicpm_v")}'
phi3v = f'{get_regex_for_mm_default_lora("phi3v")}'
Expand Down Expand Up @@ -955,7 +957,7 @@ def _new_forward(self, x):
support_flash_attn=True,
support_lmdeploy=True,
hf_model_id='hfl/chinese-alpaca-2-13b-16k')
def get_model_tokenizer_from_repo(model_dir: str,
def get_model_tokenizer_from_repo(model_dir: str,
torch_dtype: Optional[Dtype],
model_kwargs: Dict[str, Any],
load_model: bool = True,
Expand Down Expand Up @@ -4283,6 +4285,77 @@ def get_model_tokenizer_internvl(model_dir: str,
return model, tokenizer


@register_model(
ModelType.internvideo2_chat_8b,
'OpenGVLab/InternVideo2-Chat-8B',
LoRATM.internvideo,
TemplateType.internvideo2,
requires=['transformers>=4.35', 'timm'],
ignore_file_pattern=[r'.+\.zip$'],
support_flash_attn=True,
support_lmdeploy=True,
support_vllm=True,
placeholder_tokens=['<IMG_CONTEXT>'],
tags=['multi-modal', 'vision'],
hf_model_id='OpenGVLab/InternVideo2-Chat-8B')
def get_model_tokenizer_internvideo(model_dir: str,
torch_dtype: Dtype,
model_kwargs: Dict[str, Any],
load_model: bool = True,
**kwargs):
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
use_flash_attn = kwargs.pop('use_flash_attn', False)
model_config.model_config.use_flash_attention = use_flash_attn
model_quant_config = getattr(model_config, 'quantization_config', None)

use_bnb = False
if model_quant_config is not None:
use_bnb = model_quant_config.get('quant_method', None) == 'bitsandbytes'
quantization_config = model_kwargs.get('quantization_config', None)
if isinstance(quantization_config, BitsAndBytesConfig):
use_bnb = True

model_config.model_config.llm.pretrained_llm_path = snapshot_download('LLM-Research/Mistral-7B-Instruct-v0.3')

# replace the from_pretrained of BertTokenizer, BertConfig
# to let the Bert (inside the InternVideo2_VideoChat2) load from Ms but not HF
bert_model_dir = snapshot_download('AI-ModelScope/bert-base-uncased')

from transformers import BertTokenizer, BertConfig
_bert_tokenizer_old_from_pretrained = BertTokenizer.from_pretrained
@wraps(_bert_tokenizer_old_from_pretrained)
def new_from_pretrained(model_id, *args, **kwargs):
if model_id == 'bert-base-uncased':
model_id = bert_model_dir
return _bert_tokenizer_old_from_pretrained(model_id, *args, **kwargs)
BertTokenizer.from_pretrained = new_from_pretrained

_bert_config_old_from_pretrained = BertConfig.from_pretrained
@wraps(_bert_config_old_from_pretrained)
def new_from_pretrained(model_id, *args, **kwargs):
if model_id == 'bert-base-uncased':
model_id = bert_model_dir
return _bert_config_old_from_pretrained(model_id, *args, **kwargs)
BertConfig.from_pretrained = new_from_pretrained

# here we don't replace the _no_split_modules of the original model class
# because #modules inside is not one or two
# model_cls = get_class_from_dynamic_module('modeling_videochat2.InternVideo2_VideoChat2', model_dir)
# model_cls._no_split_modules = []
# bert_model_cls = get_class_from_dynamic_module('modeling_qformer.BertLMHeadModel', model_dir)
# bert_model_cls._no_split_modules = []
model_kwargs['device_map'] = torch.cuda.current_device()

tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, use_fast=False)
model, tokenizer = get_model_tokenizer_from_repo(
model_dir, torch_dtype, model_kwargs, load_model, tokenizer=tokenizer, model_config=model_config,
automodel_class=AutoModel, **kwargs)

model.base_model_prefix = 'lm'

return model, tokenizer


@register_model(
ModelType.internlm_xcomposer2_5_7b_chat,
'Shanghai_AI_Laboratory/internlm-xcomposer2d5-7b',
Expand Down
23 changes: 22 additions & 1 deletion swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class TemplateType:
internvl2 = 'internvl2'
internvl_phi3 = 'internvl-phi3'
internvl2_phi3 = 'internvl2-phi3'
internvideo2 = 'internvideo2'
florence = 'florence'
yi = 'yi'
yi1_5 = 'yi1_5'
Expand Down Expand Up @@ -963,7 +964,10 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
# multimodal
pixel_values = [b['pixel_values'] for b in batch if b.get('pixel_values') is not None]
if len(pixel_values) > 0:
res['pixel_values'] = torch.concat(pixel_values)
try:
res['pixel_values'] = torch.concat(pixel_values)
except:
raise ValueError('pixel_values is not aligned, cannot be concat.')

image_sizes = [b['image_sizes'] for b in batch if b.get('image_sizes') is not None]
if len(image_sizes) > 0:
Expand Down Expand Up @@ -1840,6 +1844,23 @@ class Internvl2Phi3Template(InternvlPhi3TemplateMixin, Internvl2Template):
register_template(TemplateType.internvl2_phi3, Internvl2Phi3Template(), use_model=True, lazy_tokenize=True)


class InternVideo2Template(Internvl2Template):
video_segments = 8
system = '你是由上海人工智能实验室联合商汤科技开发的多模态大模型, 英文名叫InternVideo, 是一个有用无害的人工智能助手。'

# def __init__(self):
# Template.__init__(
# self, [], ['<|im_start|>user\n{{QUERY}}<|im_end|><|im_start|>assistant\n'], ['<|im_end|>'], ['<|im_end|>'],
# self.system, ['<|im_start|>system\n{{SYSTEM}}<|im_end|>'],
# auto_add_bos=True)
def __init__(self):
Template.__init__(self, ['<s>[INST] '], ['{{QUERY}} [/INST]'], ['</s>'], ['</s>'],
system_prefix=['<<SYS>>\n{{system}}\n<</SYS>>\n\n'])


register_template(TemplateType.internvideo2, InternVideo2Template(), use_model=True, lazy_tokenize=True)


class FlorenceTemplate(Template):

def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _create_and_replace_hook2(self, *args, **kwargs):

is_multimodal = getattr(self.model, 'is_multimodal', False)

if is_multimodal and target and (not any(
if is_multimodal and target is not None and (not any(
[name in target.__class__.__name__.lower()
for name in all_supported_names]) and not any([isinstance(target, type) for type in all_supported_types])):
return
Expand Down
7 changes: 7 additions & 0 deletions swift/utils/module_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ class ModelKeys:
vision_tower='vision_model',
)

INTERNVIDEO_KEYS = MultiModelKeys(
language_model='lm',
projector=['project_up', 'qformer'],
vision_tower='vision_encoder',
)

DEEPSEEK_VL_KEYS = MultiModelKeys(
language_model='language_model',
projector='aligner',
Expand Down Expand Up @@ -283,6 +289,7 @@ class ModelKeys:
('llava', LLAVA_KEYS),
('internlm_xcomposer', INTERNLM_XCOMPOSER_KEYS),
('internvl', INTERNVL_KEYS),
('internvideo', INTERNVIDEO_KEYS),
('deepseek_vl', DEEPSEEK_VL_KEYS),
('minicpm_v', MINICPM_V_KEYS),
('phi3v', PHI3V_KEYS),
Expand Down
Loading