From 66773fd202cbf58ee3ccb6d88ced76572245ef3d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 18 Aug 2024 17:51:54 +0800 Subject: [PATCH 1/3] cherry pick the internvideo2 part from branch `trsfm` --- swift/llm/utils/model.py | 73 ++++++++++++++++++++++++++++++++++- swift/llm/utils/template.py | 28 +++++++++++++- swift/utils/module_mapping.py | 7 ++++ 3 files changed, 105 insertions(+), 3 deletions(-) diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index 9c5ed983c..d88fcc11a 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -312,6 +312,7 @@ class ModelType: internvl2_26b = 'internvl2-26b' internvl2_40b = 'internvl2-40b' internvl2_llama3_76b = 'internvl2-llama3-76b' + internvideo2_chat_8b = 'internvideo2-chat-8b' # deepseek deepseek_7b = 'deepseek-7b' deepseek_7b_chat = 'deepseek-7b-chat' @@ -526,6 +527,7 @@ class LoRATM(NamedTuple): yi_vl = f'{get_regex_for_mm_default_lora("yi_vl")}' internlm_xcomposer = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3'] internvl = f'{get_regex_for_mm_default_lora("internvl")}' + internvideo = f'{get_regex_for_mm_default_lora("internvideo")}' deepseek_vl = f'{get_regex_for_mm_default_lora("deepseek_vl")}' paligemma = f'{get_regex_for_mm_default_lora("paligemma")}' minicpm_v = f'{get_regex_for_mm_default_lora("minicpm_v")}' @@ -943,7 +945,7 @@ def _new_forward(self, x): support_flash_attn=True, support_lmdeploy=True, hf_model_id='hfl/chinese-alpaca-2-13b-16k') -def get_model_tokenizer_from_repo(model_dir: str, +def get_model_tokenizer_from_repo(model_dir: str, torch_dtype: Optional[Dtype], model_kwargs: Dict[str, Any], load_model: bool = True, @@ -4336,6 +4338,75 @@ def new_get_rank(group=None): return model, tokenizer +@register_model( + ModelType.internvideo2_chat_8b, + 'OpenGVLab/InternVideo2-Chat-8B', + LoRATM.internvideo, + TemplateType.internvideo2, + requires=['transformers>=4.35', 'timm'], + ignore_file_pattern=[r'.+\.zip$'], + support_flash_attn=True, + support_lmdeploy=True, + support_vllm=True, + placeholder_tokens=[''], + tags=['multi-modal', 'vision'], + hf_model_id='OpenGVLab/InternVideo2-Chat-8B') +def get_model_tokenizer_internvideo(model_dir: str, + torch_dtype: Dtype, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + use_flash_attn = kwargs.pop('use_flash_attn', False) + model_config.model_config.use_flash_attention = use_flash_attn + model_quant_config = getattr(model_config, 'quantization_config', None) + + use_bnb = False + if model_quant_config is not None: + use_bnb = model_quant_config.get('quant_method', None) == 'bitsandbytes' + quantization_config = model_kwargs.get('quantization_config', None) + if isinstance(quantization_config, BitsAndBytesConfig): + use_bnb = True + + model_config.model_config.llm.pretrained_llm_path = snapshot_download('LLM-Research/Mistral-7B-Instruct-v0.3') + + # replace the from_pretrained of BertTokenizer, BertConfig + # to let the Bert (inside the InternVideo2_VideoChat2) load from Ms but not HF + bert_model_dir = snapshot_download('AI-ModelScope/bert-base-uncased') + + from transformers import BertTokenizer, BertConfig + _bert_tokenizer_old_from_pretrained = BertTokenizer.from_pretrained + @wraps(_bert_tokenizer_old_from_pretrained) + def new_from_pretrained(model_id, *args, **kwargs): + if model_id == 'bert-base-uncased': + model_id = bert_model_dir + return _bert_tokenizer_old_from_pretrained(model_id, *args, **kwargs) + BertTokenizer.from_pretrained = new_from_pretrained + + _bert_config_old_from_pretrained = BertConfig.from_pretrained + @wraps(_bert_config_old_from_pretrained) + def new_from_pretrained(model_id, *args, **kwargs): + if model_id == 'bert-base-uncased': + model_id = bert_model_dir + return _bert_config_old_from_pretrained(model_id, *args, **kwargs) + BertConfig.from_pretrained = new_from_pretrained + + # here we don't replace the _no_split_modules of the original model class + # because #modules inside is not one or two + # model_cls = get_class_from_dynamic_module('modeling_videochat2.InternVideo2_VideoChat2', model_dir) + # model_cls._no_split_modules = [] + # bert_model_cls = get_class_from_dynamic_module('modeling_qformer.BertLMHeadModel', model_dir) + # bert_model_cls._no_split_modules = [] + model_kwargs['device_map'] = torch.cuda.current_device() + + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, use_fast=False) + model, tokenizer = get_model_tokenizer_from_repo( + model_dir, torch_dtype, model_kwargs, load_model, tokenizer=tokenizer, model_config=model_config, + automodel_class=AutoModel, **kwargs) + + return model, tokenizer + + @register_model( ModelType.internlm_xcomposer2_5_7b_chat, 'Shanghai_AI_Laboratory/internlm-xcomposer2d5-7b', diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index e3f39f0e7..82423cf09 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -71,6 +71,7 @@ class TemplateType: internvl2 = 'internvl2' internvl_phi3 = 'internvl-phi3' internvl2_phi3 = 'internvl2-phi3' + internvideo2 = 'internvideo2' florence = 'florence' yi = 'yi' yi1_5 = 'yi1_5' @@ -849,7 +850,10 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = else: input_ids = [torch.tensor(b['input_ids']) for b in batch] attention_mask = [torch.ones(len(input_ids[i]), dtype=torch.int64) for i in range(len(input_ids))] - labels = [torch.tensor(b['labels']) for b in batch] + try: + labels = [torch.tensor(b['labels']) for b in batch] + except: + raise ValueError('label is None') loss_scale = [torch.tensor(b['loss_scale']) for b in batch] if 'loss_scale' in batch[0] else None padding_right = self.padding_side == 'right' @@ -912,7 +916,10 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = # multimodal pixel_values = [b['pixel_values'] for b in batch if b.get('pixel_values') is not None] if len(pixel_values) > 0: - res['pixel_values'] = torch.concat(pixel_values) + try: + res['pixel_values'] = torch.concat(pixel_values) + except: + raise ValueError('pixel_values is not aligned, cannot be concat.') image_sizes = [b['image_sizes'] for b in batch if b.get('image_sizes') is not None] if len(image_sizes) > 0: @@ -1752,6 +1759,23 @@ def __init__(self): register_template(TemplateType.internvl2_phi3, Internvl2Phi3Template(), use_model=True, lazy_tokenize=True) +class InternVideo2Template(Internvl2Template): + video_segments = 8 + system = '你是由上海人工智能实验室联合商汤科技开发的多模态大模型, 英文名叫InternVideo, 是一个有用无害的人工智能助手。' + + # def __init__(self): + # Template.__init__( + # self, [], ['<|im_start|>user\n{{QUERY}}<|im_end|><|im_start|>assistant\n'], ['<|im_end|>'], ['<|im_end|>'], + # self.system, ['<|im_start|>system\n{{SYSTEM}}<|im_end|>'], + # auto_add_bos=True) + def __init__(self): + Template.__init__(self, ['[INST] '], ['{{QUERY}} [/INST]'], [''], [''], + system_prefix=['<>\n{{system}}\n<>\n\n']) + + +register_template(TemplateType.internvideo2, InternVideo2Template(), use_model=True, lazy_tokenize=True) + + class FlorenceTemplate(Template): def __init__(self): diff --git a/swift/utils/module_mapping.py b/swift/utils/module_mapping.py index 7c952e88f..3ba5d15c1 100644 --- a/swift/utils/module_mapping.py +++ b/swift/utils/module_mapping.py @@ -212,6 +212,12 @@ class ModelKeys: vision_tower='vision_model', ) +INTERNVIDEO_KEYS = MultiModelKeys( + language_model='lm', + # projector='', + vision_tower='vision_encoder', +) + DEEPSEEK_VL_KEYS = MultiModelKeys( language_model='language_model', projector='aligner', @@ -279,6 +285,7 @@ class ModelKeys: ('yi_vl', LLAVA_LLAMA_KEYS), ('internlm_xcomposer', INTERNLM_XCOMPOSER_KEYS), ('internvl', INTERNVL_KEYS), + ('internvideo', INTERNVIDEO_KEYS), ('deepseek_vl', DEEPSEEK_VL_KEYS), ('paligemma', LLAVA_KEYS), ('minicpm_v', MINICPM_V_KEYS), From 0bcc209a5c311f5b25fbd441dd17f8caec9d3650 Mon Sep 17 00:00:00 2001 From: DaozeZhang Date: Mon, 19 Aug 2024 15:44:48 +0800 Subject: [PATCH 2/3] modify projector to list --- swift/utils/module_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/utils/module_mapping.py b/swift/utils/module_mapping.py index 3ba5d15c1..d63ca6cff 100644 --- a/swift/utils/module_mapping.py +++ b/swift/utils/module_mapping.py @@ -214,7 +214,7 @@ class ModelKeys: INTERNVIDEO_KEYS = MultiModelKeys( language_model='lm', - # projector='', + projector=['project_up', 'qformer'], vision_tower='vision_encoder', ) From c9d1f4a52681a29ceecd4287eac55915b238bdd7 Mon Sep 17 00:00:00 2001 From: DaozeZhang Date: Mon, 19 Aug 2024 20:02:35 +0800 Subject: [PATCH 3/3] update --- swift/llm/utils/model.py | 2 ++ swift/tuners/peft.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index d88fcc11a..fca5a795c 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -4404,6 +4404,8 @@ def new_from_pretrained(model_id, *args, **kwargs): model_dir, torch_dtype, model_kwargs, load_model, tokenizer=tokenizer, model_config=model_config, automodel_class=AutoModel, **kwargs) + model.base_model_prefix = 'lm' + return model, tokenizer diff --git a/swift/tuners/peft.py b/swift/tuners/peft.py index cd6f37832..5c5f890c7 100644 --- a/swift/tuners/peft.py +++ b/swift/tuners/peft.py @@ -105,7 +105,7 @@ def _create_and_replace_hook2(self, *args, **kwargs): is_multimodal = getattr(self.model, 'is_multimodal', False) - if is_multimodal and target and (not any( + if is_multimodal and target is not None and (not any( [name in target.__class__.__name__.lower() for name in all_supported_names]) and not any([isinstance(target, type) for type in all_supported_types])): return