Skip to content

Commit

Permalink
Gate unwind catch behind std feature
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Aug 17, 2024
1 parent 678a08f commit 3fec002
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions crates/cubecl-runtime/src/tune/tuner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use web_time::Duration;

#[cfg(not(target_family = "wasm"))]
use core::time::Duration;
use core::{any::Any, mem::ManuallyDrop, panic::AssertUnwindSafe};
use std::panic::{catch_unwind, resume_unwind};
use core::{any::Any, mem::ManuallyDrop};
#[cfg(feature = "std")]
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};

use alloc::boxed::Box;
use alloc::string::ToString;
Expand Down Expand Up @@ -80,6 +81,7 @@ impl<K: AutotuneKey> Tuner<K> {
})
.collect();

#[cfg(feature = "std")]
if results.iter().all(|it| it.is_err()) {
let first_error = results.into_iter().next().unwrap().err().unwrap();
resume_unwind(ManuallyDrop::into_inner(first_error));
Expand Down Expand Up @@ -115,10 +117,15 @@ impl<K: AutotuneKey> Tuner<K> {
S: ComputeServer,
C: ComputeChannel<S>,
{
catch_unwind(AssertUnwindSafe(|| {
TuneBenchmark::new(operation, client.clone()).run()
}))
.map_err(ManuallyDrop::new)
#[cfg(feature = "std")]
{
catch_unwind(AssertUnwindSafe(|| {
TuneBenchmark::new(operation, client.clone()).run()
}))
.map_err(ManuallyDrop::new)
}
#[cfg(not(feature = "std"))]
Ok(TuneBenchmark::new(operation, client.clone()).run())
}

fn find_fastest(&self, results: Vec<BenchmarkDurations>) -> usize {
Expand Down

0 comments on commit 3fec002

Please sign in to comment.