Skip to content

Commit

Permalink
Merge pull request bytedance#102 from bytedance/gyj/add_log_sqrt_op
Browse files Browse the repository at this point in the history
[micro perf]: add log and sqrt op
  • Loading branch information
suisiyuan authored Sep 25, 2024
2 parents 1113caa + 16093a1 commit c3d23ad
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 1 deletion.
5 changes: 5 additions & 0 deletions byte_micro_perf/backends/GPU/backend_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def initialize_ccl(self, rank, world_size):
self.setup_2d_group()
return True

def log(self):
self.op = LogOp()

def sqrt(self):
self.op = SqrtOp()

def setup_2d_group(self):
# get rank and set device
Expand Down
5 changes: 5 additions & 0 deletions byte_micro_perf/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ def swiglu(self):
def cast(self):
pass

def log(self):
pass

def sqrt(self):
pass

# binary ops
def add(self):
Expand Down
18 changes: 18 additions & 0 deletions byte_micro_perf/backends/module_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,24 @@ def forward(self, input_tensors):
return result


class LogOp(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input_tensors):
result = torch.log(input_tensors)
return result


class SqrtOp(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input_tensors):
result = torch.sqrt(input_tensors)
return result


class AddOp(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion byte_micro_perf/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def parse_task(task_dir):
if task in ["gemm", "gemv", "batch_gemm", "group_gemm"]:
task_mapping["gemm_ops"].append(task)

if task in ["sin", "cos", "exp", "exponential", "silu", "gelu", "swiglu", "cast"]:
if task in ["sin", "cos", "exp", "exponential", "silu", "gelu", "swiglu", "cast", "log", "sqrt"]:
task_mapping["unary_ops"].append(task)

if task in ["add", "mul", "sub", "div"]:
Expand Down
17 changes: 17 additions & 0 deletions byte_micro_perf/workloads/log.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"operator": "log",
"iterations": 100,
"input_shape_groups": {
"inputs": [
[
[4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072],
[8192]
]
]
},
"dtype": [
"float32",
"bfloat16",
"half"
]
}
17 changes: 17 additions & 0 deletions byte_micro_perf/workloads/sqrt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"operator": "sqrt",
"iterations": 100,
"input_shape_groups": {
"inputs": [
[
[4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072],
[8192]
]
]
},
"dtype": [
"float32",
"bfloat16",
"half"
]
}

0 comments on commit c3d23ad

Please sign in to comment.