Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for compiling to SoA ABI for CUDA #1103

Open
wants to merge 14 commits into
base: branch-24.03
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*.gcda
*.gcov
core
!core/
*.fluid
*.pyc
*.swp
Expand Down
1 change: 1 addition & 0 deletions conda/conda-build/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ requirements:
- libnvjitlink
- libcusparse
{% endif %}
- numba >=0.57.1
- opt_einsum >=3.3
- scipy
- typing_extensions
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/scripts/test-cunumeric
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env bash

setup_env() {
mamba create -yn legate -c ~/.artifacts/conda-build/legate_core -c ~/.artifacts/conda-build/cunumeric -c conda-forge -c "nvidia/label/cuda-12.0.0" legate-core cunumeric
mamba create -yn legate -c ~/.artifacts/conda-build/legate_core -c ~/.artifacts/conda-build/cunumeric -c conda-forge -c "nvidia/label/cuda-12.0.0" legate-core cunumeric cuda-nvcc
}

setup_test_env() {
Expand Down
4 changes: 2 additions & 2 deletions cunumeric/_sphinxext/_cunumeric_directive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
from __future__ import annotations

from docutils import nodes
from docutils.statemachine import ViewList
from docutils.statemachine import StringList
from sphinx.util.docutils import SphinxDirective
from sphinx.util.nodes import nested_parse_with_titles


class CunumericDirective(SphinxDirective):
def parse(self, rst_text: str, annotation: str) -> list[nodes.Node]:
result = ViewList()
result = StringList()
for line in rst_text.split("\n"):
result.append(line, annotation)
node = nodes.paragraph()
Expand Down
2 changes: 1 addition & 1 deletion cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,7 +1748,7 @@ def check_list_depth(arr: Any, prefix: NdShape = (0,)) -> int:
"List depths are mismatched. First element was at depth "
f"{first_depth}, but there is an element at"
f" depth {other_depth}, "
f"arrays{convert_to_array_form(prefix+(idx+1,))}"
f"arrays{convert_to_array_form(prefix + (idx + 1,))}"
)

return depths[0] + 1
Expand Down
256 changes: 256 additions & 0 deletions cunumeric/numba_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# Copyright 2023 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
cast,
)

from llvmlite import ir
from llvmlite.ir.builder import IRBuilder
from llvmlite.ir.instructions import Ret
from llvmlite.ir.types import FunctionType
from llvmlite.ir.values import Value
from numba import types
from numba.core import sigutils
from numba.core.base import BaseContext
from numba.core.callconv import BaseCallConv
from numba.core.codegen import CodeLibrary
from numba.core.compiler_lock import global_compiler_lock
from numba.core.funcdesc import FunctionDescriptor
from numba.core.typing.templates import Signature
from numba.cuda.codegen import CUDACodeLibrary
from numba.cuda.compiler import compile_cuda


class SoACallConv(BaseCallConv):
"""
Calling convention where returned values are stored through pointers
provided as arguments.

- If the return type is a scalar, the first argument is a pointer to the
return type.
- If the return type is a tuple of length N, then the first N arguments are
pointers to each of the elements of the tuple.

In equivalent C, the prototype of a function with this calling convention
would take the following form:

void <func_name>(<Tuple item 1>*, ..., <Tuple item N>*,
<Python arguments... >);
"""

def _make_call_helper(self, builder: Any) -> None:
# Call helpers are used for the exception implementation. This is not
# needed when only wrapping functions.
msg = "Python exceptions are unsupported when returning in SoA form"
raise NotImplementedError(msg)

def return_value(self, builder: IRBuilder, retval: Value) -> Ret:
return builder.ret(retval)

def return_user_exc(
self,
builder: IRBuilder,
exc: Any,
exc_args: Any = None,
loc: Any = None,
func_name: Any = None,
) -> None:
msg = "Python exceptions are unsupported when returning in SoA form"
raise NotImplementedError(msg)

def return_status_propagate(self, builder: IRBuilder, status: Any) -> None:
msg = "Return status is unsupported when returning in SoA form"
raise NotImplementedError(msg)

def get_function_type(
self, restype: types.Type, argtypes: Iterable[types.Type]
) -> FunctionType:
"""
Get the LLVM IR Function type for *restype* and *argtypes*.
"""
arginfo = self._get_arg_packer(argtypes)
be_argtypes = list(arginfo.argument_types)
if isinstance(restype, types.BaseTuple):
return_types = [self.get_return_type(t) for t in restype.types]
else:
return_types = [self.get_return_type(restype)]
fnty = ir.FunctionType(ir.VoidType(), return_types + be_argtypes)
return fnty

def decorate_function(
self,
fn: Callable[[Any], Any],
args: Iterable[str],
fe_argtypes: List[types.Type],
noalias: bool = False,
) -> None:
"""
Set names and attributes of function arguments.
"""
raise NotImplementedError("Function decoration not used for SoA ABI")

def get_arguments(
self, func: ir.Function, restype: types.Type
) -> Tuple[ir.Argument, ...]:
"""
Get the Python-level arguments of LLVM *func*.
"""
if isinstance(restype, types.BaseTuple):
n_returns = len(restype.types)
else:
n_returns = 1

return func.args[n_returns:]

def call_function(
self,
builder: ir.IRBuilder,
callee: ir.Function,
resty: types.Type,
argtys: Iterable[types.Type],
args: Iterable[ir.Value],
attrs: Optional[Tuple[str, ...]] = None,
) -> Tuple[ir.Value, ir.Value]:
"""
Call the Numba-compiled *callee*.
"""
raise NotImplementedError("Can't call SoA return function directly")


def soa_wrap_function(
context: BaseContext,
lib: CodeLibrary,
fndesc: FunctionDescriptor,
nvvm_options: Dict[str, Union[int, str, None]],
wrapper_name: str,
) -> CUDACodeLibrary:
"""
Wrap a Numba ABI function such that it returns tuple values into SoA
arguments.
"""
new_library = lib.codegen.create_library(
f"{lib.name}_function_",
entry_name=wrapper_name,
nvvm_options=nvvm_options,
)
library = cast(CUDACodeLibrary, new_library)
library.add_linking_library(lib)

# Determine the caller (C ABI) and wrapper (Numba ABI) function types
argtypes = fndesc.argtypes
restype = fndesc.restype
soa_call_conv = SoACallConv(context)
wrapperty = soa_call_conv.get_function_type(restype, argtypes)
calleety = context.call_conv.get_function_type(restype, argtypes)

# Create a new module and declare the callee
wrapper_module = context.create_module("cuda.soa.wrapper")
callee = ir.Function(wrapper_module, calleety, fndesc.llvm_func_name)

# Define the caller - populate it with a call to the callee and return
# its return value

wrapper = ir.Function(wrapper_module, wrapperty, wrapper_name)
builder = ir.IRBuilder(wrapper.append_basic_block(""))

arginfo = context.get_arg_packer(argtypes)
wrapper_args = soa_call_conv.get_arguments(wrapper, restype)
callargs = arginfo.as_arguments(builder, wrapper_args)
# We get (status, return_value), but we ignore the status since we
# can't propagate it through the SoA ABI anyway
_, return_value = context.call_conv.call_function(
builder, callee, restype, argtypes, callargs
)

if isinstance(restype, types.BaseTuple):
for i in range(len(restype.types)):
val = builder.extract_value(return_value, i)
builder.store(val, wrapper.args[i])
else:
builder.store(return_value, wrapper.args[0])
builder.ret_void()

library.add_ir_module(wrapper_module)
library.finalize()
return library


@global_compiler_lock
def compile_ptx_soa(
pyfunc: Callable[..., Any],
sig: Union[Tuple[types.Type], str, Signature],
debug: bool = False,
lineinfo: bool = False,
device: bool = False,
fastmath: bool = False,
cc: Optional[Tuple[int, int]] = None,
opt: bool = True,
abi_info: Optional[Dict[str, str]] = None,
) -> Tuple[str, types.Type]:
# This is just a copy of Numba's compile_ptx, with a modification to return
# values as SoA and some simplifications to keep it short
if not device:
raise NotImplementedError(
"Only device functions can be compiled for the SoA ABI"
)

nvvm_options: Dict[str, Union[int, str, None]] = {
"fastmath": fastmath,
"opt": 3 if opt else 0,
}

# Use the Python function name as the function name in PTX if it is not
# specified - otherwise, use the specified name.
if abi_info:
wrapper_name = abi_info["abi_name"]
else:
wrapper_name = pyfunc.__name__

args, return_type = sigutils.normalize_signature(sig)

# Default to Compute Capability 5.0 if not specified
cc = cc or (5, 0)

cres = compile_cuda(
pyfunc,
return_type,
args,
debug=debug,
lineinfo=lineinfo,
fastmath=fastmath,
nvvm_options=nvvm_options,
cc=cc,
)

lib = soa_wrap_function(
cres.target_context,
cres.library,
cres.fndesc,
nvvm_options,
wrapper_name,
)

ptx = lib.get_asm_str(cc=cc)
resty = cres.signature.return_type

return ptx, resty
11 changes: 8 additions & 3 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,15 @@ def validate_path(path):
if debug or verbose:
cmake_flags += ["--log-level=%s" % ("DEBUG" if debug else "VERBOSE")]

if debug:
build_type = "Debug"
elif debug_release:
build_type = "RelWithDebInfo"
else:
build_type = "Release"

cmake_flags += f"""\
-DCMAKE_BUILD_TYPE={(
"Debug" if debug else "RelWithDebInfo" if debug_release else "Release"
)}
-DCMAKE_BUILD_TYPE={build_type}
-DBUILD_SHARED_LIBS=ON
-DCMAKE_CUDA_ARCHITECTURES={str(arch)}
-DLegion_MAX_DIM={str(maxdim)}
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ exclude = '''
_build |
buck-out |
build |
dist
dist |
typings
)/
'''

[tool.mypy]
python_version = "3.10"
mypy_path = "typings/"

pretty = true
show_error_codes = true
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/utils/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def allclose(
inds = islice(zip(*np.where(~close)), diff_limit)
diffs = [f" index {i}: {a[i]} {b[i]}" for i in inds]
N = len(diffs)
print(f"First {N} difference{'s' if N>1 else ''} for allclose:\n")
print(f"First {N} difference{'s' if N > 1 else ''} for allclose:\n")
print("\n".join(diffs))
print(f"\nWith diff_limit={diff_limit}\n")

Expand Down
Loading
Loading