From b33bb03df6c93034d27db2d3c247c058c2acf059 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Thu, 18 Jul 2024 22:56:56 +0200 Subject: [PATCH] feat: OpenAI Compat Options - to allow setting a different embeddings endpoint, e.g. for Ollama --- embed_compat.go | 2 +- embed_openai.go | 61 +++++++++++++++++++++++++++++++++++++------------ 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/embed_compat.go b/embed_compat.go index 18cada5..234011a 100644 --- a/embed_compat.go +++ b/embed_compat.go @@ -83,5 +83,5 @@ func NewEmbeddingFuncAzureOpenAI(apiKey string, deploymentURL string, apiVersion if apiVersion == "" { apiVersion = azureDefaultAPIVersion } - return newEmbeddingFuncOpenAICompat(deploymentURL, apiKey, model, nil, map[string]string{"api-key": apiKey}, map[string]string{"api-version": apiVersion}) + return NewEmbeddingFuncOpenAICompat(deploymentURL, apiKey, model, nil, WithOpenAICompatHeaders(map[string]string{"api-key": apiKey}), WithOpenAICompatQueryParams(map[string]string{"api-version": apiVersion})) } diff --git a/embed_openai.go b/embed_openai.go index de3de82..007040f 100644 --- a/embed_openai.go +++ b/embed_openai.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "net/url" "os" "sync" ) @@ -52,16 +53,6 @@ func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) Embedding // - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md // - etc. // -// The `normalized` parameter indicates whether the vectors returned by the embedding -// model are already normalized, as is the case for OpenAI's and Mistral's models. -// The flag is optional. If it's nil, it will be autodetected on the first request -// (which bears a small risk that the vector just happens to have a length of 1). -func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc { - return newEmbeddingFuncOpenAICompat(baseURL, apiKey, model, normalized, nil, nil) -} - -// newEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text -// using an OpenAI compatible API. // It offers options to set request headers and query parameters // e.g. to pass the `api-key` header and the `api-version` query parameter for Azure OpenAI. // @@ -69,12 +60,17 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo // model are already normalized, as is the case for OpenAI's and Mistral's models. // The flag is optional. If it's nil, it will be autodetected on the first request // (which bears a small risk that the vector just happens to have a length of 1). -func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool, headers map[string]string, queryParams map[string]string) EmbeddingFunc { +func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool, opts ...OpenAICompatOption) EmbeddingFunc { // We don't set a default timeout here, although it's usually a good idea. // In our case though, the library user can set the timeout on the context, // and it might have to be a long timeout, depending on the text length. client := &http.Client{} + cfg := DefaultOpenAICompatOptions() + for _, opt := range opts { + opt(cfg) + } + var checkedNormalized bool checkNormalized := sync.Once{} @@ -88,9 +84,14 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo return nil, fmt.Errorf("couldn't marshal request body: %w", err) } + fullURL, err := url.JoinPath(baseURL, cfg.EmbeddingsEndpoint) + if err != nil { + return nil, fmt.Errorf("couldn't join base URL and endpoint: %w", err) + } + // Create the request. Creating it with context is important for a timeout // to be possible, because the client is configured without a timeout. - req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/embeddings", bytes.NewBuffer(reqBody)) + req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewBuffer(reqBody)) if err != nil { return nil, fmt.Errorf("couldn't create request: %w", err) } @@ -98,13 +99,13 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo req.Header.Set("Authorization", "Bearer "+apiKey) // Add headers - for k, v := range headers { + for k, v := range cfg.Headers { req.Header.Add(k, v) } // Add query parameters q := req.URL.Query() - for k, v := range queryParams { + for k, v := range cfg.QueryParams { q.Add(k, v) } req.URL.RawQuery = q.Encode() @@ -158,3 +159,35 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo return v, nil } } + +type OpenAICompatOptions struct { + EmbeddingsEndpoint string + Headers map[string]string + QueryParams map[string]string +} + +type OpenAICompatOption func(*OpenAICompatOptions) + +func WithOpenAICompatEmbeddingsEndpointOverride(endpoint string) OpenAICompatOption { + return func(o *OpenAICompatOptions) { + o.EmbeddingsEndpoint = endpoint + } +} + +func WithOpenAICompatHeaders(headers map[string]string) OpenAICompatOption { + return func(o *OpenAICompatOptions) { + o.Headers = headers + } +} + +func WithOpenAICompatQueryParams(queryParams map[string]string) OpenAICompatOption { + return func(o *OpenAICompatOptions) { + o.QueryParams = queryParams + } +} + +func DefaultOpenAICompatOptions() *OpenAICompatOptions { + return &OpenAICompatOptions{ + EmbeddingsEndpoint: "/embeddings", + } +}