diff --git a/crates/cubecl-core/src/id.rs b/crates/cubecl-core/src/id.rs index f0481165..bfe9132f 100644 --- a/crates/cubecl-core/src/id.rs +++ b/crates/cubecl-core/src/id.rs @@ -7,49 +7,75 @@ use std::sync::Arc; #[derive(Hash, PartialEq, Eq, Clone, Debug)] pub struct KernelId { type_id: core::any::TypeId, - info: Info, + info: Option, } impl Display for KernelId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!("{:?}", self.info)) + match &self.info { + Some(info) => f.write_fmt(format_args!("{}", info)), + None => f.write_str("No info"), + } } } impl KernelId { /// Create a new [kernel id](KernelId) for a type. - pub fn new( - info: I, - ) -> Self { + pub fn new() -> Self { Self { type_id: core::any::TypeId::of::(), - info: Info::new(info), + info: None, } } + + /// Add information to the [kernel id](KernelId). + /// + /// The information is used to differentiate kernels of the same kind but with different + /// configurations, which affect the generated code. + pub fn info( + mut self, + info: I, + ) -> Self { + self.info = Some(Info::new(info)); + self + } } /// Extra information #[derive(Clone, Debug)] struct Info { - id: Arc, + value: Arc, +} + +impl core::fmt::Display for Info { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{:?}", self.value)) + } } impl Info { fn new(id: T) -> Self { - Self { id: Arc::new(id) } + Self { + value: Arc::new(id), + } } } -trait DynId: core::fmt::Debug + Send + Sync { +/// This trait allows various types to be used as keys within a single data structure. +/// +/// The downside is that the hashing method is hardcoded and cannot be configured using the +/// [core::hash::Hash] function. The provided [Hasher] will be modified, but only based on the +/// result of the hash from the [DefaultHasher]. +trait DynKey: core::fmt::Debug + Send + Sync { fn dyn_type_id(&self) -> TypeId; - fn dyn_eq(&self, other: &dyn DynId) -> bool; + fn dyn_eq(&self, other: &dyn DynKey) -> bool; fn dyn_hash(&self, state: &mut dyn Hasher); fn as_any(&self) -> &dyn Any; } impl PartialEq for Info { fn eq(&self, other: &Self) -> bool { - self.id.dyn_eq(other.id.as_ref()) + self.value.dyn_eq(other.value.as_ref()) } } @@ -57,13 +83,13 @@ impl Eq for Info {} impl Hash for Info { fn hash(&self, state: &mut H) { - self.id.dyn_type_id().hash(state); - self.id.dyn_hash(state) + self.value.dyn_type_id().hash(state); + self.value.dyn_hash(state) } } -impl DynId for T { - fn dyn_eq(&self, other: &dyn DynId) -> bool { +impl DynKey for T { + fn dyn_eq(&self, other: &dyn DynKey) -> bool { if let Some(other) = other.as_any().downcast_ref::() { self == other } else { @@ -93,8 +119,8 @@ mod tests { #[test] pub fn kernel_id_hash() { - let value_1 = KernelId::new::<(), _>("1"); - let value_2 = KernelId::new::<(), _>("2"); + let value_1 = KernelId::new::<()>().info("1"); + let value_2 = KernelId::new::<()>().info("2"); let mut set = HashSet::new(); diff --git a/crates/cubecl-core/src/lib.rs b/crates/cubecl-core/src/lib.rs index 99bf9687..2cdbd091 100644 --- a/crates/cubecl-core/src/lib.rs +++ b/crates/cubecl-core/src/lib.rs @@ -46,7 +46,7 @@ pub trait Kernel: Send + Sync + 'static + Sized { fn define(&self) -> KernelDefinition; /// Identifier for the kernel, used for caching kernel compilation. fn id(&self) -> KernelId { - KernelId::new::(()) + KernelId::new::() } } diff --git a/crates/cubecl-macros/src/codegen_function/launch.rs b/crates/cubecl-macros/src/codegen_function/launch.rs index bcbb4469..660e4f32 100644 --- a/crates/cubecl-macros/src/codegen_function/launch.rs +++ b/crates/cubecl-macros/src/codegen_function/launch.rs @@ -360,7 +360,7 @@ impl Codegen { } fn id(&self) -> cubecl::KernelId { - cubecl::KernelId::new::((#args)) + cubecl::KernelId::new::().info((#args)) } } }