diff --git a/pkg/solana/chainwriter/chain_writer.go b/pkg/solana/chainwriter/chain_writer.go index 4ed9c8a60..e02148d89 100644 --- a/pkg/solana/chainwriter/chain_writer.go +++ b/pkg/solana/chainwriter/chain_writer.go @@ -11,7 +11,6 @@ import ( "github.com/gagliardetto/solana-go/rpc" commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" - "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings/binary" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/types" @@ -31,7 +30,8 @@ type SolanaChainWriterService struct { ge fees.Estimator config ChainWriterConfig - codecs map[string]types.Codec + parsed *codec.ParsedTypes + encoder types.Encoder services.StateMachine } @@ -62,48 +62,54 @@ type MethodConfig struct { } func NewSolanaChainWriterService(logger logger.Logger, reader client.Reader, txm txm.TxManager, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) { - codecs, err := parseIDLCodecs(config) - if err != nil { - return nil, fmt.Errorf("failed to parse IDL codecs: %w", err) - } - - return &SolanaChainWriterService{ + w := SolanaChainWriterService{ lggr: logger, reader: reader, txm: txm, ge: ge, config: config, - codecs: codecs, - }, nil + parsed: &codec.ParsedTypes{EncoderDefs: map[string]codec.Entry{}, DecoderDefs: map[string]codec.Entry{}}, + } + + if err := w.parsePrograms(config); err != nil { + return nil, fmt.Errorf("failed to parse programs: %w", err) + } + + var err error + if w.encoder, err = w.parsed.ToCodec(); err != nil { + return nil, fmt.Errorf("%w: failed to create codec", err) + } + + return &w, nil } -func parseIDLCodecs(config ChainWriterConfig) (map[string]types.Codec, error) { - codecs := make(map[string]types.Codec) +func (s *SolanaChainWriterService) parsePrograms(config ChainWriterConfig) error { for program, programConfig := range config.Programs { var idl codec.IDL if err := json.Unmarshal([]byte(programConfig.IDL), &idl); err != nil { - return nil, fmt.Errorf("failed to unmarshal IDL for program: %s, error: %w", program, err) - } - idlCodec, err := codec.NewIDLInstructionsCodec(idl, binary.LittleEndian()) - if err != nil { - return nil, fmt.Errorf("failed to create codec from IDL for program: %s, error: %w", program, err) + return fmt.Errorf("failed to unmarshal IDL for program: %s, error: %w", program, err) } for method, methodConfig := range programConfig.Methods { - if methodConfig.InputModifications != nil { - modConfig, err := methodConfig.InputModifications.ToModifier(codec.DecoderHooks...) - if err != nil { - return nil, fmt.Errorf("failed to create input modifications for method %s.%s, error: %w", program, method, err) - } - // add mods to codec - idlCodec, err = codec.NewNamedModifierCodec(idlCodec, method, modConfig) - if err != nil { - return nil, fmt.Errorf("failed to create named codec for method %s.%s, error: %w", program, method, err) - } + idlDef, err := codec.FindDefinitionFromIDL(codec.ChainConfigTypeInstructionDef, methodConfig.ChainSpecificName, idl) + if err != nil { + return err } + + inputMod, err := methodConfig.InputModifications.ToModifier(codec.DecoderHooks...) + if err != nil { + return fmt.Errorf("failed to create input modifications for method %s.%s, error: %w", program, method, err) + } + + input, err := codec.CreateCodecEntry(idlDef, methodConfig.ChainSpecificName, idl, inputMod) + if err != nil { + return fmt.Errorf("failed to create codec entry for method %s.%s, error: %w", program, method, err) + } + + s.parsed.EncoderDefs[codec.WrapItemType(true, program, method, "")] = input } - codecs[program] = idlCodec } - return codecs, nil + + return nil } /* @@ -250,16 +256,15 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra } } - codec := s.codecs[contractName] - encodedPayload, err := codec.Encode(ctx, args, method) - - discriminator := GetDiscriminator(methodConfig.ChainSpecificName) - encodedPayload = append(discriminator[:], encodedPayload...) + encodedPayload, err := s.encoder.Encode(ctx, args, codec.WrapItemType(true, contractName, method, "")) if err != nil { return errorWithDebugID(fmt.Errorf("error encoding transaction payload: %w", err), debugID) } + discriminator := GetDiscriminator(methodConfig.ChainSpecificName) + encodedPayload = append(discriminator[:], encodedPayload...) + // Fetch derived and static table maps derivedTableMap, staticTableMap, err := s.ResolveLookupTables(ctx, args, methodConfig.LookupTables) if err != nil { diff --git a/pkg/solana/chainwriter/chain_writer_test.go b/pkg/solana/chainwriter/chain_writer_test.go index 03428d080..cd215971e 100644 --- a/pkg/solana/chainwriter/chain_writer_test.go +++ b/pkg/solana/chainwriter/chain_writer_test.go @@ -3,7 +3,7 @@ package chainwriter_test import ( "bytes" "errors" - "io/ioutil" + "fmt" "math/big" "os" "reflect" @@ -27,6 +27,12 @@ import ( txmMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" ) +type Arguments struct { + LookupTable solana.PublicKey + Seed1 []byte + Seed2 []byte +} + func TestChainWriter_GetAddresses(t *testing.T) { ctx := tests.Context(t) @@ -87,7 +93,7 @@ func TestChainWriter_GetAddresses(t *testing.T) { PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()}, Seeds: []chainwriter.Seed{ // extract seed2 for PDA lookup - {Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}}, + {Dynamic: chainwriter.AccountLookup{Name: "Seed2", Location: "Seed2"}}, }, IsSigner: derivedTablePdaLookupMeta.IsSigner, IsWritable: derivedTablePdaLookupMeta.IsWritable, @@ -107,10 +113,10 @@ func TestChainWriter_GetAddresses(t *testing.T) { // correlates to DerivedTable index in account lookup config derivedTablePdaLookupMeta.PublicKey = storedPubKeys[0] - args := map[string]interface{}{ - "lookup_table": accountLookupMeta.PublicKey.Bytes(), - "seed1": seed1, - "seed2": seed2, + args := Arguments{ + LookupTable: accountLookupMeta.PublicKey, + Seed1: seed1, + Seed2: seed2, } accountLookupConfig := []chainwriter.Lookup{ @@ -122,7 +128,7 @@ func TestChainWriter_GetAddresses(t *testing.T) { }, chainwriter.AccountLookup{ Name: "LookupTable", - Location: "lookup_table", + Location: "LookupTable", IsSigner: accountLookupMeta.IsSigner, IsWritable: accountLookupMeta.IsWritable, }, @@ -131,7 +137,7 @@ func TestChainWriter_GetAddresses(t *testing.T) { PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: solana.SystemProgramID.String()}, Seeds: []chainwriter.Seed{ // extract seed1 for PDA lookup - {Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}}, + {Dynamic: chainwriter.AccountLookup{Name: "Seed1", Location: "Seed1"}}, }, IsSigner: pdaLookupMeta.IsSigner, IsWritable: pdaLookupMeta.IsWritable, @@ -177,8 +183,8 @@ func TestChainWriter_GetAddresses(t *testing.T) { }) t.Run("resolve addresses for multiple indices from derived lookup table", func(t *testing.T) { - args := map[string]interface{}{ - "seed2": seed2, + args := Arguments{ + Seed2: seed2, } accountLookupConfig := []chainwriter.Lookup{ @@ -202,8 +208,8 @@ func TestChainWriter_GetAddresses(t *testing.T) { }) t.Run("resolve all addresses from derived lookup table if indices not specified", func(t *testing.T) { - args := map[string]interface{}{ - "seed2": seed2, + args := Arguments{ + Seed2: seed2, } accountLookupConfig := []chainwriter.Lookup{ @@ -274,7 +280,7 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) { PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()}, Seeds: []chainwriter.Seed{ // extract seed1 for PDA lookup - {Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}}, + {Dynamic: chainwriter.AccountLookup{Name: "Seed1", Location: "Seed1"}}, }, IsSigner: true, IsWritable: true, @@ -291,7 +297,7 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) { PublicKey: chainwriter.AccountConstant{Name: "UnusedAccount", Address: unusedProgramID.String()}, Seeds: []chainwriter.Seed{ // extract seed2 for PDA lookup - {Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}}, + {Dynamic: chainwriter.AccountLookup{Name: "Seed2", Location: "Seed2"}}, }, IsSigner: true, IsWritable: true, @@ -305,9 +311,9 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) { StaticLookupTables: []solana.PublicKey{staticLookupTablePubkey1, staticLookupTablePubkey2}, } - args := map[string]interface{}{ - "seed1": seed1, - "seed2": seed2, + args := Arguments{ + Seed1: seed1, + Seed2: seed2, } t.Run("returns filtered map with only relevant addresses required by account lookup config", func(t *testing.T) { @@ -403,6 +409,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { seed2 := []byte("seed2") programID := solana.MustPublicKeyFromBase58("6AfuXF6HapDUhQfE4nQG9C1SGtA1YjP3icaJyRfU4RyE") derivedTablePda := mustFindPdaProgramAddress(t, [][]byte{seed2}, programID) + fmt.Println("pda:", derivedTablePda) // mock data account response from program derivedLookupTablePubkey := mockDataAccountLookupTable(t, rw, derivedTablePda) // mock fetch lookup table addresses call @@ -414,19 +421,14 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { staticLookupKeys := chainwriter.CreateTestPubKeys(t, 2) mockFetchLookupTableAddresses(t, rw, staticLookupTablePubkey, staticLookupKeys) - jsonFile, err := os.Open("testContractIDL.json") - require.NoError(t, err) - - defer jsonFile.Close() - - data, err := ioutil.ReadAll(jsonFile) + data, err := os.ReadFile("testContractIDL.json") require.NoError(t, err) testContractIDLJson := string(data) cwConfig := chainwriter.ChainWriterConfig{ Programs: map[string]chainwriter.ProgramConfig{ - "contractReaderInterface": { + "contract_reader_interface": { Methods: map[string]chainwriter.MethodConfig{ "initializeLookupTable": { FromAddress: admin.String(), @@ -440,7 +442,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()}, Seeds: []chainwriter.Seed{ // extract seed2 for PDA lookup - {Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}}, + {Dynamic: chainwriter.AccountLookup{Name: "Seed2", Location: "Seed2"}}, }, IsSigner: false, IsWritable: false, @@ -462,7 +464,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { }, chainwriter.AccountLookup{ Name: "LookupTable", - Location: "lookup_table", + Location: "LookupTable", IsSigner: false, IsWritable: false, }, @@ -471,7 +473,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: solana.SystemProgramID.String()}, Seeds: []chainwriter.Seed{ // extract seed1 for PDA lookup - {Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}}, + {Dynamic: chainwriter.AccountLookup{Name: "Seed1", Location: "Seed1"}}, }, IsSigner: false, IsWritable: false, @@ -514,22 +516,24 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { t.Run("fails to encode payload if args with missing values provided", func(t *testing.T) { txID := uuid.NewString() - args := map[string]interface{}{} - submitErr := cw.SubmitTransaction(ctx, "contractReaderInterface", "initializeLookupTable", args, txID, programID.String(), nil, nil) + type InvalidArgs struct{} + args := InvalidArgs{} + submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "initializeLookupTable", args, txID, programID.String(), nil, nil) require.Error(t, submitErr) }) t.Run("fails if invalid contract name provided", func(t *testing.T) { txID := uuid.NewString() - args := map[string]interface{}{} + args := Arguments{} submitErr := cw.SubmitTransaction(ctx, "badContract", "initializeLookupTable", args, txID, programID.String(), nil, nil) require.Error(t, submitErr) }) t.Run("fails if invalid method provided", func(t *testing.T) { txID := uuid.NewString() - args := map[string]interface{}{} - submitErr := cw.SubmitTransaction(ctx, "contractReaderInterface", "badMethod", args, txID, programID.String(), nil, nil) + + args := Arguments{} + submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "badMethod", args, txID, programID.String(), nil, nil) require.Error(t, submitErr) }) @@ -555,13 +559,13 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { return true }), &txID, mock.Anything).Return(nil).Once() - args := map[string]interface{}{ - "lookupTable": chainwriter.GetRandomPubKey(t).Bytes(), - "lookup_table": account2.Bytes(), - "seed1": seed1, - "seed2": seed2, + args := Arguments{ + LookupTable: account2, + Seed1: seed1, + Seed2: seed2, } - submitErr := cw.SubmitTransaction(ctx, "contractReaderInterface", "initializeLookupTable", args, txID, programID.String(), nil, nil) + + submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "initializeLookupTable", args, txID, programID.String(), nil, nil) require.NoError(t, submitErr) }) } diff --git a/pkg/solana/txm/pendingtx_test.go b/pkg/solana/txm/pendingtx_test.go index b082b2162..a098f2400 100644 --- a/pkg/solana/txm/pendingtx_test.go +++ b/pkg/solana/txm/pendingtx_test.go @@ -1306,7 +1306,7 @@ func TestPendingTxContext_ListAllExpiredBroadcastedTxs(t *testing.T) { } for _, tt := range tests { - tt := tt // capture range variable + tt := tt t.Run(tt.name, func(t *testing.T) { // Initialize a new PendingTxContext ctx := newPendingTxContext()