Skip to content

Commit

Permalink
add converter
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Jul 6, 2023
1 parent f0ed7eb commit b819f77
Show file tree
Hide file tree
Showing 4 changed files with 591 additions and 15 deletions.
2 changes: 1 addition & 1 deletion onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
heads,
d_model,
dropout=attention_dropout,
is_decoder = False,
is_decoder=False,
max_relative_positions=max_relative_positions,
relative_positions_buckets=relative_positions_buckets,
attn_type="self",
Expand Down
54 changes: 42 additions & 12 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def gen_relative_positions(
return final_mat


def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Expand All @@ -107,7 +109,9 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
relative_position = -torch.min(
relative_position, torch.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
Expand All @@ -120,18 +124,32 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
relative_position_if_large,
torch.full_like(relative_position_if_large, num_buckets - 1),
)

relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
relative_buckets += torch.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets


def compute_bias(query_length, key_length, is_decoder, max_relative_positions, relative_positions_buckets, device=None):
def compute_bias(
query_length,
key_length,
is_decoder,
max_relative_positions,
relative_positions_buckets,
device=None,
):
"""Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
context_position = torch.arange(query_length, dtype=torch.long, device=device)[
:, None
]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
relative_position_bucket = _relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not is_decoder),
Expand All @@ -141,7 +159,6 @@ def compute_bias(query_length, key_length, is_decoder, max_relative_positions, r
return relative_position_bucket



# Help functions to split model dim per head


Expand Down Expand Up @@ -267,7 +284,9 @@ def __init__(
{"keys": torch.tensor([]), "values": torch.tensor([])},
)
if relative_positions_buckets > 0:
self.relative_attention_bias = nn.Embedding(relative_positions_buckets, head_count)
self.relative_attention_bias = nn.Embedding(
relative_positions_buckets, head_count
)
self.relative_positions_embeddings = None
elif max_relative_positions > 0:
# https://arxiv.org/pdf/1803.02155.pdf
Expand Down Expand Up @@ -396,9 +415,20 @@ def forward(
scores = torch.matmul(query, key.transpose(2, 3))

if self.relative_attention_bias is not None:
relative_position_bucket = compute_bias(query.size(2), key.size(2), self.is_decoder, self.max_relative_positions, self.relative_positions_buckets, device=key.device)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
position_bias = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
relative_position_bucket = compute_bias(
query.size(2),
key.size(2),
self.is_decoder,
self.max_relative_positions,
self.relative_positions_buckets,
device=key.device,
)
values = self.relative_attention_bias(
relative_position_bucket
) # shape (query_length, key_length, num_heads)
position_bias = values.permute([2, 0, 1]).unsqueeze(
0
) # shape (1, num_heads, query_length, key_length)
scores.add_(position_bias)

elif self.relative_positions_embeddings is not None:
Expand Down
4 changes: 2 additions & 2 deletions onmt/modules/position_ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def __init__(
self.dropout_1 = nn.Dropout(dropout)
self.activation = ACTIVATION_FUNCTIONS[activation_fn]
self.dropout_2 = nn.Dropout(dropout)
#if activation_fn == "silu": !!!!!!!!!!!!!!!!!!!!!!!!! temporary hack for T5
# if activation_fn == "silu": !!!!!!!!!!!!!!!!!!!!!!!!! temporary hack for T5
self.w_3 = skip_init(
nn.Linear, in_features=d_model, out_features=d_ff, bias=add_ffnbias
)
#else:
# else:
# self.w_3 = None
self.maybe_ckpt = checkpoint if "ffn" in use_ckpting else lambda f, x: f(x)

Expand Down
Loading

0 comments on commit b819f77

Please sign in to comment.