Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate new pytorch attention model in CMSSW via ONNX #216

Closed
3 tasks done
jpata opened this issue Sep 22, 2023 · 19 comments · Fixed by #324
Closed
3 tasks done

Integrate new pytorch attention model in CMSSW via ONNX #216

jpata opened this issue Sep 22, 2023 · 19 comments · Fixed by #324
Assignees
Labels
CMS Concerns the CMS MLPF model hard

Comments

@jpata
Copy link
Owner

jpata commented Sep 22, 2023

  • ONNX export of the model from pytorch: now works for attention (also gnnlsh, but less important)
  • check compatibility of the results between ONNX and pytorch inference outputs
  • re-check import in CMSSW of new ONNX model (maybe onnxruntime version is a problem and need to update CMSSW)
@jpata jpata self-assigned this Sep 22, 2023
@jpata
Copy link
Owner Author

jpata commented Sep 22, 2023

ONNX export from pytorch currently doesn't work because the MLPF forward function expects pytorch-geometric style inputs, the padding is done internally if an attention/gnn-lsh based model is used.

    def forward(self, batch):

        # unfold the Batch object
        if self.ssl:
            input_ = batch.x.float()[:, : self.input_dim]
            VICReg_embeddings = batch.x.float()[:, self.input_dim :]
        else:
            input_ = batch.x.float()

        batch_idx = batch.batch

        embeddings_id = []
        embeddings_reg = []

need to break out the 3D padded forward function to a different kind of model.

@jpata
Copy link
Owner Author

jpata commented Sep 22, 2023

Actually nevermind, changing the forward function as follows worked:

    def forward(self, element_features, batch_idx):

        # unfold the Batch object
        if self.ssl:
            input_ = element_features.float()[:, : self.input_dim]
            VICReg_embeddings = element_features.float()[:, self.input_dim :]
        else:
            input_ = element_features.float()

        embeddings_id = []
        embeddings_reg = []

@jpata
Copy link
Owner Author

jpata commented Sep 27, 2023

ONNX export now works for the GNN-LSH since #215.
Need to train a model with pytorch and test that the import in CMSSW gives reasonable results.

@jpata jpata changed the title Revisit ONNX export ONNX export in pytorch Sep 29, 2023
@jpata jpata changed the title ONNX export in pytorch implement ONNX export in pytorch Sep 29, 2023
@jpata jpata changed the title implement ONNX export in pytorch pytorch model in CMSSW Apr 11, 2024
@jpata jpata added hard CMS Concerns the CMS MLPF model labels Apr 11, 2024
@jpata
Copy link
Owner Author

jpata commented Apr 29, 2024

https://indico.cern.ch/event/1388888/contributions/5839133/attachments/2821898/4928058/2024_03_18%20ML%20production%20news.pdf

here's some material on how to integrate pytorch models directly via torchscript, rather than via ONNX.

@jpata jpata changed the title pytorch model in CMSSW Integrate new pytorhch attention model in CMSSW via ONNX May 16, 2024
@jpata jpata changed the title Integrate new pytorhch attention model in CMSSW via ONNX Integrate new pytorch attention model in CMSSW via ONNX May 16, 2024
@jpata
Copy link
Owner Author

jpata commented May 24, 2024

Couple of notes:

@jpata
Copy link
Owner Author

jpata commented May 25, 2024

Here's the summary of today.

It's possible to export the model (both quantized and unquantized) with dynamic shapes using torch.onnx.export in #324.
However, scaled_dot_product_attention creates the inefficient fully unrolled attention implementation (i.e. naive or math version), so one attention layer looks something like this:
Screenshot 2024-05-25 at 16 13 33

This results in somewhat slow runtimes and large memory usage:

timing/cpu_fp32.txt:Nelem=5120 mean_time=17029.90 ms stddev_time=126.80 ms mem_used=16672 MB
timing/gpu_fp16.txt:Nelem=5120 mean_time=85.91 ms stddev_time=10.96 ms mem_used=22884 MB
timing/gpu_fp32.txt:Nelem=5120 mean_time=134.03 ms stddev_time=20.34 ms mem_used=45426 MB
timing/gpu_int8.txt:Nelem=5120 mean_time=144.67 ms stddev_time=20.11 ms mem_used=45426 MB
timing/openvino_fp16.txt:Nelem=5120 mean_time=30045.50 ms stddev_time=2774.41 ms mem_used=31867 MB
timing/openvino_fp32.txt:Nelem=5120 mean_time=14351.32 ms stddev_time=642.40 ms mem_used=31592 MB
timing/openvino_int8.txt:Nelem=5120 mean_time=15503.07 ms stddev_time=60.57 ms mem_used=16661 MB

There is a special MultiHeadAttention op in ONNX contrib, but so far, I don't know how to convince torch / onnxscript to switch to it.
https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MultiHeadAttention

@jpata
Copy link
Owner Author

jpata commented May 25, 2024

Here's a potential example how to write the model by hand using onnxscript: microsoft/onnxruntime#19924 (comment)

@jpata
Copy link
Owner Author

jpata commented May 25, 2024

Here's how the unfused vs. fused MHA looks like based on the example above
Screenshot 2024-05-25 at 17 27 39
Screenshot 2024-05-25 at 17 27 59

@jpata
Copy link
Owner Author

jpata commented May 25, 2024

With this code

import torch
import time
import onnxruntime
import pathlib
import onnxscript
import onnx
import math
import numpy


dtype_map = {
    numpy.dtype("float32"): onnx.TensorProto.FLOAT,
    numpy.dtype("bool"): onnx.TensorProto.BOOL,
}

class Model(torch.nn.Module):
    def forward(
        self, query_states, key_states, value_states, mask
    ):
        query_states = query_states
        key_states = key_states
        value_states = value_states
        return torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=None,
            dropout_p=0.0,
        )


model = Model()
model.eval()

# (B, num_heads, N, head_dim)
query_states = torch.randn(1, 32, 4096, 64)
key_states = torch.randn(1, 32, 4096, 64)
value_states = torch.randn(1, 32, 4096, 64)
mask = torch.randn(1, 32, 4096, 1)

torch_out = model(query_states, key_states, value_states, mask)
print(torch_out.shape)
print(torch_out)

# Another reference perf comparison.
# torch.onnx.export(
#     model,
#     (query_states, key_states, value_states, mask),
#     "sdpa.onnx",
#     verbose=True,
#     opset_version=14,
#     input_names=["query_states", "key_states", "value_states", "mask"],
# )

model_dir = "multihead_attention"
fused_model_name = "multihead_attention.onnx"
fused_model_path = f"{model_dir}/{fused_model_name}"

unfused_model_name = "unfused_multihead_attention.onnx"
unfused_model_path = f"{model_dir}/{unfused_model_name}"

pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)

msft_op = onnxscript.values.Opset("com.microsoft", 1)
op = onnxscript.opset13

sqrt_head = math.sqrt(64)

query_states_ort = query_states.numpy()
key_states_ort = key_states.numpy()
value_states_ort = value_states.numpy()
attention_mask_ort = mask.numpy()

ort_inputs = {
    "query_states": query_states_ort,
    "key_states": key_states_ort,
    "value_states": value_states_ort,
    "mask": attention_mask_ort,
}

print(f"Benchmarking PT sdpa and ORT MultiHeadAttention...")


def run_pt():
    # warmup
    for _ in range(30):
        model(query_states, key_states, value_states, mask)

    total_time = 0
    for _ in range(1000):
        start_time = time.perf_counter()
        model(query_states, key_states, value_states, mask)
        total_time += time.perf_counter() - start_time
    return total_time


total_time = run_pt()

print(
    f"PT eager:"
)
print(f"Total time: {total_time:.2f}s")


def mha_onnx_model(query_states, key_states, value_states, mask):
    # query_states = op.Reshape(query_states, shape=[1, 32, 128, 1, 64])
    # key_states = op.Reshape(key_states, shape=[1, 32, 128, 1, 64])
    # value_states = op.Reshape(value_states, shape=[1, 32, 128, 1, 64])
    # qkv = op.Concat(query_states, key_states, value_states, axis=3)

    query_states = op.Reshape(op.Transpose(query_states, perm=[0,2,1,3]), shape=[1,4096,2048])
    key_states = op.Reshape(op.Transpose(key_states, perm=[0,2,1,3]), shape=[1,4096,2048])
    value_states = op.Reshape(op.Transpose(value_states, perm=[0,2,1,3]), shape=[1,4096,2048])
    output, _, _ = msft_op.MultiHeadAttention(
        query_states,
        key_states,
        value_states,
        num_heads=32,
    )
    output = op.Reshape(output, shape=[1, 4096, 32, 64])
    output = op.Transpose(output, perm=[0,2,1,3])
    return output


def unfused_onnx_model(
    query_states, key_states, value_states, mask
):

    scale = op.Constant(value_float=sqrt_head)

    attn_weights = op.MatMul(query_states, op.Transpose(key_states, perm=[0, 1, 3, 2])) / scale
    # attn_weights = op.Add(attn_weights, mask)
    attn_weights = op.Softmax(attn_weights, axis=-1)
    attn_output = op.MatMul(attn_weights, value_states)
    return attn_output


def serialize_model(model_func, model_name, ort_inputs):
    model_path = f"{model_dir}/{model_name}"
    model_proto = onnxscript.script(
        onnxscript.opset13, default_opset=onnxscript.opset13
    )(model_func).to_model_proto()


    for i, value in enumerate(ort_inputs.values()):
        model_proto.graph.input[i].type.CopyFrom(
            onnx.helper.make_tensor_type_proto(
                shape=value.shape,
                elem_type=dtype_map[value.dtype],
            )
        )
    model_proto.graph.output[0].type.CopyFrom(
        onnx.helper.make_tensor_type_proto(
            shape=[1, 32, 4096, 64],
            elem_type=onnx.TensorProto.FLOAT,
        )
    )

    onnx.save(model_proto, model_path)
    return model_proto, model_path


def save_tensor_data(numpy_tensor, output_path):
    from onnx import numpy_helper

    proto_tensor = numpy_helper.from_array(numpy_tensor)
    with open(output_path, "wb") as f:
        f.write(proto_tensor.SerializeToString())


def serialize_inputs_outputs(model_dir, onnx_inputs, onnx_outputs):
    test_data_dir = pathlib.Path(f"{model_dir}/test_data_set_0")
    test_data_dir.mkdir(parents=True, exist_ok=True)

    for i, onnx_input in enumerate(onnx_inputs.values()):
        save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb"))

    for i, onnx_output in enumerate(onnx_outputs):
        save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb"))


def run_ort(model_func, model_name, ort_inputs):
    # Serialize model
    model_proto, model_path = serialize_model(model_func, model_name, ort_inputs)

    # Serialize inputs and outputs
    sess = onnxruntime.InferenceSession(model_path)
    ort_outputs = sess.run(None, ort_inputs)

    # Parity
    torch.testing.assert_close(
        torch_out, torch.tensor(ort_outputs[0]), rtol=1e-3, atol=1e-3
    )

    serialize_inputs_outputs(model_dir, ort_inputs, ort_outputs)

    # warmup
    for _ in range(30):
        sess.run(None, ort_inputs)

    total_time = 0
    for _ in range(10):
        start_time = time.perf_counter()
        sess.run(None, ort_inputs)
        total_time += time.perf_counter() - start_time

    print(
        f"ORT {model_name}"
    )
    print(f"Total time: {total_time:.2f}s")


run_ort(unfused_onnx_model, unfused_model_name, ort_inputs)
run_ort(mha_onnx_model, fused_model_name, ort_inputs)

tt seems like at least on an M2 CPU the regular ONNX unfused scaled_dot_product_attention is fastest for a sequence len of 4096 elements:

Benchmarking PT sdpa and ORT MultiHeadAttention...
PT eager:
Total time: 660.70s
ORT unfused_multihead_attention.onnx
Total time: 11.15s
'MultiHeadAttention' is not a known op in 'com.microsoft'
ORT multihead_attention.onnx
Total time: 18.02s

pytorch sdpa:
Screenshot 2024-05-25 at 22 45 44

onnx unfused:
Screenshot 2024-05-25 at 22 41 35

onnx fused:
Screenshot 2024-05-25 at 22 42 11

@jpata
Copy link
Owner Author

jpata commented May 25, 2024

It looks as if the onnxruntime.transformer optimizer, specifically FusionAttention:
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_attention.py#L712

should replace the attention block with the MultiHeadAttention that on GPU should support flash attention.

@jpata
Copy link
Owner Author

jpata commented May 25, 2024

If I can replace this part (SDPA only):
Screenshot 2024-05-25 at 23 00 27

with this (ignore the shapes):
image

then in principle it should be possible to try flash attention on the ONNX model.

@jpata
Copy link
Owner Author

jpata commented May 26, 2024

Converting the model with the fused attention layer com.microsoft.MultiheadAttention to float16 does run flash attention on A100 with the expected speed and memory improvement.
The following code has batch size 1, sequence length 4096, num_heads 32, head_dim 64.

import torch
import time
import onnxruntime
import pathlib
import onnxscript
import onnx
import math
import numpy
import pynvml

pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)

num_iter = 1000
def get_mem_gpu_mb():
    mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
    return mem.used / 1000 / 1000

dtype_map = {
    numpy.dtype("float32"): onnx.TensorProto.FLOAT,
    numpy.dtype("bool"): onnx.TensorProto.BOOL,
}

class Model(torch.nn.Module):
    def forward(
        self, query_states, key_states, value_states, mask
    ):
        query_states = query_states
        key_states = key_states
        value_states = value_states
        return torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=None,
            dropout_p=0.0,
        )


model = Model()
model.eval()

# (B, num_heads, N, head_dim)
query_states = torch.randn(1, 32, 4096, 64)
key_states = torch.randn(1, 32, 4096, 64)
value_states = torch.randn(1, 32, 4096, 64)
mask = torch.randn(1, 32, 4096, 1)

torch_out = model(query_states, key_states, value_states, mask)
print(torch_out.shape)
print(torch_out)

# Another reference perf comparison.
torch.onnx.export(
    model,
    (query_states, key_states, value_states, mask),
    "sdpa.onnx",
    verbose=True,
    opset_version=18,
    input_names=["query_states", "key_states", "value_states", "mask"],
)

model_dir = "multihead_attention"
fused_model_name = "multihead_attention.onnx"
fused_model_path = f"{model_dir}/{fused_model_name}"

unfused_model_name = "unfused_multihead_attention.onnx"
unfused_model_path = f"{model_dir}/{unfused_model_name}"

pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)

msft_op = onnxscript.values.Opset("com.microsoft", 1)
op = onnxscript.opset18

sqrt_head = math.sqrt(64)

query_states_ort = query_states.numpy()
key_states_ort = key_states.numpy()
value_states_ort = value_states.numpy()
attention_mask_ort = mask.numpy()

ort_inputs = {
    "query_states": query_states_ort,
    "key_states": key_states_ort,
    "value_states": value_states_ort,
    "mask": attention_mask_ort,
}

print(f"Benchmarking PT sdpa and ORT MultiHeadAttention...")


def run_pt():
    # warmup
    for _ in range(30):
        model(query_states, key_states, value_states, mask)

    total_time = 0
    for _ in range(num_iter):
        start_time = time.perf_counter()
        model(query_states, key_states, value_states, mask)
        total_time += time.perf_counter() - start_time
    return total_time


total_time = run_pt()

print(
    f"PT eager:"
)
print(f"Total time: {total_time:.2f}s")


def mha_onnx_model(query_states, key_states, value_states, mask):
    # query_states = op.Reshape(query_states, shape=[1, 32, 128, 1, 64])
    # key_states = op.Reshape(key_states, shape=[1, 32, 128, 1, 64])
    # value_states = op.Reshape(value_states, shape=[1, 32, 128, 1, 64])
    # qkv = op.Concat(query_states, key_states, value_states, axis=3)

    query_states = op.Reshape(op.Transpose(query_states, perm=[0,2,1,3]), shape=[1,4096,2048])
    key_states = op.Reshape(op.Transpose(key_states, perm=[0,2,1,3]), shape=[1,4096,2048])
    value_states = op.Reshape(op.Transpose(value_states, perm=[0,2,1,3]), shape=[1,4096,2048])
    output, _, _ = msft_op.MultiHeadAttention(
        query_states,
        key_states,
        value_states,
        num_heads=32,
    )
    output = op.Reshape(output, shape=[1, 4096, 32, 64])
    output = op.Transpose(output, perm=[0,2,1,3])
    return output


def unfused_onnx_model(
    query_states, key_states, value_states, mask
):

    scale = op.Constant(value_float=sqrt_head)

    attn_weights = op.MatMul(query_states, op.Transpose(key_states, perm=[0, 1, 3, 2])) # / scale
    # attn_weights = op.Add(attn_weights, mask)
    attn_weights = op.Softmax(attn_weights, axis=-1)
    attn_output = op.MatMul(attn_weights, value_states)
    return attn_output


def serialize_model(model_func, model_name, ort_inputs):
    model_path = f"{model_dir}/{model_name}"
    model_proto = onnxscript.script(
        onnxscript.opset18, default_opset=onnxscript.opset18
    )(model_func).to_model_proto()


    for i, value in enumerate(ort_inputs.values()):
        model_proto.graph.input[i].type.CopyFrom(
            onnx.helper.make_tensor_type_proto(
                shape=value.shape,
                elem_type=dtype_map[value.dtype],
            )
        )
    model_proto.graph.output[0].type.CopyFrom(
        onnx.helper.make_tensor_type_proto(
            shape=[1, 32, 4096, 64],
            elem_type=onnx.TensorProto.FLOAT,
        )
    )

    onnx.save(model_proto, model_path)
    return model_proto, model_path


def save_tensor_data(numpy_tensor, output_path):
    from onnx import numpy_helper

    proto_tensor = numpy_helper.from_array(numpy_tensor)
    with open(output_path, "wb") as f:
        f.write(proto_tensor.SerializeToString())


def serialize_inputs_outputs(model_dir, onnx_inputs, onnx_outputs):
    test_data_dir = pathlib.Path(f"{model_dir}/test_data_set_0")
    test_data_dir.mkdir(parents=True, exist_ok=True)

    for i, onnx_input in enumerate(onnx_inputs.values()):
        save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb"))

    for i, onnx_output in enumerate(onnx_outputs):
        save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb"))


def run_ort(model_func, model_name, ort_inputs):
    # Serialize model
    model_proto, model_path = serialize_model(model_func, model_name, ort_inputs)
    
    from onnxconverter_common import float16
    model = onnx.load(model_path)
    model_fp16 = float16.convert_float_to_float16(model)
    onnx.save(model_fp16, model_path)

    # Serialize inputs and outputs
    sess = onnxruntime.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
    ort_inputs_fp16 = {k: float16.convert_np_to_float16(v) for k, v in ort_inputs.items()}
    ort_outputs = sess.run(None, ort_inputs_fp16)

    # Parity
    # torch.testing.assert_close(
    #     torch_out, torch.tensor(ort_outputs[0]), rtol=1e-3, atol=1e-3
    # )

    serialize_inputs_outputs(model_dir, ort_inputs_fp16, ort_outputs)

    # warmup
    for _ in range(30):
        sess.run(None, ort_inputs_fp16)

    total_time = 0
    mem_used = []
    for _ in range(num_iter):
        start_time = time.perf_counter()
        sess.run(None, ort_inputs_fp16)
        mem_used.append(get_mem_gpu_mb())
        total_time += time.perf_counter() - start_time

    print(
        f"ORT {model_name}"
    )
    max_mem = numpy.max(mem_used)
    print(f"Total time: {total_time:.2f}s, mem: {max_mem:.0f}MB")


run_ort(unfused_onnx_model, unfused_model_name, ort_inputs)
run_ort(mha_onnx_model, fused_model_name, ort_inputs)

gives:

#this is on CPU
PT eager:
Total time: 79.24s

#this is on A100
ORT unfused_multihead_attention.onnx
Total time: 11.96s, mem: 5771MB

ORT multihead_attention.onnx
Total time: 9.60s, mem: 1680MB

@jpata
Copy link
Owner Author

jpata commented May 26, 2024

In 7a301da I got the MultiheadAttention OP spliced into the graph using the onnxscript and torch.onnx.export (TorchScript) approach.

@jpata
Copy link
Owner Author

jpata commented May 26, 2024

In e40f5c3 the ONNX export now works with dynamic shapes, f32/fp16, using com.microsoft.MultiheadAttention (that can use Flash Attention on GPU), and the pytorch and onnx versions return the same values.
The timings and mem usage on fp16 are great now (tested on A100):

timing/gpu_fp16.txt:Nelem=5120 mean_time=14.26 ms stddev_time=0.45 ms mem_used=1946 MB
timing/gpu_fp32.txt:Nelem=5120 mean_time=122.88 ms stddev_time=6.46 ms mem_used=12274 MB

One self-attention block looks something like this:
Screenshot 2024-05-26 at 17 40 28

and inside SDPA:
Screenshot 2024-05-26 at 17 42 18

The trick was that pytorch needs (batch, seq_len, num_heads, head_dim), while com.microsoft.MultiheadAttention needs (batch, seq_len, num_heads*head_dim).

@jpata
Copy link
Owner Author

jpata commented May 27, 2024

Actually ONLY the MultiHeadAttention op needs to run in fp16:

@onnxscript.script(custom_opset)
def SDPA(
    query: TFloat,
    key: TFloat,
    value: TFloat,
) -> TFloat:

    # Unlike pytorch scaled_dot_product_attention,
    # the input here MUST BE (batch, seq_len, num_head*head_dim).
    # Also, for the op to be fast on GPU, it needs to run in float16.
    query = op.Cast(query, to=onnx.TensorProto.FLOAT16)
    key = op.Cast(key, to=onnx.TensorProto.FLOAT16)
    value = op.Cast(value, to=onnx.TensorProto.FLOAT16)
    output, _, _ = msft_op.MultiHeadAttention(query, key, value, num_heads=NUM_HEADS)
    output = op.Cast(output, to=onnx.TensorProto.FLOAT)

    return output

Then the outputs are basically equivalent to the base model (at least on CPU).

@jpata jpata linked a pull request May 27, 2024 that will close this issue
@jpata jpata reopened this May 27, 2024
@jpata
Copy link
Owner Author

jpata commented May 27, 2024

Importing the new model in CMSSW still todo.
Need #323 merged and some results from it first and then do the CMSSW updates on top of https://github.com/jpata/cmssw/releases/tag/pfanalysis_caloparticle_CMSSW_14_1_0_pre3_acat2022.

@jpata
Copy link
Owner Author

jpata commented May 27, 2024

The required changes in the CMSSW side to import the new ONNX model are here: jpata/cmssw@3d5455b

It runs and produces nonzero/nongarbage outputs. Submitted jobs on CPU, will see validations soon.

A couple of tracks being produced

ielem=1203 inputs:0=1 1=0.369013 2=0.492823 3=0.240846 4=0.970563 5=0.41474 6=0 7=0 8=1 9=0 10=0 11=0 12=2.98601 13=-2.15722 14=0 15=0 16=0 17=0.358151 18=0.0888752 19=0.18931 20=0 21=0 22=0 23=0 24=0 25=0 26=0 27=0 28=8 29=0 30=0 31=0 32=0.0194093 33=0.0901159 34=3.56876 35=0.0072579 36=0.00452945 37=0.00699194 38=0.474006 39=0.00403006 40=1.09679 41=0.00403006 42=0 43=0 44=0 45=0 46=0 47=0 48=0 49=0 50=0 51=0 52=0 53=0 54=0 
ielem=1203 pred: pid=211 E=0.434683 pt=0.365148 eta=0.49956 phi=0.233267 charge=1
ielem=1204 inputs:0=1 1=0.25898 2=1.00609 3=0.542937 4=0.839773 5=0.401488 6=0 7=0 8=1 9=0 10=0 11=0 12=2.90243 13=-0.830946 14=0 15=0 16=0 17=0.217484 18=0.14061 19=0.306793 20=0 21=0 22=0 23=0 24=0 25=0 26=0 27=0 28=5 29=0 30=0 31=0 32=0.0252961 33=0.0149914 34=1.73247 35=0.00628026 36=0.00790802 37=0.00939261 38=0.869708 39=0.00510107 40=0.701089 41=0.00510107 42=0 43=0 44=0 45=0 46=0 47=0 48=0 49=0 50=0 51=0 52=0 53=0 54=0 
ielem=1204 pred: pid=211 E=0.428164 pt=0.261295 eta=1.00512 phi=0.57336 charge=1

and a couple of neutrals

ielem=1858 inputs:0=8 1=2.61464 2=3.93908 3=-0.258819 4=-0.965926 5=67.1841 6=11 7=1 8=0 9=0 10=0 11=0 12=0 13=0 14=0 15=0 16=0 17=-42.7738 18=-11.4612 19=1137 20=0 21=0 22=0 23=0 24=0 25=0 26=0 27=0 28=3 29=0 30=-1 31=-1 32=0 33=0 34=0 35=38.7886 36=0.614548 37=0.705366 38=0 39=0 40=0 41=0 42=7.05472 43=0 44=0.487688 45=0 46=0 47=0 48=0.225962 49=0 50=0 51=0 52=30.1196 53=10.016 54=9.34346e-11 
ielem=1858 pred: pid=2 E=65.0437 pt=2.45356 eta=3.97031 phi=-2.88292 charge=0
ielem=1859 inputs:0=8 1=0.774559 2=3.06393 3=-0.573577 4=0.819152 5=8.31035 6=11 7=1 8=0 9=0 10=0 11=0 12=0 13=0 14=0 15=0 16=0 17=87.1876 18=-61.0494 19=1137 20=0 21=0 22=0 23=0 24=0 25=0 26=0 27=0 28=1 29=0 30=-1 31=-1 32=0 33=0 34=0 35=0 36=0 37=0 38=0 39=0 40=0 41=0 42=8.5 43=0 44=0 45=0 46=0 47=0 48=0 49=0 50=0 51=0 52=0 53=0 54=0 
ielem=1859 pred: pid=2 E=8.56867 pt=0.709752 eta=3.18238 phi=-0.612545 charge=0

@jpata
Copy link
Owner Author

jpata commented May 28, 2024

Here I managed to make the CMSSW ONNX GPU inference work, I think:
jpata/cmssw@36be715

CPU PF:
log_cpu_pf.txt:TimeModule> 35002 1 particleFlowTmp PFProducer 0.00893436
log_cpu_pf.txt:TimeModule> 35005 1 particleFlowTmp PFProducer 0.00696006
log_cpu_pf.txt:TimeModule> 35001 1 particleFlowTmp PFProducer 0.0205714
log_cpu_pf.txt:TimeModule> 35004 1 particleFlowTmp PFProducer 0.0115013
log_cpu_pf.txt:TimeModule> 35003 1 particleFlowTmp PFProducer 0.010012

CPU MLPF:
log_cpu.txt:TimeModule> 35002 1 particleFlowTmp MLPFProducer 9.4116
log_cpu.txt:TimeModule> 35005 1 particleFlowTmp MLPFProducer 8.02389
log_cpu.txt:TimeModule> 35001 1 particleFlowTmp MLPFProducer 13.4437
log_cpu.txt:TimeModule> 35004 1 particleFlowTmp MLPFProducer 10.4151
log_cpu.txt:TimeModule> 35003 1 particleFlowTmp MLPFProducer 12.1385

GPU MLPF (A100, 1 event per batch):
log_gpu.txt:TimeModule> 35002 1 particleFlowTmp MLPFProducer 0.177305
log_gpu.txt:TimeModule> 35005 1 particleFlowTmp MLPFProducer 0.0156437
log_gpu.txt:TimeModule> 35001 1 particleFlowTmp MLPFProducer 0.0187983
log_gpu.txt:TimeModule> 35004 1 particleFlowTmp MLPFProducer 0.0158696
log_gpu.txt:TimeModule> 35003 1 particleFlowTmp MLPFProducer 0.0171756

@jpata
Copy link
Owner Author

jpata commented May 28, 2024

All done, moved to CMSSW_14 and updated the C++ inference code.

@jpata jpata closed this as completed May 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CMS Concerns the CMS MLPF model hard
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

1 participant