diff --git a/core/services/llo/observation_context.go b/core/services/llo/observation_context.go index ab022452629..5bf82fa5a79 100644 --- a/core/services/llo/observation_context.go +++ b/core/services/llo/observation_context.go @@ -157,7 +157,7 @@ func (e MissingStreamError) Error() string { } func (oc *observationContext) run(ctx context.Context, streamID streams.StreamID) (*pipeline.Run, pipeline.TaskRunResults, error) { - strm, exists := oc.r.Get(streamID) + p, exists := oc.r.Get(streamID) if !exists { return nil, nil, MissingStreamError{StreamID: streamID} } @@ -165,7 +165,7 @@ func (oc *observationContext) run(ctx context.Context, streamID streams.StreamID // In case of multiple streamIDs per pipeline then the // first call executes and the others wait for result oc.executionsMu.Lock() - ex, isExecuting := oc.executions[strm] + ex, isExecuting := oc.executions[p] if isExecuting { oc.executionsMu.Unlock() // wait for it to finish @@ -180,10 +180,10 @@ func (oc *observationContext) run(ctx context.Context, streamID streams.StreamID // execute here ch := make(chan struct{}) ex = &execution{done: ch} - oc.executions[strm] = ex + oc.executions[p] = ex oc.executionsMu.Unlock() - run, trrs, err := strm.Run(ctx) + run, trrs, err := p.Run(ctx) ex.run = run ex.trrs = trrs ex.err = err diff --git a/core/services/llo/observation_context_test.go b/core/services/llo/observation_context_test.go index 67af24c2a7b..1efe3ec7ee9 100644 --- a/core/services/llo/observation_context_test.go +++ b/core/services/llo/observation_context_test.go @@ -14,6 +14,7 @@ import ( "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" @@ -119,6 +120,31 @@ func TestObservationContext_Observe(t *testing.T) { }) } +func TestObservationContext_Observe_concurrencyStressTest(t *testing.T) { + ctx := tests.Context(t) + r := &mockRegistry{} + telem := &mockTelemeter{} + oc := newObservationContext(r, telem) + opts := llo.DSOpts(nil) + + streamID := streams.StreamID(1) + val := decimal.NewFromFloat(123.456) + + // observes the same pipeline 1000 times to try and detect races etc + r.pipelines = make(map[streams.StreamID]*mockPipeline) + r.pipelines[streamID] = makePipelineWithSingleResult[decimal.Decimal](0, val, nil) + g, ctx := errgroup.WithContext(ctx) + for i := 0; i < 1000; i++ { + g.Go(func() error { + _, err := oc.Observe(ctx, streamID, opts) + return err + }) + } + if err := g.Wait(); err != nil { + t.Fatalf("Observation failed: %v", err) + } +} + type mockPipelineConfig struct{} func (m *mockPipelineConfig) DefaultHTTPLimit() int64 { return 10000 } @@ -139,12 +165,12 @@ func (m *mockBridgeConfig) BridgeCacheTTL() time.Duration { return 0 } -func createBridge(t *testing.T, name string, val string, borm bridges.ORM) { +func createBridge(t *testing.T, name string, val string, borm bridges.ORM, maxCalls int) { callcount := 0 bridge := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { callcount++ - if callcount > 1 { - t.Fatal("expected only one call to the bridge") + if callcount > maxCalls { + panic(fmt.Sprintf("too many calls to bridge %s", name)) } _, herr := io.ReadAll(req.Body) if herr != nil { @@ -172,8 +198,8 @@ func TestObservationContext_Observe_integrationRealPipeline(t *testing.T) { db := pgtest.NewSqlxDB(t) bridgesORM := bridges.NewORM(db) - createBridge(t, "foo-bridge", `123.456`, bridgesORM) - createBridge(t, "bar-bridge", `"124.456"`, bridgesORM) + createBridge(t, "foo-bridge", `123.456`, bridgesORM, 1) + createBridge(t, "bar-bridge", `"124.456"`, bridgesORM, 1) c := clhttptest.NewTestLocalOnlyHTTPClient() runner := pipeline.NewRunner( @@ -242,3 +268,74 @@ result3 -> result3_parse -> multiply3; }, val.(*llo.Quote)) }) } + +func TestObservationContext_Observe_integrationRealPipeline_concurrencyStressTest(t *testing.T) { + ctx := tests.Context(t) + lggr := logger.TestLogger(t) + db := pgtest.NewSqlxDB(t) + bridgesORM := bridges.NewORM(db) + + createBridge(t, "foo-bridge", `123.456`, bridgesORM, 1) + createBridge(t, "bar-bridge", `"124.456"`, bridgesORM, 1) + + c := clhttptest.NewTestLocalOnlyHTTPClient() + runner := pipeline.NewRunner( + nil, + bridgesORM, + &mockPipelineConfig{}, + &mockBridgeConfig{}, + nil, + nil, + nil, + lggr, + c, + c, + ) + + r := streams.NewRegistry(lggr, runner) + + jobStreamID := streams.StreamID(5) + + jb := job.Job{ + Type: job.Stream, + StreamID: &jobStreamID, + PipelineSpec: &pipeline.Spec{ + DotDagSource: ` +// Benchmark Price +result1 [type=memo value="900.0022"]; +multiply2 [type=multiply times=1 streamID=1 index=0]; // force conversion to decimal + +result2 [type=bridge name="foo-bridge" requestData="{\"data\":{\"data\":\"foo\"}}"]; +result2_parse [type=jsonparse path="result" streamID=2 index=1]; + +result3 [type=bridge name="bar-bridge" requestData="{\"data\":{\"data\":\"bar\"}}"]; +result3_parse [type=jsonparse path="result"]; +multiply3 [type=multiply times=1 streamID=3 index=2]; // force conversion to decimal + +result1 -> multiply2; +result2 -> result2_parse; +result3 -> result3_parse -> multiply3; +`, + }, + } + err := r.Register(jb, nil) + require.NoError(t, err) + + telem := &mockTelemeter{} + oc := newObservationContext(r, telem) + opts := llo.DSOpts(nil) + + // concurrency stress test + oc = newObservationContext(r, telem) + g, ctx := errgroup.WithContext(ctx) + for i := 0; i < 1000; i++ { + strmID := streams.StreamID(1 + i%3) + g.Go(func() error { + _, err := oc.Observe(ctx, strmID, opts) + return err + }) + } + if err := g.Wait(); err != nil { + t.Fatalf("Observation failed: %v", err) + } +}