From 11b8af112274d7df1e54515f4a25b8701d30c4dd Mon Sep 17 00:00:00 2001 From: Yuxiang Date: Tue, 26 Dec 2023 15:53:39 +0100 Subject: [PATCH 1/4] add code for apple silicon --- .gitignore | 1 + eval.py | 8 +- generate.py | 240 +++++++++++++++++++++++++++----------- generate_texts.py | 187 ++++++++++++++++++++---------- requirements.txt | 2 +- train.json | 1 - train.py | 290 +++++++++++++++++++++++++++++++--------------- train_single.py | 8 +- 8 files changed, 506 insertions(+), 231 deletions(-) delete mode 100644 train.json diff --git a/.gitignore b/.gitignore index 37078580..d706e6cb 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ data/ .idea/modules.xml .idea/vcs.xml .idea +tensorboard_summary diff --git a/eval.py b/eval.py index 14124ab6..bd11fba1 100644 --- a/eval.py +++ b/eval.py @@ -70,7 +70,13 @@ def main(): n_ctx = model_config.n_ctx full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) full_tokenizer.max_len = n_ctx - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = 'cpu' + if torch.cuda.is_available() : + device = 'cuda' + elif torch.backends.mps.is_available(): + mps_device = torch.device("mps") + x = torch.ones(1, device=mps_device) + device = "mps" print('using device:', device) raw_data_path = args.raw_data_path diff --git a/generate.py b/generate.py index fb1649a6..5bf538d6 100644 --- a/generate.py +++ b/generate.py @@ -1,14 +1,15 @@ +import argparse +import os + import torch import torch.nn.functional as F -import os -import argparse from tqdm import trange from transformers import GPT2LMHeadModel def is_word(word): for item in list(word): - if item not in 'qwertyuiopasdfghjklzxcvbnm': + if item not in "qwertyuiopasdfghjklzxcvbnm": return False return True @@ -24,29 +25,33 @@ def _is_chinese_char(char): # space-separated words, so they are not treated specially and handled # like the all of the other languages. cp = ord(char) - if ((cp >= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # return True return False -def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): - """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (vocabulary size) - top_k > 0: keep only top k tokens with highest probability (top-k filtering). - top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 +def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (vocabulary size) + top_k > 0: keep only top k tokens with highest probability (top-k filtering). + top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ - assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear + assert ( + logits.dim() == 1 + ) # batch size 1 for now - could be updated for more but the code would be less clear top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k @@ -68,28 +73,49 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') return logits -def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0, - device='cpu'): +def sample_sequence( + model, + context, + length, + n_ctx, + tokenizer, + temperature=1.0, + top_k=30, + top_p=0.0, + repitition_penalty=1.0, + device="cpu", +): + if torch.backends.mps.is_available(): + device = "mps" context = torch.tensor(context, dtype=torch.long, device=device) context = context.unsqueeze(0) generated = context with torch.no_grad(): for _ in trange(length): - inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)} + inputs = {"input_ids": generated[0][-(n_ctx - 1) :].unsqueeze(0)} outputs = model( - **inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) + **inputs + ) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) next_token_logits = outputs[0][0, -1, :] for id in set(generated): next_token_logits[id] /= repitition_penalty next_token_logits = next_token_logits / temperature - next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') - filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + next_token_logits[tokenizer.convert_tokens_to_ids("[UNK]")] = -float("Inf") + filtered_logits = top_k_top_p_filtering( + next_token_logits, top_k=top_k, top_p=top_p + ) + next_token = torch.multinomial( + F.softmax(filtered_logits, dim=-1), num_samples=1 + ) generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) return generated.tolist()[0] -def fast_sample_sequence(model, context, length, temperature=1.0, top_k=30, top_p=0.0, device='cpu'): +def fast_sample_sequence( + model, context, length, temperature=1.0, top_k=30, top_p=0.0, device="cpu" +): + if torch.backends.mps.is_available(): + device = "mps" inputs = torch.LongTensor(context).view(1, -1).to(device) if len(context) > 1: _, past = model(inputs[:, :-1], None)[:2] @@ -104,46 +130,107 @@ def fast_sample_sequence(model, context, length, temperature=1.0, top_k=30, top_ output, past = output[:2] output = output[-1].squeeze(0) / temperature filtered_logits = top_k_top_p_filtering(output, top_k=top_k, top_p=top_p) - next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1) + next_token = torch.multinomial( + torch.softmax(filtered_logits, dim=-1), num_samples=1 + ) generate.append(next_token.item()) prev = next_token.view(1, 1) return generate # 通过命令行参数--fast_pattern,指定模式 -def generate(n_ctx, model, context, length, tokenizer, temperature=1, top_k=0, top_p=0.0, repitition_penalty=1.0, device='cpu', - is_fast_pattern=False): +def generate( + n_ctx, + model, + context, + length, + tokenizer, + temperature=1, + top_k=0, + top_p=0.0, + repitition_penalty=1.0, + device="cpu", + is_fast_pattern=False, +): + if torch.backends.mps.is_available(): + device = "mps" if is_fast_pattern: - return fast_sample_sequence(model, context, length, temperature=temperature, top_k=top_k, top_p=top_p, - device=device) + return fast_sample_sequence( + model, + context, + length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + device=device, + ) else: - return sample_sequence(model, context, length, n_ctx, tokenizer=tokenizer, temperature=temperature, top_k=top_k, top_p=top_p, - repitition_penalty=repitition_penalty, device=device) + return sample_sequence( + model, + context, + length, + n_ctx, + tokenizer=tokenizer, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repitition_penalty=repitition_penalty, + device=device, + ) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='生成设备') - parser.add_argument('--length', default=-1, type=int, required=False, help='生成长度') - parser.add_argument('--batch_size', default=1, type=int, required=False, help='生成的batch size') - parser.add_argument('--nsamples', default=10, type=int, required=False, help='生成几个样本') - parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度') - parser.add_argument('--topk', default=8, type=int, required=False, help='最高几选一') - parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率') - parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False, - help='模型参数') - parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='词表路径') - parser.add_argument('--model_path', default='model/final_model', type=str, required=False, help='模型路径') - parser.add_argument('--prefix', default='萧炎', type=str, required=False, help='生成文章的开头') - parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词') - parser.add_argument('--segment', action='store_true', help='中文以词为单位') - parser.add_argument('--fast_pattern', action='store_true', help='采用更加快的方式生成文本') - parser.add_argument('--save_samples', action='store_true', help='保存产生的样本') - parser.add_argument('--save_samples_path', default='.', type=str, required=False, help="保存样本的路径") - parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False) + parser.add_argument( + "--device", default="0,1,2,3", type=str, required=False, help="生成设备" + ) + parser.add_argument("--length", default=-1, type=int, required=False, help="生成长度") + parser.add_argument( + "--batch_size", default=1, type=int, required=False, help="生成的batch size" + ) + parser.add_argument( + "--nsamples", default=10, type=int, required=False, help="生成几个样本" + ) + parser.add_argument( + "--temperature", default=1, type=float, required=False, help="生成温度" + ) + parser.add_argument("--topk", default=8, type=int, required=False, help="最高几选一") + parser.add_argument("--topp", default=0, type=float, required=False, help="最高积累概率") + parser.add_argument( + "--model_config", + default="config/model_config_small.json", + type=str, + required=False, + help="模型参数", + ) + parser.add_argument( + "--tokenizer_path", + default="cache/vocab_small.txt", + type=str, + required=False, + help="词表路径", + ) + parser.add_argument( + "--model_path", + default="model/final_model", + type=str, + required=False, + help="模型路径", + ) + parser.add_argument( + "--prefix", default="萧炎", type=str, required=False, help="生成文章的开头" + ) + parser.add_argument("--no_wordpiece", action="store_true", help="不做word piece切词") + parser.add_argument("--segment", action="store_true", help="中文以词为单位") + parser.add_argument("--fast_pattern", action="store_true", help="采用更加快的方式生成文本") + parser.add_argument("--save_samples", action="store_true", help="保存产生的样本") + parser.add_argument( + "--save_samples_path", default=".", type=str, required=False, help="保存样本的路径" + ) + parser.add_argument("--repetition_penalty", default=1.0, type=float, required=False) args = parser.parse_args() - print('args:\n' + args.__repr__()) + print("args:\n" + args.__repr__()) if args.segment: from tokenizations import tokenization_bert_word_level as tokenization_bert @@ -159,7 +246,13 @@ def main(): topp = args.topp repetition_penalty = args.repetition_penalty - device = "cuda" if torch.cuda.is_available() else "cpu" + device = 'cpu' + if torch.cuda.is_available() : + device = 'cuda' + elif torch.backends.mps.is_available(): + mps_device = torch.device("mps") + x = torch.ones(1, device=mps_device) + device = "mps" tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) model = GPT2LMHeadModel.from_pretrained(args.model_path) @@ -173,7 +266,9 @@ def main(): if args.save_samples: if not os.path.exists(args.save_samples_path): os.makedirs(args.save_samples_path) - samples_file = open(args.save_samples_path + '/samples.txt', 'w', encoding='utf8') + samples_file = open( + args.save_samples_path + "/samples.txt", "w", encoding="utf8" + ) while True: raw_text = args.prefix context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text)) @@ -184,32 +279,37 @@ def main(): model=model, context=context_tokens, length=length, - is_fast_pattern=args.fast_pattern, tokenizer=tokenizer, - temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device + is_fast_pattern=args.fast_pattern, + tokenizer=tokenizer, + temperature=temperature, + top_k=topk, + top_p=topp, + repitition_penalty=repetition_penalty, + device=device, ) for i in range(batch_size): generated += 1 text = tokenizer.convert_ids_to_tokens(out) for i, item in enumerate(text[:-1]): # 确保英文前后有空格 if is_word(item) and is_word(text[i + 1]): - text[i] = item + ' ' + text[i] = item + " " for i, item in enumerate(text): - if item == '[MASK]': - text[i] = '' - elif item == '[CLS]': - text[i] = '\n\n' - elif item == '[SEP]': - text[i] = '\n' + if item == "[MASK]": + text[i] = "" + elif item == "[CLS]": + text[i] = "\n\n" + elif item == "[SEP]": + text[i] = "\n" info = "=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40 + "\n" print(info) - text = ''.join(text).replace('##', '').strip() + text = "".join(text).replace("##", "").strip() print(text) if args.save_samples: samples_file.write(info) samples_file.write(text) - samples_file.write('\n') - samples_file.write('=' * 90) - samples_file.write('\n' * 2) + samples_file.write("\n") + samples_file.write("=" * 90) + samples_file.write("\n" * 2) print("=" * 80) if generated == nsamples: # close file when finish writing. @@ -218,5 +318,5 @@ def main(): break -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/generate_texts.py b/generate_texts.py index 1554350a..30a794b6 100644 --- a/generate_texts.py +++ b/generate_texts.py @@ -1,7 +1,8 @@ +import argparse +import os + import torch import torch.nn.functional as F -import os -import argparse from tqdm import trange from transformers import GPT2LMHeadModel @@ -10,7 +11,7 @@ def is_word(word): for item in list(word): - if item not in 'qwertyuiopasdfghjklzxcvbnm': + if item not in "qwertyuiopasdfghjklzxcvbnm": return False return True @@ -26,29 +27,33 @@ def _is_chinese_char(char): # space-separated words, so they are not treated specially and handled # like the all of the other languages. cp = ord(char) - if ((cp >= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # return True return False -def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): - """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (vocabulary size) - top_k > 0: keep only top k tokens with highest probability (top-k filtering). - top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 +def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (vocabulary size) + top_k > 0: keep only top k tokens with highest probability (top-k filtering). + top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ - assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear + assert ( + logits.dim() == 1 + ) # batch size 1 for now - could be updated for more but the code would be less clear top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k @@ -70,49 +75,102 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') return logits -def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0, - device='cpu'): +def sample_sequence( + model, + context, + length, + n_ctx, + tokenizer, + temperature=1.0, + top_k=30, + top_p=0.0, + repitition_penalty=1.0, + device="cpu", +): + if torch.backends.mps.is_available(): + device = "mps" context = torch.tensor(context, dtype=torch.long, device=device) context = context.unsqueeze(0) generated = context with torch.no_grad(): for _ in trange(length): - inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)} + inputs = {"input_ids": generated[0][-(n_ctx - 1) :].unsqueeze(0)} outputs = model( - **inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) + **inputs + ) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) next_token_logits = outputs[0][0, -1, :] for id in set(generated): next_token_logits[id] /= repitition_penalty next_token_logits = next_token_logits / temperature - next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') - filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + next_token_logits[tokenizer.convert_tokens_to_ids("[UNK]")] = -float("Inf") + filtered_logits = top_k_top_p_filtering( + next_token_logits, top_k=top_k, top_p=top_p + ) + next_token = torch.multinomial( + F.softmax(filtered_logits, dim=-1), num_samples=1 + ) generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) return generated def main(): parser = argparse.ArgumentParser() - parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡') - parser.add_argument('--length', default=-1, type=int, required=False, help='生成长度') - parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度,越高越随机') - parser.add_argument('--topk', default=8, type=int, required=False, help='生成的时候最高几选一') - parser.add_argument('--topp', default=0, type=float, required=False, help='生成的时候积累概率最高多少') - parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False, - help='模型参数路径') - parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='词表路径') - parser.add_argument('--model_path', default='model/final_model', type=str, required=False, help='模型路径') - parser.add_argument('--save_path', default='generated/', type=str, required=False, help='存放生成的文件的路径') - parser.add_argument('--articles_per_title', default=5, type=int, required=False, help='每个标题生成多少篇文章') - parser.add_argument('--titles', default='萧炎', type=str, required=False, help='标题列表,是一个字符串,用空格分开') - parser.add_argument('--titles_file', default='', type=str, required=False, - help='标题列表文件,文件中每行一个标题。如果这个选项有值则titles无效') - parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词') - parser.add_argument('--segment', action='store_true', help='中文以词为单位') - parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False) + parser.add_argument( + "--device", default="0,1,2,3", type=str, required=False, help="设置使用哪些显卡" + ) + parser.add_argument("--length", default=-1, type=int, required=False, help="生成长度") + parser.add_argument( + "--temperature", default=1, type=float, required=False, help="生成温度,越高越随机" + ) + parser.add_argument( + "--topk", default=8, type=int, required=False, help="生成的时候最高几选一" + ) + parser.add_argument( + "--topp", default=0, type=float, required=False, help="生成的时候积累概率最高多少" + ) + parser.add_argument( + "--model_config", + default="config/model_config_small.json", + type=str, + required=False, + help="模型参数路径", + ) + parser.add_argument( + "--tokenizer_path", + default="cache/vocab_small.txt", + type=str, + required=False, + help="词表路径", + ) + parser.add_argument( + "--model_path", + default="model/final_model", + type=str, + required=False, + help="模型路径", + ) + parser.add_argument( + "--save_path", default="generated/", type=str, required=False, help="存放生成的文件的路径" + ) + parser.add_argument( + "--articles_per_title", default=5, type=int, required=False, help="每个标题生成多少篇文章" + ) + parser.add_argument( + "--titles", default="萧炎", type=str, required=False, help="标题列表,是一个字符串,用空格分开" + ) + parser.add_argument( + "--titles_file", + default="", + type=str, + required=False, + help="标题列表文件,文件中每行一个标题。如果这个选项有值则titles无效", + ) + parser.add_argument("--no_wordpiece", action="store_true", help="不做word piece切词") + parser.add_argument("--segment", action="store_true", help="中文以词为单位") + parser.add_argument("--repetition_penalty", default=1.0, type=float, required=False) args = parser.parse_args() - print('args:\n' + args.__repr__()) + print("args:\n" + args.__repr__()) if args.segment: from tokenizations import tokenization_bert_word_level as tokenization_bert @@ -128,8 +186,8 @@ def main(): titles = args.titles.split() # 列表,里面每个元素是一个生成的标题 if args.titles_file: - with open(args.titles_file, 'r') as f: - titles = [line.strip('\n') for line in f.readlines()] + with open(args.titles_file, "r") as f: + titles = [line.strip("\n") for line in f.readlines()] articles_per_title = args.articles_per_title # 这里定义一个标题生成多少篇文章 save_path = args.save_path # 设置存到哪 @@ -149,15 +207,22 @@ def main(): for i, title in enumerate(titles): for j in range(articles_per_title): - with open(save_path + str(i) + '-' + str(j) + '.txt', 'w') as f: - context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(title)) + with open(save_path + str(i) + "-" + str(j) + ".txt", "w") as f: + context_tokens = tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(title) + ) generated = 0 out = sample_sequence( n_ctx=n_ctx, - model=model, length=length, - context=context_tokens, tokenizer=tokenizer, - temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, - device=device + model=model, + length=length, + context=context_tokens, + tokenizer=tokenizer, + temperature=temperature, + top_k=topk, + top_p=topp, + repitition_penalty=repetition_penalty, + device=device, ) out = out.tolist()[0] @@ -166,21 +231,21 @@ def main(): for i, item in enumerate(text[:-1]): # 确保英文前后有空格 if is_word(item) and is_word(text[i + 1]): - text[i] = item + ' ' + text[i] = item + " " for i, item in enumerate(text): - if item == '[MASK]': - text[i] = '' - if item == '[CLS]' or item == '[SEP]': - text[i] = '\n' + if item == "[MASK]": + text[i] = "" + if item == "[CLS]" or item == "[SEP]": + text[i] = "\n" print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) - text = ''.join(text).replace('##', '').strip() + text = "".join(text).replace("##", "").strip() # text = ''.join(text.split('\n')[:-1]) print(text) - f.write(text + '\n') + f.write(text + "\n") print("=" * 80) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/requirements.txt b/requirements.txt index 9cdd29f6..5e9989c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ transformers==2.1.1 torch numpy tqdm -sklearn +scikit-learn keras tb-nightly future diff --git a/train.json b/train.json deleted file mode 100644 index e2fd75b0..00000000 --- a/train.json +++ /dev/null @@ -1 +0,0 @@ -["第一篇文章的正文", "第二篇文章的正文", "第三篇文章的正文"] \ No newline at end of file diff --git a/train.py b/train.py index f08bebd7..aa50c5c1 100644 --- a/train.py +++ b/train.py @@ -1,75 +1,145 @@ -import transformers -import torch -import os +import argparse import json +import os import random +from datetime import datetime + import numpy as np -import argparse +import torch +import transformers +from tokenizations.bpe_tokenizer import get_encoder +from torch.nn import DataParallel +from torch.utils import bottleneck from torch.utils.tensorboard import SummaryWriter -from datetime import datetime from tqdm import tqdm -from torch.nn import DataParallel -from tokenizations.bpe_tokenizer import get_encoder def build_files(data_path, tokenized_data_path, num_pieces, full_tokenizer, min_length): - with open(data_path, 'r', encoding='utf8') as f: - print('reading lines') + with open(data_path, "r", encoding="utf8") as f: + print("reading lines") lines = json.load(f) - lines = [line.replace('\n', ' [SEP] ') for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束 + lines = [ + line.replace("\n", " [SEP] ") for line in lines + ] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束 all_len = len(lines) if not os.path.exists(tokenized_data_path): os.mkdir(tokenized_data_path) for i in tqdm(range(num_pieces)): - sublines = lines[all_len // num_pieces * i: all_len // num_pieces * (i + 1)] + sublines = lines[all_len // num_pieces * i : all_len // num_pieces * (i + 1)] if i == num_pieces - 1: - sublines.extend(lines[all_len // num_pieces * (i + 1):]) # 把尾部例子添加到最后一个piece - sublines = [full_tokenizer.tokenize(line) for line in sublines if - len(line) > min_length] # 只考虑长度超过min_length的句子 + sublines.extend( + lines[all_len // num_pieces * (i + 1) :] + ) # 把尾部例子添加到最后一个piece + sublines = [ + full_tokenizer.tokenize(line) for line in sublines if len(line) > min_length + ] # 只考虑长度超过min_length的句子 sublines = [full_tokenizer.convert_tokens_to_ids(line) for line in sublines] full_line = [] for subline in sublines: - full_line.append(full_tokenizer.convert_tokens_to_ids('[MASK]')) # 文章开头添加MASK表示文章开始 + full_line.append( + full_tokenizer.convert_tokens_to_ids("[MASK]") + ) # 文章开头添加MASK表示文章开始 full_line.extend(subline) - full_line.append(full_tokenizer.convert_tokens_to_ids('[CLS]')) # 文章之间添加CLS表示文章结束 - with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'w') as f: + full_line.append( + full_tokenizer.convert_tokens_to_ids("[CLS]") + ) # 文章之间添加CLS表示文章结束 + with open(tokenized_data_path + "tokenized_train_{}.txt".format(i), "w") as f: for id in full_line: - f.write(str(id) + ' ') - print('finish') + f.write(str(id) + " ") + print("finish") def main(): parser = argparse.ArgumentParser() - parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡') - parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False, - help='选择模型参数') - parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='选择词库') - parser.add_argument('--raw_data_path', default='data/train.json', type=str, required=False, help='原始训练语料') - parser.add_argument('--tokenized_data_path', default='data/tokenized/', type=str, required=False, - help='tokenized语料存放位置') - parser.add_argument('--raw', action='store_true', help='是否先做tokenize') - parser.add_argument('--epochs', default=5, type=int, required=False, help='训练循环') - parser.add_argument('--batch_size', default=8, type=int, required=False, help='训练batch size') - parser.add_argument('--lr', default=1.5e-4, type=float, required=False, help='学习率') - parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数') - parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss,设置为gradient accumulation的整数倍') - parser.add_argument('--stride', default=768, type=int, required=False, help='训练时取训练数据的窗口步长') - parser.add_argument('--gradient_accumulation', default=1, type=int, required=False, help='梯度积累') - parser.add_argument('--fp16', action='store_true', help='混合精度') - parser.add_argument('--fp16_opt_level', default='O1', type=str, required=False) - parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False) - parser.add_argument('--num_pieces', default=100, type=int, required=False, help='将训练语料分成多少份') - parser.add_argument('--min_length', default=128, type=int, required=False, help='最短收录文章长度') - parser.add_argument('--output_dir', default='model/', type=str, required=False, help='模型输出路径') - parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型训练起点路径') - parser.add_argument('--writer_dir', default='tensorboard_summary/', type=str, required=False, help='Tensorboard路径') - parser.add_argument('--segment', action='store_true', help='中文以词为单位') - parser.add_argument('--bpe_token', action='store_true', help='subword') - parser.add_argument('--encoder_json', default="tokenizations/encoder.json", type=str, help="encoder.json") - parser.add_argument('--vocab_bpe', default="tokenizations/vocab.bpe", type=str, help="vocab.bpe") + parser.add_argument( + "--device", default="0,1,2,3", type=str, required=False, help="设置使用哪些显卡" + ) + parser.add_argument( + "--model_config", + default="config/model_config_small.json", + type=str, + required=False, + help="选择模型参数", + ) + parser.add_argument( + "--tokenizer_path", + default="cache/vocab_small.txt", + type=str, + required=False, + help="选择词库", + ) + parser.add_argument( + "--raw_data_path", + default="data/train.json", + type=str, + required=False, + help="原始训练语料", + ) + parser.add_argument( + "--tokenized_data_path", + default="data/tokenized/", + type=str, + required=False, + help="tokenized语料存放位置", + ) + parser.add_argument("--raw", action="store_true", help="是否先做tokenize") + parser.add_argument("--epochs", default=5, type=int, required=False, help="训练循环") + parser.add_argument( + "--batch_size", default=8, type=int, required=False, help="训练batch size" + ) + parser.add_argument("--lr", default=1.5e-4, type=float, required=False, help="学习率") + parser.add_argument( + "--warmup_steps", default=2000, type=int, required=False, help="warm up步数" + ) + parser.add_argument( + "--log_step", + default=1, + type=int, + required=False, + help="多少步汇报一次loss,设置为gradient accumulation的整数倍", + ) + parser.add_argument( + "--stride", default=768, type=int, required=False, help="训练时取训练数据的窗口步长" + ) + parser.add_argument( + "--gradient_accumulation", default=1, type=int, required=False, help="梯度积累" + ) + parser.add_argument("--fp16", action="store_true", help="混合精度") + parser.add_argument("--fp16_opt_level", default="O1", type=str, required=False) + parser.add_argument("--max_grad_norm", default=1.0, type=float, required=False) + parser.add_argument( + "--num_pieces", default=100, type=int, required=False, help="将训练语料分成多少份" + ) + parser.add_argument( + "--min_length", default=128, type=int, required=False, help="最短收录文章长度" + ) + parser.add_argument( + "--output_dir", default="model/", type=str, required=False, help="模型输出路径" + ) + parser.add_argument( + "--pretrained_model", default="", type=str, required=False, help="模型训练起点路径" + ) + parser.add_argument( + "--writer_dir", + default="tensorboard_summary/", + type=str, + required=False, + help="Tensorboard路径", + ) + parser.add_argument("--segment", action="store_true", help="中文以词为单位") + parser.add_argument("--bpe_token", action="store_true", help="subword") + parser.add_argument( + "--encoder_json", + default="tokenizations/encoder.json", + type=str, + help="encoder.json", + ) + parser.add_argument( + "--vocab_bpe", default="tokenizations/vocab.bpe", type=str, help="vocab.bpe" + ) args = parser.parse_args() - print('args:\n' + args.__repr__()) + print("args:\n" + args.__repr__()) if args.segment: from tokenizations import tokenization_bert_word_level as tokenization_bert @@ -78,8 +148,10 @@ def main(): os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 - model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config) - print('config:\n' + model_config.to_json_string()) + model_config = transformers.modeling_gpt2.GPT2Config.from_json_file( + args.model_config + ) + print("config:\n" + model_config.to_json_string()) n_ctx = model_config.n_ctx if args.bpe_token: @@ -87,8 +159,14 @@ def main(): else: full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) full_tokenizer.max_len = 999999 - device = 'cuda' if torch.cuda.is_available() else 'cpu' - print('using device:', device) + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + mps_device = torch.device("mps") + x = torch.ones(1, device=mps_device) + device = "mps" + print("using device:", device) raw_data_path = args.raw_data_path tokenized_data_path = args.tokenized_data_path @@ -113,15 +191,22 @@ def main(): os.mkdir(output_dir) if raw: - print('building files') - build_files(data_path=raw_data_path, tokenized_data_path=tokenized_data_path, num_pieces=num_pieces, - full_tokenizer=full_tokenizer, min_length=min_length) - print('files built') + print("building files") + build_files( + data_path=raw_data_path, + tokenized_data_path=tokenized_data_path, + num_pieces=num_pieces, + full_tokenizer=full_tokenizer, + min_length=min_length, + ) + print("files built") if not args.pretrained_model: model = transformers.modeling_gpt2.GPT2LMHeadModel(config=model_config) else: - model = transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model) + model = transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained( + args.pretrained_model + ) model.train() model.to(device) @@ -129,58 +214,62 @@ def main(): parameters = model.parameters() for parameter in parameters: num_parameters += parameter.numel() - print('number of parameters: {}'.format(num_parameters)) + print("number of parameters: {}".format(num_parameters)) multi_gpu = False full_len = 0 - print('calculating total steps') + print("calculating total steps") for i in tqdm(range(num_pieces)): - with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f: + with open(tokenized_data_path + "tokenized_train_{}.txt".format(i), "r") as f: full_len += len([int(item) for item in f.read().strip().split()]) total_steps = int(full_len / stride * epochs / batch_size / gradient_accumulation) - print('total steps = {}'.format(total_steps)) + print("total steps = {}".format(total_steps)) optimizer = transformers.AdamW(model.parameters(), lr=lr, correct_bias=True) - scheduler = transformers.WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, - t_total=total_steps) + scheduler = transformers.WarmupLinearSchedule( + optimizer, warmup_steps=warmup_steps, t_total=total_steps + ) if fp16: try: from apex import amp except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + raise ImportError( + "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." + ) model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") - model = DataParallel(model, device_ids=[int(i) for i in args.device.split(',')]) + model = DataParallel(model, device_ids=[int(i) for i in args.device.split(",")]) multi_gpu = True - print('starting training') + print("starting training") overall_step = 0 running_loss = 0 for epoch in range(epochs): - print('epoch {}'.format(epoch + 1)) + print("epoch {}".format(epoch + 1)) now = datetime.now() - print('time: {}'.format(now)) + print("time: {}".format(now)) x = np.linspace(0, num_pieces - 1, num_pieces, dtype=np.int32) random.shuffle(x) piece_num = 0 for i in x: - with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f: + with open( + tokenized_data_path + "tokenized_train_{}.txt".format(i), "r" + ) as f: line = f.read().strip() tokens = line.split() tokens = [int(token) for token in tokens] start_point = 0 samples = [] while start_point < len(tokens) - n_ctx: - samples.append(tokens[start_point: start_point + n_ctx]) + samples.append(tokens[start_point : start_point + n_ctx]) start_point += stride if start_point < len(tokens): - samples.append(tokens[len(tokens)-n_ctx:]) + samples.append(tokens[len(tokens) - n_ctx :]) random.shuffle(samples) for step in range(len(samples) // batch_size): # drop last - # prepare data - batch = samples[step * batch_size: (step + 1) * batch_size] + batch = samples[step * batch_size : (step + 1) * batch_size] batch_inputs = [] for ids in batch: int_ids = [int(x) for x in ids] @@ -201,7 +290,9 @@ def main(): if fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm) + torch.nn.utils.clip_grad_norm_( + amp.master_params(optimizer), max_grad_norm + ) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) @@ -213,39 +304,46 @@ def main(): optimizer.zero_grad() scheduler.step() if (overall_step + 1) % log_step == 0: - tb_writer.add_scalar('loss', loss.item() * gradient_accumulation, overall_step) - print('now time: {}:{}. Step {} of piece {} of epoch {}, loss {}'.format( - datetime.now().hour, - datetime.now().minute, - step + 1, - piece_num, - epoch + 1, - running_loss * gradient_accumulation / (log_step / gradient_accumulation))) + tb_writer.add_scalar( + "loss", loss.item() * gradient_accumulation, overall_step + ) + print( + "now time: {}:{}. Step {} of piece {} of epoch {}, loss {}".format( + datetime.now().hour, + datetime.now().minute, + step + 1, + piece_num, + epoch + 1, + running_loss + * gradient_accumulation + / (log_step / gradient_accumulation), + ) + ) running_loss = 0 overall_step += 1 piece_num += 1 - print('saving model for epoch {}'.format(epoch + 1)) - if not os.path.exists(output_dir + 'model_epoch{}'.format(epoch + 1)): - os.mkdir(output_dir + 'model_epoch{}'.format(epoch + 1)) - model_to_save = model.module if hasattr(model, 'module') else model - model_to_save.save_pretrained(output_dir + 'model_epoch{}'.format(epoch + 1)) + print("saving model for epoch {}".format(epoch + 1)) + if not os.path.exists(output_dir + "model_epoch{}".format(epoch + 1)): + os.mkdir(output_dir + "model_epoch{}".format(epoch + 1)) + model_to_save = model.module if hasattr(model, "module") else model + model_to_save.save_pretrained(output_dir + "model_epoch{}".format(epoch + 1)) # torch.save(scheduler.state_dict(), output_dir + 'model_epoch{}/scheduler.pt'.format(epoch + 1)) # torch.save(optimizer.state_dict(), output_dir + 'model_epoch{}/optimizer.pt'.format(epoch + 1)) - print('epoch {} finished'.format(epoch + 1)) + print("epoch {} finished".format(epoch + 1)) then = datetime.now() - print('time: {}'.format(then)) - print('time for one epoch: {}'.format(then - now)) - - print('training finished') - if not os.path.exists(output_dir + 'final_model'): - os.mkdir(output_dir + 'final_model') - model_to_save = model.module if hasattr(model, 'module') else model - model_to_save.save_pretrained(output_dir + 'final_model') + print("time: {}".format(then)) + print("time for one epoch: {}".format(then - now)) + + print("training finished") + if not os.path.exists(output_dir + "final_model"): + os.mkdir(output_dir + "final_model") + model_to_save = model.module if hasattr(model, "module") else model + model_to_save.save_pretrained(output_dir + "final_model") # torch.save(scheduler.state_dict(), output_dir + 'final_model/scheduler.pt') # torch.save(optimizer.state_dict(), output_dir + 'final_model/optimizer.pt') -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/train_single.py b/train_single.py index dd06dc3e..37263649 100644 --- a/train_single.py +++ b/train_single.py @@ -75,7 +75,13 @@ def main(): n_ctx = model_config.n_ctx full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) full_tokenizer.max_len = 999999 - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = 'cpu' + if torch.cuda.is_available() : + device = 'cuda' + elif torch.backends.mps.is_available(): + mps_device = torch.device("mps") + x = torch.ones(1, device=mps_device) + device = "mps" print('using device:', device) raw_data_path = args.raw_data_path From 3f8e16b06fa5909bb275633ebd8983bbbf209f5f Mon Sep 17 00:00:00 2001 From: Yuxiang Date: Tue, 2 Jan 2024 01:24:07 +0100 Subject: [PATCH 2/4] update --- train.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index aa50c5c1..059525b5 100644 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ import json import os import random +import time from datetime import datetime import numpy as np @@ -167,7 +168,8 @@ def main(): x = torch.ones(1, device=mps_device) device = "mps" print("using device:", device) - + start_time = time.time() + save_interval = 300 raw_data_path = args.raw_data_path tokenized_data_path = args.tokenized_data_path raw = args.raw # 选择是否从零开始构建数据集 @@ -253,6 +255,21 @@ def main(): random.shuffle(x) piece_num = 0 for i in x: + if time.time() - start_time > save_interval: + checkpoint = { + "epoch": epoch, + "batch_idx": i, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "loss": loss, # Assuming 'loss' is your loss variable + "random_state_pytorch": torch.get_rng_state(), + "random_state_np": np.random.get_state(), + "random_state_python": random.getstate(), + # Add scheduler state_dict if you're using a scheduler + # 'scheduler_state_dict': scheduler.state_dict(), + } + torch.save(checkpoint, f"checkpoint_{epoch}_{i}.pth") + start_time = time.time() with open( tokenized_data_path + "tokenized_train_{}.txt".format(i), "r" ) as f: From 20663ea963bcf180c1e6ccedaa6e520a478fa588 Mon Sep 17 00:00:00 2001 From: Yuxiang Date: Tue, 2 Jan 2024 02:27:21 +0100 Subject: [PATCH 3/4] update --- train.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/train.py b/train.py index 059525b5..26c2c529 100644 --- a/train.py +++ b/train.py @@ -255,21 +255,6 @@ def main(): random.shuffle(x) piece_num = 0 for i in x: - if time.time() - start_time > save_interval: - checkpoint = { - "epoch": epoch, - "batch_idx": i, - "model_state_dict": model.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - "loss": loss, # Assuming 'loss' is your loss variable - "random_state_pytorch": torch.get_rng_state(), - "random_state_np": np.random.get_state(), - "random_state_python": random.getstate(), - # Add scheduler state_dict if you're using a scheduler - # 'scheduler_state_dict': scheduler.state_dict(), - } - torch.save(checkpoint, f"checkpoint_{epoch}_{i}.pth") - start_time = time.time() with open( tokenized_data_path + "tokenized_train_{}.txt".format(i), "r" ) as f: @@ -337,6 +322,21 @@ def main(): ) ) running_loss = 0 + if time.time() - start_time > save_interval: + checkpoint = { + "epoch": epoch, + "batch_idx": i, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "loss": loss, # Assuming 'loss' is your loss variable + "random_state_pytorch": torch.get_rng_state(), + "random_state_np": np.random.get_state(), + "random_state_python": random.getstate(), + # Add scheduler state_dict if you're using a scheduler + # 'scheduler_state_dict': scheduler.state_dict(), + } + torch.save(checkpoint, f"{output_dir}/checkpoint_{epoch}_{i}.pth") + start_time = time.time() overall_step += 1 piece_num += 1 From cbd071ecb1ccae46154293527094584485d3f7d5 Mon Sep 17 00:00:00 2001 From: Yuxiang Date: Wed, 3 Jan 2024 17:40:11 +0100 Subject: [PATCH 4/4] update --- .gitignore | 1 + train.py | 32 ++++++++++++++++++++++++++++---- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index d706e6cb..5f6b5f5f 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ data/ .idea/vcs.xml .idea tensorboard_summary +/checkpoint \ No newline at end of file diff --git a/train.py b/train.py index 26c2c529..4f650b4c 100644 --- a/train.py +++ b/train.py @@ -139,6 +139,14 @@ def main(): "--vocab_bpe", default="tokenizations/vocab.bpe", type=str, help="vocab.bpe" ) + parser.add_argument( + "--input_file", + type=str, + help="Path to the .ipt file to continue training", + default=None, + ) + # ... rest of the argument setup ... + args = parser.parse_args() print("args:\n" + args.__repr__()) @@ -209,10 +217,8 @@ def main(): model = transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained( args.pretrained_model ) - model.train() - model.to(device) - num_parameters = 0 + parameters = model.parameters() for parameter in parameters: num_parameters += parameter.numel() @@ -231,6 +237,24 @@ def main(): scheduler = transformers.WarmupLinearSchedule( optimizer, warmup_steps=warmup_steps, t_total=total_steps ) + + if args.input_file is not None: + checkpoint = torch.load(args.input_file) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + epoch = checkpoint["epoch"] + batch_idx = checkpoint["batch_idx"] + loss = checkpoint["loss"] + torch.set_rng_state(checkpoint["random_state_pytorch"]) + np.random.set_state(checkpoint["random_state_np"]) + random.setstate(checkpoint["random_state_python"]) + model.eval() + model.to(device) + print("Loaded checkpoint from %s" % args.input_file) + + model.train() + model.to(device) + if fp16: try: from apex import amp @@ -335,7 +359,7 @@ def main(): # Add scheduler state_dict if you're using a scheduler # 'scheduler_state_dict': scheduler.state_dict(), } - torch.save(checkpoint, f"{output_dir}/checkpoint_{epoch}_{i}.pth") + torch.save(checkpoint, f"{output_dir}/checkpoint.pth") start_time = time.time() overall_step += 1 piece_num += 1