Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional output to autotuned operations #73

Merged
merged 1 commit into from
Aug 19, 2024

Conversation

wingertge
Copy link
Contributor

Adds an optional output type to tuner executions.
The type is defined at the type of calling execute and not on the Tuner itself because the Tuner is static and cannot depend on the generic types passed to the tensor operation function. This is the only way to make things like JitTensor<R, E, 4> work as output types.
The new output type defaults to (), making changes to existing tuners unnecessary.

Testing

All unit tests pass and existing tuners in burn compile fine and work exactly the same. I also tested the new output type to autotune different conv2d algorithms and it works as expected, with all tests passing.

@nathanielsimard
Copy link
Member

@wingertge Can you provide an example of how the output type can be used? I'm having trouble figuring out why we need an output type. The implementation looks good though.

@wingertge
Copy link
Contributor Author

wingertge commented Aug 18, 2024

Can you provide an example of how the output type can be used? I'm having trouble figuring out why we need an output type.

The way this is currently done in existing tuned operations is to initialize an output tensor and pass it to the kernel

let output = init_matmul_output(&lhs, &rhs);

static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();

TUNER.execute(
    &JitTuneId::new::<R>(&lhs.device),
    &client,
    Box::new(MatmulAutotuneOperationSet::new(lhs, rhs, output.clone())),
);

output

This works fine, as long as the operation you're passing it to ignores the can_mut flag. However, doing any postprocessing on the output tensor using functions that do respect can_mut (i.e. float_matmul, float_slice_assign) will copy the tensor content and the output will no longer be written to the output tensor. This makes things like im2col impossible without circumventing the backend and manually launching the low level kernel for each operation. Mutable references don't work because operations must have 'static lifetime.

Output allows us to instead return the output tensor from whatever postprocessing operations we need without unsafe workarounds like swapping the inner handles. Algorithms could then just return their output as normal:

/// Executes autotune on conv2d operations
pub fn conv2d_autotune<R: JitRuntime, E: FloatElement + Element, I: IntElement>(
    input: JitTensor<R, E, 4>,
    weights: JitTensor<R, E, 4>,
    bias: Option<JitTensor<R, E, 1>>,
    options: ConvOptions<2>,
) -> JitTensor<R, E, 4> {
    let client = input.client.clone();

    static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!("conv2d");

    TUNER.execute(
        &JitTuneId::new::<R>(&input.device),
        &client,
        Box::new(Conv2dOperationsSet::<R, E, I>::new(
            input, weights, bias, options,
        )),
    )
}

This gives a lot more implementation flexibility.

@nathanielsimard
Copy link
Member

@wingertge Awesome, thanks a lot for the detailed explanation. I agree that this simplifies some workflows.

@nathanielsimard nathanielsimard merged commit ccde038 into tracel-ai:main Aug 19, 2024
1 of 2 checks passed
@wingertge wingertge deleted the tune-output branch August 20, 2024 17:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants