From ee5ce3304d6b44cce13e460b814d43585aefb491 Mon Sep 17 00:00:00 2001 From: Bolek <1416262+bolekk@users.noreply.github.com> Date: Mon, 10 Jul 2023 08:05:27 -0700 Subject: [PATCH] [Gateway] Misc fixes to Message parsing, signing and validation (#9752) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Set Message.Sender as a non-serializable field. 2. Add Message.Receiver field. Because messageID is scoped to a user, nodes need to set receiver to that user’s address to make sure responses can be uniquely attributed to users and also to prove to the user that it is in fact a response to their message. 3. Fix ValidateSignature(). I mistakenly assumed that previous logic was enough to fully validate a signature but we always need to validate the expected sender. Renamed the function to ExtractSigner() to avoid any confusion. 4. Perform message validation before calling handler methods, as suggested by @pinebit. Even though there are theoretical use cases that might not need it, I don’t think they are going to be implemented in the near future. 5. Changing HandleGatewayMessage() attribute back to full message instead of just body. This will be needed for a very real use case, where nodes want to share received messages among each other, within the OCR round. It’s also consistent with the Gateway Handler interface. 6. Fix broken scripts. --- .../gateway/connector/run_connector.go | 5 +- core/scripts/gateway/run_gateway.go | 2 +- core/services/functions/connector_handler.go | 3 +- .../functions/connector_handler_test.go | 18 +++---- .../services/gateway/api/jsonrpccodec_test.go | 8 +-- core/services/gateway/api/message.go | 51 +++++++++---------- core/services/gateway/api/message_test.go | 16 +++++- core/services/gateway/common/utils.go | 16 ++---- core/services/gateway/common/utils_test.go | 13 ++++- core/services/gateway/connectionmanager.go | 10 ++-- core/services/gateway/connector/connector.go | 4 +- .../mocks/gateway_connector_handler.go | 6 +-- core/services/gateway/gateway.go | 3 ++ core/services/gateway/gateway_test.go | 47 +++++++++++++---- .../handlers/functions/handler.functions.go | 4 -- .../functions/handler.functions_test.go | 10 ++-- .../services/ocr2/plugins/functions/plugin.go | 4 +- 17 files changed, 136 insertions(+), 84 deletions(-) diff --git a/core/scripts/gateway/connector/run_connector.go b/core/scripts/gateway/connector/run_connector.go index d3908e5821c..248162ef822 100644 --- a/core/scripts/gateway/connector/run_connector.go +++ b/core/scripts/gateway/connector/run_connector.go @@ -15,6 +15,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" + "github.com/smartcontractkit/chainlink/v2/core/utils" ) // Script to run Connector outside of the core node. @@ -28,7 +29,7 @@ type client struct { lggr logger.Logger } -func (h *client) HandleGatewayMessage(gatewayId string, msg *api.Message) { +func (h *client) HandleGatewayMessage(ctx context.Context, gatewayId string, msg *api.Message) { h.lggr.Infof("received message from gateway %s. Echoing back.", gatewayId) h.connector.SendToGateway(context.Background(), gatewayId, msg) } @@ -65,7 +66,7 @@ func main() { sampleKey, _ := crypto.HexToECDSA("cd47d3fafdbd652dd2b66c6104fa79b372c13cb01f4a4fbfc36107cce913ac1d") lggr, _ := logger.NewLogger() client := &client{privateKey: sampleKey, lggr: lggr} - connector, _ := connector.NewGatewayConnector(&cfg, client, client, common.NewRealClock(), lggr) + connector, _ := connector.NewGatewayConnector(&cfg, client, client, utils.NewRealClock(), lggr) client.connector = connector ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt) diff --git a/core/scripts/gateway/run_gateway.go b/core/scripts/gateway/run_gateway.go index f92b45a605b..1f0230bb8ad 100644 --- a/core/scripts/gateway/run_gateway.go +++ b/core/scripts/gateway/run_gateway.go @@ -48,7 +48,7 @@ func main() { lggr, _ := logger.NewLogger() - handlerFactory := gateway.NewHandlerFactory(lggr) + handlerFactory := gateway.NewHandlerFactory(nil, lggr) gw, err := gateway.NewGatewayFromConfig(&cfg, handlerFactory, lggr) if err != nil { fmt.Println("error creating Gateway object:", err) diff --git a/core/services/functions/connector_handler.go b/core/services/functions/connector_handler.go index b80d1c82871..7e6ddbdd980 100644 --- a/core/services/functions/connector_handler.go +++ b/core/services/functions/connector_handler.go @@ -56,7 +56,8 @@ func (h *functionsConnectorHandler) Sign(data ...[]byte) ([]byte, error) { return common.SignData(h.signerKey, data...) } -func (h *functionsConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayId string, body *api.MessageBody) { +func (h *functionsConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayId string, msg *api.Message) { + body := &msg.Body fromAddr := ethCommon.HexToAddress(body.Sender) if !h.allowlist.Allow(fromAddr) { h.lggr.Errorw("allowlist prevented the request from this address", "id", gatewayId, "address", fromAddr) diff --git a/core/services/functions/connector_handler_test.go b/core/services/functions/connector_handler_test.go index f6dacf144e3..42713413681 100644 --- a/core/services/functions/connector_handler_test.go +++ b/core/services/functions/connector_handler_test.go @@ -46,7 +46,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { signature, err := handler.Sign([]byte("test")) require.NoError(t, err) - signer, err := common.ValidateSignature(signature, []byte("test")) + signer, err := common.ExtractSigner(signature, []byte("test")) require.NoError(t, err) require.Equal(t, addr.Bytes(), signer) }) @@ -77,7 +77,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { }).Return(nil).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) t.Run("orm error", func(t *testing.T) { storage.On("List", ctx, addr).Return(nil, errors.New("boom")).Once() @@ -89,12 +89,12 @@ func TestFunctionsConnectorHandler(t *testing.T) { }).Return(nil).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) }) t.Run("not allowed", func(t *testing.T) { allowlist.On("Allow", addr).Return(false).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) }) }) @@ -133,7 +133,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { }).Return(nil).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) t.Run("orm error", func(t *testing.T) { storage.On("Put", ctx, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("boom")).Once() @@ -145,7 +145,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { }).Return(nil).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) }) t.Run("missing signature", func(t *testing.T) { @@ -160,7 +160,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { }).Return(nil).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) }) t.Run("malformed request", func(t *testing.T) { @@ -174,7 +174,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { }).Return(nil).Once() - handler.HandleGatewayMessage(ctx, "gw1", &msg.Body) + handler.HandleGatewayMessage(ctx, "gw1", &msg) }) }) @@ -191,7 +191,7 @@ func TestFunctionsConnectorHandler(t *testing.T) { require.NoError(t, msg.Sign(privateKey)) allowlist.On("Allow", addr).Return(true).Once() - handler.HandleGatewayMessage(testutils.Context(t), "gw1", &msg.Body) + handler.HandleGatewayMessage(testutils.Context(t), "gw1", &msg) }) }) } diff --git a/core/services/gateway/api/jsonrpccodec_test.go b/core/services/gateway/api/jsonrpccodec_test.go index 656949ad00f..a92bccf36a1 100644 --- a/core/services/gateway/api/jsonrpccodec_test.go +++ b/core/services/gateway/api/jsonrpccodec_test.go @@ -44,7 +44,7 @@ func TestJsonRPCRequest_Encode(t *testing.T) { var msg api.Message msg.Body = api.MessageBody{ MessageId: "aA-bB", - Sender: "0x1234", + Receiver: "0x1234", Method: "upload", } codec := api.JsonRPCCodec{} @@ -54,7 +54,7 @@ func TestJsonRPCRequest_Encode(t *testing.T) { decoded, err := codec.DecodeRequest(bytes) require.NoError(t, err) require.Equal(t, "aA-bB", decoded.Body.MessageId) - require.Equal(t, "0x1234", decoded.Body.Sender) + require.Equal(t, "0x1234", decoded.Body.Receiver) require.Equal(t, "upload", decoded.Body.Method) } @@ -76,7 +76,7 @@ func TestJsonRPCResponse_Encode(t *testing.T) { var msg api.Message msg.Body = api.MessageBody{ MessageId: "aA-bB", - Sender: "0x1234", + Receiver: "0x1234", Method: "upload", } codec := api.JsonRPCCodec{} @@ -86,6 +86,6 @@ func TestJsonRPCResponse_Encode(t *testing.T) { decoded, err := codec.DecodeResponse(bytes) require.NoError(t, err) require.Equal(t, "aA-bB", decoded.Body.MessageId) - require.Equal(t, "0x1234", decoded.Body.Sender) + require.Equal(t, "0x1234", decoded.Body.Receiver) require.Equal(t, "upload", decoded.Body.Method) } diff --git a/core/services/gateway/api/message.go b/core/services/gateway/api/message.go index 56f7162d368..defd3407e66 100644 --- a/core/services/gateway/api/message.go +++ b/core/services/gateway/api/message.go @@ -4,7 +4,6 @@ import ( "crypto/ecdsa" "encoding/json" "errors" - "strings" gw_common "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -16,8 +15,7 @@ const ( MessageIdMaxLen = 128 MessageMethodMaxLen = 64 MessageDonIdMaxLen = 64 - MessageSenderLen = 20 - MessageSenderHexEncodedLen = 2 + 2*MessageSenderLen + MessageReceiverLen = 2 + 2*20 ) /* @@ -36,10 +34,12 @@ type MessageBody struct { MessageId string `json:"message_id"` Method string `json:"method"` DonId string `json:"don_id"` - Sender string `json:"sender"` - + Receiver string `json:"receiver"` // Service-specific payload, decoded inside the Handler. Payload json.RawMessage `json:"payload,omitempty"` + + // Fields only used locally for convenience. Not serialized. + Sender string `json:"-"` } func (m *Message) Validate() error { @@ -58,15 +58,14 @@ func (m *Message) Validate() error { if len(m.Body.DonId) == 0 || len(m.Body.DonId) > MessageDonIdMaxLen { return errors.New("invalid DON ID length") } - signerBytes, err := m.ValidateSignature() + if len(m.Body.Receiver) != 0 && len(m.Body.Receiver) != MessageReceiverLen { + return errors.New("invalid Receiver length") + } + signerBytes, err := m.ExtractSigner() if err != nil { return err } - hexSigner := utils.StringToHex(string(signerBytes)) - if m.Body.Sender != "" && !strings.EqualFold(m.Body.Sender, hexSigner) { - return errors.New("sender doesn't match signer") - } - m.Body.Sender = hexSigner + m.Body.Sender = utils.StringToHex(string(signerBytes)) return nil } @@ -74,12 +73,13 @@ func (m *Message) Validate() error { // 1. MessageId aligned to 128 bytes // 2. Method aligned to 64 bytes // 3. DonId aligned to 64 bytes -// 4. Payload (before parsing) +// 4. Receiver (in hex) aligned to 42 bytes +// 5. Payload (raw bytes before parsing) func (m *Message) Sign(privateKey *ecdsa.PrivateKey) error { - rawData, err := getRawMessageBody(&m.Body) - if err != nil { - return err + if m == nil { + return errors.New("nil message") } + rawData := getRawMessageBody(&m.Body) signature, err := gw_common.SignData(privateKey, rawData...) if err != nil { return err @@ -88,27 +88,26 @@ func (m *Message) Sign(privateKey *ecdsa.PrivateKey) error { return nil } -func (m *Message) ValidateSignature() (signerAddress []byte, err error) { - rawData, err := getRawMessageBody(&m.Body) - if err != nil { - return +func (m *Message) ExtractSigner() (signerAddress []byte, err error) { + if m == nil { + return nil, errors.New("nil message") } + rawData := getRawMessageBody(&m.Body) signatureBytes, err := utils.TryParseHex(m.Signature) if err != nil { - return + return nil, err } - return gw_common.ValidateSignature(signatureBytes, rawData...) + return gw_common.ExtractSigner(signatureBytes, rawData...) } -func getRawMessageBody(msgBody *MessageBody) ([][]byte, error) { - if msgBody == nil { - return nil, errors.New("nil message") - } +func getRawMessageBody(msgBody *MessageBody) [][]byte { alignedMessageId := make([]byte, MessageIdMaxLen) copy(alignedMessageId, msgBody.MessageId) alignedMethod := make([]byte, MessageMethodMaxLen) copy(alignedMethod, msgBody.Method) alignedDonId := make([]byte, MessageDonIdMaxLen) copy(alignedDonId, msgBody.DonId) - return [][]byte{alignedMessageId, alignedMethod, alignedDonId, msgBody.Payload}, nil + alignedReceiver := make([]byte, MessageReceiverLen) + copy(alignedReceiver, msgBody.Receiver) + return [][]byte{alignedMessageId, alignedMethod, alignedDonId, alignedReceiver, msgBody.Payload} } diff --git a/core/services/gateway/api/message_test.go b/core/services/gateway/api/message_test.go index 9c89ca22f1f..a0835ea24bb 100644 --- a/core/services/gateway/api/message_test.go +++ b/core/services/gateway/api/message_test.go @@ -16,6 +16,7 @@ func TestMessage_Validate(t *testing.T) { MessageId: "abcd", Method: "request", DonId: "donA", + Receiver: "0x0000000000000000000000000000000000000000", Payload: []byte("datadata"), }, } @@ -42,6 +43,11 @@ func TestMessage_Validate(t *testing.T) { require.Error(t, msg.Validate()) msg.Body.Method = "request" + // incorrect receiver + msg.Body.Receiver = "blah" + require.Error(t, msg.Validate()) + msg.Body.Receiver = "0x0000000000000000000000000000000000000000" + // invalid signature msg.Signature = "0x00" require.Error(t, msg.Validate()) @@ -55,6 +61,7 @@ func TestMessage_MessageSignAndValidateSignature(t *testing.T) { MessageId: "abcd", Method: "request", DonId: "donA", + Receiver: "0x33", Payload: []byte("datadata"), }, } @@ -67,7 +74,14 @@ func TestMessage_MessageSignAndValidateSignature(t *testing.T) { require.NoError(t, err) require.Equal(t, api.MessageSignatureHexEncodedLen, len(msg.Signature)) - signer, err := msg.ValidateSignature() + // valid + signer, err := msg.ExtractSigner() require.NoError(t, err) require.True(t, bytes.Equal(address, signer)) + + // invalid + msg.Body.MessageId = "dbca" + signer, err = msg.ExtractSigner() + require.NoError(t, err) + require.False(t, bytes.Equal(address, signer)) } diff --git a/core/services/gateway/common/utils.go b/core/services/gateway/common/utils.go index 94e403243db..5c5ab037a7d 100644 --- a/core/services/gateway/common/utils.go +++ b/core/services/gateway/common/utils.go @@ -3,7 +3,6 @@ package common import ( "crypto/ecdsa" "encoding/binary" - "errors" "github.com/ethereum/go-ethereum/crypto" "golang.org/x/exp/slices" @@ -36,18 +35,11 @@ func SignData(privateKey *ecdsa.PrivateKey, data ...[]byte) ([]byte, error) { return crypto.Sign(hash.Bytes(), privateKey) } -func ValidateSignature(signature []byte, data ...[]byte) (signerAddress []byte, err error) { +func ExtractSigner(signature []byte, data ...[]byte) (signerAddress []byte, err error) { hash := crypto.Keccak256Hash(data...) - sigPublicKey, err := crypto.Ecrecover(hash.Bytes(), signature) + ecdsaPubKey, err := crypto.SigToPub(hash.Bytes(), signature) if err != nil { - return + return nil, err } - ecdsaPubKey, _ := crypto.UnmarshalPubkey(sigPublicKey) - signerAddress = crypto.PubkeyToAddress(*ecdsaPubKey).Bytes() - - signatureNoRecoverID := signature[:len(signature)-1] - if !crypto.VerifySignature(sigPublicKey, hash.Bytes(), signatureNoRecoverID) { - return nil, errors.New("invalid signature") - } - return + return crypto.PubkeyToAddress(*ecdsaPubKey).Bytes(), nil } diff --git a/core/services/gateway/common/utils_test.go b/core/services/gateway/common/utils_test.go index 796c76d6eb7..e223ba1e9fc 100644 --- a/core/services/gateway/common/utils_test.go +++ b/core/services/gateway/common/utils_test.go @@ -35,6 +35,7 @@ func TestUtils_BytesSignAndValidate(t *testing.T) { t.Parallel() data := []byte("data_data") + incorrectData := []byte("some_other_data") privateKey, err := crypto.GenerateKey() require.NoError(t, err) @@ -44,7 +45,17 @@ func TestUtils_BytesSignAndValidate(t *testing.T) { require.NoError(t, err) require.Equal(t, 65, len(signature)) - signer, err := common.ValidateSignature(signature, data) + // valid + signer, err := common.ExtractSigner(signature, data) require.NoError(t, err) require.True(t, bytes.Equal(signer, address)) + + // invalid + signer, err = common.ExtractSigner(signature, incorrectData) + require.NoError(t, err) + require.False(t, bytes.Equal(signer, address)) + + // invalid format + _, err = common.ExtractSigner([]byte{0xaa, 0xbb}, data) + require.Error(t, err) } diff --git a/core/services/gateway/connectionmanager.go b/core/services/gateway/connectionmanager.go index f15a7c48f16..664696fe614 100644 --- a/core/services/gateway/connectionmanager.go +++ b/core/services/gateway/connectionmanager.go @@ -183,7 +183,7 @@ func (m *connectionManager) parseAuthHeader(authHeader []byte) (nodeAddress stri return "", nil, errors.New("unable to parse auth header") } signature := authHeader[n-network.HandshakeSignatureLen:] - signer, err := common.ValidateSignature(signature, authHeader[:n-network.HandshakeSignatureLen]) + signer, err := common.ExtractSigner(signature, authHeader[:n-network.HandshakeSignatureLen]) nodeAddress = "0x" + hex.EncodeToString(signer) return } @@ -210,7 +210,7 @@ func (m *connectionManager) FinalizeHandshake(attemptId string, response []byte, if !ok { return errors.New("connection attempt not found") } - signer, err := common.ValidateSignature(response, attempt.challenge) + signer, err := common.ExtractSigner(response, attempt.challenge) if err != nil { return errors.New("invalid challenge response") } @@ -256,7 +256,11 @@ func (m *donConnectionManager) readLoop(nodeAddress string, nodeState *nodeState case item := <-nodeState.conn.ReadChannel(): msg, err := m.codec.DecodeResponse(item.Data) if err != nil { - m.lggr.Error("parse error when reading from node ", nodeAddress, err) + m.lggr.Errorw("parse error when reading from node", "nodeAddress", nodeAddress, "err", err) + break + } + if err = msg.Validate(); err != nil { + m.lggr.Errorw("message validation error when reading from node", "nodeAddress", nodeAddress, "err", err) break } err = m.handler.HandleNodeMessage(ctx, msg, nodeAddress) diff --git a/core/services/gateway/connector/connector.go b/core/services/gateway/connector/connector.go index a0854797e79..1db791fda9b 100644 --- a/core/services/gateway/connector/connector.go +++ b/core/services/gateway/connector/connector.go @@ -40,7 +40,7 @@ type Signer interface { type GatewayConnectorHandler interface { job.ServiceCtx - HandleGatewayMessage(ctx context.Context, gatewayId string, body *api.MessageBody) + HandleGatewayMessage(ctx context.Context, gatewayId string, msg *api.Message) } type gatewayConnector struct { @@ -142,7 +142,7 @@ func (c *gatewayConnector) readLoop(gatewayState *gatewayState) { c.lggr.Errorw("failed to validate message signature", "id", gatewayState.config.Id, "error", err) break } - c.handler.HandleGatewayMessage(ctx, gatewayState.config.Id, &msg.Body) + c.handler.HandleGatewayMessage(ctx, gatewayState.config.Id, msg) } } } diff --git a/core/services/gateway/connector/mocks/gateway_connector_handler.go b/core/services/gateway/connector/mocks/gateway_connector_handler.go index 05d6c6ed97f..1db0f45fa12 100644 --- a/core/services/gateway/connector/mocks/gateway_connector_handler.go +++ b/core/services/gateway/connector/mocks/gateway_connector_handler.go @@ -29,9 +29,9 @@ func (_m *GatewayConnectorHandler) Close() error { return r0 } -// HandleGatewayMessage provides a mock function with given fields: ctx, gatewayId, body -func (_m *GatewayConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayId string, body *api.MessageBody) { - _m.Called(ctx, gatewayId, body) +// HandleGatewayMessage provides a mock function with given fields: ctx, gatewayId, msg +func (_m *GatewayConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayId string, msg *api.Message) { + _m.Called(ctx, gatewayId, msg) } // Start provides a mock function with given fields: _a0 diff --git a/core/services/gateway/gateway.go b/core/services/gateway/gateway.go index dff03354e42..4003dc6c7dd 100644 --- a/core/services/gateway/gateway.go +++ b/core/services/gateway/gateway.go @@ -112,6 +112,9 @@ func (g *gateway) ProcessRequest(ctx context.Context, rawRequest []byte) (rawRes if err != nil { return newError(g.codec, "", api.UserMessageParseError, err.Error()) } + if err = msg.Validate(); err != nil { + return newError(g.codec, msg.Body.MessageId, api.UserMessageParseError, err.Error()) + } // find correct handler handler, ok := g.handlers[msg.Body.DonId] if !ok { diff --git a/core/services/gateway/gateway_test.go b/core/services/gateway/gateway_test.go index 1490a9fccd4..143ee4da0d0 100644 --- a/core/services/gateway/gateway_test.go +++ b/core/services/gateway/gateway_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/ethereum/go-ethereum/crypto" "github.com/pelletier/go-toml/v2" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -120,6 +121,24 @@ func newGatewayWithMockHandler(t *testing.T) (gateway.Gateway, *handler_mocks.Ha return gw, handler } +func newSignedRequest(t *testing.T, messageId string, method string, donID string, payload []byte) []byte { + msg := &api.Message{ + Body: api.MessageBody{ + MessageId: messageId, + Method: method, + DonId: donID, + Payload: payload, + }, + } + privateKey, err := crypto.GenerateKey() + require.NoError(t, err) + require.NoError(t, msg.Sign(privateKey)) + codec := api.JsonRPCCodec{} + rawRequest, err := codec.EncodeRequest(msg) + require.NoError(t, err) + return rawRequest +} + func TestGateway_ProcessRequest_ParseError(t *testing.T) { t.Parallel() @@ -129,15 +148,22 @@ func TestGateway_ProcessRequest_ParseError(t *testing.T) { require.Equal(t, 400, statusCode) } -func TestGateway_ProcessRequest_IncorrectDonId(t *testing.T) { +func TestGateway_ProcessRequest_MessageValidationError(t *testing.T) { t.Parallel() gw, _ := newGatewayWithMockHandler(t) - response, statusCode := gw.ProcessRequest(testutils.Context(t), []byte(`{"jsonrpc":"2.0", "id": "abc", "method": "request", "params": {}}`)) - requireJsonRPCError(t, response, "abc", -32602, "unsupported DON ID") + req := newSignedRequest(t, "abc", "request", "", []byte{}) + response, statusCode := gw.ProcessRequest(testutils.Context(t), req) + requireJsonRPCError(t, response, "abc", -32700, "invalid DON ID length") require.Equal(t, 400, statusCode) +} + +func TestGateway_ProcessRequest_IncorrectDonId(t *testing.T) { + t.Parallel() - response, statusCode = gw.ProcessRequest(testutils.Context(t), []byte(`{"jsonrpc":"2.0", "id": "abc", "method": "request", "params": {"body": {"don_id": "bad"}}}`)) + gw, _ := newGatewayWithMockHandler(t) + req := newSignedRequest(t, "abc", "request", "unknownDON", []byte{}) + response, statusCode := gw.ProcessRequest(testutils.Context(t), req) requireJsonRPCError(t, response, "abc", -32602, "unsupported DON ID") require.Equal(t, 400, statusCode) } @@ -151,13 +177,14 @@ func TestGateway_ProcessRequest_HandlerResponse(t *testing.T) { callbackCh := args.Get(2).(chan<- handlers.UserCallbackPayload) // echo back to sender with attached payload msg.Body.Payload = []byte(`{"result":"OK"}`) + msg.Signature = "" callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.NoError, ErrMsg: ""} }) - response, statusCode := gw.ProcessRequest(testutils.Context(t), - []byte(`{"jsonrpc":"2.0", "method": "request", "id": "abcd", "params": {"body":{"don_id": "testDON"}}}`)) + req := newSignedRequest(t, "abcd", "request", "testDON", []byte{}) + response, statusCode := gw.ProcessRequest(testutils.Context(t), req) requireJsonRPCResult(t, response, "abcd", - `{"signature":"","body":{"message_id":"abcd","method":"request","don_id":"testDON","sender":"","payload":{"result":"OK"}}}`) + `{"signature":"","body":{"message_id":"abcd","method":"request","don_id":"testDON","receiver":"","payload":{"result":"OK"}}}`) require.Equal(t, 200, statusCode) } @@ -169,7 +196,8 @@ func TestGateway_ProcessRequest_HandlerTimeout(t *testing.T) { timeoutCtx, cancel := context.WithTimeout(testutils.Context(t), time.Duration(time.Millisecond*10)) defer cancel() - response, statusCode := gw.ProcessRequest(timeoutCtx, []byte(`{"jsonrpc":"2.0", "method": "request", "id": "abcd", "params": {"body":{"don_id": "testDON"}}}`)) + req := newSignedRequest(t, "abcd", "request", "testDON", []byte{}) + response, statusCode := gw.ProcessRequest(timeoutCtx, req) requireJsonRPCError(t, response, "abcd", -32000, "handler timeout") require.Equal(t, 504, statusCode) } @@ -180,7 +208,8 @@ func TestGateway_ProcessRequest_HandlerError(t *testing.T) { gw, handler := newGatewayWithMockHandler(t) handler.On("HandleUserMessage", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failure")) - response, statusCode := gw.ProcessRequest(testutils.Context(t), []byte(`{"jsonrpc":"2.0", "method": "request", "id": "abcd", "params": {"body":{"don_id": "testDON"}}}`)) + req := newSignedRequest(t, "abcd", "request", "testDON", []byte{}) + response, statusCode := gw.ProcessRequest(testutils.Context(t), req) requireJsonRPCError(t, response, "abcd", -32000, "failure") require.Equal(t, 500, statusCode) } diff --git a/core/services/gateway/handlers/functions/handler.functions.go b/core/services/gateway/handlers/functions/handler.functions.go index d94a975b8ab..2817ffdce98 100644 --- a/core/services/gateway/handlers/functions/handler.functions.go +++ b/core/services/gateway/handlers/functions/handler.functions.go @@ -69,10 +69,6 @@ func ParseConfig(handlerConfig json.RawMessage) (*FunctionsHandlerConfig, error) } func (h *functionsHandler) HandleUserMessage(ctx context.Context, msg *api.Message, callbackCh chan<- handlers.UserCallbackPayload) error { - if err := msg.Validate(); err != nil { - h.lggr.Debugw("received invalid message", "err", err) - return err - } sender := common.HexToAddress(msg.Body.Sender) if h.allowlist != nil && !h.allowlist.Allow(sender) { h.lggr.Debugw("received a message from a non-allowlisted address", "sender", msg.Body.Sender) diff --git a/core/services/gateway/handlers/functions/handler.functions_test.go b/core/services/gateway/handlers/functions/handler.functions_test.go index c4d2f5eb999..3f412fdd011 100644 --- a/core/services/gateway/handlers/functions/handler.functions_test.go +++ b/core/services/gateway/handlers/functions/handler.functions_test.go @@ -8,17 +8,19 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions" ) -func TestFunctionsHandler_Basic(t *testing.T) { +func TestFunctionsHandler_Minimal(t *testing.T) { t.Parallel() handler, err := functions.NewFunctionsHandler(json.RawMessage("{}"), &config.DONConfig{}, nil, nil, logger.TestLogger(t)) require.NoError(t, err) - // nil message - err = handler.HandleUserMessage(testutils.Context(t), nil, nil) - require.Error(t, err) + // empty message + msg := &api.Message{} + err = handler.HandleUserMessage(testutils.Context(t), msg, nil) + require.NoError(t, err) } diff --git a/core/services/ocr2/plugins/functions/plugin.go b/core/services/ocr2/plugins/functions/plugin.go index a645a9c5c66..5cd2f3a61eb 100644 --- a/core/services/ocr2/plugins/functions/plugin.go +++ b/core/services/ocr2/plugins/functions/plugin.go @@ -169,9 +169,9 @@ func NewConnector(gwcCfg *connector.ConnectorConfig, ethKeystore keystore.Eth, c return nil, errors.New("key for configured node address not found") } signerKey := enabledKeys[idx].ToEcdsaPrivKey() - nodeAddreess := enabledKeys[idx].ID() + nodeAddress := enabledKeys[idx].ID() - handler := functions.NewFunctionsConnectorHandler(nodeAddreess, signerKey, s4Storage, allowlist, lggr) + handler := functions.NewFunctionsConnectorHandler(nodeAddress, signerKey, s4Storage, allowlist, lggr) connector, err := connector.NewGatewayConnector(gwcCfg, handler, handler, utils.NewRealClock(), lggr) if err != nil { return nil, err