Skip to content

Commit

Permalink
optimize llm_perf code.
Browse files Browse the repository at this point in the history
  • Loading branch information
suisiyuan committed Jul 18, 2024
1 parent fa9f49b commit 5c319c6
Show file tree
Hide file tree
Showing 20 changed files with 1,974 additions and 7,020 deletions.
23 changes: 14 additions & 9 deletions byte_infer_perf/llm_perf/backends/GPU/gpu_ckpt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,21 @@ def weight_to_device(self, weight : torch.Tensor, non_blocking=False):
weight = torch.empty_like(weight, device=f"cuda:{cur_device}")
return weight

def broadcast_weight(self, key, device='cpu', non_blocking=False):
weight = self.weight_to_device(self.state_dict[key])
dist.broadcast(weight, src=0)
dist.barrier()
self.state_dict[key] = weight.to(device, non_blocking=non_blocking)

def broadcast_weight(self, key, device='cpu', non_blocking=False):
if self.mp_rank != 0:
tensor_shape = self.state_dict[key]["shape"]
tensor_dtype = self.state_dict[key]["dtype"]
tensor = torch.empty(tensor_shape, dtype=tensor_dtype)
else:
tensor = self.state_dict[key].cpu()
tensor_gpu = self.weight_to_device(tensor, non_blocking=non_blocking)
dist.broadcast(tensor_gpu, src=0)
self.state_dict[key] = tensor_gpu


def scatter_weight(self, key, dim, split_mode='default', outter=1, device='cpu', non_blocking=False):
self.broadcast_weight(key, 'cuda')
self.broadcast_weight(key, non_blocking=non_blocking)
weight = self.state_dict[key]

if split_mode == 'default':
Expand All @@ -40,7 +47,5 @@ def scatter_weight(self, key, dim, split_mode='default', outter=1, device='cpu',
else:
assert False, f"unknown split mode {split_mode}"


weight_split = [x.contiguous() for x in weight_split]
weight = weight_split[self.mp_rank].clone()
self.state_dict[key] = weight.to(device, non_blocking=non_blocking)
self.state_dict[key] = weight_split[self.mp_rank]
4 changes: 3 additions & 1 deletion byte_infer_perf/llm_perf/backends/GPU/model_impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
import torch.nn as nn

from .gpu_chatglm2 import GPUChatGLM2
from .gpu_falcon import GPUFalcon

from llm_perf.utils.logger import logger

__all__ = {
"chatglm2": GPUChatGLM2
"chatglm2": GPUChatGLM2,
"falcon": GPUFalcon
}
1,360 changes: 1,360 additions & 0 deletions byte_infer_perf/llm_perf/backends/GPU/model_impl/falcon.py

Large diffs are not rendered by default.

160 changes: 160 additions & 0 deletions byte_infer_perf/llm_perf/backends/GPU/model_impl/falcon_split_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import os
import sys
import pathlib
import argparse

import torch
import torch.nn as nn
from typing import List

from accelerate import init_empty_weights
from transformers import FalconConfig


FILE_DIR = pathlib.Path(__file__).parent.absolute()


sys.path.insert(0, str(FILE_DIR.parent.parent.parent.parent))
from llm_perf.backends.GPU.model_impl.falcon import FalconForCausalLM
from llm_perf.core.ckpt_loader import Falcon_ModelLoader


def to_parameter(
data : torch.Tensor,
dtype : torch.dtype =None
):
if dtype is not None:
data = data.to(dtype)
return nn.Parameter(data, requires_grad=False)


def split(
src : torch.Tensor,
mp_size : int,
dim : int,
chunks : List [int]=[]
):
if len(chunks) == 0:
split_arg = src.shape[dim] // mp_size
output_tensors = torch.split(src, split_arg, dim=dim)
else:
# for example
# chunks = [32, 2, 2], sum_chunks = 36, src.shape[dim] = (32 + 2 + 2) * 128, other_dim = 128
# mp_size = 8
# new_chunks = [4, 1, 1]
sum_chunks = sum(chunks)
other_dim_size = src.shape[dim] // sum_chunks

split_arg = [i * other_dim_size for i in chunks]
split_tensors = torch.split(src, split_arg, dim=dim)

output_split = []
for i, tensor in enumerate(split_tensors):
if mp_size > chunks[i]:
tensor_shape = tensor.size()[:dim] + (chunks[i], 1, other_dim_size) + tensor.size()[dim+1:]
new_tensor_shape = tensor.size()[:dim] + (chunks[i], mp_size // chunks[i], other_dim_size) + tensor.size()[dim+1:]
output_tensor_shape = tensor.size()[:dim] + (mp_size * other_dim_size,) + tensor.size()[dim+1:]

tensor = tensor.view(tensor_shape)
tensor = tensor.expand(*new_tensor_shape)
tensor = tensor.contiguous()
tensor = tensor.view(output_tensor_shape)

cur_split = torch.split(tensor, tensor.shape[dim] // mp_size, dim=dim)
output_split.append(cur_split)

output_tensors = []
for i in range(mp_size):
temp_tensors = [output_split[j][i] for j in range(len(chunks))]
tp_tensors = torch.concat(temp_tensors, dim=dim)
output_tensors.append(tp_tensors)

output_tensors = [tensor.contiguous() for tensor in output_tensors]

return output_tensors



if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--mp_size", type=int, default=8, choices=[2, 4, 8])
args = parser.parse_args()

os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = str(args.mp_size)


model_path = pathlib.Path(args.model_path).absolute()
split_model_path = model_path / f"TP{args.mp_size}"
split_model_path.mkdir(parents=True, exist_ok=True)

config = FalconConfig.from_pretrained(str(model_path))
model_loader = Falcon_ModelLoader(model_path)
state_dict = model_loader.load_weight()

# for key in state_dict.keys():
# print(key, state_dict[key].shape, state_dict[key].dtype)

# print("")
# print("")
# print("")

for i in range(config.num_hidden_layers):
attn_qkv = f"transformer.h.{i}.self_attention.query_key_value.weight"
attn_dense = f"transformer.h.{i}.self_attention.dense.weight"

dense_h_to_4h = f"transformer.h.{i}.mlp.dense_h_to_4h.weight"
dense_4h_to_h = f"transformer.h.{i}.mlp.dense_4h_to_h.weight"

print(i)
state_dict[attn_qkv] = split(
state_dict[attn_qkv], args.mp_size,
dim=0,
chunks=[config.num_attention_heads, config.num_kv_heads, config.num_kv_heads]
)
state_dict[attn_dense] = split(
state_dict[attn_dense], args.mp_size,
dim=1
)
state_dict[dense_h_to_4h] = split(
state_dict[dense_h_to_4h], args.mp_size,
dim=0
)
state_dict[dense_4h_to_h] = split(
state_dict[dense_4h_to_h], args.mp_size,
dim=1
)

with init_empty_weights():
model = FalconForCausalLM(config)
model.eval()

for i in range(args.mp_size):
print(f"store model_{i}")

output_dir = split_model_path / f"device_{i}"
output_dir.mkdir(parents=True, exist_ok=True)

model.transformer.word_embeddings.weight = to_parameter(state_dict["transformer.word_embeddings.weight"])
for j in range(config.num_hidden_layers):
model.transformer.h[j].self_attention.query_key_value.weight = to_parameter(state_dict[f"transformer.h.{j}.self_attention.query_key_value.weight"][i])
model.transformer.h[j].self_attention.dense.weight = to_parameter(state_dict[f"transformer.h.{j}.self_attention.dense.weight"][i])
model.transformer.h[j].mlp.dense_h_to_4h.weight = to_parameter(state_dict[f"transformer.h.{j}.mlp.dense_h_to_4h.weight"][i])
model.transformer.h[j].mlp.dense_4h_to_h.weight = to_parameter(state_dict[f"transformer.h.{j}.mlp.dense_4h_to_h.weight"][i])

model.transformer.h[j].ln_attn.weight = to_parameter(state_dict[f"transformer.h.{j}.ln_attn.weight"])
model.transformer.h[j].ln_attn.bias = to_parameter(state_dict[f"transformer.h.{j}.ln_attn.bias"])
model.transformer.h[j].ln_mlp.weight = to_parameter(state_dict[f"transformer.h.{j}.ln_mlp.weight"])
model.transformer.h[j].ln_mlp.bias = to_parameter(state_dict[f"transformer.h.{j}.ln_mlp.bias"])
model.transformer.ln_f.weight = to_parameter(state_dict["transformer.ln_f.weight"])
model.transformer.ln_f.bias = to_parameter(state_dict["transformer.ln_f.bias"])
model.lm_head.weight = to_parameter(state_dict["lm_head.weight"])

model.save_pretrained(str(output_dir))

# small_state_dict = model.state_dict()
# for key in small_state_dict.keys():
# print(key, small_state_dict[key].shape, small_state_dict[key].dtype, small_state_dict[key].device)


47 changes: 36 additions & 11 deletions byte_infer_perf/llm_perf/backends/GPU/model_impl/gpu_chatglm2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import json
import pathlib

import torch
import torch.distributed as dist
import torch.nn as nn
Expand All @@ -10,6 +13,7 @@

from accelerate import init_empty_weights

from llm_perf.core.ckpt_loader import CoreCkptLoader, ChatGLM2_ModelLoader
from llm_perf.backends.GPU.gpu_ckpt_loader import GpuCkptLoader

from .chatglm2 import ChatGLMForConditionalGeneration, ChatGLMModel, ChatGLMConfig
Expand All @@ -19,24 +23,41 @@ class GPUChatGLM2Loader(GpuCkptLoader):
def __init__(
self,
prefix,
model,
mp_size=1,
mp_rank=0,
model, model_config,
mp_size=1, mp_rank=0,
ckpt_path: str = ""
):
super().__init__(prefix, model, mp_size, mp_rank, ckpt_path)

self.model_config = model_config


def parallel_loader(self):
self.state_dict = None
self.state_dict = {}

# load model
if self.mp_rank == 0:
self.state_dict = self.torch_load_wrapper(
self.ckpt_path, map_location=torch.device("cpu"))
self.state_dict = {}

model_dir = pathlib.Path(self.ckpt_path).absolute()
if not model_dir.exists() or not model_dir.is_dir():
return

weight_index_config = {}
for child in model_dir.iterdir():
if child.name.endswith(".index.json"):
with open(child, "r") as f:
weight_index_config = json.load(f)

model_loader = ChatGLM2_ModelLoader(model_dir, self.model_config, weight_index_config)
model_loader.load_weight()
self.state_dict = model_loader.weight_dict

if self.mp_size == 1:
return self.state_dict

# mp_size > 2
# broadcast state_dict from rank 0 to other ranks
# mp_size > 1
# broadcast {key_name: [tensor_shape, tensor]} from rank 0 to other ranks
self.broadcast_meta()

self.broadcast_weight("transformer.embedding.word_embeddings.weight")
Expand All @@ -58,6 +79,8 @@ def parallel_loader(self):

return self.state_dict



def infusion_to_model(self):
self.model.transformer.embedding.word_embeddings.weight = self.to_parameter(
self.state_dict[f"transformer.embedding.word_embeddings.weight"]
Expand Down Expand Up @@ -121,7 +144,6 @@ def __init__(self, xpu_cfg: Dict[str, Any]) -> None:
self.mp_size = int(os.environ.get("WORLD_SIZE", "1"))
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))


self.prefix = "transformer.encoder.layers"
self.transformer_model : ChatGLMForConditionalGeneration = None

Expand Down Expand Up @@ -163,16 +185,19 @@ def init_inference(self):
logger.info(f"cuda model {self.model_path} loaded {self.transformer_model}")



def load_weight(self, ckpt_path):
p_loader = GPUChatGLM2Loader(
self.prefix, self.transformer_model,
self.prefix, self.transformer_model, self.chatglm_config,
self.mp_size, self.local_rank,
ckpt_path
)
p_loader.load()
p_loader.parallel_loader()
p_loader.infusion_to_model()




def init_kvcache(self, dtype):
max_seq_len = 4096
max_batch_size = self.xpu_cfg["max_batch_size"]
Expand Down
Loading

0 comments on commit 5c319c6

Please sign in to comment.