Skip to content

Commit

Permalink
feat: add custom restful backend for complex scenarios (e.g, rag)
Browse files Browse the repository at this point in the history
Signed-off-by: popsiclexu <[email protected]>
Signed-off-by: popsiclexu <[email protected]>
Signed-off-by: popsiclexu <[email protected]>
  • Loading branch information
popsiclexu authored and popsiclexu committed Aug 16, 2024
1 parent 7019d0b commit a03b89e
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ Unused:
> noopai
> googlevertexai
> watsonxai
> customrest
```

For detailed documentation on how to configure and use each provider see [here](https://docs.k8sgpt.ai/reference/providers/backend/).
Expand Down
147 changes: 147 additions & 0 deletions pkg/ai/customrest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package ai

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)

const CustomRestClientName = "customrest"

type CustomRestClient struct {
nopCloser
client *http.Client
base *url.URL
token string
model string
temperature float32
topP float32
topK int32
}

type CustomRestRequest struct {
Model string `json:"model"`

// Prompt is the textual prompt to send to the model.
Prompt string `json:"prompt"`

// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
}

type CustomRestResponse struct {
// Model is the model name that generated the response.
Model string `json:"model"`

// CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"`

// Response is the textual response itself.
Response string `json:"response"`
}

func (c *CustomRestClient) Configure(config IAIConfig) error {
baseURL := config.GetBaseURL()
if baseURL == "" {
baseURL = defaultBaseURL
}
c.token = config.GetPassword()
baseClientURL, err := url.Parse(baseURL)
if err != nil {
return err
}
c.base = baseClientURL

proxyEndpoint := config.GetProxyEndpoint()
c.client = http.DefaultClient
if proxyEndpoint != "" {
proxyUrl, err := url.Parse(proxyEndpoint)
if err != nil {
return err
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyUrl),
}

c.client = &http.Client{
Transport: transport,
}
}

c.model = config.GetModel()
if c.model == "" {
c.model = defaultModel
}
c.temperature = config.GetTemperature()
c.topP = config.GetTopP()
c.topK = config.GetTopK()
return nil
}

func (c *CustomRestClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
var promptDetail struct {
Language string `json:"language,omitempty"`
Message string `json:"message"`
Prompt string `json:"prompt,omitempty"`
}
prompt = strings.ReplaceAll(prompt, "\n", "\\n")
prompt = strings.ReplaceAll(prompt, "\t", "\\t")
if err := json.Unmarshal([]byte(prompt), &promptDetail); err != nil {
return "", err
}
generateRequest := &CustomRestRequest{
Model: c.model,
Prompt: promptDetail.Prompt,
Options: map[string]interface{}{
"temperature": c.temperature,
"top_p": c.topP,
"top_k": c.topK,
"message": promptDetail.Message,
"language": promptDetail.Language,
},
}
bts, err := json.Marshal(generateRequest)
if err != nil {
return "", err
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base.String(), bytes.NewBuffer(bts))
if err != nil {
return "", err
}
if c.token != "" {
request.Header.Set("Authorization", "Bearer "+c.token)
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/x-ndjson")
response, err := c.client.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()

resBody, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("could not read response body: %w", err)
}

if response.StatusCode >= http.StatusBadRequest {
return "", fmt.Errorf("Request Error, StatusCode: %d, ErrorMessage: %s", response.StatusCode, resBody)
}

var resp CustomRestResponse
if err := json.Unmarshal(resBody, &resp); err != nil {
return "", err
}
return resp.Response, nil
}

func (c *CustomRestClient) GetName() string {
return CustomRestClientName
}
4 changes: 3 additions & 1 deletion pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ var (
&GoogleVertexAIClient{},
&OCIGenAIClient{},
&WatsonxAIClient{},
&CustomRestClient{},
}
Backends = []string{
openAIClientName,
Expand All @@ -48,6 +49,7 @@ var (
googleVertexAIClientName,
ociClientName,
watsonxAIClientName,
CustomRestClientName,
}
)

Expand Down Expand Up @@ -181,7 +183,7 @@ func (p *AIProvider) GetCustomHeaders() []http.Header {
return p.CustomHeaders
}

var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"}
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai", "customrest"}

func NeedPassword(backend string) bool {
for _, b := range passwordlessProviders {
Expand Down
2 changes: 2 additions & 0 deletions pkg/ai/prompts.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ const (
Solution: {kubectl command}
`
raw_promt = `{"language": "%s","message": "%s","prompt": "%s"}`
)

var PromptMap = map[string]string{
"raw": raw_promt,
"default": default_prompt,
"VulnerabilityReport": trivy_vuln_prompt, // for Trivy integration, the key should match `Result.Kind` in pkg/common/types.go
"ConfigAuditReport": trivy_conf_prompt,
Expand Down
3 changes: 3 additions & 0 deletions pkg/analysis/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ func (a *Analysis) getAIResultForSanitizedFailures(texts []string, promptTmpl st

// Process template.
prompt := fmt.Sprintf(strings.TrimSpace(promptTmpl), a.Language, inputKey)
if a.AIClient.GetName() == ai.CustomRestClientName {
prompt = fmt.Sprintf(ai.PromptMap["raw"], a.Language, inputKey, prompt)
}
response, err := a.AIClient.GetCompletion(a.Context, prompt)
if err != nil {
return "", err
Expand Down

0 comments on commit a03b89e

Please sign in to comment.