diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 9d9fa963d..a31415666 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -2022,10 +2022,34 @@ def trilu(self, rhs: Any, k: int, lower: bool) -> None: def repeat( self, repeats: Any, axis: int, scalar_repeats: bool ) -> DeferredArray: - out = self.runtime.create_unbound_thunk(self.base.type, ndim=self.ndim) task = self.context.create_auto_task(CuNumericOpCode.REPEAT) - task.add_input(self.base) - task.add_output(out.base) + if scalar_repeats: + out_shape = tuple( + self.shape[dim] * repeats if dim == axis else self.shape[dim] + for dim in range(self.ndim) + ) + out = cast( + DeferredArray, + self.runtime.create_empty_thunk( + out_shape, + dtype=self.base.type, + inputs=[self], + ), + ) + p_in = task.declare_partition(self.base) + p_out = task.declare_partition(out.base) + task.add_input(self.base, partition=p_in) + task.add_output(out.base, partition=p_out) + scale = tuple( + repeats if dim == axis else 1 for dim in range(self.ndim) + ) + task.add_constraint(p_out <= p_in * scale) + else: + out = self.runtime.create_unbound_thunk( + self.base.type, ndim=self.ndim + ) + task.add_input(self.base) + task.add_output(out.base) # We pass axis now but don't use for 1D case (will use for ND case task.add_scalar_arg(axis, ty.int32) task.add_scalar_arg(scalar_repeats, ty.bool_) diff --git a/src/cunumeric/index/repeat.cc b/src/cunumeric/index/repeat.cc index 9222d7c5f..f663718f8 100644 --- a/src/cunumeric/index/repeat.cc +++ b/src/cunumeric/index/repeat.cc @@ -31,12 +31,8 @@ struct RepeatImplBody { const int32_t axis, const Rect& in_rect) const { - Point extents = in_rect.hi - in_rect.lo + Point::ONES(); - extents[axis] *= repeats; - - auto out = out_array.create_output_buffer(extents, true); - - Rect out_rect(Point::ZEROES(), extents - Point::ONES()); + auto out_rect = out_array.shape(); + auto out = out_array.read_write_accessor(out_rect); Pitches pitches; auto out_volume = pitches.flatten(out_rect); @@ -44,7 +40,6 @@ struct RepeatImplBody { auto out_p = pitches.unflatten(idx, out_rect.lo); auto in_p = out_p; in_p[axis] /= repeats; - in_p += in_rect.lo; out[out_p] = in[in_p]; } } diff --git a/src/cunumeric/index/repeat.cu b/src/cunumeric/index/repeat.cu index 634050b9d..b8063b01d 100644 --- a/src/cunumeric/index/repeat.cu +++ b/src/cunumeric/index/repeat.cu @@ -50,7 +50,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) template __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) - repeat_kernel(Buffer out, + repeat_kernel(AccessorRW out, const AccessorRO in, int64_t repeats, const int32_t axis, @@ -63,7 +63,6 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) auto out_p = pitches.unflatten(idx, Point::ZEROES()); auto in_p = out_p; in_p[axis] /= repeats; - in_p += in_lo; out[out_p] = in[in_p]; } @@ -103,12 +102,8 @@ struct RepeatImplBody { const int32_t axis, const Rect& in_rect) const { - Point extents = in_rect.hi - in_rect.lo + Point::ONES(); - extents[axis] *= repeats; - - auto out = out_array.create_output_buffer(extents, true); - - Rect out_rect(Point::ZEROES(), extents - Point::ONES()); + auto out_rect = out_array.shape(); + auto out = out_array.read_write_accessor(out_rect); Pitches pitches; auto out_volume = pitches.flatten(out_rect); diff --git a/src/cunumeric/index/repeat_omp.cc b/src/cunumeric/index/repeat_omp.cc index 9ff130634..6bfd7e551 100644 --- a/src/cunumeric/index/repeat_omp.cc +++ b/src/cunumeric/index/repeat_omp.cc @@ -36,12 +36,8 @@ struct RepeatImplBody { const int32_t axis, const Rect& in_rect) const { - Point extents = in_rect.hi - in_rect.lo + Point::ONES(); - extents[axis] *= repeats; - - auto out = out_array.create_output_buffer(extents, true); - - Rect out_rect(Point::ZEROES(), extents - Point::ONES()); + auto out_rect = out_array.shape(); + auto out = out_array.read_write_accessor(out_rect); Pitches pitches; auto out_volume = pitches.flatten(out_rect); @@ -50,7 +46,6 @@ struct RepeatImplBody { auto out_p = pitches.unflatten(idx, out_rect.lo); auto in_p = out_p; in_p[axis] /= repeats; - in_p += in_rect.lo; out[out_p] = in[in_p]; } } diff --git a/src/cunumeric/index/repeat_template.inl b/src/cunumeric/index/repeat_template.inl index d6173dde8..8c0cc026a 100644 --- a/src/cunumeric/index/repeat_template.inl +++ b/src/cunumeric/index/repeat_template.inl @@ -37,7 +37,7 @@ struct RepeatImpl { auto input_arr = args.input.read_accessor(input_rect); if (input_rect.empty()) { - args.output.bind_empty_data(); + if (!args.scalar_repeats) { args.output.bind_empty_data(); } return; }