diff --git a/byte_micro_perf/backends/GPU/backend_gpu.py b/byte_micro_perf/backends/GPU/backend_gpu.py index fa6e8983..54f17e23 100644 --- a/byte_micro_perf/backends/GPU/backend_gpu.py +++ b/byte_micro_perf/backends/GPU/backend_gpu.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import json import logging -import math -import os + from datetime import timedelta from typing import Any, Dict, List @@ -25,287 +25,71 @@ from backends import module_store from backends.backend import Backend -from backends.utils import get_dtype_bytes logging.basicConfig(level=logging.INFO) log = logging.getLogger("PerfEngine") + class BackendGPU(Backend): - def __init__(self, workload_dict: Dict[str, Any], vendor_path: str): + def __init__(self, workload_dict, vendor_path): super().__init__(workload_dict, vendor_path) - def get_device_count(self): - return torch.cuda.device_count() + + def get_device_name(self): return torch.cuda.get_device_name(0) + def get_torch_device_name(self): + return "cuda" + + def get_device_properties(self): + return torch.cuda.get_device_properties(0) + + def get_device_count(self): + return torch.cuda.device_count() + def set_device(self, device_index : int): torch.cuda.set_device(device_index) def get_device(self): return torch.cuda.current_device() - + def device_synchronize(self): torch.cuda.synchronize() - def get_backend_properties(self): - self.memory_limit = int( - torch.cuda.get_device_properties(0).total_memory / (1024**3) - ) + def empty_cache(self): + torch.cuda.empty_cache() - if self.vendor_path is not None and os.path.exists(self.vendor_path) and (self.vendor_path).endswith(".json"): - with open(self.vendor_path, "r") as f: - self.hw_info_dict = json.load(f) - # if the vendor path does not exist, please set this param manaually - self.bandwidth_limit = self.hw_info_dict["内存参数"]["内存"]["内存带宽(GB/s)"] - else: - log.warning( - "Vendor_path: [ {} ] was not found or not a full path points to json, please check your path!!! Otherwise, please set the hardware info manaually.".format( - self.vendor_path - ) - ) - def initialize_ccl(self, rank, world_size): - torch.cuda.set_device(rank) + def get_dist_module(self): + return dist + def initialize_ccl(self, rank, world_size): + # check device_count + device_count = self.get_device_count() + if world_size > device_count: + world_size = device_count + if rank >= world_size: + return False + self.set_device(rank) + + # set envs and internal vars os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "49373" os.environ["LOCAL_RANK"] = str(rank) os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) - dist.init_process_group( + # init process group + self.get_dist_module().init_process_group( backend="nccl", world_size=world_size, rank=rank, timeout=timedelta(seconds=1800) ) - - self.setup_2d_group() return True - - - def setup_2d_group(self): - # get rank and set device - self.rank = dist.get_rank() - torch.cuda.set_device(self.rank) - - origin_store_based_barrier = dist_c10d._store_based_barrier - dist_c10d._store_based_barrier = lambda *a, **kw: None - - self.world_size = dist.get_world_size() - self.ranks = range(0, self.world_size) - group = dist.new_group(self.ranks) - if self.rank in self.ranks: - self.group = group - - dist_c10d._store_based_barrier = origin_store_based_barrier - - # wait for all ranks finish group initializing - dist.barrier() - - def destroy_process_group(self): - dist.destroy_process_group() - - def barrier(self): - dist.barrier(self.group) - - - def all_gather_object(self, obj): - gather_object_list = [None for _ in range(self.world_size)] - dist.all_gather_object( - object_list=gather_object_list, - obj=obj, - group=self.group - ) - return gather_object_list - - - # create input tensors - def build_tensor(self, input_shapes, dtype): - torch.cuda.empty_cache() - torch_dtype = getattr(torch, dtype) - - # compute size of input and output tensors - if hasattr(self.op, "compute_size"): - bytes_per_cnt = self.op.compute_size(input_shapes, dtype) - # default: input_tensors_size == output_tensor_size, all tensors have same dtype - else: - dtype_size = get_dtype_bytes(dtype) - element_num = 2 * sum([math.prod(shape) for shape in input_shapes]) - bytes_per_cnt = dtype_size * element_num - - - # avoid use L2 Cache: assume max 1GB currently - # data_per_cnt > 1GB, use two buffers - # data_per_cnt < 1GB, malloc multiple buffer to exceed 1GB - assume_l2_cache_size = 1 * 1024**3 - assume_avail_bytes = self.memory_limit * 0.9 * 1024**3 - - if bytes_per_cnt > assume_avail_bytes: - return [], 0, bytes_per_cnt - elif 2 * bytes_per_cnt > assume_avail_bytes: - max_data_cnt = 1 - elif bytes_per_cnt > assume_l2_cache_size: - max_data_cnt = 2 - else: - max_data_cnt = math.ceil(assume_l2_cache_size / bytes_per_cnt) - - # create input tensors for each op - input_tensors_list = [] - for _ in range(max_data_cnt): - # create input tensors - if hasattr(self.op, "custom_create_tensors"): - input_tensors = self.op.custom_create_tensors(input_shapes, torch_dtype, "cuda") - input_tensors_list.append(input_tensors) - # default: all input tensors have same dtype - else: - if torch_dtype in [torch.int8, torch.int32]: - input_tensors = [ - torch.randint(-3, 3, size=shape, dtype=torch_dtype, device="cuda") - for shape in input_shapes - ] - else: - input_tensors = [ - torch.randn(shape, dtype=torch_dtype, device="cuda") - for shape in input_shapes - ] - input_tensors_list.append(input_tensors) - if hasattr(self.op, "process_inputs"): - input_tensors_list = [ - self.op.process_inputs(*(input_tensor)) - for input_tensor in input_tensors_list - ] - - return input_tensors_list, max_data_cnt, bytes_per_cnt - - - def _run_operation(self, operation, inputs): - result = operation(*inputs) - return result - - - - # device/host ops - def host2device(self): - self.op = module_store.Host2DeviceOp() - - def device2host(self): - self.op = module_store.Device2HostOp() - - # communication ops - def allreduce(self): - self.op = module_store.AllReduceOp(self.group) - - def allgather(self): - self.op = module_store.AllGatherOp(self.group) - - def reducescatter(self): - self.op = module_store.ReduceScatterOp(self.group) - - def alltoall(self): - self.op = module_store.AllToAllOp(self.group) - - def broadcast(self): - self.op = module_store.BroadcastOp(self.group) - - def p2p(self): - self.op = module_store.P2POp(self.group, self.ranks, self.rank) - - # compute ops - # unary ops - def sin(self): - self.op = module_store.SinOp() - - def cos(self): - self.op = module_store.CosOp() - - def exp(self): - self.op = module_store.ExpOp() - - def exponential(self): - self.op = module_store.ExponentialOp() - - def silu(self): - self.op = module_store.SiluOp() - - def gelu(self): - self.op = module_store.GeluOp() - - def swiglu(self): - self.op = module_store.SwiGLUOp() - - def cast(self): - self.op = module_store.CastOp() - - def log(self): - self.op = module_store.LogOp() - - def sqrt(self): - self.op = module_store.SqrtOp() - - # binary ops - def add(self): - self.op = module_store.AddOp() - - def mul(self): - self.op = module_store.MulOp() - - def sub(self): - self.op = module_store.SubOp() - - def div(self): - self.op = module_store.DivOp() - - - # reduce ops - def layernorm(self): - self.op = module_store.LayerNormOp() - - def softmax(self): - self.op = module_store.SoftmaxOp() - - def reduce_sum(self): - self.op = module_store.ReduceSumOp() - - def reduce_min(self): - self.op = module_store.ReduceMinOp() - - def reduce_max(self): - self.op = module_store.ReduceMaxOp() - - - # index ops - def index_add(self): - self.op = module_store.IndexAddOp() - - def sort(self): - self.op = module_store.SortOp() - - def unique(self): - self.op = module_store.UniqueOp() - - def scatter(self): - self.op = module_store.ScatterOp() - - def gather(self): - self.op = module_store.GatherOp() - - - # gemm ops - def gemm(self): - self.op = module_store.GemmOp() - - def gemv(self): - self.op = module_store.GemmOp() - - def batch_gemm(self): - self.op = module_store.BatchGemmOp() - - def group_gemm(self): - self.op = module_store.GroupGemmOp() diff --git a/byte_micro_perf/backends/GPU/custom_ops.py b/byte_micro_perf/backends/GPU/custom_ops.py deleted file mode 100644 index 6f4a6b9a..00000000 --- a/byte_micro_perf/backends/GPU/custom_ops.py +++ /dev/null @@ -1,119 +0,0 @@ -from typing import List - -import torch -import cutlass - -from backends.module_store import GemmOp, BatchGemmOp, GroupGemmOp - - -# gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16 -# gemm(cutlass) int8 --> int32 -class GPUGemmOp(GemmOp): - def __init__(self): - super().__init__() - - try: - import cutlass - dtype = torch.int8 - accum_dtype=torch.int32 - self.plan = cutlass.op.Gemm( - alpha=1, beta=0, - element_A=dtype, - element_B=dtype, - element_C=accum_dtype, - element_D=accum_dtype, - layout_A=cutlass.LayoutType.RowMajor, - layout_B=cutlass.LayoutType.RowMajor, - layout_C=cutlass.LayoutType.RowMajor - ) - self.op = self.plan.construct() - self.gemm_op_int8 = cutlass.emit.pytorch( - self.op, name='gemm', cc=self.plan.cc, - jit=True, sourcedir='out' - ) - except: - self.gemm_op_int8 = None - raise Exception("GPUGemmOp cutlass error") - - def forward( - self, - input_tensor_a : torch.Tensor, - input_tensor_b : torch.Tensor - ): - compute_dtype = input_tensor_a.dtype - if compute_dtype == torch.int8: - output_tensor = self.gemm_op_int8.run(input_tensor_a, input_tensor_b) - else: - output_tensor = torch.mm(input_tensor_a, input_tensor_b) - return output_tensor - - -# batch_gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16 -# batch_gemm(cutlass) int8 --> int32 -class GPUBatchGemmOp(BatchGemmOp): - def __init__(self): - super().__init__() - - try: - import cutlass - except: - raise Exception("GPUBatchGemmOp import cutlass error") - - def forward( - self, - input_tensor_a : torch.Tensor, - input_tensor_b : torch.Tensor - ): - compute_dtype = input_tensor_a.dtype - - output_tensor = None - if compute_dtype == torch.int8: - bs, m, n = input_tensor_a.shape[0], input_tensor_a.shape[1], input_tensor_b.shape[2] - c_tensor = torch.randint(-3, 3, [bs, m, n], dtype=torch.int32, device="cuda") - output_tensor = torch.randint(-3, 3, [bs, m, n], dtype=torch.int32, device="cuda") - plan = cutlass.op.Gemm(A=input_tensor_a, B=input_tensor_b, C=c_tensor, D=output_tensor, element_accumulator=cutlass.DataType.s32) - plan.run(input_tensor_a, input_tensor_b, c_tensor, output_tensor, 1, 0) - else: - output_tensor = torch.bmm(input_tensor_a, input_tensor_b) - return output_tensor - - -# group_gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16 -# group_gemm(cutlass) int8 --> int32 -class GPUGroupGemmOp(GroupGemmOp): - def __init__(self): - super().__init__() - - try: - import cutlass - dtype = torch.int8 - accum_dtype=torch.int32 - self.plan = cutlass.op.GroupedGemm( - alpha=1, beta=0, - element_A=dtype, - element_B=dtype, - element_C=accum_dtype, - element_D=accum_dtype, - layout_A=cutlass.LayoutType.RowMajor, - layout_B=cutlass.LayoutType.RowMajor, - layout_C=cutlass.LayoutType.RowMajor - ) - self.op = self.plan.construct() - self.gemm_op_int8 = cutlass.emit.pytorch( - self.op, name='group_gemm', cc=self.plan.cc, - jit=True, sourcedir='out' - ) - except: - self.gemm_op_int8 = None - raise Exception("GPUGroupGemmOp cutlass error") - - def forward(self, - a_list : List[torch.Tensor], - b_list : List[torch.Tensor] - ): - compute_dtype = a_list[0].dtype - if compute_dtype == torch.int8: - output_tensors = self.gemm_op_int8.run(a_list, b_list) - else: - output_tensors = [a @ b for a, b in zip(a_list, b_list)] - return output_tensors \ No newline at end of file diff --git a/byte_micro_perf/backends/backend.py b/byte_micro_perf/backends/backend.py index b6f8f1dc..ca1436e1 100644 --- a/byte_micro_perf/backends/backend.py +++ b/byte_micro_perf/backends/backend.py @@ -14,190 +14,209 @@ import os import time +import json +import math +import random import logging import traceback from abc import ABC, abstractmethod from datetime import timedelta -from typing import Any, Dict, List +from typing import Any, Dict, List, final import torch import torch.distributed as dist import torch.distributed.distributed_c10d as dist_c10d +from backends import module_store from backends.utils import dump_communication_ops_report, dump_computation_ops_report logging.basicConfig(level=logging.INFO) log = logging.getLogger("PerfEngine") +default_op_registry = module_store.op_registry.copy() +default_op_compute_size_registry = module_store.op_compute_size_funcs.copy() +default_op_create_tensors_registry = module_store.op_create_tensors_funcs.copy() + + class Backend(ABC): def __init__(self, workload_dict: Dict[str, Any], vendor_path: str): self.op_name = workload_dict["operator"] self.iterations = workload_dict["iterations"] - self.warmup = int(0.1 * workload_dict["iterations"]) - self.vendor_path = vendor_path self.op = None - + # communication params - self.rank = None self.world_size = None - self.group = None + self.rank = None # hardware info - self.hw_info_dict = None - self.memory_limit = None - self.bandwidth_limit = None - self.get_backend_properties() + self.device_name = self.get_device_name() + + self.memory_limit = int( + self.get_device_properties().total_memory / (1024**3) + ) + if vendor_path is not None and os.path.exists(vendor_path) and (vendor_path).endswith(".json"): + with open(vendor_path, "r") as f: + self.hw_info_dict = json.load(f) + # if the vendor path does not exist, please set this param manaually + self.bandwidth_limit = self.hw_info_dict["内存参数"]["内存"]["内存带宽(GB/s)"] + + """ + op + """ + def get_op_instance(self): + if self.op_name in default_op_registry: + self.op = default_op_registry[self.op_name] + else: + raise NotImplementedError + + def get_op_compute_size_func(self): + if self.op_name in default_op_compute_size_registry: + return default_op_compute_size_registry[self.op_name] + else: + raise NotImplementedError + + def get_op_create_tensors_func(self): + if self.op_name in default_op_create_tensors_registry: + return default_op_create_tensors_registry[self.op_name] + else: + raise NotImplementedError + - self.target_dtype = None """ device management related """ - @abstractmethod - def get_device_count(self): + # torch.get_device_name() + def get_device_name(self): + raise NotImplementedError + + # "cuda" + def get_torch_device_name(self): raise NotImplementedError - @abstractmethod - def get_device_name(self): + def get_device_properties(self): raise NotImplementedError - @abstractmethod + def get_device_count(self): + raise NotImplementedError + def set_device(self, device_index : int): raise NotImplementedError - @abstractmethod def get_device(self): raise NotImplementedError - @abstractmethod def device_synchronize(self): raise NotImplementedError - @abstractmethod - def get_backend_properties(self): + def empty_cache(self): + raise NotImplementedError + + """ + ccl related + """ + def get_dist_module(self): raise NotImplementedError - - @abstractmethod def initialize_ccl(self, rank, world_size): - torch.cuda.set_device(rank) - - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = "49373" - os.environ["LOCAL_RANK"] = str(rank) - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - - dist.init_process_group( - backend="nccl", - world_size=world_size, - rank=rank, - timeout=timedelta(seconds=1800) - ) - - self.setup_2d_group() - return True - - - @abstractmethod - def setup_2d_group(self): - # get dist info - self.world_size = dist.get_world_size() - self.rank = dist.get_rank() - self.ranks = range(0, self.world_size) - - # set device - torch.cuda.set_device(self.rank) - - # get original store_based_barrier - origin_store_based_barrier = dist_c10d._store_based_barrier - dist_c10d._store_based_barrier = lambda *a, **kw: None - group = dist.new_group(self.ranks) - if self.rank in self.ranks: - self.group = group - dist_c10d._store_based_barrier = origin_store_based_barrier - - # wait for all ranks finish group initializing - torch.barrier() - - @abstractmethod + raise NotImplementedError + def destroy_process_group(self): - dist.destroy_process_group() + dist = self.get_dist_module() + if dist.is_initialized(): + dist.destroy_process_group() - @abstractmethod def barrier(self): - dist.barrier(self.group) - - @abstractmethod - def all_gather_object(self, obj): - if dist.is_initialized() and self.world_size is not None and self.group is not None: - gather_object_list = [None for _ in range(self.world_size)] - dist.all_gather_object( - object_list=gather_object_list, - obj=obj, - group=self.group - ) - return gather_object_list + dist = self.get_dist_module() + if dist.is_initialized(): + dist.barrier() - @abstractmethod - def build_tensor(self, input_shapes: List[List[int]], dtype): - raise NotImplementedError - @abstractmethod def _run_operation(self, operation, inputs): - return operation(*inputs) + result = operation(*inputs) + return result + + + def build_tensor(self, input_shapes, torch_dtype): + # get funcs + compute_size_func = self.get_op_compute_size_func() + create_tensors_func = self.get_op_create_tensors_func() + _, tensor_size, _, _ = compute_size_func(input_shapes, torch_dtype) + # avoid use cache, assume cache size is 1 GiB, and use 80% of device memory + assume_cache_size = 1 * 1024**3 + assume_avail_bytes = self.memory_limit * 0.9 * 1024**3 + + if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p", "device2host", "host2device"]: + if tensor_size > assume_avail_bytes: + return [] + else: + max_data_cnt = 1 + else: + if tensor_size > assume_avail_bytes: + return [], 0, tensor_size + elif 2 * tensor_size > assume_avail_bytes: + max_data_cnt = 1 + elif tensor_size > assume_cache_size: + max_data_cnt = 2 + else: + max_data_cnt = min(math.floor(assume_avail_bytes / tensor_size), self.iterations) + + # create tensor_list for each op + tensor_list = [ + create_tensors_func(input_shapes, torch_dtype, self.get_torch_device_name()) for _ in range(max_data_cnt) + ] + return tensor_list - # perf specify input_shape for def perf(self, input_shapes: List[List[int]], dtype): error = "" - # create input tensors based on input_shapes and dtype - tensor_list, tensor_cnt, tensor_size_perc_cnt = self.build_tensor( - input_shapes, dtype) + # create necessary tensors + torch_dtype = getattr(torch, dtype) + tensor_list = self.build_tensor(input_shapes, torch_dtype) + - if tensor_cnt > 0: + if len(tensor_list) > 0: try: + warm_iterations = 5 + test_iterations = 5 + max_total_duration = 10. + prefer_iterations = self.iterations + # warmup - num_warm_up = 10 - for _ in range(num_warm_up): - self._run_operation(self.op, tensor_list[0]) + for _ in range(warm_iterations): + self._run_operation(self.op, random.choice(tensor_list)) # test perf - num_test_perf = 10 self.device_synchronize() start_time = time.perf_counter_ns() - for i in range(num_test_perf): - self._run_operation(self.op, tensor_list[0]) + for i in range(test_iterations): + self._run_operation(self.op, random.choice(tensor_list)) self.device_synchronize() end_time = time.perf_counter_ns() + avg_op_duration = (end_time - start_time) / 1e9 / test_iterations - prefer_iterations = self.iterations - max_perf_seconds = 10.0 - op_duration = (end_time - start_time) / num_test_perf / 1e9 - if op_duration > max_perf_seconds: - prefer_iterations = 5 - else: - prefer_iterations = min(max(int(max_perf_seconds // op_duration), 10), self.iterations) - # ccl ops need barrier - if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p"]: - self.barrier() + if avg_op_duration > max_total_duration: + prefer_iterations = 2 + else: + prefer_iterations = min(math.ceil(max_total_duration / avg_op_duration), self.iterations) # perf self.device_synchronize() start_time = time.perf_counter_ns() for i in range(prefer_iterations): - self._run_operation(self.op, tensor_list[i % tensor_cnt]) + self._run_operation(self.op, tensor_list[i % len(tensor_list)]) self.device_synchronize() end_time = time.perf_counter_ns() # time in us total_exec_time = (end_time - start_time) / 1e3 latency = round(total_exec_time / prefer_iterations, 2) + except Exception as e: traceback.print_exc() latency = 0 @@ -205,152 +224,32 @@ def perf(self, input_shapes: List[List[int]], dtype): else: latency = 0 error = "OOM" - - tensor_list = [] - if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p"]: + # clean tensors and device memory + del tensor_list + self.empty_cache() + + + # create report for communication ops and computation ops + if self.op_name in [ + "allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p", + "device2host", "host2device" + ]: report = dump_communication_ops_report( - self.op_name, - dtype, - input_shapes, - self.group.size(), + self.op_name, torch_dtype, input_shapes, + self.get_op_compute_size_func(), + self.world_size, None, latency, error ) else: report = dump_computation_ops_report( - self.op_name, - dtype, - input_shapes, - self.bandwidth_limit, + self.op_name, torch_dtype, input_shapes, + self.get_op_compute_size_func(), + None, latency, error ) - return report - - - """ - gemm ops - """ - def gemm(self): - pass - - def gemv(self): - pass - - def batch_gemm(self): - pass - - def group_gemm(self): - pass - - - """ - communication ops - """ - def host2device(self): - pass - - def device2host(self): - pass - - def allreduce(self): - pass - - def allgather(self): - pass - - def reducescatter(self): - pass - - def alltoall(self): - pass - - def broadcast(self): - pass - - def p2p(self): - pass - - # compute ops - # unary ops - def sin(self): - pass - - def cos(self): - pass - - def exp(self): - pass - - def exponential(self): - pass - - def silu(self): - pass - - def gelu(self): - pass - - def swiglu(self): - pass - - def cast(self): - pass - - def log(self): - pass - - def sqrt(self): - pass - - # binary ops - def add(self): - pass - - def mul(self): - pass - - def sub(self): - pass - - def div(self): - pass - - - # reduce ops - def layernorm(self): - pass - - def softmax(self): - pass - - def reduce_sum(self): - pass - - def reduce_min(self): - pass - - def reduce_max(self): - pass - - - # index ops - def index_add(self): - pass - - def sort(self): - pass - - def unique(self): - pass - - def scatter(self): - pass - - def gather(self): - pass - diff --git a/byte_micro_perf/backends/module_store.py b/byte_micro_perf/backends/module_store.py index 1c1015e2..c94dca65 100644 --- a/byte_micro_perf/backends/module_store.py +++ b/byte_micro_perf/backends/module_store.py @@ -18,621 +18,1018 @@ import torch import torch.distributed as dist -from .utils import get_dtype_bytes - - -""" -gemm ops -""" -class GemmOp(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def compute_size(self, input_shapes, dtype): - # input_shapes: [[M, K], [K, N]] - torch_dtype = getattr(torch, dtype) - a_shape, b_shape = input_shapes - M, K = a_shape - K, N = b_shape +def gemm_compute_size(input_shapes, torch_dtype): + # input_shapes: [[M, K], [K, N]] + a_shape, b_shape = input_shapes + M, _ = a_shape + _, N = b_shape + d_shape = [M, N] + + # get element_size and dtype_size + input_element_num = sum([math.prod(shape) for shape in [a_shape, b_shape]]) + output_element_num = sum([math.prod(shape) for shape in [d_shape]]) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + + input_tensor_size = dtype_size * input_element_num + if torch_dtype == torch.int8: + output_tensor_size = 4 * output_element_num + else: + output_tensor_size = dtype_size * output_element_num + batch_size = M + tensor_size = input_tensor_size + output_tensor_size + return (batch_size, tensor_size, input_tensor_size, output_tensor_size) + +def gemm_create_tensors(input_shapes, torch_dtype, xpu_device): + # input_shapes: [[M, K], [K, N]] + a_shape, b_shape = input_shapes + M, _ = a_shape + _, N = b_shape + d_shape = [M, N] + + # create input tensors + a_tensor = torch.randint(0, 7, a_shape, dtype=torch_dtype, device=xpu_device) + b_tensor = torch.randint(0, 7, b_shape, dtype=torch_dtype, device=xpu_device) + + # create output tensors + d_tensor = torch.randint(0, 7, d_shape, dtype=torch_dtype, device=xpu_device) + return [a_tensor, b_tensor, d_tensor] + + + + + + +def batch_gemm_compute_size(input_shapes, torch_dtype): + # input_shapes: [[bs, M, K], [bs, K, N]] + a_shape, b_shape = input_shapes + bs, M, _ = a_shape + bs, _, N = b_shape + d_shape = [bs, M, N] + + # get element_size and dtype_size + input_element_num = sum([math.prod(shape) for shape in [a_shape, b_shape]]) + output_element_num = sum([math.prod(shape) for shape in [d_shape]]) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + + input_tensor_size = dtype_size * input_element_num + if torch_dtype == torch.int8: + output_tensor_size = 4 * output_element_num + else: + output_tensor_size = dtype_size * output_element_num + batch_size = bs + tensor_size = input_tensor_size + output_tensor_size + return (batch_size, tensor_size, input_tensor_size, output_tensor_size) + +def batch_gemm_create_tensors(input_shapes, torch_dtype, xpu_device): + # input_shapes: [[bs, M, K], [bs, K, N]] + a_shape, b_shape = input_shapes + bs, M, _ = a_shape + bs, _, N = b_shape + d_shape = [bs, M, N] + + # create input tensors + a_tensor = torch.randint(0, 7, a_shape, dtype=torch_dtype, device=xpu_device) + b_tensor = torch.randint(0, 7, b_shape, dtype=torch_dtype, device=xpu_device) + + # create output tensors + d_tensor = torch.randint(0, 7, d_shape, dtype=torch_dtype, device=xpu_device) + return [a_tensor, b_tensor, d_tensor] + + + +def group_gemm_compute_size(input_shapes, torch_dtype): + """ + [ + [[M1, K1], [K1, N1]], + [[M2, K2], [K2, N2]] + ] + """ + + input_tensor_size = 0 + output_tensor_size = 0 + + for problem_shape in input_shapes: + a_shape, b_shape = problem_shape + M, _ = a_shape + _, N = b_shape d_shape = [M, N] - dtype_size = get_dtype_bytes(dtype) + + # get element_size and dtype_size input_element_num = sum([math.prod(shape) for shape in [a_shape, b_shape]]) output_element_num = sum([math.prod(shape) for shape in [d_shape]]) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + + input_tensor_size += dtype_size * input_element_num if torch_dtype == torch.int8: - bytes_per_cnt = dtype_size * input_element_num + get_dtype_bytes("float32") * output_element_num + output_tensor_size += 4 * output_element_num else: - bytes_per_cnt = dtype_size * (input_element_num + output_element_num) - return bytes_per_cnt + output_tensor_size += dtype_size * output_element_num + batch_size = 1 + tensor_size = input_tensor_size + output_tensor_size + + return batch_size, tensor_size, input_tensor_size, output_tensor_size + +def group_gemm_create_tensors(input_shapes, torch_dtype, xpu_device): + """ + [ + [[M1, K1], [K1, N1]], + [[M2, K2], [K2, N2]] + ] + """ + left_tensors = [] + right_tensors = [] + output_tensors = [] + + for problem_shape in input_shapes: + a_shape, b_shape = problem_shape + M, _ = a_shape + _, N = b_shape + d_shape = [M, N] - def forward(self, input_tensor_a, input_tensor_b): - compute_dtype = input_tensor_a.dtype - output_tensor = None - if compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: - output_tensor = torch.mm(input_tensor_a, input_tensor_b) - else: - raise Exception(f"GemmOp with dtype {compute_dtype} is not implemented") - return output_tensor + # create input tensors + left_tensor = torch.randint(0, 7, a_shape, dtype=torch_dtype, device=xpu_device) + right_tensor = torch.randint(0, 7, b_shape, dtype=torch_dtype, device=xpu_device) + # create output tensors + output_tensor = torch.randint(0, 7, d_shape, dtype=torch_dtype, device=xpu_device) -class GemvOp(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def compute_size(self, input_shapes, dtype): - # input_shapes: [[M, K], [K, N]] - torch_dtype = getattr(torch, dtype) - a_shape, b_shape = input_shapes - M, K = a_shape - K, N = b_shape - d_shape = [M, N] - dtype_size = get_dtype_bytes(dtype) - input_element_num = sum([math.prod(shape) for shape in [a_shape, b_shape]]) - output_element_num = sum([math.prod(shape) for shape in [d_shape]]) - if torch_dtype == torch.int8: - bytes_per_cnt = dtype_size * input_element_num + get_dtype_bytes("float32") * output_element_num - else: - bytes_per_cnt = dtype_size * (input_element_num + output_element_num) - return bytes_per_cnt + left_tensors.append(left_tensor) + right_tensors.append(right_tensor) + output_tensors.append(output_tensor) - def forward(self, input_tensor_a, input_tensor_b): - compute_dtype = input_tensor_a.dtype - output_tensor = None - if compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: - output_tensor = torch.mm(input_tensor_a, input_tensor_b) - else: - raise Exception(f"GemvOp with dtype {compute_dtype} is not implemented") - return output_tensor + return [left_tensors, right_tensors, output_tensors] -class BatchGemmOp(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def compute_size(self, input_shapes, dtype): - # input_shapes: [[bs, M, K], [bs, K, N]] - torch_dtype = getattr(torch, dtype) - a_shape, b_shape = input_shapes - bs, M, K = a_shape - bs, K, N = b_shape - d_shape = [bs, M, N] - dtype_size = get_dtype_bytes(dtype) - input_element_num = sum([math.prod(shape) for shape in [a_shape, b_shape]]) - output_element_num = sum([math.prod(shape) for shape in [d_shape]]) - if torch_dtype == torch.int8: - bytes_per_cnt = dtype_size * input_element_num + get_dtype_bytes("int32") * output_element_num * 2 - else: - bytes_per_cnt = dtype_size * (input_element_num + output_element_num) - return bytes_per_cnt - def forward(self, input_tensor_a, input_tensor_b): - compute_dtype = input_tensor_a.dtype - output_tensor = None - if compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: - output_tensor = torch.bmm(input_tensor_a, input_tensor_b) - else: - raise Exception(f"BatchGemmOp with dtype {compute_dtype} is not implemented") - return output_tensor +def sin_compute_size(input_shapes, torch_dtype): + a_shape, = input_shapes + c_shape = a_shape + input_element_num = sum([math.prod(shape) for shape in [a_shape]]) + output_element_num = sum([math.prod(shape) for shape in [c_shape]]) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + input_tensor_size = dtype_size * input_element_num + output_tensor_size = dtype_size * output_element_num + batch_size = c_shape[0] + tensor_size = input_tensor_size + output_tensor_size + return batch_size, tensor_size, input_tensor_size, output_tensor_size -class GroupGemmOp(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def compute_size(self, input_shapes, dtype): - """ - [ - [[M1, K1], [K1, N1]], - [[M2, K2], [K2, N2]] - ] - """ - torch_dtype = getattr(torch, dtype) - bytes_per_cnt = 0 - for problem_shape in input_shapes: - a_shape, b_shape = problem_shape - M, K = a_shape - K, N = b_shape - d_shape = [M, N] - dtype_size = get_dtype_bytes(dtype) - input_element_num = sum([math.prod(shape) for shape in [a_shape, b_shape]]) - output_element_num = sum([math.prod(shape) for shape in [d_shape]]) - if torch_dtype == torch.int8: - bytes_per_cnt += dtype_size * input_element_num + get_dtype_bytes("float32") * output_element_num - else: - bytes_per_cnt += dtype_size * (input_element_num + output_element_num) - return bytes_per_cnt - - def custom_create_tensors(self, input_shapes, torch_dtype, xpu_device): - """ - [ - [[M1, K1], [K1, N1]], - [[M2, K2], [K2, N2]] - ] - """ - left_tensors = [] - right_tensors = [] - - for problem_shape in input_shapes: - a_shape, b_shape = problem_shape - if torch_dtype in [torch.int8, torch.int32]: - left_tensor = torch.randint(-3, 3, size=a_shape, dtype=torch_dtype, device=xpu_device) - right_tensor = torch.randint(-3, 3, size=b_shape, dtype=torch_dtype, device=xpu_device) - else: - left_tensor = torch.randn(a_shape, dtype=torch_dtype, device=xpu_device) - right_tensor = torch.randn(b_shape, dtype=torch_dtype, device=xpu_device) - left_tensors.append(left_tensor) - right_tensors.append(right_tensor) +def sin_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, = input_shapes + c_shape = a_shape - return [left_tensors, right_tensors] + # create input tensors + a_tensor = torch.randint(0, 7, a_shape, dtype=torch_dtype, device=xpu_device) - def forward(self, input_tensor_a, input_tensor_b): - compute_dtype = input_tensor_a[0].dtype - output_tensor_list = [] - for a, b in zip(input_tensor_a, input_tensor_b): - if compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: - output_tensor = torch.mm(a, b) - output_tensor_list.append(output_tensor) - else: - raise Exception(f"GroupGemmOp with dtype {compute_dtype} is not implemented") - return output_tensor_list + # create output tensors + c_tensor = torch.randint(0, 7, c_shape, dtype=torch_dtype, device=xpu_device) + return [a_tensor, c_tensor] +def cast_compute_size(input_shapes, torch_dtype): + a_shape, = input_shapes + c_shape = a_shape + input_element_num = sum([math.prod(shape) for shape in [a_shape]]) + output_element_num = sum([math.prod(shape) for shape in [c_shape]]) -""" -communication ops -""" -class Host2DeviceOp(torch.nn.Module): - def __init__(self): - super().__init__() + if torch_dtype == torch.float32: + dst_torch_dtype = torch.bfloat16 + elif torch_dtype == torch.bfloat16 or torch_dtype == torch.float16: + dst_torch_dtype = torch.float32 + elif torch_dtype == torch.int8: + dst_torch_dtype = torch.int32 + else: + dst_torch_dtype = torch_dtype - def custom_create_tensors(self, input_shapes, torch_dtype, xpu_device): - host_tensor = torch.zeros(input_shapes[0], dtype=torch_dtype, device="cpu").pin_memory() - device_tensor = host_tensor.to(xpu_device) - return [host_tensor, device_tensor] + src_dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + dst_dtype_size = torch.tensor([], dtype=dst_torch_dtype).element_size() - def forward(self, host_tensor, device_tensor): - device_tensor.copy_(host_tensor) - return device_tensor + input_tensor_size = src_dtype_size * input_element_num + output_tensor_size = dst_dtype_size * output_element_num + batch_size = c_shape[0] + tensor_size = input_tensor_size + output_tensor_size + return batch_size, tensor_size, input_tensor_size, output_tensor_size -class Device2HostOp(torch.nn.Module): - def __init__(self): - super().__init__() +def cast_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, = input_shapes + c_shape = a_shape - def custom_create_tensors(self, input_shapes, torch_dtype, xpu_device): - host_tensor = torch.zeros(input_shapes[0], dtype=torch_dtype, device="cpu").pin_memory() - device_tensor = host_tensor.to(xpu_device) - return [device_tensor, host_tensor] + if torch_dtype == torch.float32: + dst_torch_dtype = torch.bfloat16 + elif torch_dtype == torch.bfloat16 or torch_dtype == torch.float16: + dst_torch_dtype = torch.float32 + elif torch_dtype == torch.int8: + dst_torch_dtype = torch.int32 + else: + dst_torch_dtype = torch_dtype - def forward(self, device_tensor, host_tensor): - host_tensor.copy_(device_tensor) - return host_tensor + # create input tensors + a_tensor = torch.randint(0, 7, a_shape, dtype=torch_dtype, device=xpu_device) + # create output tensors + c_tensor = torch.randint(0, 7, c_shape, dtype=dst_torch_dtype, device=xpu_device) -class AllReduceOp(torch.nn.Module): - def __init__(self, group): - super().__init__() - self.group = group + return [a_tensor, c_tensor] - def forward(self, input_tensors): - dist.all_reduce(input_tensors, group=self.group) - return True +def swiglu_compute_size(input_shapes, torch_dtype): + a_shape, = input_shapes + batch_size, hidden_size = a_shape -class AllGatherOp(torch.nn.Module): - def __init__(self, group): - super().__init__() - self.group = group + input_tensor_shape = [batch_size, hidden_size] + output_tensor_shape = [batch_size, hidden_size] - def process_inputs(self, input_tensors): - input_tensor_list = list( - torch.chunk(input_tensors, dist.get_world_size(self.group)) - ) - return [input_tensor_list] + input_element_num = sum([math.prod(shape) for shape in [input_tensor_shape]]) + output_element_num = sum([math.prod(shape) for shape in [output_tensor_shape]]) - def forward(self, input_tensor_list): - dist.all_gather( - input_tensor_list, - input_tensor_list[dist.get_rank(self.group)], - group=self.group, - ) - return True + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + input_tensor_size = dtype_size * input_element_num + output_tensor_size = dtype_size * output_element_num + tensor_size = input_tensor_size + output_tensor_size + return batch_size, tensor_size, input_tensor_size, output_tensor_size -class ReduceScatterOp(torch.nn.Module): - def __init__(self, group): - super().__init__() - self.group = group +def swiglu_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, = input_shapes + batch_size, hidden_size = a_shape - def process_inputs(self, input_tensors): - input_tensor_list = list( - torch.chunk(input_tensors, dist.get_world_size(self.group)) - ) - return [input_tensor_list] + input_tensor_shape = [batch_size, hidden_size] + output_tensor_shape = [batch_size, hidden_size] - def forward(self, input_tensor_list): - dist.reduce_scatter( - input_tensor_list[dist.get_rank(self.group)], - input_tensor_list, - group=self.group, - ) - return True + # create input tensors + input_tensor = torch.randint(0, 7, input_tensor_shape, dtype=torch_dtype, device=xpu_device) + # create output tensors + output_tensor = torch.randint(0, 7, output_tensor_shape, dtype=torch_dtype, device=xpu_device) -class AllToAllOp(torch.nn.Module): - def __init__(self, group): - super().__init__() - self.group = group + return [input_tensor, output_tensor] - def process_inputs(self, input_tensor, output_tensor): - input_tensor_list = list( - torch.chunk(input_tensor, dist.get_world_size(self.group)) - ) - output_tensor_list = list( - torch.chunk(output_tensor, dist.get_world_size(self.group)) - ) - return [input_tensor_list, output_tensor_list] - def forward(self, in_tensors_list, out_tensors_list): - dist.all_to_all(out_tensors_list, in_tensors_list, group=self.group) - return True +def add_compute_size(input_shapes, torch_dtype): + a_shape, b_shape = input_shapes + c_shape = a_shape + batch_size, hidden_size = a_shape -class BroadcastOp(torch.nn.Module): - def __init__(self, group): - super().__init__() - self.group = group + input_element_num = sum([math.prod(shape) for shape in [a_shape, b_shape]]) + output_element_num = sum([math.prod(shape) for shape in [c_shape]]) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() - def forward(self, input_tensors): - dist.broadcast(input_tensors, 0, self.group) - return True + input_tensor_size = dtype_size * input_element_num + output_tensor_size = dtype_size * output_element_num + tensor_size = input_tensor_size + output_tensor_size + return batch_size, tensor_size, input_tensor_size, output_tensor_size +def add_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, b_shape = input_shapes + c_shape = a_shape -class P2POp(torch.nn.Module): - def __init__(self, group, ranks, rank): - super().__init__() - self.group = group - self.group_size = self.group.size() - self.rank = rank - self.ranks = ranks - self.rank_size = len(ranks) + # create input tensors + a_tensor = torch.randint(0, 7, a_shape, dtype=torch_dtype, device=xpu_device) + b_tensor = torch.randint(0, 7, b_shape, dtype=torch_dtype, device=xpu_device) - def next_rank(self): - return self.ranks[(self.rank + 1) % self.rank_size] + # create output tensors + c_tensor = torch.randint(0, 7, c_shape, dtype=torch_dtype, device=xpu_device) + return [a_tensor, b_tensor, c_tensor] - def prev_rank(self): - return self.ranks[(self.rank - 1) % self.rank_size] - def forward(self, send_tensor, recv_tensor): - reqs = [] - if self.rank != (self.group_size - 1): - send_req = dist.isend(send_tensor, self.next_rank(), self.group) - reqs.append(send_req) - if self.rank != 0: - recv_req = dist.irecv(recv_tensor, self.prev_rank(), self.group) - reqs.append(recv_req) +def layer_norm_compute_size(input_shapes, torch_dtype): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + c_shape = a_shape + w_shape = a_shape[-1:] - for req in reqs: - req.wait() - return True + input_element_num = sum([math.prod(shape) for shape in [a_shape, w_shape]]) + output_element_num = sum([math.prod(shape) for shape in [c_shape]]) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + input_tensor_size = dtype_size * input_element_num + output_tensor_size = dtype_size * output_element_num + tensor_size = input_tensor_size + output_tensor_size + return batch_size, tensor_size, input_tensor_size, output_tensor_size -class SinOp(torch.nn.Module): - def __init__(self): - super().__init__() +def layer_norm_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + c_shape = a_shape + w_shape = a_shape[-1:] - def forward(self, input_tensors): - result = torch.sin(input_tensors) - return result + # create input tensors + a_tensor = torch.randint(0, 7, a_shape, dtype=torch_dtype, device=xpu_device) + # create output tensors + c_tensor = torch.randint(0, 7, c_shape, dtype=torch_dtype, device=xpu_device) -class CosOp(torch.nn.Module): - def __init__(self): - super().__init__() + # create weight tensors + w_tensor = torch.randint(0, 7, w_shape, dtype=torch_dtype, device=xpu_device) - def forward(self, input_tensors): - result = torch.cos(input_tensors) - return result + return [a_tensor, c_tensor, w_tensor] -class ExpOp(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, input_tensors): - result = torch.exp(input_tensors) - return result +def softmax_compute_size(input_shapes, torch_dtype): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + c_shape = a_shape + input_element_num = sum([math.prod(shape) for shape in [a_shape]]) + output_element_num = sum([math.prod(shape) for shape in [c_shape]]) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() -class ExponentialOp(torch.nn.Module): - def __init__(self): - super().__init__() + input_tensor_size = dtype_size * input_element_num + output_tensor_size = dtype_size * output_element_num + tensor_size = input_tensor_size + output_tensor_size + return batch_size, tensor_size, input_tensor_size, output_tensor_size - def forward(self, input_tensors): - result = input_tensors.exponential_() - return result +def softmax_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + c_shape = a_shape + # create input tensors + a_tensor = torch.randint(0, 7, a_shape, dtype=torch_dtype, device=xpu_device) -class SiluOp(torch.nn.Module): - def __init__(self): - super().__init__() + # create output tensors + c_tensor = torch.randint(0, 7, c_shape, dtype=torch_dtype, device=xpu_device) + return [a_tensor, c_tensor] - def forward(self, input_tensors): - result = torch.nn.functional.silu(input_tensors) - return result -class GeluOp(torch.nn.Module): - def __init__(self): - super().__init__() +def reduce_sum_compute_size(input_shapes, torch_dtype): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + c_shape = [batch_size, 1] - def forward(self, input_tensors): - result = torch.nn.functional.gelu(input_tensors) - return result + input_element_num = sum([math.prod(shape) for shape in [a_shape]]) + output_element_num = sum([math.prod(shape) for shape in [c_shape]]) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + input_tensor_size = dtype_size * input_element_num + output_tensor_size = dtype_size * output_element_num + tensor_size = input_tensor_size + output_tensor_size + return batch_size, tensor_size, input_tensor_size, output_tensor_size -class SwiGLUOp(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.w = 1 - self.v = 2 +def reduce_sum_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + c_shape = [batch_size, 1] - def forward(self, input_tensors): - result = (torch.nn.functional.sigmoid(input_tensors) * self.w) + (input_tensors * self.v) - return result + # create input tensors + a_tensor = torch.randint(0, 7, a_shape, dtype=torch_dtype, device=xpu_device) + # create output tensors + c_tensor = torch.randint(0, 7, c_shape, dtype=torch_dtype, device=xpu_device) -class CastOp(torch.nn.Module): - def __init__(self): - super().__init__() - - def set_dtype(self, src_dtype: str): - target_dtype = "bfloat16" if src_dtype == "float32" else "float32" - self.target_dtype = target_dtype - self.target_torch_dtype = getattr(torch, target_dtype) - - def compute_size(self, input_shapes, dtype): - torch_dtype = getattr(torch, dtype) - self.set_dtype(dtype) - dtype_size = get_dtype_bytes(dtype) - target_dtype_size = get_dtype_bytes(self.target_dtype) - element_num = sum([math.prod(shape) for shape in input_shapes]) - bytes_per_cnt = dtype_size * element_num + target_dtype_size * element_num - return bytes_per_cnt + return [a_tensor, c_tensor] - def forward(self, input_tensors): - result = input_tensors.to(self.target_torch_dtype) - return result +def reduce_min_compute_size(input_shapes, torch_dtype): + a_shape, = input_shapes + batch_size, hidden_size = a_shape -class LogOp(torch.nn.Module): - def __init__(self): - super().__init__() + values_shape = [batch_size, 1] + indices_shape = [batch_size, 1] + + input_element_num = sum([math.prod(shape) for shape in [a_shape]]) + values_element_num = sum([math.prod(shape) for shape in [values_shape]]) + indices_element_num = sum([math.prod(shape) for shape in [indices_shape]]) + + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + indices_dtype_size = torch.tensor([], dtype=torch.int64).element_size() - def forward(self, input_tensors): - result = torch.log(input_tensors) - return result + input_tensor_size = dtype_size * input_element_num + output_tensor_size = dtype_size * values_element_num + indices_dtype_size * indices_element_num + tensor_size = input_tensor_size + output_tensor_size + return batch_size, tensor_size, input_tensor_size, output_tensor_size +def reduce_min_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + values_shape = [batch_size, 1] + indices_shape = [batch_size, 1] -class SqrtOp(torch.nn.Module): - def __init__(self): - super().__init__() + # create input tensors + a_tensor = torch.randint(0, 7, a_shape, dtype=torch_dtype, device=xpu_device) - def forward(self, input_tensors): - result = torch.sqrt(input_tensors) - return result + # create output tensors + values_tensor = torch.randint(0, 7, values_shape, dtype=torch_dtype, device=xpu_device) + indices_tensor = torch.randint(0, 7, indices_shape, dtype=torch.int64, device=xpu_device) + return [a_tensor, values_tensor, indices_tensor] -class AddOp(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, input_tensor_a, input_tensor_b): - result = input_tensor_a + input_tensor_b - return result -class MulOp(torch.nn.Module): - def __init__(self): - super().__init__() +def index_add_compute_size(input_shapes, torch_dtype): + # src_tensor -->(index_tensor) dst_tensor + dst_shape, src_shape = input_shapes - def forward(self, input_tensor_a, input_tensor_b): - result = input_tensor_a * input_tensor_b - return result + src_batch_size = src_shape[0] + dst_batch_size = dst_shape[0] + index_shape = [src_batch_size] -class SubOp(torch.nn.Module): - def __init__(self): - super().__init__() + + src_element_num = sum([math.prod(shape) for shape in [src_shape]]) + index_element_num = sum([math.prod(shape) for shape in [index_shape]]) - def forward(self, input_tensor_a, input_tensor_b): - result = input_tensor_a - input_tensor_b - return result + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + index_dtype_size = torch.tensor([], dtype=torch.int64).element_size() + src_tensor_size = dtype_size * src_element_num + index_tensor_size = index_dtype_size * index_element_num -class DivOp(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, input_tensor_a, input_tensor_b): - result = input_tensor_a / input_tensor_b - return result + input_tensor_size = 2 * src_tensor_size + index_tensor_size + output_tensor_size = src_tensor_size + tensor_size = input_tensor_size + output_tensor_size + return src_batch_size, tensor_size, input_tensor_size, output_tensor_size -class LayerNormOp(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, input_tensors): - result = torch.nn.functional.layer_norm( - input_tensors, (input_tensors.shape[-1],) - ) - return result +def index_add_create_tensors(input_shapes, torch_dtype, xpu_device): + # src_tensor -->(index_tensor) dst_tensor + dst_shape, src_shape = input_shapes + src_batch_size = src_shape[0] + dst_batch_size = dst_shape[0] -class SoftmaxOp(torch.nn.Module): - def __init__(self): - super().__init__() + index_shape = [src_batch_size] - def forward(self, input_tensors): - result = torch.nn.functional.softmax(input_tensors, dim=-1) - return result + # create output tensors + dst_tensor = torch.randint(0, 7, dst_shape, dtype=torch_dtype, device=xpu_device) + # create input tensors + src_tensor = torch.randint(0, 7, src_shape, dtype=torch_dtype, device=xpu_device) + index_tensor = torch.randint(0, dst_batch_size, index_shape, dtype=torch.int64, device=xpu_device) + + return [dst_tensor, src_tensor, index_tensor] -class ReduceSumOp(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, input_tensors): - result = torch.sum(input_tensors, dim=-1) - return result +def sort_compute_size(input_shapes, torch_dtype): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + c_shape = a_shape + input_element_num = sum([math.prod(shape) for shape in [a_shape]]) + output_element_num = sum([math.prod(shape) for shape in [c_shape]]) + indice_element_num = output_element_num -class ReduceMinOp(torch.nn.Module): - def __init__(self): - super().__init__() + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + indice_dtype_size = torch.tensor([], dtype=torch.int64).element_size() - def forward(self, input_tensors): - result = torch.min(input_tensors, dim=-1) - return result + input_tensor_size = dtype_size * input_element_num + output_tensor_size = dtype_size * output_element_num + indice_dtype_size * indice_element_num + tensor_size = input_tensor_size + output_tensor_size + return batch_size, tensor_size, input_tensor_size, output_tensor_size -class ReduceMaxOp(torch.nn.Module): - def __init__(self): - super().__init__() +def sort_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + c_shape = a_shape - def forward(self, input_tensors): - result = torch.max(input_tensors, dim=-1) - return result + # create input tensors + a_tensor = torch.randint(0, 7, a_shape, dtype=torch_dtype, device=xpu_device) + # create output tensors + c_tensor = torch.randint(0, 7, c_shape, dtype=torch_dtype, device=xpu_device) + indice_tensor = torch.randint(0, 7, c_shape, dtype=torch.int64, device=xpu_device) + return [a_tensor, c_tensor, indice_tensor] -class IndexAddOp(torch.nn.Module): - def __init__(self): - super().__init__() - def process_inputs(self, input_tensor, source_tensor): - index = torch.randint(0, input_tensor.shape[0], (source_tensor.shape[0],)).to( - input_tensor.device - ) - return [input_tensor, index, source_tensor] +def unique_compute_size(input_shapes, torch_dtype): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + c_shape = a_shape - def forward(self, input_tensor, index, source_tensor): - result = input_tensor.index_add_(0, index, source_tensor) - return result + input_element_num = sum([math.prod(shape) for shape in [a_shape]]) + output_element_num = sum([math.prod(shape) for shape in [c_shape]]) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + indice_dtype_size = torch.tensor([], dtype=torch.int64).element_size() -class SortOp(torch.nn.Module): - def __init__(self): - super().__init__() + input_tensor_size = dtype_size * input_element_num + output_tensor_size = dtype_size * output_element_num + indice_dtype_size * output_element_num + tensor_size = input_tensor_size + output_tensor_size + return batch_size, tensor_size, input_tensor_size, output_tensor_size - def forward(self, input_tensors): - result = torch.sort(input_tensors) - return result +def unique_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + c_shape = a_shape + # create input tensors + torch.manual_seed(1) + a_tensor = torch.randint(0, 1024, a_shape, dtype=torch_dtype, device="cpu").to(device=xpu_device) -class UniqueOp(torch.nn.Module): - def __init__(self): - super().__init__() + # create output tensors + c_tensor = torch.empty(c_shape, dtype=torch_dtype, device=xpu_device) + count_tensor = torch.empty(c_shape, dtype=torch.int64, device=xpu_device) + return [a_tensor, c_tensor, count_tensor] - def forward(self, input_tensors): - result = torch.unique(input_tensors, return_counts=True) - return result -class ScatterOp(torch.nn.Module): - def __init__(self): - super().__init__() - def compute_size(self, input_shapes, dtype): - # dst: [batch_size, len], dtype - # index: [batch_size, len], int64 - # src: [batch_size, len], dtype - tensor_shape = input_shapes[0] +def scatter_compute_size(input_shapes, torch_dtype): + tensor_shape = input_shapes[0] + batch_size, hidden_size = tensor_shape + index_shape = [batch_size] - tensor_dtype_size = get_dtype_bytes(dtype) - index_dtype_size = get_dtype_bytes("int64") + input_element_num = sum([math.prod(shape) for shape in [tensor_shape]]) + output_element_num = sum([math.prod(shape) for shape in [tensor_shape]]) + index_element_num = sum([math.prod(shape) for shape in [index_shape]]) - shape_func = lambda shape: math.prod(shape) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + index_dtype_size = torch.tensor([], dtype=torch.int64).element_size() - bytes_per_cnt = ( - shape_func(tensor_shape) * tensor_dtype_size - + shape_func(tensor_shape) * index_dtype_size - + shape_func(tensor_shape) * tensor_dtype_size - ) - - return bytes_per_cnt + input_element_num = dtype_size * input_element_num + index_dtype_size * index_element_num + output_element_num = dtype_size * output_element_num + tensor_size = input_element_num + output_element_num + return batch_size, tensor_size, input_element_num, output_element_num - def custom_create_tensors(self, input_shapes, torch_dtype, xpu_device): - # dst: [batch_size, len], dtype - # index: [batch_size, len], int64 - # src: [batch_size, len], dtype - tensor_shape = input_shapes[0] +def scatter_create_tensors(input_shapes, torch_dtype, xpu_device): + tensor_shape = input_shapes[0] + batch_size, hidden_size = tensor_shape + index_shape = [batch_size] - dst_tensor = torch.empty(tensor_shape, dtype=torch_dtype, device=xpu_device) - src_tensor = torch.empty(tensor_shape, dtype=torch_dtype, device=xpu_device) + # create output tensors + dst_tensor = torch.randint(0, 7, tensor_shape, dtype=torch_dtype, device=xpu_device) - # dim = 0 - # dst[index[i, j], j] = src[i, j] - batch_size = tensor_shape[0] - tensor_len = tensor_shape[1] + # create input tensors + src_tensor = torch.randint(0, 7, tensor_shape, dtype=torch_dtype, device=xpu_device) + index = [i for i in range(batch_size)] + random.shuffle(index) + index_tensor = torch.tensor(index, dtype=torch.int64, device=xpu_device) + index_tensor = index_tensor.reshape(-1, 1).expand(-1, hidden_size) - index = [i for i in range(batch_size)] - random.shuffle(index) - index_tensor = torch.tensor(index, dtype=torch.int64, device=xpu_device) - index_tensor = index_tensor.reshape(-1, 1).expand(-1, tensor_len) + return [dst_tensor, src_tensor, index_tensor] - return [dst_tensor, index_tensor, src_tensor] - def forward(self, dst_tensor, index_tensor, src_tensor): - dst_tensor.scatter_(0, index_tensor, src_tensor) - return dst_tensor -class GatherOp(torch.nn.Module): - def __init__(self): - super().__init__() - def compute_size(self, input_shapes, dtype): - # dst: [batch_size, len], dtype - # index: [batch_size, len], int64 - # src: [batch_size, len], dtype - tensor_shape = input_shapes[0] - tensor_dtype_size = get_dtype_bytes(dtype) - index_dtype_size = get_dtype_bytes("int64") - shape_func = lambda shape: math.prod(shape) +def host2device_compute_size(input_shapes, torch_dtype): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + + output_element_num = sum([math.prod(shape) for shape in [a_shape]]) + + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + output_tensor_size = dtype_size * output_element_num + + tensor_size = output_tensor_size + return batch_size, tensor_size, 0, output_tensor_size + +def host2device_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + + host_tensor = torch.empty(a_shape, dtype=torch_dtype, device="cpu").pin_memory() + device_tensor = torch.empty(a_shape, dtype=torch_dtype, device=xpu_device) + + return [host_tensor, device_tensor] + + +def allreduce_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, = input_shapes + a_tensor = torch.zeros(a_shape, dtype=torch_dtype, device=xpu_device) + return [a_tensor] + + +def allgather_compute_size(input_shapes, torch_dtype): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + + output_element_num = sum([math.prod(shape) for shape in [a_shape]]) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + output_tensor_size = dtype_size * output_element_num + tensor_size = output_tensor_size + return batch_size, tensor_size, 0, output_tensor_size + + + +def allgather_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, = input_shapes + batch_size, hidden_size = a_shape + + world_size = dist.get_world_size() + tensor = torch.empty([batch_size, hidden_size], dtype=torch_dtype, device=xpu_device) + tensors = list(torch.chunk(tensor, world_size, dim=0)) + + return [tensors] + + +def alltoall_compute_size(input_shapes, torch_dtype): + a_shape, b_shape = input_shapes + batch_size, hidden_size = a_shape + + world_size = dist.get_world_size() + output_element_num = sum([math.prod(shape) for shape in [a_shape]]) * 2 + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + output_tensor_size = dtype_size * output_element_num + tensor_size = output_tensor_size + return batch_size, tensor_size, 0, output_tensor_size + + +def alltoall_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, b_shape = input_shapes + batch_size, hidden_size = a_shape + + world_size = dist.get_world_size() + input_tensor = torch.empty([batch_size, hidden_size], dtype=torch_dtype, device=xpu_device) + input_tensors = list(torch.chunk(input_tensor, world_size, dim=0)) - bytes_per_cnt = ( - shape_func(tensor_shape) * tensor_dtype_size - + shape_func(tensor_shape) * index_dtype_size - + shape_func(tensor_shape) * tensor_dtype_size + output_tensor = torch.empty([batch_size, hidden_size], dtype=torch_dtype, device=xpu_device) + output_tensors = list(torch.chunk(output_tensor, world_size, dim=0)) + + return [input_tensors, output_tensors] + + +def p2p_compute_size(input_shapes, torch_dtype): + a_shape, b_shape = input_shapes + batch_size, hidden_size = a_shape + + input_element_num = sum([math.prod(shape) for shape in [a_shape]]) + output_element_num = sum([math.prod(shape) for shape in [b_shape]]) + + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + input_tensor_size = dtype_size * input_element_num + output_tensor_size = dtype_size * output_element_num + + tensor_size = input_tensor_size + output_tensor_size + return batch_size, tensor_size, input_tensor_size, output_tensor_size + +def p2p_create_tensors(input_shapes, torch_dtype, xpu_device): + a_shape, b_shape = input_shapes + batch_size, hidden_size = a_shape + + a_tensor = torch.empty(a_shape, dtype=torch_dtype, device=xpu_device) + b_tensor = torch.empty(b_shape, dtype=torch_dtype, device=xpu_device) + + return [a_tensor, b_tensor] + + + +""" +gemm ops +""" +class GemmOp(torch.nn.Module): + def forward(self, input_tensor_a, input_tensor_b, input_tensor_d): + compute_dtype = input_tensor_a.dtype + if compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: + torch.mm(input_tensor_a, input_tensor_b, out=input_tensor_d) + else: + raise Exception(f"GemmOp with dtype {compute_dtype} is not implemented") + +class BatchGemmOp(torch.nn.Module): + def forward(self, input_tensor_a, input_tensor_b, input_tensor_d): + compute_dtype = input_tensor_a.dtype + if compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: + torch.bmm(input_tensor_a, input_tensor_b, out=input_tensor_d) + else: + raise Exception(f"BatchGemmOp with dtype {compute_dtype} is not implemented") + +class GroupGemmOp(torch.nn.Module): + def forward(self, input_tensor_a, input_tensor_b, input_tensor_d): + compute_dtype = input_tensor_a[0].dtype + for a, b, d in zip(input_tensor_a, input_tensor_b, input_tensor_d): + if compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: + torch.mm(a, b, out=d) + else: + raise Exception(f"GroupGemmOp with dtype {compute_dtype} is not implemented") + + + +""" +unary ops +""" +class SinOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor): + torch.sin(input_tensor, out=output_tensor) + +class CosOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor): + torch.cos(input_tensor, out=output_tensor) + +class ExpOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor): + torch.exp(input_tensor, out=output_tensor) + +class ExponentialOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor): + input_tensor.exponential_() + +class LogOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor): + torch.log(input_tensor, out=output_tensor) + +class SqrtOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor): + torch.sqrt(input_tensor, out=output_tensor) + +class CastOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor): + output_tensor = input_tensor.to(output_tensor.dtype) + +class SiluOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor): + output_tensor = torch.nn.functional.silu(input_tensor) + +class GeluOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor): + output_tensor = torch.nn.functional.gelu(input_tensor) + +class SwiGLUOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor): + torch.mul(torch.nn.functional.silu(input_tensor), input_tensor, out=output_tensor) + +""" +Binary ops +""" +class AddOp(torch.nn.Module): + def forward(self, input_tensor_a, input_tensor_b, input_tensor_c): + torch.add(input_tensor_a, input_tensor_b, out=input_tensor_c) + +class MulOp(torch.nn.Module): + def forward(self, input_tensor_a, input_tensor_b, input_tensor_c): + torch.mul(input_tensor_a, input_tensor_b, out=input_tensor_c) + +class SubOp(torch.nn.Module): + def forward(self, input_tensor_a, input_tensor_b, input_tensor_c): + torch.sub(input_tensor_a, input_tensor_b, out=input_tensor_c) + +class DivOp(torch.nn.Module): + def forward(self, input_tensor_a, input_tensor_b, input_tensor_c): + torch.div(input_tensor_a, input_tensor_b, out=input_tensor_c) + + + +""" +reduction ops +""" +class LayerNormOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor, weight_tensor): + output_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), weight_tensor) + +class SoftmaxOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor): + output_tensor = torch.nn.functional.softmax(input_tensor, dim=-1, dtype=output_tensor.dtype) + +class ReduceSumOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor): + torch.sum(input_tensor, dim=-1, keepdim=True, dtype=output_tensor.dtype, out=output_tensor) + +class ReduceMinOp(torch.nn.Module): + def forward(self, input_tensor, value_tensor, indice_tensor): + torch.min(input_tensor, dim=-1, keepdim=True, out=(value_tensor, indice_tensor)) + +class ReduceMaxOp(torch.nn.Module): + def forward(self, input_tensor, value_tensor, indice_tensor): + torch.max(input_tensor, dim=-1, keepdim=True, out=(value_tensor, indice_tensor)) + + + +""" +index_ops +""" +class IndexAddOp(torch.nn.Module): + def forward(self, dst_tensor, src_tensor, index_tensor): + dst_tensor.index_add_(0, index_tensor, src_tensor) + +class SortOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor, indice_tensor): + torch.sort(input_tensor, dim=-1, out=(output_tensor, indice_tensor)) + +class UniqueOp(torch.nn.Module): + def forward(self, input_tensor, output_tensor, count_tensor): + output_tensor, count_tensor = torch.unique( + input=input_tensor, + sorted=False, + return_counts=True, + return_inverse=False ) - - return bytes_per_cnt - - def custom_create_tensors(self, input_shapes, torch_dtype, xpu_device): - # dst: [batch_size, len], dtype - # index: [batch_size, len], int64 - # src: [batch_size, len], dtype - tensor_shape = input_shapes[0] - - dst_tensor = torch.empty(tensor_shape, dtype=torch_dtype, device=xpu_device) - src_tensor = torch.empty(tensor_shape, dtype=torch_dtype, device=xpu_device) - - # dim = 0 - # dst[index[i, j], j] = src[i, j] - batch_size = tensor_shape[0] - tensor_len = tensor_shape[1] - - index = [i for i in range(batch_size)] - random.shuffle(index) - index_tensor = torch.tensor(index, dtype=torch.int64, device=xpu_device) - index_tensor = index_tensor.reshape(-1, 1).expand(-1, tensor_len) - - return [dst_tensor, index_tensor, src_tensor] +class ScatterOp(torch.nn.Module): + def forward(self, dst_tensor, src_tensor, index_tensor): + dst_tensor.scatter_(0, index_tensor, src_tensor) - def forward(self, dst_tensor, index_tensor, src_tensor): +class GatherOp(torch.nn.Module): + def forward(self, dst_tensor, src_tensor, index_tensor): torch.gather(src_tensor, 0, index_tensor, out=dst_tensor) - return dst_tensor \ No newline at end of file + + + +""" +h2d_ops +""" +class Host2DeviceOp(torch.nn.Module): + def forward(self, host_tensor, device_tensor): + device_tensor.copy_(host_tensor) + + +class Device2HostOp(torch.nn.Module): + def forward(self, host_tensor, device_tensor): + host_tensor.copy_(device_tensor) + + + +""" +communication ops +""" +class AllReduceOp(torch.nn.Module): + def forward(self, input_tensor): + dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM) + +class AllGatherOp(torch.nn.Module): + def forward(self, input_tensors): + dist.all_gather(input_tensors, input_tensors[dist.get_rank()]) + +class ReduceScatterOp(torch.nn.Module): + def forward(self, input_tensors): + dist.reduce_scatter(input_tensors[dist.get_rank()], input_tensors) + +class AllToAllOp(torch.nn.Module): + def forward(self, input_tensors, output_tensors): + dist.all_to_all(output_tensors, input_tensors) + + +class BroadcastOp(torch.nn.Module): + def forward(self, input_tensor): + dist.broadcast(input_tensor, 0) + + +class P2POp(torch.nn.Module): + def forward(self, send_tensor, recv_tensor): + world_size = dist.get_world_size() + rank = dist.get_rank() + + reqs = [] + if rank != world_size - 1: + reqs.append(dist.isend(send_tensor, (rank + 1) % world_size)) + if rank != 0: + reqs.append(dist.irecv(recv_tensor, (rank - 1 + world_size) % world_size)) + for req in reqs: + req.wait() + + +op_registry = { + # gemm ops + "gemm": GemmOp(), + "gemv": GemmOp(), + "batch_gemm": BatchGemmOp(), + "group_gemm": GroupGemmOp(), + + # unary ops + "sin": SinOp(), + "cos": CosOp(), + "exp": ExpOp(), + "exponential": ExponentialOp(), + "log": LogOp(), + "sqrt": SqrtOp(), + "cast": CastOp(), + "silu": SiluOp(), + "gelu": GeluOp(), + "swiglu": SwiGLUOp(), + + # binary ops + "add": AddOp(), + "sub": SubOp(), + "mul": MulOp(), + "div": DivOp(), + + # reduction ops + "layernorm": LayerNormOp(), + "softmax": SoftmaxOp(), + "reduce_sum": ReduceSumOp(), + "reduce_max": ReduceMaxOp(), + "reduce_min": ReduceMinOp(), + + # index_ops + "index_add": IndexAddOp(), + "sort": SortOp(), + "unique": UniqueOp(), + "scatter": ScatterOp(), + "gather": GatherOp(), + + # h2d_ops + "device2host": Device2HostOp(), + "host2device": Host2DeviceOp(), + + # ccl ops + "broadcast": BroadcastOp(), + "allreduce": AllReduceOp(), + "allgather": AllGatherOp(), + "alltoall": AllToAllOp(), + "reducescatter": ReduceScatterOp(), + "p2p": P2POp(), +} + + +op_compute_size_funcs = { + # gemm_ops + "gemm": gemm_compute_size, + "gemv": gemm_compute_size, + "batch_gemm": batch_gemm_compute_size, + "group_gemm": group_gemm_compute_size, + + # unary_ops + "sin": sin_compute_size, + "cos": sin_compute_size, + "exp": sin_compute_size, + "exponential": sin_compute_size, + "log": sin_compute_size, + "sqrt": sin_compute_size, + "cast": cast_compute_size, + "silu": sin_compute_size, + "gelu": sin_compute_size, + "swiglu": swiglu_compute_size, + + # binary_ops + "add": add_compute_size, + "mul": add_compute_size, + "sub": add_compute_size, + "div": add_compute_size, + + # reduction_ops + "layernorm": layer_norm_compute_size, + "softmax": softmax_compute_size, + "reduce_sum": reduce_sum_compute_size, + "reduce_min": reduce_min_compute_size, + "reduce_max": reduce_min_compute_size, + + # index_ops + "index_add": index_add_compute_size, + "sort": sort_compute_size, + "unique": unique_compute_size, + "scatter": scatter_compute_size, + "gather": scatter_compute_size, + + # h2d_ops + "host2device": host2device_compute_size, + "device2host": host2device_compute_size, + + # ccl_ops + "broadcast": host2device_compute_size, + "allreduce": host2device_compute_size, + "allgather": allgather_compute_size, + "alltoall": alltoall_compute_size, + "reducescatter": allgather_compute_size, + "p2p": p2p_compute_size, +} + +op_create_tensors_funcs = { + # gemm ops + "gemm": gemm_create_tensors, + "gemv": gemm_create_tensors, + "batch_gemm": batch_gemm_create_tensors, + "group_gemm": group_gemm_create_tensors, + + # unary ops + "sin": sin_create_tensors, + "cos": sin_create_tensors, + "exp": sin_create_tensors, + "exponential": sin_create_tensors, + "log": sin_create_tensors, + "sqrt": sin_create_tensors, + "cast": cast_create_tensors, + "silu": sin_create_tensors, + "gelu": sin_create_tensors, + "swiglu": swiglu_create_tensors, + + # binary ops + "add": add_create_tensors, + "mul": add_create_tensors, + "sub": add_create_tensors, + "div": add_create_tensors, + + # reduction ops + "layernorm": layer_norm_create_tensors, + "softmax": softmax_create_tensors, + "reduce_sum": reduce_sum_create_tensors, + "reduce_min": reduce_min_create_tensors, + "reduce_max": reduce_min_create_tensors, + + # index ops + "index_add": index_add_create_tensors, + "sort": sort_create_tensors, + "unique": unique_create_tensors, + "scatter": scatter_create_tensors, + "gather": scatter_create_tensors, + + # h2d_ops + "host2device": host2device_create_tensors, + "device2host": host2device_create_tensors, + + # ccl_ops + "broadcast": allreduce_create_tensors, + "allreduce": allreduce_create_tensors, + "allgather": allgather_create_tensors, + "alltoall": alltoall_create_tensors, + "reducescatter": allgather_create_tensors, + "p2p": p2p_create_tensors, +} diff --git a/byte_micro_perf/backends/utils.py b/byte_micro_perf/backends/utils.py index ab36d0aa..4fa1469c 100644 --- a/byte_micro_perf/backends/utils.py +++ b/byte_micro_perf/backends/utils.py @@ -18,108 +18,29 @@ import numpy as np import torch - -def get_dtype_bytes(dtype: str): - torch_dtype = getattr(torch, dtype) - dtype_size = 0 - if torch_dtype in [torch.int64, torch.int32, torch.int8]: - dtype_size = torch.iinfo(torch_dtype).bits // 8 - elif torch_dtype in [torch.float32, torch.float16, torch.bfloat16]: - dtype_size = torch.finfo(torch_dtype).bits // 8 - else: - # not supported yet - pass - return dtype_size - - -def get_io_amount(op_name, input_shapes, dtype): - batch_size = input_shapes[0][0] - dtype_size = get_dtype_bytes(dtype) - if op_name in ["add", "mul", "sub", "div"]: - # c = a + b - read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) - write_io_amount = dtype_size * math.prod(input_shapes[0]) - elif op_name == "gemm": - M = input_shapes[0][0] - K = input_shapes[0][1] - N = input_shapes[1][1] - read_io_amount = dtype_size * (M * K + K * N) - if dtype != torch.int8: - write_io_amount = dtype_size * (M * N) - else: - write_io_amount = get_dtype_bytes("int32") * (M * N) - elif op_name == "batch_gemm": - bs = input_shapes[0][0] - M = input_shapes[0][1] - K = input_shapes[0][2] - N = input_shapes[1][2] - read_io_amount = dtype_size * bs * (M * K + K * N) - if dtype != torch.int8: - write_io_amount = dtype_size * bs * (M * N) - else: - write_io_amount = get_dtype_bytes("int32") * bs * (M * N) - elif op_name == "group_gemm": - in_size_list = [] - out_size_list = [] - m_list = [] - for problem_shape in input_shapes: - M = problem_shape[0][0] - K = problem_shape[0][1] - N = problem_shape[1][1] - in_size_list.append(M * K + K * N) - out_size_list.append(M * N) - m_list.append(M) - batch_size = sum(m_list) - read_io_amount = dtype_size * sum(in_size_list) - if dtype != torch.int8: - write_io_amount = dtype_size * sum(out_size_list) - else: - write_io_amount = get_dtype_bytes("int32") * sum(out_size_list) - elif op_name in ["device2host"]: - read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) - write_io_amount = 0 - elif op_name in ["host2device"]: - read_io_amount = 0 - write_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) - elif op_name in ["reduce_sum", "reduce_max", "reduce_min"]: - read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) - write_io_amount = dtype_size * sum([math.prod(shape[:-1]) for shape in input_shapes]) - elif op_name in ["unqiue", "sort"]: - read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) - write_io_amount = 2 * dtype_size * sum([math.prod(shape) for shape in input_shapes]) - elif op_name in ["scatter", "gather"]: - tensor_shape = input_shapes[0] - read_io_amount = (dtype_size + get_dtype_bytes("int64")) * math.prod(tensor_shape) - write_io_amount = dtype_size * math.prod(tensor_shape) - elif op_name == "cast": - read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) - write_io_amount = read_io_amount / 2 if dtype == torch.float32 else read_io_amount * 2 - elif op_name in ["index_add"]: - read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) + get_dtype_bytes("int32") * input_shapes[1][0] - write_io_amount = dtype_size * math.prod(input_shapes[0]) - else: - read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) - write_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) - - total_io_amount = read_io_amount + write_io_amount - - return batch_size, total_io_amount, read_io_amount, write_io_amount +from backends import module_store def dump_communication_ops_report( op_name: str, - dtype: str, + torch_dtype, input_shapes: List[List[int]], - group_size: List[int], + compute_size_func, + group_size: int, bandwidth_limit: float, latency: float, error: str = "" ): - size = math.prod(input_shapes[0]) - dtype_size = get_dtype_bytes(dtype) - mb = dtype_size * size / 1024 / 1024 + # get dtype name and dtype_size + dtype_name = str(torch_dtype).split(".")[-1] + + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + element_num = math.prod(input_shapes[0]) + tensor_size = dtype_size * element_num + + mb = tensor_size / 1024 / 1024 if error == "": - algo_bw = dtype_size * size / latency / 1e3 + algo_bw = tensor_size / latency / 1e3 """ allreduce: 2 * (group_size - 1) * (tensor_size / group_size) @@ -129,17 +50,18 @@ def dump_communication_ops_report( broadcast: tensor_size p2p: tensor_size """ - bus_bw = algo_bw * (group_size - 1) / group_size - if op_name in ["broadcast", "p2p"]: + if op_name in ["allgather", "reducescatter", "alltoall"]: + bus_bw = algo_bw * (group_size - 1) / group_size + elif op_name in ["allreduce"]: + bus_bw = 2 * algo_bw * (group_size - 1) / group_size + elif op_name in ["broadcast", "p2p", "device2host", "host2device"]: bus_bw = algo_bw - if op_name == "allreduce": - bus_bw *= 2 bandwidth_utils = None if bandwidth_limit is not None: bandwidth_utils = round((algo_bw / bandwidth_limit) * 1e2, 2) report = { - "Dtype": str(dtype), + "Dtype": str(dtype_name), "Tensor Shapes": input_shapes, "Memory Size(MB)": round(mb, 2), "Group": group_size, @@ -150,7 +72,7 @@ def dump_communication_ops_report( } else: report = { - "Dtype": str(dtype), + "Dtype": str(dtype_name), "Tensor Shapes": input_shapes, "Memory Size(MB)": round(mb, 2), "Group": group_size, @@ -165,27 +87,30 @@ def dump_communication_ops_report( def dump_computation_ops_report( op_name: str, - dtype: str, - input_shapes: List[List[int]], + torch_dtype: str, + input_shapes: List[List[int]], + compute_size_func, bandwidth_limit: float, latency: float, error: str = "" ): - batch_size, total_io_amount, read_io_amount, write_io_amount = get_io_amount(op_name, input_shapes, dtype) - + # get dtype name and dtype_size + dtype_name = str(torch_dtype).split(".")[-1] + batch_size, tensor_size, input_tensor_size, output_tensor_size = compute_size_func(input_shapes, torch_dtype) + if error == "": qps = round(1e6 / latency * batch_size, 2) - algo_bw = total_io_amount / latency / 1e3 + algo_bw = tensor_size / latency / 1e3 bandwidth_utils = None if bandwidth_limit is not None: bandwidth_utils = round((algo_bw / bandwidth_limit) * 1e2, 2) report = { - "Dtype": str(dtype), + "Dtype": str(dtype_name), "Tensor Shapes": input_shapes, - "Read IO Size(MB)": round(read_io_amount / 1024 / 1024, 2), - "Write IO Size(MB)": round(write_io_amount / 1024 / 1024, 2), - "Memory Size(MB)": round(total_io_amount / 1024 / 1024, 2), + "Read IO Size(MB)": round(input_tensor_size / 1024 / 1024, 2), + "Write IO Size(MB)": round(output_tensor_size / 1024 / 1024, 2), + "Memory Size(MB)": round(tensor_size / 1024 / 1024, 2), "Kernel bandwidth(GB/s)": round(algo_bw, 2), "Bandwidth Utilization(%)": bandwidth_utils, "Avg latency(us)": round(latency, 2), @@ -193,11 +118,11 @@ def dump_computation_ops_report( } else: report = { - "Dtype": str(dtype), + "Dtype": str(dtype_name), "Tensor Shapes": input_shapes, - "Read IO Size(MB)": round(read_io_amount / 1024 / 1024, 2), - "Write IO Size(MB)": round(write_io_amount / 1024 / 1024, 2), - "Memory Size(MB)": round(total_io_amount / 1024 / 1024, 2), + "Read IO Size(MB)": round(input_tensor_size / 1024 / 1024, 2), + "Write IO Size(MB)": round(output_tensor_size / 1024 / 1024, 2), + "Memory Size(MB)": round(tensor_size / 1024 / 1024, 2), "Kernel bandwidth(GB/s)": 0, "Bandwidth Utilization(%)": None, "Avg latency(us)": 0, diff --git a/byte_micro_perf/core/perf_engine.py b/byte_micro_perf/core/perf_engine.py index d123f3ef..fbccf47b 100644 --- a/byte_micro_perf/core/perf_engine.py +++ b/byte_micro_perf/core/perf_engine.py @@ -29,6 +29,7 @@ import itertools from collections import namedtuple +import torch.distributed import torch.multiprocessing as mp import virtualenv @@ -116,36 +117,6 @@ def load_workload(task: str, task_dir: str) -> Dict[str, Any]: def parse_workload(workload): shape_list = [] - if "input_shape_list" in workload: - shape_list.extend(workload["input_shape_list"]) - # gemm or batch_gemm - elif "M/K/N" in workload: - if "batch_size" in workload: - for batch_size in workload["batch_size"]: - for M, K, N in workload["M/K/N"]: - shape_list.append([ - [batch_size, M, K], - [batch_size, K, N] - ]) - else: - for M, K, N in workload["M/K/N"]: - shape_list.append([[M, K], [K, N]]) - # group_gemm - elif "MKN_choices" in workload: - seed = workload["seed"] - MKN_list = workload["MKN_choices"] - problems_list = workload["problems"] - - random.seed(seed) - for problems in problems_list: - cur_inputs = [] - for _ in range(problems): - M, K, N = [random.choice(MKN_list) for _ in range(3)] - cur_shapes = [[M, K], [K, N]] - cur_inputs.append(cur_shapes) - shape_list.append(cur_inputs) - - if "input_shape_groups" in workload: input_shape_groups = workload["input_shape_groups"] if isinstance(workload["input_shape_groups"], list) else [workload["input_shape_groups"]] @@ -211,7 +182,7 @@ def parse_workload(workload): -ConfigInstance = namedtuple("ConfigInstance", ["dtype", "tensor_shapes", "index"]) +ConfigInstance = namedtuple("ConfigInstance", ["dtype", "tensor_shapes", "index", "total"]) ResultItem = namedtuple("ResultItem", ["config", "report"]) @@ -257,13 +228,6 @@ def start_engine(self) -> None: self.backend_class = getattr(backend_module, "Backend" + hardware_type) self.backend = self.backend_class(self.workload, self.args.vendor_path) - # start process task - logger.info( - "******************************************* Start to test op: [{}]. *******************************************".format( - self.workload["operator"] - ) - ) - # create output dir based on task # {BYTEMLPERF_ROOT}/byte_micro_perf/reports/{backend_type}/{task_name} hardware_reports_dir = BYTE_MLPERF_ROOT.joinpath( @@ -297,7 +261,7 @@ def start_engine(self) -> None: case_index = 0 for dtype in dtype_list: for shape in shape_list: - test_list.append(ConfigInstance(dtype, shape, case_index)) + test_list.append(ConfigInstance(dtype, shape, case_index + 1, len(dtype_list) * len(shape_list))) case_index = case_index + 1 try: @@ -341,24 +305,60 @@ def signal_handler(signum, frame): ) subprocess_pids = _subprocesses.pids() - logger.info(f"subprocess pids: {subprocess_pids}") - - - - logger.info("waiting for ranks to be ready") for _ in range(instance_num): assert "ready" == output_queues.get() logger.info("all ranks are ready and listening, init done") - start_time = time.perf_counter_ns() - if group == 1: for test_instance in test_list: - input_queues.put(test_instance, True) + input_queues.put(test_instance, False) + for _ in range(instance_num): + input_queues.put(None, False) + + + result_list = [] + if group == 1: for _ in range(instance_num): - input_queues.put("end", True) + result_list.extend(output_queues.get()) + elif group > 1: + result_list.extend(output_queues.get()) + result_list = sorted(result_list, key=lambda x: x.config.index) + + + dtype_results_mapping = {} + for result in result_list: + if result.config.dtype not in dtype_results_mapping: + dtype_results_mapping[result.config.dtype] = [] + dtype_results_mapping[result.config.dtype].append(result) + + for dtype, results in dtype_results_mapping.items(): + dtype_results_mapping[dtype] = sorted(results, key=lambda x: x.config.index) + + base_report = { + "Operator": self.workload["operator"].upper(), + "Backend": self.backend_type, + "Host Info": self.get_cpu_name(), + "Device Info": getattr(self.backend, "get_device_name")(), + "Version": self.version, + "Execution Date": time.strftime("%Y-%m-%d %H:%M:%S"), + "Performance": [result.report for result in dtype_results_mapping[dtype]] + } + + filename = ( + f"result-{str(dtype)}" + + ( + f"-group{group}" + if group > 1 + else "" + ) + + ".json" + ) + filepath = output_dir.joinpath(filename) + with open(filepath, "w") as f: + json.dump(base_report, f, indent=4) + for process in _subprocesses.processes: process.join() @@ -386,9 +386,9 @@ def signal_handler(signum, frame): traceback.print_exc() logger.error(f"Execute task: {self.args.task} failed, group: {group}, error msg: {e}") + current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") with open(f"{hardware_reports_dir}/_run_report.log", "a") as f: - print(f"[error] {self.args.task}, group_size={group}, {current_time}, {duration} s", file=f) - + print(f"[error] {self.args.task}, group_size={group}, {current_time}", file=f) subprocess_pids = [] time.sleep(1) @@ -396,25 +396,22 @@ def signal_handler(signum, frame): if self.args.activate_venv: self.deactivate_venv() + + def perf_func(self, rank: int, *args): - backend_instance = self.backend_class(self.workload, self.args.vendor_path) - op_name = self.workload["operator"] - world_size, group_size, output_dir, test_list, input_queues, output_queues = args + + backend_instance = self.backend_class(self.workload, self.args.vendor_path) + backend_instance.rank = rank + backend_instance.world_size = world_size - # set device accroding to local_rank - set_device_func = getattr(backend_instance, "set_device") - set_device_func(rank) + backend_instance.set_device(rank) - if world_size > 1: - init_ccl_func = getattr(backend_instance, "initialize_ccl") - init_ccl_func(rank, world_size) + if group_size > 1: + backend_instance.initialize_ccl(rank, world_size) - op = getattr(backend_instance, op_name.lower(), None) - if op is not None and callable(op): - op() - else: - raise ValueError(f"Unknown operation: {op_name.lower()}") + op_name = self.workload["operator"] + backend_instance.get_op_instance() output_queues.put("ready") @@ -422,14 +419,12 @@ def perf_func(self, rank: int, *args): if group_size == 1: while True: test_instance = input_queues.get() - if test_instance == "end": + if test_instance is None: break - - - start_time = time.perf_counter_ns() test_dtype = test_instance.dtype test_shape = test_instance.tensor_shapes + """ input_shape could be: List[int]: single shape. cos @@ -448,28 +443,21 @@ def perf_func(self, rank: int, *args): if reports and "Error" not in reports: result_list.append(ResultItem(test_instance, reports)) - duration = (time.perf_counter_ns() - start_time) / 1e9 - duration = round(duration, 3) - print(f"rank {rank}: {test_instance.index + 1} / {len(test_list)}, duration: {duration} s") - - output_result_list = [] - if world_size > 1: - all_gather_object_func = getattr(backend_instance, "all_gather_object") - all_result_list = all_gather_object_func(result_list) - for data in all_result_list: - output_result_list.extend(data) - else: - output_result_list = result_list + latency = reports.get("Avg latency(us)", 0) + kernel_bw = reports.get("Kernel bandwidth(GB/s)", 0) + bus_bw = reports.get("Bus bandwidth(GB/s)", 0) - result_list = sorted(output_result_list, key=lambda x: x.config.index) + print(f"rank {rank}, {test_instance}, latency: {latency}\nkernel_bw: {kernel_bw}, bus_bw: {bus_bw}") + else: + print(f"rank {rank}, {test_instance}, error") - elif group_size > 1: - for i, test_instance in enumerate(test_list): - - start_time = time.perf_counter_ns() + output_queues.put(result_list) + elif group_size > 1: + for test_instance in test_list: test_dtype = test_instance.dtype test_shape = test_instance.tensor_shapes + """ input_shape could be: List[int]: single shape. cos @@ -485,55 +473,24 @@ def perf_func(self, rank: int, *args): logger.error(f"Execute op: {op_name.lower()} failed, input_shape: {test_shape}, dtype: {test_dtype}, error msg: {e}") reports = {} - result_list.append(ResultItem(test_instance, reports)) - - end_time = time.perf_counter_ns() - duration = (end_time - start_time) / 1e9 - duration = round(duration, 3) - - if rank == 0: - print(f"rank {rank}: {test_instance.index + 1} / {len(test_list)}, duration: {duration} s") + if reports and "Error" not in reports: + result_list.append(ResultItem(test_instance, reports)) - if rank == 0: - print(f"{len(result_list)} tasks finished.") + latency = reports.get("Avg latency(us)", 0) + kernel_bw = reports.get("Kernel bandwidth(GB/s)", 0) + bus_bw = reports.get("Bus bandwidth(GB/s)", 0) + if rank == 0: + print(f"rank {rank}, {test_instance}, latency: {latency}\nkernel_bw: {kernel_bw}, bus_bw: {bus_bw}") + else: + if rank == 0: + print(f"rank {rank}, {test_instance}, error") + if rank == 0: + output_queues.put(result_list) - dtype_results_mapping = {} - for result in result_list: - if result.config.dtype not in dtype_results_mapping: - dtype_results_mapping[result.config.dtype] = [] - dtype_results_mapping[result.config.dtype].append(result) + if group_size > 1: + backend_instance.destroy_process_group() - for dtype, results in dtype_results_mapping.items(): - dtype_results_mapping[dtype] = sorted(results, key=lambda x: x.config.index) - - base_report = { - "Operator": op_name.upper(), - "Backend": self.backend_type, - "Host Info": self.get_cpu_name(), - "Device Info": getattr(self.backend, "get_device_name")(), - "Version": self.version, - "Execution Date": time.strftime("%Y-%m-%d %H:%M:%S"), - "Performance": [result.report for result in dtype_results_mapping[dtype]] - } - - filename = ( - f"result-{str(dtype)}" - + ( - f"-group{group_size}" - if group_size > 1 - else "" - ) - + ".json" - ) - filepath = output_dir.joinpath(filename) - with open(filepath, "w") as f: - json.dump(base_report, f, indent=4) - - if world_size > 1: - destroy_group_func = getattr(backend_instance, "destroy_process_group") - destroy_group_func() - return True def activate_venv(self, hardware_type: str) -> bool: if os.path.exists("backends/" + hardware_type + "/requirements.txt"): diff --git a/byte_micro_perf/launch.py b/byte_micro_perf/launch.py index 88599a35..7adc10a3 100644 --- a/byte_micro_perf/launch.py +++ b/byte_micro_perf/launch.py @@ -112,14 +112,14 @@ def parse_task(task_dir): "binary_ops": [], "reduction_ops": [], "index_ops": [], - "ccl_ops": [], - "h2d_ops": [] + "h2d_ops": [], + "ccl_ops": [] } for task in task_list: if task in ["gemm", "gemv", "batch_gemm", "group_gemm"]: task_mapping["gemm_ops"].append(task) - if task in ["sin", "cos", "exp", "exponential", "silu", "gelu", "swiglu", "cast", "log", "sqrt"]: + if task in ["sin", "cos", "exp", "exponential", "log", "sqrt", "cast", "silu", "gelu", "swiglu"]: task_mapping["unary_ops"].append(task) if task in ["add", "mul", "sub", "div"]: @@ -131,12 +131,12 @@ def parse_task(task_dir): if task in ["index_add", "sort", "unique", "gather", "scatter"]: task_mapping["index_ops"].append(task) - if task in ["allgather", "allreduce", "alltoall", "broadcast", "p2p", "reduce_scatter"]: - task_mapping["ccl_ops"].append(task) - if task in ["host2device", "device2host", "device2device"]: task_mapping["h2d_ops"].append(task) - + + if task in ["allgather", "allreduce", "alltoall", "broadcast", "p2p", "reduce_scatter"]: + task_mapping["ccl_ops"].append(task) + if args.show_task_list: logger.info("******************* Supported Task *******************") @@ -226,15 +226,16 @@ def signal_handler(signum, frame): if args.activate_venv: cmds.append("--activate_venv") + print(f"******************************************* Start to test op: [{task}]. *******************************************") process = subprocess.Popen(cmds) subprocess_pid = process.pid - logger.info(f"start subprocess: {subprocess_pid}") ret = process.wait() if ret != 0: failed_ops.append(task) - - subprocess_pid = -1 + print("") + + if failed_ops: logger.error(f"Failed ops: {failed_ops}") diff --git a/byte_micro_perf/workloads/add.json b/byte_micro_perf/workloads/add.json index 5885cc87..c8c07b9f 100644 --- a/byte_micro_perf/workloads/add.json +++ b/byte_micro_perf/workloads/add.json @@ -16,6 +16,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/allgather.json b/byte_micro_perf/workloads/allgather.json index a7d0b0a6..0e8c2221 100644 --- a/byte_micro_perf/workloads/allgather.json +++ b/byte_micro_perf/workloads/allgather.json @@ -4,7 +4,7 @@ "input_shape_groups": { "inputs": [ [ - [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152], [1024] ] ] @@ -12,7 +12,7 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ], "group": [ 2, diff --git a/byte_micro_perf/workloads/allreduce.json b/byte_micro_perf/workloads/allreduce.json index d81356cc..e7dbf96a 100644 --- a/byte_micro_perf/workloads/allreduce.json +++ b/byte_micro_perf/workloads/allreduce.json @@ -4,7 +4,7 @@ "input_shape_groups": { "inputs": [ [ - [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152], [1024] ] ] @@ -12,7 +12,7 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ], "group": [ 2, diff --git a/byte_micro_perf/workloads/alltoall.json b/byte_micro_perf/workloads/alltoall.json index 7550fa71..a97e0e88 100644 --- a/byte_micro_perf/workloads/alltoall.json +++ b/byte_micro_perf/workloads/alltoall.json @@ -4,11 +4,11 @@ "input_shape_groups": { "inputs": [ [ - [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152], [1024] ], [ - [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152], [1024] ] ] @@ -16,7 +16,7 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ], "group": [ 2, diff --git a/byte_micro_perf/workloads/batch_gemm.json b/byte_micro_perf/workloads/batch_gemm.json index f6d72209..c7ca53f2 100644 --- a/byte_micro_perf/workloads/batch_gemm.json +++ b/byte_micro_perf/workloads/batch_gemm.json @@ -20,7 +20,7 @@ "dtype": [ "float32", "bfloat16", - "half", + "float16", "int8" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/broadcast.json b/byte_micro_perf/workloads/broadcast.json index b815360a..a127a6bb 100644 --- a/byte_micro_perf/workloads/broadcast.json +++ b/byte_micro_perf/workloads/broadcast.json @@ -4,7 +4,7 @@ "input_shape_groups": { "inputs": [ [ - [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152], [1024] ] ] @@ -12,7 +12,7 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ], "group": [ 2, diff --git a/byte_micro_perf/workloads/cast.json b/byte_micro_perf/workloads/cast.json index 07ab85dd..50b6e96e 100644 --- a/byte_micro_perf/workloads/cast.json +++ b/byte_micro_perf/workloads/cast.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/cos.json b/byte_micro_perf/workloads/cos.json index 62725bca..18cdc318 100644 --- a/byte_micro_perf/workloads/cos.json +++ b/byte_micro_perf/workloads/cos.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/device2host.json b/byte_micro_perf/workloads/device2host.json index 3bb34dab..382f06b3 100644 --- a/byte_micro_perf/workloads/device2host.json +++ b/byte_micro_perf/workloads/device2host.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/div.json b/byte_micro_perf/workloads/div.json index bb55608b..3e8fe0d1 100644 --- a/byte_micro_perf/workloads/div.json +++ b/byte_micro_perf/workloads/div.json @@ -17,6 +17,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/exp.json b/byte_micro_perf/workloads/exp.json index 17d00d0f..8a2bd601 100644 --- a/byte_micro_perf/workloads/exp.json +++ b/byte_micro_perf/workloads/exp.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/exponential.json b/byte_micro_perf/workloads/exponential.json index 967a58e9..6952ba4e 100644 --- a/byte_micro_perf/workloads/exponential.json +++ b/byte_micro_perf/workloads/exponential.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/gather.json b/byte_micro_perf/workloads/gather.json index 6def4c14..36e6e69b 100644 --- a/byte_micro_perf/workloads/gather.json +++ b/byte_micro_perf/workloads/gather.json @@ -11,7 +11,7 @@ }, "dtype": [ "float32", - "float16", - "bfloat16" + "bfloat16", + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/gelu.json b/byte_micro_perf/workloads/gelu.json index 65574955..cdafbfce 100644 --- a/byte_micro_perf/workloads/gelu.json +++ b/byte_micro_perf/workloads/gelu.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/gemm.json b/byte_micro_perf/workloads/gemm.json index 6c166655..e8330dbe 100644 --- a/byte_micro_perf/workloads/gemm.json +++ b/byte_micro_perf/workloads/gemm.json @@ -21,7 +21,7 @@ "dtype": [ "float32", "bfloat16", - "half", + "float16", "int8" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/gemv.json b/byte_micro_perf/workloads/gemv.json index 2f81b9f5..92706949 100644 --- a/byte_micro_perf/workloads/gemv.json +++ b/byte_micro_perf/workloads/gemv.json @@ -16,7 +16,7 @@ "dtype": [ "float32", "bfloat16", - "half", + "float16", "int8" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/group_gemm.json b/byte_micro_perf/workloads/group_gemm.json index 0b0f0c05..6e37eafb 100644 --- a/byte_micro_perf/workloads/group_gemm.json +++ b/byte_micro_perf/workloads/group_gemm.json @@ -9,7 +9,7 @@ "dtype": [ "float32", "bfloat16", - "half", + "float16", "int8" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/host2device.json b/byte_micro_perf/workloads/host2device.json index 8982c678..8ab49a18 100644 --- a/byte_micro_perf/workloads/host2device.json +++ b/byte_micro_perf/workloads/host2device.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/index_add.json b/byte_micro_perf/workloads/index_add.json index 64744d14..ee4df7dc 100644 --- a/byte_micro_perf/workloads/index_add.json +++ b/byte_micro_perf/workloads/index_add.json @@ -15,7 +15,7 @@ }, "dtype": [ "float32", - "half", - "bfloat16" + "bfloat16", + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/layernorm.json b/byte_micro_perf/workloads/layernorm.json index 87711ee2..7d9b83ee 100644 --- a/byte_micro_perf/workloads/layernorm.json +++ b/byte_micro_perf/workloads/layernorm.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/log.json b/byte_micro_perf/workloads/log.json index 22d99972..ee31e9ce 100644 --- a/byte_micro_perf/workloads/log.json +++ b/byte_micro_perf/workloads/log.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/mul.json b/byte_micro_perf/workloads/mul.json index c7935637..1c6b32e7 100644 --- a/byte_micro_perf/workloads/mul.json +++ b/byte_micro_perf/workloads/mul.json @@ -17,6 +17,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/p2p.json b/byte_micro_perf/workloads/p2p.json index 7d0c5310..e0e68b3f 100644 --- a/byte_micro_perf/workloads/p2p.json +++ b/byte_micro_perf/workloads/p2p.json @@ -4,11 +4,11 @@ "input_shape_groups": { "inputs": [ [ - [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152], [1024] ], [ - [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152], [1024] ] ] @@ -16,7 +16,7 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ], "group": [ 2, diff --git a/byte_micro_perf/workloads/reduce_max.json b/byte_micro_perf/workloads/reduce_max.json index ae311a3e..99300ced 100644 --- a/byte_micro_perf/workloads/reduce_max.json +++ b/byte_micro_perf/workloads/reduce_max.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/reduce_min.json b/byte_micro_perf/workloads/reduce_min.json index 7b7edb04..8d839e0d 100644 --- a/byte_micro_perf/workloads/reduce_min.json +++ b/byte_micro_perf/workloads/reduce_min.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/reduce_sum.json b/byte_micro_perf/workloads/reduce_sum.json index 56cf77d8..948fbc83 100644 --- a/byte_micro_perf/workloads/reduce_sum.json +++ b/byte_micro_perf/workloads/reduce_sum.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/reducescatter.json b/byte_micro_perf/workloads/reducescatter.json index 228f1f6d..5fcf022d 100644 --- a/byte_micro_perf/workloads/reducescatter.json +++ b/byte_micro_perf/workloads/reducescatter.json @@ -4,7 +4,7 @@ "input_shape_groups": { "inputs": [ [ - [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152], [1024] ] ] @@ -12,7 +12,7 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ], "group": [ 2, diff --git a/byte_micro_perf/workloads/scatter.json b/byte_micro_perf/workloads/scatter.json index 63b86a83..ded92fd7 100644 --- a/byte_micro_perf/workloads/scatter.json +++ b/byte_micro_perf/workloads/scatter.json @@ -11,7 +11,7 @@ }, "dtype": [ "float32", - "float16", - "bfloat16" + "bfloat16", + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/silu.json b/byte_micro_perf/workloads/silu.json index 3770218c..722e11a0 100644 --- a/byte_micro_perf/workloads/silu.json +++ b/byte_micro_perf/workloads/silu.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/sin.json b/byte_micro_perf/workloads/sin.json index bf2bacda..465767a8 100644 --- a/byte_micro_perf/workloads/sin.json +++ b/byte_micro_perf/workloads/sin.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/softmax.json b/byte_micro_perf/workloads/softmax.json index a90f294d..d529e138 100644 --- a/byte_micro_perf/workloads/softmax.json +++ b/byte_micro_perf/workloads/softmax.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/sort.json b/byte_micro_perf/workloads/sort.json index a30222a0..ed3c9797 100644 --- a/byte_micro_perf/workloads/sort.json +++ b/byte_micro_perf/workloads/sort.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/sqrt.json b/byte_micro_perf/workloads/sqrt.json index 428d9897..086a4064 100644 --- a/byte_micro_perf/workloads/sqrt.json +++ b/byte_micro_perf/workloads/sqrt.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/sub.json b/byte_micro_perf/workloads/sub.json index 0b6a46c6..4251081f 100644 --- a/byte_micro_perf/workloads/sub.json +++ b/byte_micro_perf/workloads/sub.json @@ -17,6 +17,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/swiglu.json b/byte_micro_perf/workloads/swiglu.json index 9982a2c9..9e52be67 100644 --- a/byte_micro_perf/workloads/swiglu.json +++ b/byte_micro_perf/workloads/swiglu.json @@ -12,6 +12,6 @@ "dtype": [ "float32", "bfloat16", - "half" + "float16" ] } \ No newline at end of file diff --git a/byte_micro_perf/workloads/unique.json b/byte_micro_perf/workloads/unique.json index ba88ea4e..ed72f21c 100644 --- a/byte_micro_perf/workloads/unique.json +++ b/byte_micro_perf/workloads/unique.json @@ -11,7 +11,7 @@ }, "dtype": [ "float32", - "half", + "float16", "int32" ] } \ No newline at end of file