Skip to content

Commit

Permalink
Make autotune_persistent_cache a config that's automatically only des…
Browse files Browse the repository at this point in the history
…ktop enabled.
  • Loading branch information
ArthurBrussee committed Aug 13, 2024
1 parent 4e17724 commit 555775b
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 42 deletions.
1 change: 0 additions & 1 deletion crates/cubecl-cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ default = [
"cubecl-common/default",
"cubecl-core/default",
]
autotune = []
std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"]

[dependencies]
Expand Down
19 changes: 12 additions & 7 deletions crates/cubecl-runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,41 @@ default = [
"channel-mpsc",
"channel-cell",
"storage-bytes",
"autotune-persistent-cache",
"cubecl-common/default"
"cubecl-common/default",
]
std = ["cubecl-common/std"]
channel-mutex = []
channel-cell = []
channel-mpsc = ["dep:async-channel", "dep:pollster"] # Assume std
storage-bytes = []
autotune-persistent-cache = ["dirs", "md5", "serde", "serde_json"] # Assume std

[dependencies]
cubecl-common = { path = "../cubecl-common", version = "0.1.1", default-features = false }
derive-new = { workspace = true }
spin = { workspace = true }
log = { workspace = true }
hashbrown = { workspace = true }
dirs = { workspace = true, optional = true }
serde = { workspace = true, optional = true }
serde_json = { workspace = true, features = ["std"], optional = true }
md5 = { workspace = true, optional = true }

pollster = { workspace = true, optional = true }
async-channel = { workspace = true, optional = true }

# Persistent cache deps - has to match the autotune_persistent_cache cfg.
[target.'cfg(any(target_os = "windows", target_os = "linux", target_os = "macos"))'.dependencies]
dirs = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true, features = ["std"] }
md5 = { workspace = true }

[target.'cfg(target_family = "wasm")'.dependencies]
web-time = { workspace = true }

[dev-dependencies]
serial_test = { workspace = true }
rand = { workspace = true }

[build-dependencies]
cfg_aliases = "0.2.1"

[[bench]]
name = "dynamic"
harness = false
8 changes: 8 additions & 0 deletions crates/cubecl-runtime/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use cfg_aliases::cfg_aliases;

fn main() {
// Setup cfg aliases
cfg_aliases! {
autotune_persistent_cache: { any(target_os = "windows", target_os = "linux", target_os = "macos") },
}
}
9 changes: 5 additions & 4 deletions crates/cubecl-runtime/src/tune/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use core::fmt::{Debug, Display};
use core::hash::Hash;

/// Default checksum for an operation set
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
pub fn compute_checksum(autotunables: &[Box<dyn AutotuneOperation>]) -> String {
let mut checksum = String::new();
autotunables.iter().for_each(|op| {
Expand All @@ -28,7 +28,7 @@ pub trait AutotuneOperationSet<K>: Send {
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation>;

/// Compute a checksum that can invalidate outdated cached auto-tune results.
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
fn compute_checksum(&self) -> String {
compute_checksum(&self.autotunables())
}
Expand All @@ -48,7 +48,7 @@ pub trait AutotuneOperation {
fn clone(&self) -> Box<dyn AutotuneOperation>;
}

#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
/// Trait alias with support for persistent caching
pub trait AutotuneKey:
Clone
Expand All @@ -63,7 +63,8 @@ pub trait AutotuneKey:
+ Sync
{
}
#[cfg(not(feature = "autotune-persistent-cache"))]
#[cfg(not(autotune_persistent_cache))]
/// Trait alias
pub trait AutotuneKey: Clone + Debug + PartialEq + Eq + Hash + Display {}

impl AutotuneKey for String {}
44 changes: 21 additions & 23 deletions crates/cubecl-runtime/src/tune/tune_cache.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
mod std_imports {
pub use std::fs;
pub use std::fs::File;
pub use std::io;
pub use std::path::Path;
pub use std::path::PathBuf;
}
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
use std_imports::*;

#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
use serde::{Deserialize, Serialize};

use super::AutotuneKey;
Expand All @@ -18,7 +18,7 @@ use super::AutotuneOperationSet;
use alloc::boxed::Box;
use hashbrown::HashMap;

#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
/// Return the file path for the persistent cache on disk
/// prefix should be the device id computed at the backend level
pub fn get_persistent_cache_file_path(prefix: &str) -> PathBuf {
Expand All @@ -31,13 +31,13 @@ pub fn get_persistent_cache_file_path(prefix: &str) -> PathBuf {
/// In-memory cache entry
#[derive(Debug)]
pub(crate) struct InMemoryCacheEntry {
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
checksum_checked: bool,
fastest_index: usize,
}

/// Persistent cache entry
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct PersistentCacheEntry {
checksum: String,
Expand All @@ -48,11 +48,11 @@ pub(crate) struct PersistentCacheEntry {
#[derive(Debug)]
pub(crate) struct TuneCache<K> {
in_memory_cache: HashMap<K, InMemoryCacheEntry>,
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
persistent_cache: HashMap<K, PersistentCacheEntry>,
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
device_id: String,
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
name: String,
}

Expand All @@ -66,11 +66,10 @@ pub enum TuneCacheResult<K> {

impl<K: AutotuneKey> TuneCache<K> {
pub(crate) fn new(
#[cfg_attr(not(feature = "autotune-persistent-cache"), allow(unused_variables))] name: &str,
#[cfg_attr(not(feature = "autotune-persistent-cache"), allow(unused_variables))]
device_id: &str,
#[cfg_attr(not(autotune_persistent_cache), allow(unused_variables))] name: &str,
#[cfg_attr(not(autotune_persistent_cache), allow(unused_variables))] device_id: &str,
) -> Self {
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
{
let mut cache = TuneCache {
in_memory_cache: HashMap::new(),
Expand All @@ -87,7 +86,7 @@ impl<K: AutotuneKey> TuneCache<K> {
cache
}

#[cfg(not(feature = "autotune-persistent-cache"))]
#[cfg(not(autotune_persistent_cache))]
{
TuneCache {
in_memory_cache: HashMap::new(),
Expand All @@ -98,14 +97,14 @@ impl<K: AutotuneKey> TuneCache<K> {
pub(crate) fn find_fastest(&self, key: &K) -> Option<usize> {
let val = self.in_memory_cache.get(key)?;

#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
if val.checksum_checked {
Some(val.fastest_index)
} else {
None
}

#[cfg(not(feature = "autotune-persistent-cache"))]
#[cfg(not(autotune_persistent_cache))]
Some(val.fastest_index)
}

Expand All @@ -116,7 +115,7 @@ impl<K: AutotuneKey> TuneCache<K> {
let key = autotune_operation_set.key();
let result = self.in_memory_cache.get_mut(&key);

#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
{
if let Some(InMemoryCacheEntry {
checksum_checked,
Expand All @@ -138,7 +137,7 @@ impl<K: AutotuneKey> TuneCache<K> {
}
}

#[cfg(not(feature = "autotune-persistent-cache"))]
#[cfg(not(autotune_persistent_cache))]
{
if let Some(InMemoryCacheEntry { fastest_index, .. }) = result {
return TuneCacheResult::Hit(autotune_operation_set.fastest(*fastest_index));
Expand All @@ -152,14 +151,16 @@ impl<K: AutotuneKey> TuneCache<K> {
self.in_memory_cache.insert(
key,
InMemoryCacheEntry {
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
checksum_checked: true,
fastest_index,
},
);
}
}

#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
impl<K: AutotuneKey> TuneCache<K> {
pub(crate) fn persistent_cache_insert(
&mut self,
key: K,
Expand All @@ -176,7 +177,6 @@ impl<K: AutotuneKey> TuneCache<K> {
}

/// Load the persistent cache data from disk
#[cfg(feature = "autotune-persistent-cache")]
pub(crate) fn load(&mut self) -> Result<(), io::Error> {
let file_path = self.get_persistent_cache_file_path();
// note: reading file from memory is faster than using
Expand Down Expand Up @@ -212,7 +212,6 @@ impl<K: AutotuneKey> TuneCache<K> {
}

/// Save the persistent cache on disk
#[cfg(feature = "autotune-persistent-cache")]
pub(crate) fn save(&self) {
let file_path = self.get_persistent_cache_file_path();
if let Some(parent_dir) = file_path.parent() {
Expand All @@ -236,7 +235,6 @@ impl<K: AutotuneKey> TuneCache<K> {
}

/// Return the file path for the persistent cache on disk
#[cfg(feature = "autotune-persistent-cache")]
pub fn get_persistent_cache_file_path(&self) -> PathBuf {
get_persistent_cache_file_path(&format!("{}/{}", self.name, self.device_id))
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-runtime/src/tune/tuner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl<K: AutotuneKey> Tuner<K> {
log::info!("Fastest result {fastest_name}-{key}");

self.tune_cache.cache_insert(key.clone(), fastest_index);
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
{
let checksum = autotune_operation_set.compute_checksum();
self.tune_cache
Expand Down
6 changes: 3 additions & 3 deletions crates/cubecl-runtime/tests/dummy/tune/operation_sets.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
use rand::{distributions::Alphanumeric, Rng};
use std::sync::Arc;

#[cfg(feature = "autotune-persistent-cache")]
#[cfg(autotune_persistent_cache)]
use cubecl_runtime::tune::compute_checksum;
use cubecl_runtime::{
server::Binding,
Expand Down Expand Up @@ -167,7 +167,7 @@ impl AutotuneOperationSet<String> for CacheTestAutotuneOperationSet {
self.autotunables()[fastest_index].clone()
}

#[cfg(feature = "std")]
#[cfg(autotune_persistent_cache)]
fn compute_checksum(&self) -> String {
if self.generate_random_checksum {
let rand_string: String = rand::thread_rng()
Expand Down
9 changes: 6 additions & 3 deletions crates/cubecl-runtime/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ mod dummy;
use std::sync::Arc;

use crate::dummy::autotune_execute;
use crate::dummy::TEST_TUNER;
use crate::dummy::{client, DummyDevice, DummyElementwiseAddition};
use crate::dummy::{TEST_TUNER, TUNER_DEVICE_ID, TUNER_PREFIX};

#[cfg(autotune_persistent_cache)]
use crate::dummy::{TUNER_DEVICE_ID, TUNER_PREFIX};

use cubecl_runtime::ComputeRuntime;

Expand Down Expand Up @@ -139,7 +142,7 @@ fn autotune_cache_same_key_return_a_cache_hit() {

#[test]
#[serial]
#[cfg(feature = "std")]
#[cfg(autotune_persistent_cache)]
fn autotune_cache_no_cache_on_disk_return_a_cache_miss() {
TEST_TUNER.clear();

Expand Down Expand Up @@ -184,7 +187,7 @@ fn autotune_cache_no_cache_on_disk_return_a_cache_miss() {

#[test]
#[serial]
#[cfg(feature = "std")]
#[cfg(autotune_persistent_cache)]
fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() {
TEST_TUNER.clear();
// delete the cache file
Expand Down

0 comments on commit 555775b

Please sign in to comment.