diff --git a/byte_micro_perf/backends/GPU/backend_gpu.py b/byte_micro_perf/backends/GPU/backend_gpu.py index d7c65ca9..a6bbefaf 100644 --- a/byte_micro_perf/backends/GPU/backend_gpu.py +++ b/byte_micro_perf/backends/GPU/backend_gpu.py @@ -88,6 +88,11 @@ def initialize_ccl(self, rank, world_size): self.setup_2d_group() return True + def log(self): + self.op = LogOp() + + def sqrt(self): + self.op = SqrtOp() def setup_2d_group(self): # get rank and set device diff --git a/byte_micro_perf/backends/backend.py b/byte_micro_perf/backends/backend.py index 4a57599b..b6f8f1dc 100644 --- a/byte_micro_perf/backends/backend.py +++ b/byte_micro_perf/backends/backend.py @@ -301,6 +301,11 @@ def swiglu(self): def cast(self): pass + def log(self): + pass + + def sqrt(self): + pass # binary ops def add(self): diff --git a/byte_micro_perf/backends/module_store.py b/byte_micro_perf/backends/module_store.py index 1655a8cb..1c1015e2 100644 --- a/byte_micro_perf/backends/module_store.py +++ b/byte_micro_perf/backends/module_store.py @@ -404,6 +404,24 @@ def forward(self, input_tensors): return result +class LogOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_tensors): + result = torch.log(input_tensors) + return result + + +class SqrtOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input_tensors): + result = torch.sqrt(input_tensors) + return result + + class AddOp(torch.nn.Module): def __init__(self): super().__init__() diff --git a/byte_micro_perf/launch.py b/byte_micro_perf/launch.py index 6f58927e..88599a35 100644 --- a/byte_micro_perf/launch.py +++ b/byte_micro_perf/launch.py @@ -119,7 +119,7 @@ def parse_task(task_dir): if task in ["gemm", "gemv", "batch_gemm", "group_gemm"]: task_mapping["gemm_ops"].append(task) - if task in ["sin", "cos", "exp", "exponential", "silu", "gelu", "swiglu", "cast"]: + if task in ["sin", "cos", "exp", "exponential", "silu", "gelu", "swiglu", "cast", "log", "sqrt"]: task_mapping["unary_ops"].append(task) if task in ["add", "mul", "sub", "div"]: diff --git a/byte_micro_perf/workloads/log.json b/byte_micro_perf/workloads/log.json new file mode 100644 index 00000000..22d99972 --- /dev/null +++ b/byte_micro_perf/workloads/log.json @@ -0,0 +1,17 @@ +{ + "operator": "log", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] + ] + }, + "dtype": [ + "float32", + "bfloat16", + "half" + ] +} \ No newline at end of file diff --git a/byte_micro_perf/workloads/sqrt.json b/byte_micro_perf/workloads/sqrt.json new file mode 100644 index 00000000..428d9897 --- /dev/null +++ b/byte_micro_perf/workloads/sqrt.json @@ -0,0 +1,17 @@ +{ + "operator": "sqrt", + "iterations": 100, + "input_shape_groups": { + "inputs": [ + [ + [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], + [8192] + ] + ] + }, + "dtype": [ + "float32", + "bfloat16", + "half" + ] +} \ No newline at end of file