-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllama3_local.py
100 lines (82 loc) · 3.26 KB
/
llama3_local.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import List, Optional, Union
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter
from vllm.outputs import RequestOutput
from vllm import SamplingParams
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
import torch
import gradio as gr
class StreamingLLM:
def __init__(
self,
model: str,
dtype: str = "auto",
quantization: Optional[str] = None,
**kwargs,
) -> None:
engine_args = EngineArgs(model=model, quantization=quantization, dtype=dtype, enforce_eager=True)
self.llm_engine = LLMEngine.from_engine_args(engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
def generate(
self,
prompt: Optional[str] = None,
sampling_params: Optional[SamplingParams] = None
) -> List[RequestOutput]:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id, prompt, sampling_params)
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
yield output
class UI:
def __init__(
self,
llm: StreamingLLM,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
sampling_params: Optional[SamplingParams] = None,
) -> None:
self.llm = llm
self.tokenizer = tokenizer
self.sampling_params = sampling_params
def _generate(self, message, history):
history_chat_format = []
for human, assistant in history:
history_chat_format.append({"role": "user", "content": human })
history_chat_format.append({"role": "assistant", "content": assistant})
history_chat_format.append({"role": "user", "content": message})
prompt = self.tokenizer.apply_chat_template(history_chat_format, tokenize=False)
for chunk in self.llm.generate(prompt, self.sampling_params):
yield chunk.outputs[0].text
def launch(self):
gr.ChatInterface(self._generate, title="Pcs Llama3").launch(server_name="0.0.0.0", server_port=7680)
llm = None
tokenizer = None
sampling_params = None
def get_llm(model):
global llm
if llm is None:
llm = StreamingLLM(model=model, dtype="float16")
return llm
def get_tokenizer():
global tokenizer
if tokenizer is None:
tokenizer = llm.llm_engine.tokenizer.tokenizer
return tokenizer
def get_sampling_params():
global sampling_params
if sampling_params is None:
sampling_params = SamplingParams(temperature=0.6,
top_p=0.9,
max_tokens=4096,
stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
)
return sampling_params
if __name__ == "__main__":
torch.cuda.empty_cache()
llm = get_llm("meta-llama/Llama-3.1-8B-Instruct")
tokenizer = get_tokenizer()
sampling_params = get_sampling_params()
ui = UI(llm, tokenizer, sampling_params)
ui.launch()