Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] CoreML not being used to it's fullest capacity - custom transformer #19887

Open
pfeatherstone opened this issue Mar 13, 2024 · 10 comments
Labels
platform:mobile issues related to ONNX Runtime mobile; typically submitted using template stale issues that have not been addressed in a while; categorized by a bot

Comments

@pfeatherstone
Copy link

Describe the issue

I am converting a Pytorch model to ONNX and running it with ONNXRUNTIME on a MacBook Pro using CoreML EP.
My model is a custom transformer model.
Only 25% of the nodes can run on CoreML. So performance is about the same as running on CPU.

To reproduce

import torch
import torch.nn as nn
import torch.nn.functional as F
import onnxruntime
from einops import rearrange, repeat


def exists(val):
    return val is not None


class Attention2(nn.Module):
    def __init__(self, dim, heads, drop, causal):
        super().__init__()
        dim_head    = dim // heads
        self.heads  = heads
        self.causal = causal
        self.drop   = drop
        self.to_qkv = nn.Linear(dim, dim_head*heads*3, bias=False)
        self.to_out = nn.Linear(dim_head*heads, dim,   bias=False)
      
    def forward(self, x, input_mask):
        T       = x.shape[1]
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))

        # Build mask
        mask = repeat(input_mask, 'b j -> b 1 i j', i=T)
        if self.causal:
            mask = mask.tril()
            
        # Attention is all you need
        x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.drop if self.training else 0.0, is_causal=False)
        x = rearrange(x, 'b h t d -> b t (h d)')
        x = self.to_out(x) 
        return x
    

def FeedForward2(dim, ff_mult, drop=0.0):
    return nn.Sequential(nn.Linear(dim, dim*ff_mult),
                         nn.GELU(),
                         nn.Dropout(drop),
                         nn.Linear(dim*ff_mult, dim),
                         nn.Dropout(drop))


class TransformerEncoderLayer2(nn.Module):
    def __init__(self, dim, heads, drop, causal, ff_mult):
        super().__init__()
        self.ff     = FeedForward2(dim, ff_mult, drop)
        self.attn   = Attention2(dim, heads, drop, causal)
        self.norms  = nn.ModuleList([nn.LayerNorm(dim), nn.LayerNorm(dim)])

    def forward(self, x, input_mask):
        x = x + self.attn(self.norms[0](x), input_mask)
        x = x + self.ff(self.norms[1](x))
        return x


class TransformerEncoder2(nn.Module):
    def __init__(self, dim, heads, drop, causal, ff_mult, depth, vocab_size=None, dim_in=None, dim_out=None):
        super().__init__()
        self.proj_in    = nn.Embedding(vocab_size, dim) if exists(vocab_size) else nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
        self.proj_out   = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
        self.layers     = nn.ModuleList([TransformerEncoderLayer2(dim, heads, drop, causal, ff_mult) for _ in range(depth)])
        self.norm       = nn.LayerNorm(dim)
    
    def forward(self, input, input_mask):
        x = self.proj_in(input)

        for l in self.layers:
            x = l(x, input_mask)

        return self.proj_out(self.norm(x))

net = TransformerEncoder2(
    dim         = 512,
    heads       = 4,
    depth       = 4,
    ff_mult     = 2,
    drop        = 0.0,
    causal      = False,
    vocab_size  = 259,
    dim_out     = None
).eval()

x = torch.randint(0, 259, size=(4,1024))
m = torch.full_like(x, fill_value=True).bool()
y = net(x, m)
print(x.shape, x.dtype)
print(y.shape, y.dtype)

torch.onnx.export(
    net, 
    (x,m), 
    'net.onnx',
    opset_version=15,
    input_names=['data', 'mask'],
    output_names=['embeddings'],
    dynamic_axes={'data'        : {0: 'B', 1: 'T'},
                  'mask'        : {0: 'B', 1: 'T'},
                  'embeddings'  : {0: 'B', 1: 'T'}})

ort     = onnxruntime.InferenceSession('net.onnx', providers=['CPUExecutionProvider'])
x       = torch.randint(0, 259, size=(2,2000))
m       = torch.full_like(x, fill_value=True).bool()
y1      = net(x, m) 
y2,     = ort.run(None, {'data': x.numpy(), 'mask': m.numpy()})
torch.testing.assert_close(y1, torch.from_numpy(y2))

Then run on the terminal:

python3 -m onnxruntime.tools.check_onnx_model_mobile_usability --log_level debug net.onnx

You will see something like:

INFO:  CoreML is not recommended with this model as there are 16 partitions covering 25.0% of the nodes in the model. This will most likely result in worse performance than just using the CPU EP.

Urgency

Not super urgent but this would be a massive win for me if I could get the performance on CoreML to be within 25% of my NVIDIA card.

Platform

Linux

OS Version

ubuntu 22

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.17.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CoreML

Execution Provider Library Version

No response

Model File

No response

Is this a quantized model?

No

@github-actions github-actions bot added the platform:mobile issues related to ONNX Runtime mobile; typically submitted using template label Mar 13, 2024
@pfeatherstone
Copy link
Author

#17654 is related.

@pfeatherstone
Copy link
Author

pfeatherstone commented Mar 13, 2024

Note, my actual model is more complicated. It uses Rotary embeddings, XL-recurrence, kv memories and a few other things. I've stripped things back massively to produce a minimal example.

@natke
Copy link
Contributor

natke commented Mar 15, 2024

Thanks @pfeatherstone! Are you able to attach the output from the usability checker?

@pfeatherstone
Copy link
Author

INFO:  Checking net.onnx for usability with ORT Mobile.
INFO:  Checking NNAPI
INFO:  12 partitions with a total of 104/344 nodes can be handled by the NNAPI EP.
INFO:  Partition sizes: [14, 6, 9, 10, 6, 9, 10, 6, 9, 10, 6, 9]
INFO:  Unsupported nodes due to operator=37
INFO:  Unsupported nodes due to input having a dynamic shape=203
INFO:  Unsupported ops: ai.onnx:Equal,ai.onnx:Erf,ai.onnx:Expand,ai.onnx:Shape,ai.onnx:Where
DEBUG:  Caveats that have not been checked and may result in a node not being supported:  
     ai.onnx:Gather:Input indices should be constant if not int32 type.
     ai.onnx:Unsqueeze:Input axes should be constant.
INFO:  NNAPI is not recommended with this model as there are 12 partitions covering 30.2% of the nodes in the model. This will most likely result in worse performance than just using the CPU EP.
INFO:  Model should perform well with NNAPI as is: NO
INFO:  Checking if model will perform better if the dynamic shapes are fixed...
INFO:  Partition information if the model was updated to make the shapes fixed:
INFO:  21 partitions with a total of 307/344 nodes can be handled by the NNAPI EP.
INFO:  Partition sizes: [16, 19, 10, 4, 24, 16, 19, 10, 4, 24, 16, 19, 10, 4, 24, 16, 19, 10, 4, 24, 15]
INFO:  Unsupported nodes due to operator=37
INFO:  Unsupported ops: ai.onnx:Equal,ai.onnx:Erf,ai.onnx:Expand,ai.onnx:Shape,ai.onnx:Where
DEBUG:  Caveats that have not been checked and may result in a node not being supported:  
     ai.onnx:Gather:Input indices should be constant if not int32 type.
     ai.onnx:Unsqueeze:Input axes should be constant.
INFO:  NNAPI is not recommended with this model as there are 21 partitions covering 89.2% of the nodes in the model. This will most likely result in worse performance than just using the CPU EP.
INFO:  Model should perform well with NNAPI if modified to have fixed input shapes: NO
INFO:  Checking CoreML
INFO:  16 partitions with a total of 86/344 nodes can be handled by the CoreML EP.
INFO:  Partition sizes: [8, 9, 5, 1, 6, 9, 5, 1, 6, 9, 5, 1, 6, 9, 5, 1]
INFO:  Unsupported nodes due to operator=57
INFO:  Unsupported nodes due to input having a dynamic shape=201
INFO:  Unsupported ops: ai.onnx:Equal,ai.onnx:Erf,ai.onnx:Expand,ai.onnx:ReduceMean,ai.onnx:Unsqueeze,ai.onnx:Where
DEBUG:  Caveats that have not been checked and may result in a node not being supported:  
     ai.onnx:Gather:Input `indices` with scalar value is not supported.
     ai.onnx:MatMul:Input B should be constant.
     ai.onnx:Pow:Only supports cases when both inputs are fp32.
     ai.onnx:Shape:Attribute `start` with non-default value is not supported. Attribute `end` is not supported.
     ai.onnx:Slice:Inputs `starts`, `ends`, `axes`, and `steps` should be constant. Empty slice is not supported.
INFO:  CoreML is not recommended with this model as there are 16 partitions covering 25.0% of the nodes in the model. This will most likely result in worse performance than just using the CPU EP.
INFO:  Model should perform well with CoreML as is: NO
INFO:  Checking if model will perform better if the dynamic shapes are fixed...
INFO:  Partition information if the model was updated to make the shapes fixed:
INFO:  39 partitions with a total of 287/344 nodes can be handled by the CoreML EP.
INFO:  Partition sizes: [4, 2, 33, 3, 1, 10, 4, 2, 8, 6, 2, 33, 3, 1, 10, 4, 2, 8, 6, 2, 33, 3, 1, 10, 4, 2, 8, 6, 2, 33, 3, 1, 10, 4, 2, 8, 6, 2, 5]
INFO:  Unsupported nodes due to operator=57
INFO:  Unsupported ops: ai.onnx:Equal,ai.onnx:Erf,ai.onnx:Expand,ai.onnx:ReduceMean,ai.onnx:Unsqueeze,ai.onnx:Where
DEBUG:  Caveats that have not been checked and may result in a node not being supported:  
     ai.onnx:Gather:Input `indices` with scalar value is not supported.
     ai.onnx:MatMul:Input B should be constant.
     ai.onnx:Pow:Only supports cases when both inputs are fp32.
     ai.onnx:Shape:Attribute `start` with non-default value is not supported. Attribute `end` is not supported.
     ai.onnx:Slice:Inputs `starts`, `ends`, `axes`, and `steps` should be constant. Empty slice is not supported.
INFO:  CoreML is not recommended with this model as there are 39 partitions covering 83.4% of the nodes in the model. This will most likely result in worse performance than just using the CPU EP.
INFO:  Model should perform well with CoreML if modified to have fixed input shapes: NO
INFO:  ---------------
INFO:  Checking if pre-built ORT Mobile package can be used with net.onnx once model is converted from ONNX to ORT format using onnxruntime.tools.convert_onnx_models_to_ort...
DEBUG:  Checking if the data types and operators used in the model are supported in the pre-built ORT package...
INFO:  Model should work with the pre-built package.
INFO:  ---------------

INFO:  Run `python -m onnxruntime.tools.convert_onnx_models_to_ort ...` to convert the ONNX model to ORT format. By default, the conversion tool will create an ORT format model with saved optimizations which can potentially be applied at runtime (with a .with_runtime_opt.ort file extension) for use with NNAPI or CoreML, and a fully optimized ORT format model (with a .ort file extension) for use with the CPU EP.
INFO:  For optimal performance the <model>.ort model should be used with the CPU EP.

@pfeatherstone
Copy link
Author

Hi @natke have you had a chance to look at this?
I know the obvious thing to do is to use static shapes but that's not really a good option for me.
Do you know if this is likely to be improved on in the near future?

@natke
Copy link
Contributor

natke commented Mar 18, 2024

Hi @pfeatherstone, we are looking into it!

Copy link
Contributor

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Apr 18, 2024
@Rikyf3
Copy link

Rikyf3 commented Apr 25, 2024

Any news? I am facing a similar performance issue. LayerNorm and MultiHeadAttention seems not to be implemented as operators in CoreML. Any plans to support them?

@pfeatherstone
Copy link
Author

I just checked with onnxruntime 1.18 and it's exactly the same

@pfeatherstone
Copy link
Author

Any updates ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
platform:mobile issues related to ONNX Runtime mobile; typically submitted using template stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

3 participants