Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible bug in reference counting with shared memory regions #7688

Open
hcho3 opened this issue Oct 8, 2024 · 0 comments
Open

Possible bug in reference counting with shared memory regions #7688

hcho3 opened this issue Oct 8, 2024 · 0 comments
Assignees
Labels
investigating The developement team is investigating this issue

Comments

@hcho3
Copy link

hcho3 commented Oct 8, 2024

Description
#7567 introduced a reference counter to shared memory region to prevent users from releasing shared memory regions while any inference requests are still being processed.

I encountered an issue where the shared memory region could not be freed even when no inference request was ongoing.

Error I got:

tritonclient.utils.InferenceServerException: [StatusCode.INTERNAL] Cannot unregister shared memory region
'input_eec14889775b4c29a99d81458b0feb1a', it is currently in use

Triton Information
Triton Server 2.50.0 from NGC container 24.09

To Reproduce

  1. Extract model_repository.zip into a directory named model_repository/. The directory should have the following layout after extraction:
model_repository/
model_repository/xgboost_json
model_repository/xgboost_json/config.pbtxt
model_repository/xgboost_json/1
model_repository/xgboost_json/1/xgboost.json

The example uses an XGBoost model saved as xgboost_json format and loads it using the FIL backend.
Notably, the model is configured to use the CPU for inference (instance_group [{ kind: KIND_CPU }]).

  1. Launch the Triton server using the latest NGC container:
docker run -it --rm --network=host --ipc=host --pid=host --shm-size=1g \
  --gpus '"device=0"' \
  -v $(pwd)/model_repository:/model_repository \
  nvcr.io/nvidia/tritonserver:24.09-py3 \
  tritonserver --model-repository=/model_repository
  1. Create the script named test.py with the following content:
Click to see the content
## test.py
from argparse import ArgumentParser
from collections import namedtuple
from uuid import uuid4

import numpy as np
import tritonclient.grpc as triton_grpc
import tritonclient.http as triton_http
import tritonclient.utils as triton_utils
import tritonclient.utils.cuda_shared_memory as shm

STANDARD_PORTS = {"http": 8000, "grpc": 8001}
TritonInput = namedtuple("TritonInput", ("name", "handle", "input"))
TritonOutput = namedtuple("TritonOutput", ("name", "handle", "output"))


class TritonMessage:
    """Adapter to read output from both GRPC and HTTP responses"""

    def __init__(self, message):
        self.message = message

    def __getattr__(self, attr):
        try:
            return getattr(self.message, attr)
        except AttributeError:
            try:
                return self.message[attr]
            except Exception:  # Re-raise AttributeError
                pass
            raise


def get_triton_client(protocol="grpc", host="localhost", port=None):
    if port is None:
        port = STANDARD_PORTS[protocol]

    if protocol == "grpc":
        client = triton_grpc.InferenceServerClient(
            url=f"{host}:{port}",
            verbose=False,
        )
    elif protocol == "http":
        client = triton_http.InferenceServerClient(
            url=f"{host}:{port}",
            verbose=False,
        )
    else:
        raise RuntimeError('Bad protocol: "{}"'.format(protocol))

    return client


def set_shared_input_data(triton_client, triton_input, data, protocol="grpc"):
    input_size = data.size * data.itemsize

    input_name = "input_{}".format(uuid4().hex)
    input_handle = shm.create_shared_memory_region(input_name, input_size, 0)
    print(f"Create CUDA shared mem: name = {input_name}")

    shm.set_shared_memory_region(input_handle, [data])

    triton_client.register_cuda_shared_memory(
        input_name, shm.get_raw_handle(input_handle), 0, input_size
    )

    triton_input.set_shared_memory(input_name, input_size)

    return TritonInput(input_name, input_handle, triton_input)


def create_output_handle(triton_client, triton_output, size, shared_mem=None):
    output_name = "output_{}".format(uuid4().hex)
    output_handle = shm.create_shared_memory_region(output_name, size, 0)
    print(f"Create CUDA shared mem: name = {output_name}")

    triton_client.register_cuda_shared_memory(
        output_name, shm.get_raw_handle(output_handle), 0, size
    )

    triton_output.set_shared_memory(output_name, size)

    return output_name, output_handle


def create_triton_input(triton_client, data, name, dtype, protocol="grpc"):
    if protocol == "grpc":
        triton_input = triton_grpc.InferInput(name, data.shape, dtype)
    else:
        triton_input = triton_http.InferInput(name, data.shape, dtype)

    return set_shared_input_data(
        triton_client,
        triton_input,
        data,
        protocol=protocol,
    )


def create_triton_output(triton_client, size, name, protocol="grpc"):
    if protocol == "grpc":
        triton_output = triton_grpc.InferRequestedOutput(name)
    else:
        triton_output = triton_http.InferRequestedOutput(name, binary_data=True)

    output_name, output_handle = create_output_handle(
        triton_client, triton_output, size
    )

    return TritonOutput(name=output_name, handle=output_handle, output=triton_output)


def get_response_data(response, output_handle, output_name):
    if output_handle is None:
        return response.as_numpy(output_name)
    else:
        network_result = TritonMessage(response.get_output(output_name))
        return shm.get_contents_as_numpy(
            output_handle,
            triton_utils.triton_to_np_dtype(network_result.datatype),
            network_result.shape,
        )


def release_shared_memory(triton_client, shm_objs):
    for io_ in shm_objs:
        if io_.name is not None:
            # if 'name' field is None, no shared mem is used
            print(f"Free CUDA shared mem: name = {io_.name}")
            triton_client.unregister_cuda_shared_memory(name=io_.name)
            shm.destroy_shared_memory_region(io_.handle)


def infer(protocol, host):
    triton_client = get_triton_client(protocol=protocol, host=host)
    arr = np.zeros((1, 500), dtype=np.float32)

    inputs = [
        create_triton_input(
            triton_client,
            arr,
            name="input__0",
            dtype="FP32",
            protocol=protocol,
        )
    ]
    outputs = {
        "output__0": create_triton_output(
            triton_client,
            size=2 * 4,
            name="output__0",
            protocol=protocol,
        )
    }
    response = triton_client.infer(
        "xgboost_json",
        model_version="1",
        inputs=[input_.input for input_ in inputs],
        outputs=[output_.output for output_ in outputs.values()],
    )
    result = {
        name: get_response_data(response, handle, name)
        for name, (_, handle, _) in outputs.items()
    }
    release_shared_memory(triton_client, inputs)
    release_shared_memory(triton_client, outputs.values())


def main(args):
    for i in range(args.n_infer):
        infer(protocol=args.protocol, host=args.host)
        print("")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--n_infer",
        type=int,
        default=1000,
        help="Number of times to run inference (default 1000)",
    )
    parser.add_argument(
        "--protocol",
        type=str,
        choices=["grpc", "http"],
        required=True,
        help="Protocol to use to connect to the Triton server",
    )
    parser.add_argument("--host", type=str, default="localhost")
    parsed_args = parser.parse_args()
    main(parsed_args)
  1. Using the same machine as the server, run the test script:
docker run -it --rm --network=host --ipc=host --pid=host --shm-size=1g \
  --gpus '"device=0"' \
  -v $(pwd):/workspace \
  nvcr.io/nvidia/tritonserver:24.09-py3-sdk \
  python /workspace/test.py --protocol grpc

The script will crash, producing a stack trace that looks like the following:

Create CUDA shared mem: name = input_660453d16483422bb3839a093360fd55
Create CUDA shared mem: name = output_012419c8da8443da9100011d565930cf
Free CUDA shared mem: name = input_660453d16483422bb3839a093360fd55
Traceback (most recent call last):
  File "/workspace/test.py", line 216, in <module>
    main(parsed_args)
  File "/workspace/test.py", line 195, in main
    infer(protocol=args.protocol, host=args.host, shared_mem="cuda")
  File "/workspace/test.py", line 189, in infer
    release_shared_memory(triton_client, inputs)
  File "/workspace/test.py", line 152, in release_shared_memory
    triton_client.unregister_cuda_shared_memory(name=io_.name)
  File "/usr/local/lib/python3.10/dist-packages/tritonclient/grpc/_client.py", line 1443, in unregister_cuda_shared_memory
    raise_error_grpc(rpc_error)
  File "/usr/local/lib/python3.10/dist-packages/tritonclient/grpc/_utils.py", line 77, in raise_error_grpc
    raise get_error_grpc(rpc_error) from None
tritonclient.utils.InferenceServerException: [StatusCode.INTERNAL] Cannot unregister shared memory region 'input_660453d16483422bb3839a093360fd55', it is currently in use.

If the script does not crash, run it a few more times.

Expected behavior
The test script should complete without crashing.

Some observations

  • The test script does not crash if the number of inference is reduced to under 100. On the other hand, increasing the number of inference to 1000-5000 reliably triggers the error.
  • The test script does not crash if the HTTP protocol is used instead of gRPC. (Replace argument --protocol grpc with --protocol http.)
  • The test script does not crash if the model is configured to use the GPU for inference (instance_group [{ kind: KIND_GPU }]).
@pskiran1 pskiran1 added the investigating The developement team is investigating this issue label Oct 9, 2024
@pskiran1 pskiran1 self-assigned this Oct 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
investigating The developement team is investigating this issue
Development

No branches or pull requests

2 participants