Skip to content

Commit

Permalink
my branch update of llm from livepool
Browse files Browse the repository at this point in the history
  • Loading branch information
jjassonn committed Jul 31, 2024
1 parent 3de64b3 commit 22cd05d
Show file tree
Hide file tree
Showing 13 changed files with 642 additions and 28 deletions.
7 changes: 7 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.upscale import UpscalePipeline

return UpscalePipeline(model_id)
case "llm-generate":
from app.pipelines.llm_generate import LLMGeneratePipeline
return LLMGeneratePipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -88,6 +91,10 @@ def load_route(pipeline: str) -> any:
from app.routes import upscale

return upscale.router
case "llm-generate":
from app.routes import llm_generate

return llm_generate.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
97 changes: 97 additions & 0 deletions runner/app/pipelines/llm_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import logging
import os
from typing import Dict, Any, Optional

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from huggingface_hub import file_download, hf_hub_download

logger = logging.getLogger(__name__)


class LLMGeneratePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {
"cache_dir": get_model_dir()
}
self.device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)

# Check for fp16 variant
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
if self.device != "cpu" and has_fp16_variant:
logger.info("LLMGeneratePipeline loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

# Load model
self.model = AutoModelForCausalLM.from_pretrained(
model_id, **kwargs).to(self.device)

# Set up generation config
self.generation_config = self.model.generation_config

# Optional: Add optimizations
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
if sfast_enabled:
logger.info(
"LLMGeneratePipeline will be dynamically compiled with stable-fast for %s",
model_id,
)
from app.pipelines.optim.sfast import compile_model
self.model = compile_model(self.model)

def __call__(self, prompt: str, system_msg: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None, **kwargs) -> Dict[str, Any]:
if system_msg:
input_text = f"{system_msg}\n\n{prompt}"
else:
input_text = prompt

input_ids = self.tokenizer.encode(
input_text, return_tensors="pt").to(self.device)

# Update generation config
gen_kwargs = {}
if temperature is not None:
gen_kwargs['temperature'] = temperature
if max_tokens is not None:
gen_kwargs['max_new_tokens'] = max_tokens

# Merge generation config with provided kwargs
gen_kwargs = {**self.generation_config.to_dict(), **gen_kwargs, **kwargs}

# Generate response
with torch.no_grad():
output = self.model.generate(
input_ids,
**gen_kwargs
)

# Decode the response
response = self.tokenizer.decode(output[0], skip_special_tokens=True)

# Calculate tokens used
tokens_used = len(output[0])

return {
"response": response.strip(),
"tokens_used": tokens_used
}

def __str__(self) -> str:
return f"LLMPipeline model_id={self.model_id}"
7 changes: 7 additions & 0 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.models import AutoencoderKL
from huggingface_hub import file_download, hf_hub_download
from safetensors.torch import load_file

Expand All @@ -34,6 +35,7 @@ class ModelName(Enum):

SDXL_LIGHTNING = "ByteDance/SDXL-Lightning"
SD3_MEDIUM = "stabilityai/stable-diffusion-3-medium-diffusers"
REALISTIC_VISION_V6 = "SG161222/Realistic_Vision_V6.0_B1_noVAE"

@classmethod
def list(cls):
Expand Down Expand Up @@ -71,6 +73,11 @@ def __init__(self, model_id: str):
if os.environ.get("BFLOAT16"):
logger.info("TextToImagePipeline using bfloat16 precision for %s", model_id)
kwargs["torch_dtype"] = torch.bfloat16

# Load VAE for specific models.
if ModelName.REALISTIC_VISION_V6.value in model_id:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
kwargs["vae"] = vae

# Special case SDXL-Lightning because the unet for SDXL needs to be swapped
if ModelName.SDXL_LIGHTNING.value in model_id:
Expand Down
64 changes: 64 additions & 0 deletions runner/app/routes/llm_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging
import os
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, Form, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, LlmResponse, TextResponse, http_error

router = APIRouter()

logger = logging.getLogger(__name__)

RESPONSES = {
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


@router.post("/llm-generate", response_model=LlmResponse, responses=RESPONSES)
@router.post("/llm-generate/", response_model=LlmResponse, responses=RESPONSES, include_in_schema=False)
async def llm_generate(
prompt: Annotated[str, Form()],
model_id: Annotated[str, Form()] = "",
system_msg: Annotated[str, Form()] = None,
temperature: Annotated[float, Form()] = None,
max_tokens: Annotated[int, Form()] = None,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
),
)

try:
result = pipeline(
prompt=prompt,
system_msg=system_msg,
temperature=temperature,
max_tokens=max_tokens
)
return JSONResponse(content=result)
except Exception as e:
logger.error(f"LLM processing error: {str(e)}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("Internal server error during LLM processing."),
)
5 changes: 5 additions & 0 deletions runner/app/routes/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class TextResponse(BaseModel):
chunks: List[chunk]


class LlmResponse(BaseModel):
response: str
tokens_used: int


class APIError(BaseModel):
msg: str

Expand Down
5 changes: 5 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,15 @@ function download_all_models() {
huggingface-cli download prompthero/openjourney-v4 --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download SG161222/RealVisXL_V4.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download stabilityai/stable-diffusion-3-medium-diffusers --include "*.fp16*.safetensors" "*.model" "*.json" "*.txt" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"}
huggingface-cli download SG161222/Realistic_Vision_V6.0_B1_noVAE --include "*.fp16.safetensors" "*.json" "*.txt" "*.bin" --exclude ".onnx" ".onnx_data" --cache-dir models

# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download LLM models (Warning: large model size)
huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models


#Download frame-interpolation model.
wget -O models/film_net_fp16.pt https://github.com/dajes/frame-interpolation-pytorch/releases/download/v1.0.2/film_net_fp16.pt
}
Expand Down
2 changes: 2 additions & 0 deletions runner/gen_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
text_to_image,
frame_interpolation,
upscale,
llm_generate
)
from fastapi.openapi.utils import get_openapi

Expand Down Expand Up @@ -85,6 +86,7 @@ def write_openapi(fname, entrypoint="runner"):
app.include_router(image_to_image.router)
app.include_router(image_to_video.router)
app.include_router(audio_to_text.router)
app.include_router(llm_generate.router)
app.include_router(frame_interpolation.router)
app.include_router(upscale.router)

Expand Down
121 changes: 121 additions & 0 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,79 @@
}
]
}
},
"/llm-generate": {
"post": {
"summary": "Llm Generate",
"operationId": "llm_generate",
"requestBody": {
"content": {
"application/x-www-form-urlencoded": {
"schema": {
"$ref": "#/components/schemas/Body_llm_generate_llm_generate_post"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/LlmResponse"
}
}
}
},
"400": {
"description": "Bad Request",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"401": {
"description": "Unauthorized",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
},
"security": [
{
"HTTPBearer": []
}
]
}
}
},
"components": {
Expand Down Expand Up @@ -667,6 +740,36 @@
],
"title": "Body_image_to_video_image_to_video_post"
},
"Body_llm_generate_llm_generate_post": {
"properties": {
"prompt": {
"type": "string",
"title": "Prompt"
},
"model_id": {
"type": "string",
"title": "Model Id",
"default": ""
},
"system_msg": {
"type": "string",
"title": "System Msg"
},
"temperature": {
"type": "number",
"title": "Temperature"
},
"max_tokens": {
"type": "integer",
"title": "Max Tokens"
}
},
"type": "object",
"required": [
"prompt"
],
"title": "Body_llm_generate_llm_generate_post"
},
"Body_upscale_upscale_post": {
"properties": {
"prompt": {
Expand Down Expand Up @@ -757,6 +860,24 @@
],
"title": "ImageResponse"
},
"LlmResponse": {
"properties": {
"response": {
"type": "string",
"title": "Response"
},
"tokens_used": {
"type": "integer",
"title": "Tokens Used"
}
},
"type": "object",
"required": [
"response",
"tokens_used"
],
"title": "LlmResponse"
},
"Media": {
"properties": {
"url": {
Expand Down
Loading

0 comments on commit 22cd05d

Please sign in to comment.