Skip to content

Commit

Permalink
add batch_gemm, group_gemm; add int8 dtype to gemm ops; fix situation…
Browse files Browse the repository at this point in the history
… that world_size exceeds available devices.
  • Loading branch information
suisiyuan committed May 16, 2024
1 parent c1d1835 commit 51bd4a7
Show file tree
Hide file tree
Showing 10 changed files with 700 additions and 188 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ init_env.sh

byte_infer_perf/llm_perf/download
byte_infer_perf/llm_perf/model_zoo/sota
byte_infer_perf/llm_perf/reports
byte_infer_perf/llm_perf/reports

out/
*.db
126 changes: 88 additions & 38 deletions byte_micro_perf/backends/GPU/backend_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@
import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as dist_c10d

from backends.backend import Backend
from backends.module_store import *
from backends.utils import get_dtype_bytes

from .custom_ops import GPUGemmOp, GPUBatchGemmOp, GPUGroupGemmOp


logging.basicConfig(level=logging.INFO)
log = logging.getLogger("PerfEngine")
Expand All @@ -50,9 +55,51 @@ def get_backend_properties(self):
)
)


# gemm ops
def gemm(self):
self.op = GemmOp()
self.op = GPUGemmOp()

def batch_gemm(self):
self.op = BatchGemmOp()

def group_gemm(self):
self.op = GPUGroupGemmOp()


# device/host ops
def host2device(self):
self.op = Host2DeviceOp(torch.device("cuda"))

def device2host(self):
self.op = Device2HostOp()



# communication ops
def allreduce(self):
self.setup_2d_group()
self.op = AllReduceOp(self.group)

def allgather(self):
self.setup_2d_group()
self.op = AllGatherOp(self.group)

def reducescatter(self):
self.setup_2d_group()
self.op = ReduceScatterOp(self.group)

def alltoall(self):
self.setup_2d_group()
self.op = AllToAllOp(self.group)

def broadcast(self):
self.setup_2d_group()
self.op = BroadcastOp(self.group)



# other compute ops
def add(self):
self.op = AddOp()

Expand Down Expand Up @@ -86,57 +133,50 @@ def softmax(self):
def layernorm(self):
self.op = LayerNormOp()

def allreduce(self):
self.setup_2d_group()
self.op = AllReduceOp(self.group)

def allgather(self):
self.setup_2d_group()
self.op = AllGatherOp(self.group)

def reducescatter(self):
self.setup_2d_group()
self.op = ReduceScatterOp(self.group)

def alltoall(self):
self.setup_2d_group()
self.op = AllToAllOp(self.group)

def broadcast(self):
self.setup_2d_group()
self.op = BroadcastOp(self.group)
# create input tensors
def build_tensor(self, input_shapes, torch_dtype):

def host2device(self):
self.op = Host2DeviceOp(torch.device("cuda"))
# compute size of input and output tensors
if hasattr(self.op, "compute_size"):
bytes_per_cnt = self.op.compute_size(input_shapes, torch_dtype)
# default: input_tensors_size == output_tensor_size, all tensors have same dtype
else:
dtype_size = get_dtype_bytes(torch_dtype)
element_num = 2 * sum([math.prod(shape) for shape in input_shapes])
bytes_per_cnt = dtype_size * element_num

def device2host(self):
self.op = Device2HostOp()
# compute max avail tensors for compute
avail_bytes = (self.memory_limit - 4) * 1024**3
avail_cnts = avail_bytes // bytes_per_cnt
max_data_cnt = min(self.iterations, avail_cnts)

def build_tensor(self, input_shapes, dtype):
torch_type = getattr(torch, dtype)
if torch_type == torch.int32:
dtype_size = torch.iinfo(torch_type).bits // 8
else:
dtype_size = torch.finfo(torch_type).bits // 8
size = sum([math.prod(shape) for shape in input_shapes])
data_amount = size * 2 * dtype_size
data_cnt = (self.memory_limit - 4) * 1024**3 // data_amount
data_cnt = min(data_cnt, self.iterations)
input_tensors_list = []
for _ in range(data_cnt):
input_tensors = [
torch.randn(shape).type(torch_type).to(torch.device("cuda"))
for shape in input_shapes
]
input_tensors_list.append(input_tensors)

# create input tensors for each op
input_tensors_list = []
for _ in range(max_data_cnt):
# create input tensors
if hasattr(self.op, "custom_create_tensors"):
input_tensors = self.op.custom_create_tensors(input_shapes, torch_dtype)
input_tensors_list.append(input_tensors)
# default: all input tensors have same dtype
else:
input_tensors = [
torch.randint(0, 3, size=shape).type(torch_dtype).to(torch.device("cuda"))
for shape in input_shapes
]
input_tensors_list.append(input_tensors)
if hasattr(self.op, "process_inputs"):
input_tensors_list = [
self.op.process_inputs(*(input_tensor))
for input_tensor in input_tensors_list
]
return input_tensors_list, max_data_cnt, bytes_per_cnt


return input_tensors_list, data_cnt

def _run_operation(self, operation, inputs):
result = operation(*inputs)
Expand All @@ -150,13 +190,22 @@ def initialize_ccl(self, rank, world_size):
"""
initialize distributed process groups and relevant ENVs
"""
# check device_count
device_count = torch.cuda.device_count()
if world_size > device_count:
world_size = device_count
if rank >= world_size:
return False

# set envs
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "49373"
os.environ["LOCAL_RANK"] = str(rank)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)

torch.cuda.set_device(rank)

# Call the init process
timeout_seconds = int(os.environ.get("MEGATRON_NCCL_TIMEOUT_SECOND", 30))
torch.distributed.init_process_group(
Expand All @@ -168,6 +217,7 @@ def initialize_ccl(self, rank, world_size):
)
self.setup_2d_group()
log.warning("DIST: rank {}, world_size {}".format(rank, world_size))
return True

def setup_2d_group(self):
self.rank = dist.get_rank()
Expand Down
185 changes: 185 additions & 0 deletions byte_micro_perf/backends/GPU/custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from typing import List

import torch
import cutlass

from backends.module_store import GemmOp, BatchGemmOp, GroupGemmOp




# gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16
# gemm(cutlass) int8 --> int32
class GPUGemmOp(GemmOp):
def __init__(self):
super().__init__()

# cutlass int8 gemm
dtype = torch.int8
accum_dtype=torch.int32
self.plan = cutlass.op.Gemm(
alpha=1, beta=0,
element_A=dtype,
element_B=dtype,
element_C=accum_dtype,
element_D=accum_dtype,
layout_A=cutlass.LayoutType.ColumnMajorInterleaved32,
layout_B=cutlass.LayoutType.RowMajorInterleaved32,
layout_C=cutlass.LayoutType.RowMajor
)
self.op = self.plan.construct(
alignment_A=16,
alignment_B=16,
alignment_C=8
)
self.gemm_op_int8 = cutlass.emit.pytorch(
self.op, name='gemm', cc=self.plan.cc,
jit=True, sourcedir='out'
)


def forward(
self,
input_tensor_a : torch.Tensor,
input_tensor_b : torch.Tensor
):
compute_dtype = input_tensor_a.dtype
if compute_dtype == torch.int8 and self.gemm_op_int8 is not None:
output_tensor = self.gemm_op_int8.run(
input_tensor_a, input_tensor_b
)
else:
output_tensor = torch.mm(
input_tensor_a, input_tensor_b
)
return output_tensor




# batch_gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16
# batch_gemm(cutlass) int8 --> int32
class GPUBatchGemmOp(BatchGemmOp):
def __init__(self):
super().__init__()

# TODO: cutlass int8 batch_gemm
pass

def forward(
self,
input_tensor_a : torch.Tensor,
input_tensor_b : torch.Tensor
):
compute_dtype = input_tensor_a.dtype

output_tensor = None
if compute_dtype == torch.int8:
# TODO
pass
else:
output_tensor = torch.bmm(
input_tensor_a, input_tensor_b
)
return output_tensor




# group_gemm(cutlass) float32/float16/bfloat16 --> float32
# group_gemm(cutlass) int8 --> int32
class GPUGroupGemmOp(GroupGemmOp):
def __init__(self):
super().__init__()

self.group_gemm_fp32 = GPUGroupGemmOp.compile_mod(
dtype=torch.float32,
accum_dtype=torch.float32,
mod_name="groupd_gemm_fp32"
)

self.group_gemm_fp16 = GPUGroupGemmOp.compile_mod(
dtype=torch.float16,
accum_dtype=torch.float32,
mod_name="groupd_gemm_fp16"
)

self.group_gemm_bf16 = GPUGroupGemmOp.compile_mod(
dtype=torch.bfloat16,
accum_dtype=torch.float32,
mod_name="groupd_gemm_bf16"
)

# TODO: cutlass int8 group_gemm
self.group_gemm_int8 = None
# if "int8" in dtype_list:
# self.group_gemm_int8 = GroupGemmOp.compile_mod(
# dtype=torch.int8,
# accum_dtype=torch.int32,
# mod_name="group_gemm_int8"
# )

@staticmethod
def compile_mod(dtype, accum_dtype, mod_name):

if dtype == torch.int8:
# TODO
pass
# plan = cutlass.op.Gemm(
# alpha=1, beta=0,
# element_A=dtype,
# element_B=dtype,
# element_C=accum_dtype,
# element_D=accum_dtype,
# layout_A=cutlass.LayoutType.ColumnMajorInterleaved32,
# layout_B=cutlass.LayoutType.RowMajorInterleaved32,
# layout_C=cutlass.LayoutType.RowMajor
# )
# op = plan.construct(
# alignment_A=16,
# alignment_B=16,
# alignment_C=8
# )
# grouped_gemm = cutlass.emit.pytorch(
# op, name=mod_name,
# cc=plan.cc, jit=True,
# sourcedir='out'
# )
else:
plan = cutlass.op.GroupedGemm(
alpha=1, beta=0,
element_A=dtype,
element_B=dtype,
element_C=accum_dtype,
element_D=accum_dtype,
layout_A=cutlass.LayoutType.RowMajor,
layout_B=cutlass.LayoutType.RowMajor,
layout_C=cutlass.LayoutType.RowMajor
)
op = plan.construct()
grouped_gemm = cutlass.emit.pytorch(
op, name=mod_name,
cc=plan.cc, jit=True,
sourcedir='./out'
)
return grouped_gemm


def forward(self,
a_list : List[torch.Tensor],
b_list : List[torch.Tensor]
):
compute_dtype = a_list[0].dtype
if compute_dtype == torch.float32 and self.group_gemm_fp32 is not None:
output_tensors = self.group_gemm_fp32.run(a_list, b_list)
elif compute_dtype == torch.float16 and self.group_gemm_fp16 is not None:
output_tensors = self.group_gemm_fp16.run(a_list, b_list)
elif compute_dtype == torch.bfloat16 and self.group_gemm_bf16 is not None:
output_tensors = self.group_gemm_bf16.run(a_list, b_list)
elif compute_dtype == torch.int8 and self.group_gemm_int8 is not None:
# TODO
pass
# output_tensors = self.group_gemm_int8.run(a_list, b_list)
else:
output_tensors = []
return output_tensors
Loading

0 comments on commit 51bd4a7

Please sign in to comment.