Skip to content

Commit

Permalink
Merge pull request #3 from forrestdavis/master
Browse files Browse the repository at this point in the history
Added embedding flag --view_emb. Gets embeddings for input words
  • Loading branch information
vansky authored Apr 30, 2020
2 parents bf40b9e + f0a1b0b commit 4b85c72
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@
help='adapt model weights during evaluation')
parser.add_argument('--interact', action='store_true',
help='run a trained network interactively')

#For getting embeddings
parser.add_argument('--view_emb', action='store_true',
help='output the word embedding rather than the cell state')

parser.add_argument('--view_layer', type=int, default=-1,
help='which layer should output cell states')
parser.add_argument('--view_hidden', action='store_true',
Expand Down Expand Up @@ -459,6 +464,13 @@ def test_evaluate(test_sentences, data_source):
if args.view_hidden:
# output hidden state
print(*list(hidden[0][args.view_layer].view(1, -1).data.cpu().numpy().flatten()), sep=' ')

elif args.view_emb:
#Get embedding for input word
emb = model.encoder(word_input)
# output embedding
print(*list(emb[0].view(1,-1).data.cpu().numpy().flatten()), sep=' ')

else:
# output cell state
print(*list(hidden[1][args.view_layer].view(1, -1).data.cpu().numpy().flatten()), sep=' ')
Expand Down

0 comments on commit 4b85c72

Please sign in to comment.