-
Notifications
You must be signed in to change notification settings - Fork 693
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add custom restful backend for complex scenarios (e.g, rag)
Signed-off-by: popsiclexu <[email protected]> Signed-off-by: popsiclexu <[email protected]> Signed-off-by: popsiclexu <[email protected]>
- Loading branch information
1 parent
7019d0b
commit a03b89e
Showing
5 changed files
with
156 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
package ai | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"net/url" | ||
"strings" | ||
"time" | ||
) | ||
|
||
const CustomRestClientName = "customrest" | ||
|
||
type CustomRestClient struct { | ||
nopCloser | ||
client *http.Client | ||
base *url.URL | ||
token string | ||
model string | ||
temperature float32 | ||
topP float32 | ||
topK int32 | ||
} | ||
|
||
type CustomRestRequest struct { | ||
Model string `json:"model"` | ||
|
||
// Prompt is the textual prompt to send to the model. | ||
Prompt string `json:"prompt"` | ||
|
||
// Options lists model-specific options. For example, temperature can be | ||
// set through this field, if the model supports it. | ||
Options map[string]interface{} `json:"options"` | ||
} | ||
|
||
type CustomRestResponse struct { | ||
// Model is the model name that generated the response. | ||
Model string `json:"model"` | ||
|
||
// CreatedAt is the timestamp of the response. | ||
CreatedAt time.Time `json:"created_at"` | ||
|
||
// Response is the textual response itself. | ||
Response string `json:"response"` | ||
} | ||
|
||
func (c *CustomRestClient) Configure(config IAIConfig) error { | ||
baseURL := config.GetBaseURL() | ||
if baseURL == "" { | ||
baseURL = defaultBaseURL | ||
} | ||
c.token = config.GetPassword() | ||
baseClientURL, err := url.Parse(baseURL) | ||
if err != nil { | ||
return err | ||
} | ||
c.base = baseClientURL | ||
|
||
proxyEndpoint := config.GetProxyEndpoint() | ||
c.client = http.DefaultClient | ||
if proxyEndpoint != "" { | ||
proxyUrl, err := url.Parse(proxyEndpoint) | ||
if err != nil { | ||
return err | ||
} | ||
transport := &http.Transport{ | ||
Proxy: http.ProxyURL(proxyUrl), | ||
} | ||
|
||
c.client = &http.Client{ | ||
Transport: transport, | ||
} | ||
} | ||
|
||
c.model = config.GetModel() | ||
if c.model == "" { | ||
c.model = defaultModel | ||
} | ||
c.temperature = config.GetTemperature() | ||
c.topP = config.GetTopP() | ||
c.topK = config.GetTopK() | ||
return nil | ||
} | ||
|
||
func (c *CustomRestClient) GetCompletion(ctx context.Context, prompt string) (string, error) { | ||
var promptDetail struct { | ||
Language string `json:"language,omitempty"` | ||
Message string `json:"message"` | ||
Prompt string `json:"prompt,omitempty"` | ||
} | ||
prompt = strings.ReplaceAll(prompt, "\n", "\\n") | ||
prompt = strings.ReplaceAll(prompt, "\t", "\\t") | ||
if err := json.Unmarshal([]byte(prompt), &promptDetail); err != nil { | ||
return "", err | ||
} | ||
generateRequest := &CustomRestRequest{ | ||
Model: c.model, | ||
Prompt: promptDetail.Prompt, | ||
Options: map[string]interface{}{ | ||
"temperature": c.temperature, | ||
"top_p": c.topP, | ||
"top_k": c.topK, | ||
"message": promptDetail.Message, | ||
"language": promptDetail.Language, | ||
}, | ||
} | ||
bts, err := json.Marshal(generateRequest) | ||
if err != nil { | ||
return "", err | ||
} | ||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base.String(), bytes.NewBuffer(bts)) | ||
if err != nil { | ||
return "", err | ||
} | ||
if c.token != "" { | ||
request.Header.Set("Authorization", "Bearer "+c.token) | ||
} | ||
request.Header.Set("Content-Type", "application/json") | ||
request.Header.Set("Accept", "application/x-ndjson") | ||
response, err := c.client.Do(request) | ||
if err != nil { | ||
return "", err | ||
} | ||
defer response.Body.Close() | ||
|
||
resBody, err := io.ReadAll(response.Body) | ||
if err != nil { | ||
return "", fmt.Errorf("could not read response body: %w", err) | ||
} | ||
|
||
if response.StatusCode >= http.StatusBadRequest { | ||
return "", fmt.Errorf("Request Error, StatusCode: %d, ErrorMessage: %s", response.StatusCode, resBody) | ||
} | ||
|
||
var resp CustomRestResponse | ||
if err := json.Unmarshal(resBody, &resp); err != nil { | ||
return "", err | ||
} | ||
return resp.Response, nil | ||
} | ||
|
||
func (c *CustomRestClient) GetName() string { | ||
return CustomRestClientName | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters