From 95b345c75cdce463b71fe1f9c85d84c239e77f10 Mon Sep 17 00:00:00 2001 From: vince62s Date: Mon, 26 Feb 2024 18:44:42 +0100 Subject: [PATCH] v3.5 hotfix --- eval_llm/WIKITEXT2/readme.md | 54 +++++++++++++++++++++++++++++++ onmt/models/model.py | 8 +++-- onmt/modules/multi_headed_attn.py | 2 +- onmt/train_single.py | 28 ++++++++-------- tools/convert_llama.py | 3 ++ 5 files changed, 78 insertions(+), 17 deletions(-) create mode 100644 eval_llm/WIKITEXT2/readme.md diff --git a/eval_llm/WIKITEXT2/readme.md b/eval_llm/WIKITEXT2/readme.md new file mode 100644 index 0000000000..ae99d65c5a --- /dev/null +++ b/eval_llm/WIKITEXT2/readme.md @@ -0,0 +1,54 @@ +These are perplexity computed on wikitext2. + +Numbers are not comparable to lm-evaluation-harness since they compute word / byte / bit perplexity like this: + +hf-auto (pretrained=mistralai/Mistral-7B-Instruct-v0.2), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 8 +| Tasks |Version|Filter|n-shot| Metric |Value | |Stderr| +|--------|------:|------|------|---------------|-----:|---|------| +|wikitext| 2|none |None |word_perplexity|9.8183|± |N/A | +| | |none |None |byte_perplexity|1.5329|± |N/A | +| | |none |None |bits_per_byte |0.6163|± |N/A | + + +hf-auto (pretrained=meta-llama/Llama-2-7b-hf), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1 +| Tasks |Version|Filter|n-shot| Metric |Value | |Stderr| +|--------|------:|------|------|---------------|-----:|---|------| +|wikitext| 2|none |None |word_perplexity|8.7921|± |N/A | +| | |none |None |byte_perplexity|1.5016|± |N/A | +| | |none |None |bits_per_byte |0.5865|± |N/A | + + +Numbers are not comparable to perplexity reported by llama.cpp because we use a smaller context window but also we detokenize the raw corpus (thing that they shoudl do but they don't) + +| 7B Family | | PPL | Time (sec) | +| ---------------- | --------------------- | ----- | ---------- | +| Base | llama2 | 5.78 | 152 | +| | mistral v0.1 | 5.70 | 162 | +| | awq | 5.81 | 165 | +| | Yi-6B-200K | 7.76 | 133 | +| | xgen-7B | 8.64 | 129 | +| | mpt-7B | 8.43 | 147 | +| | | | | +| Instruct / Tuned | llama2-chat | 7.37 | 148 | +| | mistral-instr-v0.2 | 6.98 | 160 | +| | gemm-awq | 7.07 | 164 | +| | gemv-awq | 7.07 | 237 | +| | | | | +| | Alma-7B-R | 6.82 | 156 | +| | TowerInstruct-7B | 6.45 | 157 | +| | codellama-7B | 8.56 | 154 | +| | | | | +| 3B Family | Phi-2 | 9.74 | 52 | +| | Phi-2-psy | 10.44 | 53 | +| | | | | +| 13B Family | llama2 (4-bit) | 5.31 | 296 | +| | llama2-chat (4-bit) | 6.59 | 292 | +| | | | | +| 34B Family | codellama-34B (4-bit) | 6.00 | 706 | + + +We note that llama2 and Mistral are in fact very close for their base model. However there is a shift between their chat model. + +All others are quite below which is surprising for Yi given their results on the Open llm leaderboard. + +I need to check why Mistral seems a little slower than llama2, it should be the opposite. diff --git a/onmt/models/model.py b/onmt/models/model.py index 40f0ce534d..c295c66720 100644 --- a/onmt/models/model.py +++ b/onmt/models/model.py @@ -157,7 +157,9 @@ def load_state_dict( ) param.data = checkpoint["generator"][keyname] del checkpoint["generator"][keyname] - elif strict and "lora" not in param_name: + elif strict and ( + "lora" not in param_name and "slopes" not in param_name + ): raise ValueError( "Missing key in checkpoint: %s" % name + "." + param_name ) @@ -234,7 +236,9 @@ def load_safe_state_dict( name, module, param_name, param, buf_list, ckpt_t, offset ) keyfound[name + "." + param_name] = True - elif strict and "lora" not in param_name: + elif strict and ( + "lora" not in param_name and "slopes" not in param_name + ): raise ValueError( "Missing key in safetensors checkpoint: %s" % name + "." diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 1a09567e2f..21c38d2f4e 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -599,7 +599,7 @@ def forward( base=self.rotary_theta, device=query.device, ) - rope = self.rope[start_pos : start_pos + seqlen] + rope = self.rope[start_pos : start_pos + seqlen].to(query.device) query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave ) diff --git a/onmt/train_single.py b/onmt/train_single.py index efd76a2752..76ab3bef66 100644 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -27,20 +27,6 @@ def prepare_transforms_vocabs(opt, transforms_cls): """Prepare or dump transforms before training.""" - # if transform + options set in 'valid' we need to copy in main - # transform / options for scoring considered as inference - validset_transforms = opt.data.get("valid", {}).get("transforms", None) - if validset_transforms: - opt.transforms = validset_transforms - if opt.data.get("valid", {}).get("tgt_prefix", None): - opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None) - opt.tgt_file_prefix = True - if opt.data.get("valid", {}).get("src_prefix", None): - opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None) - if opt.data.get("valid", {}).get("tgt_suffix", None): - opt.tgt_suffix = opt.data.get("valid", {}).get("tgt_suffix", None) - if opt.data.get("valid", {}).get("src_suffix", None): - opt.src_suffix = opt.data.get("valid", {}).get("src_suffix", None) specials = get_specials(opt, transforms_cls) vocabs = build_vocab(opt, specials) @@ -77,6 +63,20 @@ def _init_train(opt): """ ArgumentParser.validate_prepare_opts(opt) transforms_cls = get_transforms_cls(opt._all_transform) + # if transform + options set in 'valid' we need to copy in main + # transform / options for scoring considered as inference + validset_transforms = opt.data.get("valid", {}).get("transforms", None) + if validset_transforms: + opt.transforms = validset_transforms + if opt.data.get("valid", {}).get("tgt_prefix", None): + opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None) + opt.tgt_file_prefix = True + if opt.data.get("valid", {}).get("src_prefix", None): + opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None) + if opt.data.get("valid", {}).get("tgt_suffix", None): + opt.tgt_suffix = opt.data.get("valid", {}).get("tgt_suffix", None) + if opt.data.get("valid", {}).get("src_suffix", None): + opt.src_suffix = opt.data.get("valid", {}).get("src_suffix", None) if opt.train_from: # Load checkpoint if we resume from a previous training. checkpoint = load_checkpoint(ckpt_path=opt.train_from) diff --git a/tools/convert_llama.py b/tools/convert_llama.py index 1d7571540b..26f4ee6140 100644 --- a/tools/convert_llama.py +++ b/tools/convert_llama.py @@ -428,6 +428,9 @@ def __init__(self, model_path: str): global_attention_function="softmax", self_attn_type="scaled-dot", max_relative_positions=-1, + rotary_interleave=True, + rotary_theta=10000, + rotary_dim=0, heads=heads, sliding_window=sliding_window, transformer_ff=transformer_ff,