From c49cebd8d002c48b4cff3c3693cefb7a86619e8c Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 16 Oct 2024 00:31:05 +0000 Subject: [PATCH] fix nvrtc error handling --- cuda_core/cuda/core/experimental/_program.py | 8 +++++--- cuda_core/cuda/core/experimental/_utils.py | 12 ++++++------ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_program.py b/cuda_core/cuda/core/experimental/_program.py index b6271544..ae5928ee 100644 --- a/cuda_core/cuda/core/experimental/_program.py +++ b/cuda_core/cuda/core/experimental/_program.py @@ -63,13 +63,15 @@ def compile(self, target_type, options=(), name_expressions=(), logs=None): if name_expressions: for n in name_expressions: symbol_mapping[n] = handle_return(nvrtc.nvrtcGetLoweredName( - self._handle, n.encode())) + self._handle, n.encode()), handle=self._handle) if logs is not None: - logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._handle)) + logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._handle), + handle=self._handle) if logsize > 1: log = b" " * logsize - handle_return(nvrtc.nvrtcGetProgramLog(self._handle, log)) + handle_return(nvrtc.nvrtcGetProgramLog(self._handle, log), + handle=self._handle) logs.write(log.decode()) # TODO: handle jit_options for ptx? diff --git a/cuda_core/cuda/core/experimental/_utils.py b/cuda_core/cuda/core/experimental/_utils.py index bd3c5cd6..68571ebc 100644 --- a/cuda_core/cuda/core/experimental/_utils.py +++ b/cuda_core/cuda/core/experimental/_utils.py @@ -42,12 +42,12 @@ def _check_error(error, handle=None): elif isinstance(error, nvrtc.nvrtcResult): if error == nvrtc.nvrtcResult.NVRTC_SUCCESS: return - assert handle is not None - _, logsize = nvrtc.nvrtcGetProgramLogSize(handle) - log = b" " * logsize - _ = nvrtc.nvrtcGetProgramLog(handle, log) - err = f"{error}: {nvrtc.nvrtcGetErrorString(error)[1].decode()}, " \ - f"compilation log:\n\n{log.decode()}" + err = f"{error}: {nvrtc.nvrtcGetErrorString(error)[1].decode()}" + if handle is not None: + _, logsize = nvrtc.nvrtcGetProgramLogSize(handle) + log = b" " * logsize + _ = nvrtc.nvrtcGetProgramLog(handle, log) + err += f", compilation log:\n\n{log.decode()}" raise NVRTCError(err) else: raise RuntimeError('Unknown error type: {}'.format(error))