Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
vansky committed Jun 14, 2019
2 parents dc34627 + 8078162 commit 40e49b6
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@
parser.add_argument('--interact', action='store_true',
help='run a trained network interactively')
parser.add_argument('--view_layer', type=int, default=-1,
help='which layer should output top n guesses')
help='which layer should output cell states')
parser.add_argument('--view_hidden', action='store_true',
help='output the hidden state rather than the cell state')

parser.add_argument('--words', action='store_true',
help='evaluate word-level complexities (instead of sentence-level loss)')
Expand Down Expand Up @@ -166,7 +168,7 @@ def batchify(data, bsz):
For instance, with the alphabet as the sequence and batch size 4, we'd get
a g m s
b h n t
c i state u
c i o u
d j p v
e k q w
f l r x
Expand Down Expand Up @@ -441,7 +443,12 @@ def test_evaluate(test_sentences, data_source):
if targ_word != '<eos>':
# don't output <eos> markers to align with input
# output raw activations
print(*list(hidden[0][args.view_layer].view(1, -1).data.cpu().numpy().flatten()), sep=' ')
if args.view_hidden:
# output hidden state
print(*list(hidden[0][args.view_layer].view(1, -1).data.cpu().numpy().flatten()), sep=' ')
else:
# output cell state
print(*list(hidden[1][args.view_layer].view(1, -1).data.cpu().numpy().flatten()), sep=' ')
else:
data = data.unsqueeze(1) # only needed when a single sentence is being processed
output, hidden = model(data, hidden)
Expand Down Expand Up @@ -603,6 +610,15 @@ def train():
except NameError:
pass

n_rnn_param = sum([p.nelement() for p in model.rnn.parameters()])
n_enc_param = sum([p.nelement() for p in model.encoder.parameters()])
n_dec_param = sum([p.nelement() for p in model.decoder.parameters()])

print('#rnn params = {}'.format(n_rnn_param))
print('#enc params = {}'.format(n_enc_param))
print('#dec params = {}'.format(n_dec_param))


# Then run interactively
print('Running in interactive mode. Ctrl+c to exit')
if '<unk>' not in corpus.dictionary.word2idx:
Expand Down

0 comments on commit 40e49b6

Please sign in to comment.