Skip to content

Commit

Permalink
v3.5 hotfix
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Feb 26, 2024
1 parent b9a60d6 commit 95b345c
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 17 deletions.
54 changes: 54 additions & 0 deletions eval_llm/WIKITEXT2/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
These are perplexity computed on wikitext2.

Numbers are not comparable to lm-evaluation-harness since they compute word / byte / bit perplexity like this:

hf-auto (pretrained=mistralai/Mistral-7B-Instruct-v0.2), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 8
| Tasks |Version|Filter|n-shot| Metric |Value | |Stderr|
|--------|------:|------|------|---------------|-----:|---|------|
|wikitext| 2|none |None |word_perplexity|9.8183|± |N/A |
| | |none |None |byte_perplexity|1.5329|± |N/A |
| | |none |None |bits_per_byte |0.6163|± |N/A |


hf-auto (pretrained=meta-llama/Llama-2-7b-hf), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
| Tasks |Version|Filter|n-shot| Metric |Value | |Stderr|
|--------|------:|------|------|---------------|-----:|---|------|
|wikitext| 2|none |None |word_perplexity|8.7921|± |N/A |
| | |none |None |byte_perplexity|1.5016|± |N/A |
| | |none |None |bits_per_byte |0.5865|± |N/A |


Numbers are not comparable to perplexity reported by llama.cpp because we use a smaller context window but also we detokenize the raw corpus (thing that they shoudl do but they don't)

| 7B Family | | PPL | Time (sec) |
| ---------------- | --------------------- | ----- | ---------- |
| Base | llama2 | 5.78 | 152 |
| | mistral v0.1 | 5.70 | 162 |
| | awq | 5.81 | 165 |
| | Yi-6B-200K | 7.76 | 133 |
| | xgen-7B | 8.64 | 129 |
| | mpt-7B | 8.43 | 147 |
| | | | |
| Instruct / Tuned | llama2-chat | 7.37 | 148 |
| | mistral-instr-v0.2 | 6.98 | 160 |
| | gemm-awq | 7.07 | 164 |
| | gemv-awq | 7.07 | 237 |
| | | | |
| | Alma-7B-R | 6.82 | 156 |
| | TowerInstruct-7B | 6.45 | 157 |
| | codellama-7B | 8.56 | 154 |
| | | | |
| 3B Family | Phi-2 | 9.74 | 52 |
| | Phi-2-psy | 10.44 | 53 |
| | | | |
| 13B Family | llama2 (4-bit) | 5.31 | 296 |
| | llama2-chat (4-bit) | 6.59 | 292 |
| | | | |
| 34B Family | codellama-34B (4-bit) | 6.00 | 706 |


We note that llama2 and Mistral are in fact very close for their base model. However there is a shift between their chat model.

All others are quite below which is surprising for Yi given their results on the Open llm leaderboard.

I need to check why Mistral seems a little slower than llama2, it should be the opposite.
8 changes: 6 additions & 2 deletions onmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def load_state_dict(
)
param.data = checkpoint["generator"][keyname]
del checkpoint["generator"][keyname]
elif strict and "lora" not in param_name:
elif strict and (
"lora" not in param_name and "slopes" not in param_name
):
raise ValueError(
"Missing key in checkpoint: %s" % name + "." + param_name
)
Expand Down Expand Up @@ -234,7 +236,9 @@ def load_safe_state_dict(
name, module, param_name, param, buf_list, ckpt_t, offset
)
keyfound[name + "." + param_name] = True
elif strict and "lora" not in param_name:
elif strict and (
"lora" not in param_name and "slopes" not in param_name
):
raise ValueError(
"Missing key in safetensors checkpoint: %s" % name
+ "."
Expand Down
2 changes: 1 addition & 1 deletion onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def forward(
base=self.rotary_theta,
device=query.device,
)
rope = self.rope[start_pos : start_pos + seqlen]
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
)
Expand Down
28 changes: 14 additions & 14 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,6 @@

def prepare_transforms_vocabs(opt, transforms_cls):
"""Prepare or dump transforms before training."""
# if transform + options set in 'valid' we need to copy in main
# transform / options for scoring considered as inference
validset_transforms = opt.data.get("valid", {}).get("transforms", None)
if validset_transforms:
opt.transforms = validset_transforms
if opt.data.get("valid", {}).get("tgt_prefix", None):
opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None)
opt.tgt_file_prefix = True
if opt.data.get("valid", {}).get("src_prefix", None):
opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None)
if opt.data.get("valid", {}).get("tgt_suffix", None):
opt.tgt_suffix = opt.data.get("valid", {}).get("tgt_suffix", None)
if opt.data.get("valid", {}).get("src_suffix", None):
opt.src_suffix = opt.data.get("valid", {}).get("src_suffix", None)
specials = get_specials(opt, transforms_cls)

vocabs = build_vocab(opt, specials)
Expand Down Expand Up @@ -77,6 +63,20 @@ def _init_train(opt):
"""
ArgumentParser.validate_prepare_opts(opt)
transforms_cls = get_transforms_cls(opt._all_transform)
# if transform + options set in 'valid' we need to copy in main
# transform / options for scoring considered as inference
validset_transforms = opt.data.get("valid", {}).get("transforms", None)
if validset_transforms:
opt.transforms = validset_transforms
if opt.data.get("valid", {}).get("tgt_prefix", None):
opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None)
opt.tgt_file_prefix = True
if opt.data.get("valid", {}).get("src_prefix", None):
opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None)
if opt.data.get("valid", {}).get("tgt_suffix", None):
opt.tgt_suffix = opt.data.get("valid", {}).get("tgt_suffix", None)
if opt.data.get("valid", {}).get("src_suffix", None):
opt.src_suffix = opt.data.get("valid", {}).get("src_suffix", None)
if opt.train_from:
# Load checkpoint if we resume from a previous training.
checkpoint = load_checkpoint(ckpt_path=opt.train_from)
Expand Down
3 changes: 3 additions & 0 deletions tools/convert_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,9 @@ def __init__(self, model_path: str):
global_attention_function="softmax",
self_attn_type="scaled-dot",
max_relative_positions=-1,
rotary_interleave=True,
rotary_theta=10000,
rotary_dim=0,
heads=heads,
sliding_window=sliding_window,
transformer_ff=transformer_ff,
Expand Down

0 comments on commit 95b345c

Please sign in to comment.