Skip to content

Commit

Permalink
[CAPPL-20] Support per-method handlers in GatewayConnector (#14367)
Browse files Browse the repository at this point in the history
Making GatewayConnector compatible with the new design, where each capability is able to add its own handler independently.
  • Loading branch information
bolekk authored Sep 9, 2024
1 parent 8f46d81 commit cd8be70
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 35 deletions.
5 changes: 5 additions & 0 deletions .changeset/thick-jobs-kneel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"chainlink": patch
---

Support per-method handlers in GatewayConnector
8 changes: 7 additions & 1 deletion core/scripts/gateway/connector/run_connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,13 @@ func main() {
sampleKey, _ := crypto.HexToECDSA("cd47d3fafdbd652dd2b66c6104fa79b372c13cb01f4a4fbfc36107cce913ac1d")
lggr, _ := logger.NewLogger()
client := &client{privateKey: sampleKey, lggr: lggr}
connector, _ := connector.NewGatewayConnector(&cfg, client, client, clockwork.NewRealClock(), lggr)
// client acts as a signer here
connector, _ := connector.NewGatewayConnector(&cfg, client, clockwork.NewRealClock(), lggr)
err = connector.AddHandler([]string{"test_method"}, client)
if err != nil {
fmt.Println("error adding handler:", err)
return
}
client.connector = connector

ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
Expand Down
40 changes: 30 additions & 10 deletions core/services/gateway/connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type GatewayConnector interface {
job.ServiceCtx
network.ConnectionInitiator

AddHandler(methods []string, handler GatewayConnectorHandler) error
SendToGateway(ctx context.Context, gatewayId string, msg *api.Message) error
}

Expand All @@ -51,7 +52,7 @@ type gatewayConnector struct {
clock clockwork.Clock
nodeAddress []byte
signer Signer
handler GatewayConnectorHandler
handlers map[string]GatewayConnectorHandler
gateways map[string]*gatewayState
urlToId map[string]string
closeWait sync.WaitGroup
Expand All @@ -76,8 +77,8 @@ type gatewayState struct {
wsClient network.WebSocketClient
}

func NewGatewayConnector(config *ConnectorConfig, signer Signer, handler GatewayConnectorHandler, clock clockwork.Clock, lggr logger.Logger) (GatewayConnector, error) {
if config == nil || signer == nil || handler == nil || clock == nil || lggr == nil {
func NewGatewayConnector(config *ConnectorConfig, signer Signer, clock clockwork.Clock, lggr logger.Logger) (GatewayConnector, error) {
if config == nil || signer == nil || clock == nil || lggr == nil {
return nil, errors.New("nil dependency")
}
if len(config.DonId) == 0 || len(config.DonId) > network.HandshakeDonIdLen {
Expand All @@ -93,7 +94,7 @@ func NewGatewayConnector(config *ConnectorConfig, signer Signer, handler Gateway
clock: clock,
nodeAddress: addressBytes,
signer: signer,
handler: handler,
handlers: make(map[string]GatewayConnectorHandler),
shutdownCh: make(chan struct{}),
lggr: lggr.Named("GatewayConnector"),
}
Expand Down Expand Up @@ -125,6 +126,22 @@ func NewGatewayConnector(config *ConnectorConfig, signer Signer, handler Gateway
return connector, nil
}

func (c *gatewayConnector) AddHandler(methods []string, handler GatewayConnectorHandler) error {
if handler == nil {
return errors.New("cannot add a nil handler")
}
for _, method := range methods {
if _, exists := c.handlers[method]; exists {
return fmt.Errorf("handler for method %s already exists", method)
}
}
// add all or nothing
for _, method := range methods {
c.handlers[method] = handler
}
return nil
}

func (c *gatewayConnector) SendToGateway(ctx context.Context, gatewayId string, msg *api.Message) error {
data, err := c.codec.EncodeResponse(msg)
if err != nil {
Expand Down Expand Up @@ -159,7 +176,12 @@ func (c *gatewayConnector) readLoop(gatewayState *gatewayState) {
c.lggr.Errorw("failed to validate message signature", "id", gatewayState.config.Id, "err", err)
break
}
c.handler.HandleGatewayMessage(ctx, gatewayState.config.Id, msg)
handler, exists := c.handlers[msg.Body.Method]
if !exists {
c.lggr.Errorw("no handler for method", "id", gatewayState.config.Id, "method", msg.Body.Method)
break
}
handler.HandleGatewayMessage(ctx, gatewayState.config.Id, msg)
}
}
}
Expand Down Expand Up @@ -194,9 +216,6 @@ func (c *gatewayConnector) reconnectLoop(gatewayState *gatewayState) {
func (c *gatewayConnector) Start(ctx context.Context) error {
return c.StartOnce("GatewayConnector", func() error {
c.lggr.Info("starting gateway connector")
if err := c.handler.Start(ctx); err != nil {
return err
}
for _, gatewayState := range c.gateways {
gatewayState := gatewayState
if err := gatewayState.conn.Start(ctx); err != nil {
Expand All @@ -214,11 +233,12 @@ func (c *gatewayConnector) Close() error {
return c.StopOnce("GatewayConnector", func() (err error) {
c.lggr.Info("closing gateway connector")
close(c.shutdownCh)
var errs error
for _, gatewayState := range c.gateways {
gatewayState.conn.Close()
errs = errors.Join(errs, gatewayState.conn.Close())
}
c.closeWait.Wait()
return c.handler.Close()
return errs
})
}

Expand Down
35 changes: 23 additions & 12 deletions core/services/gateway/connector/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/network"
)

const defaultConfig = `
const (
defaultConfig = `
NodeAddress = "0x68902d681c28119f9b2531473a417088bf008e59"
DonId = "example_don"
AuthMinChallengeLen = 10
Expand All @@ -32,6 +33,9 @@ URL = "ws://localhost:8081/node"
Id = "another_one"
URL = "wss://example.com:8090/node_endpoint"
`
testMethod1 = "test_method_1"
testMethod2 = "test_method_2"
)

func parseTOMLConfig(t *testing.T, tomlConfig string) *connector.ConnectorConfig {
var cfg connector.ConnectorConfig
Expand All @@ -40,12 +44,13 @@ func parseTOMLConfig(t *testing.T, tomlConfig string) *connector.ConnectorConfig
return &cfg
}

func newTestConnector(t *testing.T, config *connector.ConnectorConfig, now time.Time) (connector.GatewayConnector, *mocks.Signer, *mocks.GatewayConnectorHandler) {
func newTestConnector(t *testing.T, config *connector.ConnectorConfig) (connector.GatewayConnector, *mocks.Signer, *mocks.GatewayConnectorHandler) {
signer := mocks.NewSigner(t)
handler := mocks.NewGatewayConnectorHandler(t)
clock := clockwork.NewFakeClock()
connector, err := connector.NewGatewayConnector(config, signer, handler, clock, logger.TestLogger(t))
connector, err := connector.NewGatewayConnector(config, signer, clock, logger.TestLogger(t))
require.NoError(t, err)
require.NoError(t, connector.AddHandler([]string{testMethod1}, handler))
return connector, signer, handler
}

Expand All @@ -61,7 +66,7 @@ Id = "example_gateway"
URL = "ws://localhost:8081/node"
`)

newTestConnector(t, tomlConfig, time.Now())
newTestConnector(t, tomlConfig)
}

func TestGatewayConnector_NewGatewayConnector_InvalidConfig(t *testing.T) {
Expand Down Expand Up @@ -103,12 +108,11 @@ URL = "ws://localhost:8081/node"
}

signer := mocks.NewSigner(t)
handler := mocks.NewGatewayConnectorHandler(t)
clock := clockwork.NewFakeClock()
for name, config := range invalidCases {
config := config
t.Run(name, func(t *testing.T) {
_, err := connector.NewGatewayConnector(parseTOMLConfig(t, config), signer, handler, clock, logger.TestLogger(t))
_, err := connector.NewGatewayConnector(parseTOMLConfig(t, config), signer, clock, logger.TestLogger(t))
require.Error(t, err)
})
}
Expand All @@ -117,17 +121,15 @@ URL = "ws://localhost:8081/node"
func TestGatewayConnector_CleanStartAndClose(t *testing.T) {
t.Parallel()

connector, signer, handler := newTestConnector(t, parseTOMLConfig(t, defaultConfig), time.Now())
handler.On("Start", mock.Anything).Return(nil)
handler.On("Close").Return(nil)
connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig))
signer.On("Sign", mock.Anything).Return(nil, errors.New("cannot sign"))
servicetest.Run(t, connector)
}

func TestGatewayConnector_NewAuthHeader_SignerError(t *testing.T) {
t.Parallel()

connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig), time.Now())
connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig))
signer.On("Sign", mock.Anything).Return(nil, errors.New("cannot sign"))

url, err := url.Parse("ws://localhost:8081/node")
Expand All @@ -141,7 +143,7 @@ func TestGatewayConnector_NewAuthHeader_Success(t *testing.T) {

testSignature := make([]byte, network.HandshakeSignatureLen)
testSignature[1] = 0xfa
connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig), time.Now())
connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig))
signer.On("Sign", mock.Anything).Return(testSignature, nil)
url, err := url.Parse("ws://localhost:8081/node")
require.NoError(t, err)
Expand All @@ -157,7 +159,7 @@ func TestGatewayConnector_ChallengeResponse(t *testing.T) {
testSignature := make([]byte, network.HandshakeSignatureLen)
testSignature[1] = 0xfa
now := time.Now()
connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig), now)
connector, signer, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig))
signer.On("Sign", mock.Anything).Return(testSignature, nil)
url, err := url.Parse("ws://localhost:8081/node")
require.NoError(t, err)
Expand Down Expand Up @@ -191,3 +193,12 @@ func TestGatewayConnector_ChallengeResponse(t *testing.T) {
_, err = connector.ChallengeResponse(url, network.PackChallenge(&badChallenge))
require.Equal(t, network.ErrAuthInvalidGateway, err)
}

func TestGatewayConnector_AddHandler(t *testing.T) {
t.Parallel()

connector, _, _ := newTestConnector(t, parseTOMLConfig(t, defaultConfig))
// testMethod1 already exists
require.Error(t, connector.AddHandler([]string{testMethod1}, mocks.NewGatewayConnectorHandler(t)))
require.NoError(t, connector.AddHandler([]string{testMethod2}, mocks.NewGatewayConnectorHandler(t)))
}
48 changes: 48 additions & 0 deletions core/services/gateway/connector/mocks/gateway_connector.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,10 @@ func TestIntegration_Gateway_NoFullNodes_BasicConnectionAndMessage(t *testing.T)

// Launch Connector
client := &client{privateKey: nodeKeys.PrivateKey}
connector, err := connector.NewGatewayConnector(parseConnectorConfig(t, nodeConfigTemplate, nodeKeys.Address, nodeUrl), client, client, clockwork.NewRealClock(), lggr)
// client acts as a signer here
connector, err := connector.NewGatewayConnector(parseConnectorConfig(t, nodeConfigTemplate, nodeKeys.Address, nodeUrl), client, clockwork.NewRealClock(), lggr)
require.NoError(t, err)
require.NoError(t, connector.AddHandler([]string{"test"}, client))
client.connector = connector
servicetest.Run(t, connector)

Expand Down
25 changes: 16 additions & 9 deletions core/services/ocr2/plugins/functions/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/services/functions"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector"
hc "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common"
hf "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions"
gwAllowlist "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/allowlist"
gwSubscriptions "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/subscriptions"
"github.com/smartcontractkit/chainlink/v2/core/services/job"
Expand Down Expand Up @@ -174,11 +175,12 @@ func NewFunctionsServices(ctx context.Context, functionsOracleArgs, thresholdOra
return nil, errors.Wrap(err, "failed to create a OnchainSubscriptions")
}
connectorLogger := conf.Logger.Named("GatewayConnector").With("jobName", conf.Job.PipelineSpec.JobName)
connector, err2 := NewConnector(ctx, &pluginConfig, conf.EthKeystore, conf.Chain.ID(), s4Storage, allowlist, rateLimiter, subscriptions, functionsListener, offchainTransmitter, connectorLogger)
connector, handler, err2 := NewConnector(ctx, &pluginConfig, conf.EthKeystore, conf.Chain.ID(), s4Storage, allowlist, rateLimiter, subscriptions, functionsListener, offchainTransmitter, connectorLogger)
if err2 != nil {
return nil, errors.Wrap(err, "failed to create a GatewayConnector")
}
allServices = append(allServices, connector)
allServices = append(allServices, handler)
} else {
listenerLogger.Warn("Insufficient config, GatewayConnector will not be enabled")
}
Expand All @@ -201,29 +203,34 @@ func NewFunctionsServices(ctx context.Context, functionsOracleArgs, thresholdOra
return allServices, nil
}

func NewConnector(ctx context.Context, pluginConfig *config.PluginConfig, ethKeystore keystore.Eth, chainID *big.Int, s4Storage s4.Storage, allowlist gwAllowlist.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions gwSubscriptions.OnchainSubscriptions, listener functions.FunctionsListener, offchainTransmitter functions.OffchainTransmitter, lggr logger.Logger) (connector.GatewayConnector, error) {
func NewConnector(ctx context.Context, pluginConfig *config.PluginConfig, ethKeystore keystore.Eth, chainID *big.Int, s4Storage s4.Storage, allowlist gwAllowlist.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions gwSubscriptions.OnchainSubscriptions, listener functions.FunctionsListener, offchainTransmitter functions.OffchainTransmitter, lggr logger.Logger) (connector.GatewayConnector, connector.GatewayConnectorHandler, error) {
enabledKeys, err := ethKeystore.EnabledKeysForChain(ctx, chainID)
if err != nil {
return nil, err
return nil, nil, err
}
configuredNodeAddress := common.HexToAddress(pluginConfig.GatewayConnectorConfig.NodeAddress)
idx := slices.IndexFunc(enabledKeys, func(key ethkey.KeyV2) bool { return key.Address == configuredNodeAddress })
if idx == -1 {
return nil, errors.New("key for configured node address not found")
return nil, nil, errors.New("key for configured node address not found")
}
signerKey := enabledKeys[idx].ToEcdsaPrivKey()
if enabledKeys[idx].ID() != pluginConfig.GatewayConnectorConfig.NodeAddress {
return nil, errors.New("node address mismatch")
return nil, nil, errors.New("node address mismatch")
}

handler, err := functions.NewFunctionsConnectorHandler(pluginConfig, signerKey, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, lggr)
if err != nil {
return nil, err
return nil, nil, err
}
connector, err := connector.NewGatewayConnector(pluginConfig.GatewayConnectorConfig, handler, handler, clockwork.NewRealClock(), lggr)
// handler acts as a signer here
connector, err := connector.NewGatewayConnector(pluginConfig.GatewayConnectorConfig, handler, clockwork.NewRealClock(), lggr)
if err != nil {
return nil, err
return nil, nil, err
}
err = connector.AddHandler([]string{hf.MethodSecretsSet, hf.MethodSecretsList, hf.MethodHeartbeat}, handler)
if err != nil {
return nil, nil, err
}
handler.SetConnector(connector)
return connector, nil
return connector, handler, nil
}
4 changes: 2 additions & 2 deletions core/services/ocr2/plugins/functions/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestNewConnector_Success(t *testing.T) {
config := &config.PluginConfig{
GatewayConnectorConfig: gwcCfg,
}
_, err = functions.NewConnector(ctx, config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t))
_, _, err = functions.NewConnector(ctx, config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t))
require.NoError(t, err)
}

Expand Down Expand Up @@ -78,6 +78,6 @@ func TestNewConnector_NoKeyForConfiguredAddress(t *testing.T) {
config := &config.PluginConfig{
GatewayConnectorConfig: gwcCfg,
}
_, err = functions.NewConnector(ctx, config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t))
_, _, err = functions.NewConnector(ctx, config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t))
require.Error(t, err)
}

0 comments on commit cd8be70

Please sign in to comment.