diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 6d8f6a3..3e84ea6 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -10,7 +10,6 @@ import ( "strings" "sync" "sync/atomic" - "time" "github.com/akash-network/rpc-proxy/internal/config" "github.com/akash-network/rpc-proxy/internal/seed" @@ -23,9 +22,14 @@ const ( Rest ProxyKind = iota ) -func New(kind ProxyKind, cfg config.Config) *Proxy { +func New( + kind ProxyKind, + ch chan seed.Seed, + cfg config.Config, +) *Proxy { return &Proxy{ cfg: cfg, + ch: ch, kind: kind, } } @@ -34,6 +38,7 @@ type Proxy struct { cfg config.Config kind ProxyKind init sync.Once + ch chan seed.Seed round int mu sync.Mutex @@ -106,7 +111,6 @@ func (p *Proxy) next() *Server { return p.next() } -// TODO: move this to another thing, share it with multiple proxies func (p *Proxy) update(providers []seed.Provider) error { p.mu.Lock() defer p.mu.Unlock() @@ -146,40 +150,20 @@ func (p *Proxy) update(providers []seed.Provider) error { func (p *Proxy) Start(ctx context.Context) { p.init.Do(func() { go func() { - t := time.NewTicker(p.cfg.SeedRefreshInterval) - defer t.Stop() for { select { - case <-t.C: - p.fetchAndUpdate() + case seed := <-p.ch: + switch p.kind { + case RPC: + p.update(seed.APIs.RPC) + case Rest: + p.update(seed.APIs.Rest) + } case <-ctx.Done(): p.shuttingDown.Store(true) return } } }() - p.fetchAndUpdate() }) } - -func (p *Proxy) fetchAndUpdate() { - result, err := seed.Fetch(p.cfg.SeedURL) - if err != nil { - slog.Error("could not get initial seed list", "err", err) - return - } - if result.ChainID != p.cfg.ChainID { - slog.Error("chain ID is different than expected", "got", result.ChainID, "expected", p.cfg.ChainID) - return - } - switch p.kind { - case RPC: - if err := p.update(result.APIs.RPC); err != nil { - slog.Error("could not update servers", "err", err) - } - case Rest: - if err := p.update(result.APIs.Rest); err != nil { - slog.Error("could not update servers", "err", err) - } - } -} diff --git a/internal/seed/seed.go b/internal/seed/seed.go index e75ec15..93ce094 100644 --- a/internal/seed/seed.go +++ b/internal/seed/seed.go @@ -23,7 +23,7 @@ type Apis struct { Rest []Provider `json:"rest"` } -func Fetch(url string) (Seed, error) { +func fetch(url string) (Seed, error) { var seed Seed resp, err := http.Get(url) if err != nil { diff --git a/internal/seed/updater.go b/internal/seed/updater.go new file mode 100644 index 0000000..49b59ab --- /dev/null +++ b/internal/seed/updater.go @@ -0,0 +1,56 @@ +package seed + +import ( + "context" + "log/slog" + "sync" + "time" + + "github.com/akash-network/rpc-proxy/internal/config" +) + +type Updater struct { + cfg config.Config + listeners []chan<- Seed + init sync.Once +} + +func New(cfg config.Config, listeners ...chan<- Seed) *Updater { + return &Updater{ + cfg: cfg, + listeners: listeners, + } +} + +func (u *Updater) Start(ctx context.Context) { + u.init.Do(func() { + go func() { + t := time.NewTicker(u.cfg.SeedRefreshInterval) + defer t.Stop() + for { + select { + case <-t.C: + u.fetchAndUpdate() + case <-ctx.Done(): + return + } + } + }() + u.fetchAndUpdate() + }) +} + +func (u *Updater) fetchAndUpdate() { + result, err := fetch(u.cfg.SeedURL) + if err != nil { + slog.Error("could not get initial seed list", "err", err) + return + } + if result.ChainID != u.cfg.ChainID { + slog.Error("chain ID is different than expected", "got", result.ChainID, "expected", u.cfg.ChainID) + return + } + for _, ch := range u.listeners { + ch <- result + } +} diff --git a/main.go b/main.go index 01d61e8..e68ccba 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "github.com/akash-network/rpc-proxy/internal/config" "github.com/akash-network/rpc-proxy/internal/proxy" + "github.com/akash-network/rpc-proxy/internal/seed" "golang.org/x/crypto/acme/autocert" ) @@ -34,13 +35,18 @@ func main() { am.HostPolicy = autocert.HostWhitelist(hosts...) } - rpcProxyHandler := proxy.New(proxy.RPC, cfg) - restProxyHandler := proxy.New(proxy.Rest, cfg) + rpcListener := make(chan seed.Seed, 1) + restListener := make(chan seed.Seed, 1) - proxyCtx, proxyCtxCancel := context.WithCancel(context.Background()) + updater := seed.New(cfg, rpcListener, restListener) + rpcProxyHandler := proxy.New(proxy.RPC, rpcListener, cfg) + restProxyHandler := proxy.New(proxy.Rest, restListener, cfg) + + ctx, proxyCtxCancel := context.WithCancel(context.Background()) defer proxyCtxCancel() - rpcProxyHandler.Start(proxyCtx) - restProxyHandler.Start(proxyCtx) + updater.Start(ctx) + rpcProxyHandler.Start(ctx) + restProxyHandler.Start(ctx) indexTpl := template.Must(template.New("stats").Parse(string(index))) @@ -103,9 +109,9 @@ func main() { proxyCtxCancel() - proxyCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - if err := srv.Shutdown(proxyCtx); err != nil { + if err := srv.Shutdown(ctx); err != nil { slog.Error("could not close server", "err", err) os.Exit(1) }