Skip to content

Commit

Permalink
support llama3-70, falcon-180b, mixtral-8x22b.
Browse files Browse the repository at this point in the history
  • Loading branch information
suisiyuan committed Jul 23, 2024
1 parent 5c319c6 commit 2ffa5be
Show file tree
Hide file tree
Showing 33 changed files with 4,702 additions and 1,699 deletions.
8 changes: 3 additions & 5 deletions byte_infer_perf/llm_perf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,8 @@ Vendors can refer to this document for guidance on building backend: [Byte LLM P
## Models
The following models are planned to be supported:
* [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
* [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B)
* [shenzhi-wang/Llama3-70B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-70B-Chinese-Chat)
* [tiiuae/falcon-180B](https://huggingface.co/tiiuae/falcon-180B)
- test_accuracy is unavailable temporarily.
* [mistralai/Mixtral-8x22B-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1)

The following models are outdated and will be removed in future vesions:
* [hfl/chinese-llama-2-13b](https://huggingface.co/hfl/chinese-llama-2-13b)

- test_accuracy is unavailable temporarily.
68 changes: 68 additions & 0 deletions byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import time
from multiprocessing import Queue
from typing import List

import torch
import torch.nn as nn
Expand All @@ -9,6 +10,61 @@
from llm_perf.core.mp_engine import CoreMpEngine
from llm_perf.utils.logger import logger



# context:
# input_ids: [1, s_q]
# attention_mask = [1, s_q]
# full_attention_mask = [1, 1, s_q, s_kv] (sq == s_kv)
def get_context_masks(
input_ids : torch.Tensor,
padding_mask : torch.Tensor
):
# input_ids: [1, q_len]
# padding_mask = [1, q_len]
_, q_len = input_ids.shape

# [1, q_len, q_len]
full_attention_mask = torch.ones(
1, q_len, q_len,
device=input_ids.device
)
full_attention_mask.tril_()
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
full_attention_mask = (full_attention_mask < 0.5).bool()
full_attention_mask.unsqueeze_(1)
return full_attention_mask


# decode
# input_ids: [bs, 1]
# attention_mask = [bs, 1]
# full_attention_mask = [bs, 1, 1, s_kv]
def get_decode_masks(
input_ids : torch.Tensor,
all_kv_len: List[int]
):
# input_ids: [batch_size, 1]
# padding_mask: [batch_size, 1 + max_kv_len]
batch_size, q_len = input_ids.shape
max_qkv_len = q_len + max(all_kv_len)

# [batch_size, 1, max_qkv_len]
padding_mask = []
for i in range(batch_size):
cur_qkv_len = q_len + all_kv_len[i]
mask_per_batch = [1] * cur_qkv_len + [0] * (max_qkv_len - cur_qkv_len)
padding_mask.append(mask_per_batch)
full_attention_mask = torch.tensor(
padding_mask,
device=input_ids.device
).unsqueeze_(1)
full_attention_mask = (full_attention_mask < 0.5).bool()
full_attention_mask.unsqueeze_(1)
return full_attention_mask


class GpuMpEngine(CoreMpEngine):
def __init__(self, world_size: int, model_impl: nn.Module, xpu_cfg) -> None:
super().__init__(world_size, model_impl, xpu_cfg)
Expand All @@ -25,6 +81,18 @@ def build_inputs(self, forward_inputs):
forward_inputs["attention_mask"] = torch.tensor(
forward_inputs["attention_mask"]
).cuda()

is_context = forward_inputs["is_context"]
if is_context:
forward_inputs["full_attention_mask"] = get_context_masks(
forward_inputs["input_ids"],
forward_inputs["attention_mask"]
)
else:
forward_inputs["full_attention_mask"] = get_decode_masks(
forward_inputs["input_ids"],
forward_inputs["all_kv_len"]
)
return forward_inputs


Expand Down
6 changes: 5 additions & 1 deletion byte_infer_perf/llm_perf/backends/GPU/model_impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
import torch.nn as nn

from .gpu_chatglm2 import GPUChatGLM2
from .gpu_llama3 import GPULlama
from .gpu_falcon import GPUFalcon
from .gpu_mixtral import GPUMixtral

from llm_perf.utils.logger import logger

__all__ = {
"chatglm2": GPUChatGLM2,
"falcon": GPUFalcon
"llama3": GPULlama,
"falcon": GPUFalcon,
"mixtral": GPUMixtral
}
Loading

0 comments on commit 2ffa5be

Please sign in to comment.