forked from livepeer/ai-worker
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
my branch update of llm from livepool
- Loading branch information
Showing
13 changed files
with
642 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import logging | ||
import os | ||
from typing import Dict, Any, Optional | ||
|
||
import torch | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig | ||
from app.pipelines.base import Pipeline | ||
from app.pipelines.utils import get_model_dir, get_torch_device | ||
from huggingface_hub import file_download, hf_hub_download | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LLMGeneratePipeline(Pipeline): | ||
def __init__(self, model_id: str): | ||
self.model_id = model_id | ||
kwargs = { | ||
"cache_dir": get_model_dir() | ||
} | ||
self.device = get_torch_device() | ||
folder_name = file_download.repo_folder_name( | ||
repo_id=model_id, repo_type="model" | ||
) | ||
folder_path = os.path.join(get_model_dir(), folder_name) | ||
|
||
# Check for fp16 variant | ||
has_fp16_variant = any( | ||
".fp16.safetensors" in fname | ||
for _, _, files in os.walk(folder_path) | ||
for fname in files | ||
) | ||
if self.device != "cpu" and has_fp16_variant: | ||
logger.info("LLMGeneratePipeline loading fp16 variant for %s", model_id) | ||
kwargs["torch_dtype"] = torch.float16 | ||
kwargs["variant"] = "fp16" | ||
|
||
# Load tokenizer | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) | ||
|
||
# Load model | ||
self.model = AutoModelForCausalLM.from_pretrained( | ||
model_id, **kwargs).to(self.device) | ||
|
||
# Set up generation config | ||
self.generation_config = self.model.generation_config | ||
|
||
# Optional: Add optimizations | ||
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" | ||
if sfast_enabled: | ||
logger.info( | ||
"LLMGeneratePipeline will be dynamically compiled with stable-fast for %s", | ||
model_id, | ||
) | ||
from app.pipelines.optim.sfast import compile_model | ||
self.model = compile_model(self.model) | ||
|
||
def __call__(self, prompt: str, system_msg: Optional[str] = None, | ||
temperature: Optional[float] = None, | ||
max_tokens: Optional[int] = None, **kwargs) -> Dict[str, Any]: | ||
if system_msg: | ||
input_text = f"{system_msg}\n\n{prompt}" | ||
else: | ||
input_text = prompt | ||
|
||
input_ids = self.tokenizer.encode( | ||
input_text, return_tensors="pt").to(self.device) | ||
|
||
# Update generation config | ||
gen_kwargs = {} | ||
if temperature is not None: | ||
gen_kwargs['temperature'] = temperature | ||
if max_tokens is not None: | ||
gen_kwargs['max_new_tokens'] = max_tokens | ||
|
||
# Merge generation config with provided kwargs | ||
gen_kwargs = {**self.generation_config.to_dict(), **gen_kwargs, **kwargs} | ||
|
||
# Generate response | ||
with torch.no_grad(): | ||
output = self.model.generate( | ||
input_ids, | ||
**gen_kwargs | ||
) | ||
|
||
# Decode the response | ||
response = self.tokenizer.decode(output[0], skip_special_tokens=True) | ||
|
||
# Calculate tokens used | ||
tokens_used = len(output[0]) | ||
|
||
return { | ||
"response": response.strip(), | ||
"tokens_used": tokens_used | ||
} | ||
|
||
def __str__(self) -> str: | ||
return f"LLMPipeline model_id={self.model_id}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import logging | ||
import os | ||
from typing import Annotated, Optional | ||
from fastapi import APIRouter, Depends, Form, status | ||
from fastapi.responses import JSONResponse | ||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | ||
from app.dependencies import get_pipeline | ||
from app.pipelines.base import Pipeline | ||
from app.routes.util import HTTPError, LlmResponse, TextResponse, http_error | ||
|
||
router = APIRouter() | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
RESPONSES = { | ||
status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, | ||
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, | ||
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, | ||
} | ||
|
||
|
||
@router.post("/llm-generate", response_model=LlmResponse, responses=RESPONSES) | ||
@router.post("/llm-generate/", response_model=LlmResponse, responses=RESPONSES, include_in_schema=False) | ||
async def llm_generate( | ||
prompt: Annotated[str, Form()], | ||
model_id: Annotated[str, Form()] = "", | ||
system_msg: Annotated[str, Form()] = None, | ||
temperature: Annotated[float, Form()] = None, | ||
max_tokens: Annotated[int, Form()] = None, | ||
pipeline: Pipeline = Depends(get_pipeline), | ||
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), | ||
): | ||
auth_token = os.environ.get("AUTH_TOKEN") | ||
if auth_token: | ||
if not token or token.credentials != auth_token: | ||
return JSONResponse( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
content=http_error("Invalid bearer token"), | ||
) | ||
|
||
if model_id != "" and model_id != pipeline.model_id: | ||
return JSONResponse( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
content=http_error( | ||
f"pipeline configured with {pipeline.model_id} but called with " | ||
f"{model_id}" | ||
), | ||
) | ||
|
||
try: | ||
result = pipeline( | ||
prompt=prompt, | ||
system_msg=system_msg, | ||
temperature=temperature, | ||
max_tokens=max_tokens | ||
) | ||
return JSONResponse(content=result) | ||
except Exception as e: | ||
logger.error(f"LLM processing error: {str(e)}") | ||
return JSONResponse( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
content=http_error("Internal server error during LLM processing."), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.