Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add custom restful backend for complex scenarios (e.g, rag) #1228

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ Unused:
> noopai
> googlevertexai
> watsonxai
> customrest
```

For detailed documentation on how to configure and use each provider see [here](https://docs.k8sgpt.ai/reference/providers/backend/).
Expand Down
147 changes: 147 additions & 0 deletions pkg/ai/customrest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package ai

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)

const CustomRestClientName = "customrest"

type CustomRestClient struct {
nopCloser
client *http.Client
base *url.URL
token string
model string
temperature float32
topP float32
topK int32
}

type CustomRestRequest struct {
Model string `json:"model"`

// Prompt is the textual prompt to send to the model.
Prompt string `json:"prompt"`

// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
}

type CustomRestResponse struct {
// Model is the model name that generated the response.
Model string `json:"model"`

// CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"`

// Response is the textual response itself.
Response string `json:"response"`
}

func (c *CustomRestClient) Configure(config IAIConfig) error {
baseURL := config.GetBaseURL()
if baseURL == "" {
baseURL = defaultBaseURL
}
c.token = config.GetPassword()
baseClientURL, err := url.Parse(baseURL)
if err != nil {
return err
}
c.base = baseClientURL

proxyEndpoint := config.GetProxyEndpoint()
c.client = http.DefaultClient
if proxyEndpoint != "" {
proxyUrl, err := url.Parse(proxyEndpoint)
if err != nil {
return err
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyUrl),
}

c.client = &http.Client{
Transport: transport,
}
}

c.model = config.GetModel()
if c.model == "" {
c.model = defaultModel
}
c.temperature = config.GetTemperature()
c.topP = config.GetTopP()
c.topK = config.GetTopK()
return nil
}

func (c *CustomRestClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
var promptDetail struct {
Language string `json:"language,omitempty"`
Message string `json:"message"`
Prompt string `json:"prompt,omitempty"`
}
prompt = strings.NewReplacer("\n", "\\n", "\t", "\\t").Replace(prompt)
if err := json.Unmarshal([]byte(prompt), &promptDetail); err != nil {
return "", err
}
generateRequest := &CustomRestRequest{
Model: c.model,
Prompt: promptDetail.Prompt,
Options: map[string]interface{}{
"temperature": c.temperature,
"top_p": c.topP,
"top_k": c.topK,
"message": promptDetail.Message,
"language": promptDetail.Language,
},
}
requestBody, err := json.Marshal(generateRequest)
if err != nil {
return "", err
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base.String(), bytes.NewBuffer(requestBody))
if err != nil {
return "", err
}
if c.token != "" {
request.Header.Set("Authorization", "Bearer "+c.token)
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/x-ndjson")

response, err := c.client.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()

responseBody, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("could not read response body: %w", err)
}

if response.StatusCode >= http.StatusBadRequest {
return "", fmt.Errorf("Request Error, StatusCode: %d, ErrorMessage: %s", response.StatusCode, responseBody)
}

var result CustomRestResponse
if err := json.Unmarshal(responseBody, &result); err != nil {
return "", err
}
return result.Response, nil
}

func (c *CustomRestClient) GetName() string {
return CustomRestClientName
}
4 changes: 3 additions & 1 deletion pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ var (
&GoogleVertexAIClient{},
&OCIGenAIClient{},
&WatsonxAIClient{},
&CustomRestClient{},
}
Backends = []string{
openAIClientName,
Expand All @@ -48,6 +49,7 @@ var (
googleVertexAIClientName,
ociClientName,
watsonxAIClientName,
CustomRestClientName,
}
)

Expand Down Expand Up @@ -181,7 +183,7 @@ func (p *AIProvider) GetCustomHeaders() []http.Header {
return p.CustomHeaders
}

var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "customrest"}

func NeedPassword(backend string) bool {
for _, b := range passwordlessProviders {
Expand Down
2 changes: 2 additions & 0 deletions pkg/ai/prompts.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ const (

Solution: {kubectl command}
`
raw_promt = `{"language": "%s","message": "%s","prompt": "%s"}`
)

var PromptMap = map[string]string{
"raw": raw_promt,
"default": default_prompt,
"VulnerabilityReport": trivy_vuln_prompt, // for Trivy integration, the key should match `Result.Kind` in pkg/common/types.go
"ConfigAuditReport": trivy_conf_prompt,
Expand Down
3 changes: 3 additions & 0 deletions pkg/analysis/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ func (a *Analysis) getAIResultForSanitizedFailures(texts []string, promptTmpl st

// Process template.
prompt := fmt.Sprintf(strings.TrimSpace(promptTmpl), a.Language, inputKey)
if a.AIClient.GetName() == ai.CustomRestClientName {
prompt = fmt.Sprintf(ai.PromptMap["raw"], a.Language, inputKey, prompt)
}
response, err := a.AIClient.GetCompletion(a.Context, prompt)
if err != nil {
return "", err
Expand Down