forked from zhangfaen/finetune-Qwen2-VL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
finetune_distributed.py
193 lines (157 loc) · 9.09 KB
/
finetune_distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import torch
import json
import datetime
import os
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from torch.utils.data import Dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from functools import partial
from util.vision_util import process_vision_info
from torch.optim import AdamW
from util.logutil import init_logger, get_logger
from accelerate import Accelerator
accelerator = Accelerator(gradient_accumulation_steps=2)
device = accelerator.device
if accelerator.is_local_main_process:
output_dir = f'train_output/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}/'
init_logger(output_dir)
logger = get_logger()
class ToyDataSet(Dataset): # for toy demo, for train_data/data.json
def __init__(self, data_path):
super().__init__()
with open(data_path, "r") as f:
self.data = json.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def find_assistant_content_sublist_indexes(l):
'''
A message from data.json may look like below:
[[{'role': 'user', 'content': [{'type': 'image', 'image': 'train_data/1.jpeg'}, {'type': 'text', 'text': '描述一下这个图片'}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': '这张图片展示了一位年轻女子和她的狗在海滩上玩耍的场景。女子穿着格子衬衫和黑色裤子,坐在沙滩上,与她的金毛犬互动。她们的手臂伸展着,似乎在进行某种游戏或训练。背景是广阔的海洋和晴朗的天空,阳光洒在沙滩上,营造出温暖而宁静的氛围。整体画面充满了快乐和放松的感觉。'}]}]]
After apply_chat_template, the text will look like below:
['<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>描述一下这个图片<|im_end|>\n<|im_start|>assistant\n这张图片展示了一位年轻女子和她的狗在海滩上玩耍的场景。女子穿着格子衬衫和黑色裤子,坐在沙滩上,与她的金毛犬互动。她们的手臂伸展着,似乎在进行某种游戏或训练。背景是广阔的海洋和晴朗的天空,阳光洒在沙滩上,营造出温暖而宁静的氛围。整体画面充满了快乐和放松的感觉。<|im_end|>\n']
The key of this function is to find the indexes of the assistant content in the input_ids list to build labels.
'''
# (Pdb++) processor.tokenizer.encode("<|im_start|>assistant")
# [151644, 77091]
# (Pdb++) processor.tokenizer.encode("<|im_end|>")
# [151645]
start_indexes = []
end_indexes = []
# Iterate through the list to find starting points
for i in range(len(l) - 1):
# Check if the current and next element form the start sequence
if l[i] == 151644 and l[i + 1] == 77091:
start_indexes.append(i)
# Now look for the first 151645 after the start
for j in range(i + 2, len(l)):
if l[j] == 151645:
end_indexes.append(j)
break # Move to the next start after finding the end
return list(zip(start_indexes, end_indexes))
def collate_fn(batch, processor, device):
# (Pdb++) processor.tokenizer.encode("<|im_start|>assistant")
# [151644, 77091]
# (Pdb++) processor.tokenizer.encode("<|im_end|>")
# [151645]
messages = [m['messages'] for m in batch]
texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=False) for msg in messages]
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=texts,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
input_ids_lists = inputs['input_ids'].tolist()
assert len(messages) == len(input_ids_lists)
labels_list = []
for ids_list in input_ids_lists:
label_ids = [-100] * len(ids_list)
for begin_end_indexs in find_assistant_content_sublist_indexes(ids_list):
label_ids[begin_end_indexs[0]+2:begin_end_indexs[1]+1] = ids_list[begin_end_indexs[0]+2:begin_end_indexs[1]+1]
labels_list.append(label_ids)
labels_ids = torch.tensor(labels_list, dtype=torch.int64)
return inputs, labels_ids
def write_chat_template(processor, output_dir):
'''
***Note**
We should have not had this function, as normal processor.save_pretrained(output_dir) would save chat_template.json file.
However, on 2024/09/05, I think a commit introduced a bug to "huggingface/transformers", which caused the chat_template.json file not to be saved.
See the below commit, src/transformers/processing_utils.py line 393, this commit avoided chat_template.json to be saved.
https://github.com/huggingface/transformers/commit/43df47d8e78238021a4273746fc469336f948314#diff-6505546ec5a9ab74b2ce6511681dd31194eb91e9fa3ce26282e487a5e61f9356
To walk around that bug, we need manually save the chat_template.json file.
I hope this bug will be fixed soon and I can remove this function then.
'''
output_chat_template_file = os.path.join(output_dir, "chat_template.json")
chat_template_json_string = json.dumps({"chat_template": processor.chat_template}, indent=2, sort_keys=True) + "\n"
with open(output_chat_template_file, "w", encoding="utf-8") as writer:
writer.write(chat_template_json_string)
logger.info(f"chat template saved in {output_chat_template_file}")
def train():
# default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="bfloat16"
)
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen2VLForConditionalGeneration.from_pretrained(
# "Qwen/Qwen2-VL-2B-Instruct",
# torch_dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
# device_map="auto",
# )
# default processer, min image tokens 256, max image tokens 512
# Note: typically, in training, when we batch size of training dataloader is > 1, it is often we need pad shorter inputs to the same length.
# in training, we often add "padding_token_id" to the right side of shorter inputs to make them the same length. padding_side right
# make casual_mask easier to build by attention mask. for more detail, see *** notes.txt *** of this repo.
# in batching inference, we must use "padding_side" left, as generation usually ust last token of output list of tokens.
# in training, it is recommended to use "padding_side" right.
# https://github.com/huggingface/transformers/pull/26572
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
# see transformers/models/qwen2_vl/modeling_qwen2_vl.py: causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=256*28*28, max_pixels=512*28*28, padding_side="right")
# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
train_loader = DataLoader(
ToyDataSet("train_data/data.json"),
batch_size=1,
collate_fn=partial(collate_fn, processor=processor, device=device)
)
model.train()
epochs = 10
optimizer = AdamW(model.parameters(), lr=1e-5)
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
for epoch in range(epochs):
steps = 0
for batch in train_loader:
steps += 1
with accelerator.accumulate(model):
optimizer.zero_grad()
inputs, labels = batch
outputs = model(**inputs, labels=labels)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
if accelerator.is_local_main_process:
logger.info(f"Batch {steps} of epoch {epoch + 1}/{epochs}, training loss : {loss.item()}")
accelerator.wait_for_everyone()
if accelerator.is_local_main_process:
os.makedirs(output_dir, exist_ok=True)
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
)
processor.save_pretrained(output_dir)
write_chat_template(processor, output_dir)
if __name__ == "__main__":
train()