diff --git a/chsmm.py b/chsmm.py index 2bf42eb..7e768ef 100644 --- a/chsmm.py +++ b/chsmm.py @@ -10,7 +10,6 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from torch.autograd import Variable import labeled_data from utils import logsumexp1, make_fwd_constr_idxs, make_bwd_constr_idxs, backtrace3, backtrace from data.utils import get_wikibio_poswrds, get_e2e_poswrds @@ -39,7 +38,7 @@ def __init__(self, wordtypes, gentypes, opt): self.yes_self_trans = opt.yes_self_trans if not self.yes_self_trans: selfmask = torch.Tensor(opt.K*opt.Kmul).fill_(-float("inf")) - self.register_buffer('selfmask', Variable(torch.diag(selfmask), requires_grad=False)) + self.register_buffer('selfmask', torch.diag(selfmask)) self.max_pool = opt.max_pool self.emb_size, self.layers, self.hid_size = opt.emb_size, opt.layers, opt.hid_size @@ -203,7 +202,7 @@ def len_logprobs(self): len_scores = self.len_scores.expand(K, self.L) else: len_scores = self.len_decoder(state_embs) # K x L - lplist = [Variable(len_scores.data.new(1, K).zero_())] + lplist = [len_scores.data.new(1, K).zero_()] for l in xrange(2, self.L+1): lplist.append(self.lsm(len_scores.narrow(1, 0, l)).t()) return lplist, len_scores @@ -299,7 +298,7 @@ def obs_logprobs(self, x, srcenc, srcfieldenc, fieldmask, combotargs, bsz): # bsz x dim -> bsz x seqlen x dim -> bsz*seqlen x dim -> layers x bsz*seqlen x dim inits = self.h0_lin(srcenc) # bsz x 2*dim h0, c0 = inits[:, :rnn_size], inits[:, rnn_size:] # (bsz x dim, bsz x dim) - h0 = F.tanh(h0).unsqueeze(1).expand(bsz, seqlen, rnn_size).contiguous().view( + h0 = torch.tanh(h0).unsqueeze(1).expand(bsz, seqlen, rnn_size).contiguous().view( -1, rnn_size).unsqueeze(0).expand(layers, -1, rnn_size).contiguous() c0 = c0.unsqueeze(1).expand(bsz, seqlen, rnn_size).contiguous().view( -1, rnn_size).unsqueeze(0).expand(layers, -1, rnn_size).contiguous() @@ -325,7 +324,7 @@ def obs_logprobs(self, x, srcenc, srcfieldenc, fieldmask, combotargs, bsz): Lp1, bsz, seqlen, -1) # L+1 x bsz x seqlen x rnn_size -> bsz x (L+1)seqlen x rnn_size attnin1 = attnin1.transpose(0, 1).contiguous().view(bsz, Lp1*seqlen, -1) - attnin1 = F.tanh(attnin1) + attnin1 = torch.tanh(attnin1) ascores = torch.bmm(attnin1, srcfieldenc.transpose(1, 2)) # bsz x (L+1)slen x nfield ascores = ascores + fieldmask.unsqueeze(1).expand_as(ascores) aprobs = F.softmax(ascores, dim=2) @@ -337,7 +336,7 @@ def obs_logprobs(self, x, srcenc, srcfieldenc, fieldmask, combotargs, bsz): out_hid_sz = rnn_size + encdim cat_ctx = cat_ctx.view(Lp1, -1, out_hid_sz) # now linear to get L+1 x bsz*seqlen x rnn_size - states_k = F.tanh(cat_ctx * self.state_out_gates[k].expand_as(cat_ctx) + states_k = torch.tanh(cat_ctx * self.state_out_gates[k].expand_as(cat_ctx) + self.state_out_biases[k].expand_as(cat_ctx)).view( Lp1, -1, out_hid_sz) @@ -347,7 +346,7 @@ def obs_logprobs(self, x, srcenc, srcfieldenc, fieldmask, combotargs, bsz): Lp1, bsz, seqlen, -1) # L+1 x bsz x seqlen x rnn_size -> bsz x (L+1)seqlen x emb_size attnin2 = attnin2.transpose(0, 1).contiguous().view(bsz, Lp1*seqlen, -1) - attnin2 = F.tanh(attnin2) + attnin2 = torch.tanh(attnin2) ascores = torch.bmm(attnin2, srcfieldenc.transpose(1, 2)) # bsz x (L+1)slen x nfield ascores = ascores + fieldmask.unsqueeze(1).expand_as(ascores) @@ -356,7 +355,7 @@ def obs_logprobs(self, x, srcenc, srcfieldenc, fieldmask, combotargs, bsz): ascores.view(bsz, Lp1, seqlen, nfields).transpose( 0, 1).contiguous().view(-1, nfields)], 1), dim=1) # concatenate on dummy column for when only a single answer... - wlps_k = torch.cat([wlps_k, Variable(self.zeros.expand(wlps_k.size(0), 1))], 1) + wlps_k = torch.cat([wlps_k, self.zeros.expand(wlps_k.size(0), 1)], 1) # get scores for predicted next-words (but not for last words in each segment as usual) psk = wlps_k.narrow(0, 0, self.L*bszsl).gather(1, combotargs.view(self.L*bszsl, -1)) if self.lse_obj: @@ -403,7 +402,7 @@ def encode(self, src, avgmask, uniqfields): masked = embs.view(bsz, nfields, emb_size) srcenc = F.max_pool1d(masked.transpose(1, 2), nfields).squeeze(2) # bsz x emb_size else: - embs = F.tanh(embs.sum(1) + self.src_bias.expand(bsz*nfields, emb_size)) + embs = torch.tanh(embs.sum(1) + self.src_bias.expand(bsz*nfields, emb_size)) # average it manually, bleh if avgmask is not None: srcenc = (embs.view(bsz, nfields, emb_size) @@ -433,18 +432,18 @@ def get_next_word_dist(self, hid, k, srcfieldenc): srcfldenc = srcfieldenc.expand(bsz, nfields, rnn_size) attnin1 = (hid * self.state_att_gates[k].expand_as(hid) + self.state_att_biases[k].expand_as(hid)) # 1 x bsz x rnn_size - attnin1 = F.tanh(attnin1) + attnin1 = torch.tanh(attnin1) ascores = torch.bmm(attnin1.transpose(0, 1), srcfldenc.transpose(1, 2)) # bsz x 1 x nfields aprobs = F.softmax(ascores, dim=2) ctx = torch.bmm(aprobs, srcfldenc) # bsz x 1 x rnn_size cat_ctx = torch.cat([hid, ctx.transpose(0, 1)], 2) # 1 x bsz x rnn_size - state_k = F.tanh(cat_ctx * self.state_out_gates[k].expand_as(cat_ctx) + state_k = torch.tanh(cat_ctx * self.state_out_gates[k].expand_as(cat_ctx) + self.state_out_biases[k].expand_as(cat_ctx)) # 1 x bsz x rnn_size if self.sep_attn: attnin2 = (hid * self.state_att2_gates[k].expand_as(hid) + self.state_att2_biases[k].expand_as(hid)) - attnin2 = F.tanh(attnin2) + attnin2 = torch.tanh(attnin2) ascores = torch.bmm(attnin2.transpose(0, 1), srcfldenc.transpose(1, 2)) # bsz x 1 x nfld wlps_k = F.softmax(torch.cat([self.decoder(state_k.squeeze(0)), @@ -502,79 +501,80 @@ def temp_bs(self, corpus, ss, start_inp, exh0, exc0, srcfieldenc, # N.B. we assume we have a single feature row for each timestep rather than avg # over them as at training time. probably better, but could conceivably average like # at training time. - inps = Variable(torch.LongTensor(K, 4), volatile=True) - for ell in xrange(self.L): + with torch.no_grad(): + inps = torch.LongTensor(K, 4) + for ell in xrange(self.L): + wrd_dist = self.get_next_word_dist(hid, rul_ss, srcfieldenc).cpu() # K x nwords + # disallow unks + wrd_dist[:, unk_idx].zero_() + if not final_state: + wrd_dist[:, eos_idx].zero_() + self.collapse_word_probs(row2tblent, wrd_dist, corpus) + wrd_dist.log_() + if ell > 0: # add previous scores + wrd_dist.add_(curr_scores.expand_as(wrd_dist)) + maxprobs, top2k = torch.topk(wrd_dist.view(-1), 2*K) + cols = wrd_dist.size(1) + # we'll break as soon as is at the top of the beam. + # this ignores but whatever + if top2k[0] == eos_idx: + final_hyp = backtrace(curr_hyps[0]) + final_hyp.append(eos_idx) + return final_hyp, maxprobs[0], len_lps[ss][ell] + + new_hyps, anc_hs, anc_cs = [], [], [] + #inps.data.fill_(pad_idx) + inps.data[:, 1].fill_(w2i[""]) + inps.data[:, 2].fill_(w2i[""]) + inps.data[:, 3].fill_(w2i[""]) + for k in xrange(2*K): + anc, wrd = top2k[k] / cols, top2k[k] % cols + # check if any of the maxes are eop + if wrd == self.eop_idx and ell > 0: + # add len score (and avg over num words incl eop i guess) + wlenscore = maxprobs[k]/(ell+1) + len_lps[ss][ell-1] + if wlenscore > best_hyp_score: + best_hyp_score = wlenscore + best_hyp = backtrace(curr_hyps[anc]) + best_wscore, best_lscore = maxprobs[k], len_lps[ss][ell-1] + else: + curr_scores[len(new_hyps)][0] = maxprobs[k] + if wrd >= self.decoder.out_features: # a copy + tblidx = wrd - self.decoder.out_features + inps.data[len(new_hyps)].copy_(row2feats[tblidx.item()]) + else: + inps.data[len(new_hyps)][0] = wrd if i2w[wrd] in genset else unk_idx + new_hyps.append((wrd, curr_hyps[anc])) + anc_hs.append(hc.narrow(1, anc, 1)) # layers x 1 x rnn_size + anc_cs.append(cc.narrow(1, anc, 1)) # layers x 1 x rnn_size + if len(new_hyps) == K: + break + assert len(new_hyps) == K + curr_hyps = new_hyps + if self.lut.weight.data.is_cuda: + inps = inps.cuda() + embs = self.lut(inps).view(1, K, -1) # 1 x K x nfeats*emb_size + if self.mlpinp: + embs = self.inpmlp(embs) # 1 x K x rnninsz + if self.one_rnn: + cond_embs = torch.cat([embs, self.state_embs[rul_ss].expand(1, K, state_emb_sz)], 2) + hid, (hc, cc) = self.seg_rnns[0](cond_embs, (torch.cat(anc_hs, 1), torch.cat(anc_cs, 1))) + else: + hid, (hc, cc) = self.seg_rnns[rul_ss](embs, (torch.cat(anc_hs, 1), torch.cat(anc_cs, 1))) + # hypotheses of length L still need their end probs added + # N.B. if the falls off the beam we could end up with situations + # where we take an L-length phrase w/ a lower score than 1-word followed by eos. wrd_dist = self.get_next_word_dist(hid, rul_ss, srcfieldenc).cpu() # K x nwords - # disallow unks - wrd_dist[:, unk_idx].zero_() - if not final_state: - wrd_dist[:, eos_idx].zero_() - self.collapse_word_probs(row2tblent, wrd_dist, corpus) wrd_dist.log_() - if ell > 0: # add previous scores - wrd_dist.add_(curr_scores.expand_as(wrd_dist)) - maxprobs, top2k = torch.topk(wrd_dist.view(-1), 2*K) - cols = wrd_dist.size(1) - # we'll break as soon as is at the top of the beam. - # this ignores but whatever - if top2k[0] == eos_idx: - final_hyp = backtrace(curr_hyps[0]) - final_hyp.append(eos_idx) - return final_hyp, maxprobs[0], len_lps[ss][ell] - - new_hyps, anc_hs, anc_cs = [], [], [] - #inps.data.fill_(pad_idx) - inps.data[:, 1].fill_(w2i[""]) - inps.data[:, 2].fill_(w2i[""]) - inps.data[:, 3].fill_(w2i[""]) - for k in xrange(2*K): - anc, wrd = top2k[k] / cols, top2k[k] % cols - # check if any of the maxes are eop - if wrd == self.eop_idx and ell > 0: - # add len score (and avg over num words incl eop i guess) - wlenscore = maxprobs[k]/(ell+1) + len_lps[ss][ell-1] - if wlenscore > best_hyp_score: - best_hyp_score = wlenscore - best_hyp = backtrace(curr_hyps[anc]) - best_wscore, best_lscore = maxprobs[k], len_lps[ss][ell-1] - else: - curr_scores[len(new_hyps)][0] = maxprobs[k] - if wrd >= self.decoder.out_features: # a copy - tblidx = wrd - self.decoder.out_features - inps.data[len(new_hyps)].copy_(row2feats[tblidx]) - else: - inps.data[len(new_hyps)][0] = wrd if i2w[wrd] in genset else unk_idx - new_hyps.append((wrd, curr_hyps[anc])) - anc_hs.append(hc.narrow(1, anc, 1)) # layers x 1 x rnn_size - anc_cs.append(cc.narrow(1, anc, 1)) # layers x 1 x rnn_size - if len(new_hyps) == K: - break - assert len(new_hyps) == K - curr_hyps = new_hyps - if self.lut.weight.data.is_cuda: - inps = inps.cuda() - embs = self.lut(inps).view(1, K, -1) # 1 x K x nfeats*emb_size - if self.mlpinp: - embs = self.inpmlp(embs) # 1 x K x rnninsz - if self.one_rnn: - cond_embs = torch.cat([embs, self.state_embs[rul_ss].expand(1, K, state_emb_sz)], 2) - hid, (hc, cc) = self.seg_rnns[0](cond_embs, (torch.cat(anc_hs, 1), torch.cat(anc_cs, 1))) - else: - hid, (hc, cc) = self.seg_rnns[rul_ss](embs, (torch.cat(anc_hs, 1), torch.cat(anc_cs, 1))) - # hypotheses of length L still need their end probs added - # N.B. if the falls off the beam we could end up with situations - # where we take an L-length phrase w/ a lower score than 1-word followed by eos. - wrd_dist = self.get_next_word_dist(hid, rul_ss, srcfieldenc).cpu() # K x nwords - wrd_dist.log_() - wrd_dist.add_(curr_scores.expand_as(wrd_dist)) - for k in xrange(K): - wlenscore = wrd_dist[k][self.eop_idx]/(self.L+1) + len_lps[ss][self.L-1] - if wlenscore > best_hyp_score: - best_hyp_score = wlenscore - best_hyp = backtrace(curr_hyps[k]) - best_wscore, best_lscore = wrd_dist[k][self.eop_idx], len_lps[ss][self.L-1] - - return best_hyp, best_wscore, best_lscore + wrd_dist.add_(curr_scores.expand_as(wrd_dist)) + for k in xrange(K): + wlenscore = wrd_dist[k][self.eop_idx]/(self.L+1) + len_lps[ss][self.L-1] + if wlenscore > best_hyp_score: + best_hyp_score = wlenscore + best_hyp = backtrace(curr_hyps[k]) + best_wscore, best_lscore = wrd_dist[k][self.eop_idx], len_lps[ss][self.L-1] + + return best_hyp, best_wscore, best_lscore def gen_one(self, templt, h0, c0, srcfieldenc, len_lps, row2tblent, row2feats): @@ -604,7 +604,7 @@ def gen_one(self, templt, h0, c0, srcfieldenc, len_lps, row2tblent, row2feats): phrs.append(i2w[phrs_idxs[ii]]) else: tblidx = phrs_idxs[ii] - nout_wrds - _, _, wordstr = row2tblent[tblidx] + _, _, wordstr = row2tblent[tblidx.item()] if args.verbose: phrs.append(wordstr + " (c)") else: @@ -633,140 +633,142 @@ def temp_ar_bs(self, templt, row2tblent, row2feats, h0, c0, srcfieldenc, len_lps curr_hyps = [(None, None, None)] nfeats = 4 - inps = Variable(torch.LongTensor(K, nfeats), volatile=True) - curr_scores, curr_lens, nulens = torch.zeros(K, 1), torch.zeros(K, 1), torch.zeros(K, 1) - if self.lut.weight.data.is_cuda: - inps = inps.cuda() - curr_scores, curr_lens, nulens = curr_scores.cuda(), curr_lens.cuda(), nulens.cuda() - - # start ar rnn; hackily use bos_idx - rnnsz = self.ar_rnn.hidden_size - thid, (thc, tcc) = self.ar_rnn(self.lut.weight[2].view(1, 1, -1)) # 1 x 1 x rnn_size - - for stidx, ss in enumerate(templt): - final_state = (stidx == len(templt) - 1) - minq = [] # so we can compare stuff of different lengths - rul_ss = ss % self.K + + with torch.no_grad(): + inps = torch.LongTensor(K, nfeats) + curr_scores, curr_lens, nulens = torch.zeros(K, 1), torch.zeros(K, 1), torch.zeros(K, 1) + if self.lut.weight.data.is_cuda: + inps = inps.cuda() + curr_scores, curr_lens, nulens = curr_scores.cuda(), curr_lens.cuda(), nulens.cuda() - if self.one_rnn: - cond_start_inp = torch.cat([start_inp, self.state_embs[rul_ss]], 2) # 1x1x cat_size - hid, (hc, cc) = self.seg_rnns[0](cond_start_inp, (exh0, exc0)) # 1 x 1 x rnn_size - else: - hid, (hc, cc) = self.seg_rnns[rul_ss](start_inp, (exh0, exc0)) # 1 x 1 x rnn_size - hid = hid.expand_as(thid) - hc = hc.expand_as(thc) - cc = cc.expand_as(tcc) + # start ar rnn; hackily use bos_idx + rnnsz = self.ar_rnn.hidden_size + thid, (thc, tcc) = self.ar_rnn(self.lut.weight[2].view(1, 1, -1)) # 1 x 1 x rnn_size - for ell in xrange(self.L+1): - new_hyps, anc_hs, anc_cs, anc_ths, anc_tcs = [], [], [], [], [] - inps.data[:, 1].fill_(w2i[""]) - inps.data[:, 2].fill_(w2i[""]) - inps.data[:, 3].fill_(w2i[""]) + for stidx, ss in enumerate(templt): + final_state = (stidx == len(templt) - 1) + minq = [] # so we can compare stuff of different lengths + rul_ss = ss % self.K - wrd_dist = self.get_next_word_dist(hid + thid, rul_ss, srcfieldenc) # K x nwords - currK = wrd_dist.size(0) - # disallow unks and eos's - wrd_dist[:, unk_idx].zero_() - if not final_state: - wrd_dist[:, eos_idx].zero_() - self.collapse_word_probs(row2tblent, wrd_dist, corpus) - wrd_dist.log_() - curr_scores[:currK].mul_(curr_lens[:currK]) - wrd_dist.add_(curr_scores[:currK].expand_as(wrd_dist)) - wrd_dist.div_((curr_lens[:currK]+1).expand_as(wrd_dist)) - maxprobs, top2k = torch.topk(wrd_dist.view(-1), 2*K) - cols = wrd_dist.size(1) - # used to check for eos here, but maybe we shouldn't + if self.one_rnn: + cond_start_inp = torch.cat([start_inp, self.state_embs[rul_ss]], 2) # 1x1x cat_size + hid, (hc, cc) = self.seg_rnns[0](cond_start_inp, (exh0, exc0)) # 1 x 1 x rnn_size + else: + hid, (hc, cc) = self.seg_rnns[rul_ss](start_inp, (exh0, exc0)) # 1 x 1 x rnn_size + hid = hid.expand_as(thid) + hc = hc.expand_as(thc) + cc = cc.expand_as(tcc) + + for ell in xrange(self.L+1): + new_hyps, anc_hs, anc_cs, anc_ths, anc_tcs = [], [], [], [], [] + inps.data[:, 1].fill_(w2i[""]) + inps.data[:, 2].fill_(w2i[""]) + inps.data[:, 3].fill_(w2i[""]) + + wrd_dist = self.get_next_word_dist(hid + thid, rul_ss, srcfieldenc) # K x nwords + currK = wrd_dist.size(0) + # disallow unks and eos's + wrd_dist[:, unk_idx].zero_() + if not final_state: + wrd_dist[:, eos_idx].zero_() + self.collapse_word_probs(row2tblent, wrd_dist, corpus) + wrd_dist.log_() + curr_scores[:currK].mul_(curr_lens[:currK]) + wrd_dist.add_(curr_scores[:currK].expand_as(wrd_dist)) + wrd_dist.div_((curr_lens[:currK]+1).expand_as(wrd_dist)) + maxprobs, top2k = torch.topk(wrd_dist.view(-1), 2*K) + cols = wrd_dist.size(1) + # used to check for eos here, but maybe we shouldn't + + for k in xrange(2*K): + anc, wrd = top2k[k] / cols, top2k[k] % cols + # check if any of the maxes are eop + if wrd == self.eop_idx and ell > 0 and (not final_state or curr_hyps[anc][0] == eos_idx): + ## add len score (and avg over num words *incl eop*) + ## actually ignoring len score for now + #wlenscore = maxprobs[k]/(ell+1) # + len_lps[ss][ell-1] + #assert not final_state or curr_hyps[anc][0] == eos_idx # seems like should hold... + heapitem = (maxprobs[k], curr_lens[anc][0]+1, curr_hyps[anc], + thc.narrow(1, anc, 1), tcc.narrow(1, anc, 1)) + if len(minq) < K: + heapq.heappush(minq, heapitem) + else: + heapq.heappushpop(minq, heapitem) + elif ell < self.L: # only allow non-eop if < L so far + curr_scores[len(new_hyps)][0] = maxprobs[k] + nulens[len(new_hyps)][0] = curr_lens[anc][0]+1 + if wrd >= self.decoder.out_features: # a copy + tblidx = wrd - self.decoder.out_features + inps.data[len(new_hyps)].copy_(row2feats[tblidx.item()]) + else: + inps.data[len(new_hyps)][0] = wrd if i2w[wrd] in genset else unk_idx - for k in xrange(2*K): - anc, wrd = top2k[k] / cols, top2k[k] % cols - # check if any of the maxes are eop - if wrd == self.eop_idx and ell > 0 and (not final_state or curr_hyps[anc][0] == eos_idx): - ## add len score (and avg over num words *incl eop*) - ## actually ignoring len score for now - #wlenscore = maxprobs[k]/(ell+1) # + len_lps[ss][ell-1] - #assert not final_state or curr_hyps[anc][0] == eos_idx # seems like should hold... - heapitem = (maxprobs[k], curr_lens[anc][0]+1, curr_hyps[anc], - thc.narrow(1, anc, 1), tcc.narrow(1, anc, 1)) - if len(minq) < K: - heapq.heappush(minq, heapitem) - else: - heapq.heappushpop(minq, heapitem) - elif ell < self.L: # only allow non-eop if < L so far - curr_scores[len(new_hyps)][0] = maxprobs[k] - nulens[len(new_hyps)][0] = curr_lens[anc][0]+1 - if wrd >= self.decoder.out_features: # a copy - tblidx = wrd - self.decoder.out_features - inps.data[len(new_hyps)].copy_(row2feats[tblidx]) - else: - inps.data[len(new_hyps)][0] = wrd if i2w[wrd] in genset else unk_idx + new_hyps.append((wrd, ss, curr_hyps[anc])) + anc_hs.append(hc.narrow(1, anc, 1)) # layers x 1 x rnn_size + anc_cs.append(cc.narrow(1, anc, 1)) # layers x 1 x rnn_size + anc_ths.append(thc.narrow(1, anc, 1)) # layers x 1 x rnn_size + anc_tcs.append(tcc.narrow(1, anc, 1)) # layers x 1 x rnn_size + if len(new_hyps) == K: + break - new_hyps.append((wrd, ss, curr_hyps[anc])) - anc_hs.append(hc.narrow(1, anc, 1)) # layers x 1 x rnn_size - anc_cs.append(cc.narrow(1, anc, 1)) # layers x 1 x rnn_size - anc_ths.append(thc.narrow(1, anc, 1)) # layers x 1 x rnn_size - anc_tcs.append(tcc.narrow(1, anc, 1)) # layers x 1 x rnn_size - if len(new_hyps) == K: + if ell >= self.L: # don't want to put in eops break - if ell >= self.L: # don't want to put in eops - break + assert len(new_hyps) == K + curr_hyps = new_hyps + curr_lens.copy_(nulens) + embs = self.lut(inps).view(1, K, -1) # 1 x K x nfeats*emb_size + if self.word_ar: + ar_embs = embs.view(1, K, nfeats, -1)[:, :, 0] # 1 x K x emb_size + else: # ar on fields + ar_embs = embs.view(1, K, nfeats, -1)[:, :, 1] # 1 x K x emb_size + if self.mlpinp: + embs = self.inpmlp(embs) # 1 x K x rnninsz + if self.one_rnn: + cond_embs = torch.cat([embs, self.state_embs[rul_ss].expand( + 1, K, state_emb_sz)], 2) + hid, (hc, cc) = self.seg_rnns[0](cond_embs, (torch.cat(anc_hs, 1), + torch.cat(anc_cs, 1))) + else: + hid, (hc, cc) = self.seg_rnns[rul_ss](embs, (torch.cat(anc_hs, 1), + torch.cat(anc_cs, 1))) + thid, (thc, tcc) = self.ar_rnn(ar_embs, (torch.cat(anc_ths, 1), + torch.cat(anc_tcs, 1))) + + # retrieve topk for this segment (in reverse order) + seghyps = [heapq.heappop(minq) for _ in xrange(len(minq))] + if len(seghyps) == 0: + return -float("inf"), None - assert len(new_hyps) == K - curr_hyps = new_hyps - curr_lens.copy_(nulens) - embs = self.lut(inps).view(1, K, -1) # 1 x K x nfeats*emb_size - if self.word_ar: - ar_embs = embs.view(1, K, nfeats, -1)[:, :, 0] # 1 x K x emb_size - else: # ar on fields - ar_embs = embs.view(1, K, nfeats, -1)[:, :, 1] # 1 x K x emb_size - if self.mlpinp: - embs = self.inpmlp(embs) # 1 x K x rnninsz - if self.one_rnn: - cond_embs = torch.cat([embs, self.state_embs[rul_ss].expand( - 1, K, state_emb_sz)], 2) - hid, (hc, cc) = self.seg_rnns[0](cond_embs, (torch.cat(anc_hs, 1), - torch.cat(anc_cs, 1))) - else: - hid, (hc, cc) = self.seg_rnns[rul_ss](embs, (torch.cat(anc_hs, 1), - torch.cat(anc_cs, 1))) - thid, (thc, tcc) = self.ar_rnn(ar_embs, (torch.cat(anc_ths, 1), - torch.cat(anc_tcs, 1))) - - # retrieve topk for this segment (in reverse order) - seghyps = [heapq.heappop(minq) for _ in xrange(len(minq))] - if len(seghyps) == 0: - return -float("inf"), None - - if len(seghyps) < K and not final_state: - # haaaaaaaaaaaaaaack - ugh = [] - for ick in xrange(K-len(seghyps)): - scoreick, lenick, hypick, thcick, tccick = seghyps[0] - ugh.append((scoreick - 9999999.0 + ick, lenick, hypick, thcick, tccick)) - # break ties for the comparison - ugh.extend(seghyps) - seghyps = ugh - - #assert final_state or len(seghyps) == K - - if final_state: - if len(seghyps) > 0: - scoreb, lenb, hypb, thcb, tccb = seghyps[-1] - return scoreb, backtrace3(hypb) + if len(seghyps) < K and not final_state: + # haaaaaaaaaaaaaaack + ugh = [] + for ick in xrange(K-len(seghyps)): + scoreick, lenick, hypick, thcick, tccick = seghyps[0] + ugh.append((scoreick - 9999999.0 + ick, lenick, hypick, thcick, tccick)) + # break ties for the comparison + ugh.extend(seghyps) + seghyps = ugh + + #assert final_state or len(seghyps) == K + + if final_state: + if len(seghyps) > 0: + scoreb, lenb, hypb, thcb, tccb = seghyps[-1] + return scoreb, backtrace3(hypb) + else: + return -float("inf"), None else: - return -float("inf"), None - else: - thidlst, thclst, tcclst = [], [], [] - for i in xrange(K): - scorei, leni, hypi, thci, tcci = seghyps[K-i-1] - curr_scores[i][0], curr_lens[i][0], curr_hyps[i] = scorei, leni, hypi - thidlst.append(thci[-1:, :, :]) # each is 1 x 1 x rnn_size - thclst.append(thci) # each is layers x 1 x rnn_size - tcclst.append(tcci) # each is layers x 1 x rnn_size + thidlst, thclst, tcclst = [], [], [] + for i in xrange(K): + scorei, leni, hypi, thci, tcci = seghyps[K-i-1] + curr_scores[i][0], curr_lens[i][0], curr_hyps[i] = scorei, leni, hypi + thidlst.append(thci[-1:, :, :]) # each is 1 x 1 x rnn_size + thclst.append(thci) # each is layers x 1 x rnn_size + tcclst.append(tcci) # each is layers x 1 x rnn_size - # we already have the state for the next word b/c we put it thru to also predict eop - thid, (thc, tcc) = torch.cat(thidlst, 1), (torch.cat(thclst, 1), torch.cat(tcclst, 1)) + # we already have the state for the next word b/c we put it thru to also predict eop + thid, (thc, tcc) = torch.cat(thidlst, 1), (torch.cat(thclst, 1), torch.cat(tcclst, 1)) def gen_one_ar(self, templt, h0, c0, srcfieldenc, len_lps, row2tblent, row2feats): @@ -798,7 +800,7 @@ def gen_one_ar(self, templt, h0, c0, srcfieldenc, len_lps, row2tblent, row2feats phrs.append(i2w[widx]) else: tblidx = widx - nout_wrds - _, _, wordstr = row2tblent[tblidx] + _, _, wordstr = row2tblent[tblidx.item()] if args.verbose: phrs.append(wordstr + " (c)") else: @@ -955,7 +957,12 @@ def make_masks(src, pad_idx, max_pool=False): saved_args, saved_state = None, None if len(args.load) > 0: - saved_stuff = torch.load(args.load) + saved_stuff = '' + if args.cuda: + saved_stuff = torch.load(args.load) + else: + saved_stuff = torch.load(args.load, map_location=torch.device('cpu')) + saved_stuff["opt"].cuda = False saved_args, saved_state = saved_stuff["opt"], saved_stuff["state_dict"] for k, v in args.__dict__.iteritems(): if k not in saved_args.__dict__: @@ -1021,20 +1028,20 @@ def train(epoch): fmask, amask = fmask.cuda(), amask.cuda() uniqfields = uniqfields.cuda() - srcenc, srcfieldenc, uniqenc = net.encode(Variable(src), Variable(amask), # bsz x hid - Variable(uniqfields)) + srcenc, srcfieldenc, uniqenc = net.encode(src, amask, # bsz x hid + uniqfields) init_logps, trans_logps = net.trans_logprobs(uniqenc, seqlen) # bsz x K, T-1 x bsz x KxK len_logprobs, _ = net.len_logprobs() - fwd_obs_logps = net.obs_logprobs(Variable(inps), srcenc, srcfieldenc, Variable(fmask), - Variable(combotargs), bsz) # L x T x bsz x K + fwd_obs_logps = net.obs_logprobs(inps, srcenc, srcfieldenc, fmask, + combotargs, bsz) # L x T x bsz x K # get T+1 x bsz x K beta quantities beta, beta_star = infc.just_bwd(trans_logps, fwd_obs_logps, len_logprobs, constraints=cidxs) log_marg = logsumexp1(beta_star[0] + init_logps).sum() # bsz x 1 -> 1 - neglogev -= log_marg.data[0] + neglogev -= log_marg.data.item() lossvar = -log_marg/bsz lossvar.backward() - torch.nn.utils.clip_grad_norm(net.parameters(), args.clip) + torch.nn.utils.clip_grad_norm_(net.parameters(), args.clip) if optalg is not None: optalg.step() else: @@ -1079,23 +1086,22 @@ def test(epoch): inps = inps.cuda() fmask, amask = fmask.cuda(), amask.cuda() uniqfields = uniqfields.cuda() - - srcenc, srcfieldenc, uniqenc = net.encode(Variable(src, volatile=True), # bsz x hid - Variable(amask, volatile=True), - Variable(uniqfields, volatile=True)) - init_logps, trans_logps = net.trans_logprobs(uniqenc, seqlen) # bsz x K, T-1 x bsz x KxK - len_logprobs, _ = net.len_logprobs() - fwd_obs_logps = net.obs_logprobs(Variable(inps, volatile=True), srcenc, - srcfieldenc, Variable(fmask, volatile=True), - Variable(combotargs, volatile=True), - bsz) # L x T x bsz x K - - # get T+1 x bsz x K beta quantities - beta, beta_star = infc.just_bwd(trans_logps, fwd_obs_logps, - len_logprobs, constraints=cidxs) - log_marg = logsumexp1(beta_star[0] + init_logps).sum() # bsz x 1 -> 1 - neglogev -= log_marg.data[0] - nsents += bsz + + with torch.no_grad(): + srcenc, srcfieldenc, uniqenc = net.encode(src, # bsz x hid + amask, uniqfields) + init_logps, trans_logps = net.trans_logprobs(uniqenc, seqlen) # bsz x K, T-1 x bsz x KxK + len_logprobs, _ = net.len_logprobs() + fwd_obs_logps = net.obs_logprobs(inps, srcenc, + srcfieldenc, fmask, + combotargs, bsz) # L x T x bsz x K + + # get T+1 x bsz x K beta quantities + beta, beta_star = infc.just_bwd(trans_logps, fwd_obs_logps, + len_logprobs, constraints=cidxs) + log_marg = logsumexp1(beta_star[0] + init_logps).sum() # bsz x 1 -> 1 + neglogev -= log_marg.data.item() + nsents += bsz print "epoch %d | valid ev %g" % (epoch, neglogev/nsents) return neglogev/nsents @@ -1125,22 +1131,23 @@ def label_train(): fmask, amask = fmask.cuda(), amask.cuda() uniqfields = uniqfields.cuda() - srcenc, srcfieldenc, uniqenc = net.encode(Variable(src, volatile=True), # bsz x hid - Variable(amask, volatile=True), - Variable(uniqfields, volatile=True)) - init_logps, trans_logps = net.trans_logprobs(uniqenc, seqlen) # bsz x K, T-1 x bsz x KxK - len_logprobs, _ = net.len_logprobs() - fwd_obs_logps = net.obs_logprobs(Variable(inps, volatile=True), srcenc, - srcfieldenc, Variable(fmask, volatile=True), - Variable(combotargs, volatile=True), bsz) # LxTxbsz x K - bwd_obs_logprobs = infc.bwd_from_fwd_obs_logprobs(fwd_obs_logps.data) - seqs = infc.viterbi(init_logps.data, trans_logps.data, bwd_obs_logprobs, - [t.data for t in len_logprobs], constraints=fwd_cidxs) - for b in xrange(bsz): - words = [corpus.dictionary.idx2word[w] for w in x[:, b]] - for (start, end, label) in seqs[b]: - print "%s|%d" % (" ".join(words[start:end]), label), - print + with torch.no_grad(): + srcenc, srcfieldenc, uniqenc = net.encode(src, # bsz x hid + amask, + uniqfields) + init_logps, trans_logps = net.trans_logprobs(uniqenc, seqlen) # bsz x K, T-1 x bsz x KxK + len_logprobs, _ = net.len_logprobs() + fwd_obs_logps = net.obs_logprobs(inps, srcenc, + srcfieldenc, fmask, + combotargs, bsz) # LxTxbsz x K + bwd_obs_logprobs = infc.bwd_from_fwd_obs_logprobs(fwd_obs_logps.data) + seqs = infc.viterbi(init_logps.data, trans_logps.data, bwd_obs_logprobs, + [t.data for t in len_logprobs], constraints=fwd_cidxs) + for b in xrange(bsz): + words = [corpus.dictionary.idx2word[w] for w in x[:, b]] + for (start, end, label) in seqs[b]: + print "%s|%d" % (" ".join(words[start:end]), label), + print def gen_from_srctbl(src_tbl, top_temps, coeffs, src_line=None): net.ar = saved_args.ar_after_decay @@ -1157,96 +1164,98 @@ def gen_from_srctbl(src_tbl, top_temps, coeffs, src_line=None): if args.cuda: src_b = src_b.cuda() uniq_b = uniq_b.cuda() + + with torch.no_grad(): + srcenc, srcfieldenc, uniqenc = net.encode(src_b, None, uniq_b) + init_logps, trans_logps = net.trans_logprobs(uniqenc, 2) + _, len_scores = net.len_logprobs() + len_lps = net.lsm(len_scores).data.cpu() + init_logps, trans_logps = init_logps.data.cpu(), trans_logps.data[0].cpu() + inits = net.h0_lin(srcenc) + h0, c0 = torch.tanh(inits[:, :inits.size(1)/2]), inits[:, inits.size(1)/2:] + + nfields = src_b.size(1) + row2tblent = {} + for ff in xrange(nfields): + field, idx = i2w[src_b[0][ff][0]], i2w[src_b[0][ff][1]] + if (field, idx) in src_tbl: + row2tblent[ff] = (field, idx, src_tbl[field, idx]) + else: + row2tblent[ff] = (None, None, None) + + # get row to input feats + row2feats = {} + # precompute wrd stuff + fld_cntr = Counter([key for key, _ in src_tbl]) + for row, (k, idx, wrd) in row2tblent.iteritems(): + if k in w2i: + widx = w2i[wrd] if wrd in w2i else w2i[""] + keyidx = w2i[k] if k in w2i else w2i[""] + idxidx = w2i[idx] + cheatfeat = w2i[""] if fld_cntr[k] == idx else w2i[""] + #row2feats[row] = torch.LongTensor([keyidx, idxidx, cheatfeat]) + row2feats[row] = torch.LongTensor([widx, keyidx, idxidx, cheatfeat]) + + constr_sat = False + # search over all templates + for templt in top_temps: + #print "templt is", templt + # get templt transition prob + tscores = [init_logps[0][templt[0]]] + [tscores.append(trans_logps[0][templt[tt-1]][templt[tt]]) + for tt in xrange(1, len(templt))] + + if net.ar: + phrases, wscore, tokes = net.gen_one_ar(templt, h0[0], c0[0], srcfieldenc, + len_lps, row2tblent, row2feats) + rul_tokes = tokes + else: + phrases, wscore, lscore, tokes, segs = net.gen_one(templt, h0[0], c0[0], + srcfieldenc, len_lps, row2tblent, row2feats) + rul_tokes = tokes - segs # subtract imaginary toke for each + wscore /= tokes + segs = len(templt) + if (rul_tokes < args.min_gen_tokes or segs < args.min_gen_states) and constr_sat: + continue + if rul_tokes >= args.min_gen_tokes and segs >= args.min_gen_states: + constr_sat = True # satisfied our constraint + tscore = sum(tscores[:int(segs)])/segs + if not net.unif_lenps: + tscore += lscore/segs + + gscore = wscore + if not isinstance(wscore, float): + gscore = wscore.cpu() + ascore = coeffs[0]*tscore + coeffs[1]*gscore + if (constr_sat and ascore > best_score) or (not constr_sat and rul_tokes > best_len) or (not constr_sat and rul_tokes == best_len and ascore > best_score): + # take if improves score or not long enough yet and this is longer... + #if ascore > best_score: #or (not constr_sat and rul_tokes > best_len): + best_score, best_tscore, best_gscore = ascore, tscore, gscore + best_phrases, best_templt = phrases, templt + best_len = rul_tokes + #str_phrases = [" ".join(phrs) for phrs in phrases] + #tmpltd = ["%s|%d" % (phrs, templt[k]) for k, phrs in enumerate(str_phrases)] + #statstr = "a=%.2f t=%.2f g=%.2f" % (ascore, tscore, gscore) + #print "%s|||%s" % (" ".join(str_phrases), " ".join(tmpltd)), statstr + #assert False + #assert False - srcenc, srcfieldenc, uniqenc = net.encode(Variable(src_b, volatile=True), None, - Variable(uniq_b, volatile=True)) - init_logps, trans_logps = net.trans_logprobs(uniqenc, 2) - _, len_scores = net.len_logprobs() - len_lps = net.lsm(len_scores).data - init_logps, trans_logps = init_logps.data.cpu(), trans_logps.data[0].cpu() - inits = net.h0_lin(srcenc) - h0, c0 = F.tanh(inits[:, :inits.size(1)/2]), inits[:, inits.size(1)/2:] - - nfields = src_b.size(1) - row2tblent = {} - for ff in xrange(nfields): - field, idx = i2w[src_b[0][ff][0]], i2w[src_b[0][ff][1]] - if (field, idx) in src_tbl: - row2tblent[ff] = (field, idx, src_tbl[field, idx]) - else: - row2tblent[ff] = (None, None, None) - - # get row to input feats - row2feats = {} - # precompute wrd stuff - fld_cntr = Counter([key for key, _ in src_tbl]) - for row, (k, idx, wrd) in row2tblent.iteritems(): - if k in w2i: - widx = w2i[wrd] if wrd in w2i else w2i[""] - keyidx = w2i[k] if k in w2i else w2i[""] - idxidx = w2i[idx] - cheatfeat = w2i[""] if fld_cntr[k] == idx else w2i[""] - #row2feats[row] = torch.LongTensor([keyidx, idxidx, cheatfeat]) - row2feats[row] = torch.LongTensor([widx, keyidx, idxidx, cheatfeat]) - - constr_sat = False - # search over all templates - for templt in top_temps: - #print "templt is", templt - # get templt transition prob - tscores = [init_logps[0][templt[0]]] - [tscores.append(trans_logps[0][templt[tt-1]][templt[tt]]) - for tt in xrange(1, len(templt))] - - if net.ar: - phrases, wscore, tokes = net.gen_one_ar(templt, h0[0], c0[0], srcfieldenc, - len_lps, row2tblent, row2feats) - rul_tokes = tokes - else: - phrases, wscore, lscore, tokes, segs = net.gen_one(templt, h0[0], c0[0], - srcfieldenc, len_lps, row2tblent, row2feats) - rul_tokes = tokes - segs # subtract imaginary toke for each - wscore /= tokes - segs = len(templt) - if (rul_tokes < args.min_gen_tokes or segs < args.min_gen_states) and constr_sat: - continue - if rul_tokes >= args.min_gen_tokes and segs >= args.min_gen_states: - constr_sat = True # satisfied our constraint - tscore = sum(tscores[:int(segs)])/segs - if not net.unif_lenps: - tscore += lscore/segs - - gscore = wscore - ascore = coeffs[0]*tscore + coeffs[1]*gscore - if (constr_sat and ascore > best_score) or (not constr_sat and rul_tokes > best_len) or (not constr_sat and rul_tokes == best_len and ascore > best_score): - # take if improves score or not long enough yet and this is longer... - #if ascore > best_score: #or (not constr_sat and rul_tokes > best_len): - best_score, best_tscore, best_gscore = ascore, tscore, gscore - best_phrases, best_templt = phrases, templt - best_len = rul_tokes - #str_phrases = [" ".join(phrs) for phrs in phrases] - #tmpltd = ["%s|%d" % (phrs, templt[k]) for k, phrs in enumerate(str_phrases)] - #statstr = "a=%.2f t=%.2f g=%.2f" % (ascore, tscore, gscore) - #print "%s|||%s" % (" ".join(str_phrases), " ".join(tmpltd)), statstr + try: + str_phrases = [" ".join(phrs) for phrs in best_phrases] + except TypeError: + # sometimes it puts an actual number in + str_phrases = [" ".join([str(n) if type(n) is int else n for n in phrs]) for phrs in best_phrases] + tmpltd = ["%s|%d" % (phrs, best_templt[kk]) for kk, phrs in enumerate(str_phrases)] + if args.verbose: + print src_line + #print src_tbl + + print "%s|||%s" % (" ".join(str_phrases), " ".join(tmpltd)) + if args.verbose: + statstr = "a=%.2f t=%.2f g=%.2f" % (best_score, best_tscore, best_gscore) + print statstr + print #assert False - #assert False - - try: - str_phrases = [" ".join(phrs) for phrs in best_phrases] - except TypeError: - # sometimes it puts an actual number in - str_phrases = [" ".join([str(n) if type(n) is int else n for n in phrs]) for phrs in best_phrases] - tmpltd = ["%s|%d" % (phrs, best_templt[kk]) for kk, phrs in enumerate(str_phrases)] - if args.verbose: - print src_line - #print src_tbl - - print "%s|||%s" % (" ".join(str_phrases), " ".join(tmpltd)) - if args.verbose: - statstr = "a=%.2f t=%.2f g=%.2f" % (best_score, best_tscore, best_gscore) - print statstr - print - #assert False def gen_from_src(): from template_extraction import extract_from_tagged_data, align_cntr @@ -1281,7 +1290,7 @@ def gen_from_src(): for b in xrange(bsz): src_line = src_lines[corpus.val_mb2linenos[i][b]] - if "wiki" in args.data: + if "wiki" in args.data or "wb" in args.data: src_tbl = get_wikibio_poswrds(src_line.strip().split()) else: src_tbl = get_e2e_poswrds(src_line.strip().split()) @@ -1289,7 +1298,7 @@ def gen_from_src(): gen_from_srctbl(src_tbl, top_temps, coeffs, src_line=src_line) else: for ll, src_line in enumerate(src_lines): - if "wiki" in args.data: + if "wiki" in args.data or "wb" in args.data: src_tbl = get_wikibio_poswrds(src_line.strip().split()) else: src_tbl = get_e2e_poswrds(src_line.strip().split()) @@ -1335,36 +1344,36 @@ def align_stuff(): inps = inps.cuda() fmask, amask = fmask.cuda(), amask.cuda() uniqfields = uniqfields.cuda() - - srcenc, srcfieldenc, uniqenc = net.encode(Variable(src, volatile=True), # bsz x hid - Variable(amask, volatile=True), - Variable(uniqfields, volatile=True)) - init_logps, trans_logps = net.trans_logprobs(uniqenc, seqlen) # bsz x K, T-1 x bsz x KxK - len_logprobs, _ = net.len_logprobs() - fwd_obs_logps = net.obs_logprobs(Variable(inps, volatile=True), srcenc, - srcfieldenc, Variable(fmask, volatile=True), - Variable(combotargs, volatile=True), bsz) # LxTxbsz x K - bwd_obs_logprobs = infc.bwd_from_fwd_obs_logprobs(fwd_obs_logps.data) - seqs = infc.viterbi(init_logps.data, trans_logps.data, bwd_obs_logprobs, - [t.data for t in len_logprobs], constraints=fwd_cidxs) - # get rid of stuff not in our top_temps - for bidx in xrange(bsz): - if tuple(labe for (start, end, labe) in seqs[bidx]) in top_temps: - lineno = corpus.train_mb2linenos[i][bidx] - tgttokes = tgtlines[lineno] - if "wiki" in args.data: - src_tbl = get_wikibio_poswrds(srclines[lineno]) - else: - src_tbl = get_e2e_poswrds(srclines[lineno]) # field, idx -> wrd - wrd2fields = defaultdict(list) - for (field, idx), wrd in src_tbl.iteritems(): - wrd2fields[wrd].append(field) - for (start, end, labe) in seqs[bidx]: - for wrd in tgttokes[start:end]: - if wrd in wrd2fields: - cop_counters[labe].update(wrd2fields[wrd]) - else: - cop_counters[labe]["other"] += 1 + + with torch.no_grad(): + srcenc, srcfieldenc, uniqenc = net.encode(src, # bsz x hid + amask, uniqfields) + init_logps, trans_logps = net.trans_logprobs(uniqenc, seqlen) # bsz x K, T-1 x bsz x KxK + len_logprobs, _ = net.len_logprobs() + fwd_obs_logps = net.obs_logprobs(inps, srcenc, + srcfieldenc, fmask, + combotargs, bsz) # LxTxbsz x K + bwd_obs_logprobs = infc.bwd_from_fwd_obs_logprobs(fwd_obs_logps.data) + seqs = infc.viterbi(init_logps.data, trans_logps.data, bwd_obs_logprobs, + [t.data for t in len_logprobs], constraints=fwd_cidxs) + # get rid of stuff not in our top_temps + for bidx in xrange(bsz): + if tuple(labe for (start, end, labe) in seqs[bidx]) in top_temps: + lineno = corpus.train_mb2linenos[i][bidx] + tgttokes = tgtlines[lineno] + if "wiki" in args.data or "wb" in args.data: + src_tbl = get_wikibio_poswrds(srclines[lineno]) + else: + src_tbl = get_e2e_poswrds(srclines[lineno]) # field, idx -> wrd + wrd2fields = defaultdict(list) + for (field, idx), wrd in src_tbl.iteritems(): + wrd2fields[wrd].append(field) + for (start, end, labe) in seqs[bidx]: + for wrd in tgttokes[start:end]: + if wrd in wrd2fields: + cop_counters[labe].update(wrd2fields[wrd]) + else: + cop_counters[labe]["other"] += 1 return cop_counters diff --git a/infc.py b/infc.py index 23b61f5..a0b1607 100644 --- a/infc.py +++ b/infc.py @@ -4,7 +4,6 @@ import math import torch -from torch.autograd import Variable from utils import logsumexp0, logsumexp2 @@ -20,7 +19,7 @@ def recover_bps(delt, bps, bps_star): for b in xrange(bsz): seq = [] _, last_lab = delt[seqlen][b].max(0) - last_lab = last_lab[0] + last_lab = last_lab.item() curr_idx = seqlen # 1-indexed while True: last_len = bps[curr_idx][b][last_lab] @@ -140,7 +139,7 @@ def just_fwd(pi, trans_logprobs, bwd_obs_logprobs, constraints=None): - bwd_maxlens[t-steps_back:t].expand(steps_back, bsz, K)) if constraints is not None and constraints[t] is not None: - alph_terms = alph_terms + tmask #Variable(tmask) + alph_terms = alph_terms + tmask alph[t] = logsumexp0(alph_terms) # bsz x K @@ -167,7 +166,7 @@ def just_bwd(trans_logprobs, fwd_obs_logprobs, len_logprobs, constraints=None): # we'll be 1-indexed for alphas and betas beta = [None]*(seqlen+1) beta_star = [None]*(seqlen+1) - beta[seqlen] = Variable(trans_logprobs.data.new(bsz, K).zero_()) + beta[seqlen] = trans_logprobs.data.new(bsz, K).zero_() mask = trans_logprobs.data.new(L, bsz, K) for t in xrange(1, seqlen+1): @@ -187,7 +186,7 @@ def just_bwd(trans_logprobs, fwd_obs_logprobs, len_logprobs, constraints=None): + len_terms.unsqueeze(1).expand(steps_fwd, bsz, K)) if constraints is not None and constraints[seqlen-t+1] is not None: - beta_star_terms = beta_star_terms + Variable(tmask) + beta_star_terms = beta_star_terms + tmask beta_star[seqlen-t] = logsumexp0(beta_star_terms) if seqlen-t > 0: diff --git a/labeled_data.py b/labeled_data.py index 9e40d40..6db6312 100644 --- a/labeled_data.py +++ b/labeled_data.py @@ -45,7 +45,7 @@ def __init__(self, path, bsz, thresh=0, add_bos=False, add_eos=False, test=False): self.dictionary = Dictionary() self.bsz = bsz - self.wiki = "wiki" in path + self.wiki = "wiki" in path or "wb" in path train_src = os.path.join(path, "src_train.txt") diff --git a/utils.py b/utils.py index d4845b5..9573936 100644 --- a/utils.py +++ b/utils.py @@ -4,7 +4,6 @@ import math from collections import defaultdict, Counter import torch -from torch.autograd import Variable def logsumexp0(X): """ @@ -165,82 +164,83 @@ def beam_search2(net, corpus, ss, start_inp, exh0, exc0, srcfieldenc, # N.B. we assume we have a single feature row for each timestep rather than avg # over them as at training time. probably better, but could conceivably average like # at training time. - inps = Variable(torch.LongTensor(K, 4), volatile=True) - for ell in xrange(net.L): - wrd_dist = net.get_next_word_dist(hid, rul_ss, srcfieldenc).cpu() # K x nwords - # disallow unks - wrd_dist[:, unk_idx].zero_() - if not final_state: - wrd_dist[:, eos_idx].zero_() - #if not ss == 25 or not ell == 3: - net.collapse_word_probs(row2tblent, wrd_dist) - wrd_dist.log_() - if ell > 0: # add previous scores - wrd_dist.add_(curr_scores.expand_as(wrd_dist)) - maxprobs, top2k = torch.topk(wrd_dist.view(-1), 2*K) - cols = wrd_dist.size(1) - # we'll break as soon as is at the top of the beam. - # this ignores but whatever - if top2k[0] == eos_idx: - final_hyp = backtrace(curr_hyps[0]) - final_hyp.append(eos_idx) - return final_hyp, maxprobs[0], len_lps[ss][ell] + with torch.no_grad(): + inps = torch.LongTensor(K, 4) + for ell in xrange(net.L): + wrd_dist = net.get_next_word_dist(hid, rul_ss, srcfieldenc).cpu() # K x nwords + # disallow unks + wrd_dist[:, unk_idx].zero_() + if not final_state: + wrd_dist[:, eos_idx].zero_() + #if not ss == 25 or not ell == 3: + net.collapse_word_probs(row2tblent, wrd_dist) + wrd_dist.log_() + if ell > 0: # add previous scores + wrd_dist.add_(curr_scores.expand_as(wrd_dist)) + maxprobs, top2k = torch.topk(wrd_dist.view(-1), 2*K) + cols = wrd_dist.size(1) + # we'll break as soon as is at the top of the beam. + # this ignores but whatever + if top2k[0] == eos_idx: + final_hyp = backtrace(curr_hyps[0]) + final_hyp.append(eos_idx) + return final_hyp, maxprobs[0], len_lps[ss][ell] - new_hyps, anc_hs, anc_cs = [], [], [] - #inps.data.fill_(pad_idx) - inps.data[:, 1].fill_(w2i[""]) - inps.data[:, 2].fill_(w2i[""]) - inps.data[:, 3].fill_(w2i[""]) - for k in xrange(2*K): - anc, wrd = top2k[k] / cols, top2k[k] % cols - # check if any of the maxes are eop - if wrd == net.eop_idx and ell > 0: - # add len score (and avg over num words incl eop i guess) - wlenscore = maxprobs[k]/(ell+1) + len_lps[ss][ell-1] - if wlenscore > best_hyp_score: - best_hyp_score = wlenscore - best_hyp = backtrace(curr_hyps[anc]) - best_wscore, best_lscore = maxprobs[k], len_lps[ss][ell-1] - else: - curr_scores[len(new_hyps)][0] = maxprobs[k] - if wrd >= net.decoder.out_features: # a copy - tblidx = wrd - net.decoder.out_features - inps.data[len(new_hyps)].copy_(row2feats[tblidx]) + new_hyps, anc_hs, anc_cs = [], [], [] + #inps.data.fill_(pad_idx) + inps.data[:, 1].fill_(w2i[""]) + inps.data[:, 2].fill_(w2i[""]) + inps.data[:, 3].fill_(w2i[""]) + for k in xrange(2*K): + anc, wrd = top2k[k] / cols, top2k[k] % cols + # check if any of the maxes are eop + if wrd == net.eop_idx and ell > 0: + # add len score (and avg over num words incl eop i guess) + wlenscore = maxprobs[k]/(ell+1) + len_lps[ss][ell-1] + if wlenscore > best_hyp_score: + best_hyp_score = wlenscore + best_hyp = backtrace(curr_hyps[anc]) + best_wscore, best_lscore = maxprobs[k], len_lps[ss][ell-1] else: - inps.data[len(new_hyps)][0] = wrd if i2w[wrd] in genset else unk_idx - new_hyps.append((wrd, curr_hyps[anc])) - anc_hs.append(hc.narrow(1, anc, 1)) # layers x 1 x rnn_size - anc_cs.append(cc.narrow(1, anc, 1)) # layers x 1 x rnn_size - if len(new_hyps) == K: - break - assert len(new_hyps) == K - curr_hyps = new_hyps - if net.lut.weight.data.is_cuda: - inps = inps.cuda() - embs = net.lut(inps).view(1, K, -1) # 1 x K x nfeats*emb_size - if net.mlpinp: - embs = net.inpmlp(embs) # 1 x K x rnninsz - if net.one_rnn: - cond_embs = torch.cat([embs, net.state_embs[rul_ss].expand(1, K, state_emb_sz)], 2) - hid, (hc, cc) = net.seg_rnns[0](cond_embs, (torch.cat(anc_hs, 1), torch.cat(anc_cs, 1))) - else: - hid, (hc, cc) = net.seg_rnns[rul_ss](embs, (torch.cat(anc_hs, 1), torch.cat(anc_cs, 1))) - # hypotheses of length L still need their end probs added - # N.B. if the falls off the beam we could end up with situations - # where we take an L-length phrase w/ a lower score than 1-word followed by eos. - wrd_dist = net.get_next_word_dist(hid, rul_ss, srcfieldenc).cpu() # K x nwords - #wrd_dist = net.get_next_word_dist(hid, ss, srcfieldenc).cpu() # K x nwords - wrd_dist.log_() - wrd_dist.add_(curr_scores.expand_as(wrd_dist)) - for k in xrange(K): - wlenscore = wrd_dist[k][net.eop_idx]/(net.L+1) + len_lps[ss][net.L-1] - if wlenscore > best_hyp_score: - best_hyp_score = wlenscore - best_hyp = backtrace(curr_hyps[k]) - best_wscore, best_lscore = wrd_dist[k][net.eop_idx], len_lps[ss][net.L-1] - #if ss == 80: - # print "going with", best_hyp - return best_hyp, best_wscore, best_lscore + curr_scores[len(new_hyps)][0] = maxprobs[k] + if wrd >= net.decoder.out_features: # a copy + tblidx = wrd - net.decoder.out_features + inps.data[len(new_hyps)].copy_(row2feats[tblidx]) + else: + inps.data[len(new_hyps)][0] = wrd if i2w[wrd] in genset else unk_idx + new_hyps.append((wrd, curr_hyps[anc])) + anc_hs.append(hc.narrow(1, anc, 1)) # layers x 1 x rnn_size + anc_cs.append(cc.narrow(1, anc, 1)) # layers x 1 x rnn_size + if len(new_hyps) == K: + break + assert len(new_hyps) == K + curr_hyps = new_hyps + if net.lut.weight.data.is_cuda: + inps = inps.cuda() + embs = net.lut(inps).view(1, K, -1) # 1 x K x nfeats*emb_size + if net.mlpinp: + embs = net.inpmlp(embs) # 1 x K x rnninsz + if net.one_rnn: + cond_embs = torch.cat([embs, net.state_embs[rul_ss].expand(1, K, state_emb_sz)], 2) + hid, (hc, cc) = net.seg_rnns[0](cond_embs, (torch.cat(anc_hs, 1), torch.cat(anc_cs, 1))) + else: + hid, (hc, cc) = net.seg_rnns[rul_ss](embs, (torch.cat(anc_hs, 1), torch.cat(anc_cs, 1))) + # hypotheses of length L still need their end probs added + # N.B. if the falls off the beam we could end up with situations + # where we take an L-length phrase w/ a lower score than 1-word followed by eos. + wrd_dist = net.get_next_word_dist(hid, rul_ss, srcfieldenc).cpu() # K x nwords + #wrd_dist = net.get_next_word_dist(hid, ss, srcfieldenc).cpu() # K x nwords + wrd_dist.log_() + wrd_dist.add_(curr_scores.expand_as(wrd_dist)) + for k in xrange(K): + wlenscore = wrd_dist[k][net.eop_idx]/(net.L+1) + len_lps[ss][net.L-1] + if wlenscore > best_hyp_score: + best_hyp_score = wlenscore + best_hyp = backtrace(curr_hyps[k]) + best_wscore, best_lscore = wrd_dist[k][net.eop_idx], len_lps[ss][net.L-1] + #if ss == 80: + # print "going with", best_hyp + return best_hyp, best_wscore, best_lscore def calc_pur(counters):