Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] qwen2-vl-72b无法使用 #2622

Open
3 tasks done
bltcn opened this issue Oct 18, 2024 · 2 comments
Open
3 tasks done

[Bug] qwen2-vl-72b无法使用 #2622

bltcn opened this issue Oct 18, 2024 · 2 comments
Assignees

Comments

@bltcn
Copy link

bltcn commented Oct 18, 2024

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.

Describe the bug

调用函数直接报错,然后显存释出

Reproduction

lmdeploy serve api_server /root/hf_model/Qwen/Qwen2-VL-72B-Instruct-AWQ --model-name pkumlm_img --backend pytorch --server-port 8000 --log-level INFO --cache-max-entry-count 0.5 --session-len 32768 --model-format awq --tp 4

Environment

sys.platform: linux
Python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1,2,3: NVIDIA GeForce RTX 2080 Ti
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
GCC: x86_64-linux-gnu-gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
PyTorch: 2.3.1+cu121
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.3.6 (Git Hash 86e6af5974177e513fd3fee58425e1063e7f1361)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.1
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.9.2
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.3.1, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, 

TorchVision: 0.18.1+cu121
LMDeploy: 0.6.1+2e49fc3
transformers: 4.46.0.dev0
gradio: 4.44.0
fastapi: 0.115.0
pydantic: 2.9.2
triton: 2.3.1
NVIDIA Topology: 
	�[4mGPU0	GPU1	GPU2	GPU3	CPU Affinity	NUMA Affinity	GPU NUMA ID�[0m
GPU0	 X 	PIX	PIX	PIX	0-19,40-59	0		N/A
GPU1	PIX	 X 	PIX	PIX	0-19,40-59	0		N/A
GPU2	PIX	PIX	 X 	PIX	0-19,40-59	0		N/A
GPU3	PIX	PIX	PIX	 X 	0-19,40-59	0		N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Error traceback

2024-10-18 17:51:42,123 - lmdeploy - �[37mINFO�[0m - async_engine.py:142 - input backend=pytorch, backend_config=PytorchEngineConfig(dtype='auto', tp=4, session_len=32768, max_batch_size=128, cache_max_entry_count=0.5, prefill_interval=16, block_size=64, num_cpu_blocks=0, num_gpu_blocks=0, adapters=None, max_prefill_token_num=8192, thread_safe=False, enable_prefix_caching=False, device_type='cuda', eager_mode=False, custom_module_map=None, download_dir=None, revision=None)
2024-10-18 17:51:42,123 - lmdeploy - �[37mINFO�[0m - async_engine.py:144 - input chat_template_config=None
2024-10-18 17:51:42,129 - lmdeploy - �[37mINFO�[0m - async_engine.py:154 - updated chat_template_onfig=ChatTemplateConfig(model_name='qwen', system=None, meta_instruction=None, eosys=None, user=None, eoh=None, assistant=None, eoa=None, separator=None, capability=None, stop_words=None)
2024-10-18 17:51:42,154 - lmdeploy - �[37mINFO�[0m - __init__.py:98 - Checking environment for PyTorch Engine.
2024-10-18 17:51:42,305 - lmdeploy - �[33mWARNING�[0m - __init__.py:73 - Engine has not been tested on triton>2.2.0.
2024-10-18 17:51:42,858 - lmdeploy - �[37mINFO�[0m - __init__.py:221 - Checking model.
2024-10-18 17:51:42,858 - lmdeploy - �[33mWARNING�[0m - __init__.py:146 - LMDeploy requires transformers version: [4.33.0 ~ 4.44.1], but found version: 4.46.0.dev0
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}
2024-10-18 17:51:43,463 - lmdeploy - �[37mINFO�[0m - model_agent.py:594 - MASTER_ADDR=127.0.0.1, MASTER_PORT=29500
2024-10-18 17:51:46,797 - lmdeploy - �[37mINFO�[0m - model_agent.py:358 - build model.
2024-10-18 17:51:47,470 - lmdeploy - �[37mINFO�[0m - model_agent.py:361 - loading weights.
2024-10-18 17:51:47,472 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[0] loading weights - "model-00004-of-00011.safetensors"
2024-10-18 17:51:47,819 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[0] loading weights - "model-00006-of-00011.safetensors"
2024-10-18 17:51:48,234 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[0] loading weights - "model-00005-of-00011.safetensors"
2024-10-18 17:51:48,651 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[0] loading weights - "model-00009-of-00011.safetensors"
2024-10-18 17:51:49,067 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[0] loading weights - "model-00007-of-00011.safetensors"
2024-10-18 17:51:49,487 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[0] loading weights - "model-00008-of-00011.safetensors"
2024-10-18 17:51:49,509 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[2] loading weights - "model-00011-of-00011.safetensors"
2024-10-18 17:51:49,901 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[0] loading weights - "model-00010-of-00011.safetensors"
2024-10-18 17:51:50,217 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[1] loading weights - "model-00006-of-00011.safetensors"
2024-10-18 17:51:50,247 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[3] loading weights - "model-00008-of-00011.safetensors"
2024-10-18 17:51:50,567 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[0] loading weights - "model-00001-of-00011.safetensors"
2024-10-18 17:51:51,035 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[2] loading weights - "model-00004-of-00011.safetensors"
2024-10-18 17:51:51,739 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[0] loading weights - "model-00002-of-00011.safetensors"
2024-10-18 17:51:53,503 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[1] loading weights - "model-00007-of-00011.safetensors"
2024-10-18 17:51:53,683 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[3] loading weights - "model-00003-of-00011.safetensors"
2024-10-18 17:51:54,866 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[2] loading weights - "model-00007-of-00011.safetensors"
2024-10-18 17:51:55,795 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[0] loading weights - "model-00011-of-00011.safetensors"
2024-10-18 17:51:56,999 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[1] loading weights - "model-00002-of-00011.safetensors"
2024-10-18 17:51:57,419 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[0] loading weights - "model-00003-of-00011.safetensors"
2024-10-18 17:51:57,619 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[3] loading weights - "model-00002-of-00011.safetensors"
2024-10-18 17:51:58,920 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[2] loading weights - "model-00003-of-00011.safetensors"
2024-10-18 17:52:01,295 - lmdeploy - �[33mWARNING�[0m - model_agent.py:70 - device<0> No enough memory. update max_prefill_token_num=4096
2024-10-18 17:52:01,303 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[1] loading weights - "model-00004-of-00011.safetensors"
2024-10-18 17:52:01,398 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[3] loading weights - "model-00006-of-00011.safetensors"
2024-10-18 17:52:02,536 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[2] loading weights - "model-00009-of-00011.safetensors"
2024-10-18 17:52:04,191 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[1] loading weights - "model-00009-of-00011.safetensors"
2024-10-18 17:52:04,307 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[3] loading weights - "model-00001-of-00011.safetensors"
2024-10-18 17:52:04,506 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[2] loading weights - "model-00006-of-00011.safetensors"
2024-10-18 17:52:05,287 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[3] loading weights - "model-00009-of-00011.safetensors"
2024-10-18 17:52:05,308 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[1] loading weights - "model-00010-of-00011.safetensors"
2024-10-18 17:52:05,463 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[2] loading weights - "model-00008-of-00011.safetensors"
2024-10-18 17:52:07,823 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[1] loading weights - "model-00005-of-00011.safetensors"
2024-10-18 17:52:08,220 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[3] loading weights - "model-00010-of-00011.safetensors"
2024-10-18 17:52:08,338 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[2] loading weights - "model-00002-of-00011.safetensors"
2024-10-18 17:52:10,323 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[1] loading weights - "model-00003-of-00011.safetensors"
2024-10-18 17:52:10,680 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[3] loading weights - "model-00011-of-00011.safetensors"
2024-10-18 17:52:10,950 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[2] loading weights - "model-00010-of-00011.safetensors"
2024-10-18 17:52:11,591 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[1] loading weights - "model-00011-of-00011.safetensors"
2024-10-18 17:52:11,743 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[3] loading weights - "model-00005-of-00011.safetensors"
2024-10-18 17:52:11,752 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[2] loading weights - "model-00005-of-00011.safetensors"
2024-10-18 17:52:12,274 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[2] loading weights - "model-00001-of-00011.safetensors"
2024-10-18 17:52:12,537 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[3] loading weights - "model-00004-of-00011.safetensors"
2024-10-18 17:52:12,599 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[1] loading weights - "model-00001-of-00011.safetensors"
2024-10-18 17:52:12,980 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[3] loading weights - "model-00007-of-00011.safetensors"
2024-10-18 17:52:13,322 - lmdeploy - �[33mWARNING�[0m - model_agent.py:70 - device<2> No enough memory. update max_prefill_token_num=4096
2024-10-18 17:52:13,443 - lmdeploy - �[37mINFO�[0m - model_weight_loader.py:152 - rank[1] loading weights - "model-00008-of-00011.safetensors"
2024-10-18 17:52:13,563 - lmdeploy - �[33mWARNING�[0m - model_agent.py:70 - device<3> No enough memory. update max_prefill_token_num=4096
2024-10-18 17:52:13,982 - lmdeploy - �[33mWARNING�[0m - model_agent.py:70 - device<1> No enough memory. update max_prefill_token_num=4096
2024-10-18 17:52:14,511 - lmdeploy - �[37mINFO�[0m - cache_engine.py:36 - build CacheEngine with config:CacheConfig(max_batches=128, block_size=64, num_cpu_blocks=204, num_gpu_blocks=365, window_size=-1, cache_max_entry_count=0.5, max_prefill_token_num=4096, enable_prefix_caching=False)
2024-10-18 17:52:15,750 - lmdeploy - �[37mINFO�[0m - async_engine.py:168 - updated backend_config=PytorchEngineConfig(dtype='auto', tp=4, session_len=32768, max_batch_size=128, cache_max_entry_count=0.5, prefill_interval=16, block_size=64, num_cpu_blocks=0, num_gpu_blocks=0, adapters=None, max_prefill_token_num=8192, thread_safe=False, enable_prefix_caching=False, device_type='cuda', eager_mode=False, custom_module_map=None, download_dir=None, revision=None)
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}
HINT:    Please open �[93m�[1mhttp://0.0.0.0:8000�[0m in a browser for detailed api usage!!!
HINT:    Please open �[93m�[1mhttp://0.0.0.0:8000�[0m in a browser for detailed api usage!!!
HINT:    Please open �[93m�[1mhttp://0.0.0.0:8000�[0m in a browser for detailed api usage!!!
INFO:     Started server process [973]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
2024-10-18 17:54:59,080 - lmdeploy - �[37mINFO�[0m - logger.py:27 - session_id=1, prompt=[{'role': 'user', 'content': 'ping'}]
2024-10-18 17:54:59,082 - lmdeploy - �[37mINFO�[0m - logger.py:41 - session_id=1, prompt='<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nping<|im_end|>\n<|im_start|>assistant\n', gen_config=GenerationConfig(n=1, max_new_tokens=5, do_sample=True, top_p=1.0, top_k=40, min_p=0.0, temperature=0.7, repetition_penalty=1.0, ignore_eos=False, random_seed=2595885313040664692, stop_words=None, bad_words=None, stop_token_ids=[151645], bad_token_ids=None, min_new_tokens=None, skip_special_tokens=True, logprobs=None, response_format=None, logits_processors=None), prompt_token_id=[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 9989, 151645, 198, 151644, 77091, 198], adapter_name=None.
2024-10-18 17:54:59,082 - lmdeploy - �[37mINFO�[0m - async_engine.py:527 - session_id=1, history_tokens=0, input_tokens=20, max_new_tokens=5, seq_start=True, seq_end=True, step=0, prep=True
2024-10-18 17:55:00,025 - lmdeploy - �[31mERROR�[0m - request.py:21 - Engine loop failed with error: map::at
Traceback (most recent call last):
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
    task.result()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 946, in async_loop
    await self._async_loop()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 940, in _async_loop
    await __step()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 928, in __step
    raise e
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 922, in __step
    raise out
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 856, in _async_loop_background
    await self._async_step_background(
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 738, in _async_step_background
    output = await self._async_model_forward(
  File "/opt/lmdeploy/lmdeploy/utils.py", line 239, in __tmp
    return (await func(*args, **kwargs))
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 629, in _async_model_forward
    ret = await __forward(inputs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 607, in __forward
    return await self.model_agent.async_forward(
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 706, in async_forward
    output = self._forward_impl(inputs,
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 673, in _forward_impl
    output = model_forward(
  File "/opt/py3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 151, in model_forward
    output = model(**input_dict)
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 160, in __call__
    runner.capture(**kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 77, in capture
    self.model(**padded_kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 379, in forward
    hidden_states = self.model(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 318, in forward
    hidden_states, residual = decoder_layer(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 226, in forward
    hidden_states = self.self_attn(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 121, in forward
    attn_output = self.attn_fwd(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/nn/attention.py", line 65, in forward
    return self.impl.forward(
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/attention.py", line 103, in forward
    self.paged_attention_fwd(
  File "/opt/lmdeploy/lmdeploy/pytorch/kernels/cuda/pagedattention.py", line 570, in paged_attention_fwd
    _fwd_kernel[grid](q,
  File "<string>", line 16, in __fwd_kernel_launcher
  File "/opt/py3/lib/python3.10/site-packages/triton/runtime/jit.py", line 167, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/triton/runtime/jit.py", line 416, in run
    self.cache[device][key] = compile(
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 193, in compile
    next_module = compile_ir(module, metadata)
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/backends/cuda.py", line 199, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/backends/cuda.py", line 173, in make_llir
    ret = translate_triton_gpu_to_llvmir(src, capability, tma_infos, runtime.TARGET.NVVM)
IndexError: map::at
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/usr/lib/python3.10/asyncio/queues.py", line 159, in get
    await getter
asyncio.exceptions.CancelledError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.10/asyncio/tasks.py", line 456, in wait_for
    return fut.result()
asyncio.exceptions.CancelledError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/request.py", line 171, in __no_threadsafe_get
    return await asyncio.wait_for(self.resp_que.get(), timeout)
  File "/usr/lib/python3.10/asyncio/tasks.py", line 458, in wait_for
    raise exceptions.TimeoutError() from exc
asyncio.exceptions.TimeoutError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/py3/lib/python3.10/site-packages/uvicorn/protocols/http/h11_impl.py", line 406, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "/opt/py3/lib/python3.10/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
  File "/opt/py3/lib/python3.10/site-packages/fastapi/applications.py", line 1054, in __call__
    await super().__call__(scope, receive, send)
  File "/opt/py3/lib/python3.10/site-packages/starlette/applications.py", line 113, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/opt/py3/lib/python3.10/site-packages/starlette/middleware/errors.py", line 165, in __call__
    await self.app(scope, receive, _send)
  File "/opt/py3/lib/python3.10/site-packages/starlette/middleware/cors.py", line 85, in __call__
    await self.app(scope, receive, send)
  File "/opt/py3/lib/python3.10/site-packages/starlette/middleware/exceptions.py", line 62, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/opt/py3/lib/python3.10/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
    await app(scope, receive, sender)
  File "/opt/py3/lib/python3.10/site-packages/starlette/routing.py", line 715, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/opt/py3/lib/python3.10/site-packages/starlette/routing.py", line 735, in app
    await route.handle(scope, receive, send)
  File "/opt/py3/lib/python3.10/site-packages/starlette/routing.py", line 288, in handle
    await self.app(scope, receive, send)
  File "/opt/py3/lib/python3.10/site-packages/starlette/routing.py", line 76, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/opt/py3/lib/python3.10/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
    await app(scope, receive, sender)
  File "/opt/py3/lib/python3.10/site-packages/starlette/routing.py", line 73, in app
    response = await f(request)
  File "/opt/py3/lib/python3.10/site-packages/fastapi/routing.py", line 301, in app
    raw_response = await run_endpoint_function(
  File "/opt/py3/lib/python3.10/site-packages/fastapi/routing.py", line 212, in run_endpoint_function
    return await dependant.call(**values)
  File "/opt/lmdeploy/lmdeploy/serve/openai/api_server.py", line 475, in chat_completions_v1
    async for res in result_generator:
  File "/opt/lmdeploy/lmdeploy/serve/async_engine.py", line 565, in generate
    async for outputs in generator.async_stream_infer(
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine_instance.py", line 177, in async_stream_infer
    resp = await self.req_sender.async_recv(req_id)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/request.py", line 312, in async_recv
    resp: Response = await self._async_resp_get()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/request.py", line 187, in _async_resp_get
    return await __no_threadsafe_get()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/request.py", line 175, in __no_threadsafe_get
    exit(1)
  File "/usr/lib/python3.10/_sitebuiltins.py", line 26, in __call__
    raise SystemExit(code)
SystemExit: 1
INFO:     192.168.1.254:49272 - "POST /v1/chat/completions HTTP/1.1" 500 Internal Server Error
2024-10-18 17:55:00,436 - lmdeploy - �[31mERROR�[0m - model_agent.py:489 - Rank[1] failed.
Traceback (most recent call last):
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 486, in _start_tp_process
    func(rank, *args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 449, in _tp_model_loop
    model_forward(
  File "/opt/py3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 151, in model_forward
    output = model(**input_dict)
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 160, in __call__
    runner.capture(**kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 77, in capture
    self.model(**padded_kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 379, in forward
    hidden_states = self.model(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 318, in forward
    hidden_states, residual = decoder_layer(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 226, in forward
    hidden_states = self.self_attn(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 121, in forward
    attn_output = self.attn_fwd(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/nn/attention.py", line 65, in forward
    return self.impl.forward(
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/attention.py", line 103, in forward
    self.paged_attention_fwd(
  File "/opt/lmdeploy/lmdeploy/pytorch/kernels/cuda/pagedattention.py", line 570, in paged_attention_fwd
    _fwd_kernel[grid](q,
  File "<string>", line 16, in __fwd_kernel_launcher
  File "/opt/py3/lib/python3.10/site-packages/triton/runtime/jit.py", line 167, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/triton/runtime/jit.py", line 416, in run
    self.cache[device][key] = compile(
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 193, in compile
    next_module = compile_ir(module, metadata)
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/backends/cuda.py", line 199, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/backends/cuda.py", line 173, in make_llir
    ret = translate_triton_gpu_to_llvmir(src, capability, tma_infos, runtime.TARGET.NVVM)
IndexError: map::at
2024-10-18 17:55:00,456 - lmdeploy - �[31mERROR�[0m - model_agent.py:489 - Rank[2] failed.
2024-10-18 17:55:00,457 - lmdeploy - �[31mERROR�[0m - model_agent.py:489 - Rank[3] failed.
Traceback (most recent call last):
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 486, in _start_tp_process
    func(rank, *args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 449, in _tp_model_loop
    model_forward(
  File "/opt/py3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 151, in model_forward
    output = model(**input_dict)
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 160, in __call__
    runner.capture(**kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 77, in capture
    self.model(**padded_kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 379, in forward
    hidden_states = self.model(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 318, in forward
    hidden_states, residual = decoder_layer(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 226, in forward
    hidden_states = self.self_attn(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 121, in forward
    attn_output = self.attn_fwd(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/nn/attention.py", line 65, in forward
    return self.impl.forward(
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/attention.py", line 103, in forward
    self.paged_attention_fwd(
  File "/opt/lmdeploy/lmdeploy/pytorch/kernels/cuda/pagedattention.py", line 570, in paged_attention_fwd
    _fwd_kernel[grid](q,
  File "<string>", line 16, in __fwd_kernel_launcher
  File "/opt/py3/lib/python3.10/site-packages/triton/runtime/jit.py", line 167, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/triton/runtime/jit.py", line 416, in run
    self.cache[device][key] = compile(
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 193, in compile
    next_module = compile_ir(module, metadata)
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/backends/cuda.py", line 199, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/backends/cuda.py", line 173, in make_llir
    ret = translate_triton_gpu_to_llvmir(src, capability, tma_infos, runtime.TARGET.NVVM)
IndexError: map::at
Traceback (most recent call last):
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 486, in _start_tp_process
    func(rank, *args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 449, in _tp_model_loop
    model_forward(
  File "/opt/py3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 151, in model_forward
    output = model(**input_dict)
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 160, in __call__
    runner.capture(**kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 77, in capture
    self.model(**padded_kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 379, in forward
    hidden_states = self.model(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 318, in forward
    hidden_states, residual = decoder_layer(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 226, in forward
    hidden_states = self.self_attn(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 121, in forward
    attn_output = self.attn_fwd(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/nn/attention.py", line 65, in forward
    return self.impl.forward(
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/attention.py", line 103, in forward
    self.paged_attention_fwd(
  File "/opt/lmdeploy/lmdeploy/pytorch/kernels/cuda/pagedattention.py", line 570, in paged_attention_fwd
    _fwd_kernel[grid](q,
  File "<string>", line 16, in __fwd_kernel_launcher
  File "/opt/py3/lib/python3.10/site-packages/triton/runtime/jit.py", line 167, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/triton/runtime/jit.py", line 416, in run
    self.cache[device][key] = compile(
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 193, in compile
    next_module = compile_ir(module, metadata)
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/backends/cuda.py", line 199, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/backends/cuda.py", line 173, in make_llir
    ret = translate_triton_gpu_to_llvmir(src, capability, tma_infos, runtime.TARGET.NVVM)
IndexError: map::at
@grimoire
Copy link
Collaborator

attention kernel does not support triton<3.0 on 2080ti(compute capability<8). Please upgrade triton version.

@bltcn
Copy link
Author

bltcn commented Oct 23, 2024

attention kernel does not support triton<3.0 on 2080ti(compute capability<8). Please upgrade triton version.

检查过,已经是3.0了,还是报错。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants