-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] mlx.core.conv_general is really slow #1409
Comments
Indeed, we are aware that there are performance cliffs in our convolutions, see e.g. #1313 Thanks for the benchmark though! We will make sure to include the 3D stuff in our optimizations. FYI, there are a few problems with your benchmark:
Here is an improved version: def mlx_sample():
x = mx.random.normal([8, 16, 128, 128, 32], dtype=mx.float32)
weight = mx.random.normal([4, 1, 1, 1, 32], dtype=mx.float32)
stride = [1, 1, 1]
padding = [0, 0, 0]
dilation = [1, 1, 1]
# Warmup
for _ in range(5):
out = mx.conv_general(x, weight, stride, padding, dilation, stream=mx.gpu)
mx.eval(out)
start = time.time()
n = 10
for _ in range(n):
out = mx.conv_general(x, weight, stride, padding, dilation, stream=mx.gpu)
mx.eval(out)
print(f'MLX time: {(time.time() - start) * 1000 / n:0.2f}ms')
def torch_sample():
x = torch.randn([8, 32, 16, 128, 128], dtype=torch.float32, device='mps')
weight = torch.randn([4, 32, 1, 1, 1], dtype=torch.float32, device='mps')
bias = torch.randn([4], dtype=torch.float32, device='mps')
stride = [1, 1, 1]
padding = [0, 0, 0]
dilation = [1, 1, 1]
# Warmup
for _ in range(5):
out = torch.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0, 0], 1)
torch.mps.synchronize()
start = time.time()
n = 10
for _ in range(n):
out = torch.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0, 0], 1)
torch.mps.synchronize()
print(f'MPS time: {(time.time() - start) * 1000 / n:0.2f}ms') Finally, for a 1x1x1 convolution, I'd encourage you to use a Linear layer / matmul. That will be way faster for now. I ran the revised benchmark on an M2 Ultra:
MLX with matmul instead:
That looks like this: x = mx.random.normal([8, 16 * 128 * 128, 32], dtype=mx.float32)
weight = mx.random.normal([4, 32], dtype=mx.float32)
out = x @ weight.T |
Thanks a lot |
After some digging, the main issue is channel size of the input and the first dimension of the weight. PyTorch has implemented convolution using native code by utilising MPSGraphConvolution3DOpDescriptor. x = mx.random.normal([8, 142, 16, 64, 64], dtype=mx.float32)
weight = mx.random.normal([22, 142, 7, 7, 7], dtype=mx.float32)
bias = mx.random.normal([22], dtype=mx.float32)
stride = [1, 1, 1]
padding = [3, 3, 3]
dilation = [1, 1, 1] Is there a way to get a metal buffer from an array? Then I'll be probably able to use the native function to calculate 3d convolution. |
You have a couple of options:
More info in the docs on converting to other frameworks. |
Describe the bug
Method mlx.core.conv_general is significantly slower than PyTorch analog. Can vary from 10x to 150x slower.
To Reproduce
Just run the attached code.
Include code snippet
Output:
MLX time: 20.65ms
MPS time: 0.93ms
Expected behavior
At least the same speed as in PyTorch.
Desktop (please complete the following information):
The text was updated successfully, but these errors were encountered: