diff --git a/index.yaml b/index.yaml index aae8e4c2..65fb7c20 100644 --- a/index.yaml +++ b/index.yaml @@ -96,5 +96,7 @@ modelProviders: reference: ./groq-model-provider voyage-model-provider: reference: ./voyage-model-provider + vllm-model-provider: + reference: ./vllm-model-provider anthropic-bedrock-model-provider: reference: ./anthropic-bedrock-model-provider diff --git a/vllm-model-provider/go.mod b/vllm-model-provider/go.mod new file mode 100644 index 00000000..04ea51d8 --- /dev/null +++ b/vllm-model-provider/go.mod @@ -0,0 +1,5 @@ +module github.com/obot-platform/tools/vllm-model-provider + +go 1.23.4 + +require github.com/gptscript-ai/chat-completion-client v0.0.0-20241216203633-5c0178fb89ed diff --git a/vllm-model-provider/go.sum b/vllm-model-provider/go.sum new file mode 100644 index 00000000..1d96f708 --- /dev/null +++ b/vllm-model-provider/go.sum @@ -0,0 +1,2 @@ +github.com/gptscript-ai/chat-completion-client v0.0.0-20241216203633-5c0178fb89ed h1:qMHm0IYpKgmw4KHX76RMB/duSICxo7IZuimPCKb0qG4= +github.com/gptscript-ai/chat-completion-client v0.0.0-20241216203633-5c0178fb89ed/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo= diff --git a/vllm-model-provider/main.go b/vllm-model-provider/main.go new file mode 100644 index 00000000..e13fd1c8 --- /dev/null +++ b/vllm-model-provider/main.go @@ -0,0 +1,31 @@ +package main + +import ( + "fmt" + "os" + + "github.com/obot-platform/tools/vllm-model-provider/server" +) + +func main() { + apiKey := os.Getenv("OBOT_VLLM_MODEL_PROVIDER_API_KEY") + if apiKey == "" { + fmt.Println("OBOT_VLLM_MODEL_PROVIDER_API_KEY environment variable not set") + os.Exit(1) + } + + endpoint := os.Getenv("OBOT_VLLM_MODEL_PROVIDER_ENDPOINT") + if endpoint == "" { + fmt.Println("OBOT_VLLM_MODEL_PROVIDER_ENDPOINT environment variable not set") + os.Exit(1) + } + + port := os.Getenv("PORT") + if port == "" { + port = "8000" + } + + if err := server.Run(apiKey, endpoint, port); err != nil { + panic(err) + } +} diff --git a/vllm-model-provider/server/server.go b/vllm-model-provider/server/server.go new file mode 100644 index 00000000..802d406c --- /dev/null +++ b/vllm-model-provider/server/server.go @@ -0,0 +1,120 @@ +package server + +import ( + "bytes" + "compress/gzip" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httputil" + "net/url" + "path" + + openai "github.com/gptscript-ai/chat-completion-client" +) + +func Run(apiKey, endpointStr, port string) error { + // Parse the endpoint URL + endpoint, err := url.Parse(endpointStr) + if err != nil { + return fmt.Errorf("Invalid endpoint URL %q: %w", endpointStr, err) + } + + if endpoint.Scheme == "" { + if endpoint.Hostname() == "localhost" || endpoint.Hostname() == "127.0.0.1" { + endpoint.Scheme = "http" + } else { + endpoint.Scheme = "https" + } + } + + mux := http.NewServeMux() + + s := &server{ + apiKey: apiKey, + port: port, + endpoint: endpoint, + } + + mux.HandleFunc("/{$}", s.healthz) + mux.Handle("GET /v1/models", &httputil.ReverseProxy{ + Director: s.proxy, + ModifyResponse: s.rewriteModelsResponse, + }) + mux.Handle("/{path...}", &httputil.ReverseProxy{ + Director: s.proxy, + }) + + httpServer := &http.Server{ + Addr: "127.0.0.1:" + port, + Handler: mux, + } + + if err := httpServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + return err + } + + return nil +} + +type server struct { + apiKey, port string + endpoint *url.URL +} + +func (s *server) healthz(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("http://127.0.0.1:" + s.port)) +} + +func (s *server) rewriteModelsResponse(resp *http.Response) error { + if resp.StatusCode != http.StatusOK { + return nil + } + + originalBody := resp.Body + defer originalBody.Close() + + if resp.Header.Get("Content-Encoding") == "gzip" { + var err error + originalBody, err = gzip.NewReader(originalBody) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer originalBody.Close() + resp.Header.Del("Content-Encoding") + } + + var models openai.ModelsList + if err := json.NewDecoder(originalBody).Decode(&models); err != nil { + return fmt.Errorf("failed to decode models response: %w, %d, %v", err, resp.StatusCode, resp.Header) + } + + // Set all models as LLM + for i, model := range models.Models { + if model.Metadata == nil { + model.Metadata = make(map[string]string) + } + model.Metadata["usage"] = "llm" + models.Models[i] = model + } + + b, err := json.Marshal(models) + if err != nil { + return fmt.Errorf("failed to marshal models response: %w", err) + } + + resp.Body = io.NopCloser(bytes.NewReader(b)) + resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(b))) + return nil +} + +func (s *server) proxy(req *http.Request) { + u := *s.endpoint + u.Path = path.Join(u.Path, req.URL.Path) + req.URL = &u + req.Host = req.URL.Host + + req.Header.Set("Authorization", "Bearer "+s.apiKey) +} diff --git a/vllm-model-provider/tool.gpt b/vllm-model-provider/tool.gpt new file mode 100644 index 00000000..5f2f9fce --- /dev/null +++ b/vllm-model-provider/tool.gpt @@ -0,0 +1,11 @@ +Name: vLLM +Description: Model Provider for vLLM +Metadata: envVars: OBOT_VLLM_MODEL_PROVIDER_ENDPOINT,OBOT_VLLM_MODEL_PROVIDER_API_KEY +Model Provider: true +Credential: ../model-provider-credential as vllm-model-provider + +#!sys.daemon ${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool + +--- +!metadata:*:icon +/admin/assets/vllm-logo.svg \ No newline at end of file