Skip to content

Commit

Permalink
feat(micro): support more ops cast, silu, swiglu, div, mul, sub, gemv…
Browse files Browse the repository at this point in the history
…, reducemax, reducemin, reducesum, p2p; modify workloads
  • Loading branch information
YJessicaGao committed Jul 12, 2024
1 parent 51bd4a7 commit f17657a
Show file tree
Hide file tree
Showing 42 changed files with 1,067 additions and 1,194 deletions.
5 changes: 1 addition & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,4 @@ init_env.sh

byte_infer_perf/llm_perf/download
byte_infer_perf/llm_perf/model_zoo/sota
byte_infer_perf/llm_perf/reports

out/
*.db
byte_infer_perf/llm_perf/reports
19 changes: 10 additions & 9 deletions byte_micro_perf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,26 @@ Example:
"Operator": "EXP",
"Backend": "GPU",
"Host Info": "Intel(R) Xeon(R) Platinum 8336C CPU @ 2.30GHz",
"Device Info": "A100-PCIE-40GB",
"Device Info": "NVIDIA A800-SXM4-80GB",
"Performance": [
{
"Dtype": "float32",
"Tensor Shapes": [
[
2,
512,
512
256,
8192
]
],
"Memory Size(MB)": 4.0,
"Kernel bandwidth(GB/s)": 271.83,
"Bandwidth Utilization(%)": 0.17,
"Avg latency(us)": 15.43
"Read IO Size(MB)": 8.0,
"Write IO Size(MB)": 8.0,
"Memory Size(MB)": 16.0,
"Kernel bandwidth(GB/s)": 1790.52,
"Bandwidth Utilization(%)": 87.81,
"Avg latency(us)": 9.37,
"QPS": 27321.24
}
]
}
```

## Trouble Shooting
Expand Down
113 changes: 78 additions & 35 deletions byte_micro_perf/backends/GPU/backend_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,6 @@ def get_backend_properties(self):
)


# gemm ops
def gemm(self):
self.op = GPUGemmOp()

def batch_gemm(self):
self.op = BatchGemmOp()

def group_gemm(self):
self.op = GPUGroupGemmOp()


# device/host ops
def host2device(self):
self.op = Host2DeviceOp(torch.device("cuda"))
Expand All @@ -75,7 +64,6 @@ def device2host(self):
self.op = Device2HostOp()



# communication ops
def allreduce(self):
self.setup_2d_group()
Expand All @@ -97,12 +85,12 @@ def broadcast(self):
self.setup_2d_group()
self.op = BroadcastOp(self.group)

def p2p(self):
self.setup_2d_group()
self.op = P2POp(self.group, self.ranks, self.rank)


# other compute ops
def add(self):
self.op = AddOp()

# compute ops
# unary ops
def sin(self):
self.op = SinOp()

Expand All @@ -115,37 +103,87 @@ def exp(self):
def exponential(self):
self.op = ExponentialOp()

def silu(self):
self.op = SiluOp()

def gelu(self):
self.op = GeluOp()

def swiglu(self):
self.op = SwiGLUOp()

def cast(self):
self.op = CastOp()


# binary ops
def add(self):
self.op = AddOp()

def mul(self):
self.op = MulOp()

def sub(self):
self.op = SubOp()

def div(self):
self.op = DivOp()


# reduce ops
def layernorm(self):
self.op = LayerNormOp()

def softmax(self):
self.op = SoftmaxOp()

def reduce_sum(self):
self.op = ReduceSumOp()

def reduce_min(self):
self.op = ReduceMinOp()

def reduce_max(self):
self.op = ReduceMaxOp()


# index ops
def index_add(self):
self.op = IndexAddOp()

def sort(self):
self.op = SortOp()

def unique(self):
self.op = UniqueOp()

def indexadd(self):
self.op = IndexAddOp()

def softmax(self):
self.op = SoftmaxOp()
# gemm ops
def gemm(self):
self.op = GPUGemmOp()

def layernorm(self):
self.op = LayerNormOp()
def gemv(self):
self.op = GPUGemmOp()

def batch_gemm(self):
self.op = GPUBatchGemmOp()

def group_gemm(self):
self.op = GPUGroupGemmOp()



# create input tensors
def build_tensor(self, input_shapes, torch_dtype):
def build_tensor(self, input_shapes, dtype):
torch.cuda.empty_cache()
torch_dtype = getattr(torch, dtype)

# compute size of input and output tensors
if hasattr(self.op, "compute_size"):
bytes_per_cnt = self.op.compute_size(input_shapes, torch_dtype)
bytes_per_cnt = self.op.compute_size(input_shapes, dtype)
# default: input_tensors_size == output_tensor_size, all tensors have same dtype
else:
dtype_size = get_dtype_bytes(torch_dtype)
dtype_size = get_dtype_bytes(dtype)
element_num = 2 * sum([math.prod(shape) for shape in input_shapes])
bytes_per_cnt = dtype_size * element_num

Expand All @@ -154,20 +192,25 @@ def build_tensor(self, input_shapes, torch_dtype):
avail_cnts = avail_bytes // bytes_per_cnt
max_data_cnt = min(self.iterations, avail_cnts)


# create input tensors for each op
input_tensors_list = []
for _ in range(max_data_cnt):
# create input tensors
if hasattr(self.op, "custom_create_tensors"):
input_tensors = self.op.custom_create_tensors(input_shapes, torch_dtype)
input_tensors = self.op.custom_create_tensors(input_shapes, torch_dtype, "cuda")
input_tensors_list.append(input_tensors)
# default: all input tensors have same dtype
else:
input_tensors = [
torch.randint(0, 3, size=shape).type(torch_dtype).to(torch.device("cuda"))
for shape in input_shapes
]
if torch_dtype in [torch.int8, torch.int32]:
input_tensors = [
torch.randint(-3, 3, size=shape, dtype=torch_dtype, device="cuda")
for shape in input_shapes
]
else:
input_tensors = [
torch.randn(shape, dtype=torch_dtype, device="cuda")
for shape in input_shapes
]
input_tensors_list.append(input_tensors)
if hasattr(self.op, "process_inputs"):
input_tensors_list = [
Expand Down Expand Up @@ -225,9 +268,9 @@ def setup_2d_group(self):
origin_store_based_barrier = dist_c10d._store_based_barrier
dist_c10d._store_based_barrier = lambda *a, **kw: None
self.world_size = dist.get_world_size()
ranks = range(0, self.world_size)
group = dist.new_group(ranks)
if self.rank in ranks:
self.ranks = range(0, self.world_size)
group = dist.new_group(self.ranks)
if self.rank in self.ranks:
self.group = group
dist_c10d._store_based_barrier = origin_store_based_barrier
# wait for all ranks finish group initializing
Expand Down
Loading

0 comments on commit f17657a

Please sign in to comment.