Skip to content

Commit

Permalink
feat: Add FetchModels endpoint and refactor FetchUpstreamModels
Browse files Browse the repository at this point in the history
- Introduced a new `FetchModels` endpoint to retrieve model IDs from a specified base URL and API key, enhancing flexibility for different channel types.
- Refactored `FetchUpstreamModels` to simplify base URL handling and improve error messages during response parsing.
- Updated API routes to include the new endpoint and adjusted the frontend to utilize the new fetch mechanism for model lists.
- Removed outdated checks for channel type in the frontend, streamlining the model fetching process.
  • Loading branch information
Calcium-Ion committed Dec 24, 2024
1 parent 2ec5eaf commit 93cda60
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 33 deletions.
114 changes: 96 additions & 18 deletions controller/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ func FetchUpstreamModels(c *gin.Context) {
})
return
}

channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
Expand All @@ -105,34 +106,35 @@ func FetchUpstreamModels(c *gin.Context) {
})
return
}
if channel.Type != common.ChannelTypeOpenAI {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "仅支持 OpenAI 类型渠道",
})
return
}
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)

//if channel.Type != common.ChannelTypeOpenAI {
// c.JSON(http.StatusOK, gin.H{
// "success": false,
// "message": "仅支持 OpenAI 类型渠道",
// })
// return
//}
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL
}
url := fmt.Sprintf("%s/v1/models", baseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
result := OpenAIModelsResponse{}
err = json.Unmarshal(body, &result)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
if !result.Success {

var result OpenAIModelsResponse
if err = json.Unmarshal(body, &result); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "上游返回错误",
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
})
return
}

var ids []string
Expand Down Expand Up @@ -492,3 +494,79 @@ func UpdateChannel(c *gin.Context) {
})
return
}

func FetchModels(c *gin.Context) {
var req struct {
BaseURL string `json:"base_url"`
Key string `json:"key"`
}

if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request",
})
return
}

baseURL := req.BaseURL
if baseURL == "" {
baseURL = "https://api.openai.com"
}

client := &http.Client{}
url := fmt.Sprintf("%s/v1/models", baseURL)

request, err := http.NewRequest("GET", url, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}

request.Header.Set("Authorization", "Bearer "+req.Key)

response, err := client.Do(request)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
//check status code
if response.StatusCode != http.StatusOK {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "Failed to fetch models",
})
return
}
defer response.Body.Close()

var result struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}

if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}

var models []string
for _, model := range result.Data {
models = append(models, model.ID)
}

c.JSON(http.StatusOK, gin.H{
"success": true,
"data": models,
})
}
1 change: 1 addition & 0 deletions router/api-router.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/batch", controller.DeleteChannelBatch)
channelRoute.POST("/fix", controller.FixChannelsAbilities)
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
channelRoute.POST("/fetch_models", controller.FetchModels)

}
tokenRoute := apiRouter.Group("/token")
Expand Down
31 changes: 16 additions & 15 deletions web/src/pages/Channel/EditChannel.js
Original file line number Diff line number Diff line change
Expand Up @@ -193,45 +193,46 @@ const EditChannel = (props) => {


const fetchUpstreamModelList = async (name) => {
if (inputs['type'] !== 1) {
showError(t('仅支持 OpenAI 接口格式'));
return;
}
// if (inputs['type'] !== 1) {
// showError(t('仅支持 OpenAI 接口格式'));
// return;
// }
setLoading(true);
const models = inputs['models'] || [];
let err = false;

if (isEdit) {
// 如果是编辑模式,使用已有的channel id获取模型列表
const res = await API.get('/api/channel/fetch_models/' + channelId);
if (res.data && res.data?.success) {
models.push(...res.data.data);
} else {
err = true;
}
} else {
// 如果是新建模式,通过后端代理获取模型列表
if (!inputs?.['key']) {
showError(t('请填写密钥'));
err = true;
} else {
try {
const host = new URL((inputs['base_url'] || 'https://api.openai.com'));

const url = `https://${host.hostname}/v1/models`;
const key = inputs['key'];
const res = await axios.get(url, {
headers: {
'Authorization': `Bearer ${key}`
}
const res = await API.post('/api/channel/fetch_models', {
base_url: inputs['base_url'],
key: inputs['key']
});
if (res.data) {
models.push(...res.data.data.map((model) => model.id));

if (res.data && res.data.success) {
models.push(...res.data.data);
} else {
err = true;
}
} catch (error) {
console.error('Error fetching models:', error);
err = true;
}
}
}

if (!err) {
handleInputChange(name, Array.from(new Set(models)));
showSuccess(t('获取模型列表成功'));
Expand Down Expand Up @@ -638,7 +639,7 @@ const EditChannel = (props) => {
{inputs.type === 21 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>知识库 ID:</Typography.Text>
<Typography.Text strong>��识库 ID:</Typography.Text>
</div>
<Input
label="知识库 ID"
Expand Down

0 comments on commit 93cda60

Please sign in to comment.