From 02b76df27fb4369ee017ce5512df1f8041f1afce Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Wed, 4 Oct 2023 11:47:00 -0500 Subject: [PATCH] BCF-2684: copy types from core to support relayer implementations (#187) --- .github/workflows/golangci_lint.yml | 2 +- .tool-versions | 2 +- go.mod | 2 +- ops/go.mod | 2 +- pkg/chains/nodes.go | 97 +++++++++ pkg/chains/nodes_test.go | 166 ++++++++++++++ pkg/config/error.go | 47 +++- pkg/loop/README.md | 2 +- pkg/loop/internal/service.go | 6 +- pkg/loop/internal/test/median.go | 6 +- pkg/loop/internal/test/relayer.go | 6 +- pkg/loop/median_service_test.go | 6 +- pkg/loop/plugin_service.go | 7 +- pkg/loop/plugin_service_test.go | 4 +- pkg/loop/plugin_test.go | 4 +- pkg/loop/relayer_service_test.go | 6 +- pkg/monitoring/exporter_prometheus_test.go | 14 +- pkg/services/health.go | 24 +++ pkg/services/multi.go | 95 ++++++++ pkg/services/multi_example_test.go | 90 ++++++++ pkg/services/state.go | 239 +++++++++++++++++++++ pkg/services/types.go | 93 ++++++++ pkg/types/types.go | 1 + pkg/utils/start_stop_once.go | 1 + pkg/utils/tests/tests.go | 35 +++ pkg/utils/testutils.go | 15 +- 26 files changed, 920 insertions(+), 52 deletions(-) create mode 100644 pkg/chains/nodes.go create mode 100644 pkg/chains/nodes_test.go create mode 100644 pkg/services/health.go create mode 100644 pkg/services/multi.go create mode 100644 pkg/services/multi_example_test.go create mode 100644 pkg/services/state.go create mode 100644 pkg/services/types.go create mode 100644 pkg/utils/tests/tests.go diff --git a/.github/workflows/golangci_lint.yml b/.github/workflows/golangci_lint.yml index ecb23aba3..5d4472e55 100644 --- a/.github/workflows/golangci_lint.yml +++ b/.github/workflows/golangci_lint.yml @@ -12,7 +12,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: '1.20' + go-version: '1.21' - name: Install golangci-lint run: curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.53.3 diff --git a/.tool-versions b/.tool-versions index 2bbf20a04..12ed2585d 100644 --- a/.tool-versions +++ b/.tool-versions @@ -1,2 +1,2 @@ -golang 1.20.4 +golang 1.21.1 golangci-lint 1.51.1 diff --git a/go.mod b/go.mod index 645d41029..28cfeb406 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/smartcontractkit/chainlink-relay -go 1.20 +go 1.21 require ( github.com/confluentinc/confluent-kafka-go v1.9.2 diff --git a/ops/go.mod b/ops/go.mod index a2e79d9dc..0b756168d 100644 --- a/ops/go.mod +++ b/ops/go.mod @@ -1,6 +1,6 @@ module github.com/smartcontractkit/chainlink-relay/ops -go 1.20 +go 1.21 require ( github.com/lib/pq v1.10.4 diff --git a/pkg/chains/nodes.go b/pkg/chains/nodes.go new file mode 100644 index 000000000..0e2eff2bd --- /dev/null +++ b/pkg/chains/nodes.go @@ -0,0 +1,97 @@ +package chains + +import ( + "encoding/base64" + "errors" + "fmt" + "net/url" + "strconv" + + "github.com/smartcontractkit/chainlink-relay/pkg/types" +) + +// pageToken is simple internal representation for coordination requests and responses in a paginated API +// It is inspired by the Google API Design patterns +// https://cloud.google.com/apis/design/design_patterns#list_pagination +// https://google.aip.dev/158 +type pageToken struct { + Page int + Size int +} + +var ( + ErrInvalidToken = errors.New("invalid page token") + ErrOutOfRange = errors.New("out of range") + defaultSize = 100 +) + +// Encode the token in base64 for transmission for the wire +func (pr *pageToken) Encode() string { + if pr.Size == 0 { + pr.Size = defaultSize + } + // this is a simple minded implementation and may benefit from something fancier + // note that this is a valid url.Query string, which we leverage in decoding + s := fmt.Sprintf("page=%d&size=%d", pr.Page, pr.Size) + return base64.RawStdEncoding.EncodeToString([]byte(s)) +} + +// b64enc must be the base64 encoded token string, corresponding to [pageToken.Encode()] +func NewPageToken(b64enc string) (*pageToken, error) { + // empty is valid + if b64enc == "" { + return &pageToken{Page: 0, Size: defaultSize}, nil + } + + b, err := base64.RawStdEncoding.DecodeString(b64enc) + if err != nil { + return nil, err + } + // here too, this is simple minded and could be fancier + + vals, err := url.ParseQuery(string(b)) + if err != nil { + return nil, err + } + if !(vals.Has("page") && vals.Has("size")) { + return nil, ErrInvalidToken + } + page, err := strconv.Atoi(vals.Get("page")) + if err != nil { + return nil, fmt.Errorf("%w: bad page", ErrInvalidToken) + } + size, err := strconv.Atoi(vals.Get("size")) + if err != nil { + return nil, fmt.Errorf("%w: bad size", ErrInvalidToken) + } + return &pageToken{ + Page: page, + Size: size, + }, err +} + +// if start is out of range, must return ErrOutOfRange +type ListNodeStatusFn = func(start, end int) (stats []types.NodeStatus, total int, err error) + +func ListNodeStatuses(pageSize int, pageTokenStr string, listFn ListNodeStatusFn) (stats []types.NodeStatus, nextPageToken string, total int, err error) { + if pageSize == 0 { + pageSize = defaultSize + } + t := &pageToken{Page: 0, Size: pageSize} + if pageTokenStr != "" { + t, err = NewPageToken(pageTokenStr) + if err != nil { + return nil, "", -1, err + } + } + start, end := t.Page*t.Size, (t.Page+1)*t.Size + stats, total, err = listFn(start, end) + if err != nil { + return stats, "", -1, err + } + if total > end { + next_token := &pageToken{Page: t.Page + 1, Size: t.Size} + nextPageToken = next_token.Encode() + } + return stats, nextPageToken, total, nil +} diff --git a/pkg/chains/nodes_test.go b/pkg/chains/nodes_test.go new file mode 100644 index 000000000..38084d7fb --- /dev/null +++ b/pkg/chains/nodes_test.go @@ -0,0 +1,166 @@ +package chains + +import ( + "encoding/base64" + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/smartcontractkit/chainlink-relay/pkg/types" +) + +func TestNewPageToken(t *testing.T) { + type args struct { + t *pageToken + } + tests := []struct { + name string + args args + want *pageToken + wantErr bool + }{ + { + name: "empty", + args: args{t: &pageToken{}}, + want: &pageToken{Page: 0, Size: defaultSize}, + }, + { + name: "page set, size unset", + args: args{t: &pageToken{Page: 1}}, + want: &pageToken{Page: 1, Size: defaultSize}, + }, + { + name: "page set, size set", + args: args{t: &pageToken{Page: 3, Size: 10}}, + want: &pageToken{Page: 3, Size: 10}, + }, + { + name: "page unset, size set", + args: args{t: &pageToken{Size: 17}}, + want: &pageToken{Page: 0, Size: 17}, + }, + } + for _, tt := range tests { + enc := tt.args.t.Encode() + t.Run(tt.name, func(t *testing.T) { + got, err := NewPageToken(enc) + if (err != nil) != tt.wantErr { + t.Errorf("NewPageToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewPageToken() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestListNodeStatuses(t *testing.T) { + testStats := []types.NodeStatus{ + { + ChainID: "chain-1", + Name: "name-1", + }, + { + ChainID: "chain-2", + Name: "name-2", + }, + { + ChainID: "chain-3", + Name: "name-3", + }, + } + + type args struct { + pageSize int + pageToken string + listFn ListNodeStatusFn + } + tests := []struct { + name string + args args + wantStats []types.NodeStatus + wantNextPageToken string + wantTotal int + wantErr bool + }{ + { + name: "all on first page", + args: args{ + pageSize: 10, // > length of test stats + pageToken: "", + listFn: func(start, end int) ([]types.NodeStatus, int, error) { + return testStats, len(testStats), nil + }, + }, + wantNextPageToken: "", + wantTotal: len(testStats), + wantStats: testStats, + }, + { + name: "small first page", + args: args{ + pageSize: len(testStats) - 1, + pageToken: "", + listFn: func(start, end int) ([]types.NodeStatus, int, error) { + return testStats[start:end], len(testStats), nil + }, + }, + wantNextPageToken: base64.RawStdEncoding.EncodeToString([]byte("page=1&size=2")), // hard coded 2 is len(testStats)-1 + wantTotal: len(testStats), + wantStats: testStats[0 : len(testStats)-1], + }, + { + name: "second page", + args: args{ + pageSize: len(testStats) - 1, + pageToken: base64.RawStdEncoding.EncodeToString([]byte("page=1&size=2")), // hard coded 2 is len(testStats)-1 + listFn: func(start, end int) ([]types.NodeStatus, int, error) { + // note list function must do the start, end bound checking. here we are making it simple + if end > len(testStats) { + end = len(testStats) + } + return testStats[start:end], len(testStats), nil + }, + }, + wantNextPageToken: "", + wantTotal: len(testStats), + wantStats: testStats[len(testStats)-1:], + }, + { + name: "bad list fn", + args: args{ + listFn: func(start, end int) ([]types.NodeStatus, int, error) { + return nil, 0, fmt.Errorf("i'm a bad list fn") + }, + }, + wantTotal: -1, + wantErr: true, + }, + { + name: "invalid token", + args: args{ + pageToken: "invalid token", + listFn: func(start, end int) ([]types.NodeStatus, int, error) { + return testStats[start:end], len(testStats), nil + }, + }, + wantTotal: -1, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotStats, gotNext_pageToken, gotTotal, err := ListNodeStatuses(tt.args.pageSize, tt.args.pageToken, tt.args.listFn) + if (err != nil) != tt.wantErr { + t.Errorf("ListNodeStatuses() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.wantStats, gotStats) + assert.Equal(t, tt.wantNextPageToken, gotNext_pageToken) + assert.Equal(t, tt.wantTotal, gotTotal) + }) + } +} diff --git a/pkg/config/error.go b/pkg/config/error.go index a504654f2..8cd4741b0 100644 --- a/pkg/config/error.go +++ b/pkg/config/error.go @@ -1,6 +1,9 @@ package config -import "fmt" +import ( + "fmt" + "reflect" +) // lightweight error types copied from core @@ -11,7 +14,12 @@ type ErrInvalid struct { } func (e ErrInvalid) Error() string { - return fmt.Sprintf("%s: invalid value %v: %s", e.Name, e.Value, e.Msg) + return fmt.Sprintf("%s: invalid value (%v): %s", e.Name, e.Value, e.Msg) +} + +// NewErrDuplicate returns an ErrInvalid with a standard duplicate message. +func NewErrDuplicate(name string, value any) ErrInvalid { + return ErrInvalid{Name: name, Value: value, Msg: "duplicate - must be unique"} } type ErrMissing struct { @@ -40,3 +48,38 @@ type KeyNotFoundError struct { func (e KeyNotFoundError) Error() string { return fmt.Sprintf("unable to find %s key with id %s", e.KeyType, e.ID) } + +// UniqueStrings is a helper for tracking unique values in string form. +type UniqueStrings map[string]struct{} + +// IsDupeFmt is like IsDupe, but calls String(). +func (u UniqueStrings) IsDupeFmt(t fmt.Stringer) bool { + if t == nil { + return false + } + if reflect.ValueOf(t).IsNil() { + // interface holds a typed-nil value + return false + } + return u.isDupe(t.String()) +} + +// IsDupe returns true if the set already contains the string, otherwise false. +// Non-nil/empty strings are added to the set. +func (u UniqueStrings) IsDupe(s *string) bool { + if s == nil { + return false + } + return u.isDupe(*s) +} + +func (u UniqueStrings) isDupe(s string) bool { + if s == "" { + return false + } + _, ok := u[s] + if !ok { + u[s] = struct{}{} + } + return ok +} diff --git a/pkg/loop/README.md b/pkg/loop/README.md index 67b1ca60f..6fec980aa 100644 --- a/pkg/loop/README.md +++ b/pkg/loop/README.md @@ -120,7 +120,7 @@ sequenceDiagram The `pluginService` type contains reusable automatic recovery code. -`type pluginService[P grpcPlugin, S types.Service] struct` +`type pluginService[P grpcPlugin, S services.Service] struct` Each plugin implements their own interface (Relayer, Median, etc.) with a new type that also embeds a `pluginService`. This new **service** type implements the original interface, but internally manages re-starting and re-connecting to the plugin diff --git a/pkg/loop/internal/service.go b/pkg/loop/internal/service.go index 61d0e6d37..8003ca3d3 100644 --- a/pkg/loop/internal/service.go +++ b/pkg/loop/internal/service.go @@ -8,10 +8,10 @@ import ( "google.golang.org/protobuf/types/known/emptypb" "github.com/smartcontractkit/chainlink-relay/pkg/loop/internal/pb" - "github.com/smartcontractkit/chainlink-relay/pkg/types" + "github.com/smartcontractkit/chainlink-relay/pkg/services" ) -var _ types.Service = (*serviceClient)(nil) +var _ services.Service = (*serviceClient)(nil) type serviceClient struct { b *brokerExt @@ -66,7 +66,7 @@ var _ pb.ServiceServer = (*serviceServer)(nil) type serviceServer struct { pb.UnimplementedServiceServer - srv types.Service + srv services.Service } func (s *serviceServer) Close(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empty, error) { diff --git a/pkg/loop/internal/test/median.go b/pkg/loop/internal/test/median.go index 9ea530d56..925309a2c 100644 --- a/pkg/loop/internal/test/median.go +++ b/pkg/loop/internal/test/median.go @@ -16,7 +16,7 @@ import ( libocr "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/smartcontractkit/chainlink-relay/pkg/types" - "github.com/smartcontractkit/chainlink-relay/pkg/utils" + "github.com/smartcontractkit/chainlink-relay/pkg/utils/tests" ) func TestPluginMedian(t *testing.T, p types.PluginMedian) { @@ -29,7 +29,7 @@ type PluginMedianTest struct { func (m PluginMedianTest) TestPluginMedian(t *testing.T, p types.PluginMedian) { t.Run("PluginMedian", func(t *testing.T) { - ctx := utils.Context(t) + ctx := tests.Context(t) factory, err := p.NewMedianFactory(ctx, m.MedianProvider, &staticDataSource{value}, &staticDataSource{juelsPerFeeCoin}, &StaticErrorLog{}) require.NoError(t, err) @@ -44,7 +44,7 @@ func TestReportingPluginFactory(t *testing.T, factory types.ReportingPluginFacto assert.Equal(t, rpi, gotRPI) t.Cleanup(func() { assert.NoError(t, rp.Close()) }) t.Run("ReportingPlugin", func(t *testing.T) { - ctx := utils.Context(t) + ctx := tests.Context(t) gotQuery, err := rp.Query(ctx, reportContext.ReportTimestamp) require.NoError(t, err) assert.Equal(t, query, []byte(gotQuery)) diff --git a/pkg/loop/internal/test/relayer.go b/pkg/loop/internal/test/relayer.go index 12357e04e..ca4dca5bb 100644 --- a/pkg/loop/internal/test/relayer.go +++ b/pkg/loop/internal/test/relayer.go @@ -14,7 +14,7 @@ import ( "github.com/smartcontractkit/chainlink-relay/pkg/loop/internal" "github.com/smartcontractkit/chainlink-relay/pkg/types" - "github.com/smartcontractkit/chainlink-relay/pkg/utils" + "github.com/smartcontractkit/chainlink-relay/pkg/utils/tests" ) type StaticKeystore struct{} @@ -149,7 +149,7 @@ func newRelayArgsWithProviderType(_type types.OCR2PluginType) types.RelayArgs { } func TestPluginRelayer(t *testing.T, p internal.PluginRelayer) { - ctx := utils.Context(t) + ctx := tests.Context(t) t.Run("Relayer", func(t *testing.T) { relayer, err := p.NewRelayer(ctx, ConfigTOML, StaticKeystore{}) @@ -161,7 +161,7 @@ func TestPluginRelayer(t *testing.T, p internal.PluginRelayer) { } func TestRelayer(t *testing.T, relayer internal.Relayer) { - ctx := utils.Context(t) + ctx := tests.Context(t) t.Run("ConfigProvider", func(t *testing.T) { t.Parallel() diff --git a/pkg/loop/median_service_test.go b/pkg/loop/median_service_test.go index 2774212f9..c0e81ccfa 100644 --- a/pkg/loop/median_service_test.go +++ b/pkg/loop/median_service_test.go @@ -13,7 +13,7 @@ import ( "github.com/smartcontractkit/chainlink-relay/pkg/logger" "github.com/smartcontractkit/chainlink-relay/pkg/loop" "github.com/smartcontractkit/chainlink-relay/pkg/loop/internal/test" - "github.com/smartcontractkit/chainlink-relay/pkg/utils" + "github.com/smartcontractkit/chainlink-relay/pkg/utils/tests" ) func TestMedianService(t *testing.T) { @@ -22,7 +22,7 @@ func TestMedianService(t *testing.T) { return helperProcess(loop.PluginMedianName) }, test.StaticMedianProvider{}, test.StaticDataSource(), test.StaticJuelsPerFeeCoinDataSource(), &test.StaticErrorLog{}) hook := median.TestHook() - require.NoError(t, median.Start(utils.Context(t))) + require.NoError(t, median.Start(tests.Context(t))) t.Cleanup(func() { assert.NoError(t, median.Close()) }) t.Run("control", func(t *testing.T) { @@ -54,7 +54,7 @@ func TestMedianService_recovery(t *testing.T) { median := loop.NewMedianService(logger.Test(t), loop.GRPCOpts{}, func() *exec.Cmd { return helperProcess(loop.PluginMedianName, strconv.Itoa(int(limit.Add(1)))) }, test.StaticMedianProvider{}, test.StaticDataSource(), test.StaticJuelsPerFeeCoinDataSource(), &test.StaticErrorLog{}) - require.NoError(t, median.Start(utils.Context(t))) + require.NoError(t, median.Start(tests.Context(t))) t.Cleanup(func() { assert.NoError(t, median.Close()) }) test.TestReportingPluginFactory(t, median) diff --git a/pkg/loop/plugin_service.go b/pkg/loop/plugin_service.go index 7a40fa2ec..bbfe41772 100644 --- a/pkg/loop/plugin_service.go +++ b/pkg/loop/plugin_service.go @@ -9,13 +9,12 @@ import ( "time" "github.com/hashicorp/go-plugin" - "golang.org/x/exp/maps" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "github.com/smartcontractkit/chainlink-relay/pkg/logger" "github.com/smartcontractkit/chainlink-relay/pkg/loop/internal" - "github.com/smartcontractkit/chainlink-relay/pkg/types" + "github.com/smartcontractkit/chainlink-relay/pkg/services" "github.com/smartcontractkit/chainlink-relay/pkg/utils" ) @@ -31,7 +30,7 @@ type grpcPlugin interface { // pluginService is a [types.Service] wrapper that maintains an internal [types.Service] created from a [grpcPlugin] // client instance by launching and re-launching as necessary. -type pluginService[P grpcPlugin, S types.Service] struct { +type pluginService[P grpcPlugin, S services.Service] struct { utils.StartStopOnce pluginName string @@ -171,7 +170,7 @@ func (s *pluginService[P, S]) HealthReport() map[string]error { select { case <-s.serviceCh: hr := map[string]error{s.Name(): s.Healthy()} - maps.Copy(hr, s.service.HealthReport()) + services.CopyHealth(hr, s.service.HealthReport()) return hr default: return map[string]error{s.Name(): ErrPluginUnavailable} diff --git a/pkg/loop/plugin_service_test.go b/pkg/loop/plugin_service_test.go index d5d62b9f1..7f13d4842 100644 --- a/pkg/loop/plugin_service_test.go +++ b/pkg/loop/plugin_service_test.go @@ -1,7 +1,7 @@ package loop import ( - "github.com/smartcontractkit/chainlink-relay/pkg/types" + "github.com/smartcontractkit/chainlink-relay/pkg/services" ) const KeepAliveTickDuration = keepAliveTickDuration @@ -14,7 +14,7 @@ func (s *pluginService[P, S]) TestHook() TestPluginService[P, S] { } // TestPluginService supports Killing & Resetting a running *pluginService. -type TestPluginService[P grpcPlugin, S types.Service] chan<- func(*pluginService[P, S]) +type TestPluginService[P grpcPlugin, S services.Service] chan<- func(*pluginService[P, S]) func (ch TestPluginService[P, S]) Kill() { done := make(chan struct{}) diff --git a/pkg/loop/plugin_test.go b/pkg/loop/plugin_test.go index 09391a87f..643a32adb 100644 --- a/pkg/loop/plugin_test.go +++ b/pkg/loop/plugin_test.go @@ -16,11 +16,11 @@ import ( "github.com/smartcontractkit/chainlink-relay/pkg/logger" "github.com/smartcontractkit/chainlink-relay/pkg/loop" "github.com/smartcontractkit/chainlink-relay/pkg/loop/internal/test" - "github.com/smartcontractkit/chainlink-relay/pkg/utils" + "github.com/smartcontractkit/chainlink-relay/pkg/utils/tests" ) func testPlugin[I any](t *testing.T, name string, p plugin.Plugin, testFn func(*testing.T, I)) { - ctx, cancel := context.WithCancel(utils.Context(t)) + ctx, cancel := context.WithCancel(tests.Context(t)) defer cancel() ch := make(chan *plugin.ReattachConfig, 1) diff --git a/pkg/loop/relayer_service_test.go b/pkg/loop/relayer_service_test.go index 63289b674..620657734 100644 --- a/pkg/loop/relayer_service_test.go +++ b/pkg/loop/relayer_service_test.go @@ -13,7 +13,7 @@ import ( "github.com/smartcontractkit/chainlink-relay/pkg/logger" "github.com/smartcontractkit/chainlink-relay/pkg/loop" "github.com/smartcontractkit/chainlink-relay/pkg/loop/internal/test" - "github.com/smartcontractkit/chainlink-relay/pkg/utils" + "github.com/smartcontractkit/chainlink-relay/pkg/utils/tests" ) func TestRelayerService(t *testing.T) { @@ -22,7 +22,7 @@ func TestRelayerService(t *testing.T) { return helperProcess(loop.PluginRelayerName) }, test.ConfigTOML, test.StaticKeystore{}) hook := relayer.TestHook() - require.NoError(t, relayer.Start(utils.Context(t))) + require.NoError(t, relayer.Start(tests.Context(t))) t.Cleanup(func() { assert.NoError(t, relayer.Close()) }) t.Run("control", func(t *testing.T) { @@ -54,7 +54,7 @@ func TestRelayerService_recovery(t *testing.T) { relayer := loop.NewRelayerService(logger.Test(t), loop.GRPCOpts{}, func() *exec.Cmd { return helperProcess(loop.PluginRelayerName, strconv.Itoa(int(limit.Add(1)))) }, test.ConfigTOML, test.StaticKeystore{}) - require.NoError(t, relayer.Start(utils.Context(t))) + require.NoError(t, relayer.Start(tests.Context(t))) t.Cleanup(func() { assert.NoError(t, relayer.Close()) }) test.TestRelayer(t, relayer) diff --git a/pkg/monitoring/exporter_prometheus_test.go b/pkg/monitoring/exporter_prometheus_test.go index e59ee50da..a2813fa05 100644 --- a/pkg/monitoring/exporter_prometheus_test.go +++ b/pkg/monitoring/exporter_prometheus_test.go @@ -14,8 +14,7 @@ func TestPrometheusExporter(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() log := newNullLogger() - metrics := new(MetricsMock) - metrics.Test(t) + metrics := NewMetricsMock(t) factory := NewPrometheusExporterFactory(log, metrics) chainConfig := generateChainConfig() @@ -371,15 +370,12 @@ func TestPrometheusExporter(t *testing.T) { feedConfig.GetID(), // feedID ).Once() exporter.Cleanup(ctx) - - mock.AssertExpectationsForObjects(t, metrics) }) t.Run("should not emit metrics for stale transmissions", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() log := newNullLogger() - metrics := new(MetricsMock) - metrics.Test(t) + metrics := NewMetricsMock(t) factory := NewPrometheusExporterFactory(log, metrics) chainConfig := generateChainConfig() @@ -629,14 +625,12 @@ func TestPrometheusExporter(t *testing.T) { metrics.AssertNumberOfCalls(t, "SetOffchainAggregatorAnswersRaw", 1) metrics.AssertNumberOfCalls(t, "IncOffchainAggregatorAnswersTotal", 1) metrics.AssertNumberOfCalls(t, "SetOffchainAggregatorSubmissionReceivedValues", 1) - mock.AssertExpectationsForObjects(t, metrics) }) t.Run("should emit transaction results metrics", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() log := newNullLogger() - metrics := new(MetricsMock) - metrics.Test(t) + metrics := NewMetricsMock(t) factory := NewPrometheusExporterFactory(log, metrics) chainConfig := generateChainConfig() @@ -684,7 +678,5 @@ func TestPrometheusExporter(t *testing.T) { chainConfig.GetNetworkName(), // networkName ).Once() exporter.Export(ctx, txResults) - - mock.AssertExpectationsForObjects(t, metrics) }) } diff --git a/pkg/services/health.go b/pkg/services/health.go new file mode 100644 index 000000000..3596c8d51 --- /dev/null +++ b/pkg/services/health.go @@ -0,0 +1,24 @@ +package services + +import ( + "errors" + "testing" +) + +// CopyHealth copies health statuses from src to dest. Useful when implementing Service.HealthReport. +// If duplicate names are encountered, the errors are joined, unless testing in which case a panic is thrown. +func CopyHealth(dest, src map[string]error) { + for name, err := range src { + errOrig, ok := dest[name] + if ok { + if testing.Testing() { + panic("service names must be unique: duplicate name: " + name) + } + if errOrig != nil { + dest[name] = errors.Join(errOrig, err) + continue + } + } + dest[name] = err + } +} diff --git a/pkg/services/multi.go b/pkg/services/multi.go new file mode 100644 index 000000000..9d5620bd4 --- /dev/null +++ b/pkg/services/multi.go @@ -0,0 +1,95 @@ +package services + +import ( + "context" + "errors" + "io" + "sync" +) + +// StartClose is a subset of the ServiceCtx interface. +type StartClose interface { + Start(context.Context) error + Close() error +} + +// MultiStart is a utility for starting multiple services together. +// The set of started services is tracked internally, so that they can be closed if any single service fails to start. +type MultiStart struct { + started []StartClose +} + +// Start attempts to Start all services. If any service fails to start, the previously started services will be +// Closed, and an error returned. +func (m *MultiStart) Start(ctx context.Context, srvcs ...StartClose) (err error) { + for _, s := range srvcs { + err = m.start(ctx, s) + if err != nil { + return err + } + } + return +} + +func (m *MultiStart) start(ctx context.Context, s StartClose) (err error) { + err = s.Start(ctx) + if err != nil { + err = errors.Join(err, m.Close()) + } else { + m.started = append(m.started, s) + } + return +} + +// Close closes all started services, in reverse order. +func (m *MultiStart) Close() (err error) { + for i := len(m.started) - 1; i >= 0; i-- { + s := m.started[i] + err = errors.Join(err, s.Close()) + } + return +} + +// CloseBecause calls Close and returns reason along with any additional errors. +func (m *MultiStart) CloseBecause(reason error) (err error) { + return errors.Join(reason, m.Close()) +} + +// CloseAll closes all elements concurrently. +// Use this when you have various different types of io.Closer. +func CloseAll(cs ...io.Closer) error { + return multiCloser[io.Closer](cs).Close() +} + +// MultiCloser returns an io.Closer which closes all elements concurrently. +// Use this when you have a slice of a type which implements io.Closer. +// []io.Closer can be cast directly to MultiCloser. +func MultiCloser[C io.Closer](cs []C) io.Closer { + return multiCloser[C](cs) +} + +type multiCloser[C io.Closer] []C + +// Close closes all elements concurrently and joins any returned errors as one. +func (m multiCloser[C]) Close() (err error) { + if len(m) == 0 { + return nil + } + var wg sync.WaitGroup + wg.Add(len(m)) + errs := make(chan error, len(m)) + for _, s := range m { + go func(c io.Closer) { + defer wg.Done() + if e := c.Close(); e != nil { + errs <- e + } + }(s) + } + wg.Wait() + close(errs) + for e := range errs { + err = errors.Join(err, e) + } + return +} diff --git a/pkg/services/multi_example_test.go b/pkg/services/multi_example_test.go new file mode 100644 index 000000000..3ba342324 --- /dev/null +++ b/pkg/services/multi_example_test.go @@ -0,0 +1,90 @@ +package services + +import ( + "context" + "fmt" +) + +type Healthy string + +func (h Healthy) Start(ctx context.Context) error { + fmt.Println(h, "started") + return nil +} + +func (h Healthy) Close() error { + fmt.Println(h, "closed") + return nil +} + +type CloseFailure string + +func (c CloseFailure) Start(ctx context.Context) error { + fmt.Println(c, "started") + return nil +} + +func (c CloseFailure) Close() error { + fmt.Println(c, "close failure") + return fmt.Errorf("failed to close: %s", c) +} + +type WontStart string + +func (f WontStart) Start(ctx context.Context) error { + fmt.Println(f, "start failure") + return fmt.Errorf("failed to start: %s", f) +} + +func (f WontStart) Close() error { + fmt.Println(f, "close failure") + return fmt.Errorf("cannot call Close after failed Start: %s", f) +} + +func ExampleMultiStart() { + ctx := context.Background() + + a := Healthy("a") + b := CloseFailure("b") + c := WontStart("c") + + var ms MultiStart + if err := ms.Start(ctx, a, b, c); err != nil { + fmt.Println(err) + } + + // Output: + // a started + // b started + // c start failure + // b close failure + // a closed + // failed to start: c + // failed to close: b +} + +func ExampleMultiCloser() { + ctx := context.Background() + + f1 := CloseFailure("f") + f2 := CloseFailure("f") + cs := []CloseFailure{f1, f2} + + var ms MultiStart + if err := ms.Start(ctx, f1, f2); err != nil { + fmt.Println(err) + return + } + mc := MultiCloser(cs) + if err := mc.Close(); err != nil { + fmt.Println(err) + } + + // Output: + // f started + // f started + // f close failure + // f close failure + // failed to close: f + // failed to close: f +} diff --git a/pkg/services/state.go b/pkg/services/state.go new file mode 100644 index 000000000..abb01eae3 --- /dev/null +++ b/pkg/services/state.go @@ -0,0 +1,239 @@ +package services + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + + pkgerrors "github.com/pkg/errors" +) + +// defaultErrorBufferCap is the default cap on the errors an error buffer can store at any time +const defaultErrorBufferCap = 50 + +type errNotStarted struct { + state State +} + +func (e *errNotStarted) Error() string { + return fmt.Sprintf("service is %q, not started", e.state) +} + +var ( + ErrAlreadyStopped = errors.New("already stopped") + ErrCannotStopUnstarted = errors.New("cannot stop unstarted service") +) + +// StateMachine contains a State integer +type StateMachine struct { + state atomic.Int32 + sync.RWMutex // lock is held during startup/shutdown, RLock is held while executing functions dependent on a particular state + + // SvcErrBuffer is an ErrorBuffer that let service owners track critical errors happening in the service. + // + // SvcErrBuffer.SetCap(int) Overrides buffer limit from defaultErrorBufferCap + // SvcErrBuffer.Append(error) Appends an error to the buffer + // SvcErrBuffer.Flush() error returns all tracked errors as a single joined error + SvcErrBuffer ErrorBuffer +} + +// State holds the state for StateMachine +type State int32 + +// nolint +const ( + stateUnstarted State = iota + stateStarted + stateStarting + stateStartFailed + stateStopping + stateStopped + stateStopFailed +) + +func (s State) String() string { + switch s { + case stateUnstarted: + return "Unstarted" + case stateStarted: + return "Started" + case stateStarting: + return "Starting" + case stateStartFailed: + return "StartFailed" + case stateStopping: + return "Stopping" + case stateStopped: + return "Stopped" + case stateStopFailed: + return "StopFailed" + default: + return fmt.Sprintf("unrecognized state: %d", s) + } +} + +// StartOnce sets the state to Started +func (once *StateMachine) StartOnce(name string, fn func() error) error { + // SAFETY: We do this compare-and-swap outside of the lock so that + // concurrent StartOnce() calls return immediately. + success := once.state.CompareAndSwap(int32(stateUnstarted), int32(stateStarting)) + + if !success { + return pkgerrors.Errorf("%v has already been started once; state=%v", name, State(once.state.Load())) + } + + once.Lock() + defer once.Unlock() + + // Setting cap before calling startup fn in case of crits in startup + once.SvcErrBuffer.SetCap(defaultErrorBufferCap) + err := fn() + + if err == nil { + success = once.state.CompareAndSwap(int32(stateStarting), int32(stateStarted)) + } else { + success = once.state.CompareAndSwap(int32(stateStarting), int32(stateStartFailed)) + } + + if !success { + // SAFETY: If this is reached, something must be very wrong: once.state + // was tampered with outside of the lock. + panic(fmt.Sprintf("%v entered unreachable state, unable to set state to started", name)) + } + + return err +} + +// StopOnce sets the state to Stopped +func (once *StateMachine) StopOnce(name string, fn func() error) error { + // SAFETY: We hold the lock here so that Stop blocks until StartOnce + // executes. This ensures that a very fast call to Stop will wait for the + // code to finish starting up before teardown. + once.Lock() + defer once.Unlock() + + success := once.state.CompareAndSwap(int32(stateStarted), int32(stateStopping)) + + if !success { + state := once.state.Load() + switch state { + case int32(stateStopped): + return pkgerrors.Wrapf(ErrAlreadyStopped, "%s has already been stopped", name) + case int32(stateUnstarted): + return pkgerrors.Wrapf(ErrCannotStopUnstarted, "%s has not been started", name) + default: + return pkgerrors.Errorf("%v cannot be stopped from this state; state=%v", name, State(state)) + } + } + + err := fn() + + if err == nil { + success = once.state.CompareAndSwap(int32(stateStopping), int32(stateStopped)) + } else { + success = once.state.CompareAndSwap(int32(stateStopping), int32(stateStopFailed)) + } + + if !success { + // SAFETY: If this is reached, something must be very wrong: once.state + // was tampered with outside of the lock. + panic(fmt.Sprintf("%v entered unreachable state, unable to set state to stopped", name)) + } + + return err +} + +// State retrieves the current state +func (once *StateMachine) State() State { + state := once.state.Load() + return State(state) +} + +// IfStarted runs the func and returns true only if started, otherwise returns false +func (once *StateMachine) IfStarted(f func()) (ok bool) { + once.RLock() + defer once.RUnlock() + + state := once.state.Load() + + if State(state) == stateStarted { + f() + return true + } + return false +} + +// IfNotStopped runs the func and returns true if in any state other than Stopped +func (once *StateMachine) IfNotStopped(f func()) (ok bool) { + once.RLock() + defer once.RUnlock() + + state := once.state.Load() + + if State(state) == stateStopped { + return false + } + f() + return true +} + +// Ready returns ErrNotStarted if the state is not started. +func (once *StateMachine) Ready() error { + state := once.State() + if state == stateStarted { + return nil + } + return &errNotStarted{state: state} +} + +// Healthy returns ErrNotStarted if the state is not started. +// Override this per-service with more specific implementations. +func (once *StateMachine) Healthy() error { + state := once.State() + if state == stateStarted { + return once.SvcErrBuffer.Flush() + } + return &errNotStarted{state: state} +} + +// ErrorBuffer uses joinedErrors interface to join multiple errors into a single error. +// This is useful to track the most recent N errors in a service and flush them as a single error. +type ErrorBuffer struct { + // buffer is a slice of errors + buffer []error + + // cap is the maximum number of errors that the buffer can hold. + // Exceeding the cap results in discarding the oldest error + cap int + + mu sync.RWMutex +} + +func (eb *ErrorBuffer) Flush() (err error) { + eb.mu.RLock() + defer eb.mu.RUnlock() + err = errors.Join(eb.buffer...) + eb.buffer = nil + return +} + +func (eb *ErrorBuffer) Append(incoming error) { + eb.mu.Lock() + defer eb.mu.Unlock() + + if len(eb.buffer) == eb.cap && eb.cap != 0 { + eb.buffer = append(eb.buffer[1:], incoming) + return + } + eb.buffer = append(eb.buffer, incoming) +} + +func (eb *ErrorBuffer) SetCap(cap int) { + eb.mu.Lock() + defer eb.mu.Unlock() + if len(eb.buffer) > cap { + eb.buffer = eb.buffer[len(eb.buffer)-cap:] + } + eb.cap = cap +} diff --git a/pkg/services/types.go b/pkg/services/types.go new file mode 100644 index 000000000..604a5cb51 --- /dev/null +++ b/pkg/services/types.go @@ -0,0 +1,93 @@ +package services + +import "context" + +// Service represents a long-running service inside the Application. +// +// Typically, a Service will leverage utils.StateMachine to implement these +// calls in a safe manner. +// +// # Template +// +// Mockable Foo service with a run loop +// +// //go:generate mockery --quiet --name Foo --output ../internal/mocks/ --case=underscore +// type ( +// // Expose a public interface so we can mock the service. +// Foo interface { +// service.Service +// +// // ... +// } +// +// foo struct { +// // ... +// +// stop chan struct{} +// done chan struct{} +// +// utils.StartStopOnce +// } +// ) +// +// var _ Foo = (*foo)(nil) +// +// func NewFoo() Foo { +// f := &foo{ +// // ... +// } +// +// return f +// } +// +// func (f *foo) Start(ctx context.Context) error { +// return f.StartOnce("Foo", func() error { +// go f.run() +// +// return nil +// }) +// } +// +// func (f *foo) Close() error { +// return f.StopOnce("Foo", func() error { +// // trigger goroutine cleanup +// close(f.stop) +// // wait for cleanup to complete +// <-f.done +// return nil +// }) +// } +// +// func (f *foo) run() { +// // signal cleanup completion +// defer close(f.done) +// +// for { +// select { +// // ... +// case <-f.stop: +// // stop the routine +// return +// } +// } +// +// } +type Service interface { + // Start the service. Must quit immediately if the context is cancelled. + // The given context applies to Start function only and must not be retained. + Start(context.Context) error + // Close stops the Service. + // Invariants: Usually after this call the Service cannot be started + // again, you need to build a new Service to do so. + Close() error + + // Ready should return nil if ready, or an error message otherwise. From the k8s docs: + // > ready means it’s initialized and healthy means that it can accept traffic in kubernetes + // See: https://kubernetes.io/docs/tasks/configure-pod-container/configure-liveness-readiness-startup-probes/ + Ready() error + // HealthReport returns a full health report of the callee including it's dependencies. + // key is the dep name, value is nil if healthy, or error message otherwise. + HealthReport() map[string]error + // Name returns the fully qualified name of the component. Usually the logger name. + Name() string +} diff --git a/pkg/types/types.go b/pkg/types/types.go index cd5f6a628..98bc7d6a3 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -4,6 +4,7 @@ import ( "context" ) +// Deprecated: use services.Service type Service interface { Name() string Start(context.Context) error diff --git a/pkg/utils/start_stop_once.go b/pkg/utils/start_stop_once.go index 68a4f75c8..8b035d737 100644 --- a/pkg/utils/start_stop_once.go +++ b/pkg/utils/start_stop_once.go @@ -45,6 +45,7 @@ func (s startStopOnceState) String() string { } // StartStopOnce can be embedded in a struct to help implement types.Service. +// Deprecated: use services.StateMachine type StartStopOnce struct { state atomic.Int32 sync.RWMutex // lock is held during startup/shutdown, RLock is held while executing functions dependent on a particular state diff --git a/pkg/utils/tests/tests.go b/pkg/utils/tests/tests.go new file mode 100644 index 000000000..f6078d18e --- /dev/null +++ b/pkg/utils/tests/tests.go @@ -0,0 +1,35 @@ +package tests + +import ( + "context" + "testing" + "time" +) + +func Context(t *testing.T) context.Context { + ctx := context.Background() + var cancel func() + + if d, ok := t.Deadline(); ok { + ctx, cancel = context.WithDeadline(ctx, d) + } else { + ctx, cancel = context.WithCancel(ctx) + } + + t.Cleanup(cancel) + return ctx +} + +// DefaultWaitTimeout is the default wait timeout. If you have a *testing.T, use WaitTimeout instead. +const DefaultWaitTimeout = 30 * time.Second + +// WaitTimeout returns a timeout based on the test's Deadline, if available. +// Especially important to use in parallel tests, as their individual execution +// can get paused for arbitrary amounts of time. +func WaitTimeout(t *testing.T) time.Duration { + if d, ok := t.Deadline(); ok { + // 10% buffer for cleanup and scheduling delay + return time.Until(d) * 9 / 10 + } + return DefaultWaitTimeout +} diff --git a/pkg/utils/testutils.go b/pkg/utils/testutils.go index 5c129b18a..573412673 100644 --- a/pkg/utils/testutils.go +++ b/pkg/utils/testutils.go @@ -3,18 +3,11 @@ package utils import ( "context" "testing" + + "github.com/smartcontractkit/chainlink-relay/pkg/utils/tests" ) +// Deprecated: use tests.Context func Context(t *testing.T) context.Context { - ctx := context.Background() - var cancel func() - - if d, ok := t.Deadline(); ok { - ctx, cancel = context.WithDeadline(ctx, d) - } else { - ctx, cancel = context.WithCancel(ctx) - } - - t.Cleanup(cancel) - return ctx + return tests.Context(t) }