Skip to content

Commit

Permalink
✨ feat: add channel tag
Browse files Browse the repository at this point in the history
  • Loading branch information
MartialBE committed Jun 21, 2024
1 parent 1425366 commit 1eff202
Show file tree
Hide file tree
Showing 16 changed files with 1,281 additions and 137 deletions.
27 changes: 27 additions & 0 deletions common/utils/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,33 @@ func Contains[T comparable](value T, slice []T) bool {
return false
}

func SliceToMap[T comparable](slice []T) map[T]bool {
res := make(map[T]bool)
for _, item := range slice {
res[item] = true
}
return res
}

func DifferenceSets[T comparable](set1, set2 map[T]bool) (diff1, diff2 []T) {
diff1 = make([]T, 0)
diff2 = make([]T, 0)

for key := range set1 {
if !set2[key] {
diff1 = append(diff1, key)
}
}

for key := range set2 {
if !set1[key] {
diff2 = append(diff2, key)
}
}

return diff1, diff2
}

func Filter[T any](arr []T, f func(T) bool) []T {
var res []T
for _, v := range arr {
Expand Down
16 changes: 16 additions & 0 deletions controller/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,22 @@ func DeleteChannel(c *gin.Context) {
})
}

func DeleteChannelTag(c *gin.Context) {
id, _ := strconv.Atoi(c.Param("id"))
err := model.DeleteChannelTag(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
}

func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
Expand Down
100 changes: 100 additions & 0 deletions controller/channel_tag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package controller

import (
"net/http"
"one-api/common"
"one-api/model"

"github.com/gin-gonic/gin"
)

func GetChannelsTagList(c *gin.Context) {
var params model.SearchChannelsTagParams
if err := c.ShouldBindQuery(&params); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}

channelsTag, err := model.GetChannelsTagList(&params)
if err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channelsTag,
})
}

func GetChannelsTagAllList(c *gin.Context) {
channelTags, err := model.GetChannelsTagAllList()
if err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channelTags,
})
}

func GetChannelsTag(c *gin.Context) {
tag := c.Param("tag")
if tag == "" {
common.AbortWithMessage(c, http.StatusOK, "tag is required")
return
}
channel, err := model.GetChannelsTag(tag)
if err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channel,
})
}

func UpdateChannelsTag(c *gin.Context) {
tag := c.Param("tag")
if tag == "" {
common.AbortWithMessage(c, http.StatusOK, "tag is required")
return
}
channel := model.Channel{}
err := c.ShouldBindJSON(&channel)
if err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}

err = model.UpdateChannelsTag(tag, &channel)
if err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
}

func DeleteChannelsTag(c *gin.Context) {
tag := c.Param("tag")
if tag == "" {
common.AbortWithMessage(c, http.StatusOK, "tag is required")
return
}
err := model.DeleteChannelsTag(tag)
if err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
}
22 changes: 21 additions & 1 deletion model/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ var allowedChannelOrderFields = map[string]bool{
type SearchChannelsParams struct {
Channel
PaginationParams
FilterTag bool `json:"filter_tag" form:"filter_tag"`
}

func GetChannelsList(params *SearchChannelsParams) (*DataResult[Channel], error) {
Expand Down Expand Up @@ -92,7 +93,15 @@ func GetChannelsList(params *SearchChannelsParams) (*DataResult[Channel], error)
db = db.Where("test_model LIKE ?", params.TestModel+"%")
}

return PaginateAndOrder[Channel](db, &params.PaginationParams, &channels, allowedChannelOrderFields)
if params.Tag != "" {
db = db.Where("tag = ?", params.Tag)
}

if params.FilterTag {
db = db.Where("tag = ''")
}

return PaginateAndOrder(db, &params.PaginationParams, &channels, allowedChannelOrderFields)
}

func GetAllChannels() ([]*Channel, error) {
Expand All @@ -109,6 +118,17 @@ func GetChannelById(id int) (*Channel, error) {
return &channel, err
}

func GetChannelsByTag(tag string) ([]*Channel, error) {
var channels []*Channel
err := DB.Where("tag = ?", tag).Find(&channels).Error
return channels, err
}

func DeleteChannelTag(channelId int) error {
err := DB.Model(&Channel{}).Where("id = ?", channelId).Update("tag", "").Error
return err
}

func BatchInsertChannels(channels []Channel) error {
var err error
err = DB.Omit("UsedQuota").Create(&channels).Error
Expand Down
173 changes: 173 additions & 0 deletions model/channel_tag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package model

import (
"one-api/common/config"
"strings"
)

type SearchChannelsTagParams struct {
Tag string `json:"tag" form:"tag"`
PaginationParams
}

type ChannelTag struct {
ID int `json:"id" gorm:"column:id"`
Tag string `json:"tag" gorm:"column:tag"`
}

func GetChannelsTagList(params *SearchChannelsTagParams) (*DataResult[Channel], error) {
var channels []*Channel
// 子查询:为每个tag选择最小的id
subQuery := DB.Model(&Channel{}).
Select("MIN(id) as id").
Where("tag != ''").
Group("tag")

db := DB.Select("tag, type, models, " + quotePostgresField("group"))
if params.Tag != "" {
subQuery = subQuery.Where("tag = ?", params.Tag)
}

db = db.Model(&Channel{}).Where("id IN (?)", subQuery)

return PaginateAndOrder(db, &params.PaginationParams, &channels, allowedChannelOrderFields)
}

func GetChannelsTagAllList() ([]*ChannelTag, error) {
var channelTags []*ChannelTag
err := DB.Model(&Channel{}).
Select("tag").
Where("tag != ''").
Group("tag").
Find(&channelTags).Error

return channelTags, err
}

func GetChannelsTag(tag string) (*Channel, error) {
var channel Channel
err := DB.Where("tag = ?", tag).First(&channel).Error
return &channel, err
}

func UpdateChannelsTag(tag string, channel *Channel) error {
channelTag, err := GetChannelsTag(tag)
if err != nil {
return err
}

tx := DB.Begin()
err = tx.Model(Channel{}).Where("tag = ?", tag).Updates(
Channel{
BaseURL: channel.BaseURL,
Other: channel.Other,
Models: channel.Models,
Group: channel.Group,
Tag: channel.Tag,
ModelMapping: channel.ModelMapping,
Proxy: channel.Proxy,
TestModel: channel.TestModel,
OnlyChat: channel.OnlyChat,
Plugin: channel.Plugin,
}).Error

if err != nil {
tx.Rollback()
return err
}

// 判断模型和分组是否有变化
if channelTag.Models == channel.Models && channelTag.Group == channel.Group {
tx.Commit()
return nil
}

channelList, err := GetChannelsByTag(tag)
if err != nil {
tx.Rollback()
return err
}

channelIds := make([]int, 0, len(channelList))
for _, c := range channelList {
channelIds = append(channelIds, c.Id)
}

models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")

// 如果模型有变化,更新
abilities := make([]*Ability, 0)
for _, c := range channelList {
enabled := c.Status == config.ChannelStatusEnabled
priority := c.Priority
weight := c.Weight
for _, model := range models_ {
for _, group := range groups_ {
ability := &Ability{
Group: group,
Model: model,
ChannelId: c.Id,
Enabled: enabled,
Priority: priority,
Weight: weight,
}
abilities = append(abilities, ability)
}
}
}

// 删除旧的
err = tx.Where("channel_id IN (?)", channelIds).Delete(&Ability{}).Error
if err != nil {
tx.Rollback()
return err
}

// 添加新的
err = BatchInsert(tx, abilities)
if err != nil {
tx.Rollback()
return err
}

tx.Commit()

go ChannelGroup.Load()

return err
}

func DeleteChannelsTag(tag string) error {
if tag == "" {
return nil
}

tx := DB.Begin()
channelList, err := GetChannelsByTag(tag)
if err != nil {
return err
}

channelIds := make([]int, 0, len(channelList))
for _, c := range channelList {
channelIds = append(channelIds, c.Id)
}

err = tx.Where("channel_id IN (?)", channelIds).Delete(&Ability{}).Error
if err != nil {
tx.Rollback()
return err
}

err = tx.Where("tag = ?", tag).Delete(&Channel{}).Error
if err != nil {
tx.Rollback()
return err
}

tx.Commit()
go ChannelGroup.Load()

return err
}
Loading

0 comments on commit 1eff202

Please sign in to comment.