From f17657ac9e778354edb5d4fac469a2ad28814ee4 Mon Sep 17 00:00:00 2001 From: "gaoyujia.01" Date: Wed, 10 Jul 2024 17:08:40 +0800 Subject: [PATCH] feat(micro): support more ops cast, silu, swiglu, div, mul, sub, gemv, reducemax, reducemin, reducesum, p2p; modify workloads --- .gitignore | 5 +- byte_micro_perf/README.md | 19 +- byte_micro_perf/backends/GPU/backend_gpu.py | 113 +++++-- byte_micro_perf/backends/GPU/custom_ops.py | 168 +++------- byte_micro_perf/backends/backend.py | 134 +++++--- byte_micro_perf/backends/module_store.py | 323 ++++++++++++++----- byte_micro_perf/backends/utils.py | 147 ++++++--- byte_micro_perf/core/perf_engine.py | 162 +++++++--- byte_micro_perf/requirements.txt | 2 + byte_micro_perf/workloads/add.json | 64 +--- byte_micro_perf/workloads/allgather.json | 42 +-- byte_micro_perf/workloads/allreduce.json | 42 +-- byte_micro_perf/workloads/alltoall.json | 86 +---- byte_micro_perf/workloads/batch_gemm.json | 38 ++- byte_micro_perf/workloads/broadcast.json | 44 +-- byte_micro_perf/workloads/cast.json | 17 + byte_micro_perf/workloads/cos.json | 23 +- byte_micro_perf/workloads/device2host.json | 38 +-- byte_micro_perf/workloads/div.json | 22 ++ byte_micro_perf/workloads/exp.json | 23 +- byte_micro_perf/workloads/exponential.json | 13 +- byte_micro_perf/workloads/gelu.json | 23 +- byte_micro_perf/workloads/gemm.json | 221 +------------ byte_micro_perf/workloads/gemv.json | 22 ++ byte_micro_perf/workloads/group_gemm.json | 16 +- byte_micro_perf/workloads/host2device.json | 38 +-- byte_micro_perf/workloads/index_add.json | 21 ++ byte_micro_perf/workloads/indexadd.json | 31 -- byte_micro_perf/workloads/layernorm.json | 40 +-- byte_micro_perf/workloads/mul.json | 22 ++ byte_micro_perf/workloads/p2p.json | 26 ++ byte_micro_perf/workloads/reduce_max.json | 17 + byte_micro_perf/workloads/reduce_min.json | 17 + byte_micro_perf/workloads/reduce_sum.json | 17 + byte_micro_perf/workloads/reducescatter.json | 42 +-- byte_micro_perf/workloads/silu.json | 17 + byte_micro_perf/workloads/sin.json | 23 +- byte_micro_perf/workloads/softmax.json | 52 +-- byte_micro_perf/workloads/sort.json | 26 +- byte_micro_perf/workloads/sub.json | 22 ++ byte_micro_perf/workloads/swiglu.json | 17 + byte_micro_perf/workloads/unique.json | 26 +- 42 files changed, 1067 insertions(+), 1194 deletions(-) create mode 100644 byte_micro_perf/workloads/cast.json create mode 100644 byte_micro_perf/workloads/div.json create mode 100644 byte_micro_perf/workloads/gemv.json create mode 100644 byte_micro_perf/workloads/index_add.json delete mode 100644 byte_micro_perf/workloads/indexadd.json create mode 100644 byte_micro_perf/workloads/mul.json create mode 100644 byte_micro_perf/workloads/p2p.json create mode 100644 byte_micro_perf/workloads/reduce_max.json create mode 100644 byte_micro_perf/workloads/reduce_min.json create mode 100644 byte_micro_perf/workloads/reduce_sum.json create mode 100644 byte_micro_perf/workloads/silu.json create mode 100644 byte_micro_perf/workloads/sub.json create mode 100644 byte_micro_perf/workloads/swiglu.json diff --git a/.gitignore b/.gitignore index eb9726f4..2e06b074 100644 --- a/.gitignore +++ b/.gitignore @@ -25,7 +25,4 @@ init_env.sh byte_infer_perf/llm_perf/download byte_infer_perf/llm_perf/model_zoo/sota -byte_infer_perf/llm_perf/reports - -out/ -*.db \ No newline at end of file +byte_infer_perf/llm_perf/reports \ No newline at end of file diff --git a/byte_micro_perf/README.md b/byte_micro_perf/README.md index ae803775..6033f40e 100644 --- a/byte_micro_perf/README.md +++ b/byte_micro_perf/README.md @@ -46,25 +46,26 @@ Example: "Operator": "EXP", "Backend": "GPU", "Host Info": "Intel(R) Xeon(R) Platinum 8336C CPU @ 2.30GHz", - "Device Info": "A100-PCIE-40GB", + "Device Info": "NVIDIA A800-SXM4-80GB", "Performance": [ { "Dtype": "float32", "Tensor Shapes": [ [ - 2, - 512, - 512 + 256, + 8192 ] ], - "Memory Size(MB)": 4.0, - "Kernel bandwidth(GB/s)": 271.83, - "Bandwidth Utilization(%)": 0.17, - "Avg latency(us)": 15.43 + "Read IO Size(MB)": 8.0, + "Write IO Size(MB)": 8.0, + "Memory Size(MB)": 16.0, + "Kernel bandwidth(GB/s)": 1790.52, + "Bandwidth Utilization(%)": 87.81, + "Avg latency(us)": 9.37, + "QPS": 27321.24 } ] } - ``` ## Trouble Shooting diff --git a/byte_micro_perf/backends/GPU/backend_gpu.py b/byte_micro_perf/backends/GPU/backend_gpu.py index 649547b3..3afc3af1 100644 --- a/byte_micro_perf/backends/GPU/backend_gpu.py +++ b/byte_micro_perf/backends/GPU/backend_gpu.py @@ -56,17 +56,6 @@ def get_backend_properties(self): ) - # gemm ops - def gemm(self): - self.op = GPUGemmOp() - - def batch_gemm(self): - self.op = BatchGemmOp() - - def group_gemm(self): - self.op = GPUGroupGemmOp() - - # device/host ops def host2device(self): self.op = Host2DeviceOp(torch.device("cuda")) @@ -75,7 +64,6 @@ def device2host(self): self.op = Device2HostOp() - # communication ops def allreduce(self): self.setup_2d_group() @@ -97,12 +85,12 @@ def broadcast(self): self.setup_2d_group() self.op = BroadcastOp(self.group) + def p2p(self): + self.setup_2d_group() + self.op = P2POp(self.group, self.ranks, self.rank) - - # other compute ops - def add(self): - self.op = AddOp() - + # compute ops + # unary ops def sin(self): self.op = SinOp() @@ -115,37 +103,87 @@ def exp(self): def exponential(self): self.op = ExponentialOp() + def silu(self): + self.op = SiluOp() + def gelu(self): self.op = GeluOp() + def swiglu(self): + self.op = SwiGLUOp() + + def cast(self): + self.op = CastOp() + + + # binary ops + def add(self): + self.op = AddOp() + + def mul(self): + self.op = MulOp() + + def sub(self): + self.op = SubOp() + + def div(self): + self.op = DivOp() + + + # reduce ops + def layernorm(self): + self.op = LayerNormOp() + + def softmax(self): + self.op = SoftmaxOp() + + def reduce_sum(self): + self.op = ReduceSumOp() + + def reduce_min(self): + self.op = ReduceMinOp() + + def reduce_max(self): + self.op = ReduceMaxOp() + + + # index ops + def index_add(self): + self.op = IndexAddOp() + def sort(self): self.op = SortOp() def unique(self): self.op = UniqueOp() - def indexadd(self): - self.op = IndexAddOp() - def softmax(self): - self.op = SoftmaxOp() + # gemm ops + def gemm(self): + self.op = GPUGemmOp() - def layernorm(self): - self.op = LayerNormOp() + def gemv(self): + self.op = GPUGemmOp() + def batch_gemm(self): + self.op = GPUBatchGemmOp() + def group_gemm(self): + self.op = GPUGroupGemmOp() # create input tensors - def build_tensor(self, input_shapes, torch_dtype): + def build_tensor(self, input_shapes, dtype): + torch.cuda.empty_cache() + torch_dtype = getattr(torch, dtype) # compute size of input and output tensors if hasattr(self.op, "compute_size"): - bytes_per_cnt = self.op.compute_size(input_shapes, torch_dtype) + bytes_per_cnt = self.op.compute_size(input_shapes, dtype) # default: input_tensors_size == output_tensor_size, all tensors have same dtype else: - dtype_size = get_dtype_bytes(torch_dtype) + dtype_size = get_dtype_bytes(dtype) element_num = 2 * sum([math.prod(shape) for shape in input_shapes]) bytes_per_cnt = dtype_size * element_num @@ -154,20 +192,25 @@ def build_tensor(self, input_shapes, torch_dtype): avail_cnts = avail_bytes // bytes_per_cnt max_data_cnt = min(self.iterations, avail_cnts) - # create input tensors for each op input_tensors_list = [] for _ in range(max_data_cnt): # create input tensors if hasattr(self.op, "custom_create_tensors"): - input_tensors = self.op.custom_create_tensors(input_shapes, torch_dtype) + input_tensors = self.op.custom_create_tensors(input_shapes, torch_dtype, "cuda") input_tensors_list.append(input_tensors) # default: all input tensors have same dtype else: - input_tensors = [ - torch.randint(0, 3, size=shape).type(torch_dtype).to(torch.device("cuda")) - for shape in input_shapes - ] + if torch_dtype in [torch.int8, torch.int32]: + input_tensors = [ + torch.randint(-3, 3, size=shape, dtype=torch_dtype, device="cuda") + for shape in input_shapes + ] + else: + input_tensors = [ + torch.randn(shape, dtype=torch_dtype, device="cuda") + for shape in input_shapes + ] input_tensors_list.append(input_tensors) if hasattr(self.op, "process_inputs"): input_tensors_list = [ @@ -225,9 +268,9 @@ def setup_2d_group(self): origin_store_based_barrier = dist_c10d._store_based_barrier dist_c10d._store_based_barrier = lambda *a, **kw: None self.world_size = dist.get_world_size() - ranks = range(0, self.world_size) - group = dist.new_group(ranks) - if self.rank in ranks: + self.ranks = range(0, self.world_size) + group = dist.new_group(self.ranks) + if self.rank in self.ranks: self.group = group dist_c10d._store_based_barrier = origin_store_based_barrier # wait for all ranks finish group initializing diff --git a/byte_micro_perf/backends/GPU/custom_ops.py b/byte_micro_perf/backends/GPU/custom_ops.py index a13ec25d..6f4a6b9a 100644 --- a/byte_micro_perf/backends/GPU/custom_ops.py +++ b/byte_micro_perf/backends/GPU/custom_ops.py @@ -6,37 +6,34 @@ from backends.module_store import GemmOp, BatchGemmOp, GroupGemmOp - - # gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16 # gemm(cutlass) int8 --> int32 class GPUGemmOp(GemmOp): def __init__(self): super().__init__() - # cutlass int8 gemm - dtype = torch.int8 - accum_dtype=torch.int32 - self.plan = cutlass.op.Gemm( - alpha=1, beta=0, - element_A=dtype, - element_B=dtype, - element_C=accum_dtype, - element_D=accum_dtype, - layout_A=cutlass.LayoutType.ColumnMajorInterleaved32, - layout_B=cutlass.LayoutType.RowMajorInterleaved32, - layout_C=cutlass.LayoutType.RowMajor - ) - self.op = self.plan.construct( - alignment_A=16, - alignment_B=16, - alignment_C=8 - ) - self.gemm_op_int8 = cutlass.emit.pytorch( - self.op, name='gemm', cc=self.plan.cc, - jit=True, sourcedir='out' - ) - + try: + import cutlass + dtype = torch.int8 + accum_dtype=torch.int32 + self.plan = cutlass.op.Gemm( + alpha=1, beta=0, + element_A=dtype, + element_B=dtype, + element_C=accum_dtype, + element_D=accum_dtype, + layout_A=cutlass.LayoutType.RowMajor, + layout_B=cutlass.LayoutType.RowMajor, + layout_C=cutlass.LayoutType.RowMajor + ) + self.op = self.plan.construct() + self.gemm_op_int8 = cutlass.emit.pytorch( + self.op, name='gemm', cc=self.plan.cc, + jit=True, sourcedir='out' + ) + except: + self.gemm_op_int8 = None + raise Exception("GPUGemmOp cutlass error") def forward( self, @@ -44,27 +41,23 @@ def forward( input_tensor_b : torch.Tensor ): compute_dtype = input_tensor_a.dtype - if compute_dtype == torch.int8 and self.gemm_op_int8 is not None: - output_tensor = self.gemm_op_int8.run( - input_tensor_a, input_tensor_b - ) + if compute_dtype == torch.int8: + output_tensor = self.gemm_op_int8.run(input_tensor_a, input_tensor_b) else: - output_tensor = torch.mm( - input_tensor_a, input_tensor_b - ) + output_tensor = torch.mm(input_tensor_a, input_tensor_b) return output_tensor - - # batch_gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16 # batch_gemm(cutlass) int8 --> int32 class GPUBatchGemmOp(BatchGemmOp): def __init__(self): super().__init__() - # TODO: cutlass int8 batch_gemm - pass + try: + import cutlass + except: + raise Exception("GPUBatchGemmOp import cutlass error") def forward( self, @@ -75,78 +68,27 @@ def forward( output_tensor = None if compute_dtype == torch.int8: - # TODO - pass + bs, m, n = input_tensor_a.shape[0], input_tensor_a.shape[1], input_tensor_b.shape[2] + c_tensor = torch.randint(-3, 3, [bs, m, n], dtype=torch.int32, device="cuda") + output_tensor = torch.randint(-3, 3, [bs, m, n], dtype=torch.int32, device="cuda") + plan = cutlass.op.Gemm(A=input_tensor_a, B=input_tensor_b, C=c_tensor, D=output_tensor, element_accumulator=cutlass.DataType.s32) + plan.run(input_tensor_a, input_tensor_b, c_tensor, output_tensor, 1, 0) else: - output_tensor = torch.bmm( - input_tensor_a, input_tensor_b - ) + output_tensor = torch.bmm(input_tensor_a, input_tensor_b) return output_tensor - - -# group_gemm(cutlass) float32/float16/bfloat16 --> float32 +# group_gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16 # group_gemm(cutlass) int8 --> int32 class GPUGroupGemmOp(GroupGemmOp): def __init__(self): super().__init__() - self.group_gemm_fp32 = GPUGroupGemmOp.compile_mod( - dtype=torch.float32, - accum_dtype=torch.float32, - mod_name="groupd_gemm_fp32" - ) - - self.group_gemm_fp16 = GPUGroupGemmOp.compile_mod( - dtype=torch.float16, - accum_dtype=torch.float32, - mod_name="groupd_gemm_fp16" - ) - - self.group_gemm_bf16 = GPUGroupGemmOp.compile_mod( - dtype=torch.bfloat16, - accum_dtype=torch.float32, - mod_name="groupd_gemm_bf16" - ) - - # TODO: cutlass int8 group_gemm - self.group_gemm_int8 = None - # if "int8" in dtype_list: - # self.group_gemm_int8 = GroupGemmOp.compile_mod( - # dtype=torch.int8, - # accum_dtype=torch.int32, - # mod_name="group_gemm_int8" - # ) - - @staticmethod - def compile_mod(dtype, accum_dtype, mod_name): - - if dtype == torch.int8: - # TODO - pass - # plan = cutlass.op.Gemm( - # alpha=1, beta=0, - # element_A=dtype, - # element_B=dtype, - # element_C=accum_dtype, - # element_D=accum_dtype, - # layout_A=cutlass.LayoutType.ColumnMajorInterleaved32, - # layout_B=cutlass.LayoutType.RowMajorInterleaved32, - # layout_C=cutlass.LayoutType.RowMajor - # ) - # op = plan.construct( - # alignment_A=16, - # alignment_B=16, - # alignment_C=8 - # ) - # grouped_gemm = cutlass.emit.pytorch( - # op, name=mod_name, - # cc=plan.cc, jit=True, - # sourcedir='out' - # ) - else: - plan = cutlass.op.GroupedGemm( + try: + import cutlass + dtype = torch.int8 + accum_dtype=torch.int32 + self.plan = cutlass.op.GroupedGemm( alpha=1, beta=0, element_A=dtype, element_B=dtype, @@ -156,30 +98,22 @@ def compile_mod(dtype, accum_dtype, mod_name): layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor ) - op = plan.construct() - grouped_gemm = cutlass.emit.pytorch( - op, name=mod_name, - cc=plan.cc, jit=True, - sourcedir='./out' + self.op = self.plan.construct() + self.gemm_op_int8 = cutlass.emit.pytorch( + self.op, name='group_gemm', cc=self.plan.cc, + jit=True, sourcedir='out' ) - return grouped_gemm - + except: + self.gemm_op_int8 = None + raise Exception("GPUGroupGemmOp cutlass error") def forward(self, a_list : List[torch.Tensor], b_list : List[torch.Tensor] ): compute_dtype = a_list[0].dtype - if compute_dtype == torch.float32 and self.group_gemm_fp32 is not None: - output_tensors = self.group_gemm_fp32.run(a_list, b_list) - elif compute_dtype == torch.float16 and self.group_gemm_fp16 is not None: - output_tensors = self.group_gemm_fp16.run(a_list, b_list) - elif compute_dtype == torch.bfloat16 and self.group_gemm_bf16 is not None: - output_tensors = self.group_gemm_bf16.run(a_list, b_list) - elif compute_dtype == torch.int8 and self.group_gemm_int8 is not None: - # TODO - pass - # output_tensors = self.group_gemm_int8.run(a_list, b_list) + if compute_dtype == torch.int8: + output_tensors = self.gemm_op_int8.run(a_list, b_list) else: - output_tensors = [] + output_tensors = [a @ b for a, b in zip(a_list, b_list)] return output_tensors \ No newline at end of file diff --git a/byte_micro_perf/backends/backend.py b/byte_micro_perf/backends/backend.py index 79258b1b..eb156f85 100644 --- a/byte_micro_perf/backends/backend.py +++ b/byte_micro_perf/backends/backend.py @@ -15,6 +15,7 @@ import os import time import random +import traceback from abc import ABC, abstractmethod from typing import Any, Dict, List @@ -39,6 +40,8 @@ def __init__(self, workload_dict: Dict[str, Any], vendor_path: str): self.bandwidth_limit = None self.get_backend_properties() + self.target_dtype = None + @abstractmethod def get_device_name(self): pass @@ -68,27 +71,13 @@ def setup_2d_group(self): pass - - # gemm ops - def gemm(self): - pass - - def batch_gemm(self): - pass - - def group_gemm(self): - pass - - - # device/host ops + # communication ops def host2device(self): pass def device2host(self): pass - - # communication ops def allreduce(self): pass @@ -104,11 +93,11 @@ def alltoall(self): def broadcast(self): pass - - # other compute ops - def add(self): + def p2p(self): pass + # compute ops + # unary ops def sin(self): pass @@ -121,30 +110,73 @@ def exp(self): def exponential(self): pass + def silu(self): + pass + def gelu(self): pass - def indexadd(self): + def swiglu(self): pass - def sort(self): + def cast(self): pass - def unique(self): + + # binary ops + def add(self): pass - def softmax(self): + def mul(self): + pass + + def sub(self): pass + def div(self): + pass + + + # reduce ops def layernorm(self): pass + def softmax(self): + pass + + def reduce_sum(self): + pass + def reduce_min(self): + pass + def reduce_max(self): + pass + # index ops + def index_add(self): + pass + def sort(self): + pass + def unique(self): + pass + + + # gemm ops + def gemm(self): + pass + + def gemv(self): + pass + + def batch_gemm(self): + pass + + def group_gemm(self): + pass # perf specify input_shape for @@ -157,42 +189,48 @@ def perf(self, input_shapes: List[List[int]], dtype): ) if tensor_cnt > 0: - # random select input tensors - input_index_list = [ - random.randint(0, tensor_cnt - 1) for _ in range(self.iterations) - ] - - # warmup - num_warm_up = 10 - for _ in range(num_warm_up): - self._run_operation(self.op, tensor_list[0]) - - # perf - self.device_synchronize() - start_time = time.perf_counter_ns() - for i in range(self.iterations): - result = self._run_operation( - self.op, - tensor_list[input_index_list[i]] - ) - self.device_synchronize() - end_time = time.perf_counter_ns() - - # time in us - total_exec_time = (end_time - start_time) / 1e3 - latency = round(total_exec_time / self.iterations, 2) + try: + # random select input tensors + input_index_list = [ + random.randint(0, tensor_cnt - 1) for _ in range(self.iterations) + ] + + # warmup + num_warm_up = 10 + for _ in range(num_warm_up): + self._run_operation(self.op, tensor_list[0]) + + # perf + self.device_synchronize() + start_time = time.perf_counter_ns() + for i in range(self.iterations): + self._run_operation( + self.op, + tensor_list[input_index_list[i]] + ) + self.device_synchronize() + end_time = time.perf_counter_ns() + + # time in us + total_exec_time = (end_time - start_time) / 1e3 + latency = round(total_exec_time / self.iterations, 2) + except Exception as e: + traceback.print_exc() + latency = 0 + error = "RUN_OP_ERROR" else: latency = 0 error = "OOM" + tensor_list = [] - if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast"]: + if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p"]: report = dump_communication_ops_report( self.op_name, dtype, input_shapes, self.group.size(), - self.bandwidth_limit, + None, latency, error ) diff --git a/byte_micro_perf/backends/module_store.py b/byte_micro_perf/backends/module_store.py index fcda6f81..3a8b8e7e 100644 --- a/byte_micro_perf/backends/module_store.py +++ b/byte_micro_perf/backends/module_store.py @@ -25,92 +25,120 @@ class GemmOp(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + def compute_size(self, input_shapes, dtype): + # input_shapes: [[M, K], [K, N]] + torch_dtype = getattr(torch, dtype) + a_shape, b_shape = input_shapes + M, K = a_shape + K, N = b_shape + d_shape = [M, N] + dtype_size = get_dtype_bytes(dtype) + input_element_num = sum([math.prod(shape) for shape in [a_shape, b_shape]]) + output_element_num = sum([math.prod(shape) for shape in [d_shape]]) + if torch_dtype == torch.int8: + bytes_per_cnt = dtype_size * input_element_num + get_dtype_bytes("float32") * output_element_num + else: + bytes_per_cnt = dtype_size * (input_element_num + output_element_num) + return bytes_per_cnt + def forward(self, input_tensor_a, input_tensor_b): compute_dtype = input_tensor_a.dtype output_tensor = None - if compute_dtype == torch.int8: - # to be realized - pass - elif compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: + if compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: output_tensor = torch.mm(input_tensor_a, input_tensor_b) + else: + raise Exception(f"GemmOp with dtype {compute_dtype} is not implemented") return output_tensor + + +class GemvOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) - def compute_size(self, input_shapes, torch_dtype): + def compute_size(self, input_shapes, dtype): # input_shapes: [[M, K], [K, N]] + torch_dtype = getattr(torch, dtype) a_shape, b_shape = input_shapes M, K = a_shape K, N = b_shape d_shape = [M, N] - dtype_size = get_dtype_bytes(torch_dtype) - element_num = sum([math.prod(shape) for shape in [a_shape, b_shape, d_shape]]) - bytes_per_cnt = dtype_size * element_num + dtype_size = get_dtype_bytes(dtype) + input_element_num = sum([math.prod(shape) for shape in [a_shape, b_shape]]) + output_element_num = sum([math.prod(shape) for shape in [d_shape]]) + if torch_dtype == torch.int8: + bytes_per_cnt = dtype_size * input_element_num + get_dtype_bytes("float32") * output_element_num + else: + bytes_per_cnt = dtype_size * (input_element_num + output_element_num) return bytes_per_cnt + def forward(self, input_tensor_a, input_tensor_b): + compute_dtype = input_tensor_a.dtype + output_tensor = None + if compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: + output_tensor = torch.mm(input_tensor_a, input_tensor_b) + else: + raise Exception(f"GemvOp with dtype {compute_dtype} is not implemented") + return output_tensor class BatchGemmOp(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - def forward(self, input_tensor_a, input_tensor_b): - compute_dtype = input_tensor_a.dtype - output_tensor = None - if compute_dtype == torch.int8: - # to be realized - pass - elif compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: - output_tensor = torch.bmm(input_tensor_a, input_tensor_b) - return output_tensor - - def compute_size(self, input_shapes, torch_dtype): + def compute_size(self, input_shapes, dtype): # input_shapes: [[bs, M, K], [bs, K, N]] + torch_dtype = getattr(torch, dtype) a_shape, b_shape = input_shapes bs, M, K = a_shape bs, K, N = b_shape d_shape = [bs, M, N] - dtype_size = get_dtype_bytes(torch_dtype) - element_num = sum([math.prod(shape) for shape in [a_shape, b_shape, d_shape]]) - bytes_per_cnt = dtype_size * element_num + dtype_size = get_dtype_bytes(dtype) + input_element_num = sum([math.prod(shape) for shape in [a_shape, b_shape]]) + output_element_num = sum([math.prod(shape) for shape in [d_shape]]) + if torch_dtype == torch.int8: + bytes_per_cnt = dtype_size * input_element_num + get_dtype_bytes("int32") * output_element_num * 2 + else: + bytes_per_cnt = dtype_size * (input_element_num + output_element_num) return bytes_per_cnt - + + def forward(self, input_tensor_a, input_tensor_b): + compute_dtype = input_tensor_a.dtype + output_tensor = None + if compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: + output_tensor = torch.bmm(input_tensor_a, input_tensor_b) + else: + raise Exception(f"BatchGemmOp with dtype {compute_dtype} is not implemented") + return output_tensor class GroupGemmOp(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - def forward(self, input_tensor_a, input_tensor_b): - compute_dtype = input_tensor_a.dtype - output_tensor_list = [] - for a, b in zip(input_tensor_a, input_tensor_b): - if compute_dtype == torch.int8: - # to be realized - pass - elif compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: - output_tensor = torch.mm(a, b) - output_tensor_list.append(output_tensor) - return output_tensor_list - - - def compute_size(self, input_shapes, torch_dtype): + def compute_size(self, input_shapes, dtype): """ [ [[M1, K1], [K1, N1]], [[M2, K2], [K2, N2]] ] """ + torch_dtype = getattr(torch, dtype) bytes_per_cnt = 0 for problem_shape in input_shapes: a_shape, b_shape = problem_shape M, K = a_shape K, N = b_shape d_shape = [M, N] - dtype_size = get_dtype_bytes(torch_dtype) - element_num = sum([math.prod(shape) for shape in [a_shape, b_shape, d_shape]]) - bytes_per_cnt += dtype_size * element_num + dtype_size = get_dtype_bytes(dtype) + input_element_num = sum([math.prod(shape) for shape in [a_shape, b_shape]]) + output_element_num = sum([math.prod(shape) for shape in [d_shape]]) + if torch_dtype == torch.int8: + bytes_per_cnt += dtype_size * input_element_num + get_dtype_bytes("float32") * output_element_num + else: + bytes_per_cnt += dtype_size * (input_element_num + output_element_num) return bytes_per_cnt - def custom_create_tensors(self, input_shapes, torch_dtype): + def custom_create_tensors(self, input_shapes, torch_dtype, xpu_device): """ [ [[M1, K1], [K1, N1]], @@ -122,16 +150,27 @@ def custom_create_tensors(self, input_shapes, torch_dtype): for problem_shape in input_shapes: a_shape, b_shape = problem_shape - left_tensors.append( - torch.randint(0, 3, size=a_shape).type(torch_dtype).to(torch.device("cuda")) - ) - right_tensors.append( - torch.randint(0, 3, size=b_shape).type(torch_dtype).to(torch.device("cuda")) - ) - return [left_tensors, right_tensors] - + if torch_dtype in [torch.int8, torch.int32]: + left_tensor = torch.randint(-3, 3, size=a_shape, dtype=torch_dtype, device=xpu_device) + right_tensor = torch.randint(-3, 3, size=b_shape, dtype=torch_dtype, device=xpu_device) + else: + left_tensor = torch.randn(a_shape, dtype=torch_dtype, device=xpu_device) + right_tensor = torch.randn(b_shape, dtype=torch_dtype, device=xpu_device) + left_tensors.append(left_tensor) + right_tensors.append(right_tensor) + return [left_tensors, right_tensors] + def forward(self, input_tensor_a, input_tensor_b): + compute_dtype = input_tensor_a.dtype + output_tensor_list = [] + for a, b in zip(input_tensor_a, input_tensor_b): + if compute_dtype in [torch.float32, torch.float16, torch.bfloat16]: + output_tensor = torch.mm(a, b) + output_tensor_list.append(output_tensor) + else: + raise Exception(f"GroupGemmOp with dtype {compute_dtype} is not implemented") + return output_tensor_list class Host2DeviceOp(torch.nn.Module): @@ -159,8 +198,6 @@ def forward(self, input_tensors): return output_cpu - - class AllReduceOp(torch.nn.Module): def __init__(self, group): super().__init__() @@ -238,35 +275,79 @@ def __init__(self, group): def forward(self, input_tensors): dist.broadcast(input_tensors, 0, self.group) return True - +class P2POp(torch.nn.Module): + def __init__(self, group, ranks, rank): + super().__init__() + self.group = group + self.group_size = self.group.size() + self.rank = rank + self.ranks = ranks + self.rank_size = len(ranks) + + def next_rank(self): + return self.ranks[(self.rank + 1) % self.rank_size] + + def prev_rank(self): + return self.ranks[(self.rank - 1) % self.rank_size] + + def forward(self, send_tensor, recv_tensor): + reqs = [] + if self.rank != (self.group_size - 1): + send_req = dist.isend(send_tensor, self.next_rank(), self.group) + reqs.append(send_req) + if self.rank != 0: + recv_req = dist.irecv(recv_tensor, self.prev_rank(), self.group) + reqs.append(recv_req) + + for req in reqs: + req.wait() + return True -class AddOp(torch.nn.Module): +class SinOp(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, input_tensor_a, input_tensor_b): - result = input_tensor_a + input_tensor_b + def forward(self, input_tensors): + result = torch.sin(input_tensors) return result -class SinOp(torch.nn.Module): +class CosOp(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input_tensors): - result = torch.sin(input_tensors) + result = torch.cos(input_tensors) return result -class CosOp(torch.nn.Module): +class ExpOp(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input_tensors): - result = torch.cos(input_tensors) + result = torch.exp(input_tensors) + return result + + +class ExponentialOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_tensors): + result = input_tensors.exponential_() + return result + + +class SiluOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_tensors): + result = torch.nn.functional.silu(input_tensors) return result @@ -279,77 +360,151 @@ def forward(self, input_tensors): return result -class ExponentialOp(torch.nn.Module): +class SwiGLUOp(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.w = 1 + self.v = 2 + + def forward(self, input_tensors): + result = (torch.nn.functional.sigmoid(input_tensors) * self.w) + (input_tensors * self.v) + return result + + +class CastOp(torch.nn.Module): def __init__(self): super().__init__() + def set_dtype(self, src_dtype: str): + target_dtype = "bfloat16" if src_dtype == "float32" else "float32" + self.target_dtype = target_dtype + self.target_torch_dtype = getattr(torch, target_dtype) + + def compute_size(self, input_shapes, dtype): + torch_dtype = getattr(torch, dtype) + self.set_dtype(dtype) + dtype_size = get_dtype_bytes(dtype) + target_dtype_size = get_dtype_bytes(self.target_dtype) + element_num = sum([math.prod(shape) for shape in input_shapes]) + bytes_per_cnt = dtype_size * element_num + target_dtype_size * element_num + return bytes_per_cnt + def forward(self, input_tensors): - result = input_tensors.exponential_() + result = input_tensors.to(self.target_torch_dtype) return result -class IndexAddOp(torch.nn.Module): +class AddOp(torch.nn.Module): def __init__(self): super().__init__() - def process_inputs(self, input_tensor, source_tensor): - index = torch.randint(0, input_tensor.shape[0], (source_tensor.shape[0],)).to( - input_tensor.device - ) - return [input_tensor, index, source_tensor] + def forward(self, input_tensor_a, input_tensor_b): + result = input_tensor_a + input_tensor_b + return result - def forward(self, input_tensor, index, source_tensor): - result = input_tensor.index_add_(0, index, source_tensor) + +class MulOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_tensor_a, input_tensor_b): + result = input_tensor_a * input_tensor_b return result -class SortOp(torch.nn.Module): +class SubOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_tensor_a, input_tensor_b): + result = input_tensor_a - input_tensor_b + return result + + +class DivOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_tensor_a, input_tensor_b): + result = input_tensor_a / input_tensor_b + return result + + +class LayerNormOp(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input_tensors): - result = torch.sort(input_tensors) + result = torch.nn.functional.layer_norm( + input_tensors, (input_tensors.shape[-1],) + ) return result -class UniqueOp(torch.nn.Module): +class SoftmaxOp(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input_tensors): - result = torch.unique(input_tensors, return_counts=True) + result = torch.nn.functional.softmax(input_tensors, dim=-1) return result -class ExpOp(torch.nn.Module): +class ReduceSumOp(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input_tensors): - result = torch.exp(input_tensors) + result = torch.sum(input_tensors, dim=-1) return result +class ReduceMinOp(torch.nn.Module): + def __init__(self): + super().__init__() -class SoftmaxOp(torch.nn.Module): + def forward(self, input_tensors): + result = torch.min(input_tensors, dim=-1) + return result + + +class ReduceMaxOp(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, hidden_states): - logits = torch.nn.functional.softmax(hidden_states, dim=-1) - return logits + def forward(self, input_tensors): + result = torch.max(input_tensors, dim=-1) + return result -class LayerNormOp(torch.nn.Module): +class IndexAddOp(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, hidden_states): - logits = torch.nn.functional.layer_norm( - hidden_states, (hidden_states.shape[-1],) + def process_inputs(self, input_tensor, source_tensor): + index = torch.randint(0, input_tensor.shape[0], (source_tensor.shape[0],)).to( + input_tensor.device ) - return logits + return [input_tensor, index, source_tensor] + + def forward(self, input_tensor, index, source_tensor): + result = input_tensor.index_add_(0, index, source_tensor) + return result + + +class SortOp(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, input_tensors): + result = torch.sort(input_tensors) + return result +class UniqueOp(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, input_tensors): + result = torch.unique(input_tensors, return_counts=True) + return result diff --git a/byte_micro_perf/backends/utils.py b/byte_micro_perf/backends/utils.py index 9265ed38..1d5e8811 100644 --- a/byte_micro_perf/backends/utils.py +++ b/byte_micro_perf/backends/utils.py @@ -19,21 +19,92 @@ import torch -def get_dtype_bytes(torch_type): +def get_dtype_bytes(dtype: str): + torch_dtype = getattr(torch, dtype) dtype_size = 0 - if torch_type in [torch.int32, torch.int8]: - dtype_size = torch.iinfo(torch_type).bits // 8 - elif torch_type in [torch.float32, torch.float16, torch.bfloat16]: - dtype_size = torch.finfo(torch_type).bits // 8 + if torch_dtype in [torch.int32, torch.int8]: + dtype_size = torch.iinfo(torch_dtype).bits // 8 + elif torch_dtype in [torch.float32, torch.float16, torch.bfloat16]: + dtype_size = torch.finfo(torch_dtype).bits // 8 else: # not supported yet pass return dtype_size +def get_io_amount(op_name, input_shapes, dtype): + batch_size = input_shapes[0][0] + dtype_size = get_dtype_bytes(dtype) + if op_name in ["add", "mul", "sub", "div"]: + # c = a + b + read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) + write_io_amount = dtype_size * math.prod(input_shapes[0]) + elif op_name == "gemm": + M = input_shapes[0][0] + K = input_shapes[0][1] + N = input_shapes[1][1] + read_io_amount = dtype_size * (M * K + K * N) + if dtype != torch.int8: + write_io_amount = dtype_size * (M * N) + else: + write_io_amount = get_dtype_bytes("int32") * (M * N) + elif op_name == "batch_gemm": + bs = input_shapes[0][0] + M = input_shapes[0][1] + K = input_shapes[0][2] + N = input_shapes[1][2] + read_io_amount = dtype_size * bs * (M * K + K * N) + if dtype != torch.int8: + write_io_amount = dtype_size * bs * (M * N) + else: + write_io_amount = get_dtype_bytes("int32") * bs * (M * N) + elif op_name == "group_gemm": + in_size_list = [] + out_size_list = [] + m_list = [] + for problem_shape in input_shapes: + M = problem_shape[0][0] + K = problem_shape[0][1] + N = problem_shape[1][1] + in_size_list.append(M * K + K * N) + out_size_list.append(M * N) + m_list.append(M) + batch_size = sum(m_list) + read_io_amount = dtype_size * sum(in_size_list) + if dtype != torch.int8: + write_io_amount = dtype_size * sum(out_size_list) + else: + write_io_amount = get_dtype_bytes("int32") * sum(out_size_list) + elif op_name in ["device2host"]: + read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) + write_io_amount = 0 + elif op_name in ["host2device"]: + read_io_amount = 0 + write_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) + elif op_name in ["reduce_sum", "reduce_max", "reduce_min"]: + read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) + write_io_amount = dtype_size * sum([math.prod(shape[:-1]) for shape in input_shapes]) + elif op_name in ["unqiue", "sort"]: + read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) + write_io_amount = 2 * dtype_size * sum([math.prod(shape) for shape in input_shapes]) + elif op_name == "cast": + read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) + write_io_amount = read_io_amount / 2 if dtype == torch.float32 else read_io_amount * 2 + elif op_name in ["index_add"]: + read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) + get_dtype_bytes("int32") * input_shapes[1][0] + write_io_amount = dtype_size * math.prod(input_shapes[0]) + else: + read_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) + write_io_amount = dtype_size * sum([math.prod(shape) for shape in input_shapes]) + + total_io_amount = read_io_amount + write_io_amount + + return batch_size, total_io_amount, read_io_amount, write_io_amount + + def dump_communication_ops_report( op_name: str, - dtype: torch.dtype, + dtype: str, input_shapes: List[List[int]], group_size: List[int], bandwidth_limit: float, @@ -52,9 +123,10 @@ def dump_communication_ops_report( reducescatter: 1 * (group_size - 1) * (tensor_size / group_size) alltoall: 1 * (group_size - 1) * (tensor_size / group_size) broadcast: tensor_size + p2p: tensor_size """ bus_bw = algo_bw * (group_size - 1) / group_size - if op_name == "broadcast": + if op_name in ["broadcast", "p2p"]: bus_bw = algo_bw if op_name == "allreduce": bus_bw *= 2 @@ -63,7 +135,7 @@ def dump_communication_ops_report( if bandwidth_limit is not None: bandwidth_utils = round((algo_bw / bandwidth_limit) * 1e2, 2) report = { - "Dtype": dtype, + "Dtype": str(dtype), "Tensor Shapes": input_shapes, "Memory Size(MB)": round(mb, 2), "Group": group_size, @@ -74,7 +146,7 @@ def dump_communication_ops_report( } else: report = { - "Dtype": dtype, + "Dtype": str(dtype), "Tensor Shapes": input_shapes, "Memory Size(MB)": round(mb, 2), "Group": group_size, @@ -89,72 +161,43 @@ def dump_communication_ops_report( def dump_computation_ops_report( op_name: str, - dtype: torch.dtype, + dtype: str, input_shapes: List[List[int]], bandwidth_limit: float, latency: float, error: str = "" ): - if op_name == "add": - # c = a + b - # MAC_total = MAC_a + MAC_b + MAC_c - size = sum( - [math.prod(shape) for shape in input_shapes], math.prod(input_shapes[0]) - ) - elif op_name == "gemm": - # c = gemm(a, b) - # MAC_total = MAC_a + MAC_b + MAC_c - M = input_shapes[0][0] - K = input_shapes[0][1] - N = input_shapes[1][1] - size = M * K + K * N + M * N - elif op_name == "batch_gemm": - # c = batch_gemm(a, b) - bs = input_shapes[0][0] - M = input_shapes[0][1] - K = input_shapes[0][2] - N = input_shapes[1][2] - size = bs * (M * K + K * N + M * N) - elif op_name == "group_gemm": - # c_list = group_gemm(a_list, b_list) - size_list = [] - for problem_shape in input_shapes: - M = problem_shape[0][0] - K = problem_shape[0][1] - N = problem_shape[1][1] - size_list.append(M * K + K * N + M * N) - size = sum(size_list) - elif op_name in ["unique", "device2host", "host2device"]: - size = sum([math.prod(shape) for shape in input_shapes]) - else: - # out = func(in) - # MAC_total = MAC_in + MAC_out - size = sum([math.prod(shape) for shape in input_shapes]) * 2 + batch_size, total_io_amount, read_io_amount, write_io_amount = get_io_amount(op_name, input_shapes, dtype) - dtype_size = get_dtype_bytes(dtype) - mb = dtype_size * size / 1024 / 1024 if error == "": - algo_bw = dtype_size * size / latency / 1e3 + qps = round(1000 / latency * batch_size, 2) + algo_bw = total_io_amount / latency / 1e3 bandwidth_utils = None if bandwidth_limit is not None: bandwidth_utils = round((algo_bw / bandwidth_limit) * 1e2, 2) report = { - "Dtype": dtype, + "Dtype": str(dtype), "Tensor Shapes": input_shapes, - "Memory Size(MB)": round(mb, 2), + "Read IO Size(MB)": round(read_io_amount / 1024 / 1024, 2), + "Write IO Size(MB)": round(write_io_amount / 1024 / 1024, 2), + "Memory Size(MB)": round(total_io_amount / 1024 / 1024, 2), "Kernel bandwidth(GB/s)": round(algo_bw, 2), "Bandwidth Utilization(%)": bandwidth_utils, "Avg latency(us)": round(latency, 2), + "QPS": qps, } else: report = { - "Dtype": dtype, + "Dtype": str(dtype), "Tensor Shapes": input_shapes, - "Memory Size(MB)": round(mb, 2), + "Read IO Size(MB)": round(read_io_amount / 1024 / 1024, 2), + "Write IO Size(MB)": round(write_io_amount / 1024 / 1024, 2), + "Memory Size(MB)": round(total_io_amount / 1024 / 1024, 2), "Kernel bandwidth(GB/s)": 0, "Bandwidth Utilization(%)": None, "Avg latency(us)": 0, + "QPS": 0, "Error": error, } return report diff --git a/byte_micro_perf/core/perf_engine.py b/byte_micro_perf/core/perf_engine.py index 94258120..e28c1aef 100644 --- a/byte_micro_perf/core/perf_engine.py +++ b/byte_micro_perf/core/perf_engine.py @@ -24,9 +24,9 @@ import traceback import random from typing import Any, Dict, List +import itertools -import torch import torch.multiprocessing as mp import virtualenv @@ -92,6 +92,99 @@ def load_workload(task: str) -> Dict[str, Any]: "Task name: [ {} ] was not found, please check your task name".format(task) ) +def parse_workload(workload): + shape_list = [] + if "input_shape_list" in workload: + shape_list.extend(workload["input_shape_list"]) + # gemm or batch_gemm + elif "M/K/N" in workload: + if "batch_size" in workload: + for batch_size in workload["batch_size"]: + for M, K, N in workload["M/K/N"]: + shape_list.append([ + [batch_size, M, K], + [batch_size, K, N] + ]) + else: + for M, K, N in workload["M/K/N"]: + shape_list.append([[M, K], [K, N]]) + # group_gemm + elif "MKN_choices" in workload: + seed = workload["seed"] + MKN_list = workload["MKN_choices"] + problems_list = workload["problems"] + + random.seed(seed) + for problems in problems_list: + cur_inputs = [] + for _ in range(problems): + M, K, N = [random.choice(MKN_list) for _ in range(3)] + cur_shapes = [[M, K], [K, N]] + cur_inputs.append(cur_shapes) + shape_list.append(cur_inputs) + + + if "input_shape_groups" in workload: + input_shape_groups = workload["input_shape_groups"] if isinstance(workload["input_shape_groups"], list) else [workload["input_shape_groups"]] + + for input_shape_group in input_shape_groups: + if "inputs" in input_shape_group: + input_shape_list = [] + for input_shapes in input_shape_group["inputs"]: + input_shape_list.append([list(shape) for shape in itertools.product(*input_shapes)]) + if len(input_shape_list) == 1: + shape_list.extend(input_shape_list[0]) + else: + shape_list.extend([list(input_shape) for input_shape in zip(*input_shape_list)]) + + else: + gemm_keys = ["M", "K", "N", "MN", "MK", "KN"] + gemm_values = [input_shape_group.get(k, []) for k in gemm_keys] + if any(gemm_values): + m ,k, n, mn, mk, kn = gemm_values + # batch gemm + if "batch_size" in input_shape_group: + bs = input_shape_group.get("batch_size", []) + if m and n and k: + for p in itertools.product(bs, m, k, n): + shape_list.append([[p[0], p[1], p[2]], [p[0], p[2], p[3]]]) + if mn and k: + for p in itertools.product(bs, mn, k): + shape_list.append([[p[0], p[1][0], p[2]], [p[0], p[2], p[1][1]]]) + if mk and n: + for p in itertools.product(bs, mk, n): + shape_list.append([[p[0], p[1][0], p[1][1]], [p[0], p[1][1], p[2]]]) + if m and kn: + for p in itertools.product(bs, m, kn): + shape_list.append([[p[0], p[1], p[2][0]], [p[0], p[2][0], p[2][1]]]) + # group gemm + elif "gemm_group" in input_shape_group: + groups = input_shape_group.get("gemm_group", []) + kn = input_shape_group.get("KN", []) + if k and n: + kn.append([list(shape) for shape in itertools.product(k, n)]) + for group in groups: + for _kn in kn: + group_input_shape_list = [] + for m in group: + group_input_shape_list.append([[m, _kn[0]], [_kn[0], _kn[1]]]) + shape_list.append(group_input_shape_list) + # gemm + else: + if m and n and k: + for p in itertools.product(m, k, n): + shape_list.append([[p[0], p[1]], [p[1], p[2]]]) + if mn and k: + for p in itertools.product(mn, k): + shape_list.append([[p[0][0], p[1]], [p[1], p[0][1]]]) + if mk and n: + for p in itertools.product(mk, n): + shape_list.append([[p[0][0], p[0][1]], [p[0][1], p[1]]]) + if m and kn: + for p in itertools.product(m, kn): + shape_list.append([[p[0], p[1][0]], [p[1][0], p[1][1]]]) + return shape_list + class PerfEngine: def __init__(self) -> None: @@ -143,20 +236,26 @@ def start_engine(self) -> None: output_dir = os.path.abspath("reports/" + self.backend_type) os.makedirs(output_dir, exist_ok=True) - if self.args.task in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast"]: + if self.args.task in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p"]: for group in self.workload["group"]: - mp.spawn(fn=self.init_process, args=(group,), nprocs=group) + try: + mp.spawn(fn=self.init_process, args=(group,), nprocs=group) + except Exception as e: + traceback.print_exc() + log.error(f"Execute task: {self.args.task} failed, group: {group}, error msg: {e}") else: status = self.start_perf(self.workload) self.deactivate_venv() def start_perf(self, workload: Dict[str, Any]) -> bool: - log.info( - "******************************************* Start to test op: {}. *******************************************".format( - workload["operator"] + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + if local_rank == 0: + log.info( + "******************************************* Start to test op: [{}]. *******************************************".format( + workload["operator"] + ) ) - ) # Initalize Output Dir and Reports output_dir = pathlib.Path("reports").joinpath(self.backend_type).joinpath(workload["operator"]) @@ -177,44 +276,12 @@ def start_perf(self, workload: Dict[str, Any]) -> bool: raise ValueError(f"Unknown operation: {op_name.lower()}") # get input shape info - shape_list = [] - - # normal ops - if "input_shape_list" in self.workload: - shape_list = self.workload["input_shape_list"] - # gemm or batch_gemm - elif "M/N/K" in self.workload: - if "batch_size" in self.workload: - for batch_size in self.workload["batch_size"]: - for M, K, N in self.workload["M/N/K"]: - shape_list.append([ - [batch_size, M, K], - [batch_size, K, N] - ]) - else: - for M, K, N in self.workload["M/N/K"]: - shape_list.append([[M, K], [K, N]]) - # group_gemm - elif "MNK_choices" in self.workload: - seed = workload["seed"] - MNK_list = self.workload["MNK_choices"] - problems_list = workload["problems"] - - random.seed(seed) - for problems in problems_list: - cur_inputs = [] - for _ in range(problems): - M, N, K = [random.choice(MNK_list) for _ in range(3)] - cur_shapes = [[M, K], [K, N]] - cur_inputs.append(cur_shapes) - shape_list.append(cur_inputs) + shape_list = parse_workload(self.workload) # dtype list dtype_list = self.workload["dtype"] for dtype in dtype_list: - torch_dtype = getattr(torch, dtype) - perf_reports = [] base_report["Performance"] = {} @@ -225,11 +292,12 @@ def start_perf(self, workload: Dict[str, Any]) -> bool: List[List[int]]: multiple inputs. add List[List[List[in]]]: multiple inputs with multiple problems. group_gemm """ - + if local_rank == 0: + log.info(f"Execute op: [{op_name.lower()}], input_shape: {input_shape}, dtype: {dtype}") if isinstance(input_shape[0], int): input_shape = [input_shape] try: - reports = self.backend.perf(input_shape, torch_dtype) + reports = self.backend.perf(input_shape, dtype) except Exception as e: traceback.print_exc() log.error(f"Execute op: {op_name.lower()} failed, input_shape: {input_shape}, dtype: {dtype}, error msg: {e}") @@ -249,16 +317,16 @@ def start_perf(self, workload: Dict[str, Any]) -> bool: + ".json" ) output_report_path = os.path.join(output_dir, output_report_path) - local_rank = int(os.environ.get("LOCAL_RANK", 0)) if local_rank == 0: # logging.info(base_report["Performance"]) with open(output_report_path, "w") as file: json.dump(base_report, file, indent=4) - log.info( - "******************************************* End to test op: {}. *******************************************".format( - workload["operator"] + if local_rank == 0: + log.info( + "******************************************* Test op: [{}] SUCCESS. *******************************************".format( + workload["operator"] + ) ) - ) return True def get_cpu_name(self): diff --git a/byte_micro_perf/requirements.txt b/byte_micro_perf/requirements.txt index 5011e26e..ce7a7fc9 100644 --- a/byte_micro_perf/requirements.txt +++ b/byte_micro_perf/requirements.txt @@ -11,3 +11,5 @@ fpdf attrs decorator typing-extensions +torch==2.1.0 +nvidia-cutlass diff --git a/byte_micro_perf/workloads/add.json b/byte_micro_perf/workloads/add.json index c8a5591f..5885cc87 100644 --- a/byte_micro_perf/workloads/add.json +++ b/byte_micro_perf/workloads/add.json @@ -1,68 +1,18 @@ { "operator": "add", "iterations": 100, - "input_shape_list": [ - [ + "input_shape_groups": { + "inputs": [ [ - 1, - 65536 + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] ], [ - 1, - 65536 - ] - ], - [ - [ - 16, - 65536 - ], - [ - 16, - 65536 - ] - ], - [ - [ - 32, - 65536 - ], - [ - 32, - 65536 - ] - ], - [ - [ - 64, - 65536 - ], - [ - 64, - 65536 - ] - ], - [ - [ - 128, - 65536 - ], - [ - 128, - 65536 - ] - ], - [ - [ - 65536, - 65536 - ], - [ - 65536, - 65536 + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/allgather.json b/byte_micro_perf/workloads/allgather.json index efc59392..a7d0b0a6 100644 --- a/byte_micro_perf/workloads/allgather.json +++ b/byte_micro_perf/workloads/allgather.json @@ -1,42 +1,14 @@ { "operator": "allgather", "iterations": 100, - "input_shape_list": [ - [ - 1024, - 1024 - ], - [ - 8, - 1024, - 1024 - ], - [ - 32, - 1024, - 1024 - ], - [ - 64, - 1024, - 1024 - ], - [ - 128, - 1024, - 1024 - ], - [ - 256, - 1024, - 1024 - ], - [ - 512, - 1024, - 1024 + "input_shape_groups": { + "inputs": [ + [ + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [1024] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/allreduce.json b/byte_micro_perf/workloads/allreduce.json index 94cdc404..d81356cc 100644 --- a/byte_micro_perf/workloads/allreduce.json +++ b/byte_micro_perf/workloads/allreduce.json @@ -1,42 +1,14 @@ { "operator": "allreduce", "iterations": 100, - "input_shape_list": [ - [ - 1024, - 1024 - ], - [ - 8, - 1024, - 1024 - ], - [ - 32, - 1024, - 1024 - ], - [ - 64, - 1024, - 1024 - ], - [ - 128, - 1024, - 1024 - ], - [ - 256, - 1024, - 1024 - ], - [ - 512, - 1024, - 1024 + "input_shape_groups": { + "inputs": [ + [ + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [1024] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/alltoall.json b/byte_micro_perf/workloads/alltoall.json index f7b0fa23..7550fa71 100644 --- a/byte_micro_perf/workloads/alltoall.json +++ b/byte_micro_perf/workloads/alltoall.json @@ -1,90 +1,18 @@ { "operator": "alltoall", "iterations": 100, - "input_shape_list": [ - [ + "input_shape_groups": { + "inputs": [ [ - 1024, - 1024 + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [1024] ], [ - 1024, - 1024 - ] - ], - [ - [ - 8, - 1024, - 1024 - ], - [ - 8, - 1024, - 1024 - ] - ], - [ - [ - 32, - 1024, - 1024 - ], - [ - 32, - 1024, - 1024 - ] - ], - [ - [ - 64, - 1024, - 1024 - ], - [ - 64, - 1024, - 1024 - ] - ], - [ - [ - 128, - 1024, - 1024 - ], - [ - 128, - 1024, - 1024 - ] - ], - [ - [ - 256, - 1024, - 1024 - ], - [ - 256, - 1024, - 1024 - ] - ], - [ - [ - 512, - 1024, - 1024 - ], - [ - 512, - 1024, - 1024 + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [1024] ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/batch_gemm.json b/byte_micro_perf/workloads/batch_gemm.json index 59b5ac71..13c3773e 100644 --- a/byte_micro_perf/workloads/batch_gemm.json +++ b/byte_micro_perf/workloads/batch_gemm.json @@ -1,21 +1,31 @@ { "operator": "batch_gemm", "iterations": 100, - "M/N/K": [ - [ - 1024, - 1024, - 1024 - ] - ], - "batch_size": [ - 2, - 8, - 32 - ], + "input_shape_groups": [ + { + "batch_size": [4, 8, 16, 32, 64, 128, 256, 512, 1024], + "MN": [[1, 1], [1, 1024], [1, 2048], [1, 4096]], + "K": [128, 256, 512] + }, + { + "batch_size": [4, 8, 16, 32, 64, 128, 256], + "MN": [[1, 8192],[1, 16384], [1, 32768], [1, 65536], [1, 131072]], + "K": [128, 256, 512] + }, + { + "batch_size": [1, 2, 4, 8, 16, 32], + "MN": [[1, 1], [1024, 1024], [2048, 2048], [4096, 4096], [8192, 8192]], + "K": [128, 256, 512] + }, + { + "batch_size": [1, 2, 4], + "MN": [[16384, 16384], [32768, 32768], [65536, 65536], [131072, 131072]], + "K": [128, 256, 512] + } + ], "dtype": [ - "float32", - "bfloat16", + "float32", + "bfloat16", "half", "int8" ] diff --git a/byte_micro_perf/workloads/broadcast.json b/byte_micro_perf/workloads/broadcast.json index 40a15139..b815360a 100644 --- a/byte_micro_perf/workloads/broadcast.json +++ b/byte_micro_perf/workloads/broadcast.json @@ -1,42 +1,14 @@ { - "operator": "broadcast", + "operator": "broadcast", "iterations": 100, - "input_shape_list": [ - [ - 1024, - 1024 - ], - [ - 8, - 1024, - 1024 - ], - [ - 32, - 1024, - 1024 - ], - [ - 64, - 1024, - 1024 - ], - [ - 128, - 1024, - 1024 - ], - [ - 256, - 1024, - 1024 - ], - [ - 512, - 1024, - 1024 + "input_shape_groups": { + "inputs": [ + [ + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [1024] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/cast.json b/byte_micro_perf/workloads/cast.json new file mode 100644 index 00000000..07ab85dd --- /dev/null +++ b/byte_micro_perf/workloads/cast.json @@ -0,0 +1,17 @@ +{ + "operator": "cast", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] + ] +}, +"dtype": [ + "float32", + "bfloat16", + "half" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/cos.json b/byte_micro_perf/workloads/cos.json index 7203d99a..62725bca 100644 --- a/byte_micro_perf/workloads/cos.json +++ b/byte_micro_perf/workloads/cos.json @@ -1,23 +1,14 @@ { "operator": "cos", "iterations": 100, - "input_shape_list": [ - [ - 4, - 1024, - 1024 - ], - [ - 16, - 1024, - 1024 - ], - [ - 64, - 1024, - 1024 + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/device2host.json b/byte_micro_perf/workloads/device2host.json index 971956aa..3bb34dab 100644 --- a/byte_micro_perf/workloads/device2host.json +++ b/byte_micro_perf/workloads/device2host.json @@ -1,38 +1,14 @@ { "operator": "device2host", "iterations": 100, - "input_shape_list": [ - [ - 1, - 1024, - 1024 - ], - [ - 4, - 1024, - 1024 - ], - [ - 16, - 1024, - 1024 - ], - [ - 64, - 1024, - 1024 - ], - [ - 128, - 1024, - 1024 - ], - [ - 256, - 1024, - 1024 + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [1024] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/div.json b/byte_micro_perf/workloads/div.json new file mode 100644 index 00000000..bb55608b --- /dev/null +++ b/byte_micro_perf/workloads/div.json @@ -0,0 +1,22 @@ +{ + "operator": "div", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ], + + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] + ] + }, + "dtype": [ + "float32", + "bfloat16", + "half" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/exp.json b/byte_micro_perf/workloads/exp.json index 025d0303..17d00d0f 100644 --- a/byte_micro_perf/workloads/exp.json +++ b/byte_micro_perf/workloads/exp.json @@ -1,23 +1,14 @@ { "operator": "exp", "iterations": 100, - "input_shape_list": [ - [ - 4, - 1024, - 1024 - ], - [ - 16, - 1024, - 1024 - ], - [ - 64, - 1024, - 1024 + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/exponential.json b/byte_micro_perf/workloads/exponential.json index d5ce4832..967a58e9 100644 --- a/byte_micro_perf/workloads/exponential.json +++ b/byte_micro_perf/workloads/exponential.json @@ -1,13 +1,14 @@ { "operator": "exponential", "iterations": 100, - "input_shape_list": [ - [ - 8, - 1024, - 1024 + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/gelu.json b/byte_micro_perf/workloads/gelu.json index 2f0984d1..65574955 100644 --- a/byte_micro_perf/workloads/gelu.json +++ b/byte_micro_perf/workloads/gelu.json @@ -1,23 +1,14 @@ { "operator": "gelu", "iterations": 100, - "input_shape_list": [ - [ - 4, - 1024, - 1024 - ], - [ - 16, - 1024, - 1024 - ], - [ - 64, - 1024, - 1024 + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/gemm.json b/byte_micro_perf/workloads/gemm.json index 82350510..d287481d 100644 --- a/byte_micro_perf/workloads/gemm.json +++ b/byte_micro_perf/workloads/gemm.json @@ -1,223 +1,10 @@ { "operator": "gemm", "iterations": 100, - "M/N/K": [ - [ - 64, - 65536, - 2048 - ], - [ - 64, - 2048, - 65536 - ], - [ - 2048, - 65536, - 64 - ], - [ - 2048, - 64, - 65536 - ], - [ - 65536, - 64, - 2048 - ], - [ - 65536, - 2048, - 64 - ], - [ - 64, - 65280, - 2048 - ], - [ - 64, - 2048, - 65280 - ], - [ - 2048, - 65280, - 64 - ], - [ - 2048, - 64, - 65280 - ], - [ - 65280, - 64, - 2048 - ], - [ - 65280, - 2048, - 64 - ], - [ - 800, - 1536, - 12288 - ], - [ - 128, - 1536, - 12288 - ], - [ - 800, - 12288, - 1536 - ], - [ - 128, - 12288, - 1536 - ], - [ - 64, - 64, - 65536 - ], - [ - 64, - 65536, - 65536 - ], - [ - 65536, - 64, - 65536 - ], - [ - 64, - 64, - 64 - ], - [ - 65536, - 64, - 64 - ], - [ - 64, - 65536, - 64 - ], - [ - 65536, - 65536, - 64 - ], - [ - 65536, - 65536, - 65536 - ], - [ - 1, - 16, - 7168 - ], - [ - 1, - 32, - 7168 - ], - [ - 1, - 64, - 7168 - ], - [ - 1, - 128, - 7168 - ], - [ - 1, - 16, - 8192 - ], - [ - 1, - 32, - 8192 - ], - [ - 1, - 64, - 8192 - ], - [ - 1, - 128, - 8192 - ], - [ - 1, - 7168, - 16 - ], - [ - 1, - 7168, - 32 - ], - [ - 1, - 7168, - 64 - ], - [ - 1, - 7168, - 128 - ], - [ - 1, - 8192, - 16 - ], - [ - 1, - 8192, - 32 - ], - [ - 1, - 8192, - 64 - ], - [ - 1, - 8192, - 128 - ], - [ - 1, - 7168, - 7168 - ], - [ - 1, - 8192, - 8192 - ], - [ - 1, - 65536, - 65536 - ] - ], + "input_shape_groups": { + "M": [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + "KN": [[1024, 1024], [16384, 1024], [16384, 32], [1024, 16384]] + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/gemv.json b/byte_micro_perf/workloads/gemv.json new file mode 100644 index 00000000..dcf6a1d0 --- /dev/null +++ b/byte_micro_perf/workloads/gemv.json @@ -0,0 +1,22 @@ +{ + "operator": "gemv", + "iterations": 100, + "input_shape_groups": [ + { + "M": [1], + "K": [16, 32, 64, 128, 256, 512], + "N": [4096, 8192] + }, + { + "M": [1], + "K": [4096, 8192], + "N": [16, 32, 64, 128, 256, 512] + } + ], + "dtype": [ + "float32", + "bfloat16", + "half", + "int8" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/group_gemm.json b/byte_micro_perf/workloads/group_gemm.json index 51dbfe68..745d66a6 100644 --- a/byte_micro_perf/workloads/group_gemm.json +++ b/byte_micro_perf/workloads/group_gemm.json @@ -1,16 +1,10 @@ { "operator": "group_gemm", - "iterations": 20, - "MNK_choices": [ - 128, - 256, - 512, - 1024 - ], - "seed": 2024, - "problems": [ - 20 - ], + "iterations": 100, + "input_shape_groups": { + "gemm_group": [[1, 16, 32, 64, 128, 256, 512, 1024]], + "KN": [[4096, 4096], [7168, 7168], [16384, 16384]] + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/host2device.json b/byte_micro_perf/workloads/host2device.json index 417cf8d6..8982c678 100644 --- a/byte_micro_perf/workloads/host2device.json +++ b/byte_micro_perf/workloads/host2device.json @@ -1,38 +1,14 @@ { "operator": "host2device", "iterations": 100, - "input_shape_list": [ - [ - 1, - 1024, - 1024 - ], - [ - 4, - 1024, - 1024 - ], - [ - 16, - 1024, - 1024 - ], - [ - 64, - 1024, - 1024 - ], - [ - 128, - 1024, - 1024 - ], - [ - 256, - 1024, - 1024 + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [1024] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/index_add.json b/byte_micro_perf/workloads/index_add.json new file mode 100644 index 00000000..64744d14 --- /dev/null +++ b/byte_micro_perf/workloads/index_add.json @@ -0,0 +1,21 @@ +{ + "operator": "index_add", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [1024], + [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288] + ], + [ + [1024], + [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288] + ] + ] + }, + "dtype": [ + "float32", + "half", + "bfloat16" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/indexadd.json b/byte_micro_perf/workloads/indexadd.json deleted file mode 100644 index 5617d6b2..00000000 --- a/byte_micro_perf/workloads/indexadd.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "operator": "indexadd", - "iterations": 100, - "input_shape_list": [ - [ - [ - 4, - 7168 - ], - [ - 20, - 7168 - ] - ], - [ - [ - 2048, - 7168 - ], - [ - 10240, - 7168 - ] - ] - ], - "dtype": [ - "float32", - "half", - "bfloat16" - ] -} \ No newline at end of file diff --git a/byte_micro_perf/workloads/layernorm.json b/byte_micro_perf/workloads/layernorm.json index 80146428..87711ee2 100644 --- a/byte_micro_perf/workloads/layernorm.json +++ b/byte_micro_perf/workloads/layernorm.json @@ -1,40 +1,14 @@ { "operator": "layernorm", "iterations": 100, - "input_shape_list": [ - [ - 131072, - 32 - ], - [ - 131072, - 64 - ], - [ - 131072, - 128 - ], - [ - 131072, - 512 - ], - [ - 131072, - 1024 - ], - [ - 131072, - 4096 - ], - [ - 131072, - 16384 - ], - [ - 131072, - 32768 + "input_shape_groups": { + "inputs": [ + [ + [1024], + [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/mul.json b/byte_micro_perf/workloads/mul.json new file mode 100644 index 00000000..c7935637 --- /dev/null +++ b/byte_micro_perf/workloads/mul.json @@ -0,0 +1,22 @@ +{ + "operator": "mul", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ], + + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] + ] + }, + "dtype": [ + "float32", + "bfloat16", + "half" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/p2p.json b/byte_micro_perf/workloads/p2p.json new file mode 100644 index 00000000..7d0c5310 --- /dev/null +++ b/byte_micro_perf/workloads/p2p.json @@ -0,0 +1,26 @@ +{ + "operator": "p2p", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [1024] + ], + [ + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [1024] + ] + ] + }, + "dtype": [ + "float32", + "bfloat16", + "half" + ], + "group": [ + 2, + 4, + 8 + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/reduce_max.json b/byte_micro_perf/workloads/reduce_max.json new file mode 100644 index 00000000..ae311a3e --- /dev/null +++ b/byte_micro_perf/workloads/reduce_max.json @@ -0,0 +1,17 @@ +{ + "operator": "reduce_max", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [1024], + [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288] + ] + ] + }, + "dtype": [ + "float32", + "bfloat16", + "half" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/reduce_min.json b/byte_micro_perf/workloads/reduce_min.json new file mode 100644 index 00000000..7b7edb04 --- /dev/null +++ b/byte_micro_perf/workloads/reduce_min.json @@ -0,0 +1,17 @@ +{ + "operator": "reduce_min", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [1024], + [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288] + ] + ] + }, + "dtype": [ + "float32", + "bfloat16", + "half" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/reduce_sum.json b/byte_micro_perf/workloads/reduce_sum.json new file mode 100644 index 00000000..56cf77d8 --- /dev/null +++ b/byte_micro_perf/workloads/reduce_sum.json @@ -0,0 +1,17 @@ +{ + "operator": "reduce_sum", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [1024], + [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288] + ] + ] + }, + "dtype": [ + "float32", + "bfloat16", + "half" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/reducescatter.json b/byte_micro_perf/workloads/reducescatter.json index e472e25a..228f1f6d 100644 --- a/byte_micro_perf/workloads/reducescatter.json +++ b/byte_micro_perf/workloads/reducescatter.json @@ -1,42 +1,14 @@ { "operator": "reducescatter", "iterations": 100, - "input_shape_list": [ - [ - 1024, - 1024 - ], - [ - 8, - 1024, - 1024 - ], - [ - 32, - 1024, - 1024 - ], - [ - 64, - 1024, - 1024 - ], - [ - 128, - 1024, - 1024 - ], - [ - 256, - 1024, - 1024 - ], - [ - 512, - 1024, - 1024 + "input_shape_groups": { + "inputs": [ + [ + [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608], + [1024] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/silu.json b/byte_micro_perf/workloads/silu.json new file mode 100644 index 00000000..3770218c --- /dev/null +++ b/byte_micro_perf/workloads/silu.json @@ -0,0 +1,17 @@ +{ + "operator": "silu", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] + ] + }, + "dtype": [ + "float32", + "bfloat16", + "half" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/sin.json b/byte_micro_perf/workloads/sin.json index 00d37b94..bf2bacda 100644 --- a/byte_micro_perf/workloads/sin.json +++ b/byte_micro_perf/workloads/sin.json @@ -1,23 +1,14 @@ { "operator": "sin", "iterations": 100, - "input_shape_list": [ - [ - 4, - 1024, - 1024 - ], - [ - 16, - 1024, - 1024 - ], - [ - 64, - 1024, - 1024 + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/softmax.json b/byte_micro_perf/workloads/softmax.json index fae9cc69..a90f294d 100644 --- a/byte_micro_perf/workloads/softmax.json +++ b/byte_micro_perf/workloads/softmax.json @@ -1,52 +1,14 @@ { "operator": "softmax", "iterations": 100, - "input_shape_list": [ - [ - 131072, - 32 - ], - [ - 131072, - 64 - ], - [ - 131072, - 128 - ], - [ - 131072, - 512 - ], - [ - 131072, - 1024 - ], - [ - 131072, - 2048 - ], - [ - 131072, - 4096 - ], - [ - 131072, - 8192 - ], - [ - 131072, - 16384 - ], - [ - 131072, - 32768 - ], - [ - 131072, - 65536 + "input_shape_groups": { + "inputs": [ + [ + [1024], + [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/sort.json b/byte_micro_perf/workloads/sort.json index f19dd2b2..a30222a0 100644 --- a/byte_micro_perf/workloads/sort.json +++ b/byte_micro_perf/workloads/sort.json @@ -1,26 +1,14 @@ { "operator": "sort", "iterations": 100, - "input_shape_list": [ - [ - 20 - ], - [ - 128 - ], - [ - 1024 - ], - [ - 10240 - ], - [ - 61440 - ], - [ - 102400 + "input_shape_groups": { + "inputs": [ + [ + [1024], + [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288] + ] ] - ], + }, "dtype": [ "float32", "bfloat16", diff --git a/byte_micro_perf/workloads/sub.json b/byte_micro_perf/workloads/sub.json new file mode 100644 index 00000000..0b6a46c6 --- /dev/null +++ b/byte_micro_perf/workloads/sub.json @@ -0,0 +1,22 @@ +{ + "operator": "sub", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ], + + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] + ] + }, + "dtype": [ + "float32", + "bfloat16", + "half" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/swiglu.json b/byte_micro_perf/workloads/swiglu.json new file mode 100644 index 00000000..9982a2c9 --- /dev/null +++ b/byte_micro_perf/workloads/swiglu.json @@ -0,0 +1,17 @@ +{ + "operator": "swiglu", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] + ] + }, + "dtype": [ + "float32", + "bfloat16", + "half" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/unique.json b/byte_micro_perf/workloads/unique.json index 452e243b..ba88ea4e 100644 --- a/byte_micro_perf/workloads/unique.json +++ b/byte_micro_perf/workloads/unique.json @@ -1,26 +1,14 @@ { "operator": "unique", "iterations": 100, - "input_shape_list": [ - [ - 20 - ], - [ - 128 - ], - [ - 1024 - ], - [ - 10240 - ], - [ - 61440 - ], - [ - 102400 + "input_shape_groups": { + "inputs": [ + [ + [1024], + [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288] + ] ] - ], + }, "dtype": [ "float32", "half",