Skip to content

Commit

Permalink
add predictor class
Browse files Browse the repository at this point in the history
  • Loading branch information
benob committed Sep 24, 2021
1 parent 16fa430 commit 2bf6c1c
Showing 1 changed file with 92 additions and 17 deletions.
109 changes: 92 additions & 17 deletions recasepunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
random.seed(seed)
np.random.seed(seed)

# NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label!

punctuation = {
'O': 0,
'COMMA': 1,
Expand All @@ -37,7 +39,7 @@
'EXCLAMATION': 4,
}

punctuation_syms = ['', ',', '.', '?', '!']
punctuation_syms = ['', ',', '.', ' ?', ' !']

case = {
'LOWER': 0,
Expand All @@ -48,7 +50,7 @@


class Model(nn.Module):
def __init__(self):
def __init__(self, flavor=flavor, device=device):
super().__init__()
self.bert = AutoModel.from_pretrained(flavor)
self.punc = nn.Linear(self.bert.dim, 5)
Expand All @@ -72,6 +74,27 @@ def drop_end(rate, x, y):
x[i, length:] = 0
y[i, length:] = 0

def drop_at_boundaries(rate, x, y):
#num_dropped = 0
for i, dropped in enumerate(torch.rand((len(x),)) < rate):
if dropped:
# select all indices that are sentence endings
indices = (y[i,:,0] > 1).nonzero(as_tuple=True)[0]
if len(indices) < 2:
continue
start = indices[0] + 1
end = indices[random.randint(1, len(indices) - 1)]
length = end - start
#print(y[i,:,0])
#print(indices)
#print(start.item(), end.item(), length.item())
x[i, 0: length] = x[i, start: end].clone()
x[i, length:] = 0
y[i, 0: length] = y[i, start: end].clone()
y[i, length:] = 0
#num_dropped += 1
#print(num_dropped / len(x))


def compute_performance(model, loader):
criterion = nn.CrossEntropyLoss()
Expand Down Expand Up @@ -127,7 +150,8 @@ def fit(model, checkpoint_path, train_loader, valid_loader, iterations, valid_pe
for x, y in tqdm(train_loader):
x = x.long().to(device)
y = y.long().to(device)
drop_end(0.1, x, y)
#drop_end(0.1, x, y)
drop_at_boundaries(0.1, x, y)
y1 = y[:,:,0]
y2 = y[:,:,1]
optimizer.zero_grad()
Expand Down Expand Up @@ -207,6 +231,56 @@ def recase(token, label):
return token


class CasePuncPredictor:
def __init__(self, checkpoint_path, flavor=flavor, device=device):
self.model = Model(flavor)
loaded = torch.load(checkpoint_path, map_location=device)
self.model.load_state_dict(loaded['model_state_dict'])

self.tokenizer = AutoTokenizer.from_pretrained(flavor, do_lower_case=True)

self.rev_case = {b: a for a, b in case.items()}
self.rev_punc = {b: a for a, b in punctuation.items()}

def tokenize(self, text):
return self.tokenizer.tokenize(text.lower())

def predict(self, tokens, getter=lambda x: x, max_length=max_length, device=device):
if type(tokens) == str:
tokens = self.tokenize(tokens)
last_label = punctuation['PERIOD']
for start in range(0, len(tokens), max_length):
instance = tokens[start: start + max_length]
if type(getter(instance[0])) == str:
ids = self.tokenizer.convert_tokens_to_ids(getter(token) for token in instance)
else:
ids = [getter(token) for token in instance]
if len(ids) < max_length:
ids += [0] * (max_length - len(ids))
x = torch.tensor([ids]).long().to(device)
y_scores1, y_scores2 = self.model(x)
y_pred1 = torch.max(y_scores1, 2)[1]
y_pred2 = torch.max(y_scores2, 2)[1]
for i, token, punc_label, case_label in zip(range(len(instance)), instance, y_pred1[0].tolist()[:len(instance)], y_pred2[0].tolist()[:len(instance)]):
if last_label != None and last_label > 1:
if case_label in [0, 3]: # LOWER, OTHER
case_label = case['CAPITALIZE']
if i + start == len(tokens) - 1 and punc_label == punctuation['O']:
punc_label = punctuation['PERIOD']
yield (token, self.rev_case[case_label], self.rev_punc[punc_label])
last_label = punc_label

def map_case_label(self, token, case_label):
if token.endswith('</w>'):
token = token[:-4]
return recase(token, case[case_label])

def map_punc_label(self, token, punc_label):
if token.endswith('</w>'):
token = token[:-4]
return token + punctuation_syms[punctuation[punc_label]]


def generate_predictions(checkpoint_path, debug=False):
model = Model()
loaded = torch.load(checkpoint_path, map_location=device)
Expand Down Expand Up @@ -331,17 +405,18 @@ def preprocess_text():
if last_token != None:
print(last_token, 'PERIOD', sep='\t')

command = sys.argv[1]
if command == 'train':
train(*sys.argv[2:])
elif command == 'eval':
run_eval(*sys.argv[2:])
elif command == 'predict':
generate_predictions(*sys.argv[2:])
elif command == 'tensorize':
make_tensors(*sys.argv[2:])
elif command == 'preprocess':
preprocess_text()
else:
print('usage: %s train|eval|predict|tensorize|preprocess' % sys.argv[0])
sys.exit(1)
if __name__ == '__main__':
command = sys.argv[1]
if command == 'train':
train(*sys.argv[2:])
elif command == 'eval':
run_eval(*sys.argv[2:])
elif command == 'predict':
generate_predictions(*sys.argv[2:])
elif command == 'tensorize':
make_tensors(*sys.argv[2:])
elif command == 'preprocess':
preprocess_text()
else:
print('usage: %s train|eval|predict|tensorize|preprocess' % sys.argv[0])
sys.exit(1)

0 comments on commit 2bf6c1c

Please sign in to comment.