Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move output swizzling pass before fusions #1651

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from

Conversation

krzysz00
Copy link
Collaborator

  • Move the output fusion swizzling before fusions
  • Due to the tension between LDS consolidation and multibuffering, remove the reuse-lds call before the output swizzle. As a consequence, remove the "increasing total LDS usage" heuristic from output swizzle enablemente, since it should probably be fine
  • Fix an issue where fusion traversal wasn't working correctly, resulting in insufficinetly vectorized writes to global memory despite previous attempts to fix the issue
  • Fix a test that wasn't using i8 LDS
  • Update the packed arithmetic test to check for vectorized writes
  • Add a guard in case the ExistingOps strictness is still letting LDS writes into the output swizzle rewrite

* Move the output fusion swizzling before fusions
* Due to the tension between LDS consolidation and multibuffering,
remove the reuse-lds call before the output swizzle. As a consequence,
remove the "increasing total LDS usage" heuristic from output swizzle
enablemente, since it should probably be fine
* Fix an issue where fusion traversal wasn't working correctly,
resulting in insufficinetly vectorized writes to global memory despite
previous attempts to fix the issue
* Fix a test that wasn't using i8 LDS
* Update the packed arithmetic test to check for vectorized writes
* Add a guard in case the ExistingOps strictness is still letting LDS
writes into the output swizzle rewrite
@@ -441,14 +426,6 @@ void RockOutputSwizzlePass::runOnOperation() {
<< ldsRequiredBytes << " bytes, skipping pass\n");
return;
}
// heuristic: if we need more LDS, skip this pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check if there's any performance regression due to this. I'm happy to do this if you are busy with other things.

%input_reg = rock.alloc() : memref<16xf32, #gpu.address_space<private>>
%output_reg = rock.alloc() : memref<16xf32, #gpu.address_space<private>>
%ws_lds = rock.alloc() : memref<64xf32, #gpu.address_space<workgroup>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if rock.alloc() are always supposed to allocate i8, should we add that as a check in GpuAllocOp::verify()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more that the LDS reduce pass fails if you don't do this so ... yeah, I'll add a check.

@manupak
Copy link
Contributor

manupak commented Sep 17, 2024

What problem are you solving here ?

As in what is the motivation for this to be moved upwards in the pipeline and rely "more" on "utils" to get vectorizations data?

krzysz00 and others added 2 commits September 17, 2024 16:48
* Move the output fusion swizzling before fusions
* Due to the tension between LDS consolidation and multibuffering,
remove the reuse-lds call before the output swizzle. As a consequence,
remove the "increasing total LDS usage" heuristic from output swizzle
enablemente, since it should probably be fine
* Fix an issue where fusion traversal wasn't working correctly,
resulting in insufficinetly vectorized writes to global memory despite
previous attempts to fix the issue
* Fix a test that wasn't using i8 LDS
* Update the packed arithmetic test to check for vectorized writes
* Add a guard in case the ExistingOps strictness is still letting LDS
writes into the output swizzle rewrite
@manupak
Copy link
Contributor

manupak commented Sep 18, 2024

After a long discussion with @dhernandez0, Im feeling we should not be doing this w/o more analysis on how this affects threadwise_read_into ops of other inputs. @krzysz00 we three can have a chat when you are working..

Before :

  • We were only modifying the last threadwise_write_all where:
    • All other inputs for post-gemm fused operations are read in the original output layout (of MFMAs).
    • OutputSwizzle pass only looked at the last threadwise_write_all that is post all fusions where inputs are threadwise_read_into d in that original layout.
    • LDS based swizzling happened after all fusions are read and computed.
    • As a final note, we kind of go with the heuristic that extraViews is representative of how vector length is going to be decided.
      • but truly, we should be be looking at all the views (extra views + views on dest), however, Im nervous to expose removeUpperDims to any sort of combination of transform map given how much of problem I encountered with it recently.

After this PR :

  • We hoisting the LDS based swizzling directly after the gemm.
  • Then, we should also be looking at vector lengths of all the threadwise_read_intos and threadwise_write_all when deciding to the output swizzle. Im not saying this is bad but just moving the current implementation up will not just work (tm).
  • Even then its not clear to me, how do we decide to transpose m <-> n when the all above global memory operations disagree with each other. If you have a thought, Im all ears.

Again, coming back to the original question of mine, what problem are you solving here ?

@krzysz00
Copy link
Collaborator Author

The problem I wanted to solve here was that we'd be doing badly-formatted reads from global memory because we'd be reading in the MFMA layout and not the coalesced-read-promoting layout that you get after doing the LDS swizzle - on the reasonable assumption that the fusion inputs are stored somewhat like the final output

Which might've been false

@manupak
Copy link
Contributor

manupak commented Sep 18, 2024

That itself sounds like good idea... however I think we can pragmatically verify that is the case -- as in we only do the output swizzle if the gemm output buffer agrees with other input's layout. So that way we dont create a potential regression.

@dhernandez0
Copy link
Contributor

I had some time and I've run some performance experiments (see the file attached). There are some performance regressions:
results.xlsx

We should do performance experiments for fusions as well I think.

@manupak
Copy link
Contributor

manupak commented Oct 1, 2024

Thanks @dhernandez0...
I must admit I did not expect a regression in non-fused cases... we need to figure out whats going on there..

@dhernandez0
Copy link
Contributor

dhernandez0 commented Oct 1, 2024

I've done a quick experiment with fusion (conv+add+relu) on MI300:

module {                                                                                                                                                                                           func.func @mlir_transpose_convolution_add_relu(%arg0: !migraphx.shaped<1x512x7x7xf32, 25088x1x3584x512>, %arg1: !migraphx.shaped<1x7x7x2048xf32, 100352x14336x2048x1>, %arg2: !migraphx.shaped<512x2048x1x1xf32, 2048x1x1x1>) -> !migraphx.shaped<1x512x7x7xf32, 25088x1x3584x512> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
    %0 = migraphx.transpose %arg1 {permutation = [0, 3, 1, 2]} : <1x7x7x2048xf32, 100352x14336x2048x1> -> <1x2048x7x7xf32, 100352x1x14336x2048>
    %1 = migraphx.convolution %0, %arg2 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x2048x7x7xf32, 100352x1x14336x2048>, <512x2048x1x1xf32, 2048x1x1x1> -> <1x512x7x7xf32, 25088x1x3584x512>
    %2 = migraphx.add %1, %arg0 : <1x512x7x7xf32, 25088x1x3584x512>, <1x512x7x7xf32, 25088x1x3584x512> -> <1x512x7x7xf32, 25088x1x3584x512>
    %3 = migraphx.relu %2 : <1x512x7x7xf32, 25088x1x3584x512> -> <1x512x7x7xf32, 25088x1x3584x512>
    return %3 : !migraphx.shaped<1x512x7x7xf32, 25088x1x3584x512>
  }
}

tensorflow code:

import tensorflow as tf
import numpy as np
import tf2onnx

# Define the model
class ConvAddReluModelNHWC(tf.keras.Model):
    def __init__(self):
        super(ConvAddReluModelNHWC, self).__init__()
        # Convolution layer
        self.conv = tf.keras.layers.Conv2D(
            filters=512,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding='valid',
            use_bias=False
        )
        # ReLU activation
        self.relu = tf.keras.layers.ReLU()

    def call(self, inputs):
        arg0, arg1 = inputs
        # Perform convolution
        conv_out = self.conv(arg1)  # arg1: Input tensor to convolution

        # Perform element-wise addition
        add_out = conv_out + arg0  # Add output of conv and arg0

        # Apply ReLU
        relu_out = self.relu(add_out)

        return relu_out

# Create an instance of the model
model = ConvAddReluModelNHWC()

# Define input tensors in NHWC format (TensorFlow's default format)
arg0 = np.random.randn(512, 7, 7, 512).astype(np.float16)  # NHWC
arg1 = np.random.randn(512, 7, 7, 2048).astype(np.float16)  # NHWC
#arg2 = np.random.randn(512, 1, 1, 2048).astype(np.float16)  # NHWC

# Convert inputs to TensorFlow tensors
arg0_tf = tf.convert_to_tensor(arg0)
arg1_tf = tf.convert_to_tensor(arg1)
#arg2_tf = tf.convert_to_tensor(arg2)

# Run a forward pass to ensure the model works
model.build(input_shape=[(None, 7, 7, 512), (None, 7, 7, 2048)])
output = model((arg0_tf, arg1_tf))

# Export the model to ONNX format
spec = (tf.TensorSpec((None, 7, 7, 512), tf.float16, name="arg0"),
        tf.TensorSpec((None, 7, 7, 2048), tf.float16, name="arg1"))

# Export the model to ONNX format using tf2onnx
model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13)

# Save the ONNX model to file
with open("slow_nhwc_tf.onnx", "wb") as f:
    f.write(model_proto.SerializeToString())

print("Model exported to slow_nhwc_tf.onnx")

There is a nice speed up in this case:

develop 0.0541076ms
move outswizzle 0.0463547ms

To reproduce, run:

python3 slow_tf.py
MIGRAPHX_DISABLE_PASSES=auto_contiguous MIGRAPHX_TRACE_BENCHMARKING=3 ./bin/migraphx-driver perf --exhaustive-tune --onnx slow_nhwc_tf.onnx

@manupak
Copy link
Contributor

manupak commented Oct 1, 2024

Thanks @dhernandez0! so it works out nicely when the layout of other inputs match with conv output.

@dhernandez0
Copy link
Contributor

Thanks @dhernandez0! so it works out nicely when the layout of other inputs match with conv output.

Yes, I think so, it makes sense. However, I think this is a realistic case for most use cases. Tensors of a network generally have the same layout. In the recent ticket https://github.com/ROCm/rocMLIR-internal/issues/1625, other is a vector of 512 elements only, in that case I think it probably doesn't matter how you access because it's cached? I guess we need to understand what the most typical use cases of fusions are... (other layout and shape).

After running 5 times here are averages of the previous experiment:
develop 0.06248738ms
move output swizzle 0.0570795ms

@dhernandez0
Copy link
Contributor

@krzysz00 just came to my mind. I used the develop branch after upstream merge for these experiments. I think this PR branch is not up to date.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants