From 678a08f4dd34906fe79f5b76d55aa7252e462196 Mon Sep 17 00:00:00 2001 From: Genna Wingert Date: Fri, 16 Aug 2024 22:14:58 +0200 Subject: [PATCH 1/4] Allow graceful failover when one tunable panics. --- crates/cubecl-runtime/src/tune/tuner.rs | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/crates/cubecl-runtime/src/tune/tuner.rs b/crates/cubecl-runtime/src/tune/tuner.rs index c4f936d3..5ada1aa9 100644 --- a/crates/cubecl-runtime/src/tune/tuner.rs +++ b/crates/cubecl-runtime/src/tune/tuner.rs @@ -3,6 +3,8 @@ use web_time::Duration; #[cfg(not(target_family = "wasm"))] use core::time::Duration; +use core::{any::Any, mem::ManuallyDrop, panic::AssertUnwindSafe}; +use std::panic::{catch_unwind, resume_unwind}; use alloc::boxed::Box; use alloc::string::ToString; @@ -16,6 +18,10 @@ use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCa use super::AutotuneKey; +/// An error that occured during benchmarking. If other benches succeeded, ignore this bench and +/// continue gracefully. If all benches fail, panic. +type BenchError = ManuallyDrop>; + #[derive(Debug)] /// Executes autotune benchmarking and caching pub struct Tuner { @@ -66,7 +72,7 @@ impl Tuner { let autotunables = autotune_operation_set.autotunables(); let mut names = Vec::with_capacity(autotunables.len()); - let results: Vec = autotunables + let results: Vec> = autotunables .into_iter() .map(|op| { names.push(op.name().to_string()); @@ -74,6 +80,12 @@ impl Tuner { }) .collect(); + if results.iter().all(|it| it.is_err()) { + let first_error = results.into_iter().next().unwrap().err().unwrap(); + resume_unwind(ManuallyDrop::into_inner(first_error)); + } + let results = results.into_iter().filter_map(Result::ok).collect(); + // Finds the fastest operation, stores it and returns it let fastest_index = self.find_fastest(results); let fastest_name = names.get(fastest_index).unwrap(); @@ -98,12 +110,15 @@ impl Tuner { &mut self, operation: Box, client: &ComputeClient, - ) -> BenchmarkDurations + ) -> Result where S: ComputeServer, C: ComputeChannel, { - TuneBenchmark::new(operation, client.clone()).run() + catch_unwind(AssertUnwindSafe(|| { + TuneBenchmark::new(operation, client.clone()).run() + })) + .map_err(ManuallyDrop::new) } fn find_fastest(&self, results: Vec) -> usize { From 3fec0028d1b2fdf57737059c643681c5f10e7749 Mon Sep 17 00:00:00 2001 From: Genna Wingert Date: Sat, 17 Aug 2024 12:48:52 +0200 Subject: [PATCH 2/4] Gate unwind catch behind `std` feature --- crates/cubecl-runtime/src/tune/tuner.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/crates/cubecl-runtime/src/tune/tuner.rs b/crates/cubecl-runtime/src/tune/tuner.rs index 5ada1aa9..b794568a 100644 --- a/crates/cubecl-runtime/src/tune/tuner.rs +++ b/crates/cubecl-runtime/src/tune/tuner.rs @@ -3,8 +3,9 @@ use web_time::Duration; #[cfg(not(target_family = "wasm"))] use core::time::Duration; -use core::{any::Any, mem::ManuallyDrop, panic::AssertUnwindSafe}; -use std::panic::{catch_unwind, resume_unwind}; +use core::{any::Any, mem::ManuallyDrop}; +#[cfg(feature = "std")] +use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; use alloc::boxed::Box; use alloc::string::ToString; @@ -80,6 +81,7 @@ impl Tuner { }) .collect(); + #[cfg(feature = "std")] if results.iter().all(|it| it.is_err()) { let first_error = results.into_iter().next().unwrap().err().unwrap(); resume_unwind(ManuallyDrop::into_inner(first_error)); @@ -115,10 +117,15 @@ impl Tuner { S: ComputeServer, C: ComputeChannel, { - catch_unwind(AssertUnwindSafe(|| { - TuneBenchmark::new(operation, client.clone()).run() - })) - .map_err(ManuallyDrop::new) + #[cfg(feature = "std")] + { + catch_unwind(AssertUnwindSafe(|| { + TuneBenchmark::new(operation, client.clone()).run() + })) + .map_err(ManuallyDrop::new) + } + #[cfg(not(feature = "std"))] + Ok(TuneBenchmark::new(operation, client.clone()).run()) } fn find_fastest(&self, results: Vec) -> usize { From 76df6e47f70d1af487655a5d215f90283e754af9 Mon Sep 17 00:00:00 2001 From: Genna Wingert Date: Sat, 17 Aug 2024 14:30:17 +0200 Subject: [PATCH 3/4] Print message to inform user an error is handled gracefully. --- crates/cubecl-runtime/src/tune/tuner.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/cubecl-runtime/src/tune/tuner.rs b/crates/cubecl-runtime/src/tune/tuner.rs index b794568a..1d25c95f 100644 --- a/crates/cubecl-runtime/src/tune/tuner.rs +++ b/crates/cubecl-runtime/src/tune/tuner.rs @@ -122,7 +122,10 @@ impl Tuner { catch_unwind(AssertUnwindSafe(|| { TuneBenchmark::new(operation, client.clone()).run() })) - .map_err(ManuallyDrop::new) + .map_err(|e| { + println!("Caught error while benchmarking, falling back to next operation."); + ManuallyDrop::new(e) + }) } #[cfg(not(feature = "std"))] Ok(TuneBenchmark::new(operation, client.clone()).run()) From 9ab16cc43e142df50500e34f38dc84afb57fc691 Mon Sep 17 00:00:00 2001 From: Genna Wingert Date: Mon, 19 Aug 2024 21:30:07 +0200 Subject: [PATCH 4/4] Improve docs on `BenchError` --- crates/cubecl-runtime/src/tune/tuner.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/cubecl-runtime/src/tune/tuner.rs b/crates/cubecl-runtime/src/tune/tuner.rs index 1d25c95f..88b9b497 100644 --- a/crates/cubecl-runtime/src/tune/tuner.rs +++ b/crates/cubecl-runtime/src/tune/tuner.rs @@ -21,6 +21,9 @@ use super::AutotuneKey; /// An error that occured during benchmarking. If other benches succeeded, ignore this bench and /// continue gracefully. If all benches fail, panic. +/// This error cannot be acted on in any way, because it's an opaque unwind object, and must be +/// `ManuallyDrop` because dropping it can cause unwinding to proceed. It can only +/// be passed to `resume_unwind` to continue the panic. type BenchError = ManuallyDrop>; #[derive(Debug)]