Skip to content

Commit

Permalink
fix distributed init logic for torchbench dynamo runner
Browse files Browse the repository at this point in the history
  • Loading branch information
xmfan committed Oct 9, 2023
1 parent 320111e commit 567e5d2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 37 deletions.
10 changes: 5 additions & 5 deletions torchbenchmark/models/simple_gpt_tp_manual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
# temporary workarounds
torch._inductor.config.allow_buffer_reuse = False
torch._inductor.config.inplace_buffers = False
torch.cuda.set_device(self._rank)

model = LLaMA.from_name("7B")

print("Applying tensor parallel to model ...")
apply_tp(model)
apply_tp(model, self._rank, self._world_size)

max_batch_size = self.batch_size
model.setup_caches(
max_batch_size=max_batch_size, max_seq_length=model.config.block_size
)
with torch.device(device):
model.setup_caches(
max_batch_size=max_batch_size, max_seq_length=model.config.block_size
)

self.model = model.to(device=device, dtype=torch.bfloat16)

Expand Down
1 change: 1 addition & 0 deletions torchbenchmark/models/simple_gpt_tp_manual/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def _init_weights(self, module: nn.Module) -> None:
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))

@torch.no_grad()
def forward(
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]:
Expand Down
63 changes: 31 additions & 32 deletions torchbenchmark/models/simple_gpt_tp_manual/tp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import Optional, List

import torch
Expand All @@ -8,28 +7,16 @@
from .model import LLaMA, CausalSelfAttention, MLP


def _get_rank() -> int:
return int(os.environ.get("LOCAL_RANK", "0"))

LOCAL_RANK = None
LOCAL_WORLD_SIZE = None

def _get_world_size() -> int:
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))

def maybe_init_dist() -> Optional[int]:
try:
# provided by torchrun
rank = _get_rank()
world_size = _get_world_size()
def _get_rank() -> int:
return LOCAL_RANK

if world_size < 2:
# too few gpus to parallelize, tp is no-op
return None
except KeyError:
# not run via torchrun, no-op
return None

dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
return rank
def _get_world_size() -> int:
return LOCAL_WORLD_SIZE


def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None:
Expand All @@ -47,24 +34,30 @@ def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = [

# ensure we can shard evenly
assert getattr(linear, size_attr) % world_size == 0
def shard(x, dim):
assert x.size(dim=dim) % world_size == 0
return torch.tensor_split(x, world_size, dim=dim)[rank]

def shard_qkv(qkv, dim):
q, k, v = qkv.split(weight_splits, dim=dim)
q = shard(q, dim)
k = shard(k, dim)
v = shard(v, dim)
return torch.cat((q,k,v))

# shard
if weight_splits:
# attention
assert len(weight_splits) == 3

q, k, v = linear.weight.split(weight_splits, dim=shard_dim)

assert q.size(dim=shard_dim) % world_size == 0
q = torch.tensor_split(q, world_size, dim=shard_dim)[rank]
assert k.size(dim=shard_dim) % world_size == 0
k = torch.tensor_split(k, world_size, dim=shard_dim)[rank]
assert v.size(dim=shard_dim) % world_size == 0
v = torch.tensor_split(v, world_size, dim=shard_dim)[rank]

sharded_weight = torch.cat((q,k,v))
sharded_weight = shard_qkv(linear.weight, shard_dim)
if hasattr(linear, "scales") and style == "colwise":
linear.scales = shard_qkv(linear.scales, 0)
else:
sharded_weight = torch.tensor_split(linear.weight, world_size, dim=shard_dim)[rank]
sharded_weight = shard(linear.weight, shard_dim)
if hasattr(linear, "scales") and style == "colwise":
linear.scales = shard(linear.scales, 0)


# overwrite
linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
Expand Down Expand Up @@ -113,9 +106,15 @@ def _apply_tp_llama(llama: LLaMA) -> None:
llama.config.n_head = llama.config.n_head // world_size
llama.config.n_embd = llama.config.n_embd // world_size
llama.config.n_query_groups = llama.config.n_query_groups // world_size


def apply_tp(model: LLaMA) -> None:

def apply_tp(model: LLaMA, rank: int, world_size: int) -> None:
global LOCAL_RANK, LOCAL_WORLD_SIZE
LOCAL_RANK = rank
LOCAL_WORLD_SIZE = world_size
assert LOCAL_RANK >= 0 and LOCAL_RANK < 8
assert LOCAL_WORLD_SIZE > 1 and LOCAL_WORLD_SIZE <= 8

_apply_tp_llama(model)
for block in model.transformer.h:
# Apply to MLP
Expand Down

0 comments on commit 567e5d2

Please sign in to comment.