diff --git a/core/scripts/gateway/connector/run_connector.go b/core/scripts/gateway/connector/run_connector.go index d3908e5821c..248162ef822 100644 --- a/core/scripts/gateway/connector/run_connector.go +++ b/core/scripts/gateway/connector/run_connector.go @@ -15,6 +15,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" + "github.com/smartcontractkit/chainlink/v2/core/utils" ) // Script to run Connector outside of the core node. @@ -28,7 +29,7 @@ type client struct { lggr logger.Logger } -func (h *client) HandleGatewayMessage(gatewayId string, msg *api.Message) { +func (h *client) HandleGatewayMessage(ctx context.Context, gatewayId string, msg *api.Message) { h.lggr.Infof("received message from gateway %s. Echoing back.", gatewayId) h.connector.SendToGateway(context.Background(), gatewayId, msg) } @@ -65,7 +66,7 @@ func main() { sampleKey, _ := crypto.HexToECDSA("cd47d3fafdbd652dd2b66c6104fa79b372c13cb01f4a4fbfc36107cce913ac1d") lggr, _ := logger.NewLogger() client := &client{privateKey: sampleKey, lggr: lggr} - connector, _ := connector.NewGatewayConnector(&cfg, client, client, common.NewRealClock(), lggr) + connector, _ := connector.NewGatewayConnector(&cfg, client, client, utils.NewRealClock(), lggr) client.connector = connector ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt) diff --git a/core/scripts/gateway/run_gateway.go b/core/scripts/gateway/run_gateway.go index f92b45a605b..1f0230bb8ad 100644 --- a/core/scripts/gateway/run_gateway.go +++ b/core/scripts/gateway/run_gateway.go @@ -48,7 +48,7 @@ func main() { lggr, _ := logger.NewLogger() - handlerFactory := gateway.NewHandlerFactory(lggr) + handlerFactory := gateway.NewHandlerFactory(nil, lggr) gw, err := gateway.NewGatewayFromConfig(&cfg, handlerFactory, lggr) if err != nil { fmt.Println("error creating Gateway object:", err) diff --git a/core/services/functions/connector_handler.go b/core/services/functions/connector_handler.go index b80d1c82871..7e6ddbdd980 100644 --- a/core/services/functions/connector_handler.go +++ b/core/services/functions/connector_handler.go @@ -56,7 +56,8 @@ func (h *functionsConnectorHandler) Sign(data ...[]byte) ([]byte, error) { return common.SignData(h.signerKey, data...) } -func (h *functionsConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayId string, body *api.MessageBody) { +func (h *functionsConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayId string, msg *api.Message) { + body := &msg.Body fromAddr := ethCommon.HexToAddress(body.Sender) if !h.allowlist.Allow(fromAddr) { h.lggr.Errorw("allowlist prevented the request from this address", "id", gatewayId, "address", fromAddr) diff --git a/core/services/functions/connector_handler_test.go b/core/services/functions/connector_handler_test.go index f6dacf144e3..42713413681 100644 --- a/core/services/functions/connector_handler_test.go +++ b/core/services/functions/connector_handler_test.go @@ -46,7 +46,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { signature, err := handler.Sign([]byte("test")) require.NoError(t, err) - signer, err := common.ValidateSignature(signature, []byte("test")) + signer, err := common.ExtractSigner(signature, []byte("test")) require.NoError(t, err) require.Equal(t, addr.Bytes(), signer) }) @@ -77,7 +77,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { }).Return(nil).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) t.Run("orm error", func(t *testing.T) { storage.On("List", ctx, addr).Return(nil, errors.New("boom")).Once() @@ -89,12 +89,12 @@ func TestFunctionsConnectorHandler(t *testing.T) { }).Return(nil).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) }) t.Run("not allowed", func(t *testing.T) { allowlist.On("Allow", addr).Return(false).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) }) }) @@ -133,7 +133,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { }).Return(nil).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) t.Run("orm error", func(t *testing.T) { storage.On("Put", ctx, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("boom")).Once() @@ -145,7 +145,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { }).Return(nil).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) }) t.Run("missing signature", func(t *testing.T) { @@ -160,7 +160,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { }).Return(nil).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) }) t.Run("malformed request", func(t *testing.T) { @@ -174,7 +174,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { }).Return(nil).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) }) }) @@ -191,7 +191,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { require.NoError(t, msg.Sign(privateKey)) allowlist.On("Allow", addr).Return(true).Once() - handler.HandleGatewayMessage(testutils.Context(t), "gw1", &msg.Body) + handler.HandleGatewayMessage(testutils.Context(t), "gw1", &msg) }) }) } diff --git a/core/services/gateway/api/jsonrpccodec_test.go b/core/services/gateway/api/jsonrpccodec_test.go index 656949ad00f..a92bccf36a1 100644 --- a/core/services/gateway/api/jsonrpccodec_test.go +++ b/core/services/gateway/api/jsonrpccodec_test.go @@ -44,7 +44,7 @@ func TestJsonRPCRequest_Encode(t *testing.T) { var msg api.Message msg.Body = api.MessageBody{ MessageId: "aA-bB", - Sender: "0x1234", + Receiver: "0x1234", Method: "upload", } codec := api.JsonRPCCodec{} @@ -54,7 +54,7 @@ func TestJsonRPCRequest_Encode(t *testing.T) { decoded, err := codec.DecodeRequest(bytes) require.NoError(t, err) require.Equal(t, "aA-bB", decoded.Body.MessageId) - require.Equal(t, "0x1234", decoded.Body.Sender) + require.Equal(t, "0x1234", decoded.Body.Receiver) require.Equal(t, "upload", decoded.Body.Method) } @@ -76,7 +76,7 @@ func TestJsonRPCResponse_Encode(t *testing.T) { var msg api.Message msg.Body = api.MessageBody{ MessageId: "aA-bB", - Sender: "0x1234", + Receiver: "0x1234", Method: "upload", } codec := api.JsonRPCCodec{} @@ -86,6 +86,6 @@ func TestJsonRPCResponse_Encode(t *testing.T) { decoded, err := codec.DecodeResponse(bytes) require.NoError(t, err) require.Equal(t, "aA-bB", decoded.Body.MessageId) - require.Equal(t, "0x1234", decoded.Body.Sender) + require.Equal(t, "0x1234", decoded.Body.Receiver) require.Equal(t, "upload", decoded.Body.Method) } diff --git a/core/services/gateway/api/message.go b/core/services/gateway/api/message.go index 56f7162d368..defd3407e66 100644 --- a/core/services/gateway/api/message.go +++ b/core/services/gateway/api/message.go @@ -4,7 +4,6 @@ import ( "crypto/ecdsa" "encoding/json" "errors" - "strings" gw_common "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -16,8 +15,7 @@ const ( MessageIdMaxLen = 128 MessageMethodMaxLen = 64 MessageDonIdMaxLen = 64 - MessageSenderLen = 20 - MessageSenderHexEncodedLen = 2 + 2*MessageSenderLen + MessageReceiverLen = 2 + 2*20 ) /* @@ -36,10 +34,12 @@ type MessageBody struct { MessageId string `json:"message_id"` Method string `json:"method"` DonId string `json:"don_id"` - Sender string `json:"sender"` - + Receiver string `json:"receiver"` // Service-specific payload, decoded inside the Handler. Payload json.RawMessage `json:"payload,omitempty"` + + // Fields only used locally for convenience. Not serialized. + Sender string `json:"-"` } func (m *Message) Validate() error { @@ -58,15 +58,14 @@ func (m *Message) Validate() error { if len(m.Body.DonId) == 0 || len(m.Body.DonId) > MessageDonIdMaxLen { return errors.New("invalid DON ID length") } - signerBytes, err := m.ValidateSignature() + if len(m.Body.Receiver) != 0 && len(m.Body.Receiver) != MessageReceiverLen { + return errors.New("invalid Receiver length") + } + signerBytes, err := m.ExtractSigner() if err != nil { return err } - hexSigner := utils.StringToHex(string(signerBytes)) - if m.Body.Sender != "" && !strings.EqualFold(m.Body.Sender, hexSigner) { - return errors.New("sender doesn't match signer") - } - m.Body.Sender = hexSigner + m.Body.Sender = utils.StringToHex(string(signerBytes)) return nil } @@ -74,12 +73,13 @@ func (m *Message) Validate() error { // 1. MessageId aligned to 128 bytes // 2. Method aligned to 64 bytes // 3. DonId aligned to 64 bytes -// 4. Payload (before parsing) +// 4. Receiver (in hex) aligned to 42 bytes +// 5. Payload (raw bytes before parsing) func (m *Message) Sign(privateKey *ecdsa.PrivateKey) error { - rawData, err := getRawMessageBody(&m.Body) - if err != nil { - return err + if m == nil { + return errors.New("nil message") } + rawData := getRawMessageBody(&m.Body) signature, err := gw_common.SignData(privateKey, rawData...) if err != nil { return err @@ -88,27 +88,26 @@ func (m *Message) Sign(privateKey *ecdsa.PrivateKey) error { return nil } -func (m *Message) ValidateSignature() (signerAddress []byte, err error) { - rawData, err := getRawMessageBody(&m.Body) - if err != nil { - return +func (m *Message) ExtractSigner() (signerAddress []byte, err error) { + if m == nil { + return nil, errors.New("nil message") } + rawData := getRawMessageBody(&m.Body) signatureBytes, err := utils.TryParseHex(m.Signature) if err != nil { - return + return nil, err } - return gw_common.ValidateSignature(signatureBytes, rawData...) + return gw_common.ExtractSigner(signatureBytes, rawData...) } -func getRawMessageBody(msgBody *MessageBody) ([][]byte, error) { - if msgBody == nil { - return nil, errors.New("nil message") - } +func getRawMessageBody(msgBody *MessageBody) [][]byte { alignedMessageId := make([]byte, MessageIdMaxLen) copy(alignedMessageId, msgBody.MessageId) alignedMethod := make([]byte, MessageMethodMaxLen) copy(alignedMethod, msgBody.Method) alignedDonId := make([]byte, MessageDonIdMaxLen) copy(alignedDonId, msgBody.DonId) - return [][]byte{alignedMessageId, alignedMethod, alignedDonId, msgBody.Payload}, nil + alignedReceiver := make([]byte, MessageReceiverLen) + copy(alignedReceiver, msgBody.Receiver) + return [][]byte{alignedMessageId, alignedMethod, alignedDonId, alignedReceiver, msgBody.Payload} } diff --git a/core/services/gateway/api/message_test.go b/core/services/gateway/api/message_test.go index 9c89ca22f1f..a0835ea24bb 100644 --- a/core/services/gateway/api/message_test.go +++ b/core/services/gateway/api/message_test.go @@ -16,6 +16,7 @@ func TestMessage_Validate(t *testing.T) { MessageId: "abcd", Method: "request", DonId: "donA", + Receiver: "0x0000000000000000000000000000000000000000", Payload: []byte("datadata"), }, } @@ -42,6 +43,11 @@ func TestMessage_Validate(t *testing.T) { require.Error(t, msg.Validate()) msg.Body.Method = "request" + // incorrect receiver + msg.Body.Receiver = "blah" + require.Error(t, msg.Validate()) + msg.Body.Receiver = "0x0000000000000000000000000000000000000000" + // invalid signature msg.Signature = "0x00" require.Error(t, msg.Validate()) @@ -55,6 +61,7 @@ func TestMessage_MessageSignAndValidateSignature(t *testing.T) { MessageId: "abcd", Method: "request", DonId: "donA", + Receiver: "0x33", Payload: []byte("datadata"), }, } @@ -67,7 +74,14 @@ func TestMessage_MessageSignAndValidateSignature(t *testing.T) { require.NoError(t, err) require.Equal(t, api.MessageSignatureHexEncodedLen, len(msg.Signature)) - signer, err := msg.ValidateSignature() + // valid + signer, err := msg.ExtractSigner() require.NoError(t, err) require.True(t, bytes.Equal(address, signer)) + + // invalid + msg.Body.MessageId = "dbca" + signer, err = msg.ExtractSigner() + require.NoError(t, err) + require.False(t, bytes.Equal(address, signer)) } diff --git a/core/services/gateway/common/utils.go b/core/services/gateway/common/utils.go index 94e403243db..5c5ab037a7d 100644 --- a/core/services/gateway/common/utils.go +++ b/core/services/gateway/common/utils.go @@ -3,7 +3,6 @@ package common import ( "crypto/ecdsa" "encoding/binary" - "errors" "github.com/ethereum/go-ethereum/crypto" "golang.org/x/exp/slices" @@ -36,18 +35,11 @@ func SignData(privateKey *ecdsa.PrivateKey, data ...[]byte) ([]byte, error) { return crypto.Sign(hash.Bytes(), privateKey) } -func ValidateSignature(signature []byte, data ...[]byte) (signerAddress []byte, err error) { +func ExtractSigner(signature []byte, data ...[]byte) (signerAddress []byte, err error) { hash := crypto.Keccak256Hash(data...) - sigPublicKey, err := crypto.Ecrecover(hash.Bytes(), signature) + ecdsaPubKey, err := crypto.SigToPub(hash.Bytes(), signature) if err != nil { - return + return nil, err } - ecdsaPubKey, _ := crypto.UnmarshalPubkey(sigPublicKey) - signerAddress = crypto.PubkeyToAddress(*ecdsaPubKey).Bytes() - - signatureNoRecoverID := signature[:len(signature)-1] - if !crypto.VerifySignature(sigPublicKey, hash.Bytes(), signatureNoRecoverID) { - return nil, errors.New("invalid signature") - } - return + return crypto.PubkeyToAddress(*ecdsaPubKey).Bytes(), nil } diff --git a/core/services/gateway/common/utils_test.go b/core/services/gateway/common/utils_test.go index 796c76d6eb7..e223ba1e9fc 100644 --- a/core/services/gateway/common/utils_test.go +++ b/core/services/gateway/common/utils_test.go @@ -35,6 +35,7 @@ func TestUtils_BytesSignAndValidate(t *testing.T) { t.Parallel() data := []byte("data_data") + incorrectData := []byte("some_other_data") privateKey, err := crypto.GenerateKey() require.NoError(t, err) @@ -44,7 +45,17 @@ func TestUtils_BytesSignAndValidate(t *testing.T) { require.NoError(t, err) require.Equal(t, 65, len(signature)) - signer, err := common.ValidateSignature(signature, data) + // valid + signer, err := common.ExtractSigner(signature, data) require.NoError(t, err) require.True(t, bytes.Equal(signer, address)) + + // invalid + signer, err = common.ExtractSigner(signature, incorrectData) + require.NoError(t, err) + require.False(t, bytes.Equal(signer, address)) + + // invalid format + _, err = common.ExtractSigner([]byte{0xaa, 0xbb}, data) + require.Error(t, err) } diff --git a/core/services/gateway/connectionmanager.go b/core/services/gateway/connectionmanager.go index f15a7c48f16..664696fe614 100644 --- a/core/services/gateway/connectionmanager.go +++ b/core/services/gateway/connectionmanager.go @@ -183,7 +183,7 @@ func (m *connectionManager) parseAuthHeader(authHeader []byte) (nodeAddress stri return "", nil, errors.New("unable to parse auth header") } signature := authHeader[n-network.HandshakeSignatureLen:] - signer, err := common.ValidateSignature(signature, authHeader[:n-network.HandshakeSignatureLen]) + signer, err := common.ExtractSigner(signature, authHeader[:n-network.HandshakeSignatureLen]) nodeAddress = "0x" + hex.EncodeToString(signer) return } @@ -210,7 +210,7 @@ func (m *connectionManager) FinalizeHandshake(attemptId string, response []byte, if !ok { return errors.New("connection attempt not found") } - signer, err := common.ValidateSignature(response, attempt.challenge) + signer, err := common.ExtractSigner(response, attempt.challenge) if err != nil { return errors.New("invalid challenge response") } @@ -256,7 +256,11 @@ func (m *donConnectionManager) readLoop(nodeAddress string, nodeState *nodeState case item := <-nodeState.conn.ReadChannel(): msg, err := m.codec.DecodeResponse(item.Data) if err != nil { - m.lggr.Error("parse error when reading from node ", nodeAddress, err) + m.lggr.Errorw("parse error when reading from node", "nodeAddress", nodeAddress, "err", err) + break + } + if err = msg.Validate(); err != nil { + m.lggr.Errorw("message validation error when reading from node", "nodeAddress", nodeAddress, "err", err) break } err = m.handler.HandleNodeMessage(ctx, msg, nodeAddress) diff --git a/core/services/gateway/connector/connector.go b/core/services/gateway/connector/connector.go index a0854797e79..1db791fda9b 100644 --- a/core/services/gateway/connector/connector.go +++ b/core/services/gateway/connector/connector.go @@ -40,7 +40,7 @@ type Signer interface { type GatewayConnectorHandler interface { job.ServiceCtx - HandleGatewayMessage(ctx context.Context, gatewayId string, body *api.MessageBody) + HandleGatewayMessage(ctx context.Context, gatewayId string, msg *api.Message) } type gatewayConnector struct { @@ -142,7 +142,7 @@ func (c *gatewayConnector) readLoop(gatewayState *gatewayState) { c.lggr.Errorw("failed to validate message signature", "id", gatewayState.config.Id, "error", err) break } - c.handler.HandleGatewayMessage(ctx, gatewayState.config.Id, &msg.Body) + c.handler.HandleGatewayMessage(ctx, gatewayState.config.Id, msg) } } } diff --git a/core/services/gateway/connector/mocks/gateway_connector_handler.go b/core/services/gateway/connector/mocks/gateway_connector_handler.go index 05d6c6ed97f..1db0f45fa12 100644 --- a/core/services/gateway/connector/mocks/gateway_connector_handler.go +++ b/core/services/gateway/connector/mocks/gateway_connector_handler.go @@ -29,9 +29,9 @@ func (_m *GatewayConnectorHandler) Close() error { return r0 } -// HandleGatewayMessage provides a mock function with given fields: ctx, gatewayId, body -func (_m *GatewayConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayId string, body *api.MessageBody) { - _m.Called(ctx, gatewayId, body) +// HandleGatewayMessage provides a mock function with given fields: ctx, gatewayId, msg +func (_m *GatewayConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayId string, msg *api.Message) { + _m.Called(ctx, gatewayId, msg) } // Start provides a mock function with given fields: _a0 diff --git a/core/services/gateway/gateway.go b/core/services/gateway/gateway.go index dff03354e42..4003dc6c7dd 100644 --- a/core/services/gateway/gateway.go +++ b/core/services/gateway/gateway.go @@ -112,6 +112,9 @@ func (g *gateway) ProcessRequest(ctx context.Context, rawRequest []byte) (rawRes if err != nil { return newError(g.codec, "", api.UserMessageParseError, err.Error()) } + if err = msg.Validate(); err != nil { + return newError(g.codec, msg.Body.MessageId, api.UserMessageParseError, err.Error()) + } // find correct handler handler, ok := g.handlers[msg.Body.DonId] if !ok { diff --git a/core/services/gateway/gateway_test.go b/core/services/gateway/gateway_test.go index 1490a9fccd4..143ee4da0d0 100644 --- a/core/services/gateway/gateway_test.go +++ b/core/services/gateway/gateway_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/ethereum/go-ethereum/crypto" "github.com/pelletier/go-toml/v2" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -120,6 +121,24 @@ func newGatewayWithMockHandler(t *testing.T) (gateway.Gateway, *handler_mocks.Ha return gw, handler } +func newSignedRequest(t *testing.T, messageId string, method string, donID string, payload []byte) []byte { + msg := &api.Message{ + Body: api.MessageBody{ + MessageId: messageId, + Method: method, + DonId: donID, + Payload: payload, + }, + } + privateKey, err := crypto.GenerateKey() + require.NoError(t, err) + require.NoError(t, msg.Sign(privateKey)) + codec := api.JsonRPCCodec{} + rawRequest, err := codec.EncodeRequest(msg) + require.NoError(t, err) + return rawRequest +} + func TestGateway_ProcessRequest_ParseError(t *testing.T) { t.Parallel() @@ -129,15 +148,22 @@ func TestGateway_ProcessRequest_ParseError(t *testing.T) { require.Equal(t, 400, statusCode) } -func TestGateway_ProcessRequest_IncorrectDonId(t *testing.T) { +func TestGateway_ProcessRequest_MessageValidationError(t *testing.T) { t.Parallel() gw, _ := newGatewayWithMockHandler(t) - response, statusCode := gw.ProcessRequest(testutils.Context(t), []byte(`{"jsonrpc":"2.0", "id": "abc", "method": "request", "params": {}}`)) - requireJsonRPCError(t, response, "abc", -32602, "unsupported DON ID") + req := newSignedRequest(t, "abc", "request", "", []byte{}) + response, statusCode := gw.ProcessRequest(testutils.Context(t), req) + requireJsonRPCError(t, response, "abc", -32700, "invalid DON ID length") require.Equal(t, 400, statusCode) +} + +func TestGateway_ProcessRequest_IncorrectDonId(t *testing.T) { + t.Parallel() - response, statusCode = gw.ProcessRequest(testutils.Context(t), []byte(`{"jsonrpc":"2.0", "id": "abc", "method": "request", "params": {"body": {"don_id": "bad"}}}`)) + gw, _ := newGatewayWithMockHandler(t) + req := newSignedRequest(t, "abc", "request", "unknownDON", []byte{}) + response, statusCode := gw.ProcessRequest(testutils.Context(t), req) requireJsonRPCError(t, response, "abc", -32602, "unsupported DON ID") require.Equal(t, 400, statusCode) } @@ -151,13 +177,14 @@ func TestGateway_ProcessRequest_HandlerResponse(t *testing.T) { callbackCh := args.Get(2).(chan<- handlers.UserCallbackPayload) // echo back to sender with attached payload msg.Body.Payload = []byte(`{"result":"OK"}`) + msg.Signature = "" callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.NoError, ErrMsg: ""} }) - response, statusCode := gw.ProcessRequest(testutils.Context(t), - []byte(`{"jsonrpc":"2.0", "method": "request", "id": "abcd", "params": {"body":{"don_id": "testDON"}}}`)) + req := newSignedRequest(t, "abcd", "request", "testDON", []byte{}) + response, statusCode := gw.ProcessRequest(testutils.Context(t), req) requireJsonRPCResult(t, response, "abcd", - `{"signature":"","body":{"message_id":"abcd","method":"request","don_id":"testDON","sender":"","payload":{"result":"OK"}}}`) + `{"signature":"","body":{"message_id":"abcd","method":"request","don_id":"testDON","receiver":"","payload":{"result":"OK"}}}`) require.Equal(t, 200, statusCode) } @@ -169,7 +196,8 @@ func TestGateway_ProcessRequest_HandlerTimeout(t *testing.T) { timeoutCtx, cancel := context.WithTimeout(testutils.Context(t), time.Duration(time.Millisecond*10)) defer cancel() - response, statusCode := gw.ProcessRequest(timeoutCtx, []byte(`{"jsonrpc":"2.0", "method": "request", "id": "abcd", "params": {"body":{"don_id": "testDON"}}}`)) + req := newSignedRequest(t, "abcd", "request", "testDON", []byte{}) + response, statusCode := gw.ProcessRequest(timeoutCtx, req) requireJsonRPCError(t, response, "abcd", -32000, "handler timeout") require.Equal(t, 504, statusCode) } @@ -180,7 +208,8 @@ func TestGateway_ProcessRequest_HandlerError(t *testing.T) { gw, handler := newGatewayWithMockHandler(t) handler.On("HandleUserMessage", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failure")) - response, statusCode := gw.ProcessRequest(testutils.Context(t), []byte(`{"jsonrpc":"2.0", "method": "request", "id": "abcd", "params": {"body":{"don_id": "testDON"}}}`)) + req := newSignedRequest(t, "abcd", "request", "testDON", []byte{}) + response, statusCode := gw.ProcessRequest(testutils.Context(t), req) requireJsonRPCError(t, response, "abcd", -32000, "failure") require.Equal(t, 500, statusCode) } diff --git a/core/services/gateway/handlers/functions/handler.functions.go b/core/services/gateway/handlers/functions/handler.functions.go index d94a975b8ab..2817ffdce98 100644 --- a/core/services/gateway/handlers/functions/handler.functions.go +++ b/core/services/gateway/handlers/functions/handler.functions.go @@ -69,10 +69,6 @@ func ParseConfig(handlerConfig json.RawMessage) (*FunctionsHandlerConfig, error) } func (h *functionsHandler) HandleUserMessage(ctx context.Context, msg *api.Message, callbackCh chan<- handlers.UserCallbackPayload) error { - if err := msg.Validate(); err != nil { - h.lggr.Debugw("received invalid message", "err", err) - return err - } sender := common.HexToAddress(msg.Body.Sender) if h.allowlist != nil && !h.allowlist.Allow(sender) { h.lggr.Debugw("received a message from a non-allowlisted address", "sender", msg.Body.Sender) diff --git a/core/services/gateway/handlers/functions/handler.functions_test.go b/core/services/gateway/handlers/functions/handler.functions_test.go index c4d2f5eb999..3f412fdd011 100644 --- a/core/services/gateway/handlers/functions/handler.functions_test.go +++ b/core/services/gateway/handlers/functions/handler.functions_test.go @@ -8,17 +8,19 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions" ) -func TestFunctionsHandler_Basic(t *testing.T) { +func TestFunctionsHandler_Minimal(t *testing.T) { t.Parallel() handler, err := functions.NewFunctionsHandler(json.RawMessage("{}"), &config.DONConfig{}, nil, nil, logger.TestLogger(t)) require.NoError(t, err) - // nil message - err = handler.HandleUserMessage(testutils.Context(t), nil, nil) - require.Error(t, err) + // empty message + msg := &api.Message{} + err = handler.HandleUserMessage(testutils.Context(t), msg, nil) + require.NoError(t, err) } diff --git a/core/services/ocr2/plugins/functions/plugin.go b/core/services/ocr2/plugins/functions/plugin.go index a645a9c5c66..5cd2f3a61eb 100644 --- a/core/services/ocr2/plugins/functions/plugin.go +++ b/core/services/ocr2/plugins/functions/plugin.go @@ -169,9 +169,9 @@ func NewConnector(gwcCfg *connector.ConnectorConfig, ethKeystore keystore.Eth, c return nil, errors.New("key for configured node address not found") } signerKey := enabledKeys[idx].ToEcdsaPrivKey() - nodeAddreess := enabledKeys[idx].ID() + nodeAddress := enabledKeys[idx].ID() - handler := functions.NewFunctionsConnectorHandler(nodeAddreess, signerKey, s4Storage, allowlist, lggr) + handler := functions.NewFunctionsConnectorHandler(nodeAddress, signerKey, s4Storage, allowlist, lggr) connector, err := connector.NewGatewayConnector(gwcCfg, handler, handler, utils.NewRealClock(), lggr) if err != nil { return nil, err