diff --git a/crates/cubecl-runtime/src/tune/tuner.rs b/crates/cubecl-runtime/src/tune/tuner.rs index c4f936d3..88b9b497 100644 --- a/crates/cubecl-runtime/src/tune/tuner.rs +++ b/crates/cubecl-runtime/src/tune/tuner.rs @@ -3,6 +3,9 @@ use web_time::Duration; #[cfg(not(target_family = "wasm"))] use core::time::Duration; +use core::{any::Any, mem::ManuallyDrop}; +#[cfg(feature = "std")] +use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; use alloc::boxed::Box; use alloc::string::ToString; @@ -16,6 +19,13 @@ use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCa use super::AutotuneKey; +/// An error that occured during benchmarking. If other benches succeeded, ignore this bench and +/// continue gracefully. If all benches fail, panic. +/// This error cannot be acted on in any way, because it's an opaque unwind object, and must be +/// `ManuallyDrop` because dropping it can cause unwinding to proceed. It can only +/// be passed to `resume_unwind` to continue the panic. +type BenchError = ManuallyDrop>; + #[derive(Debug)] /// Executes autotune benchmarking and caching pub struct Tuner { @@ -66,7 +76,7 @@ impl Tuner { let autotunables = autotune_operation_set.autotunables(); let mut names = Vec::with_capacity(autotunables.len()); - let results: Vec = autotunables + let results: Vec> = autotunables .into_iter() .map(|op| { names.push(op.name().to_string()); @@ -74,6 +84,13 @@ impl Tuner { }) .collect(); + #[cfg(feature = "std")] + if results.iter().all(|it| it.is_err()) { + let first_error = results.into_iter().next().unwrap().err().unwrap(); + resume_unwind(ManuallyDrop::into_inner(first_error)); + } + let results = results.into_iter().filter_map(Result::ok).collect(); + // Finds the fastest operation, stores it and returns it let fastest_index = self.find_fastest(results); let fastest_name = names.get(fastest_index).unwrap(); @@ -98,12 +115,23 @@ impl Tuner { &mut self, operation: Box, client: &ComputeClient, - ) -> BenchmarkDurations + ) -> Result where S: ComputeServer, C: ComputeChannel, { - TuneBenchmark::new(operation, client.clone()).run() + #[cfg(feature = "std")] + { + catch_unwind(AssertUnwindSafe(|| { + TuneBenchmark::new(operation, client.clone()).run() + })) + .map_err(|e| { + println!("Caught error while benchmarking, falling back to next operation."); + ManuallyDrop::new(e) + }) + } + #[cfg(not(feature = "std"))] + Ok(TuneBenchmark::new(operation, client.clone()).run()) } fn find_fastest(&self, results: Vec) -> usize {