Skip to content

Commit

Permalink
script changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Jul 12, 2024
1 parent 216dcd7 commit 842001b
Showing 1 changed file with 137 additions and 109 deletions.
246 changes: 137 additions & 109 deletions onnxruntime/test/python/transformers/test_parity_dbrx_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
from collections import OrderedDict

import numpy
import os

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'os' is not used.

Check warning

Code scanning / lintrunner

RUFF/F401 Warning test

import torch
import torch.nn.functional as F

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'F' is not used.

Check warning

Code scanning / lintrunner

RUFF/F401 Warning test

torch.nn.functional imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
from mpi4py import MPI
from onnx import TensorProto, helper
from torch import nn
import torch.nn.init as init
from typing import Tuple


Expand Down Expand Up @@ -52,6 +54,17 @@ def value_string_of(numpy_array):
def print_tensor(name, numpy_array):
print(f"const std::vector<float> {name} = {value_string_of(numpy_array)};")

def save_model_to_disk(model, model_path):
external_data_path = model_path + ".data"
onnx.save_model(
model, model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=external_data_path
)


#def delete_model_data(external_data):
#os.remove("dbrx_moe.onnx")
#os.remove(external_data)

Check notice

Code scanning / CodeQL

Commented-out code Note test

This comment appears to contain commented-out code.

ORT_DTYPE = TensorProto.FLOAT16
NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32
THRESHOLD = 3e-2
Expand Down Expand Up @@ -91,19 +104,13 @@ def create_moe_onnx_graph(
),
]

print("fc1_experts_weights shape:", fc1_experts_weights.shape)
print("fc2_experts_weights shape:", fc2_experts_weights.shape)
print("fc3_experts_weights shape:", fc3_experts_weights.shape)


fc1_shape = [num_experts, num_experts * inter_size, hidden_size]
fc2_shape = [num_experts, num_experts * inter_size, hidden_size]
fc3_shape = [num_experts, num_experts * inter_size, hidden_size]

print("Expected fc1_shape:", fc1_shape)
print("Expected fc2_shape:", fc2_shape)
print("Expected fc3_shape:", fc3_shape)
fc1_experts_weights = fc1_experts_weights.view(16, 6144, 10752)
fc2_experts_weights = fc2_experts_weights.view(16, 6144, 10752).transpose(1, 2)
fc3_experts_weights = fc3_experts_weights.view(16, 6144, 10752)

fc1_shape = [num_experts, hidden_size, inter_size]
fc2_shape = [num_experts, inter_size, hidden_size]
fc3_shape = [num_experts, hidden_size, inter_size]

torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32

Expand Down Expand Up @@ -160,7 +167,12 @@ def create_moe_onnx_graph(
)

model = helper.make_model(graph)
return model.SerializeToString()
model_path = "dbrx_moe.onnx"

save_model_to_disk(model, model_path)

return model_path
#return model.SerializeToString()



Expand All @@ -181,7 +193,7 @@ class DBRXConfig:
def __init__(
self,
hidden_size=6144,
intermediate_size=1500,
intermediate_size=10752,
num_hidden_layers=40,
num_attention_heads=48,
num_key_value_heads=8,
Expand All @@ -194,7 +206,7 @@ def __init__(
num_experts_per_tok=4,
num_local_experts=16,
output_router_logits=False,
router_aux_loss_coef=0.001,
router_aux_loss_coef=0.001
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
Expand All @@ -214,6 +226,30 @@ def __init__(
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef

class DbrxRouter(nn.Module):
def __init__(
self,
hidden_size: int,
moe_num_experts: int,
moe_top_k: int,
config: DBRXConfig,
):
super().__init__()
self.hidden_size = hidden_size
self.moe_num_experts = config.num_local_experts
self.moe_top_k = config.num_experts_per_tok

self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False)

def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32)
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)

weights = weights.to(hidden_states.dtype)
top_weights = top_weights.to(hidden_states.dtype)
return weights, top_weights, top_experts


class DbrxExpertGLU(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict, config: DBRXConfig):
Expand All @@ -223,9 +259,13 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int,
self.moe_num_experts = config.num_local_experts
ffn_act_fn = {"name": config.hidden_act}

self.w1 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size))
self.v1 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size))
self.w2 = nn.Parameter(torch.empty(moe_num_experts, moe_num_experts * ffn_hidden_size, hidden_size))
self.w1 = nn.Parameter(torch.randn(moe_num_experts * ffn_hidden_size, hidden_size))
self.v1 = nn.Parameter(torch.randn(moe_num_experts * ffn_hidden_size, hidden_size))
self.w2 = nn.Parameter(torch.randn(moe_num_experts * ffn_hidden_size, hidden_size))

init.xavier_uniform_(self.w1)
init.xavier_uniform_(self.v1)
init.xavier_uniform_(self.w2)

act_fn_name = ffn_act_fn.get("name", "silu")
self.activation_fn = ACT2FN[act_fn_name]
Expand All @@ -242,7 +282,13 @@ def forward(


class DbrxExperts(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict, batch_size: int, sequence_length: int, config: DBRXConfig):
def __init__(self, hidden_size: int,
ffn_hidden_size: int,
moe_num_experts: int,
ffn_act_fn: dict,
batch_size: int,
sequence_length: int,
config: DBRXConfig):
super().__init__()
self.moe_num_experts = config.num_local_experts
self.config = DBRXConfig()
Expand All @@ -257,26 +303,16 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int,
config=config
)

w1_list = []
v1_list = []
w2_list = []
for i in range(self.moe_num_experts):
w1_list.append(self.mlp.w1[i])
v1_list.append(self.mlp.v1[i])
w2_list.append(self.mlp.w2[i])
self.moe_experts_weight1 = torch.stack(w1_list, dim=0)
self.moe_experts_weight2 = torch.stack(v1_list, dim=0)
self.moe_experts_weight3 = torch.stack(w2_list, dim=0)
self.batch_size = batch_size
self.sequence_length = sequence_length
self.moe_onnx_graph = create_moe_onnx_graph(
self.batch_size * self.sequence_length,
self.moe_num_experts,
self.hidden_size,
self.ffn_hidden_size,
self.moe_experts_weight1,
self.moe_experts_weight2,
self.moe_experts_weight3,
self.mlp.w1,
self.mlp.v1,
self.mlp.w2,
self.moe_top_k
)

Expand Down Expand Up @@ -399,10 +435,10 @@ def generate_weights_and_initial_model(
hidden_size,
inter_size,
):
#s = 0.1
fc1_experts_weights_all = self.moe_experts_weight1
fc2_experts_weights_all = self.moe_experts_weight2
fc3_experts_weights_all = self.moe_experts_weight3
s = 0.1
fc1_experts_weights_all = numpy.random.normal(scale=s, size=(num_experts, hidden_size, inter_size)).astype(NP_TYPE)
fc2_experts_weights_all = numpy.random.normal(scale=s, size=(num_experts, inter_size, hidden_size)).astype(NP_TYPE)
fc3_experts_weights_all = numpy.random.normal(scale=s, size=(num_experts, hidden_size, inter_size)).astype(NP_TYPE)

onnx_model_full = create_moe_onnx_graph(
num_rows,
Expand Down Expand Up @@ -520,41 +556,20 @@ def forward(
return out


def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
ort_inputs = {
"input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)),
"router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)),
}

ort_output = None
if self.ort_sess is not None:
if not iobinding:
ort_output = self.ort_sess.run(None, ort_inputs)
else:
self.ort_run_with_iobinding(ort_inputs)
return None

# print_tensor("input", ort_inputs["input"])
# print_tensor("router_probs", ort_inputs["router_probs"])
# print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy())
# print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy())
# print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy())
# print_tensor("output", ort_output[0])

return ort_output

def parity_check(self):
experts = DbrxExperts()
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim)
torch_output = self.forward(hidden_state)
final_torch_output = experts.forward(torch_output)
ort_output = self.ort_forward(hidden_state, iobinding=True)
config = DBRXConfig()
ffn = DbrxFFN(config, self.batch_size, self.sequence_length)
router = DbrxRouter(hidden_size=config.hidden_size,

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable router is not used.

Check warning

Code scanning / lintrunner

RUFF/F841 Warning test

Local variable router is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
moe_num_experts=config.num_local_experts,
moe_top_k=config.num_local_experts,
config=DBRXConfig())
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_size)
torch_output = ffn.forward(hidden_state)
print("forward: ", torch_output)
ort_output = ffn.ort_forward(hidden_state, iobinding=False)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable ort_output is not used.

Check warning

Code scanning / lintrunner

RUFF/F841 Warning test

Local variable ort\_output is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
"""
if ort_output is not None:
assert torch.allclose(final_torch_output, ort_output, rtol=1e-04, atol=1e-04)
assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04)
print(
"batch_size:",
self.batch_size,
Expand All @@ -564,69 +579,79 @@ def parity_check(self):
(torch_output - ort_output).abs().max(),
" parity: OK",
)
"""



class DbrxFFN(nn.Module):
def __init__(self, config: DBRXConfig):
def __init__(self, config: DBRXConfig, batch_size, sequence_length):
super().__init__()

self.batch_size = batch_size
self.sequence_length = sequence_length
self.router = DbrxRouter(
hidden_size=config.hidden_size,
moe_num_experts=config.num_local_experts,
moe_top_k=config.num_experts_per_tok,
config=DBRXConfig()
)

self.experts = DbrxExperts(
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
moe_num_experts=config.num_local_experts,
ffn_act_fn=config.hidden_act,
)
batch_size=batch_size,
sequence_length=sequence_length,
config=DBRXConfig()
)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
print("Input:", x)
weights, top_weights, top_experts = self.router(x)
print("After router:", weights, top_weights, top_experts)
out = self.experts(x, weights, top_weights, top_experts)
print("After experts:", out)
return out, weights


def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

class DbrxRouter(nn.Module):
def __init__(
self,
hidden_size: int,
config: DBRXConfig,
moe_num_experts: int,
moe_top_k: int,
batch_size: int,
sequence_length: int,
ffn_hidden_size: int,
ffn_act_fn: dict
):
super().__init__()
self.hidden_size = hidden_size
self.moe_num_experts = config.num_local_experts
self.moe_top_k = config.num_experts_per_tok
self.ffn_hidden_size = config.intermediate_size
self.ffn_act_fn = {"name", config.hidden_act}
assert not torch.isnan(hidden_states).any(), "Input hidden_states contains NaN values"
assert not torch.isinf(hidden_states).any(), "Input hidden_states contains Inf values"

self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.router.layer(hidden_states)

def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32)
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
assert not torch.isnan(router_logits).any(), "router_logits contains NaN values"
assert not torch.isinf(router_logits).any(), "router_logits contains Inf values"
ort_inputs = {
"input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)),
"router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)),
}

top_weights_scale = (
torch.norm(top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True)
if self.moe_normalize_expert_weights is not None
else 1.0
)
top_weights = top_weights / top_weights_scale
#ort_output = None
if self.experts.ort_sess is not None:
if not iobinding:
ort_output = self.experts.ort_sess.run(None, ort_inputs)
else:
ort_output = self.experts.ort_run_with_iobinding(ort_inputs)
#return ort_output

weights = weights.to(hidden_states.dtype)
top_weights = top_weights.to(hidden_states.dtype)
return weights, top_weights, top_experts


# print_tensor("input", ort_inputs["input"])
# print_tensor("router_probs", ort_inputs["router_probs"])
# print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy())
# print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy())
# print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy())
# print_tensor("output", ort_output[0])

print("ORT output:", ort_output)
assert not numpy.isnan(ort_output).any(), "ORT output contains NaN values"
assert not numpy.isinf(ort_output).any(), "ORT output contains Inf values"
return ort_output



Expand All @@ -648,11 +673,14 @@ def test_dbrx_moe_parity(self):
batch_size,
sequence_length,
config)
dbrx_moe.test_moe_with_tensor_parallelism(hidden_size,
ffn_hidden_size,
moe_num_experts,
num_rows=batch_size * sequence_length,
threshold=THRESHOLD)
dbrx_moe.parity_check()
#dbrx_moe.test_moe_with_tensor_parallelism(hidden_size,
#ffn_hidden_size,
#moe_num_experts,
#num_rows=batch_size * sequence_length,
#threshold=THRESHOLD)
#external_data_path = "dbrx_moe.onnx" + ".data"
#delete_model_data(external_data_path)


if __name__ == "__main__":
Expand Down

0 comments on commit 842001b

Please sign in to comment.