Skip to content

Commit

Permalink
feat: add vLLM model provider (#292)
Browse files Browse the repository at this point in the history
Users can now use vLLM as a model provider. Tested with llama3.2 3b and with Rubra models.
  • Loading branch information
sanjay920 authored Dec 20, 2024
1 parent c323717 commit a0c5e92
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 0 deletions.
2 changes: 2 additions & 0 deletions index.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,7 @@ modelProviders:
reference: ./groq-model-provider
voyage-model-provider:
reference: ./voyage-model-provider
vllm-model-provider:
reference: ./vllm-model-provider
anthropic-bedrock-model-provider:
reference: ./anthropic-bedrock-model-provider
5 changes: 5 additions & 0 deletions vllm-model-provider/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module github.com/obot-platform/tools/vllm-model-provider

go 1.23.4

require github.com/gptscript-ai/chat-completion-client v0.0.0-20241216203633-5c0178fb89ed
2 changes: 2 additions & 0 deletions vllm-model-provider/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/gptscript-ai/chat-completion-client v0.0.0-20241216203633-5c0178fb89ed h1:qMHm0IYpKgmw4KHX76RMB/duSICxo7IZuimPCKb0qG4=
github.com/gptscript-ai/chat-completion-client v0.0.0-20241216203633-5c0178fb89ed/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
31 changes: 31 additions & 0 deletions vllm-model-provider/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package main

import (
"fmt"
"os"

"github.com/obot-platform/tools/vllm-model-provider/server"
)

func main() {
apiKey := os.Getenv("OBOT_VLLM_MODEL_PROVIDER_API_KEY")
if apiKey == "" {
fmt.Println("OBOT_VLLM_MODEL_PROVIDER_API_KEY environment variable not set")
os.Exit(1)
}

endpoint := os.Getenv("OBOT_VLLM_MODEL_PROVIDER_ENDPOINT")
if endpoint == "" {
fmt.Println("OBOT_VLLM_MODEL_PROVIDER_ENDPOINT environment variable not set")
os.Exit(1)
}

port := os.Getenv("PORT")
if port == "" {
port = "8000"
}

if err := server.Run(apiKey, endpoint, port); err != nil {
panic(err)
}
}
120 changes: 120 additions & 0 deletions vllm-model-provider/server/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package server

import (
"bytes"
"compress/gzip"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httputil"
"net/url"
"path"

openai "github.com/gptscript-ai/chat-completion-client"
)

func Run(apiKey, endpointStr, port string) error {
// Parse the endpoint URL
endpoint, err := url.Parse(endpointStr)
if err != nil {
return fmt.Errorf("Invalid endpoint URL %q: %w", endpointStr, err)
}

if endpoint.Scheme == "" {
if endpoint.Hostname() == "localhost" || endpoint.Hostname() == "127.0.0.1" {
endpoint.Scheme = "http"
} else {
endpoint.Scheme = "https"
}
}

mux := http.NewServeMux()

s := &server{
apiKey: apiKey,
port: port,
endpoint: endpoint,
}

mux.HandleFunc("/{$}", s.healthz)
mux.Handle("GET /v1/models", &httputil.ReverseProxy{
Director: s.proxy,
ModifyResponse: s.rewriteModelsResponse,
})
mux.Handle("/{path...}", &httputil.ReverseProxy{
Director: s.proxy,
})

httpServer := &http.Server{
Addr: "127.0.0.1:" + port,
Handler: mux,
}

if err := httpServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
return err
}

return nil
}

type server struct {
apiKey, port string
endpoint *url.URL
}

func (s *server) healthz(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("http://127.0.0.1:" + s.port))
}

func (s *server) rewriteModelsResponse(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
return nil
}

originalBody := resp.Body
defer originalBody.Close()

if resp.Header.Get("Content-Encoding") == "gzip" {
var err error
originalBody, err = gzip.NewReader(originalBody)
if err != nil {
return fmt.Errorf("failed to create gzip reader: %w", err)
}
defer originalBody.Close()
resp.Header.Del("Content-Encoding")
}

var models openai.ModelsList
if err := json.NewDecoder(originalBody).Decode(&models); err != nil {
return fmt.Errorf("failed to decode models response: %w, %d, %v", err, resp.StatusCode, resp.Header)
}

// Set all models as LLM
for i, model := range models.Models {
if model.Metadata == nil {
model.Metadata = make(map[string]string)
}
model.Metadata["usage"] = "llm"
models.Models[i] = model
}

b, err := json.Marshal(models)
if err != nil {
return fmt.Errorf("failed to marshal models response: %w", err)
}

resp.Body = io.NopCloser(bytes.NewReader(b))
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(b)))
return nil
}

func (s *server) proxy(req *http.Request) {
u := *s.endpoint
u.Path = path.Join(u.Path, req.URL.Path)
req.URL = &u
req.Host = req.URL.Host

req.Header.Set("Authorization", "Bearer "+s.apiKey)
}
11 changes: 11 additions & 0 deletions vllm-model-provider/tool.gpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Name: vLLM
Description: Model Provider for vLLM
Metadata: envVars: OBOT_VLLM_MODEL_PROVIDER_ENDPOINT,OBOT_VLLM_MODEL_PROVIDER_API_KEY
Model Provider: true
Credential: ../model-provider-credential as vllm-model-provider

#!sys.daemon ${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool

---
!metadata:*:icon
/admin/assets/vllm-logo.svg

0 comments on commit a0c5e92

Please sign in to comment.