From 9ee042bc8777449f300f89f74567188bbb74902f Mon Sep 17 00:00:00 2001 From: maxime Date: Wed, 27 Nov 2024 16:45:05 -0500 Subject: [PATCH] implement Tensor::coordinate method --- .../src/frontend/container/tensor/base.rs | 54 ++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/crates/cubecl-core/src/frontend/container/tensor/base.rs b/crates/cubecl-core/src/frontend/container/tensor/base.rs index 299748e14..5e99a3f38 100644 --- a/crates/cubecl-core/src/frontend/container/tensor/base.rs +++ b/crates/cubecl-core/src/frontend/container/tensor/base.rs @@ -18,7 +18,10 @@ pub struct Tensor { /// Module that contains the implementation details of the metadata functions. mod metadata { use super::*; - use crate::{ir::Instruction, prelude::Array}; + use crate::{ + ir::{BinaryOperator, Instruction, Operator}, + prelude::Array, + }; impl Tensor { /// Obtain the stride of input at dimension dim @@ -31,6 +34,11 @@ mod metadata { unexpanded!() } + /// Obtain the coordinate corresponding to the given `index` of input at dimension `dim`. + pub fn coordinate(&self, _index: I, _dim: D) -> u32 { + unexpanded!() + } + /// The number of vectorized elements in the tensor. /// /// # Warning @@ -76,6 +84,16 @@ mod metadata { expand.__expand_shape_method(context, dim) } + // Expand function of [coordinate](Tensor::coordinate). + pub fn __expand_coordinate( + context: &mut CubeContext, + expand: ExpandElementTyped>, + index: ExpandElementTyped, + dim: ExpandElementTyped, + ) -> ExpandElementTyped { + expand.__expand_coordinate_method(context, index, dim) + } + // Expand function of [len](Tensor::len). pub fn __expand_len( context: &mut CubeContext, @@ -138,6 +156,40 @@ mod metadata { out.into() } + // Expand method of [coordinate](Tensor::coordinate). + pub fn __expand_coordinate_method( + self, + context: &mut CubeContext, + index: ExpandElementTyped, + dim: ExpandElementTyped, + ) -> ExpandElementTyped { + let index: ExpandElement = index.into(); + let stride = self.clone().__expand_stride_method(context, dim.clone()); + let shape = self.clone().__expand_shape_method(context, dim.clone()); + + // Compute `num_strides = index / stride`. + let num_strides = context.create_local_binding(Item::new(u32::as_elem())); + context.register(Instruction::new( + Operator::Div(BinaryOperator { + lhs: *index, + rhs: stride.expand.into(), + }), + num_strides.clone().into(), + )); + + // Compute `coordinate = num_strides % shape `. + let coordinate = context.create_local_binding(Item::new(u32::as_elem())); + context.register(Instruction::new( + Operator::Modulo(BinaryOperator { + lhs: *index, + rhs: shape.expand.into(), + }), + coordinate.clone().into(), + )); + + coordinate.into() + } + // Expand method of [len](Tensor::len). pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped { let elem: ExpandElementTyped> = self.expand.into();