diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 3e84ea6..8bfd068 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -111,7 +111,20 @@ func (p *Proxy) next() *Server { return p.next() } -func (p *Proxy) update(providers []seed.Provider) error { +func (p *Proxy) update(seed seed.Seed) { + var err error + switch p.kind { + case RPC: + err = p.doUpdate(seed.APIs.RPC) + case Rest: + err = p.doUpdate(seed.APIs.Rest) + } + if err != nil { + slog.Error("could not update seed", "err", err) + } +} + +func (p *Proxy) doUpdate(providers []seed.Provider) error { p.mu.Lock() defer p.mu.Unlock() @@ -153,12 +166,7 @@ func (p *Proxy) Start(ctx context.Context) { for { select { case seed := <-p.ch: - switch p.kind { - case RPC: - p.update(seed.APIs.RPC) - case Rest: - p.update(seed.APIs.Rest) - } + p.update(seed) case <-ctx.Done(): p.shuttingDown.Store(true) return diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 63298fb..b9e0067 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -2,7 +2,6 @@ package proxy import ( "context" - "encoding/json" "fmt" "io" "net/http" @@ -17,53 +16,33 @@ import ( ) func TestProxy(t *testing.T) { - const chainID = "unittest" + for name, kind := range map[string]ProxyKind{ + "rpc": RPC, + "rest": Rest, + } { + t.Run(name, func(t *testing.T) { + testProxy(t, kind) + }) + } +} + +func testProxy(tb testing.TB, kind ProxyKind) { srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "srv1 replied") })) - t.Cleanup(srv1.Close) + tb.Cleanup(srv1.Close) srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(time.Millisecond * 500) _, _ = io.WriteString(w, "srv2 replied") })) - t.Cleanup(srv2.Close) + tb.Cleanup(srv2.Close) srv3 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTeapot) })) - t.Cleanup(srv2.Close) - - seed := seed.Seed{ - ChainID: chainID, - APIs: seed.Apis{ - RPC: []seed.Provider{ - { - Address: srv1.URL, - Provider: "srv1", - }, - { - Address: srv2.URL, - Provider: "srv2", - }, - { - Address: srv3.URL, - Provider: "srv3", - }, - }, - }, - } - - t.Logf("%+v", seed) - - seedSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - bts, _ := json.Marshal(seed) - _, _ = w.Write(bts) - })) - t.Cleanup(seedSrv.Close) + tb.Cleanup(srv2.Close) - proxy := New(config.Config{ - SeedURL: seedSrv.URL, - SeedRefreshInterval: 500 * time.Millisecond, - ChainID: chainID, + ch := make(chan seed.Seed, 1) + proxy := New(kind, ch, config.Config{ HealthyThreshold: 10 * time.Millisecond, ProxyRequestTimeout: time.Second, UnhealthyServerRecoverChancePct: 1, @@ -72,19 +51,43 @@ func TestProxy(t *testing.T) { }) ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) + tb.Cleanup(cancel) proxy.Start(ctx) - require.Len(t, proxy.servers, 3) + serverList := []seed.Provider{ + { + Address: srv1.URL, + Provider: "srv1", + }, + { + Address: srv2.URL, + Provider: "srv2", + }, + { + Address: srv3.URL, + Provider: "srv3", + }, + } + + ch <- seed.Seed{ + APIs: seed.Apis{ + Rest: serverList, + RPC: serverList, + }, + } + + require.Eventually(tb, func() bool { return proxy.initialized.Load() }, time.Second, time.Millisecond) + + require.Len(tb, proxy.servers, 3) proxySrv := httptest.NewServer(proxy) - t.Cleanup(proxySrv.Close) + tb.Cleanup(proxySrv.Close) var wg errgroup.Group wg.SetLimit(20) for i := 0; i < 100; i++ { wg.Go(func() error { - t.Log("go") + tb.Log("go") req, err := http.NewRequest(http.MethodGet, proxySrv.URL, nil) if err != nil { return err @@ -102,13 +105,13 @@ func TestProxy(t *testing.T) { return nil }) } - require.NoError(t, wg.Wait()) + require.NoError(tb, wg.Wait()) // stop the proxy cancel() stats := proxy.Stats() - require.Len(t, stats, 3) + require.Len(tb, stats, 3) var srv1Stats ServerStat var srv2Stats ServerStat @@ -124,13 +127,13 @@ func TestProxy(t *testing.T) { srv3Stats = st } } - require.Zero(t, srv1Stats.ErrorRate) - require.Zero(t, srv2Stats.ErrorRate) - require.Equal(t, float64(100), srv3Stats.ErrorRate) - require.Greater(t, srv1Stats.Requests, srv2Stats.Requests) - require.Greater(t, srv2Stats.Avg, srv1Stats.Avg) - require.False(t, srv1Stats.Degraded) - require.True(t, srv2Stats.Degraded) - require.True(t, srv1Stats.Initialized) - require.True(t, srv2Stats.Initialized) + require.Zero(tb, srv1Stats.ErrorRate) + require.Zero(tb, srv2Stats.ErrorRate) + require.Equal(tb, float64(100), srv3Stats.ErrorRate) + require.Greater(tb, srv1Stats.Requests, srv2Stats.Requests) + require.Greater(tb, srv2Stats.Avg, srv1Stats.Avg) + require.False(tb, srv1Stats.Degraded) + require.True(tb, srv2Stats.Degraded) + require.True(tb, srv1Stats.Initialized) + require.True(tb, srv2Stats.Initialized) } diff --git a/internal/seed/updater_test.go b/internal/seed/updater_test.go new file mode 100644 index 0000000..21feca9 --- /dev/null +++ b/internal/seed/updater_test.go @@ -0,0 +1,78 @@ +package seed + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/akash-network/rpc-proxy/internal/config" + "github.com/stretchr/testify/require" +) + +func TestUpdater(t *testing.T) { + chainID := "test" + seed := Seed{ + ChainID: chainID, + APIs: Apis{ + RPC: []Provider{ + { + Address: "http://rpc.local", + Provider: "rpc-provider", + }, + }, + Rest: []Provider{ + { + Address: "http://rest.local", + Provider: "rest-provider", + }, + }, + }, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bts, _ := json.Marshal(seed) + _, _ = w.Write(bts) + })) + t.Cleanup(srv.Close) + + rpc := make(chan Seed, 1) + rest := make(chan Seed, 1) + + up := New(config.Config{ + SeedRefreshInterval: time.Millisecond, + SeedURL: srv.URL, + ChainID: chainID, + }, rpc, rest) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + up.Start(ctx) + + go func() { + time.Sleep(time.Millisecond * 500) + cancel() + }() + + var rpcUpdates, restUpdates atomic.Uint32 + +outer: + for { + select { + case got := <-rpc: + rpcUpdates.Add(1) + require.Equal(t, seed, got) + case got := <-rest: + restUpdates.Add(1) + require.Equal(t, seed, got) + case <-ctx.Done(): + break outer + } + } + + require.NotZero(t, rpcUpdates.Load()) + require.NotZero(t, restUpdates.Load()) +}