Skip to content

Commit

Permalink
Merge pull request #174 from laekov/fix-test
Browse files Browse the repository at this point in the history
Fix tests
  • Loading branch information
laekov authored Sep 11, 2023
2 parents 3a41edb + 1aedcdb commit feeac05
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 42 deletions.
3 changes: 1 addition & 2 deletions cuda/fastermoe/smart_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ void fmoe_cuda_fused_backward_impl(
collect_fn(si, i / num_expert, 0);
if (i / num_expert == rank) {
cudaEventCreate(evt_reduce + i % num_expert);
cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(num_expert));
cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0));
}
++si;
}
Expand Down Expand Up @@ -367,7 +367,6 @@ void fmoe_cuda_fused_backward_impl(
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
if (i / num_expert == rank) {
FMOE_SWE(smgr->stream(0), evt_reduce[i % num_expert]);
FMOE_SWE(smgr->torchStream(), evt_reduce[i % num_expert]);
set_grad_fn(si, i % num_expert);
}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
if __name__ == '__main__':
setuptools.setup(
name='fastmoe',
version='1.0.2',
version='1.1.0',
description='An efficient Mixture-of-Experts system for PyTorch',
author=', '.join(authors),
author_email='[email protected]',
Expand Down
4 changes: 2 additions & 2 deletions tests/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def forward(self, inp, gate_idx, gate_score):
class NaiveExpert(nn.Module):
def __init__(self, d_model):
super(NaiveExpert, self).__init__()
self.linear = nn.Linear(d_model, d_model).cuda()
self.linear = nn.Linear(d_model, d_model)

def forward(self, x, fec=None):
return self.linear(x)
Expand All @@ -89,7 +89,7 @@ def __init__(self, d_model):
super(LinearExpert, self).__init__()
self.model = nn.Sequential(
nn.Linear(d_model, d_model * 2), nn.ReLU(), nn.Linear(d_model * 2, d_model),
).cuda()
)

def forward(self, x, fec=None):
return self.model(x)
2 changes: 1 addition & 1 deletion tests/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ fi

export CUDA_VISIBLE_DEVICES=$localrank

exec $@
exec $@ 2>&1 | tee $RANK.log
42 changes: 35 additions & 7 deletions tests/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
from typing import Dict
import random
import socket as sock

import pytest
import torch
Expand All @@ -24,6 +25,8 @@ def _ensure_initialized():
dist.init_process_group(backend="nccl")


port_count = 0

def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
device_count = torch.cuda.device_count()
if device_count < world_size:
Expand All @@ -33,7 +36,9 @@ def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):

ps = []
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = str(random.randint(50000, 60000))
global port_count
env["MASTER_PORT"] = str(9010 + port_count)
port_count += 1
env["OMPI_COMM_WORLD_SIZE"] = str(world_size)
env["LD_LIBRARY_PATH"] = os.environ.get("LD_LIBRARY_PATH")

Expand All @@ -58,7 +63,7 @@ def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("mp_size", [1, 2])
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
@pytest.mark.parametrize("data_type", ['torch.float32', 'torch.bfloat16', 'torch.float16'])
def test_fmoe_linear_distributed(
num_expert, top_k, batch_size, d_model, d_hidden, mp_size, data_type
):
Expand All @@ -83,7 +88,8 @@ def test_fmoe_linear_distributed(
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
@pytest.mark.parametrize("mp_size", [1, 2])
def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_size):
@pytest.mark.parametrize("data_type", ['torch.float32', 'torch.bfloat16', 'torch.float16'])
def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_size, data_type):
_run_distributed(
"_test_fmoe",
mp_size * 2,
Expand All @@ -94,6 +100,7 @@ def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_siz
"d_model": d_model,
"expert": expert,
"mp_size": mp_size,
"data_type": data_type,
},
)

Expand Down Expand Up @@ -137,8 +144,29 @@ def test_fmoe_local_ddp(mp_size):
del args["mp_size"]
locals()[sys.argv[1]](**args)
else:
test_fmoe_local_ddp(mp_size=2)
test_fmoe_linear_distributed(
num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2,
data_type="torch.HalfTensor"
torch.distributed.init_process_group(backend="nccl")
args = dict(mp_size=1, data_type='torch.float16')
args["rank"] = torch.distributed.get_rank()
args["world_size"] = torch.distributed.get_world_size()
args["mp_group"] = [
torch.distributed.new_group(
ranks=[j * args["mp_size"] + i for i in range(args["mp_size"])],
backend="nccl",
)
for j in range(args["world_size"] // args["mp_size"])
][args["rank"] // args["mp_size"]]
args["dp_group"] = [
torch.distributed.new_group(
ranks=[
i * args["mp_size"] + j
for i in range(args["world_size"] // args["mp_size"])
],
backend="nccl",
)
for j in range(args["mp_size"])
][args["rank"] % args["mp_size"]]
args["world_group"] = torch.distributed.new_group(
ranks=list(range(args["world_size"])), backend="nccl",
)
del args["mp_size"]
_test_fmoe(4, 2, 16, 2, 'NaiveExpert', **args)
41 changes: 26 additions & 15 deletions tests/test_faster_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@pytest.mark.parametrize("n_process", [8])
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("group_sz", [1, 2, 4])
def test_faster_schedule(n_process, d_model, batch_size, n_expert, group_sz):
_run_distributed('_test_faster_schedule',
Expand All @@ -45,28 +45,39 @@ def _test_faster_schedule(d_model, batch_size, n_expert):
x2 = x1.data.clone()
x2.requires_grad = True
topk_idx = torch.randint(0, world_size * n_expert, (batch_size, 2)).cuda()
m1 = torch.nn.Linear(d_model, d_model).cuda()
m2 = torch.nn.Linear(d_model, d_model).cuda()
m1s = [torch.nn.Linear(d_model, d_model).cuda() for _ in range(n_expert)]
m2s = [torch.nn.Linear(d_model, d_model).cuda() for _ in range(n_expert)]
with torch.no_grad():
m2.weight.copy_(m1.weight)
m2.bias.copy_(m1.bias)
for m1, m2 in zip(m1s, m2s):
m2.weight.copy_(m1.weight)
m2.bias.copy_(m1.bias)

def ef1(x, fec, eidx):
return m1s[eidx](x)

def ef1(x, fec):
y = m1(x)
return y
def ef2(x, fec):
y = m2(x)
o = 0
ys = []
for m, i in zip(m2s, fec):
if i > 0:
ys.append(m(x[o:o + i]))
o += i
y = torch.cat(ys)
return y

ensure_comm(x1, None)
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size)
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=m1s)
y1.sum().backward()

y2 = naive_fwd(x2, topk_idx, ef2, n_expert, world_size)
y2 = naive_fwd(x2, topk_idx, ef2, n_expert, world_size, experts=m2s)
y2.sum().backward()
_assert_numerical(['out', 'grad_in', 'grad_bias', 'grad_weight'],
[y1, x1.grad, m1.bias.grad, m1.weight.grad],
[y2, x2.grad, m2.bias.grad, m2.weight.grad], rank)
_assert_numerical(['out', 'grad_in'],
[y1, x1.grad],
[y2, x2.grad], rank)
for i in range(n_expert):
_assert_numerical([f'grad_bias_{i}', f'grad_weight_{i}'],
[m1s[i].bias.grad, m1s[i].weight.grad],
[m2s[i].bias.grad, m2s[i].weight.grad], rank)


if __name__ == '__main__':
Expand All @@ -75,4 +86,4 @@ def ef2(x, fec):
locals()[sys.argv[1]](**args)
else:
# test_faster_schedule(8, 16, 16, 1, 2)
_test_faster_schedule(4, 2, 1)
_test_faster_schedule(4, 2, 4)
19 changes: 10 additions & 9 deletions tests/test_faster_shadow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
@pytest.mark.parametrize("batch_size", [16, 512])
@pytest.mark.parametrize("n_expert", [1])
@pytest.mark.parametrize("group_sz", [1, 2, 4])
@pytest.mark.parametrize("pass_stored", [False, True])
@pytest.mark.parametrize("pass_stored", [True, False])
def test_faster_shadow(n_process, d_model, batch_size, n_expert, group_sz, pass_stored):
_run_distributed('_test_faster_shadow',
n_process,
Expand Down Expand Up @@ -54,30 +54,31 @@ def _test_faster_shadow(d_model, batch_size, n_expert, pass_stored):
m2.weight.copy_(m1.weight)
m2.bias.copy_(m1.bias)

def ef1(x, fec):
def ef1(x, fec, eidx):
y = m1(x)
return y
def ef2(x, fec):
y = m2(x)
return y

if pass_stored:
stored_models = torch.randint(0, 2, (world_size,)).bool().cuda()
stored_models = torch.randint(0, 2, (world_size * n_expert,)).bool().cuda()
while stored_models.sum().item() == 0:
stored_models = torch.randint(0, 2, (world_size * n_expert,)).bool().cuda()
stored_models[-1] = True
dist.broadcast(stored_models, 0)
stored_models = stored_models.cpu()

# if rank == 0:
# print('stored models {}'.format(stored_models))
print(stored_models)

ensure_comm(x1, None)
if pass_stored:
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=m1,
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=[m1],
stored_models=stored_models)
else:
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=m1)
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=[m1])
y1.sum().backward()

y2 = naive_fwd(x2, topk_idx, ef2, n_expert, world_size, experts=m2)
y2 = naive_fwd(x2, topk_idx, ef2, n_expert, world_size, experts=[m2])
y2.sum().backward()
_assert_numerical(['out', 'grad_in', 'grad_bias', 'grad_weight'],
[y1, x1.grad, m1.bias.grad, m1.weight.grad],
Expand Down
27 changes: 22 additions & 5 deletions tests/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ def _perform_forward(

def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
for name, mo, ro in zip(names, moe_out_list, raw_out_list):
if mo is None and ro is None:
continue
if mo is None or ro is None:
assert False
err = (mo - ro).abs().max()
if err.dtype == torch.bfloat16:
if err.dtype == torch.bfloat16 or err.dtype == torch.float16:
precision *= 100
print("Rank {} {} abs err {}".format(rank, name, err))
if err > precision:
Expand Down Expand Up @@ -93,7 +97,7 @@ def __init__(
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
@pytest.mark.parametrize("data_type", ['torch.float32', 'torch.bfloat16', 'torch.float16'])
def test_fmoe_linear(
num_expert,
top_k,
Expand All @@ -111,6 +115,9 @@ def test_fmoe_linear(
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)

if isinstance(data_type, str):
data_type = eval(data_type)

moe = MyMoE(
num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
).type(data_type).cuda()
Expand Down Expand Up @@ -238,6 +245,9 @@ def test_fmoe(

if isinstance(expert, str):
expert = globals()[expert]
assert(expert is not None)
if isinstance(data_type, str):
data_type = eval(data_type)

moe = FMoE(
num_expert=num_expert,
Expand All @@ -247,15 +257,15 @@ def test_fmoe(
mp_group=mp_group,
expert=expert,
top_k=top_k,
).cuda().to(data_type)
).cuda().type(data_type)

moe_raw = BruteForceMoE(
expert=expert,
num_expert=num_expert,
d_model=d_model,
world_size=world_size,
top_k=top_k,
).cuda().to(data_type)
).cuda().type(data_type)

if world_size == 1:
for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
Expand All @@ -266,9 +276,11 @@ def test_fmoe(
else:
assert len(moe.experts) >= 1
for idx, para in enumerate(moe.experts[0].parameters()):
assert(para.device.type == 'cuda')
para_tensor = torch.cat(
[list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
)
assert(para_tensor.device.type == 'cuda')
para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
torch.distributed.all_gather(para_array, para_tensor)
para_tensor_gathered = torch.cat(para_array, dim=0)
Expand Down Expand Up @@ -419,6 +431,8 @@ def test_fmoe_experts(

if isinstance(expert, str):
expert = globals()[expert]
if isinstance(data_type, str):
data_type = eval(data_type)

moe = FMoE(
num_expert=num_expert,
Expand All @@ -428,7 +442,7 @@ def test_fmoe_experts(
mp_group=mp_group,
expert=expert,
top_k=top_k,
).cuda().to(data_type)
).cuda().type(data_type)

moe_raw = BruteForceMoE(
expert=expert,
Expand All @@ -447,9 +461,12 @@ def test_fmoe_experts(
else:
assert len(moe.experts) >= 1
for idx, para in enumerate(moe.experts[0].parameters()):
for ep in expert.parameters():
assert(ep.device.type == 'cuda')
para_tensor = torch.cat(
[list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
)
assert(para_tensor.device.type == 'cuda')
para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
torch.distributed.all_gather(para_array, para_tensor)
para_tensor_gathered = torch.cat(para_array, dim=0)
Expand Down

0 comments on commit feeac05

Please sign in to comment.