diff --git a/byte_infer_perf/llm_perf/README.md b/byte_infer_perf/llm_perf/README.md index 12d2d879..4807ce28 100644 --- a/byte_infer_perf/llm_perf/README.md +++ b/byte_infer_perf/llm_perf/README.md @@ -1,6 +1,6 @@ # Byte LLM Perf -Vendors can refer to this document for guidance on building backend: [Byte LLM Perf](https://bytedance.larkoffice.com/docx/ZoU7dkPXYoKtJtxlrRMcNGMwnTc) +Vendors can refer to this document for guidance on building backend: [Byte LLM Perf](https://bytemlperf.ai/zh/guide/inference_llm_vendor.html) ## Requirements * Python >= 3.8 diff --git a/byte_infer_perf/llm_perf/backends/GPU/gpu_ckpt_loader.py b/byte_infer_perf/llm_perf/backends/GPU/gpu_ckpt_loader.py new file mode 100644 index 00000000..51f2a113 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/GPU/gpu_ckpt_loader.py @@ -0,0 +1,46 @@ +import torch +import torch.distributed as dist + +from llm_perf.core.ckpt_loader import CoreCkptLoader + +class GpuCkptLoader(CoreCkptLoader): + def __init__( + self, + prefix, model, + mp_size=1, mp_rank=0, + ckpt_path: str="" + ): + super().__init__(prefix, model, mp_size, mp_rank, ckpt_path) + + + def weight_to_device(self, weight : torch.Tensor, non_blocking=False): + if self.mp_rank == 0: + weight = weight.cuda(non_blocking=non_blocking) + else: + cur_device = torch.cuda.current_device() + weight = torch.empty_like(weight, device=f"cuda:{cur_device}") + return weight + + def broadcast_weight(self, key, device='cpu', non_blocking=False): + weight = self.weight_to_device(self.state_dict[key]) + dist.broadcast(weight, src=0) + dist.barrier() + self.state_dict[key] = weight.to(device, non_blocking=non_blocking) + + def scatter_weight(self, key, dim, split_mode='default', outter=1, device='cpu', non_blocking=False): + self.broadcast_weight(key, 'cuda') + weight = self.state_dict[key] + + if split_mode == 'default': + weight_split = self.split(weight, dim) + elif split_mode == 'with_outter': + weight_split = self.with_outter_split(weight, dim, outter) + elif split_mode == 'split_outter': + weight_split = self.split(weight, dim, outter) + else: + assert False, f"unknown split mode {split_mode}" + + + weight_split = [x.contiguous() for x in weight_split] + weight = weight_split[self.mp_rank].clone() + self.state_dict[key] = weight.to(device, non_blocking=non_blocking) \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/GPU/gpu_inferencer.py b/byte_infer_perf/llm_perf/backends/GPU/gpu_inferencer.py index ec45ce67..9c102037 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/gpu_inferencer.py +++ b/byte_infer_perf/llm_perf/backends/GPU/gpu_inferencer.py @@ -13,49 +13,105 @@ def __init__(self, model_impl, xpu_cfg): self.tp_size = xpu_cfg["tp_size"] self.pad_token_id = xpu_cfg["pad_token_id"] - self.mp_engine = GpuMpEngine(self.tp_size, model_impl, xpu_cfg) - - def prepare_inputs(self, tasks: List["CoreInferencer.Task"]): - all_input_ids = [] - all_position_ids = [] - all_attention_mask = [] - - max_seq_len = -1 - for task in tasks: - cur_id_len = len(task.request.input_ids) + len(task.generate_ids) - max_seq_len = cur_id_len if cur_id_len > max_seq_len else max_seq_len - - for task in tasks: - cur_id_len = len(task.request.input_ids) + len(task.generate_ids) - pad_len = max_seq_len - cur_id_len - # using left padding - input_ids = ( - [self.pad_token_id] * pad_len + - task.request.input_ids + - task.generate_ids - ) - pos_ids = ( - [i for i in range(max_seq_len)] - ) - attention_mask = ( - [0] * pad_len + - [1] * cur_id_len - ) - all_input_ids.append(input_ids) - all_position_ids.append(pos_ids) - all_attention_mask.append(attention_mask) - - # create model_inputs - model_inputs = {} - model_inputs["input_ids"] = all_input_ids - model_inputs["position_ids"] = all_position_ids - model_inputs["attention_mask"] = all_attention_mask + self.max_batch_size = xpu_cfg["max_batch_size"] + self.mp_engine = GpuMpEngine(self.tp_size, model_impl, xpu_cfg) + + def prepare_inputs( + self, + tasks: List[CoreInferencer.Task], + **kwargs + ): + input_dict = { + "input_ids": None, + "position_ids": None, + "attention_mask": None, + "all_q_len": None, + "all_kv_len": None, + "is_context": None, + "valid_slot_ids": None + } + + is_context = kwargs.get("is_context") if "is_context" in kwargs.keys() else False + valid_slot_ids = kwargs.get("valid_slot_ids") if "valid_slot_ids" in kwargs.keys() else [i for i in range(self.max_batch_size)] + + input_dict["is_context"] = is_context + input_dict["valid_slot_ids"] = valid_slot_ids + + if is_context: + q_len = len(tasks[0].request.input_ids) + kv_len = len(tasks[0].request.input_ids) + + input_dict["input_ids"] = [ + tasks[0].request.input_ids + ] + input_dict["position_ids"] = [ + [i for i in range(q_len)] + ] + input_dict["attention_mask"] = [ + [1 for _ in range(q_len)] + ] + input_dict["all_q_len"] = [ + q_len + ] + input_dict["all_kv_len"] = [ + kv_len + ] + else: + all_input_ids = [] + all_position_ids = [] + all_attention_mask = [] + all_q_len = [] + all_kv_len = [] + + for task in tasks: + q_len = 1 + kv_len = 0 - return model_inputs + if task is None: + kv_len = 1 + input_ids = [ + self.pad_token_id + ] + position_ids = [ + 0 + ] + attention_mask = [ + 0 + ] + else: + kv_len = len(task.request.input_ids) + len(task.generate_ids) - 1 - def infer(self, tasks: List["CoreInferencer.Task"]): - input_dict = self.prepare_inputs(tasks) + input_ids = [ + task.generate_ids[-1] + ] + position_ids = [ + kv_len + ] + attention_mask = [ + 1 + ] + all_input_ids.append(input_ids) + all_position_ids.append(position_ids) + all_attention_mask.append(attention_mask) + all_q_len.append(q_len) + all_kv_len.append(kv_len) + + input_dict["input_ids"] = all_input_ids + input_dict["position_ids"] = all_position_ids + input_dict["attention_mask"] = all_attention_mask + input_dict["all_q_len"] = all_q_len + input_dict["all_kv_len"] = all_kv_len + + return input_dict + + + def infer( + self, + tasks: List[CoreInferencer.Task], + **kwargs + ): + input_dict = self.prepare_inputs(tasks, **kwargs) outputs = self.mp_engine.mp_forward(input_dict) input_logits = outputs.logits[..., :-1, :].contiguous() @@ -68,4 +124,5 @@ def infer(self, tasks: List["CoreInferencer.Task"]): return { "input_logits": input_logits, "last_logits": next_tokens_logits, - } \ No newline at end of file + } + diff --git a/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py b/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py index 9808fe8a..621aeeec 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py +++ b/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py @@ -1,8 +1,10 @@ import os +import time from multiprocessing import Queue import torch import torch.nn as nn +import torch.distributed as dist from llm_perf.core.mp_engine import CoreMpEngine from llm_perf.utils.logger import logger @@ -13,11 +15,19 @@ def __init__(self, world_size: int, model_impl: nn.Module, xpu_cfg) -> None: def build_inputs(self, forward_inputs): - forward_inputs["input_ids"] = torch.tensor(forward_inputs["input_ids"]).cuda() - forward_inputs["position_ids"] = torch.tensor(forward_inputs["position_ids"]).cuda() - forward_inputs["attention_mask"] = torch.tensor(forward_inputs["attention_mask"]).cuda() + # list --> torch.Tensor --> cuda + forward_inputs["input_ids"] = torch.tensor( + forward_inputs["input_ids"] + ).cuda() + forward_inputs["position_ids"] = torch.tensor( + forward_inputs["position_ids"] + ).cuda() + forward_inputs["attention_mask"] = torch.tensor( + forward_inputs["attention_mask"] + ).cuda() return forward_inputs - + + @torch.no_grad() def mp_loop_worker( self, @@ -36,13 +46,10 @@ def mp_loop_worker( os.environ["LOCAL_RANK"] = str(local_rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_WORLD_SIZE"] = str(world_size) - - - # set device - torch.cuda.set_device(local_rank) # create and init model based on model_impl and xpu_config model = model_impl(xpu_config) + model.init_inference() # current rank is ready output_queue.put("ready") @@ -53,14 +60,11 @@ def mp_loop_worker( ( forward_inputs, ) = input_queue.get(block=True) - - # model forward inputs = self.build_inputs(forward_inputs) logits = model.forward(inputs) - + torch.cuda.synchronize() if local_rank == 0: output_queue.put(logits) - torch.cuda.synchronize() except Exception as e: logger.exception(f"[BUG] engine _load_and_listen failed, no more requests will be handled. {e}") diff --git a/byte_infer_perf/llm_perf/backends/GPU/gpu_sampler.py b/byte_infer_perf/llm_perf/backends/GPU/gpu_sampler.py index 4514d413..8993d023 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/gpu_sampler.py +++ b/byte_infer_perf/llm_perf/backends/GPU/gpu_sampler.py @@ -109,30 +109,43 @@ def postprocess( generate_result = [] for i in range(len(tasks)): token_id = next_tokens[i] - packet = tasks[i] - - if token_id == packet.request.generate_config.eos_token_id: - finish_reason = "stop" + task = tasks[i] + + if token_id == task.request.generate_config.eos_token_id: + if len(task.generate_ids) + 1 < task.request.generate_config.min_new_tokens: + finish_reason = "" + token_id = task.request.generate_config.eos_token_id + else: + finish_reason = "stop" # take current generated token into account elif ( - len(packet.generate_ids) + 1 - >= packet.request.generate_config.max_new_tokens + len(task.generate_ids) + 1 + >= task.request.generate_config.max_new_tokens ): finish_reason = "max_length" else: finish_reason = "" - if packet.request.generate_config.get_input_logits: + if task.request.generate_config.get_input_logits: last_logits = infer_outputs["last_logits"] input_logits = infer_outputs["input_logits"] gen_res = GenerateResult( token_id=token_id, - finish_reason=finish_reason, + finish_reason=finish_reason, + wait_time=task.wait_time[-1], + model_time=task.model_time[-1], + post_process_time=task.post_process_time[-1], last_logits=last_logits.view(-1).tolist(), input_logits=input_logits.view(-1).tolist(), ) else: - gen_res = GenerateResult(token_id=token_id, finish_reason=finish_reason) + gen_res = GenerateResult( + token_id=token_id, + finish_reason=finish_reason, + wait_time=task.wait_time[-1], + model_time=task.model_time[-1], + post_process_time=task.post_process_time[-1], + ) generate_result.append(gen_res) diff --git a/byte_infer_perf/llm_perf/backends/GPU/gpu_scheduler.py b/byte_infer_perf/llm_perf/backends/GPU/gpu_scheduler.py index 678ebd55..a4a03c0c 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/gpu_scheduler.py +++ b/byte_infer_perf/llm_perf/backends/GPU/gpu_scheduler.py @@ -1,7 +1,7 @@ import sys import time from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import List +from typing import List, Set import torch @@ -24,58 +24,118 @@ def __init__( self.max_batch_size = xpu_cfg["max_batch_size"] - @torch.inference_mode() def scheduler_loop(self): - batch: List[CoreInferencer.Task] = [] + task_slots: List[CoreInferencer.Task] = [None] * self.max_batch_size + avail_slots: List[int] = [self.max_batch_size - 1 - i for i in range(self.max_batch_size)] + context_slots: List[int] = [] + while self.started: - # 1. select batch --> batch - batch = self.select_batch(batch) - if not batch: + + while not self.task_queue.empty(): + if len(avail_slots) == 0: + break + slot = avail_slots.pop() + task_slots[slot] = self.task_queue.get() + context_slots.append(slot) + + if len(avail_slots) == self.max_batch_size: with self.task_queue.not_empty: self.task_queue.not_empty.wait(0.1) continue - logger.debug(f"get batch size: {len(batch)}") - - # 2. do inference -> logits - outputs = self.inferencer.infer(batch) - - # 3. sample logits -> tokens - next_tokens, softmax_out = self.sampler.sample( - tasks=batch, logits=outputs["last_logits"] - ) - - # 4.postprocess -> gen result - generation_results = self.sampler.postprocess( - tasks=batch, - infer_outputs=outputs, - next_tokens=next_tokens, - ) - - # 5. add result to task - for i, gen_res in enumerate(generation_results): - batch[i].add_result(gen_res) - if gen_res.finish_reason: - batch[i].finish() - - # 6. is not finished -> remain - remained: List[CoreInferencer.Packet] = [] - for task in batch: - if not task.is_finished(): - remained.append(task) - batch = remained - - def select_batch(self, - batch: CoreInferencer.Task - ): - batching_size: int = len(batch) - new_select_packets: List[CoreInferencer.Task] = [] - - while not self.task_queue.empty(): - if batching_size == self.max_batch_size: - break - batching_size += 1 - new_select_packets.append(self.task_queue.get()) - - return batch + new_select_packets + + # context phase + if len(context_slots) != 0: + # do inference --> logits + select_slot = context_slots.pop(0) + select_slots= [ + select_slot + ] + + cur_task = task_slots[select_slot] + cur_tasks = [ + cur_task + ] + + cur_task.update_st("model_start") + + outputs = self.inferencer.infer( + cur_tasks, + is_context=True, + valid_slot_ids=select_slots + ) + + cur_task.update_st("model_end") + + # sample logits --> tokens + next_tokens, _ = self.sampler.sample( + tasks=cur_tasks, + logits=outputs["last_logits"] + ) + + cur_task.update_st("process_end") + + # postprocess -> gen result + generation_results = self.sampler.postprocess( + tasks=cur_tasks, + infer_outputs=outputs, + next_tokens=next_tokens, + ) + + # add result to task + cur_task.add_result(generation_results[0]) + if generation_results[0].finish_reason: + cur_task.finish() + + + # decode phase + else: + select_slots = [] + valid_tasks = [] + for i, task in enumerate(task_slots): + if task is not None: + select_slots.append(i) + valid_tasks.append(task) + + for task in valid_tasks: + task.update_st("model_start") + + outputs = self.inferencer.infer( + valid_tasks, + is_context=False, + valid_slot_ids=select_slots + ) + + for task in valid_tasks: + task.update_st("model_end") + + + # sample logits --> tokens + next_tokens, _ = self.sampler.sample( + tasks=valid_tasks, + logits=outputs["last_logits"] + ) + + for task in valid_tasks: + task.update_st("process_end") + + # postprocess -> gen result + generation_results = self.sampler.postprocess( + tasks=valid_tasks, + infer_outputs=outputs, + next_tokens=next_tokens, + ) + + # add result to task + for i, gen_res in enumerate(generation_results): + valid_tasks[i].add_result(gen_res) + if gen_res.finish_reason: + valid_tasks[i].finish() + + for i, task in enumerate(task_slots): + if task is not None and task.is_finished(): + avail_slots.append(i) + task_slots[i] = None + + avail_slots.sort(reverse=True) \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/__init__.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/__init__.py index b15939fb..9c152f62 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/model_impl/__init__.py +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/__init__.py @@ -11,33 +11,10 @@ import torch import torch.nn as nn -from .chatglm2 import ChatGLMForConditionalGeneration, ChatGLMConfig +from .gpu_chatglm2 import GPUChatGLM2 from llm_perf.utils.logger import logger - -class GPUChatGLM2(nn.Module): - def __init__(self, xpu_cfg: Dict[str, Any]) -> None: - super().__init__() - - model_config = xpu_cfg["model_config"] - model_name = model_config["model_name"] - model_path = model_config["model_path"] - model_network = model_config["network"] - - self.model = ChatGLMForConditionalGeneration.from_pretrained( - model_path, - config=ChatGLMConfig(**model_network) - ) - self.model.eval() - self.model.half().cuda() - logger.info(f"cuda model {model_path} loaded {self.model}") - - def forward(self, inputs): - outputs = self.model(**inputs) - return outputs - - __all__ = { "chatglm2": GPUChatGLM2 } \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/chatglm2.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/chatglm2.py index 4163233b..d9ac55af 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/model_impl/chatglm2.py +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/chatglm2.py @@ -1,12 +1,16 @@ """ PyTorch ChatGLM model. """ +import os import math import copy import warnings import re import sys +import time import torch +import torch.nn as nn +import torch.distributed as dist import torch.utils.checkpoint import torch.nn.functional as F from torch import nn @@ -256,6 +260,11 @@ class CoreAttention(torch.nn.Module): def __init__(self, config: ChatGLMConfig, layer_number): super(CoreAttention, self).__init__() + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: @@ -269,6 +278,8 @@ def __init__(self, config: ChatGLMConfig, layer_number): self.hidden_size_per_attention_head = projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads + self.hidden_size_per_tp = self.hidden_size_per_partition // self.mp_size + coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: @@ -291,7 +302,9 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask) context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_tp,) + + # [s_kv, bs, head_num * head_dim] context_layer = context_layer.reshape(*new_context_layer_shape) else: # Raw attention scores @@ -364,7 +377,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_tp,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer @@ -379,6 +392,12 @@ class SelfAttention(torch.nn.Module): def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(SelfAttention, self).__init__() + + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + # layer index self.layer_number = max(1, layer_number) self.projection_size = config.kv_channels * config.num_attention_heads @@ -394,17 +413,39 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.qkv_hidden_size = ( self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num ) - self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, - device=device, **_config_to_kwargs(config) - ) + + self.hidden_size = self.projection_size + self.num_head = config.num_attention_heads + self.hidden_size_per_tp = self.hidden_size // self.mp_size + self.num_heads_per_tp = self.num_head // self.mp_size + self.head_dim_per_tp = self.hidden_size_per_tp // self.num_heads_per_tp + + self.kv_heads = config.multi_query_group_num + self.kv_heads_per_tp = self.kv_heads // self.mp_size if self.mp_size % self.kv_heads else 1 + + self.qkv_hidden_size = ( + self.num_heads_per_tp * self.head_dim_per_tp + + 2 * self.kv_heads_per_tp * self.head_dim_per_tp + ) + + self.query_key_value = nn.Linear( + config.hidden_size, + self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, + **_config_to_kwargs(config) + ) self.core_attention = CoreAttention(config, self.layer_number) # Output. - self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, - device=device, **_config_to_kwargs(config) - ) + self.dense = nn.Linear( + self.projection_size // self.mp_size, + config.hidden_size, + bias=config.add_bias_linear, + device=device, + **_config_to_kwargs(config) + ) def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): if self.multi_query_attention: @@ -421,7 +462,8 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, ) def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, + **kwargs ): # hidden_states: [sq, b, h] @@ -433,26 +475,55 @@ def forward( # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + + # [seq_len, bs, hidden_size] --> [seq_len, batch_size, (32 + 2 + 2) * 128] mixed_x_layer = self.query_key_value(hidden_states) if self.multi_query_attention: + + # (query_layer, key_layer, value_layer) = mixed_x_layer.split( + # [ + # self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + # self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + # self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + # ], + # dim=-1, + # ) + (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_heads_per_tp * self.head_dim_per_tp, + self.kv_heads_per_tp * self.head_dim_per_tp, + self.kv_heads_per_tp * self.head_dim_per_tp, ], dim=-1, ) + + # query: [seq_len, batch_size, 32, 128] + # key: [seq_len, batch_size, 2, 128] + # value: [seq_len, batch_size, 2, 128] + # query_layer = query_layer.view( + # query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + # ) + # key_layer = key_layer.view( + # key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + # ) + # value_layer = value_layer.view( + # value_layer.size()[:-1] + # + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + # ) + query_layer = query_layer.view( - query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + query_layer.size()[:-1] + + (self.num_heads_per_tp, self.head_dim_per_tp) ) key_layer = key_layer.view( - key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + key_layer.size()[:-1] + + (self.kv_heads_per_tp, self.head_dim_per_tp) ) value_layer = value_layer.view( value_layer.size()[:-1] - + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + + (self.kv_heads_per_tp, self.head_dim_per_tp) ) else: new_tensor_shape = mixed_x_layer.size()[:-1] + \ @@ -468,30 +539,80 @@ def forward( query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - # adjust key and value for inference - if kv_cache is not None: - cache_k, cache_v = kv_cache - key_layer = torch.cat((cache_k, key_layer), dim=0) - value_layer = torch.cat((cache_v, value_layer), dim=0) - if use_cache: - kv_cache = (key_layer, value_layer) + # # adjust key and value for inference + # if kv_cache is not None: + # cache_k, cache_v = kv_cache + # key_layer = torch.cat((cache_k, key_layer), dim=0) + # value_layer = torch.cat((cache_v, value_layer), dim=0) + # if use_cache: + # kv_cache = (key_layer, value_layer) + # else: + # kv_cache = None + + + is_context = kwargs.get("is_context") + valid_slot_ids = kwargs.get("valid_slot_ids") + all_q_len = kwargs.get("all_q_len") + all_kv_len = kwargs.get("all_kv_len") + + + # kv_cache: 2 * [max_seq_len, max_batch_size, kv_head_num, kv_head_dim] + if is_context: + slot_id = valid_slot_ids[0] + q_len = all_q_len[0] + kv_cache[0][0:q_len, slot_id:slot_id+1, :, :] = key_layer + kv_cache[1][0:q_len, slot_id:slot_id+1, :, :] = value_layer else: - kv_cache = None + q_len, batch_size, _, _ = key_layer.shape + max_qkv_len = q_len + max(all_kv_len) + for i, slot_id in enumerate(valid_slot_ids): + q_len = all_q_len[i] + kv_len = all_kv_len[i] + kv_cache[0][kv_len:kv_len+q_len, slot_id:slot_id+1, :, :] = key_layer[:, i, :, :] + kv_cache[1][kv_len:kv_len+q_len, slot_id:slot_id+1, :, :] = value_layer[:, i, :, :] + cur_k_cache = kv_cache[0][0:max_qkv_len] + cur_v_cache = kv_cache[1][0:max_qkv_len] + select_slots = torch.tensor(valid_slot_ids, device=key_layer.device) + key_layer = torch.index_select(cur_k_cache, 1, select_slots) + value_layer = torch.index_select(cur_v_cache, 1, select_slots) + + # if self.multi_query_attention: + # # [seq_len, batch_size, 2, 1, 128] + # key_layer = key_layer.unsqueeze(-2) + # # [seq_len, batch_size, 2, 16, 128] + # key_layer = key_layer.expand( + # -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + # ) + # # [seq_len, batch_size, 32, 128] + # key_layer = key_layer.contiguous().view( + # key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + # ) + # value_layer = value_layer.unsqueeze(-2) + # value_layer = value_layer.expand( + # -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + # ) + # value_layer = value_layer.contiguous().view( + # value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + # ) + if self.multi_query_attention: + # [seq_len, batch_size, 2, 1, 128] key_layer = key_layer.unsqueeze(-2) + # [seq_len, batch_size, 2, 16, 128] key_layer = key_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + -1, -1, -1, self.num_heads_per_tp // self.kv_heads_per_tp, -1 ) + # [seq_len, batch_size, 32, 128] key_layer = key_layer.contiguous().view( - key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + key_layer.size()[:2] + (self.num_heads_per_tp, self.head_dim_per_tp) ) value_layer = value_layer.unsqueeze(-2) value_layer = value_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + -1, -1, -1, self.num_heads_per_tp // self.kv_heads_per_tp, -1 ) value_layer = value_layer.contiguous().view( - value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + value_layer.size()[:2] + (self.num_heads_per_tp, self.head_dim_per_tp) ) # ================================== @@ -527,12 +648,16 @@ class MLP(torch.nn.Module): def __init__(self, config: ChatGLMConfig, device=None): super(MLP, self).__init__() + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + self.add_bias = config.add_bias_linear # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf self.dense_h_to_4h = nn.Linear( config.hidden_size, - config.ffn_hidden_size * 2, + config.ffn_hidden_size * 2 // self.mp_size, bias=self.add_bias, device=device, **_config_to_kwargs(config) @@ -546,7 +671,7 @@ def swiglu(x): # Project back to h. self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, + config.ffn_hidden_size // self.mp_size, config.hidden_size, bias=self.add_bias, device=device, @@ -571,6 +696,10 @@ class GLMBlock(torch.nn.Module): def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(GLMBlock, self).__init__() + + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + self.layer_number = layer_number self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm @@ -594,7 +723,13 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.mlp = MLP(config, device=device) def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + **kwargs ): # hidden_states: [s, b, h] @@ -606,8 +741,11 @@ def forward( attention_mask, rotary_pos_emb, kv_cache=kv_cache, - use_cache=use_cache + use_cache=use_cache, + **kwargs ) + if self.mp_size > 1: + dist.all_reduce(attention_output) # Residual connection. if self.apply_residual_connection_post_layernorm: @@ -623,6 +761,8 @@ def forward( # MLP. mlp_output = self.mlp(layernorm_output) + if self.mp_size > 1: + dist.all_reduce(mlp_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: @@ -642,6 +782,10 @@ class GLMTransformer(torch.nn.Module): def __init__(self, config: ChatGLMConfig, device=None): super(GLMTransformer, self).__init__() + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + self.fp32_residual_connection = config.fp32_residual_connection self.post_layer_norm = config.post_layer_norm @@ -669,16 +813,12 @@ def forward( self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, use_cache: Optional[bool] = True, output_hidden_states: Optional[bool] = False, + **kwargs ): if not kv_caches: kv_caches = [None for _ in range(self.num_layers)] presents = () if use_cache else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False + all_self_attentions = None all_hidden_states = () if output_hidden_states else None @@ -702,7 +842,8 @@ def forward( attention_mask, rotary_pos_emb, kv_cache=kv_caches[index], - use_cache=use_cache + use_cache=use_cache, + **kwargs ) hidden_states, kv_cache = layer_ret if use_cache: @@ -752,6 +893,53 @@ def get_masks(self, input_ids, past_key_values, padding_mask=None): full_attention_mask.unsqueeze_(1) return full_attention_mask + + def get_context_masks( + self, + input_ids : torch.Tensor, + padding_mask : torch.Tensor + ): + # input_ids: [1, q_len] + # padding_mask = [1, q_len] + batch_size, q_len = input_ids.shape + + # [1, q_len, q_len] + full_attention_mask = torch.ones( + 1, q_len, q_len, + device=input_ids.device + ) + full_attention_mask.tril_() + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_decode_masks( + self, + input_ids : torch.Tensor, + all_kv_len: List[int] + ): + # input_ids: [batch_size, 1] + # padding_mask: [batch_size, 1 + max_kv_len] + batch_size, q_len = input_ids.shape + max_qkv_len = q_len + max(all_kv_len) + + # [batch_size, 1, max_qkv_len] + padding_mask = [] + for i in range(batch_size): + cur_qkv_len = q_len + all_kv_len[i] + mask_per_batch = [1] * cur_qkv_len + [0] * (max_qkv_len - cur_qkv_len) + padding_mask.append(mask_per_batch) + full_attention_mask = torch.tensor( + padding_mask, + device=input_ids.device + ).unsqueeze_(1) + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): batch_size, seq_length = input_ids.shape position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) @@ -854,6 +1042,7 @@ def forward( use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -874,9 +1063,18 @@ def forward( attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1) - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + # if full_attention_mask is None: + # if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + # full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + + is_context = kwargs.get("is_context") + all_kv_len = kwargs.get("all_kv_len") + if is_context: + full_attention_mask = self.get_context_masks(input_ids, attention_mask) + else: + full_attention_mask = self.get_decode_masks(input_ids, all_kv_len) + # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) @@ -889,7 +1087,8 @@ def forward( # Run encoder. hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states + kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states, + **kwargs ) if not return_dict: @@ -902,10 +1101,6 @@ def forward( attentions=all_self_attentions, ) - def quantize(self, weight_bit_width: int): - from .quantization import quantize - quantize(self.encoder, weight_bit_width) - return self class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): @@ -920,63 +1115,6 @@ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): if self.config.quantization_bit: self.quantize(self.config.quantization_bit, empty_init=True) - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) - - model_kwargs["is_first_forward"] = False - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - is_first_forward: bool = True, - **kwargs - ) -> dict: - # only last token for input_ids if past is not None - if position_ids is None: - position_ids = self.get_position_ids(input_ids, device=input_ids.device) - if not is_first_forward: - if past_key_values is not None: - position_ids = position_ids[..., -1:] - input_ids = input_ids[:, -1:] - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "position_ids": position_ids, - "attention_mask": attention_mask, - "return_last_logit": True, - "use_cache": use_cache - } - def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -990,6 +1128,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, return_last_logit: Optional[bool] = False, + **kwargs ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1003,6 +1142,7 @@ def forward( use_cache=use_cache, output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs ) hidden_states = transformer_outputs[0] @@ -1011,335 +1151,157 @@ def forward( lm_logits = self.transformer.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, + logits=lm_logits ) - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - - def process_response(self, response): - response = response.strip() - response = response.replace("[[训练时间]]", "2023年") - return response - - def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): - prompt = tokenizer.build_prompt(query, history=history) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - return inputs - - def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): - if history: - prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) - input_ids = tokenizer.encode(prompt, add_special_tokens=False) - input_ids = input_ids[1:] - inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) - else: - prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - return inputs - - @torch.inference_mode() - def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1, - do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - inputs = self.build_inputs(tokenizer, query, history=history) - outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - response = self.process_response(response) - history = history + [(query, response)] - return response, history - - @torch.inference_mode() - def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None, - max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, - return_past_key_values=False, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if past_key_values is None and not return_past_key_values: - inputs = self.build_inputs(tokenizer, query, history=history) - else: - inputs = self.build_stream_inputs(tokenizer, query, history=history) - if past_key_values is not None: - past_length = past_key_values[0][0].shape[0] - if self.transformer.pre_seq_len is not None: - past_length -= self.transformer.pre_seq_len - inputs.position_ids += past_length - attention_mask = inputs.attention_mask - attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) - inputs['attention_mask'] = attention_mask - for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, - return_past_key_values=return_past_key_values, **gen_kwargs): - if return_past_key_values: - outputs, past_key_values = outputs - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - if response and response[-1] != "�": - response = self.process_response(response) - new_history = history + [(query, response)] - if return_past_key_values: - yield response, new_history, past_key_values - else: - yield response, new_history - - @torch.inference_mode() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - return_past_key_values=False, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - model_kwargs["use_cache"] = generation_config.use_cache - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) - if return_past_key_values: - yield input_ids, outputs.past_key_values - else: - yield input_ids - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - - def quantize(self, bits: int, empty_init=False, device=None, **kwargs): - if bits == 0: - return - - from .quantization import quantize - - if self.quantized: - logger.info("Already quantized.") - return self - - self.quantized = True - - self.config.quantization_bit = bits - self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, - **kwargs) - return self + # def _update_model_kwargs_for_generation( + # self, + # outputs: ModelOutput, + # model_kwargs: Dict[str, Any], + # is_encoder_decoder: bool = False, + # standardize_cache_format: bool = False, + # ) -> Dict[str, Any]: + # # update past_key_values + # model_kwargs["past_key_values"] = self._extract_past_from_model_output( + # outputs, standardize_cache_format=standardize_cache_format + # ) + + # # update attention mask + # if "attention_mask" in model_kwargs: + # attention_mask = model_kwargs["attention_mask"] + # model_kwargs["attention_mask"] = torch.cat( + # [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + # ) + + # # update position ids + # if "position_ids" in model_kwargs: + # position_ids = model_kwargs["position_ids"] + # new_position_id = position_ids[..., -1:].clone() + # new_position_id += 1 + # model_kwargs["position_ids"] = torch.cat( + # [position_ids, new_position_id], dim=-1 + # ) + + # model_kwargs["is_first_forward"] = False + # return model_kwargs + + # def prepare_inputs_for_generation( + # self, + # input_ids: torch.LongTensor, + # past_key_values: Optional[torch.Tensor] = None, + # attention_mask: Optional[torch.Tensor] = None, + # position_ids: Optional[torch.Tensor] = None, + # use_cache: Optional[bool] = None, + # is_first_forward: bool = True, + # **kwargs + # ) -> dict: + # # only last token for input_ids if past is not None + # if position_ids is None: + # position_ids = self.get_position_ids(input_ids, device=input_ids.device) + # if not is_first_forward: + # if past_key_values is not None: + # position_ids = position_ids[..., -1:] + # input_ids = input_ids[:, -1:] + # return { + # "input_ids": input_ids, + # "past_key_values": past_key_values, + # "position_ids": position_ids, + # "attention_mask": attention_mask, + # "return_last_logit": True, + # "use_cache": use_cache + # } + + + + # @staticmethod + # def _reorder_cache( + # past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + # ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + # """ + # This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + # [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + # beam_idx at every generation step. + + # Output shares the same memory storage as `past`. + # """ + # return tuple( + # ( + # layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + # layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + # ) + # for layer_past in past + # ) + + # def process_response(self, response): + # response = response.strip() + # response = response.replace("[[训练时间]]", "2023年") + # return response + + # def build_inputs( + # self, + # tokenizer, + # query: str, + # history: List[Tuple[str, str]] = None + # ): + # prompt = tokenizer.build_prompt(query, history=history) + # inputs = tokenizer([prompt], return_tensors="pt") + # inputs = inputs.to(self.device) + # return inputs + + + + # def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + # if history: + # prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + # input_ids = tokenizer.encode(prompt, add_special_tokens=False) + # input_ids = input_ids[1:] + # inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) + # else: + # prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + # inputs = tokenizer([prompt], return_tensors="pt") + # inputs = inputs.to(self.device) + # return inputs + + # @torch.inference_mode() + # def chat( + # self, + # tokenizer, + # query: str, + # history: List[Tuple[str, str]] = None, + # max_length: int = 8192, + # num_beams=1, + # do_sample=True, + # top_p=0.8, + # temperature=0.8, + # logits_processor=None, + # **kwargs + # ): + # if history is None: + # history = [] + # if logits_processor is None: + # logits_processor = LogitsProcessorList() + # logits_processor.append(InvalidScoreLogitsProcessor()) + # gen_kwargs = { + # "max_length": max_length, + # "num_beams": num_beams, + # "do_sample": do_sample, + # "top_p": top_p, + # "temperature": temperature, + # "logits_processor": logits_processor, + # **kwargs + # } + # inputs = self.build_inputs(tokenizer, query, history=history) + + + # outputs = self.generate(**inputs, **gen_kwargs) + + + # outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + # response = tokenizer.decode(outputs) + # response = self.process_response(response) + # history = history + [(query, response)] + # return response, history -class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.num_labels = config.num_labels - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - - self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half) - if config.classifier_dropout is not None: - self.dropout = nn.Dropout(config.classifier_dropout) - else: - self.dropout = None - self.config = config - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - full_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - full_attention_mask=full_attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - pooled_hidden_states = hidden_states[-1] - if self.dropout is not None: - pooled_hidden_states = self.dropout(pooled_hidden_states) - logits = self.classifier_head(pooled_hidden_states) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze().float(), labels.squeeze()) - else: - loss = loss_fct(logits.float(), labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits.float(), labels.view(-1, self.num_labels)) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_chatglm2.py b/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_chatglm2.py new file mode 100644 index 00000000..05430a82 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_chatglm2.py @@ -0,0 +1,213 @@ +import os +import torch +import torch.distributed as dist +import torch.nn as nn + +from typing import Dict, Any +from llm_perf.utils.logger import logger +from llm_perf.utils.ps_utils import check_memory_usage +from llm_perf.utils.dist_utils import check_dist + +from accelerate import init_empty_weights + +from llm_perf.backends.GPU.gpu_ckpt_loader import GpuCkptLoader + +from .chatglm2 import ChatGLMForConditionalGeneration, ChatGLMModel, ChatGLMConfig + + +class GPUChatGLM2Loader(GpuCkptLoader): + def __init__( + self, + prefix, + model, + mp_size=1, + mp_rank=0, + ckpt_path: str = "" + ): + super().__init__(prefix, model, mp_size, mp_rank, ckpt_path) + + def parallel_loader(self): + self.state_dict = None + if self.mp_rank == 0: + self.state_dict = self.torch_load_wrapper( + self.ckpt_path, map_location=torch.device("cpu")) + + if self.mp_size == 1: + return self.state_dict + + # mp_size > 2 + # broadcast state_dict from rank 0 to other ranks + self.broadcast_meta() + + self.broadcast_weight("transformer.embedding.word_embeddings.weight") + self.broadcast_weight("transformer.output_layer.weight") + self.broadcast_weight("transformer.rotary_pos_emb.inv_freq") + self.broadcast_weight("transformer.encoder.final_layernorm.weight") + + for i, block in enumerate(self.model.transformer.encoder.layers): + self.broadcast_weight(f"transformer.encoder.layers.{i}.input_layernorm.weight") + self.broadcast_weight(f"transformer.encoder.layers.{i}.post_attention_layernorm.weight") + + self.scatter_weight(f"transformer.encoder.layers.{i}.mlp.dense_h_to_4h.weight", dim=0, split_mode='with_outter', outter=2) + self.scatter_weight(f"transformer.encoder.layers.{i}.mlp.dense_4h_to_h.weight", dim=-1) + + self.scatter_weight(f"transformer.encoder.layers.{i}.self_attention.query_key_value.weight", dim=0, split_mode='split_outter', outter=[32, 2, 2]) + self.scatter_weight(f"transformer.encoder.layers.{i}.self_attention.query_key_value.bias", dim=0, split_mode='split_outter', outter=[32, 2, 2]) + + self.scatter_weight(f"transformer.encoder.layers.{i}.self_attention.dense.weight", dim=-1) + + return self.state_dict + + def infusion_to_model(self): + self.model.transformer.embedding.word_embeddings.weight = self.to_parameter( + self.state_dict[f"transformer.embedding.word_embeddings.weight"] + ) + self.model.transformer.output_layer.weight = self.to_parameter( + self.state_dict[f"transformer.output_layer.weight"] + ) + self.model.transformer.rotary_pos_emb.inv_freq = self.to_parameter( + self.state_dict[f"transformer.rotary_pos_emb.inv_freq"] + ) + self.model.transformer.encoder.final_layernorm.weight = self.to_parameter( + self.state_dict[f"transformer.encoder.final_layernorm.weight"] + ) + + for i, block in enumerate(self.model.transformer.encoder.layers): + block.input_layernorm.weight = self.to_parameter( + self.state_dict[f"transformer.encoder.layers.{i}.input_layernorm.weight"] + ) + + block.mlp.dense_4h_to_h.weight = self.to_parameter( + self.state_dict[f"transformer.encoder.layers.{i}.mlp.dense_4h_to_h.weight"] + ) + block.mlp.dense_h_to_4h.weight = self.to_parameter( + self.state_dict[f"transformer.encoder.layers.{i}.mlp.dense_h_to_4h.weight"] + ) + + block.post_attention_layernorm.weight = self.to_parameter( + self.state_dict[f"transformer.encoder.layers.{i}.post_attention_layernorm.weight"] + ) + + block.self_attention.dense.weight = self.to_parameter( + self.state_dict[f"transformer.encoder.layers.{i}.self_attention.dense.weight"] + ) + block.self_attention.query_key_value.bias = self.to_parameter( + self.state_dict[f"transformer.encoder.layers.{i}.self_attention.query_key_value.bias"] + ) + block.self_attention.query_key_value.weight = self.to_parameter( + self.state_dict[f"transformer.encoder.layers.{i}.self_attention.query_key_value.weight"] + ) + + return self.model + + + + +class GPUChatGLM2(nn.Module): + def __init__(self, xpu_cfg: Dict[str, Any]) -> None: + super().__init__() + + self.xpu_cfg = xpu_cfg + self.model_config = xpu_cfg["model_config"] + + self.model_name = self.model_config["model_name"] + self.model_path = self.model_config["model_path"] + self.model_network = self.model_config["network"] + + self.chatglm_config = ChatGLMConfig(**self.model_network) + # print(self.chatglm_config) + + # dist config + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + + self.prefix = "transformer.encoder.layers" + self.transformer_model : ChatGLMForConditionalGeneration = None + + + + + def init_inference(self): + torch.cuda.set_device(self.local_rank) + + if self.mp_size > 1: + logger.info(f"RANK: {self.local_rank} {self.mp_size} init_process_group...") + dist.init_process_group( + backend="nccl", + world_size=self.mp_size, + rank=self.local_rank + ) + check_dist() + + check_memory_usage("Begin") + + with init_empty_weights(): + self.transformer_model = ChatGLMForConditionalGeneration( + self.chatglm_config, empty_init=False + ) + self.transformer_model.eval() + + check_memory_usage("After build model") + + self.load_weight(self.model_path) + + check_memory_usage("After load_weight") + + self.transformer_model.half().cuda() + + check_memory_usage("After model to device") + + self.kv_cache = self.init_kvcache(torch.float16) + + logger.info(f"cuda model {self.model_path} loaded {self.transformer_model}") + + + def load_weight(self, ckpt_path): + p_loader = GPUChatGLM2Loader( + self.prefix, self.transformer_model, + self.mp_size, self.local_rank, + ckpt_path + ) + p_loader.load() + p_loader.infusion_to_model() + + + def init_kvcache(self, dtype): + max_seq_len = 4096 + max_batch_size = self.xpu_cfg["max_batch_size"] + kv_head_num = self.chatglm_config.multi_query_group_num + kv_head_dim = self.chatglm_config.kv_channels + + kv_head_num = kv_head_num // self.mp_size if self.mp_size % kv_head_num else 1 + + past_key_values = () + layer_num = self.chatglm_config.num_layers + for i in range(layer_num): + # [max_seq_len, max_batch_size, kv_head_num, kv_head_dim] + key_cache = torch.zeros( + (max_seq_len, max_batch_size, kv_head_num, kv_head_dim), + dtype=dtype, + device='cuda' + ) + value_cache = torch.zeros( + (max_seq_len, max_batch_size, kv_head_num, kv_head_dim), + dtype=dtype, + device='cuda' + ) + past_key_values += ((key_cache, value_cache),) + + return past_key_values + + + def forward(self, inputs : Dict[str, torch.Tensor]): + outputs = self.transformer_model.forward( + **inputs, + past_key_values=self.kv_cache, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + return_last_logit=False + ) + return outputs diff --git a/byte_infer_perf/llm_perf/benchmark/bench.py b/byte_infer_perf/llm_perf/benchmark/bench.py index 9ae5baa9..7ca21e4b 100644 --- a/byte_infer_perf/llm_perf/benchmark/bench.py +++ b/byte_infer_perf/llm_perf/benchmark/bench.py @@ -88,32 +88,46 @@ def bench_performance( result_queue: mp.Queue, ): result_queue.put("@start") - perf_start = time.time() - perf_time: int = workload["perf_time"] - while perf_start + perf_time > time.time(): + + + accum_time = 0 + perf_time: int = workload["perf_time"] * int(1e9) + + while accum_time < perf_time: # make fake prompt, actual input_ids len may exceed input_tokens prompt = "我" * input_tokens - st = time.time() + st = time.perf_counter_ns() first_token_latency = 0 + min_new_tokens = workload["min_new_tokens"] + max_new_tokens = workload["max_new_tokens"] + output_messages: str = "" + wait_time = [] + model_time = [] + post_process_time = [] + for res in gen_stream_request( stub, index=index, prompt=prompt, - min_new_tokens=workload["min_new_tokens"], - max_new_tokens=workload["max_new_tokens"], + min_new_tokens=min_new_tokens, + max_new_tokens=max_new_tokens, top_p=0, top_k=1, get_input_logits=0, ): res = {k: deserialize_value(v) for k, v in res.outputs.items()} output_messages += res["choice"]["message"] + wait_time.append(res["choice"]["wait_time"]) + model_time.append(res["choice"]["model_time"]) + post_process_time.append(res["choice"]["post_process_time"]) if first_token_latency == 0: - first_token_latency = time.time() - st + first_token_latency = (time.perf_counter_ns() - st) / 1e6 - use_time = time.time() - st + use_time = time.perf_counter_ns() - st + accum_time += use_time # record context and decode len prompt_tokens = res["usage"]["prompt_tokens"] @@ -121,9 +135,17 @@ def bench_performance( # seperate context and decode latency if completion_tokens > 1: - per_token_latency = (use_time - first_token_latency) / (completion_tokens - 1) + per_token_latency = (use_time - first_token_latency) / (completion_tokens - 1) / 1e6 else: - per_token_latency = first_token_latency + per_token_latency = first_token_latency / 1e6 + + context_wait_time = wait_time[0] + context_model_time = model_time[0] + context_postprocess_time = post_process_time[0] + + decode_wait_time = sum(wait_time[1:]) / len(wait_time[1:]) + decode_model_time = sum(model_time[1:]) / len(model_time[1:]) + decode_postprocess_time = sum(post_process_time[1:]) / len(post_process_time[1:]) result = { "prompt_tokens": prompt_tokens, @@ -131,32 +153,22 @@ def bench_performance( "output_message": output_messages, "first_token_latency": first_token_latency, "per_token_latency": per_token_latency, - } - logger.debug(f"bench_{index} prompt response: {result}") - result_queue.put(result) + "context_wait_time": context_wait_time, + "context_model_time": context_model_time, + "context_postprocess_time": context_postprocess_time, -def test(stub, index: int): - prompt = "中国的首都在哪里?" - output_messages: str = "" - for res in gen_stream_request( - stub, - index=index, - prompt=prompt, - min_new_tokens=1, - max_new_tokens=256, - top_p=0, - top_k=1, - get_input_logits=0, - ): - res = {k: deserialize_value(v) for k, v in res.outputs.items()} - output_messages += res["choice"]["message"] - logger.info(f"bench_{index} prompt response: {output_messages}") - + "decode_wait_time": decode_wait_time, + "decode_model_time": decode_model_time, + "decode_postprocess_time": decode_postprocess_time, + } + logger.debug(f"bench_{index} prompt response: {result}") + result_queue.put(result) def benchmark( index: int, + start_wait: int, workload: Dict[str, Any], report_type: ReportType, input_tokens: int, @@ -165,11 +177,11 @@ def benchmark( ): with grpc.insecure_channel(f"{args.host}:{args.port}") as channel: stub = server_pb2_grpc.InferenceStub(channel) - logger.info(f"{report_type.name} bench_{index} start") - - # test function - # test(stub, index) - + logger.debug(f"{report_type.name} bench_{index} start") + + # wait for start_wait seconds + time.sleep(1 * start_wait) + try: if report_type == ReportType.ACCURACY: bench_accuracy(stub, workload, result_queue) @@ -179,5 +191,5 @@ def benchmark( logger.error(f"{report_type.name} bench_{index} error: {e}") raise e - logger.info(f"{report_type.name} bench_{index} finish") + logger.debug(f"{report_type.name} bench_{index} finish") result_queue.put(None) diff --git a/byte_infer_perf/llm_perf/core/ckpt_loader.py b/byte_infer_perf/llm_perf/core/ckpt_loader.py new file mode 100644 index 00000000..5bdb702f --- /dev/null +++ b/byte_infer_perf/llm_perf/core/ckpt_loader.py @@ -0,0 +1,228 @@ +import os +import sys +import time +import pathlib + +import torch +import torch.nn as nn +import torch.distributed as dist + +from llm_perf.utils.logger import logger + +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor + +from typing import Union, List + +class CoreCkptLoader(ABC): + def __init__( + self, + prefix, model, + mp_size=1, mp_rank=0, + ckpt_path: str = "" + ): + self.prefix = prefix + self.model = model + + self.mp_size = mp_size + self.mp_rank = mp_rank + + self.ckpt_path = ckpt_path + + self.state_dict = None + + + def to_parameter( + self, + data : torch.Tensor, + dtype : torch.dtype =None + ): + if dtype is not None: + data = data.to(dtype) + return nn.Parameter(data, requires_grad=False) + + + def to_contiguous(self, num_layers, param_suffixes, prefix, state_dict): + result = {} + + with ThreadPoolExecutor() as executor: + for i in range(num_layers): + for suffix in param_suffixes: + # for example: + # "transformer.encoder.layers.0.mlp.dense_4h_to_h.weight" + name = f"{prefix}.{i}.{suffix}" + if name in state_dict: + result[name] = executor.submit(lambda t : t.contiguous(), state_dict[name]) + + for i in range(num_layers): + for suffix in param_suffixes: + name = f"{prefix}.{i}.{suffix}" + if name in state_dict: + state_dict[name] = result[name].result + + + def gqa_split(self, src, dim): + qkv_head_num = src.shape[dim] // self.head_dim + src_split = src.chunk(qkv_head_num, dim=dim) + qkv_cat = [] + for i in range(self.mp_size): + qkv_cat.append( + torch.cat( + [src_split[i * self.mp_size + self.mp_rank] for i in range(qkv_head_num // self.mp_size)], + axis=dim, + ) + ) + + return qkv_cat + + + def qkv_split(self, src, dim): + src_split = torch.split(src.data, src.shape[dim] // 3, dim=dim) + qkv_split = [torch.split(src_s, src_s.shape[dim] // self.mp_size, dim=dim) for src_s in src_split] + qkv_cat = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=dim) for i in range(len(qkv_split[0]))] + return qkv_cat + + + def with_outter_split( + self, + src : torch.Tensor, + dim : int, + outter : int + ): + src_split = torch.split(src.data, src.shape[dim] // outter, dim=dim) + output_split = [torch.split(src_s, src_s.shape[dim] // self.mp_size, dim=dim) for src_s in src_split] + output_tensors = [ + torch.cat( + [output_s[i] for output_s in output_split], + axis=dim + ) for i in range(len(output_split[0])) + ] + return output_tensors + + + def split( + self, + src : torch.Tensor, + dim : int, + chunks : List [int]=[] + ): + if len(chunks) == 0: + split_arg = src.shape[dim] // self.mp_size + output_tensors = torch.split(src, split_arg, dim=dim) + else: + # for example + # chunks = [32, 2, 2], sum_chunks = 36, src.shape[dim] = (32 + 2 + 2) * 128, other_dim = 128 + # mp_size = 8 + # new_chunks = [4, 1, 1] + sum_chunks = sum(chunks) + other_dim_size = src.shape[dim] // sum_chunks + + split_arg = [i * other_dim_size for i in chunks] + split_tensors = torch.split(src, split_arg, dim=dim) + + output_split = [] + for i, tensor in enumerate(split_tensors): + if self.mp_size > chunks[i]: + tensor_shape = tensor.size()[:dim] + (chunks[i], 1, other_dim_size) + tensor.size()[dim+1:] + new_tensor_shape = tensor.size()[:dim] + (chunks[i], self.mp_size // chunks[i], other_dim_size) + tensor.size()[dim+1:] + output_tensor_shape = tensor.size()[:dim] + (self.mp_size * other_dim_size,) + tensor.size()[dim+1:] + + tensor = tensor.view(tensor_shape) + tensor = tensor.expand(*new_tensor_shape) + tensor = tensor.contiguous() + tensor = tensor.view(output_tensor_shape) + + cur_split = torch.split(tensor, tensor.shape[dim] // self.mp_size, dim=dim) + output_split.append(cur_split) + + output_tensors = [] + for i in range(self.mp_size): + temp_tensors = [output_split[j][i] for j in range(len(chunks))] + tp_tensors = torch.concat(temp_tensors, dim=dim) + output_tensors.append(tp_tensors) + + return output_tensors + + + + def broadcast_meta(self): + meta = [ + {k: [v.shape, v.dtype] for k, v in self.state_dict.items()} + ] if self.mp_rank == 0 else [None] + dist.broadcast_object_list(meta, src=0) + dist.barrier() + if self.mp_rank != 0: + self.state_dict = { + k: torch.empty(v[0], dtype=v[1], device='meta') for k, v in meta[0].items() + } + + + @abstractmethod + def broadcast_weight(self, key, device='cpu', non_blocking=False): + raise NotImplementedError + + + # split_mode + # default + # with_outter + # split_outter + @abstractmethod + def scatter_weight(self, key, dim, split_mode='default', outter=1, non_blocking=False): + raise NotImplementedError + + + @abstractmethod + def parallel_loader(self): + raise NotImplementedError + + + @abstractmethod + def infusion_to_model(self): + raise NotImplementedError + + + + + + def load(self): + return self.parallel_loader() + + + def torch_load_wrapper( + self, + ckpt_path: str, + map_location: Union[str, torch.device]=torch.device('cpu') + ): + st = time.time() + + state_dict = {} + model_path = pathlib.Path(ckpt_path) + if model_path.is_dir(): + if model_path.joinpath("pytorch_model.bin.index.json").exists(): + file_list = [] + for file in model_path.iterdir(): + if not (file.stem.startswith('pytorch_model-') and file.suffix.endswith('.bin')): + continue + file_list.append(file) + file_list.sort() + + for file in file_list: + state_dict.update(torch.load(file, map_location=map_location)) + + logger.info(f"RANK{self.mp_rank} load {ckpt_path} cost: {time.time() - st}s") + + # for key in state_dict.keys(): + # print(f"{key}, {state_dict[key].shape}") + + return state_dict + + + + + + + + + + + \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/core/generation.py b/byte_infer_perf/llm_perf/core/generation.py index d0761201..dc7d0d54 100644 --- a/byte_infer_perf/llm_perf/core/generation.py +++ b/byte_infer_perf/llm_perf/core/generation.py @@ -26,6 +26,9 @@ class GenerateRequest: class GenerateResult: token_id: int finish_reason: str + wait_time: float + model_time: float + post_process_time: float last_logits: List[float] = field(default_factory=list) input_logits: List[float] = field(default_factory=list) diff --git a/byte_infer_perf/llm_perf/core/inferencer.py b/byte_infer_perf/llm_perf/core/inferencer.py index 84fa6055..e2a955a1 100644 --- a/byte_infer_perf/llm_perf/core/inferencer.py +++ b/byte_infer_perf/llm_perf/core/inferencer.py @@ -1,3 +1,4 @@ +import time from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum @@ -32,6 +33,42 @@ def __init__(self, request: GenerateRequest): self.state = PacketStatus.PENDING self.generate_ids = [] self.exception = None + + self.create_st = time.perf_counter_ns() + self.last_model_start_st = time.perf_counter_ns() + self.last_model_end_st = time.perf_counter_ns() + self.last_process_st = time.perf_counter_ns() + + self.wait_time = [] + self.model_time = [] + self.post_process_time = [] + + + def update_st(self, st_name): + if st_name == "model_start": + self.last_model_start_st = time.perf_counter_ns() + self.wait_time.append((self.last_model_start_st - self.last_process_st) / 1e6) + elif st_name == "model_end": + self.last_model_end_st = time.perf_counter_ns() + self.model_time.append((self.last_model_end_st - self.last_model_start_st) / 1e6) + elif st_name == "process_end": + self.last_process_st = time.perf_counter_ns() + self.post_process_time.append((self.last_process_st - self.last_model_end_st) / 1e6) + + # def print_time(self): + + # context_wait_time = self.wait_time[0] + # context_model_time = self.model_time[0] + # context_postprocess_time = self.post_process_time[0] + + # decode_wait_time = sum(self.wait_time[1:]) / len(self.wait_time[1:]) + # decode_model_time = sum(self.model_time[1:]) / len(self.model_time[1:]) + # decode_postprocess_time = sum(self.post_process_time[1:]) / len(self.post_process_time[1:]) + + # print(f"context wait/model/postprocess: {context_wait_time}\t{context_model_time}\t{context_postprocess_time}") + # print(f"decode wait/model/postprocess: {decode_wait_time}\t{decode_model_time}\t{decode_postprocess_time}") + + def add_result(self, res: GenerateResult): self.generate_ids.append(res.token_id) @@ -57,7 +94,6 @@ def __init__(self) -> None: super().__init__() - @abstractmethod - def infer(self, tasks: List["CoreInferencer.Task"]): + def infer(self, tasks: List["CoreInferencer.Task"], **kwargs): raise NotImplementedError \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/core/sampler.py b/byte_infer_perf/llm_perf/core/sampler.py index 891e202d..5508eee3 100644 --- a/byte_infer_perf/llm_perf/core/sampler.py +++ b/byte_infer_perf/llm_perf/core/sampler.py @@ -14,7 +14,7 @@ def __init__(self) -> None: @abstractmethod def sample( self, - packets: List[CoreInferencer.Task], + tasks: List[CoreInferencer.Task], logits: torch.FloatTensor ) -> List[int]: """Sample next tokens diff --git a/byte_infer_perf/llm_perf/launch.py b/byte_infer_perf/llm_perf/launch.py index f13b3381..5fa8a9fe 100644 --- a/byte_infer_perf/llm_perf/launch.py +++ b/byte_infer_perf/llm_perf/launch.py @@ -13,6 +13,7 @@ # limitations under the License. import os import sys +import random import argparse import subprocess import json @@ -75,7 +76,7 @@ def get_args(): ) parser.add_argument( "--port", type=int, - default="50052", + default=51000, help="port of the server") args = parser.parse_args() @@ -300,11 +301,16 @@ def start_benchmark( report_type: ReportType, ): clients = 1 if report_type == ReportType.ACCURACY else batch_size + + sleep_units = [i for i in range(batch_size)] + random.shuffle(sleep_units) + for i in range(clients): p = mp.Process( target=benchmark, args=( i, + sleep_units[i], workload, report_type, input_tokens, diff --git a/byte_infer_perf/llm_perf/server/endpoint.py b/byte_infer_perf/llm_perf/server/endpoint.py index 8c4c517d..b32ce79d 100644 --- a/byte_infer_perf/llm_perf/server/endpoint.py +++ b/byte_infer_perf/llm_perf/server/endpoint.py @@ -44,13 +44,12 @@ def __init__(self, xpu_cfg) -> None: self.scheduler : CoreScheduler = setup.setup_scheduler(xpu_cfg) self.scheduler.start() - self.warmup() + self.warmup(xpu_cfg["max_batch_size"]) def __del__(self): self.scheduler.stop() - - def warmup(self): + def warmup(self, max_batch_size): prompt = "中国的首都是哪里?" generate_config = { "min_new_tokens": 1, @@ -59,7 +58,7 @@ def warmup(self): "temperature": 0.2, "presence_penalty": 1.0, } - logger.info(f"warmup prompt: {prompt}\nconfig: {generate_config}") + logger.info(f"warmup prompt: {prompt}, config: {generate_config}") async def _steram_warmup(): message = "" @@ -68,9 +67,18 @@ async def _steram_warmup(): result["choice"]["message"] = message return result - result = asyncio.run(_steram_warmup()) - logger.info(f"warmup response: {result}") + async def _multiple_warmup(): + tasks = [] + for _ in range(max_batch_size): + tasks.append(_steram_warmup()) + res = await asyncio.gather(*tasks) + return res + + single_result = asyncio.run(_steram_warmup()) + logger.info(f"single warmup response: {single_result}") + multiple_result = asyncio.run(_multiple_warmup()) + logger.info(f"multiple warmup reponse: {multiple_result}") async def prepare_request( self, prompt: str, generate_config: Dict[str, Any] @@ -136,7 +144,12 @@ async def streaming_inference( else: result: GenerateResult = gen_res outputs["choice"].update( - {"message": self.tokenizer.decode(result.token_id)} + { + "message": self.tokenizer.decode(result.token_id), + "wait_time": result.wait_time, + "model_time": result.model_time, + "post_process_time": result.post_process_time + } ) logger.debug(f"steam inference result: {outputs}") diff --git a/byte_infer_perf/llm_perf/server/launch_server.py b/byte_infer_perf/llm_perf/server/launch_server.py index 7d2146b9..eed965f6 100644 --- a/byte_infer_perf/llm_perf/server/launch_server.py +++ b/byte_infer_perf/llm_perf/server/launch_server.py @@ -99,7 +99,7 @@ def parse_args(): ) parser.add_argument( "--port", type=int, - default=50050 + default=51000 ) parser.add_argument( "--log_level", type=str, diff --git a/byte_infer_perf/llm_perf/utils/dist_utils.py b/byte_infer_perf/llm_perf/utils/dist_utils.py new file mode 100644 index 00000000..b2cbe96d --- /dev/null +++ b/byte_infer_perf/llm_perf/utils/dist_utils.py @@ -0,0 +1,16 @@ +import os +import torch +import torch.distributed as dist + + +def check_dist(): + mp_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + buffer = torch.zeros([1], dtype=torch.int32).cuda() + if local_rank == 0: + buffer = buffer + 1 + + print(f"rank={local_rank}, before, {buffer}") + dist.broadcast(buffer, 0) + print(f"rank={local_rank}, after, {buffer}") diff --git a/byte_infer_perf/llm_perf/utils/ps_utils.py b/byte_infer_perf/llm_perf/utils/ps_utils.py new file mode 100644 index 00000000..c2abbe3d --- /dev/null +++ b/byte_infer_perf/llm_perf/utils/ps_utils.py @@ -0,0 +1,28 @@ +import gc +import os + +import psutil +import torch + +from llm_perf.utils.logger import logger + +def check_memory_usage(tag): + # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports + gc.collect() + + vm_stats = psutil.virtual_memory() + used_GB = round((vm_stats.total - vm_stats.available) / (1024**3), 2) + + dev_mem_reserved = 0 + dev_mem_allocated = 0 + if torch.cuda.is_available(): + dev_mem_reserved = round(torch.cuda.memory_reserved() / (1024**3), 2) + dev_mem_allocated = round(torch.cuda.memory_allocated() / (1024**3), 2) + else: + pass + + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + msg = f"<<{tag}>> CPU VM State: Used = {used_GB} GB, Percent = {vm_stats.percent}% | "\ + f"DEV MEM State(Rank{local_rank}): Used = {dev_mem_allocated} GB, Reserved = {dev_mem_reserved} GB" + logger.info(msg) + diff --git a/byte_infer_perf/llm_perf/utils/reporter.py b/byte_infer_perf/llm_perf/utils/reporter.py index c74f09d7..e9155215 100644 --- a/byte_infer_perf/llm_perf/utils/reporter.py +++ b/byte_infer_perf/llm_perf/utils/reporter.py @@ -18,9 +18,18 @@ # Time To First Token __TTFT_AVG__ = "First Token Latency(AVG)" __TTFT_P90__ = "First Token Latency(P90)" +__CONTEXT_WAIT_AVG__ = "Context Wait Time(AVG)" +__CONTEXT_WAIT_P90__ = "Context Wait Time(P90)" +__CONTEXT_MODEL_AVG__ = "Context Model Time(AVG)" +__CONTEXT_MODEL_P90__ = "Context Model Time(P90)" + # Time Per Output Token __TPOT_AVG__ = "Per Token Latency(AVG)" __TPOT_P90__ = "Per Token Latency(P90)" +__DECODE_WAIT_AVG__ = "Decode Wait Time(AVG)" +__DECODE_WAIT_P90__ = "Decode Wait Time(P90)" +__DECODE_MODEL_AVG__ = "Decode Model Time(AVG)" +__DECODE_MODEL_P90__ = "Decode Model Time(P90)" class ReportType(Enum): @@ -103,7 +112,7 @@ def update_meta(self, tp_size: int, batch_size: int, input_tokens: int): self.batch_size = batch_size self.input_tokens = input_tokens - self.start_time = time.time() + self.start_time = time.perf_counter_ns() self.request = 0 self.performance_datas.clear() logger.info( @@ -116,7 +125,7 @@ def start(self): self._running = True self._worker.start() - self.start_time = time.time() + self.start_time = time.perf_counter_ns() self.request = 0 def stop(self): @@ -133,7 +142,7 @@ def submit(self, data: Dict[str, Any], report_type: ReportType): self._is_performance = True self.performance_datas.append(data) self.request += 1 - self.last_submit_time = time.time() + self.last_submit_time = time.perf_counter_ns() self.cond.notify() def worker(self): @@ -145,17 +154,49 @@ def worker(self): def _calc_performance(self): # Calc avg/p99/sum of data, return result + + completion_tokens = 0 + time_since_start = (self.last_submit_time - self.start_time) / 1e9 + ttfts = [] tpots = [] - completion_tokens = 0 - for i, data in enumerate(self.performance_datas): + + context_wait_time = [] + context_model_time = [] + + decode_wait_time = [] + decode_model_time = [] + + for data in self.performance_datas: + completion_tokens += data["completion_tokens"] + ttfts.append(data["first_token_latency"]) tpots.append(data["per_token_latency"]) - completion_tokens += data["completion_tokens"] - cur_ttft_avg = np.mean(ttfts) - cur_tpot_avg = np.mean(tpots) + + context_wait_time.append(data["context_wait_time"]) + context_model_time.append(data["context_model_time"]) + + decode_wait_time.append(data["decode_wait_time"]) + decode_model_time.append(data["decode_model_time"]) + + + + # context + cur_ttft_avg = np.mean(ttfts) cur_ttft_p90 = np.percentile(ttfts, 90) + cur_context_wait_avg = np.mean(context_wait_time) + cur_context_wait_p90 = np.percentile(context_wait_time, 90) + cur_context_model_avg = np.mean(context_model_time) + cur_context_model_p90 = np.percentile(context_model_time, 90) + + # decode + cur_tpot_avg = np.mean(tpots) cur_tpot_p90 = np.percentile(tpots, 90) + cur_decode_wait_avg = np.mean(decode_wait_time) + cur_decode_wait_p90 = np.percentile(decode_wait_time, 90) + cur_decode_model_avg = np.mean(decode_model_time) + cur_decode_model_p90 = np.percentile(decode_model_time, 90) + performance = None for perf in self.result["Performance"]: @@ -174,20 +215,31 @@ def _calc_performance(self): } self.result["Performance"].append(performance) - performance[__TTFT_AVG__] = cur_ttft_avg - performance[__TPOT_AVG__] = cur_tpot_avg - performance[__TTFT_P90__] = cur_ttft_p90 - performance[__TPOT_P90__] = cur_tpot_p90 - logger.info( + performance["client"] = { + __TTFT_AVG__: cur_ttft_avg, + __TTFT_P90__: cur_ttft_p90, + __TPOT_AVG__: cur_tpot_avg, + __TPOT_P90__: cur_tpot_p90, + } + performance["server"] = { + __CONTEXT_WAIT_AVG__ : cur_context_wait_avg, + __CONTEXT_WAIT_P90__ : cur_context_wait_p90, + __CONTEXT_MODEL_AVG__ : cur_context_model_avg, + __CONTEXT_MODEL_P90__ : cur_context_model_p90, + __DECODE_WAIT_AVG__ : cur_decode_wait_avg, + __DECODE_WAIT_P90__ : cur_decode_wait_p90, + __DECODE_MODEL_AVG__ : cur_decode_model_avg, + __DECODE_MODEL_P90__ : cur_decode_model_p90, + } + + logger.debug( f"TTFT(AVG)={cur_ttft_avg}, TTFT(P90)={cur_ttft_p90}, TPOT(AVG)={cur_tpot_avg}, TPOT(P90)={cur_tpot_p90}" ) - performance["Token Throughput"] = completion_tokens / ( - self.last_submit_time - self.start_time - ) + performance["Token Throughput"] = completion_tokens / time_since_start performance["Request Number"] = self.request - performance["QPS"] = self.request / (self.last_submit_time - self.start_time) + performance["QPS"] = self.request / time_since_start logger.info( f"Request Number={performance['Request Number']}, Token Throughput={performance['Token Throughput']}, QPS={performance['QPS']}" diff --git a/byte_infer_perf/llm_perf/workloads/chatglm2-torch-fp16-6b.json b/byte_infer_perf/llm_perf/workloads/chatglm2-torch-fp16-6b.json index b71ebaee..0b0d950c 100644 --- a/byte_infer_perf/llm_perf/workloads/chatglm2-torch-fp16-6b.json +++ b/byte_infer_perf/llm_perf/workloads/chatglm2-torch-fp16-6b.json @@ -2,10 +2,10 @@ "model": "chatglm2-torch-fp16-6b", "test_accuracy": false, "test_perf": true, - "min_new_tokens": 128, - "max_new_tokens": 256, + "min_new_tokens": 200, + "max_new_tokens": 200, "tp_sizes": [1, 2, 4, 8], - "batch_sizes":[8], + "batch_sizes": [1, 8], "input_tokens": [1024, 2048], "dataset": "llm_perf/datasets/merged_52_test.csv", "perf_time": 100 diff --git a/byte_infer_perf/llm_perf/workloads/chinese-llama2-torch-fp16-13b.json b/byte_infer_perf/llm_perf/workloads/chinese-llama2-torch-fp16-13b.json index 4c893608..a3d96ebf 100644 --- a/byte_infer_perf/llm_perf/workloads/chinese-llama2-torch-fp16-13b.json +++ b/byte_infer_perf/llm_perf/workloads/chinese-llama2-torch-fp16-13b.json @@ -2,10 +2,10 @@ "model": "chinese-llama2-torch-fp16-13b", "test_accuracy": false, "test_perf": true, - "min_new_tokens": 128, - "max_new_tokens": 256, + "min_new_tokens": 200, + "max_new_tokens": 200, "tp_sizes": [1, 2, 4, 8], - "batch_sizes":[8], + "batch_sizes": [1, 8], "input_tokens": [1024, 2048], "dataset": "llm_perf/datasets/merged_52_test.csv", "perf_time": 100 diff --git a/byte_infer_perf/llm_perf/workloads/mixtral-torch-fp16-8x7b.json b/byte_infer_perf/llm_perf/workloads/mixtral-torch-fp16-8x7b.json index 8e26092f..90ac561b 100644 --- a/byte_infer_perf/llm_perf/workloads/mixtral-torch-fp16-8x7b.json +++ b/byte_infer_perf/llm_perf/workloads/mixtral-torch-fp16-8x7b.json @@ -2,10 +2,10 @@ "model": "mixtral-torch-fp16-8x7b", "test_accuracy": false, "test_perf": true, - "min_new_tokens": 128, - "max_new_tokens": 256, + "min_new_tokens": 200, + "max_new_tokens": 200, "tp_sizes": [4, 8], - "batch_sizes": [8], + "batch_sizes": [1, 8], "input_tokens": [1024, 2048], "dataset": "llm_perf/datasets/merged_52_test.csv", "perf_time": 100