Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batched cholesky implementation and tests #1029

Merged
merged 9 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cunumeric/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class _CunumericSharedLib:
CUNUMERIC_ADVANCED_INDEXING: int
CUNUMERIC_ARANGE: int
CUNUMERIC_ARGWHERE: int
CUNUMERIC_BATCHED_CHOLESKY: int
CUNUMERIC_BINARY_OP: int
CUNUMERIC_BINARY_RED: int
CUNUMERIC_BINCOUNT: int
Expand Down Expand Up @@ -333,6 +334,7 @@ class CuNumericOpCode(IntEnum):
ADVANCED_INDEXING = _cunumeric.CUNUMERIC_ADVANCED_INDEXING
ARANGE = _cunumeric.CUNUMERIC_ARANGE
ARGWHERE = _cunumeric.CUNUMERIC_ARGWHERE
BATCHED_CHOLESKY = _cunumeric.CUNUMERIC_BATCHED_CHOLESKY
BINARY_OP = _cunumeric.CUNUMERIC_BINARY_OP
BINARY_RED = _cunumeric.CUNUMERIC_BINARY_RED
BINCOUNT = _cunumeric.CUNUMERIC_BINCOUNT
Expand Down
40 changes: 38 additions & 2 deletions cunumeric/linalg/cholesky.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021-2022 NVIDIA Corporation
# Copyright 2023 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -202,11 +202,47 @@ def tril(context: Context, p_output: StorePartition, n: int) -> None:
task.execute()


def _batched_cholesky(output: DeferredArray, input: DeferredArray) -> None:
# the only feasible implementation for right now is that
# each cholesky submatrix fits on a single proc. We will have
# wildly varying memory available depending on the system.
# Just use a fixed cutoff to provide some sensible warning.
# TODO: find a better way to inform the user dims are too big
ipdemes marked this conversation as resolved.
Show resolved Hide resolved
context: Context = output.context
task = context.create_auto_task(CuNumericOpCode.BATCHED_CHOLESKY)
task.add_input(input.base)
task.add_output(output.base)
ndim = input.base.ndim
task.add_broadcast(input.base, (ndim - 2, ndim - 1))
task.add_broadcast(output.base, (ndim - 2, ndim - 1))
ipdemes marked this conversation as resolved.
Show resolved Hide resolved
task.add_alignment(input.base, output.base)
task.throws_exception(LinAlgError)
task.execute()


def cholesky(
output: DeferredArray, input: DeferredArray, no_tril: bool
) -> None:
runtime = output.runtime
context = output.context
context: Context = output.context
if len(input.base.shape) > 2:
if no_tril:
raise NotImplementedError(
"batched cholesky expects to only "
"produce the lower triangular matrix"
)
size = input.base.shape[-1]
# Choose 32768 as dimension cutoff for warning
# so that for float64 anything larger than
# 8 GiB produces a warning
if size > 32768:
runtime.warn(
"batched cholesky is only valid"
" when the square submatrices fit"
f" on a single proc, n > {size} may be too large",
category=UserWarning,
)
return _batched_cholesky(output, input)

if runtime.num_procs == 1:
transpose_copy_single(context, input.base, output.base)
Expand Down
4 changes: 0 additions & 4 deletions cunumeric/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ def cholesky(a: ndarray) -> ndarray:
elif shape[-1] != shape[-2]:
raise ValueError("Last 2 dimensions of the array must be square")

if len(shape) > 2:
raise NotImplementedError(
"cuNumeric needs to support stacked 2d arrays"
)
return _cholesky(a)


Expand Down
3 changes: 3 additions & 0 deletions cunumeric_cpp.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ list(APPEND cunumeric_SOURCES
src/cunumeric/index/putmask.cc
src/cunumeric/item/read.cc
src/cunumeric/item/write.cc
src/cunumeric/matrix/batched_cholesky.cc
src/cunumeric/matrix/contract.cc
src/cunumeric/matrix/diag.cc
src/cunumeric/matrix/gemm.cc
Expand Down Expand Up @@ -195,6 +196,7 @@ if(Legion_USE_OpenMP)
src/cunumeric/index/repeat_omp.cc
src/cunumeric/index/wrap_omp.cc
src/cunumeric/index/zip_omp.cc
src/cunumeric/matrix/batched_cholesky_omp.cc
src/cunumeric/matrix/contract_omp.cc
src/cunumeric/matrix/diag_omp.cc
src/cunumeric/matrix/gemm_omp.cc
Expand Down Expand Up @@ -245,6 +247,7 @@ if(Legion_USE_CUDA)
src/cunumeric/index/putmask.cu
src/cunumeric/item/read.cu
src/cunumeric/item/write.cu
src/cunumeric/matrix/batched_cholesky.cu
src/cunumeric/matrix/contract.cu
src/cunumeric/matrix/diag.cu
src/cunumeric/matrix/gemm.cu
Expand Down
1 change: 1 addition & 0 deletions src/cunumeric/cunumeric_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ enum CuNumericOpCode {
CUNUMERIC_ADVANCED_INDEXING,
CUNUMERIC_ARANGE,
CUNUMERIC_ARGWHERE,
CUNUMERIC_BATCHED_CHOLESKY,
CUNUMERIC_BINARY_OP,
CUNUMERIC_BINARY_RED,
CUNUMERIC_BINCOUNT,
Expand Down
19 changes: 19 additions & 0 deletions src/cunumeric/mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,25 @@ std::vector<StoreMapping> CuNumericMapper::store_mappings(
}
return std::move(mappings);
}
// CHANGE: If this code is changed, make sure all layouts are
// consistent with those assumed in batched_cholesky.cu, etc
case CUNUMERIC_BATCHED_CHOLESKY: {
std::vector<StoreMapping> mappings;
jjwilke marked this conversation as resolved.
Show resolved Hide resolved
auto& inputs = task.inputs();
auto& outputs = task.outputs();
mappings.reserve(inputs.size() + outputs.size());
for (auto& input : inputs) {
mappings.push_back(StoreMapping::default_mapping(input, options.front()));
mappings.back().policy.exact = true;
ipdemes marked this conversation as resolved.
Show resolved Hide resolved
mappings.back().policy.ordering.set_c_order();
}
for (auto& output : outputs) {
mappings.push_back(StoreMapping::default_mapping(output, options.front()));
mappings.back().policy.exact = true;
mappings.back().policy.ordering.set_c_order();
}
return std::move(mappings);
}
case CUNUMERIC_TRILU: {
if (task.scalars().size() == 2) return {};
// If we're here, this task was the post-processing for Cholesky.
Expand Down
85 changes: 85 additions & 0 deletions src/cunumeric/matrix/batched_cholesky.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/* Copyright 2023 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#include "cunumeric/matrix/batched_cholesky.h"
#include "cunumeric/cunumeric.h"
#include "cunumeric/matrix/batched_cholesky_template.inl"

#include <cblas.h>
#include <core/type/type_info.h>
#include <lapack.h>

namespace cunumeric {

using namespace legate;

template <>
void CopyBlockImpl<VariantKind::CPU>::operator()(void* dst, const void* src, size_t size)
{
::memcpy(dst, src, size);
}

template <Type::Code CODE>
struct BatchedTransposeImplBody<VariantKind::CPU, CODE> {
using VAL = legate_type_of<CODE>;

static constexpr int tile_size = 64;

void operator()(VAL* out, int n) const
{
VAL tile[tile_size][tile_size];
int nblocks = (n + tile_size - 1) / tile_size;

for (int rb = 0; rb < nblocks; ++rb) {
for (int cb = 0; cb < nblocks; ++cb) {
int r_start = rb * tile_size;
int r_stop = std::min(r_start + tile_size, n);
int c_start = cb * tile_size;
int c_stop = std::min(c_start + tile_size, n);
for (int r = r_start, tr = 0; r < r_stop; ++r, ++tr) {
for (int c = c_start, tc = 0; c < c_stop; ++c, ++tc) {
if (r <= c) {
tile[tr][tc] = out[r * n + c];
} else {
tile[tr][tc] = 0;
}
}
}
for (int r = c_start, tr = 0; r < c_stop; ++r, ++tr) {
for (int c = r_start, tc = 0; c < r_stop; ++c, ++tc) { out[r * n + c] = tile[tc][tr]; }
}
}
}
}
};

/*static*/ void BatchedCholeskyTask::cpu_variant(TaskContext& context)
{
#ifdef LEGATE_USE_OPENMP
openblas_set_num_threads(1); // make sure this isn't overzealous
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this even if we don't call any openmp pragmas inside of the cpu variant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm, not sure. I copied this straight from potrf.cc, which was @magnatelee who usually has good reasons for including things : ) I would say leave for now until we get clarification? I don't think it's hurting anything, but I also don't understand why it would be necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will keep this comment open so we don't loose it and can ask @magnatelee when he is back

Copy link
Contributor

@magnatelee magnatelee Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if one program doesn't use a cpu task and an openmp task, both calling openblas, this isn't strictly necessary. but we would never know, so to be absolutely safe, any persistent states like the number of openmp threads for openblas should be set by each task to make sure they match the assumption of the task.

#endif
batched_cholesky_task_context_dispatch<VariantKind::CPU>(context);
}

namespace // unnamed
{
static void __attribute__((constructor)) register_tasks(void)
{
BatchedCholeskyTask::register_variants();
}
} // namespace

} // namespace cunumeric
111 changes: 111 additions & 0 deletions src/cunumeric/matrix/batched_cholesky.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/* Copyright 2023 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#include "cunumeric/matrix/batched_cholesky.h"
#include "cunumeric/matrix/potrf.h"
#include "cunumeric/matrix/batched_cholesky_template.inl"

#include "cunumeric/cuda_help.h"

namespace cunumeric {

using namespace legate;

#define TILE_DIM 32
#define BLOCK_ROWS 8

template <>
void CopyBlockImpl<VariantKind::GPU>::operator()(void* dst, const void* src, size_t size)
{
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, get_cached_stream());
}

template <typename VAL>
__global__ static void __launch_bounds__((TILE_DIM * BLOCK_ROWS), MIN_CTAS_PER_SM)
transpose_2d_lower(VAL* out, int n)
{
__shared__ VAL tile[TILE_DIM][TILE_DIM + 1 /*avoid bank conflicts*/];

// The y dim is fast-moving index for coalescing
auto r_block = blockIdx.x * TILE_DIM;
auto c_block = blockIdx.y * TILE_DIM;
auto r = blockIdx.x * TILE_DIM + threadIdx.x;
auto c = blockIdx.y * TILE_DIM + threadIdx.y;
auto stride = BLOCK_ROWS;
// The tile coordinates
auto tr = threadIdx.x;
auto tc = threadIdx.y;
auto offset = r * n + c;

// only execute across the upper diagonal
// a single thread block will store the upper diagonal block into
// a temp shared memory then set the block to zeros
if (c_block >= r_block) {
#pragma unroll
for (int i = 0; i < TILE_DIM; i += BLOCK_ROWS, offset += stride) {
if (r < n && (c + i) < n) {
if (r <= (c + i)) {
tile[tr][tc + i] = out[offset];
// clear the upper diagonal entry
out[offset] = 0;
} else {
tile[tr][tc + i] = 0;
}
}
}

// Make sure all the data is in shared memory
__syncthreads();

// Transpose the global coordinates, keep y the fast-moving index
r = blockIdx.y * TILE_DIM + threadIdx.x;
c = blockIdx.x * TILE_DIM + threadIdx.y;
offset = r * n + c;

#pragma unroll
for (int i = 0; i < TILE_DIM; i += BLOCK_ROWS, offset += stride) {
if (r < n && (c + i) < n) {
if (r >= (c + i)) { out[offset] = tile[tc + i][tr]; }
}
}
}
}

template <Type::Code CODE>
struct BatchedTransposeImplBody<VariantKind::GPU, CODE> {
using VAL = legate_type_of<CODE>;

void operator()(VAL* out, int n) const
{
const dim3 blocks((n + TILE_DIM - 1) / TILE_DIM, (n + TILE_DIM - 1) / TILE_DIM, 1);
const dim3 threads(TILE_DIM, BLOCK_ROWS, 1);

auto stream = get_cached_stream();

// CUDA Potrf produces the full matrix, we only want
// the lower diagonal
transpose_2d_lower<VAL><<<blocks, threads, 0, stream>>>(out, n);

CHECK_CUDA_STREAM(stream);
}
};

/*static*/ void BatchedCholeskyTask::gpu_variant(TaskContext& context)
{
batched_cholesky_task_context_dispatch<VariantKind::GPU>(context);
}

} // namespace cunumeric
38 changes: 38 additions & 0 deletions src/cunumeric/matrix/batched_cholesky.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/* Copyright 2021-2022 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#pragma once

#include "cunumeric/cunumeric.h"
#include "cunumeric/cunumeric_c.h"

namespace cunumeric {

class BatchedCholeskyTask : public CuNumericTask<BatchedCholeskyTask> {
public:
static const int TASK_ID = CUNUMERIC_BATCHED_CHOLESKY;

public:
static void cpu_variant(legate::TaskContext& context);
#ifdef LEGATE_USE_OPENMP
static void omp_variant(legate::TaskContext& context);
#endif
#ifdef LEGATE_USE_CUDA
static void gpu_variant(legate::TaskContext& context);
#endif
};

} // namespace cunumeric
Loading