Skip to content

Commit

Permalink
feat: support IP rate limiting (#59)
Browse files Browse the repository at this point in the history
* feat: support ip rate limiting in the backend

* feat: add IpRateTable

* update: rename to  IP Rate Limiting and add i18n support
  • Loading branch information
love98ooo authored Aug 10, 2024
1 parent 31fc868 commit c085f09
Show file tree
Hide file tree
Showing 9 changed files with 430 additions and 9 deletions.
18 changes: 18 additions & 0 deletions controllers/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ func checkExpressions(expressions []*object.Expression, ruleType string) error {
return checkWafRule(values)
case "IP":
return checkIpRule(values)
case "IP Rate Limiting":
return checkIpRateRule(expressions)
}
return nil
}
Expand All @@ -182,3 +184,19 @@ func checkIpRule(ipLists []string) error {
}
return nil
}

func checkIpRateRule(expressions []*object.Expression) error {
if len(expressions) != 1 {
return errors.New("IP Rate Limiting rule must have exactly one expression")
}
expression := expressions[0]
_, err := util.ParseIntWithError(expression.Operator)
if err != nil {
return err
}
_, err = util.ParseIntWithError(expression.Value)
if err != nil {
return err
}
return nil
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ require (
github.com/xorm-io/core v0.7.4
github.com/xorm-io/xorm v1.1.6
golang.org/x/net v0.21.0
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
modernc.org/sqlite v1.11.2
)
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
Expand Down
4 changes: 4 additions & 0 deletions rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func CheckRules(ruleIds []string, r *http.Request) (string, string, error) {
ruleObj = &IpRule{}
case "WAF":
ruleObj = &WafRule{}
case "IP Rate Limiting":
ruleObj = &IpRateRule{
ruleName: rule.GetId(),
}
default:
return "", "", fmt.Errorf("unknown rule type: %s for rule: %s", rule.Type, rule.GetId())
}
Expand Down
128 changes: 128 additions & 0 deletions rule/rule_ip_rate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2024 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package rule

import (
"net/http"
"sync"
"time"

"github.com/casbin/caswaf/object"
"github.com/casbin/caswaf/util"
"golang.org/x/time/rate"
)

type IpRateRule struct {
ruleName string
}

type IpRateLimiter struct {
ips map[string]*rate.Limiter
mu *sync.RWMutex
r rate.Limit
b int
}

var blackList = map[string]map[string]time.Time{}

var ipRateLimiters = map[string]*IpRateLimiter{}

// NewIpRateLimiter .
func NewIpRateLimiter(r rate.Limit, b int) *IpRateLimiter {
i := &IpRateLimiter{
ips: make(map[string]*rate.Limiter),
mu: &sync.RWMutex{},
r: r,
b: b,
}

return i
}

// AddIP creates a new rate limiter and adds it to the ips map,
// using the IP address as the key
func (i *IpRateLimiter) AddIP(ip string) *rate.Limiter {
i.mu.Lock()
defer i.mu.Unlock()

limiter := rate.NewLimiter(i.r, i.b)

i.ips[ip] = limiter

return limiter
}

// GetLimiter returns the rate limiter for the provided IP address if it exists.
// Otherwise, calls AddIP to add IP address to the map
func (i *IpRateLimiter) GetLimiter(ip string) *rate.Limiter {
i.mu.Lock()
limiter, exists := i.ips[ip]

if !exists {
i.mu.Unlock()
return i.AddIP(ip)
}

i.mu.Unlock()

return limiter
}

func (r *IpRateRule) checkRule(expressions []*object.Expression, req *http.Request) (bool, string, string, error) {
expression := expressions[0] // IpRate rule should have only one expression
clientIp := util.GetClientIp(req)

// If the client IP is in the blacklist, check the block time
createAt, ok := blackList[r.ruleName][clientIp]
if ok {
blockTime := util.ParseInt(expression.Value)
if time.Now().Sub(createAt) < time.Duration(blockTime)*time.Second {
return true, "Block", "Rate limit exceeded", nil
} else {
delete(blackList, clientIp)
}
}

// If the client IP is not in the blacklist, check the rate limit
ipRateLimiter := ipRateLimiters[r.ruleName]
parseInt := util.ParseInt(expression.Operator)
if ipRateLimiter == nil {
ipRateLimiter = NewIpRateLimiter(rate.Limit(parseInt), parseInt)
ipRateLimiters[r.ruleName] = ipRateLimiter
}

// If the rate limit has changed, update the rate limiter
limiter := ipRateLimiter.GetLimiter(clientIp)
if ipRateLimiter.r != rate.Limit(parseInt) {
ipRateLimiter.r = rate.Limit(parseInt)
ipRateLimiter.b = parseInt
limiter.SetLimit(ipRateLimiter.r)
limiter.SetBurst(ipRateLimiter.b)
err := limiter.Wait(req.Context())
if err != nil {
return false, "", "", err
}
} else {
// If the rate limit is exceeded, add the client IP to the blacklist
allow := limiter.Allow()
if !allow {
blackList[r.ruleName] = map[string]time.Time{}
blackList[r.ruleName][clientIp] = time.Now()
return true, "Block", "Rate limit exceeded", nil
}
}

return false, "", "", nil
}
133 changes: 133 additions & 0 deletions rule/rule_ip_rate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package rule

import (
"net/http"
"testing"

"github.com/casbin/caswaf/object"
)

func TestIpRateRule_checkRule(t *testing.T) {
type fields struct {
ruleName string
}
type args struct {
args []struct {
expressions []*object.Expression
req *http.Request
}
}

tests := []struct {
name string
fields fields
args args
want []bool
want1 []string
want2 []string
wantErr []bool
}{
{
name: "Test 1",
fields: fields{
ruleName: "rule1",
},
args: args{
args: []struct {
expressions []*object.Expression
req *http.Request
}{
{
expressions: []*object.Expression{
{
Operator: "1",
Value: "1",
},
},
req: &http.Request{
RemoteAddr: "127.0.0.1",
},
},
{
expressions: []*object.Expression{
{
Operator: "1",
Value: "1",
},
},
req: &http.Request{
RemoteAddr: "127.0.0.1",
},
},
},
},
want: []bool{false, true},
want1: []string{"", "Block"},
want2: []string{"", "Rate limit exceeded"},
wantErr: []bool{false, false},
},
{
name: "Test 2",
fields: fields{
ruleName: "rule2",
},
args: args{
args: []struct {
expressions []*object.Expression
req *http.Request
}{
{
expressions: []*object.Expression{
{
Operator: "1",
Value: "1",
},
},
req: &http.Request{
RemoteAddr: "127.0.0.1",
},
},
{
expressions: []*object.Expression{
{
Operator: "10",
Value: "1",
},
},
req: &http.Request{
RemoteAddr: "127.0.0.1",
},
},
},
},
want: []bool{false, false},
want1: []string{"", ""},
want2: []string{"", ""},
wantErr: []bool{false, false},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &IpRateRule{
ruleName: tt.fields.ruleName,
}
for i, arg := range tt.args.args {
got, got1, got2, err := r.checkRule(arg.expressions, arg.req)
if (err != nil) != tt.wantErr[i] {
t.Errorf("checkRule() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want[i] {
t.Errorf("checkRule() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1[i] {
t.Errorf("checkRule() got1 = %v, want %v", got1, tt.want1)
}
if got2 != tt.want2[i] {
t.Errorf("checkRule() got2 = %v, want %v", got2, tt.want2)
}
}
})
}
}
27 changes: 24 additions & 3 deletions web/src/RuleEditPage.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import i18next from "i18next";
import WafRuleTable from "./components/WafRuleTable";
import IpRuleTable from "./components/IpRuleTable";
import UaRuleTable from "./components/UaRuleTable";
import IpRateRuleTable from "./components/IpRateRuleTable";

const {Option} = Select;

Expand Down Expand Up @@ -57,12 +58,21 @@ class RuleEditPage extends React.Component {
});
}

updateRuleFieldInExpressions(index, key, value) {
const rule = Setting.deepCopy(this.state.rule);
rule.expressions[index][key] = value;
this.updateRuleField("expressions", rule.expressions);
this.setState({
rule: rule,
});
}

renderRule() {
return (
<Card size="small" title={
<div>
Edit Rule&nbsp;&nbsp;&nbsp;&nbsp;
<Button type="primary" onClick={this.submitRuleEdit.bind(this)}>Save</Button>
{i18next.t("rule:Edit Rule")}&nbsp;&nbsp;&nbsp;&nbsp;
<Button type="primary" onClick={this.submitRuleEdit.bind(this)}>{i18next.t("general:Save")}</Button>
</div>
} style={{marginTop: 10}} type="inner">
<Row style={{marginTop: "20px"}}>
Expand All @@ -86,7 +96,7 @@ class RuleEditPage extends React.Component {
{value: "WAF", text: "WAF"},
{value: "IP", text: "IP"},
{value: "User-Agent", text: "User-Agent"},
// {value: "frequency", text: "Frequency"},
{value: "IP Rate Limiting", text: "IP Rate Limiting"},
// {value: "complex", text: "Complex"},
].map((item, index) => <Option key={index} value={item.value}>{item.text}</Option>)
}
Expand Down Expand Up @@ -131,6 +141,17 @@ class RuleEditPage extends React.Component {
/>
) : null
}
{
this.state.rule.type === "IP Rate Limiting" ? (
<IpRateRuleTable
title={"IP Rate Limiting"}
table={this.state.rule.expressions}
ruleName={this.state.rule.name}
account={this.props.account}
onUpdateTable={(value) => {this.updateRuleField("expressions", value);}}
/>
) : null
}
</Col>
</Row>
{
Expand Down
Loading

0 comments on commit c085f09

Please sign in to comment.