Skip to content

Commit

Permalink
onnx export with dynamic shapes, fast attention (jpata#324)
Browse files Browse the repository at this point in the history
* enable onnx export via dynamo with dynamic shapes

* added standalone export script

* fp16 quantization sort of works also

* use sdpa

* MultiheadAttention op runs

* update timing study

* cleanup

* model closes

* update timing study

* onnx is factorized

* update onnx script

* revert main model code

* move to notebook
  • Loading branch information
jpata authored and farakiko committed Aug 26, 2024
1 parent 6f25fa3 commit d002adc
Show file tree
Hide file tree
Showing 8 changed files with 940 additions and 62 deletions.
30 changes: 1 addition & 29 deletions mlpf/pyg/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,40 +812,12 @@ def run(rank, world_size, config, args, outdir, logfile):
dir_name=testdir_name,
)

if (rank == 0) or (rank == "cpu"): # make plots and export to onnx only on a single machine
if (rank == 0) or (rank == "cpu"): # make plots only on a single machine
if args.make_plots:
for sample in args.test_datasets:
_logger.info(f"Plotting distributions for {sample}")
make_plots(outdir, sample, config["dataset"], testdir_name)

if args.export_onnx:
try:
dummy_features = torch.randn(1, 8192, model_kwargs["input_dim"], device=rank)
dummy_mask = torch.zeros(1, 8192, dtype=torch.bool, device=rank)

# Torch ONNX export in the old way
torch.onnx.export(
model,
(dummy_features, dummy_mask),
"test.onnx",
verbose=False,
input_names=["features", "mask"],
output_names=["id", "momentum"],
dynamic_axes={
"features": {0: "num_batch", 1: "num_elements"},
"mask": [0, 1],
"id": [0, 1],
"momentum": [0, 1],
# "charge": [0, 1],
},
)

# Torch ONNX export in the new way
# onnx_program = torch.onnx.dynamo_export(model, (dummy_features, dummy_mask))
# onnx_program.save("test.onnx")
except Exception as e:
print("ONNX export failed: {}".format(e))

if world_size > 1:
dist.destroy_process_group()

Expand Down
2 changes: 1 addition & 1 deletion mlpf/pyg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def get_lr_schedule(config, opt, epochs=None, steps_per_epoch=None, last_epoch=-


def count_parameters(model):
column_names = ["Modules", "Trainable parameters", "Non-tranable parameters"]
column_names = ["Modules", "Trainable parameters", "Non-trainable parameters"]
table = pd.DataFrame(columns=column_names)
trainable_params = 0
nontrainable_params = 0
Expand Down
76 changes: 46 additions & 30 deletions mlpf/timing.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,37 @@
import sys
import time

import numpy as np
import onnxruntime as rt
import pynvml
import resource
import argparse


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--bin-size", type=int, default=256)
parser.add_argument("--num-features", type=int, default=17)
parser.add_argument("--batch-size", type=int, default=20)
parser.add_argument("--num-features", type=int, default=55)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--num-threads", type=int, default=1)
parser.add_argument("--use-gpu", type=bool, action="store_true")
parser.add_argument("--input-dtype", type=str, default="float32")
parser.add_argument("--use-gpu", action="store_true")
parser.add_argument("--model", type=str, default="test.onnx")
parser.add_argument(
"--execution-provider",
type=str,
default="CPUExecutionProvider",
choices=["CPUExecutionProvider", "CUDAExecutionProvider", "OpenVINOExecutionProvider"],
)
args = parser.parse_args()
return args


# for GPU testing, you need to
# pip install only onnxruntime_gpu, not onnxruntime!
args = parse_args()

bin_size = args.bin_size
num_features = args.num_features
use_gpu = args.use_gpu
batch_size = args.batch_size
num_threads = args.num_threads

if use_gpu:
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)


def get_mem_cpu_mb():
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1000


def get_mem_gpu_mb():
import pynvml

mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
return mem.used / 1000 / 1000

Expand All @@ -51,16 +44,29 @@ def get_mem_mb(use_gpu):


if __name__ == "__main__":
# for GPU testing, you need to
# pip install only onnxruntime_gpu, not onnxruntime!
args = parse_args()

bin_size = args.bin_size
num_features = args.num_features
use_gpu = args.execution_provider == "CUDAExecutionProvider"
batch_size = args.batch_size
num_threads = args.num_threads

if use_gpu:
import pynvml

pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)

print(
"batch_size={} bin_size={} num_features={} use_gpu={} num_threads={}".format(
batch_size, bin_size, num_features, use_gpu, num_threads
)
)

if use_gpu:
EP_list = ["CUDAExecutionProvider"]
else:
EP_list = ["CPUExecutionProvider"]
EP_list = [args.execution_provider]

time.sleep(5)

Expand All @@ -74,25 +80,35 @@ def get_mem_mb(use_gpu):
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.add_session_config_entry("session.intra_op.allow_spinning", "1")

onnx_sess = rt.InferenceSession(sys.argv[1], sess_options, providers=EP_list)
time.sleep(5)
onnx_sess = rt.InferenceSession(args.model, sess_options, providers=EP_list)
# warmup

mem_onnx = get_mem_mb(use_gpu)
print("mem_onnx", mem_onnx)

for num_elems in [bin_size, 2 * bin_size, 10 * bin_size, 20 * bin_size, 40 * bin_size]:
X = np.array(np.random.randn(batch_size, bin_size, num_features), getattr(np, args.input_dtype))
for i in range(10):
onnx_sess.run(None, {"Xfeat_normed": X, "mask": X[..., 0] != 0})

for bin_mul in [
10,
20,
40,
]:
num_elems = bin_size * bin_mul
times = []
mem_used = []

# average over 100 events
for i in range(10):
for i in range(100):

# allocate array in system memory
X = np.array(np.random.randn(batch_size, num_elems, num_features), np.float32)
X = np.array(np.random.randn(batch_size, num_elems, num_features), getattr(np, args.input_dtype))

# transfer data to GPU, run model, transfer data back
t0 = time.time()
pred_onx = onnx_sess.run(None, {"x:0": X})
# pred_onx = onnx_sess.run(None, {"Xfeat_normed": X, "l_mask_": X[..., 0]==0})
pred_onx = onnx_sess.run(None, {"Xfeat_normed": X, "mask": X[..., 0] != 0})
t1 = time.time()
dt = (t1 - t0) / batch_size
times.append(dt)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/clic/mlpf-pytorch-transformer-standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion notebooks/cms/cms-mlpf.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2580,7 +2580,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit d002adc

Please sign in to comment.