Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: provide access method with api for 7B:int8 #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ docker run --gpus all --ipc=host --ulimit memlock=-1 -v `pwd`/models:/llama_data
For **the minimum memory** requirements (7B almost 7.12GB) docker images, use the following command:

```bash
docker run --gpus all --ipc=host --ulimit memlock=-1 -v `pwd`/models:/app/models -p 7860:7860 -it --rm soulteary/llama:int8
docker run --gpus all --ipc=host --ulimit memlock=-1 -e PORT=7860 -v `pwd`/models:/app/models -p 7860:7860 -it --rm soulteary/llama:int8
```

**For fine-tune**, [read this documentation](https://soulteary.com/2023/03/25/model-finetuning-on-llama-65b-large-model-using-docker-and-alpaca-lora.html).
Expand Down
7 changes: 5 additions & 2 deletions docker/Dockerfile.int8
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ RUN git clone https://github.com/tloen/llama-int8.git && \
python setup.py bdist_wheel && \
pip install -r requirements.txt && \
pip install -e . && \
pip install gradio
pip install gradio fastapi uvicorn

COPY webui/api.py ./api.py
COPY webui/int8.py ./webapp.py

CMD ["python", "webapp.py"]
ENV PORT=7860

CMD uvicorn webapp:app --host 0.0.0.0 --port $PORT
74 changes: 74 additions & 0 deletions webui/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Optional, Union, List
from fastapi import FastAPI
from pydantic import BaseModel

app = FastAPI()

app_env = {
"app": app,
"generator": None # LLaMA
}

"""
body: {
"prompts": ["text1", "text2],
"max_gen_len": 1024, # default
"temperature": 0.8, # default
"top_p": 0.95, # default
"repetition_penalty": { # default
"range": 1024,
"slope": 0,
"value": 1.15
}
}
"""

class RepetitionPenalty(BaseModel):
range: Optional[int] = 1024
slope: Optional[float] = 0
value: Optional[float] = 1.15

class GenerateParam(BaseModel):
prompts: Union[List[str], str]
max_gen_len: Optional[int] = 1024
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.95
repetition_penalty: Optional[RepetitionPenalty] = RepetitionPenalty()

@app.post('/generate')
async def generate(params: GenerateParam):
if len(params.prompts) == 0:
return {
"error": -1,
"msg": "There are prompts should be provided."
}

if type(params.prompts) is not list:
params.prompts = [params.prompts]

results = app_env["generator"].generate(
prompts = params.prompts,
max_gen_len = params.max_gen_len,
temperature = params.temperature,
top_p = params.top_p,
repetition_penalty_range = params.repetition_penalty.range,
repetition_penalty_slope = params.repetition_penalty.slope,
repetition_penalty = params.repetition_penalty.value
)

return {
"results": results
}


def get_args():
# import argparse
# parser = argparse.ArgumentParser()
# parser.add_argument("--ckpt_dir", type=str, default="./models/7B")
# parser.add_argument("--tokenizer_path", type=str, default="./models/tokenizer.model")
# return parser.parse_args()

return {
"ckpt_dir": "./models/7B",
"tokenizer_path": "./models/tokenizer.model"
}
59 changes: 26 additions & 33 deletions webui/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pathlib import Path
import gradio as gr

from api import app_env, get_args, app

os.environ["BITSANDBYTES_NOWELCOME"] = "1"

def load(
Expand Down Expand Up @@ -99,36 +101,27 @@ def process(prompt: str):
print("Generated:\n", results[0])
return str(results[0])


def get_args():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_dir", type=str, default="./models/7B")
parser.add_argument("--tokenizer_path", type=str,
default="./models/tokenizer.model")
return parser.parse_args()


if __name__ == '__main__':
args = get_args()
ckpt_dir = args.ckpt_dir
tokenizer_path = args.tokenizer_path
temperature: float = 0.8
top_p: float = 0.95
max_seq_len: int = 512
max_batch_size: int = 1
use_int8: bool = True
repetition_penalty_range: int = 1024
repetition_penalty_slope: float = 0
repetition_penalty: float = 1.15

generator = load(ckpt_dir, tokenizer_path,
max_seq_len, max_batch_size, use_int8)

demo = gr.Interface(
fn=process,
inputs=gr.Textbox(lines=10, placeholder="Your prompt here..."),
outputs="text",
)

demo.launch(server_name="0.0.0.0")
args = get_args()
ckpt_dir = args["ckpt_dir"]
tokenizer_path = args["tokenizer_path"]
temperature: float = 0.8
top_p: float = 0.95
max_seq_len: int = 512
max_batch_size: int = 1
use_int8: bool = True
repetition_penalty_range: int = 1024
repetition_penalty_slope: float = 0
repetition_penalty: float = 1.15

generator = load(ckpt_dir, tokenizer_path,
max_seq_len, max_batch_size, use_int8)

demo = gr.Interface(
fn=process,
inputs=gr.Textbox(lines=10, placeholder="Your prompt here..."),
outputs="text",
)

# demo.launch(server_name="0.0.0.0")
app_env["generator"] = generator
app = gr.mount_gradio_app(app_env["app"], demo, path="/")