From 567e5d248db99c7c28a075896ec5e60051f6d7d7 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 9 Oct 2023 22:15:30 +0000 Subject: [PATCH] fix distributed init logic for torchbench dynamo runner --- .../models/simple_gpt_tp_manual/__init__.py | 10 +-- .../models/simple_gpt_tp_manual/model.py | 1 + .../models/simple_gpt_tp_manual/tp.py | 63 +++++++++---------- 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/torchbenchmark/models/simple_gpt_tp_manual/__init__.py b/torchbenchmark/models/simple_gpt_tp_manual/__init__.py index bf1abed422..905fb20d75 100644 --- a/torchbenchmark/models/simple_gpt_tp_manual/__init__.py +++ b/torchbenchmark/models/simple_gpt_tp_manual/__init__.py @@ -47,17 +47,17 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): # temporary workarounds torch._inductor.config.allow_buffer_reuse = False torch._inductor.config.inplace_buffers = False - torch.cuda.set_device(self._rank) model = LLaMA.from_name("7B") print("Applying tensor parallel to model ...") - apply_tp(model) + apply_tp(model, self._rank, self._world_size) max_batch_size = self.batch_size - model.setup_caches( - max_batch_size=max_batch_size, max_seq_length=model.config.block_size - ) + with torch.device(device): + model.setup_caches( + max_batch_size=max_batch_size, max_seq_length=model.config.block_size + ) self.model = model.to(device=device, dtype=torch.bfloat16) diff --git a/torchbenchmark/models/simple_gpt_tp_manual/model.py b/torchbenchmark/models/simple_gpt_tp_manual/model.py index f3e66aa559..023a9922d2 100644 --- a/torchbenchmark/models/simple_gpt_tp_manual/model.py +++ b/torchbenchmark/models/simple_gpt_tp_manual/model.py @@ -143,6 +143,7 @@ def _init_weights(self, module: nn.Module) -> None: elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) + @torch.no_grad() def forward( self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]: diff --git a/torchbenchmark/models/simple_gpt_tp_manual/tp.py b/torchbenchmark/models/simple_gpt_tp_manual/tp.py index 32b7534476..74802e48d4 100644 --- a/torchbenchmark/models/simple_gpt_tp_manual/tp.py +++ b/torchbenchmark/models/simple_gpt_tp_manual/tp.py @@ -1,4 +1,3 @@ -import os from typing import Optional, List import torch @@ -8,28 +7,16 @@ from .model import LLaMA, CausalSelfAttention, MLP -def _get_rank() -> int: - return int(os.environ.get("LOCAL_RANK", "0")) - +LOCAL_RANK = None +LOCAL_WORLD_SIZE = None -def _get_world_size() -> int: - return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) -def maybe_init_dist() -> Optional[int]: - try: - # provided by torchrun - rank = _get_rank() - world_size = _get_world_size() +def _get_rank() -> int: + return LOCAL_RANK - if world_size < 2: - # too few gpus to parallelize, tp is no-op - return None - except KeyError: - # not run via torchrun, no-op - return None - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) - return rank +def _get_world_size() -> int: + return LOCAL_WORLD_SIZE def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None: @@ -47,24 +34,30 @@ def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = [ # ensure we can shard evenly assert getattr(linear, size_attr) % world_size == 0 + def shard(x, dim): + assert x.size(dim=dim) % world_size == 0 + return torch.tensor_split(x, world_size, dim=dim)[rank] + + def shard_qkv(qkv, dim): + q, k, v = qkv.split(weight_splits, dim=dim) + q = shard(q, dim) + k = shard(k, dim) + v = shard(v, dim) + return torch.cat((q,k,v)) # shard if weight_splits: # attention assert len(weight_splits) == 3 - q, k, v = linear.weight.split(weight_splits, dim=shard_dim) - - assert q.size(dim=shard_dim) % world_size == 0 - q = torch.tensor_split(q, world_size, dim=shard_dim)[rank] - assert k.size(dim=shard_dim) % world_size == 0 - k = torch.tensor_split(k, world_size, dim=shard_dim)[rank] - assert v.size(dim=shard_dim) % world_size == 0 - v = torch.tensor_split(v, world_size, dim=shard_dim)[rank] - - sharded_weight = torch.cat((q,k,v)) + sharded_weight = shard_qkv(linear.weight, shard_dim) + if hasattr(linear, "scales") and style == "colwise": + linear.scales = shard_qkv(linear.scales, 0) else: - sharded_weight = torch.tensor_split(linear.weight, world_size, dim=shard_dim)[rank] + sharded_weight = shard(linear.weight, shard_dim) + if hasattr(linear, "scales") and style == "colwise": + linear.scales = shard(linear.scales, 0) + # overwrite linear.weight = nn.Parameter(sharded_weight, requires_grad=False) @@ -113,9 +106,15 @@ def _apply_tp_llama(llama: LLaMA) -> None: llama.config.n_head = llama.config.n_head // world_size llama.config.n_embd = llama.config.n_embd // world_size llama.config.n_query_groups = llama.config.n_query_groups // world_size - -def apply_tp(model: LLaMA) -> None: + +def apply_tp(model: LLaMA, rank: int, world_size: int) -> None: + global LOCAL_RANK, LOCAL_WORLD_SIZE + LOCAL_RANK = rank + LOCAL_WORLD_SIZE = world_size + assert LOCAL_RANK >= 0 and LOCAL_RANK < 8 + assert LOCAL_WORLD_SIZE > 1 and LOCAL_WORLD_SIZE <= 8 + _apply_tp_llama(model) for block in model.transformer.h: # Apply to MLP