Skip to content

Commit

Permalink
Add optional output to autotuned operations (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Aug 19, 2024
1 parent f245f4b commit ccde038
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 33 deletions.
12 changes: 6 additions & 6 deletions crates/cubecl-runtime/src/tune/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ impl<AK: AutotuneKey, ID: Hash + PartialEq + Eq + Clone + Display> LocalTuner<AK
}

/// Execute the best operation in the provided [autotune operation set](AutotuneOperationSet)
pub fn execute<S, C>(
pub fn execute<S, C, Out>(
&self,
id: &ID,
client: &ComputeClient<S, C>,
autotune_operation_set: Box<dyn AutotuneOperationSet<AK>>,
) where
autotune_operation_set: Box<dyn AutotuneOperationSet<AK, Out>>,
) -> Out
where
S: ComputeServer,
C: ComputeChannel<S>,
{
Expand All @@ -61,8 +62,7 @@ impl<AK: AutotuneKey, ID: Hash + PartialEq + Eq + Clone + Display> LocalTuner<AK
let key = autotune_operation_set.key();
if let Some(index) = tuner.autotune_fastest(&key) {
let op = autotune_operation_set.fastest(index);
op.execute();
return;
return op.execute();
}
}
}
Expand All @@ -80,7 +80,7 @@ impl<AK: AutotuneKey, ID: Hash + PartialEq + Eq + Clone + Display> LocalTuner<AK
map.get_mut(id).unwrap()
};

tuner.execute_autotune(autotune_operation_set, client);
tuner.execute_autotune(autotune_operation_set, client)
}

/// Return the autotune result given a key.
Expand Down
14 changes: 7 additions & 7 deletions crates/cubecl-runtime/src/tune/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use core::hash::Hash;

/// Default checksum for an operation set
#[cfg(autotune_persistent_cache)]
pub fn compute_checksum(autotunables: &[Box<dyn AutotuneOperation>]) -> String {
pub fn compute_checksum<Out>(autotunables: &[Box<dyn AutotuneOperation<Out>>]) -> String {
let mut checksum = String::new();
autotunables.iter().for_each(|op| {
checksum += op.name();
Expand All @@ -15,17 +15,17 @@ pub fn compute_checksum(autotunables: &[Box<dyn AutotuneOperation>]) -> String {
}

/// Groups operations of the same type for autotune
pub trait AutotuneOperationSet<K>: Send {
pub trait AutotuneOperationSet<K, Output = ()>: Send {
/// The key used in the tune cache
fn key(&self) -> K;

/// All candidate operations for autotuning this operation type
/// Operations can run on toy tensors of relevant size
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>>;
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation<Output>>>;

/// Returns the operation for the given index, matching the order
/// returned by autotunables. Operation obtained here runs on original tensors
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation>;
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation<Output>>;

/// Compute a checksum that can invalidate outdated cached auto-tune results.
#[cfg(autotune_persistent_cache)]
Expand All @@ -35,17 +35,17 @@ pub trait AutotuneOperationSet<K>: Send {
}

/// Contains operation to run and inputs on which to run it
pub trait AutotuneOperation {
pub trait AutotuneOperation<Output = ()> {
/// Runs the operation
fn execute(self: Box<Self>);
fn execute(self: Box<Self>) -> Output;

/// The name of the operation.
fn name(&self) -> &str {
core::any::type_name::<Self>()
}

/// Clones the operation and inputs
fn clone(&self) -> Box<dyn AutotuneOperation>;
fn clone(&self) -> Box<dyn AutotuneOperation<Output>>;
}

#[cfg(autotune_persistent_cache)]
Expand Down
10 changes: 5 additions & 5 deletions crates/cubecl-runtime/src/tune/tune_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ use alloc::string::{String, ToString};

/// A benchmark that runs on server handles
#[derive(new)]
pub struct TuneBenchmark<S: ComputeServer, C> {
operation: Box<dyn AutotuneOperation>,
pub struct TuneBenchmark<S: ComputeServer, C, Out = ()> {
operation: Box<dyn AutotuneOperation<Out>>,
client: ComputeClient<S, C>,
}

impl Clone for Box<dyn AutotuneOperation> {
impl<Out> Clone for Box<dyn AutotuneOperation<Out>> {
fn clone(&self) -> Self {
self.as_ref().clone()
}
}

impl<S: ComputeServer, C: ComputeChannel<S>> Benchmark for TuneBenchmark<S, C> {
type Args = Box<dyn AutotuneOperation>;
impl<S: ComputeServer, C: ComputeChannel<S>, Out> Benchmark for TuneBenchmark<S, C, Out> {
type Args = Box<dyn AutotuneOperation<Out>>;

fn prepare(&self) -> Self::Args {
self.operation.clone()
Expand Down
12 changes: 6 additions & 6 deletions crates/cubecl-runtime/src/tune/tune_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ pub(crate) struct TuneCache<K> {
}

/// Result of the cache try
pub enum TuneCacheResult<K> {
pub enum TuneCacheResult<K, Out = ()> {
/// An operation is found and given
Hit(Box<dyn AutotuneOperation>),
Hit(Box<dyn AutotuneOperation<Out>>),
/// No operation is found and the set is given back for ownership
Miss(Box<dyn AutotuneOperationSet<K>>),
Miss(Box<dyn AutotuneOperationSet<K, Out>>),
}

impl<K: AutotuneKey> TuneCache<K> {
Expand Down Expand Up @@ -108,10 +108,10 @@ impl<K: AutotuneKey> TuneCache<K> {
Some(val.fastest_index)
}

pub(crate) fn try_cache(
pub(crate) fn try_cache<Out>(
&mut self,
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
) -> TuneCacheResult<K> {
autotune_operation_set: Box<dyn AutotuneOperationSet<K, Out>>,
) -> TuneCacheResult<K, Out> {
let key = autotune_operation_set.key();
let result = self.in_memory_cache.get_mut(&key);

Expand Down
19 changes: 10 additions & 9 deletions crates/cubecl-runtime/src/tune/tuner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ impl<K: AutotuneKey> Tuner<K> {
}

/// Execute the fastest autotune operation if known, otherwise perform some benchmarks before.
pub fn execute_autotune<S, C>(
pub fn execute_autotune<S, C, Out>(
&mut self,
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
autotune_operation_set: Box<dyn AutotuneOperationSet<K, Out>>,
client: &ComputeClient<S, C>,
) where
) -> Out
where
S: ComputeServer,
C: ComputeChannel<S>,
{
Expand All @@ -50,14 +51,14 @@ impl<K: AutotuneKey> Tuner<K> {
super::TuneCacheResult::Miss(set) => self.autotuning(set, client),
};

AutotuneOperation::execute(operation);
AutotuneOperation::<Out>::execute(operation)
}

fn autotuning<S, C>(
fn autotuning<S, C, Out>(
&mut self,
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
autotune_operation_set: Box<dyn AutotuneOperationSet<K, Out>>,
client: &ComputeClient<S, C>,
) -> Box<dyn AutotuneOperation>
) -> Box<dyn AutotuneOperation<Out>>
where
S: ComputeServer,
C: ComputeChannel<S>,
Expand Down Expand Up @@ -94,9 +95,9 @@ impl<K: AutotuneKey> Tuner<K> {
}
}

fn run_benchmark<S, C>(
fn run_benchmark<S, C, Out>(
&mut self,
operation: Box<dyn AutotuneOperation>,
operation: Box<dyn AutotuneOperation<Out>>,
client: &ComputeClient<S, C>,
) -> BenchmarkDurations
where
Expand Down

0 comments on commit ccde038

Please sign in to comment.