Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP v1 - deprecated] entmax 1.5 for attention and outputs, faster implementation of sparsemax #1541

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
and creates each encoder and decoder accordingly.
"""
import re
from functools import partial
import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_
Expand All @@ -14,6 +15,7 @@
from onmt.decoders import str2dec

from onmt.modules import Embeddings, VecEmbedding, CopyGenerator
from onmt.modules.sparse_activations import LogSparsemax, LogEntmax15
from onmt.modules.util_class import Cast
from onmt.utils.misc import use_gpu
from onmt.utils.logging import logger
Expand Down Expand Up @@ -173,10 +175,11 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):

# Build Generator.
if not model_opt.copy_attn:
if model_opt.generator_function == "sparsemax":
gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1)
else:
gen_func = nn.LogSoftmax(dim=-1)
gen_funcs = {"softmax": nn.LogSoftmax,
"sparsemax": partial(LogSparsemax, k=512),
"entmax15": partial(LogEntmax15, k=512)}
assert model_opt.generator_function in gen_funcs
gen_func = gen_funcs[model_opt.generator_function](dim=-1)
generator = nn.Sequential(
nn.Linear(model_opt.dec_rnn_size,
len(fields["tgt"].base_field.vocab)),
Expand Down
79 changes: 35 additions & 44 deletions onmt/modules/global_attention.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Global attention modules (Luong / Bahdanau)"""
import torch
import torch.nn as nn
import torch.nn.functional as F

from onmt.modules.sparse_activations import sparsemax
from entmax import Sparsemax, Entmax15
from onmt.utils.misc import aeq, sequence_mask

# This class is mainly used by decoder.py for RNNs but also
Expand Down Expand Up @@ -77,9 +76,12 @@ def __init__(self, dim, coverage=False, attn_type="dot",
"Please select a valid attention type (got {:s}).".format(
attn_type))
self.attn_type = attn_type
assert attn_func in ["softmax", "sparsemax"], (
"Please select a valid attention function.")
self.attn_func = attn_func
attn_funcs = {"softmax": nn.Softmax,
"sparsemax": Sparsemax,
"entmax15": Entmax15}
assert attn_func in attn_funcs, \
"Unknown attention function {}".format(attn_func)
self.attn_func = attn_funcs[attn_func](dim=-1)

if self.attn_type == "general":
self.linear_in = nn.Linear(dim, dim, bias=False)
Expand Down Expand Up @@ -135,11 +137,11 @@ def score(self, h_t, h_s):

return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)

def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
def forward(self, query, memory_bank, memory_lengths=None, coverage=None):
"""

Args:
source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
query (FloatTensor): query vectors ``(batch, tgt_len, dim)``
memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
memory_lengths (LongTensor): the source context lengths ``(batch,)``
coverage (FloatTensor): None (not supported yet)
Expand All @@ -152,76 +154,65 @@ def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
``(tgt_len, batch, src_len)``
"""

# one step input
if source.dim() == 2:
one_step = True
source = source.unsqueeze(1)
else:
one_step = False
one_step = query.dim() == 2
if one_step:
# compute attention for only one tgt step on this call
query = query.unsqueeze(1)

batch, source_l, dim = memory_bank.size()
batch_, target_l, dim_ = source.size()
aeq(batch, batch_)
aeq(dim, dim_)
aeq(self.dim, dim)
batch, source_l, memory_dim = memory_bank.size()
query_batch, target_l, query_dim = query.size()
aeq(batch, query_batch)
aeq(self.dim, memory_dim, query_dim)
if coverage is not None:
batch_, source_l_ = coverage.size()
aeq(batch, batch_)
aeq(source_l, source_l_)
coverage_batch, coverage_l = coverage.size()
aeq(batch, coverage_batch)
aeq(source_l, coverage_l)

if coverage is not None:
cover = coverage.view(-1).unsqueeze(1)
memory_bank += self.linear_cover(cover).view_as(memory_bank)
memory_bank = torch.tanh(memory_bank)

# compute attention scores, as in Luong et al.
align = self.score(source, memory_bank)
align = self.score(query, memory_bank)

if memory_lengths is not None:
mask = sequence_mask(memory_lengths, max_len=align.size(-1))
mask = mask.unsqueeze(1) # Make it broadcastable.
align.masked_fill_(~mask, -float('inf'))

# Softmax or sparsemax to normalize attention weights
if self.attn_func == "softmax":
align_vectors = F.softmax(align.view(batch*target_l, source_l), -1)
else:
align_vectors = sparsemax(align.view(batch*target_l, source_l), -1)
n_vectors = batch * target_l
# normalize attention weights
align_vectors = self.attn_func(align.view(n_vectors, source_l))
align_vectors = align_vectors.view(batch, target_l, source_l)

# each context vector c_t is the weighted average
# over all the source hidden states
c = torch.bmm(align_vectors, memory_bank)

# concatenate
concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2)
attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
concat_c = torch.cat([c, query], 2).view(n_vectors, self.dim * 2)
attn_h = self.linear_out(concat_c).view(batch, target_l, self.dim)
if self.attn_type in ["general", "dot"]:
attn_h = torch.tanh(attn_h)

if one_step:
attn_h = attn_h.squeeze(1)
align_vectors = align_vectors.squeeze(1)

# Check output sizes
batch_, dim_ = attn_h.size()
aeq(batch, batch_)
aeq(dim, dim_)
batch_, source_l_ = align_vectors.size()
aeq(batch, batch_)
aeq(source_l, source_l_)

align_batch, align_src_l = align_vectors.size()
attn_batch, attn_dim = attn_h.size()
else:
attn_h = attn_h.transpose(0, 1).contiguous()
align_vectors = align_vectors.transpose(0, 1).contiguous()
# Check output sizes
target_l_, batch_, dim_ = attn_h.size()
aeq(target_l, target_l_)
aeq(batch, batch_)
aeq(dim, dim_)
target_l_, batch_, source_l_ = align_vectors.size()
aeq(target_l, target_l_)
aeq(batch, batch_)
aeq(source_l, source_l_)
attn_l, attn_batch, attn_dim = attn_h.size()
align_tgt_l, align_batch, align_src_l = align_vectors.size()
aeq(target_l, attn_l, align_tgt_l)

aeq(batch, attn_batch, align_batch)
aeq(source_l, align_src_l)
aeq(self.dim, attn_dim)

return attn_h, align_vectors
99 changes: 7 additions & 92 deletions onmt/modules/sparse_activations.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,12 @@
"""
An implementation of sparsemax (Martins & Astudillo, 2016). See
:cite:`DBLP:journals/corr/MartinsA16` for detailed description.

By Ben Peters and Vlad Niculae
"""

import torch
from torch.autograd import Function
import torch.nn as nn


def _make_ix_like(input, dim=0):
d = input.size(dim)
rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype)
view = [1] * input.dim()
view[0] = -1
return rho.view(view).transpose(0, dim)


def _threshold_and_support(input, dim=0):
"""Sparsemax building block: compute the threshold

Args:
input: any dimension
dim: dimension along which to apply the sparsemax

Returns:
the threshold value
"""

input_srt, _ = torch.sort(input, descending=True, dim=dim)
input_cumsum = input_srt.cumsum(dim) - 1
rhos = _make_ix_like(input, dim)
support = rhos * input_srt > input_cumsum

support_size = support.sum(dim=dim).unsqueeze(dim)
tau = input_cumsum.gather(dim, support_size - 1)
tau /= support_size.to(input.dtype)
return tau, support_size


class SparsemaxFunction(Function):

@staticmethod
def forward(ctx, input, dim=0):
"""sparsemax: normalizing sparse transform (a la softmax)

Parameters:
input (Tensor): any shape
dim: dimension along which to apply sparsemax

Returns:
output (Tensor): same shape as input
"""
ctx.dim = dim
max_val, _ = input.max(dim=dim, keepdim=True)
input -= max_val # same numerical stability trick as for softmax
tau, supp_size = _threshold_and_support(input, dim=dim)
output = torch.clamp(input - tau, min=0)
ctx.save_for_backward(supp_size, output)
return output

@staticmethod
def backward(ctx, grad_output):
supp_size, output = ctx.saved_tensors
dim = ctx.dim
grad_input = grad_output.clone()
grad_input[output == 0] = 0

v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze()
v_hat = v_hat.unsqueeze(dim)
grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
return grad_input, None


sparsemax = SparsemaxFunction.apply


class Sparsemax(nn.Module):

def __init__(self, dim=0):
self.dim = dim
super(Sparsemax, self).__init__()

def forward(self, input):
return sparsemax(input, self.dim)
from entmax import Entmax15, Sparsemax


class LogSparsemax(nn.Module):
class LogSparsemax(Sparsemax):
def forward(self, *args, **kwargs):
return torch.log(super().forward(*args, **kwargs))

def __init__(self, dim=0):
self.dim = dim
super(LogSparsemax, self).__init__()

def forward(self, input):
return torch.log(sparsemax(input, self.dim))
class LogEntmax15(Entmax15):
def forward(self, *args, **kwargs):
return torch.log(super().forward(*args, **kwargs))
76 changes: 0 additions & 76 deletions onmt/modules/sparse_losses.py

This file was deleted.

5 changes: 3 additions & 2 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def model_opts(parser):
help="The attention type to use: "
"dotprod or general (Luong) or MLP (Bahdanau)")
group.add('--global_attention_function', '-global_attention_function',
type=str, default="softmax", choices=["softmax", "sparsemax"])
type=str, default="softmax",
choices=["softmax", "sparsemax", "entmax15"])
group.add('--self_attn_type', '-self_attn_type',
type=str, default="scaled-dot",
help='Self attention type in Transformer decoder '
Expand All @@ -163,7 +164,7 @@ def model_opts(parser):
help="The copy attention type to use. Leave as None to use "
"the same as -global_attention.")
group.add('--generator_function', '-generator_function', default="softmax",
choices=["softmax", "sparsemax"],
choices=["softmax", "sparsemax", "entmax15"],
help="Which function to use for generating "
"probabilities over the target vocabulary (choices: "
"softmax, sparsemax)")
Expand Down
Loading