diff --git a/crates/cubecl-runtime/src/tune/tuner.rs b/crates/cubecl-runtime/src/tune/tuner.rs index 5ada1aa9..b794568a 100644 --- a/crates/cubecl-runtime/src/tune/tuner.rs +++ b/crates/cubecl-runtime/src/tune/tuner.rs @@ -3,8 +3,9 @@ use web_time::Duration; #[cfg(not(target_family = "wasm"))] use core::time::Duration; -use core::{any::Any, mem::ManuallyDrop, panic::AssertUnwindSafe}; -use std::panic::{catch_unwind, resume_unwind}; +use core::{any::Any, mem::ManuallyDrop}; +#[cfg(feature = "std")] +use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; use alloc::boxed::Box; use alloc::string::ToString; @@ -80,6 +81,7 @@ impl Tuner { }) .collect(); + #[cfg(feature = "std")] if results.iter().all(|it| it.is_err()) { let first_error = results.into_iter().next().unwrap().err().unwrap(); resume_unwind(ManuallyDrop::into_inner(first_error)); @@ -115,10 +117,15 @@ impl Tuner { S: ComputeServer, C: ComputeChannel, { - catch_unwind(AssertUnwindSafe(|| { - TuneBenchmark::new(operation, client.clone()).run() - })) - .map_err(ManuallyDrop::new) + #[cfg(feature = "std")] + { + catch_unwind(AssertUnwindSafe(|| { + TuneBenchmark::new(operation, client.clone()).run() + })) + .map_err(ManuallyDrop::new) + } + #[cfg(not(feature = "std"))] + Ok(TuneBenchmark::new(operation, client.clone()).run()) } fn find_fastest(&self, results: Vec) -> usize {