diff --git a/stanza/models/common/char_model.py b/stanza/models/common/char_model.py index b7b6e993bf..e80da10df6 100644 --- a/stanza/models/common/char_model.py +++ b/stanza/models/common/char_model.py @@ -115,6 +115,9 @@ def build_charlm_vocab(path, cutoff=0): vocab = CharVocab(data) # skip cutoff argument because this has been dealt with return vocab +CHARLM_START = "\n" +CHARLM_END = " " + class CharacterLanguageModel(nn.Module): def __init__(self, args, vocab, pad=False, is_forward_lm=True): @@ -162,13 +165,25 @@ def get_representation(self, chars, charoffsets, charlens, char_orig_idx): res = pad_packed_sequence(res, batch_first=True)[0] return res + def per_char_representation(self, words): + device = next(self.parameters()).device + vocab = self.char_vocab() + + all_data = [(vocab.map(word), len(word), idx) for idx, word in enumerate(words)] + all_data.sort(key=itemgetter(1), reverse=True) + chars = [x[0] for x in all_data] + char_lens = [x[1] for x in all_data] + char_tensor = get_long_tensor(chars, len(chars), pad_id=vocab.unit2id(CHARLM_END)).to(device=device) + with torch.no_grad(): + output, _, _ = self.forward(char_tensor, char_lens) + output = [x[:y, :] for x, y in zip(output, char_lens)] + output = unsort(output, [x[2] for x in all_data]) + return output + def build_char_representation(self, sentences): """ Return values from this charlm for a list of list of words """ - CHARLM_START = "\n" - CHARLM_END = " " - forward = self.is_forward_lm vocab = self.char_vocab() device = next(self.parameters()).device @@ -191,6 +206,7 @@ def build_char_representation(self, sentences): all_data.sort(key=itemgetter(2), reverse=True) chars, char_offsets, char_lens, orig_idx = tuple(zip(*all_data)) + # TODO: can this be faster? chars = get_long_tensor(chars, len(all_data), pad_id=vocab.unit2id(CHARLM_END)).to(device=device) with torch.no_grad(): @@ -250,6 +266,27 @@ def load(cls, filename, finetune=False): return cls.from_full_state(state, finetune) return cls.from_full_state(state['model'], finetune) +class CharacterLanguageModelWordAdapter(nn.Module): + """ + Adapts a character model to return embeddings for each character in a word + + TODO: multiple charlms, eg, forward & back + """ + def __init__(self, charlm): + super().__init__() + self.charlm = charlm + + def forward(self, words): + words = [CHARLM_START + x + CHARLM_END for x in words] + rep = self.charlm.per_char_representation(words) + padded_rep = torch.zeros(len(rep), max(x.shape[0] for x in rep), rep[0].shape[1], dtype=rep[0].dtype, device=rep[0].device) + for idx, row in enumerate(rep): + padded_rep[idx, :row.shape[0], :] = row + return padded_rep + + def hidden_dim(self): + return self.charlm.hidden_dim() + class CharacterLanguageModelTrainer(): def __init__(self, model, params, optimizer, criterion, scheduler, epoch=1, global_step=0): self.model = model diff --git a/stanza/models/common/seq2seq_model.py b/stanza/models/common/seq2seq_model.py index 9784d4a23e..66173e1b7a 100644 --- a/stanza/models/common/seq2seq_model.py +++ b/stanza/models/common/seq2seq_model.py @@ -19,7 +19,7 @@ class Seq2SeqModel(nn.Module): """ A complete encoder-decoder model, with optional attention. """ - def __init__(self, args, emb_matrix=None): + def __init__(self, args, emb_matrix=None, contextual_embedding=None): super().__init__() self.vocab_size = args['vocab_size'] self.emb_dim = args['emb_dim'] @@ -32,6 +32,7 @@ def __init__(self, args, emb_matrix=None): self.top = args.get('top', 1e10) self.args = args self.emb_matrix = emb_matrix + self.contextual_embedding = contextual_embedding logger.debug("Building an attentional Seq2Seq model...") logger.debug("Using a Bi-LSTM encoder") @@ -50,7 +51,10 @@ def __init__(self, args, emb_matrix=None): self.emb_drop = nn.Dropout(self.emb_dropout) self.drop = nn.Dropout(self.dropout) self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token) - self.encoder = nn.LSTM(self.emb_dim, self.enc_hidden_dim, self.nlayers, \ + self.input_dim = self.emb_dim + if self.contextual_embedding is not None: + self.input_dim += self.contextual_embedding.hidden_dim() + self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \ bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0) self.decoder = LSTMAttention(self.emb_dim, self.dec_hidden_dim, \ batch_first=True, attn_type=self.args['attn_type']) @@ -158,7 +162,7 @@ def decode(self, dec_inputs, hn, cn, ctx, ctx_mask=None, src=None): return log_probs, dec_hidden - def embed(self, src, src_mask, pos): + def embed(self, src, src_mask, pos, raw): enc_inputs = self.emb_drop(self.embedding(src)) batch_size = enc_inputs.size(0) if self.use_pos: @@ -167,12 +171,18 @@ def embed(self, src, src_mask, pos): enc_inputs = torch.cat([pos_inputs.unsqueeze(1), enc_inputs], dim=1) pos_src_mask = src_mask.new_zeros([batch_size, 1]) src_mask = torch.cat([pos_src_mask, src_mask], dim=1) + if raw is not None and self.contextual_embedding is not None: + raw_inputs = self.contextual_embedding(raw) + if self.use_pos: + raw_zeros = raw_inputs.new_zeros((raw_inputs.shape[0], 1, raw_inputs.shape[2])) + raw_inputs = torch.cat([raw_inputs, raw_zeros], dim=1) + enc_inputs = torch.cat([enc_inputs, raw_inputs], dim=2) src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1)) return enc_inputs, batch_size, src_lens, src_mask - def forward(self, src, src_mask, tgt_in, pos=None): + def forward(self, src, src_mask, tgt_in, pos=None, raw=None): # prepare for encoder/decoder - enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos) + enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw) # encode source h_in, (hn, cn) = self.encode(enc_inputs, src_lens) @@ -194,9 +204,9 @@ def get_log_prob(self, logits): return log_probs return log_probs.view(logits.size(0), logits.size(1), logits.size(2)) - def predict_greedy(self, src, src_mask, pos=None): + def predict_greedy(self, src, src_mask, pos=None, raw=None): """ Predict with greedy decoding. """ - enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos) + enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw) # encode source h_in, (hn, cn) = self.encode(enc_inputs, src_lens) @@ -231,12 +241,12 @@ def predict_greedy(self, src, src_mask, pos=None): output_seqs[i].append(token) return output_seqs, edit_logits - def predict(self, src, src_mask, pos=None, beam_size=5): + def predict(self, src, src_mask, pos=None, beam_size=5, raw=None): """ Predict with beam search. """ if beam_size == 1: - return self.predict_greedy(src, src_mask, pos=pos) + return self.predict_greedy(src, src_mask, pos, raw) - enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos) + enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw) # (1) encode source h_in, (hn, cn) = self.encode(enc_inputs, src_lens) diff --git a/stanza/models/lemma/data.py b/stanza/models/lemma/data.py index 1f23125289..1e258220cb 100644 --- a/stanza/models/lemma/data.py +++ b/stanza/models/lemma/data.py @@ -77,7 +77,7 @@ def preprocess(self, data, char_vocab, pos_vocab, args): tgt = list(d[2]) tgt_in = char_vocab.map([constant.SOS] + tgt) tgt_out = char_vocab.map(tgt + [constant.EOS]) - processed += [[src, tgt_in, tgt_out, pos, edit_type]] + processed += [[src, tgt_in, tgt_out, pos, edit_type, d[0]]] return processed def __len__(self): @@ -92,7 +92,7 @@ def __getitem__(self, key): batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) - assert len(batch) == 5 + assert len(batch) == 6 # sort all fields by lens for easy RNN operations lens = [len(x) for x in batch[0]] @@ -106,8 +106,9 @@ def __getitem__(self, key): tgt_out = get_long_tensor(batch[2], batch_size) pos = torch.LongTensor(batch[3]) edits = torch.LongTensor(batch[4]) + text = batch[5] assert tgt_in.size(1) == tgt_out.size(1), "Target input and output sequence sizes do not match." - return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx + return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx, text def __iter__(self): for i in range(self.__len__()): @@ -124,4 +125,4 @@ def resolve_none(self, data): for feat_idx in range(len(data[tok_idx])): if data[tok_idx][feat_idx] is None: data[tok_idx][feat_idx] = '_' - return data \ No newline at end of file + return data diff --git a/stanza/models/lemma/trainer.py b/stanza/models/lemma/trainer.py index d0a9d365d8..2730be7e05 100644 --- a/stanza/models/lemma/trainer.py +++ b/stanza/models/lemma/trainer.py @@ -12,7 +12,9 @@ import torch.nn.init as init import stanza.models.common.seq2seq_constant as constant +from stanza.models.common.foundation_cache import load_charlm from stanza.models.common.seq2seq_model import Seq2SeqModel +from stanza.models.common.char_model import CharacterLanguageModelWordAdapter from stanza.models.common import utils, loss from stanza.models.lemma import edit from stanza.models.lemma.vocab import MultiVocab @@ -23,18 +25,24 @@ def unpack_batch(batch, device): """ Unpack a batch from the data loader. """ inputs = [b.to(device) if b is not None else None for b in batch[:6]] orig_idx = batch[6] - return inputs, orig_idx + text = batch[7] + return inputs, orig_idx, text class Trainer(object): """ A trainer for training models. """ - def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None): + def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None, foundation_cache=None): + self.unsaved_modules = [] if model_file is not None: # load everything from file - self.load(model_file) + self.load(model_file, args, foundation_cache) else: # build model from scratch self.args = args - self.model = None if args['dict_only'] else Seq2SeqModel(args, emb_matrix=emb_matrix) + if args['dict_only']: + self.model = None + else: + self.model, charmodel = self.build_seq2seq(args, emb_matrix, foundation_cache) + self.add_unsaved_module("charmodel", charmodel) self.vocab = vocab # dict-based components self.word_dict = dict() @@ -48,9 +56,21 @@ def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, devi self.crit = loss.SequenceLoss(self.vocab['char'].size).to(device) self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr']) + def build_seq2seq(self, args, emb_matrix, foundation_cache): + charmodel = None + if args is not None and args.get('charlm_forward_file', None): + charmodel_forward = load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache) + charmodel = CharacterLanguageModelWordAdapter(charmodel_forward) + model = Seq2SeqModel(args, emb_matrix=emb_matrix, contextual_embedding=charmodel) + return model, charmodel + + def add_unsaved_module(self, name, module): + self.unsaved_modules += [name] + setattr(self, name, module) + def update(self, batch, eval=False): device = next(self.model.parameters()).device - inputs, orig_idx = unpack_batch(batch, device) + inputs, orig_idx, text = unpack_batch(batch, device) src, src_mask, tgt_in, tgt_out, pos, edits = inputs if eval: @@ -58,7 +78,7 @@ def update(self, batch, eval=False): else: self.model.train() self.optimizer.zero_grad() - log_probs, edit_logits = self.model(src, src_mask, tgt_in, pos) + log_probs, edit_logits = self.model(src, src_mask, tgt_in, pos, raw=text) if self.args.get('edit', False): assert edit_logits is not None loss = self.crit(log_probs.view(-1, self.vocab['char'].size), tgt_out.view(-1), \ @@ -76,12 +96,12 @@ def update(self, batch, eval=False): def predict(self, batch, beam_size=1): device = next(self.model.parameters()).device - inputs, orig_idx = unpack_batch(batch, device) + inputs, orig_idx, text = unpack_batch(batch, device) src, src_mask, tgt, tgt_mask, pos, edits = inputs self.model.eval() batch_size = src.size(0) - preds, edit_logits = self.model.predict(src, src_mask, pos=pos, beam_size=beam_size) + preds, edit_logits = self.model.predict(src, src_mask, pos=pos, beam_size=beam_size, raw=text) pred_seqs = [self.vocab['char'].unmap(ids) for ids in preds] # unmap to tokens pred_seqs = utils.prune_decoded_seqs(pred_seqs) pred_tokens = ["".join(seq) for seq in pred_seqs] # join chars to be tokens @@ -182,7 +202,13 @@ def ensemble(self, pairs, other_preds): lemmas.append(lemma) return lemmas - def save(self, filename): + def save(self, filename, skip_modules=True): + model_state = self.model.state_dict() + # skip saving modules like the pretrained charlm + if skip_modules: + skipped = [k for k in model_state.keys() if k.split('.')[0] in self.unsaved_modules] + for k in skipped: + del model_state[k] params = { 'model': self.model.state_dict() if self.model is not None else None, 'dicts': (self.word_dict, self.composite_dict), @@ -193,16 +219,20 @@ def save(self, filename): torch.save(params, filename, _use_new_zipfile_serialization=False) logger.info("Model saved to {}".format(filename)) - def load(self, filename): + def load(self, filename, args, foundation_cache): try: checkpoint = torch.load(filename, lambda storage, loc: storage) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise self.args = checkpoint['config'] + if args is not None: + self.args['charlm_forward_file'] = args['charlm_forward_file'] + self.args['charlm_backward_file'] = args['charlm_backward_file'] self.word_dict, self.composite_dict = checkpoint['dicts'] if not self.args['dict_only']: - self.model = Seq2SeqModel(self.args) + self.model, charmodel = self.build_seq2seq(self.args, None, foundation_cache) + self.add_unsaved_module("charmodel", charmodel) # could remove strict=False after rebuilding all models, # or could switch to 1.6.0 torch with the buffer in seq2seq persistent=False self.model.load_state_dict(checkpoint['model'], strict=False) diff --git a/stanza/models/lemmatizer.py b/stanza/models/lemmatizer.py index 19dd2e441a..88914cc14f 100644 --- a/stanza/models/lemmatizer.py +++ b/stanza/models/lemmatizer.py @@ -61,6 +61,9 @@ def build_argparse(): parser.add_argument('--no_pos', dest='pos', action='store_false', help='Do not use UPOS in lemmatization. By default UPOS is used.') parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in lemmatization. By default copy mechanism is used to improve generalization.') + parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") + parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm") + parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.') parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') @@ -236,7 +239,7 @@ def evaluate(args): model_file = build_model_filename(args) # load model - trainer = Trainer(model_file=model_file, device=args['device']) + trainer = Trainer(model_file=model_file, device=args['device'], args=args) loaded_args, vocab = trainer.args, trainer.vocab for k in args: diff --git a/stanza/pipeline/lemma_processor.py b/stanza/pipeline/lemma_processor.py index b04e3909f2..d2706d2d3d 100644 --- a/stanza/pipeline/lemma_processor.py +++ b/stanza/pipeline/lemma_processor.py @@ -50,7 +50,9 @@ def _set_up_model(self, config, pipeline, device): # we make this an option, not the default self.store_results = config.get('store_results', False) self._use_identity = False - self._trainer = Trainer(model_file=config['model_path'], device=device) + args = {'charlm_forward_file': config.get('forward_charlm_path', None), + 'charlm_backward_file': config.get('backward_charlm_path', None)} + self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache) def _set_up_requires(self): self._pretagged = self._config.get('pretagged', None)