Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add vLLM model provider #292

Merged
merged 11 commits into from
Dec 20, 2024
2 changes: 2 additions & 0 deletions index.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,7 @@ modelProviders:
reference: ./groq-model-provider
voyage-model-provider:
reference: ./voyage-model-provider
vllm-model-provider:
reference: ./vllm-model-provider
anthropic-bedrock-model-provider:
reference: ./anthropic-bedrock-model-provider
5 changes: 5 additions & 0 deletions vllm-model-provider/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module github.com/obot-platform/tools/vllm-model-provider

go 1.23.4

require github.com/gptscript-ai/chat-completion-client v0.0.0-20241216203633-5c0178fb89ed
2 changes: 2 additions & 0 deletions vllm-model-provider/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/gptscript-ai/chat-completion-client v0.0.0-20241216203633-5c0178fb89ed h1:qMHm0IYpKgmw4KHX76RMB/duSICxo7IZuimPCKb0qG4=
github.com/gptscript-ai/chat-completion-client v0.0.0-20241216203633-5c0178fb89ed/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
31 changes: 31 additions & 0 deletions vllm-model-provider/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package main

import (
"fmt"
"os"

"github.com/obot-platform/tools/vllm-model-provider/server"
)

func main() {
apiKey := os.Getenv("OBOT_VLLM_MODEL_PROVIDER_API_KEY")
if apiKey == "" {
fmt.Println("OBOT_VLLM_MODEL_PROVIDER_API_KEY environment variable not set")
os.Exit(1)
}

endpoint := os.Getenv("OBOT_VLLM_MODEL_PROVIDER_ENDPOINT")
if endpoint == "" {
fmt.Println("OBOT_VLLM_MODEL_PROVIDER_ENDPOINT environment variable not set")
os.Exit(1)
}

port := os.Getenv("PORT")
if port == "" {
port = "8000"
}

if err := server.Run(apiKey, endpoint, port); err != nil {
panic(err)
}
}
120 changes: 120 additions & 0 deletions vllm-model-provider/server/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package server

import (
"bytes"
"compress/gzip"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httputil"
"net/url"
"path"

openai "github.com/gptscript-ai/chat-completion-client"
)

func Run(apiKey, endpointStr, port string) error {
// Parse the endpoint URL
endpoint, err := url.Parse(endpointStr)
if err != nil {
return fmt.Errorf("Invalid endpoint URL %q: %w", endpointStr, err)
}

if endpoint.Scheme == "" {
if endpoint.Hostname() == "localhost" || endpoint.Hostname() == "127.0.0.1" {
endpoint.Scheme = "http"
} else {
endpoint.Scheme = "https"
}
}

mux := http.NewServeMux()

s := &server{
apiKey: apiKey,
port: port,
endpoint: endpoint,
}

mux.HandleFunc("/{$}", s.healthz)
mux.Handle("GET /v1/models", &httputil.ReverseProxy{
Director: s.proxy,
ModifyResponse: s.rewriteModelsResponse,
})
mux.Handle("/{path...}", &httputil.ReverseProxy{
Director: s.proxy,
})

httpServer := &http.Server{
Addr: "127.0.0.1:" + port,
Handler: mux,
}

if err := httpServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
return err
}

return nil
}

type server struct {
apiKey, port string
endpoint *url.URL
}

func (s *server) healthz(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("http://127.0.0.1:" + s.port))
}

func (s *server) rewriteModelsResponse(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
return nil
}

originalBody := resp.Body
defer originalBody.Close()

if resp.Header.Get("Content-Encoding") == "gzip" {
var err error
originalBody, err = gzip.NewReader(originalBody)
if err != nil {
return fmt.Errorf("failed to create gzip reader: %w", err)
}
defer originalBody.Close()
resp.Header.Del("Content-Encoding")
}

var models openai.ModelsList
if err := json.NewDecoder(originalBody).Decode(&models); err != nil {
return fmt.Errorf("failed to decode models response: %w, %d, %v", err, resp.StatusCode, resp.Header)
}

// Set all models as LLM
for i, model := range models.Models {
if model.Metadata == nil {
model.Metadata = make(map[string]string)
}
model.Metadata["usage"] = "llm"
models.Models[i] = model
}

b, err := json.Marshal(models)
if err != nil {
return fmt.Errorf("failed to marshal models response: %w", err)
}

resp.Body = io.NopCloser(bytes.NewReader(b))
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(b)))
return nil
}

func (s *server) proxy(req *http.Request) {
u := *s.endpoint
u.Path = path.Join(u.Path, req.URL.Path)
req.URL = &u
req.Host = req.URL.Host

req.Header.Set("Authorization", "Bearer "+s.apiKey)
}
11 changes: 11 additions & 0 deletions vllm-model-provider/tool.gpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Name: vLLM
Description: Model Provider for vLLM
Metadata: envVars: OBOT_VLLM_MODEL_PROVIDER_ENDPOINT,OBOT_VLLM_MODEL_PROVIDER_API_KEY
Model Provider: true
Credential: ../model-provider-credential as vllm-model-provider

#!sys.daemon ${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool

---
!metadata:*:icon
/admin/assets/vllm-logo.svg