Skip to content

Commit

Permalink
Renamed the new function and added a new one to delete the model and …
Browse files Browse the repository at this point in the history
…the model data
  • Loading branch information
Your Name committed Jun 27, 2024
1 parent ad45401 commit 2de42ed
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions onnxruntime/test/python/transformers/test_parity_mixtral_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,16 @@ def print_tensor(name, numpy_array):
print(f"const std::vector<float> {name} = {value_string_of(numpy_array)};")


def create_onnx_graph(model, model_path):
def save_model_to_disk(model, model_path):
external_data_path = "mixtral_moe.onnx" + ".data"
onnx.save_model(
model, model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=external_data_path
)

return model_path

def delete_model_data(external_data):
os.remove("mixtral_moe.onnx")
os.remove(external_data)


def create_moe_onnx_graph(
Expand Down Expand Up @@ -139,9 +142,9 @@ def create_moe_onnx_graph(
model = helper.make_model(graph)
model_path = "mixtral_moe.onnx"

save_model = create_onnx_graph(model, model_path)
save_model_to_disk(model, model_path)

return save_model
return model_path


class ClassInstantier(OrderedDict):
Expand Down Expand Up @@ -434,9 +437,8 @@ def test_mixtral_moe_benchmark(self):
mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length)
mixtral_moe.benchmark()

os.remove("mixtral_moe.onnx")
external_data_path = "mixtral_moe.onnx" + ".data"
os.remove(external_data_path)
delete_model_data(external_data_path)


if __name__ == "__main__":
Expand Down

0 comments on commit 2de42ed

Please sign in to comment.