Skip to content

Commit

Permalink
Merge pull request bytedance#91 from bytedance/xuzhenglin/fix_gemm
Browse files Browse the repository at this point in the history
Fix gemm: fix workloads of batch/group gemm; add task_dir in perf_engine; multiply qps with 1000
  • Loading branch information
YJessicaGao authored Aug 19, 2024
2 parents b028576 + da12e24 commit fe25ade
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 39 deletions.
2 changes: 1 addition & 1 deletion byte_micro_perf/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def dump_computation_ops_report(
batch_size, total_io_amount, read_io_amount, write_io_amount = get_io_amount(op_name, input_shapes, dtype)

if error == "":
qps = round(1000 / latency * batch_size, 2)
qps = round(1e6 / latency * batch_size, 2)
algo_bw = total_io_amount / latency / 1e3

bandwidth_utils = None
Expand Down
24 changes: 13 additions & 11 deletions byte_micro_perf/core/perf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def get_args():
default="gemm",
help="The task going to be evaluted, refs to workloads/",
)
parser.add_argument(
"--task_dir",
default=os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/workloads",
help="The direcotry of tasks going to be evaluted, e.g., set to workloads"
)
parser.add_argument(
"--hardware_type",
default="GPU",
Expand All @@ -65,14 +70,14 @@ def get_args():
return args


def load_workload(task: str) -> Dict[str, Any]:
def load_workload(task: str, task_dir: str) -> Dict[str, Any]:
"""
Return a list of dictionary with model Configuration
Args: List[str]
Returns: List[dic]
"""
modules_dir = (
os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/workloads"
task_dir
)

for file in os.listdir(modules_dir):
Expand All @@ -83,8 +88,7 @@ def load_workload(task: str) -> Dict[str, Any]:
and (file.endswith(".json") or os.path.isdir(path))
and file[: file.find(".json")] == task
):
module_name = file
with open("workloads/" + module_name, "r") as f:
with open(path, "r") as f:
workload_dict = json.load(f)
return workload_dict
else:
Expand Down Expand Up @@ -160,15 +164,14 @@ def parse_workload(workload):
# group gemm
elif "gemm_group" in input_shape_group:
groups = input_shape_group.get("gemm_group", [])
batches = input_shape_group.get("batch", [])
kn = input_shape_group.get("KN", [])
if k and n:
kn.append([list(shape) for shape in itertools.product(k, n)])
for group in groups:
for _kn in kn:
group_input_shape_list = []
for m in group:
group_input_shape_list.append([[m, _kn[0]], [_kn[0], _kn[1]]])
shape_list.append(group_input_shape_list)
for batch in batches:
for _kn in kn:
shape_list.append([[[group * batch, _kn[0]], [_kn[0], _kn[1]]]])
# gemm
else:
if m and n and k:
Expand All @@ -190,7 +193,7 @@ class PerfEngine:
def __init__(self) -> None:
super().__init__()
self.args = get_args()
self.workload = load_workload(self.args.task)
self.workload = load_workload(self.args.task, self.args.task_dir)
self.backend_type = self.args.hardware_type
self.old_os_path = os.environ["PATH"]
self.prev_sys_path = list(sys.path)
Expand Down Expand Up @@ -280,7 +283,6 @@ def start_perf(self, workload: Dict[str, Any]) -> bool:

# dtype list
dtype_list = self.workload["dtype"]

for dtype in dtype_list:
perf_reports = []
base_report["Performance"] = {}
Expand Down
4 changes: 2 additions & 2 deletions byte_micro_perf/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def parse_task(task_dir):
)

for task in tasks:
cmd = "python3 core/perf_engine.py --hardware_type {} --task {} --vendor_path {}".format(
args.hardware_type, task, args.vendor_path
cmd = "python3 core/perf_engine.py --hardware_type {} --task {} --vendor_path {} --task_dir {}".format(
args.hardware_type, task, args.vendor_path, args.task_dir
)
exit_code = subprocess.call(cmd, shell=True)

Expand Down
27 changes: 5 additions & 22 deletions byte_micro_perf/workloads/batch_gemm.json
Original file line number Diff line number Diff line change
@@ -1,28 +1,11 @@
{
"operator": "batch_gemm",
"iterations": 100,
"input_shape_groups": [
{
"batch_size": [4, 8, 16, 32, 64, 128, 256, 512, 1024],
"MN": [[1, 1], [1, 1024], [1, 2048], [1, 4096]],
"K": [128, 256, 512]
},
{
"batch_size": [4, 8, 16, 32, 64, 128, 256],
"MN": [[1, 8192],[1, 16384], [1, 32768], [1, 65536], [1, 131072]],
"K": [128, 256, 512]
},
{
"batch_size": [1, 2, 4, 8, 16, 32],
"MN": [[1, 1], [1024, 1024], [2048, 2048], [4096, 4096], [8192, 8192]],
"K": [128, 256, 512]
},
{
"batch_size": [1, 2, 4],
"MN": [[16384, 16384], [32768, 32768], [65536, 65536], [131072, 131072]],
"K": [128, 256, 512]
}
],
"input_shape_groups": {
"batch_size": [8, 12, 16, 20, 24, 28, 32, 36],
"M": [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768],
"KN": [[1024, 1024], [4096, 4096], [8192, 8192], [12288, 12288], [16384, 32], [32, 16384], [16384, 1024], [1024, 16384]]
},
"dtype": [
"float32",
"bfloat16",
Expand Down
2 changes: 1 addition & 1 deletion byte_micro_perf/workloads/gemm.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"iterations": 100,
"input_shape_groups": {
"M": [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072],
"KN": [[1024, 1024], [16384, 1024], [16384, 32], [1024, 16384], [4096, 4096], [8192, 8192], [12288, 12288]]
"KN": [[1024, 1024], [16384, 1024], [16384, 32], [32, 16384], [1024, 16384], [4096, 4096], [8192, 8192], [12288, 12288]]
},
"dtype": [
"float32",
Expand Down
5 changes: 3 additions & 2 deletions byte_micro_perf/workloads/group_gemm.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
"operator": "group_gemm",
"iterations": 100,
"input_shape_groups": {
"gemm_group": [[1, 16, 32, 64, 128, 256, 512, 1024]],
"KN": [[4096, 4096], [7168, 7168], [16384, 16384]]
"gemm_group": [1, 2, 3, 4, 5, 6, 7, 8],
"batch": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192],
"KN": [[32, 16384], [16384, 32], [16384, 16384]]
},
"dtype": [
"float32",
Expand Down

0 comments on commit fe25ade

Please sign in to comment.