From bcc0643d9759e161dc08297e4ff04462ae6c0b41 Mon Sep 17 00:00:00 2001 From: Benoit Favre Date: Sun, 26 Sep 2021 13:04:28 +0200 Subject: [PATCH] fix generation when input is shorter than max length --- README.md | 22 +++++++------ recasepunc.py | 90 +++++++++++++++++++++++++++++---------------------- 2 files changed, 64 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index 3d1fe75..18252e2 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,8 @@ In training, batches containing less that 256 tokens are simulated by drawing uniformly a length and replacing all tokens and labels after that point with padding (called Cut-drop). +Changelong: +* Fix generation when input is smaller than max length Installation ------------ @@ -50,21 +52,21 @@ python recasepunc.py predict checkpoint/path.iteration < input.txt > output.txt Models ------ -* French: [fr-txt.large.19000](https://github.com/benob/recasepunc/releases/download/0.1/fr-txt.large.19000) trained on 160M tokens from Common Crawl +* French: [fr-txt.large.19000](https://github.com/benob/recasepunc/releases/download/0.2/fr-txt.large.19000) trained on 160M tokens from Common Crawl * Iterations: 19000 * Batch size: 16 * Max length: 256 * Seed: 871253 * Cut-drop probability: 0.1 - * Train loss: 0.007630254164338112 - * Valid loss: 0.016180261224508285 - * Recasing accuracy: 96.63 - * Punctuation accuracy: 94.96 - * All punctuation F-score: 68.04 - * Comma F-score: 67.87 - * Period F-score: 73.83 - * Question F-score: 58.82 - * Exclamation mark F-score: 15.38 + * Train loss: 0.021128975618630648 + * Valid loss: 0.015684964135289192 + * Recasing accuracy: 96.73 + * Punctuation accuracy: 95.02 + * All punctuation F-score: 67.79 + * Comma F-score: 67.94 + * Period F-score: 72.91 + * Question F-score: 57.57 + * Exclamation mark F-score: 15.78 * Training data: First 100M words from [Common Crawl](http://data.statmt.org/cc-100/fr.txt.xz]) diff --git a/recasepunc.py b/recasepunc.py index 0c9e592..081b2c6 100644 --- a/recasepunc.py +++ b/recasepunc.py @@ -48,6 +48,12 @@ 'OTHER': 3, } +tokenizer = AutoTokenizer.from_pretrained(flavor, do_lower_case=True) +pad_token_id = tokenizer.pad_token_id +cls_token_id = tokenizer.bos_token_id +cls_token = tokenizer.bos_token +sep_token_id = tokenizer.sep_token_id +sep_token = tokenizer.sep_token class Model(nn.Module): def __init__(self, flavor=flavor, device=device): @@ -60,20 +66,13 @@ def __init__(self, flavor=flavor, device=device): def forward(self, x): output = self.bert(x) - representations = self.dropout(F.gelu(output['last_hidden_state'])) + representations = self.dropout(F.gelu(output['previous_hidden_state'])) punc = self.punc(representations) case = self.case(representations) return punc, case -# randomly drop the end of sequences -def drop_end(rate, x, y): - for i, dropped in enumerate(torch.rand((len(x),)) > rate): - if dropped: - length = random.randint(1, len(x[i])) - x[i, length:] = 0 - y[i, length:] = 0 - +# randomly create sequences that align to punctuation boundaries def drop_at_boundaries(rate, x, y): #num_dropped = 0 for i, dropped in enumerate(torch.rand((len(x),)) < rate): @@ -83,15 +82,27 @@ def drop_at_boundaries(rate, x, y): if len(indices) < 2: continue start = indices[0] + 1 - end = indices[random.randint(1, len(indices) - 1)] + end = indices[random.randint(1, len(indices) - 1)] + 1 length = end - start + if length + 2 > len(x[i]): + continue #print(y[i,:,0]) - #print(indices) + #print(x[i].tolist()) + #print(y[i,:,0].tolist()) + #print(y[i,:,1].tolist()) + #print(indices.tolist()) #print(start.item(), end.item(), length.item()) - x[i, 0: length] = x[i, start: end].clone() - x[i, length:] = 0 - y[i, 0: length] = y[i, start: end].clone() - y[i, length:] = 0 + x[i, 0] = cls_token_id + x[i, 1: length + 1] = x[i, start: end].clone() + x[i, length + 1] = sep_token_id + x[i, length + 2:] = pad_token_id + y[i, 0] = 0 + y[i, 1: length + 1] = y[i, start: end].clone() + y[i, length + 1:] = 0 + #print(x[i].tolist()) + #print(y[i,:,0].tolist()) + #print(y[i,:,1].tolist()) + #print() #num_dropped += 1 #print(num_dropped / len(x)) @@ -150,7 +161,6 @@ def fit(model, checkpoint_path, train_loader, valid_loader, iterations, valid_pe for x, y in tqdm(train_loader): x = x.long().to(device) y = y.long().to(device) - #drop_end(0.1, x, y) drop_at_boundaries(0.1, x, y) y1 = y[:,:,0] y2 = y[:,:,1] @@ -243,12 +253,12 @@ def __init__(self, checkpoint_path, flavor=flavor, device=device): self.rev_punc = {b: a for a, b in punctuation.items()} def tokenize(self, text): - return self.tokenizer.tokenize(text.lower()) + return [cls_token] + tokenizer.tokenize(text.lower()) + [sep_token] def predict(self, tokens, getter=lambda x: x, max_length=max_length, device=device): if type(tokens) == str: tokens = self.tokenize(tokens) - last_label = punctuation['PERIOD'] + previous_label = punctuation['PERIOD'] for start in range(0, len(tokens), max_length): instance = tokens[start: start + max_length] if type(getter(instance[0])) == str: @@ -261,14 +271,16 @@ def predict(self, tokens, getter=lambda x: x, max_length=max_length, device=devi y_scores1, y_scores2 = self.model(x) y_pred1 = torch.max(y_scores1, 2)[1] y_pred2 = torch.max(y_scores2, 2)[1] - for i, token, punc_label, case_label in zip(range(len(instance)), instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]): - if last_label != None and last_label > 1: + for i, id, token, punc_label, case_label in zip(range(len(instance)), ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]): + if id == cls_token_id or id == sep_token_id: + continue + if previous_label != None and previous_label > 1: if case_label in [0, 3]: # LOWER, OTHER case_label = case['CAPITALIZE'] - if i + start == len(tokens) - 1 and punc_label == punctuation['O']: + if i + start == len(tokens) - 2 and punc_label == punctuation['O']: punc_label = punctuation['PERIOD'] yield (token, self.rev_case[case_label], self.rev_punc[punc_label]) - last_label = punc_label + previous_label = punc_label def map_case_label(self, token, case_label): if token.endswith(''): @@ -292,9 +304,9 @@ def generate_predictions(checkpoint_path, debug=False): rev_punc = {b: a for a, b in punctuation.items()} for line in sys.stdin: - tokens = tokenizer.tokenize(line.lower()) + tokens = [cls_token] + tokenizer.tokenize(line.lower()) + [sep_token] was_word = False - last_label = punctuation['PERIOD'] + previous_label = punctuation['PERIOD'] for start in range(0, len(tokens), max_length): instance = tokens[start: start + max_length] ids = tokenizer.convert_tokens_to_ids(instance) @@ -305,13 +317,15 @@ def generate_predictions(checkpoint_path, debug=False): y_scores1, y_scores2 = model(x) y_pred1 = torch.max(y_scores1, 2)[1] y_pred2 = torch.max(y_scores2, 2)[1] - for token, punc_label, case_label in zip(instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]): + for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]): if debug: - print(token, punc_label, case_label, file=sys.stderr) - if last_label != None and last_label > 1: + print(id, token, punc_label, case_label, file=sys.stderr) + if id == cls_token_id or id == sep_token_id: + continue + if previous_label != None and previous_label > 1: if case_label in [0, 3]: # LOWER, OTHER case_label = case['CAPITALIZE'] - last_label = punc_label + previous_label = punc_label if token.endswith(''): cased_token = recase(token[:-4], case_label) if was_word: @@ -324,7 +338,7 @@ def generate_predictions(checkpoint_path, debug=False): print(' ', end='') print(cased_token, end='') was_word = False - if last_label == 0: + if previous_label == 0: print('.', end='') print() @@ -390,20 +404,20 @@ def preprocess_text(): if line.strip() != '': for sentence in splitsents([normalize(line)]): tokens = tokenize(sentence) - last_token = None + previous_token = None for token in tokens: if token in punctuation: - if last_token != None: - print(last_token, punctuation[token], sep='\t') - last_token = None + if previous_token != None: + print(previous_token, punctuation[token], sep='\t') + previous_token = None elif not re.search('[\p{Ll}\p{Lu}\d]', token): # remove non-alphanumeric tokens continue else: - if last_token != None: - print(last_token, 'O', sep='\t') - last_token = token - if last_token != None: - print(last_token, 'PERIOD', sep='\t') + if previous_token != None: + print(previous_token, 'O', sep='\t') + previous_token = token + if previous_token != None: + print(previous_token, 'PERIOD', sep='\t') if __name__ == '__main__': command = sys.argv[1]