Skip to content

Commit

Permalink
feat: PRT - Websocket limited per ip (#1738)
Browse files Browse the repository at this point in the history
* feat: PRT - websocket limited per ip

* feature complete.

* adding websocket limiter

* remove logs

* lintush

* add rate limit header allowing us to set rate limit using the first connection request from ngnix without limiting everyone the same.

* fix lintushiush

* feat: added new websocket connection limit from ngnix headers

* v4 lint

* fixed

* add user agent

* setting user agent and better disconnection limit

* increasing protocol version

* adding unitest

* max idle duration for ws connections

* Fix lint

* CR Fix: Close the go routine on connection close

* CR Fix: Update the idleFor after we get a message from subscription

* CR Fix: Rename function

* combine the go routines in websocket checks

---------

Co-authored-by: Elad Gildnur <[email protected]>
Co-authored-by: omerlavanet <[email protected]>
  • Loading branch information
3 people authored Dec 5, 2024
1 parent 34d7c4e commit 098c7a7
Show file tree
Hide file tree
Showing 10 changed files with 425 additions and 13 deletions.
48 changes: 38 additions & 10 deletions protocol/chainlib/consumer_websocket_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chainlib

import (
"context"
"fmt"
"strconv"
"sync/atomic"
"time"
Expand All @@ -20,6 +21,12 @@ import (
var (
WebSocketRateLimit = -1 // rate limit requests per second on websocket connection
WebSocketBanDuration = time.Duration(0) // once rate limit is reached, will not allow new incoming message for a duration
MaxIdleTimeInSeconds = int64(20 * 60) // 20 minutes of idle time will disconnect the websocket connection
)

const (
WebSocketRateLimitHeader = "x-lava-websocket-rate-limit"
WebSocketOpenConnectionsLimitHeader = "x-lava-websocket-open-connections-limit"
)

type ConsumerWebsocketManager struct {
Expand All @@ -35,6 +42,7 @@ type ConsumerWebsocketManager struct {
relaySender RelaySender
consumerWsSubscriptionManager *ConsumerWSSubscriptionManager
WebsocketConnectionUID string
headerRateLimit uint64
}

type ConsumerWebsocketManagerOptions struct {
Expand All @@ -50,6 +58,7 @@ type ConsumerWebsocketManagerOptions struct {
RelaySender RelaySender
ConsumerWsSubscriptionManager *ConsumerWSSubscriptionManager
WebsocketConnectionUID string
headerRateLimit uint64
}

func NewConsumerWebsocketManager(options ConsumerWebsocketManagerOptions) *ConsumerWebsocketManager {
Expand All @@ -66,6 +75,7 @@ func NewConsumerWebsocketManager(options ConsumerWebsocketManagerOptions) *Consu
refererData: options.RefererData,
consumerWsSubscriptionManager: options.ConsumerWsSubscriptionManager,
WebsocketConnectionUID: options.WebsocketConnectionUID,
headerRateLimit: options.headerRateLimit,
}
return cwm
}
Expand Down Expand Up @@ -142,34 +152,49 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {
}
}()

// rate limit routine
// set up a routine to check for rate limits or idle time
idleFor := atomic.Int64{}
idleFor.Store(time.Now().Unix())
requestsPerSecond := &atomic.Uint64{}
go func() {
if WebSocketRateLimit <= 0 {
if WebSocketRateLimit <= 0 && cwm.headerRateLimit <= 0 && MaxIdleTimeInSeconds <= 0 {
return
}
ticker := time.NewTicker(time.Second) // rate limit per second.
defer ticker.Stop()
for {
select {
case <-webSocketCtx.Done():
utils.LavaFormatDebug("ctx done in time checker")
return
case <-ticker.C:
// check if rate limit reached, and ban is required
if WebSocketBanDuration > 0 && requestsPerSecond.Load() > uint64(WebSocketRateLimit) {
// wait the ban duration before resetting the store.
select {
case <-webSocketCtx.Done():
if MaxIdleTimeInSeconds > 0 {
utils.LavaFormatDebug("checking idle time", utils.LogAttr("idleFor", idleFor.Load()), utils.LogAttr("maxIdleTime", MaxIdleTimeInSeconds), utils.LogAttr("now", time.Now().Unix()))
idleDuration := idleFor.Load() + MaxIdleTimeInSeconds
if time.Now().Unix() > idleDuration {
websocketConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("Connection idle for too long, closing connection. Idle time: %d", idleDuration)))
return
case <-time.After(WebSocketBanDuration): // just continue
}
}
requestsPerSecond.Store(0)
if cwm.headerRateLimit > 0 || WebSocketRateLimit > 0 {
// check if rate limit reached, and ban is required
currentRequestsPerSecondLoad := requestsPerSecond.Load()
if WebSocketBanDuration > 0 && (currentRequestsPerSecondLoad > cwm.headerRateLimit || currentRequestsPerSecondLoad > uint64(WebSocketRateLimit)) {
// wait the ban duration before resetting the store.
select {
case <-webSocketCtx.Done():
return
case <-time.After(WebSocketBanDuration): // just continue
}
}
requestsPerSecond.Store(0)
}
}
}
}()

for {
idleFor.Store(time.Now().Unix())
startTime := time.Now()
msgSeed := guidString + "_" + strconv.Itoa(rand.Intn(10000000000)) // use message seed with original guid and new int

Expand All @@ -185,7 +210,9 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {
}

// Check rate limit is met
if WebSocketRateLimit > 0 && requestsPerSecond.Add(1) > uint64(WebSocketRateLimit) {
currentRequestsPerSecond := requestsPerSecond.Add(1)
if (cwm.headerRateLimit > 0 && currentRequestsPerSecond > cwm.headerRateLimit) ||
(WebSocketRateLimit > 0 && currentRequestsPerSecond > uint64(WebSocketRateLimit)) {
rateLimitResponse, err := cwm.handleRateLimitReached(msg)
if err == nil {
websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: rateLimitResponse}
Expand Down Expand Up @@ -313,6 +340,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {
)

for subscriptionMsgReply := range subscriptionMsgsChan {
idleFor.Store(time.Now().Unix())
websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: outputFormatter(subscriptionMsgReply.Data)}
}

Expand Down
117 changes: 117 additions & 0 deletions protocol/chainlib/consumer_websocket_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package chainlib

import (
"net"
"testing"

"github.com/golang/mock/gomock"
"github.com/lavanet/lava/v4/protocol/common"
"github.com/stretchr/testify/assert"
)

func TestWebsocketConnectionLimiter(t *testing.T) {
tests := []struct {
name string
connectionLimit int64
headerLimit int64
ipAddress string
forwardedIP string
userAgent string
expectSuccess []bool
}{
{
name: "Single connection allowed",
connectionLimit: 1,
headerLimit: 0,
ipAddress: "127.0.0.1",
forwardedIP: "",
userAgent: "test-agent",
expectSuccess: []bool{true},
},
{
name: "Single connection allowed",
connectionLimit: 1,
headerLimit: 0,
ipAddress: "127.0.0.1",
forwardedIP: "",
userAgent: "test-agent",
expectSuccess: []bool{true, false},
},
{
name: "Multiple connections allowed",
connectionLimit: 2,
headerLimit: 0,
ipAddress: "127.0.0.1",
forwardedIP: "",
userAgent: "test-agent",
expectSuccess: []bool{true, true},
},
{
name: "Multiple connections allowed",
connectionLimit: 2,
headerLimit: 0,
ipAddress: "127.0.0.1",
forwardedIP: "",
userAgent: "test-agent",
expectSuccess: []bool{true, true, false},
},
{
name: "Header limit overrides global limit succeed",
connectionLimit: 3,
headerLimit: 2,
ipAddress: "127.0.0.1",
forwardedIP: "",
userAgent: "test-agent",
expectSuccess: []bool{true, true},
},
{
name: "Header limit overrides global limit fail",
connectionLimit: 0,
headerLimit: 2,
ipAddress: "127.0.0.1",
forwardedIP: "",
userAgent: "test-agent",
expectSuccess: []bool{true, true, false},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

// Create a new connection limiter
wcl := &WebsocketConnectionLimiter{
ipToNumberOfActiveConnections: make(map[string]int64),
}

// Set global connection limit for testing
MaximumNumberOfParallelWebsocketConnectionsPerIp = tt.connectionLimit

// Create mock websocket connection
mockWsConn := NewMockWebsocketConnection(ctrl)

// Set up expectations
mockWsConn.EXPECT().Locals(WebSocketOpenConnectionsLimitHeader).Return(tt.headerLimit).AnyTimes()
mockWsConn.EXPECT().Locals(common.IP_FORWARDING_HEADER_NAME).Return(tt.forwardedIP).AnyTimes()
mockWsConn.EXPECT().Locals("User-Agent").Return(tt.userAgent).AnyTimes()
mockWsConn.EXPECT().RemoteAddr().Return(&net.TCPAddr{
IP: net.ParseIP(tt.ipAddress),
Port: 8080,
}).AnyTimes()
mockWsConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Do(func(messageType int, data []byte) {
t.Logf("WriteMessage called with messageType: %d, data: %s", messageType, string(data))
}).AnyTimes()

// Test the connection
for _, expectSuccess := range tt.expectSuccess {
canOpen, _ := wcl.CanOpenConnection(mockWsConn)
if expectSuccess {
assert.True(t, canOpen, "Expected connection to be allowed")
} else {
assert.False(t, canOpen, "Expected connection to be denied")
}
}
})
}
}
22 changes: 21 additions & 1 deletion protocol/chainlib/jsonRPC.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ import (
spectypes "github.com/lavanet/lava/v4/x/spec/types"
)

const SEP = "&"
const (
SEP = "&"
)

var MaximumNumberOfParallelWebsocketConnectionsPerIp int64 = 0

type JsonRPCChainParser struct {
BaseChainParser
Expand Down Expand Up @@ -321,6 +325,7 @@ type JsonRPCChainListener struct {
refererData *RefererData
consumerWsSubscriptionManager *ConsumerWSSubscriptionManager
listeningAddress string
websocketConnectionLimiter *WebsocketConnectionLimiter
}

// NewJrpcChainListener creates a new instance of JsonRPCChainListener
Expand All @@ -338,6 +343,7 @@ func NewJrpcChainListener(ctx context.Context, listenEndpoint *lavasession.RPCEn
logger: rpcConsumerLogs,
refererData: refererData,
consumerWsSubscriptionManager: consumerWsSubscriptionManager,
websocketConnectionLimiter: &WebsocketConnectionLimiter{ipToNumberOfActiveConnections: make(map[string]int64)},
}

return chainListener
Expand All @@ -354,6 +360,8 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con
app := createAndSetupBaseAppListener(cmdFlags, apil.endpoint.HealthCheckPath, apil.healthReporter)

app.Use("/ws", func(c *fiber.Ctx) error {
apil.websocketConnectionLimiter.HandleFiberRateLimitFlags(c)

// IsWebSocketUpgrade returns true if the client
// requested upgrade to the WebSocket protocol.
if websocket.IsWebSocketUpgrade(c) {
Expand All @@ -367,6 +375,17 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con
apiInterface := apil.endpoint.ApiInterface

webSocketCallback := websocket.New(func(websocketConn *websocket.Conn) {
canOpenConnection, decreaseIpConnection := apil.websocketConnectionLimiter.CanOpenConnection(websocketConn)
defer decreaseIpConnection()
if !canOpenConnection {
return
}
rateLimitInf := websocketConn.Locals(WebSocketRateLimitHeader)
rateLimit, assertionSuccessful := rateLimitInf.(int64)
if !assertionSuccessful || rateLimit < 0 {
rateLimit = 0
}

utils.LavaFormatDebug("jsonrpc websocket opened", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String()))
defer utils.LavaFormatDebug("jsonrpc websocket closed", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String()))

Expand All @@ -383,6 +402,7 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con
RelaySender: apil.relaySender,
ConsumerWsSubscriptionManager: apil.consumerWsSubscriptionManager,
WebsocketConnectionUID: strconv.FormatUint(utils.GenerateUniqueIdentifier(), 10),
headerRateLimit: uint64(rateLimit),
})

consumerWebsocketManager.ListenToMessages()
Expand Down
77 changes: 77 additions & 0 deletions protocol/chainlib/mock_websocket.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 098c7a7

Please sign in to comment.