Skip to content

Commit

Permalink
Support for Sliding window attention mask (Mistral architecture) (Ope…
Browse files Browse the repository at this point in the history
…nNMT#2487)

* support sliding window
* trim kv cache to sliding window
* apex fusedRMSNorm
  • Loading branch information
vince62s authored Oct 19, 2023
1 parent fb2efc8 commit cb35810
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 42 deletions.
28 changes: 8 additions & 20 deletions eval_llm/MMLU-FR/run_mmlu_opennmt_fr.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
"""
Code taken and adapted from https://github.com/FranxYao/chain-of-thought-hub
"""

import json
import os
import time
import pandas as pd
from onmt.utils.logging import init_logger
from onmt.translate.translator import build_translator
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.inputters.inputter import IterOnDevice
from onmt.transforms import get_transforms_cls
from onmt.constants import CorpusTask
from onmt.inference_engine import InferenceEnginePY
import onmt.opts as opts
from onmt.utils.parse import ArgumentParser
from onmt.utils.misc import use_gpu, set_random_seed
Expand Down Expand Up @@ -148,9 +145,7 @@ def evaluate(opt):
output_filename = os.path.join(dir_name, "mmlu_results_%s.fr.json" % base_name[:-3])

# Build the translator (along with the model)
translator = build_translator(opt, logger=logger, report_score=True)
# Build the transforms (along with the tokenizer)
transforms_cls = get_transforms_cls(opt._all_transform)
engine = InferenceEnginePY(opt)

data_dir = "eval_llm/MMLU-FR/data/"
ntrain = 5 # nshots from dev
Expand Down Expand Up @@ -180,30 +175,23 @@ def evaluate(opt):
records.append({"prompt": prompt, "answer": label})
src.append(prompt.replace("\n", "⦅newline⦆"))

infer_iter = build_dynamic_dataset_iter(
opt, transforms_cls, translator.vocabs, task=CorpusTask.INFER, src=src
)

infer_iter = IterOnDevice(infer_iter, opt.gpu)
scores, preds = engine.infer_list(src)

scores, preds = translator._translate(
infer_iter,
transform=infer_iter.transform,
attn_debug=opt.attn_debug,
align_debug=opt.align_debug,
)
pred_answers = [
x.lstrip() for sublist in preds for x in sublist
] # flatten the list of list

gold_answers = [record["answer"] for record in records]
run_results[task] = {"pred_answers": pred_answers, "gold_answers": gold_answers}

engine.terminate()

with open(output_filename, "w") as f:
json.dump(run_results, f, ensure_ascii=False, indent=2)

compute_metric(output_filename)
end_time = time.time()
print("total run time %.2f" % (end_time - start_time))
logger.info("total run time %.2f" % (end_time - start_time))


def _get_parser():
Expand Down
32 changes: 25 additions & 7 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from onmt.modules.position_ffn import PositionwiseFeedForward
from onmt.modules.position_ffn import ActivationFunction
from onmt.utils.misc import sequence_mask
from onmt.modules.rmsnorm import RMSNorm

try:
from apex.normalization import FusedRMSNorm as RMSNorm
except ImportError:
from onmt.modules.rmsnorm import RMSNorm


class TransformerDecoderLayerBase(nn.Module):
Expand All @@ -37,6 +41,7 @@ def __init__(
norm_eps=1e-6,
use_ckpting=[],
parallel_gpu=1,
sliding_window=0,
):
"""
Args:
Expand Down Expand Up @@ -118,6 +123,7 @@ def __init__(
self.dropout_p = dropout
self.full_context_alignment = full_context_alignment
self.alignment_heads = alignment_heads
self.sliding_window = sliding_window

def forward(self, *args, **kwargs):
"""Extend `_forward` for (possibly) multiple decoder pass:
Expand Down Expand Up @@ -171,12 +177,12 @@ def _compute_dec_mask(self, tgt_pad_mask, future):
device=tgt_pad_mask.device,
dtype=torch.uint8,
)
future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
# BoolTensor was introduced in pytorch 1.2
try:
future_mask = future_mask.bool()
except AttributeError:
pass
future_mask = future_mask.tril_(0)
if self.sliding_window > 0:
future_mask = future_mask.triu_(-self.sliding_window)
future_mask = future_mask.bool()
future_mask = ~future_mask.view(1, tgt_len, tgt_len)

dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
else: # only mask padding, result mask in (B, 1, T)
dec_mask = tgt_pad_mask
Expand All @@ -189,6 +195,7 @@ def _forward_self_attn(self, norm_layer_in, dec_mask, step, return_attn=False):
norm_layer_in,
norm_layer_in,
mask=dec_mask,
sliding_window=self.sliding_window,
step=step,
return_attn=return_attn,
)
Expand Down Expand Up @@ -230,6 +237,7 @@ def __init__(
norm_eps=1e-6,
use_ckpting=[],
parallel_gpu=1,
sliding_window=0,
):
"""
Args:
Expand Down Expand Up @@ -257,6 +265,7 @@ def __init__(
norm_eps=norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
)
self.context_attn = MultiHeadedAttention(
heads,
Expand Down Expand Up @@ -414,6 +423,7 @@ def from_opt(cls, opt, embeddings):
parallel_gpu=opt.world_size
if opt.parallel_mode == "tensor_parallel"
else 1,
sliding_window=opt.sliding_window,
)

def init_state(self, src, enc_out, enc_final_hs):
Expand Down Expand Up @@ -507,6 +517,7 @@ def __init__(
norm_eps=1e-6,
use_ckpting=[],
parallel_gpu=1,
sliding_window=0,
):
super(TransformerDecoder, self).__init__(
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
Expand Down Expand Up @@ -536,6 +547,7 @@ def __init__(
norm_eps=norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -639,6 +651,8 @@ def _init_cache(self, enc_out):
"values": torch.tensor([], device=enc_out.device),
},
)
if hasattr(layer.self_attn, "rope"):
layer.self_attn.rope = layer.self_attn.rope.to(enc_out.device)


class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
Expand Down Expand Up @@ -749,6 +763,7 @@ def __init__(
norm_eps=1e-6,
use_ckpting=[],
parallel_gpu=1,
sliding_window=0,
):
super(TransformerLMDecoder, self).__init__(
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
Expand Down Expand Up @@ -777,6 +792,7 @@ def __init__(
norm_eps=norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -841,3 +857,5 @@ def _init_cache(self, tgt=None):
"values": torch.tensor([], device=tgt.device),
},
)
if hasattr(layer.self_attn, "rope"):
layer.self_attn.rope = layer.self_attn.rope.to(tgt.device)
6 changes: 5 additions & 1 deletion onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from onmt.modules.position_ffn import PositionwiseFeedForward
from onmt.modules.position_ffn import ActivationFunction
from onmt.utils.misc import sequence_mask
from onmt.modules.rmsnorm import RMSNorm

try:
from apex.normalization import FusedRMSNorm as RMSNorm
except ImportError:
from onmt.modules.rmsnorm import RMSNorm


class TransformerEncoderLayer(nn.Module):
Expand Down
6 changes: 4 additions & 2 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ def _read_vocab_file(vocab_path, min_count):
if not os.path.exists(vocab_path):
raise RuntimeError("Vocabulary not found at {}".format(vocab_path))
else:
with codecs.open(vocab_path, "rb", "utf-8") as f:
lines = [line.strip("\n") for line in f if line.strip("\n")]
with codecs.open(vocab_path, "rb") as f:
lines = [line.decode("utf-8") for line in f.read().split(b"\n")]
lines = lines[:-1]

first_line = lines[0].split(None, 1)
has_count = len(first_line) == 2 and first_line[-1].isdigit()
if has_count:
Expand Down
14 changes: 13 additions & 1 deletion onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def forward(
value: Tensor,
query: Tensor,
mask: Optional[Tensor] = None,
sliding_window: Optional[int] = 0,
step: Optional[int] = 0,
return_attn: Optional[bool] = False,
) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -401,12 +402,16 @@ def forward(
if self.max_relative_positions == -1: # Rotary Embeddings
start_pos = step
seqlen = query.size(2)
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(query, key, rope=rope)

if self.layer_cache[1]["keys"].numel() != 0:
key = torch.cat((self.layer_cache[1]["keys"], key), dim=2)
value = torch.cat((self.layer_cache[1]["values"], value), dim=2)
if sliding_window > 0 and key.size(2) > sliding_window:
key = key[:, :, 1:, :]
value = value[:, :, 1:, :]

self.layer_cache[1]["keys"] = key
self.layer_cache[1]["values"] = value
elif self.attn_type == "context":
Expand Down Expand Up @@ -466,12 +471,19 @@ def forward(
):
causal = self.is_decoder and self.attn_type == "self" and mask is not None
if self.is_decoder and self.attn_type == "self" and flash2:
if causal:
window_size = (
(-1, -1) if sliding_window == 0 else (sliding_window, 0)
)
else:
window_size = (-1, -1)
attn_output = self.flash_attn_func(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
dropout_p=self.dropout_p,
causal=causal,
window_size=window_size,
).transpose(1, 2)
else:
with torch.backends.cuda.sdp_kernel(
Expand Down
6 changes: 5 additions & 1 deletion onmt/modules/position_ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from onmt.modules.rmsnorm import RMSNorm

try:
from apex.normalization import FusedRMSNorm as RMSNorm
except ImportError:
from onmt.modules.rmsnorm import RMSNorm
from torch.nn.utils import skip_init
import torch.distributed as dist

Expand Down
7 changes: 7 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,13 @@ def model_opts(parser):
default=8,
help="Number of heads for transformer self-attention",
)
group.add(
"--sliding_window",
"-sliding_window",
type=int,
default=0,
help="sliding window for transformer self-attention",
)
group.add(
"--transformer_ff",
"-transformer_ff",
Expand Down
11 changes: 1 addition & 10 deletions onmt/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,7 @@ def tile(x, count, dim=0):
x = x.permute(perm)
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = (
x.contiguous()
.view(batch, -1)
.transpose(0, 1)
.repeat(count, 1)
.transpose(0, 1)
.contiguous()
.view(*out_size)
)
x = x.contiguous().view(x.size(0), -1).repeat(1, count).view(*out_size)
if dim != 0:
x = x.permute(perm).contiguous()
return x
Expand Down
5 changes: 5 additions & 0 deletions tools/convert_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def __init__(self, model_path: str):
num_kv = params["n_kv_heads"]
else:
num_kv = 0
if "sliding_window" in params.keys():
sliding_window = params["sliding_window"]
else:
sliding_window = 0

for shard in range(opt.nshards):

Expand Down Expand Up @@ -425,6 +429,7 @@ def __init__(self, model_path: str):
self_attn_type="scaled-dot",
max_relative_positions=-1,
heads=heads,
sliding_window=sliding_window,
transformer_ff=transformer_ff,
aan_useffn=False,
add_qkvbias=False,
Expand Down

0 comments on commit cb35810

Please sign in to comment.