diff --git a/cunumeric/config.py b/cunumeric/config.py index bdea334a1..635544bd8 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -32,6 +32,7 @@ class _CunumericSharedLib: CUNUMERIC_ADVANCED_INDEXING: int CUNUMERIC_ARANGE: int CUNUMERIC_ARGWHERE: int + CUNUMERIC_BATCHED_CHOLESKY: int CUNUMERIC_BINARY_OP: int CUNUMERIC_BINARY_RED: int CUNUMERIC_BINCOUNT: int @@ -333,6 +334,7 @@ class CuNumericOpCode(IntEnum): ADVANCED_INDEXING = _cunumeric.CUNUMERIC_ADVANCED_INDEXING ARANGE = _cunumeric.CUNUMERIC_ARANGE ARGWHERE = _cunumeric.CUNUMERIC_ARGWHERE + BATCHED_CHOLESKY = _cunumeric.CUNUMERIC_BATCHED_CHOLESKY BINARY_OP = _cunumeric.CUNUMERIC_BINARY_OP BINARY_RED = _cunumeric.CUNUMERIC_BINARY_RED BINCOUNT = _cunumeric.CUNUMERIC_BINCOUNT diff --git a/cunumeric/linalg/cholesky.py b/cunumeric/linalg/cholesky.py index 9bba03361..4ff4fe212 100644 --- a/cunumeric/linalg/cholesky.py +++ b/cunumeric/linalg/cholesky.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 NVIDIA Corporation +# Copyright 2023 NVIDIA Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -202,11 +202,47 @@ def tril(context: Context, p_output: StorePartition, n: int) -> None: task.execute() +def _batched_cholesky(output: DeferredArray, input: DeferredArray) -> None: + # the only feasible implementation for right now is that + # each cholesky submatrix fits on a single proc. We will have + # wildly varying memory available depending on the system. + # Just use a fixed cutoff to provide some sensible warning. + # TODO: find a better way to inform the user dims are too big + context: Context = output.context + task = context.create_auto_task(CuNumericOpCode.BATCHED_CHOLESKY) + task.add_input(input.base) + task.add_output(output.base) + ndim = input.base.ndim + task.add_broadcast(input.base, (ndim - 2, ndim - 1)) + task.add_broadcast(output.base, (ndim - 2, ndim - 1)) + task.add_alignment(input.base, output.base) + task.throws_exception(LinAlgError) + task.execute() + + def cholesky( output: DeferredArray, input: DeferredArray, no_tril: bool ) -> None: runtime = output.runtime - context = output.context + context: Context = output.context + if len(input.base.shape) > 2: + if no_tril: + raise NotImplementedError( + "batched cholesky expects to only " + "produce the lower triangular matrix" + ) + size = input.base.shape[-1] + # Choose 32768 as dimension cutoff for warning + # so that for float64 anything larger than + # 8 GiB produces a warning + if size > 32768: + runtime.warn( + "batched cholesky is only valid" + " when the square submatrices fit" + f" on a single proc, n > {size} may be too large", + category=UserWarning, + ) + return _batched_cholesky(output, input) if runtime.num_procs == 1: transpose_copy_single(context, input.base, output.base) diff --git a/cunumeric/linalg/linalg.py b/cunumeric/linalg/linalg.py index f3f7eb9fb..d1c0498b2 100644 --- a/cunumeric/linalg/linalg.py +++ b/cunumeric/linalg/linalg.py @@ -82,10 +82,6 @@ def cholesky(a: ndarray) -> ndarray: elif shape[-1] != shape[-2]: raise ValueError("Last 2 dimensions of the array must be square") - if len(shape) > 2: - raise NotImplementedError( - "cuNumeric needs to support stacked 2d arrays" - ) return _cholesky(a) diff --git a/cunumeric_cpp.cmake b/cunumeric_cpp.cmake index 4270962ba..f7feee620 100644 --- a/cunumeric_cpp.cmake +++ b/cunumeric_cpp.cmake @@ -143,6 +143,7 @@ list(APPEND cunumeric_SOURCES src/cunumeric/index/putmask.cc src/cunumeric/item/read.cc src/cunumeric/item/write.cc + src/cunumeric/matrix/batched_cholesky.cc src/cunumeric/matrix/contract.cc src/cunumeric/matrix/diag.cc src/cunumeric/matrix/gemm.cc @@ -195,6 +196,7 @@ if(Legion_USE_OpenMP) src/cunumeric/index/repeat_omp.cc src/cunumeric/index/wrap_omp.cc src/cunumeric/index/zip_omp.cc + src/cunumeric/matrix/batched_cholesky_omp.cc src/cunumeric/matrix/contract_omp.cc src/cunumeric/matrix/diag_omp.cc src/cunumeric/matrix/gemm_omp.cc @@ -245,6 +247,7 @@ if(Legion_USE_CUDA) src/cunumeric/index/putmask.cu src/cunumeric/item/read.cu src/cunumeric/item/write.cu + src/cunumeric/matrix/batched_cholesky.cu src/cunumeric/matrix/contract.cu src/cunumeric/matrix/diag.cu src/cunumeric/matrix/gemm.cu diff --git a/src/cunumeric/cunumeric_c.h b/src/cunumeric/cunumeric_c.h index b5b392835..99d9bea19 100644 --- a/src/cunumeric/cunumeric_c.h +++ b/src/cunumeric/cunumeric_c.h @@ -29,6 +29,7 @@ enum CuNumericOpCode { CUNUMERIC_ADVANCED_INDEXING, CUNUMERIC_ARANGE, CUNUMERIC_ARGWHERE, + CUNUMERIC_BATCHED_CHOLESKY, CUNUMERIC_BINARY_OP, CUNUMERIC_BINARY_RED, CUNUMERIC_BINCOUNT, diff --git a/src/cunumeric/mapper.cc b/src/cunumeric/mapper.cc index 247ded4fd..ba7114e45 100644 --- a/src/cunumeric/mapper.cc +++ b/src/cunumeric/mapper.cc @@ -145,6 +145,25 @@ std::vector CuNumericMapper::store_mappings( } return std::move(mappings); } + // CHANGE: If this code is changed, make sure all layouts are + // consistent with those assumed in batched_cholesky.cu, etc + case CUNUMERIC_BATCHED_CHOLESKY: { + std::vector mappings; + auto& inputs = task.inputs(); + auto& outputs = task.outputs(); + mappings.reserve(inputs.size() + outputs.size()); + for (auto& input : inputs) { + mappings.push_back(StoreMapping::default_mapping(input, options.front())); + mappings.back().policy.exact = true; + mappings.back().policy.ordering.set_c_order(); + } + for (auto& output : outputs) { + mappings.push_back(StoreMapping::default_mapping(output, options.front())); + mappings.back().policy.exact = true; + mappings.back().policy.ordering.set_c_order(); + } + return std::move(mappings); + } case CUNUMERIC_TRILU: { if (task.scalars().size() == 2) return {}; // If we're here, this task was the post-processing for Cholesky. diff --git a/src/cunumeric/matrix/batched_cholesky.cc b/src/cunumeric/matrix/batched_cholesky.cc new file mode 100644 index 000000000..30dbe3c53 --- /dev/null +++ b/src/cunumeric/matrix/batched_cholesky.cc @@ -0,0 +1,85 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/matrix/batched_cholesky.h" +#include "cunumeric/cunumeric.h" +#include "cunumeric/matrix/batched_cholesky_template.inl" + +#include +#include +#include + +namespace cunumeric { + +using namespace legate; + +template <> +void CopyBlockImpl::operator()(void* dst, const void* src, size_t size) +{ + ::memcpy(dst, src, size); +} + +template +struct BatchedTransposeImplBody { + using VAL = legate_type_of; + + static constexpr int tile_size = 64; + + void operator()(VAL* out, int n) const + { + VAL tile[tile_size][tile_size]; + int nblocks = (n + tile_size - 1) / tile_size; + + for (int rb = 0; rb < nblocks; ++rb) { + for (int cb = 0; cb < nblocks; ++cb) { + int r_start = rb * tile_size; + int r_stop = std::min(r_start + tile_size, n); + int c_start = cb * tile_size; + int c_stop = std::min(c_start + tile_size, n); + for (int r = r_start, tr = 0; r < r_stop; ++r, ++tr) { + for (int c = c_start, tc = 0; c < c_stop; ++c, ++tc) { + if (r <= c) { + tile[tr][tc] = out[r * n + c]; + } else { + tile[tr][tc] = 0; + } + } + } + for (int r = c_start, tr = 0; r < c_stop; ++r, ++tr) { + for (int c = r_start, tc = 0; c < r_stop; ++c, ++tc) { out[r * n + c] = tile[tc][tr]; } + } + } + } + } +}; + +/*static*/ void BatchedCholeskyTask::cpu_variant(TaskContext& context) +{ +#ifdef LEGATE_USE_OPENMP + openblas_set_num_threads(1); // make sure this isn't overzealous +#endif + batched_cholesky_task_context_dispatch(context); +} + +namespace // unnamed +{ +static void __attribute__((constructor)) register_tasks(void) +{ + BatchedCholeskyTask::register_variants(); +} +} // namespace + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/batched_cholesky.cu b/src/cunumeric/matrix/batched_cholesky.cu new file mode 100644 index 000000000..26fe3058f --- /dev/null +++ b/src/cunumeric/matrix/batched_cholesky.cu @@ -0,0 +1,111 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/matrix/batched_cholesky.h" +#include "cunumeric/matrix/potrf.h" +#include "cunumeric/matrix/batched_cholesky_template.inl" + +#include "cunumeric/cuda_help.h" + +namespace cunumeric { + +using namespace legate; + +#define TILE_DIM 32 +#define BLOCK_ROWS 8 + +template <> +void CopyBlockImpl::operator()(void* dst, const void* src, size_t size) +{ + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, get_cached_stream()); +} + +template +__global__ static void __launch_bounds__((TILE_DIM * BLOCK_ROWS), MIN_CTAS_PER_SM) + transpose_2d_lower(VAL* out, int n) +{ + __shared__ VAL tile[TILE_DIM][TILE_DIM + 1 /*avoid bank conflicts*/]; + + // The y dim is fast-moving index for coalescing + auto r_block = blockIdx.x * TILE_DIM; + auto c_block = blockIdx.y * TILE_DIM; + auto r = blockIdx.x * TILE_DIM + threadIdx.x; + auto c = blockIdx.y * TILE_DIM + threadIdx.y; + auto stride = BLOCK_ROWS; + // The tile coordinates + auto tr = threadIdx.x; + auto tc = threadIdx.y; + auto offset = r * n + c; + + // only execute across the upper diagonal + // a single thread block will store the upper diagonal block into + // a temp shared memory then set the block to zeros + if (c_block >= r_block) { +#pragma unroll + for (int i = 0; i < TILE_DIM; i += BLOCK_ROWS, offset += stride) { + if (r < n && (c + i) < n) { + if (r <= (c + i)) { + tile[tr][tc + i] = out[offset]; + // clear the upper diagonal entry + out[offset] = 0; + } else { + tile[tr][tc + i] = 0; + } + } + } + + // Make sure all the data is in shared memory + __syncthreads(); + + // Transpose the global coordinates, keep y the fast-moving index + r = blockIdx.y * TILE_DIM + threadIdx.x; + c = blockIdx.x * TILE_DIM + threadIdx.y; + offset = r * n + c; + +#pragma unroll + for (int i = 0; i < TILE_DIM; i += BLOCK_ROWS, offset += stride) { + if (r < n && (c + i) < n) { + if (r >= (c + i)) { out[offset] = tile[tc + i][tr]; } + } + } + } +} + +template +struct BatchedTransposeImplBody { + using VAL = legate_type_of; + + void operator()(VAL* out, int n) const + { + const dim3 blocks((n + TILE_DIM - 1) / TILE_DIM, (n + TILE_DIM - 1) / TILE_DIM, 1); + const dim3 threads(TILE_DIM, BLOCK_ROWS, 1); + + auto stream = get_cached_stream(); + + // CUDA Potrf produces the full matrix, we only want + // the lower diagonal + transpose_2d_lower<<>>(out, n); + + CHECK_CUDA_STREAM(stream); + } +}; + +/*static*/ void BatchedCholeskyTask::gpu_variant(TaskContext& context) +{ + batched_cholesky_task_context_dispatch(context); +} + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/batched_cholesky.h b/src/cunumeric/matrix/batched_cholesky.h new file mode 100644 index 000000000..fceba2a9f --- /dev/null +++ b/src/cunumeric/matrix/batched_cholesky.h @@ -0,0 +1,38 @@ +/* Copyright 2021-2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include "cunumeric/cunumeric.h" +#include "cunumeric/cunumeric_c.h" + +namespace cunumeric { + +class BatchedCholeskyTask : public CuNumericTask { + public: + static const int TASK_ID = CUNUMERIC_BATCHED_CHOLESKY; + + public: + static void cpu_variant(legate::TaskContext& context); +#ifdef LEGATE_USE_OPENMP + static void omp_variant(legate::TaskContext& context); +#endif +#ifdef LEGATE_USE_CUDA + static void gpu_variant(legate::TaskContext& context); +#endif +}; + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/batched_cholesky_omp.cc b/src/cunumeric/matrix/batched_cholesky_omp.cc new file mode 100644 index 000000000..84b311ff2 --- /dev/null +++ b/src/cunumeric/matrix/batched_cholesky_omp.cc @@ -0,0 +1,83 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/cunumeric.h" +#include "cunumeric/matrix/batched_cholesky.h" +#include "cunumeric/matrix/batched_cholesky_template.inl" + +#include +#include +#include + +namespace cunumeric { + +using namespace legate; + +template <> +void CopyBlockImpl::operator()(void* dst, const void* src, size_t n) +{ + ::memcpy(dst, src, n); +} + +template +struct BatchedTransposeImplBody { + using VAL = legate_type_of; + + static constexpr int tile_size = 64; + + void operator()(VAL* out, int n) const + { + int nblocks = (n + tile_size - 1) / tile_size; + +#pragma omp parallel for + for (int rb = 0; rb < nblocks; ++rb) { + // only loop the upper diagonal + // transpose the elements that are there and + // zero out the elements after reading them + for (int cb = rb; cb < nblocks; ++cb) { + VAL tile[tile_size][tile_size]; + int r_start = rb * tile_size; + int r_stop = std::min(r_start + tile_size, n); + int c_start = cb * tile_size; + int c_stop = std::min(c_start + tile_size, n); + + for (int r = r_start, tr = 0; r < r_stop; ++r, ++tr) { + for (int c = c_start, tc = 0; c < c_stop; ++c, ++tc) { + if (r <= c) { + auto offset = r * n + c; + tile[tr][tc] = out[offset]; + out[offset] = 0; + } else { + tile[tr][tc] = 0; + } + } + } + + for (int r = c_start, tr = 0; r < c_stop; ++r, ++tr) { + for (int c = r_start, tc = 0; c < r_stop; ++c, ++tc) { out[r * n + c] = tile[tc][tr]; } + } + } + } + } +}; + +/*static*/ void BatchedCholeskyTask::omp_variant(TaskContext& context) +{ + openblas_set_num_threads(omp_get_max_threads()); + batched_cholesky_task_context_dispatch(context); +} + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/batched_cholesky_template.inl b/src/cunumeric/matrix/batched_cholesky_template.inl new file mode 100644 index 000000000..8d266e3f0 --- /dev/null +++ b/src/cunumeric/matrix/batched_cholesky_template.inl @@ -0,0 +1,147 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +// Useful for IDEs +#include +#include "cunumeric/cunumeric.h" +#include "cunumeric/matrix/batched_cholesky.h" +#include "cunumeric/matrix/potrf_template.inl" +#include "cunumeric/matrix/transpose_template.inl" +#include "cunumeric/pitches.h" + +namespace cunumeric { + +using namespace legate; + +template +struct BatchedCholeskyImplBody { + template + void operator()(T* array, int32_t m, int32_t n) + { + PotrfImplBody()(array, m, n); + } +}; + +template +struct CopyBlockImpl { + void operator()(void* dst, const void* src, size_t n); +}; + +template +struct BatchedTransposeImplBody { + using VAL = legate_type_of; + + void operator()(VAL* array, int32_t n); +}; + +template +struct _cholesky_supported { + static constexpr bool value = CODE == Type::Code::FLOAT64 || CODE == Type::Code::FLOAT32 || + CODE == Type::Code::COMPLEX64 || CODE == Type::Code::COMPLEX128; +}; + +template +struct BatchedCholeskyImpl { + template + void operator()(Array& input_array, Array& output_array) const + { + using VAL = legate_type_of; + + auto shape = input_array.shape(); + if (shape != output_array.shape()) { + throw legate::TaskException( + "Batched cholesky is not supported when input/output shapes differ"); + } + + Pitches pitches; + size_t volume = pitches.flatten(shape); + + if (volume == 0) return; + + auto ncols = shape.hi[DIM - 1] - shape.lo[DIM - 1] + 1; + + size_t in_strides[DIM]; + size_t out_strides[DIM]; + + auto input = input_array.read_accessor(shape).ptr(shape, in_strides); + if (in_strides[DIM - 2] != ncols || in_strides[DIM - 1] != 1) { + throw legate::TaskException( + "Bad input accessor in batched cholesky, last two dimensions must be non-transformed and " + "dense with stride == 1"); + } + + auto output = output_array.write_accessor(shape).ptr(shape, out_strides); + if (out_strides[DIM - 2] != ncols || out_strides[DIM - 1] != 1) { + throw legate::TaskException( + "Bad output accessor in batched cholesky, last two dimensions must be non-transformed and " + "dense with stride == 1"); + } + + if (shape.empty()) return; + + int num_blocks = 1; + for (int i = 0; i < (DIM - 2); ++i) { num_blocks *= (shape.hi[i] - shape.lo[i] + 1); } + + auto m = static_cast(shape.hi[DIM - 2] - shape.lo[DIM - 2] + 1); + auto n = static_cast(shape.hi[DIM - 1] - shape.lo[DIM - 1] + 1); + assert(m > 0 && n > 0); + + auto block_stride = m * n; + + for (int i = 0; i < num_blocks; ++i) { + if constexpr (_cholesky_supported::value) { + CopyBlockImpl()(output, input, sizeof(VAL) * block_stride); + PotrfImplBody()(output, m, n); + // Implicit assumption here about the cholesky code created. + // We assume the output has C layout, but each subblock + // will be generated in Fortran layout. Transpose the Fortran + // subblock into C layout. + // CHANGE: If this code is changed, please make sure all changes + // are consistent with those found in mapper.cc. + BatchedTransposeImplBody()(output, n); + input += block_stride; + output += block_stride; + } + } + } +}; + +template +static void batched_cholesky_task_context_dispatch(TaskContext& context) +{ + auto& batched_input = context.inputs()[0]; + auto& batched_output = context.outputs()[0]; + if (batched_input.code() != batched_output.code()) { + throw legate::TaskException( + "batched cholesky is not yet supported when input/output types differ"); + } + if (batched_input.dim() != batched_output.dim()) { + throw legate::TaskException("input/output have different dims in batched cholesky"); + } + if (batched_input.dim() <= 2) { + throw legate::TaskException( + "internal error: batched cholesky input does not have more than 2 dims"); + } + double_dispatch(batched_input.dim(), + batched_input.code(), + BatchedCholeskyImpl{}, + batched_input, + batched_output); +} + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/potrf.cc b/src/cunumeric/matrix/potrf.cc index 02ae06246..46ed58b6a 100644 --- a/src/cunumeric/matrix/potrf.cc +++ b/src/cunumeric/matrix/potrf.cc @@ -25,48 +25,48 @@ namespace cunumeric { using namespace legate; template <> -struct PotrfImplBody { - void operator()(float* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_spotrf(&uplo, &n, array, &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(float* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_spotrf(&uplo, &n, array, &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} template <> -struct PotrfImplBody { - void operator()(double* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_dpotrf(&uplo, &n, array, &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(double* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_dpotrf(&uplo, &n, array, &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} template <> -struct PotrfImplBody { - void operator()(complex* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_cpotrf(&uplo, &n, reinterpret_cast<__complex__ float*>(array), &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(complex* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_cpotrf(&uplo, &n, reinterpret_cast<__complex__ float*>(array), &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} template <> -struct PotrfImplBody { - void operator()(complex* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_zpotrf(&uplo, &n, reinterpret_cast<__complex__ double*>(array), &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(complex* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_zpotrf(&uplo, &n, reinterpret_cast<__complex__ double*>(array), &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} /*static*/ void PotrfTask::cpu_variant(TaskContext& context) { diff --git a/src/cunumeric/matrix/potrf.cu b/src/cunumeric/matrix/potrf.cu index 68616525f..8f13a5168 100644 --- a/src/cunumeric/matrix/potrf.cu +++ b/src/cunumeric/matrix/potrf.cu @@ -49,41 +49,38 @@ static inline void potrf_template( } template <> -struct PotrfImplBody { - void operator()(float* array, int32_t m, int32_t n) - { - potrf_template(cusolverDnSpotrf_bufferSize, cusolverDnSpotrf, array, m, n); - } -}; +void PotrfImplBody::operator()(float* array, + int32_t m, + int32_t n) +{ + potrf_template(cusolverDnSpotrf_bufferSize, cusolverDnSpotrf, array, m, n); +} template <> -struct PotrfImplBody { - void operator()(double* array, int32_t m, int32_t n) - { - potrf_template(cusolverDnDpotrf_bufferSize, cusolverDnDpotrf, array, m, n); - } -}; +void PotrfImplBody::operator()(double* array, + int32_t m, + int32_t n) +{ + potrf_template(cusolverDnDpotrf_bufferSize, cusolverDnDpotrf, array, m, n); +} template <> -struct PotrfImplBody { - void operator()(complex* array, int32_t m, int32_t n) - { - potrf_template( - cusolverDnCpotrf_bufferSize, cusolverDnCpotrf, reinterpret_cast(array), m, n); - } -}; +void PotrfImplBody::operator()(complex* array, + int32_t m, + int32_t n) +{ + potrf_template( + cusolverDnCpotrf_bufferSize, cusolverDnCpotrf, reinterpret_cast(array), m, n); +} template <> -struct PotrfImplBody { - void operator()(complex* array, int32_t m, int32_t n) - { - potrf_template(cusolverDnZpotrf_bufferSize, - cusolverDnZpotrf, - reinterpret_cast(array), - m, - n); - } -}; +void PotrfImplBody::operator()(complex* array, + int32_t m, + int32_t n) +{ + potrf_template( + cusolverDnZpotrf_bufferSize, cusolverDnZpotrf, reinterpret_cast(array), m, n); +} /*static*/ void PotrfTask::gpu_variant(TaskContext& context) { diff --git a/src/cunumeric/matrix/potrf_omp.cc b/src/cunumeric/matrix/potrf_omp.cc index d26143a6f..36b32968d 100644 --- a/src/cunumeric/matrix/potrf_omp.cc +++ b/src/cunumeric/matrix/potrf_omp.cc @@ -26,48 +26,48 @@ namespace cunumeric { using namespace legate; template <> -struct PotrfImplBody { - void operator()(float* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_spotrf(&uplo, &n, array, &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(float* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_spotrf(&uplo, &n, array, &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} template <> -struct PotrfImplBody { - void operator()(double* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_dpotrf(&uplo, &n, array, &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(double* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_dpotrf(&uplo, &n, array, &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} template <> -struct PotrfImplBody { - void operator()(complex* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_cpotrf(&uplo, &n, reinterpret_cast<__complex__ float*>(array), &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(complex* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_cpotrf(&uplo, &n, reinterpret_cast<__complex__ float*>(array), &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} template <> -struct PotrfImplBody { - void operator()(complex* array, int32_t m, int32_t n) - { - char uplo = 'L'; - int32_t info = 0; - LAPACK_zpotrf(&uplo, &n, reinterpret_cast<__complex__ double*>(array), &m, &info); - if (info != 0) throw legate::TaskException("Matrix is not positive definite"); - } -}; +void PotrfImplBody::operator()(complex* array, + int32_t m, + int32_t n) +{ + char uplo = 'L'; + int32_t info = 0; + LAPACK_zpotrf(&uplo, &n, reinterpret_cast<__complex__ double*>(array), &m, &info); + if (info != 0) throw legate::TaskException("Matrix is not positive definite"); +} /*static*/ void PotrfTask::omp_variant(TaskContext& context) { diff --git a/src/cunumeric/matrix/potrf_template.inl b/src/cunumeric/matrix/potrf_template.inl index 55c782ad0..7e4252189 100644 --- a/src/cunumeric/matrix/potrf_template.inl +++ b/src/cunumeric/matrix/potrf_template.inl @@ -26,6 +26,26 @@ using namespace legate; template struct PotrfImplBody; +template +struct PotrfImplBody { + void operator()(float* array, int32_t m, int32_t n); +}; + +template +struct PotrfImplBody { + void operator()(double* array, int32_t m, int32_t n); +}; + +template +struct PotrfImplBody { + void operator()(complex* array, int32_t m, int32_t n); +}; + +template +struct PotrfImplBody { + void operator()(complex* array, int32_t m, int32_t n); +}; + template struct support_potrf : std::false_type {}; template <> diff --git a/tests/integration/test_cholesky.py b/tests/integration/test_cholesky.py index 91edbaa7e..5b2659a16 100644 --- a/tests/integration/test_cholesky.py +++ b/tests/integration/test_cholesky.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 NVIDIA Corporation +# Copyright 2023 NVIDIA Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -56,10 +56,14 @@ def test_diagonal(): assert allclose(b**2.0, a) +def _get_real_symm_posdef(n): + a = num.random.rand(n, n) + return a + a.T + num.eye(n) * n + + @pytest.mark.parametrize("n", SIZES) def test_real(n): - a = num.random.rand(n, n) - b = a + a.T + num.eye(n) * n + b = _get_real_symm_posdef(n) c = num.linalg.cholesky(b) c_np = np.linalg.cholesky(b.__array__()) assert allclose(c, c_np) @@ -80,6 +84,45 @@ def test_complex(n): assert allclose(c, c_np) +@pytest.mark.parametrize("n", SIZES) +def test_batched_3d(n): + batch = 4 + a = _get_real_symm_posdef(n) + np_a = a.__array__() + a_batched = num.einsum("i,jk->ijk", np.arange(batch) + 1, a) + test_c = num.linalg.cholesky(a_batched) + for i in range(batch): + correct = np.linalg.cholesky(np_a * (i + 1)) + test = test_c[i, :] + assert allclose(correct, test) + + +def test_batched_empty(): + batch = 4 + a = _get_real_symm_posdef(8) + a_batched = num.einsum("i,jk->ijk", np.arange(batch) + 1, a) + a_sliced = a_batched[0:0, :, :] + empty = num.linalg.cholesky(a_sliced) + assert empty.shape == a_sliced.shape + + +@pytest.mark.parametrize("n", SIZES) +def test_batched_4d(n): + batch = 2 + a = _get_real_symm_posdef(n) + np_a = a.__array__() + + outer = np.einsum("i,j->ij", np.arange(batch) + 1, np.arange(batch) + 1) + + a_batched = num.einsum("ij,kl->ijkl", outer, a) + test_c = num.linalg.cholesky(a_batched) + for i in range(batch): + for j in range(batch): + correct = np.linalg.cholesky(np_a * (i + 1) * (j + 1)) + test = test_c[i, j, :] + assert allclose(correct, test) + + if __name__ == "__main__": import sys diff --git a/tests/unit/cunumeric/test_config.py b/tests/unit/cunumeric/test_config.py index 5e85ccfde..6f8f43df5 100644 --- a/tests/unit/cunumeric/test_config.py +++ b/tests/unit/cunumeric/test_config.py @@ -117,6 +117,7 @@ def test_CuNumericOpCode() -> None: "ADVANCED_INDEXING", "ARANGE", "ARGWHERE", + "BATCHED_CHOLESKY", "BINARY_OP", "BINARY_RED", "BINCOUNT",