From f505afdc1025082a96153d6e66215d3028d90b07 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 17 Sep 2024 20:49:51 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E4=BB=A4=E7=89=8Cip?= =?UTF-8?q?=E7=99=BD=E5=90=8D=E5=8D=95=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/utils.go | 5 +++++ controller/token.go | 2 ++ middleware/auth.go | 1 + middleware/distributor.go | 8 ++++++++ model/token.go | 25 ++++++++++++++++++++++++- web/src/pages/Token/EditToken.js | 19 +++++++++++++++++-- 6 files changed, 57 insertions(+), 3 deletions(-) diff --git a/common/utils.go b/common/utils.go index 3d95508cb..3d0cb6a00 100644 --- a/common/utils.go +++ b/common/utils.go @@ -128,6 +128,11 @@ func IntMax(a int, b int) int { } } +func IsIP(s string) bool { + ip := net.ParseIP(s) + return ip != nil +} + func GetUUID() string { code := uuid.New().String() code = strings.Replace(code, "-", "", -1) diff --git a/controller/token.go b/controller/token.go index 39e602463..50a368f6f 100644 --- a/controller/token.go +++ b/controller/token.go @@ -134,6 +134,7 @@ func AddToken(c *gin.Context) { UnlimitedQuota: token.UnlimitedQuota, ModelLimitsEnabled: token.ModelLimitsEnabled, ModelLimits: token.ModelLimits, + AllowIps: token.AllowIps, } err = cleanToken.Insert() if err != nil { @@ -221,6 +222,7 @@ func UpdateToken(c *gin.Context) { cleanToken.UnlimitedQuota = token.UnlimitedQuota cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled cleanToken.ModelLimits = token.ModelLimits + cleanToken.AllowIps = token.AllowIps } err = cleanToken.Update() if err != nil { diff --git a/middleware/auth.go b/middleware/auth.go index f9a590017..481960efa 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -175,6 +175,7 @@ func TokenAuth() func(c *gin.Context) { } else { c.Set("token_model_limit_enabled", false) } + c.Set("allow_ips", token.GetIpLimitsMap()) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("specific_channel_id", parts[1]) diff --git a/middleware/distributor.go b/middleware/distributor.go index 3ca5b8f7f..9b55cc2d2 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -22,6 +22,14 @@ type ModelRequest struct { func Distribute() func(c *gin.Context) { return func(c *gin.Context) { + allowIpsMap := c.GetStringMap("allow_ips") + if len(allowIpsMap) != 0 { + clientIp := c.ClientIP() + if _, ok := allowIpsMap[clientIp]; !ok { + abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中") + return + } + } userId := c.GetInt("id") var channel *model.Channel channelId, ok := c.Get("specific_channel_id") diff --git a/model/token.go b/model/token.go index 272c5734f..18aa2979e 100644 --- a/model/token.go +++ b/model/token.go @@ -23,10 +23,33 @@ type Token struct { UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"` ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"` + AllowIps *string `json:"allow_ips" gorm:"default:''"` UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota DeletedAt gorm.DeletedAt `gorm:"index"` } +func (token *Token) GetIpLimitsMap() map[string]any { + // delete empty spaces + //split with \n + ipLimitsMap := make(map[string]any) + if token.AllowIps == nil { + return ipLimitsMap + } + cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "") + if cleanIps == "" { + return ipLimitsMap + } + ips := strings.Split(cleanIps, "\n") + for _, ip := range ips { + ip = strings.TrimSpace(ip) + ip = strings.ReplaceAll(ip, ",", "") + if common.IsIP(ip) { + ipLimitsMap[ip] = true + } + } + return ipLimitsMap +} + func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { var tokens []*Token var err error @@ -130,7 +153,7 @@ func (token *Token) Insert() error { // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error + err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits", "allow_ips").Updates(token).Error return err } diff --git a/web/src/pages/Token/EditToken.js b/web/src/pages/Token/EditToken.js index 2af406f8a..64aa71982 100644 --- a/web/src/pages/Token/EditToken.js +++ b/web/src/pages/Token/EditToken.js @@ -18,8 +18,8 @@ import { Select, SideSheet, Space, - Spin, - Typography, + Spin, TextArea, + Typography } from '@douyinfe/semi-ui'; import Title from '@douyinfe/semi-ui/lib/es/typography/title'; import { Divider } from 'semantic-ui-react'; @@ -34,6 +34,7 @@ const EditToken = (props) => { unlimited_quota: false, model_limits_enabled: false, model_limits: [], + allow_ips: '', }; const [inputs, setInputs] = useState(originInputs); const { @@ -43,6 +44,7 @@ const EditToken = (props) => { unlimited_quota, model_limits_enabled, model_limits, + allow_ips } = inputs; // const [visible, setVisible] = useState(false); const [models, setModels] = useState({}); @@ -374,6 +376,19 @@ const EditToken = (props) => { +
+ IP白名单(请勿过度信任此功能) +
+