Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] MultiHeadAttention CPU kernel slower than unfused #19924

Open
BowenBao opened this issue Mar 15, 2024 · 4 comments
Open

[Performance] MultiHeadAttention CPU kernel slower than unfused #19924

BowenBao opened this issue Mar 15, 2024 · 4 comments
Labels
performance issues related to performance regressions stale issues that have not been addressed in a while; categorized by a bot

Comments

@BowenBao
Copy link
Contributor

BowenBao commented Mar 15, 2024

Describe the issue

As title. Running the below repro script does a lite benchmark, as well as saves onnx model files to disk for further analysis.

Benchmarking PT sdpa and ORT MultiHeadAttention...
PT eager:
Total time: 0.26s
ORT unfused_multihead_attention.onnx
Total time: 0.26s
'MultiHeadAttention' is not a known op in 'com.microsoft'  # Please ignore, irrelevant warning from onnxscript
ORT multihead_attention.onnx
Total time: 0.88s

To reproduce

import torch
import time
import onnxruntime
import pathlib
import onnxscript
import onnx
import math

IS_CAUSAL = False


class Model(torch.nn.Module):
    def forward(
        self, query_states, key_states, value_states, mask, past_key, past_value
    ):
        query_states = query_states.view(1, 1, 32, 128).transpose(1, 2)
        key_states = key_states.view(1, 1, 32, 128).transpose(1, 2)
        value_states = value_states.view(1, 1, 32, 128).transpose(1, 2)
        key_states = torch.cat([past_key, key_states], dim=2)
        value_states = torch.cat([past_value, value_states], dim=2)
        return (
            torch.nn.functional.scaled_dot_product_attention(
                query_states,
                key_states,
                value_states,
                attn_mask=mask,
                dropout_p=0.0,
                is_causal=IS_CAUSAL,
            )
            .transpose(1, 2)
            .reshape(1, 1, -1)
        )


model = Model()
model.eval()

query_states = torch.randn(1, 1, 4096)
key_states = torch.randn(1, 1, 4096)
value_states = torch.randn(1, 1, 4096)
mask = torch.randn(1, 32, 1, 513)
past_key = torch.randn(1, 32, 512, 128)
past_value = torch.randn(1, 32, 512, 128)

torch_out = model(query_states, key_states, value_states, mask, past_key, past_value)

# Another reference perf comparison.
# torch.onnx.export(
#     model,
#     (query_states, key_states, value_states, mask, past_key, past_value),
#     "sdpa.onnx",
#     verbose=True,
#     opset_version=14,
#     input_names=["query_states", "key_states", "value_states", "mask", "past_key", "past_value"],
# )

model_dir = "multihead_attention"
fused_model_name = "multihead_attention.onnx"
fused_model_path = f"{model_dir}/{fused_model_name}"

unfused_model_name = "unfused_multihead_attention.onnx"
unfused_model_path = f"{model_dir}/{unfused_model_name}"

pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)

msft_op = onnxscript.values.Opset("com.microsoft", 1)
op = onnxscript.opset13

sqrt_head = math.sqrt(128)

query_states_ort = query_states.numpy()
key_states_ort = key_states.numpy()
value_states_ort = value_states.numpy()
attention_mask_ort = mask.numpy()
past_key_ort = past_key.numpy()
past_value_ort = past_value.numpy()
ort_inputs = {
    "query_states": query_states_ort,
    "key_states": key_states_ort,
    "value_states": value_states_ort,
    "mask": attention_mask_ort,
    "past_key": past_key_ort,
    "past_value": past_value_ort,
}

print(f"Benchmarking PT sdpa and ORT MultiHeadAttention...")


def run_pt():
    # warmup
    for _ in range(30):
        model(query_states, key_states, value_states, mask, past_key, past_value)

    total_time = 0
    for _ in range(1000):
        start_time = time.perf_counter()
        model(query_states, key_states, value_states, mask, past_key, past_value)
        total_time += time.perf_counter() - start_time
    return total_time


total_time = run_pt()

print(
    f"PT eager:"
)
print(f"Total time: {total_time:.2f}s")


def mha_onnx_model(query_states, key_states, value_states, mask, past_key, past_value):
    output, _, _ = msft_op.MultiHeadAttention(
        query_states,
        key_states,
        value_states,
        None,
        None,
        mask,
        past_key,
        past_value,
        num_heads=32,
    )
    return output


def unfused_onnx_model(
    query_states, key_states, value_states, mask, past_key, past_value
):
    query_states = op.Transpose(
        op.Reshape(query_states, shape=[1, 1, 32, 128]), perm=[0, 2, 1, 3]
    )
    key_states = op.Transpose(
        op.Reshape(key_states, shape=[1, 1, 32, 128]), perm=[0, 2, 1, 3]
    )
    value_states = op.Transpose(
        op.Reshape(value_states, shape=[1, 1, 32, 128]), perm=[0, 2, 1, 3]
    )

    key_states = op.Concat(past_key, key_states, axis=2)
    value_states = op.Concat(past_value, value_states, axis=2)
    scale = op.Constant(value_float=sqrt_head)

    attn_weights = op.MatMul(query_states, op.Transpose(key_states, perm=[0, 1, 3, 2])) / scale
    attn_weights = op.Add(attn_weights, mask)
    attn_weights = op.Softmax(attn_weights, axis=-1)
    attn_output = op.MatMul(attn_weights, value_states)
    attn_output = op.Reshape(attn_output, shape=[1, 1, 4096])
    # present_key = op.Concat(past_key, key_states, axis=2)
    # present_value = op.Concat(past_value, value_states, axis=2)
    return attn_output


def serialize_model(model_func, model_name, ort_inputs):
    model_path = f"{model_dir}/{model_name}"
    model_proto = onnxscript.script(
        onnxscript.opset13, default_opset=onnxscript.opset13
    )(model_func).to_model_proto()

    for i, value in enumerate(ort_inputs.values()):
        model_proto.graph.input[i].type.CopyFrom(
            onnx.helper.make_tensor_type_proto(
                shape=value.shape,
                elem_type=onnx.TensorProto.FLOAT,
            )
        )
    model_proto.graph.output[0].type.CopyFrom(
        onnx.helper.make_tensor_type_proto(
            shape=[1, 1, 4096],
            elem_type=onnx.TensorProto.FLOAT,
        )
    )

    onnx.save(model_proto, model_path)
    return model_proto, model_path


def save_tensor_data(numpy_tensor, output_path):
    from onnx import numpy_helper

    proto_tensor = numpy_helper.from_array(numpy_tensor)
    with open(output_path, "wb") as f:
        f.write(proto_tensor.SerializeToString())


def serialize_inputs_outputs(model_dir, onnx_inputs, onnx_outputs):
    test_data_dir = pathlib.Path(f"{model_dir}/test_data_set_0")
    test_data_dir.mkdir(parents=True, exist_ok=True)

    for i, onnx_input in enumerate(onnx_inputs.values()):
        save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb"))

    for i, onnx_output in enumerate(onnx_outputs):
        save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb"))


def run_ort(model_func, model_name, ort_inputs):
    # Serialize model
    model_proto, model_path = serialize_model(model_func, model_name, ort_inputs)

    # Serialize inputs and outputs
    sess = onnxruntime.InferenceSession(model_path)
    ort_outputs = sess.run(None, ort_inputs)

    # Parity
    torch.testing.assert_close(
        torch_out, torch.tensor(ort_outputs[0]), rtol=1e-3, atol=1e-3
    )

    serialize_inputs_outputs(model_dir, ort_inputs, ort_outputs)

    # warmup
    for _ in range(30):
        sess.run(None, ort_inputs)

    total_time = 0
    for _ in range(1000):
        start_time = time.perf_counter()
        sess.run(None, ort_inputs)
        total_time += time.perf_counter() - start_time

    print(
        f"ORT {model_name}"
    )
    print(f"Total time: {total_time:.2f}s")


run_ort(unfused_onnx_model, unfused_model_name, ort_inputs)
run_ort(mha_onnx_model, fused_model_name, ort_inputs)

Urgency

Negatively affecting Llm inference on CPU w/ ORT.

Platform

Linux

OS Version

20.04

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

33578cc

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

Model File

No response

Is this a quantized model?

No

@pranavsharma pranavsharma added the performance issues related to performance regressions label Mar 15, 2024
@yihonglyu
Copy link
Contributor

yihonglyu commented Mar 16, 2024

@BowenBao Which version of PyTorch are you currently using?

@BowenBao
Copy link
Contributor Author

BowenBao commented Mar 18, 2024

PyTorch is 2.2.1+cpu. onnxscript and onnx are also most recent version.

(Updated pytorch version, I made a mistake previously)

@BowenBao
Copy link
Contributor Author

The cpu for the repro was Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz from azure Standard F64s v2

yufenglee added a commit that referenced this issue Apr 2, 2024
### Description
<!-- Describe your changes. -->
The cost computation of ComputeVxAttentionScore is wrong. It should be
sequence_length * v_head_size * total_sequence_length instead of
sequence_length * v_head_size * sequence_length.

The PR also fine-tuned the cost computation.

on my local box with i9 cpu, the performance is same as unfused version,
but it is much faster on an azure vm with 16 threads.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

#19924
Copy link
Contributor

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Apr 18, 2024
TedThemistokleous pushed a commit to TedThemistokleous/onnxruntime that referenced this issue May 7, 2024
### Description
<!-- Describe your changes. -->
The cost computation of ComputeVxAttentionScore is wrong. It should be
sequence_length * v_head_size * total_sequence_length instead of
sequence_length * v_head_size * sequence_length.

The PR also fine-tuned the cost computation.

on my local box with i9 cpu, the performance is same as unfused version,
but it is much faster on an azure vm with 16 threads.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

microsoft#19924
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance issues related to performance regressions stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

3 participants