Skip to content

Commit

Permalink
BCF-3016 Convert Multiple Binding Calls to Parallel (#617)
Browse files Browse the repository at this point in the history
* Convert Multiple Binding Calls to Parallel

The binding calls are expected to be run in sequence as they map over
the returnVal, but the associated RPC calls can be run in parallel.
This commit preloads all data over RPC in parallel before mapping over
the returnVal in sequence.

* remove unnecessary constant
  • Loading branch information
EasterTheBunny authored Mar 15, 2024
1 parent 786829f commit 5b3b9fb
Show file tree
Hide file tree
Showing 6 changed files with 473 additions and 44 deletions.
53 changes: 51 additions & 2 deletions pkg/solana/chainreader/account_read_binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,61 @@ type accountReadBinding struct {
reader BinaryDataReader
}

func newAccountReadBinding(acct string, codec types.RemoteCodec, reader BinaryDataReader) *accountReadBinding {
return &accountReadBinding{
idlAccount: acct,
codec: codec,
reader: reader,
}
}

var _ readBinding = &accountReadBinding{}

func (b *accountReadBinding) GetLatestValue(ctx context.Context, _ any, outVal any) error {
func (b *accountReadBinding) PreLoad(ctx context.Context, result *loadedResult) {
if result == nil {
return
}

bts, err := b.reader.ReadAll(ctx, b.account)
if err != nil {
return fmt.Errorf("%w: failed to get binary data", err)
result.err <- fmt.Errorf("%w: failed to get binary data", err)

return
}

select {
case <-ctx.Done():
result.err <- ctx.Err()
default:
result.value <- bts
}
}

func (b *accountReadBinding) GetLatestValue(ctx context.Context, _ any, outVal any, result *loadedResult) error {
var (
bts []byte
err error
)

if result != nil {
// when preloading, the process will wait for one of three conditions:
// 1. the context ends and returns an error
// 2. bytes were loaded in the bytes channel
// 3. an error was loaded in the err channel
select {
case <-ctx.Done():
err = ctx.Err()
case bts = <-result.value:
case err = <-result.err:
}

if err != nil {
return err
}
} else {
if bts, err = b.reader.ReadAll(ctx, b.account); err != nil {
return fmt.Errorf("%w: failed to get binary data", err)
}
}

return b.codec.Decode(ctx, bts, outVal, b.idlAccount)
Expand Down
159 changes: 159 additions & 0 deletions pkg/solana/chainreader/account_read_binding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package chainreader

import (
"context"
"errors"
"testing"
"time"

"github.com/gagliardetto/solana-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

"github.com/smartcontractkit/chainlink-common/pkg/codec/encodings"
"github.com/smartcontractkit/chainlink-common/pkg/codec/encodings/binary"
"github.com/smartcontractkit/chainlink-common/pkg/types"
)

func TestPreload(t *testing.T) {
t.Parallel()

testCodec := makeTestCodec(t)

t.Run("get latest value waits for preload", func(t *testing.T) {
t.Parallel()

reader := new(mockReader)
binding := newAccountReadBinding(testCodecKey, testCodec, reader)

expected := testStruct{A: true, B: 42}
bts, err := testCodec.Encode(context.Background(), expected, testCodecKey)

require.NoError(t, err)

reader.On("ReadAll", mock.Anything, mock.Anything).Return(bts, nil).After(time.Second)

ctx := context.Background()
start := time.Now()
loaded := &loadedResult{
value: make(chan []byte, 1),
err: make(chan error, 1),
}

binding.PreLoad(ctx, loaded)

var result testStruct

err = binding.GetLatestValue(ctx, nil, &result, loaded)
elapsed := time.Since(start)

require.NoError(t, err)
assert.GreaterOrEqual(t, elapsed, time.Second)
assert.Less(t, elapsed, 1100*time.Millisecond)
assert.Equal(t, expected, result)
})

t.Run("cancelled context exits preload and returns error on get latest value", func(t *testing.T) {
t.Parallel()

reader := new(mockReader)
binding := newAccountReadBinding(testCodecKey, testCodec, reader)

ctx, cancel := context.WithCancelCause(context.Background())

// make the readall pause until after the context is cancelled
reader.On("ReadAll", mock.Anything, mock.Anything).
Return([]byte{}, nil).
After(600 * time.Millisecond)

expectedErr := errors.New("test error")
go func() {
time.Sleep(500 * time.Millisecond)
cancel(expectedErr)
}()

loaded := &loadedResult{
value: make(chan []byte, 1),
err: make(chan error, 1),
}
start := time.Now()
binding.PreLoad(ctx, loaded)

var result testStruct
err := binding.GetLatestValue(ctx, nil, &result, loaded)
elapsed := time.Since(start)

assert.ErrorIs(t, err, ctx.Err())
assert.ErrorIs(t, context.Cause(ctx), expectedErr)
assert.GreaterOrEqual(t, elapsed, 600*time.Millisecond)
assert.Less(t, elapsed, 700*time.Millisecond)
})

t.Run("error from preload is returned in get latest value", func(t *testing.T) {
t.Parallel()

reader := new(mockReader)
binding := newAccountReadBinding(testCodecKey, testCodec, reader)
ctx := context.Background()
expectedErr := errors.New("test error")

reader.On("ReadAll", mock.Anything, mock.Anything).
Return([]byte{}, expectedErr)

loaded := &loadedResult{
value: make(chan []byte, 1),
err: make(chan error, 1),
}
binding.PreLoad(ctx, loaded)

var result testStruct
err := binding.GetLatestValue(ctx, nil, &result, loaded)

assert.ErrorIs(t, err, expectedErr)
})
}

type mockReader struct {
mock.Mock
}

func (_m *mockReader) ReadAll(ctx context.Context, pk solana.PublicKey) ([]byte, error) {
ret := _m.Called(ctx, pk)

var r0 []byte
if val, ok := ret.Get(0).([]byte); ok {
r0 = val
}

var r1 error
if fn, ok := ret.Get(1).(func() error); ok {
r1 = fn()
} else {
r1 = ret.Error(1)
}

return r0, r1
}

type testStruct struct {
A bool
B int64
}

const testCodecKey = "TEST"

func makeTestCodec(t *testing.T) types.RemoteCodec {
t.Helper()

builder := binary.LittleEndian()

structCodec, err := encodings.NewStructCodec([]encodings.NamedTypeCodec{
{Name: "A", Codec: builder.Bool()},
{Name: "B", Codec: builder.Int64()},
})

require.NoError(t, err)

return encodings.CodecFromTypeCodec(map[string]encodings.TypeCodec{testCodecKey: structCodec})
}
11 changes: 10 additions & 1 deletion pkg/solana/chainreader/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import (
)

type readBinding interface {
GetLatestValue(ctx context.Context, params, returnVal any) error
PreLoad(context.Context, *loadedResult)
GetLatestValue(ctx context.Context, params, returnVal any, preload *loadedResult) error
Bind(types.BoundContract) error
CreateType(bool) (any, error)
}
Expand Down Expand Up @@ -77,6 +78,9 @@ func (b namespaceBindings) CreateType(namespace, methodName string, forEncoding
}

tBinding := reflect.TypeOf(bindingType)
if tBinding.Kind() == reflect.Pointer {
tBinding = tBinding.Elem()
}

// all bindings must be structs to allow multiple bindings
if tBinding.Kind() != reflect.Struct {
Expand Down Expand Up @@ -140,3 +144,8 @@ func (b namespaceBindings) Bind(boundContracts []types.BoundContract) error {

return nil
}

type loadedResult struct {
value chan []byte
err chan error
}
4 changes: 3 additions & 1 deletion pkg/solana/chainreader/bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ type mockBinding struct {
mock.Mock
}

func (_m *mockBinding) GetLatestValue(ctx context.Context, params, returnVal any) error {
func (_m *mockBinding) PreLoad(context.Context, *loadedResult) {}

func (_m *mockBinding) GetLatestValue(ctx context.Context, params, returnVal any, _ *loadedResult) error {
return nil
}

Expand Down
81 changes: 74 additions & 7 deletions pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package chainreader
import (
"context"
"encoding/json"
"fmt"
"reflect"
"sync"

ag_solana "github.com/gagliardetto/solana-go"
"github.com/gagliardetto/solana-go/rpc"
Expand All @@ -26,6 +29,7 @@ type SolanaChainReaderService struct {
bindings namespaceBindings

// service state management
wg sync.WaitGroup
services.StateMachine
}

Expand Down Expand Up @@ -67,6 +71,8 @@ func (s *SolanaChainReaderService) Start(_ context.Context) error {
// up used resources. Subsequent calls to Close return an error.
func (s *SolanaChainReaderService) Close() error {
return s.StopOnce(ServiceName, func() error {
s.wg.Wait()

return nil
})
}
Expand All @@ -86,17 +92,78 @@ func (s *SolanaChainReaderService) HealthReport() map[string]error {
// GetLatestValue implements the types.ChainReader interface and requests and parses on-chain
// data named by the provided contract, method, and params.
func (s *SolanaChainReaderService) GetLatestValue(ctx context.Context, contractName, method string, params any, returnVal any) error {
if err := s.Ready(); err != nil {
return err
}

s.wg.Add(1)
defer s.wg.Done()

bindings, err := s.bindings.GetReadBindings(contractName, method)
if err != nil {
return err
}

for _, binding := range bindings {
if err := binding.GetLatestValue(ctx, params, returnVal); err != nil {
localCtx, localCancel := context.WithCancel(ctx)

// the wait group ensures GetLatestValue returns only after all go-routines have completed
var wg sync.WaitGroup

results := make(map[int]*loadedResult)

if len(bindings) > 1 {
// might go for some guardrails when dealing with multiple bindings
// the returnVal should be compatible with multiple passes by the codec decoder
// this should only apply to types struct{} and map[any]any
tReturnVal := reflect.TypeOf(returnVal)
if tReturnVal.Kind() == reflect.Pointer {
tReturnVal = reflect.Indirect(reflect.ValueOf(returnVal)).Type()
}

switch tReturnVal.Kind() {
case reflect.Struct, reflect.Map:
default:
localCancel()

wg.Wait()

return fmt.Errorf("%w: multiple bindings is only supported for struct and map", types.ErrInvalidType)
}

// for multiple bindings, preload the remote data in parallel
for idx, binding := range bindings {
results[idx] = &loadedResult{
value: make(chan []byte, 1),
err: make(chan error, 1),
}

wg.Add(1)
go func(ctx context.Context, rb readBinding, res *loadedResult) {
defer wg.Done()

rb.PreLoad(ctx, res)
}(localCtx, binding, results[idx])
}
}

// in the case of parallel preloading, GetLatestValue will still run in
// sequence because the function will block until the data is loaded.
// in the case of no preloading, GetLatestValue will load and decode in
// sequence.
for idx, binding := range bindings {
if err := binding.GetLatestValue(ctx, params, returnVal, results[idx]); err != nil {
localCancel()

wg.Wait()

return err
}
}

localCancel()

wg.Wait()

return nil
}

Expand Down Expand Up @@ -136,11 +203,11 @@ func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainReader
return err
}

s.bindings.AddReadBinding(namespace, methodName, &accountReadBinding{
idlAccount: procedure.IDLAccount,
codec: codecWithModifiers,
reader: s.client,
})
s.bindings.AddReadBinding(namespace, methodName, newAccountReadBinding(
procedure.IDLAccount,
codecWithModifiers,
s.client,
))
}
}
}
Expand Down
Loading

0 comments on commit 5b3b9fb

Please sign in to comment.