From 098c7a793766f2b8bc998f8d2f464a7634f5cc88 Mon Sep 17 00:00:00 2001 From: Ran Mishael <106548467+ranlavanet@users.noreply.github.com> Date: Thu, 5 Dec 2024 14:00:36 +0100 Subject: [PATCH] feat: PRT - Websocket limited per ip (#1738) * feat: PRT - websocket limited per ip * feature complete. * adding websocket limiter * remove logs * lintush * add rate limit header allowing us to set rate limit using the first connection request from ngnix without limiting everyone the same. * fix lintushiush * feat: added new websocket connection limit from ngnix headers * v4 lint * fixed * add user agent * setting user agent and better disconnection limit * increasing protocol version * adding unitest * max idle duration for ws connections * Fix lint * CR Fix: Close the go routine on connection close * CR Fix: Update the idleFor after we get a message from subscription * CR Fix: Rename function * combine the go routines in websocket checks --------- Co-authored-by: Elad Gildnur Co-authored-by: omerlavanet --- .../chainlib/consumer_websocket_manager.go | 48 ++++-- .../consumer_websocket_manager_test.go | 117 ++++++++++++++ protocol/chainlib/jsonRPC.go | 22 ++- protocol/chainlib/mock_websocket.go | 77 +++++++++ protocol/chainlib/tendermintRPC.go | 16 ++ .../chainlib/websocket_connection_limiter.go | 150 ++++++++++++++++++ protocol/common/cobra_common.go | 2 + protocol/rpcconsumer/rpcconsumer.go | 2 + .../pre_setups/init_lava_only_with_node.sh | 2 +- x/protocol/types/params.go | 2 +- 10 files changed, 425 insertions(+), 13 deletions(-) create mode 100644 protocol/chainlib/consumer_websocket_manager_test.go create mode 100644 protocol/chainlib/mock_websocket.go create mode 100644 protocol/chainlib/websocket_connection_limiter.go diff --git a/protocol/chainlib/consumer_websocket_manager.go b/protocol/chainlib/consumer_websocket_manager.go index ee8ab6319a..83ff9f08f3 100644 --- a/protocol/chainlib/consumer_websocket_manager.go +++ b/protocol/chainlib/consumer_websocket_manager.go @@ -2,6 +2,7 @@ package chainlib import ( "context" + "fmt" "strconv" "sync/atomic" "time" @@ -20,6 +21,12 @@ import ( var ( WebSocketRateLimit = -1 // rate limit requests per second on websocket connection WebSocketBanDuration = time.Duration(0) // once rate limit is reached, will not allow new incoming message for a duration + MaxIdleTimeInSeconds = int64(20 * 60) // 20 minutes of idle time will disconnect the websocket connection +) + +const ( + WebSocketRateLimitHeader = "x-lava-websocket-rate-limit" + WebSocketOpenConnectionsLimitHeader = "x-lava-websocket-open-connections-limit" ) type ConsumerWebsocketManager struct { @@ -35,6 +42,7 @@ type ConsumerWebsocketManager struct { relaySender RelaySender consumerWsSubscriptionManager *ConsumerWSSubscriptionManager WebsocketConnectionUID string + headerRateLimit uint64 } type ConsumerWebsocketManagerOptions struct { @@ -50,6 +58,7 @@ type ConsumerWebsocketManagerOptions struct { RelaySender RelaySender ConsumerWsSubscriptionManager *ConsumerWSSubscriptionManager WebsocketConnectionUID string + headerRateLimit uint64 } func NewConsumerWebsocketManager(options ConsumerWebsocketManagerOptions) *ConsumerWebsocketManager { @@ -66,6 +75,7 @@ func NewConsumerWebsocketManager(options ConsumerWebsocketManagerOptions) *Consu refererData: options.RefererData, consumerWsSubscriptionManager: options.ConsumerWsSubscriptionManager, WebsocketConnectionUID: options.WebsocketConnectionUID, + headerRateLimit: options.headerRateLimit, } return cwm } @@ -142,10 +152,12 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { } }() - // rate limit routine + // set up a routine to check for rate limits or idle time + idleFor := atomic.Int64{} + idleFor.Store(time.Now().Unix()) requestsPerSecond := &atomic.Uint64{} go func() { - if WebSocketRateLimit <= 0 { + if WebSocketRateLimit <= 0 && cwm.headerRateLimit <= 0 && MaxIdleTimeInSeconds <= 0 { return } ticker := time.NewTicker(time.Second) // rate limit per second. @@ -153,23 +165,36 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { for { select { case <-webSocketCtx.Done(): + utils.LavaFormatDebug("ctx done in time checker") return case <-ticker.C: - // check if rate limit reached, and ban is required - if WebSocketBanDuration > 0 && requestsPerSecond.Load() > uint64(WebSocketRateLimit) { - // wait the ban duration before resetting the store. - select { - case <-webSocketCtx.Done(): + if MaxIdleTimeInSeconds > 0 { + utils.LavaFormatDebug("checking idle time", utils.LogAttr("idleFor", idleFor.Load()), utils.LogAttr("maxIdleTime", MaxIdleTimeInSeconds), utils.LogAttr("now", time.Now().Unix())) + idleDuration := idleFor.Load() + MaxIdleTimeInSeconds + if time.Now().Unix() > idleDuration { + websocketConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("Connection idle for too long, closing connection. Idle time: %d", idleDuration))) return - case <-time.After(WebSocketBanDuration): // just continue } } - requestsPerSecond.Store(0) + if cwm.headerRateLimit > 0 || WebSocketRateLimit > 0 { + // check if rate limit reached, and ban is required + currentRequestsPerSecondLoad := requestsPerSecond.Load() + if WebSocketBanDuration > 0 && (currentRequestsPerSecondLoad > cwm.headerRateLimit || currentRequestsPerSecondLoad > uint64(WebSocketRateLimit)) { + // wait the ban duration before resetting the store. + select { + case <-webSocketCtx.Done(): + return + case <-time.After(WebSocketBanDuration): // just continue + } + } + requestsPerSecond.Store(0) + } } } }() for { + idleFor.Store(time.Now().Unix()) startTime := time.Now() msgSeed := guidString + "_" + strconv.Itoa(rand.Intn(10000000000)) // use message seed with original guid and new int @@ -185,7 +210,9 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { } // Check rate limit is met - if WebSocketRateLimit > 0 && requestsPerSecond.Add(1) > uint64(WebSocketRateLimit) { + currentRequestsPerSecond := requestsPerSecond.Add(1) + if (cwm.headerRateLimit > 0 && currentRequestsPerSecond > cwm.headerRateLimit) || + (WebSocketRateLimit > 0 && currentRequestsPerSecond > uint64(WebSocketRateLimit)) { rateLimitResponse, err := cwm.handleRateLimitReached(msg) if err == nil { websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: rateLimitResponse} @@ -313,6 +340,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { ) for subscriptionMsgReply := range subscriptionMsgsChan { + idleFor.Store(time.Now().Unix()) websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: outputFormatter(subscriptionMsgReply.Data)} } diff --git a/protocol/chainlib/consumer_websocket_manager_test.go b/protocol/chainlib/consumer_websocket_manager_test.go new file mode 100644 index 0000000000..c501a663c9 --- /dev/null +++ b/protocol/chainlib/consumer_websocket_manager_test.go @@ -0,0 +1,117 @@ +package chainlib + +import ( + "net" + "testing" + + "github.com/golang/mock/gomock" + "github.com/lavanet/lava/v4/protocol/common" + "github.com/stretchr/testify/assert" +) + +func TestWebsocketConnectionLimiter(t *testing.T) { + tests := []struct { + name string + connectionLimit int64 + headerLimit int64 + ipAddress string + forwardedIP string + userAgent string + expectSuccess []bool + }{ + { + name: "Single connection allowed", + connectionLimit: 1, + headerLimit: 0, + ipAddress: "127.0.0.1", + forwardedIP: "", + userAgent: "test-agent", + expectSuccess: []bool{true}, + }, + { + name: "Single connection allowed", + connectionLimit: 1, + headerLimit: 0, + ipAddress: "127.0.0.1", + forwardedIP: "", + userAgent: "test-agent", + expectSuccess: []bool{true, false}, + }, + { + name: "Multiple connections allowed", + connectionLimit: 2, + headerLimit: 0, + ipAddress: "127.0.0.1", + forwardedIP: "", + userAgent: "test-agent", + expectSuccess: []bool{true, true}, + }, + { + name: "Multiple connections allowed", + connectionLimit: 2, + headerLimit: 0, + ipAddress: "127.0.0.1", + forwardedIP: "", + userAgent: "test-agent", + expectSuccess: []bool{true, true, false}, + }, + { + name: "Header limit overrides global limit succeed", + connectionLimit: 3, + headerLimit: 2, + ipAddress: "127.0.0.1", + forwardedIP: "", + userAgent: "test-agent", + expectSuccess: []bool{true, true}, + }, + { + name: "Header limit overrides global limit fail", + connectionLimit: 0, + headerLimit: 2, + ipAddress: "127.0.0.1", + forwardedIP: "", + userAgent: "test-agent", + expectSuccess: []bool{true, true, false}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create a new connection limiter + wcl := &WebsocketConnectionLimiter{ + ipToNumberOfActiveConnections: make(map[string]int64), + } + + // Set global connection limit for testing + MaximumNumberOfParallelWebsocketConnectionsPerIp = tt.connectionLimit + + // Create mock websocket connection + mockWsConn := NewMockWebsocketConnection(ctrl) + + // Set up expectations + mockWsConn.EXPECT().Locals(WebSocketOpenConnectionsLimitHeader).Return(tt.headerLimit).AnyTimes() + mockWsConn.EXPECT().Locals(common.IP_FORWARDING_HEADER_NAME).Return(tt.forwardedIP).AnyTimes() + mockWsConn.EXPECT().Locals("User-Agent").Return(tt.userAgent).AnyTimes() + mockWsConn.EXPECT().RemoteAddr().Return(&net.TCPAddr{ + IP: net.ParseIP(tt.ipAddress), + Port: 8080, + }).AnyTimes() + mockWsConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Do(func(messageType int, data []byte) { + t.Logf("WriteMessage called with messageType: %d, data: %s", messageType, string(data)) + }).AnyTimes() + + // Test the connection + for _, expectSuccess := range tt.expectSuccess { + canOpen, _ := wcl.CanOpenConnection(mockWsConn) + if expectSuccess { + assert.True(t, canOpen, "Expected connection to be allowed") + } else { + assert.False(t, canOpen, "Expected connection to be denied") + } + } + }) + } +} diff --git a/protocol/chainlib/jsonRPC.go b/protocol/chainlib/jsonRPC.go index a8c18e2db7..de48263375 100644 --- a/protocol/chainlib/jsonRPC.go +++ b/protocol/chainlib/jsonRPC.go @@ -28,7 +28,11 @@ import ( spectypes "github.com/lavanet/lava/v4/x/spec/types" ) -const SEP = "&" +const ( + SEP = "&" +) + +var MaximumNumberOfParallelWebsocketConnectionsPerIp int64 = 0 type JsonRPCChainParser struct { BaseChainParser @@ -321,6 +325,7 @@ type JsonRPCChainListener struct { refererData *RefererData consumerWsSubscriptionManager *ConsumerWSSubscriptionManager listeningAddress string + websocketConnectionLimiter *WebsocketConnectionLimiter } // NewJrpcChainListener creates a new instance of JsonRPCChainListener @@ -338,6 +343,7 @@ func NewJrpcChainListener(ctx context.Context, listenEndpoint *lavasession.RPCEn logger: rpcConsumerLogs, refererData: refererData, consumerWsSubscriptionManager: consumerWsSubscriptionManager, + websocketConnectionLimiter: &WebsocketConnectionLimiter{ipToNumberOfActiveConnections: make(map[string]int64)}, } return chainListener @@ -354,6 +360,8 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con app := createAndSetupBaseAppListener(cmdFlags, apil.endpoint.HealthCheckPath, apil.healthReporter) app.Use("/ws", func(c *fiber.Ctx) error { + apil.websocketConnectionLimiter.HandleFiberRateLimitFlags(c) + // IsWebSocketUpgrade returns true if the client // requested upgrade to the WebSocket protocol. if websocket.IsWebSocketUpgrade(c) { @@ -367,6 +375,17 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con apiInterface := apil.endpoint.ApiInterface webSocketCallback := websocket.New(func(websocketConn *websocket.Conn) { + canOpenConnection, decreaseIpConnection := apil.websocketConnectionLimiter.CanOpenConnection(websocketConn) + defer decreaseIpConnection() + if !canOpenConnection { + return + } + rateLimitInf := websocketConn.Locals(WebSocketRateLimitHeader) + rateLimit, assertionSuccessful := rateLimitInf.(int64) + if !assertionSuccessful || rateLimit < 0 { + rateLimit = 0 + } + utils.LavaFormatDebug("jsonrpc websocket opened", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String())) defer utils.LavaFormatDebug("jsonrpc websocket closed", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String())) @@ -383,6 +402,7 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con RelaySender: apil.relaySender, ConsumerWsSubscriptionManager: apil.consumerWsSubscriptionManager, WebsocketConnectionUID: strconv.FormatUint(utils.GenerateUniqueIdentifier(), 10), + headerRateLimit: uint64(rateLimit), }) consumerWebsocketManager.ListenToMessages() diff --git a/protocol/chainlib/mock_websocket.go b/protocol/chainlib/mock_websocket.go new file mode 100644 index 0000000000..a87f28f0ab --- /dev/null +++ b/protocol/chainlib/mock_websocket.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: consumer_websocket_manager_test.go + +// Package chainlib is a generated GoMock package. +package chainlib + +import ( + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockWebsocketConnection is a mock of WebsocketConnection interface. +type MockWebsocketConnection struct { + ctrl *gomock.Controller + recorder *MockWebsocketConnectionMockRecorder +} + +// MockWebsocketConnectionMockRecorder is the mock recorder for MockWebsocketConnection. +type MockWebsocketConnectionMockRecorder struct { + mock *MockWebsocketConnection +} + +// NewMockWebsocketConnection creates a new mock instance. +func NewMockWebsocketConnection(ctrl *gomock.Controller) *MockWebsocketConnection { + mock := &MockWebsocketConnection{ctrl: ctrl} + mock.recorder = &MockWebsocketConnectionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWebsocketConnection) EXPECT() *MockWebsocketConnectionMockRecorder { + return m.recorder +} + +// Locals mocks base method. +func (m *MockWebsocketConnection) Locals(key string) interface{} { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Locals", key) + ret0, _ := ret[0].(interface{}) + return ret0 +} + +// Locals indicates an expected call of Locals. +func (mr *MockWebsocketConnectionMockRecorder) Locals(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Locals", reflect.TypeOf((*MockWebsocketConnection)(nil).Locals), key) +} + +// RemoteAddr mocks base method. +func (m *MockWebsocketConnection) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr. +func (mr *MockWebsocketConnectionMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockWebsocketConnection)(nil).RemoteAddr)) +} + +// WriteMessage mocks base method. +func (m *MockWebsocketConnection) WriteMessage(messageType int, data []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteMessage", messageType, data) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteMessage indicates an expected call of WriteMessage. +func (mr *MockWebsocketConnectionMockRecorder) WriteMessage(messageType, data interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteMessage", reflect.TypeOf((*MockWebsocketConnection)(nil).WriteMessage), messageType, data) +} diff --git a/protocol/chainlib/tendermintRPC.go b/protocol/chainlib/tendermintRPC.go index ad35dbc517..3bb8867095 100644 --- a/protocol/chainlib/tendermintRPC.go +++ b/protocol/chainlib/tendermintRPC.go @@ -353,6 +353,7 @@ type TendermintRpcChainListener struct { refererData *RefererData consumerWsSubscriptionManager *ConsumerWSSubscriptionManager listeningAddress string + websocketConnectionLimiter *WebsocketConnectionLimiter } // NewTendermintRpcChainListener creates a new instance of TendermintRpcChainListener @@ -370,6 +371,7 @@ func NewTendermintRpcChainListener(ctx context.Context, listenEndpoint *lavasess logger: rpcConsumerLogs, refererData: refererData, consumerWsSubscriptionManager: consumerWsSubscriptionManager, + websocketConnectionLimiter: &WebsocketConnectionLimiter{ipToNumberOfActiveConnections: make(map[string]int64)}, } return chainListener @@ -388,6 +390,7 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context, cmdFlags comm apiInterface := apil.endpoint.ApiInterface app.Use("/ws", func(c *fiber.Ctx) error { + apil.websocketConnectionLimiter.HandleFiberRateLimitFlags(c) // IsWebSocketUpgrade returns true if the client // requested upgrade to the WebSocket protocol. if websocket.IsWebSocketUpgrade(c) { @@ -397,6 +400,18 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context, cmdFlags comm return fiber.ErrUpgradeRequired }) webSocketCallback := websocket.New(func(websocketConn *websocket.Conn) { + canOpenConnection, decreaseIpConnection := apil.websocketConnectionLimiter.CanOpenConnection(websocketConn) + defer decreaseIpConnection() + if !canOpenConnection { + return + } + + rateLimitInf := websocketConn.Locals(WebSocketRateLimitHeader) + rateLimit, assertionSuccessful := rateLimitInf.(int64) + if !assertionSuccessful || rateLimit < 0 { + rateLimit = 0 + } + utils.LavaFormatDebug("tendermintrpc websocket opened", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String())) defer utils.LavaFormatDebug("tendermintrpc websocket closed", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String())) @@ -413,6 +428,7 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context, cmdFlags comm RelaySender: apil.relaySender, ConsumerWsSubscriptionManager: apil.consumerWsSubscriptionManager, WebsocketConnectionUID: strconv.FormatUint(utils.GenerateUniqueIdentifier(), 10), + headerRateLimit: uint64(rateLimit), }) consumerWebsocketManager.ListenToMessages() diff --git a/protocol/chainlib/websocket_connection_limiter.go b/protocol/chainlib/websocket_connection_limiter.go new file mode 100644 index 0000000000..9be6a27e73 --- /dev/null +++ b/protocol/chainlib/websocket_connection_limiter.go @@ -0,0 +1,150 @@ +package chainlib + +import ( + "fmt" + "net" + "strconv" + "strings" + "sync" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/websocket/v2" + "github.com/lavanet/lava/v4/protocol/common" + "github.com/lavanet/lava/v4/utils" +) + +// WebsocketConnection defines the interface for websocket connections +type WebsocketConnection interface { + // Add only the methods you need to mock + RemoteAddr() net.Addr + Locals(key string) interface{} + WriteMessage(messageType int, data []byte) error +} + +// Will limit a certain amount of connections per IP +type WebsocketConnectionLimiter struct { + ipToNumberOfActiveConnections map[string]int64 + lock sync.RWMutex +} + +func (wcl *WebsocketConnectionLimiter) HandleFiberRateLimitFlags(c *fiber.Ctx) { + userAgent := c.Get(fiber.HeaderUserAgent) + // Store the User-Agent in locals for later use + c.Locals(fiber.HeaderUserAgent, userAgent) + + forwardedFor := c.Get(common.IP_FORWARDING_HEADER_NAME) + if forwardedFor == "" { + // If not present, fallback to c.IP() which retrieves the real IP + forwardedFor = c.IP() + } + // Store the X-Forwarded-For or real IP in the context + c.Locals(common.IP_FORWARDING_HEADER_NAME, forwardedFor) + + rateLimitString := c.Get(WebSocketRateLimitHeader) + rateLimit, err := strconv.ParseInt(rateLimitString, 10, 64) + if err != nil { + rateLimit = 0 + } + c.Locals(WebSocketRateLimitHeader, rateLimit) + + connectionLimitString := c.Get(WebSocketOpenConnectionsLimitHeader) + connectionLimit, err := strconv.ParseInt(connectionLimitString, 10, 64) + if err != nil { + connectionLimit = 0 + } + c.Locals(WebSocketOpenConnectionsLimitHeader, connectionLimit) +} + +func (wcl *WebsocketConnectionLimiter) getConnectionLimit(websocketConn WebsocketConnection) int64 { + connectionLimitHeaderValue, ok := websocketConn.Locals(WebSocketOpenConnectionsLimitHeader).(int64) + if !ok || connectionLimitHeaderValue < 0 { + connectionLimitHeaderValue = 0 + } + // Do not allow header to overwrite flag value if its set. + if MaximumNumberOfParallelWebsocketConnectionsPerIp > 0 && connectionLimitHeaderValue > MaximumNumberOfParallelWebsocketConnectionsPerIp { + return MaximumNumberOfParallelWebsocketConnectionsPerIp + } + // Return the larger of the global limit (if set) or the header value + return utils.Max(MaximumNumberOfParallelWebsocketConnectionsPerIp, connectionLimitHeaderValue) +} + +func (wcl *WebsocketConnectionLimiter) CanOpenConnection(websocketConn WebsocketConnection) (bool, func()) { + // Check which connection limit is higher and use that. + connectionLimit := wcl.getConnectionLimit(websocketConn) + decreaseIpConnectionCallback := func() {} + if connectionLimit > 0 { // 0 is disabled. + ipForwardedInterface := websocketConn.Locals(common.IP_FORWARDING_HEADER_NAME) + ipForwarded, assertionSuccessful := ipForwardedInterface.(string) + if !assertionSuccessful { + ipForwarded = "" + } + ip := websocketConn.RemoteAddr().String() + userAgent, assertionSuccessful := websocketConn.Locals("User-Agent").(string) + if !assertionSuccessful { + userAgent = "" + } + key := wcl.getKey(ip, ipForwarded, userAgent) + + // Check current connections before incrementing + currentConnections := wcl.getCurrentAmountOfConnections(key) + // If already at or exceeding limit, deny the connection + if currentConnections >= connectionLimit { + websocketConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("Too Many Open Connections, limited to %d", connectionLimit))) + return false, decreaseIpConnectionCallback + } + // If under limit, increment and return cleanup function + wcl.addIpConnection(key) + decreaseIpConnectionCallback = func() { wcl.decreaseIpConnection(key) } + } + return true, decreaseIpConnectionCallback +} + +func (wcl *WebsocketConnectionLimiter) getCurrentAmountOfConnections(key string) int64 { + wcl.lock.RLock() + defer wcl.lock.RUnlock() + return wcl.ipToNumberOfActiveConnections[key] +} + +func (wcl *WebsocketConnectionLimiter) addIpConnection(key string) { + wcl.lock.Lock() + defer wcl.lock.Unlock() + // wether it exists or not we add 1. + wcl.ipToNumberOfActiveConnections[key] += 1 +} + +func (wcl *WebsocketConnectionLimiter) decreaseIpConnection(key string) { + wcl.lock.Lock() + defer wcl.lock.Unlock() + // it must exist as we dont get here without adding it prior + wcl.ipToNumberOfActiveConnections[key] -= 1 + if wcl.ipToNumberOfActiveConnections[key] == 0 { + delete(wcl.ipToNumberOfActiveConnections, key) + } +} + +func (wcl *WebsocketConnectionLimiter) getKey(ip string, forwardedIp string, userAgent string) string { + returnedKey := "" + ipOriginal := net.ParseIP(ip) + if ipOriginal != nil { + returnedKey = ipOriginal.String() + } else { + ipPart, _, err := net.SplitHostPort(ip) + if err == nil { + returnedKey = ipPart + } + } + ips := strings.Split(forwardedIp, ",") + for _, ipStr := range ips { + ipParsed := net.ParseIP(strings.TrimSpace(ipStr)) + if ipParsed != nil { + returnedKey += SEP + ipParsed.String() + } else { + ipPart, _, err := net.SplitHostPort(ipStr) + if err == nil { + returnedKey += SEP + ipPart + } + } + } + returnedKey += SEP + userAgent + return returnedKey +} diff --git a/protocol/common/cobra_common.go b/protocol/common/cobra_common.go index 74c8211b37..05e6259ec6 100644 --- a/protocol/common/cobra_common.go +++ b/protocol/common/cobra_common.go @@ -47,6 +47,8 @@ const ( // websocket flags RateLimitWebSocketFlag = "rate-limit-websocket-requests-per-connection" BanDurationForWebsocketRateLimitExceededFlag = "ban-duration-for-websocket-rate-limit-exceeded" + LimitParallelWebsocketConnectionsPerIpFlag = "limit-parallel-websocket-connections-per-ip" + LimitWebsocketIdleTimeFlag = "limit-websocket-connection-idle-time" RateLimitRequestPerSecondFlag = "rate-limit-requests-per-second" ) diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index dc92b3e117..c69a7a069f 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -766,6 +766,8 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77 cmdRPCConsumer.Flags().DurationVar(&metrics.OptimizerQosServerPushInterval, common.OptimizerQosServerPushIntervalFlag, time.Minute*5, "interval to push optimizer qos reports") cmdRPCConsumer.Flags().DurationVar(&metrics.OptimizerQosServerSamplingInterval, common.OptimizerQosServerSamplingIntervalFlag, time.Second*1, "interval to sample optimizer qos reports") cmdRPCConsumer.Flags().IntVar(&chainlib.WebSocketRateLimit, common.RateLimitWebSocketFlag, chainlib.WebSocketRateLimit, "rate limit (per second) websocket requests per user connection, default is unlimited") + cmdRPCConsumer.Flags().Int64Var(&chainlib.MaximumNumberOfParallelWebsocketConnectionsPerIp, common.LimitParallelWebsocketConnectionsPerIpFlag, chainlib.MaximumNumberOfParallelWebsocketConnectionsPerIp, "limit number of parallel connections to websocket, per ip, default is unlimited (0)") + cmdRPCConsumer.Flags().Int64Var(&chainlib.MaxIdleTimeInSeconds, common.LimitWebsocketIdleTimeFlag, chainlib.MaxIdleTimeInSeconds, "limit the idle time in seconds for a websocket connection, default is 20 minutes ( 20 * 60 )") cmdRPCConsumer.Flags().DurationVar(&chainlib.WebSocketBanDuration, common.BanDurationForWebsocketRateLimitExceededFlag, chainlib.WebSocketBanDuration, "once websocket rate limit is reached, user will be banned Xfor a duration, default no ban") cmdRPCConsumer.Flags().Bool(LavaOverLavaBackupFlagName, true, "enable lava over lava backup to regular rpc calls") common.AddRollingLogConfig(cmdRPCConsumer) diff --git a/scripts/pre_setups/init_lava_only_with_node.sh b/scripts/pre_setups/init_lava_only_with_node.sh index c47abc0ada..61f814b263 100755 --- a/scripts/pre_setups/init_lava_only_with_node.sh +++ b/scripts/pre_setups/init_lava_only_with_node.sh @@ -57,7 +57,7 @@ wait_next_block screen -d -m -S consumers bash -c "source ~/.bashrc; lavap rpcconsumer \ 127.0.0.1:3360 LAV1 rest 127.0.0.1:3361 LAV1 tendermintrpc 127.0.0.1:3362 LAV1 grpc \ -$EXTRA_PORTAL_FLAGS --geolocation 1 --optimizer-qos-listen --log_level trace --from user1 --chain-id lava --add-api-method-metrics --allow-insecure-provider-dialing --metrics-listen-address ":7779" 2>&1 | tee $LOGS_DIR/CONSUMERS.log" && sleep 0.25 +$EXTRA_PORTAL_FLAGS --geolocation 1 --optimizer-qos-listen --log_level trace --from user1 --chain-id lava --add-api-method-metrics --limit-parallel-websocket-connections-per-ip 1 --allow-insecure-provider-dialing --metrics-listen-address ":7779" 2>&1 | tee $LOGS_DIR/CONSUMERS.log" && sleep 0.25 echo "--- setting up screens done ---" screen -ls \ No newline at end of file diff --git a/x/protocol/types/params.go b/x/protocol/types/params.go index 2848f2aeb3..f10f41d29e 100644 --- a/x/protocol/types/params.go +++ b/x/protocol/types/params.go @@ -12,7 +12,7 @@ import ( var _ paramtypes.ParamSet = (*Params)(nil) const ( - TARGET_VERSION = "4.1.4" + TARGET_VERSION = "4.1.6" MIN_VERSION = "3.1.0" )