diff --git a/integration-tests/smoke/ccip/ccip_rmn_test.go b/integration-tests/smoke/ccip/ccip_rmn_test.go index 16213cd81cd..ea6c1b63c30 100644 --- a/integration-tests/smoke/ccip/ccip_rmn_test.go +++ b/integration-tests/smoke/ccip/ccip_rmn_test.go @@ -6,6 +6,7 @@ import ( "errors" "math/big" "os" + "slices" "strconv" "strings" "testing" @@ -178,9 +179,11 @@ func TestRMN_DifferentRmnNodesForDifferentChains(t *testing.T) { func TestRMN_TwoMessagesOneSourceChainCursed(t *testing.T) { runRmnTestCase(t, rmnTestCase{ - name: "two messages, one source chain is cursed", - passIfNoCommitAfter: 15 * time.Second, - cursedSourceChainIdxs: []int{chain0}, // <---- chain0 is cursed (only as source) + name: "two messages, one source chain is cursed", + passIfNoCommitAfter: 15 * time.Second, + cursedSubjectsPerChain: map[int][]int{ + chain1: {chain0}, + }, homeChainConfig: homeChainConfig{ f: map[int]int{chain0: 1, chain1: 1}, }, @@ -220,14 +223,18 @@ func TestRMN_GlobalCurseTwoMessagesOnTwoLanes(t *testing.T) { {fromChainIdx: chain0, toChainIdx: chain1, count: 1}, {fromChainIdx: chain1, toChainIdx: chain0, count: 5}, }, - globalCurse: true, + cursedSubjectsPerChain: map[int][]int{ + chain1: {globalCurse}, + chain0: {globalCurse}, + }, passIfNoCommitAfter: 15 * time.Second, }) } const ( - chain0 = 0 - chain1 = 1 + chain0 = 0 + chain1 = 1 + globalCurse = 1000 ) func runRmnTestCase(t *testing.T, tc rmnTestCase) { @@ -305,7 +312,11 @@ func runRmnTestCase(t *testing.T, tc rmnTestCase) { expectedSeqNum := make(map[changeset.SourceDestPair]uint64) for k, v := range seqNumCommit { - if !tc.pf.cursedSourceChains.Contains(k.SourceChainSelector) { + cursedSubjectsOfDest, exists := tc.pf.cursedSubjectsPerChainSel[k.DestChainSelector] + shouldSkip := exists && (slices.Contains(cursedSubjectsOfDest, globalCurse) || + slices.Contains(cursedSubjectsOfDest, k.SourceChainSelector)) + + if !shouldSkip { expectedSeqNum[k] = v } } @@ -313,16 +324,19 @@ func runRmnTestCase(t *testing.T, tc rmnTestCase) { t.Logf("expectedSeqNums: %v", expectedSeqNum) t.Logf("expectedSeqNums including cursed chains: %v", seqNumCommit) - if len(tc.cursedSourceChainIdxs) > 0 && len(seqNumCommit) == len(expectedSeqNum) { - t.Fatalf("test case is wrong: no message was sent to non-cursed chains") + if len(tc.cursedSubjectsPerChain) > 0 && len(seqNumCommit) == len(expectedSeqNum) { + t.Fatalf("test case is wrong: no message was sent to non-cursed chains when you " + + "define curse subjects, your test case should have at least one message not expected to be delivered") } commitReportReceived := make(chan struct{}) go func() { - changeset.ConfirmCommitForAllWithExpectedSeqNums(t, envWithRMN.Env, onChainState, expectedSeqNum, startBlocks) - commitReportReceived <- struct{}{} + if len(expectedSeqNum) > 0 { + changeset.ConfirmCommitForAllWithExpectedSeqNums(t, envWithRMN.Env, onChainState, expectedSeqNum, startBlocks) + commitReportReceived <- struct{}{} + } - if tc.pf.cursedSourceChains.Cardinality() > 0 { + if len(seqNumCommit) > 0 && len(seqNumCommit) > len(expectedSeqNum) { // wait for a duration and assert that commit reports were not delivered for cursed source chains changeset.ConfirmCommitForAllWithExpectedSeqNums(t, envWithRMN.Env, onChainState, seqNumCommit, startBlocks) commitReportReceived <- struct{}{} @@ -330,7 +344,7 @@ func runRmnTestCase(t *testing.T, tc rmnTestCase) { }() if tc.passIfNoCommitAfter > 0 { // wait for a duration and assert that commit reports were not delivered - if len(tc.cursedSourceChainIdxs) > 0 && len(expectedSeqNum) > 0 { + if len(expectedSeqNum) > 0 && len(seqNumCommit) > len(expectedSeqNum) { t.Logf("⌛ Waiting for commit reports of non-cursed chains...") <-commitReportReceived t.Logf("✅ Commit reports of non-cursed chains received") @@ -403,29 +417,24 @@ type rmnTestCase struct { name string // If set to 0, the test will wait for commit reports. // If set to a positive value, the test will wait for that duration and will assert that commit report was not delivered. - passIfNoCommitAfter time.Duration - // If set to true, the test will only wait for non-cursed chain msgs. - // And then wait for passIfNoCommitAfter (must be set) to assert that msgs from cursed sources are not transmitted. - // At the moment, it does not support waitForExec=true since only commit plugin has cursing checks. - cursedSourceChainIdxs []int - // globalCurse marks every chain as cursed by setting the global curse subject on each rmnRemote - globalCurse bool - waitForExec bool - homeChainConfig homeChainConfig - remoteChainsConfig []remoteChainConfig - rmnNodes []rmnNode - messagesToSend []messageToSend + passIfNoCommitAfter time.Duration + cursedSubjectsPerChain map[int][]int + waitForExec bool + homeChainConfig homeChainConfig + remoteChainsConfig []remoteChainConfig + rmnNodes []rmnNode + messagesToSend []messageToSend // populated fields after environment setup pf testCasePopulatedFields } type testCasePopulatedFields struct { - chainSelectors []uint64 - rmnHomeNodes []rmn_home.RMNHomeNode - rmnRemoteSigners []rmn_remote.RMNRemoteSigner - rmnHomeSourceChains []rmn_home.RMNHomeSourceChain - cursedSourceChains mapset.Set[uint64] + chainSelectors []uint64 + rmnHomeNodes []rmn_home.RMNHomeNode + rmnRemoteSigners []rmn_remote.RMNRemoteSigner + rmnHomeSourceChains []rmn_home.RMNHomeSourceChain + cursedSubjectsPerChainSel map[uint64][]uint64 } func (tc *rmnTestCase) populateFields(t *testing.T, envWithRMN changeset.DeployedEnv, rmnCluster devenv.RMNCluster) { @@ -468,31 +477,25 @@ func (tc *rmnTestCase) populateFields(t *testing.T, envWithRMN changeset.Deploye }) } - tc.pf.cursedSourceChains = mapset.NewSet[uint64]() - for _, chainIdx := range tc.cursedSourceChainIdxs { - tc.pf.cursedSourceChains.Add(tc.pf.chainSelectors[chainIdx]) + // populate cursed subjects with actual chain selectors + tc.pf.cursedSubjectsPerChainSel = make(map[uint64][]uint64) + for chainIdx, subjects := range tc.cursedSubjectsPerChain { + chainSel := tc.pf.chainSelectors[chainIdx] + for _, subject := range subjects { + subjSel := uint64(globalCurse) + if subject != globalCurse { + subjSel = tc.pf.chainSelectors[subject] + } + tc.pf.cursedSubjectsPerChainSel[chainSel] = append(tc.pf.cursedSubjectsPerChainSel[chainSel], subjSel) + } } } func (tc rmnTestCase) validate() error { - if len(tc.cursedSourceChainIdxs) > 0 { - if tc.waitForExec { - return errors.New("cursedSourceChainIdxs is set but waitForExec is true which is not supported") - } - if tc.passIfNoCommitAfter == 0 { - return errors.New("cursedSourceChainIdxs is set but passIfNoCommitAfter is not set") - } + if len(tc.cursedSubjectsPerChain) > 0 && tc.passIfNoCommitAfter == 0 { + return errors.New("when you define cursed subjects you also need to define the duration that the " + + "test will wait for non-transmitted roots") } - - if tc.globalCurse { - if tc.passIfNoCommitAfter == 0 { - return errors.New("globalCurse is set but passIfNoCommitAfter is not set") - } - if len(tc.cursedSourceChainIdxs) > 0 { - return errors.New("globalCurse is set but cursedSourceChainIdxs is not empty, this is not supported") - } - } - return nil } @@ -550,7 +553,7 @@ func (tc rmnTestCase) killMarkedRmnNodes(t *testing.T, rmnCluster devenv.RMNClus func (tc rmnTestCase) disableOraclesIfThisIsACursingTestCase(ctx context.Context, t *testing.T, envWithRMN changeset.DeployedEnv) []string { disabledNodes := make([]string, 0) - if len(tc.cursedSourceChainIdxs) > 0 || tc.globalCurse { + if len(tc.cursedSubjectsPerChain) > 0 { listNodesResp, err := envWithRMN.Env.Offchain.ListNodes(ctx, &node.ListNodesRequest{}) require.NoError(t, err) @@ -610,15 +613,22 @@ func (tc rmnTestCase) callContractsToCurseChains(ctx context.Context, t *testing require.True(t, ok) chain, ok := envWithRMN.Env.Chains[remoteSel] require.True(t, ok) - for _, chainSel := range tc.pf.cursedSourceChains.ToSlice() { - txCurse, errCurse := chState.RMNRemote.Curse(chain.DeployerKey, chainSelectorToBytes16(chainSel)) - _, errConfirm := deployment.ConfirmIfNoError(chain, txCurse, errCurse) - require.NoError(t, errConfirm) + + cursedSubjects, ok := tc.cursedSubjectsPerChain[remoteCfg.chainIdx] + if !ok { + continue // nothing to curse on this chain } - if tc.globalCurse { - txCurseGlobal, errCurseGlobal := chState.RMNRemote.Curse(chain.DeployerKey, types.GlobalCurseSubject) - _, errConfirm := deployment.ConfirmIfNoError(chain, txCurseGlobal, errCurseGlobal) + for _, subjectDescription := range cursedSubjects { + subj := [16]byte{} + if subjectDescription == globalCurse { + subj = types.GlobalCurseSubject + } else { + subj = chainSelectorToBytes16(tc.pf.chainSelectors[subjectDescription]) + } + t.Logf("cursing subject %d (%d)", subj, subjectDescription) + txCurse, errCurse := chState.RMNRemote.Curse(chain.DeployerKey, subj) + _, errConfirm := deployment.ConfirmIfNoError(chain, txCurse, errCurse) require.NoError(t, errConfirm) }