Skip to content

Commit

Permalink
fix generation when input is shorter than max length
Browse files Browse the repository at this point in the history
  • Loading branch information
benob committed Sep 26, 2021
1 parent 2bf6c1c commit bcc0643
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 48 deletions.
22 changes: 12 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ In training, batches containing less that 256 tokens are simulated by drawing
uniformly a length and replacing all tokens and labels after that point with
padding (called Cut-drop).

Changelong:
* Fix generation when input is smaller than max length

Installation
------------
Expand All @@ -50,21 +52,21 @@ python recasepunc.py predict checkpoint/path.iteration < input.txt > output.txt
Models
------

* French: [fr-txt.large.19000](https://github.com/benob/recasepunc/releases/download/0.1/fr-txt.large.19000) trained on 160M tokens from Common Crawl
* French: [fr-txt.large.19000](https://github.com/benob/recasepunc/releases/download/0.2/fr-txt.large.19000) trained on 160M tokens from Common Crawl
* Iterations: 19000
* Batch size: 16
* Max length: 256
* Seed: 871253
* Cut-drop probability: 0.1
* Train loss: 0.007630254164338112
* Valid loss: 0.016180261224508285
* Recasing accuracy: 96.63
* Punctuation accuracy: 94.96
* All punctuation F-score: 68.04
* Comma F-score: 67.87
* Period F-score: 73.83
* Question F-score: 58.82
* Exclamation mark F-score: 15.38
* Train loss: 0.021128975618630648
* Valid loss: 0.015684964135289192
* Recasing accuracy: 96.73
* Punctuation accuracy: 95.02
* All punctuation F-score: 67.79
* Comma F-score: 67.94
* Period F-score: 72.91
* Question F-score: 57.57
* Exclamation mark F-score: 15.78
* Training data: First 100M words from [Common Crawl](http://data.statmt.org/cc-100/fr.txt.xz])


Expand Down
90 changes: 52 additions & 38 deletions recasepunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@
'OTHER': 3,
}

tokenizer = AutoTokenizer.from_pretrained(flavor, do_lower_case=True)
pad_token_id = tokenizer.pad_token_id
cls_token_id = tokenizer.bos_token_id
cls_token = tokenizer.bos_token
sep_token_id = tokenizer.sep_token_id
sep_token = tokenizer.sep_token

class Model(nn.Module):
def __init__(self, flavor=flavor, device=device):
Expand All @@ -60,20 +66,13 @@ def __init__(self, flavor=flavor, device=device):

def forward(self, x):
output = self.bert(x)
representations = self.dropout(F.gelu(output['last_hidden_state']))
representations = self.dropout(F.gelu(output['previous_hidden_state']))
punc = self.punc(representations)
case = self.case(representations)
return punc, case


# randomly drop the end of sequences
def drop_end(rate, x, y):
for i, dropped in enumerate(torch.rand((len(x),)) > rate):
if dropped:
length = random.randint(1, len(x[i]))
x[i, length:] = 0
y[i, length:] = 0

# randomly create sequences that align to punctuation boundaries
def drop_at_boundaries(rate, x, y):
#num_dropped = 0
for i, dropped in enumerate(torch.rand((len(x),)) < rate):
Expand All @@ -83,15 +82,27 @@ def drop_at_boundaries(rate, x, y):
if len(indices) < 2:
continue
start = indices[0] + 1
end = indices[random.randint(1, len(indices) - 1)]
end = indices[random.randint(1, len(indices) - 1)] + 1
length = end - start
if length + 2 > len(x[i]):
continue
#print(y[i,:,0])
#print(indices)
#print(x[i].tolist())
#print(y[i,:,0].tolist())
#print(y[i,:,1].tolist())
#print(indices.tolist())
#print(start.item(), end.item(), length.item())
x[i, 0: length] = x[i, start: end].clone()
x[i, length:] = 0
y[i, 0: length] = y[i, start: end].clone()
y[i, length:] = 0
x[i, 0] = cls_token_id
x[i, 1: length + 1] = x[i, start: end].clone()
x[i, length + 1] = sep_token_id
x[i, length + 2:] = pad_token_id
y[i, 0] = 0
y[i, 1: length + 1] = y[i, start: end].clone()
y[i, length + 1:] = 0
#print(x[i].tolist())
#print(y[i,:,0].tolist())
#print(y[i,:,1].tolist())
#print()
#num_dropped += 1
#print(num_dropped / len(x))

Expand Down Expand Up @@ -150,7 +161,6 @@ def fit(model, checkpoint_path, train_loader, valid_loader, iterations, valid_pe
for x, y in tqdm(train_loader):
x = x.long().to(device)
y = y.long().to(device)
#drop_end(0.1, x, y)
drop_at_boundaries(0.1, x, y)
y1 = y[:,:,0]
y2 = y[:,:,1]
Expand Down Expand Up @@ -243,12 +253,12 @@ def __init__(self, checkpoint_path, flavor=flavor, device=device):
self.rev_punc = {b: a for a, b in punctuation.items()}

def tokenize(self, text):
return self.tokenizer.tokenize(text.lower())
return [cls_token] + tokenizer.tokenize(text.lower()) + [sep_token]

def predict(self, tokens, getter=lambda x: x, max_length=max_length, device=device):
if type(tokens) == str:
tokens = self.tokenize(tokens)
last_label = punctuation['PERIOD']
previous_label = punctuation['PERIOD']
for start in range(0, len(tokens), max_length):
instance = tokens[start: start + max_length]
if type(getter(instance[0])) == str:
Expand All @@ -261,14 +271,16 @@ def predict(self, tokens, getter=lambda x: x, max_length=max_length, device=devi
y_scores1, y_scores2 = self.model(x)
y_pred1 = torch.max(y_scores1, 2)[1]
y_pred2 = torch.max(y_scores2, 2)[1]
for i, token, punc_label, case_label in zip(range(len(instance)), instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
if last_label != None and last_label > 1:
for i, id, token, punc_label, case_label in zip(range(len(instance)), ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
if id == cls_token_id or id == sep_token_id:
continue
if previous_label != None and previous_label > 1:
if case_label in [0, 3]: # LOWER, OTHER
case_label = case['CAPITALIZE']
if i + start == len(tokens) - 1 and punc_label == punctuation['O']:
if i + start == len(tokens) - 2 and punc_label == punctuation['O']:
punc_label = punctuation['PERIOD']
yield (token, self.rev_case[case_label], self.rev_punc[punc_label])
last_label = punc_label
previous_label = punc_label

def map_case_label(self, token, case_label):
if token.endswith('</w>'):
Expand All @@ -292,9 +304,9 @@ def generate_predictions(checkpoint_path, debug=False):
rev_punc = {b: a for a, b in punctuation.items()}

for line in sys.stdin:
tokens = tokenizer.tokenize(line.lower())
tokens = [cls_token] + tokenizer.tokenize(line.lower()) + [sep_token]
was_word = False
last_label = punctuation['PERIOD']
previous_label = punctuation['PERIOD']
for start in range(0, len(tokens), max_length):
instance = tokens[start: start + max_length]
ids = tokenizer.convert_tokens_to_ids(instance)
Expand All @@ -305,13 +317,15 @@ def generate_predictions(checkpoint_path, debug=False):
y_scores1, y_scores2 = model(x)
y_pred1 = torch.max(y_scores1, 2)[1]
y_pred2 = torch.max(y_scores2, 2)[1]
for token, punc_label, case_label in zip(instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
if debug:
print(token, punc_label, case_label, file=sys.stderr)
if last_label != None and last_label > 1:
print(id, token, punc_label, case_label, file=sys.stderr)
if id == cls_token_id or id == sep_token_id:
continue
if previous_label != None and previous_label > 1:
if case_label in [0, 3]: # LOWER, OTHER
case_label = case['CAPITALIZE']
last_label = punc_label
previous_label = punc_label
if token.endswith('</w>'):
cased_token = recase(token[:-4], case_label)
if was_word:
Expand All @@ -324,7 +338,7 @@ def generate_predictions(checkpoint_path, debug=False):
print(' ', end='')
print(cased_token, end='')
was_word = False
if last_label == 0:
if previous_label == 0:
print('.', end='')
print()

Expand Down Expand Up @@ -390,20 +404,20 @@ def preprocess_text():
if line.strip() != '':
for sentence in splitsents([normalize(line)]):
tokens = tokenize(sentence)
last_token = None
previous_token = None
for token in tokens:
if token in punctuation:
if last_token != None:
print(last_token, punctuation[token], sep='\t')
last_token = None
if previous_token != None:
print(previous_token, punctuation[token], sep='\t')
previous_token = None
elif not re.search('[\p{Ll}\p{Lu}\d]', token): # remove non-alphanumeric tokens
continue
else:
if last_token != None:
print(last_token, 'O', sep='\t')
last_token = token
if last_token != None:
print(last_token, 'PERIOD', sep='\t')
if previous_token != None:
print(previous_token, 'O', sep='\t')
previous_token = token
if previous_token != None:
print(previous_token, 'PERIOD', sep='\t')

if __name__ == '__main__':
command = sys.argv[1]
Expand Down

0 comments on commit bcc0643

Please sign in to comment.