Skip to content

Commit

Permalink
Merge pull request #9 from StacklokLabs/std-prompt
Browse files Browse the repository at this point in the history
Standard Prompt Format
  • Loading branch information
lukehinds authored Oct 14, 2024
2 parents ce86e31 + e97d040 commit 25bad2e
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 65 deletions.
81 changes: 81 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,87 @@ Should you wish, the docker-compose will automate the setup of the database.
Best bet is to see `/examples/*` for reference, this explains how to use
the library with examples for generation, embeddings and implementing RAG.

There are currently two backend systems supported, Ollama and OpenAI, with
the ability to generate embeddings for RAG.

## Ollama

First create a Backend object

```go
generationBackend := backend.NewOllamaBackend("http://localhost:11434", "llama3", time.Duration(10*time.Second))
```

Create a prompt

```go
prompt := backend.NewPrompt().
AddMessage("system", "You are an AI assistant. Use the provided context to answer the user's question as accurately as possible.").
AddMessage("user", "What is love?").
SetParameters(backend.Parameters{
MaxTokens: 150,
Temperature: 0.7,
TopP: 0.9,
})
```

Generate a response

```go
response, err := generationBackend.Generate(ctx, prompt)
if err != nil {
log.Fatalf("Failed to generate response: %v", err)
}
```

## OpenAI

First create a Backend object

```go
generationBackend = backend.NewOpenAIBackend("API_KEY", "gpt-3.5-turbo", 10*time.Second)
```

Create a prompt

```go
prompt := backend.NewPrompt().
AddMessage("system", "You are an AI assistant. Use the provided context to answer the user's question as accurately as possible.").
AddMessage("user", "How much is too much?").
SetParameters(backend.Parameters{
MaxTokens: 150,
Temperature: 0.7,
TopP: 0.9,
FrequencyPenalty: 0.5,
PresencePenalty: 0.6,
})
```

Generate a response

```go
response, err := generationBackend.Generate(ctx, prompt)
if err != nil {
log.Fatalf("Failed to generate response: %v", err)
}
```

## RAG

To generate embeddings for RAG, you can use the `Embeddings` interface in both
Ollama and OpenAI backends.

```go
embedding, err := embeddingBackend.Embed(ctx, "Mickey mouse is a real human being")
if err != nil {
log.Fatalf("Error generating embedding: %v", err)
}
log.Println("Embedding generated")
```

A database is also required, we have support for PostGres with pgvector. See `/examples/*`
for reference.

# 📝 Contributing

We welcome contributions! Please submit a pull request or raise an issue if
Expand Down
27 changes: 13 additions & 14 deletions examples/ollama/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,12 @@ var (
)

func main() {
// Initialize Config

// Select backends based on config
var embeddingBackend backend.Backend
var generationBackend backend.Backend

// Choose the backend for embeddings based on the config

embeddingBackend = backend.NewOllamaBackend(ollamaHost, ollamaEmbModel)

// Configure the Ollama backend for both embedding and generation
embeddingBackend := backend.NewOllamaBackend(ollamaHost, ollamaEmbModel, time.Duration(10*time.Second))
log.Printf("Embedding backend LLM: %s", ollamaEmbModel)

// Choose the backend for generation based on the config
generationBackend = backend.NewOllamaBackend(ollamaHost, ollamaGenModel)

generationBackend := backend.NewOllamaBackend(ollamaHost, ollamaGenModel, time.Duration(10*time.Second))
log.Printf("Generation backend: %s", ollamaGenModel)

// Initialize the vector database
Expand Down Expand Up @@ -84,10 +75,18 @@ func main() {

// Augment the query with retrieved context
augmentedQuery := db.CombineQueryWithContext(query, retrievedDocs)
log.Printf("LLM Prompt: %s", query)

prompt := backend.NewPrompt().
AddMessage("system", "You are an AI assistant. Use the provided context to answer the user's question as accurately as possible.").
AddMessage("user", augmentedQuery).
SetParameters(backend.Parameters{
MaxTokens: 150, // Supported by LLaMa
Temperature: 0.7, // Supported by LLaMa
TopP: 0.9, // Supported by LLaMa
})

// Generate response with the specified generation backend
response, err := generationBackend.Generate(ctx, augmentedQuery)
response, err := generationBackend.Generate(ctx, prompt)
if err != nil {
log.Fatalf("Failed to generate response: %v", err)
}
Expand Down
17 changes: 14 additions & 3 deletions examples/openai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ func main() {

// Choose the backend for embeddings based on the config

embeddingBackend = backend.NewOpenAIBackend(apiKey, openAIEmbModel)
embeddingBackend = backend.NewOpenAIBackend(apiKey, openAIEmbModel, 10*time.Second)

log.Printf("Embedding backend LLM: %s", openAIEmbModel)

// Choose the backend for generation based on the config
generationBackend = backend.NewOpenAIBackend(apiKey, openAIGenModel)
generationBackend = backend.NewOpenAIBackend(apiKey, openAIGenModel, 10*time.Second)

log.Printf("Generation backend: %s", openAIGenModel)

Expand Down Expand Up @@ -94,8 +94,19 @@ func main() {

log.Printf("Augmented Query: %s", augmentedQuery)

prompt := backend.NewPrompt().
AddMessage("system", "You are an AI assistant. Use the provided context to answer the user's question as accurately as possible.").
AddMessage("user", augmentedQuery).
SetParameters(backend.Parameters{
MaxTokens: 150,
Temperature: 0.7,
TopP: 0.9,
FrequencyPenalty: 0.5,
PresencePenalty: 0.6,
})

// Generate response with the specified generation backend
response, err := generationBackend.Generate(ctx, augmentedQuery)
response, err := generationBackend.Generate(ctx, prompt)
if err != nil {
log.Fatalf("Failed to generate response: %v", err)
}
Expand Down
43 changes: 40 additions & 3 deletions pkg/backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,57 @@
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package backend

import "context"

// Backend defines the interface for interacting with various LLM backends.
type Backend interface {
Generate(ctx context.Context, prompt string) (string, error)
Generate(ctx context.Context, prompt *Prompt) (string, error)
Embed(ctx context.Context, input string) ([]float32, error)
}

// Message represents a single role-based message in the conversation.
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

// Parameters defines generation settings for LLM completions.
type Parameters struct {
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
FrequencyPenalty float64 `json:"frequency_penalty"`
PresencePenalty float64 `json:"presence_penalty"`
}

// Prompt represents a structured prompt with role-based messages and parameters.
type Prompt struct {
Messages []Message `json:"messages"`
Parameters Parameters `json:"parameters"`
}

// NewPrompt creates and returns a new Prompt.
func NewPrompt() *Prompt {
return &Prompt{}
}

// AddMessage adds a message with a specific role to the prompt.
func (p *Prompt) AddMessage(role, content string) *Prompt {
p.Messages = append(p.Messages, Message{Role: role, Content: content})
return p
}

// SetParameters sets the generation parameters for the prompt.
func (p *Prompt) SetParameters(params Parameters) *Prompt {
p.Parameters = params
return p
}
53 changes: 40 additions & 13 deletions pkg/backend/ollama_backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package backend

import (
Expand Down Expand Up @@ -59,23 +58,45 @@ type OllamaEmbeddingResponse struct {
}

// NewOllamaBackend creates a new OllamaBackend instance.
func NewOllamaBackend(baseURL, model string) *OllamaBackend {
func NewOllamaBackend(baseURL, model string, timeout time.Duration) *OllamaBackend {
return &OllamaBackend{
BaseURL: baseURL,
Model: model,
Client: &http.Client{
Timeout: defaultTimeout,
Timeout: timeout,
},
}
}

// Generate produces a response from the Ollama API based on the given prompt.
func (o *OllamaBackend) Generate(ctx context.Context, prompt string) (string, error) {
// Generate produces a response from the Ollama API based on the given structured prompt.
//
// Parameters:
// - ctx: The context for the API request, which can be used for cancellation.
// - prompt: A structured prompt containing messages and parameters.
//
// Returns:
// - A string containing the generated response from the Ollama model.
// - An error if the API request fails or if there's an issue processing the response.
func (o *OllamaBackend) Generate(ctx context.Context, prompt *Prompt) (string, error) {
url := o.BaseURL + generateEndpoint

// Concatenate the messages into a single prompt string
var promptText string
for _, message := range prompt.Messages {
// Append role and content into one string (adjust formatting as needed)
promptText += message.Role + ": " + message.Content + "\n"
}

// Construct the request body with concatenated prompt
reqBody := map[string]interface{}{
"model": o.Model,
"prompt": prompt,
"stream": false,
"model": o.Model,
"prompt": promptText, // Use concatenated string
"max_tokens": prompt.Parameters.MaxTokens,
"temperature": prompt.Parameters.Temperature,
"top_p": prompt.Parameters.TopP,
"frequency_penalty": prompt.Parameters.FrequencyPenalty,
"presence_penalty": prompt.Parameters.PresencePenalty,
"stream": false, // Explicitly set stream to false
}

reqBodyBytes, err := json.Marshal(reqBody)
Expand All @@ -95,8 +116,12 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt string) (string, er
}
defer resp.Body.Close()

bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response body: %w", err)
}

if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf(
"failed to generate response from Ollama: "+
"status code %d, response: %s",
Expand All @@ -105,7 +130,7 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt string) (string, er
}

var result Response
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
if err := json.NewDecoder(bytes.NewBuffer(bodyBytes)).Decode(&result); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}

Expand All @@ -114,7 +139,6 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt string) (string, er

// Embed generates embeddings for the given input text using the Ollama API.
func (o *OllamaBackend) Embed(ctx context.Context, input string) ([]float32, error) {

url := o.BaseURL + embedEndpoint
reqBody := map[string]interface{}{
"model": o.Model,
Expand All @@ -139,7 +163,10 @@ func (o *OllamaBackend) Embed(ctx context.Context, input string) ([]float32, err
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return nil, fmt.Errorf(
"failed to generate embeddings from Ollama: "+
"status code %d, response: %s",
Expand Down
31 changes: 19 additions & 12 deletions pkg/backend/ollama_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ func TestOllamaGenerate(t *testing.T) {
}

// Check Content-Type header

if r.Header.Get("Content-Type") != contentTypeJSON {
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
}
Expand All @@ -54,15 +53,10 @@ func TestOllamaGenerate(t *testing.T) {
t.Errorf("Failed to decode request body: %v", err)
}

// Optional: Validate request body contents
if reqBody["model"] != "test-model" {
t.Errorf("Expected model 'test-model', got '%v'", reqBody["model"])
}
if reqBody["prompt"] != "Hello, Ollama!" {
t.Errorf("Expected prompt 'Hello, Ollama!', got '%v'", reqBody["prompt"])
}
if reqBody["stream"] != false {
t.Errorf("Expected stream false, got '%v'", reqBody["stream"])
// Check that the "prompt" field is correctly passed
promptText, ok := reqBody["prompt"].(string)
if !ok || promptText == "" {
t.Errorf("Expected a valid prompt, got: %v", reqBody["prompt"])
}

// Write the mock response
Expand All @@ -81,8 +75,21 @@ func TestOllamaGenerate(t *testing.T) {
}

ctx := context.Background()
prompt := "Hello, Ollama!"

promptMsg := "Hello, Ollama!"

// Construct the prompt
prompt := NewPrompt().
AddMessage("system", "You are an AI assistant.").
AddMessage("user", promptMsg).
SetParameters(Parameters{
MaxTokens: 150,
Temperature: 0.7,
TopP: 0.9,
FrequencyPenalty: 0.5,
PresencePenalty: 0.6,
})

// Call the Generate method
response, err := backend.Generate(ctx, prompt)
if err != nil {
t.Fatalf("Generate returned error: %v", err)
Expand Down
Loading

0 comments on commit 25bad2e

Please sign in to comment.