-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathfinetune.py
868 lines (805 loc) · 37.2 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
#!/usr/bin/env python
# coding=utf-8
import argparse
import logging
import math
import os
import random
import datasets
from datetime import timedelta
import torch
from functools import partial
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed, InitProcessGroupKwargs
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import deepspeed
from deepspeed import get_accelerator
import json
import jsonlines
from peft import AutoPeftModelForCausalLM
import transformers
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
LlamaTokenizer,
LlamaTokenizerFast,
MistralForCausalLM,
SchedulerType,
DataCollatorForSeq2Seq,
get_scheduler,
GPTNeoXTokenizerFast,
GPT2Tokenizer,
OPTForCausalLM,
BitsAndBytesConfig,
)
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help="The name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The configuration name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
)
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=False,
)
parser.add_argument(
"--peft_model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=False,
)
parser.add_argument(
"--config_name",
type=str,
default=None,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--use_lora",
action="store_true",
help="If passed, will use LORA (low-rank parameter-efficient training) to train the model.",
)
parser.add_argument(
"--lora_rank",
type=int,
default=64,
help="The rank of lora.",
)
parser.add_argument(
"--lora_alpha",
type=float,
default=16,
help="The alpha parameter of lora.",
)
parser.add_argument(
"--lora_dropout",
type=float,
default=0.1,
help="The dropout rate of lora modules.",
)
parser.add_argument(
"--use_flash_attn",
action="store_true",
help="If passed, will use flash attention to train the model.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--max_seq_length",
type=int,
default=512,
help="The maximum total sequence length (prompt+completion) of each training example.",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--num_train_examples",
type=int,
default=50000,
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument(
"--warmup_ratio", type=float, default=0, help="Ratio of total training steps used for warmup."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--preprocessing_num_workers",
type=int,
default=None,
help="The number of processes to use for the preprocessing.",
)
parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
)
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--logging_steps",
type=int,
default=None,
help="Log the training loss and learning rate every logging_steps steps.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument(
"--low_cpu_mem_usage",
action="store_true",
help=(
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
"If passed, LLM loading time and RAM consumption will be benefited."
),
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help=(
"Turn on gradient checkpointing. Saves memory but slows training."
),
)
parser.add_argument(
"--use_qlora",
action="store_true",
help=(
"Use qLoRA training - main thing is initialising model in quantised form. Not compatible with deepspeed."
),
)
parser.add_argument(
'--clip_grad_norm',
type=float,
default=-1,
help='Clip gradient norm. Not compatible with deepspeed (use deepspeed config instead).',
)
parser.add_argument(
'--use_8bit_optimizer',
action='store_true',
help='Use 8bit optimizer from bitsandbytes. Not compatible with deepspeed (use deepspeed config instead).',
)
parser.add_argument(
'--add_bos',
action='store_true',
help='Forcibly add bos token to the beginning of the input sequence. Use only when tokenizer does not add bos token by default (e.g., olmo).',
)
parser.add_argument(
'--timeout',
type=int,
default=1800,
help='Timeout for the training process. Useful if tokenization process is long. Default is 1800 seconds (30 minutes).',
)
parser.add_argument(
'--trust_remote_code',
action='store_true',
help='Trust remote code when loading pretrained models and tokenizers. Use only when you trust the remote code.',
)
parser.add_argument(
'--reduce_loss',
default='mean',
choices=['mean', 'sum'],
help='How to reduce loss over tokens. Default is mean, but using sum can improve chat model performance.',
)
args = parser.parse_args()
# Sanity checks
if args.dataset_name is None and args.train_file is None:
raise ValueError("Need either a dataset name or a training file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
assert extension in ["json", "jsonl"], "`train_file` should be a json/jsonl file."
return args
def encode_with_prompt_completion_format(example, tokenizer, max_seq_length, add_bos=False):
'''
Here we assume each example has 'prompt' and 'completion' fields.
We concatenate prompt and completion and tokenize them together because otherwise prompt will be padded/trancated
and it doesn't make sense to follow directly with the completion.
'''
# if prompt doesn't end with space and completion doesn't start with space, add space
if not example['prompt'].endswith((' ', '\n', '\t')) and not example['completion'].startswith((' ', '\n', '\t')):
example_text = example['prompt'] + ' ' + example['completion']
else:
example_text = example['prompt'] + example['completion']
example_text = example_text + tokenizer.eos_token
if add_bos:
example_text = tokenizer.bos_token + example_text
tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
input_ids = tokenized_example.input_ids
labels = input_ids.clone()
tokenized_prompt = tokenizer(example['prompt'], return_tensors='pt', max_length=max_seq_length, truncation=True)
# mask the prompt part for avoiding loss
labels[:, :tokenized_prompt.input_ids.shape[1]] = -100
attention_mask = torch.ones_like(input_ids)
return {
'input_ids': input_ids.flatten(),
'labels': labels.flatten(),
'attention_mask': attention_mask.flatten(),
}
def encode_with_messages_format(example, tokenizer, max_seq_length, add_bos=False):
'''
Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
We concatenate all messages with the roles as delimiters and tokenize them together.
'''
messages = example['messages']
if len(messages) == 0:
raise ValueError('messages field is empty.')
def _concat_messages(messages):
message_text = ""
for message in messages:
if message["role"] == "system":
message_text += "[INST]\n" + message["content"].strip() + "[/INST]\n"
elif message["role"] == "user":
message_text += "[INST]\n" + message["content"].strip() + "[/INST]\n"
elif message["role"] == "assistant":
message_text += message["content"].strip() + tokenizer.eos_token + "\n"
else:
raise ValueError("Invalid role: {}".format(message["role"]))
return message_text
example_text = _concat_messages(messages).strip()
if add_bos:
example_text = tokenizer.bos_token + example_text
tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
input_ids = tokenized_example.input_ids
labels = input_ids.clone()
tokenized_prompt = tokenizer(_concat_messages(messages[:-1]).strip(), return_tensors='pt', max_length=max_seq_length, truncation=True)
labels[:, :tokenized_prompt.input_ids.shape[1]] = -100
attention_mask = torch.ones_like(input_ids)
return {
'input_ids': input_ids.flatten(),
'labels': labels.flatten(),
'attention_mask': attention_mask.flatten(),
}
def encode_with_qr_format(question, response, tokenizer, max_seq_length, add_bos=False):
example_text = f"[INST]\n{question.strip()}\n[/INST]\n{response.strip()}\n"
if add_bos:
example_text = tokenizer.bos_token + example_text
tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
input_ids = tokenized_example.input_ids
labels = input_ids.clone()
label_text = f"[INST]\n{question.strip()}\n[/INST]\n"
tokenized_prompt = tokenizer(label_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
labels[:, :tokenized_prompt.input_ids.shape[1]] = -100
attention_mask = torch.ones_like(input_ids)
return {
'input_ids': input_ids.flatten(),
'labels': labels.flatten(),
'attention_mask': attention_mask.flatten(),
}
def save_with_accelerate(accelerator, model, tokenizer, output_dir, args):
unwrapped_model = accelerator.unwrap_model(model)
# When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict.
# Otherwise, sometimes the model will be saved with only part of the parameters.
# Also, accelerator needs to use the wrapped model to get the state_dict.
state_dict = accelerator.get_state_dict(model)
if args.use_lora:
# When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process
# and has its own save_pretrained function for only saving lora modules.
# We have to manually specify the is_main_process outside the save_pretrained function.
if accelerator.is_main_process:
unwrapped_model.save_pretrained(output_dir, state_dict=state_dict)
else:
# don't use safetensors for saving for now
unwrapped_model.save_pretrained(
output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=state_dict,
safe_serialization=False
)
def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator_log_kwargs = {}
if args.with_tracking:
accelerator_log_kwargs["log_with"] = args.report_to
accelerator_log_kwargs["project_dir"] = args.output_dir
# if you get timeouts (e.g. due to long tokenization) increase this.
timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout))
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
**accelerator_log_kwargs,
kwargs_handlers=[timeout_kwargs]
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
args.dataset_name,
args.dataset_config_name,
)
else:
data_files = {}
dataset_args = {}
if args.train_file is not None:
data_files["train"] = args.train_file
try:
raw_datasets_dict = json.load(open(args.train_file))
raw_datasets = datasets.Dataset.from_dict(raw_datasets_dict)
except:
raw_datasets_reader = jsonlines.open(args.train_file)
raw_datasets_dict = {"messages": []}
for idx, line in enumerate(raw_datasets_reader):
raw_datasets_dict["messages"].append([{
"role": "user",
"content": line["question"]
}, {
"role": "assistant",
"content": line["response"]
}])
raw_datasets = datasets.Dataset.from_dict(raw_datasets_dict)
raw_datasets_reader.close()
raw_datasets = raw_datasets.train_test_split(test_size=0.0005, seed=42)
# filter out examples with too long messages (Currently set to 400 words)
logger.info(f"Currently, there are {len(raw_datasets['train'])} examples in the training set.")
raw_datasets['train'] = raw_datasets['train'].filter(lambda example: sum([len(d['content'].split()) for d in example['messages']]) < 200)
raw_datasets['test'] = raw_datasets['test'].filter(lambda example: sum([len(d['content'].split()) for d in example['messages']]) < 200)
logger.info(f"After filtering, there are {len(raw_datasets['train'])} examples in the training set.")
raw_datasets['train'] = datasets.Dataset.from_dict(raw_datasets['train'][:args.num_train_examples])
logger.info(f"Sample of the training set: {random.choice(raw_datasets['train'])}")
# Load pretrained model and tokenizer
if args.config_name:
config = AutoConfig.from_pretrained(args.config_name, trust_remote_code=args.trust_remote_code)
elif args.model_name_or_path:
config = None
elif args.peft_model_name_or_path:
config = None
else:
raise ValueError(
"You are instantiating a new config instance from scratch. This is not supported by this script."
)
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=args.trust_remote_code, use_fast=not args.use_slow_tokenizer)
elif args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=args.trust_remote_code, use_fast=not args.use_slow_tokenizer)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
if args.model_name_or_path:
if args.use_qlora:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
device_index = accelerator.local_process_index
device_map = {"": device_index} # force data-parallel training.
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
load_in_4bit=True,
quantization_config=bnb_config,
device_map=device_map,
trust_remote_code=args.trust_remote_code,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True if args.use_flash_attn else False,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
trust_remote_code=args.trust_remote_code,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_flash_attention_2=True if args.use_flash_attn else False,
)
elif args.peft_model_name_or_path:
model = AutoPeftModelForCausalLM.from_pretrained(
args.peft_model_name_or_path,
trust_remote_code=args.trust_remote_code,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_flash_attention_2=True if args.use_flash_attn else False,
is_trainable=True,
attn_implementation="flash_attention_2" if args.use_flash_attn else None,
)
else:
logger.info("Training new model from scratch")
model = AutoModelForCausalLM.from_config(config)
# no default pad token for llama!
# here we add all special tokens again, because the default ones are not in the special_tokens_map
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast):
num_added_tokens = tokenizer.add_special_tokens({
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
})
assert num_added_tokens in [0, 1], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present."
elif isinstance(tokenizer, GPTNeoXTokenizerFast):
num_added_tokens = tokenizer.add_special_tokens({
"pad_token": "<pad>",
})
assert num_added_tokens == 1, "GPTNeoXTokenizer should only add one special token - the pad_token."
elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM):
num_added_tokens = tokenizer.add_special_tokens({'unk_token': '<unk>'})
elif isinstance(tokenizer, OLMoTokenizerFast):
# only the eos for olmo, but we use it as bos
tokenizer.bos_token = tokenizer.eos_token
assert args.add_bos, "For OLMo, you must add bos token to the beginning of the input sequence."
elif isinstance(model, MistralForCausalLM):
num_added_tokens = tokenizer.add_special_tokens({
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
})
assert num_added_tokens in [0, 1], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present."
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
# gather deepspeed to get "real" embedding size
embeddings = model.get_input_embeddings()
with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
embedding_size = embeddings.weight.shape[0]
if len(tokenizer) > embeddings.weight.shape[0]:
model.resize_token_embeddings(len(tokenizer))
if args.use_lora:
if args.use_qlora:
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
logger.info("Initializing LORA model...")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=["q_proj", "o_proj", "v_proj", "k_proj", "gate_proj", "up_proj", "down_proj"]
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# Preprocessing the datasets.
if "prompt" in raw_datasets["train"].column_names and "completion" in raw_datasets["train"].column_names:
encode_function = partial(
encode_with_prompt_completion_format,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
add_bos=args.add_bos,
)
elif "messages" in raw_datasets["train"].column_names:
encode_function = partial(
encode_with_messages_format,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
add_bos=args.add_bos,
)
else:
raise ValueError("You need to have either 'prompt'&'completion' or 'messages' in your column names.")
with accelerator.main_process_first():
lm_datasets = raw_datasets.map(
encode_function,
batched=False,
num_proc=args.preprocessing_num_workers,
load_from_cache_file=not args.overwrite_cache,
remove_columns=[name for name in raw_datasets["train"].column_names if name not in ["input_ids", "labels", "attention_mask"]],
desc="Tokenizing and reformatting instruction data",
)
lm_datasets.set_format(type="pt")
lm_datasets = lm_datasets.filter(lambda example: (example['labels'] != -100).any())
train_dataset = lm_datasets["train"]
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# DataLoaders creation:
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
batch_size=args.per_device_train_batch_size
)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
if args.use_qlora:
from bitsandbytes.optim import AdamW
optimizer = AdamW(
optimizer_grouped_parameters,
lr=args.learning_rate,
optim_bits=8 if args.use_8bit_optimizer else 32,
is_paged=True
)
else:
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
# Create the learning rate scheduler.
# Note: the current accelerator.step() calls the .step() of the real scheduler for the `num_processes` times. This is because they assume
# the user initialize the scheduler with the entire training set. In the case of data parallel training, each process only
# sees a subset (1/num_processes) of the training set. So each time the process needs to update the lr multiple times so that the total
# number of updates in the end matches the num_training_steps here.
# Here we need to set the num_training_steps to either using the entire training set (when epochs is specified) or we need to multiply the
# num_training_steps by num_processes so that the total number of updates matches the num_training_steps.
num_training_steps_for_scheduler = args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio),
)
# Prepare everything with `accelerator`.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
checkpointing_steps = args.checkpointing_steps
if checkpointing_steps is not None and checkpointing_steps.isdigit():
checkpointing_steps = int(checkpointing_steps)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("open_instruct", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
checkpoint_path = args.resume_from_checkpoint
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime)
path = dirs[
-1
] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(checkpoint_path)
# Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
completed_steps = starting_epoch * num_update_steps_per_epoch
else:
# need to multiply `gradient_accumulation_steps` to reflect real steps
resume_step = (
int(training_difference.replace("step_", ""))
* args.gradient_accumulation_steps
)
starting_epoch = resume_step // len(train_dataloader)
completed_steps = resume_step // args.gradient_accumulation_steps
resume_step -= starting_epoch * len(train_dataloader)
# update the progress_bar if load from checkpoint
progress_bar.update(completed_steps)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train()
total_loss = 0
if (
args.resume_from_checkpoint
and epoch == starting_epoch
and resume_step is not None
):
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
active_dataloader = accelerator.skip_first_batches(
train_dataloader, resume_step
)
else:
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
with accelerator.accumulate(model):
outputs = model(**batch, use_cache=False)
if args.reduce_loss == 'mean':
loss = outputs.loss
else:
# reduce loss is sum
# this ensures that we weight all tokens in the dataset equally,
# rather than weighting each overall example equally when
# using high amounts of gradient accumulation.
# this can result in > 5 point improvements in AlpacaEval
# see https://github.com/huggingface/transformers/issues/24725 for
# more discussion and details.
logits = outputs.logits
labels = batch["labels"]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss(reduction='sum')
shift_logits = shift_logits.view(-1, embedding_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
# We keep track of the loss at each logged step
total_loss += loss.detach().float()
accelerator.backward(loss)
# clip gradient norm. don't do this with deepspeed
if accelerator.sync_gradients and args.clip_grad_norm > 0:
accelerator.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
if args.logging_steps and completed_steps % args.logging_steps == 0:
avg_loss = accelerator.gather(total_loss).mean().item() / args.gradient_accumulation_steps / args.logging_steps
logger.info(f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}")
if args.with_tracking:
accelerator.log(
{
"learning_rate": lr_scheduler.get_last_lr()[0],
"train_loss": avg_loss,
},
step=completed_steps,
)
total_loss = 0
if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
save_with_accelerate(accelerator, model, tokenizer, output_dir, args)
if completed_steps >= args.max_train_steps:
break
get_accelerator().empty_cache()
if step % 200 == 0 and step != 0:
# test output with a random data from training set
# fetch a random data from training set
logger.info("Generating test...")
random_data = random.choice(train_dataset)
input_ids, attention_mask = random_data["input_ids"].unsqueeze(0).to(accelerator.device), random_data["attention_mask"].unsqueeze(0).to(accelerator.device)
labels = random_data["labels"].unsqueeze(0).to(accelerator.device)
# delete the labels from input_ids and attention_mask
input_ids[labels != -100] = tokenizer.pad_token_id
attention_mask[labels != -100] = 0
with torch.no_grad():
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample=True,
max_new_tokens=128,
top_k=50,
top_p=0.95,
temperature=0.7,
num_return_sequences=1,
)
logger.info(f"Generated test:\nInput: {tokenizer.decode(input_ids[0], skip_special_tokens=True)}\nOutput: {tokenizer.decode(output[0], skip_special_tokens=True)}")
if args.checkpointing_steps == "epoch":
output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
save_with_accelerate(accelerator, model, tokenizer, output_dir, args)
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir)
save_with_accelerate(accelerator, model, tokenizer, args.output_dir, args)
if __name__ == "__main__":
main()