-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate new pytorch attention model in CMSSW via ONNX #216
Comments
ONNX export from pytorch currently doesn't work because the MLPF forward function expects pytorch-geometric style inputs, the padding is done internally if an attention/gnn-lsh based model is used.
need to break out the 3D padded forward function to a different kind of model. |
Actually nevermind, changing the forward function as follows worked:
|
ONNX export now works for the GNN-LSH since #215. |
here's some material on how to integrate pytorch models directly via torchscript, rather than via ONNX. |
Couple of notes:
|
Here's the summary of today. It's possible to export the model (both quantized and unquantized) with dynamic shapes using This results in somewhat slow runtimes and large memory usage:
There is a special MultiHeadAttention op in ONNX contrib, but so far, I don't know how to convince torch / onnxscript to switch to it. |
Here's a potential example how to write the model by hand using onnxscript: microsoft/onnxruntime#19924 (comment) |
With this code import torch
import time
import onnxruntime
import pathlib
import onnxscript
import onnx
import math
import numpy
dtype_map = {
numpy.dtype("float32"): onnx.TensorProto.FLOAT,
numpy.dtype("bool"): onnx.TensorProto.BOOL,
}
class Model(torch.nn.Module):
def forward(
self, query_states, key_states, value_states, mask
):
query_states = query_states
key_states = key_states
value_states = value_states
return torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
)
model = Model()
model.eval()
# (B, num_heads, N, head_dim)
query_states = torch.randn(1, 32, 4096, 64)
key_states = torch.randn(1, 32, 4096, 64)
value_states = torch.randn(1, 32, 4096, 64)
mask = torch.randn(1, 32, 4096, 1)
torch_out = model(query_states, key_states, value_states, mask)
print(torch_out.shape)
print(torch_out)
# Another reference perf comparison.
# torch.onnx.export(
# model,
# (query_states, key_states, value_states, mask),
# "sdpa.onnx",
# verbose=True,
# opset_version=14,
# input_names=["query_states", "key_states", "value_states", "mask"],
# )
model_dir = "multihead_attention"
fused_model_name = "multihead_attention.onnx"
fused_model_path = f"{model_dir}/{fused_model_name}"
unfused_model_name = "unfused_multihead_attention.onnx"
unfused_model_path = f"{model_dir}/{unfused_model_name}"
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
msft_op = onnxscript.values.Opset("com.microsoft", 1)
op = onnxscript.opset13
sqrt_head = math.sqrt(64)
query_states_ort = query_states.numpy()
key_states_ort = key_states.numpy()
value_states_ort = value_states.numpy()
attention_mask_ort = mask.numpy()
ort_inputs = {
"query_states": query_states_ort,
"key_states": key_states_ort,
"value_states": value_states_ort,
"mask": attention_mask_ort,
}
print(f"Benchmarking PT sdpa and ORT MultiHeadAttention...")
def run_pt():
# warmup
for _ in range(30):
model(query_states, key_states, value_states, mask)
total_time = 0
for _ in range(1000):
start_time = time.perf_counter()
model(query_states, key_states, value_states, mask)
total_time += time.perf_counter() - start_time
return total_time
total_time = run_pt()
print(
f"PT eager:"
)
print(f"Total time: {total_time:.2f}s")
def mha_onnx_model(query_states, key_states, value_states, mask):
# query_states = op.Reshape(query_states, shape=[1, 32, 128, 1, 64])
# key_states = op.Reshape(key_states, shape=[1, 32, 128, 1, 64])
# value_states = op.Reshape(value_states, shape=[1, 32, 128, 1, 64])
# qkv = op.Concat(query_states, key_states, value_states, axis=3)
query_states = op.Reshape(op.Transpose(query_states, perm=[0,2,1,3]), shape=[1,4096,2048])
key_states = op.Reshape(op.Transpose(key_states, perm=[0,2,1,3]), shape=[1,4096,2048])
value_states = op.Reshape(op.Transpose(value_states, perm=[0,2,1,3]), shape=[1,4096,2048])
output, _, _ = msft_op.MultiHeadAttention(
query_states,
key_states,
value_states,
num_heads=32,
)
output = op.Reshape(output, shape=[1, 4096, 32, 64])
output = op.Transpose(output, perm=[0,2,1,3])
return output
def unfused_onnx_model(
query_states, key_states, value_states, mask
):
scale = op.Constant(value_float=sqrt_head)
attn_weights = op.MatMul(query_states, op.Transpose(key_states, perm=[0, 1, 3, 2])) / scale
# attn_weights = op.Add(attn_weights, mask)
attn_weights = op.Softmax(attn_weights, axis=-1)
attn_output = op.MatMul(attn_weights, value_states)
return attn_output
def serialize_model(model_func, model_name, ort_inputs):
model_path = f"{model_dir}/{model_name}"
model_proto = onnxscript.script(
onnxscript.opset13, default_opset=onnxscript.opset13
)(model_func).to_model_proto()
for i, value in enumerate(ort_inputs.values()):
model_proto.graph.input[i].type.CopyFrom(
onnx.helper.make_tensor_type_proto(
shape=value.shape,
elem_type=dtype_map[value.dtype],
)
)
model_proto.graph.output[0].type.CopyFrom(
onnx.helper.make_tensor_type_proto(
shape=[1, 32, 4096, 64],
elem_type=onnx.TensorProto.FLOAT,
)
)
onnx.save(model_proto, model_path)
return model_proto, model_path
def save_tensor_data(numpy_tensor, output_path):
from onnx import numpy_helper
proto_tensor = numpy_helper.from_array(numpy_tensor)
with open(output_path, "wb") as f:
f.write(proto_tensor.SerializeToString())
def serialize_inputs_outputs(model_dir, onnx_inputs, onnx_outputs):
test_data_dir = pathlib.Path(f"{model_dir}/test_data_set_0")
test_data_dir.mkdir(parents=True, exist_ok=True)
for i, onnx_input in enumerate(onnx_inputs.values()):
save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb"))
for i, onnx_output in enumerate(onnx_outputs):
save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb"))
def run_ort(model_func, model_name, ort_inputs):
# Serialize model
model_proto, model_path = serialize_model(model_func, model_name, ort_inputs)
# Serialize inputs and outputs
sess = onnxruntime.InferenceSession(model_path)
ort_outputs = sess.run(None, ort_inputs)
# Parity
torch.testing.assert_close(
torch_out, torch.tensor(ort_outputs[0]), rtol=1e-3, atol=1e-3
)
serialize_inputs_outputs(model_dir, ort_inputs, ort_outputs)
# warmup
for _ in range(30):
sess.run(None, ort_inputs)
total_time = 0
for _ in range(10):
start_time = time.perf_counter()
sess.run(None, ort_inputs)
total_time += time.perf_counter() - start_time
print(
f"ORT {model_name}"
)
print(f"Total time: {total_time:.2f}s")
run_ort(unfused_onnx_model, unfused_model_name, ort_inputs)
run_ort(mha_onnx_model, fused_model_name, ort_inputs) tt seems like at least on an M2 CPU the regular ONNX unfused scaled_dot_product_attention is fastest for a sequence len of 4096 elements:
|
It looks as if the onnxruntime.transformer optimizer, specifically FusionAttention: should replace the attention block with the MultiHeadAttention that on GPU should support flash attention. |
Converting the model with the fused attention layer import torch
import time
import onnxruntime
import pathlib
import onnxscript
import onnx
import math
import numpy
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
num_iter = 1000
def get_mem_gpu_mb():
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
return mem.used / 1000 / 1000
dtype_map = {
numpy.dtype("float32"): onnx.TensorProto.FLOAT,
numpy.dtype("bool"): onnx.TensorProto.BOOL,
}
class Model(torch.nn.Module):
def forward(
self, query_states, key_states, value_states, mask
):
query_states = query_states
key_states = key_states
value_states = value_states
return torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
)
model = Model()
model.eval()
# (B, num_heads, N, head_dim)
query_states = torch.randn(1, 32, 4096, 64)
key_states = torch.randn(1, 32, 4096, 64)
value_states = torch.randn(1, 32, 4096, 64)
mask = torch.randn(1, 32, 4096, 1)
torch_out = model(query_states, key_states, value_states, mask)
print(torch_out.shape)
print(torch_out)
# Another reference perf comparison.
torch.onnx.export(
model,
(query_states, key_states, value_states, mask),
"sdpa.onnx",
verbose=True,
opset_version=18,
input_names=["query_states", "key_states", "value_states", "mask"],
)
model_dir = "multihead_attention"
fused_model_name = "multihead_attention.onnx"
fused_model_path = f"{model_dir}/{fused_model_name}"
unfused_model_name = "unfused_multihead_attention.onnx"
unfused_model_path = f"{model_dir}/{unfused_model_name}"
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
msft_op = onnxscript.values.Opset("com.microsoft", 1)
op = onnxscript.opset18
sqrt_head = math.sqrt(64)
query_states_ort = query_states.numpy()
key_states_ort = key_states.numpy()
value_states_ort = value_states.numpy()
attention_mask_ort = mask.numpy()
ort_inputs = {
"query_states": query_states_ort,
"key_states": key_states_ort,
"value_states": value_states_ort,
"mask": attention_mask_ort,
}
print(f"Benchmarking PT sdpa and ORT MultiHeadAttention...")
def run_pt():
# warmup
for _ in range(30):
model(query_states, key_states, value_states, mask)
total_time = 0
for _ in range(num_iter):
start_time = time.perf_counter()
model(query_states, key_states, value_states, mask)
total_time += time.perf_counter() - start_time
return total_time
total_time = run_pt()
print(
f"PT eager:"
)
print(f"Total time: {total_time:.2f}s")
def mha_onnx_model(query_states, key_states, value_states, mask):
# query_states = op.Reshape(query_states, shape=[1, 32, 128, 1, 64])
# key_states = op.Reshape(key_states, shape=[1, 32, 128, 1, 64])
# value_states = op.Reshape(value_states, shape=[1, 32, 128, 1, 64])
# qkv = op.Concat(query_states, key_states, value_states, axis=3)
query_states = op.Reshape(op.Transpose(query_states, perm=[0,2,1,3]), shape=[1,4096,2048])
key_states = op.Reshape(op.Transpose(key_states, perm=[0,2,1,3]), shape=[1,4096,2048])
value_states = op.Reshape(op.Transpose(value_states, perm=[0,2,1,3]), shape=[1,4096,2048])
output, _, _ = msft_op.MultiHeadAttention(
query_states,
key_states,
value_states,
num_heads=32,
)
output = op.Reshape(output, shape=[1, 4096, 32, 64])
output = op.Transpose(output, perm=[0,2,1,3])
return output
def unfused_onnx_model(
query_states, key_states, value_states, mask
):
scale = op.Constant(value_float=sqrt_head)
attn_weights = op.MatMul(query_states, op.Transpose(key_states, perm=[0, 1, 3, 2])) # / scale
# attn_weights = op.Add(attn_weights, mask)
attn_weights = op.Softmax(attn_weights, axis=-1)
attn_output = op.MatMul(attn_weights, value_states)
return attn_output
def serialize_model(model_func, model_name, ort_inputs):
model_path = f"{model_dir}/{model_name}"
model_proto = onnxscript.script(
onnxscript.opset18, default_opset=onnxscript.opset18
)(model_func).to_model_proto()
for i, value in enumerate(ort_inputs.values()):
model_proto.graph.input[i].type.CopyFrom(
onnx.helper.make_tensor_type_proto(
shape=value.shape,
elem_type=dtype_map[value.dtype],
)
)
model_proto.graph.output[0].type.CopyFrom(
onnx.helper.make_tensor_type_proto(
shape=[1, 32, 4096, 64],
elem_type=onnx.TensorProto.FLOAT,
)
)
onnx.save(model_proto, model_path)
return model_proto, model_path
def save_tensor_data(numpy_tensor, output_path):
from onnx import numpy_helper
proto_tensor = numpy_helper.from_array(numpy_tensor)
with open(output_path, "wb") as f:
f.write(proto_tensor.SerializeToString())
def serialize_inputs_outputs(model_dir, onnx_inputs, onnx_outputs):
test_data_dir = pathlib.Path(f"{model_dir}/test_data_set_0")
test_data_dir.mkdir(parents=True, exist_ok=True)
for i, onnx_input in enumerate(onnx_inputs.values()):
save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb"))
for i, onnx_output in enumerate(onnx_outputs):
save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb"))
def run_ort(model_func, model_name, ort_inputs):
# Serialize model
model_proto, model_path = serialize_model(model_func, model_name, ort_inputs)
from onnxconverter_common import float16
model = onnx.load(model_path)
model_fp16 = float16.convert_float_to_float16(model)
onnx.save(model_fp16, model_path)
# Serialize inputs and outputs
sess = onnxruntime.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
ort_inputs_fp16 = {k: float16.convert_np_to_float16(v) for k, v in ort_inputs.items()}
ort_outputs = sess.run(None, ort_inputs_fp16)
# Parity
# torch.testing.assert_close(
# torch_out, torch.tensor(ort_outputs[0]), rtol=1e-3, atol=1e-3
# )
serialize_inputs_outputs(model_dir, ort_inputs_fp16, ort_outputs)
# warmup
for _ in range(30):
sess.run(None, ort_inputs_fp16)
total_time = 0
mem_used = []
for _ in range(num_iter):
start_time = time.perf_counter()
sess.run(None, ort_inputs_fp16)
mem_used.append(get_mem_gpu_mb())
total_time += time.perf_counter() - start_time
print(
f"ORT {model_name}"
)
max_mem = numpy.max(mem_used)
print(f"Total time: {total_time:.2f}s, mem: {max_mem:.0f}MB")
run_ort(unfused_onnx_model, unfused_model_name, ort_inputs)
run_ort(mha_onnx_model, fused_model_name, ort_inputs) gives:
|
In 7a301da I got the MultiheadAttention OP spliced into the graph using the onnxscript and torch.onnx.export (TorchScript) approach. |
In e40f5c3 the ONNX export now works with dynamic shapes, f32/fp16, using com.microsoft.MultiheadAttention (that can use Flash Attention on GPU), and the pytorch and onnx versions return the same values.
One self-attention block looks something like this: The trick was that pytorch needs |
Actually ONLY the MultiHeadAttention op needs to run in fp16:
Then the outputs are basically equivalent to the base model (at least on CPU). |
Importing the new model in CMSSW still todo. |
The required changes in the CMSSW side to import the new ONNX model are here: jpata/cmssw@3d5455b It runs and produces nonzero/nongarbage outputs. Submitted jobs on CPU, will see validations soon. A couple of tracks being produced
and a couple of neutrals
|
Here I managed to make the CMSSW ONNX GPU inference work, I think:
|
All done, moved to CMSSW_14 and updated the C++ inference code. |
attention
(alsognnlsh
, but less important)The text was updated successfully, but these errors were encountered: