From dbbbcfd05ade2692ca1881ae121f9509c3fd8c2e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 15 Apr 2021 12:00:55 -0700 Subject: [PATCH] offer support for chinese --- README.md | 14 +++++++++++++ dalle_pytorch/tokenizer.py | 43 +++++++++++++++++++++++++++++++++++++- generate.py | 11 +++++++++- setup.py | 3 ++- train_dalle.py | 6 +++++- 5 files changed, 73 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 441e4a0d..29bb63de 100644 --- a/README.md +++ b/README.md @@ -440,6 +440,20 @@ ex. $ python train_dalle.py --image_text_folder ./path/to/data --bpe_path ./path/to/bpe.json ``` +#### Chinese + +You can train with a pretrained chinese tokenizer offered by Huggingface 🤗 by simply passing in an extra flag `--chinese` + +ex. + +```sh +$ python train_dalle.py --chinese --image_text_folder ./path/to/data +``` + +```sh +$ python generate.py --chinese --text '追老鼠的猫' +``` + ## Citations ```bibtex diff --git a/dalle_pytorch/tokenizer.py b/dalle_pytorch/tokenizer.py index 7f6bffb6..ef6a5744 100644 --- a/dalle_pytorch/tokenizer.py +++ b/dalle_pytorch/tokenizer.py @@ -2,7 +2,9 @@ # to give users a quick easy start to training DALL-E without doing BPE import torch + from tokenizers import Tokenizer +from transformers import BertTokenizer import html import os @@ -123,6 +125,7 @@ def encode(self, text): def decode(self, tokens, remove_start_end = True): if torch.is_tensor(tokens): tokens = tokens.tolist() + if remove_start_end: tokens = [token for token in tokens if token not in (49406, 40407, 0)] text = ''.join([self.decoder[token] for token in tokens]) @@ -160,7 +163,10 @@ def __init__(self, bpe_path = None): self.vocab_size = tokenizer.get_vocab_size() def decode(self, tokens): - tokens = [token for token in tokens.tolist() if token not in (0,)] + if torch.is_tensor(tokens): + tokens = tokens.tolist() + + tokens = [token for token in tokens if token not in (0,)] return self.tokenizer.decode(tokens, skip_special_tokens = True) def encode(self, text): @@ -182,3 +188,38 @@ def tokenize(self, texts, context_length = 256, truncate_text = False): result[i, :len(tokens)] = torch.tensor(tokens) return result + +# chinese tokenizer + +class ChineseTokenizer: + def __init__(self): + tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') + self.tokenizer = tokenizer + self.vocab_size = tokenizer.vocab_size + + def decode(self, tokens): + if torch.is_tensor(tokens): + tokens = tokens.tolist() + + tokens = [token for token in tokens if token not in (0,)] + return self.tokenizer.decode(tokens) + + def encode(self, text): + return torch.tensor(self.tokenizer.encode(text, add_special_tokens = False)) + + def tokenize(self, texts, context_length = 256, truncate_text = False): + if isinstance(texts, str): + texts = [texts] + + all_tokens = [self.encode(text) for text in texts] + + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate_text: + tokens = tokens[:context_length] + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/generate.py b/generate.py index ab6c725c..cb699dba 100644 --- a/generate.py +++ b/generate.py @@ -43,14 +43,23 @@ parser.add_argument('--bpe_path', type = str, help='path to your huggingface BPE json file') +parser.add_argument('--chinese', dest='chinese', action = 'store_true') + parser.add_argument('--taming', dest='taming', action='store_true') args = parser.parse_args() +# helper fns + +def exists(val): + return val is not None + # tokenizer -if args.bpe_path is not None: +if exists(args.bpe_path): tokenizer = HugTokenizer(args.bpe_path) +elif args.chinese: + tokenizer = ChineseTokenizer() # load DALL-E diff --git a/setup.py b/setup.py index b7973fc4..98a06416 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'dalle-pytorch', packages = find_packages(), include_package_data = True, - version = '0.10.1', + version = '0.10.2', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang', @@ -27,6 +27,7 @@ 'tokenizers', 'torch>=1.6', 'torchvision', + 'transformers', 'tqdm' ], classifiers=[ diff --git a/train_dalle.py b/train_dalle.py index 04f91163..3b59d811 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -21,7 +21,7 @@ from dalle_pytorch import distributed_utils from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE1024, DiscreteVAE, DALLE -from dalle_pytorch.tokenizer import tokenizer, HugTokenizer +from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer # argument parsing @@ -41,6 +41,8 @@ parser.add_argument('--truncate_captions', dest='truncate_captions', help='Captions passed in which exceed the max token length will be truncated if this is set.') +parser.add_argument('--chinese', dest='chinese', action = 'store_true') + parser.add_argument('--taming', dest='taming', action='store_true') parser.add_argument('--bpe_path', type = str, @@ -89,6 +91,8 @@ def exists(val): if exists(args.bpe_path): tokenizer = HugTokenizer(args.bpe_path) +elif args.chinese: + tokenizer = ChineseTokenizer() # reconstitute vae