-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_embeddings.py
43 lines (30 loc) · 1.29 KB
/
model_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch.nn as nn
class ModelEmbeddings(nn.Module):
"""
Class that converts input words to their embeddings.
"""
def __init__(self, embed_size, vocab):
"""
Init the Embedding layers.
@param embed_size (int): Embedding size (dimensionality)
@param vocab (Vocab): Vocabulary object containing src and tgt languages
See vocab.py for documentation.
"""
super(ModelEmbeddings, self).__init__()
self.embed_size = embed_size
# default values
self.source = None
self.target = None
src_pad_token_idx = vocab.src['<pad>']
tgt_pad_token_idx = vocab.tgt['<pad>']
self.source = nn.Embedding(num_embeddings=len(vocab.src),
embedding_dim=self.embed_size,
padding_idx=src_pad_token_idx)
self.target = nn.Embedding(num_embeddings=len(vocab.tgt),
embedding_dim=self.embed_size,
padding_idx=tgt_pad_token_idx)
### Use the following docs to properly initialize these variables:
### Embedding Layer:
### https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding