Skip to content

Commit

Permalink
Made suggested changes.
Browse files Browse the repository at this point in the history
Signed-off-by: Cody Littley <[email protected]>
  • Loading branch information
cody-littley committed Nov 22, 2024
1 parent 43d8c91 commit 9969859
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 36 deletions.
7 changes: 4 additions & 3 deletions relay/auth/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type requestAuthenticator struct {

// NewRequestAuthenticator creates a new RequestAuthenticator.
func NewRequestAuthenticator(
ctx context.Context,
ics core.IndexedChainState,
keyCacheSize int,
authenticationTimeoutDuration time.Duration) (RequestAuthenticator, error) {
Expand All @@ -71,20 +72,20 @@ func NewRequestAuthenticator(
keyCache: keyCache,
}

err = authenticator.preloadCache()
err = authenticator.preloadCache(ctx)
if err != nil {
return nil, fmt.Errorf("failed to preload cache: %w", err)
}

return authenticator, nil
}

func (a *requestAuthenticator) preloadCache() error {
func (a *requestAuthenticator) preloadCache(ctx context.Context) error {
blockNumber, err := a.ics.GetCurrentBlockNumber()
if err != nil {
return fmt.Errorf("failed to get current block number: %w", err)
}
operators, err := a.ics.GetIndexedOperators(context.Background(), blockNumber)
operators, err := a.ics.GetIndexedOperators(ctx, blockNumber)
if err != nil {
return fmt.Errorf("failed to get operators: %w", err)
}
Expand Down
38 changes: 24 additions & 14 deletions relay/auth/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
func TestMockSigning(t *testing.T) {
tu.InitializeRandom()

ctx := context.Background()

operatorID := mock.MakeOperatorId(0)
stakes := map[core.QuorumID]map[core.OperatorID]int{
core.QuorumID(0): {
Expand All @@ -24,7 +26,7 @@ func TestMockSigning(t *testing.T) {
ics, err := mock.NewChainDataMock(stakes)
require.NoError(t, err)

operators, err := ics.GetIndexedOperators(context.Background(), 0)
operators, err := ics.GetIndexedOperators(ctx, 0)
require.NoError(t, err)

operator, ok := operators[operatorID]
Expand All @@ -46,6 +48,8 @@ func TestMockSigning(t *testing.T) {
func TestValidRequest(t *testing.T) {
tu.InitializeRandom()

ctx := context.Background()

operatorID := mock.MakeOperatorId(0)
stakes := map[core.QuorumID]map[core.OperatorID]int{
core.QuorumID(0): {
Expand All @@ -58,7 +62,7 @@ func TestValidRequest(t *testing.T) {

timeout := 10 * time.Second

authenticator, err := NewRequestAuthenticator(ics, 1024, timeout)
authenticator, err := NewRequestAuthenticator(ctx, ics, 1024, timeout)
require.NoError(t, err)

request := randomGetChunksRequest()
Expand All @@ -69,7 +73,7 @@ func TestValidRequest(t *testing.T) {
now := time.Now()

err = authenticator.AuthenticateGetChunksRequest(
context.Background(),
ctx,
"foobar",
request,
now)
Expand All @@ -84,14 +88,14 @@ func TestValidRequest(t *testing.T) {
start := now
for now.Before(start.Add(timeout)) {
err = authenticator.AuthenticateGetChunksRequest(
context.Background(),
ctx,
"foobar",
invalidRequest,
now)
require.NoError(t, err)

err = authenticator.AuthenticateGetChunksRequest(
context.Background(),
ctx,
"baz",
invalidRequest,
now)
Expand All @@ -102,7 +106,7 @@ func TestValidRequest(t *testing.T) {

// After the timeout elapses, new requests should trigger authentication.
err = authenticator.AuthenticateGetChunksRequest(
context.Background(),
ctx,
"foobar",
invalidRequest,
now)
Expand All @@ -112,6 +116,8 @@ func TestValidRequest(t *testing.T) {
func TestAuthenticationSavingDisabled(t *testing.T) {
tu.InitializeRandom()

ctx := context.Background()

operatorID := mock.MakeOperatorId(0)
stakes := map[core.QuorumID]map[core.OperatorID]int{
core.QuorumID(0): {
Expand All @@ -125,7 +131,7 @@ func TestAuthenticationSavingDisabled(t *testing.T) {
// This disables saving of authentication results.
timeout := time.Duration(0)

authenticator, err := NewRequestAuthenticator(ics, 1024, timeout)
authenticator, err := NewRequestAuthenticator(ctx, ics, 1024, timeout)
require.NoError(t, err)

request := randomGetChunksRequest()
Expand All @@ -136,7 +142,7 @@ func TestAuthenticationSavingDisabled(t *testing.T) {
now := time.Now()

err = authenticator.AuthenticateGetChunksRequest(
context.Background(),
ctx,
"foobar",
request,
now)
Expand All @@ -149,7 +155,7 @@ func TestAuthenticationSavingDisabled(t *testing.T) {
invalidRequest.OperatorSignature = signature // the previous signature is invalid here

err = authenticator.AuthenticateGetChunksRequest(
context.Background(),
ctx,
"foobar",
invalidRequest,
now)
Expand All @@ -159,6 +165,8 @@ func TestAuthenticationSavingDisabled(t *testing.T) {
func TestNonExistingClient(t *testing.T) {
tu.InitializeRandom()

ctx := context.Background()

operatorID := mock.MakeOperatorId(0)
stakes := map[core.QuorumID]map[core.OperatorID]int{
core.QuorumID(0): {
Expand All @@ -171,7 +179,7 @@ func TestNonExistingClient(t *testing.T) {

timeout := 10 * time.Second

authenticator, err := NewRequestAuthenticator(ics, 1024, timeout)
authenticator, err := NewRequestAuthenticator(ctx, ics, 1024, timeout)
require.NoError(t, err)

invalidOperatorID := tu.RandomBytes(32)
Expand All @@ -180,7 +188,7 @@ func TestNonExistingClient(t *testing.T) {
request.OperatorId = invalidOperatorID

err = authenticator.AuthenticateGetChunksRequest(
context.Background(),
ctx,
"foobar",
request,
time.Now())
Expand All @@ -190,6 +198,8 @@ func TestNonExistingClient(t *testing.T) {
func TestBadSignature(t *testing.T) {
tu.InitializeRandom()

ctx := context.Background()

operatorID := mock.MakeOperatorId(0)
stakes := map[core.QuorumID]map[core.OperatorID]int{
core.QuorumID(0): {
Expand All @@ -202,7 +212,7 @@ func TestBadSignature(t *testing.T) {

timeout := 10 * time.Second

authenticator, err := NewRequestAuthenticator(ics, 1024, timeout)
authenticator, err := NewRequestAuthenticator(ctx, ics, 1024, timeout)
require.NoError(t, err)

request := randomGetChunksRequest()
Expand All @@ -212,7 +222,7 @@ func TestBadSignature(t *testing.T) {
now := time.Now()

err = authenticator.AuthenticateGetChunksRequest(
context.Background(),
ctx,
"foobar",
request,
now)
Expand All @@ -225,7 +235,7 @@ func TestBadSignature(t *testing.T) {
request.OperatorSignature[0] = request.OperatorSignature[0] ^ 1

err = authenticator.AuthenticateGetChunksRequest(
context.Background(),
ctx,
"foobar",
request,
now)
Expand Down
31 changes: 12 additions & 19 deletions relay/cache/cached_accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cache
import (
"context"
lru "github.com/hashicorp/golang-lru/v2"
"golang.org/x/sync/semaphore"
"sync"
)

Expand All @@ -20,8 +21,8 @@ type Accessor[K comparable, V any] func(key K) (V, error)

// accessResult is a struct that holds the result of an Accessor call.
type accessResult[V any] struct {
// wg.Wait() will block until the value is fetched.
wg sync.WaitGroup
// sem is a semaphore used to signal that the value has been fetched.
sem *semaphore.Weighted
// value is the value fetched by the Accessor, or nil if there was an error.
value V
// err is the error returned by the Accessor, or nil if the fetch was successful.
Expand All @@ -37,7 +38,6 @@ var _ CachedAccessor[string, string] = &cachedAccessor[string, string]{}

// cachedAccessor is an implementation of CachedAccessor.
type cachedAccessor[K comparable, V any] struct {

// lookupsInProgress has an entry for each key that is currently being looked up via the accessor. The value
// is written into the channel when it is eventually fetched. If a key is requested more than once while a
// lookup in progress, the second (and following) requests will wait for the result of the first lookup
Expand Down Expand Up @@ -89,9 +89,9 @@ func NewCachedAccessor[K comparable, V any](

func newAccessResult[V any]() *accessResult[V] {
result := &accessResult[V]{
wg: sync.WaitGroup{},
sem: semaphore.NewWeighted(1),
}
result.wg.Add(1)
result.sem.TryAcquire(1)
return result
}

Expand Down Expand Up @@ -127,21 +127,14 @@ func (c *cachedAccessor[K, V]) Get(ctx context.Context, key K) (V, error) {
// when it becomes is available. This method will return quickly if the provided context is cancelled.
// Doing so does not disrupt the other requesters that are also waiting for this result.
func (c *cachedAccessor[K, V]) waitForResult(ctx context.Context, result *accessResult[V]) (V, error) {
wgChan := make(chan struct{}, 1)
go func() {
// Wait inside this goroutine for select statement compatibility.
result.wg.Wait()
wgChan <- struct{}{}
}()

select {
case <-ctx.Done():
// The context was cancelled before the value was fetched, possibly due to a timeout.
err := result.sem.Acquire(ctx, 1)
if err != nil {
var zeroValue V
return zeroValue, ctx.Err()
case <-wgChan:
return result.value, result.err
return zeroValue, err
}

result.sem.Release(1)
return result.value, result.err
}

// fetchResult fetches the value for the given key and returns it. If the context is cancelled before the value
Expand Down Expand Up @@ -172,7 +165,7 @@ func (c *cachedAccessor[K, V]) fetchResult(ctx context.Context, key K, result *a
// Provide the result to all other goroutines that may be waiting for it.
result.err = err
result.value = value
result.wg.Done()
result.sem.Release(1)

// Clean up the lookupInProgress map.
delete(c.lookupsInProgress, key)
Expand Down
1 change: 1 addition & 0 deletions relay/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ func NewServer(
var authenticator auth.RequestAuthenticator
if !config.AuthenticationDisabled {
authenticator, err = auth.NewRequestAuthenticator(
ctx,
ics,
config.AuthenticationKeyCacheSize,
config.AuthenticationTimeout)
Expand Down

0 comments on commit 9969859

Please sign in to comment.