From 76cfd5f6b8d2e11a4f05f082c1a0b4e0a1934eb6 Mon Sep 17 00:00:00 2001 From: jiangzishan Date: Fri, 23 Aug 2024 06:28:30 +0000 Subject: [PATCH] support AMD backend for micro_perf. --- byte_micro_perf/backends/AMD/backend_amd.py | 299 ++++++++++++++++++ byte_micro_perf/backends/AMD/custom_ops.py | 63 ++++ byte_micro_perf/backends/AMD/requirements.txt | 2 + byte_micro_perf/backends/GPU/backend_gpu.py | 3 + byte_micro_perf/backends/backend.py | 5 + byte_micro_perf/core/perf_engine.py | 9 +- 6 files changed, 380 insertions(+), 1 deletion(-) create mode 100644 byte_micro_perf/backends/AMD/backend_amd.py create mode 100644 byte_micro_perf/backends/AMD/custom_ops.py create mode 100644 byte_micro_perf/backends/AMD/requirements.txt diff --git a/byte_micro_perf/backends/AMD/backend_amd.py b/byte_micro_perf/backends/AMD/backend_amd.py new file mode 100644 index 00000000..2a109730 --- /dev/null +++ b/byte_micro_perf/backends/AMD/backend_amd.py @@ -0,0 +1,299 @@ +# Copyright 2023 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import math +import os +from datetime import timedelta +from typing import Any, Dict, List + +import torch +import torch.distributed as dist +import torch.distributed.distributed_c10d as dist_c10d + +from backends.backend import Backend +from backends.module_store import * +from backends.utils import get_dtype_bytes + +from .custom_ops import GPUGemmOp, GPUBatchGemmOp, GPUGroupGemmOp + + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("PerfEngine") + + +class BackendAMD(Backend): + + def get_device_count(self): + return torch.cuda.device_count() + + def set_device(self, device_index): + torch.cuda.set_device(device_index) + + def get_device(self): + return torch.cuda.current_device() + + def all_gather_object(self, obj): + gather_object_list = [None for _ in range(self.world_size)] + dist.all_gather_object( + object_list=gather_object_list, + obj=obj, + group=self.group + ) + return gather_object_list + + + def get_device_name(self): + return torch.cuda.get_device_name(0) + + def get_backend_properties(self): + self.memory_limit = int( + torch.cuda.get_device_properties(0).total_memory / (1024**3) + ) + + if self.vendor_path is not None and os.path.exists(self.vendor_path) and (self.vendor_path).endswith(".json"): + with open(self.vendor_path, "r") as f: + self.hw_info_dict = json.load(f) + # if the vendor path does not exist, please set this param manaually + self.bandwidth_limit = self.hw_info_dict["内存参数"]["内存"]["内存带宽(GB/s)"] + else: + log.warning( + "Vendor_path: [ {} ] was not found or not a full path points to json, please check your path!!! Otherwise, please set the hardware info manaually.".format( + self.vendor_path + ) + ) + + + # device/host ops + def host2device(self): + self.op = Host2DeviceOp(torch.device("cuda")) + + def device2host(self): + self.op = Device2HostOp() + + + # communication ops + def allreduce(self): + self.op = AllReduceOp(self.group) + + def allgather(self): + self.op = AllGatherOp(self.group) + + def reducescatter(self): + self.op = ReduceScatterOp(self.group) + + def alltoall(self): + self.op = AllToAllOp(self.group) + + def broadcast(self): + self.op = BroadcastOp(self.group) + + def p2p(self): + self.op = P2POp(self.group, self.ranks, self.rank) + + # compute ops + # unary ops + def sin(self): + self.op = SinOp() + + def cos(self): + self.op = CosOp() + + def exp(self): + self.op = ExpOp() + + def exponential(self): + self.op = ExponentialOp() + + def silu(self): + self.op = SiluOp() + + def gelu(self): + self.op = GeluOp() + + def swiglu(self): + self.op = SwiGLUOp() + + def cast(self): + self.op = CastOp() + + + # binary ops + def add(self): + self.op = AddOp() + + def mul(self): + self.op = MulOp() + + def sub(self): + self.op = SubOp() + + def div(self): + self.op = DivOp() + + + # reduce ops + def layernorm(self): + self.op = LayerNormOp() + + def softmax(self): + self.op = SoftmaxOp() + + def reduce_sum(self): + self.op = ReduceSumOp() + + def reduce_min(self): + self.op = ReduceMinOp() + + def reduce_max(self): + self.op = ReduceMaxOp() + + + # index ops + def index_add(self): + self.op = IndexAddOp() + + def sort(self): + self.op = SortOp() + + def unique(self): + self.op = UniqueOp() + + def scatter(self): + self.op = ScatterOp() + + def gather(self): + self.op = GatherOp() + + # gemm ops + def gemm(self): + self.op = GPUGemmOp() + + def gemv(self): + self.op = GPUGemmOp() + + def batch_gemm(self): + self.op = GPUBatchGemmOp() + + def group_gemm(self): + self.op = GPUGroupGemmOp() + + + + # create input tensors + def build_tensor(self, input_shapes, dtype): + torch.cuda.empty_cache() + torch_dtype = getattr(torch, dtype) + + # compute size of input and output tensors + if hasattr(self.op, "compute_size"): + bytes_per_cnt = self.op.compute_size(input_shapes, dtype) + # default: input_tensors_size == output_tensor_size, all tensors have same dtype + else: + dtype_size = get_dtype_bytes(dtype) + element_num = 2 * sum([math.prod(shape) for shape in input_shapes]) + bytes_per_cnt = dtype_size * element_num + + # compute max avail tensors for compute + avail_bytes = (self.memory_limit - 4) * 1024**3 + avail_cnts = avail_bytes // bytes_per_cnt + max_data_cnt = min(self.iterations, avail_cnts) + + # create input tensors for each op + input_tensors_list = [] + for _ in range(max_data_cnt): + # create input tensors + if hasattr(self.op, "custom_create_tensors"): + input_tensors = self.op.custom_create_tensors(input_shapes, torch_dtype, "cuda") + input_tensors_list.append(input_tensors) + # default: all input tensors have same dtype + else: + if torch_dtype in [torch.int8, torch.int32]: + input_tensors = [ + torch.randint(-3, 3, size=shape, dtype=torch_dtype, device="cuda") + for shape in input_shapes + ] + else: + input_tensors = [ + torch.randn(shape, dtype=torch_dtype, device="cuda") + for shape in input_shapes + ] + input_tensors_list.append(input_tensors) + if hasattr(self.op, "process_inputs"): + input_tensors_list = [ + self.op.process_inputs(*(input_tensor)) + for input_tensor in input_tensors_list + ] + return input_tensors_list, max_data_cnt, bytes_per_cnt + + + + def _run_operation(self, operation, inputs): + result = operation(*inputs) + return result + + def device_synchronize(self): + torch.cuda.synchronize() + return True + + def initialize_ccl(self, rank, world_size): + """ + initialize distributed process groups and relevant ENVs + """ + # check device_count + device_count = torch.cuda.device_count() + if world_size > device_count: + world_size = device_count + if rank >= world_size: + return False + + # set envs + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "49373" + os.environ["LOCAL_RANK"] = str(rank) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + torch.cuda.set_device(rank) + + # Call the init process + timeout_seconds = int(os.environ.get("MEGATRON_NCCL_TIMEOUT_SECOND", 30)) + torch.distributed.init_process_group( + backend="nccl", + world_size=world_size, + rank=rank, + store=None, + timeout=timedelta(seconds=timeout_seconds), + ) + self.setup_2d_group() + log.warning("DIST: rank {}, world_size {}".format(rank, world_size)) + return True + + def setup_2d_group(self): + self.rank = dist.get_rank() + torch.cuda.set_device(self.rank) + origin_store_based_barrier = dist_c10d._store_based_barrier + dist_c10d._store_based_barrier = lambda *a, **kw: None + self.world_size = dist.get_world_size() + self.ranks = range(0, self.world_size) + group = dist.new_group(self.ranks) + if self.rank in self.ranks: + self.group = group + dist_c10d._store_based_barrier = origin_store_based_barrier + # wait for all ranks finish group initializing + torch.distributed.barrier() + + def destroy_process_group(self): + dist.destroy_process_group() \ No newline at end of file diff --git a/byte_micro_perf/backends/AMD/custom_ops.py b/byte_micro_perf/backends/AMD/custom_ops.py new file mode 100644 index 00000000..8a6ee7ce --- /dev/null +++ b/byte_micro_perf/backends/AMD/custom_ops.py @@ -0,0 +1,63 @@ +from typing import List + +import torch + +from backends.module_store import GemmOp, BatchGemmOp, GroupGemmOp + + +# gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16 +# gemm(cutlass) int8 --> int32 +class GPUGemmOp(GemmOp): + def __init__(self): + super().__init__() + + def forward( + self, + input_tensor_a : torch.Tensor, + input_tensor_b : torch.Tensor + ): + compute_dtype = input_tensor_a.dtype + if compute_dtype == torch.int8: + output_tensor = input_tensor_a + else: + output_tensor = torch.mm(input_tensor_a, input_tensor_b) + return output_tensor + + +# batch_gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16 +# batch_gemm(cutlass) int8 --> int32 +class GPUBatchGemmOp(BatchGemmOp): + def __init__(self): + super().__init__() + + def forward( + self, + input_tensor_a : torch.Tensor, + input_tensor_b : torch.Tensor + ): + compute_dtype = input_tensor_a.dtype + + output_tensor = None + if compute_dtype == torch.int8: + output_tensor = input_tensor_a + else: + output_tensor = torch.bmm(input_tensor_a, input_tensor_b) + return output_tensor + + +# group_gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16 +# group_gemm(cutlass) int8 --> int32 +class GPUGroupGemmOp(GroupGemmOp): + def __init__(self): + super().__init__() + + def forward(self, + a_list : List[torch.Tensor], + b_list : List[torch.Tensor] + ): + compute_dtype = a_list[0].dtype + if compute_dtype == torch.int8: + output_tensors = a_list + else: + output_tensors = [a @ b for a, b in zip(a_list, b_list)] + return output_tensors \ No newline at end of file diff --git a/byte_micro_perf/backends/AMD/requirements.txt b/byte_micro_perf/backends/AMD/requirements.txt new file mode 100644 index 00000000..3c7b834c --- /dev/null +++ b/byte_micro_perf/backends/AMD/requirements.txt @@ -0,0 +1,2 @@ +-i https://download.pytorch.org/whl/rocm6.1 +torch \ No newline at end of file diff --git a/byte_micro_perf/backends/GPU/backend_gpu.py b/byte_micro_perf/backends/GPU/backend_gpu.py index 430dcb1e..9f09b781 100644 --- a/byte_micro_perf/backends/GPU/backend_gpu.py +++ b/byte_micro_perf/backends/GPU/backend_gpu.py @@ -302,3 +302,6 @@ def setup_2d_group(self): # wait for all ranks finish group initializing torch.distributed.barrier() + + def destroy_process_group(self): + dist.destroy_process_group() \ No newline at end of file diff --git a/byte_micro_perf/backends/backend.py b/byte_micro_perf/backends/backend.py index c712ac6b..d8c041b6 100644 --- a/byte_micro_perf/backends/backend.py +++ b/byte_micro_perf/backends/backend.py @@ -88,6 +88,11 @@ def initialize_ccl(self, rank, world_size): def setup_2d_group(self): pass + @abstractmethod + def destroy_process_group(self): + pass + + # communication ops def host2device(self): diff --git a/byte_micro_perf/core/perf_engine.py b/byte_micro_perf/core/perf_engine.py index 7edda902..2fb03418 100644 --- a/byte_micro_perf/core/perf_engine.py +++ b/byte_micro_perf/core/perf_engine.py @@ -331,6 +331,8 @@ def start_engine(self) -> None: self.deactivate_venv() + + def perf_func(self, rank: int, *args): backend_instance = self.backend_class(self.workload, self.args.vendor_path) op_name = self.workload["operator"] @@ -419,11 +421,12 @@ def perf_func(self, rank: int, *args): result_list.append(ResultItem(test_instance, reports)) + if rank == 0: print(f"{len(result_list)} tasks finished.") dtype_results_mapping = {} - for result in output_result_list: + for result in result_list: if result.config.dtype not in dtype_results_mapping: dtype_results_mapping[result.config.dtype] = [] dtype_results_mapping[result.config.dtype].append(result) @@ -452,6 +455,10 @@ def perf_func(self, rank: int, *args): with open(filepath, "w") as f: json.dump(base_report, f, indent=4) + if world_size > 1: + destroy_group_func = getattr(backend_instance, "destroy_process_group") + destroy_group_func() + return True