Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jjwilke committed Aug 18, 2023
1 parent 4965489 commit 1967204
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 18 deletions.
24 changes: 14 additions & 10 deletions cunumeric/linalg/cholesky.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021-2022 NVIDIA Corporation
# Copyright 2023 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -208,13 +208,6 @@ def _batched_cholesky(output: DeferredArray, input: DeferredArray) -> None:
# wildly varying memory available depending on the system.
# Just use a fixed cutoff to provide some sensible warning.
# TODO: find a better way to inform the user dims are too big
size = input.base.shape[-1]
if size > 32768:
raise NotImplementedError(
"batched cholesky is only valid"
" when the square submatrices fit"
f" on a single proc, n > {size} is too large"
)
context = output.context
task = context.create_auto_task(CuNumericOpCode.BATCHED_CHOLESKY)
task.add_input(input.base)
Expand All @@ -229,16 +222,27 @@ def _batched_cholesky(output: DeferredArray, input: DeferredArray) -> None:
def cholesky(
output: DeferredArray, input: DeferredArray, no_tril: bool
) -> None:
runtime = output.runtime
context = output.context
if len(input.base.shape) > 2:
if no_tril:
raise NotImplementedError(
"batched cholesky expects to only "
"produce the lower triangular matrix"
)
size = input.base.shape[-1]
# Choose 32768 as dimension cutoff for warning
# so that for float64 anything larger than
# 8 GiB produces a warning
if size > 32768:
runtime.warn(
"batched cholesky is only valid"
" when the square submatrices fit"
f" on a single proc, n > {size} may be too large",
category=UserWarning,
)
return _batched_cholesky(output, input)

runtime = output.runtime
context = output.context
if runtime.num_procs == 1:
transpose_copy_single(context, input.base, output.base)
potrf_single(context, output.base)
Expand Down
2 changes: 1 addition & 1 deletion src/cunumeric/matrix/batched_cholesky.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021-2022 NVIDIA Corporation
/* Copyright 2023 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion src/cunumeric/matrix/batched_cholesky.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021-2022 NVIDIA Corporation
/* Copyright 2023 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion src/cunumeric/matrix/batched_cholesky_omp.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021-2022 NVIDIA Corporation
/* Copyright 2023 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
8 changes: 4 additions & 4 deletions src/cunumeric/matrix/batched_cholesky_template.inl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021-2022 NVIDIA Corporation
/* Copyright 2023 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,16 +64,16 @@ struct BatchedCholeskyImpl {
auto shape = input_array.shape<DIM>();
if (shape != output_array.shape<DIM>()) {
throw legate::TaskException(
"Batched cholesky is not yet supported when input/output types differ");
"Batched cholesky is not supported when input/output shapes differ");
}

if (shape.empty()) return;

size_t strides[DIM];

auto input = input_array.read_accessor<VAL, DIM>(shape).ptr(shape, strides);
auto output = output_array.write_accessor<VAL, DIM>(shape).ptr(shape, strides);

if (shape.empty()) return;

// TODO: we need some sort of check here on the strides
// This should be a dense thing.

Expand Down
11 changes: 10 additions & 1 deletion tests/integration/test_cholesky.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021-2022 NVIDIA Corporation
# Copyright 2023 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -82,6 +82,15 @@ def test_batched_3d(n):
assert allclose(correct, test)


def test_batched_empty():
batch = 4
a = _get_real_symm_posdef(8)
a_batched = num.einsum("i,jk->ijk", np.arange(batch) + 1, a)
a_sliced = a_batched[0:0, :, :]
empty = num.linalg.cholesky(a_sliced)
assert empty.shape == a_sliced.shape


@pytest.mark.parametrize("n", SIZES)
def test_batched_4d(n):
batch = 2
Expand Down

0 comments on commit 1967204

Please sign in to comment.