Skip to content

Commit

Permalink
Updated codec usage
Browse files Browse the repository at this point in the history
  • Loading branch information
silaslenihan committed Jan 2, 2025
1 parent dcb53a2 commit fbcb658
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 74 deletions.
2 changes: 1 addition & 1 deletion pkg/solana/chainwriter/ccip_example_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestConfig() {
Name: "RegistryTokenState",
// In this case, the user configured the lookup table accounts to use a PDALookup, which
// generates a list of one of more PDA accounts based on the input parameters. Specifically,
// there will be multple PDA accounts if there are multiple addresses in the message, otherwise,
// there will be multiple PDA accounts if there are multiple addresses in the message, otherwise,
// there will only be one PDA account to read from. The PDA account corresponds to the lookup table.
Accounts: PDALookups{
Name: "RegistryTokenState",
Expand Down
73 changes: 39 additions & 34 deletions pkg/solana/chainwriter/chain_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/gagliardetto/solana-go/rpc"

commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec"
"github.com/smartcontractkit/chainlink-common/pkg/codec/encodings/binary"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/services"
"github.com/smartcontractkit/chainlink-common/pkg/types"
Expand All @@ -31,7 +30,8 @@ type SolanaChainWriterService struct {
ge fees.Estimator
config ChainWriterConfig

codecs map[string]types.Codec
parsed *codec.ParsedTypes
encoder types.Encoder

services.StateMachine
}
Expand Down Expand Up @@ -62,48 +62,54 @@ type MethodConfig struct {
}

func NewSolanaChainWriterService(logger logger.Logger, reader client.Reader, txm txm.TxManager, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) {
codecs, err := parseIDLCodecs(config)
if err != nil {
return nil, fmt.Errorf("failed to parse IDL codecs: %w", err)
}

return &SolanaChainWriterService{
w := SolanaChainWriterService{
lggr: logger,
reader: reader,
txm: txm,
ge: ge,
config: config,
codecs: codecs,
}, nil
parsed: &codec.ParsedTypes{EncoderDefs: map[string]codec.Entry{}, DecoderDefs: map[string]codec.Entry{}},
}

if err := w.parsePrograms(config); err != nil {
return nil, fmt.Errorf("failed to parse programs: %w", err)
}

var err error
if w.encoder, err = w.parsed.ToCodec(); err != nil {
return nil, fmt.Errorf("%w: failed to create codec", err)
}

return &w, nil
}

func parseIDLCodecs(config ChainWriterConfig) (map[string]types.Codec, error) {
codecs := make(map[string]types.Codec)
func (s *SolanaChainWriterService) parsePrograms(config ChainWriterConfig) error {
for program, programConfig := range config.Programs {
var idl codec.IDL
if err := json.Unmarshal([]byte(programConfig.IDL), &idl); err != nil {
return nil, fmt.Errorf("failed to unmarshal IDL for program: %s, error: %w", program, err)
}
idlCodec, err := codec.NewIDLInstructionsCodec(idl, binary.LittleEndian())
if err != nil {
return nil, fmt.Errorf("failed to create codec from IDL for program: %s, error: %w", program, err)
return fmt.Errorf("failed to unmarshal IDL for program: %s, error: %w", program, err)
}
for method, methodConfig := range programConfig.Methods {
if methodConfig.InputModifications != nil {
modConfig, err := methodConfig.InputModifications.ToModifier(codec.DecoderHooks...)
if err != nil {
return nil, fmt.Errorf("failed to create input modifications for method %s.%s, error: %w", program, method, err)
}
// add mods to codec
idlCodec, err = codec.NewNamedModifierCodec(idlCodec, method, modConfig)
if err != nil {
return nil, fmt.Errorf("failed to create named codec for method %s.%s, error: %w", program, method, err)
}
idlDef, err := codec.FindDefinitionFromIDL(codec.ChainConfigTypeInstructionDef, methodConfig.ChainSpecificName, idl)
if err != nil {
return err
}

inputMod, err := methodConfig.InputModifications.ToModifier(codec.DecoderHooks...)
if err != nil {
return fmt.Errorf("failed to create input modifications for method %s.%s, error: %w", program, method, err)
}

input, err := codec.CreateCodecEntry(idlDef, methodConfig.ChainSpecificName, idl, inputMod)
if err != nil {
return fmt.Errorf("failed to create codec entry for method %s.%s, error: %w", program, method, err)
}

s.parsed.EncoderDefs[codec.WrapItemType(true, program, method, "")] = input
}
codecs[program] = idlCodec
}
return codecs, nil

return nil
}

/*
Expand Down Expand Up @@ -250,16 +256,15 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
}
}

codec := s.codecs[contractName]
encodedPayload, err := codec.Encode(ctx, args, method)

discriminator := GetDiscriminator(methodConfig.ChainSpecificName)
encodedPayload = append(discriminator[:], encodedPayload...)
encodedPayload, err := s.encoder.Encode(ctx, args, codec.WrapItemType(true, contractName, method, ""))

if err != nil {
return errorWithDebugID(fmt.Errorf("error encoding transaction payload: %w", err), debugID)
}

discriminator := GetDiscriminator(methodConfig.ChainSpecificName)
encodedPayload = append(discriminator[:], encodedPayload...)

// Fetch derived and static table maps
derivedTableMap, staticTableMap, err := s.ResolveLookupTables(ctx, args, methodConfig.LookupTables)
if err != nil {
Expand Down
78 changes: 40 additions & 38 deletions pkg/solana/chainwriter/chain_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package chainwriter_test
import (
"bytes"
"errors"
"io/ioutil"
"math/big"
"os"
"reflect"
Expand All @@ -27,6 +26,12 @@ import (
txmMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks"
)

type Arguments struct {
LookupTable solana.PublicKey
Seed1 []byte
Seed2 []byte
}

func TestChainWriter_GetAddresses(t *testing.T) {
ctx := tests.Context(t)

Expand Down Expand Up @@ -87,7 +92,7 @@ func TestChainWriter_GetAddresses(t *testing.T) {
PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()},
Seeds: []chainwriter.Seed{
// extract seed2 for PDA lookup
{Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}},
{Dynamic: chainwriter.AccountLookup{Name: "Seed2", Location: "Seed2"}},
},
IsSigner: derivedTablePdaLookupMeta.IsSigner,
IsWritable: derivedTablePdaLookupMeta.IsWritable,
Expand All @@ -107,10 +112,10 @@ func TestChainWriter_GetAddresses(t *testing.T) {
// correlates to DerivedTable index in account lookup config
derivedTablePdaLookupMeta.PublicKey = storedPubKeys[0]

args := map[string]interface{}{
"lookup_table": accountLookupMeta.PublicKey.Bytes(),
"seed1": seed1,
"seed2": seed2,
args := Arguments{
LookupTable: accountLookupMeta.PublicKey,
Seed1: seed1,
Seed2: seed2,
}

accountLookupConfig := []chainwriter.Lookup{
Expand All @@ -122,7 +127,7 @@ func TestChainWriter_GetAddresses(t *testing.T) {
},
chainwriter.AccountLookup{
Name: "LookupTable",
Location: "lookup_table",
Location: "LookupTable",
IsSigner: accountLookupMeta.IsSigner,
IsWritable: accountLookupMeta.IsWritable,
},
Expand All @@ -131,7 +136,7 @@ func TestChainWriter_GetAddresses(t *testing.T) {
PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: solana.SystemProgramID.String()},
Seeds: []chainwriter.Seed{
// extract seed1 for PDA lookup
{Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}},
{Dynamic: chainwriter.AccountLookup{Name: "Seed1", Location: "Seed1"}},
},
IsSigner: pdaLookupMeta.IsSigner,
IsWritable: pdaLookupMeta.IsWritable,
Expand Down Expand Up @@ -177,8 +182,8 @@ func TestChainWriter_GetAddresses(t *testing.T) {
})

t.Run("resolve addresses for multiple indices from derived lookup table", func(t *testing.T) {
args := map[string]interface{}{
"seed2": seed2,
args := Arguments{
Seed2: seed2,
}

accountLookupConfig := []chainwriter.Lookup{
Expand All @@ -202,8 +207,8 @@ func TestChainWriter_GetAddresses(t *testing.T) {
})

t.Run("resolve all addresses from derived lookup table if indices not specified", func(t *testing.T) {
args := map[string]interface{}{
"seed2": seed2,
args := Arguments{
Seed2: seed2,
}

accountLookupConfig := []chainwriter.Lookup{
Expand Down Expand Up @@ -274,7 +279,7 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) {
PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()},
Seeds: []chainwriter.Seed{
// extract seed1 for PDA lookup
{Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}},
{Dynamic: chainwriter.AccountLookup{Name: "Seed1", Location: "Seed1"}},
},
IsSigner: true,
IsWritable: true,
Expand All @@ -291,7 +296,7 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) {
PublicKey: chainwriter.AccountConstant{Name: "UnusedAccount", Address: unusedProgramID.String()},
Seeds: []chainwriter.Seed{
// extract seed2 for PDA lookup
{Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}},
{Dynamic: chainwriter.AccountLookup{Name: "Seed2", Location: "Seed2"}},
},
IsSigner: true,
IsWritable: true,
Expand All @@ -305,9 +310,9 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) {
StaticLookupTables: []solana.PublicKey{staticLookupTablePubkey1, staticLookupTablePubkey2},
}

args := map[string]interface{}{
"seed1": seed1,
"seed2": seed2,
args := Arguments{
Seed1: seed1,
Seed2: seed2,
}

t.Run("returns filtered map with only relevant addresses required by account lookup config", func(t *testing.T) {
Expand Down Expand Up @@ -414,19 +419,14 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
staticLookupKeys := chainwriter.CreateTestPubKeys(t, 2)
mockFetchLookupTableAddresses(t, rw, staticLookupTablePubkey, staticLookupKeys)

jsonFile, err := os.Open("testContractIDL.json")
require.NoError(t, err)

defer jsonFile.Close()

data, err := ioutil.ReadAll(jsonFile)
data, err := os.ReadFile("testContractIDL.json")
require.NoError(t, err)

testContractIDLJson := string(data)

cwConfig := chainwriter.ChainWriterConfig{
Programs: map[string]chainwriter.ProgramConfig{
"contractReaderInterface": {
"contract_reader_interface": {
Methods: map[string]chainwriter.MethodConfig{
"initializeLookupTable": {
FromAddress: admin.String(),
Expand All @@ -440,7 +440,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()},
Seeds: []chainwriter.Seed{
// extract seed2 for PDA lookup
{Dynamic: chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}},
{Dynamic: chainwriter.AccountLookup{Name: "Seed2", Location: "Seed2"}},
},
IsSigner: false,
IsWritable: false,
Expand All @@ -462,7 +462,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
},
chainwriter.AccountLookup{
Name: "LookupTable",
Location: "lookup_table",
Location: "LookupTable",
IsSigner: false,
IsWritable: false,
},
Expand All @@ -471,7 +471,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: solana.SystemProgramID.String()},
Seeds: []chainwriter.Seed{
// extract seed1 for PDA lookup
{Dynamic: chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}},
{Dynamic: chainwriter.AccountLookup{Name: "Seed1", Location: "Seed1"}},
},
IsSigner: false,
IsWritable: false,
Expand Down Expand Up @@ -514,22 +514,24 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {

t.Run("fails to encode payload if args with missing values provided", func(t *testing.T) {
txID := uuid.NewString()
args := map[string]interface{}{}
submitErr := cw.SubmitTransaction(ctx, "contractReaderInterface", "initializeLookupTable", args, txID, programID.String(), nil, nil)
type InvalidArgs struct{}
args := InvalidArgs{}
submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "initializeLookupTable", args, txID, programID.String(), nil, nil)
require.Error(t, submitErr)
})

t.Run("fails if invalid contract name provided", func(t *testing.T) {
txID := uuid.NewString()
args := map[string]interface{}{}
args := Arguments{}
submitErr := cw.SubmitTransaction(ctx, "badContract", "initializeLookupTable", args, txID, programID.String(), nil, nil)
require.Error(t, submitErr)
})

t.Run("fails if invalid method provided", func(t *testing.T) {
txID := uuid.NewString()
args := map[string]interface{}{}
submitErr := cw.SubmitTransaction(ctx, "contractReaderInterface", "badMethod", args, txID, programID.String(), nil, nil)

args := Arguments{}
submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "badMethod", args, txID, programID.String(), nil, nil)
require.Error(t, submitErr)
})

Expand All @@ -555,13 +557,13 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
return true
}), &txID, mock.Anything).Return(nil).Once()

args := map[string]interface{}{
"lookupTable": chainwriter.GetRandomPubKey(t).Bytes(),
"lookup_table": account2.Bytes(),
"seed1": seed1,
"seed2": seed2,
args := Arguments{
LookupTable: account2,
Seed1: seed1,
Seed2: seed2,
}
submitErr := cw.SubmitTransaction(ctx, "contractReaderInterface", "initializeLookupTable", args, txID, programID.String(), nil, nil)

submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "initializeLookupTable", args, txID, programID.String(), nil, nil)
require.NoError(t, submitErr)
})
}
Expand Down
1 change: 1 addition & 0 deletions pkg/solana/chainwriter/lookups.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/smartcontractkit/chainlink-solana/pkg/solana/client"
)

// Lookup is an interface that defines a method to resolve an address (or multiple addresses) from a given definition.
type Lookup interface {
Resolve(ctx context.Context, args any, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader) ([]*solana.AccountMeta, error)
}
Expand Down
1 change: 0 additions & 1 deletion pkg/solana/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ func FundAccounts(t *testing.T, accounts []solana.PrivateKey, solanaGoClient *rp
}
}
remaining = unconfirmedTxCount
fmt.Printf("Waiting for finalized funding on %d addresses\n", remaining)

time.Sleep(500 * time.Millisecond)
if count > 60 {
Expand Down

0 comments on commit fbcb658

Please sign in to comment.