Skip to content

Commit

Permalink
Merge pull request bytedance#94 from bytedance/jzs/support_multi_card…
Browse files Browse the repository at this point in the history
…_perf

support multi-card perf for micro_perf.
  • Loading branch information
suisiyuan authored Aug 22, 2024
2 parents 9e4c3c7 + fbe99ad commit aca559b
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 186 deletions.
44 changes: 34 additions & 10 deletions byte_micro_perf/backends/GPU/backend_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@


class BackendGPU(Backend):

def get_device_count(self):
return torch.cuda.device_count()

def set_device(self, device_index):
torch.cuda.set_device(device_index)

def get_device(self):
return torch.cuda.current_device()

def all_gather_object(self, obj):
gather_object_list = [None for _ in range(self.world_size)]
dist.all_gather_object(
object_list=gather_object_list,
obj=obj,
group=self.group
)
return gather_object_list


def get_device_name(self):
return torch.cuda.get_device_name(0)

Expand Down Expand Up @@ -234,16 +254,14 @@ def device_synchronize(self):
torch.cuda.synchronize()
return True


def initialize_ccl(self, rank, world_size):
"""
initialize distributed process groups and relevant ENVs
"""
# check device_count
device_count = torch.cuda.device_count()
if world_size > device_count:
world_size = device_count
if rank >= world_size:
return False

# set cuda device
torch.cuda.set_device(rank)

# set envs
os.environ["MASTER_ADDR"] = "127.0.0.1"
Expand All @@ -252,31 +270,37 @@ def initialize_ccl(self, rank, world_size):
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)

torch.cuda.set_device(rank)

# Call the init process
timeout_seconds = int(os.environ.get("MEGATRON_NCCL_TIMEOUT_SECOND", 30))
torch.distributed.init_process_group(
backend="nccl",
world_size=world_size,
rank=rank,
store=None,
timeout=timedelta(seconds=timeout_seconds),
)

# create group
self.setup_2d_group()
log.warning("DIST: rank {}, world_size {}".format(rank, world_size))

log.info(f"DIST: rank {rank}, world_size {world_size}")
return True


def setup_2d_group(self):
# get rank and set device
self.rank = dist.get_rank()
torch.cuda.set_device(self.rank)

origin_store_based_barrier = dist_c10d._store_based_barrier
dist_c10d._store_based_barrier = lambda *a, **kw: None

self.world_size = dist.get_world_size()
self.ranks = range(0, self.world_size)
group = dist.new_group(self.ranks)
if self.rank in self.ranks:
self.group = group

dist_c10d._store_based_barrier = origin_store_based_barrier

# wait for all ranks finish group initializing
torch.distributed.barrier()
3 changes: 2 additions & 1 deletion byte_micro_perf/backends/GPU/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
torch==2.1.0
nvidia-cutlass
-i https://download.pytorch.org/whl/cu118
torch==2.1.2
18 changes: 18 additions & 0 deletions byte_micro_perf/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ def __init__(self, workload_dict: Dict[str, Any], vendor_path: str):

self.target_dtype = None

@abstractmethod
def get_device_count(self):
pass

@abstractmethod
def set_device(self, device_index):
pass

@abstractmethod
def get_device(self):
pass

@abstractmethod
def all_gather_object(self, obj):
pass



@abstractmethod
def get_device_name(self):
pass
Expand Down
Loading

0 comments on commit aca559b

Please sign in to comment.