Skip to content

Commit

Permalink
[Gateway] Misc fixes to Message parsing, signing and validation (#9752)
Browse files Browse the repository at this point in the history
1. Set Message.Sender as a non-serializable field.
2. Add Message.Receiver field. Because messageID is scoped to a user, nodes need to set receiver to that user’s address to make sure responses can be uniquely attributed to users and also to prove to the user that it is in fact a response to their message.
3. Fix ValidateSignature(). I mistakenly assumed that previous logic was enough to fully validate a signature but we always need to validate the expected sender. Renamed the function to ExtractSigner() to avoid any confusion.
4. Perform message validation before calling handler methods, as suggested by @pinebit. Even though there are theoretical use cases that might not need it, I don’t think they are going to be implemented in the near future.
5. Changing HandleGatewayMessage() attribute back to full message instead of just body. This will be needed for a very real use case, where nodes want to share received messages among each other, within the OCR round. It’s also consistent with the Gateway Handler interface.
6. Fix broken scripts.
  • Loading branch information
bolekk authored Jul 10, 2023
1 parent ea29e8f commit ee5ce33
Show file tree
Hide file tree
Showing 17 changed files with 136 additions and 84 deletions.
5 changes: 3 additions & 2 deletions core/scripts/gateway/connector/run_connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/api"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/common"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector"
"github.com/smartcontractkit/chainlink/v2/core/utils"
)

// Script to run Connector outside of the core node.
Expand All @@ -28,7 +29,7 @@ type client struct {
lggr logger.Logger
}

func (h *client) HandleGatewayMessage(gatewayId string, msg *api.Message) {
func (h *client) HandleGatewayMessage(ctx context.Context, gatewayId string, msg *api.Message) {
h.lggr.Infof("received message from gateway %s. Echoing back.", gatewayId)
h.connector.SendToGateway(context.Background(), gatewayId, msg)
}
Expand Down Expand Up @@ -65,7 +66,7 @@ func main() {
sampleKey, _ := crypto.HexToECDSA("cd47d3fafdbd652dd2b66c6104fa79b372c13cb01f4a4fbfc36107cce913ac1d")
lggr, _ := logger.NewLogger()
client := &client{privateKey: sampleKey, lggr: lggr}
connector, _ := connector.NewGatewayConnector(&cfg, client, client, common.NewRealClock(), lggr)
connector, _ := connector.NewGatewayConnector(&cfg, client, client, utils.NewRealClock(), lggr)
client.connector = connector

ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
Expand Down
2 changes: 1 addition & 1 deletion core/scripts/gateway/run_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func main() {

lggr, _ := logger.NewLogger()

handlerFactory := gateway.NewHandlerFactory(lggr)
handlerFactory := gateway.NewHandlerFactory(nil, lggr)
gw, err := gateway.NewGatewayFromConfig(&cfg, handlerFactory, lggr)
if err != nil {
fmt.Println("error creating Gateway object:", err)
Expand Down
3 changes: 2 additions & 1 deletion core/services/functions/connector_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ func (h *functionsConnectorHandler) Sign(data ...[]byte) ([]byte, error) {
return common.SignData(h.signerKey, data...)
}

func (h *functionsConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayId string, body *api.MessageBody) {
func (h *functionsConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayId string, msg *api.Message) {
body := &msg.Body
fromAddr := ethCommon.HexToAddress(body.Sender)
if !h.allowlist.Allow(fromAddr) {
h.lggr.Errorw("allowlist prevented the request from this address", "id", gatewayId, "address", fromAddr)
Expand Down
18 changes: 9 additions & 9 deletions core/services/functions/connector_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestFunctionsConnectorHandler(t *testing.T) {
signature, err := handler.Sign([]byte("test"))
require.NoError(t, err)

signer, err := common.ValidateSignature(signature, []byte("test"))
signer, err := common.ExtractSigner(signature, []byte("test"))
require.NoError(t, err)
require.Equal(t, addr.Bytes(), signer)
})
Expand Down Expand Up @@ -77,7 +77,7 @@ func TestFunctionsConnectorHandler(t *testing.T) {

}).Return(nil).Once()

handler.HandleGatewayMessage(ctx, "gw1", &msg.Body)
handler.HandleGatewayMessage(ctx, "gw1", &msg)

t.Run("orm error", func(t *testing.T) {
storage.On("List", ctx, addr).Return(nil, errors.New("boom")).Once()
Expand All @@ -89,12 +89,12 @@ func TestFunctionsConnectorHandler(t *testing.T) {

}).Return(nil).Once()

handler.HandleGatewayMessage(ctx, "gw1", &msg.Body)
handler.HandleGatewayMessage(ctx, "gw1", &msg)
})

t.Run("not allowed", func(t *testing.T) {
allowlist.On("Allow", addr).Return(false).Once()
handler.HandleGatewayMessage(ctx, "gw1", &msg.Body)
handler.HandleGatewayMessage(ctx, "gw1", &msg)
})
})

Expand Down Expand Up @@ -133,7 +133,7 @@ func TestFunctionsConnectorHandler(t *testing.T) {

}).Return(nil).Once()

handler.HandleGatewayMessage(ctx, "gw1", &msg.Body)
handler.HandleGatewayMessage(ctx, "gw1", &msg)

t.Run("orm error", func(t *testing.T) {
storage.On("Put", ctx, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("boom")).Once()
Expand All @@ -145,7 +145,7 @@ func TestFunctionsConnectorHandler(t *testing.T) {

}).Return(nil).Once()

handler.HandleGatewayMessage(ctx, "gw1", &msg.Body)
handler.HandleGatewayMessage(ctx, "gw1", &msg)
})

t.Run("missing signature", func(t *testing.T) {
Expand All @@ -160,7 +160,7 @@ func TestFunctionsConnectorHandler(t *testing.T) {

}).Return(nil).Once()

handler.HandleGatewayMessage(ctx, "gw1", &msg.Body)
handler.HandleGatewayMessage(ctx, "gw1", &msg)
})

t.Run("malformed request", func(t *testing.T) {
Expand All @@ -174,7 +174,7 @@ func TestFunctionsConnectorHandler(t *testing.T) {

}).Return(nil).Once()

handler.HandleGatewayMessage(ctx, "gw1", &msg.Body)
handler.HandleGatewayMessage(ctx, "gw1", &msg)
})
})

Expand All @@ -191,7 +191,7 @@ func TestFunctionsConnectorHandler(t *testing.T) {
require.NoError(t, msg.Sign(privateKey))

allowlist.On("Allow", addr).Return(true).Once()
handler.HandleGatewayMessage(testutils.Context(t), "gw1", &msg.Body)
handler.HandleGatewayMessage(testutils.Context(t), "gw1", &msg)
})
})
}
8 changes: 4 additions & 4 deletions core/services/gateway/api/jsonrpccodec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestJsonRPCRequest_Encode(t *testing.T) {
var msg api.Message
msg.Body = api.MessageBody{
MessageId: "aA-bB",
Sender: "0x1234",
Receiver: "0x1234",
Method: "upload",
}
codec := api.JsonRPCCodec{}
Expand All @@ -54,7 +54,7 @@ func TestJsonRPCRequest_Encode(t *testing.T) {
decoded, err := codec.DecodeRequest(bytes)
require.NoError(t, err)
require.Equal(t, "aA-bB", decoded.Body.MessageId)
require.Equal(t, "0x1234", decoded.Body.Sender)
require.Equal(t, "0x1234", decoded.Body.Receiver)
require.Equal(t, "upload", decoded.Body.Method)
}

Expand All @@ -76,7 +76,7 @@ func TestJsonRPCResponse_Encode(t *testing.T) {
var msg api.Message
msg.Body = api.MessageBody{
MessageId: "aA-bB",
Sender: "0x1234",
Receiver: "0x1234",
Method: "upload",
}
codec := api.JsonRPCCodec{}
Expand All @@ -86,6 +86,6 @@ func TestJsonRPCResponse_Encode(t *testing.T) {
decoded, err := codec.DecodeResponse(bytes)
require.NoError(t, err)
require.Equal(t, "aA-bB", decoded.Body.MessageId)
require.Equal(t, "0x1234", decoded.Body.Sender)
require.Equal(t, "0x1234", decoded.Body.Receiver)
require.Equal(t, "upload", decoded.Body.Method)
}
51 changes: 25 additions & 26 deletions core/services/gateway/api/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"crypto/ecdsa"
"encoding/json"
"errors"
"strings"

gw_common "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common"
"github.com/smartcontractkit/chainlink/v2/core/utils"
Expand All @@ -16,8 +15,7 @@ const (
MessageIdMaxLen = 128
MessageMethodMaxLen = 64
MessageDonIdMaxLen = 64
MessageSenderLen = 20
MessageSenderHexEncodedLen = 2 + 2*MessageSenderLen
MessageReceiverLen = 2 + 2*20
)

/*
Expand All @@ -36,10 +34,12 @@ type MessageBody struct {
MessageId string `json:"message_id"`
Method string `json:"method"`
DonId string `json:"don_id"`
Sender string `json:"sender"`

Receiver string `json:"receiver"`
// Service-specific payload, decoded inside the Handler.
Payload json.RawMessage `json:"payload,omitempty"`

// Fields only used locally for convenience. Not serialized.
Sender string `json:"-"`
}

func (m *Message) Validate() error {
Expand All @@ -58,28 +58,28 @@ func (m *Message) Validate() error {
if len(m.Body.DonId) == 0 || len(m.Body.DonId) > MessageDonIdMaxLen {
return errors.New("invalid DON ID length")
}
signerBytes, err := m.ValidateSignature()
if len(m.Body.Receiver) != 0 && len(m.Body.Receiver) != MessageReceiverLen {
return errors.New("invalid Receiver length")
}
signerBytes, err := m.ExtractSigner()
if err != nil {
return err
}
hexSigner := utils.StringToHex(string(signerBytes))
if m.Body.Sender != "" && !strings.EqualFold(m.Body.Sender, hexSigner) {
return errors.New("sender doesn't match signer")
}
m.Body.Sender = hexSigner
m.Body.Sender = utils.StringToHex(string(signerBytes))
return nil
}

// Message signatures are over the following data:
// 1. MessageId aligned to 128 bytes
// 2. Method aligned to 64 bytes
// 3. DonId aligned to 64 bytes
// 4. Payload (before parsing)
// 4. Receiver (in hex) aligned to 42 bytes
// 5. Payload (raw bytes before parsing)
func (m *Message) Sign(privateKey *ecdsa.PrivateKey) error {
rawData, err := getRawMessageBody(&m.Body)
if err != nil {
return err
if m == nil {
return errors.New("nil message")
}
rawData := getRawMessageBody(&m.Body)
signature, err := gw_common.SignData(privateKey, rawData...)
if err != nil {
return err
Expand All @@ -88,27 +88,26 @@ func (m *Message) Sign(privateKey *ecdsa.PrivateKey) error {
return nil
}

func (m *Message) ValidateSignature() (signerAddress []byte, err error) {
rawData, err := getRawMessageBody(&m.Body)
if err != nil {
return
func (m *Message) ExtractSigner() (signerAddress []byte, err error) {
if m == nil {
return nil, errors.New("nil message")
}
rawData := getRawMessageBody(&m.Body)
signatureBytes, err := utils.TryParseHex(m.Signature)
if err != nil {
return
return nil, err
}
return gw_common.ValidateSignature(signatureBytes, rawData...)
return gw_common.ExtractSigner(signatureBytes, rawData...)
}

func getRawMessageBody(msgBody *MessageBody) ([][]byte, error) {
if msgBody == nil {
return nil, errors.New("nil message")
}
func getRawMessageBody(msgBody *MessageBody) [][]byte {
alignedMessageId := make([]byte, MessageIdMaxLen)
copy(alignedMessageId, msgBody.MessageId)
alignedMethod := make([]byte, MessageMethodMaxLen)
copy(alignedMethod, msgBody.Method)
alignedDonId := make([]byte, MessageDonIdMaxLen)
copy(alignedDonId, msgBody.DonId)
return [][]byte{alignedMessageId, alignedMethod, alignedDonId, msgBody.Payload}, nil
alignedReceiver := make([]byte, MessageReceiverLen)
copy(alignedReceiver, msgBody.Receiver)
return [][]byte{alignedMessageId, alignedMethod, alignedDonId, alignedReceiver, msgBody.Payload}
}
16 changes: 15 additions & 1 deletion core/services/gateway/api/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ func TestMessage_Validate(t *testing.T) {
MessageId: "abcd",
Method: "request",
DonId: "donA",
Receiver: "0x0000000000000000000000000000000000000000",
Payload: []byte("datadata"),
},
}
Expand All @@ -42,6 +43,11 @@ func TestMessage_Validate(t *testing.T) {
require.Error(t, msg.Validate())
msg.Body.Method = "request"

// incorrect receiver
msg.Body.Receiver = "blah"
require.Error(t, msg.Validate())
msg.Body.Receiver = "0x0000000000000000000000000000000000000000"

// invalid signature
msg.Signature = "0x00"
require.Error(t, msg.Validate())
Expand All @@ -55,6 +61,7 @@ func TestMessage_MessageSignAndValidateSignature(t *testing.T) {
MessageId: "abcd",
Method: "request",
DonId: "donA",
Receiver: "0x33",
Payload: []byte("datadata"),
},
}
Expand All @@ -67,7 +74,14 @@ func TestMessage_MessageSignAndValidateSignature(t *testing.T) {
require.NoError(t, err)
require.Equal(t, api.MessageSignatureHexEncodedLen, len(msg.Signature))

signer, err := msg.ValidateSignature()
// valid
signer, err := msg.ExtractSigner()
require.NoError(t, err)
require.True(t, bytes.Equal(address, signer))

// invalid
msg.Body.MessageId = "dbca"
signer, err = msg.ExtractSigner()
require.NoError(t, err)
require.False(t, bytes.Equal(address, signer))
}
16 changes: 4 additions & 12 deletions core/services/gateway/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package common
import (
"crypto/ecdsa"
"encoding/binary"
"errors"

"github.com/ethereum/go-ethereum/crypto"
"golang.org/x/exp/slices"
Expand Down Expand Up @@ -36,18 +35,11 @@ func SignData(privateKey *ecdsa.PrivateKey, data ...[]byte) ([]byte, error) {
return crypto.Sign(hash.Bytes(), privateKey)
}

func ValidateSignature(signature []byte, data ...[]byte) (signerAddress []byte, err error) {
func ExtractSigner(signature []byte, data ...[]byte) (signerAddress []byte, err error) {
hash := crypto.Keccak256Hash(data...)
sigPublicKey, err := crypto.Ecrecover(hash.Bytes(), signature)
ecdsaPubKey, err := crypto.SigToPub(hash.Bytes(), signature)
if err != nil {
return
return nil, err
}
ecdsaPubKey, _ := crypto.UnmarshalPubkey(sigPublicKey)
signerAddress = crypto.PubkeyToAddress(*ecdsaPubKey).Bytes()

signatureNoRecoverID := signature[:len(signature)-1]
if !crypto.VerifySignature(sigPublicKey, hash.Bytes(), signatureNoRecoverID) {
return nil, errors.New("invalid signature")
}
return
return crypto.PubkeyToAddress(*ecdsaPubKey).Bytes(), nil
}
13 changes: 12 additions & 1 deletion core/services/gateway/common/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func TestUtils_BytesSignAndValidate(t *testing.T) {
t.Parallel()

data := []byte("data_data")
incorrectData := []byte("some_other_data")

privateKey, err := crypto.GenerateKey()
require.NoError(t, err)
Expand All @@ -44,7 +45,17 @@ func TestUtils_BytesSignAndValidate(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 65, len(signature))

signer, err := common.ValidateSignature(signature, data)
// valid
signer, err := common.ExtractSigner(signature, data)
require.NoError(t, err)
require.True(t, bytes.Equal(signer, address))

// invalid
signer, err = common.ExtractSigner(signature, incorrectData)
require.NoError(t, err)
require.False(t, bytes.Equal(signer, address))

// invalid format
_, err = common.ExtractSigner([]byte{0xaa, 0xbb}, data)
require.Error(t, err)
}
Loading

0 comments on commit ee5ce33

Please sign in to comment.