Skip to content

Commit

Permalink
Add Kubernetes autoscaler with configuration and state management
Browse files Browse the repository at this point in the history
  • Loading branch information
Sawyer committed Nov 4, 2024
1 parent 8e275e4 commit 5114545
Show file tree
Hide file tree
Showing 10 changed files with 443 additions and 224 deletions.
Empty file.
135 changes: 135 additions & 0 deletions k8s-autoscaler/k8s_autoscaler/api/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from fastapi import (
APIRouter,
Request,
Response,
BackgroundTasks,
HTTPException,
status,
Depends,
)
from fastapi.responses import StreamingResponse, JSONResponse
import httpx
import logging
import time
from typing import AsyncGenerator
from ..types import AutoscalerState, PodPhase
from ..config import Settings
from ..kubernetes import KubeCommand
from ..vllm import VLLMManager
from ..dependencies import get_settings, get_state, get_kube, get_vllm_manager

logger = logging.getLogger(__name__)

router = APIRouter()


async def stream_response(response: httpx.Response) -> AsyncGenerator[bytes, None]:
"""Stream response content."""
try:
async for chunk in response.aiter_bytes():
yield chunk
except httpx.HTTPError as e:
logger.error(f"Error streaming response: {e}")
raise HTTPException(status_code=502, detail="Error streaming from vLLM service")


@router.get("/health")
async def health_check(
kube: KubeCommand = Depends(get_kube), state: AutoscalerState = Depends(get_state)
):
"""Health check endpoint."""
phase = await kube.get_pod_phase()
current_replicas, desired_replicas = await kube.get_replicas()
return {
"status": "healthy",
"vllm_status": phase,
"vllm_running": phase == PodPhase.RUNNING,
"current_replicas": current_replicas,
"desired_replicas": desired_replicas,
"last_activity": time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(state.last_activity)
),
}


@router.post("/scale/{replicas}")
async def scale(
replicas: int,
background_tasks: BackgroundTasks,
kube: KubeCommand = Depends(get_kube),
vllm_manager: VLLMManager = Depends(get_vllm_manager),
) -> JSONResponse:
"""Manually scale the vLLM deployment."""
if replicas < 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Replica count must be non-negative",
)

if not await kube.scale_deployment(replicas):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to scale deployment",
)

if replicas > 0:
vllm_manager.reset_inactivity_timer(background_tasks)

return JSONResponse(
status_code=status.HTTP_200_OK,
content={"message": f"Scaling deployment to {replicas} replicas"},
)


@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy_request(
request: Request,
path: str,
background_tasks: BackgroundTasks,
settings: Settings = Depends(get_settings),
state: AutoscalerState = Depends(get_state),
vllm_manager: VLLMManager = Depends(get_vllm_manager),
) -> StreamingResponse:
"""Proxy requests to vLLM service, handling activation as needed."""
try:
if not await vllm_manager.ensure_running():
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"vLLM service activation failed after {settings.activation_timeout}s",
)

vllm_manager.reset_inactivity_timer(background_tasks)

if not state.http_client:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="HTTP client not initialized",
)

# Forward the request to vLLM
url = f"{settings.vllm_url_base}/{path}"
headers = dict(request.headers)
headers.pop("host", None) # Remove host header to avoid conflicts

vllm_response = await state.http_client.request(
method=request.method,
url=url,
headers=headers,
content=await request.body(),
params=request.query_params,
)

return StreamingResponse(
stream_response(vllm_response),
status_code=vllm_response.status_code,
headers=dict(vllm_response.headers),
)

except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing request: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error",
)
47 changes: 47 additions & 0 deletions k8s-autoscaler/k8s_autoscaler/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field
from functools import cached_property


class Settings(BaseSettings):
"""Application settings with validation and documentation."""

model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)

vllm_service_host: str = Field(
default="vllm-svc", description="Hostname of the vLLM service"
)
vllm_service_port: str = Field(
default="8000", description="Port of the vLLM service"
)
vllm_deployment: str = Field(
default="vllm", description="Name of the vLLM deployment"
)
kubernetes_namespace: str = Field(
default="default", description="Kubernetes namespace for the vLLM deployment"
)
inactivity_timeout: int = Field(
default=900,
description="Timeout in seconds before scaling down due to inactivity",
gt=0,
)
activation_timeout: int = Field(
default=120,
description="Timeout in seconds while waiting for vLLM to become ready",
gt=0,
)
proxy_timeout: float = Field(
default=30.0, description="Timeout in seconds for proxy requests", gt=0
)

@cached_property
def vllm_url_base(self) -> str:
"""Base URL for the vLLM service."""
return f"http://{self.vllm_service_host}:{self.vllm_service_port}"

@cached_property
def kubectl_base_cmd(self) -> str:
"""Base kubectl command with namespace."""
return f"kubectl -n {self.kubernetes_namespace}"
26 changes: 26 additions & 0 deletions k8s-autoscaler/k8s_autoscaler/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# app/dependencies.py
from fastapi import Depends
from .config import Settings
from .types import AutoscalerState
from .kubernetes import KubeCommand
from .vllm import VLLMManager


def get_settings() -> Settings:
return Settings()


def get_state() -> AutoscalerState:
return AutoscalerState()


def get_kube(settings: Settings = Depends(get_settings)) -> KubeCommand:
return KubeCommand(settings)


def get_vllm_manager(
settings: Settings = Depends(get_settings),
state: AutoscalerState = Depends(get_state),
kube: KubeCommand = Depends(get_kube),
) -> VLLMManager:
return VLLMManager(settings, state, kube)
71 changes: 71 additions & 0 deletions k8s-autoscaler/k8s_autoscaler/kubernetes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import asyncio
import subprocess
import logging
from .types import PodPhase
from .config import Settings

logger = logging.getLogger(__name__)


class KubeCommand:
"""Kubectl command builder and executor."""

def __init__(self, settings: Settings):
self.settings = settings

async def execute(self, cmd: str) -> tuple[bool, str]:
"""Execute a kubectl command and return success status and output."""
full_cmd = f"{self.settings.kubectl_base_cmd} {cmd}"
try:
process = await asyncio.create_subprocess_shell(
full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()

success = process.returncode == 0
output = stdout.decode().strip() if success else stderr.decode().strip()
if not success:
logger.error(f"kubectl command failed: {output}")
return success, output
except Exception as e:
logger.error(f"kubectl command failed: {e}")
return False, str(e)

async def get_pod_phase(self) -> PodPhase:
"""Get the phase of the vLLM pod."""
success, output = await self.execute(
"get pods -l app=vllm -o jsonpath='{.items[0].status.phase}'"
)
try:
return PodPhase(output) if success and output else PodPhase.UNKNOWN
except ValueError:
logger.warning(f"Unknown pod phase: {output}")
return PodPhase.UNKNOWN

async def scale_deployment(self, replicas: int) -> bool:
"""Scale vLLM deployment to specified replicas."""
if replicas < 0:
logger.error(f"Invalid replica count: {replicas}")
return False

success, output = await self.execute(
f"scale deployment {self.settings.vllm_deployment} --replicas={replicas}"
)
if success:
logger.info(f"Successfully scaled deployment to {replicas} replicas")
return success

async def get_replicas(self) -> tuple[int, int]:
"""Get current and desired replica counts."""
cmd = (
f"get deployment {self.settings.vllm_deployment} "
"-o jsonpath='{.status.replicas} {.spec.replicas}'"
)
success, output = await self.execute(cmd)
if success and output:
try:
current, desired = map(int, output.split())
return current, desired
except ValueError:
logger.error(f"Failed to parse replica counts: {output}")
return -1, -1
Loading

0 comments on commit 5114545

Please sign in to comment.