From fcafe49697816c3677a6b8c91386df9860b49aed Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Mon, 14 Oct 2024 07:59:51 +0100 Subject: [PATCH 1/2] Standard Prompt Format This introduces a more customizable prompt ability. It is now possible to set system context etc along with parameters Closes: #7 --- README.md | 81 ++++++++++++++++++++++++++++++ examples/ollama/main.go | 27 +++++----- examples/openai/main.go | 17 +++++-- pkg/backend/backend.go | 43 ++++++++++++++-- pkg/backend/ollama_backend.go | 45 ++++++++++++----- pkg/backend/ollama_backend_test.go | 31 +++++++----- pkg/backend/openai_backend.go | 59 +++++++++++++++------- pkg/backend/openai_backend_test.go | 13 ++++- 8 files changed, 254 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index 8246ed7..763e433 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,87 @@ Should you wish, the docker-compose will automate the setup of the database. Best bet is to see `/examples/*` for reference, this explains how to use the library with examples for generation, embeddings and implementing RAG. +There are currently two backend systems supported, Ollama and OpenAI, with +the ability to generate embeddings for RAG. + +## Ollama + +First create a Backend object + +```go +generationBackend := backend.NewOllamaBackend("http://localhost:11434", "llama3", time.Duration(10*time.Second)) +``` + +Create a prompt + +```go +prompt := backend.NewPrompt(). + AddMessage("system", "You are an AI assistant. Use the provided context to answer the user's question as accurately as possible."). + AddMessage("user", "What is love?"). + SetParameters(backend.Parameters{ + MaxTokens: 150, + Temperature: 0.7, + TopP: 0.9, + }) +``` + +Generate a response + +```go +response, err := generationBackend.Generate(ctx, prompt) +if err != nil { + log.Fatalf("Failed to generate response: %v", err) +} +``` + +## OpenAI + +First create a Backend object + +```go +generationBackend = backend.NewOpenAIBackend("API_KEY", "gpt-3.5-turbo", 10*time.Second) +``` + +Create a prompt + +```go +prompt := backend.NewPrompt(). + AddMessage("system", "You are an AI assistant. Use the provided context to answer the user's question as accurately as possible."). + AddMessage("user", "How much is too much?"). + SetParameters(backend.Parameters{ + MaxTokens: 150, + Temperature: 0.7, + TopP: 0.9, + FrequencyPenalty: 0.5, + PresencePenalty: 0.6, + }) +``` + +Generate a response + +```go +response, err := generationBackend.Generate(ctx, prompt) +if err != nil { + log.Fatalf("Failed to generate response: %v", err) +} +``` + +## RAG + +To generate embeddings for RAG, you can use the `Embeddings` interface in both +Ollama and OpenAI backends. + +```go +embedding, err := embeddingBackend.Embed(ctx, "Mickey mouse is a real human being") +if err != nil { + log.Fatalf("Error generating embedding: %v", err) +} +log.Println("Embedding generated") +``` + +A database is also required, we have support for PostGres with pgvector. See `/examples/*` +for reference. + # 📝 Contributing We welcome contributions! Please submit a pull request or raise an issue if diff --git a/examples/ollama/main.go b/examples/ollama/main.go index 0190683..377fc32 100644 --- a/examples/ollama/main.go +++ b/examples/ollama/main.go @@ -18,21 +18,12 @@ var ( ) func main() { - // Initialize Config - - // Select backends based on config - var embeddingBackend backend.Backend - var generationBackend backend.Backend - - // Choose the backend for embeddings based on the config - - embeddingBackend = backend.NewOllamaBackend(ollamaHost, ollamaEmbModel) + // Configure the Ollama backend for both embedding and generation + embeddingBackend := backend.NewOllamaBackend(ollamaHost, ollamaEmbModel, time.Duration(10*time.Second)) log.Printf("Embedding backend LLM: %s", ollamaEmbModel) - // Choose the backend for generation based on the config - generationBackend = backend.NewOllamaBackend(ollamaHost, ollamaGenModel) - + generationBackend := backend.NewOllamaBackend(ollamaHost, ollamaGenModel, time.Duration(10*time.Second)) log.Printf("Generation backend: %s", ollamaGenModel) // Initialize the vector database @@ -84,10 +75,18 @@ func main() { // Augment the query with retrieved context augmentedQuery := db.CombineQueryWithContext(query, retrievedDocs) - log.Printf("LLM Prompt: %s", query) + + prompt := backend.NewPrompt(). + AddMessage("system", "You are an AI assistant. Use the provided context to answer the user's question as accurately as possible."). + AddMessage("user", augmentedQuery). + SetParameters(backend.Parameters{ + MaxTokens: 150, // Supported by LLaMa + Temperature: 0.7, // Supported by LLaMa + TopP: 0.9, // Supported by LLaMa + }) // Generate response with the specified generation backend - response, err := generationBackend.Generate(ctx, augmentedQuery) + response, err := generationBackend.Generate(ctx, prompt) if err != nil { log.Fatalf("Failed to generate response: %v", err) } diff --git a/examples/openai/main.go b/examples/openai/main.go index 9d41bcc..173cc28 100644 --- a/examples/openai/main.go +++ b/examples/openai/main.go @@ -30,12 +30,12 @@ func main() { // Choose the backend for embeddings based on the config - embeddingBackend = backend.NewOpenAIBackend(apiKey, openAIEmbModel) + embeddingBackend = backend.NewOpenAIBackend(apiKey, openAIEmbModel, 10*time.Second) log.Printf("Embedding backend LLM: %s", openAIEmbModel) // Choose the backend for generation based on the config - generationBackend = backend.NewOpenAIBackend(apiKey, openAIGenModel) + generationBackend = backend.NewOpenAIBackend(apiKey, openAIGenModel, 10*time.Second) log.Printf("Generation backend: %s", openAIGenModel) @@ -94,8 +94,19 @@ func main() { log.Printf("Augmented Query: %s", augmentedQuery) + prompt := backend.NewPrompt(). + AddMessage("system", "You are an AI assistant. Use the provided context to answer the user's question as accurately as possible."). + AddMessage("user", augmentedQuery). + SetParameters(backend.Parameters{ + MaxTokens: 150, + Temperature: 0.7, + TopP: 0.9, + FrequencyPenalty: 0.5, + PresencePenalty: 0.6, + }) + // Generate response with the specified generation backend - response, err := generationBackend.Generate(ctx, augmentedQuery) + response, err := generationBackend.Generate(ctx, prompt) if err != nil { log.Fatalf("Failed to generate response: %v", err) } diff --git a/pkg/backend/backend.go b/pkg/backend/backend.go index a47e8d1..0cf7e15 100644 --- a/pkg/backend/backend.go +++ b/pkg/backend/backend.go @@ -4,20 +4,57 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - package backend import "context" // Backend defines the interface for interacting with various LLM backends. type Backend interface { - Generate(ctx context.Context, prompt string) (string, error) + Generate(ctx context.Context, prompt *Prompt) (string, error) Embed(ctx context.Context, input string) ([]float32, error) } + +// Message represents a single role-based message in the conversation. +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// Parameters defines generation settings for LLM completions. +type Parameters struct { + MaxTokens int `json:"max_tokens"` + Temperature float64 `json:"temperature"` + TopP float64 `json:"top_p"` + FrequencyPenalty float64 `json:"frequency_penalty"` + PresencePenalty float64 `json:"presence_penalty"` +} + +// Prompt represents a structured prompt with role-based messages and parameters. +type Prompt struct { + Messages []Message `json:"messages"` + Parameters Parameters `json:"parameters"` +} + +// NewPrompt creates and returns a new Prompt. +func NewPrompt() *Prompt { + return &Prompt{} +} + +// AddMessage adds a message with a specific role to the prompt. +func (p *Prompt) AddMessage(role, content string) *Prompt { + p.Messages = append(p.Messages, Message{Role: role, Content: content}) + return p +} + +// SetParameters sets the generation parameters for the prompt. +func (p *Prompt) SetParameters(params Parameters) *Prompt { + p.Parameters = params + return p +} diff --git a/pkg/backend/ollama_backend.go b/pkg/backend/ollama_backend.go index 36d8b4b..15a7428 100644 --- a/pkg/backend/ollama_backend.go +++ b/pkg/backend/ollama_backend.go @@ -4,14 +4,13 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - package backend import ( @@ -59,23 +58,45 @@ type OllamaEmbeddingResponse struct { } // NewOllamaBackend creates a new OllamaBackend instance. -func NewOllamaBackend(baseURL, model string) *OllamaBackend { +func NewOllamaBackend(baseURL, model string, timeout time.Duration) *OllamaBackend { return &OllamaBackend{ BaseURL: baseURL, Model: model, Client: &http.Client{ - Timeout: defaultTimeout, + Timeout: timeout, }, } } -// Generate produces a response from the Ollama API based on the given prompt. -func (o *OllamaBackend) Generate(ctx context.Context, prompt string) (string, error) { +// Generate produces a response from the Ollama API based on the given structured prompt. +// +// Parameters: +// - ctx: The context for the API request, which can be used for cancellation. +// - prompt: A structured prompt containing messages and parameters. +// +// Returns: +// - A string containing the generated response from the Ollama model. +// - An error if the API request fails or if there's an issue processing the response. +func (o *OllamaBackend) Generate(ctx context.Context, prompt *Prompt) (string, error) { url := o.BaseURL + generateEndpoint + + // Concatenate the messages into a single prompt string + var promptText string + for _, message := range prompt.Messages { + // Append role and content into one string (adjust formatting as needed) + promptText += message.Role + ": " + message.Content + "\n" + } + + // Construct the request body with concatenated prompt reqBody := map[string]interface{}{ - "model": o.Model, - "prompt": prompt, - "stream": false, + "model": o.Model, + "prompt": promptText, // Use concatenated string + "max_tokens": prompt.Parameters.MaxTokens, + "temperature": prompt.Parameters.Temperature, + "top_p": prompt.Parameters.TopP, + "frequency_penalty": prompt.Parameters.FrequencyPenalty, + "presence_penalty": prompt.Parameters.PresencePenalty, + "stream": false, // Explicitly set stream to false } reqBodyBytes, err := json.Marshal(reqBody) @@ -95,8 +116,9 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt string) (string, er } defer resp.Body.Close() + bodyBytes, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) return "", fmt.Errorf( "failed to generate response from Ollama: "+ "status code %d, response: %s", @@ -105,7 +127,7 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt string) (string, er } var result Response - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := json.NewDecoder(bytes.NewBuffer(bodyBytes)).Decode(&result); err != nil { return "", fmt.Errorf("failed to decode response: %w", err) } @@ -114,7 +136,6 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt string) (string, er // Embed generates embeddings for the given input text using the Ollama API. func (o *OllamaBackend) Embed(ctx context.Context, input string) ([]float32, error) { - url := o.BaseURL + embedEndpoint reqBody := map[string]interface{}{ "model": o.Model, diff --git a/pkg/backend/ollama_backend_test.go b/pkg/backend/ollama_backend_test.go index e24cbe6..eec5bbf 100644 --- a/pkg/backend/ollama_backend_test.go +++ b/pkg/backend/ollama_backend_test.go @@ -43,7 +43,6 @@ func TestOllamaGenerate(t *testing.T) { } // Check Content-Type header - if r.Header.Get("Content-Type") != contentTypeJSON { t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) } @@ -54,15 +53,10 @@ func TestOllamaGenerate(t *testing.T) { t.Errorf("Failed to decode request body: %v", err) } - // Optional: Validate request body contents - if reqBody["model"] != "test-model" { - t.Errorf("Expected model 'test-model', got '%v'", reqBody["model"]) - } - if reqBody["prompt"] != "Hello, Ollama!" { - t.Errorf("Expected prompt 'Hello, Ollama!', got '%v'", reqBody["prompt"]) - } - if reqBody["stream"] != false { - t.Errorf("Expected stream false, got '%v'", reqBody["stream"]) + // Check that the "prompt" field is correctly passed + promptText, ok := reqBody["prompt"].(string) + if !ok || promptText == "" { + t.Errorf("Expected a valid prompt, got: %v", reqBody["prompt"]) } // Write the mock response @@ -81,8 +75,21 @@ func TestOllamaGenerate(t *testing.T) { } ctx := context.Background() - prompt := "Hello, Ollama!" - + promptMsg := "Hello, Ollama!" + + // Construct the prompt + prompt := NewPrompt(). + AddMessage("system", "You are an AI assistant."). + AddMessage("user", promptMsg). + SetParameters(Parameters{ + MaxTokens: 150, + Temperature: 0.7, + TopP: 0.9, + FrequencyPenalty: 0.5, + PresencePenalty: 0.6, + }) + + // Call the Generate method response, err := backend.Generate(ctx, prompt) if err != nil { t.Fatalf("Generate returned error: %v", err) diff --git a/pkg/backend/openai_backend.go b/pkg/backend/openai_backend.go index f726a58..828a431 100644 --- a/pkg/backend/openai_backend.go +++ b/pkg/backend/openai_backend.go @@ -4,14 +4,13 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - package backend import ( @@ -21,6 +20,7 @@ import ( "fmt" "io" "net/http" + "time" ) // OpenAIBackend represents a backend for interacting with the OpenAI API. @@ -49,20 +49,28 @@ type OpenAIEmbeddingResponse struct { } `json:"usage"` } -// NewOpenAIBackend creates and returns a new OpenAIBackend instance. +// NewOpenAIBackend creates and returns a new OpenAIBackend instance with a custom timeout. // // Parameters: // - apiKey: The API key for authenticating with the OpenAI API. // - model: The name of the OpenAI model to use for generating responses. +// - timeout: The duration for the HTTP client timeout. If zero, the default timeout is used. // // Returns: -// - A pointer to a new OpenAIBackend instance configured with the provided API key and model. -func NewOpenAIBackend(apiKey, model string) *OpenAIBackend { +// - A pointer to a new OpenAIBackend instance configured with the provided API key, model, and timeout. +func NewOpenAIBackend(apiKey, model string, timeout time.Duration) *OpenAIBackend { + // Use defaultTimeout if the user passes 0 as the timeout value + if timeout == 0 { + timeout = defaultTimeout + } + return &OpenAIBackend{ - APIKey: apiKey, - Model: model, - HTTPClient: http.DefaultClient, - BaseURL: "https://api.openai.com", + APIKey: apiKey, + Model: model, + HTTPClient: &http.Client{ + Timeout: timeout, // Use the user-specified or default timeout here + }, + BaseURL: "https://api.openai.com", } } @@ -98,22 +106,37 @@ type OpenAIResponse struct { // Returns: // - A string containing the generated response from the OpenAI model. // - An error if the API request fails or if there's an issue processing the response. -func (o *OpenAIBackend) Generate(ctx context.Context, prompt string) (string, error) { + +// Generate sends a structured prompt to the OpenAI API and returns the generated response. +// +// Parameters: +// - ctx: The context for the API request, which can be used for cancellation. +// - prompt: A structured prompt containing messages and parameters. +// +// Returns: +// - A string containing the generated response from the OpenAI model. +// - An error if the API request fails or if there's an issue processing the response. +func (o *OpenAIBackend) Generate(ctx context.Context, prompt *Prompt) (string, error) { + timeoutCtx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + url := o.BaseURL + "/v1/chat/completions" reqBody := map[string]interface{}{ - "model": o.Model, - "messages": []map[string]string{ - {"role": "user", "content": prompt}, - }, + "model": o.Model, + "messages": prompt.Messages, + "max_tokens": prompt.Parameters.MaxTokens, + "temperature": prompt.Parameters.Temperature, + "top_p": prompt.Parameters.TopP, + "frequency_penalty": prompt.Parameters.FrequencyPenalty, + "presence_penalty": prompt.Parameters.PresencePenalty, } reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - return "", fmt.Errorf("failed to marshal request body: %w", err) } - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBodyBytes)) + req, err := http.NewRequestWithContext(timeoutCtx, "POST", url, bytes.NewBuffer(reqBodyBytes)) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } @@ -151,6 +174,8 @@ func (o *OpenAIBackend) Generate(ctx context.Context, prompt string) (string, er // - A slice of float32 values representing the embedding vector. // - An error if the API request fails or if there's an issue processing the response. func (o *OpenAIBackend) Embed(ctx context.Context, text string) ([]float32, error) { + timeoutCtx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() url := o.BaseURL + "/v1/embeddings" reqBody := map[string]interface{}{ "model": o.Model, @@ -162,7 +187,7 @@ func (o *OpenAIBackend) Embed(ctx context.Context, text string) ([]float32, erro return nil, fmt.Errorf("failed to marshal request body: %w", err) } - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBodyBytes)) + req, err := http.NewRequestWithContext(timeoutCtx, "POST", url, bytes.NewBuffer(reqBodyBytes)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } diff --git a/pkg/backend/openai_backend_test.go b/pkg/backend/openai_backend_test.go index 93a5965..e088a2d 100644 --- a/pkg/backend/openai_backend_test.go +++ b/pkg/backend/openai_backend_test.go @@ -95,7 +95,18 @@ func TestGenerate(t *testing.T) { } ctx := context.Background() - prompt := "Hello, world!" + msgPrompt := "Hello, openAI!" + + prompt := NewPrompt(). + AddMessage("system", "You are an AI assistant."). + AddMessage("user", msgPrompt). + SetParameters(Parameters{ + MaxTokens: 150, + Temperature: 0.7, + TopP: 0.9, + FrequencyPenalty: 0.5, + PresencePenalty: 0.6, + }) response, err := backend.Generate(ctx, prompt) if err != nil { From e97d0407677530093babd73491dcbaec03e3def8 Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Mon, 14 Oct 2024 20:19:33 +0100 Subject: [PATCH 2/2] Capture errors --- pkg/backend/ollama_backend.go | 10 ++++++++-- pkg/backend/openai_backend.go | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pkg/backend/ollama_backend.go b/pkg/backend/ollama_backend.go index 15a7428..434b4db 100644 --- a/pkg/backend/ollama_backend.go +++ b/pkg/backend/ollama_backend.go @@ -116,7 +116,10 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt *Prompt) (string, e } defer resp.Body.Close() - bodyBytes, _ := io.ReadAll(resp.Body) + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf( @@ -160,7 +163,10 @@ func (o *OllamaBackend) Embed(ctx context.Context, input string) ([]float32, err defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } return nil, fmt.Errorf( "failed to generate embeddings from Ollama: "+ "status code %d, response: %s", diff --git a/pkg/backend/openai_backend.go b/pkg/backend/openai_backend.go index 828a431..aea6bcf 100644 --- a/pkg/backend/openai_backend.go +++ b/pkg/backend/openai_backend.go @@ -151,7 +151,10 @@ func (o *OpenAIBackend) Generate(ctx context.Context, prompt *Prompt) (string, e defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } return "", fmt.Errorf("failed to generate response from OpenAI: "+ "status code %d, response: %s", resp.StatusCode, string(bodyBytes)) } @@ -202,7 +205,10 @@ func (o *OpenAIBackend) Embed(ctx context.Context, text string) ([]float32, erro defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } return nil, fmt.Errorf("failed to generate embedding from OpenAI: "+ "status code %d, response: %s", resp.StatusCode, string(bodyBytes)) }