diff --git a/PCFG.py b/PCFG.py new file mode 100755 index 0000000..596f1f7 --- /dev/null +++ b/PCFG.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import itertools +import random + +class PCFG(nn.Module): + def __init__(self, nt_states, t_states): + super(PCFG, self).__init__() + self.nt_states = nt_states + self.t_states = t_states + self.states = nt_states + t_states + self.huge = 1e9 + + def logadd(self, x, y): + d = torch.max(x,y) + return torch.log(torch.exp(x-d) + torch.exp(y-d)) + d + + def logsumexp(self, x, dim=1): + d = torch.max(x, dim)[0] + if x.dim() == 1: + return torch.log(torch.exp(x - d).sum(dim)) + d + else: + return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d + + def _inside(self, unary_scores, rule_scores, root_scores): + #inside step + #unary scores : b x n x T + #rule scores : b x NT x (NT+T) x (NT+T) + #root : b x NT + + # statistics + batch_size = unary_scores.size(0) + n = unary_scores.size(1) + + # uses conventional python numbering scheme: [s, t] represents span [s, t) + # this scheme facilitates fast computation + # f[s, t] = logsumexp(f[s, :] * f[:, t]) + self.beta = unary_scores.new(batch_size, n + 1, n + 1, self.states).fill_(-self.huge) + + # initialization: f[k, k+1] + for k in range(n): + for state in range(self.t_states): + self.beta[:, k, k+1, self.nt_states + state] = unary_scores[:, k, state] + + # span length w, at least 2 + for w in np.arange(2, n+1): + + # start point s + for s in range(n-w+1): + t = s + w + + f = lambda x:torch.logsumexp(x.view(batch_size, self.nt_states, -1), dim=2) + + if w == 2: + tmp = self.beta[:, s, s+1, self.nt_states:].unsqueeze(2).unsqueeze(1) \ + + self.beta[:, s+1, t, self.nt_states:].unsqueeze(1).unsqueeze(2) \ + + rule_scores[:, :, self.nt_states:, self.nt_states:] + tmp = f(tmp) + elif w == 3: + tmp1 = self.beta[:, s, s+1, self.nt_states:].unsqueeze(2).unsqueeze(1) \ + + self.beta[:, s+1, t, :self.nt_states].unsqueeze(1).unsqueeze(2) \ + + rule_scores[:, :, self.nt_states:, :self.nt_states] + tmp2 = self.beta[:, s, t-1, :self.nt_states].unsqueeze(2).unsqueeze(1) \ + + self.beta[:, t-1, t, self.nt_states:].unsqueeze(1).unsqueeze(2) \ + + rule_scores[:, :, :self.nt_states, self.nt_states:] + tmp = self.logadd(f(tmp1), f(tmp2)) + elif w >= 4: + tmp1 = self.beta[:, s, s+1, self.nt_states:].unsqueeze(2).unsqueeze(1) \ + + self.beta[:, s+1, t, :self.nt_states].unsqueeze(1).unsqueeze(2) \ + + rule_scores[:, :, self.nt_states:, :self.nt_states] + tmp2 = self.beta[:, s, t-1, :self.nt_states].unsqueeze(2).unsqueeze(1) \ + + self.beta[:, t-1, t, self.nt_states:].unsqueeze(1).unsqueeze(2) \ + + rule_scores[:, :, :self.nt_states, self.nt_states:] + tmp3 = self.beta[:, s, s+2:t-1, :self.nt_states].unsqueeze(3).unsqueeze(1) \ + + self.beta[:, s+2:t-1, t, :self.nt_states].unsqueeze(1).unsqueeze(3) \ + + rule_scores[:, :, :self.nt_states, :self.nt_states].unsqueeze(2) + tmp = self.logadd(self.logadd(f(tmp1), f(tmp2)), f(tmp3)) + + self.beta[:, s, t, :self.nt_states] = tmp + log_Z = self.beta[:, 0, n, :self.nt_states] + root_scores + log_Z = self.logsumexp(log_Z, 1) + return log_Z + + def _viterbi(self, unary_scores, rule_scores, root_scores): + #unary scores : b x n x T + #rule scores : b x NT x (NT+T) x (NT+T) + + batch_size = unary_scores.size(0) + n = unary_scores.size(1) + + # dummy rules + rule_scores = torch.cat([rule_scores, \ + rule_scores.new(batch_size, self.t_states, self.states, self.states) \ + .fill_(-self.huge)], dim=1) + + self.scores = unary_scores.new(batch_size, n+1, n+1, self.states).fill_(-self.huge) + self.bp = unary_scores.new(batch_size, n+1, n+1, self.states).fill_(-1) + self.left_bp = unary_scores.new(batch_size, n+1, n+1, self.states).fill_(-1) + self.right_bp = unary_scores.new(batch_size, n+1, n+1, self.states).fill_(-1) + self.argmax = unary_scores.new(batch_size, n, n).fill_(-1) + self.argmax_tags = unary_scores.new(batch_size, n).fill_(-1) + self.spans = [[] for _ in range(batch_size)] + + for k in range(n): + for state in range(self.t_states): + self.scores[:, k, k + 1, self.nt_states + state] = unary_scores[:, k, state] + + for w in np.arange(2, n+1): + for s in range(n-w+1): + t = s + w + + tmp = self.scores[:, s, s+1:t, :].unsqueeze(3).unsqueeze(1) \ + + self.scores[:, s+1:t, t, :].unsqueeze(1).unsqueeze(3) \ + + rule_scores.unsqueeze(2) + + # view once and marginalize + tmp, max_pos = torch.max(tmp.view(batch_size, self.states, -1), dim=2) + + # step by step marginalization + # tmp = self.logsumexp(tmp, dim=4) + # tmp = self.logsumexp(tmp, dim=3) + # tmp = self.logsumexp(tmp, dim=2) + + max_idx = max_pos / (self.states * self.states) + s + 1 + left_child = (max_pos % (self.states * self.states)) / self.states + right_child = max_pos % self.states + + self.scores[:, s, t, :self.nt_states] = tmp[:, :self.nt_states] + + self.bp[:, s, t, :self.nt_states] = max_idx[:, :self.nt_states] + self.left_bp[:, s, t, :self.nt_states] = left_child[:, :self.nt_states] + self.right_bp[:, s, t, :self.nt_states] = right_child[:, :self.nt_states] + max_score = self.scores[:, 0, n, :self.nt_states] + root_scores + max_score, max_idx = torch.max(max_score, 1) + for b in range(batch_size): + self._backtrack(b, 0, n, max_idx[b].item()) + return self.scores, self.argmax, self.spans + + def _backtrack(self, b, s, t, state): + u = int(self.bp[b][s][t][state]) + assert(s < t), "s: %d, t %d"%(s, t) + left_state = int(self.left_bp[b][s][t][state]) + right_state = int(self.right_bp[b][s][t][state]) + self.argmax[b][s][t-1] = 1 + if s == t-1: + self.spans[b].insert(0, (s, t-1, state)) + self.argmax_tags[b][s] = state - self.nt_states + return None + else: + self.spans[b].insert(0, (s, t-1, state)) + self._backtrack(b, s, u, left_state) + self._backtrack(b, u, t, right_state) + return None diff --git a/data.py b/data.py new file mode 100755 index 0000000..4e89692 --- /dev/null +++ b/data.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +import numpy as np +import torch +import pickle + +class Dataset(object): + def __init__(self, data_file, load_dep=False): + data = pickle.load(open(data_file, 'rb')) #get text data + self.sents = self._convert(data['source']).long() + self.other_data = data['other_data'] + self.sent_lengths = self._convert(data['source_l']).long() + self.batch_size = self._convert(data['batch_l']).long() + self.batch_idx = self._convert(data['batch_idx']).long() + self.vocab_size = data['vocab_size'][0] + self.num_batches = self.batch_idx.size(0) + self.word2idx = data['word2idx'] + self.idx2word = data['idx2word'] + self.load_dep = load_dep + + def _convert(self, x): + return torch.from_numpy(np.asarray(x)) + + def __len__(self): + return self.num_batches + + def __getitem__(self, idx): + assert(idx < self.num_batches and idx >= 0) + start_idx = self.batch_idx[idx] + end_idx = start_idx + self.batch_size[idx] + length = self.sent_lengths[idx].item() + sents = self.sents[start_idx:end_idx] + other_data = self.other_data[start_idx:end_idx] + sent_str = [d[0] for d in other_data] + tags = [d[1] for d in other_data] + actions = [d[2] for d in other_data] + binary_tree = [d[3] for d in other_data] + spans = [d[5] for d in other_data] + if(self.load_dep): + heads = [d[7] for d in other_data] + batch_size = self.batch_size[idx].item() + # original data includes , which we don't need + data_batch = [sents[:, 1:length-1], length-2, batch_size, tags, actions, + spans, binary_tree, other_data] + if(self.load_dep): + data_batch.append(heads) + return data_batch diff --git a/lexicalizedPCFG.py b/lexicalizedPCFG.py new file mode 100644 index 0000000..12148e5 --- /dev/null +++ b/lexicalizedPCFG.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import itertools +import random +from torch.cuda import memory_allocated +import pdb + +class LexicalizedPCFG(nn.Module): + # Lexicalized PCFG: + # S → A[x] A ∈ N, x ∈ 𝚺 + # A[x] → B[x] C[y] A, B, C ∈ N ∪ P, x, y ∈ 𝚺 + # A[x] → B[y] C[x] A, B, C ∈ N ∪ P, x, y ∈ 𝚺 + # T[x] → x T ∈ P, x ∈ 𝚺 + + def __init__(self, nt_states, t_states, nt_emission=False, supervised_signals = []): + super(LexicalizedPCFG, self).__init__() + self.nt_states = nt_states + self.t_states = t_states + self.states = nt_states + t_states + self.nt_emission = nt_emission + self.huge = 1e9 + + if(self.nt_emission): + self.word_span_slice = slice(self.states) + else: + self.word_span_slice = slice(self.nt_states,self.states) + + self.supervised_signals = supervised_signals + + # def logadd(self, x, y): + # d = torch.max(x,y) + # return torch.log(torch.exp(x-d) + torch.exp(y-d)) + d + def logadd(self, x, y): + names = x.names + assert names == y.names, "Two operants' names are not matched {} and {}.".format(names, y.names) + return torch.logsumexp(torch.stack([x.rename(None), y.rename(None)]), dim=0).refine_names(*names) + + def logsumexp(self, x, dim=1): + d = torch.max(x, dim)[0] + if x.dim() == 1: + return torch.log(torch.exp(x - d).sum(dim)) + d + else: + return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d + + def __get_scores(self, unary_scores, rule_scores, root_scores, dir_scores): + # INPUT + # unary scores : b x n x (NT + T) + # rule scores : b x (NT+T) x (NT+T) x (NT+T) + # root_scores : b x NT + # dir_scores : 2 x b x NT x (NT + T) x (NT + T) x N + # OUTPUT + # rule scores: 2 x B x (NT x T) x (NT x T) x (NT x T) x N + # (D, B, T, TL, TR, H) + # root_scores : b x NT x n + # (B, T, H) + assert unary_scores.names == ('B', 'H', 'T') + assert rule_scores.names == ('B', 'T', 'TL', 'TR') + assert root_scores.names == ('B', 'T') + assert dir_scores.names == ('D', 'B', 'T', 'H', 'TL', 'TR') + + rule_shape = ('D', 'B', 'T', 'H', 'TL', 'TR') + root_shape = ('B', 'T', 'H') + rule_scores = rule_scores.align_to(*rule_shape) \ + + dir_scores.align_to(*rule_shape) + + if rule_scores.size('H') == 1: + rule_scores = rule_scores.expand(-1, -1, -1, unary_scores.size('H'), -1, -1) + return rule_scores, root_scores, unary_scores + + def __get_scores(self, unary_scores, rule_scores, root_scores): + # INPUT + # unary scores : b x n x (NT + T) + # rule scores : b x (NT+T) x (NT+T) x (NT+T) + # root_scores : b x NT + # dir_scores : 2 x b x NT x (NT + T) x (NT + T) x N + # OUTPUT + # rule scores: 2 x B x (NT x T) x (NT x T) x (NT x T) x N + # (D, B, T, TL, TR, H) + # root_scores : b x NT x n + # (B, T, H) + assert unary_scores.names == ('B', 'H', 'T') + assert rule_scores.names == ('B', 'T', 'H', 'TL', 'TR', 'D') + assert root_scores.names == ('B', 'T') + + rule_shape = ('D', 'B', 'T', 'H', 'TL', 'TR') + root_shape = ('B', 'T', 'H') + rule_scores = rule_scores.align_to(*rule_shape) + + if rule_scores.size('H') == 1: + rule_scores = rule_scores.expand(-1, -1, -1, unary_scores.size('H'), -1, -1) + return rule_scores, root_scores, unary_scores + + def print_name_size(self, x): + print(x.size(), x.names) + + def print_memory_usage(self, lineno, device="cuda:0"): + print("Line {}: {}M".format(lineno, int(memory_allocated(device)/1000000))) + + def cross_bracket(self, l, r, gold_brackets): + for bl, br in gold_brackets: + if((blbr) or (l= 4: + tmp1 = g(self.beta[:, l, l+1, self.word_span_slice, l:r], + self.beta[:, l+1, r, :self.nt_states, l:r], + self.beta_[:, l, l+1, self.word_span_slice], + self.beta_[:, l+1, r, :self.nt_states], + rule_scores[:, :, :, l:r, self.word_span_slice, :self.nt_states]) + + tmp2 = g(self.beta[:, l, r-1, :self.nt_states, l:r], + self.beta[:, r-1, r, self.word_span_slice, l:r], + self.beta_[:, l, r-1, :self.nt_states], + self.beta_[:, r-1, r, self.word_span_slice], + rule_scores[:, :, :, l:r, :self.nt_states, self.word_span_slice]) + + tmp3 = g(self.beta[:, l, l+2:r-1, :self.nt_states, l:r].rename(R='U'), + self.beta[:, l+2:r-1, r, :self.nt_states, l:r].rename(L='U'), + self.beta_[:, l, l+2:r-1, :self.nt_states].rename(R='U'), + self.beta_[:, l+2:r-1, r, :self.nt_states].rename(L='U'), + rule_scores[:, :, :, l:r, :self.nt_states, :self.nt_states].align_to('D', 'B', 'T', 'H', 'U', ...)) + tmp = self.logadd(self.logadd(f(tmp1), f(tmp2)), f(tmp3)) + + tmp = tmp + mask[:, l, r, :self.nt_states, l:r] + self.beta[:, l, r, :self.nt_states, l:r] = tmp.rename(None) + tmp_ = torch.logsumexp(tmp + unary_scores[:, l:r, :self.nt_states].align_as(tmp), dim='H') + self.beta_[:, l, r, :self.nt_states] = tmp_.rename(None) + + + log_Z = self.beta_[:, 0, N, :self.nt_states] + root_scores + log_Z = torch.logsumexp(log_Z, dim='T') + return log_Z + + def _viterbi(self, **kwargs): + #unary scores : b x n x T + #rule scores : b x NT x (NT+T) x (NT+T) + + rule_scores, root_scores, unary_scores = self.__get_scores(**kwargs) + + # statistics + B = rule_scores.size('B') + N = unary_scores.size('H') + T = self.states + + # # dummy rules + # rule_scores = torch.cat([rule_scores, \ + # rule_scores.new(B, self.t_states, T, T) \ + # .fill_(-self.huge)], dim=1) + + self.scores = rule_scores.new(B, N+1, N+1, T, N).fill_(-self.huge).refine_names('B', 'L', 'R', 'T', 'H') + self.scores_ = rule_scores.new(B, N+1, N+1, T).fill_(-self.huge).refine_names('B', 'L', 'R', 'T') + self.bp = rule_scores.new(B, N+1, N+1, T, N).long().fill_(-1).refine_names('B', 'L', 'R', 'T', 'H') + self.left_bp = rule_scores.new(B, N+1, N+1, T, N).long().fill_(-1).refine_names('B', 'L', 'R', 'T', 'H') + self.right_bp = rule_scores.new(B, N+1, N+1, T, N).long().fill_(-1).refine_names('B', 'L', 'R', 'T', 'H') + self.dir_bp = rule_scores.new(B, N+1, N+1, T, N).long().fill_(-1).refine_names('B', 'L', 'R', 'T', 'H') + self.new_head_bp = rule_scores.new(B, N+1, N+1, T).long().fill_(-1).refine_names('B', 'L', 'R', 'T') + self.argmax = rule_scores.new(B, N, N).long().fill_(-1) + self.argmax_tags = rule_scores.new(B, N).long().fill_(-1) + self.spans = [[] for _ in range(B)] + + # initialization: f[k, k+1] + for k in range(N): + for state in range(self.states): + if(not self.nt_emission and state < self.nt_states): + continue + self.scores[:, k, k+1, state, k] = 0 + self.scores_[:, k, k+1, state] = unary_scores[:, k, state].rename(None) + self.new_head_bp[:, k, k+1, state] = k + + for W in np.arange(2, N+1): + for l in range(N-W+1): + r = l + W + + left = lambda x, y, z: x.rename(T='TL').align_as(z) + y.rename(T='TR').align_as(z) + z + right = lambda x, y, z: x.rename(T='TL').align_as(z) + y.rename(T='TR').align_as(z) + z + g = lambda x, y, x_, y_, z: torch.cat((left(x, y_, z[0]).align_as(z), + right(x_, y, z[1]).align_as(z)), dim='D') + + # self.print_name_size(self.scores[:, l, l+1:r, :, l:r]) + # self.print_name_size(rule_scores[:, :, :, l:r, :self.nt_states, :self.nt_states, l:r].align_to('D', 'B', 'T', 'H', 'U', ...)) + tmp = g(self.scores[:, l, l+1:r, :, l:r].rename(R='U'), + self.scores[:, l+1:r, r, :, l:r].rename(L='U'), + self.scores_[:, l, l+1:r, :].rename(R='U'), + self.scores_[:, l+1:r, r, :].rename(L='U'), + rule_scores[:, :, :, l:r, :, :].align_to('D', 'B', 'T', 'H', 'U', ...)) + + tmp = tmp.align_to('B', 'T', 'H', 'D', 'U', 'TL', 'TR').flatten(['D', 'U', 'TL', 'TR'], 'position') + + assert(tmp.size('position') == self.states * self.states * (W-1) * 2), "{}".format(tmp.size('position')) + # view once and marginalize + tmp, max_pos = torch.max(tmp, dim=3) + + max_pos = max_pos.rename(None) + right_child = max_pos % self.states + max_pos /= self.states + left_child = max_pos % self.states + max_pos /= self.states + max_idx = max_pos % (W-1) + l + 1 + max_pos = max_pos / int(W - 1) + max_dir = max_pos + + self.scores[:, l, r, :self.nt_states, l:r] = tmp.rename(None) + tmp_ = tmp + unary_scores[:, l:r, :self.nt_states].align_as(tmp) + tmp_, new_head = torch.max(tmp_, dim='H') + self.scores_[:, l, r, :self.nt_states] = tmp_.rename(None) + + self.bp[:, l, r, :self.nt_states, l:r] = max_idx + self.left_bp[:, l, r, :self.nt_states, l:r] = left_child + self.right_bp[:, l, r, :self.nt_states, l:r] = right_child + self.dir_bp[:, l, r, :self.nt_states, l:r] = max_dir + self.new_head_bp[:, l, r, :self.nt_states] = new_head.rename(None) + l + + max_score = self.scores_[:, 0, N, :self.nt_states] + root_scores + max_score, max_idx = torch.max(max_score, dim='T') + for b in range(B): + self._backtrack(b, 0, N, max_idx[b].item()) + return self.scores, self.argmax, self.spans + + def _backtrack(self, b, s, t, state, head=-1): + if(head == -1): + head = int(self.new_head_bp[b][s][t][state]) + u = int(self.bp[b][s][t][state][head]) + assert(s < t), "s: %d, t %d"%(s, t) + left_state = int(self.left_bp[b][s][t][state][head]) + right_state = int(self.right_bp[b][s][t][state][head]) + direction = int(self.dir_bp[b][s][t][state][head]) + self.argmax[b][s][t-1] = 1 + if s == t-1: + self.spans[b].insert(0, (s, t-1, state, head)) + self.argmax_tags[b][s] = state + return None + else: + self.spans[b].insert(0, (s, t-1, state, head)) + if(direction == 0): + assert head < u, "head: {} < u: {}".format(head, u) + self._backtrack(b, s, u, left_state, head) + self._backtrack(b, u, t, right_state) + else: + assert head >= u, "head: {} >= u: {}".format(head, u) + self._backtrack(b, s, u, left_state) + self._backtrack(b, u, t, right_state, head) + + return None \ No newline at end of file diff --git a/preprocess.py b/preprocess.py new file mode 100755 index 0000000..9800adc --- /dev/null +++ b/preprocess.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""Create data files +""" + +import os +import sys +import argparse +import numpy as np +import pickle +import itertools +from collections import defaultdict +import utils +import re + +class Indexer: + def __init__(self, symbols = ["","","",""]): + self.vocab = defaultdict(int) + self.PAD = symbols[0] + self.UNK = symbols[1] + self.BOS = symbols[2] + self.EOS = symbols[3] + self.d = {self.PAD: 0, self.UNK: 1, self.BOS: 2, self.EOS: 3} + self.idx2word = {} + + def add_w(self, ws): + for w in ws: + if w not in self.d: + self.d[w] = len(self.d) + + def convert(self, w): + return self.d[w] if w in self.d else self.d[self.UNK] + + def convert_sequence(self, ls): + return [self.convert(l) for l in ls] + + def write(self, outfile): + out = open(outfile, "w") + items = [(v, k) for k, v in self.d.items()] + items.sort() + for v, k in items: + out.write(" ".join([k, str(v)]) + "\n") + out.close() + + def prune_vocab(self, k, cnt = False): + vocab_list = [(word, count) for word, count in self.vocab.items()] + if cnt: + self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list if pair[1] > k} + else: + vocab_list.sort(key = lambda x: x[1], reverse=True) + k = min(k, len(vocab_list)) + self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list[:k]} + for word in self.pruned_vocab: + if word not in self.d: + self.d[word] = len(self.d) + for word, idx in self.d.items(): + self.idx2word[idx] = word + + def load_vocab(self, vocab_file): + self.d = {} + for line in open(vocab_file, 'r'): + v, k = line.strip().split() + self.d[v] = int(k) + for word, idx in self.d.items(): + self.idx2word[idx] = word + + +def is_next_open_bracket(line, start_idx): + for char in line[(start_idx + 1):]: + if char == '(': + return True + elif char == ')': + return False + raise IndexError('Bracket possibly not balanced, open bracket not followed by closed bracket') + +def get_between_brackets(line, start_idx): + output = [] + for char in line[(start_idx + 1):]: + if char == ')': + break + assert not(char == '(') + output.append(char) + return ''.join(output) + +def get_tags_tokens_lowercase(line): + output = [] + line_strip = line.rstrip() + for i in range(len(line_strip)): + if i == 0: + assert line_strip[i] == '(' + if line_strip[i] == '(' and not(is_next_open_bracket(line_strip, i)): # fulfilling this condition means this is a terminal symbol + output.append(get_between_brackets(line_strip, i)) + #print 'output:',output + output_tags = [] + output_tokens = [] + output_lowercase = [] + for terminal in output: + terminal_split = terminal.split() + # print(terminal, terminal_split) + assert len(terminal_split) == 2, (terminal_split, output) # each terminal contains a POS tag and word + output_tags.append(terminal_split[0]) + output_tokens.append(terminal_split[1]) + output_lowercase.append(terminal_split[1].lower()) + return [output_tags, output_tokens, output_lowercase] + +def get_nonterminal(line, start_idx): + assert line[start_idx] == '(' # make sure it's an open bracket + output = [] + for char in line[(start_idx + 1):]: + if char == ' ': + break + assert not(char == '(') and not(char == ')') + output.append(char) + return ''.join(output) + + +def get_actions(line): + output_actions = [] + line_strip = line.rstrip() + i = 0 + max_idx = (len(line_strip) - 1) + while i <= max_idx: + assert line_strip[i] == '(' or line_strip[i] == ')' + if line_strip[i] == '(': + if is_next_open_bracket(line_strip, i): # open non-terminal + curr_NT = get_nonterminal(line_strip, i) + output_actions.append('NT(' + curr_NT + ')') + i += 1 + while line_strip[i] != '(': # get the next open bracket, which may be a terminal or another non-terminal + i += 1 + else: # it's a terminal symbol + output_actions.append('SHIFT') + while line_strip[i] != ')': + i += 1 + i += 1 + while line_strip[i] != ')' and line_strip[i] != '(': + i += 1 + else: + output_actions.append('REDUCE') + if i == max_idx: + break + i += 1 + while line_strip[i] != ')' and line_strip[i] != '(': + i += 1 + assert i == max_idx + return output_actions + +def pad(ls, length, symbol): + if len(ls) >= length: + return ls[:length] + return ls + [symbol] * (length -len(ls)) + +def clean_number(w): + new_w = re.sub('[0-9]{1,}([,.]?[0-9]*)*', 'N', w) + return new_w + +def get_data(args): + indexer = Indexer(["","","",""]) + + def make_vocab(textfile, seqlength, minseqlength, lowercase, replace_num, + train=1, apply_length_filter=1): + num_sents = 0 + max_seqlength = 0 + for tree in open(textfile, 'r'): + tree = tree.strip() + try: + tags, sent, sent_lower = get_tags_tokens_lowercase(tree) + except: + print(tree) + + assert(len(tags) == len(sent)) + if lowercase == 1: + sent = sent_lower + if replace_num == 1: + sent = [clean_number(w) for w in sent] + if (len(sent) > seqlength and apply_length_filter == 1) or len(sent) < minseqlength: + continue + num_sents += 1 + max_seqlength = max(max_seqlength, len(sent)) + if train == 1: + for word in sent: + indexer.vocab[word] += 1 + return num_sents, max_seqlength + + def convert(textfile, lowercase, replace_num, + batchsize, seqlength, minseqlength, outfile, num_sents, max_sent_l=0, + shuffle=0, include_boundary=1, apply_length_filter=1, conllfile=""): + newseqlength = seqlength + if include_boundary == 1: + newseqlength += 2 #add 2 for EOS and BOS + sents = np.zeros((num_sents, newseqlength), dtype=int) + sent_lengths = np.zeros((num_sents,), dtype=int) + dropped = 0 + sent_id = 0 + other_data = [] + if(conllfile != ""): + deptrees = utils.read_conll(open(conllfile, "r")) + for tree in open(textfile, 'r'): + tree = tree.strip() + action = get_actions(tree) + tags, sent, sent_lower = get_tags_tokens_lowercase(tree) + assert(len(tags) == len(sent)) + if(conllfile != ""): + words, heads = next(deptrees) + if words != sent: + print("Data mismatch, got {} in {}, but {} in {}.".format(sent, textfile, words, conllfile)) + assert(len(words) == len(heads)) + assert(len(heads) == len(sent)) + if lowercase == 1: + sent = sent_lower + sent_str = " ".join(sent) + if replace_num == 1: + sent = [clean_number(w) for w in sent] + if (len(sent) > seqlength and apply_length_filter == 1) or len(sent) < minseqlength: + continue + + if include_boundary == 1: + sent = [indexer.BOS] + sent + [indexer.EOS] + max_sent_l = max(len(sent), max_sent_l) + sent_pad = pad(sent, newseqlength, indexer.PAD) + sents[sent_id] = np.array(indexer.convert_sequence(sent_pad), dtype=int) + sent_lengths[sent_id] = (sents[sent_id] != 0).sum() + span, binary_actions, nonbinary_actions = utils.get_nonbinary_spans(action) + other_data_item = [sent_str, tags, action, + binary_actions, nonbinary_actions, span, tree] + if(conllfile != ""): + other_data_item.append(heads) + other_data.append(other_data_item) + assert(2*(len(sent)- 2) - 1 == len(binary_actions)) + assert(sum(binary_actions) + 1 == len(sent) - 2) + sent_id += 1 + if sent_id % 100000 == 0: + print("{}/{} sentences processed".format(sent_id, num_sents)) + print(sent_id, num_sents) + if shuffle == 1: + rand_idx = np.random.permutation(sent_id) + sents = sents[rand_idx] + sent_lengths = sent_lengths[rand_idx] + other_data = [other_data[idx] for idx in rand_idx] + + print(len(sents), len(other_data)) + #break up batches based on source lengths + sent_lengths = sent_lengths[:sent_id] + sent_sort = np.argsort(sent_lengths) + sents = sents[sent_sort] + other_data = [other_data[idx] for idx in sent_sort] + sent_l = sent_lengths[sent_sort] + curr_l = 1 + l_location = [] #idx where sent length changes + + for j,i in enumerate(sent_sort): + if sent_lengths[i] > curr_l: + curr_l = sent_lengths[i] + l_location.append(j) + l_location.append(len(sents)) + #get batch sizes + curr_idx = 0 + batch_idx = [0] + nonzeros = [] + batch_l = [] + batch_w = [] + for i in range(len(l_location)-1): + while curr_idx < l_location[i+1]: + curr_idx = min(curr_idx + batchsize, l_location[i+1]) + batch_idx.append(curr_idx) + for i in range(len(batch_idx)-1): + batch_l.append(batch_idx[i+1] - batch_idx[i]) + batch_w.append(sent_l[batch_idx[i]]) + + # Write output + f = {} + f["source"] = sents + f["other_data"] = other_data + f["batch_l"] = np.array(batch_l, dtype=int) + f["source_l"] = np.array(batch_w, dtype=int) + f["sents_l"] = np.array(sent_l, dtype = int) + f["batch_idx"] = np.array(batch_idx[:-1], dtype=int) + f["vocab_size"] = np.array([len(indexer.d)]) + f["idx2word"] = indexer.idx2word + f["word2idx"] = {word : idx for idx, word in indexer.idx2word.items()} + + print("Saved {} sentences (dropped {} due to length/unk filter)".format( + len(f["source"]), dropped)) + pickle.dump(f, open(outfile, 'wb')) + return max_sent_l + + print("First pass through data to get vocab...") + num_sents_train, train_seqlength = make_vocab(args.trainfile, args.seqlength, args.minseqlength, + args.lowercase, args.replace_num, 1, 1) + print("Number of sentences in training: {}".format(num_sents_train)) + num_sents_valid, valid_seqlength = make_vocab(args.valfile, args.seqlength, args.minseqlength, + args.lowercase, args.replace_num, 0, 0) + print("Number of sentences in valid: {}".format(num_sents_valid)) + num_sents_test, test_seqlength = make_vocab(args.testfile, args.seqlength, args.minseqlength, + args.lowercase, args.replace_num, 0, 0) + print("Number of sentences in test: {}".format(num_sents_test)) + + if args.vocabminfreq >= 0: + indexer.prune_vocab(args.vocabminfreq, True) + else: + indexer.prune_vocab(args.vocabsize, False) + if args.vocabfile != '': + print('Loading pre-specified source vocab from ' + args.vocabfile) + indexer.load_vocab(args.vocabfile) + indexer.write(args.outputfile + ".dict") + print("Vocab size: Original = {}, Pruned = {}".format(len(indexer.vocab), + len(indexer.d))) + print(train_seqlength, valid_seqlength, test_seqlength) + max_sent_l = 0 + max_sent_l = convert(args.testfile, args.lowercase, args.replace_num, + args.batchsize, test_seqlength, args.minseqlength, + args.outputfile + "-test.pkl", num_sents_test, + max_sent_l, args.shuffle, args.include_boundary, 0, + conllfile=os.path.splitext(args.testfile)[0] + ".conllx" if args.dep else "") + max_sent_l = convert(args.valfile, args.lowercase, args.replace_num, + args.batchsize, valid_seqlength, args.minseqlength, + args.outputfile + "-val.pkl", num_sents_valid, + max_sent_l, args.shuffle, args.include_boundary, 0, + conllfile=os.path.splitext(args.valfile)[0] + ".conllx" if args.dep else "") + max_sent_l = convert(args.trainfile, args.lowercase, args.replace_num, + args.batchsize, args.seqlength, args.minseqlength, + args.outputfile + "-train.pkl", num_sents_train, + max_sent_l, args.shuffle, args.include_boundary, 1, + conllfile=os.path.splitext(args.trainfile)[0] + ".conllx" if args.dep else "") + print("Max sent length (before dropping): {}".format(max_sent_l)) + +def main(arguments): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--vocabsize', help="Size of source vocabulary, constructed " + "by taking the top X most frequent words. " + " Rest are replaced with special UNK tokens.", + type=int, default=10000) + parser.add_argument('--vocabminfreq', help="Minimum frequency for vocab. Use this instead of " + "vocabsize if > 0", + type=int, default=-1) + parser.add_argument('--include_boundary', help="Add BOS/EOS tokens", type=int, default=1) + parser.add_argument('--lowercase', help="Lower case", type=int, default=1) + parser.add_argument('--replace_num', help="Replace numbers with N", type=int, default=1) + parser.add_argument('--trainfile', help="Path to training data.", required=True) + parser.add_argument('--valfile', help="Path to validation data.", required=True) + parser.add_argument('--testfile', help="Path to test validation data.", required=True) + parser.add_argument('--batchsize', help="Size of each minibatch.", type=int, default=4) + parser.add_argument('--seqlength', help="Maximum sequence length. Sequences longer " + "than this are dropped.", type=int, default=150) + parser.add_argument('--minseqlength', help="Minimum sequence length. Sequences shorter " + "than this are dropped.", type=int, default=0) + parser.add_argument('--outputfile', help="Prefix of the output file names. ", type=str, + required=True) + parser.add_argument('--vocabfile', help="If working with a preset vocab, " + "then including this will ignore srcvocabsize and use the" + "vocab provided here.", + type = str, default='') + parser.add_argument('--shuffle', help="If = 1, shuffle sentences before sorting (based on " + "source length).", + type = int, default = 0) + parser.add_argument('--dep', action="store_true", help="Including dependency parse files. Their " + "names should be same as data file, but extensions " + "are .conllx.") + args = parser.parse_args(arguments) + np.random.seed(3435) + get_data(args) + +if __name__ == '__main__': + sys.exit(main(sys.argv[1:])) diff --git a/process_ptb.py b/process_ptb.py new file mode 100644 index 0000000..7eb14fd --- /dev/null +++ b/process_ptb.py @@ -0,0 +1,93 @@ +import os +import re +import sys +import argparse +import nltk +from nltk.corpus import ptb +import os +from pathlib import Path + + +def get_data_ptb(root, output): + # tag filter is from https://github.com/yikangshen/PRPN/blob/master/data_ptb.py + word_tags = ['CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', + 'NNS', 'NNP', 'NNPS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', + 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', + 'WDT', 'WP', 'WP$', 'WRB'] + currency_tags_words = ['#', '$', 'C$', 'A$'] + ellipsis = ['*', '*?*', '0', '*T*', '*ICH*', '*U*', '*RNR*', '*EXP*', '*PPA*', '*NOT*'] + punctuation_tags = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``'] + punctuation_words = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``', '--', ';', + '-', '?', '!', '...', '-LCB-', '-RCB-'] + train_file_ids = [] + val_file_ids = [] + test_file_ids = [] + train_section = ['02', '03', '04', '05', '06', '07', '08', '09', '10', + '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21'] + val_section = ['22'] + test_section = ['23'] + + for dir_name, _, file_list in os.walk(root, topdown=False): + if dir_name.split("/")[-1] in train_section: + file_ids = train_file_ids + elif dir_name.split("/")[-1] in val_section: + file_ids = val_file_ids + elif dir_name.split("/")[-1] in test_section: + file_ids = test_file_ids + else: + continue + for fname in file_list: + file_ids.append(os.path.join(dir_name, fname)) + assert(file_ids[-1].split(".")[-1] == "mrg") + print(len(train_file_ids), len(val_file_ids), len(test_file_ids)) + + def del_tags(tree, word_tags): + for sub in tree.subtrees(): + for n, child in enumerate(sub): + if isinstance(child, str): + continue + if all(leaf_tag not in word_tags for leaf, leaf_tag in child.pos()): + del sub[n] + + def save_file(file_ids, out_file): + sens = [] + trees = [] + tags = [] + f_out = open(out_file, 'w') + for f in file_ids: + sentences = ptb.parsed_sents(f) + for sen_tree in sentences: + orig = sen_tree.pformat(margin=sys.maxsize).strip() + c = 0 + while not all([tag in word_tags for _, tag in sen_tree.pos()]): + del_tags(sen_tree, word_tags) + c += 1 + if c > 10: + assert False + out = sen_tree.pformat(margin=sys.maxsize).strip() + while re.search('\(([A-Z0-9]{1,})((-|=)[A-Z0-9]*)*\s{1,}\)', out) is not None: + out = re.sub('\(([A-Z0-9]{1,})((-|=)[A-Z0-9]*)*\s{1,}\)', '', out) + out = out.replace(' )', ')') + out = re.sub('\s{2,}', ' ', out) + f_out.write(out + '\n') + f_out.close() + + save_file(train_file_ids, os.path.join(output, "ptb-train.txt")) + save_file(val_file_ids, os.path.join(output, "ptb-valid.txt")) + save_file(test_file_ids, os.path.join(output, "ptb-test.txt")) + +def main(arguments): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--ptb_path', help='Path to parsed/mrg/wsj folder', type=str, + default='PATH-TO-PTB/parsed/mrg/wsj') + parser.add_argument('--output_path', help='Path to save processed files', + type=str, default='data') + args = parser.parse_args(arguments) + get_data_ptb(args.ptb_path, args.output_path) + +if __name__ == '__main__': + sys.exit(main(sys.argv[1:])) + + diff --git a/train.py b/train.py new file mode 100755 index 0000000..ce76b5b --- /dev/null +++ b/train.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python3 +import sys +import os + +import argparse +import json +import random +import shutil +import copy + +from collections import defaultdict + +import torch +from torch import cuda +import numpy as np +import time +import logging +from data import Dataset +from utils import * +from models import CompPCFG, LexicalizedCompPCFG +from torch.nn.init import xavier_uniform_ +from torch.utils.tensorboard import SummaryWriter + +try: + from apex import amp + APEX_AVAILABLE = True +except ModuleNotFoundError: + APEX_AVAILABLE = False + +import pdb + +import warnings +warnings.filterwarnings("ignore") + +parser = argparse.ArgumentParser() + +# Program options +parser.add_argument('--mode', default='train', help='train/test') +parser.add_argument('--test_file', default='data/preprocessed/ptb-test.pkl') +# Data path options +parser.add_argument('--train_file', default='data/preprocessed/ptb-train.pkl') +parser.add_argument('--val_file', default='data/preprocessed/ptb-val.pkl') +parser.add_argument('--save_path', default='compound-pcfg.pt', help='where to save the model') +parser.add_argument('--pretrained_word_emb', default="", help="word emb file") +# Model options +parser.add_argument('--model', default='LexicalizedCompPCFG', type=str, help='model name') +parser.add_argument('--load_model', default='', type=str, help='checkpoint file of stored model') +parser.add_argument('--init_gain', default=1., type=float, help='gain of xaviar initialization') +parser.add_argument('--init_model', default='', help='initial lexicalized pcfg with compound pcfg') +# Generative model parameters +parser.add_argument('--z_dim', default=64, type=int, help='latent dimension') +parser.add_argument('--t_states', default=60, type=int, help='number of preterminal states') +parser.add_argument('--nt_states', default=30, type=int, help='number of nonterminal states') +parser.add_argument('--state_dim', default=256, type=int, help='symbol embedding dimension') +parser.add_argument('--nt_emission', action="store_true", help='allow a single word span with a non-terminal') +parser.add_argument('--scalar_dir_scores', action="store_true", help='using scalar dir scores instead neural ones') +parser.add_argument('--seperate_nt_emb_for_emission', action="store_true", help='seperate nt embeddings for emission probability') +parser.add_argument('--head_first', action="store_true", help="first generate head and direction") +parser.add_argument('--tie_word_emb', action="store_true", help="tie the word embeddings") +parser.add_argument('--flow_word_emb', action="store_true", help="emit words via invertible flow") +parser.add_argument('--freeze_word_emb', action="store_true", help="freeze word embeddings") +# Inference network parameters +parser.add_argument('--h_dim', default=512, type=int, help='hidden dim for variational LSTM') +parser.add_argument('--w_dim', default=512, type=int, help='embedding dim for variational LSTM') +# Optimization options +parser.add_argument('--num_epochs', default=10, type=int, help='number of training epochs') +parser.add_argument('--lr', default=0.001, type=float, help='starting learning rate') +parser.add_argument('--delay_step', default=1, type=int, help='number of backprop before step') +parser.add_argument('--max_grad_norm', default=3, type=float, help='gradient clipping parameter') +parser.add_argument('--max_length', default=30, type=float, help='max sentence length cutoff start') +parser.add_argument('--len_incr', default=1, type=int, help='increment max length each epoch') +parser.add_argument('--final_max_length', default=40, type=int, help='final max length cutoff') +parser.add_argument('--eval_max_length', default=None, type=int, help='max length in evaluation. set to the same as final_max_length by default') +parser.add_argument('--beta1', default=0.75, type=float, help='beta1 for adam') +parser.add_argument('--beta2', default=0.999, type=float, help='beta2 for adam') +parser.add_argument('--gpu', default=0, type=int, help='which gpu to use') +parser.add_argument('--seed', default=3435, type=int, help='random seed') +parser.add_argument('--print_every', type=int, default=1000, help='print stats after N batches') +parser.add_argument('--supervised_signals', nargs="*", default = [], help="supervised signals to use") +parser.add_argument('--opt_level', type=str, default="O0", help="mixed precision") +parser.add_argument('--t_emb_init', type=str, default="", help="initial value of t_emb") +parser.add_argument('--vocab_mlp_identity_init', action='store_true', help="initialize vocab_mlp as identity function") +# Evaluation optiones +parser.add_argument('--evaluate_dep', action='store_true', help='evaluate dependency parsing results') + +parser.add_argument('--log_dir', type=str, default="", help='tensorboard logdir') + +args = parser.parse_args() + +if(args.eval_max_length is None): + args.eval_max_length = args.final_max_length + +# tensorboard +if(args.log_dir == ""): + writer = SummaryWriter() +else: + writer = SummaryWriter(log_dir=args.log_dir) +global_step = 0 + +def add_scalars(main_tag, tag_scalar_dict, global_step): + for tag in tag_scalar_dict: + writer.add_scalar("{}/{}".format(main_tag, tag), tag_scalar_dict[tag], global_step) + +def main(args): + global global_step + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if(args.mode == 'train'): + train_data = Dataset(args.train_file, load_dep=args.evaluate_dep) + val_data = Dataset(args.val_file, load_dep=args.evaluate_dep) + train_sents = train_data.batch_size.sum() + vocab_size = int(train_data.vocab_size) + max_len = max(val_data.sents.size(1), train_data.sents.size(1)) + print('Train: %d sents / %d batches, Val: %d sents / %d batches' % + (train_data.sents.size(0), len(train_data), val_data.sents.size(0), len(val_data))) + if(not args.pretrained_word_emb == ""): + pretrained_word_emb_matrix = get_word_emb_matrix(args.pretrained_word_emb, train_data.idx2word) + else: + pretrained_word_emb_matrix = None + else: + test_data = Dataset(args.test_file, load_dep=args.evaluate_dep) + vocab_size = int(test_data.vocab_size) + max_len = test_data.sents.size(1) + print("Test: %d sents / %d batches" % (test_data.sents.size(0), len(test_data))) + if(not args.pretrained_word_emb == ""): + pretrained_word_emb_matrix = get_word_emb_matrix(args.pretrained_word_emb, test_data.idx2word) + else: + pretrained_word_emb_matrix = None + print('Vocab size: %d, Max Sent Len: %d' % (vocab_size, max_len)) + print('Save Path', args.save_path) + cuda.set_device(args.gpu) + if(args.model == 'CompPCFG'): + model = CompPCFG(vocab = vocab_size, + state_dim = args.state_dim, + t_states = args.t_states, + nt_states = args.nt_states, + h_dim = args.h_dim, + w_dim = args.w_dim, + z_dim = args.z_dim) + init_model = None + elif(args.model == 'LexicalizedCompPCFG'): + if args.init_model != '': + init_model = CompPCFG(vocab = vocab_size, + state_dim = args.state_dim, + t_states = args.t_states, + nt_states = args.nt_states, + h_dim = args.h_dim, + w_dim = args.w_dim, + z_dim = args.z_dim) + init_model.load_state_dict(torch.load(args.init_model)["model"]) + args.supervised_signals = ["phrase", "tag", "nt"] + else: + init_model = None + model = LexicalizedCompPCFG(vocab = vocab_size, + state_dim = args.state_dim, + t_states = args.t_states, + nt_states = args.nt_states, + h_dim = args.h_dim, + w_dim = args.w_dim, + z_dim = args.z_dim, + nt_emission=args.nt_emission, + scalar_dir_scores=args.scalar_dir_scores, + seperate_nt_emb_for_emission=args.seperate_nt_emb_for_emission, + head_first=args.head_first, + tie_word_emb=args.tie_word_emb, + flow_word_emb=args.flow_word_emb, + freeze_word_emb=args.freeze_word_emb, + pretrained_word_emb=pretrained_word_emb_matrix, + supervised_signals=args.supervised_signals) + else: + raise NotImplementedError + for name, param in model.named_parameters(): + if param.dim() > 1: + xavier_uniform_(param, args.init_gain) + if(args.t_emb_init != ""): + t_emb_init = np.loadtxt(args.t_emb_init) + model.t_emb.data.copy_(torch.from_numpy(t_emb_init)) + if(args.vocab_mlp_identity_init): + model.vocab_mlp[0].bias.data.copy_(torch.zeros(args.state_dim)) + model.vocab_mlp[0].weight.data.copy_(torch.cat([torch.eye(args.state_dim, args.state_dim), torch.zeros(args.state_dim, args.z_dim)], dim=1)) + if(args.load_model != ''): + print("Loading model from {}.".format(args.load_model)) + model.load_state_dict(torch.load(args.load_model)["model"]) + print("Model loaded from {}.".format(args.load_model)) + print("model architecture") + print(model) + model.train() + model.cuda() + if init_model: + init_model.eval() + init_model.cuda() + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas = (args.beta1, args.beta2)) + if args.opt_level != "O0": + model.pcfg.huge = 1e4 + model, optimizer = amp.initialize( + model, optimizer, opt_level=args.opt_level, + keep_batchnorm_fp32=True, loss_scale="dynamic" + ) + if(args.mode == "test"): + print('--------------------------------') + print('Checking validation perf...') + test_ppl, test_f1 = eval(test_data, model) + print('--------------------------------') + return + best_val_ppl = 1e5 + best_val_f1 = 0 + epoch = 0 + while epoch < args.num_epochs: + start_time = time.time() + epoch += 1 + print('Starting epoch %d' % epoch) + train_nll = 0. + train_kl = 0. + num_sents = 0. + num_words = 0. + all_stats = [[0., 0., 0.]] + if(args.evaluate_dep): + dep_stats = [[0., 0., 0.]] + b = b_ = 0 + optimization_delay_count_down = args.delay_step + for i in np.random.permutation(len(train_data)): + b += 1 + gold_tree = None + if(not args.evaluate_dep): + sents, length, batch_size, _, _, gold_spans, gold_binary_trees, _ = train_data[i] + else: + sents, length, batch_size, gold_tags, gold_actions, gold_spans, gold_binary_trees, _, heads = train_data[i] + if(len(args.supervised_signals)): + gold_tree = [] + for j in range(len(heads)): + gold_tree.append(get_span2head(gold_spans[j], heads[j], gold_actions=gold_actions[j], gold_tags=gold_tags[j])) + for span, (head, label) in gold_tree[j].items(): + if(span[0] == span[1]): + gold_tree[j][span] = (head, PT2ID[label]) + else: + f = lambda x : x[:x.find('-')] if x.find('-') != -1 else x + g = lambda y : y[:y.find('=')] if y.find('=') != -1 else y + gold_tree[j][span] = (head, NT2ID[f(g(label))]) + if length > args.max_length or length == 1: #length filter based on curriculum + continue + b_ += 1 + sents = sents.cuda() + if init_model: + gold_tree = [] + with torch.no_grad(): + _, _, _, argmax_spans = init_model(sents, argmax=True) + for j in range(len(argmax_spans)): + gold_tree.append({}) + for span in argmax_spans[j]: + if(span[0] == span[1]): + gold_tree[j][(span[0], span[1])] = (-1, span[2] - args.nt_states) + else: + gold_tree[j][(span[0], span[1])] = (-1, span[2]) + nll, kl, binary_matrix, argmax_spans = model(sents, argmax=True, gold_tree=gold_tree) + loss = (nll + kl).mean() + if(args.opt_level != "O0"): + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + train_nll += nll.sum().item() + train_kl += kl.sum().item() + if(optimization_delay_count_down == 1): + if args.max_grad_norm > 0: + if args.opt_level == "O0": + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(amp.master_params( + optimizer), args.max_grad_norm) + optimizer.step() + optimizer.zero_grad() + optimization_delay_count_down = args.delay_step + else: + optimization_delay_count_down -= 1 + num_sents += batch_size + num_words += batch_size * (length + 1) # we implicitly generate so we explicitly count it + for bb in range(batch_size): + span_b = [(a[0], a[1]) for a in argmax_spans[bb] if a[0] != a[1]] #ignore labels + span_b_set = set(span_b[:-1]) + update_stats(span_b_set, [set(gold_spans[bb][:-1])], all_stats) + if(args.evaluate_dep): + update_dep_stats(argmax_spans[bb], heads[bb], dep_stats) + if b_ % args.print_every == 0: + all_f1 = get_f1(all_stats) + dir_acc, undir_acc = get_dep_acc(dep_stats) if args.evaluate_dep else (0., 0.) + param_norm = sum([p.norm()**2 for p in model.parameters()]).item()**0.5 + gparam_norm = sum([p.grad.norm()**2 for p in model.parameters() + if p.grad is not None]).item()**0.5 + log_str = 'Epoch: %d, Batch: %d/%d, |Param|: %.6f, |GParam|: %.2f, LR: %.4f, ' + \ + 'ReconPPL: %.2f, NLLloss: %.4f, KL: %.4f, PPLBound: %.2f, ValPPL: %.2f, ValF1: %.2f, ' + \ + 'CorpusF1: %.2f, DirAcc: %.2f, UndirAcc: %.2f, Throughput: %.2f examples/sec' + print(log_str % + (epoch, b, len(train_data), param_norm, gparam_norm, args.lr, + np.exp(train_nll / num_words), train_nll / num_words, train_kl /num_sents, + np.exp((train_nll + train_kl)/num_words), best_val_ppl, best_val_f1, + all_f1[0], dir_acc, undir_acc, num_sents / (time.time() - start_time))) + # print an example parse + tree = get_tree_from_binary_matrix(binary_matrix[0], length) + action = get_actions(tree) + sent_str = [train_data.idx2word[word_idx] for word_idx in list(sents[0].cpu().numpy())] + if(args.evaluate_dep): + print("Pred Tree: %s" % get_tagged_parse(get_tree(action, sent_str), argmax_spans[0])) + else: + print("Pred Tree: %s" % get_tree(action, sent_str)) + print("Gold Tree: %s" % get_tree(gold_binary_trees[0], sent_str)) + + # tensorboard + global_step += args.print_every + add_scalars(main_tag="train", + tag_scalar_dict={"ParamNorm": param_norm, + "ParamGradNorm": gparam_norm, + "ReconPPL": np.exp(train_nll / num_words), + "KL": train_kl /num_sents, + "PPLBound": np.exp((train_nll + train_kl)/num_words), + "CorpusF1": all_f1[0], + "DirAcc": dir_acc, + "UndirAcc": undir_acc, + "Throughput (examples/sec)": num_sents / (time.time() - start_time), + "GPU memory usage": torch.cuda.memory_allocated()}, + global_step=global_step) + if(args.evaluate_dep): + writer.add_text("Pred Tree", get_tagged_parse(get_tree(action, sent_str), argmax_spans[0]), global_step) + else: + writer.add_text("Pred Tree", get_tree(action, sent_str), global_step) + writer.add_text("Gold Tree", get_tree(gold_binary_trees[0], sent_str), global_step) + + args.max_length = min(args.final_max_length, args.max_length + args.len_incr) + print('--------------------------------') + print('Checking validation perf...') + val_ppl, val_f1 = eval(val_data, model) + print('--------------------------------') + if val_ppl < best_val_ppl: + best_val_ppl = val_ppl + best_val_f1 = val_f1 + checkpoint = { + 'args': args.__dict__, + 'model': model.cpu().state_dict(), + 'word2idx': train_data.word2idx, + 'idx2word': train_data.idx2word + } + print('Saving checkpoint to %s' % args.save_path) + torch.save(checkpoint, args.save_path) + model.cuda() + +def eval(data, model): + global global_step + model.eval() + num_sents = 0 + num_words = 0 + total_nll = 0. + total_kl = 0. + corpus_f1 = [0., 0., 0.] + corpus_f1_by_cat = [defaultdict(int), defaultdict(int), defaultdict(int)] + dep_stats = [[0., 0., 0.]] + sent_f1 = [] + + # f = open("tmp.txt", "w") + + with torch.no_grad(): + for i in range(len(data)): + if(not args.evaluate_dep): + sents, length, batch_size, _, gold_actions, gold_spans, gold_binary_trees, other_data = data[i] + else: + sents, length, batch_size, gold_tags, gold_actions, gold_spans, gold_binary_trees, other_data, heads = data[i] + span_dicts = [] + for j in range(batch_size): + span_dict = {} + for l, r, nt in get_nonbinary_spans_label(gold_actions[j])[0]: + span_dict[(l, r)] = nt + span_dicts.append(span_dict) + if length == 1 or length > args.eval_max_length: + continue + sents = sents.cuda() + # note that for unsuperised parsing, we should do model(sents, argmax=True, use_mean = True) + # but we don't for eval since we want a valid upper bound on PPL for early stopping + # see eval.py for proper MAP inference + nll, kl, binary_matrix, argmax_spans = model(sents, argmax=True) + total_nll += nll.sum().item() + total_kl += kl.sum().item() + num_sents += batch_size + num_words += batch_size*(length +1) # we implicitly generate so we explicitly count it + + gold_tree = [] + for j in range(len(heads)): + gold_tree.append(get_span2head(gold_spans[j], heads[j], gold_actions=gold_actions[j], gold_tags=gold_tags[j])) + for span, (head, label) in gold_tree[j].items(): + if(span[0] == span[1]): + gold_tree[j][span] = (head, PT2ID[label]) + else: + f = lambda x : x[:x.find('-')] if x.find('-') != -1 else x + g = lambda y : y[:y.find('=')] if y.find('=') != -1 else y + gold_tree[j][span] = (head, f(g(label))) + + for b in range(batch_size): + # for a in argmax_spans[b]: + # if((a[0], a[1]) in span_dicts[b]): + # f.write("{}\t{}\n".format(a[2], span_dicts[b][(a[0], a[1])])) + + span_b = [(a[0], a[1]) for a in argmax_spans[b] if a[0] != a[1]] #ignore labels + span_b_set = set(span_b[:-1]) + gold_b_set = set(gold_spans[b][:-1]) + tp, fp, fn = get_stats(span_b_set, gold_b_set) + corpus_f1[0] += tp + corpus_f1[1] += fp + corpus_f1[2] += fn + tp_by_cat, all_by_cat = get_stats_by_cat(span_b_set, gold_b_set, gold_tree[b]) + for j in tp_by_cat: + corpus_f1_by_cat[0][j] += tp_by_cat[j] + for j in all_by_cat: + corpus_f1_by_cat[1][j] += all_by_cat[j] + # sent-level F1 is based on L83-89 from https://github.com/yikangshen/PRPN/test_phrase_grammar.py + + model_out = span_b_set + std_out = gold_b_set + overlap = model_out.intersection(std_out) + prec = float(len(overlap)) / (len(model_out) + 1e-8) + reca = float(len(overlap)) / (len(std_out) + 1e-8) + if len(std_out) == 0: + reca = 1. + if len(model_out) == 0: + prec = 1. + f1 = 2 * prec * reca / (prec + reca + 1e-8) + sent_f1.append(f1) + + if(args.evaluate_dep): + update_dep_stats(argmax_spans[b], heads[b], dep_stats) + tp, fp, fn = corpus_f1 + prec = tp / (tp + fp) + recall = tp / (tp + fn) + corpus_f1 = 2*prec*recall/(prec+recall) if prec+recall > 0 else 0. + for j in corpus_f1_by_cat[1]: + corpus_f1_by_cat[2][j] = corpus_f1_by_cat[0] / corpus_f1_by_cat[1] + sent_f1 = np.mean(np.array(sent_f1)) + dir_acc, undir_acc = get_dep_acc(dep_stats) if args.evaluate_dep else (0., 0.) + recon_ppl = np.exp(total_nll / num_words) + ppl_elbo = np.exp((total_nll + total_kl)/num_words) + kl = total_kl /num_sents + print('ReconPPL: %.2f, KL: %.4f, NLLloss: %.4f, PPL (Upper Bound): %.2f' % + (recon_ppl, kl, total_nll / num_words, ppl_elbo)) + print('Corpus F1: %.2f, Sentence F1: %.2f' % + (corpus_f1*100, sent_f1*100)) + if(args.evaluate_dep): + print('DirAcc: %.2f, UndirAcc: %.2f'%(dir_acc, undir_acc)) + print('Corpus Recall by Category: {}'.format(corpus_f1_by_cat[2])) + # tensorboard + add_scalars(main_tag="validation", + tag_scalar_dict={"ReconPPL": recon_ppl, + "KL": kl, + "PPL (Upper Bound)": ppl_elbo, + "Corpus F1": corpus_f1 * 100, + "Sentence F1": sent_f1*100, + "DirAcc": dir_acc if args.evaluate_dep else 0, + "UndirAcc": undir_acc if args.evaluate_dep else 0}, + global_step=global_step) + model.train() + return ppl_elbo, sent_f1*100 + +if __name__ == '__main__': + main(args)