Skip to content

Commit

Permalink
Updated codec usage
Browse files Browse the repository at this point in the history
  • Loading branch information
silaslenihan committed Dec 31, 2024
1 parent 2e41287 commit e21738a
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 73 deletions.
73 changes: 39 additions & 34 deletions pkg/solana/chainwriter/chain_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/gagliardetto/solana-go/rpc"

commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec"
"github.com/smartcontractkit/chainlink-common/pkg/codec/encodings/binary"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/services"
"github.com/smartcontractkit/chainlink-common/pkg/types"
Expand All @@ -31,7 +30,8 @@ type SolanaChainWriterService struct {
ge fees.Estimator
config ChainWriterConfig

codecs map[string]types.Codec
parsed *codec.ParsedTypes
encoder types.Encoder

services.StateMachine
}
Expand Down Expand Up @@ -62,48 +62,54 @@ type MethodConfig struct {
}

func NewSolanaChainWriterService(logger logger.Logger, reader client.Reader, txm txm.TxManager, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) {
codecs, err := parseIDLCodecs(config)
if err != nil {
return nil, fmt.Errorf("failed to parse IDL codecs: %w", err)
}

return &SolanaChainWriterService{
w := SolanaChainWriterService{
lggr: logger,
reader: reader,
txm: txm,
ge: ge,
config: config,
codecs: codecs,
}, nil
parsed: &codec.ParsedTypes{EncoderDefs: map[string]codec.Entry{}, DecoderDefs: map[string]codec.Entry{}},
}

if err := w.parsePrograms(config); err != nil {
return nil, fmt.Errorf("failed to parse programs: %w", err)
}

var err error
if w.encoder, err = w.parsed.ToCodec(); err != nil {
return nil, fmt.Errorf("%w: failed to create codec", err)
}

return &w, nil
}

func parseIDLCodecs(config ChainWriterConfig) (map[string]types.Codec, error) {
codecs := make(map[string]types.Codec)
func (s *SolanaChainWriterService) parsePrograms(config ChainWriterConfig) error {
for program, programConfig := range config.Programs {
var idl codec.IDL
if err := json.Unmarshal([]byte(programConfig.IDL), &idl); err != nil {
return nil, fmt.Errorf("failed to unmarshal IDL for program: %s, error: %w", program, err)
}
idlCodec, err := codec.NewIDLInstructionsCodec(idl, binary.LittleEndian())
if err != nil {
return nil, fmt.Errorf("failed to create codec from IDL for program: %s, error: %w", program, err)
return fmt.Errorf("failed to unmarshal IDL for program: %s, error: %w", program, err)
}
for method, methodConfig := range programConfig.Methods {
if methodConfig.InputModifications != nil {
modConfig, err := methodConfig.InputModifications.ToModifier(codec.DecoderHooks...)
if err != nil {
return nil, fmt.Errorf("failed to create input modifications for method %s.%s, error: %w", program, method, err)
}
// add mods to codec
idlCodec, err = codec.NewNamedModifierCodec(idlCodec, method, modConfig)
if err != nil {
return nil, fmt.Errorf("failed to create named codec for method %s.%s, error: %w", program, method, err)
}
idlDef, err := codec.FindDefinitionFromIDL(codec.ChainConfigTypeInstructionDef, methodConfig.ChainSpecificName, idl)
if err != nil {
return err
}

inputMod, err := methodConfig.InputModifications.ToModifier(codec.DecoderHooks...)
if err != nil {
return fmt.Errorf("failed to create input modifications for method %s.%s, error: %w", program, method, err)
}

input, err := codec.CreateCodecEntry(idlDef, methodConfig.ChainSpecificName, idl, inputMod)
if err != nil {
return fmt.Errorf("failed to create codec entry for method %s.%s, error: %w", program, method, err)
}

s.parsed.EncoderDefs[codec.WrapItemType(true, program, method, "")] = input
}
codecs[program] = idlCodec
}
return codecs, nil

return nil
}

/*
Expand Down Expand Up @@ -250,16 +256,15 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
}
}

codec := s.codecs[contractName]
encodedPayload, err := codec.Encode(ctx, args, method)

discriminator := GetDiscriminator(methodConfig.ChainSpecificName)
encodedPayload = append(discriminator[:], encodedPayload...)
encodedPayload, err := s.encoder.Encode(ctx, args, codec.WrapItemType(true, contractName, method, ""))

if err != nil {
return errorWithDebugID(fmt.Errorf("error encoding transaction payload: %w", err), debugID)
}

discriminator := GetDiscriminator(methodConfig.ChainSpecificName)
encodedPayload = append(discriminator[:], encodedPayload...)

// Fetch derived and static table maps
derivedTableMap, staticTableMap, err := s.ResolveLookupTables(ctx, args, methodConfig.LookupTables)
if err != nil {
Expand Down
80 changes: 42 additions & 38 deletions pkg/solana/chainwriter/chain_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package chainwriter_test
import (
"bytes"
"errors"
"io/ioutil"
"fmt"
"math/big"
"os"
"reflect"
Expand All @@ -27,6 +27,12 @@ import (
txmMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks"
)

type Arguments struct {
LookupTable solana.PublicKey
Seed1 []byte
Seed2 []byte
}

func TestChainWriter_GetAddresses(t *testing.T) {
ctx := tests.Context(t)

Expand Down Expand Up @@ -87,7 +93,7 @@ func TestChainWriter_GetAddresses(t *testing.T) {
PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()},
Seeds: []chainwriter.Seed{
// extract seed2 for PDA lookup
{Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}},
{Dynamic: chainwriter.AccountLookup{Name: "Seed2", Location: "Seed2"}},
},
IsSigner: derivedTablePdaLookupMeta.IsSigner,
IsWritable: derivedTablePdaLookupMeta.IsWritable,
Expand All @@ -107,10 +113,10 @@ func TestChainWriter_GetAddresses(t *testing.T) {
// correlates to DerivedTable index in account lookup config
derivedTablePdaLookupMeta.PublicKey = storedPubKeys[0]

args := map[string]interface{}{
"lookup_table": accountLookupMeta.PublicKey.Bytes(),
"seed1": seed1,
"seed2": seed2,
args := Arguments{
LookupTable: accountLookupMeta.PublicKey,
Seed1: seed1,
Seed2: seed2,
}

accountLookupConfig := []chainwriter.Lookup{
Expand All @@ -122,7 +128,7 @@ func TestChainWriter_GetAddresses(t *testing.T) {
},
chainwriter.AccountLookup{
Name: "LookupTable",
Location: "lookup_table",
Location: "LookupTable",
IsSigner: accountLookupMeta.IsSigner,
IsWritable: accountLookupMeta.IsWritable,
},
Expand All @@ -131,7 +137,7 @@ func TestChainWriter_GetAddresses(t *testing.T) {
PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: solana.SystemProgramID.String()},
Seeds: []chainwriter.Seed{
// extract seed1 for PDA lookup
{Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}},
{Dynamic: chainwriter.AccountLookup{Name: "Seed1", Location: "Seed1"}},
},
IsSigner: pdaLookupMeta.IsSigner,
IsWritable: pdaLookupMeta.IsWritable,
Expand Down Expand Up @@ -177,8 +183,8 @@ func TestChainWriter_GetAddresses(t *testing.T) {
})

t.Run("resolve addresses for multiple indices from derived lookup table", func(t *testing.T) {
args := map[string]interface{}{
"seed2": seed2,
args := Arguments{
Seed2: seed2,
}

accountLookupConfig := []chainwriter.Lookup{
Expand All @@ -202,8 +208,8 @@ func TestChainWriter_GetAddresses(t *testing.T) {
})

t.Run("resolve all addresses from derived lookup table if indices not specified", func(t *testing.T) {
args := map[string]interface{}{
"seed2": seed2,
args := Arguments{
Seed2: seed2,
}

accountLookupConfig := []chainwriter.Lookup{
Expand Down Expand Up @@ -274,7 +280,7 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) {
PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()},
Seeds: []chainwriter.Seed{
// extract seed1 for PDA lookup
{Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}},
{Dynamic: chainwriter.AccountLookup{Name: "Seed1", Location: "Seed1"}},
},
IsSigner: true,
IsWritable: true,
Expand All @@ -291,7 +297,7 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) {
PublicKey: chainwriter.AccountConstant{Name: "UnusedAccount", Address: unusedProgramID.String()},
Seeds: []chainwriter.Seed{
// extract seed2 for PDA lookup
{Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}},
{Dynamic: chainwriter.AccountLookup{Name: "Seed2", Location: "Seed2"}},
},
IsSigner: true,
IsWritable: true,
Expand All @@ -305,9 +311,9 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) {
StaticLookupTables: []solana.PublicKey{staticLookupTablePubkey1, staticLookupTablePubkey2},
}

args := map[string]interface{}{
"seed1": seed1,
"seed2": seed2,
args := Arguments{
Seed1: seed1,
Seed2: seed2,
}

t.Run("returns filtered map with only relevant addresses required by account lookup config", func(t *testing.T) {
Expand Down Expand Up @@ -403,6 +409,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
seed2 := []byte("seed2")
programID := solana.MustPublicKeyFromBase58("6AfuXF6HapDUhQfE4nQG9C1SGtA1YjP3icaJyRfU4RyE")
derivedTablePda := mustFindPdaProgramAddress(t, [][]byte{seed2}, programID)
fmt.Println("pda:", derivedTablePda)
// mock data account response from program
derivedLookupTablePubkey := mockDataAccountLookupTable(t, rw, derivedTablePda)
// mock fetch lookup table addresses call
Expand All @@ -414,19 +421,14 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
staticLookupKeys := chainwriter.CreateTestPubKeys(t, 2)
mockFetchLookupTableAddresses(t, rw, staticLookupTablePubkey, staticLookupKeys)

jsonFile, err := os.Open("testContractIDL.json")
require.NoError(t, err)

defer jsonFile.Close()

data, err := ioutil.ReadAll(jsonFile)
data, err := os.ReadFile("testContractIDL.json")
require.NoError(t, err)

testContractIDLJson := string(data)

cwConfig := chainwriter.ChainWriterConfig{
Programs: map[string]chainwriter.ProgramConfig{
"contractReaderInterface": {
"contract_reader_interface": {
Methods: map[string]chainwriter.MethodConfig{
"initializeLookupTable": {
FromAddress: admin.String(),
Expand All @@ -440,7 +442,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()},
Seeds: []chainwriter.Seed{
// extract seed2 for PDA lookup
{Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}},
{Dynamic: chainwriter.AccountLookup{Name: "Seed2", Location: "Seed2"}},
},
IsSigner: false,
IsWritable: false,
Expand All @@ -462,7 +464,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
},
chainwriter.AccountLookup{
Name: "LookupTable",
Location: "lookup_table",
Location: "LookupTable",
IsSigner: false,
IsWritable: false,
},
Expand All @@ -471,7 +473,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: solana.SystemProgramID.String()},
Seeds: []chainwriter.Seed{
// extract seed1 for PDA lookup
{Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}},
{Dynamic: chainwriter.AccountLookup{Name: "Seed1", Location: "Seed1"}},
},
IsSigner: false,
IsWritable: false,
Expand Down Expand Up @@ -514,22 +516,24 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {

t.Run("fails to encode payload if args with missing values provided", func(t *testing.T) {
txID := uuid.NewString()
args := map[string]interface{}{}
submitErr := cw.SubmitTransaction(ctx, "contractReaderInterface", "initializeLookupTable", args, txID, programID.String(), nil, nil)
type InvalidArgs struct{}
args := InvalidArgs{}
submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "initializeLookupTable", args, txID, programID.String(), nil, nil)
require.Error(t, submitErr)
})

t.Run("fails if invalid contract name provided", func(t *testing.T) {
txID := uuid.NewString()
args := map[string]interface{}{}
args := Arguments{}
submitErr := cw.SubmitTransaction(ctx, "badContract", "initializeLookupTable", args, txID, programID.String(), nil, nil)
require.Error(t, submitErr)
})

t.Run("fails if invalid method provided", func(t *testing.T) {
txID := uuid.NewString()
args := map[string]interface{}{}
submitErr := cw.SubmitTransaction(ctx, "contractReaderInterface", "badMethod", args, txID, programID.String(), nil, nil)

args := Arguments{}
submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "badMethod", args, txID, programID.String(), nil, nil)
require.Error(t, submitErr)
})

Expand All @@ -555,13 +559,13 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
return true
}), &txID, mock.Anything).Return(nil).Once()

args := map[string]interface{}{
"lookupTable": chainwriter.GetRandomPubKey(t).Bytes(),
"lookup_table": account2.Bytes(),
"seed1": seed1,
"seed2": seed2,
args := Arguments{
LookupTable: account2,
Seed1: seed1,
Seed2: seed2,
}
submitErr := cw.SubmitTransaction(ctx, "contractReaderInterface", "initializeLookupTable", args, txID, programID.String(), nil, nil)

submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "initializeLookupTable", args, txID, programID.String(), nil, nil)
require.NoError(t, submitErr)
})
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/solana/txm/pendingtx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,7 @@ func TestPendingTxContext_ListAllExpiredBroadcastedTxs(t *testing.T) {
}

for _, tt := range tests {
tt := tt // capture range variable
tt := tt
t.Run(tt.name, func(t *testing.T) {
// Initialize a new PendingTxContext
ctx := newPendingTxContext()
Expand Down

0 comments on commit e21738a

Please sign in to comment.