Skip to content

Commit

Permalink
Relay timeouts (Layr-Labs#918)
Browse files Browse the repository at this point in the history
Signed-off-by: Cody Littley <[email protected]>
  • Loading branch information
cody-littley authored Nov 25, 2024
1 parent 0a4e852 commit ac7ffdd
Show file tree
Hide file tree
Showing 18 changed files with 547 additions and 120 deletions.
10 changes: 0 additions & 10 deletions common/aws/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@ type ClientConfig struct {
// FragmentParallelismConstant helps determine the size of the pool of workers to help upload/download files.
// A non-zero value for this parameter adds a constant number of workers. Default is 0.
FragmentParallelismConstant int
// FragmentReadTimeout is used to bound the maximum time to wait for a single fragmented read.
// Default is 30 seconds.
FragmentReadTimeout time.Duration
// FragmentWriteTimeout is used to bound the maximum time to wait for a single fragmented write.
// Default is 30 seconds.
FragmentWriteTimeout time.Duration
}

func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag {
Expand Down Expand Up @@ -120,8 +114,6 @@ func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig {
EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)),
FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)),
FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)),
FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)),
FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)),
}
}

Expand All @@ -131,7 +123,5 @@ func DefaultClientConfig() *ClientConfig {
Region: "us-east-2",
FragmentParallelismFactor: 8,
FragmentParallelismConstant: 0,
FragmentReadTimeout: 30 * time.Second,
FragmentWriteTimeout: 30 * time.Second,
}
}
6 changes: 0 additions & 6 deletions common/aws/s3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,6 @@ func (s *client) FragmentedUploadObject(
}
resultChannel := make(chan error, len(fragments))

ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout)
defer cancel()

for _, fragment := range fragments {
fragmentCapture := fragment
s.concurrencyLimiter <- struct{}{}
Expand Down Expand Up @@ -301,9 +298,6 @@ func (s *client) FragmentedDownloadObject(
}
resultChannel := make(chan *readResult, len(fragmentKeys))

ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout)
defer cancel()

for i, fragmentKey := range fragmentKeys {
boundFragmentKey := fragmentKey
boundI := i
Expand Down
15 changes: 9 additions & 6 deletions relay/auth/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type RequestAuthenticator interface {
// The origin is the address of the peer that sent the request. This may be used to cache auth results
// in order to save server resources.
AuthenticateGetChunksRequest(
ctx context.Context,
origin string,
request *pb.GetChunksRequest,
now time.Time) error
Expand Down Expand Up @@ -53,6 +54,7 @@ type requestAuthenticator struct {

// NewRequestAuthenticator creates a new RequestAuthenticator.
func NewRequestAuthenticator(
ctx context.Context,
ics core.IndexedChainState,
keyCacheSize int,
authenticationTimeoutDuration time.Duration) (RequestAuthenticator, error) {
Expand All @@ -70,20 +72,20 @@ func NewRequestAuthenticator(
keyCache: keyCache,
}

err = authenticator.preloadCache()
err = authenticator.preloadCache(ctx)
if err != nil {
return nil, fmt.Errorf("failed to preload cache: %w", err)
}

return authenticator, nil
}

func (a *requestAuthenticator) preloadCache() error {
func (a *requestAuthenticator) preloadCache(ctx context.Context) error {
blockNumber, err := a.ics.GetCurrentBlockNumber()
if err != nil {
return fmt.Errorf("failed to get current block number: %w", err)
}
operators, err := a.ics.GetIndexedOperators(context.Background(), blockNumber)
operators, err := a.ics.GetIndexedOperators(ctx, blockNumber)
if err != nil {
return fmt.Errorf("failed to get operators: %w", err)
}
Expand All @@ -96,6 +98,7 @@ func (a *requestAuthenticator) preloadCache() error {
}

func (a *requestAuthenticator) AuthenticateGetChunksRequest(
ctx context.Context,
origin string,
request *pb.GetChunksRequest,
now time.Time) error {
Expand All @@ -105,7 +108,7 @@ func (a *requestAuthenticator) AuthenticateGetChunksRequest(
return nil
}

key, err := a.getOperatorKey(core.OperatorID(request.OperatorId))
key, err := a.getOperatorKey(ctx, core.OperatorID(request.OperatorId))
if err != nil {
return fmt.Errorf("failed to get operator key: %w", err)
}
Expand All @@ -131,7 +134,7 @@ func (a *requestAuthenticator) AuthenticateGetChunksRequest(
}

// getOperatorKey returns the public key of the operator with the given ID, caching the result.
func (a *requestAuthenticator) getOperatorKey(operatorID core.OperatorID) (*core.G2Point, error) {
func (a *requestAuthenticator) getOperatorKey(ctx context.Context, operatorID core.OperatorID) (*core.G2Point, error) {
key, ok := a.keyCache.Get(operatorID)
if ok {
return key, nil
Expand All @@ -141,7 +144,7 @@ func (a *requestAuthenticator) getOperatorKey(operatorID core.OperatorID) (*core
if err != nil {
return nil, fmt.Errorf("failed to get current block number: %w", err)
}
operators, err := a.ics.GetIndexedOperators(context.Background(), blockNumber)
operators, err := a.ics.GetIndexedOperators(ctx, blockNumber)
if err != nil {
return nil, fmt.Errorf("failed to get operators: %w", err)
}
Expand Down
29 changes: 24 additions & 5 deletions relay/auth/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
func TestMockSigning(t *testing.T) {
tu.InitializeRandom()

ctx := context.Background()

operatorID := mock.MakeOperatorId(0)
stakes := map[core.QuorumID]map[core.OperatorID]int{
core.QuorumID(0): {
Expand All @@ -24,7 +26,7 @@ func TestMockSigning(t *testing.T) {
ics, err := mock.NewChainDataMock(stakes)
require.NoError(t, err)

operators, err := ics.GetIndexedOperators(context.Background(), 0)
operators, err := ics.GetIndexedOperators(ctx, 0)
require.NoError(t, err)

operator, ok := operators[operatorID]
Expand All @@ -46,6 +48,8 @@ func TestMockSigning(t *testing.T) {
func TestValidRequest(t *testing.T) {
tu.InitializeRandom()

ctx := context.Background()

operatorID := mock.MakeOperatorId(0)
stakes := map[core.QuorumID]map[core.OperatorID]int{
core.QuorumID(0): {
Expand All @@ -58,7 +62,7 @@ func TestValidRequest(t *testing.T) {

timeout := 10 * time.Second

authenticator, err := NewRequestAuthenticator(ics, 1024, timeout)
authenticator, err := NewRequestAuthenticator(ctx, ics, 1024, timeout)
require.NoError(t, err)

request := randomGetChunksRequest()
Expand All @@ -69,6 +73,7 @@ func TestValidRequest(t *testing.T) {
now := time.Now()

err = authenticator.AuthenticateGetChunksRequest(
ctx,
"foobar",
request,
now)
Expand All @@ -83,12 +88,14 @@ func TestValidRequest(t *testing.T) {
start := now
for now.Before(start.Add(timeout)) {
err = authenticator.AuthenticateGetChunksRequest(
ctx,
"foobar",
invalidRequest,
now)
require.NoError(t, err)

err = authenticator.AuthenticateGetChunksRequest(
ctx,
"baz",
invalidRequest,
now)
Expand All @@ -99,6 +106,7 @@ func TestValidRequest(t *testing.T) {

// After the timeout elapses, new requests should trigger authentication.
err = authenticator.AuthenticateGetChunksRequest(
ctx,
"foobar",
invalidRequest,
now)
Expand All @@ -108,6 +116,8 @@ func TestValidRequest(t *testing.T) {
func TestAuthenticationSavingDisabled(t *testing.T) {
tu.InitializeRandom()

ctx := context.Background()

operatorID := mock.MakeOperatorId(0)
stakes := map[core.QuorumID]map[core.OperatorID]int{
core.QuorumID(0): {
Expand All @@ -121,7 +131,7 @@ func TestAuthenticationSavingDisabled(t *testing.T) {
// This disables saving of authentication results.
timeout := time.Duration(0)

authenticator, err := NewRequestAuthenticator(ics, 1024, timeout)
authenticator, err := NewRequestAuthenticator(ctx, ics, 1024, timeout)
require.NoError(t, err)

request := randomGetChunksRequest()
Expand All @@ -132,6 +142,7 @@ func TestAuthenticationSavingDisabled(t *testing.T) {
now := time.Now()

err = authenticator.AuthenticateGetChunksRequest(
ctx,
"foobar",
request,
now)
Expand All @@ -144,6 +155,7 @@ func TestAuthenticationSavingDisabled(t *testing.T) {
invalidRequest.OperatorSignature = signature // the previous signature is invalid here

err = authenticator.AuthenticateGetChunksRequest(
ctx,
"foobar",
invalidRequest,
now)
Expand All @@ -153,6 +165,8 @@ func TestAuthenticationSavingDisabled(t *testing.T) {
func TestNonExistingClient(t *testing.T) {
tu.InitializeRandom()

ctx := context.Background()

operatorID := mock.MakeOperatorId(0)
stakes := map[core.QuorumID]map[core.OperatorID]int{
core.QuorumID(0): {
Expand All @@ -165,7 +179,7 @@ func TestNonExistingClient(t *testing.T) {

timeout := 10 * time.Second

authenticator, err := NewRequestAuthenticator(ics, 1024, timeout)
authenticator, err := NewRequestAuthenticator(ctx, ics, 1024, timeout)
require.NoError(t, err)

invalidOperatorID := tu.RandomBytes(32)
Expand All @@ -174,6 +188,7 @@ func TestNonExistingClient(t *testing.T) {
request.OperatorId = invalidOperatorID

err = authenticator.AuthenticateGetChunksRequest(
ctx,
"foobar",
request,
time.Now())
Expand All @@ -183,6 +198,8 @@ func TestNonExistingClient(t *testing.T) {
func TestBadSignature(t *testing.T) {
tu.InitializeRandom()

ctx := context.Background()

operatorID := mock.MakeOperatorId(0)
stakes := map[core.QuorumID]map[core.OperatorID]int{
core.QuorumID(0): {
Expand All @@ -195,7 +212,7 @@ func TestBadSignature(t *testing.T) {

timeout := 10 * time.Second

authenticator, err := NewRequestAuthenticator(ics, 1024, timeout)
authenticator, err := NewRequestAuthenticator(ctx, ics, 1024, timeout)
require.NoError(t, err)

request := randomGetChunksRequest()
Expand All @@ -205,6 +222,7 @@ func TestBadSignature(t *testing.T) {
now := time.Now()

err = authenticator.AuthenticateGetChunksRequest(
ctx,
"foobar",
request,
now)
Expand All @@ -217,6 +235,7 @@ func TestBadSignature(t *testing.T) {
request.OperatorSignature[0] = request.OperatorSignature[0] ^ 1

err = authenticator.AuthenticateGetChunksRequest(
ctx,
"foobar",
request,
now)
Expand Down
24 changes: 16 additions & 8 deletions relay/blob_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/Layr-Labs/eigenda/disperser/common/v2/blobstore"
"github.com/Layr-Labs/eigenda/relay/cache"
"github.com/Layr-Labs/eigensdk-go/logging"
"time"
)

// blobProvider encapsulates logic for fetching blobs. Utilized by the relay Server.
Expand All @@ -20,6 +21,9 @@ type blobProvider struct {

// blobCache is an LRU cache of blobs.
blobCache cache.CachedAccessor[v2.BlobKey, []byte]

// fetchTimeout is the maximum time to wait for a blob fetch operation to complete.
fetchTimeout time.Duration
}

// newBlobProvider creates a new blobProvider.
Expand All @@ -28,12 +32,14 @@ func newBlobProvider(
logger logging.Logger,
blobStore *blobstore.BlobStore,
blobCacheSize int,
maxIOConcurrency int) (*blobProvider, error) {
maxIOConcurrency int,
fetchTimeout time.Duration) (*blobProvider, error) {

server := &blobProvider{
ctx: ctx,
logger: logger,
blobStore: blobStore,
ctx: ctx,
logger: logger,
blobStore: blobStore,
fetchTimeout: fetchTimeout,
}

c, err := cache.NewCachedAccessor[v2.BlobKey, []byte](blobCacheSize, maxIOConcurrency, server.fetchBlob)
Expand All @@ -46,9 +52,8 @@ func newBlobProvider(
}

// GetBlob retrieves a blob from the blob store.
func (s *blobProvider) GetBlob(blobKey v2.BlobKey) ([]byte, error) {

data, err := s.blobCache.Get(blobKey)
func (s *blobProvider) GetBlob(ctx context.Context, blobKey v2.BlobKey) ([]byte, error) {
data, err := s.blobCache.Get(ctx, blobKey)

if err != nil {
// It should not be possible for external users to force an error here since we won't
Expand All @@ -62,7 +67,10 @@ func (s *blobProvider) GetBlob(blobKey v2.BlobKey) ([]byte, error) {

// fetchBlob retrieves a single blob from the blob store.
func (s *blobProvider) fetchBlob(blobKey v2.BlobKey) ([]byte, error) {
data, err := s.blobStore.GetBlob(s.ctx, blobKey)
ctx, cancel := context.WithTimeout(s.ctx, s.fetchTimeout)
defer cancel()

data, err := s.blobStore.GetBlob(ctx, blobKey)
if err != nil {
s.logger.Errorf("Failed to fetch blob: %v", err)
return nil, err
Expand Down
Loading

0 comments on commit ac7ffdd

Please sign in to comment.