diff --git a/byte_micro_perf/core/perf_engine.py b/byte_micro_perf/core/perf_engine.py index b20d0912..24af158e 100644 --- a/byte_micro_perf/core/perf_engine.py +++ b/byte_micro_perf/core/perf_engine.py @@ -249,6 +249,8 @@ def start_engine(self) -> None: output_dir.mkdir(parents=True, exist_ok=True) + op_name = self.workload["operator"] + # get input shape info target_group_list = self.workload.get("group", [1]) target_group_list.sort() @@ -304,6 +306,8 @@ def signal_handler(signum, frame): # get actual instance num instance_num = min(device_count, max(1, self.args.parallel)) if group == 1 else group + if group == 1 and op_name in ["host2device", "device2host"]: + instance_num = 1 input_queues = mp.Queue() output_queues = mp.Queue(maxsize=1)