Skip to content

Commit

Permalink
[llm_perf] add tp, kvcache, seperate schedule support for chatglm2 mo…
Browse files Browse the repository at this point in the history
…del on GPU backend.
  • Loading branch information
suisiyuan committed Jun 18, 2024
1 parent 426191d commit 161ecd7
Show file tree
Hide file tree
Showing 21 changed files with 1,196 additions and 608 deletions.
2 changes: 1 addition & 1 deletion byte_infer_perf/llm_perf/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Byte LLM Perf

Vendors can refer to this document for guidance on building backend: [Byte LLM Perf](https://bytedance.larkoffice.com/docx/ZoU7dkPXYoKtJtxlrRMcNGMwnTc)
Vendors can refer to this document for guidance on building backend: [Byte LLM Perf](https://bytemlperf.ai/zh/guide/inference_llm_vendor.html)

## Requirements
* Python >= 3.8
Expand Down
46 changes: 46 additions & 0 deletions byte_infer_perf/llm_perf/backends/GPU/gpu_ckpt_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
import torch.distributed as dist

from llm_perf.core.ckpt_loader import CoreCkptLoader

class GpuCkptLoader(CoreCkptLoader):
def __init__(
self,
prefix, model,
mp_size=1, mp_rank=0,
ckpt_path: str=""
):
super().__init__(prefix, model, mp_size, mp_rank, ckpt_path)


def weight_to_device(self, weight : torch.Tensor, non_blocking=False):
if self.mp_rank == 0:
weight = weight.cuda(non_blocking=non_blocking)
else:
cur_device = torch.cuda.current_device()
weight = torch.empty_like(weight, device=f"cuda:{cur_device}")
return weight

def broadcast_weight(self, key, device='cpu', non_blocking=False):
weight = self.weight_to_device(self.state_dict[key])
dist.broadcast(weight, src=0)
dist.barrier()
self.state_dict[key] = weight.to(device, non_blocking=non_blocking)

def scatter_weight(self, key, dim, split_mode='default', outter=1, device='cpu', non_blocking=False):
self.broadcast_weight(key, 'cuda')
weight = self.state_dict[key]

if split_mode == 'default':
weight_split = self.split(weight, dim)
elif split_mode == 'with_outter':
weight_split = self.with_outter_split(weight, dim, outter)
elif split_mode == 'split_outter':
weight_split = self.split(weight, dim, outter)
else:
assert False, f"unknown split mode {split_mode}"


weight_split = [x.contiguous() for x in weight_split]
weight = weight_split[self.mp_rank].clone()
self.state_dict[key] = weight.to(device, non_blocking=non_blocking)
139 changes: 98 additions & 41 deletions byte_infer_perf/llm_perf/backends/GPU/gpu_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,49 +13,105 @@ def __init__(self, model_impl, xpu_cfg):

self.tp_size = xpu_cfg["tp_size"]
self.pad_token_id = xpu_cfg["pad_token_id"]
self.mp_engine = GpuMpEngine(self.tp_size, model_impl, xpu_cfg)

def prepare_inputs(self, tasks: List["CoreInferencer.Task"]):
all_input_ids = []
all_position_ids = []
all_attention_mask = []

max_seq_len = -1
for task in tasks:
cur_id_len = len(task.request.input_ids) + len(task.generate_ids)
max_seq_len = cur_id_len if cur_id_len > max_seq_len else max_seq_len

for task in tasks:
cur_id_len = len(task.request.input_ids) + len(task.generate_ids)
pad_len = max_seq_len - cur_id_len
# using left padding
input_ids = (
[self.pad_token_id] * pad_len +
task.request.input_ids +
task.generate_ids
)
pos_ids = (
[i for i in range(max_seq_len)]
)
attention_mask = (
[0] * pad_len +
[1] * cur_id_len
)
all_input_ids.append(input_ids)
all_position_ids.append(pos_ids)
all_attention_mask.append(attention_mask)

# create model_inputs
model_inputs = {}
model_inputs["input_ids"] = all_input_ids
model_inputs["position_ids"] = all_position_ids
model_inputs["attention_mask"] = all_attention_mask
self.max_batch_size = xpu_cfg["max_batch_size"]
self.mp_engine = GpuMpEngine(self.tp_size, model_impl, xpu_cfg)

def prepare_inputs(
self,
tasks: List[CoreInferencer.Task],
**kwargs
):
input_dict = {
"input_ids": None,
"position_ids": None,
"attention_mask": None,
"all_q_len": None,
"all_kv_len": None,
"is_context": None,
"valid_slot_ids": None
}

is_context = kwargs.get("is_context") if "is_context" in kwargs.keys() else False
valid_slot_ids = kwargs.get("valid_slot_ids") if "valid_slot_ids" in kwargs.keys() else [i for i in range(self.max_batch_size)]

input_dict["is_context"] = is_context
input_dict["valid_slot_ids"] = valid_slot_ids

if is_context:
q_len = len(tasks[0].request.input_ids)
kv_len = len(tasks[0].request.input_ids)

input_dict["input_ids"] = [
tasks[0].request.input_ids
]
input_dict["position_ids"] = [
[i for i in range(q_len)]
]
input_dict["attention_mask"] = [
[1 for _ in range(q_len)]
]
input_dict["all_q_len"] = [
q_len
]
input_dict["all_kv_len"] = [
kv_len
]
else:
all_input_ids = []
all_position_ids = []
all_attention_mask = []
all_q_len = []
all_kv_len = []

for task in tasks:
q_len = 1
kv_len = 0

return model_inputs
if task is None:
kv_len = 1

input_ids = [
self.pad_token_id
]
position_ids = [
0
]
attention_mask = [
0
]
else:
kv_len = len(task.request.input_ids) + len(task.generate_ids) - 1

def infer(self, tasks: List["CoreInferencer.Task"]):
input_dict = self.prepare_inputs(tasks)
input_ids = [
task.generate_ids[-1]
]
position_ids = [
kv_len
]
attention_mask = [
1
]
all_input_ids.append(input_ids)
all_position_ids.append(position_ids)
all_attention_mask.append(attention_mask)
all_q_len.append(q_len)
all_kv_len.append(kv_len)

input_dict["input_ids"] = all_input_ids
input_dict["position_ids"] = all_position_ids
input_dict["attention_mask"] = all_attention_mask
input_dict["all_q_len"] = all_q_len
input_dict["all_kv_len"] = all_kv_len

return input_dict


def infer(
self,
tasks: List[CoreInferencer.Task],
**kwargs
):
input_dict = self.prepare_inputs(tasks, **kwargs)
outputs = self.mp_engine.mp_forward(input_dict)

input_logits = outputs.logits[..., :-1, :].contiguous()
Expand All @@ -68,4 +124,5 @@ def infer(self, tasks: List["CoreInferencer.Task"]):
return {
"input_logits": input_logits,
"last_logits": next_tokens_logits,
}
}

27 changes: 16 additions & 11 deletions byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import time
from multiprocessing import Queue

import torch
import torch.nn as nn
import torch.distributed as dist

from llm_perf.core.mp_engine import CoreMpEngine
from llm_perf.utils.logger import logger
Expand All @@ -13,11 +15,19 @@ def __init__(self, world_size: int, model_impl: nn.Module, xpu_cfg) -> None:


def build_inputs(self, forward_inputs):
forward_inputs["input_ids"] = torch.tensor(forward_inputs["input_ids"]).cuda()
forward_inputs["position_ids"] = torch.tensor(forward_inputs["position_ids"]).cuda()
forward_inputs["attention_mask"] = torch.tensor(forward_inputs["attention_mask"]).cuda()
# list --> torch.Tensor --> cuda
forward_inputs["input_ids"] = torch.tensor(
forward_inputs["input_ids"]
).cuda()
forward_inputs["position_ids"] = torch.tensor(
forward_inputs["position_ids"]
).cuda()
forward_inputs["attention_mask"] = torch.tensor(
forward_inputs["attention_mask"]
).cuda()
return forward_inputs



@torch.no_grad()
def mp_loop_worker(
self,
Expand All @@ -36,13 +46,10 @@ def mp_loop_worker(
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


# set device
torch.cuda.set_device(local_rank)

# create and init model based on model_impl and xpu_config
model = model_impl(xpu_config)
model.init_inference()

# current rank is ready
output_queue.put("ready")
Expand All @@ -54,13 +61,11 @@ def mp_loop_worker(
forward_inputs,
) = input_queue.get(block=True)

# model forward
inputs = self.build_inputs(forward_inputs)
logits = model.forward(inputs)

torch.cuda.synchronize()
if local_rank == 0:
output_queue.put(logits)
torch.cuda.synchronize()

except Exception as e:
logger.exception(f"[BUG] engine _load_and_listen failed, no more requests will be handled. {e}")
Expand Down
6 changes: 5 additions & 1 deletion byte_infer_perf/llm_perf/backends/GPU/gpu_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ def postprocess(
packet = tasks[i]

if token_id == packet.request.generate_config.eos_token_id:
finish_reason = "stop"
if len(packet.generate_ids) + 1 < packet.request.generate_config.min_new_tokens:
finish_reason = ""
token_id = packet.request.generate_config.eos_token_id
else:
finish_reason = "stop"
# take current generated token into account
elif (
len(packet.generate_ids) + 1
Expand Down
Loading

0 comments on commit 161ecd7

Please sign in to comment.