-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Kubernetes autoscaler with configuration and state management
- Loading branch information
Sawyer
committed
Nov 4, 2024
1 parent
8e275e4
commit 5114545
Showing
10 changed files
with
443 additions
and
224 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from fastapi import ( | ||
APIRouter, | ||
Request, | ||
Response, | ||
BackgroundTasks, | ||
HTTPException, | ||
status, | ||
Depends, | ||
) | ||
from fastapi.responses import StreamingResponse, JSONResponse | ||
import httpx | ||
import logging | ||
import time | ||
from typing import AsyncGenerator | ||
from ..types import AutoscalerState, PodPhase | ||
from ..config import Settings | ||
from ..kubernetes import KubeCommand | ||
from ..vllm import VLLMManager | ||
from ..dependencies import get_settings, get_state, get_kube, get_vllm_manager | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
router = APIRouter() | ||
|
||
|
||
async def stream_response(response: httpx.Response) -> AsyncGenerator[bytes, None]: | ||
"""Stream response content.""" | ||
try: | ||
async for chunk in response.aiter_bytes(): | ||
yield chunk | ||
except httpx.HTTPError as e: | ||
logger.error(f"Error streaming response: {e}") | ||
raise HTTPException(status_code=502, detail="Error streaming from vLLM service") | ||
|
||
|
||
@router.get("/health") | ||
async def health_check( | ||
kube: KubeCommand = Depends(get_kube), state: AutoscalerState = Depends(get_state) | ||
): | ||
"""Health check endpoint.""" | ||
phase = await kube.get_pod_phase() | ||
current_replicas, desired_replicas = await kube.get_replicas() | ||
return { | ||
"status": "healthy", | ||
"vllm_status": phase, | ||
"vllm_running": phase == PodPhase.RUNNING, | ||
"current_replicas": current_replicas, | ||
"desired_replicas": desired_replicas, | ||
"last_activity": time.strftime( | ||
"%Y-%m-%d %H:%M:%S", time.localtime(state.last_activity) | ||
), | ||
} | ||
|
||
|
||
@router.post("/scale/{replicas}") | ||
async def scale( | ||
replicas: int, | ||
background_tasks: BackgroundTasks, | ||
kube: KubeCommand = Depends(get_kube), | ||
vllm_manager: VLLMManager = Depends(get_vllm_manager), | ||
) -> JSONResponse: | ||
"""Manually scale the vLLM deployment.""" | ||
if replicas < 0: | ||
raise HTTPException( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
detail="Replica count must be non-negative", | ||
) | ||
|
||
if not await kube.scale_deployment(replicas): | ||
raise HTTPException( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail="Failed to scale deployment", | ||
) | ||
|
||
if replicas > 0: | ||
vllm_manager.reset_inactivity_timer(background_tasks) | ||
|
||
return JSONResponse( | ||
status_code=status.HTTP_200_OK, | ||
content={"message": f"Scaling deployment to {replicas} replicas"}, | ||
) | ||
|
||
|
||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||
async def proxy_request( | ||
request: Request, | ||
path: str, | ||
background_tasks: BackgroundTasks, | ||
settings: Settings = Depends(get_settings), | ||
state: AutoscalerState = Depends(get_state), | ||
vllm_manager: VLLMManager = Depends(get_vllm_manager), | ||
) -> StreamingResponse: | ||
"""Proxy requests to vLLM service, handling activation as needed.""" | ||
try: | ||
if not await vllm_manager.ensure_running(): | ||
raise HTTPException( | ||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | ||
detail=f"vLLM service activation failed after {settings.activation_timeout}s", | ||
) | ||
|
||
vllm_manager.reset_inactivity_timer(background_tasks) | ||
|
||
if not state.http_client: | ||
raise HTTPException( | ||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | ||
detail="HTTP client not initialized", | ||
) | ||
|
||
# Forward the request to vLLM | ||
url = f"{settings.vllm_url_base}/{path}" | ||
headers = dict(request.headers) | ||
headers.pop("host", None) # Remove host header to avoid conflicts | ||
|
||
vllm_response = await state.http_client.request( | ||
method=request.method, | ||
url=url, | ||
headers=headers, | ||
content=await request.body(), | ||
params=request.query_params, | ||
) | ||
|
||
return StreamingResponse( | ||
stream_response(vllm_response), | ||
status_code=vllm_response.status_code, | ||
headers=dict(vllm_response.headers), | ||
) | ||
|
||
except HTTPException: | ||
raise | ||
except Exception as e: | ||
logger.error(f"Error processing request: {e}", exc_info=True) | ||
raise HTTPException( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail="Internal server error", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from pydantic_settings import BaseSettings, SettingsConfigDict | ||
from pydantic import Field | ||
from functools import cached_property | ||
|
||
|
||
class Settings(BaseSettings): | ||
"""Application settings with validation and documentation.""" | ||
|
||
model_config = SettingsConfigDict( | ||
env_file=".env", env_file_encoding="utf-8", extra="ignore" | ||
) | ||
|
||
vllm_service_host: str = Field( | ||
default="vllm-svc", description="Hostname of the vLLM service" | ||
) | ||
vllm_service_port: str = Field( | ||
default="8000", description="Port of the vLLM service" | ||
) | ||
vllm_deployment: str = Field( | ||
default="vllm", description="Name of the vLLM deployment" | ||
) | ||
kubernetes_namespace: str = Field( | ||
default="default", description="Kubernetes namespace for the vLLM deployment" | ||
) | ||
inactivity_timeout: int = Field( | ||
default=900, | ||
description="Timeout in seconds before scaling down due to inactivity", | ||
gt=0, | ||
) | ||
activation_timeout: int = Field( | ||
default=120, | ||
description="Timeout in seconds while waiting for vLLM to become ready", | ||
gt=0, | ||
) | ||
proxy_timeout: float = Field( | ||
default=30.0, description="Timeout in seconds for proxy requests", gt=0 | ||
) | ||
|
||
@cached_property | ||
def vllm_url_base(self) -> str: | ||
"""Base URL for the vLLM service.""" | ||
return f"http://{self.vllm_service_host}:{self.vllm_service_port}" | ||
|
||
@cached_property | ||
def kubectl_base_cmd(self) -> str: | ||
"""Base kubectl command with namespace.""" | ||
return f"kubectl -n {self.kubernetes_namespace}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# app/dependencies.py | ||
from fastapi import Depends | ||
from .config import Settings | ||
from .types import AutoscalerState | ||
from .kubernetes import KubeCommand | ||
from .vllm import VLLMManager | ||
|
||
|
||
def get_settings() -> Settings: | ||
return Settings() | ||
|
||
|
||
def get_state() -> AutoscalerState: | ||
return AutoscalerState() | ||
|
||
|
||
def get_kube(settings: Settings = Depends(get_settings)) -> KubeCommand: | ||
return KubeCommand(settings) | ||
|
||
|
||
def get_vllm_manager( | ||
settings: Settings = Depends(get_settings), | ||
state: AutoscalerState = Depends(get_state), | ||
kube: KubeCommand = Depends(get_kube), | ||
) -> VLLMManager: | ||
return VLLMManager(settings, state, kube) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import asyncio | ||
import subprocess | ||
import logging | ||
from .types import PodPhase | ||
from .config import Settings | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class KubeCommand: | ||
"""Kubectl command builder and executor.""" | ||
|
||
def __init__(self, settings: Settings): | ||
self.settings = settings | ||
|
||
async def execute(self, cmd: str) -> tuple[bool, str]: | ||
"""Execute a kubectl command and return success status and output.""" | ||
full_cmd = f"{self.settings.kubectl_base_cmd} {cmd}" | ||
try: | ||
process = await asyncio.create_subprocess_shell( | ||
full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE | ||
) | ||
stdout, stderr = await process.communicate() | ||
|
||
success = process.returncode == 0 | ||
output = stdout.decode().strip() if success else stderr.decode().strip() | ||
if not success: | ||
logger.error(f"kubectl command failed: {output}") | ||
return success, output | ||
except Exception as e: | ||
logger.error(f"kubectl command failed: {e}") | ||
return False, str(e) | ||
|
||
async def get_pod_phase(self) -> PodPhase: | ||
"""Get the phase of the vLLM pod.""" | ||
success, output = await self.execute( | ||
"get pods -l app=vllm -o jsonpath='{.items[0].status.phase}'" | ||
) | ||
try: | ||
return PodPhase(output) if success and output else PodPhase.UNKNOWN | ||
except ValueError: | ||
logger.warning(f"Unknown pod phase: {output}") | ||
return PodPhase.UNKNOWN | ||
|
||
async def scale_deployment(self, replicas: int) -> bool: | ||
"""Scale vLLM deployment to specified replicas.""" | ||
if replicas < 0: | ||
logger.error(f"Invalid replica count: {replicas}") | ||
return False | ||
|
||
success, output = await self.execute( | ||
f"scale deployment {self.settings.vllm_deployment} --replicas={replicas}" | ||
) | ||
if success: | ||
logger.info(f"Successfully scaled deployment to {replicas} replicas") | ||
return success | ||
|
||
async def get_replicas(self) -> tuple[int, int]: | ||
"""Get current and desired replica counts.""" | ||
cmd = ( | ||
f"get deployment {self.settings.vllm_deployment} " | ||
"-o jsonpath='{.status.replicas} {.spec.replicas}'" | ||
) | ||
success, output = await self.execute(cmd) | ||
if success and output: | ||
try: | ||
current, desired = map(int, output.split()) | ||
return current, desired | ||
except ValueError: | ||
logger.error(f"Failed to parse replica counts: {output}") | ||
return -1, -1 |
Oops, something went wrong.