Skip to content

Commit

Permalink
freeze embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
vansky committed Sep 5, 2019
1 parent e9d8b82 commit 1b90823
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
help='test a trained LM')
parser.add_argument('--load_checkpoint', action='store_true',
help='continue training a pre-trained LM')
parser.add_argument('--freeze_embedding', action='store_true',
help='do not train embedding weights')
parser.add_argument('--single', action='store_true',
help='use only a single GPU (even if more are available)')
parser.add_argument('--multisentence_test', action='store_true',
Expand Down Expand Up @@ -238,7 +240,8 @@ def batchify(data, bsz):
ntokens = len(corpus.dictionary)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
args.nlayers, embedding_file=args.embedding_file,
dropout=args.dropout, tie_weights=args.tied).to(device)
dropout=args.dropout, tie_weights=args.tied,
freeze_embedding=args.freeze_embedding).to(device)

# after load the rnn params are not a continuous chunk of memory
# this makes them a continuous chunk, and will speed up forward pass
Expand Down
12 changes: 8 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""

def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers,
embedding_file=None, dropout=0.5, tie_weights=False):
embedding_file=None, dropout=0.5, tie_weights=False, freeze_embedding=False):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(dropout)
if embedding_file:
Expand All @@ -28,7 +28,10 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers,
self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)

self.init_weights()
self.init_weights(freeze_embedding)
if freeze_embedding:
for param in self.encoder.parameters():
param.requires_grad = False

# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2017)
Expand All @@ -46,10 +49,11 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers,
self.nhid = nhid
self.nlayers = nlayers

def init_weights(self):
def init_weights(self, freeze_embedding):
""" Initialize encoder and decoder weights """
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
if not freeze_embedding:
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)

Expand Down

0 comments on commit 1b90823

Please sign in to comment.