-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,11 +14,13 @@ | |
from collections import OrderedDict | ||
|
||
import numpy | ||
import os | ||
Check notice Code scanning / CodeQL Unused import Note test
Import of 'os' is not used.
Check warning Code scanning / lintrunner RUFF/F401 Warning test
os imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import |
||
import torch | ||
import torch.nn.functional as F | ||
Check notice Code scanning / CodeQL Unused import Note test
Import of 'F' is not used.
Check warning Code scanning / lintrunner RUFF/F401 Warning test
torch.nn.functional imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import |
||
from mpi4py import MPI | ||
from onnx import TensorProto, helper | ||
from torch import nn | ||
import torch.nn.init as init | ||
from typing import Tuple | ||
|
||
|
||
|
@@ -52,6 +54,17 @@ def value_string_of(numpy_array): | |
def print_tensor(name, numpy_array): | ||
print(f"const std::vector<float> {name} = {value_string_of(numpy_array)};") | ||
|
||
def save_model_to_disk(model, model_path): | ||
external_data_path = model_path + ".data" | ||
onnx.save_model( | ||
model, model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=external_data_path | ||
) | ||
|
||
|
||
#def delete_model_data(external_data): | ||
#os.remove("dbrx_moe.onnx") | ||
#os.remove(external_data) | ||
Check notice Code scanning / CodeQL Commented-out code Note test
This comment appears to contain commented-out code.
|
||
|
||
ORT_DTYPE = TensorProto.FLOAT16 | ||
NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 | ||
THRESHOLD = 3e-2 | ||
|
@@ -91,19 +104,13 @@ def create_moe_onnx_graph( | |
), | ||
] | ||
|
||
print("fc1_experts_weights shape:", fc1_experts_weights.shape) | ||
print("fc2_experts_weights shape:", fc2_experts_weights.shape) | ||
print("fc3_experts_weights shape:", fc3_experts_weights.shape) | ||
|
||
|
||
fc1_shape = [num_experts, num_experts * inter_size, hidden_size] | ||
fc2_shape = [num_experts, num_experts * inter_size, hidden_size] | ||
fc3_shape = [num_experts, num_experts * inter_size, hidden_size] | ||
|
||
print("Expected fc1_shape:", fc1_shape) | ||
print("Expected fc2_shape:", fc2_shape) | ||
print("Expected fc3_shape:", fc3_shape) | ||
fc1_experts_weights = fc1_experts_weights.view(16, 6144, 10752) | ||
fc2_experts_weights = fc2_experts_weights.view(16, 6144, 10752).transpose(1, 2) | ||
fc3_experts_weights = fc3_experts_weights.view(16, 6144, 10752) | ||
|
||
fc1_shape = [num_experts, hidden_size, inter_size] | ||
fc2_shape = [num_experts, inter_size, hidden_size] | ||
fc3_shape = [num_experts, hidden_size, inter_size] | ||
|
||
torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 | ||
|
||
|
@@ -160,7 +167,12 @@ def create_moe_onnx_graph( | |
) | ||
|
||
model = helper.make_model(graph) | ||
return model.SerializeToString() | ||
model_path = "dbrx_moe.onnx" | ||
|
||
save_model_to_disk(model, model_path) | ||
|
||
return model_path | ||
#return model.SerializeToString() | ||
|
||
|
||
|
||
|
@@ -181,7 +193,7 @@ class DBRXConfig: | |
def __init__( | ||
self, | ||
hidden_size=6144, | ||
intermediate_size=1500, | ||
intermediate_size=10752, | ||
num_hidden_layers=40, | ||
num_attention_heads=48, | ||
num_key_value_heads=8, | ||
|
@@ -194,7 +206,7 @@ def __init__( | |
num_experts_per_tok=4, | ||
num_local_experts=16, | ||
output_router_logits=False, | ||
router_aux_loss_coef=0.001, | ||
router_aux_loss_coef=0.001 | ||
): | ||
self.hidden_size = hidden_size | ||
self.intermediate_size = intermediate_size | ||
|
@@ -214,6 +226,30 @@ def __init__( | |
self.output_router_logits = output_router_logits | ||
self.router_aux_loss_coef = router_aux_loss_coef | ||
|
||
class DbrxRouter(nn.Module): | ||
def __init__( | ||
self, | ||
hidden_size: int, | ||
moe_num_experts: int, | ||
moe_top_k: int, | ||
config: DBRXConfig, | ||
): | ||
super().__init__() | ||
self.hidden_size = hidden_size | ||
self.moe_num_experts = config.num_local_experts | ||
self.moe_top_k = config.num_experts_per_tok | ||
|
||
self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False) | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: | ||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||
weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32) | ||
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1) | ||
|
||
weights = weights.to(hidden_states.dtype) | ||
top_weights = top_weights.to(hidden_states.dtype) | ||
return weights, top_weights, top_experts | ||
|
||
|
||
class DbrxExpertGLU(nn.Module): | ||
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict, config: DBRXConfig): | ||
|
@@ -223,9 +259,13 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, | |
self.moe_num_experts = config.num_local_experts | ||
ffn_act_fn = {"name": config.hidden_act} | ||
|
||
self.w1 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size)) | ||
self.v1 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size)) | ||
self.w2 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size)) | ||
self.w1 = nn.Parameter(torch.randn(moe_num_experts * ffn_hidden_size, hidden_size)) | ||
self.v1 = nn.Parameter(torch.randn(moe_num_experts * ffn_hidden_size, hidden_size)) | ||
self.w2 = nn.Parameter(torch.randn(moe_num_experts * ffn_hidden_size, hidden_size)) | ||
|
||
init.xavier_uniform_(self.w1) | ||
init.xavier_uniform_(self.v1) | ||
init.xavier_uniform_(self.w2) | ||
|
||
act_fn_name = ffn_act_fn.get("name", "silu") | ||
self.activation_fn = ACT2FN[act_fn_name] | ||
|
@@ -242,7 +282,13 @@ def forward( | |
|
||
|
||
class DbrxExperts(nn.Module): | ||
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict, batch_size: int, sequence_length: int, config: DBRXConfig): | ||
def __init__(self, hidden_size: int, | ||
ffn_hidden_size: int, | ||
moe_num_experts: int, | ||
ffn_act_fn: dict, | ||
batch_size: int, | ||
sequence_length: int, | ||
config: DBRXConfig): | ||
super().__init__() | ||
self.moe_num_experts = config.num_local_experts | ||
self.config = DBRXConfig() | ||
|
@@ -257,26 +303,16 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, | |
config=config | ||
) | ||
|
||
w1_list = [] | ||
v1_list = [] | ||
w2_list = [] | ||
for i in range(self.moe_num_experts): | ||
w1_list.append(self.mlp.w1[i]) | ||
v1_list.append(self.mlp.v1[i]) | ||
w2_list.append(self.mlp.w2[i]) | ||
self.moe_experts_weight1 = torch.stack(w1_list, dim=0) | ||
self.moe_experts_weight2 = torch.stack(v1_list, dim=0) | ||
self.moe_experts_weight3 = torch.stack(w2_list, dim=0) | ||
self.batch_size = batch_size | ||
self.sequence_length = sequence_length | ||
self.moe_onnx_graph = create_moe_onnx_graph( | ||
self.batch_size * self.sequence_length, | ||
self.moe_num_experts, | ||
self.hidden_size, | ||
self.ffn_hidden_size, | ||
self.moe_experts_weight1, | ||
self.moe_experts_weight2, | ||
self.moe_experts_weight3, | ||
self.mlp.w1, | ||
self.mlp.v1, | ||
self.mlp.w2, | ||
self.moe_top_k | ||
) | ||
|
||
|
@@ -399,10 +435,10 @@ def generate_weights_and_initial_model( | |
hidden_size, | ||
inter_size, | ||
): | ||
#s = 0.1 | ||
fc1_experts_weights_all = self.moe_experts_weight1 | ||
fc2_experts_weights_all = self.moe_experts_weight2 | ||
fc3_experts_weights_all = self.moe_experts_weight3 | ||
s = 0.1 | ||
fc1_experts_weights_all = numpy.random.normal(scale=s, size=(num_experts, hidden_size, inter_size)).astype(NP_TYPE) | ||
fc2_experts_weights_all = numpy.random.normal(scale=s, size=(num_experts, inter_size, hidden_size)).astype(NP_TYPE) | ||
fc3_experts_weights_all = numpy.random.normal(scale=s, size=(num_experts, hidden_size, inter_size)).astype(NP_TYPE) | ||
|
||
onnx_model_full = create_moe_onnx_graph( | ||
num_rows, | ||
|
@@ -520,41 +556,20 @@ def forward( | |
return out | ||
|
||
|
||
def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: | ||
batch_size, sequence_length, hidden_dim = hidden_states.shape | ||
hidden_states = hidden_states.view(-1, hidden_dim) | ||
# router_logits: (batch * sequence_length, n_experts) | ||
router_logits = self.gate(hidden_states) | ||
ort_inputs = { | ||
"input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), | ||
"router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), | ||
} | ||
|
||
ort_output = None | ||
if self.ort_sess is not None: | ||
if not iobinding: | ||
ort_output = self.ort_sess.run(None, ort_inputs) | ||
else: | ||
self.ort_run_with_iobinding(ort_inputs) | ||
return None | ||
|
||
# print_tensor("input", ort_inputs["input"]) | ||
# print_tensor("router_probs", ort_inputs["router_probs"]) | ||
# print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) | ||
# print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) | ||
# print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) | ||
# print_tensor("output", ort_output[0]) | ||
|
||
return ort_output | ||
|
||
def parity_check(self): | ||
experts = DbrxExperts() | ||
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) | ||
torch_output = self.forward(hidden_state) | ||
final_torch_output = experts.forward(torch_output) | ||
ort_output = self.ort_forward(hidden_state, iobinding=True) | ||
config = DBRXConfig() | ||
ffn = DbrxFFN(config, self.batch_size, self.sequence_length) | ||
router = DbrxRouter(hidden_size=config.hidden_size, | ||
Check notice Code scanning / CodeQL Unused local variable Note test
Variable router is not used.
Check warning Code scanning / lintrunner RUFF/F841 Warning test
Local variable router is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable |
||
moe_num_experts=config.num_local_experts, | ||
moe_top_k=config.num_local_experts, | ||
config=DBRXConfig()) | ||
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_size) | ||
torch_output = ffn.forward(hidden_state) | ||
print("forward: ", torch_output) | ||
ort_output = ffn.ort_forward(hidden_state, iobinding=False) | ||
Check notice Code scanning / CodeQL Unused local variable Note test
Variable ort_output is not used.
Check warning Code scanning / lintrunner RUFF/F841 Warning test
Local variable ort\_output is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable |
||
""" | ||
if ort_output is not None: | ||
assert torch.allclose(final_torch_output, ort_output, rtol=1e-04, atol=1e-04) | ||
assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04) | ||
print( | ||
"batch_size:", | ||
self.batch_size, | ||
|
@@ -564,69 +579,79 @@ def parity_check(self): | |
(torch_output - ort_output).abs().max(), | ||
" parity: OK", | ||
) | ||
""" | ||
|
||
|
||
|
||
class DbrxFFN(nn.Module): | ||
def __init__(self, config: DBRXConfig): | ||
def __init__(self, config: DBRXConfig, batch_size, sequence_length): | ||
super().__init__() | ||
|
||
self.batch_size = batch_size | ||
self.sequence_length = sequence_length | ||
self.router = DbrxRouter( | ||
hidden_size=config.hidden_size, | ||
moe_num_experts=config.num_local_experts, | ||
moe_top_k=config.num_experts_per_tok, | ||
config=DBRXConfig() | ||
) | ||
|
||
self.experts = DbrxExperts( | ||
hidden_size=config.hidden_size, | ||
ffn_hidden_size=config.intermediate_size, | ||
moe_num_experts=config.num_local_experts, | ||
ffn_act_fn=config.hidden_act, | ||
) | ||
batch_size=batch_size, | ||
sequence_length=sequence_length, | ||
config=DBRXConfig() | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
print("Input:", x) | ||
weights, top_weights, top_experts = self.router(x) | ||
print("After router:", weights, top_weights, top_experts) | ||
out = self.experts(x, weights, top_weights, top_experts) | ||
print("After experts:", out) | ||
return out, weights | ||
|
||
|
||
def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: | ||
batch_size, sequence_length, hidden_dim = hidden_states.shape | ||
hidden_states = hidden_states.view(-1, hidden_dim) | ||
|
||
class DbrxRouter(nn.Module): | ||
def __init__( | ||
self, | ||
hidden_size: int, | ||
config: DBRXConfig, | ||
moe_num_experts: int, | ||
moe_top_k: int, | ||
batch_size: int, | ||
sequence_length: int, | ||
ffn_hidden_size: int, | ||
ffn_act_fn: dict | ||
): | ||
super().__init__() | ||
self.hidden_size = hidden_size | ||
self.moe_num_experts = config.num_local_experts | ||
self.moe_top_k = config.num_experts_per_tok | ||
self.ffn_hidden_size = config.intermediate_size | ||
self.ffn_act_fn = {"name", config.hidden_act} | ||
assert not torch.isnan(hidden_states).any(), "Input hidden_states contains NaN values" | ||
assert not torch.isinf(hidden_states).any(), "Input hidden_states contains Inf values" | ||
|
||
self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False) | ||
# router_logits: (batch * sequence_length, n_experts) | ||
router_logits = self.router.layer(hidden_states) | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: | ||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||
weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32) | ||
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1) | ||
assert not torch.isnan(router_logits).any(), "router_logits contains NaN values" | ||
assert not torch.isinf(router_logits).any(), "router_logits contains Inf values" | ||
ort_inputs = { | ||
"input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), | ||
"router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), | ||
} | ||
|
||
top_weights_scale = ( | ||
torch.norm(top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True) | ||
if self.moe_normalize_expert_weights is not None | ||
else 1.0 | ||
) | ||
top_weights = top_weights / top_weights_scale | ||
#ort_output = None | ||
if self.experts.ort_sess is not None: | ||
if not iobinding: | ||
ort_output = self.experts.ort_sess.run(None, ort_inputs) | ||
else: | ||
ort_output = self.experts.ort_run_with_iobinding(ort_inputs) | ||
#return ort_output | ||
|
||
weights = weights.to(hidden_states.dtype) | ||
top_weights = top_weights.to(hidden_states.dtype) | ||
return weights, top_weights, top_experts | ||
|
||
|
||
# print_tensor("input", ort_inputs["input"]) | ||
# print_tensor("router_probs", ort_inputs["router_probs"]) | ||
# print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) | ||
# print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) | ||
# print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) | ||
# print_tensor("output", ort_output[0]) | ||
|
||
print("ORT output:", ort_output) | ||
assert not numpy.isnan(ort_output).any(), "ORT output contains NaN values" | ||
assert not numpy.isinf(ort_output).any(), "ORT output contains Inf values" | ||
return ort_output | ||
|
||
|
||
|
||
|
@@ -648,11 +673,14 @@ def test_dbrx_moe_parity(self): | |
batch_size, | ||
sequence_length, | ||
config) | ||
dbrx_moe.test_moe_with_tensor_parallelism(hidden_size, | ||
ffn_hidden_size, | ||
moe_num_experts, | ||
num_rows=batch_size * sequence_length, | ||
threshold=THRESHOLD) | ||
dbrx_moe.parity_check() | ||
#dbrx_moe.test_moe_with_tensor_parallelism(hidden_size, | ||
#ffn_hidden_size, | ||
#moe_num_experts, | ||
#num_rows=batch_size * sequence_length, | ||
#threshold=THRESHOLD) | ||
#external_data_path = "dbrx_moe.onnx" + ".data" | ||
#delete_model_data(external_data_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
|