-
Notifications
You must be signed in to change notification settings - Fork 0
/
finetune-lora-ds-int.py
812 lines (735 loc) · 31.7 KB
/
finetune-lora-ds-int.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
#!/usr/bin/env python
# coding=utf-8
import logging
import math
import os
import sys
import random
from dataclasses import dataclass, field
from itertools import chain
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
import numpy as np
from typing import Dict, Optional, List, Union
import datasets
import evaluate
import torch
from datasets import load_dataset
from peft import (
LoraConfig,
PeftModel,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
set_peft_model_state_dict,
)
import transformers
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers import (
CONFIG_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
TrainerCallback,
TrainerState,
TrainerControl,
HfArgumentParser,
Trainer,
TrainingArguments,
default_data_collator,
BitsAndBytesConfig,
is_torch_tpu_available,
set_seed,
get_scheduler,
)
from transformers.testing_utils import CaptureLogger
from transformers.trainer_utils import get_last_checkpoint
from utils.prompter import Prompter
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from ds.ds_utils import get_train_ds_config
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, \
get_optimizer_grouped_parameters, save_zero_three_model
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str] = field(
default=None,
metadata={
"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_overrides: Optional[str] = field(
default=None,
metadata={
"help": (
"Override some existing default config settings when a model is trained from scratch. Example: "
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
)
},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={
"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
lora_r: Optional[int] = field(default=16)
lora_alpha: Optional[int] = field(default=32)
target_modules: Optional[str] = field(
default='q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj',
metadata={
"help": "List of module names or regex expression of the module names to replace with Lora."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
load_in_bits: Optional[int] = field(default=8)
model_revision: str = field(
default="main",
metadata={
"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
raise ValueError(
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
)
if type(self.target_modules) == str:
self.target_modules = self.target_modules.split(',')
@dataclass
class DataTrainingArguments:
train_on_inputs: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_files: Optional[List[str]] = field(
default=None, metadata={"help": "The input training data file (a text file)."})
validation_files: Optional[List[str]] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
streaming: bool = field(default=False, metadata={
"help": "Enable streaming mode"})
block_size: Optional[int] = field(
default=None,
metadata={
"help": (
"Optional input sequence length after tokenization. "
"The training dataset will be truncated in block of this size for training. "
"Default to the model max input length for single sentence inputs (take into account special tokens)."
)
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
keep_linebreaks: bool = field(
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
)
def __post_init__(self):
if self.streaming:
require_version("datasets>=2.0.0",
"The streaming feature requires `datasets>=2.0.0`")
if self.dataset_name is None and self.train_files is None and self.validation_files is None:
raise ValueError(
"Need either a dataset name or a training/validation file.")
else:
if self.train_files is not None:
extension = self.train_files[0].split(".")[-1]
assert extension in [
"csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_files is not None:
extension = self.validation_files[0].split(".")[-1]
assert extension in [
"csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if state.is_world_process_zero:
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
)
kwargs["model"].save_pretrained(checkpoint_folder)
pytorch_model_path = os.path.join(
checkpoint_folder, "pytorch_model.bin")
if os.path.exists(pytorch_model_path):
os.remove(pytorch_model_path)
return control
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-
num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-
num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _load_dataset(data_args, training_args, model_args):
data_files = {}
dataset_args = {}
if data_args.train_files is not None:
data_files["train"] = data_args.train_files
if data_args.validation_files is not None:
data_files["validation"] = data_args.validation_files
extension = (
data_args.train_files[0].split(".")[-1]
if data_args.train_files is not None
else data_args.validation_files.split(".")[-1]
)
raw_datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=os.path.join(training_args.output_dir, 'dataset_cache'),
use_auth_token=True if model_args.use_auth_token else None,
**dataset_args,
)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
**dataset_args,
)
raw_datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
**dataset_args,
)
return raw_datasets
def get_optimizer_grouped_parameters(model,
weight_decay,
no_decay_name_list=[
"bias", "LayerNorm.weight"
]):
optimizer_grouped_parameters = [
{
"params": [
p for n, p in model.named_parameters()
if (not any(nd in n
for nd in no_decay_name_list) and p.requires_grad)
],
"weight_decay":
weight_decay,
},
{
"params": [
p for n, p in model.named_parameters()
if (any(nd in n
for nd in no_decay_name_list) and p.requires_grad)
],
"weight_decay":
0.0,
},
]
return optimizer_grouped_parameters
def train():
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
torch.cuda.set_device(training_args.local_rank)
device = torch.device("cuda", training_args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
# torch.distributed.init_process_group(backend='nccl')
deepspeed.init_distributed()
training_args.global_rank = torch.distributed.get_rank()
ds_config = get_train_ds_config(offload=False, stage=2)
ds_config[
'train_micro_batch_size_per_gpu'] = training_args.per_device_train_batch_size
ds_config[
'train_batch_size'] = training_args.per_device_train_batch_size * torch.distributed.get_world_size(
) * training_args.gradient_accumulation_steps
print_rank_0('train_batch_size:' + str(ds_config['train_batch_size']), training_args.global_rank)
# torch.distributed.barrier()
send_example_telemetry("run_clm", model_args, data_args)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
if training_args.should_log:
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
logger.info("**********判断是否存在检查点**********")
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
logger.info("**********set seed**********")
set_seed(training_args.seed)
logger.info("**********装载数据集**********")
raw_datasets = _load_dataset(data_args, training_args, model_args)
logger.info("**********从模型装载config**********")
config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.config_name:
config = AutoConfig.from_pretrained(
model_args.config_name, **config_kwargs)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(
model_args.model_name_or_path, **config_kwargs)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning(
"You are instantiating a new config instance from scratch.")
if model_args.config_overrides is not None:
logger.info(f"Overriding config: {model_args.config_overrides}")
config.update_from_string(model_args.config_overrides)
logger.info(f"New config: {config}")
logger.info("**********装载tokenizer**********")
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
"padding_side": 'left'
}
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name, **tokenizer_kwargs)
elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, **tokenizer_kwargs)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
logger.info("**********使用Lora方式装载模型**********")
lora_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
target_modules=model_args.target_modules,
fan_in_fan_out=False,
lora_dropout=0.05,
inference_mode=False,
bias="none",
task_type="CAUSAL_LM",
)
bnb_config_4bit = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
bnb_config_8bit = BitsAndBytesConfig(
load_in_8bit=True
)
if model_args.model_name_or_path:
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
torch_dtype=torch.float16,
load_in_8bit=True if model_args.load_in_bits == 8 else False,
quantization_config=bnb_config_4bit if model_args.load_in_bits == 4 else bnb_config_8bit,
device_map={"": int(os.environ.get("LOCAL_RANK") or 0)}
) #.half().cuda()
else:
model = AutoModelForCausalLM.from_config(config)
n_params = sum({p.data_ptr(): p.numel()
for p in model.parameters()}.values())
logger.info(
f"Training new model from scratch - Total size={n_params / 2 ** 20:.2f}M params")
logger.info("**********调整嵌入的大小**********")
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small lora_configvocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# if model_args.load_in_bits == 8:
# model = prepare_model_for_int8_training(model)
# elif model_args.load_in_bits == 4:
# model = prepare_model_for_kbit_training(model)
train_on_inputs = True
print('train_on_inputs', train_on_inputs)
logger.info("**********生成和词元化prompt**********")
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
if data_args.block_size is None:
cutoff_len = 512
else:
cutoff_len = data_args.block_size
prompter = Prompter("alpaca")
def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=True,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
add_eos_token = tokenizer.add_eos_token
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
tokenized_datasets = raw_datasets.map(generate_and_tokenize_prompt)
logger.info("**********校验数据集**********")
if data_args.block_size is None:
block_size = tokenizer.model_max_length
if block_size > 2048:
block_size = 2048
else:
block_size = min(data_args.block_size, tokenizer.model_max_length)
if training_args.do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = tokenized_datasets["train"]
if data_args.max_train_samples is not None:
max_train_samples = min(
len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
for index in random.sample(range(len(train_dataset)), 3):
logger.info(
f"Sample {index} of the training set: {train_dataset[index]}.")
train_dataset = train_dataset.shuffle(seed=training_args.seed)
if training_args.do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = tokenized_datasets["validation"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(
len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1)
metric = evaluate.load("accuracy.py")
def compute_metrics(eval_preds):
preds, labels = eval_preds
labels = labels[:, 1:].reshape(-1)
preds = preds[:, :-1].reshape(-1)
return metric.compute(predictions=preds, references=labels)
# layer_norm_names=[]
logger.info("**********peft处理model**********")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
special_tokens_dict = dict()
if tokenizer.pad_token is None:
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
smart_tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
)
logger.info("**********deepspeed模型优化**********")
optimizer_grouped_parameters = get_optimizer_grouped_parameters(
model, training_args.weight_decay)
AdamOptimizer = FusedAdam
optimizer = AdamOptimizer(optimizer_grouped_parameters,
lr=training_args.learning_rate,
betas=(0.9, 0.95))
logger.info("**********optimizer**********")
num_update_steps_per_epoch = math.ceil(
len(train_dataset) / ds_config['train_batch_size'])
lr_scheduler = get_scheduler(
name='cosine',
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=training_args.num_train_epochs * num_update_steps_per_epoch,
)
training_args.gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
logger.info("**********deepspeed.initialize**********")
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=training_args,
config=ds_config,
lr_scheduler=lr_scheduler,
dist_init_required=True)
print_rank_0(train_dataset, training_args.global_rank)
train_sampler = DistributedSampler(train_dataset)
eval_sampler = DistributedSampler(eval_dataset)
data_collator = transformers.DataCollatorWithPadding(
tokenizer
)
eval_dataloader = DataLoader(eval_dataset,
collate_fn=data_collator,
sampler=eval_sampler,
batch_size=training_args.per_device_eval_batch_size)
train_dataloader = DataLoader(train_dataset,
collate_fn=data_collator,
sampler=train_sampler,
batch_size=training_args.per_device_train_batch_size)
def evaluation(model, eval_dataloader):
model.eval()
losses = 0
metrics = []
for step, batch in enumerate(eval_dataloader):
batch = to_device(batch, device)
with torch.no_grad():
outputs = model(**batch)
preds = preprocess_logits_for_metrics(outputs.logits, batch)
eval_preds = (preds, batch['labels'])
metric = compute_metrics(eval_preds)
metrics.append(metric['accuracy'])
return np.mean(metrics)
for epoch in range(int(training_args.num_train_epochs)):
print_rank_0(
f"Beginning of Epoch {epoch + 1}/{training_args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}",
training_args.global_rank)
model.train()
for step, batch in enumerate(train_dataloader):
# logger.info(batch['input_ids'].shape)
# logger.info(batch['attention_mask'].shape)
# logger.info(batch['labels'].shape)
batch = to_device(batch, device)
outputs = model(**batch, use_cache=False)
loss = outputs.loss
model.backward(loss)
model.step()
# Evaluate perplexity on the validation set.
print_rank_0(
f"***** Evaluating perplexity, Epoch {epoch + 1}/{training_args.num_train_epochs} *****",
training_args.global_rank)
perplexity = evaluation(model, eval_dataloader)
print_rank_0(f"ppl: {perplexity}", training_args.global_rank)
model.tput_timer.update_epoch_count()
if training_args.output_dir is not None:
print_rank_0('saving the final model ...', training_args.global_rank)
if training_args.global_rank == 0:
save_hf_format(model, tokenizer, training_args)
#
# logger.info("**********初始化训练器**********")
# trainer = Trainer(
# model=model,
# args=training_args,
# train_dataset=train_dataset if training_args.do_train else None,
# eval_dataset=eval_dataset if training_args.do_eval else None,
# tokenizer=tokenizer,
# # Data collator will default to DataCollatorWithPadding, so we change it.
# data_collator=transformers.DataCollatorForSeq2Seq(
# tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
# ),
# compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
# preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval and not is_torch_tpu_available()else None,
# callbacks=([SavePeftModelCallback] if isinstance(
# model, PeftModel) else None),
# )
#
# logger.info("**********开始训练**********")
# if training_args.do_train:
# checkpoint = None
# if training_args.resume_from_checkpoint is not None:
# resume_from_checkpoint = training_args.resume_from_checkpoint
# checkpoint_name = os.path.join(
# resume_from_checkpoint, "pytorch_model.bin")
# if not os.path.exists(checkpoint_name):
# checkpoint_name = os.path.join(
# resume_from_checkpoint, "adapter_model.bin"
# ) # only LoRA model - LoRA config above has to fit
# resume_from_checkpoint = (
# False # So the trainer won't try loading its state
# )
# # The two files above have a different name depending on how they were saved, but are actually the same.
# if os.path.exists(checkpoint_name):
# print(f"Restarting from {checkpoint_name}")
# adapters_weights = torch.load(checkpoint_name)
# set_peft_model_state_dict(model, adapters_weights)
# else:
# print(f"Checkpoint {checkpoint_name} not found")
# elif last_checkpoint is not None:
# checkpoint = last_checkpoint
#
# if torch.__version__ >= "2" and sys.platform != "win32":
# model = torch.compile(model)
#
# train_result = trainer.train(resume_from_checkpoint=checkpoint)
# trainer.save_model() # Saves the tokenizer too for easy upload
#
# metrics = train_result.metrics
#
# max_train_samples = (
# data_args.max_train_samples if data_args.max_train_samples is not None else len(
# train_dataset)
# )
# metrics["train_samples"] = min(max_train_samples, len(train_dataset))
#
# trainer.log_metrics("train", metrics)
# trainer.save_metrics("train", metrics)
# trainer.save_state()
#
# logger.info("**********开始评估**********")
# if training_args.do_eval:
# metrics = trainer.evaluate()
#
# max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(
# eval_dataset)
# metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
# try:
# perplexity = math.exp(metrics["eval_loss"])
# except OverflowError:
# perplexity = float("inf")
# metrics["perplexity"] = perplexity
#
# trainer.log_metrics("eval", metrics)
# trainer.save_metrics("eval", metrics)
if __name__ == "__main__":
# with torch.autocast("cuda"):
train()