Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Jul 28, 2024
1 parent 43419c4 commit e53d544
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 19 deletions.
60 changes: 43 additions & 17 deletions crates/cubecl-core/src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,63 +7,89 @@ use std::sync::Arc;
#[derive(Hash, PartialEq, Eq, Clone, Debug)]
pub struct KernelId {
type_id: core::any::TypeId,
info: Info,
info: Option<Info>,
}

impl Display for KernelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{:?}", self.info))
match &self.info {
Some(info) => f.write_fmt(format_args!("{}", info)),
None => f.write_str("No info"),
}
}
}

impl KernelId {
/// Create a new [kernel id](KernelId) for a type.
pub fn new<T: 'static, I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
info: I,
) -> Self {
pub fn new<T: 'static>() -> Self {
Self {
type_id: core::any::TypeId::of::<T>(),
info: Info::new(info),
info: None,
}
}

/// Add information to the [kernel id](KernelId).
///
/// The information is used to differentiate kernels of the same kind but with different
/// configurations, which affect the generated code.
pub fn info<I: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(
mut self,
info: I,
) -> Self {
self.info = Some(Info::new(info));
self
}
}

/// Extra information
#[derive(Clone, Debug)]
struct Info {
id: Arc<dyn DynId>,
value: Arc<dyn DynKey>,
}

impl core::fmt::Display for Info {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{:?}", self.value))
}
}

impl Info {
fn new<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync>(id: T) -> Self {
Self { id: Arc::new(id) }
Self {
value: Arc::new(id),
}
}
}

trait DynId: core::fmt::Debug + Send + Sync {
/// This trait allows various types to be used as keys within a single data structure.
///
/// The downside is that the hashing method is hardcoded and cannot be configured using the
/// [core::hash::Hash] function. The provided [Hasher] will be modified, but only based on the
/// result of the hash from the [DefaultHasher].
trait DynKey: core::fmt::Debug + Send + Sync {
fn dyn_type_id(&self) -> TypeId;
fn dyn_eq(&self, other: &dyn DynId) -> bool;
fn dyn_eq(&self, other: &dyn DynKey) -> bool;
fn dyn_hash(&self, state: &mut dyn Hasher);
fn as_any(&self) -> &dyn Any;
}

impl PartialEq for Info {
fn eq(&self, other: &Self) -> bool {
self.id.dyn_eq(other.id.as_ref())
self.value.dyn_eq(other.value.as_ref())
}
}

impl Eq for Info {}

impl Hash for Info {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.dyn_type_id().hash(state);
self.id.dyn_hash(state)
self.value.dyn_type_id().hash(state);
self.value.dyn_hash(state)
}
}

impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynId for T {
fn dyn_eq(&self, other: &dyn DynId) -> bool {
impl<T: 'static + PartialEq + Eq + Hash + core::fmt::Debug + Send + Sync> DynKey for T {
fn dyn_eq(&self, other: &dyn DynKey) -> bool {
if let Some(other) = other.as_any().downcast_ref::<T>() {
self == other
} else {
Expand Down Expand Up @@ -93,8 +119,8 @@ mod tests {

#[test]
pub fn kernel_id_hash() {
let value_1 = KernelId::new::<(), _>("1");
let value_2 = KernelId::new::<(), _>("2");
let value_1 = KernelId::new::<()>().info("1");
let value_2 = KernelId::new::<()>().info("2");

let mut set = HashSet::new();

Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub trait Kernel: Send + Sync + 'static + Sized {
fn define(&self) -> KernelDefinition;
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> KernelId {
KernelId::new::<Self, _>(())
KernelId::new::<Self>()
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-macros/src/codegen_function/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ impl Codegen {
}

fn id(&self) -> cubecl::KernelId {
cubecl::KernelId::new::<Self, _>((#args))
cubecl::KernelId::new::<Self>().info((#args))
}
}
}
Expand Down

0 comments on commit e53d544

Please sign in to comment.