Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Build] CUDA Illegal Memory Access error when using a custom Triton kernel #20885

Open
Numeri opened this issue May 31, 2024 · 4 comments
Open
Labels
build build issues; typically submitted using template ep:CUDA issues related to the CUDA execution provider stale issues that have not been addressed in a while; categorized by a bot

Comments

@Numeri
Copy link

Numeri commented May 31, 2024

Describe the issue

I am trying to add a custom Triton kernel to ONNX Runtime as an operator. This works, but whenever I call the operator, I get the following CUDA error (Illegal Memory Access):

2024-05-31 15:58:30.844121139 [E:onnxruntime:Default, cuda_call.cc:118 CudaCall] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=998ab211f19f ; file=/code/onnxruntime/core/providers/cuda/gpu_data_transfer.cc ; line=73 ; expr=cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast<cudaStream_t>(stream.GetHandle())); 
2024-05-31 15:58:30.844186034 [E:onnxruntime:Default, cuda_call.cc:118 CudaCall] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=998ab211f19f ; file=/code/onnxruntime/core/providers/cuda/cuda_execution_provider.cc ; line=446 ; expr=cudaStreamSynchronize(static_cast<cudaStream_t>(stream_)); 

This occurs with both IOBinding and normal inference (the exact error is slightly different, but it's still an illegal memory access).

I have reduced my code to a minimal working example (the Triton kernel is essentially just lambda x: -x) and put it into this draft PR, which also contains a few small fixes and more extensive documentation.

I believe the issue is specifically in how I pass the CUDA stream to onnxruntime::cuda::LaunchTritonKernel [here]:(https://github.com/microsoft/onnxruntime/pull/20883/files#diff-3ed25bb54cb594743621055f0b541d6b93c6792a49da3a9a1bb5a65d73abf22eR60-R72):

  cudaStream_t stream = Stream(ctx);

  return onnxruntime::cuda::LaunchTritonKernel(stream, function_name, grid_size, 1, 1, &args, sizeof(args));

but I may be wrong about that.

I understand that this is my code, not Microsoft's, but think that my PR would be a good contribution to ORT if I had a little help to fix this issue.

Urgency

Medium, not super urgent but I'd love some help :)

Target platform

CUDA

Build script

git clone https://github.com/Numeri/onnxruntime.git
cd onnxruntime
git checkout numeri/minimal_triton_kernel

bash tools/scripts/compile_triton_kernels.sh
sudo docker build -t ort_cuda -f dockerfiles/Dockerfile.cuda .
sudo docker run -it --gpus all \
    -v $PATH_TO_PROVIDED_SCRIPTS:/test_ort \
    ort_cuda \
    bash -c "pip install onnx numpy && python3 /test_ort/make_graph.py && python3 /test_ort/run_graph.py; /bin/bash"

In $PATH_TO_PROVIDED_SCRIPTS place these two scripts (one for making the ONNX graph with the test operator, one for running it). make_graph.py:

import onnx
from onnx import helper
import numpy as np

input_node = helper.make_tensor_value_info(
    name='X',
    elem_type=helper.np_dtype_to_tensor_dtype(np.dtype('float16')),
    shape=('seq_len',),
)
output_node = helper.make_tensor_value_info(
    name='Y',
    elem_type=helper.np_dtype_to_tensor_dtype(np.dtype('float16')),
    shape=('seq_len',),
)
softmax_node = helper.make_node(
    "MyTritonKernel",
    inputs=["X"],
    outputs=["Y"],
    domain="com.microsoft",
)

graph = helper.make_graph(
    name='test_model',
    nodes=[softmax_node],
    inputs=[input_node],
    outputs=[output_node],
)

model = helper.make_model(graph=graph)
onnx.save(model, 'test_model.onnx')

run_graph.py

import onnxruntime as ort
import numpy as np

providers = [
    (
        'CUDAExecutionProvider',
        {
           'device_id': 0,
        }
    ),
    'CPUExecutionProvider',
]

session = ort.InferenceSession('test_model.onnx', providers=providers)

X = np.random.rand(256).astype(np.float16)

output_names = ["Y"]

use_iobinding = True

if use_iobinding:
    binding = session.io_binding()
    binding.bind_cpu_input('X', X)

    # binding.bind_output('Y', device_type='cuda', device_id=0)
    binding.bind_output('Y', device_type='cpu')

    binding.synchronize_inputs()
    session.run_with_iobinding(binding)
    binding.synchronize_outputs()

    results = {
        name: output.numpy()
        for name, output
        in zip(output_names, binding.get_outputs())
    }
else:
    inputs = {
        "X": X,
    }
    results = {
        name: output
        for name, output
        in zip(output_names, session.run(output_names, inputs))
    }

Error / output

2024-05-31 15:58:30.844121139 [E:onnxruntime:Default, cuda_call.cc:118 CudaCall] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=998ab211f19f ; file=/code/onnxruntime/core/providers/cuda/gpu_data_transfer.cc ; line=73 ; expr=cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast<cudaStream_t>(stream.GetHandle())); 
2024-05-31 15:58:30.844186034 [E:onnxruntime:Default, cuda_call.cc:118 CudaCall] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=998ab211f19f ; file=/code/onnxruntime/core/providers/cuda/cuda_execution_provider.cc ; line=446 ; expr=cudaStreamSynchronize(static_cast<cudaStream_t>(stream_)); 

Visual Studio Version

No response

GCC / Compiler Version

Using the dockerfile's compiler version

@Numeri Numeri added the build build issues; typically submitted using template label May 31, 2024
@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label May 31, 2024
@Numeri
Copy link
Author

Numeri commented Jun 11, 2024

Perhaps someone here has advice on how to debug this, at least? I'd love to just have some next steps to try.

Copy link
Contributor

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Jul 11, 2024
@zhangzhen507
Copy link

zhangzhen507 commented Sep 6, 2024

Hello, I try to clone your code

git clone https://github.com/Numeri/onnxruntime.git
cd onnxruntime
git checkout numeri/minimal_triton_kernel
./build.sh --update --build --config RelWithDebInfo --skip_submodule_sync --build_shared_lib --parallel 20 --build_wheel --use_triton_kernel --use_cuda --cuda_home $CUDA_HOME --cudnn_home $CUDNN_HOME
pip install onnxruntime-1.19.0/build/Linux/RelWithDebInfo/dist/onnxruntime_gpu-1.19.0-cp38-cp38-linux_x86_64.whl
python make_graph.py
python run_graph.py
I have error like this:

2024-09-06 16:25:38.797108890 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running MyTritonKernel node. Name:'' Status Message: Launching kernel failed. too many resources requested for launch Traceback (most recent call last): File "run_graph.py", line 30, in <module> session.run_with_iobinding(binding) File "/home/zhen1.zhang/virtual_env/vir_onnx_py3.8/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 331, in run_with_iobinding self._sess.run_with_iobinding(iobinding._iobinding, run_options) RuntimeError: Error in execution: Non-zero status code returned while running MyTritonKernel node. Name:'' Status Message: Launching kernel failed. too many resources requested for launch

@Numeri
Copy link
Author

Numeri commented Sep 6, 2024

@zhangzhen507 Interesting, I haven't seen that. Just so you're aware, I'm not completely confident I left that branch in a working state. I did manage to fix the error mentioned in this thread, but need to do a little cleanup before I feel confident the Triton kernels are working properly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
build build issues; typically submitted using template ep:CUDA issues related to the CUDA execution provider stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

2 participants