Skip to content

Commit

Permalink
fix(rpc): Added Timeout for RPC handler (#673)
Browse files Browse the repository at this point in the history
  • Loading branch information
zale144 authored Apr 16, 2024
1 parent c4379ff commit cefca7a
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 44 deletions.
70 changes: 32 additions & 38 deletions rpc/json/service.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
package json

import (
"context"
"encoding/json"
"fmt"
"net/http"
"reflect"
"strconv"
"time"

"cosmossdk.io/errors"
"github.com/gorilla/rpc/v2/json2"
"github.com/tendermint/tendermint/libs/pubsub"
tmquery "github.com/tendermint/tendermint/libs/pubsub/query"
rpcclient "github.com/tendermint/tendermint/rpc/client"
ctypes "github.com/tendermint/tendermint/rpc/core/types"
rpctypes "github.com/tendermint/tendermint/rpc/jsonrpc/types"
Expand All @@ -21,9 +17,22 @@ import (
"github.com/dymensionxyz/dymint/types"
)

const (
// DefaultSubscribeBufferSize is the default buffer size for a subscription.
defaultSubscribeBufferSize = 100
)

// GetHTTPHandler returns handler configured to serve Tendermint-compatible RPC.
func GetHTTPHandler(l *client.Client, logger types.Logger) (http.Handler, error) {
return newHandler(newService(l, logger), json2.NewCodec(), logger), nil
func GetHTTPHandler(l *client.Client, logger types.Logger, opts ...option) (http.Handler, error) {
return newHandler(newService(l, logger, opts...), json2.NewCodec(), logger), nil
}

type option func(*service)

func WithSubscribeBufferSize(size int) option {
return func(s *service) {
s.subscribeBufferSize = size
}
}

type method struct {
Expand All @@ -48,12 +57,15 @@ type service struct {
client *client.Client
methods map[string]*method
logger types.Logger

subscribeBufferSize int
}

func newService(c *client.Client, l types.Logger) *service {
func newService(c *client.Client, l types.Logger, opts ...option) *service {
s := service{
client: c,
logger: l,
client: c,
logger: l,
subscribeBufferSize: defaultSubscribeBufferSize,
}
s.methods = map[string]*method{
"subscribe": newMethod(s.Subscribe),
Expand Down Expand Up @@ -86,6 +98,11 @@ func newService(c *client.Client, l types.Logger) *service {
"abci_info": newMethod(s.ABCIInfo),
"broadcast_evidence": newMethod(s.BroadcastEvidence),
}

for _, opt := range opts {
opt(&s)
}

return &s
}

Expand All @@ -96,37 +113,25 @@ func (s *service) Subscribe(req *http.Request, args *subscribeArgs, wsConn *wsCo
return nil, errors.Wrap(err, "subscription not allowed")
}

q, err := tmquery.New(args.Query)
if err != nil {
return nil, fmt.Errorf("parse query: %w", err)
}

s.logger.Debug("subscribe to query", "remote", addr, "query", args.Query)

// TODO(tzdybal): extract consts or configs
const SubscribeTimeout = 5 * time.Second
const subBufferSize = 100
ctx, cancel := context.WithTimeout(req.Context(), SubscribeTimeout)
defer cancel()

sub, err := s.client.EventBus.Subscribe(ctx, addr, q, subBufferSize)
out, err := s.client.Subscribe(req.Context(), addr, args.Query, s.subscribeBufferSize)
if err != nil {
return nil, fmt.Errorf("subscribe: %w", err)
}
go func(subscriptionID []byte) {
for {
select {
case msg := <-sub.Out():
case msg := <-out:
// build the base response
resultEvent := &ctypes.ResultEvent{Query: args.Query, Data: msg.Data(), Events: msg.Events()}
var resp rpctypes.RPCResponse
// Check if subscriptionID is string or int and generate the rest of the response accordingly
subscriptionIDInt, err := strconv.Atoi(string(subscriptionID))
if err != nil {
s.logger.Info("Failed to convert subscriptionID to int")
resp = rpctypes.NewRPCSuccessResponse(rpctypes.JSONRPCStringID(subscriptionID), resultEvent)
resp = rpctypes.NewRPCSuccessResponse(rpctypes.JSONRPCStringID(subscriptionID), msg)
} else {
resp = rpctypes.NewRPCSuccessResponse(rpctypes.JSONRPCIntID(subscriptionIDInt), resultEvent)
resp = rpctypes.NewRPCSuccessResponse(rpctypes.JSONRPCIntID(subscriptionIDInt), msg)
}
// Marshal response to JSON and send it to the websocket queue
jsonBytes, err := json.MarshalIndent(resp, "", " ")
Expand All @@ -137,17 +142,6 @@ func (s *service) Subscribe(req *http.Request, args *subscribeArgs, wsConn *wsCo
if wsConn != nil {
wsConn.queue <- jsonBytes
}
case <-sub.Cancelled():
if sub.Err() != pubsub.ErrUnsubscribed {
var reason string
if sub.Err() == nil {
reason = "unknown failure"
} else {
reason = sub.Err().Error()
}
s.logger.Error("subscription was cancelled", "reason", reason)
}
return
}
}
}(subscriptionID)
Expand All @@ -157,7 +151,7 @@ func (s *service) Subscribe(req *http.Request, args *subscribeArgs, wsConn *wsCo

func (s *service) Unsubscribe(req *http.Request, args *unsubscribeArgs) (*emptyResult, error) {
s.logger.Debug("unsubscribe from query", "remote", req.RemoteAddr, "query", args.Query)
err := s.client.Unsubscribe(context.Background(), req.RemoteAddr, args.Query)
err := s.client.Unsubscribe(req.Context(), req.RemoteAddr, args.Query)
if err != nil {
return nil, fmt.Errorf("unsubscribe: %w", err)
}
Expand All @@ -166,7 +160,7 @@ func (s *service) Unsubscribe(req *http.Request, args *unsubscribeArgs) (*emptyR

func (s *service) UnsubscribeAll(req *http.Request, args *unsubscribeAllArgs) (*emptyResult, error) {
s.logger.Debug("unsubscribe from all queries", "remote", req.RemoteAddr)
err := s.client.UnsubscribeAll(context.Background(), req.RemoteAddr)
err := s.client.UnsubscribeAll(req.Context(), req.RemoteAddr)
if err != nil {
return nil, fmt.Errorf("unsubscribe all: %w", err)
}
Expand Down
33 changes: 27 additions & 6 deletions rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,16 @@ type Server struct {
node *node.Node
healthStatus sharedtypes.HealthStatus
listener net.Listener
ctx context.Context
timeout time.Duration

server http.Server
}

const (
// DefaultServerTimeout is the default global timeout for the server.
defaultServerTimeout = 15 * time.Second
)

// Option is a function that configures the Server.
type Option func(*Server)

Expand All @@ -49,6 +54,13 @@ func WithListener(listener net.Listener) Option {
}
}

// WithTimeout is an option that sets the global timeout for the server.
func WithTimeout(timeout time.Duration) Option {
return func(d *Server) {
d.timeout = timeout
}
}

// NewServer creates new instance of Server with given configuration.
func NewServer(node *node.Node, config *config.RPCConfig, logger log.Logger, options ...Option) *Server {
srv := &Server{
Expand All @@ -59,7 +71,7 @@ func NewServer(node *node.Node, config *config.RPCConfig, logger log.Logger, opt
IsHealthy: true,
Error: nil,
},
ctx: context.Background(),
timeout: defaultServerTimeout,
}
srv.BaseService = service.NewBaseService(logger, "RPC", srv)

Expand Down Expand Up @@ -89,16 +101,23 @@ func (s *Server) OnStart() error {

// OnStop is called when Server is stopped (see service.BaseService for details).
func (s *Server) OnStop() {
ctx, cancel := context.WithTimeout(s.ctx, 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), s.timeout)
defer cancel()
if err := s.server.Shutdown(ctx); err != nil {
s.Logger.Error("while shuting down RPC server", "error", err)
s.Logger.Error("while shutting down RPC server", "error", err)
}
}

// EventListener registers events to callbacks.
func (s *Server) eventListener() {
go utils.SubscribeAndHandleEvents(s.ctx, s.PubSubServer(), "RPCNodeHealthStatusHandler", events.EventQueryHealthStatus, s.healthStatusEventCallback, s.Logger)
go utils.SubscribeAndHandleEvents(
context.Background(),
s.PubSubServer(),
"RPCNodeHealthStatusHandler",
events.EventQueryHealthStatus,
s.healthStatusEventCallback,
s.Logger,
)
}

// healthStatusEventCallback is a callback function that handles health status events.
Expand Down Expand Up @@ -162,10 +181,12 @@ func (s *Server) startRPC() error {
)
middlewareClient := middleware.NewClient(*reg, s.Logger.With("module", "rpc/middleware"))
handler = middlewareClient.Handle(handler)
// Set a global timeout
handlerWithTimeout := http.TimeoutHandler(handler, s.timeout, "Server Timeout")

// Start HTTP server
go func() {
err := s.serve(listener, handler)
err := s.serve(listener, handlerWithTimeout)
if !errors.Is(err, http.ErrServerClosed) {
s.Logger.Error("while serving HTTP", "error", err)
}
Expand Down

0 comments on commit cefca7a

Please sign in to comment.