diff --git a/.gitignore b/.gitignore index 2e06b074..eb9726f4 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,7 @@ init_env.sh byte_infer_perf/llm_perf/download byte_infer_perf/llm_perf/model_zoo/sota -byte_infer_perf/llm_perf/reports \ No newline at end of file +byte_infer_perf/llm_perf/reports + +out/ +*.db \ No newline at end of file diff --git a/byte_micro_perf/backends/GPU/backend_gpu.py b/byte_micro_perf/backends/GPU/backend_gpu.py index a8da5cbb..649547b3 100644 --- a/byte_micro_perf/backends/GPU/backend_gpu.py +++ b/byte_micro_perf/backends/GPU/backend_gpu.py @@ -22,8 +22,13 @@ import torch import torch.distributed as dist import torch.distributed.distributed_c10d as dist_c10d + from backends.backend import Backend from backends.module_store import * +from backends.utils import get_dtype_bytes + +from .custom_ops import GPUGemmOp, GPUBatchGemmOp, GPUGroupGemmOp + logging.basicConfig(level=logging.INFO) log = logging.getLogger("PerfEngine") @@ -50,9 +55,51 @@ def get_backend_properties(self): ) ) + + # gemm ops def gemm(self): - self.op = GemmOp() + self.op = GPUGemmOp() + + def batch_gemm(self): + self.op = BatchGemmOp() + + def group_gemm(self): + self.op = GPUGroupGemmOp() + + + # device/host ops + def host2device(self): + self.op = Host2DeviceOp(torch.device("cuda")) + + def device2host(self): + self.op = Device2HostOp() + + + + # communication ops + def allreduce(self): + self.setup_2d_group() + self.op = AllReduceOp(self.group) + + def allgather(self): + self.setup_2d_group() + self.op = AllGatherOp(self.group) + + def reducescatter(self): + self.setup_2d_group() + self.op = ReduceScatterOp(self.group) + + def alltoall(self): + self.setup_2d_group() + self.op = AllToAllOp(self.group) + + def broadcast(self): + self.setup_2d_group() + self.op = BroadcastOp(self.group) + + + # other compute ops def add(self): self.op = AddOp() @@ -86,57 +133,50 @@ def softmax(self): def layernorm(self): self.op = LayerNormOp() - def allreduce(self): - self.setup_2d_group() - self.op = AllReduceOp(self.group) - def allgather(self): - self.setup_2d_group() - self.op = AllGatherOp(self.group) - def reducescatter(self): - self.setup_2d_group() - self.op = ReduceScatterOp(self.group) - def alltoall(self): - self.setup_2d_group() - self.op = AllToAllOp(self.group) - def broadcast(self): - self.setup_2d_group() - self.op = BroadcastOp(self.group) + # create input tensors + def build_tensor(self, input_shapes, torch_dtype): - def host2device(self): - self.op = Host2DeviceOp(torch.device("cuda")) + # compute size of input and output tensors + if hasattr(self.op, "compute_size"): + bytes_per_cnt = self.op.compute_size(input_shapes, torch_dtype) + # default: input_tensors_size == output_tensor_size, all tensors have same dtype + else: + dtype_size = get_dtype_bytes(torch_dtype) + element_num = 2 * sum([math.prod(shape) for shape in input_shapes]) + bytes_per_cnt = dtype_size * element_num - def device2host(self): - self.op = Device2HostOp() + # compute max avail tensors for compute + avail_bytes = (self.memory_limit - 4) * 1024**3 + avail_cnts = avail_bytes // bytes_per_cnt + max_data_cnt = min(self.iterations, avail_cnts) - def build_tensor(self, input_shapes, dtype): - torch_type = getattr(torch, dtype) - if torch_type == torch.int32: - dtype_size = torch.iinfo(torch_type).bits // 8 - else: - dtype_size = torch.finfo(torch_type).bits // 8 - size = sum([math.prod(shape) for shape in input_shapes]) - data_amount = size * 2 * dtype_size - data_cnt = (self.memory_limit - 4) * 1024**3 // data_amount - data_cnt = min(data_cnt, self.iterations) - input_tensors_list = [] - for _ in range(data_cnt): - input_tensors = [ - torch.randn(shape).type(torch_type).to(torch.device("cuda")) - for shape in input_shapes - ] - input_tensors_list.append(input_tensors) + # create input tensors for each op + input_tensors_list = [] + for _ in range(max_data_cnt): + # create input tensors + if hasattr(self.op, "custom_create_tensors"): + input_tensors = self.op.custom_create_tensors(input_shapes, torch_dtype) + input_tensors_list.append(input_tensors) + # default: all input tensors have same dtype + else: + input_tensors = [ + torch.randint(0, 3, size=shape).type(torch_dtype).to(torch.device("cuda")) + for shape in input_shapes + ] + input_tensors_list.append(input_tensors) if hasattr(self.op, "process_inputs"): input_tensors_list = [ self.op.process_inputs(*(input_tensor)) for input_tensor in input_tensors_list ] + return input_tensors_list, max_data_cnt, bytes_per_cnt + - return input_tensors_list, data_cnt def _run_operation(self, operation, inputs): result = operation(*inputs) @@ -150,6 +190,14 @@ def initialize_ccl(self, rank, world_size): """ initialize distributed process groups and relevant ENVs """ + # check device_count + device_count = torch.cuda.device_count() + if world_size > device_count: + world_size = device_count + if rank >= world_size: + return False + + # set envs os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "49373" os.environ["LOCAL_RANK"] = str(rank) @@ -157,6 +205,7 @@ def initialize_ccl(self, rank, world_size): os.environ["WORLD_SIZE"] = str(world_size) torch.cuda.set_device(rank) + # Call the init process timeout_seconds = int(os.environ.get("MEGATRON_NCCL_TIMEOUT_SECOND", 30)) torch.distributed.init_process_group( @@ -168,6 +217,7 @@ def initialize_ccl(self, rank, world_size): ) self.setup_2d_group() log.warning("DIST: rank {}, world_size {}".format(rank, world_size)) + return True def setup_2d_group(self): self.rank = dist.get_rank() diff --git a/byte_micro_perf/backends/GPU/custom_ops.py b/byte_micro_perf/backends/GPU/custom_ops.py new file mode 100644 index 00000000..a13ec25d --- /dev/null +++ b/byte_micro_perf/backends/GPU/custom_ops.py @@ -0,0 +1,185 @@ +from typing import List + +import torch +import cutlass + +from backends.module_store import GemmOp, BatchGemmOp, GroupGemmOp + + + + +# gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16 +# gemm(cutlass) int8 --> int32 +class GPUGemmOp(GemmOp): + def __init__(self): + super().__init__() + + # cutlass int8 gemm + dtype = torch.int8 + accum_dtype=torch.int32 + self.plan = cutlass.op.Gemm( + alpha=1, beta=0, + element_A=dtype, + element_B=dtype, + element_C=accum_dtype, + element_D=accum_dtype, + layout_A=cutlass.LayoutType.ColumnMajorInterleaved32, + layout_B=cutlass.LayoutType.RowMajorInterleaved32, + layout_C=cutlass.LayoutType.RowMajor + ) + self.op = self.plan.construct( + alignment_A=16, + alignment_B=16, + alignment_C=8 + ) + self.gemm_op_int8 = cutlass.emit.pytorch( + self.op, name='gemm', cc=self.plan.cc, + jit=True, sourcedir='out' + ) + + + def forward( + self, + input_tensor_a : torch.Tensor, + input_tensor_b : torch.Tensor + ): + compute_dtype = input_tensor_a.dtype + if compute_dtype == torch.int8 and self.gemm_op_int8 is not None: + output_tensor = self.gemm_op_int8.run( + input_tensor_a, input_tensor_b + ) + else: + output_tensor = torch.mm( + input_tensor_a, input_tensor_b + ) + return output_tensor + + + + +# batch_gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16 +# batch_gemm(cutlass) int8 --> int32 +class GPUBatchGemmOp(BatchGemmOp): + def __init__(self): + super().__init__() + + # TODO: cutlass int8 batch_gemm + pass + + def forward( + self, + input_tensor_a : torch.Tensor, + input_tensor_b : torch.Tensor + ): + compute_dtype = input_tensor_a.dtype + + output_tensor = None + if compute_dtype == torch.int8: + # TODO + pass + else: + output_tensor = torch.bmm( + input_tensor_a, input_tensor_b + ) + return output_tensor + + + + +# group_gemm(cutlass) float32/float16/bfloat16 --> float32 +# group_gemm(cutlass) int8 --> int32 +class GPUGroupGemmOp(GroupGemmOp): + def __init__(self): + super().__init__() + + self.group_gemm_fp32 = GPUGroupGemmOp.compile_mod( + dtype=torch.float32, + accum_dtype=torch.float32, + mod_name="groupd_gemm_fp32" + ) + + self.group_gemm_fp16 = GPUGroupGemmOp.compile_mod( + dtype=torch.float16, + accum_dtype=torch.float32, + mod_name="groupd_gemm_fp16" + ) + + self.group_gemm_bf16 = GPUGroupGemmOp.compile_mod( + dtype=torch.bfloat16, + accum_dtype=torch.float32, + mod_name="groupd_gemm_bf16" + ) + + # TODO: cutlass int8 group_gemm + self.group_gemm_int8 = None + # if "int8" in dtype_list: + # self.group_gemm_int8 = GroupGemmOp.compile_mod( + # dtype=torch.int8, + # accum_dtype=torch.int32, + # mod_name="group_gemm_int8" + # ) + + @staticmethod + def compile_mod(dtype, accum_dtype, mod_name): + + if dtype == torch.int8: + # TODO + pass + # plan = cutlass.op.Gemm( + # alpha=1, beta=0, + # element_A=dtype, + # element_B=dtype, + # element_C=accum_dtype, + # element_D=accum_dtype, + # layout_A=cutlass.LayoutType.ColumnMajorInterleaved32, + # layout_B=cutlass.LayoutType.RowMajorInterleaved32, + # layout_C=cutlass.LayoutType.RowMajor + # ) + # op = plan.construct( + # alignment_A=16, + # alignment_B=16, + # alignment_C=8 + # ) + # grouped_gemm = cutlass.emit.pytorch( + # op, name=mod_name, + # cc=plan.cc, jit=True, + # sourcedir='out' + # ) + else: + plan = cutlass.op.GroupedGemm( + alpha=1, beta=0, + element_A=dtype, + element_B=dtype, + element_C=accum_dtype, + element_D=accum_dtype, + layout_A=cutlass.LayoutType.RowMajor, + layout_B=cutlass.LayoutType.RowMajor, + layout_C=cutlass.LayoutType.RowMajor + ) + op = plan.construct() + grouped_gemm = cutlass.emit.pytorch( + op, name=mod_name, + cc=plan.cc, jit=True, + sourcedir='./out' + ) + return grouped_gemm + + + def forward(self, + a_list : List[torch.Tensor], + b_list : List[torch.Tensor] + ): + compute_dtype = a_list[0].dtype + if compute_dtype == torch.float32 and self.group_gemm_fp32 is not None: + output_tensors = self.group_gemm_fp32.run(a_list, b_list) + elif compute_dtype == torch.float16 and self.group_gemm_fp16 is not None: + output_tensors = self.group_gemm_fp16.run(a_list, b_list) + elif compute_dtype == torch.bfloat16 and self.group_gemm_bf16 is not None: + output_tensors = self.group_gemm_bf16.run(a_list, b_list) + elif compute_dtype == torch.int8 and self.group_gemm_int8 is not None: + # TODO + pass + # output_tensors = self.group_gemm_int8.run(a_list, b_list) + else: + output_tensors = [] + return output_tensors \ No newline at end of file diff --git a/byte_micro_perf/backends/backend.py b/byte_micro_perf/backends/backend.py index d6d00963..79258b1b 100644 --- a/byte_micro_perf/backends/backend.py +++ b/byte_micro_perf/backends/backend.py @@ -27,10 +27,12 @@ def __init__(self, workload_dict: Dict[str, Any], vendor_path: str): self.warmup = int(0.1 * workload_dict["iterations"]) self.vendor_path = vendor_path self.op = None + # communication params self.rank = None self.world_size = None self.group = None + # hardware info self.hw_info_dict = None self.memory_limit = None @@ -65,9 +67,45 @@ def initialize_ccl(self, rank, world_size): def setup_2d_group(self): pass + + + # gemm ops def gemm(self): pass + def batch_gemm(self): + pass + + def group_gemm(self): + pass + + + # device/host ops + def host2device(self): + pass + + def device2host(self): + pass + + + # communication ops + def allreduce(self): + pass + + def allgather(self): + pass + + def reducescatter(self): + pass + + def alltoall(self): + pass + + def broadcast(self): + pass + + + # other compute ops def add(self): pass @@ -101,57 +139,53 @@ def softmax(self): def layernorm(self): pass - def allreduce(self): - pass - def allgather(self): - pass - def reducescatter(self): - pass - def alltoall(self): - pass - def broadcast(self): - pass - def host2device(self): - pass - def device2host(self): - pass + + # perf specify input_shape for def perf(self, input_shapes: List[List[int]], dtype): error = "" - inputs_list, data_cnt = self.build_tensor(input_shapes, dtype) + # create input tensors based on input_shapes and dtype + tensor_list, tensor_cnt, tensor_size_perc_cnt = self.build_tensor( + input_shapes, dtype + ) - if data_cnt > 0: + if tensor_cnt > 0: + # random select input tensors input_index_list = [ - random.randint(0, data_cnt - 1) for _ in range(self.iterations) + random.randint(0, tensor_cnt - 1) for _ in range(self.iterations) ] # warmup num_warm_up = 10 for _ in range(num_warm_up): - self._run_operation(self.op, inputs_list[0]) + self._run_operation(self.op, tensor_list[0]) # perf self.device_synchronize() start_time = time.perf_counter_ns() for i in range(self.iterations): - result = self._run_operation(self.op, inputs_list[input_index_list[i]]) + result = self._run_operation( + self.op, + tensor_list[input_index_list[i]] + ) self.device_synchronize() end_time = time.perf_counter_ns() # time in us - exec_time = (end_time - start_time) / 1e3 - latency = round(exec_time / self.iterations, 2) + total_exec_time = (end_time - start_time) / 1e3 + latency = round(total_exec_time / self.iterations, 2) else: latency = 0 error = "OOM" + if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast"]: report = dump_communication_ops_report( self.op_name, @@ -164,6 +198,12 @@ def perf(self, input_shapes: List[List[int]], dtype): ) else: report = dump_computation_ops_report( - self.op_name, dtype, input_shapes, self.bandwidth_limit, latency, error + self.op_name, + dtype, + input_shapes, + self.bandwidth_limit, + latency, + error ) return report + diff --git a/byte_micro_perf/backends/module_store.py b/byte_micro_perf/backends/module_store.py index a8a5ff09..fcda6f81 100644 --- a/byte_micro_perf/backends/module_store.py +++ b/byte_micro_perf/backends/module_store.py @@ -12,124 +12,153 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math +from typing import List + import torch import torch.distributed as dist - -class AddOp(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input_tensor_a, input_tensor_b): - result = input_tensor_a + input_tensor_b - return result - - -class SinOp(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input_tensors): - result = torch.sin(input_tensors) - return result +from .utils import get_dtype_bytes -class CosOp(torch.nn.Module): - def __init__(self): - super().__init__() +class GemmOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) - def forward(self, input_tensors): - result = torch.cos(input_tensors) - return result + def forward(self, input_tensor_a, input_tensor_b): + compute_dtype = input_tensor_a.dtype + output_tensor = None + if compute_dtype == torch.int8: + # to be realized + pass + elif compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: + output_tensor = torch.mm(input_tensor_a, input_tensor_b) + return output_tensor + + def compute_size(self, input_shapes, torch_dtype): + # input_shapes: [[M, K], [K, N]] + a_shape, b_shape = input_shapes + M, K = a_shape + K, N = b_shape + d_shape = [M, N] + dtype_size = get_dtype_bytes(torch_dtype) + element_num = sum([math.prod(shape) for shape in [a_shape, b_shape, d_shape]]) + bytes_per_cnt = dtype_size * element_num + return bytes_per_cnt -class GeluOp(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, input_tensors): - result = torch.nn.functional.gelu(input_tensors) - return result +class BatchGemmOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + def forward(self, input_tensor_a, input_tensor_b): + compute_dtype = input_tensor_a.dtype + output_tensor = None + if compute_dtype == torch.int8: + # to be realized + pass + elif compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: + output_tensor = torch.bmm(input_tensor_a, input_tensor_b) + return output_tensor + + def compute_size(self, input_shapes, torch_dtype): + # input_shapes: [[bs, M, K], [bs, K, N]] + a_shape, b_shape = input_shapes + bs, M, K = a_shape + bs, K, N = b_shape + d_shape = [bs, M, N] + dtype_size = get_dtype_bytes(torch_dtype) + element_num = sum([math.prod(shape) for shape in [a_shape, b_shape, d_shape]]) + bytes_per_cnt = dtype_size * element_num + return bytes_per_cnt + -class ExponentialOp(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, input_tensors): - result = input_tensors.exponential_() - return result +class GroupGemmOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + def forward(self, input_tensor_a, input_tensor_b): + compute_dtype = input_tensor_a.dtype + output_tensor_list = [] + for a, b in zip(input_tensor_a, input_tensor_b): + if compute_dtype == torch.int8: + # to be realized + pass + elif compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: + output_tensor = torch.mm(a, b) + output_tensor_list.append(output_tensor) + return output_tensor_list + + + def compute_size(self, input_shapes, torch_dtype): + """ + [ + [[M1, K1], [K1, N1]], + [[M2, K2], [K2, N2]] + ] + """ + bytes_per_cnt = 0 + for problem_shape in input_shapes: + a_shape, b_shape = problem_shape + M, K = a_shape + K, N = b_shape + d_shape = [M, N] + dtype_size = get_dtype_bytes(torch_dtype) + element_num = sum([math.prod(shape) for shape in [a_shape, b_shape, d_shape]]) + bytes_per_cnt += dtype_size * element_num + return bytes_per_cnt + + def custom_create_tensors(self, input_shapes, torch_dtype): + """ + [ + [[M1, K1], [K1, N1]], + [[M2, K2], [K2, N2]] + ] + """ + left_tensors = [] + right_tensors = [] + + for problem_shape in input_shapes: + a_shape, b_shape = problem_shape + left_tensors.append( + torch.randint(0, 3, size=a_shape).type(torch_dtype).to(torch.device("cuda")) + ) + right_tensors.append( + torch.randint(0, 3, size=b_shape).type(torch_dtype).to(torch.device("cuda")) + ) + return [left_tensors, right_tensors] -class IndexAddOp(torch.nn.Module): - def __init__(self): - super().__init__() - def process_inputs(self, input_tensor, source_tensor): - index = torch.randint(0, input_tensor.shape[0], (source_tensor.shape[0],)).to( - input_tensor.device - ) - return [input_tensor, index, source_tensor] - def forward(self, input_tensor, index, source_tensor): - result = input_tensor.index_add_(0, index, source_tensor) - return result -class SortOp(torch.nn.Module): - def __init__(self): +class Host2DeviceOp(torch.nn.Module): + def __init__(self, xpu_device): super().__init__() + self.xpu_device = xpu_device - def forward(self, input_tensors): - result = torch.sort(input_tensors) - return result - - -class UniqueOp(torch.nn.Module): - def __init__(self): - super().__init__() + def process_inputs(self, input_tensors): + new_inputs = input_tensors.cpu() + return [new_inputs] def forward(self, input_tensors): - result = torch.unique(input_tensors, return_counts=True) - return result + assert input_tensors.device.type == "cpu" + output_xpu = input_tensors.to(self.xpu_device) + return output_xpu -class ExpOp(torch.nn.Module): +class Device2HostOp(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input_tensors): - result = torch.exp(input_tensors) - return result - - -class GemmOp(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input_tensor_a, input_tensor_b): - logits = torch.matmul(input_tensor_a, input_tensor_b) - return logits - - -class SoftmaxOp(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, hidden_states): - logits = torch.nn.functional.softmax(hidden_states, dim=-1) - return logits - + assert input_tensors.device.type != "cpu" + output_cpu = input_tensors.cpu() + return output_cpu -class LayerNormOp(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, hidden_states): - logits = torch.nn.functional.layer_norm( - hidden_states, (hidden_states.shape[-1],) - ) - return logits class AllReduceOp(torch.nn.Module): @@ -211,26 +240,116 @@ def forward(self, input_tensors): return True -class Device2HostOp(torch.nn.Module): + + + +class AddOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_tensor_a, input_tensor_b): + result = input_tensor_a + input_tensor_b + return result + + +class SinOp(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input_tensors): - assert input_tensors.device.type != "cpu" - output_cpu = input_tensors.cpu() - return output_cpu + result = torch.sin(input_tensors) + return result -class Host2DeviceOp(torch.nn.Module): - def __init__(self, xpu_device): +class CosOp(torch.nn.Module): + def __init__(self): super().__init__() - self.xpu_device = xpu_device - def process_inputs(self, input_tensors): - new_inputs = input_tensors.cpu() - return [new_inputs] + def forward(self, input_tensors): + result = torch.cos(input_tensors) + return result + + +class GeluOp(torch.nn.Module): + def __init__(self): + super().__init__() def forward(self, input_tensors): - assert input_tensors.device.type == "cpu" - output_xpu = input_tensors.to(self.xpu_device) - return output_xpu + result = torch.nn.functional.gelu(input_tensors) + return result + + +class ExponentialOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_tensors): + result = input_tensors.exponential_() + return result + + +class IndexAddOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def process_inputs(self, input_tensor, source_tensor): + index = torch.randint(0, input_tensor.shape[0], (source_tensor.shape[0],)).to( + input_tensor.device + ) + return [input_tensor, index, source_tensor] + + def forward(self, input_tensor, index, source_tensor): + result = input_tensor.index_add_(0, index, source_tensor) + return result + + +class SortOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_tensors): + result = torch.sort(input_tensors) + return result + + +class UniqueOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_tensors): + result = torch.unique(input_tensors, return_counts=True) + return result + + +class ExpOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_tensors): + result = torch.exp(input_tensors) + return result + + + +class SoftmaxOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states): + logits = torch.nn.functional.softmax(hidden_states, dim=-1) + return logits + + +class LayerNormOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states): + logits = torch.nn.functional.layer_norm( + hidden_states, (hidden_states.shape[-1],) + ) + return logits + + + + diff --git a/byte_micro_perf/backends/utils.py b/byte_micro_perf/backends/utils.py index 6c0531ce..9265ed38 100644 --- a/byte_micro_perf/backends/utils.py +++ b/byte_micro_perf/backends/utils.py @@ -19,9 +19,21 @@ import torch +def get_dtype_bytes(torch_type): + dtype_size = 0 + if torch_type in [torch.int32, torch.int8]: + dtype_size = torch.iinfo(torch_type).bits // 8 + elif torch_type in [torch.float32, torch.float16, torch.bfloat16]: + dtype_size = torch.finfo(torch_type).bits // 8 + else: + # not supported yet + pass + return dtype_size + + def dump_communication_ops_report( op_name: str, - dtype: str, + dtype: torch.dtype, input_shapes: List[List[int]], group_size: List[int], bandwidth_limit: float, @@ -29,16 +41,19 @@ def dump_communication_ops_report( error: str = "" ): size = math.prod(input_shapes[0]) - torch_type = getattr(torch, dtype) - if torch_type == torch.int32: - dtype_size = torch.iinfo(torch_type).bits // 8 - else: - dtype_size = torch.finfo(torch_type).bits // 8 + dtype_size = get_dtype_bytes(dtype) mb = dtype_size * size / 1024 / 1024 if error == "": algo_bw = dtype_size * size / latency / 1e3 - bus_bw = algo_bw * (group_size - 1) / group_size + """ + allreduce: 2 * (group_size - 1) * (tensor_size / group_size) + allgather: 1 * (group_size - 1) * (tensor_size / group_size) + reducescatter: 1 * (group_size - 1) * (tensor_size / group_size) + alltoall: 1 * (group_size - 1) * (tensor_size / group_size) + broadcast: tensor_size + """ + bus_bw = algo_bw * (group_size - 1) / group_size if op_name == "broadcast": bus_bw = algo_bw if op_name == "allreduce": @@ -74,7 +89,7 @@ def dump_communication_ops_report( def dump_computation_ops_report( op_name: str, - dtype: str, + dtype: torch.dtype, input_shapes: List[List[int]], bandwidth_limit: float, latency: float, @@ -93,6 +108,22 @@ def dump_computation_ops_report( K = input_shapes[0][1] N = input_shapes[1][1] size = M * K + K * N + M * N + elif op_name == "batch_gemm": + # c = batch_gemm(a, b) + bs = input_shapes[0][0] + M = input_shapes[0][1] + K = input_shapes[0][2] + N = input_shapes[1][2] + size = bs * (M * K + K * N + M * N) + elif op_name == "group_gemm": + # c_list = group_gemm(a_list, b_list) + size_list = [] + for problem_shape in input_shapes: + M = problem_shape[0][0] + K = problem_shape[0][1] + N = problem_shape[1][1] + size_list.append(M * K + K * N + M * N) + size = sum(size_list) elif op_name in ["unique", "device2host", "host2device"]: size = sum([math.prod(shape) for shape in input_shapes]) else: @@ -100,11 +131,7 @@ def dump_computation_ops_report( # MAC_total = MAC_in + MAC_out size = sum([math.prod(shape) for shape in input_shapes]) * 2 - torch_type = getattr(torch, dtype) - if torch_type == torch.int32: - dtype_size = torch.iinfo(torch_type).bits // 8 - else: - dtype_size = torch.finfo(torch_type).bits // 8 + dtype_size = get_dtype_bytes(dtype) mb = dtype_size * size / 1024 / 1024 if error == "": algo_bw = dtype_size * size / latency / 1e3 diff --git a/byte_micro_perf/core/perf_engine.py b/byte_micro_perf/core/perf_engine.py index d0ac2ec6..94258120 100644 --- a/byte_micro_perf/core/perf_engine.py +++ b/byte_micro_perf/core/perf_engine.py @@ -20,8 +20,11 @@ import os import subprocess import sys -from typing import Any, Dict, List +import pathlib import traceback +import random +from typing import Any, Dict, List + import torch import torch.multiprocessing as mp @@ -106,8 +109,14 @@ def init_process(self, rank: int, world_size: int): """ initialize_func = getattr(self.backend, "initialize_ccl") - initialize_func(rank, world_size) + + # world_size may excced available device count + ret = initialize_func(rank, world_size) + if ret is not None and not ret: + return + status = self.start_perf(self.workload) + return status def init_backend(self, hardware_type: str) -> Backend: """ @@ -150,9 +159,7 @@ def start_perf(self, workload: Dict[str, Any]) -> bool: ) # Initalize Output Dir and Reports - output_dir = os.path.abspath( - "reports/" + self.backend_type + "/" + workload["operator"] - ) + output_dir = pathlib.Path("reports").joinpath(self.backend_type).joinpath(workload["operator"]) os.makedirs(output_dir, exist_ok=True) op_name = workload["operator"] @@ -169,22 +176,60 @@ def start_perf(self, workload: Dict[str, Any]) -> bool: else: raise ValueError(f"Unknown operation: {op_name.lower()}") - perf_reports = [] + # get input shape info + shape_list = [] + + # normal ops if "input_shape_list" in self.workload: shape_list = self.workload["input_shape_list"] - else: - shape_list = [] - for M, N, K in self.workload["M/N/K"]: - shape_list.append([[M, K], [K, N]]) + # gemm or batch_gemm + elif "M/N/K" in self.workload: + if "batch_size" in self.workload: + for batch_size in self.workload["batch_size"]: + for M, K, N in self.workload["M/N/K"]: + shape_list.append([ + [batch_size, M, K], + [batch_size, K, N] + ]) + else: + for M, K, N in self.workload["M/N/K"]: + shape_list.append([[M, K], [K, N]]) + # group_gemm + elif "MNK_choices" in self.workload: + seed = workload["seed"] + MNK_list = self.workload["MNK_choices"] + problems_list = workload["problems"] + + random.seed(seed) + for problems in problems_list: + cur_inputs = [] + for _ in range(problems): + M, N, K = [random.choice(MNK_list) for _ in range(3)] + cur_shapes = [[M, K], [K, N]] + cur_inputs.append(cur_shapes) + shape_list.append(cur_inputs) + + # dtype list + dtype_list = self.workload["dtype"] + + for dtype in dtype_list: + torch_dtype = getattr(torch, dtype) - for dtype in self.workload["dtype"]: perf_reports = [] base_report["Performance"] = {} + for input_shape in shape_list: + """ + input_shape could be: + List[int]: single shape. cos + List[List[int]]: multiple inputs. add + List[List[List[in]]]: multiple inputs with multiple problems. group_gemm + """ + if isinstance(input_shape[0], int): input_shape = [input_shape] try: - reports = self.backend.perf(input_shape, dtype) + reports = self.backend.perf(input_shape, torch_dtype) except Exception as e: traceback.print_exc() log.error(f"Execute op: {op_name.lower()} failed, input_shape: {input_shape}, dtype: {dtype}, error msg: {e}") diff --git a/byte_micro_perf/workloads/batch_gemm.json b/byte_micro_perf/workloads/batch_gemm.json new file mode 100644 index 00000000..59b5ac71 --- /dev/null +++ b/byte_micro_perf/workloads/batch_gemm.json @@ -0,0 +1,22 @@ +{ + "operator": "batch_gemm", + "iterations": 100, + "M/N/K": [ + [ + 1024, + 1024, + 1024 + ] + ], + "batch_size": [ + 2, + 8, + 32 + ], + "dtype": [ + "float32", + "bfloat16", + "half", + "int8" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/gemm.json b/byte_micro_perf/workloads/gemm.json index 94bd370a..82350510 100644 --- a/byte_micro_perf/workloads/gemm.json +++ b/byte_micro_perf/workloads/gemm.json @@ -221,6 +221,7 @@ "dtype": [ "float32", "bfloat16", - "half" + "half", + "int8" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/group_gemm.json b/byte_micro_perf/workloads/group_gemm.json new file mode 100644 index 00000000..51dbfe68 --- /dev/null +++ b/byte_micro_perf/workloads/group_gemm.json @@ -0,0 +1,20 @@ +{ + "operator": "group_gemm", + "iterations": 20, + "MNK_choices": [ + 128, + 256, + 512, + 1024 + ], + "seed": 2024, + "problems": [ + 20 + ], + "dtype": [ + "float32", + "bfloat16", + "half", + "int8" + ] +} \ No newline at end of file