From 31376c59241c97311bc8cfed08be60991d96de7a Mon Sep 17 00:00:00 2001
From: congw729 <115451386+congw729@users.noreply.github.com>
Date: Thu, 26 Sep 2024 15:02:21 +0800
Subject: [PATCH] ShareGPT4v infer codes. Tested on 910B with MS2.3.1.
---
examples/sharegpt_4v/readme.md | 13 +
examples/sharegpt_4v/share4v/__init__.py | 1 +
.../sharegpt_4v/share4v/configs/config.json | 46 +
.../share4v/configs/tokenizer_config.json | 36 +
.../share4v/configs/vit/config.json | 23 +
.../configs/vit/preprocessor_config.json | 29 +
examples/sharegpt_4v/share4v/constants.py | 12 +
examples/sharegpt_4v/share4v/conversation.py | 369 ++++++
examples/sharegpt_4v/share4v/mm_utils.py | 105 ++
.../sharegpt_4v/share4v/model/__init__.py | 3 +
examples/sharegpt_4v/share4v/model/builder.py | 74 ++
.../model/language_model/share4v_llama.py | 423 +++++++
.../model/multimodal_encoder/builder.py | 15 +
.../model/multimodal_encoder/clip_encoder.py | 116 ++
.../model/multimodal_projector/builder.py | 57 +
.../sharegpt_4v/share4v/model/share4v_arch.py | 352 +++++
examples/sharegpt_4v/share4v/model/utils.py | 20 +
.../sharegpt_4v/share4v/pipeline/__init__.py | 1 +
.../share4v/pipeline/helpers/__init__.py | 1 +
.../pipeline/helpers/stopping_criteria.py | 56 +
.../share4v/pipeline/text_generation.py | 275 ++++
.../share4v/transformers/__init__.py | 9 +
.../share4v/transformers/activations_ms.py | 218 ++++
.../share4v/transformers/modeling_ms_utils.py | 1127 +++++++++++++++++
.../share4v/transformers/models/__init__.py | 2 +
.../share4v/transformers/models/cache.py | 67 +
.../transformers/models/clip/__init__.py | 7 +
.../models/clip/modeling_ms_clip.py | 965 ++++++++++++++
.../transformers/models/llama/__init__.py | 1 +
.../models/llama/modeling_ms_llama.py | 633 +++++++++
30 files changed, 5056 insertions(+)
create mode 100644 examples/sharegpt_4v/readme.md
create mode 100644 examples/sharegpt_4v/share4v/__init__.py
create mode 100644 examples/sharegpt_4v/share4v/configs/config.json
create mode 100644 examples/sharegpt_4v/share4v/configs/tokenizer_config.json
create mode 100644 examples/sharegpt_4v/share4v/configs/vit/config.json
create mode 100644 examples/sharegpt_4v/share4v/configs/vit/preprocessor_config.json
create mode 100644 examples/sharegpt_4v/share4v/constants.py
create mode 100644 examples/sharegpt_4v/share4v/conversation.py
create mode 100644 examples/sharegpt_4v/share4v/mm_utils.py
create mode 100644 examples/sharegpt_4v/share4v/model/__init__.py
create mode 100644 examples/sharegpt_4v/share4v/model/builder.py
create mode 100644 examples/sharegpt_4v/share4v/model/language_model/share4v_llama.py
create mode 100644 examples/sharegpt_4v/share4v/model/multimodal_encoder/builder.py
create mode 100644 examples/sharegpt_4v/share4v/model/multimodal_encoder/clip_encoder.py
create mode 100644 examples/sharegpt_4v/share4v/model/multimodal_projector/builder.py
create mode 100644 examples/sharegpt_4v/share4v/model/share4v_arch.py
create mode 100644 examples/sharegpt_4v/share4v/model/utils.py
create mode 100644 examples/sharegpt_4v/share4v/pipeline/__init__.py
create mode 100644 examples/sharegpt_4v/share4v/pipeline/helpers/__init__.py
create mode 100644 examples/sharegpt_4v/share4v/pipeline/helpers/stopping_criteria.py
create mode 100644 examples/sharegpt_4v/share4v/pipeline/text_generation.py
create mode 100644 examples/sharegpt_4v/share4v/transformers/__init__.py
create mode 100644 examples/sharegpt_4v/share4v/transformers/activations_ms.py
create mode 100644 examples/sharegpt_4v/share4v/transformers/modeling_ms_utils.py
create mode 100644 examples/sharegpt_4v/share4v/transformers/models/__init__.py
create mode 100644 examples/sharegpt_4v/share4v/transformers/models/cache.py
create mode 100644 examples/sharegpt_4v/share4v/transformers/models/clip/__init__.py
create mode 100644 examples/sharegpt_4v/share4v/transformers/models/clip/modeling_ms_clip.py
create mode 100644 examples/sharegpt_4v/share4v/transformers/models/llama/__init__.py
create mode 100644 examples/sharegpt_4v/share4v/transformers/models/llama/modeling_ms_llama.py
diff --git a/examples/sharegpt_4v/readme.md b/examples/sharegpt_4v/readme.md
new file mode 100644
index 0000000000..b01d09ae1d
--- /dev/null
+++ b/examples/sharegpt_4v/readme.md
@@ -0,0 +1,13 @@
+# ShareGPT4V: Improving Large Multi-modal Models with Better Captions
+
+[Paper](!https://arxiv.org/pdf/2311.12793.pdf)
+
+[Official Repo](!https://github.com/ShareGPT4Omni/ShareGPT4V)
+
+[Image](!https://raw.githubusercontent.com/ShareGPT4V/ShareGPT4V-Resources/master/images/teaser.png)
+
+
+## Inference
+
+
+1. Prepare weight files:
\ No newline at end of file
diff --git a/examples/sharegpt_4v/share4v/__init__.py b/examples/sharegpt_4v/share4v/__init__.py
new file mode 100644
index 0000000000..cc5dd007d3
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/__init__.py
@@ -0,0 +1 @@
+from .model import Share4VLlamaForCausalLM
diff --git a/examples/sharegpt_4v/share4v/configs/config.json b/examples/sharegpt_4v/share4v/configs/config.json
new file mode 100644
index 0000000000..032967f6e5
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/configs/config.json
@@ -0,0 +1,46 @@
+{
+ "_name_or_path": "MS_ShareGPT4V-7B",
+ "architectures": [
+ "Share4VLlamaForCausalLM"
+ ],
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "freeze_mm_mlp_adapter": false,
+ "hidden_act": "silu",
+ "hidden_size": 4096,
+ "image_aspect_ratio": "pad",
+ "image_grid_pinpoints": null,
+ "initializer_range": 0.02,
+ "intermediate_size": 11008,
+ "max_position_embeddings": 4096,
+ "mm_hidden_size": 1024,
+ "mm_projector_lr": null,
+ "mm_projector_type": "mlp2x_gelu",
+ "mm_use_im_patch_token": false,
+ "mm_use_im_start_end": false,
+ "mm_vision_select_feature": "patch",
+ "mm_vision_select_layer": -2,
+ "mm_vision_tower": "/root/congw/project/ms_ShareGPT4V/share4v/configs/vit/",
+ "mm_vision_tower_path":"/root/congw/project/ms_ShareGPT4V/share4v/configs/vit/vit-large336-l12.ckpt",
+ "model_type": "share4v",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 32,
+ "num_key_value_heads": 32,
+ "pad_token_id": 0,
+ "pretraining_tp": 1,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": null,
+ "tie_word_embeddings": false,
+ "dtype": "float32",
+ "transformers_version": "4.31.0",
+ "tune_entire_model": false,
+ "tune_mm_mlp_adapter": false,
+ "tune_vision_tower": false,
+ "use_cache": true,
+ "use_mm_proj": true,
+ "vision_tower_lr": null,
+ "vocab_size": 32000,
+ "output_attentions": false,
+ "output_hidden_states": false,
+ "use_return_dict": true
+}
diff --git a/examples/sharegpt_4v/share4v/configs/tokenizer_config.json b/examples/sharegpt_4v/share4v/configs/tokenizer_config.json
new file mode 100644
index 0000000000..8455d61094
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/configs/tokenizer_config.json
@@ -0,0 +1,36 @@
+{
+ "add_bos_token": true,
+ "add_eos_token": false,
+ "bos_token": {
+ "__type": "AddedToken",
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "clean_up_tokenization_spaces": false,
+ "eos_token": {
+ "__type": "AddedToken",
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "legacy": false,
+ "model_max_length": 2048,
+ "pad_token": null,
+ "padding_side": "right",
+ "sp_model_kwargs": {},
+ "tokenizer_class": "LlamaTokenizer",
+ "unk_token": {
+ "__type": "AddedToken",
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ }
+ }
+
\ No newline at end of file
diff --git a/examples/sharegpt_4v/share4v/configs/vit/config.json b/examples/sharegpt_4v/share4v/configs/vit/config.json
new file mode 100644
index 0000000000..99620410b3
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/configs/vit/config.json
@@ -0,0 +1,23 @@
+{
+ "_name_or_path": "ShareGPT4V-7B_Pretrained_vit-large336-l12",
+ "architectures": [
+ "CLIPVisionModel"
+ ],
+ "attention_dropout": 0.0,
+ "dropout": 0.0,
+ "hidden_act": "quick_gelu",
+ "hidden_size": 1024,
+ "image_size": 336,
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 4096,
+ "layer_norm_eps": 1e-05,
+ "model_type": "clip_vision_model",
+ "num_attention_heads": 16,
+ "num_channels": 3,
+ "num_hidden_layers": 24,
+ "patch_size": 14,
+ "projection_dim": 768,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.31.0"
+}
diff --git a/examples/sharegpt_4v/share4v/configs/vit/preprocessor_config.json b/examples/sharegpt_4v/share4v/configs/vit/preprocessor_config.json
new file mode 100644
index 0000000000..9d81580626
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/configs/vit/preprocessor_config.json
@@ -0,0 +1,29 @@
+{
+ "crop_size": {
+ "height": 336,
+ "width": 336
+ },
+ "do_center_crop": true,
+ "do_convert_rgb": true,
+ "do_normalize": true,
+ "do_rescale": true,
+ "do_resize": true,
+ "feature_extractor_type": "CLIPFeatureExtractor",
+ "image_mean": [
+ 0.48145466,
+ 0.4578275,
+ 0.40821073
+ ],
+ "image_processor_type": "CLIPImageProcessor",
+ "image_std": [
+ 0.26862954,
+ 0.26130258,
+ 0.27577711
+ ],
+ "resample": 3,
+ "rescale_factor": 0.00392156862745098,
+ "size": {
+ "shortest_edge": 336
+ }
+ }
+
\ No newline at end of file
diff --git a/examples/sharegpt_4v/share4v/constants.py b/examples/sharegpt_4v/share4v/constants.py
new file mode 100644
index 0000000000..be8cf02049
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/constants.py
@@ -0,0 +1,12 @@
+CONTROLLER_HEART_BEAT_EXPIRATION = 30
+WORKER_HEART_BEAT_INTERVAL = 15
+
+LOGDIR = "."
+
+# Model Constants
+IGNORE_INDEX = -100
+IMAGE_TOKEN_INDEX = -200
+DEFAULT_IMAGE_TOKEN = ""
+DEFAULT_IMAGE_PATCH_TOKEN = ""
+DEFAULT_IM_START_TOKEN = ""
+DEFAULT_IM_END_TOKEN = ""
diff --git a/examples/sharegpt_4v/share4v/conversation.py b/examples/sharegpt_4v/share4v/conversation.py
new file mode 100644
index 0000000000..b6ec8732a4
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/conversation.py
@@ -0,0 +1,369 @@
+import dataclasses
+from enum import Enum, auto
+from typing import List
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+ PLAIN = auto()
+ LLAMA_2 = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ skip_next: bool = False
+
+ def get_prompt(self):
+ messages = self.messages
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
+ messages = self.messages.copy()
+ init_role, init_msg = messages[0].copy()
+ init_msg = init_msg[0].replace("", "").strip()
+ if 'mmtag' in self.version:
+ messages[0] = (init_role, init_msg)
+ messages.insert(0, (self.roles[0], ""))
+ messages.insert(1, (self.roles[1], "Received."))
+ else:
+ messages[0] = (init_role, "\n" + init_msg)
+
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
+ def wrap_sys(msg): return f"<>\n{msg}\n<>\n\n"
+ def wrap_inst(msg): return f"[INST] {msg} [/INST]"
+ ret = ""
+
+ for i, (role, message) in enumerate(messages):
+ if i == 0:
+ assert message, "first message should not be none"
+ assert role == self.roles[0], "first message should come from user"
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i == 0:
+ message = wrap_sys(self.system) + message
+ if i % 2 == 0:
+ message = wrap_inst(message)
+ ret += self.sep + message
+ else:
+ ret += " " + message + " " + self.sep2
+ else:
+ ret += ""
+ ret = ret.lstrip(self.sep)
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = self.system
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return ret
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def get_images(self, return_pil=False):
+ images = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+
+ from PIL import Image
+ msg, image, image_process_mode = msg
+ if image_process_mode == "Pad":
+ def expand2square(pil_img, background_color=(122, 116, 104)):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(
+ pil_img.mode, (width, width), background_color)
+ result.paste(
+ pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(
+ pil_img.mode, (height, height), background_color)
+ result.paste(
+ pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image)
+ elif image_process_mode in ["Default", "Crop"]:
+ pass
+ elif image_process_mode == "Resize":
+ image = image.resize((336, 336))
+ else:
+ raise ValueError(
+ f"Invalid image_process_mode: {image_process_mode}")
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(
+ min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if longest_edge != max(image.size):
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ if return_pil:
+ images.append(image)
+ else:
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ img_b64_str = base64.b64encode(
+ buffered.getvalue()).decode()
+ images.append(img_b64_str)
+ return images
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ msg, image, image_process_mode = msg
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(
+ min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG")
+ img_b64_str = base64.b64encode(
+ buffered.getvalue()).decode()
+ img_str = f''
+ msg = img_str + msg.replace('', '').strip()
+ ret.append([msg, None])
+ else:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ version=self.version)
+
+ def dict(self):
+ if len(self.get_images()) > 0:
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+
+
+conv_vicuna_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
+ ("Assistant",
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
+ "renewable and non-renewable energy sources:\n"
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
+ "energy sources are finite and will eventually run out.\n"
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
+ "and other negative effects.\n"
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
+ "have lower operational costs than non-renewable sources.\n"
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
+ "locations than non-renewable sources.\n"
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_vicuna_v1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_llama_2 = Conversation(
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_share4v_llama_2 = Conversation(
+ system="You are a helpful language and vision assistant. "
+ "You are able to understand the visual content that the user provides, "
+ "and assist the user with a variety of tasks using natural language.",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_share4v_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.PLAIN,
+ sep="\n",
+)
+
+conv_share4v_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_share4v_v0_mmtag = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "The visual content will be provided with the following format: visual content.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+ version="v0_mmtag",
+)
+
+conv_share4v_v1 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_share4v_v1_mmtag = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "The visual content will be provided with the following format: visual content.",
+ roles=("USER", "ASSISTANT"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+ version="v1_mmtag",
+)
+
+default_conversation = conv_vicuna_v1
+conv_templates = {
+ "default": conv_vicuna_v0,
+ "v0": conv_vicuna_v0,
+ "v1": conv_vicuna_v1,
+ "vicuna_v1": conv_vicuna_v1,
+ "llama_2": conv_llama_2,
+
+ "plain": conv_share4v_plain,
+ "v0_plain": conv_share4v_plain,
+ "share4v_v0": conv_share4v_v0,
+ "v0_mmtag": conv_share4v_v0_mmtag,
+ "share4v_v1": conv_share4v_v1,
+ "v1_mmtag": conv_share4v_v1_mmtag,
+ "share4v_llama_2": conv_share4v_llama_2
+}
+
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
diff --git a/examples/sharegpt_4v/share4v/mm_utils.py b/examples/sharegpt_4v/share4v/mm_utils.py
new file mode 100644
index 0000000000..c432550312
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/mm_utils.py
@@ -0,0 +1,105 @@
+import base64
+from io import BytesIO
+
+import mindspore as ms
+from PIL import Image
+from transformers import StoppingCriteria
+
+from share4v.constants import IMAGE_TOKEN_INDEX
+
+
+def load_image_from_base64(image):
+ return Image.open(BytesIO(base64.b64decode(image)))
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def process_images(images, image_processor, model_cfg):
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
+ new_images = []
+ if image_aspect_ratio == 'pad':
+ for image in images:
+ image = expand2square(image, tuple(int(x*255)
+ for x in image_processor.image_mean))
+ image = image_processor.preprocess(image, return_tensors='pt')[
+ 'pixel_values'][0]
+ new_images.append(image)
+ else:
+ return image_processor(images, return_tensors='pt')['pixel_values']
+ if all(x.shape == new_images[0].shape for x in new_images):
+ new_images = ms.ops.stack(new_images, axis=0)
+ return new_images
+
+
+def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
+ prompt_chunks = [
+ tokenizer(chunk).input_ids for chunk in prompt.split('')]
+
+ def insert_separator(X, sep):
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
+
+ input_ids = []
+ offset = 0
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
+ offset = 1
+ input_ids.append(prompt_chunks[0][0])
+
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
+ input_ids.extend(x[offset:])
+
+ if return_tensors is not None:
+ if return_tensors == 'ms':
+ return ms.Tensor(input_ids, ms.int64)
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
+ return input_ids
+
+
+def get_model_name_from_path(model_path):
+ model_path = model_path.strip("/")
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith('checkpoint-'):
+ return model_paths[-2] + "_" + model_paths[-1]
+ else:
+ return model_paths[-1].split('_')[0]
+
+
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.keyword_ids = []
+ self.max_keyword_len = 0
+ for keyword in keywords:
+ cur_keyword_ids = tokenizer(keyword).input_ids
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
+ cur_keyword_ids = cur_keyword_ids[1:]
+ if len(cur_keyword_ids) > self.max_keyword_len:
+ self.max_keyword_len = len(cur_keyword_ids)
+ self.keyword_ids.append(ms.Tensor(cur_keyword_ids))
+ self.tokenizer = tokenizer
+ self.start_len = input_ids.shape[1]
+
+ def __call__(self, output_ids: ms.int64, scores: ms.float32, **kwargs) -> bool:
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)"
+ offset = min(output_ids.shape[1] -
+ self.start_len, self.max_keyword_len)
+ for keyword_id in self.keyword_ids:
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
+ return True
+ outputs = self.tokenizer.batch_decode(
+ output_ids[:, -offset:], skip_special_tokens=True)[0]
+ for keyword in self.keywords:
+ if keyword in outputs:
+ return True
+ return False
diff --git a/examples/sharegpt_4v/share4v/model/__init__.py b/examples/sharegpt_4v/share4v/model/__init__.py
new file mode 100644
index 0000000000..43acf5db60
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/model/__init__.py
@@ -0,0 +1,3 @@
+# from .language_model.share4v_llama import (Share4VConfig,
+# Share4VLlamaForCausalLM)
+from .language_model.share4v_llama import Share4VLlamaForCausalLM
\ No newline at end of file
diff --git a/examples/sharegpt_4v/share4v/model/builder.py b/examples/sharegpt_4v/share4v/model/builder.py
new file mode 100644
index 0000000000..a5cdc4596c
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/model/builder.py
@@ -0,0 +1,74 @@
+import os
+import warnings
+
+import mindspore as ms
+# from transformers import AutoTokenizer, BitsAndBytesConfig
+
+# cong TODO: double check this import
+from transformers import BitsAndBytesConfig, AutoTokenizer
+
+from share4v.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
+ DEFAULT_IMAGE_PATCH_TOKEN)
+from share4v.model import Share4VLlamaForCausalLM
+
+
+def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device="Ascend"):
+ # device: 'Ascend', 'GPU', 'CPU'
+ kwargs = {}
+
+ if load_8bit:
+ kwargs['load_in_8bit'] = True
+ elif load_4bit:
+ kwargs['load_in_4bit'] = True
+ kwargs['quantization_config'] = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=ms.float16,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type='nf4'
+ )
+ else:
+ kwargs['dtype'] = ms.float16
+
+ if 'sharegpt4v' in model_name.lower():
+ # Load ShareGPT4V model
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=False)
+ model = Share4VLlamaForCausalLM.from_pretrained(model_path, **kwargs)
+
+ else:
+ raise NotImplementedError("Please make sure the model is ShareGPT4V model")
+
+ image_processor = None
+
+ if 'sharegpt4v' in model_name.lower():
+ mm_use_im_start_end = getattr(
+ model.config, "mm_use_im_start_end", False)
+ mm_use_im_patch_token = getattr(
+ model.config, "mm_use_im_patch_token", True)
+ if mm_use_im_patch_token:
+ tokenizer.add_tokens(
+ [DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ if mm_use_im_start_end:
+ tokenizer.add_tokens(
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ model.resize_token_embeddings(len(tokenizer))
+
+ vision_tower = model.get_vision_tower()
+ if not vision_tower.is_loaded:
+ # CLIPVisionTower.load_model()
+ print('trying load vision tower')
+ vision_tower.load_model()
+ # cong TODO: check if this setting is successful
+ # vision_tower = vision_tower.to(ms.float16)
+ image_processor = vision_tower.image_processor
+
+ if hasattr(model.config, "max_sequence_length"):
+ context_len = model.config.max_sequence_length
+ else:
+ context_len = 2048
+
+ # set llama model, clip vision tower, mm_projector to inference mode
+ model.set_train(False)
+ model.set_dtype(kwargs['dtype'])
+
+ return tokenizer, model, image_processor, context_len
diff --git a/examples/sharegpt_4v/share4v/model/language_model/share4v_llama.py b/examples/sharegpt_4v/share4v/model/language_model/share4v_llama.py
new file mode 100644
index 0000000000..b14f4bd26c
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/model/language_model/share4v_llama.py
@@ -0,0 +1,423 @@
+import os
+from typing import List, Optional, Tuple, Union
+
+import mindspore as ms
+import mindspore.nn as nn
+from mindspore.nn import CrossEntropyLoss
+
+# from transformers import (AutoConfig, AutoModelForCausalLM, LlamaConfig,
+# LlamaForCausalLM, LlamaModel)
+# from mindformers import (AutoConfig, AutoModelForCausalLM, LlamaConfig,
+# LlamaForCausalLM, LlamaModel)
+
+# from mindformers import LlamaConfig
+from share4v.transformers.models.llama import LlamaForCausalLM, LlamaModel
+
+
+# from transformers.modeling_outputs import CausalLMOutputWithPast
+# cong TODO: move this part to a suitable file, mindformers doesn't have this class
+from dataclasses import dataclass
+from collections import OrderedDict
+from typing import Any, ContextManager, Iterable, List, Tuple
+
+os.sys.path.append('/Users/congwang/Documents/M/ms_ShareGPT4V')
+from share4v.pipeline import TextGenerator
+from share4v.transformers.models.cache import Cache, DynamicCache
+
+# @dataclass
+# class CausalLMOutputWithPast(OrderedDict):
+# """
+# Base class for causal language model (or autoregressive) outputs.
+
+# Args:
+# loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+# Language modeling loss (for next-token prediction).
+# logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+# Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+# past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+# `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+# Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+# `past_key_values` input) to speed up sequential decoding.
+# hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+# Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+# one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+# Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+# attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+# Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+# sequence_length)`.
+
+# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+# heads.
+# """
+# def __init__(self, *args, **kwargs):
+# super().__init__(*args, **kwargs)
+
+# def __delitem__(self, *args, **kwargs):
+# raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+# def setdefault(self, *args, **kwargs):
+# raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+# def pop(self, *args, **kwargs):
+# raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+# def update(self, *args, **kwargs):
+# raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+# def __getitem__(self, k):
+# if isinstance(k, str):
+# inner_dict = dict(self.items())
+# return inner_dict[k]
+# else:
+# return self.to_tuple()[k]
+
+# def __setattr__(self, name, value):
+# if name in self.keys() and value is not None:
+# # Don't call self.__setitem__ to avoid recursion errors
+# super().__setitem__(name, value)
+# super().__setattr__(name, value)
+
+# def __setitem__(self, key, value):
+# # Will raise a KeyException if needed
+# super().__setitem__(key, value)
+# # Don't call self.__setattr__ to avoid recursion errors
+# super().__setattr__(key, value)
+
+# # def __reduce__(self):
+# # if not is_dataclass(self):
+# # return super().__reduce__()
+# # callable, _args, *remaining = super().__reduce__()
+# # args = tuple(getattr(self, field.name) for field in fields(self))
+# # return callable, args, *remaining
+
+# def to_tuple(self) -> Tuple[Any]:
+# """
+# Convert self to a tuple containing all the attributes/keys that are not `None`.
+# """
+# return tuple(self[k] for k in self.keys())
+
+# loss: Optional[float] = None
+# logits: float = None
+# past_key_values: Optional[Tuple[Tuple[float]]] = None
+# hidden_states: Optional[Tuple[float, ...]] = None
+# attentions: Optional[Tuple[float, ...]] = None
+# # loss: Optional[ms.float32] = None
+# # logits: ms.float32 = None
+# # past_key_values: Optional[Tuple[Tuple[ms.float32]]] = None
+# # hidden_states: Optional[Tuple[ms.float32, ...]] = None
+# # attentions: Optional[Tuple[ms.float32, ...]] = None
+
+from ..share4v_arch import Share4VMetaForCausalLM, Share4VMetaModel
+
+# this part used for mindformers.LlamaModel
+# class Share4VConfig(LlamaConfig):
+# model_type = "share4v"
+# def __init__(self, **kwargs):
+# self.mm_vision_tower_path = kwargs.pop("mm_vision_tower_path", None)
+# # self.use_return_dict = kwargs.pop("use_return_dict", None)
+# # self.seq_length = kwargs.pop("seq_length", None)
+# super(Share4VConfig, self).__init__(**kwargs)
+
+
+
+class Share4VLlamaModel(Share4VMetaModel, LlamaModel):
+ # config_class = Share4VConfig
+
+ # this func used for self-defined LlamaModel
+ def __init__(self, config):
+ super(Share4VLlamaModel, self).__init__(config)
+ self.dtype = config.get('dtype')
+ self.set_dtype(self.dtype)
+
+ # this func used for mindformers.LlamaModel
+ # def __init__(self, config: LlamaConfig):
+ # super(Share4VLlamaModel, self).__init__(config)
+
+ # this func used for mindformers.LlamaModel
+ # def embed_tokens(self, input_ids):
+ # if type(input_ids) is not ms.Tensor:
+ # print('Share4VLlamaModel.embed_tokens')
+ # print(type(input_ids), input_ids.shape, input_ids)
+ # # np.ndarray to ms.Tensor
+ # input_ids = ms.Tensor(input_ids, dtype=ms.int64)
+ # return self.tok_embeddings(input_ids)
+ def set_train(self, mode:bool):
+ '''Set the model(llama model, vision_tower and mm_projector) to training mode or evaluation mode.'''
+ # cong TODO: edit alert message
+ for param in self.get_parameters():
+ param.requires_grad = mode
+
+ vision_tower = self.get_vision_tower()
+ if not vision_tower.is_loaded:
+ print(f"set train to {mode}, but vision_tower is not load")
+
+ # mm_projector = self.get_mm_projector()
+ # if mm_projector is not None:
+ # for cell in mm_projector:
+ # for param in cell.get_parameters():
+ # param.requires_grad = mode
+ # else:
+ # print("mm_projector is None")
+ return self
+
+ def set_dtype(self, dtype):
+ '''Set the model(llama model, vision_tower and mm_projector) to target data type'''
+ self.dtype = dtype
+ for param in self.get_parameters():
+ param.set_dtype(self.dtype)
+
+
+# class Share4VLlamaForCausalLM(LlamaForCausalLM, Share4VMetaForCausalLM, TextGenerator):
+class Share4VLlamaForCausalLM(LlamaForCausalLM, Share4VMetaForCausalLM):
+ # config_class = Share4VConfig
+
+ def __init__(self, config):
+ # print("Share4VLlamaForCausalLM init")
+
+ self.vocab_size = int(config.get('vocab_size')) if config.get('vocab_size') else 32000
+ config['vocab_size'] = self.vocab_size
+ config['dtype'] = ms.float32 if config.get('vocab_size') == "float32" else ms.float16
+ self.config = config
+
+ super(LlamaForCausalLM, self).__init__(config)
+ self.model = Share4VLlamaModel(config)
+ # LlamaForCausalLM.construct()
+
+
+ # self.model = LlamaModel(
+ # hidden_size=hidden_size,
+ # intermediate_size=intermediate_size,
+ # max_position_embeddings=max_position_embeddings,
+ # num_attention_heads=num_attention_heads,
+ # num_hidden_layers=num_hidden_layers,
+ # num_key_value_heads=num_key_value_heads,
+ # rms_norm_eps=rms_norm_eps,
+ # rope_theta=rope_theta,
+ # vocab_size=vocab_size,
+ # attention_dropout=attention_dropout,
+ # hidden_act=hidden_act,
+ # pad_token_id=pad_token_id,
+ # past_key_value_cache=past_key_value_cache,
+ # )
+ # print(type(self.model))
+ # super(Share4VMetaForCausalLM, self).__init__(self.model)
+ # super(TextGenerator, self).__init__(self.model)
+ self.lm_head = nn.Dense(
+ config['hidden_size'], config['vocab_size'], has_bias=False)
+ # self.lm_head = nn.Dense(
+ # config.hidden_size, config.vocab_size, has_bias=False)
+
+ # Initialize weights and apply final processing
+ # self.post_init()
+
+
+ def get_model(self):
+ return self.model
+
+ # this func is used for self-defined LlamaModel
+ def load_model(self, model_path, **kwargs):
+ # cong TODO: fix input args model_path to load from config.json
+ # import json
+ # with open(os.path.join(model_path, "config.json"), "r") as f:
+ # config = json.load(f)
+ ms_source_data = ms.load_checkpoint(model_path)
+ # due to version update reason, need to align the key name with latest version
+ if 'model.embed_tokens.embedding_table' in ms_source_data.keys():
+ ms_source_data['model.embed_tokens.weight'] = ms_source_data['model.embed_tokens.embedding_table']
+ params_not_load = ms.load_param_into_net(self, ms_source_data, strict_load=False)
+ print(f"Params not loaded: {params_not_load}")
+
+
+ def set_train(self, mode:bool):
+ self.get_model().set_train(mode)
+
+ return self
+ # # cong TODO: edit alert message
+
+ # model = self.get_model()
+ # if model is not None:
+ # for param in model.get_parameters():
+ # param.requires_grad = mode
+ # else:
+ # print("model is None")
+
+ # vision_tower = self.get_vision_tower()
+ # if not vision_tower.is_loaded:
+ # vision_tower.set_train(mode)
+ # else:
+ # print("vision_tower is not load")
+
+ # mm_projector = self.get_mm_projector()
+ # if mm_projector is not None:
+ # for cell in mm_projector:
+ # for param in cell.get_parameters():
+ # param.requires_grad = mode
+ # else:
+ # print("mm_projector is None")
+
+ def set_dtype(self, dtype):
+ self.dtype = dtype
+ self.lm_head.weight.set_dtype(dtype)
+ self.get_model().set_dtype(dtype)
+ # self.dtype = dtype
+ # for param in self.get_model().get_parameters():
+ # param.set_dtype(self.dtype)
+
+ def construct(
+ self,
+ # cong TODO: modify this area
+ # input_ids: ms.int64 = None,
+ # attention_mask: Optional[ms.Tensor] = None,
+ # past_key_values: Optional[List[ms.float32]] = None,
+ # inputs_embeds: Optional[ms.float32] = None,
+ # labels: Optional[ms.int64] = None,
+ # use_cache: Optional[bool] = None,
+ # output_attentions: Optional[bool] = None,
+ # output_hidden_states: Optional[bool] = None,
+ # images: Optional[ms.float32] = None,
+ # return_dict: Optional[bool] = None,
+ # ) -> Union[Tuple, CausalLMOutputWithPast]:
+ input_ids = None,
+ attention_mask = None,
+ # past_key_values = None,
+ inputs_embeds = None,
+ labels = None,
+ # use_cache = None,
+ past_key_cache_list = None,
+ past_value_cache_list = None,
+ images = None,
+ # return_dict = None,
+ return_key_value_cache: bool = False,
+ ) -> Tuple:
+ # output_attentions = output_attentions if output_attentions is not None else self.config['output_attentions']
+ # output_hidden_states = (
+ # output_hidden_states if output_hidden_states is not None else self.config['output_hidden_states']
+ # )
+ # return_dict = return_dict if return_dict is not None else self.config['use_return_dict']
+
+ input_ids, attention_mask, past_key_cache_list, past_value_cache_list, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(
+ input_ids, attention_mask, past_key_cache_list, past_value_cache_list, labels, images)
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ # outputs = self.model(
+ # input_ids=input_ids,
+ # attention_mask=attention_mask,
+ # past_key_values=past_key_values,
+ # inputs_embeds=inputs_embeds,
+ # use_cache=use_cache,
+ # output_attentions=output_attentions,
+ # output_hidden_states=output_hidden_states,
+ # return_dict=return_dict
+ # )
+ # print("input_ids", input_ids)
+ # if attention_mask is not None:
+ # print("attention_mask", type(attention_mask), attention_mask.shape)
+ # if inputs_embeds is not None:
+ # print("inputs_embeds", type(inputs_embeds), inputs_embeds.shape)
+
+ # this part for self-defined LlamaModel
+ # transformers.LlamaModel outputs:
+ # last_hidden_state=hidden_states,
+ # past_key_values=next_cache,
+ # hidden_states=all_hidden_states,
+ # attentions=all_self_attns,
+
+ # 0625 self-defined LlamaModel outputs:
+ # (hidden_states, key_cache_list, value_cache_list)
+ # cong TODO: edit inputs for llamamodel, other args:
+ # position_ids
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ # past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ # use_cache=use_cache,
+ past_key_cache_list=past_key_cache_list,
+ past_value_cache_list=past_value_cache_list,
+ return_key_value_cache=return_key_value_cache
+ )
+
+ # cong TODO: for 0625 version, llama model return hidden_states, key_cache_list, value_cache_list
+ hidden_states = outputs[0]
+ # cong TODO: .to(ms.float32)?
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config['vocab_size'])
+ shift_labels = shift_labels.view(-1)
+ # Enable model/pipeline parallelism
+ loss = loss_fct(shift_logits, shift_labels)
+
+ # print('Share4VLlamaForCausalLM.construct return tuple...')
+ # don't have hidden_states, attentions for current self-defined LlamaModel
+ # original results: loss, logits, past_key_values, hidden_states, attentions
+ # now results: loss, logits, key_cache_list, value_cache_list
+ results = (loss, logits, outputs[1], outputs[2])
+ return results
+
+ # cong TODO: modify this area
+ # if not return_dict:
+ # output = (logits,) + outputs[1:]
+ # return (loss,) + output if loss is not None else output
+
+ # return CausalLMOutputWithPast(
+ # loss=loss,
+ # logits=logits,
+ # past_key_values=outputs.past_key_values,
+ # hidden_states=outputs.hidden_states,
+ # attentions=outputs.attentions,
+ # )
+
+ def prepare_inputs_for_generation(
+ # self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ self, input_ids, **kwargs
+
+ ):
+ # mindformers.LlamaForCausalLM.prepare_inputs_for_generation only takes (input_ids, **kwarg)
+
+ # cong TODO: what is this doing?
+ if kwargs.get("return_key_value_cache"):
+ input_ids = input_ids[:, -1:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if kwargs.get("inputs_embeds") is not None and not kwargs.get("return_key_value_cache"):
+ model_inputs = {"inputs_embeds": kwargs.get("inputs_embeds") }
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ # "past_key_values": past_key_values if past_key_values is not None else kwargs.get("past_key_values"),
+ "return_key_value_cache": kwargs.get("return_key_value_cache", None),
+ "past_key_cache_list": kwargs.get("past_key_cache_list", None),
+ "past_value_cache_list": kwargs.get("past_value_cache_list", None),
+ "attention_mask": kwargs.get("attention_mask", None),
+ "images": kwargs.get("images", None),
+ }
+ )
+ # print("model_inputs.get('input_ids')", model_inputs.get('input_ids').shape if model_inputs.get('input_ids') is not None else None)
+ # print("model_inputs.get('inputs_embeds')", model_inputs.get('inputs_embeds').shape if model_inputs.get('inputs_embeds') is not None else None)
+ # print("model_inputs.get('images')", model_inputs.get('images').shape if model_inputs.get('images') is not None else None)
+ # print("model_inputs.get('past_key_cache_list')", len(model_inputs.get('past_key_cache_list')) if model_inputs.get('past_key_cache_list') is not None else None)
+ # print("model_inputs.get('past_value_cache_list')", len(model_inputs.get('past_value_cache_list')) if model_inputs.get('past_value_cache_list') is not None else None)
+ # print("model_inputs.get('return_key_value_cache')", model_inputs.get('return_key_value_cache'))
+
+ # print("model_inputs", model_inputs)
+ return model_inputs
+
+ # def init_weights(self, module=None):
+ # """Initialize the weights"""
+ # pass
+
+# AutoConfig.register("share4v", Share4VConfig)
+# AutoModelForCausalLM.register(Share4VConfig, Share4VLlamaForCausalLM)
diff --git a/examples/sharegpt_4v/share4v/model/multimodal_encoder/builder.py b/examples/sharegpt_4v/share4v/model/multimodal_encoder/builder.py
new file mode 100644
index 0000000000..e4b5c78252
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/model/multimodal_encoder/builder.py
@@ -0,0 +1,15 @@
+import os
+
+from .clip_encoder import CLIPVisionTower
+
+
+def build_vision_tower(vision_tower_cfg, **kwargs):
+ vision_tower = vision_tower_cfg.get('mm_vision_tower', None)
+ # vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
+
+ # Lin-Chen/ShareGPT4V-7B_Pretrained_vit-large336-l12
+ is_absolute_path_exists = os.path.exists(vision_tower)
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or vision_tower.startswith("Lin-Chen"):
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
+
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
diff --git a/examples/sharegpt_4v/share4v/model/multimodal_encoder/clip_encoder.py b/examples/sharegpt_4v/share4v/model/multimodal_encoder/clip_encoder.py
new file mode 100644
index 0000000000..30ffc6548d
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/model/multimodal_encoder/clip_encoder.py
@@ -0,0 +1,116 @@
+import mindspore as ms
+import mindspore.nn as nn
+import mindspore.ops as ops
+
+# from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
+import sys
+from share4v.transformers.models.clip import CLIPVisionModel
+
+# from examples.stable_diffusion_v2.tools._common.clip.clip_config import CLIPVisionConfig
+# from transformers import CLIPImageProcessor, CLIPVisionConfig
+from transformers import CLIPImageProcessor, CLIPVisionConfig
+
+
+class CLIPVisionTower(nn.Cell):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+ self.vision_tower_name = vision_tower
+ self.select_layer = args['mm_vision_select_layer']
+ self.select_feature = args.get('mm_vision_select_feature', 'patch')
+ self.model_path = args['mm_vision_tower_path']
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
+
+ if not delay_load:
+ self.load_model()
+
+
+ def load_model(self):
+ print(f'Load vision tower from {self.vision_tower_name}')
+ self.image_processor = CLIPImageProcessor.from_pretrained(
+ self.vision_tower_name)
+ if 'eva' in self.vision_tower_name.lower():
+ raise NotImplementedError("Not support eva models")
+ else:
+ if self.model_path is not None:
+ print('load vision tower using ms.load_param_into_net function')
+ self.vision_tower = CLIPVisionModel(self.cfg_only)
+ ms_vit_source_data = ms.load_checkpoint(self.model_path)
+ # modify dict keys
+ ms_vit_source_data = {k.replace('vision_model', 'vision_tower.vision_model'): v for k, v in ms_vit_source_data.items()}
+ params_not_load = ms.load_param_into_net(self.vision_tower, ms_vit_source_data, strict_load=False)
+ print(f"Params not loaded: {params_not_load}")
+ else:
+ print('load vision tower using CLIPVisionModel.from_pretrained function')
+ self.vision_tower = CLIPVisionModel.from_pretrained(
+ self.vision_tower_name)
+
+ # freeze the vision tower
+ self.set_train(False)
+ print("vision tower set_train to False ")
+
+ self.is_loaded = True
+
+ def set_train(self, mode:bool):
+ for param in self.vision_tower.get_parameters():
+ param.requires_grad = mode
+ return self
+
+ def set_dtype(self, dtype):
+ for param in self.vision_tower.get_parameters():
+ param.set_dtype(dtype)
+
+ def feature_select(self, image_forward_outs):
+ # image_features = image_forward_outs['hidden_states'][self.select_layer]
+ image_features = image_forward_outs[2][self.select_layer]
+ if self.select_feature == 'patch':
+ image_features = image_features[:, 1:]
+ elif self.select_feature == 'cls_patch':
+ image_features = image_features
+ else:
+ raise ValueError(
+ f'Unexpected select feature: {self.select_feature}')
+ return image_features
+
+ # @model.set_train(False) comment to enable fine-tune vit
+ def construct(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_forward_out = self.vision_tower(image.to(self.dtype).unsqueeze(0),
+ output_hidden_states=True)
+ # image_forward_out: last_hidden_state, pooled_output, hidden_states, attentions
+ image_feature = self.feature_select(
+ image_forward_out).to(image.dtype)
+ image_features.append(image_feature)
+ else:
+ image_forward_outs = self.vision_tower(
+ images.to(self.dtype), output_hidden_states=True)
+ image_features = self.feature_select(
+ image_forward_outs).to(images.dtype)
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return ops.zeros(1, self.hidden_size, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ return self.vision_tower.dtype
+
+ @property
+ def config(self):
+ if self.is_loaded:
+ return self.vision_tower.config
+ else:
+ return self.cfg_only
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size
+
+ @property
+ def num_patches(self):
+ return (self.config.image_size // self.config.patch_size) ** 2
diff --git a/examples/sharegpt_4v/share4v/model/multimodal_projector/builder.py b/examples/sharegpt_4v/share4v/model/multimodal_projector/builder.py
new file mode 100644
index 0000000000..907ed3a645
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/model/multimodal_projector/builder.py
@@ -0,0 +1,57 @@
+import re
+import mindspore.nn as nn
+
+class IdentityMap(nn.Cell):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+ @property
+ def config(self):
+ return {"mm_projector_type": 'identity'}
+
+
+class SimpleResBlock(nn.Cell):
+ def __init__(self, channels):
+ super().__init__()
+ self.pre_norm = nn.LayerNorm(channels)
+
+ self.proj = nn.SequentialCell(
+ nn.Dense(channels, channels),
+ # cong TODO: according to ms document,
+ # set approximate to False can get similar result with pt
+ # need to check it
+ nn.GELU(approximate=False),
+ nn.Dense(channels, channels)
+ )
+
+ def forward(self, x):
+ x = self.pre_norm(x)
+ return x + self.proj(x)
+
+
+def build_vision_projector(config, delay_load=False, **kwargs):
+ projector_type = config.get('mm_projector_type', 'linear')
+
+ if projector_type == 'linear':
+ # return nn.Dense(config.mm_hidden_size, config.hidden_size)
+ return nn.Dense(config['mm_hidden_size'], config['hidden_size'])
+
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
+ if mlp_gelu_match:
+ mlp_depth = int(mlp_gelu_match.group(1))
+ modules = [nn.Dense(config['mm_hidden_size'], config['hidden_size'])]
+ for _ in range(1, mlp_depth):
+ # cong TODO: according to ms document,
+ # set approximate to False can get similar result with pt
+ # need to check it
+ modules.append(nn.GELU(approximate=False))
+ modules.append(nn.Dense(config['hidden_size'], config['hidden_size']))
+ return nn.SequentialCell(*modules)
+
+ if projector_type == 'identity':
+ return IdentityMap()
+
+ raise ValueError(f'Unknown projector type: {projector_type}')
diff --git a/examples/sharegpt_4v/share4v/model/share4v_arch.py b/examples/sharegpt_4v/share4v/model/share4v_arch.py
new file mode 100644
index 0000000000..9a28a23090
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/model/share4v_arch.py
@@ -0,0 +1,352 @@
+from abc import ABC, abstractmethod
+
+import mindspore as ms
+import mindspore.ops as ops
+import mindspore.nn as nn
+
+from share4v.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
+ DEFAULT_IMAGE_PATCH_TOKEN, IGNORE_INDEX,
+ IMAGE_TOKEN_INDEX)
+
+from .multimodal_encoder.builder import build_vision_tower
+from .multimodal_projector.builder import build_vision_projector
+
+
+class Share4VMetaModel:
+
+ def __init__(self, config):
+ # hidden_size = config.hidden_size if hasattr(config, 'hidden_size') else 4096
+ # intermediate_size = config.intermediate_size if hasattr(config, 'intermediate_size') else 14336
+ # max_position_embeddings = config.max_position_embeddings if hasattr(config, 'max_position_embeddings') else 32768
+ # num_attention_heads = config.num_attention_heads if hasattr(config, 'num_attention_heads') else 32
+ # num_hidden_layers = config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else 32
+ # num_key_value_heads = config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else 8
+ # rms_norm_eps = config.rms_norm_eps if hasattr(config, 'rms_norm_eps') else 1e-5
+ # rope_theta = config.rope_scaling if hasattr(config, 'rope_theta') else 1000000.0
+ # vocab_size = config.vocab_size if hasattr(config, 'vocab_size') else 32064
+ # attention_dropout = config.attention_dropout if hasattr(config, 'attention_dropout') else 0.0
+ # hidden_act = config.hidden_act if hasattr(config, 'hidden_act') else "silu"
+ # pad_token_id = config.pad_token_id if hasattr(config, 'pad_token_id') else None
+
+ rope_theta = config['rope_scaling'] if config.get('rope_scaling') else 1000000.0
+ attention_dropout = config['attention_dropout'] if config.get('attention_dropout') else 0.0
+ dtype = ms.float16 if config.get('dtype') == "float16" else ms.float32
+ # this part is used for self-defined LlamaModel
+ # and all config related codes are ajusted to fit self-defined LlamaModel
+ # for now, config file don't have attention_dropout
+ super(Share4VMetaModel, self).__init__(hidden_size=config['hidden_size'],
+ intermediate_size=config['intermediate_size'],
+ max_position_embeddings=config['max_position_embeddings'],
+ num_attention_heads=config['num_attention_heads'],
+ num_hidden_layers=config['num_hidden_layers'],
+ num_key_value_heads=config['num_key_value_heads'],
+ rms_norm_eps=config['rms_norm_eps'],
+ rope_theta=rope_theta,
+ vocab_size=config['vocab_size'],
+ attention_dropout=attention_dropout,
+ hidden_act=config['hidden_act'],
+ pad_token_id=config['pad_token_id'],
+ dtype=dtype)
+
+ # this part is used for mindformers.LlamaModel
+ # super(Share4VMetaModel, self).__init__(config)
+
+ if config.get('mm_vision_tower'):
+ # if hasattr(config, "mm_vision_tower"):
+ self.vision_tower = build_vision_tower(config, delay_load=True)
+ self.mm_projector = build_vision_projector(config)
+ self.config = config
+
+ def get_vision_tower(self):
+ vision_tower = getattr(self, 'vision_tower', None)
+ if type(vision_tower) is list:
+ vision_tower = vision_tower[0]
+ return vision_tower
+
+ def get_mm_projector(self):
+ return getattr(self, 'mm_projector', None)
+
+ def set_mm_projector_dtype(self, dtype):
+ mm_projector = self.get_mm_projector()
+ if mm_projector is nn.SequentialCell:
+ for cell in mm_projector:
+ for param in cell.get_parameters():
+ param.set_dtype(dtype)
+ else:
+ for param in mm_projector.get_parameters():
+ param.set_dtype(dtype)
+
+
+
+ def initialize_vision_modules(self, model_args, fsdp=None):
+ vision_tower = model_args.vision_tower
+ mm_vision_select_layer = model_args.mm_vision_select_layer
+ mm_vision_select_feature = model_args.mm_vision_select_feature
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
+
+ # self.config.mm_vision_tower = vision_tower
+ self.config['mm_vision_tower'] = vision_tower
+ if self.get_vision_tower() is None:
+ vision_tower = build_vision_tower(model_args)
+
+ if fsdp is not None and len(fsdp) > 0:
+ self.vision_tower = [vision_tower]
+ else:
+ self.vision_tower = vision_tower
+ elif self.get_vision_tower().vision_tower_name != vision_tower:
+ vision_tower = build_vision_tower(model_args)
+ if fsdp is not None and len(fsdp) > 0:
+ self.vision_tower = [vision_tower]
+ else:
+ self.vision_tower = vision_tower
+ else:
+ if fsdp is not None and len(fsdp) > 0:
+ vision_tower = self.vision_tower[0]
+ vision_tower.load_model()
+ else:
+ vision_tower = self.vision_tower
+ vision_tower.load_model()
+
+ # self.config.use_mm_proj = True
+ # self.config.mm_projector_type = getattr(
+ # model_args, 'mm_projector_type', 'linear')
+ # self.config.mm_hidden_size = vision_tower.hidden_size
+ # self.config.mm_vision_select_layer = mm_vision_select_layer
+ # self.config.mm_vision_select_feature = mm_vision_select_feature
+ self.config['use_mm_proj'] = True
+ self.config['mm_projector_type'] = getattr(
+ model_args, 'mm_projector_type', 'linear')
+ self.config['mm_hidden_size'] = vision_tower.hidden_size
+ self.config['mm_vision_select_layer'] = mm_vision_select_layer
+ self.config['mm_vision_select_feature'] = mm_vision_select_feature
+
+ if getattr(self, 'mm_projector', None) is None:
+ self.mm_projector = build_vision_projector(self.config)
+ else:
+ # In case it is frozen by LoRA
+ for p in self.mm_projector.get_parameters():
+ p.requires_grad = True
+
+ if pretrain_mm_mlp_adapter is not None:
+ print(f'Load mm_mlp_adapter from {pretrain_mm_mlp_adapter}')
+ mm_projector_weights = ms.load_checkpoint(pretrain_mm_mlp_adapter)
+
+ def get_w(weights, keyword):
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
+
+ self.mm_projector.load_state_dict(
+ get_w(mm_projector_weights, 'mm_projector'))
+
+
+class Share4VMetaForCausalLM(ABC):
+
+ @abstractmethod
+ def get_model(self):
+ pass
+
+ def get_vision_tower(self):
+ return self.get_model().get_vision_tower()
+
+ def get_mm_projector(self):
+ return self.get_model().get_mm_projector()
+
+ def encode_images(self, images):
+ image_features = self.get_model().get_vision_tower()(images)
+ image_features = self.get_model().mm_projector(image_features)
+ return image_features
+
+ def prepare_inputs_labels_for_multimodal(
+ self, input_ids, attention_mask, past_key_cache_list, past_value_cache_list, labels, images
+ ):
+ vision_tower = self.get_vision_tower()
+ # print(input_ids.shape)
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
+ if past_key_cache_list is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
+ attention_mask = ops.ones(
+ # cong TODO: check the shape
+ (attention_mask.shape[0], past_value_cache_list[-1].shape[-2] + 1), dtype=attention_mask.dtype)
+ # (attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype)
+ # print('-------------------------prepare_inputs_labels_for_multimodal return inputs_embeds None----------')
+ return input_ids, attention_mask, past_key_cache_list, past_value_cache_list, None, labels
+
+ #cong TODO: check the usage images.ndim
+ if type(images) is list or images.ndim == 5:
+ concat_images = ops.cat([image for image in images], axis=0)
+ image_features = self.encode_images(concat_images)
+ split_sizes = [image.shape[0] for image in images]
+ image_features = ops.split(image_features, split_sizes, axis=0)
+ image_features = [x.flatten(0, 1) for x in image_features]
+ else:
+ image_features = self.encode_images(images)
+
+ new_input_embeds = []
+ new_labels = [] if labels is not None else None
+ cur_image_idx = 0
+ for batch_idx, cur_input_ids in enumerate(input_ids):
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
+ # multimodal LLM, but the current sample is not multimodal
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
+ half_len = cur_input_ids.shape[0] // 2
+ cur_image_features = image_features[cur_image_idx]
+ cur_input_embeds_1 = self.get_model().embed_tokens(
+ cur_input_ids[:half_len])
+ cur_input_embeds_2 = self.get_model().embed_tokens(
+ cur_input_ids[half_len:])
+ cur_input_embeds = ops.cat(
+ [cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], axis=0)
+ new_input_embeds.append(cur_input_embeds)
+ if labels is not None:
+ new_labels.append(labels[batch_idx])
+ cur_image_idx += 1
+ continue
+ image_token_indices = ops.nonzero(ops.where(
+ cur_input_ids == IMAGE_TOKEN_INDEX, 1, 0)).squeeze(1)
+ cur_new_input_embeds = []
+ if labels is not None:
+ cur_labels = labels[batch_idx]
+ cur_new_labels = []
+ assert cur_labels.shape == cur_input_ids.shape
+ while image_token_indices.numel() > 0:
+ # print(3)
+ cur_image_features = image_features[cur_image_idx]
+ image_token_start = image_token_indices[0]
+ if self.config.get('tune_mm_mlp_adapter', False) and self.config.get('mm_use_im_start_end', False):
+ # if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
+ cur_new_input_embeds.append(self.get_model().embed_tokens(
+ cur_input_ids[:image_token_start-1]).detach())
+ cur_new_input_embeds.append(self.get_model().embed_tokens(
+ cur_input_ids[image_token_start-1:image_token_start]))
+ cur_new_input_embeds.append(cur_image_features)
+ cur_new_input_embeds.append(self.get_model().embed_tokens(
+ cur_input_ids[image_token_start+1:image_token_start+2]))
+ if labels is not None:
+ cur_new_labels.append(cur_labels[:image_token_start])
+ cur_new_labels.append(ops.full(
+ (cur_image_features.shape[0],), IGNORE_INDEX, dtype=labels.dtype))
+ cur_new_labels.append(
+ cur_labels[image_token_start+1:image_token_start+2])
+ cur_labels = cur_labels[image_token_start+2:]
+ else:
+ cur_new_input_embeds.append(self.get_model().embed_tokens(
+ cur_input_ids[:image_token_start]))
+ cur_new_input_embeds.append(cur_image_features)
+ if labels is not None:
+ cur_new_labels.append(cur_labels[:image_token_start])
+ cur_new_labels.append(ops.full(
+ (cur_image_features.shape[0],), IGNORE_INDEX, dtype=labels.dtype))
+ cur_labels = cur_labels[image_token_start+1:]
+ cur_image_idx += 1
+ if self.config.get('tune_mm_mlp_adapter', False) and self.config.get('mm_use_im_start_end', False):
+ # if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
+ cur_input_ids = cur_input_ids[image_token_start+2:]
+ else:
+ cur_input_ids = cur_input_ids[image_token_start+1:]
+ image_token_indices = ops.nonzero(ops.where(
+ cur_input_ids == IMAGE_TOKEN_INDEX, 1, 0)).squeeze(1)
+ if cur_input_ids.numel() > 0:
+ if self.config.get('tune_mm_mlp_adapter', False) and self.config.get('mm_use_im_start_end', False):
+ # if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
+ cur_new_input_embeds.append(
+ self.get_model().embed_tokens(cur_input_ids).detach())
+ else:
+ cur_new_input_embeds.append(
+ self.get_model().embed_tokens(cur_input_ids))
+ if labels is not None:
+ cur_new_labels.append(cur_labels)
+ cur_new_input_embeds = ops.cat(cur_new_input_embeds, axis=0)
+ new_input_embeds.append(cur_new_input_embeds)
+ if labels is not None:
+ cur_new_labels = ops.cat(cur_new_labels, axis=0)
+ new_labels.append(cur_new_labels)
+
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
+ max_len = max(x.shape[0] for x in new_input_embeds)
+
+ new_input_embeds_align = []
+ for cur_new_embed in new_input_embeds:
+ cur_new_embed = ops.cat((cur_new_embed, ops.zeros(
+ (max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype)), axis=0)
+ new_input_embeds_align.append(cur_new_embed)
+ new_input_embeds = ops.stack(new_input_embeds_align, axis=0)
+
+ if labels is not None:
+ new_labels_align = []
+ _new_labels = new_labels
+ for cur_new_label in new_labels:
+ cur_new_label = ops.cat((cur_new_label, ops.full(
+ (max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype)), axis=0)
+ new_labels_align.append(cur_new_label)
+ new_labels = ops.stack(new_labels_align, axis=0)
+
+ if attention_mask is not None:
+ new_attention_mask = []
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
+ new_attn_mask_pad_left = ops.full(
+ (cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype)
+ new_attn_mask_pad_right = ops.full(
+ (cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype)
+ cur_new_attention_mask = ops.cat(
+ (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), axis=0)
+ new_attention_mask.append(cur_new_attention_mask)
+ attention_mask = ops.stack(new_attention_mask, axis=0)
+ assert attention_mask.shape == new_labels.shape
+ else:
+ new_input_embeds = ops.stack(new_input_embeds, axis=0)
+ if labels is not None:
+ new_labels = ops.stack(new_labels, axis=0)
+
+ if attention_mask is not None:
+ new_attn_mask_pad_left = ops.full(
+ (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype)
+ attention_mask = ops.cat(
+ (new_attn_mask_pad_left, attention_mask), axis=1)
+ assert attention_mask.shape == new_input_embeds.shape[:2]
+
+ return None, attention_mask, past_key_cache_list, past_value_cache_list, new_input_embeds, new_labels
+
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
+ if model_args.mm_use_im_patch_token:
+ tokenizer.add_tokens(
+ [DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+
+ if model_args.mm_use_im_start_end:
+ num_new_tokens = tokenizer.add_tokens(
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = self.get_input_embeddings().weight.data
+ output_embeddings = self.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ axis=0, keep_dims=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ axis=0, keep_dims=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+ if model_args.tune_mm_mlp_adapter:
+ for p in self.get_input_embeddings().get_parameters():
+ p.requires_grad = True
+ for p in self.get_output_embeddings().get_parameters():
+ p.requires_grad = False
+
+ if model_args.pretrain_mm_mlp_adapter:
+ mm_projector_weights = ms.load_checkpoint(model_args.pretrain_mm_mlp_adapter)
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
+ assert num_new_tokens == 2
+ if input_embeddings.shape == embed_tokens_weight.shape:
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
+ else:
+ raise ValueError(
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
+ elif model_args.mm_use_im_patch_token:
+ if model_args.tune_mm_mlp_adapter:
+ for p in self.get_input_embeddings().get_parameters():
+ p.requires_grad = False
+ for p in self.get_output_embeddings().get_parameters():
+ p.requires_grad = False
diff --git a/examples/sharegpt_4v/share4v/model/utils.py b/examples/sharegpt_4v/share4v/model/utils.py
new file mode 100644
index 0000000000..a0d63bb01c
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/model/utils.py
@@ -0,0 +1,20 @@
+from mindformers import AutoConfig
+
+
+def auto_upgrade(config):
+ cfg = AutoConfig.from_pretrained(config)
+ if 'share4v' in config and 'share4v' not in cfg.model_type:
+ assert cfg.model_type == 'llama'
+ print("You are using newer ShareGPT4V code base, while the checkpoint of v0 is from older code base.")
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
+ if confirm.lower() in ["y", "yes"]:
+ print("Upgrading checkpoint...")
+ assert len(cfg.architectures) == 1
+ setattr(cfg.__class__, "model_type", "share4v")
+ cfg.architectures[0] = 'Share4VLlamaForCausalLM'
+ cfg.save_pretrained(config)
+ print("Checkpoint upgraded.")
+ else:
+ print("Checkpoint upgrade aborted.")
+ exit(1)
diff --git a/examples/sharegpt_4v/share4v/pipeline/__init__.py b/examples/sharegpt_4v/share4v/pipeline/__init__.py
new file mode 100644
index 0000000000..bfa5257cff
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/pipeline/__init__.py
@@ -0,0 +1 @@
+from .text_generation import TextGenerator
diff --git a/examples/sharegpt_4v/share4v/pipeline/helpers/__init__.py b/examples/sharegpt_4v/share4v/pipeline/helpers/__init__.py
new file mode 100644
index 0000000000..1ee5cab11e
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/pipeline/helpers/__init__.py
@@ -0,0 +1 @@
+from .stopping_criteria import *
\ No newline at end of file
diff --git a/examples/sharegpt_4v/share4v/pipeline/helpers/stopping_criteria.py b/examples/sharegpt_4v/share4v/pipeline/helpers/stopping_criteria.py
new file mode 100644
index 0000000000..232100a65d
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/pipeline/helpers/stopping_criteria.py
@@ -0,0 +1,56 @@
+import abc
+import logging
+from typing import List, Optional, Union
+
+import mindspore as ms
+import mindspore.ops as ops
+from mindspore import Tensor
+
+logger = logging.getLogger(__name__)
+
+
+__all__ = ["MaxLengthCriteria", "EosTokenCriteria", "StoppingCriteriaList"]
+
+
+class StoppingCriteria(abc.ABC):
+ @abc.abstractmethod
+ def __call__(self, input_ids: Tensor) -> Tensor:
+ raise NotImplementedError("StoppingCriteria needs to be subclassed")
+
+
+class MaxLengthCriteria(StoppingCriteria):
+ def __init__(self, max_length: int) -> None:
+ self.max_length = max_length
+
+ def __call__(self, input_ids: Tensor) -> Tensor:
+ cur_len = input_ids.shape[-1]
+ is_done = cur_len >= self.max_length
+ return ops.full((input_ids.shape[0],), is_done, dtype=ms.bool_)
+
+
+class EosTokenCriteria(StoppingCriteria):
+ def __init__(self, eos_token_id: Union[int, List[int], Tensor]) -> None:
+ if not isinstance(eos_token_id, Tensor):
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ eos_token_id = Tensor(eos_token_id)
+ self.eos_token_id = eos_token_id
+
+ def __call__(self, input_ids: Tensor) -> Tensor:
+ is_done = ms.numpy.isin(input_ids[:, -1], self.eos_token_id)
+ return is_done
+
+
+class StoppingCriteriaList(list):
+ def __call__(self, input_ids: Tensor) -> Tensor:
+ is_done = ops.full((input_ids.shape[0],), False, dtype=ms.bool_)
+ for criteria in self:
+ is_done = ops.logical_or(is_done, criteria(input_ids))
+ return is_done
+
+ @property
+ def max_length(self) -> Optional[int]:
+ for stopping_criterium in self:
+ if isinstance(stopping_criterium, MaxLengthCriteria):
+ return stopping_criterium.max_length
+ return None
\ No newline at end of file
diff --git a/examples/sharegpt_4v/share4v/pipeline/text_generation.py b/examples/sharegpt_4v/share4v/pipeline/text_generation.py
new file mode 100644
index 0000000000..d7bda257fd
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/pipeline/text_generation.py
@@ -0,0 +1,275 @@
+import logging
+from typing import Dict, Optional, Tuple
+
+import mindspore as ms
+import mindspore.nn as nn
+import mindspore.ops as ops
+from mindspore import Tensor
+
+from .helpers import EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList
+
+logger = logging.getLogger(__name__)
+
+
+class TextGenerator:
+ def __init__(
+ self,
+ model: nn.Cell,
+ bos_token_id: int = 1,
+ eos_token_id: int = 2,
+ pad_token_id: Optional[int] = None,
+ max_new_tokens: Optional[int] = 100,
+ min_new_tokens: Optional[int] = None,
+ use_kv_cache: bool = False,
+ ) -> None:
+ self.model = model.set_train(False)
+ for param in self.model.trainable_params():
+ param.requires_grad = False
+
+ self._bos_token_id = bos_token_id
+ self._eos_token_id = eos_token_id
+ self._pad_token_id = pad_token_id
+ self._max_new_tokens = max_new_tokens
+ self._min_new_tokens = min_new_tokens
+ self._use_kv_cache = use_kv_cache
+
+ self._max_length: Optional[int] = None
+ self._min_length: Optional[int] = None
+
+ if not hasattr(self.model, "prepare_inputs_for_generation"):
+ raise NotImplementedError(
+ "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`."
+ )
+
+ if self._use_kv_cache:
+ self._past_key_cache_list: Optional[Tensor] = None
+ self._past_value_cache_list: Optional[Tensor] = None
+
+ def _prepare_model_inputs(
+ self, bos_token_id: Optional[Tensor] = None, model_kwargs: Optional[Dict[str, Tensor]] = None
+ ) -> Tuple[Tensor, Dict[str, Tensor]]:
+ input_name = "input_ids" # support inputs id only
+ model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
+
+ inputs = model_kwargs.pop(input_name, None)
+
+ # if `inputs` is still None, try to create `input_ids` from BOS token
+ inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
+ return inputs, model_kwargs
+
+ def _maybe_initialize_input_ids_for_generation(
+ self,
+ inputs: Optional[Tensor] = None,
+ bos_token_id: Optional[Tensor] = None,
+ model_kwargs: Optional[Dict[str, Tensor]] = None,
+ ) -> Tensor:
+ """Initializes input ids for generation, if necessary."""
+ if inputs is not None:
+ return inputs
+
+ # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
+ # soft-prompting or in multimodal implementations built on top of decoder-only language models.
+ batch_size = 1
+ for value in model_kwargs.values():
+ if isinstance(value, Tensor):
+ batch_size = value.shape[0]
+ break
+
+ if bos_token_id is None:
+ raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
+
+ return ops.ones((batch_size, 1), dtype=ms.int32) * bos_token_id
+
+ def _prepare_attention_mask_for_generation(
+ self, inputs: Tensor, pad_token_id: Optional[Tensor], eos_token_id: Optional[Tensor]
+ ) -> Tensor:
+ # No information for attention mask inference -> return default attention mask
+ default_attention_mask = ops.ones(inputs.shape[:2], dtype=ms.int32)
+ if pad_token_id is None:
+ return default_attention_mask
+
+ is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [ms.int32, ms.int64]
+ if not is_input_ids:
+ return default_attention_mask
+
+ is_pad_token_in_inputs = (pad_token_id is not None) and (
+ ms.numpy.isin(element=inputs, test_elements=pad_token_id).any()
+ )
+ is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
+ ms.numpy.isin(element=eos_token_id, test_elements=pad_token_id).any()
+ )
+ can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
+ attention_mask_from_padding = inputs.ne(pad_token_id).to(ms.int32)
+
+ attention_mask = (
+ attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
+ )
+ return attention_mask
+
+ def _update_model_kwargs_for_generation(
+ self,
+ model_kwargs: Dict[str, Tensor],
+ key_cache_list: Optional[Tensor] = None,
+ value_cache_list: Optional[Tensor] = None,
+ ) -> Dict[str, Tensor]:
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ model_kwargs["attention_mask"] = ops.concat(
+ [attention_mask, ops.ones((attention_mask.shape[0], 1), dtype=attention_mask.dtype)], axis=-1
+ )
+
+ # update kv cache
+ if key_cache_list is not None and value_cache_list is not None:
+ if self._past_key_cache_list is not None and self._past_value_cache_list is not None:
+ self._past_key_cache_list = ops.concat([self._past_key_cache_list, key_cache_list], axis=-2)
+ self._past_value_cache_list = ops.concat([self._past_value_cache_list, value_cache_list], axis=-2)
+ else:
+ self._past_key_cache_list = key_cache_list
+ self._past_value_cache_list = value_cache_list
+
+ model_kwargs["past_key_cache_list"] = self._past_key_cache_list
+ model_kwargs["past_value_cache_list"] = self._past_value_cache_list
+
+ return model_kwargs
+
+ def _get_stopping_criteria(self) -> StoppingCriteriaList:
+ criteria = StoppingCriteriaList()
+ if self._max_length is not None:
+ criteria.append(MaxLengthCriteria(self._max_length))
+
+ if self._eos_token_id is not None:
+ criteria.append(EosTokenCriteria(eos_token_id=self._eos_token_id))
+ return criteria
+
+ def _prepare_generated_length(self, input_ids_length: int) -> None:
+ """Prepared max and min length in generaion configs to avoid clashes between similar attributes"""
+ if self._max_new_tokens is not None:
+ self._max_length = self._max_new_tokens + input_ids_length
+ if self._min_new_tokens is not None:
+ self._min_length = self._min_new_tokens + input_ids_length
+
+ def _prepare_special_tokens(self, kwargs_has_attention_mask: Optional[bool] = None):
+ # Convert special tokens to tensors (if they exist either in kwargs or in self.config)
+ def _tensor_or_none(token):
+ if token is None or isinstance(token, Tensor):
+ return token
+ return Tensor(token, dtype=ms.int32)
+
+ bos_token_id = _tensor_or_none(self._bos_token_id)
+ eos_token_id = _tensor_or_none(self._eos_token_id)
+ pad_token_id = _tensor_or_none(self._pad_token_id)
+
+ # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
+ if eos_token_id is not None and eos_token_id.ndim == 0:
+ eos_token_id = eos_token_id.unsqueeze(0)
+
+ # Set pad token if unset (and there are conditions to do so)
+ if pad_token_id is None and eos_token_id is not None:
+ if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
+ logger.warning(
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
+ )
+ pad_token_id = eos_token_id[0]
+ logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
+
+ # we can't infer attn mask if pad token is set to be eos token in model's generation config
+ if eos_token_id is not None and ms.numpy.isin(element=eos_token_id, test_elements=pad_token_id).any():
+ if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
+ logger.warning(
+ "The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
+ "As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
+ "to obtain reliable results."
+ )
+
+ # Sanity checks/warnings
+ if eos_token_id is not None and (eos_token_id < 0).any():
+ logger.warning(
+ f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not "
+ "stop until the maximum length is reached. Depending on other flags, it may even crash."
+ )
+
+ # Update generation config with the updated special tokens tensors
+ self._bos_token_id = bos_token_id
+ self._eos_token_id = eos_token_id
+ self._pad_token_id = pad_token_id
+
+ def generate(self, **model_kwargs) -> Tensor:
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+
+ # Define model inputs
+ input_ids, model_kwargs = self._prepare_model_inputs(self._bos_token_id, model_kwargs)
+ batch_size = input_ids.shape[0]
+ self._prepare_special_tokens(kwargs_has_attention_mask)
+
+ # decoder-only models must use left-padding for batched generation.
+ # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
+ if (
+ self._pad_token_id is not None
+ and batch_size > 1
+ and len(input_ids.shape) == 2
+ and ops.sum(input_ids[:, -1] == self._pad_token_id) > 0
+ ):
+ logger.warning(
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
+ )
+
+ if not kwargs_has_attention_mask:
+ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
+ input_ids, self._pad_token_id, self._eos_token_id
+ )
+
+ # prepare `max_length` depending on other stopping criteria.
+ input_ids_length = input_ids.shape[-1]
+ self._prepare_generated_length(input_ids_length)
+
+ # reset cache if neccesary
+ if self._use_kv_cache:
+ self._past_key_cache_list, self._past_value_cache_list = None, None
+
+ # prepare stopping criteria
+ prepared_stopping_criteria = self._get_stopping_criteria()
+
+ # run sample
+ result = self._sample(input_ids, stopping_criteria=prepared_stopping_criteria, **model_kwargs)
+
+ return result
+
+ def _sample(self, input_ids: Tensor, stopping_criteria: StoppingCriteriaList, **model_kwargs: Tensor) -> Tensor:
+ # init values
+ pad_token_id = self._pad_token_id
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
+
+ # keep track of which sequences are already finished
+ batch_size = input_ids.shape[0]
+ this_peer_finished = False
+ unfinished_sequences = ops.ones(batch_size, dtype=ms.int32)
+
+ while not this_peer_finished:
+ # prepare model inputs
+ model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+ # inject kv cache state
+ model_inputs["return_key_value_cache"] = self._use_kv_cache
+
+ # forward pass to get next token
+ loss, logits, key_cache_list, value_cache_list = self.model(**model_inputs)
+ next_token_scores = logits[:, -1, :]
+
+ # token selection
+ next_tokens = ops.argmax(next_token_scores, dim=-1)
+
+ # finished sentences should have their next token be a padding token
+ if has_eos_stopping_criteria:
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
+
+ # update generated ids, model inputs, and length for next step
+ input_ids = ops.concat([input_ids, next_tokens[:, None]], axis=-1)
+ model_kwargs = self._update_model_kwargs_for_generation(model_kwargs, key_cache_list, value_cache_list)
+
+ unfinished_sequences = ops.logical_and(unfinished_sequences, ~stopping_criteria(input_ids))
+ this_peer_finished = unfinished_sequences.max() == 0
+
+ return input_ids
\ No newline at end of file
diff --git a/examples/sharegpt_4v/share4v/transformers/__init__.py b/examples/sharegpt_4v/share4v/transformers/__init__.py
new file mode 100644
index 0000000000..5965a6081f
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/transformers/__init__.py
@@ -0,0 +1,9 @@
+from .modeling_ms_utils import MSPreTrainedModel
+from .models import (
+ CLIPModel,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPVisionModel,
+ CLIPVisionModelWithProjection,
+)
+from .models import (LlamaForCausalLM, LlamaModel)
\ No newline at end of file
diff --git a/examples/sharegpt_4v/share4v/transformers/activations_ms.py b/examples/sharegpt_4v/share4v/transformers/activations_ms.py
new file mode 100644
index 0000000000..7309b5adbe
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/transformers/activations_ms.py
@@ -0,0 +1,218 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections import OrderedDict
+
+from mindspore import Tensor, nn, ops
+
+
+class PytorchGELUTanh(nn.Cell):
+ """
+ A fast C implementation of the tanh approximation of the GeLU activation function. See
+ https://arxiv.org/abs/1606.08415.
+
+ This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
+ match due to rounding errors.
+ """
+
+ def construct(self, input: Tensor) -> Tensor:
+ return ops.gelu(input, approximate="tanh")
+
+
+class NewGELUActivation(nn.Cell):
+ """
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+
+ def construct(self, input: Tensor) -> Tensor:
+ return (
+ 0.5 * input * (1.0 + ops.tanh(ops.sqrt(Tensor(2.0 / math.pi)) * (input + 0.044715 * ops.pow(input, 3.0))))
+ )
+
+
+class GELUActivation(nn.Cell):
+ """
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
+ Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, use_gelu_python: bool = False):
+ super().__init__()
+ if use_gelu_python:
+ self.act = self._gelu_python
+ else:
+ self.act = ops.gelu
+
+ def _gelu_python(self, input: Tensor) -> Tensor:
+ return input * 0.5 * (1.0 + ops.erf(input / math.sqrt(2.0)))
+
+ def construct(self, input: Tensor) -> Tensor:
+ return self.act(input)
+
+
+class FastGELUActivation(nn.Cell):
+ """
+ Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
+ """
+
+ def construct(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1.0 + ops.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
+
+
+class QuickGELUActivation(nn.Cell):
+ """
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
+ """
+
+ def construct(self, input: Tensor) -> Tensor:
+ return input * ops.sigmoid(1.702 * input)
+
+
+class ClippedGELUActivation(nn.Cell):
+ """
+ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
+ it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
+ https://arxiv.org/abs/2004.09602.
+
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
+ initially created.
+
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, min: float, max: float):
+ if min > max:
+ raise ValueError(f"min should be < max (got min: {min}, max: {max})")
+
+ super().__init__()
+ self.min = min
+ self.max = max
+ self.gelu = get_activation("gelu")
+
+ def construct(self, x: Tensor) -> Tensor:
+ return ops.clip(self.gelu(x), self.min, self.max)
+
+
+class AccurateGELUActivation(nn.Cell):
+ """
+ Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
+ https://github.com/hendrycks/GELUs
+
+ Implemented along with MEGA (Moving Average Equipped Gated Attention)
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.precomputed_constant = math.sqrt(2 / math.pi)
+
+ def construct(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1 + ops.tanh(self.precomputed_constant * (input + 0.044715 * ops.pow(input, 3))))
+
+
+class SiLUActivation(nn.Cell):
+ """
+ See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
+ Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
+ Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
+ Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
+ later.
+ """
+
+ def construct(self, input: Tensor) -> Tensor:
+ return ops.silu(input)
+
+
+class MishActivation(nn.Cell):
+ """
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
+ """
+
+ def construct(self, input: Tensor) -> Tensor:
+ return ops.mish(input)
+
+
+class LinearActivation(nn.Cell):
+ """
+ Applies the linear activation function, i.e. forwarding input directly to output.
+ """
+
+ def construct(self, input: Tensor) -> Tensor:
+ return input
+
+
+class LaplaceActivation(nn.Cell):
+ """
+ Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
+ https://arxiv.org/abs/2209.10655
+
+ Inspired by squared relu, but with bounded range and gradient for better stability
+ """
+
+ def construct(self, input, mu=0.707107, sigma=0.282095):
+ input = (input - mu).div(sigma * math.sqrt(2.0))
+ return 0.5 * (1.0 + ops.erf(input))
+
+
+class ReLUSquaredActivation(nn.Cell):
+ """
+ Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
+ """
+
+ def construct(self, input):
+ relu_applied = ops.relu(input)
+ squared = ops.square(relu_applied)
+ return squared
+
+
+class ClassInstantier(OrderedDict):
+ def __getitem__(self, key):
+ content = super().__getitem__(key)
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
+ return cls(**kwargs)
+
+
+ACT2CLS = {
+ "gelu": GELUActivation,
+ "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
+ "gelu_fast": FastGELUActivation,
+ "gelu_new": NewGELUActivation,
+ "gelu_python": (GELUActivation, {"use_gelu_python": True}),
+ "gelu_pytorch_tanh": PytorchGELUTanh,
+ "gelu_accurate": AccurateGELUActivation,
+ "laplace": LaplaceActivation,
+ "linear": LinearActivation,
+ "mish": MishActivation,
+ "quick_gelu": QuickGELUActivation,
+ "relu": nn.ReLU,
+ "relu2": ReLUSquaredActivation,
+ "relu6": nn.ReLU6,
+ "sigmoid": nn.Sigmoid,
+ "silu": SiLUActivation,
+ "swish": SiLUActivation,
+ "tanh": nn.Tanh,
+}
+ACT2FN = ClassInstantier(ACT2CLS)
+
+
+def get_activation(activation_string):
+ if activation_string in ACT2FN:
+ return ACT2FN[activation_string]
+ else:
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
diff --git a/examples/sharegpt_4v/share4v/transformers/modeling_ms_utils.py b/examples/sharegpt_4v/share4v/transformers/modeling_ms_utils.py
new file mode 100644
index 0000000000..53df0ee019
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/transformers/modeling_ms_utils.py
@@ -0,0 +1,1127 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+import os
+import warnings
+from typing import Callable, Optional, Tuple, Union
+
+import numpy as np
+from transformers.configuration_utils import PretrainedConfig
+from transformers.safetensors_conversion import auto_conversion
+from transformers.utils import (
+ CONFIG_NAME,
+ FLAX_WEIGHTS_NAME,
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ TF2_WEIGHTS_NAME,
+ TF_WEIGHTS_NAME,
+ WEIGHTS_INDEX_NAME,
+ WEIGHTS_NAME,
+ PushToHubMixin,
+ cached_file,
+ download_url,
+ extract_commit_hash,
+ has_file,
+ is_offline_mode,
+ is_remote_url,
+ is_safetensors_available,
+ logging,
+)
+from transformers.utils.hub import get_checkpoint_shard_files
+
+import mindspore as ms
+from mindspore import Tensor
+from mindspore import dtype as mstype
+from mindspore import nn, ops
+
+
+if is_safetensors_available():
+ from safetensors import safe_open
+ from mindone.safetensors.mindspore import load_file as safe_load_file
+
+logger = logging.get_logger(__name__)
+
+
+def get_parameter_dtype(parameter: Union[nn.Cell, "ModuleUtilsMixin"]):
+ """
+ Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
+ """
+ last_dtype = None
+ for t in parameter.get_parameters():
+ last_dtype = t.dtype
+ if t.is_floating_point():
+ return t.dtype
+
+ # if no floating dtype was found return whatever the first dtype is
+ return last_dtype
+
+
+def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
+ """
+ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
+ """
+ try:
+ if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
+ # Check format of the archive
+ with safe_open(checkpoint_file, framework="np") as f:
+ metadata = f.metadata()
+ if metadata.get("format") not in ["pt", "tf", "flax", "np"]:
+ raise OSError(
+ f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
+ "you save your model with the `save_pretrained` method."
+ )
+ return safe_load_file(checkpoint_file)
+ else:
+ raise NotImplementedError(
+ f"Only supports deserialization of weights file in safetensors format, but got {checkpoint_file}"
+ )
+ except Exception as e:
+ try:
+ with open(checkpoint_file) as f:
+ if f.read(7) == "version":
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+ "you cloned."
+ )
+ else:
+ raise ValueError(
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
+ "model. Make sure you have saved the model properly."
+ ) from e
+ except (UnicodeDecodeError, ValueError):
+ raise OSError(
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
+ )
+
+
+def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
+ if variant is not None:
+ splits = weights_name.split(".")
+ splits = splits[:-1] + [variant] + splits[-1:]
+ weights_name = ".".join(splits)
+
+ return weights_name
+
+
+class ModuleUtilsMixin:
+ """
+ A few utilities for `mindspore.nn.Cell`, to be used as a mixin.
+ """
+
+ def to(self, dtype: Optional[ms.Type] = None):
+ for p in self.get_parameters():
+ p.set_dtype(dtype)
+ return self
+
+ def float(self):
+ for p in self.get_parameters():
+ p.set_dtype(ms.float32)
+ return self
+
+ def half(self):
+ for p in self.get_parameters():
+ p.set_dtype(ms.float16)
+ return self
+
+ @property
+ def dtype(self) -> ms.Type:
+ """
+ `ms.Type`: The dtype of the module (assuming that all the module parameters have the same dtype).
+ """
+ return get_parameter_dtype(self)
+
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
+ """
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
+
+ Args:
+ only_trainable (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of trainable parameters
+
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of non-embeddings parameters
+
+ Returns:
+ `int`: The number of parameters.
+ """
+
+ if exclude_embeddings:
+ embedding_param_names = [
+ f"{name}.weight"
+ for name, module_type in self.cells_and_names()
+ if isinstance(module_type, nn.Embedding)
+ ]
+ total_parameters = [
+ parameter for name, parameter in self.cells_and_names() if name not in embedding_param_names
+ ]
+ else:
+ total_parameters = list(self.get_parameters())
+
+ total_numel = []
+ for param in total_parameters:
+ if param.requires_grad or not only_trainable:
+ total_numel.append(param.numel())
+
+ return sum(total_numel)
+
+ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
+ """
+ Invert an attention mask (e.g., switches 0. and 1.).
+
+ Args:
+ encoder_attention_mask (`Tensor`): An attention mask.
+
+ Returns:
+ `Tensor`: The inverted attention mask.
+ """
+ encoder_extended_attention_mask = None
+ if encoder_attention_mask.dim() == 3:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+
+ if encoder_attention_mask.dim() == 2:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
+ # /transformer/transformer_layers.py#L270
+ # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
+ # encoder_extended_attention_mask.transpose(-1, -2))
+ if encoder_extended_attention_mask is not None:
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * Tensor(
+ np.finfo(mstype.dtype_to_nptype(self.dtype)).min
+ )
+
+ return encoder_extended_attention_mask
+
+ @staticmethod
+ def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
+ batch_size, seq_length = input_shape
+ seq_ids = ops.arange(seq_length)
+ causal_mask = seq_ids[None, None, :].tile((batch_size, seq_length, 1)) <= seq_ids[None, :, None]
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = ops.cat(
+ [
+ ops.ones((batch_size, seq_length, prefix_seq_len), dtype=causal_mask.dtype),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ # extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ # extended_attention_mask = ops.mul(causal_mask[:, None, :, :], attention_mask[:, None, None, :])
+ extended_attention_mask = ops.mul(causal_mask.unsqueeze(1), attention_mask.unsqueeze(1).unsqueeze(1))
+ return extended_attention_mask
+
+ def get_extended_attention_mask(
+ self, attention_mask: Tensor, input_shape: Tuple[int], decoder_type, dtype: ms.float32 = None
+ ) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (`Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (`Tuple[int]`):
+ The shape of the input to the model.
+
+ Returns:
+ `Tensor` The extended attention mask, with the same dtype as `attention_mask.dtype`.
+ """
+ if dtype is None:
+ dtype = self.dtype
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if decoder_type:
+ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
+ input_shape, attention_mask
+ )
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * Tensor(
+ np.finfo(mstype.dtype_to_nptype(self.dtype)).min
+ )
+ return extended_attention_mask
+
+ def get_head_mask(
+ self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
+ ) -> Tensor:
+ """
+ Prepare the head mask if needed.
+
+ Args:
+ head_mask (`Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
+ The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
+ num_hidden_layers (`int`):
+ The number of hidden layers in the model.
+ is_attention_chunked (`bool`, *optional*, defaults to `False`):
+ Whether or not the attentions scores are computed by chunks or not.
+
+ Returns:
+ `Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
+ `[None]` for each layer.
+ """
+ if head_mask is not None:
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
+ if is_attention_chunked is True:
+ head_mask = head_mask.unsqueeze(-1)
+ else:
+ head_mask = [None] * num_hidden_layers
+
+ return head_mask
+
+ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
+ """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
+ if head_mask.dim() == 1:
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ head_mask = head_mask.broadcast_to(num_hidden_layers, -1, -1, -1, -1)
+ elif head_mask.dim() == 2:
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
+ head_mask = Tensor(head_mask, self.dtype) # switch to float if need + fp16 compatibility
+ return head_mask
+
+
+class MSPreTrainedModel(nn.Cell, ModuleUtilsMixin, PushToHubMixin):
+ r"""
+ Base class for all models.
+
+ [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
+ downloading and saving models as well as a few methods common to all models to:
+
+ - resize the input embeddings,
+ - prune heads in the self-attention heads.
+
+ Class attributes (overridden by derived classes):
+
+ - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
+ for this model architecture.
+ - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
+ taking as arguments:
+
+ - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
+ - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
+ - **path** (`str`) -- A path to the TensorFlow checkpoint.
+
+ - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
+ classes of the same architecture adding modules on top of the base model.
+ - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
+ - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
+ models, `pixel_values` for vision models and `input_values` for speech models).
+ """
+ config_class = None
+ base_model_prefix = ""
+ main_input_name = "input_ids"
+ _auto_class = None
+ _no_split_modules = None
+ _skip_keys_device_placement = None
+ _keep_in_fp32_modules = None
+
+ # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
+ # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
+ _keys_to_ignore_on_load_missing = None
+ # a list of `re` patterns of `state_dict` keys that should be removed from the list of
+ # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
+ # warnings.
+ _keys_to_ignore_on_load_unexpected = None
+ # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
+ # trained, but which are either deterministic or tied variables)
+ _keys_to_ignore_on_save = None
+
+ is_parallelizable = False
+ supports_gradient_checkpointing = False
+
+ @property
+ def framework(self) -> str:
+ """
+ :str: Identifies that this is a MindSpore model.
+ """
+ return "ms"
+
+ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
+ super().__init__()
+ if not isinstance(config, PretrainedConfig):
+ raise ValueError(
+ f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
+ "`PretrainedConfig`. To create a model from a pretrained model use "
+ f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ # Save config and origin of the pretrained weights if given in model
+ self.config = config
+ self.name_or_path = config.name_or_path
+ self.warnings_issued = {}
+ self.generation_config = None
+
+ def post_init(self):
+ """
+ A method executed at the end of each Transformer model initialization, to execute code that needs the model's
+ modules properly initialized (such as weight initialization).
+ """
+ pass
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ state_dict: Optional[dict] = None,
+ save_function: Callable = ms.save_checkpoint,
+ push_to_hub: bool = False,
+ max_shard_size: Union[int, str] = "5GB",
+ safe_serialization: bool = True,
+ variant: Optional[str] = None,
+ token: Optional[Union[str, bool]] = None,
+ save_peft_format: bool = True,
+ **kwargs,
+ ):
+ logger.warning(f"{self.__class__.__name__}.save_pretrained is not implemented.")
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ *model_args,
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ ignore_mismatched_sizes: bool = False,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ use_safetensors: bool = None,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+ - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
+ `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
+ `True`.
+ - `None` if you are both providing the configuration and state dictionary (resp. with keyword
+ arguments `config` and `state_dict`).
+ model_args (sequence of positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+ config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
+ Can be either:
+
+ - an instance of a class derived from [`PretrainedConfig`],
+ - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
+
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
+ be automatically loaded when:
+
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
+ model).
+ - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
+ save directory.
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
+ configuration JSON file named *config.json* is found in the directory.
+ state_dict (`Dict[str, torch.Tensor]`, *optional*):
+ A state dictionary to use instead of a state dictionary loaded from saved weights file.
+
+ This option can be used if you want to create a model from a pretrained configuration but load your own
+ weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
+ [`~PreTrainedModel.from_pretrained`] is not a simpler option.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ from_tf (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a TensorFlow checkpoint save file (see docstring of
+ `pretrained_model_name_or_path` argument).
+ from_flax (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a Flax checkpoint save file (see docstring of
+ `pretrained_model_name_or_path` argument).
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+ Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
+ as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
+ checkpoint with 3 labels).
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/".
+
+
+
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+ _fast_init(`bool`, *optional*, defaults to `True`):
+ Whether or not to disable fast initialization.
+
+
+
+ One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ <
+ 4.6.0` for seeded model initialization. This argument will be removed at the next major version. See
+ [pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information.
+
+
+
+ > Parameters for big model inference
+
+ low_cpu_mem_usage(`bool`, *optional*):
+ Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ This is an experimental feature and a subject to change at any moment.
+ mindspore_dtype (`str` or `mindspore.Type`, *optional*):
+ Override the default `mindspore.Type` and load the model under a specific `dtype`. The different options
+ are:
+
+ 1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified
+ `dtype`, ignoring the model's `config.mindspore_dtype` if one exists. If not specified
+ - the model will get loaded in `torch.float` (fp32).
+
+ 2. `"auto"` - A `mindspore_dtype` entry in the `config.json` file of the model will be
+ attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in
+ the checkpoint that's of a floating point type and use that as `dtype`. This will load the model
+ using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
+ the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
+
+
+
+ For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or
+ reach out to the authors and ask them to add this information to the model's card and to insert the
+ `mindspore_dtype` entry in `config.json` on the hub.
+
+
+
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+ same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
+ like `1`) on which the model will be allocated, the device map will map the entire model to this
+ device. Passing `device_map = 0` means put the whole model on GPU 0.
+
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
+ GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
+ RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
+ `True` when there is some disk offload.
+ quantization_config (`Union[QuantizationConfigMixin,Dict]`, *optional*):
+ A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g
+ bitsandbytes, gptq). There may be other quantization-related kwargs, including `load_in_4bit` and
+ `load_in_8bit`, which are parsed by QuantizationConfigParser. Supported only for bitsandbytes
+ quantizations and not preferred. consider inserting all such arguments into quantization_config
+ instead.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+ variant (`str`, *optional*):
+ If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is
+ ignored when using `from_tf` or `from_flax`.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
+ is not installed, it will be set to `False`.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
+ automatically loaded:
+
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
+ already been done)
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
+ corresponds to a configuration attribute will be used to override said attribute with the
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
+ will be passed to the underlying model's `__init__` function.
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+
+ Examples:
+
+ ```python
+ >>> from transformers import BertConfig, BertModel
+
+ >>> # Download model and configuration from huggingface.co and cache.
+ >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased")
+ >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
+ >>> model = BertModel.from_pretrained("./test/saved_model/")
+ >>> # Update configuration during loading.
+ >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
+ >>> assert model.config.output_attentions == True
+ >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
+ >>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json")
+ >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
+ >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
+ >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)
+ ```
+
+ * `low_cpu_mem_usage` algorithm:
+
+ This is an experimental function that loads the model using ~1x model size CPU memory
+
+ Here is how it works:
+
+ 1. save which state_dict keys we have
+ 2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
+ 3. after the model has been instantiated switch to the meta device all params/buffers that
+ are going to be replaced from the loaded state_dict
+ 4. load state_dict 2nd time
+ 5. replace the params/buffers from the state_dict
+
+ Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors
+
+ """
+ state_dict = kwargs.pop("state_dict", None)
+ from_tf = kwargs.pop("from_tf", False)
+ from_flax = kwargs.pop("from_flax", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ _ = kwargs.pop("mirror", None)
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+ mindspore_dtype = kwargs.pop("mindspore_dtype", None)
+ offload_folder = kwargs.pop("offload_folder", None)
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
+ subfolder = kwargs.pop("subfolder", "")
+ commit_hash = kwargs.pop("_commit_hash", None)
+ variant = kwargs.pop("variant", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if use_safetensors is None and not is_safetensors_available():
+ use_safetensors = False
+
+ if commit_hash is None:
+ if not isinstance(config, PretrainedConfig):
+ # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
+ resolved_config_file = cached_file(
+ pretrained_model_name_or_path,
+ CONFIG_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_gated_repo=False,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ )
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
+ else:
+ commit_hash = getattr(config, "_commit_hash", None)
+
+ from_pt = not (from_tf | from_flax)
+
+ user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ # Load config if we don't provide a configuration
+ if not isinstance(config, PretrainedConfig):
+ config_path = config if config is not None else pretrained_model_name_or_path
+ config, model_kwargs = cls.config_class.from_pretrained(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ _from_auto=from_auto_class,
+ _from_pipeline=from_pipeline,
+ **kwargs,
+ )
+ else:
+ # In case one passes a config to `from_pretrained` + "attn_implementation"
+ # override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
+ # Please see: https://github.com/huggingface/transformers/issues/28038
+
+ # Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
+ # we pop attn_implementation from the kwargs but this handles the case where users
+ # passes manually the config to `from_pretrained`.
+ config = copy.deepcopy(config)
+
+ kwarg_attn_imp = kwargs.pop("attn_implementation", None)
+ if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
+ config._attn_implementation = kwarg_attn_imp
+ model_kwargs = kwargs
+
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+ # index of the files.
+ is_sharded = False
+ sharded_metadata = None
+ # Load model
+ loading_info = None
+
+ if pretrained_model_name_or_path is not None:
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if is_local:
+ if from_tf and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
+ ):
+ # Load from a TF 1.0 checkpoint in priority if from_tf
+ archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
+ elif from_tf and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
+ ):
+ # Load from a TF 2.0 checkpoint in priority if from_tf
+ archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
+ elif from_flax and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
+ ):
+ # Load from a Flax checkpoint in priority if from_flax
+ archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
+ elif use_safetensors is not False and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
+ ):
+ # Load from a safetensors checkpoint
+ archive_file = os.path.join(
+ pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
+ )
+ elif use_safetensors is not False and os.path.isfile(
+ os.path.join(
+ pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
+ )
+ ):
+ # Load from a sharded safetensors checkpoint
+ archive_file = os.path.join(
+ pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
+ )
+ is_sharded = True
+ elif os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
+ ):
+ # Load from a PyTorch checkpoint
+ archive_file = os.path.join(
+ pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
+ )
+ elif os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
+ ):
+ # Load from a sharded PyTorch checkpoint
+ archive_file = os.path.join(
+ pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
+ )
+ is_sharded = True
+ # At this stage we don't have a weight file so we will raise an error.
+ elif os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
+ ) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
+ raise EnvironmentError(
+ f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
+ f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
+ " `from_tf=True` to load this model from those weights."
+ )
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
+ raise EnvironmentError(
+ f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
+ f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
+ " to load this model from those weights."
+ )
+ elif use_safetensors:
+ raise EnvironmentError(
+ f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory"
+ f" {pretrained_model_name_or_path}."
+ )
+ else:
+ raise EnvironmentError(
+ f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME},"
+ f" {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory"
+ f" {pretrained_model_name_or_path}."
+ )
+ elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
+ archive_file = pretrained_model_name_or_path
+ is_local = True
+ elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
+ if not from_tf:
+ raise ValueError(
+ f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set "
+ "from_tf to True to load from this checkpoint."
+ )
+ archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
+ is_local = True
+ elif is_remote_url(pretrained_model_name_or_path):
+ filename = pretrained_model_name_or_path
+ resolved_archive_file = download_url(pretrained_model_name_or_path)
+ else:
+ # set correct filename
+ if from_tf:
+ filename = TF2_WEIGHTS_NAME
+ elif from_flax:
+ filename = FLAX_WEIGHTS_NAME
+ elif use_safetensors is not False:
+ filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
+ else:
+ filename = _add_variant(WEIGHTS_NAME, variant)
+
+ try:
+ # Load from URL or cache if already cached
+ cached_file_kwargs = {
+ "cache_dir": cache_dir,
+ "force_download": force_download,
+ "proxies": proxies,
+ "resume_download": resume_download,
+ "local_files_only": local_files_only,
+ "token": token,
+ "user_agent": user_agent,
+ "revision": revision,
+ "subfolder": subfolder,
+ "_raise_exceptions_for_gated_repo": False,
+ "_raise_exceptions_for_missing_entries": False,
+ "_commit_hash": commit_hash,
+ }
+ resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
+
+ # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
+ # result when internet is up, the repo and revision exist, but the file does not.
+ if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
+ # Maybe the checkpoint is sharded, we try to grab the index name in this case.
+ resolved_archive_file = cached_file(
+ pretrained_model_name_or_path,
+ _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
+ **cached_file_kwargs,
+ )
+ if resolved_archive_file is not None:
+ is_sharded = True
+ elif use_safetensors:
+ if revision == "main":
+ resolved_archive_file, revision, is_sharded = auto_conversion(
+ pretrained_model_name_or_path, **cached_file_kwargs
+ )
+ cached_file_kwargs["revision"] = revision
+ if resolved_archive_file is None:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
+ f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
+ "and thus cannot be loaded with `safetensors`. Please make sure that the model has "
+ "been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
+ )
+ else:
+ # This repo has no safetensors file of any kind, we switch to PyTorch.
+ filename = _add_variant(WEIGHTS_NAME, variant)
+ resolved_archive_file = cached_file(
+ pretrained_model_name_or_path, filename, **cached_file_kwargs
+ )
+ if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
+ # Maybe the checkpoint is sharded, we try to grab the index name in this case.
+ resolved_archive_file = cached_file(
+ pretrained_model_name_or_path,
+ _add_variant(WEIGHTS_INDEX_NAME, variant),
+ **cached_file_kwargs,
+ )
+ if resolved_archive_file is not None:
+ is_sharded = True
+ if resolved_archive_file is None:
+ # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
+ # message.
+ has_file_kwargs = {
+ "revision": revision,
+ "proxies": proxies,
+ "token": token,
+ }
+ if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
+ f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights."
+ " Use `from_tf=True` to load this model from those weights."
+ )
+ elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
+ f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use"
+ " `from_flax=True` to load this model from those weights."
+ )
+ elif variant is not None and has_file(
+ pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
+ ):
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
+ f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
+ f" {variant}. Use `variant=None` to load this model from those weights."
+ )
+ else:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
+ f" {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
+ f" {FLAX_WEIGHTS_NAME}."
+ )
+ except EnvironmentError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
+ # to the original exception.
+ raise
+ except Exception as e:
+ # For any other exception, we throw a generic error.
+ raise EnvironmentError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)},"
+ f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
+ ) from e
+
+ if is_local:
+ logger.info(f"loading weights file {archive_file}")
+ resolved_archive_file = archive_file
+ else:
+ logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
+ else:
+ resolved_archive_file = None
+
+ # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
+ if is_sharded:
+ # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
+ resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
+ pretrained_model_name_or_path,
+ resolved_archive_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _commit_hash=commit_hash,
+ )
+
+ if (
+ is_safetensors_available()
+ and isinstance(resolved_archive_file, str)
+ and resolved_archive_file.endswith(".safetensors")
+ ):
+ with safe_open(resolved_archive_file, framework="np") as f:
+ metadata = f.metadata()
+
+ if metadata.get("format") == "pt":
+ pass
+ elif metadata.get("format") == "tf":
+ from_tf = True
+ logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.")
+ elif metadata.get("format") == "flax":
+ from_flax = True
+ logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
+ else:
+ raise ValueError(
+ f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}"
+ )
+
+ from_pt = not (from_tf | from_flax)
+
+ # load pt weights early so that we know which dtype to init the model under
+ if from_pt:
+ if not is_sharded and state_dict is None:
+ # Time to load the checkpoint
+ state_dict = load_state_dict(resolved_archive_file)
+
+ if is_sharded:
+ loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
+ for sharded_file in resolved_archive_file:
+ if state_dict is None:
+ state_dict = safe_load_file(sharded_file)
+ else:
+ state_dict.update(safe_load_file(sharded_file))
+ else:
+ loaded_state_dict_keys = list(state_dict.keys())
+
+ config.name_or_path = pretrained_model_name_or_path
+ config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
+ model = cls(config, *model_args, **model_kwargs)
+
+ # todo: load state_dict into model
+ (
+ model,
+ missing_keys,
+ unexpected_keys,
+ mismatched_keys,
+ offload_index,
+ error_msgs,
+ ) = cls._load_pretrained_model(
+ model,
+ state_dict,
+ loaded_state_dict_keys, # XXX: rename?
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ sharded_metadata=sharded_metadata,
+ offload_folder=offload_folder,
+ offload_state_dict=offload_state_dict,
+ dtype=mindspore_dtype,
+ )
+
+ if mindspore_dtype is not None:
+ model = model.to(mindspore_dtype)
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.set_train(False)
+
+ if output_loading_info:
+ if loading_info is None:
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ "error_msgs": error_msgs,
+ }
+ return model, loading_info
+
+ return model
+
+ @classmethod
+ def _load_pretrained_model(
+ cls,
+ model,
+ state_dict,
+ loaded_keys,
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=False,
+ sharded_metadata=None,
+ offload_folder=None,
+ offload_state_dict=None,
+ dtype=None,
+ ):
+ def get_pt2ms_mappings(m):
+ mappings = {} # pt_param_name: (ms_param_name, pt_param_to_ms_param_func)
+ for name, cell in m.cells_and_names():
+ if isinstance(cell, nn.Conv1d):
+ mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ops.expand_dims(x, axis=-2)
+ elif isinstance(cell, nn.Embedding):
+ if "shared" in name:
+ if "decoder" in loaded_keys:
+ mappings[f"{name}.weight"] = (
+ "decoder.embed_tokens.embedding_table",
+ lambda x: x,
+ )
+ else:
+ mappings[f"{name}.weight"] = (
+ "encoder.embed_tokens.embedding_table",
+ lambda x: x,
+ )
+ else:
+ mappings[f"{name}.weight"] = f"{name}.embedding_table", lambda x: x
+ elif isinstance(cell, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
+ mappings[f"{name}.weight"] = f"{name}.gamma", lambda x: x
+ mappings[f"{name}.bias"] = f"{name}.beta", lambda x: x
+ if isinstance(cell, (nn.BatchNorm2d,)):
+ mappings[f"{name}.running_mean"] = f"{name}.moving_mean", lambda x: x
+ mappings[f"{name}.running_var"] = f"{name}.moving_variance", lambda x: x
+ mappings[f"{name}.num_batches_tracked"] = None, lambda x: x
+ return mappings
+
+ def convert_state_dict(m, state_dict_pt):
+ mappings = get_pt2ms_mappings(m)
+ state_dict_ms = {}
+ for name_pt, data_pt in state_dict_pt.items():
+ name_ms, data_mapping = mappings.get(name_pt, (name_pt, lambda x: x))
+ data_ms = data_mapping(data_pt)
+ if name_ms is not None:
+ state_dict_ms[name_ms] = data_ms
+ return state_dict_ms
+
+ missing_keys, unexpected_keys = ms.load_param_into_net(
+ model, convert_state_dict(model, state_dict), strict_load=True
+ )
+ mismatched_keys, offload_index, error_msgs = [], 0, ""
+ return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
diff --git a/examples/sharegpt_4v/share4v/transformers/models/__init__.py b/examples/sharegpt_4v/share4v/transformers/models/__init__.py
new file mode 100644
index 0000000000..6253816175
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/transformers/models/__init__.py
@@ -0,0 +1,2 @@
+from .clip import CLIPModel, CLIPTextModel, CLIPTextModelWithProjection, CLIPVisionModel, CLIPVisionModelWithProjection
+from .llama import LlamaForCausalLM, LlamaModel
\ No newline at end of file
diff --git a/examples/sharegpt_4v/share4v/transformers/models/cache.py b/examples/sharegpt_4v/share4v/transformers/models/cache.py
new file mode 100644
index 0000000000..ea672b5fa9
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/transformers/models/cache.py
@@ -0,0 +1,67 @@
+import abc
+from typing import List, Optional, Tuple
+
+import mindspore.ops as ops
+from mindspore import Tensor
+
+
+class Cache(abc.ABC):
+ @abc.abstractmethod
+ def __getitem__(self, layer_idx: int) -> List[Tuple[Tensor]]:
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def update(self, key_states: Tensor, value_states: Tensor, layer_idx: int) -> Tuple[Tensor, Tensor]:
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def reset(self, key_states: Tensor, value_states: Tensor, layer_idx: int) -> Tuple[Tensor, Tensor]:
+ raise NotImplementedError
+
+ @property
+ def seen_tokens(self) -> Optional[int]:
+ return getattr(self, "_seen_tokens", None)
+
+
+class DynamicCache(Cache):
+ def __init__(self) -> None:
+ self.key_cache: List[Tensor] = []
+ self.value_cache: List[Tensor] = []
+ self._seen_tokens = 0
+
+ def __getitem__(self, layer_idx: int) -> List[Tuple[Tensor]]:
+ if layer_idx < len(self.key_cache):
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx])
+ else:
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+ def update(self, key_states: Tensor, value_states: Tensor, layer_idx: int) -> Tuple[Tensor, Tensor]:
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ if len(self.key_cache) == layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ elif len(self.key_cache) > layer_idx:
+ self.key_cache[layer_idx] = ops.concat([self.key_cache[layer_idx], key_states], axis=-2)
+ self.value_cache[layer_idx] = ops.concat([self.value_cache[layer_idx], value_states], axis=-2)
+ else:
+ raise KeyError(
+ f"Layer index {layer_idx} is larger than the current key_cache length {len(self.key_cache)}."
+ )
+
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ if len(self.key_cache) <= layer_idx:
+ return 0
+ return self.key_cache[layer_idx].shape[-2]
+
+ def reset(self) -> None:
+ self.key_cache = []
+ self.value_cache = []
+ self._seen_tokens = 0
\ No newline at end of file
diff --git a/examples/sharegpt_4v/share4v/transformers/models/clip/__init__.py b/examples/sharegpt_4v/share4v/transformers/models/clip/__init__.py
new file mode 100644
index 0000000000..ebf248c8d5
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/transformers/models/clip/__init__.py
@@ -0,0 +1,7 @@
+from .modeling_ms_clip import (
+ CLIPModel,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPVisionModel,
+ CLIPVisionModelWithProjection,
+)
diff --git a/examples/sharegpt_4v/share4v/transformers/models/clip/modeling_ms_clip.py b/examples/sharegpt_4v/share4v/transformers/models/clip/modeling_ms_clip.py
new file mode 100644
index 0000000000..52927b189d
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/transformers/models/clip/modeling_ms_clip.py
@@ -0,0 +1,965 @@
+# coding=utf-8
+# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" MindSpore CLIP model."""
+from typing import List, Optional, Tuple, Union
+
+from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
+from transformers.utils import logging
+
+import mindspore as ms
+from mindspore import nn, ops
+
+from ...activations_ms import ACT2FN
+from ...modeling_ms_utils import MSPreTrainedModel
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "CLIPConfig"
+_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "openai/clip-vit-base-patch32"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
+
+CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "openai/clip-vit-base-patch32",
+ # See all CLIP models at https://huggingface.co/models?filter=clip
+]
+
+
+def _prepare_4d_attention_mask(mask: ms.Tensor, dtype: ms.Type, tgt_len: Optional[int] = None):
+ """
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`
+
+ Args:
+ mask (`torch.Tensor` or `None`):
+ A 2D attention mask of shape `(batch_size, key_value_length)`
+ dtype (`torch.dtype`):
+ The torch dtype the created mask shall have.
+ tgt_len (`int`):
+ The target length or query length the created mask shall have.
+ """
+ bsz, src_len = mask.shape
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].tile((1, 1, tgt_len, 1)).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(ms.bool_), float("-inf"))
+
+
+def _create_4d_causal_attention_mask(
+ input_shape: Union[Tuple, List],
+ dtype: ms.Type,
+ past_key_values_length: int = 0,
+ sliding_window: Optional[int] = None,
+) -> Optional[ms.Tensor]:
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
+
+ Args:
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
+ dtype (`torch.dtype`):
+ The torch dtype the created mask shall have.
+ sliding_window (`int`, *optional*):
+ If the model uses windowed attention, a sliding window should be passed.
+ """
+
+ key_value_length = past_key_values_length + input_shape[-1]
+ past_key_values_length = key_value_length - input_shape[-1]
+
+ def _make_causal_mask(
+ input_ids_shape: Union[Tuple, List],
+ dtype: ms.Type,
+ past_key_values_length: int = 0,
+ sliding_window: Optional[int] = None,
+ ):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = ops.full((tgt_len, tgt_len), float("-inf"), dtype=dtype)
+ mask_cond = ops.arange(mask.shape[-1])
+ mask = mask.masked_fill(mask_cond < (mask_cond + 1).view(mask.shape[-1], 1), 0)
+
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = ops.cat([ops.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1)
+
+ # add lower triangular sliding window mask if necessary
+ if sliding_window is not None:
+ diagonal = past_key_values_length - sliding_window + 1
+
+ context_mask = 1 - ops.triu(ops.ones_like(mask, dtype=ms.int32), diagonal=diagonal)
+ mask = mask.masked_fill(context_mask.bool(), float("-inf"))
+
+ return mask[None, None, :, :].tile((bsz, 1, 1, 1))
+
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ causal_4d_mask = None
+ if input_shape[-1] > 1 or sliding_window is not None:
+ causal_4d_mask = _make_causal_mask(
+ input_shape,
+ dtype,
+ past_key_values_length=past_key_values_length,
+ sliding_window=sliding_window,
+ )
+
+ return causal_4d_mask
+
+
+class CLIPVisionEmbeddings(nn.Cell):
+ def __init__(self, config: CLIPVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = ms.Parameter(ops.randn(self.embed_dim))
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ has_bias=False,
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.position_ids = ops.arange(self.num_positions).unsqueeze(0)
+
+ def construct(self, pixel_values: ms.Tensor) -> ms.Tensor:
+ batch_size = pixel_values.shape[0]
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(start_dim=2).swapaxes(1, 2)
+
+ class_embeds = self.class_embedding.reshape((1, 1, -1)).tile((batch_size, 1, 1)).to(dtype=target_dtype)
+ embeddings = ops.cat([class_embeds, patch_embeds], axis=1)
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+class CLIPTextEmbeddings(nn.Cell):
+ def __init__(self, config: CLIPTextConfig):
+ super().__init__()
+ embed_dim = config.hidden_size
+
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_ids = ops.arange(config.max_position_embeddings).unsqueeze(0)
+
+ def construct(
+ self,
+ input_ids: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ ) -> ms.Tensor:
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+class CLIPAttention(nn.Cell):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Dense(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Dense(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Dense(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Dense(self.embed_dim, self.embed_dim)
+
+ def _shape(self, tensor: ms.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).swapaxes(1, 2)
+
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: Optional[ms.Tensor] = None,
+ causal_attention_mask: Optional[ms.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.shape
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scale
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(proj_shape)
+ key_states = key_states.view(proj_shape)
+ value_states = value_states.view(proj_shape)
+
+ src_len = key_states.shape[1]
+ attn_weights = ops.bmm(query_states, key_states.swapaxes(1, 2))
+
+ if attn_weights.shape != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.shape}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.shape != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.shape}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.shape != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.shape}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = ops.softmax(attn_weights, axis=-1)
+
+ if output_attentions:
+ # this operation is a bit akward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = ops.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = ops.bmm(attn_probs, value_states)
+
+ if attn_output.shape != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.shape}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.swapaxes(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+class CLIPMLP(nn.Cell):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Dense(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Dense(config.intermediate_size, config.hidden_size)
+
+ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class CLIPEncoderLayer(nn.Cell):
+ def __init__(self, config: CLIPConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = CLIPAttention(config)
+ self.layer_norm1 = nn.LayerNorm((self.embed_dim,), epsilon=config.layer_norm_eps)
+ self.mlp = CLIPMLP(config)
+ self.layer_norm2 = nn.LayerNorm((self.embed_dim,), epsilon=config.layer_norm_eps)
+
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: ms.Tensor,
+ causal_attention_mask: ms.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[ms.Tensor]:
+ """
+ Args:
+ hidden_states (`ms.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`ms.Tensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class CLIPPreTrainedModel(MSPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = CLIPConfig
+ base_model_prefix = "clip"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ pass
+
+
+class CLIPEncoder(nn.Cell):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`CLIPEncoderLayer`].
+
+ Args:
+ config: CLIPConfig
+ """
+
+ def __init__(self, config: CLIPConfig):
+ super().__init__()
+ self.config = config
+ self.output_attentions = config.output_attentions
+ self.output_hidden_states = config.output_hidden_states
+ self.layers = nn.CellList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def construct(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[ms.Tensor] = None,
+ causal_attention_mask: Optional[ms.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ # return_dict: Optional[bool] = None,
+ ) -> Tuple:
+ r"""
+ Args:
+ inputs_embeds (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ causal_attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ raise NotImplementedError
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ # if not return_dict:
+ # return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+
+ # print("CLIPEncoder return tuple")
+ # results: last_hidden_state, hidden_states, attentions
+ results = (hidden_states, encoder_states, all_attentions)
+
+ # results = {
+ # 'last_hidden_state': hidden_states,
+ # 'hidden_states': encoder_states,
+ # 'attentions': all_attentions}
+ # print(type(hidden_states), type(encoder_states), type(all_attentions))
+ return results
+
+
+
+
+class CLIPTextTransformer(nn.Cell):
+ def __init__(self, config: CLIPTextConfig):
+ super().__init__()
+ self.config = config
+ self.output_attentions = config.output_attentions
+ self.output_hidden_states = config.output_hidden_states
+ embed_dim = config.hidden_size
+ self.embeddings = CLIPTextEmbeddings(config)
+ self.encoder = CLIPEncoder(config)
+ self.final_layer_norm = nn.LayerNorm((embed_dim,), epsilon=config.layer_norm_eps)
+
+ # For `pooled_output` computation
+ self.eos_token_id = config.eos_token_id
+
+ def construct(
+ self,
+ input_ids: Optional[ms.Tensor] = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Tuple:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
+
+ if input_ids is None:
+ raise ValueError("You have to specify input_ids")
+
+ input_shape = input_ids.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+ # CLIP's text model uses causal mask, prepare it here.
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+ causal_attention_mask = _create_4d_causal_attention_mask(input_shape, hidden_states.dtype)
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ if self.eos_token_id == 2:
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
+ # ------------------------------------------------------------
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
+ pooled_output = last_hidden_state[
+ ops.arange(last_hidden_state.shape[0]),
+ input_ids.to(dtype=ms.int32).argmax(axis=-1),
+ ]
+ else:
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
+ pooled_output = last_hidden_state[
+ ops.arange(last_hidden_state.shape[0]),
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
+ (input_ids.to(dtype=ms.int32) == self.eos_token_id).int().argmax(axis=-1),
+ ]
+
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+
+class CLIPTextModel(CLIPPreTrainedModel):
+ config_class = CLIPTextConfig
+
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
+
+ def __init__(self, config: CLIPTextConfig):
+ super().__init__(config)
+ self.text_model = CLIPTextTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Cell:
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, value):
+ self.text_model.embeddings.token_embedding = value
+
+ def construct(
+ self,
+ input_ids: Optional[ms.Tensor] = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Tuple:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CLIPTextModel
+
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+
+class CLIPVisionTransformer(nn.Cell):
+ def __init__(self, config: CLIPVisionConfig):
+ super().__init__()
+ self.config = config
+ self.output_attentions = config.output_attentions
+ self.output_hidden_states = config.output_hidden_states
+ embed_dim = config.hidden_size
+
+ self.embeddings = CLIPVisionEmbeddings(config)
+ self.pre_layrnorm = nn.LayerNorm((embed_dim,), epsilon=config.layer_norm_eps)
+ self.encoder = CLIPEncoder(config)
+ self.post_layernorm = nn.LayerNorm((embed_dim,), epsilon=config.layer_norm_eps)
+
+ def construct(
+ self,
+ pixel_values: Optional[ms.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ # return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, dict]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.pre_layrnorm(hidden_states)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ # (last_hidden_state, hidden_states, attentions)
+
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = last_hidden_state[:, 0, :]
+ pooled_output = self.post_layernorm(pooled_output)
+
+ # if not return_dict:
+ # return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ # print("CLIPVisionTransformer return tuple")
+ # results: last_hidden_state, pooled_output, hidden_states, attentions
+ results = (last_hidden_state, pooled_output, encoder_outputs[1], encoder_outputs[2])
+ # results = {}
+ # results["last_hidden_state"] = last_hidden_state
+ # results["pooled_output"] = pooled_output
+ # results["hidden_states"] = encoder_outputs['hidden_states']
+ # results["attentions"] = encoder_outputs['attentions']
+ return results
+
+
+
+class CLIPVisionModel(CLIPPreTrainedModel):
+ config_class = CLIPVisionConfig
+ main_input_name = "pixel_values"
+ _no_split_modules = ["CLIPEncoderLayer"]
+
+ def __init__(self, config: CLIPVisionConfig):
+ super().__init__(config)
+ self.vision_model = CLIPVisionTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Cell:
+ return self.vision_model.embeddings.patch_embedding
+
+ def construct(
+ self,
+ pixel_values: Optional[ms.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Tuple:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, CLIPVisionModel
+
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
+ ```"""
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+
+class CLIPModel(CLIPPreTrainedModel):
+ config_class = CLIPConfig
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
+
+ def __init__(self, config: CLIPConfig):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, CLIPTextConfig):
+ raise ValueError(
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
+ f" {type(config.text_config)}."
+ )
+
+ if not isinstance(config.vision_config, CLIPVisionConfig):
+ raise ValueError(
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
+ f" {type(config.vision_config)}."
+ )
+
+ text_config = config.text_config
+ vision_config = config.vision_config
+ self.output_attentions = config.output_attentions
+ self.output_hidden_states = config.output_hidden_states
+ self.logit_scale_init_value = config.logit_scale_init_value
+
+ self.projection_dim = config.projection_dim
+ self.text_embed_dim = text_config.hidden_size
+ self.vision_embed_dim = vision_config.hidden_size
+
+ self.text_model = CLIPTextTransformer(text_config)
+ self.vision_model = CLIPVisionTransformer(vision_config)
+
+ self.visual_projection = nn.Dense(self.vision_embed_dim, self.projection_dim, has_bias=False)
+ self.text_projection = nn.Dense(self.text_embed_dim, self.projection_dim, has_bias=False)
+ self.logit_scale = ms.Parameter(ms.Tensor(self.logit_scale_init_value))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_text_features(
+ self,
+ input_ids: Optional[ms.Tensor] = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> ms.Tensor:
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ pooled_output = text_outputs[1]
+ text_features = self.text_projection(pooled_output)
+
+ return text_features
+
+ def get_image_features(
+ self,
+ pixel_values: Optional[ms.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> ms.Tensor:
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ pooled_output = vision_outputs[1] # pooled_output
+ image_features = self.visual_projection(pooled_output)
+
+ return image_features
+
+ def construct(
+ self,
+ input_ids: Optional[ms.Tensor] = None,
+ pixel_values: Optional[ms.Tensor] = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Tuple:
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ image_embeds = vision_outputs[1]
+ image_embeds = self.visual_projection(image_embeds)
+
+ text_embeds = text_outputs[1]
+ text_embeds = self.text_projection(text_embeds)
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(ord=2, dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(ord=2, dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_text = ops.matmul(text_embeds, image_embeds.t()) * logit_scale
+ logits_per_image = logits_per_text.t()
+
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
+ return output
+
+
+class CLIPTextModelWithProjection(CLIPPreTrainedModel):
+ config_class = CLIPTextConfig
+
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
+
+ def __init__(self, config: CLIPTextConfig):
+ super().__init__(config)
+
+ self.text_model = CLIPTextTransformer(config)
+
+ self.text_projection = nn.Dense(config.hidden_size, config.projection_dim, has_bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Cell:
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, value):
+ self.text_model.embeddings.token_embedding = value
+
+ def construct(
+ self,
+ input_ids: Optional[ms.Tensor] = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Tuple:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CLIPTextModelWithProjection
+
+ >>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> text_embeds = outputs.text_embeds
+ ```"""
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ pooled_output = text_outputs[1]
+
+ text_embeds = self.text_projection(pooled_output)
+
+ outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
+ return tuple(output for output in outputs if output is not None)
+
+
+class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
+ config_class = CLIPVisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: CLIPVisionConfig):
+ super().__init__(config)
+
+ self.vision_model = CLIPVisionTransformer(config)
+
+ self.visual_projection = nn.Dense(config.hidden_size, config.projection_dim, has_bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Cell:
+ return self.vision_model.embeddings.patch_embedding
+
+ def construct(
+ self,
+ pixel_values: Optional[ms.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Tuple:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, CLIPVisionModelWithProjection
+
+ >>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> image_embeds = outputs.image_embeds
+ ```"""
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ pooled_output = vision_outputs[1] # pooled_output
+
+ image_embeds = self.visual_projection(pooled_output)
+
+ outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]
+ return tuple(output for output in outputs if output is not None)
diff --git a/examples/sharegpt_4v/share4v/transformers/models/llama/__init__.py b/examples/sharegpt_4v/share4v/transformers/models/llama/__init__.py
new file mode 100644
index 0000000000..cf58d6af71
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/transformers/models/llama/__init__.py
@@ -0,0 +1 @@
+from .modeling_ms_llama import (LlamaForCausalLM, LlamaModel)
\ No newline at end of file
diff --git a/examples/sharegpt_4v/share4v/transformers/models/llama/modeling_ms_llama.py b/examples/sharegpt_4v/share4v/transformers/models/llama/modeling_ms_llama.py
new file mode 100644
index 0000000000..df8f0b972f
--- /dev/null
+++ b/examples/sharegpt_4v/share4v/transformers/models/llama/modeling_ms_llama.py
@@ -0,0 +1,633 @@
+import logging
+from typing import Literal, Optional, Tuple, Union
+import numbers
+
+import mindspore as ms
+import mindspore.nn as nn
+import mindspore.ops as ops
+from mindspore import Parameter, Tensor
+from mindspore.common.initializer import Initializer, initializer
+from mindspore.ops.operations.nn_ops import FlashAttentionScore
+
+
+from ...activations_ms import ACT2FN
+
+logger = logging.getLogger(__name__)
+
+class Embedding(nn.Embedding):
+ def __init__(
+ self,
+ vocab_size: int,
+ embedding_size: int,
+ use_one_hot: bool = False,
+ embedding_table: Union[Tensor, str, Initializer, numbers.Number] = "normal",
+ dtype: ms.dtype = ms.float32,
+ padding_idx: Optional[int] = None,
+ ):
+ """Initialize Embedding."""
+ super(nn.Embedding, self).__init__()
+ self.vocab_size = vocab_size
+ self.embedding_size = embedding_size
+ self.use_one_hot = use_one_hot
+ self.dtype = dtype
+ self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size], dtype=dtype)
+ self.padding_idx = padding_idx
+ if padding_idx is not None:
+ self.padding_idx = padding_idx
+ if isinstance(self.init_tensor, Tensor) and self.init_tensor.init is not None:
+ self.init_tensor = self.init_tensor.init_data()
+ self.init_tensor = self.init_tensor.asnumpy()
+ self.init_tensor[self.padding_idx] = 0
+ self.init_tensor = Tensor(self.init_tensor)
+ self.weight = Parameter(self.init_tensor)
+ self.expand = ops.ExpandDims()
+ self.reshape_flat = ops.Reshape()
+ self.shp_flat = (-1,)
+ self.gather = ops.Gather()
+ self.one_hot = ops.OneHot()
+ self.on_value = Tensor(1.0, self.dtype)
+ self.off_value = Tensor(0.0, self.dtype)
+ self.array_mul = ops.MatMul()
+ self.reshape = ops.Reshape()
+ self.get_shp = ops.Shape()
+ self.concat = ops.Concat()
+
+ def construct(self, ids):
+ out_shape = self.get_shp(ids) + (self.embedding_size,)
+ flat_ids = self.reshape_flat(ids, self.shp_flat)
+
+ if self.use_one_hot:
+ one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
+ output_for_reshape = self.array_mul(one_hot_ids, self.weight)
+ else:
+ output_for_reshape = self.gather(self.weight, flat_ids, 0)
+
+ output = self.reshape(output_for_reshape, out_shape)
+ return output
+
+ def extend_repr(self):
+ return (
+ f"vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, "
+ f"embedding_table={self.weight}, dtype={self.dtype}, padding_idx={self.padding_idx}"
+ )
+
+
+class LlamaRMSNorm(nn.Cell):
+ def __init__(self, hidden_size: int, eps: float = 1e-6, dtype: ms.dtype = ms.float32) -> None:
+ super().__init__()
+ self.weight = Parameter(ops.ones(hidden_size, dtype=dtype))
+ self.variance_epsilon = eps
+
+ def construct(self, hidden_states: Tensor) -> Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(ms.float32)
+ variance = ops.pow(hidden_states, 2)
+ variance = ops.mean(variance, axis=-1, keep_dims=True)
+ hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class LlamaRotaryEmbedding(nn.Cell):
+ def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0) -> None:
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ self.inv_freq = 1.0 / (self.base ** (ops.arange(0, self.dim, 2, dtype=ms.float32) / self.dim))
+
+ def construct(self, x: Tensor, position_ids: Tensor) -> Tuple[Tensor, Tensor]:
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ inv_freq_expanded = ops.broadcast_to(self.inv_freq[None, :, None], (position_ids.shape[0], -1, 1))
+ position_ids_expanded = position_ids[:, None, :].to(ms.float32)
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ freqs = ops.matmul(inv_freq_expanded.to(ms.float32), position_ids_expanded.to(ms.float32))
+ freqs = ops.transpose(freqs, (0, 2, 1))
+ emb = ops.concat((freqs, freqs), axis=-1)
+ cos = ops.cos(emb)
+ sin = ops.sin(emb)
+ return cos.to(x.dtype), sin.to(x.dtype)
+
+
+def rotate_half(x: Tensor) -> Tensor:
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return ops.concat((-x2, x1), axis=-1)
+
+
+def apply_rotary_pos_emb(
+ q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, unsqueeze_dim: int = 1
+) -> Tuple[Tensor, Tensor]:
+ cos = ops.unsqueeze(cos, unsqueeze_dim)
+ sin = ops.unsqueeze(sin, unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class LlamaMLP(nn.Cell):
+ def __init__(
+ self,
+ intermediate_size: int = 14336,
+ hidden_size: int = 4096,
+ hidden_act: str = "silu",
+ dtype: ms.dtype = ms.float32,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.gate_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=False, dtype=dtype)
+ self.up_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=False, dtype=dtype)
+ self.down_proj = nn.Dense(self.intermediate_size, self.hidden_size, has_bias=False, dtype=dtype)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def construct(self, hidden_state: Tensor) -> Tensor:
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor:
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :]
+ hidden_states = ops.broadcast_to(hidden_states, (batch, num_key_value_heads, n_rep, slen, head_dim))
+ hidden_states = ops.reshape(hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim))
+ return hidden_states
+
+
+class LlamaAttention(nn.Cell):
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ num_attention_heads: int = 32,
+ num_key_value_heads: int = 8,
+ max_position_embeddings: int = 32768,
+ rope_theta: float = 1000000.0,
+ attention_dropout: float = 0.0,
+ dtype: ms.dtype = ms.float32,
+ ) -> None:
+ super().__init__()
+
+ self.attention_dropout = attention_dropout
+ self.hidden_size = hidden_size
+ self.num_heads = num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.rope_theta = rope_theta
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Dense(self.hidden_size, self.num_heads * self.head_dim, has_bias=False, dtype=dtype)
+ self.k_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype)
+ self.v_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype)
+ self.o_proj = nn.Dense(self.num_heads * self.head_dim, self.hidden_size, has_bias=False, dtype=dtype)
+
+ self.rotary_emb = LlamaRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+
+ def construct(
+ self,
+ hidden_states: Tensor,
+ attention_mask: Optional[Tensor] = None,
+ position_ids: Optional[Tensor] = None,
+ past_key_cache: Optional[Tensor] = None,
+ past_value_cache: Optional[Tensor] = None,
+ return_key_value_cache: bool = False,
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ bsz, q_len, _ = hidden_states.shape
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = ops.reshape(query_states, (bsz, q_len, self.num_heads, self.head_dim))
+ query_states = ops.transpose(query_states, (0, 2, 1, 3))
+
+ key_states = ops.reshape(key_states, (bsz, q_len, self.num_key_value_heads, self.head_dim))
+ key_states = ops.transpose(key_states, (0, 2, 1, 3))
+
+ value_states = ops.reshape(value_states, (bsz, q_len, self.num_key_value_heads, self.head_dim))
+ value_states = ops.transpose(value_states, (0, 2, 1, 3))
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if return_key_value_cache:
+ key_cache, value_cache = key_states, value_states
+ else:
+ key_cache, value_cache = None, None
+
+ if past_key_cache is not None and past_value_cache is not None:
+ key_states = ops.concat([past_key_cache, key_states], axis=-2)
+ value_states = ops.concat([past_value_cache, value_states], axis=-2)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ key_states = ops.transpose(key_states, (0, 1, 3, 2))
+ attn_weights = ops.matmul(query_states, key_states) / ms.numpy.sqrt(ms.Tensor(self.head_dim, query_states.dtype))
+
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = ops.softmax(attn_weights.to(ms.float32), axis=-1).to(query_states.dtype)
+ attn_weights = ops.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = ops.matmul(attn_weights, value_states)
+
+ attn_output = ops.transpose(attn_output, (0, 2, 1, 3))
+ attn_output = ops.reshape(attn_output, (bsz, q_len, -1))
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, key_cache, value_cache
+
+
+class LlamaFlashAttention(nn.Cell):
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ num_attention_heads: int = 32,
+ num_key_value_heads: int = 8,
+ max_position_embeddings: int = 32768,
+ rope_theta: float = 1000000.0,
+ attention_dropout: float = 0.0,
+ dtype: ms.dtype = ms.float32,
+ ) -> None:
+ super().__init__()
+
+ self.attention_dropout = attention_dropout
+ self.hidden_size = hidden_size
+ self.num_heads = num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.rope_theta = rope_theta
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Dense(self.hidden_size, self.num_heads * self.head_dim, has_bias=False, dtype=dtype)
+ self.k_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype)
+ self.v_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype)
+ self.o_proj = nn.Dense(self.num_heads * self.head_dim, self.hidden_size, has_bias=False, dtype=dtype)
+
+ self.rotary_emb = LlamaRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+
+ self.flash_attention = FlashAttentionScore(
+ self.num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BSND"
+ )
+
+ def construct(
+ self,
+ hidden_states: Tensor,
+ attention_mask: Optional[Tensor] = None,
+ position_ids: Optional[Tensor] = None,
+ past_key_cache: Optional[Tensor] = None,
+ past_value_cache: Optional[Tensor] = None,
+ return_key_value_cache: bool = False,
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ bsz, q_len, _ = hidden_states.shape
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = ops.reshape(query_states, (bsz, q_len, self.num_heads, self.head_dim))
+ query_states = ops.transpose(query_states, (0, 2, 1, 3))
+
+ key_states = ops.reshape(key_states, (bsz, q_len, self.num_key_value_heads, self.head_dim))
+ key_states = ops.transpose(key_states, (0, 2, 1, 3))
+
+ value_states = ops.reshape(value_states, (bsz, q_len, self.num_key_value_heads, self.head_dim))
+ value_states = ops.transpose(value_states, (0, 2, 1, 3))
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if return_key_value_cache:
+ key_cache, value_cache = key_states, value_states
+ else:
+ key_cache, value_cache = None, None
+
+ if past_key_cache is not None and past_value_cache is not None:
+ key_states = ops.concat([key_states, past_key_cache], axis=-2)
+ value_states = ops.concat([value_states, past_value_cache], axis=-2)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # Reshape to the expected shape and dtype for Flash Attention
+ query_states = ops.transpose(query_states, (0, 2, 1, 3))
+ key_states = ops.transpose(key_states, (0, 2, 1, 3))
+ value_states = ops.transpose(value_states, (0, 2, 1, 3))
+ attention_mask = attention_mask.to(ms.uint8)
+
+ _, _, _, attn_output = self.flash_attention(
+ query_states, key_states, value_states, None, None, None, attention_mask
+ )
+ attn_output = ops.reshape(attn_output, (bsz, q_len, -1))
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, key_cache, value_cache
+
+MISTRAL_ATTENTION_CLASSES = {
+ "eager": LlamaAttention,
+ "flash_attention": LlamaFlashAttention,
+}
+
+class LlamaDecoderLayer(nn.Cell):
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ intermediate_size: int = 14336,
+ num_attention_heads: int = 32,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-5,
+ max_position_embeddings: int = 32768,
+ rope_theta: float = 1000000.0,
+ attention_dropout: float = 0.0,
+ hidden_act: str = "silu",
+ attn_implementation: Literal["eager", "flash_attention"] = "eager",
+ dtype: ms.dtype = ms.float32,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[attn_implementation](
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ max_position_embeddings=max_position_embeddings,
+ rope_theta=rope_theta,
+ attention_dropout=attention_dropout,
+ dtype=dtype,
+ )
+
+ self.mlp = LlamaMLP(
+ intermediate_size=intermediate_size, hidden_size=hidden_size, hidden_act=hidden_act, dtype=dtype
+ )
+ self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype)
+ self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype)
+
+ def construct(
+ self,
+ hidden_states: Tensor,
+ attention_mask: Optional[Tensor] = None,
+ position_ids: Optional[Tensor] = None,
+ past_key_cache: Optional[Tensor] = None,
+ past_value_cache: Optional[Tensor] = None,
+ return_key_value_cache: bool = False,
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, key_cache, value_cache = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_cache=past_key_cache,
+ past_value_cache=past_value_cache,
+ return_key_value_cache=return_key_value_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states, key_cache, value_cache
+
+
+class LlamaModel(nn.Cell):
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ intermediate_size: int = 14336,
+ max_position_embeddings: int = 32768,
+ num_attention_heads: int = 32,
+ num_hidden_layers: int = 32,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-5,
+ rope_theta: float = 1000000.0,
+ vocab_size: int = 32064,
+ attention_dropout: float = 0.0,
+ hidden_act: str = "silu",
+ pad_token_id: Optional[int] = None,
+ attn_implementation: Literal["eager", "flash_attention"] = "eager",
+ dtype: ms.dtype = ms.float32,
+ ) -> None:
+ super().__init__()
+ self.padding_idx = pad_token_id
+ self.vocab_size = vocab_size
+ self.attn_implementation = attn_implementation
+
+ self.embed_tokens = Embedding(vocab_size, hidden_size, padding_idx=self.padding_idx, dtype=dtype)
+ self.layers = nn.CellList(
+ [
+ LlamaDecoderLayer(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ max_position_embeddings=max_position_embeddings,
+ rope_theta=rope_theta,
+ attention_dropout=attention_dropout,
+ hidden_act=hidden_act,
+ attn_implementation=attn_implementation,
+ dtype=dtype,
+ )
+ for _ in range(num_hidden_layers)
+ ]
+ )
+ self.norm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype)
+
+ def get_input_embeddings(self) -> nn.Cell:
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value: nn.Cell) -> None:
+ self.embed_tokens = value
+
+ def construct(
+ self,
+ input_ids: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ position_ids: Optional[Tensor] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ past_key_cache_list: Optional[Tensor] = None,
+ past_value_cache_list: Optional[Tensor] = None,
+ return_key_value_cache: bool = False,
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ past_seen_tokens = past_key_cache_list.shape[-2] if past_key_cache_list is not None else 0
+ cache_position = ops.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], dtype=ms.int32)
+ if position_ids is None:
+ position_ids = ops.unsqueeze(cache_position, 0)
+
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
+
+ hidden_states = inputs_embeds
+
+ if return_key_value_cache:
+ key_cache_list, value_cache_list = [], []
+ else:
+ key_cache_list, value_cache_list = None, None
+
+ for i, decoder_layer in enumerate(self.layers):
+ if past_key_cache_list is not None and past_value_cache_list is not None:
+ past_key_cache, past_value_cache = past_key_cache_list[i], past_value_cache_list[i]
+ else:
+ past_key_cache, past_value_cache = None, None
+
+ hidden_states, key_cache, value_cache = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_cache=past_key_cache,
+ past_value_cache=past_value_cache,
+ return_key_value_cache=return_key_value_cache,
+ )
+
+ if return_key_value_cache:
+ key_cache_list.append(key_cache)
+ value_cache_list.append(value_cache)
+
+ hidden_states = self.norm(hidden_states)
+
+ if return_key_value_cache:
+ key_cache_list = ops.stack(key_cache_list)
+ value_cache_list = ops.stack(value_cache_list)
+
+ return hidden_states, key_cache_list, value_cache_list
+
+ def _update_causal_mask(self, attention_mask: Tensor, input_tensor: Tensor, cache_position: Tensor) -> Tensor:
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ target_length = attention_mask.shape[-1]
+
+ if len(attention_mask.shape) == 4:
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
+ if attention_mask.max() != 0:
+ raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
+ causal_mask = attention_mask
+ else:
+ fill_value = -ms.numpy.inf if self.attn_implementation == "eager" else 1.0
+ causal_mask = ops.full((sequence_length, target_length), fill_value=fill_value, dtype=dtype)
+ exclude_mask = ops.arange(target_length) > cache_position.reshape(-1, 1)
+ causal_mask = ops.masked_fill(causal_mask, ~exclude_mask, Tensor(0, dtype=dtype))
+ causal_mask = ops.broadcast_to(causal_mask[None, None, :, :], (input_tensor.shape[0], 1, -1, -1))
+ if len(attention_mask.shape) == 2:
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = ops.masked_fill(
+ causal_mask[:, :, :, :mask_length], padding_mask, Tensor(fill_value, dtype=dtype)
+ )
+
+ return causal_mask
+
+
+class LlamaForCausalLM(nn.Cell):
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ intermediate_size: int = 14336,
+ max_position_embeddings: int = 32768,
+ num_attention_heads: int = 32,
+ num_hidden_layers: int = 32,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-5,
+ rope_theta: float = 1000000.0,
+ vocab_size: int = 32000,
+ attention_dropout: float = 0.0,
+ hidden_act: str = "silu",
+ pad_token_id: Optional[int] = None,
+ attn_implementation: Literal["eager", "flash_attention"] = "eager",
+ dtype: ms.dtype = ms.float32,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ self.model = LlamaModel(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ max_position_embeddings=max_position_embeddings,
+ num_attention_heads=num_attention_heads,
+ num_hidden_layers=num_hidden_layers,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ rope_theta=rope_theta,
+ vocab_size=vocab_size,
+ attention_dropout=attention_dropout,
+ hidden_act=hidden_act,
+ pad_token_id=pad_token_id,
+ attn_implementation=attn_implementation,
+ dtype=dtype,
+ )
+ self.vocab_size = vocab_size
+ self.lm_head = nn.Dense(hidden_size, vocab_size, has_bias=False, dtype=dtype)
+
+ def get_input_embeddings(self) -> nn.Cell:
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value: nn.Cell) -> None:
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self) -> nn.Cell:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings: nn.Cell) -> None:
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder: nn.Cell) -> None:
+ self.model = decoder
+
+ def get_decoder(self) -> nn.Cell:
+ return self.model
+
+ @ms.jit
+ def construct(
+ self,
+ input_ids: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ position_ids: Optional[Tensor] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ past_key_cache_list: Optional[Tensor] = None,
+ past_value_cache_list: Optional[Tensor] = None,
+ return_key_value_cache: bool = False,
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ hidden_states, key_cache_list, value_cache_list = self.model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_cache_list=past_key_cache_list,
+ past_value_cache_list=past_value_cache_list,
+ return_key_value_cache=return_key_value_cache,
+ )
+ logits = self.lm_head(hidden_states).to(ms.float32)
+ return logits, key_cache_list, value_cache_list
\ No newline at end of file