From ae67a9eddc98f792df02890f659be136bb3f37db Mon Sep 17 00:00:00 2001 From: Michael Street <5597260+MStreet3@users.noreply.github.com> Date: Thu, 19 Dec 2024 05:15:15 -0500 Subject: [PATCH] fix(deployment): enables multiple ocr3 contracts per chain --- deployment/address_book.go | 12 +++++-- deployment/address_book_test.go | 28 +++++++++++++++ deployment/keystone/changeset/deploy_ocr3.go | 2 ++ deployment/keystone/deploy.go | 36 ++++++++++++++++++-- deployment/keystone/state.go | 10 +++++- 5 files changed, 81 insertions(+), 7 deletions(-) diff --git a/deployment/address_book.go b/deployment/address_book.go index 3ce0332a4c3..46935a894a3 100644 --- a/deployment/address_book.go +++ b/deployment/address_book.go @@ -2,6 +2,7 @@ package deployment import ( "fmt" + "log" "strings" "sync" @@ -83,7 +84,7 @@ func NewTypeAndVersion(t ContractType, v semver.Version) TypeAndVersion { type AddressBook interface { Save(chainSelector uint64, address string, tv TypeAndVersion) error Addresses() (map[uint64]map[string]TypeAndVersion, error) - AddressesForChain(chain uint64) (map[string]TypeAndVersion, error) + AddressesForChain(chain uint64, opts ...func(map[string]TypeAndVersion)) (map[string]TypeAndVersion, error) // Allows for merging address books (e.g. new deployments with existing ones) Merge(other AddressBook) error Remove(ab AddressBook) error @@ -149,7 +150,7 @@ func (m *AddressBookMap) Addresses() (map[uint64]map[string]TypeAndVersion, erro return m.cloneAddresses(m.addressesByChain), nil } -func (m *AddressBookMap) AddressesForChain(chainSelector uint64) (map[string]TypeAndVersion, error) { +func (m *AddressBookMap) AddressesForChain(chainSelector uint64, opts ...func(map[string]TypeAndVersion)) (map[string]TypeAndVersion, error) { _, err := chainsel.GetChainIDFromSelector(chainSelector) if err != nil { return nil, errors.Wrapf(ErrInvalidChainSelector, "chain selector %d", chainSelector) @@ -165,7 +166,12 @@ func (m *AddressBookMap) AddressesForChain(chainSelector uint64) (map[string]Typ // maps are mutable and pass via a pointer // creating a copy of the map to prevent concurrency // read and changes outside object-bound - return maps.Clone(m.addressesByChain[chainSelector]), nil + cloned := maps.Clone(m.addressesByChain[chainSelector]) + for _, opt := range opts { + log.Printf("Applying option %v", opt) + opt(cloned) + } + return cloned, nil } // Merge will merge the addresses from another address book into this one. diff --git a/deployment/address_book_test.go b/deployment/address_book_test.go index e022e89a9ab..3277a111af8 100644 --- a/deployment/address_book_test.go +++ b/deployment/address_book_test.go @@ -72,6 +72,34 @@ func TestAddressBook_Save(t *testing.T) { }) } +func TestAddressBook_AddressesForChain(t *testing.T) { + ab := NewMemoryAddressBook() + ocr3Cap100 := NewTypeAndVersion("OCR3Capability", Version1_0_0) + copyOCR3Cap100 := NewTypeAndVersion("OCR3Capability", Version1_0_0) + + addr1 := common.HexToAddress("0x1").String() + addr2 := common.HexToAddress("0x2").String() + + err := ab.Save(chainsel.TEST_90000001.Selector, addr1, ocr3Cap100) + require.NoError(t, err) + + err = ab.Save(chainsel.TEST_90000001.Selector, addr2, copyOCR3Cap100) + require.NoError(t, err) + + addresses, err := ab.AddressesForChain(chainsel.TEST_90000001.Selector, + func(m map[string]TypeAndVersion) { + for k, v := range m { + if v.Type == "OCR3Capability" { + if k != addr1 { + delete(m, k) + } + } + } + }, + ) + require.Len(t, addresses, 1) +} + func TestAddressBook_Merge(t *testing.T) { onRamp100 := NewTypeAndVersion("OnRamp", Version1_0_0) onRamp110 := NewTypeAndVersion("OnRamp", Version1_1_0) diff --git a/deployment/keystone/changeset/deploy_ocr3.go b/deployment/keystone/changeset/deploy_ocr3.go index 057bba4c12d..b9f74110bc9 100644 --- a/deployment/keystone/changeset/deploy_ocr3.go +++ b/deployment/keystone/changeset/deploy_ocr3.go @@ -37,6 +37,7 @@ var _ deployment.ChangeSet[ConfigureOCR3Config] = ConfigureOCR3Contract type ConfigureOCR3Config struct { ChainSel uint64 NodeIDs []string + OCR3ContractAddr *string OCR3Config *kslib.OracleConfig DryRun bool WriteGeneratedConfig io.Writer // if not nil, write the generated config to this writer as JSON [OCR2OracleConfig] @@ -55,6 +56,7 @@ func ConfigureOCR3Contract(env deployment.Environment, cfg ConfigureOCR3Config) NodeIDs: cfg.NodeIDs, OCR3Config: cfg.OCR3Config, DryRun: cfg.DryRun, + OCR3Addr: cfg.OCR3ContractAddr, UseMCMS: cfg.UseMCMS(), }) if err != nil { diff --git a/deployment/keystone/deploy.go b/deployment/keystone/deploy.go index 3d415c5d2da..5c386f4d668 100644 --- a/deployment/keystone/deploy.go +++ b/deployment/keystone/deploy.go @@ -351,7 +351,11 @@ type ConfigureOCR3Config struct { ChainSel uint64 NodeIDs []string OCR3Config *OracleConfig - DryRun bool + + // OCR3Addr is the address of the OCR3 contract to configure. If nil, and there are more than + // one deployed contracts on the chain, then the configured contract is non-deterministic. + OCR3Addr *string + DryRun bool UseMCMS bool } @@ -367,9 +371,16 @@ func ConfigureOCR3ContractFromJD(env *deployment.Environment, cfg ConfigureOCR3C if !ok { return nil, fmt.Errorf("chain %d not found in environment", cfg.ChainSel) } + + filterFunc := identityFilterer + if cfg.OCR3Addr != nil { + filterFunc = makeOCR3CapabilityFilterer(*cfg.OCR3Addr) + } + contractSetsResp, err := GetContractSets(env.Logger, &GetContractSetsRequest{ - Chains: env.Chains, - AddressBook: env.ExistingAddresses, + Chains: env.Chains, + AddressBook: env.ExistingAddresses, + AddressFilterer: filterFunc, }) if err != nil { return nil, fmt.Errorf("failed to get contract sets: %w", err) @@ -971,3 +982,22 @@ func configureForwarder(lggr logger.Logger, chain deployment.Chain, contractSet } return opMap, nil } + +// makeOCR3CapabilityFilterer returns a filter func that deletes any OCR3Capability contract entries +// that do not match a given OCR3 Address. If no address is given, no filtering is done. +func makeOCR3CapabilityFilterer(ocr3Addr string) func(map[string]deployment.TypeAndVersion) { + return func(m map[string]deployment.TypeAndVersion) { + for k, v := range m { + if v.Type == OCR3Capability { + if k != ocr3Addr { + delete(m, k) + } + } + } + } +} + +// identityFilterer is a filter func that does nothing +func identityFilterer(m map[string]deployment.TypeAndVersion) { + // no-op +} diff --git a/deployment/keystone/state.go b/deployment/keystone/state.go index 0ac7cdc89ed..851dea650a1 100644 --- a/deployment/keystone/state.go +++ b/deployment/keystone/state.go @@ -19,6 +19,10 @@ import ( type GetContractSetsRequest struct { Chains map[uint64]deployment.Chain AddressBook deployment.AddressBook + + // AddressFilterer is a function that filters the addresses a given address book. Filtering + // mutates the input map so the input map should be a copy of the original map. + AddressFilterer func(map[string]deployment.TypeAndVersion) } type GetContractSetsResponse struct { @@ -62,8 +66,12 @@ func GetContractSets(lggr logger.Logger, req *GetContractSetsRequest) (*GetContr resp := &GetContractSetsResponse{ ContractSets: make(map[uint64]ContractSet), } + var opts []func(map[string]deployment.TypeAndVersion) + if req.AddressFilterer != nil { + opts = append(opts, req.AddressFilterer) + } for id, chain := range req.Chains { - addrs, err := req.AddressBook.AddressesForChain(id) + addrs, err := req.AddressBook.AddressesForChain(id, opts...) if err != nil { return nil, fmt.Errorf("failed to get addresses for chain %d: %w", id, err) }