From 2bf6c1cec5f4693e8a0ec4445ab2da3563f56105 Mon Sep 17 00:00:00 2001 From: Benoit Favre Date: Fri, 24 Sep 2021 16:15:22 +0200 Subject: [PATCH] add predictor class --- recasepunc.py | 109 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 92 insertions(+), 17 deletions(-) diff --git a/recasepunc.py b/recasepunc.py index c0c1b03..0c9e592 100644 --- a/recasepunc.py +++ b/recasepunc.py @@ -29,6 +29,8 @@ random.seed(seed) np.random.seed(seed) +# NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label! + punctuation = { 'O': 0, 'COMMA': 1, @@ -37,7 +39,7 @@ 'EXCLAMATION': 4, } -punctuation_syms = ['', ',', '.', '?', '!'] +punctuation_syms = ['', ',', '.', ' ?', ' !'] case = { 'LOWER': 0, @@ -48,7 +50,7 @@ class Model(nn.Module): - def __init__(self): + def __init__(self, flavor=flavor, device=device): super().__init__() self.bert = AutoModel.from_pretrained(flavor) self.punc = nn.Linear(self.bert.dim, 5) @@ -72,6 +74,27 @@ def drop_end(rate, x, y): x[i, length:] = 0 y[i, length:] = 0 +def drop_at_boundaries(rate, x, y): + #num_dropped = 0 + for i, dropped in enumerate(torch.rand((len(x),)) < rate): + if dropped: + # select all indices that are sentence endings + indices = (y[i,:,0] > 1).nonzero(as_tuple=True)[0] + if len(indices) < 2: + continue + start = indices[0] + 1 + end = indices[random.randint(1, len(indices) - 1)] + length = end - start + #print(y[i,:,0]) + #print(indices) + #print(start.item(), end.item(), length.item()) + x[i, 0: length] = x[i, start: end].clone() + x[i, length:] = 0 + y[i, 0: length] = y[i, start: end].clone() + y[i, length:] = 0 + #num_dropped += 1 + #print(num_dropped / len(x)) + def compute_performance(model, loader): criterion = nn.CrossEntropyLoss() @@ -127,7 +150,8 @@ def fit(model, checkpoint_path, train_loader, valid_loader, iterations, valid_pe for x, y in tqdm(train_loader): x = x.long().to(device) y = y.long().to(device) - drop_end(0.1, x, y) + #drop_end(0.1, x, y) + drop_at_boundaries(0.1, x, y) y1 = y[:,:,0] y2 = y[:,:,1] optimizer.zero_grad() @@ -207,6 +231,56 @@ def recase(token, label): return token +class CasePuncPredictor: + def __init__(self, checkpoint_path, flavor=flavor, device=device): + self.model = Model(flavor) + loaded = torch.load(checkpoint_path, map_location=device) + self.model.load_state_dict(loaded['model_state_dict']) + + self.tokenizer = AutoTokenizer.from_pretrained(flavor, do_lower_case=True) + + self.rev_case = {b: a for a, b in case.items()} + self.rev_punc = {b: a for a, b in punctuation.items()} + + def tokenize(self, text): + return self.tokenizer.tokenize(text.lower()) + + def predict(self, tokens, getter=lambda x: x, max_length=max_length, device=device): + if type(tokens) == str: + tokens = self.tokenize(tokens) + last_label = punctuation['PERIOD'] + for start in range(0, len(tokens), max_length): + instance = tokens[start: start + max_length] + if type(getter(instance[0])) == str: + ids = self.tokenizer.convert_tokens_to_ids(getter(token) for token in instance) + else: + ids = [getter(token) for token in instance] + if len(ids) < max_length: + ids += [0] * (max_length - len(ids)) + x = torch.tensor([ids]).long().to(device) + y_scores1, y_scores2 = self.model(x) + y_pred1 = torch.max(y_scores1, 2)[1] + y_pred2 = torch.max(y_scores2, 2)[1] + for i, token, punc_label, case_label in zip(range(len(instance)), instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]): + if last_label != None and last_label > 1: + if case_label in [0, 3]: # LOWER, OTHER + case_label = case['CAPITALIZE'] + if i + start == len(tokens) - 1 and punc_label == punctuation['O']: + punc_label = punctuation['PERIOD'] + yield (token, self.rev_case[case_label], self.rev_punc[punc_label]) + last_label = punc_label + + def map_case_label(self, token, case_label): + if token.endswith(''): + token = token[:-4] + return recase(token, case[case_label]) + + def map_punc_label(self, token, punc_label): + if token.endswith(''): + token = token[:-4] + return token + punctuation_syms[punctuation[punc_label]] + + def generate_predictions(checkpoint_path, debug=False): model = Model() loaded = torch.load(checkpoint_path, map_location=device) @@ -331,17 +405,18 @@ def preprocess_text(): if last_token != None: print(last_token, 'PERIOD', sep='\t') -command = sys.argv[1] -if command == 'train': - train(*sys.argv[2:]) -elif command == 'eval': - run_eval(*sys.argv[2:]) -elif command == 'predict': - generate_predictions(*sys.argv[2:]) -elif command == 'tensorize': - make_tensors(*sys.argv[2:]) -elif command == 'preprocess': - preprocess_text() -else: - print('usage: %s train|eval|predict|tensorize|preprocess' % sys.argv[0]) - sys.exit(1) +if __name__ == '__main__': + command = sys.argv[1] + if command == 'train': + train(*sys.argv[2:]) + elif command == 'eval': + run_eval(*sys.argv[2:]) + elif command == 'predict': + generate_predictions(*sys.argv[2:]) + elif command == 'tensorize': + make_tensors(*sys.argv[2:]) + elif command == 'preprocess': + preprocess_text() + else: + print('usage: %s train|eval|predict|tensorize|preprocess' % sys.argv[0]) + sys.exit(1)