Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference Speed is Extremely Slow for 72B Model with Long Contexts #1767

Open
wrench1997 opened this issue Sep 27, 2024 · 0 comments
Open

Inference Speed is Extremely Slow for 72B Model with Long Contexts #1767

wrench1997 opened this issue Sep 27, 2024 · 0 comments

Comments

@wrench1997
Copy link

wrench1997 commented Sep 27, 2024

Description:

When running inference on a 72B model with long context lengths(40960), the process is extremely slow, taking approximately 40 minutes to generate results. However, using a standard transformer package, the same task takes only about 5 minutes.

Details:

  • Model name: qwen2.5:72b-instruct-q8_0
  • Model Size: 72B
  • Context Length: Long
  • Issue: Inference time is around 40 minutes, compared to 5 minutes using a standard transformer package.
  • Expected Behavior: Inference should complete within a reasonable time frame, closer to the 5-minute benchmark seen with the transformer package.
  • Actual Behavior: Severe slowdown, resulting in 40-minute inference times when JSON formatting is enforced.

Steps to Reproduce:

  1. Set up the 72B model with long context inputs and enforce multiple fields in JSON format.
  2. Run the inference.
  3. Compare the time taken against the standard transformer package.

Environment:

  • Hardware:
nvidia-smi 
Fri Sep 27 16:39:49 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A800-SXM4-80GB          On  | 00000000:10:00.0 Off |                    0 |
| N/A   44C    P0              56W / 400W |      2MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A800-SXM4-80GB          On  | 00000000:8F:00.0 Off |                    0 |
| N/A   31C    P0              57W / 400W |      2MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A800-SXM4-80GB          On  | 00000000:C6:00.0 Off |                    0 |
| N/A   31C    P0              54W / 400W |      2MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA A800-SXM4-80GB          On  | 00000000:CA:00.0 Off |                    0 |
| N/A   33C    P0              58W / 400W |      2MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
  • Software:
llama_cpp_python  0.3.0
torch  2.4.1+cu121
nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Jun_13_19:16:58_PDT_2023
Cuda compilation tools, release 12.2, V12.2.91
Build cuda_12.2.r12.2/compiler.32965470_0

TestCode

from typing import Optional
from llama_index.llms.llama_cpp import LlamaCPP
from llama_cpp import LogitsProcessorList
from lmformatenforcer import CharacterLevelParser, JsonSchemaParser
from lmformatenforcer.integrations.llamacpp import build_llamacpp_logits_processor

llm_llamacpp = LlamaCPP(model_path="/root/Qwen/qwen2.5:72b-instruct-q8_0.gguf", 
                                model_kwargs={
                                    "n_gpu_layers": -1,
                                },  # if compiled to use GPU
                                max_new_tokens=40960 , # 131072
                                context_window=40960,
                                temperature=0,
                                verbose = True
        )

from pydantic import BaseModel
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""

def get_prompt(message: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
    return f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{message} [/INST]'

class AnswerFormat(BaseModel):
    first_name: str
    last_name: str
    year_of_birth: int
    num_seasons_in_nba: int
    ...........
    About 50 fields


question = <4k length content>
question_with_schema = f'{question}{AnswerFormat.schema_json()}'
prompt = get_prompt(question_with_schema)



def llamaindex_llamacpp_lm_format_enforcer(llm: LlamaCPP, prompt: str, character_level_parser: Optional[CharacterLevelParser]) -> str:
    logits_processors: Optional[LogitsProcessorList] = None
    if character_level_parser:
        logits_processors = LogitsProcessorList([build_llamacpp_logits_processor(llm._model, character_level_parser)])
    
    # If changing the character level parser each call, inject it before calling complete. If its the same format
    # each time, you can set it once after creating the LlamaCPP model
    llm.generate_kwargs['logits_processor'] = logits_processors
    output = llm.complete(prompt)
    text: str = output.text
    return text

result = llamaindex_llamacpp_lm_format_enforcer(llm_llamacpp, prompt, JsonSchemaParser(AnswerFormat.schema()))
print(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant