Skip to content

Commit

Permalink
Merge pull request #173 from leofang/fix_nvrtc
Browse files Browse the repository at this point in the history
Fix NVRTC error handling
  • Loading branch information
leofang authored Oct 18, 2024
2 parents 39a86fd + c49cebd commit a860436
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
8 changes: 5 additions & 3 deletions cuda_core/cuda/core/experimental/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ def compile(self, target_type, options=(), name_expressions=(), logs=None):
if name_expressions:
for n in name_expressions:
symbol_mapping[n] = handle_return(nvrtc.nvrtcGetLoweredName(
self._handle, n.encode()))
self._handle, n.encode()), handle=self._handle)

if logs is not None:
logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._handle))
logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._handle),
handle=self._handle)
if logsize > 1:
log = b" " * logsize
handle_return(nvrtc.nvrtcGetProgramLog(self._handle, log))
handle_return(nvrtc.nvrtcGetProgramLog(self._handle, log),
handle=self._handle)
logs.write(log.decode())

# TODO: handle jit_options for ptx?
Expand Down
12 changes: 6 additions & 6 deletions cuda_core/cuda/core/experimental/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ def _check_error(error, handle=None):
elif isinstance(error, nvrtc.nvrtcResult):
if error == nvrtc.nvrtcResult.NVRTC_SUCCESS:
return
assert handle is not None
_, logsize = nvrtc.nvrtcGetProgramLogSize(handle)
log = b" " * logsize
_ = nvrtc.nvrtcGetProgramLog(handle, log)
err = f"{error}: {nvrtc.nvrtcGetErrorString(error)[1].decode()}, " \
f"compilation log:\n\n{log.decode()}"
err = f"{error}: {nvrtc.nvrtcGetErrorString(error)[1].decode()}"
if handle is not None:
_, logsize = nvrtc.nvrtcGetProgramLogSize(handle)
log = b" " * logsize
_ = nvrtc.nvrtcGetProgramLog(handle, log)
err += f", compilation log:\n\n{log.decode()}"
raise NVRTCError(err)
else:
raise RuntimeError('Unknown error type: {}'.format(error))
Expand Down

0 comments on commit a860436

Please sign in to comment.