Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add built-in llama_ffn; add helloworld_custom_expert_sharded; #235

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tutel/examples/helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
parser.add_argument('--capacity_factor', type=float, default=1.0) # 0.0 for dMoE (dropless-MoE), negative for no-padded capacity.
parser.add_argument('--megablocks_size', type=int, default=0)
parser.add_argument('--use_tensorcore', default=False, action='store_true')
parser.add_argument('--expert_type', type=str, default='ffn')

args = parser.parse_args()

Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(self):

self._moe_layer = tutel_moe.moe_layer(
gate_type = {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate, 'capacity_factor': args.capacity_factor},
experts = {'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_size, 'activation_fn': lambda x: F.relu(x)},
experts = {'type': args.expert_type, 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_size, 'activation_fn': lambda x: F.relu(x)},
model_dim = model_dim,
scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True),
seeds = (1, dist_rank + 1, 1),
Expand Down
3 changes: 2 additions & 1 deletion tutel/examples/helloworld_custom_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __init__(self, model_dim, local_experts, sharded_count, my_config):
self.reset_parameters()

def reset_parameters(self):
pass
with torch.no_grad():
self.W.normal_(0, 0.001)

def forward(self, x, ctx):
if ctx.sharded_count > 1:
Expand Down
176 changes: 176 additions & 0 deletions tutel/examples/helloworld_custom_expert_sharded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
import argparse

from tutel import system
from tutel import moe as tutel_moe
from tutel import net

parser = argparse.ArgumentParser()

parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--num_tokens', type=int, default=512)
parser.add_argument('--model_dim', type=int, default=2048)
parser.add_argument('--num_local_experts', type=int, default=2)
parser.add_argument('--dtype', type=str, default='float32')
parser.add_argument('--fp32_gate', default=False, action='store_true')
parser.add_argument('--top', type=int, default=2)
parser.add_argument('--l_aux_wt', type=float, default=0.0)
parser.add_argument('--a2a_ffn_overlap_degree', type=int, default=1)
parser.add_argument('--allreduce_degree', type=int, default=1)
parser.add_argument('--num_steps', type=int, default=100)
parser.add_argument('--checkpoint_path', type=str, default='')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--use_2dh', default=False, action='store_true')
parser.add_argument('--eval', default=False, action='store_true')
parser.add_argument('--capacity_factor', type=float, default=1.0) # 0.0 for dMoE (dropless-MoE), negative for no-padded capacity.

args = parser.parse_args()

parallel_env = system.init_data_model_parallel(backend='nccl' if args.device == 'cuda' else 'gloo')
dist_rank, dist_world_size, dist_print = parallel_env.global_rank, parallel_env.global_size, parallel_env.dist_print
args.local_rank = parallel_env.local_device.index

batch_size = args.batch_size
num_tokens = args.num_tokens
model_dim = args.model_dim
num_local_experts = args.num_local_experts
top_value = args.top
a2a_ffn_overlap_degree = args.a2a_ffn_overlap_degree
device = parallel_env.local_device

if args.dtype == 'float32':
torch.set_default_dtype(torch.float32)
elif args.dtype == 'float64':
torch.set_default_dtype(torch.float64)
elif args.dtype == 'float16':
torch.set_default_dtype(torch.float16)
elif args.dtype == 'bfloat16':
torch.set_default_dtype(torch.bfloat16)
else:
raise Exception('Unrecognized data type specified: %s' % args.dtype)


class CustomExpertDemo(torch.nn.Module):

def _create_sharded_param(self, *full_shape, **kwargs):
full_shape = torch.Size(full_shape)
sharded_shape = (full_shape.numel() + self.sharded_count - 1) // self.sharded_count
return torch.nn.Parameter(torch.empty(sharded_shape, **kwargs)), full_shape

def _get_gathered_param(self, param, full_shape):
sharded_group = net.create_groups_from_world(group_count=-self.sharded_count).model_group
return net.zero_gather(param, group=sharded_group).view(-1).narrow(0, 0, full_shape.numel()).view(full_shape)

def __init__(self, model_dim, local_experts, sharded_count, my_config):
super().__init__()
self.sharded_count = sharded_count
self.W, self.W_full_shape = self._create_sharded_param(local_experts, model_dim, model_dim)
self.my_activation = torch.nn.functional.relu if my_config == 'relu' else None
self.reset_parameters()

def reset_parameters(self):
with torch.no_grad():
self.W.normal_(0, 0.001)

def forward(self, x, ctx):
W_full = self._get_gathered_param(self.W, self.W_full_shape)
y = torch.matmul(x, W_full)
if self.my_activation is not None:
y = self.my_activation(y)
return y


class ExampleModel(torch.nn.Module):
def __init__(self):
super().__init__()

self._moe_layer = tutel_moe.moe_layer(
gate_type = {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate, 'capacity_factor': args.capacity_factor},
experts = {'type': 'custom', 'module': CustomExpertDemo, 'count_per_node': num_local_experts, 'my_config': None},
model_dim = model_dim,
scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True),
seeds = (1, dist_rank + 1, 1),
a2a_ffn_overlap_degree = a2a_ffn_overlap_degree,
use_2dh=args.use_2dh,
)

# Summary of different parameter types: gate, local_experts
local_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='local_experts')])
shared_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='gate')])
dist_print('[Statistics] param count for MoE local_experts = %s, param count for MoE gate = %s.\n' % (local_count, shared_count))

def forward(self, input):
result = self._moe_layer(input)
result = F.log_softmax(torch.sum(result, dim=2), dim=1)
return result

model = ExampleModel().to(device)
dist_print(model)

if args.checkpoint_path:
checkpoint_path = system.apply_rank_size_from_pattern(args.checkpoint_path, rank=parallel_env.global_rank, size=parallel_env.global_size)
if os.path.exists(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path))
else:
print('Checkpoint not loaded: file `%s` is not found. Will train the model from start.' % checkpoint_path)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)

torch.manual_seed(0)
x = torch.tensor(torch.randn([batch_size, num_tokens, model_dim], dtype=torch.float32, device='cpu').detach().numpy(), dtype=torch.get_default_dtype(), requires_grad=False, device=device)
y = torch.LongTensor(batch_size).random_(1).to(device)

tuples = (dist_world_size, args.dtype, model_dim, batch_size * num_tokens, num_local_experts, top_value, a2a_ffn_overlap_degree, device)
dist_print('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, samples = %s, num_local_experts = %s, topK = %s, a2a_ffn_overlap_degree = %s, device = `%s`' % tuples)

average_time, num_steps = 0, args.num_steps

if args.allreduce_degree == -1:
params_for_all_reduce = []
else:
params_for_all_reduce = [p for p in model.parameters() if not hasattr(p, 'skip_allreduce') and getattr(p, 'requires_grad', False)]

for i in range(num_steps):
t_start = system.record_time()

if not args.eval:
optimizer.zero_grad()
output = model(x)
loss = F.nll_loss(output, y)
if args.l_aux_wt:
loss += args.l_aux_wt * model._moe_layer.l_aux
loss.backward()
if dist_world_size > 1:
for p in params_for_all_reduce:
p.grad /= dist_world_size
p.grad = net.simple_all_reduce(p.grad)
optimizer.step()
else:
with torch.no_grad():
output = model(x)
loss = F.nll_loss(output, y)

t_stop = system.record_time()

num_global_experts = tutel_moe.moe_layer.global_expert_count(num_local_experts, group=system.get_local_session().model_group)
mm_ceof, cap_ceof = 1 if args.eval else 3, min(args.top, num_global_experts)
dist_print('STEP-%s: loss = %.5f, step_time = %.6f sec.' % (i, float(loss.data), t_stop - t_start))

if i + 10 >= num_steps:
average_time += t_stop - t_start

average_time /= 10
dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time)

if args.checkpoint_path:
torch.save(model.state_dict(), checkpoint_path)
45 changes: 45 additions & 0 deletions tutel/experts/llama_ffn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
from .. import net

class LlamaFFNNetwork(torch.nn.Module):

def _create_sharded_param(self, *full_shape, **kwargs):
full_shape = torch.Size(full_shape)
sharded_shape = (full_shape.numel() + self.sharded_count - 1) // self.sharded_count
return torch.nn.Parameter(torch.empty(sharded_shape, **kwargs)), full_shape

def _get_gathered_param(self, param, full_shape):
sharded_group = net.create_groups_from_world(group_count=-self.sharded_count).model_group
return net.zero_gather(param, group=sharded_group).view(-1).narrow(0, 0, full_shape.numel()).view(full_shape)

def __init__(self, model_dim, hidden_size_per_expert, local_experts, sharded_count, activation_fn=torch.nn.functional.silu):
super().__init__()
self.sharded_count = sharded_count
self.W_fc1, self.W_fc1_full_shape = self._create_sharded_param(local_experts, model_dim, hidden_size_per_expert)
self.W_fc2, self.W_fc2_full_shape = self._create_sharded_param(local_experts, model_dim, hidden_size_per_expert)
self.W_fc3, self.W_fc3_full_shape = self._create_sharded_param(local_experts, hidden_size_per_expert, model_dim)
self.activation_fn = activation_fn
self.reset_parameters()

def reset_parameters(self):
with torch.no_grad():
self.W_fc1.normal_(0, 0.01)
self.W_fc2.normal_(0, 0.01)
self.W_fc3.normal_(0, 0.01)

def forward(self, x, ctx):
W_fc1_full = self._get_gathered_param(self.W_fc1, self.W_fc1_full_shape)
W_fc2_full = self._get_gathered_param(self.W_fc2, self.W_fc2_full_shape)
W_fc3_full = self._get_gathered_param(self.W_fc3, self.W_fc3_full_shape)

y1 = torch.matmul(x, W_fc1_full)
y2 = torch.matmul(x, W_fc2_full)
y = self.activation_fn(y1 * y2)
y = torch.matmul(y, W_fc3_full)
return y


ExpertModule = LlamaFFNNetwork