From 1b90823c9b0187c19b0d1e5ba0e6a604d9e82e8a Mon Sep 17 00:00:00 2001 From: Marten van Schijndel Date: Wed, 4 Sep 2019 23:05:34 -0400 Subject: [PATCH] freeze embeddings --- main.py | 5 ++++- model.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 7fcbc90..0694727 100644 --- a/main.py +++ b/main.py @@ -90,6 +90,8 @@ help='test a trained LM') parser.add_argument('--load_checkpoint', action='store_true', help='continue training a pre-trained LM') +parser.add_argument('--freeze_embedding', action='store_true', + help='do not train embedding weights') parser.add_argument('--single', action='store_true', help='use only a single GPU (even if more are available)') parser.add_argument('--multisentence_test', action='store_true', @@ -238,7 +240,8 @@ def batchify(data, bsz): ntokens = len(corpus.dictionary) model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, embedding_file=args.embedding_file, - dropout=args.dropout, tie_weights=args.tied).to(device) + dropout=args.dropout, tie_weights=args.tied, + freeze_embedding=args.freeze_embedding).to(device) # after load the rnn params are not a continuous chunk of memory # this makes them a continuous chunk, and will speed up forward pass diff --git a/model.py b/model.py index a276505..07351bc 100644 --- a/model.py +++ b/model.py @@ -8,7 +8,7 @@ class RNNModel(nn.Module): """Container module with an encoder, a recurrent module, and a decoder.""" def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, - embedding_file=None, dropout=0.5, tie_weights=False): + embedding_file=None, dropout=0.5, tie_weights=False, freeze_embedding=False): super(RNNModel, self).__init__() self.drop = nn.Dropout(dropout) if embedding_file: @@ -28,7 +28,10 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout) self.decoder = nn.Linear(nhid, ntoken) - self.init_weights() + self.init_weights(freeze_embedding) + if freeze_embedding: + for param in self.encoder.parameters(): + param.requires_grad = False # Optionally tie weights as in: # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2017) @@ -46,10 +49,11 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, self.nhid = nhid self.nlayers = nlayers - def init_weights(self): + def init_weights(self, freeze_embedding): """ Initialize encoder and decoder weights """ initrange = 0.1 - self.encoder.weight.data.uniform_(-initrange, initrange) + if not freeze_embedding: + self.encoder.weight.data.uniform_(-initrange, initrange) self.decoder.bias.data.fill_(0) self.decoder.weight.data.uniform_(-initrange, initrange)