Skip to content

Commit

Permalink
Allow graceful failover when one tunable panics (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Aug 20, 2024
1 parent f8f4c6c commit 23f0d33
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions crates/cubecl-runtime/src/tune/tuner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use web_time::Duration;

#[cfg(not(target_family = "wasm"))]
use core::time::Duration;
use core::{any::Any, mem::ManuallyDrop};
#[cfg(feature = "std")]
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};

use alloc::boxed::Box;
use alloc::string::ToString;
Expand All @@ -16,6 +19,13 @@ use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCa

use super::AutotuneKey;

/// An error that occured during benchmarking. If other benches succeeded, ignore this bench and
/// continue gracefully. If all benches fail, panic.
/// This error cannot be acted on in any way, because it's an opaque unwind object, and must be
/// `ManuallyDrop` because dropping it can cause unwinding to proceed. It can only
/// be passed to `resume_unwind` to continue the panic.
type BenchError = ManuallyDrop<Box<dyn Any + Send>>;

#[derive(Debug)]
/// Executes autotune benchmarking and caching
pub struct Tuner<K: AutotuneKey> {
Expand Down Expand Up @@ -67,14 +77,21 @@ impl<K: AutotuneKey> Tuner<K> {
let autotunables = autotune_operation_set.autotunables();
let mut names = Vec::with_capacity(autotunables.len());

let results: Vec<BenchmarkDurations> = autotunables
let results: Vec<Result<BenchmarkDurations, BenchError>> = autotunables
.into_iter()
.map(|op| {
names.push(op.name().to_string());
self.run_benchmark(op, client)
})
.collect();

#[cfg(feature = "std")]
if results.iter().all(|it| it.is_err()) {
let first_error = results.into_iter().next().unwrap().err().unwrap();
resume_unwind(ManuallyDrop::into_inner(first_error));
}
let results = results.into_iter().filter_map(Result::ok).collect();

// Finds the fastest operation, stores it and returns it
let fastest_index = self.find_fastest(results);
let fastest_name = names.get(fastest_index).unwrap();
Expand All @@ -99,12 +116,23 @@ impl<K: AutotuneKey> Tuner<K> {
&mut self,
operation: Box<dyn AutotuneOperation<Out>>,
client: &ComputeClient<S, C>,
) -> BenchmarkDurations
) -> Result<BenchmarkDurations, BenchError>
where
S: ComputeServer,
C: ComputeChannel<S>,
{
TuneBenchmark::new(operation, client.clone()).run()
#[cfg(feature = "std")]
{
catch_unwind(AssertUnwindSafe(|| {
TuneBenchmark::new(operation, client.clone()).run()
}))
.map_err(|e| {
println!("Caught error while benchmarking, falling back to next operation.");
ManuallyDrop::new(e)
})
}
#[cfg(not(feature = "std"))]
Ok(TuneBenchmark::new(operation, client.clone()).run())
}

fn find_fastest(&self, results: Vec<BenchmarkDurations>) -> usize {
Expand Down

0 comments on commit 23f0d33

Please sign in to comment.