Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow graceful failover when one tunable panics #74

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions crates/cubecl-runtime/src/tune/tuner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use web_time::Duration;

#[cfg(not(target_family = "wasm"))]
use core::time::Duration;
use core::{any::Any, mem::ManuallyDrop};
#[cfg(feature = "std")]
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};

use alloc::boxed::Box;
use alloc::string::ToString;
Expand All @@ -16,6 +19,13 @@ use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCa

use super::AutotuneKey;

/// An error that occured during benchmarking. If other benches succeeded, ignore this bench and
/// continue gracefully. If all benches fail, panic.
/// This error cannot be acted on in any way, because it's an opaque unwind object, and must be
/// `ManuallyDrop` because dropping it can cause unwinding to proceed. It can only
/// be passed to `resume_unwind` to continue the panic.
type BenchError = ManuallyDrop<Box<dyn Any + Send>>;
wingertge marked this conversation as resolved.
Show resolved Hide resolved

#[derive(Debug)]
/// Executes autotune benchmarking and caching
pub struct Tuner<K: AutotuneKey> {
Expand Down Expand Up @@ -66,14 +76,21 @@ impl<K: AutotuneKey> Tuner<K> {
let autotunables = autotune_operation_set.autotunables();
let mut names = Vec::with_capacity(autotunables.len());

let results: Vec<BenchmarkDurations> = autotunables
let results: Vec<Result<BenchmarkDurations, BenchError>> = autotunables
.into_iter()
.map(|op| {
names.push(op.name().to_string());
self.run_benchmark(op, client)
})
.collect();

#[cfg(feature = "std")]
if results.iter().all(|it| it.is_err()) {
let first_error = results.into_iter().next().unwrap().err().unwrap();
resume_unwind(ManuallyDrop::into_inner(first_error));
}
let results = results.into_iter().filter_map(Result::ok).collect();

// Finds the fastest operation, stores it and returns it
let fastest_index = self.find_fastest(results);
let fastest_name = names.get(fastest_index).unwrap();
Expand All @@ -98,12 +115,23 @@ impl<K: AutotuneKey> Tuner<K> {
&mut self,
operation: Box<dyn AutotuneOperation>,
client: &ComputeClient<S, C>,
) -> BenchmarkDurations
) -> Result<BenchmarkDurations, BenchError>
where
S: ComputeServer,
C: ComputeChannel<S>,
{
TuneBenchmark::new(operation, client.clone()).run()
#[cfg(feature = "std")]
{
catch_unwind(AssertUnwindSafe(|| {
TuneBenchmark::new(operation, client.clone()).run()
}))
.map_err(|e| {
println!("Caught error while benchmarking, falling back to next operation.");
ManuallyDrop::new(e)
})
}
#[cfg(not(feature = "std"))]
Ok(TuneBenchmark::new(operation, client.clone()).run())
}

fn find_fastest(&self, results: Vec<BenchmarkDurations>) -> usize {
Expand Down
Loading