Skip to content

Commit

Permalink
Merge pull request bytedance#105 from bytedance/jzs/fix_llm_perf
Browse files Browse the repository at this point in the history
add single_query and bench_model scripts to debug and perf llm models.
  • Loading branch information
suisiyuan authored Sep 25, 2024
2 parents c5d85fe + 15299f9 commit 1113caa
Show file tree
Hide file tree
Showing 27 changed files with 1,409 additions and 792 deletions.
27 changes: 26 additions & 1 deletion byte_infer_perf/llm_perf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pip3 install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip3 install -r requirements.txt
```

## Quick Start
## Quick Start (run accuracy and performance tests)
Please be sure to complete the installation steps before proceeding with the following steps:
1. Modify task workload, for example, [chatglm2-torch-fp16-6b.json](https://github.com/bytedance/ByteMLPerf/blob/main/byte_infer_perf/llm_perf/workloads/chatglm2-torch-fp16-6b.json)
2. Download model weights using prepare_model.sh or huggingface_cli.
Expand All @@ -24,6 +24,31 @@ You can run following command automate all steps with chatglm2 model on GPU back
python3 byte_infer_perf/llm_perf/launch.py --hardware_type GPU --task chatglm2-torch-fp16-6b
```

## Test accuracy (single query with specify prompt)
Launch a server running mixtral-8x22b (tp_size=8, max_batch_size=8) with following command:
```shell
cd byte_infer_perf/llm_perf
python3 ./server/launch_server.py --hardware_type GPU --model_config ./model_zoo/mixtral-torch-bf16-8x22b.json --tp_size 8 --max_batch_size 8
```

Test server with single prompt, and you can get infer result, logits numpy file and model forward time. Output files will locate in `./reports/single_query/`
```shell
python3 ./script/single_query.py --prompt "What is 7 multiplied by 7?" --batch_size 8
```

## Test model_impl model forward performance
Only need to instantiate MpEngine running mixtral-8x22b (tp_size=8, max_batch_size=8) and feed proper inputs. Runing following command will get performance outputs. You can modify test cases in `./bench_model.py` currerntly.
```shell
python3 ./bench_model.py --hardware_type GPU --model_config ./model_zoo/mixtral-torch-bf16-8x22b.json --tp_size 8 --max_batch_size 8
```

The output will located in `./reports/{hardware_type}/{model_config}/bench_model`:
- **config.json**: perf config
- **context_perf.csv**: prefill, latency with specified {batch_size, seq_len}
- **decode_perf.csv**: decode, latency with specified {batch_size, seq_len}
- **output.txt**: raw latency data


## Demo Project
[GPU Backend](https://github.com/bytedance/ByteMLPerf/tree/main/byte_infer_perf/llm_perf/backends/GPU) provides a demo project that realizes llm inference of chatglm2-6b on A100 with following features:
- Separate functional components:
Expand Down
65 changes: 59 additions & 6 deletions byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import sys
import time
import signal
import pathlib
from multiprocessing import Queue
from typing import List

Expand All @@ -11,7 +14,6 @@
from llm_perf.utils.logger import logger



# context:
# input_ids: [1, s_q]
# attention_mask = [1, s_q]
Expand Down Expand Up @@ -65,6 +67,13 @@ def get_decode_masks(
return full_attention_mask



# basic TP realization mp engine
# 1. main process send all inputs to all subprocesses
# 2. subprocesses process inputs with same logic simultaneously and collaboratively using TP mechanism
# 3. suppose tp = 8, rank 0-7 receive same data,
# computing each part of data, using allreduce or allgather to gather data.
# then rank 0 sends data back to main process
class GpuMpEngine(CoreMpEngine):
def __init__(self, world_size: int, model_impl: nn.Module, xpu_cfg) -> None:
super().__init__(world_size, model_impl, xpu_cfg)
Expand Down Expand Up @@ -117,29 +126,73 @@ def mp_loop_worker(

# create and init model based on model_impl and xpu_config
model = model_impl(xpu_config)
model.init_inference()
if hasattr(model, 'init_inference'):
model.init_inference()

def signal_handler(signum, frame):
logger.info(f"rank {local_rank} received signal {signum}, exiting...")
if hasattr(model, 'finalize_inference'):
model.finalize_inference()
os._exit(0)

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

# current rank is ready
output_queue.put("ready")
output_queue.put("ready", block=True)
logger.info(f"{local_rank}/{world_size} rank is ready")

# model process loop
while True:
(
forward_inputs,
) = input_queue.get(block=True)

log = forward_inputs.get("log", False)
workspace = forward_inputs.get("workspace", None)

forward_inputs["log_file"] = None
if log and workspace is not None:
workspace_dir = workspace / f"rank_{local_rank}"
workspace_dir.mkdir(exist_ok=True, parents=True)
forward_inputs["log_file"] = open(workspace_dir / "run.log", "w")


inputs_dict = self.build_inputs(forward_inputs)
start_time = time.perf_counter_ns()

output_dict = model.forward(inputs_dict)

torch.cuda.synchronize()
end_time = time.perf_counter_ns()
duration_ms = round((end_time - start_time) / 1e6, 3)
output_dict["duration_ms"] = duration_ms

# TP realization: rank0 send result back to main process
if local_rank == 0:
output_queue.put(output_dict)

if log and workspace is not None:
forward_inputs["log_file"].close()

except Exception as e:
logger.exception(f"[BUG] engine _load_and_listen failed, no more requests will be handled. {e}")
output_queue.put(RuntimeError("[BUG] fatal exception in model subprocess"))


def mp_forward(self, *args):
for i in range(self.world_size):
self._input_queues.put(args, True)
return self._output_queues.get(True)
# extra args
# workspace: pathlib.Path, where to save files for each rank
# log: bool, whether to save logs to file
# override_hidden_states: bool, whether to override hidden_states
# random_seed: int, random seed for torch.manual_seed

# send inputs to all subprocesses
for _ in range(self.world_size):
self._input_queues.put(args, block=True)

# wait for one subprocess send result back to main process
output_dict = self._output_queues.get(block=True)

return output_dict

Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,12 @@ def init_inference(self):

logger.info(f"cuda model {self.model_path} loaded {self.transformer_model}")

if self.mp_size > 1:
dist.barrier()

def finalize_inference(self):
if self.mp_size > 1 and dist.is_initialized():
dist.destroy_process_group()

def load_weight(self, ckpt_path):
p_loader = GPUChatGLM2Loader(
Expand All @@ -192,8 +197,6 @@ def load_weight(self, ckpt_path):
p_loader.infusion_to_model()




def init_kvcache(self, dtype):
max_seq_len = 4096
max_batch_size = self.xpu_cfg["max_batch_size"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def init_inference(self):
dist.barrier()


def finalize_inference(self):
if self.mp_size > 1 and dist.is_initialized():
dist.destroy_process_group()

def load_weight(self, ckpt_path):
p_loader = GPUFalconLoader(self.transformer_model, self.falcon_config, ckpt_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def init_inference(self):
if self.mp_size > 1:
dist.barrier()

def finalize_inference(self):
if self.mp_size > 1 and dist.is_initialized():
dist.destroy_process_group()

def load_weight(self, ckpt_path):
p_loader = GPULlamaLoader(self.transformer_model, self.llama_config, ckpt_path)
p_loader.parallel_loader()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def init_inference(self):
if self.mp_size > 1:
dist.barrier()

def finalize_inference(self):
if self.mp_size > 1 and dist.is_initialized():
dist.destroy_process_group()

def load_weight(self, ckpt_path):
p_loader = GPUMixtralLoader(self.transformer_model, self.mixtral_config, ckpt_path)
p_loader.parallel_loader()
Expand All @@ -143,8 +147,9 @@ def init_kvcache(self, dtype):
past_key_values = ()
for i in range(num_layers):
kv_shape = (max_batch_size, kv_head_num // self.mp_size, max_seq_len, head_dim)
key_cache = torch.empty(kv_shape, dtype=dtype, device=cur_device)
value_cache = torch.empty(kv_shape, dtype=dtype, device=cur_device)
torch.manual_seed(1)
key_cache = torch.randn(size=kv_shape, dtype=torch.float32, device="cpu").to(dtype=dtype).to(device=cur_device)
value_cache = torch.randn(size=kv_shape, dtype=torch.float32, device="cpu").to(dtype=dtype).to(device=cur_device)
past_key_values += ((key_cache, value_cache),)
return past_key_values

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,11 @@ def __init__(self, config):
# Jitter parameters
self.jitter_noise = config.router_jitter_noise

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
**kwargs
) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
Expand All @@ -918,11 +922,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

log_file = kwargs.get("log_file", None)
non_zero_num = 0
tokens_list = []

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

if top_x.shape[0] > 0:
non_zero_num += 1
tokens_list.append(top_x.shape[0])

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
Expand All @@ -932,6 +944,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

if log_file is not None:
print(f"num_enabled_experts={non_zero_num}, tokens_distribution={tokens_list}", file=log_file, flush=True)

final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

Expand Down Expand Up @@ -1005,7 +1021,7 @@ def forward(
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states, **kwargs)
if self.mp_size > 1:
dist.all_reduce(hidden_states)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -1177,21 +1193,43 @@ def forward(
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, MoeModelOutputWithPast]:

hidden_states = self.embed_tokens(input_ids)

for decoder_layer in self.layers:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
output_router_logits=False,
use_cache=False,
**kwargs,
)
hidden_states = layer_outputs[0]
if kwargs.pop("override_hidden_states", False):
random_seed = kwargs.pop("random_seed", 1)
layer_index = kwargs.pop("fixed_layer_index", -1)
layer_index = layer_index % len(self.layers)

torch.manual_seed(random_seed)
hidden_states = torch.randn(
size=hidden_states.shape,
dtype=hidden_states.dtype,
device="cpu"
).to(device=hidden_states.device)

for _ in self.layers:
layer_outputs = self.layers[layer_index](
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
output_router_logits=False,
use_cache=False,
**kwargs,
)
else:
for decoder_layer in self.layers:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
output_router_logits=False,
use_cache=False,
**kwargs,
)
hidden_states = layer_outputs[0]

hidden_states = self.norm(hidden_states)

Expand Down
28 changes: 28 additions & 0 deletions byte_infer_perf/llm_perf/backends/GPU/setup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,40 @@
import torch
import importlib
from typing import Any, Dict

from llm_perf.core.scheduler import CoreScheduler
from llm_perf.backends.GPU.gpu_inferencer import GpuInferencer
from llm_perf.backends.GPU.gpu_sampler import GpuSampler
from llm_perf.backends.GPU.gpu_scheduler import GpuScheduler
from llm_perf.backends.GPU.gpu_mp_engine import GpuMpEngine
from llm_perf.utils.logger import logger

def get_device_name():
return torch.cuda.get_device_name(0)


def get_engine(xpu_cfg) -> CoreScheduler:
# get model impl
hardware_type = xpu_cfg["hardware_type"]
model_config = xpu_cfg["model_config"]
model_name = model_config["model_name"]

vendor_model_path = f"llm_perf/backends/{hardware_type}/model_impl"
vendor_model_impl = importlib.import_module(
".", package=vendor_model_path.replace("/", ".")
)
vendor_model = vendor_model_impl.__all__[model_name]

# create mp engine
mp_engine = GpuMpEngine(
world_size=xpu_cfg["tp_size"],
model_impl=vendor_model,
xpu_cfg=xpu_cfg
)

return mp_engine


def setup_scheduler(xpu_cfg) -> CoreScheduler:

# get model impl
Expand Down
Loading

0 comments on commit 1113caa

Please sign in to comment.