Skip to content

Commit

Permalink
Provide nvJitLinkVersion as a public API.
Browse files Browse the repository at this point in the history
  • Loading branch information
bdice committed Jun 27, 2024
1 parent 86329f3 commit 8380c0c
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
__pycache__
build
dist
final_dist
*.so
*.egg-info
.*.swp
Expand Down
8 changes: 3 additions & 5 deletions ci/install_latest_cuda_toolkit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
# Copyright (c) 2024, NVIDIA CORPORATION

# Installs the latest CUDA Toolkit.
# Supports CentOS 7 and Rocky Linux 8.
# Supports Rocky Linux 8.

yum update -y
yum install -y epel-release

OS_ID=$(. /etc/os-release; echo $ID)
if [ "${OS_ID}" == "rocky" ]; then
yum install -y nvidia-driver
else
if [ "${OS_ID}" != "rocky" ]; then
echo "Error: OS not detected as Rocky Linux. Exiting."
exit 1
fi

yum install -y cuda-toolkit-12-5
yum install -y nvidia-driver cuda-nvcc-12-5 cuda-cudart-devel-12-5 cuda-driver-devel-12-5 libnvjitlink-devel-12-5
3 changes: 2 additions & 1 deletion pynvjitlink/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.

from pynvjitlink.api import NvJitLinker, NvJitLinkError
from pynvjitlink.api import NvJitLinker, NvJitLinkError, nvjitlink_version
from pynvjitlink._version import __git_commit__, __version__

__all__ = [
"NvJitLinkError",
"NvJitLinker",
"nvjitlink_version",
"__git_commit__",
"__version__",
]
38 changes: 37 additions & 1 deletion pynvjitlink/_nvjitlinklib.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -55,6 +55,40 @@ static void set_exception(PyObject *exception_type, const char *message_format,
PyErr_SetString(exception_type, exception_message);
}

static PyObject *nvjitlink_version() {
unsigned int major;
unsigned int minor;

nvJitLinkResult res = nvJitLinkVersion(&major, &minor);

if (res != NVJITLINK_SUCCESS) {
set_exception(PyExc_RuntimeError, "%s error when calling nvJitLinkVersion",
res);
return nullptr;
}

PyObject *py_version = PyTuple_New(2);
PyObject *py_major = PyLong_FromUnsignedLong(major);
PyObject *py_minor = PyLong_FromUnsignedLong(minor);
if (!py_version || !py_major || !py_minor) {
PyErr_SetString(PyExc_RuntimeError, "Failed to create version tuple");
if (py_major) {
Py_DecRef(py_major);
}
if (py_minor) {
Py_DecRef(py_minor);
}
if (py_version) {
Py_DecRef(py_version);
}
return nullptr;
}

PyTuple_SetItem(py_version, 0, py_major);
PyTuple_SetItem(py_version, 1, py_minor);
return py_version;
}

static PyObject *create(PyObject *self, PyObject *args) {
PyObject *ret = nullptr;
const char **jitlink_options;
Expand Down Expand Up @@ -302,6 +336,8 @@ static PyObject *get_linked_cubin(PyObject *self, PyObject *args) {
}

static PyMethodDef ext_methods[] = {
{"nvjitlink_version", (PyCFunction)nvjitlink_version, METH_VARARGS,
"Returns the nvJitLink version"},
{"create", (PyCFunction)create, METH_VARARGS,
"Returns a handle to a new nvJitLink object"},
{"destroy", (PyCFunction)destroy, METH_VARARGS,
Expand Down
6 changes: 5 additions & 1 deletion pynvjitlink/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.

from enum import Enum
from pynvjitlink import _nvjitlinklib
Expand All @@ -16,6 +16,10 @@ class InputType(Enum):
LIBRARY = 6


def nvjitlink_version():
return _nvjitlinklib.nvjitlink_version()


class NvJitLinkError(RuntimeError):
pass

Expand Down
7 changes: 7 additions & 0 deletions pynvjitlink/tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ def test_version_constants_are_populated():
# __version__ should always be non-empty
assert isinstance(pynvjitlink.__version__, str)
assert len(pynvjitlink.__version__) > 0


def test_nvjitlink_version():
nvjitlink_version = pynvjitlink.nvjitlink_version()
assert len(nvjitlink_version) == 2
assert nvjitlink_version[0] >= 12
assert nvjitlink_version[1] >= 0

0 comments on commit 8380c0c

Please sign in to comment.