Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Use triton to deploy minicpm-v-2_6 GPU memory keeps increasing until it overflows #2642

Open
1 of 3 tasks
LinJianping opened this issue Oct 24, 2024 · 1 comment
Open
1 of 3 tasks
Assignees

Comments

@LinJianping
Copy link

LinJianping commented Oct 24, 2024

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.

Describe the bug

When I use lmdeploy to deploy minicpm-V-2_6 in triton, I find that the GPU memory keeps increasing until an exception occurs due to GPU memory issues, but this does not happen when I using a python script for loop inference. I experimented with the official model weights(https://huggingface.co/openbmb/MiniCPM-V-2_6) and still had the same problem. In order to align with the training environment, Torch 2.4.0 was used instead of Torch 2.3.1, which is compatible with lmdeploy. I don't know if this is the reason. When I use torch.cuda.empty_cache() to free unused GPU memory in each request, I can solve this problem, but it will cause the inference time to increase.

I have downgrade the Torch version to 2.3.1, the problem still exists.

Reproduction

The triton service startup script is as follows:

#!/bin/bash
cd ~
export OMP_NUM_THREADS=1
tritonserver --log-dir ./serving-logs \
             --model-control-mode "explicit" \
             --load-model minicpmv_2_6_awq_4bit \
             --whale-rpc-port 26381 \
             --grpc-port 26384 \
             --mport 26382 \
             --whale-rpc-use-async 1 \
             --log-info false \
             --allow-grpc true \
             --allow-whale-rpc true \
             --allow-metrics 0 \
             --model-repository /xx/model_repos/minicpmv_2_6_awq_4bit/2 \
             --exit-timeout-secs 0 \
             --log-verbose 0 \
             --server-appkey  xxxx

The python backend inference script is as follows:

import os
import numpy as np
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
import triton_python_backend_utils as pb_utils
import json
import time
import torch

from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image


class TritonPythonModel:
    """Your Python model must use the same class name. Every Python model
    that is created must have "TritonPythonModel" as the class name.
    """

    def initialize(self, args):
        """`initialize` is called only once when the model is being loaded.
        Implementing `initialize` function is optional. This function allows
        the model to intialize any state associated with this model.

        Parameters
        ----------
        args : dict
          Both keys and values are strings. The dictionary keys and values are:
          * model_config: A JSON string containing the model configuration
          * model_instance_kind: A string containing model instance kind
          * model_instance_device_id: A string containing model instance device ID
          * model_repository: Model repository path
          * model_version: Model version
          * model_name: Model name
        """
        # You must parse model_config. JSON string is not parsed here
        # print("args:", args)
        pwd = os.path.abspath(os.path.dirname(__file__))
        self.ckpt_path = f'{pwd}/MiniCPM-V-2_6/'
        print('self.ckpt_path', self.ckpt_path)
        engine_config = TurbomindEngineConfig(session_len=8192, max_batch_size=1)
        self.pipeline = pipeline(self.ckpt_path, backend_config=engine_config)
        print(os.environ['LD_LIBRARY_PATH'])

    def execute(self, requests):
        """`execute` must be implemented in every Python model. `execute`
        function receives a list of pb_utils.InferenceRequest as the only
        argument. This function is called when an inference request is made
        for this model. Depending on the batching configuration (e.g. Dynamic
        Batching) used, `requests` may contain multiple requests. Every
        Python model, must create one pb_utils.InferenceResponse for every
        pb_utils.InferenceRequest in `requests`. If there is an error, you can
        set the error argument when creating a pb_utils.InferenceResponse

        Parameters
        ----------
        requests : list
          A list of pb_utils.InferenceRequest

        Returns
        -------
        list
          A list of pb_utils.InferenceResponse. The length of this list must
          be the same as `requests`
        """
        responses = []

        # 后处理模型输出,并为每个请求创建响应
        for request in requests:
            query = pb_utils.get_input_tensor_by_name(request, "query")
            query = query.as_numpy()

            image = pb_utils.get_input_tensor_by_name(request, "image")
            image = image.as_numpy()

            result = self.pipeline((query[0][0].decode('utf-8'), image[0][0].decode('utf-8')))

            output = np.array(result.text.encode('utf-8'),  dtype=np.string_)

            out_tensor_0 = pb_utils.Tensor("response", output.astype(np.string_))
            response = pb_utils.InferenceResponse(output_tensors=[out_tensor_0])
            responses.append(response)


        # print("len req:", len(requests))
        # print("len resp:", len(responses))
        #torch.cuda.empty_cache()

        return responses

    def finalize(self):
        """`finalize` is called only once when the model is being unloaded.
        Implementing `finalize` function is OPTIONAL. This function allows
        the model to perform any necessary clean ups before exit.
        """
        print('Cleaning up...')

Environment

sys.platform: linux
Python: 3.9.16 (main, Apr  2 2024, 20:40:25) [GCC 10.2.1 20210130 (Red Hat 10.2.1-11)]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0: NVIDIA L40
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.2, V12.2.91
GCC: gcc (GCC) 4.8.5 20150623 (Red Hat 4.8.5-44)
PyTorch: 2.4.0+cu121
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.4.2 (Git Hash 1137e04ec0b5251ca2b4400a4fd3c667ce843d67)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.1
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  - CuDNN 90.1  (built against CUDA 12.4)
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=9.1.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.4.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,

TorchVision: 0.19.0+cu121
LMDeploy: 0.6.1+
transformers: 4.45.2
gradio: Not Found
fastapi: 0.115.0
pydantic: 2.9.2
triton: 3.0.0
NVIDIA Topology:
	GPU0	NIC0	NIC1	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	SYS	SYS				N/A
NIC0	SYS	 X 	SYS
NIC1	SYS	SYS	 X

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_bond_0
  NIC1: mlx5_bond_1

Error traceback

terminate called after throwing an instance of 'std::runtime_error'
  what():  [TM][ERROR] CUDA runtime error: out of memory /lmdeploy/src/turbomind/utils/allocator.h:246
@RunningLeon
Copy link
Collaborator

RunningLeon commented Oct 28, 2024

hi seems you are using quantized model /xx/model_repos/minicpmv_2_6_awq_4bit, you should add model_format='awq' to TurbomindEngineConfig, see here. Could you try again? If the issue still exists, could you provide a dockerfile to reproduce it ? thanks.

--model-repository /xx/model_repos/minicpmv_2_6_awq_4bit/2
...
self.ckpt_path = f'{pwd}/MiniCPM-V-2_6/'
print('self.ckpt_path', self.ckpt_path)
engine_config = TurbomindEngineConfig(session_len=8192, max_batch_size=1)
self.pipeline = pipeline(self.ckpt_path, backend_config=engine_config)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants