Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Usage]: OOM when using Llama-3.2-11B-Vision-Instruct #8879

Closed
1 task done
hrson-1203 opened this issue Sep 27, 2024 · 22 comments · Fixed by #8894
Closed
1 task done

[Usage]: OOM when using Llama-3.2-11B-Vision-Instruct #8879

hrson-1203 opened this issue Sep 27, 2024 · 22 comments · Fixed by #8894
Labels
usage How to use vllm

Comments

@hrson-1203
Copy link

Your current environment

The output of `python collect_env.py`

Collecting environment information...
PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe

Nvidia driver version: 535.171.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 57 bits virtual
CPU(s): 64
On-line CPU(s) list: 0-63
Thread(s) per core: 2
Core(s) per socket: 16
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 106
Model name: Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz
Stepping: 6
CPU MHz: 806.789
CPU max MHz: 3500.0000
CPU min MHz: 800.0000
BogoMIPS: 5800.00
Virtualization: VT-x
L1d cache: 1.5 MiB
L1i cache: 1 MiB
L2 cache: 40 MiB
L3 cache: 48 MiB
NUMA node0 CPU(s): 0-15,32-47
NUMA node1 CPU(s): 16-31,48-63
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid fsrm md_clear pconfig flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-ml-py==12.535.161
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.6.68
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] pyzmq==25.1.2
[pip3] torch==2.4.0
[pip3] torchvision==0.19.0
[pip3] transformers==4.45.0
[pip3] triton==3.0.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.2.106 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi
[conda] nvidia-ml-py 12.535.161 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.20.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.6.68 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.1.105 pypi_0 pypi
[conda] pyzmq 26.2.0 pypi_0 pypi
[conda] torch 2.4.0 pypi_0 pypi
[conda] torchvision 0.19.0 pypi_0 pypi
[conda] transformers 4.45.0 pypi_0 pypi
[conda] triton 3.0.0 pypi_0 pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: N/A
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
�[4mGPU0 GPU1 CPU Affinity NUMA Affinity GPU NUMA ID�[0m
GPU0 X NV12 0-15,32-47 0 N/A
GPU1 NV12 X 0-15,32-47 0 N/A

Legend:

X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks

How would you like to use vllm

I want to run inference of a meta-llama/Llama-3.2-11B-Vision-Instruct.

I tried to load the multi-modal model into vllm and proceed with inference.
However, even with two A100s, an OOM error occurred while loading the 11B model.

The error message below indicates that only A100 was used.
Even if both are used, the same OOM occurs.

How can I load the Llama-3.2-11B-Vision-Instruct model with vllm?

- Using vLLM
Args - Namespace(gpu_devices='0', model='meta-llama/Llama-3.2-11B-Vision-Instruct', model_len=4096)
INFO 09-27 11:18:06 llm_engine.py:226] Initializing an LLM engine (v0.6.1.dev238+ge2c6e0a82) with config: model='meta-llama/Llama-3.2-11B-Vision-Instruct', speculative_config=None, tokenizer='meta-llama/Llama-3.2-11B-Vision-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=meta-llama/Llama-3.2-11B-Vision-Instruct, use_v2_block_manager=False, num_scheduler_steps=1, multi_step_stream_outputs=False, enable_prefix_caching=False, use_async_output_proc=True, use_cached_outputs=False, mm_processor_kwargs=None)
INFO 09-27 11:18:07 enc_dec_model_runner.py:140] EncoderDecoderModelRunner requires XFormers backend; overriding backend auto-selection and forcing XFormers.
INFO 09-27 11:18:07 selector.py:116] Using XFormers backend.
/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("xformers_flash::flash_fwd")
/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
INFO 09-27 11:18:16 model_runner.py:1014] Starting to load model meta-llama/Llama-3.2-11B-Vision-Instruct...
INFO 09-27 11:18:17 selector.py:116] Using XFormers backend.
INFO 09-27 11:18:18 weight_utils.py:242] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  20% Completed | 1/5 [00:00<00:00,  4.18it/s]
Loading safetensors checkpoint shards:  40% Completed | 2/5 [00:01<00:01,  1.64it/s]
Loading safetensors checkpoint shards:  60% Completed | 3/5 [00:01<00:01,  1.37it/s]
Loading safetensors checkpoint shards:  80% Completed | 4/5 [00:02<00:00,  1.28it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:03<00:00,  1.25it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:03<00:00,  1.36it/s]

INFO 09-27 11:18:22 model_runner.py:1025] Loading model weights took 19.9073 GB
INFO 09-27 11:18:35 enc_dec_model_runner.py:297] Starting profile run for multi-modal models.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]:     return _run_code(code, main_globals, None,
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/home/heerak/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
[rank0]:     cli.main()
[rank0]:   File "/home/heerak/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
[rank0]:     run()
[rank0]:   File "/home/heerak/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
[rank0]:     runpy.run_path(target, run_name="__main__")
[rank0]:   File "/home/heerak/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
[rank0]:     return _run_module_code(code, init_globals, run_name,
[rank0]:   File "/home/heerak/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
[rank0]:     _run_code(code, mod_globals, init_globals,
[rank0]:   File "/home/heerak/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/home/heerak/workspace/cw-llm-eval/generator.py", line 37, in <module>
[rank0]:     llm = LLM(
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 214, in __init__
[rank0]:     self.llm_engine = LLMEngine.from_engine_args(
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 564, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 339, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 474, in _initialize_kv_caches
[rank0]:     self.model_executor.determine_num_available_blocks())
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/executor/gpu_executor.py", line 114, in determine_num_available_blocks
[rank0]:     return self.driver_worker.determine_num_available_blocks()
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/worker/worker.py", line 223, in determine_num_available_blocks
[rank0]:     self.model_runner.profile_run()
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/worker/enc_dec_model_runner.py", line 348, in profile_run
[rank0]:     self.execute_model(model_input, kv_caches, intermediate_tensors)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/worker/enc_dec_model_runner.py", line 201, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/model_executor/models/mllama.py", line 1084, in forward
[rank0]:     cross_attention_states = self.vision_model(pixel_values,
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/model_executor/models/mllama.py", line 556, in forward
[rank0]:     output = self.transformer(
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/model_executor/models/mllama.py", line 430, in forward
[rank0]:     hidden_states = encoder_layer(
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/model_executor/models/mllama.py", line 398, in forward
[rank0]:     hidden_state = self.mlp(hidden_state)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/model_executor/models/clip.py", line 278, in forward
[rank0]:     hidden_states, _ = self.fc1(hidden_states)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/model_executor/layers/linear.py", line 367, in forward
[rank0]:     output_parallel = self.quant_method.apply(self, input_, bias)
[rank0]:   File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/model_executor/layers/linear.py", line 135, in apply
[rank0]:     return F.linear(x, layer.weight, bias)
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 15.70 GiB. GPU 0 has a total capacity of 79.15 GiB of which 4.05 GiB is free. Including non-PyTorch memory, this process has 75.09 GiB memory in use. Of the allocated memory 62.54 GiB is allocated by PyTorch, and 12.05 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@hrson-1203 hrson-1203 added the usage How to use vllm label Sep 27, 2024
@DarkLight1337
Copy link
Member

To split the model across GPUs, you should set tensor_parallel_size argument to the number of GPUs.

@hrson-1203
Copy link
Author

hrson-1203 commented Sep 27, 2024

To split the model across GPUs, you should set tensor_parallel_size argument to the number of GPUs.

@DarkLight1337
Of course, I tried it with 2 A100 settings.

If it's 11B, wouldn't it be possible to use approximately 22GB of memory?

@DarkLight1337
Copy link
Member

You should also consider the memory required for inference, not just the model weights. If you run into OOM issues, you may need to reduce max_model_len and/or max_num_seqs as shown in the example script.

@DarkLight1337 DarkLight1337 changed the title [Usage]: [Usage]: OOM when using Llama-3.2-11B-Vision-Instruct Sep 27, 2024
@hrson-1203
Copy link
Author

hrson-1203 commented Sep 27, 2024

You should also consider the memory required for inference, not just the model weights. If you run into OOM issues, you may need to reduce max_model_len and/or max_num_seqs as shown in the example script.

@DarkLight1337

llm = LLM(
    model="meta-llama/Llama-3.2-11B-Vision-Instruct",
    tensor_parallel_size=2,
    max_model_len=4096,
    gpu_memory_utilization=0.8,
    trust_remote_code=True,  # !
)

When �I run the above code, lines 338 to 339 of llm_engine.py

if not self.model_config.embedding_mode:
    self._initialize_kv_caches()

An error occurs when running here.

@DarkLight1337
Copy link
Member

How did you install vLLM? I see in the output of collect_env.py

ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: N/A
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
�[4mGPU0 GPU1 CPU Affinity NUMA Affinity GPU NUMA ID�[0m
GPU0 X NV12 0-15,32-47 0 N/A
GPU1 NV12 X 0-15,32-47 0 N/A

@hrson-1203
Copy link
Author

How did you install vLLM? I see in the output of collect_env.py

ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: N/A
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
�[4mGPU0 GPU1 CPU Affinity NUMA Affinity GPU NUMA ID�[0m
GPU0 X NV12 0-15,32-47 0 N/A
GPU1 NV12 X 0-15,32-47 0 N/A

I don't know why vllm version is not included;;

I'm using vllm==0.6.2

@DarkLight1337
Copy link
Member

What is the command you used to install vLLM?

@hrson-1203
Copy link
Author

What is the command you used to install vLLM?

I created a conda virtual environment and installed vllm using the pip install vllm command.

@DarkLight1337
Copy link
Member

@youkaichao @dtrifiro There seems to be something wrong with collect_env.py, can you look into this? I suspect it has something to do with the recent change to using setuptools-scm.

@DarkLight1337
Copy link
Member

When �I run the above code, lines 338 to 339 of llm_engine.py

if not self.model_config.embedding_mode:
    self._initialize_kv_caches()

An error occurs when running here.

Can you show more about the error?

@hrson-1203
Copy link
Author

When �I run the above code, lines 338 to 339 of llm_engine.py

if not self.model_config.embedding_mode:
    self._initialize_kv_caches()

An error occurs when running here.

Can you show more about the error?

I'm continuing debugging now;;
Other than the error message that appeared when I first wrote the issue, nothing appears.

I'm still debugging so I don't know exactly where the error occurred.

@hrson-1203
Copy link
Author

hrson-1203 commented Sep 27, 2024

@DarkLight1337

File "/home/heerak/miniconda3/envs/eval/lib/python3.10/site-packages/vllm/model_executor/models/mllama.py", line 1084, in forward
cross_attention_states = self.vision_model(pixel_values,

In this part, OOM occurs as the memory continues to increase as it passes forward.

@ywang96
Copy link
Member

ywang96 commented Sep 27, 2024

@hrson-1203 You should only need to set max_num_seqs=16 and enforce_eager=True in order to launch the model.

@hrson-1203
Copy link
Author

max_num_seqs=16 and enforce_eager=True

Oh, it finally works.

Could you also explain why we need to use max_num_seqs=16 and enforce_eager=True?

@ywang96
Copy link
Member

ywang96 commented Sep 27, 2024

Could you also explain why we need to use max_num_seqs=16 and enforce_eager=True?

Mostly because of its architecture.

max_num_seqs=16: this model has a context length of 128k+ plus the additional block tables for cross-attention layers, so the default setting max_num_seqs=256 won't work.

enforce_eager=True: By default we turned on cuda graph for decoder-only language models, but the cross attention layers in this model are only needed at inference time if there's an image. This dynamic nature is incompatible with the current cuda graph implementation, and supporting it is WIP.

@hrson-1203
Copy link
Author

Could you also explain why we need to use max_num_seqs=16 and enforce_eager=True?

Mostly because of its architecture.

max_num_seqs=16: this model has a context length of 128k+ plus the additional block tables for cross-attention layers, so the default setting max_num_seqs=256 won't work.

enforce_eager=True: By default we turned on cuda graph for decoder-only language models, but the cross attention layers in this model are only needed at inference time if there's an image. This dynamic nature is compatible with the current cuda graph implementation, and supporting it is WIP.

thank you so much for explaining

Thanks to you, I was able to test it.

What parts of each model should I look at to figure out if such a setting is necessary?

@ywang96
Copy link
Member

ywang96 commented Sep 27, 2024

What parts of each model should I look at to figure out if such a setting is necessary?

For most models, you don't need to worry about enforcing eager mode unless you need additional VRAM (This is because cuda graphs themselves also consume some memory), and if you run into OOM issue, always try lowering max_num_seqs.

@swapnil3597
Copy link

swapnil3597 commented Sep 27, 2024

I'm trying to run meta-llama/Llama-3.2-11B-Vision-Instruct using vLLM docker:

GPU Server specifications:

  • GPU Count: 4
  • GPU Type: A100 - 80GB

vLLM Docker run command:

docker run  --gpus all \
    -v /data/hf_cache/ \
    --env "HUGGING_FACE_HUB_TOKEN=<token>" \
    -p 8000:8000 \
    --ipc=host \
    vllm/vllm-openai:latest \
    --model meta-llama/Llama-3.2-11B-Vision-Instruct \
    --tensor-parallel-size 4 \
    --max-model-len 4096 \
    --download_dir /data/vllm_cache \
    --enforce-eager

Facing similar issue.
Have raised a new issue: [Usage]: DOCKER - Getting OOM while running meta-llama/Llama-3.2-11B-Vision-Instruct

@DarkLight1337
Copy link
Member

I'm trying to run meta-llama/Llama-3.2-11B-Vision-Instruct using vLLM docker:

GPU Server specifications:

  • GPU Count: 4
  • GPU Type: A100 - 80GB

vLLM Docker run command:

docker run  --gpus all \
    -v /data/hf_cache/ \
    --env "HUGGING_FACE_HUB_TOKEN=<token>" \
    -p 8000:8000 \
    --ipc=host \
    vllm/vllm-openai:latest \
    --model meta-llama/Llama-3.2-11B-Vision-Instruct \
    --tensor-parallel-size 4 \
    --max-model-len 4096 \
    --download_dir /data/vllm_cache \
    --enforce-eager

Facing similar issue.

As mentioned above, you should limit --max-num-seqs to a smaller value, e.g. 16.

@dtrifiro
Copy link
Contributor

@DarkLight1337 fix is included in #8900

@Reichenbachian
Copy link

Reichenbachian commented Oct 2, 2024

Found this thread very useful. Appreciate the guidance and generally all your work on the vllm multimodal models @DarkLight1337

@rdaiello
Copy link

rdaiello commented Oct 4, 2024

Thank you @DarkLight1337. This solved my issue as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
usage How to use vllm
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants