From c41cb14ac9cd29ca2e253d0dd58c570fe8643db7 Mon Sep 17 00:00:00 2001 From: dave vader <48764154+plyr4@users.noreply.github.com> Date: Thu, 17 Aug 2023 10:24:02 -0500 Subject: [PATCH] chore: add context to pipeline functions (#923) --- api/build/create.go | 4 +- api/build/restart.go | 4 +- api/pipeline/create.go | 3 +- api/pipeline/delete.go | 3 +- api/pipeline/list.go | 3 +- api/pipeline/update.go | 3 +- api/webhook/post.go | 4 +- cmd/vela-server/schedule.go | 4 +- database/integration_test.go | 18 ++++---- database/pipeline/count.go | 4 +- database/pipeline/count_repo.go | 4 +- database/pipeline/count_repo_test.go | 7 +-- database/pipeline/count_test.go | 7 +-- database/pipeline/create.go | 4 +- database/pipeline/create_test.go | 3 +- database/pipeline/delete.go | 4 +- database/pipeline/delete_test.go | 5 ++- database/pipeline/get.go | 4 +- database/pipeline/get_repo.go | 4 +- database/pipeline/get_repo_test.go | 5 ++- database/pipeline/get_test.go | 5 ++- database/pipeline/index.go | 4 +- database/pipeline/index_test.go | 3 +- database/pipeline/interface.go | 24 +++++----- database/pipeline/list.go | 6 ++- database/pipeline/list_repo.go | 6 ++- database/pipeline/list_repo_test.go | 7 +-- database/pipeline/list_test.go | 7 +-- database/pipeline/opts.go | 11 +++++ database/pipeline/opts_test.go | 50 +++++++++++++++++++++ database/pipeline/pipeline.go | 8 +++- database/pipeline/pipeline_test.go | 4 ++ database/pipeline/table.go | 8 +++- database/pipeline/table_test.go | 3 +- database/pipeline/update.go | 4 +- database/pipeline/update_test.go | 5 ++- database/resource.go | 1 + router/middleware/pipeline/pipeline.go | 3 +- router/middleware/pipeline/pipeline_test.go | 4 +- 39 files changed, 186 insertions(+), 74 deletions(-) diff --git a/api/build/create.go b/api/build/create.go index 0591945e5..ee0d0f64b 100644 --- a/api/build/create.go +++ b/api/build/create.go @@ -232,7 +232,7 @@ func CreateBuild(c *gin.Context) { ) // send API call to attempt to capture the pipeline - pipeline, err = database.FromContext(c).GetPipelineForRepo(input.GetCommit(), r) + pipeline, err = database.FromContext(c).GetPipelineForRepo(ctx, input.GetCommit(), r) if err != nil { // assume the pipeline doesn't exist in the database yet // send API call to capture the pipeline configuration file config, err = scm.FromContext(c).ConfigBackoff(u, r, input.GetCommit()) @@ -309,7 +309,7 @@ func CreateBuild(c *gin.Context) { pipeline.SetRef(input.GetRef()) // send API call to create the pipeline - pipeline, err = database.FromContext(c).CreatePipeline(pipeline) + pipeline, err = database.FromContext(c).CreatePipeline(ctx, pipeline) if err != nil { retErr := fmt.Errorf("unable to create new build: failed to create pipeline for %s: %w", r.GetFullName(), err) diff --git a/api/build/restart.go b/api/build/restart.go index 63847d2d3..804ad88bd 100644 --- a/api/build/restart.go +++ b/api/build/restart.go @@ -223,7 +223,7 @@ func RestartBuild(c *gin.Context) { ) // send API call to attempt to capture the pipeline - pipeline, err = database.FromContext(c).GetPipelineForRepo(b.GetCommit(), r) + pipeline, err = database.FromContext(c).GetPipelineForRepo(ctx, b.GetCommit(), r) if err != nil { // assume the pipeline doesn't exist in the database yet (before pipeline support was added) // send API call to capture the pipeline configuration file config, err = scm.FromContext(c).ConfigBackoff(u, r, b.GetCommit()) @@ -300,7 +300,7 @@ func RestartBuild(c *gin.Context) { pipeline.SetRef(b.GetRef()) // send API call to create the pipeline - pipeline, err = database.FromContext(c).CreatePipeline(pipeline) + pipeline, err = database.FromContext(c).CreatePipeline(ctx, pipeline) if err != nil { retErr := fmt.Errorf("unable to create pipeline for %s: %w", r.GetFullName(), err) diff --git a/api/pipeline/create.go b/api/pipeline/create.go index f6a4938bc..87bfb6634 100644 --- a/api/pipeline/create.go +++ b/api/pipeline/create.go @@ -70,6 +70,7 @@ func CreatePipeline(c *gin.Context) { o := org.Retrieve(c) r := repo.Retrieve(c) u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -98,7 +99,7 @@ func CreatePipeline(c *gin.Context) { input.SetRepoID(r.GetID()) // send API call to create the pipeline - p, err := database.FromContext(c).CreatePipeline(input) + p, err := database.FromContext(c).CreatePipeline(ctx, input) if err != nil { retErr := fmt.Errorf("unable to create pipeline %s/%s: %w", r.GetFullName(), input.GetCommit(), err) diff --git a/api/pipeline/delete.go b/api/pipeline/delete.go index da8b0ecf1..1012b3094 100644 --- a/api/pipeline/delete.go +++ b/api/pipeline/delete.go @@ -65,6 +65,7 @@ func DeletePipeline(c *gin.Context) { p := pipeline.Retrieve(c) r := repo.Retrieve(c) u := user.Retrieve(c) + ctx := c.Request.Context() entry := fmt.Sprintf("%s/%s", r.GetFullName(), p.GetCommit()) @@ -79,7 +80,7 @@ func DeletePipeline(c *gin.Context) { }).Infof("deleting pipeline %s", entry) // send API call to remove the build - err := database.FromContext(c).DeletePipeline(p) + err := database.FromContext(c).DeletePipeline(ctx, p) if err != nil { retErr := fmt.Errorf("unable to delete pipeline %s: %w", entry, err) diff --git a/api/pipeline/list.go b/api/pipeline/list.go index a2582eb08..a43f33788 100644 --- a/api/pipeline/list.go +++ b/api/pipeline/list.go @@ -80,6 +80,7 @@ func ListPipelines(c *gin.Context) { o := org.Retrieve(c) r := repo.Retrieve(c) u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -117,7 +118,7 @@ func ListPipelines(c *gin.Context) { //nolint:gomnd // ignore magic number perPage = util.MaxInt(1, util.MinInt(100, perPage)) - p, t, err := database.FromContext(c).ListPipelinesForRepo(r, page, perPage) + p, t, err := database.FromContext(c).ListPipelinesForRepo(ctx, r, page, perPage) if err != nil { retErr := fmt.Errorf("unable to list pipelines for repo %s: %w", r.GetFullName(), err) diff --git a/api/pipeline/update.go b/api/pipeline/update.go index 9c10d40b6..8c349f5bb 100644 --- a/api/pipeline/update.go +++ b/api/pipeline/update.go @@ -72,6 +72,7 @@ func UpdatePipeline(c *gin.Context) { p := pipeline.Retrieve(c) r := repo.Retrieve(c) u := user.Retrieve(c) + ctx := c.Request.Context() entry := fmt.Sprintf("%s/%s", r.GetFullName(), p.GetCommit()) @@ -170,7 +171,7 @@ func UpdatePipeline(c *gin.Context) { } // send API call to update the pipeline - p, err = database.FromContext(c).UpdatePipeline(p) + p, err = database.FromContext(c).UpdatePipeline(ctx, p) if err != nil { retErr := fmt.Errorf("unable to update pipeline %s: %w", entry, err) diff --git a/api/webhook/post.go b/api/webhook/post.go index 8d5ca369b..ce11a614f 100644 --- a/api/webhook/post.go +++ b/api/webhook/post.go @@ -427,7 +427,7 @@ func PostWebhook(c *gin.Context) { } // send API call to attempt to capture the pipeline - pipeline, err = database.FromContext(c).GetPipelineForRepo(b.GetCommit(), repo) + pipeline, err = database.FromContext(c).GetPipelineForRepo(ctx, b.GetCommit(), repo) if err != nil { // assume the pipeline doesn't exist in the database yet // send API call to capture the pipeline configuration file config, err = scm.FromContext(c).ConfigBackoff(u, repo, b.GetCommit()) @@ -562,7 +562,7 @@ func PostWebhook(c *gin.Context) { pipeline.SetRef(b.GetRef()) // send API call to create the pipeline - pipeline, err = database.FromContext(c).CreatePipeline(pipeline) + pipeline, err = database.FromContext(c).CreatePipeline(ctx, pipeline) if err != nil { retErr := fmt.Errorf("%s: failed to create pipeline for %s: %w", baseErr, repo.GetFullName(), err) diff --git a/cmd/vela-server/schedule.go b/cmd/vela-server/schedule.go index 804da21f3..a8d927a1b 100644 --- a/cmd/vela-server/schedule.go +++ b/cmd/vela-server/schedule.go @@ -237,7 +237,7 @@ func processSchedule(ctx context.Context, s *library.Schedule, compiler compiler } // send API call to attempt to capture the pipeline - pipeline, err = database.GetPipelineForRepo(b.GetCommit(), r) + pipeline, err = database.GetPipelineForRepo(context.TODO(), b.GetCommit(), r) if err != nil { // assume the pipeline doesn't exist in the database yet // send API call to capture the pipeline configuration file config, err = scm.ConfigBackoff(u, r, b.GetCommit()) @@ -326,7 +326,7 @@ func processSchedule(ctx context.Context, s *library.Schedule, compiler compiler pipeline.SetRef(b.GetRef()) // send API call to create the pipeline - pipeline, err = database.CreatePipeline(pipeline) + pipeline, err = database.CreatePipeline(context.TODO(), pipeline) if err != nil { err = fmt.Errorf("failed to create pipeline for %s: %w", r.GetFullName(), err) diff --git a/database/integration_test.go b/database/integration_test.go index d8f8a6e67..d8dede18b 100644 --- a/database/integration_test.go +++ b/database/integration_test.go @@ -664,7 +664,7 @@ func testPipelines(t *testing.T, db Interface, resources *Resources) { // create the pipelines for _, pipeline := range resources.Pipelines { - _, err := db.CreatePipeline(pipeline) + _, err := db.CreatePipeline(context.TODO(), pipeline) if err != nil { t.Errorf("unable to create pipeline %d: %v", pipeline.GetID(), err) } @@ -672,7 +672,7 @@ func testPipelines(t *testing.T, db Interface, resources *Resources) { methods["CreatePipeline"] = true // count the pipelines - count, err := db.CountPipelines() + count, err := db.CountPipelines(context.TODO()) if err != nil { t.Errorf("unable to count pipelines: %v", err) } @@ -682,7 +682,7 @@ func testPipelines(t *testing.T, db Interface, resources *Resources) { methods["CountPipelines"] = true // count the pipelines for a repo - count, err = db.CountPipelinesForRepo(resources.Repos[0]) + count, err = db.CountPipelinesForRepo(context.TODO(), resources.Repos[0]) if err != nil { t.Errorf("unable to count pipelines for repo %d: %v", resources.Repos[0].GetID(), err) } @@ -692,7 +692,7 @@ func testPipelines(t *testing.T, db Interface, resources *Resources) { methods["CountPipelinesForRepo"] = true // list the pipelines - list, err := db.ListPipelines() + list, err := db.ListPipelines(context.TODO()) if err != nil { t.Errorf("unable to list pipelines: %v", err) } @@ -702,7 +702,7 @@ func testPipelines(t *testing.T, db Interface, resources *Resources) { methods["ListPipelines"] = true // list the pipelines for a repo - list, count, err = db.ListPipelinesForRepo(resources.Repos[0], 1, 10) + list, count, err = db.ListPipelinesForRepo(context.TODO(), resources.Repos[0], 1, 10) if err != nil { t.Errorf("unable to list pipelines for repo %d: %v", resources.Repos[0].GetID(), err) } @@ -717,7 +717,7 @@ func testPipelines(t *testing.T, db Interface, resources *Resources) { // lookup the pipelines by name for _, pipeline := range resources.Pipelines { repo := resources.Repos[pipeline.GetRepoID()-1] - got, err := db.GetPipelineForRepo(pipeline.GetCommit(), repo) + got, err := db.GetPipelineForRepo(context.TODO(), pipeline.GetCommit(), repo) if err != nil { t.Errorf("unable to get pipeline %d for repo %d: %v", pipeline.GetID(), repo.GetID(), err) } @@ -730,13 +730,13 @@ func testPipelines(t *testing.T, db Interface, resources *Resources) { // update the pipelines for _, pipeline := range resources.Pipelines { pipeline.SetVersion("2") - _, err = db.UpdatePipeline(pipeline) + _, err = db.UpdatePipeline(context.TODO(), pipeline) if err != nil { t.Errorf("unable to update pipeline %d: %v", pipeline.GetID(), err) } // lookup the pipeline by ID - got, err := db.GetPipeline(pipeline.GetID()) + got, err := db.GetPipeline(context.TODO(), pipeline.GetID()) if err != nil { t.Errorf("unable to get pipeline %d by ID: %v", pipeline.GetID(), err) } @@ -749,7 +749,7 @@ func testPipelines(t *testing.T, db Interface, resources *Resources) { // delete the pipelines for _, pipeline := range resources.Pipelines { - err = db.DeletePipeline(pipeline) + err = db.DeletePipeline(context.TODO(), pipeline) if err != nil { t.Errorf("unable to delete pipeline %d: %v", pipeline.GetID(), err) } diff --git a/database/pipeline/count.go b/database/pipeline/count.go index 67845adff..33377d105 100644 --- a/database/pipeline/count.go +++ b/database/pipeline/count.go @@ -5,11 +5,13 @@ package pipeline import ( + "context" + "github.com/go-vela/types/constants" ) // CountPipelines gets the count of all pipelines from the database. -func (e *engine) CountPipelines() (int64, error) { +func (e *engine) CountPipelines(ctx context.Context) (int64, error) { e.logger.Tracef("getting count of all pipelines from the database") // variable to store query results diff --git a/database/pipeline/count_repo.go b/database/pipeline/count_repo.go index 50de5cae7..a3318a4e6 100644 --- a/database/pipeline/count_repo.go +++ b/database/pipeline/count_repo.go @@ -5,13 +5,15 @@ package pipeline import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/library" "github.com/sirupsen/logrus" ) // CountPipelinesForRepo gets the count of pipelines by repo ID from the database. -func (e *engine) CountPipelinesForRepo(r *library.Repo) (int64, error) { +func (e *engine) CountPipelinesForRepo(ctx context.Context, r *library.Repo) (int64, error) { e.logger.WithFields(logrus.Fields{ "org": r.GetOrg(), "repo": r.GetName(), diff --git a/database/pipeline/count_repo_test.go b/database/pipeline/count_repo_test.go index cc3650d12..0d2469e77 100644 --- a/database/pipeline/count_repo_test.go +++ b/database/pipeline/count_repo_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "reflect" "testing" @@ -42,12 +43,12 @@ func TestPipeline_Engine_CountPipelinesForRepo(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreatePipeline(_pipelineOne) + _, err := _sqlite.CreatePipeline(context.TODO(), _pipelineOne) if err != nil { t.Errorf("unable to create test pipeline for sqlite: %v", err) } - _, err = _sqlite.CreatePipeline(_pipelineTwo) + _, err = _sqlite.CreatePipeline(context.TODO(), _pipelineTwo) if err != nil { t.Errorf("unable to create test pipeline for sqlite: %v", err) } @@ -76,7 +77,7 @@ func TestPipeline_Engine_CountPipelinesForRepo(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.CountPipelinesForRepo(&library.Repo{ID: _pipelineOne.RepoID}) + got, err := test.database.CountPipelinesForRepo(context.TODO(), &library.Repo{ID: _pipelineOne.RepoID}) if test.failure { if err == nil { diff --git a/database/pipeline/count_test.go b/database/pipeline/count_test.go index bbc654fb5..6b638420a 100644 --- a/database/pipeline/count_test.go +++ b/database/pipeline/count_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "reflect" "testing" @@ -41,12 +42,12 @@ func TestPipeline_Engine_CountPipelines(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreatePipeline(_pipelineOne) + _, err := _sqlite.CreatePipeline(context.TODO(), _pipelineOne) if err != nil { t.Errorf("unable to create test pipeline for sqlite: %v", err) } - _, err = _sqlite.CreatePipeline(_pipelineTwo) + _, err = _sqlite.CreatePipeline(context.TODO(), _pipelineTwo) if err != nil { t.Errorf("unable to create test pipeline for sqlite: %v", err) } @@ -75,7 +76,7 @@ func TestPipeline_Engine_CountPipelines(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.CountPipelines() + got, err := test.database.CountPipelines(context.TODO()) if test.failure { if err == nil { diff --git a/database/pipeline/create.go b/database/pipeline/create.go index 424c24c23..04b344c95 100644 --- a/database/pipeline/create.go +++ b/database/pipeline/create.go @@ -5,6 +5,8 @@ package pipeline import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // CreatePipeline creates a new pipeline in the database. -func (e *engine) CreatePipeline(p *library.Pipeline) (*library.Pipeline, error) { +func (e *engine) CreatePipeline(ctx context.Context, p *library.Pipeline) (*library.Pipeline, error) { e.logger.WithFields(logrus.Fields{ "pipeline": p.GetCommit(), }).Tracef("creating pipeline %s in the database", p.GetCommit()) diff --git a/database/pipeline/create_test.go b/database/pipeline/create_test.go index 5a19c7e84..9077dad5e 100644 --- a/database/pipeline/create_test.go +++ b/database/pipeline/create_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "reflect" "testing" @@ -59,7 +60,7 @@ VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15) RETURNING "id"`). // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.CreatePipeline(_pipeline) + got, err := test.database.CreatePipeline(context.TODO(), _pipeline) if test.failure { if err == nil { diff --git a/database/pipeline/delete.go b/database/pipeline/delete.go index eb03db88c..d473d5d4f 100644 --- a/database/pipeline/delete.go +++ b/database/pipeline/delete.go @@ -5,6 +5,8 @@ package pipeline import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // DeletePipeline deletes an existing pipeline from the database. -func (e *engine) DeletePipeline(p *library.Pipeline) error { +func (e *engine) DeletePipeline(ctx context.Context, p *library.Pipeline) error { e.logger.WithFields(logrus.Fields{ "pipeline": p.GetCommit(), }).Tracef("deleting pipeline %s from the database", p.GetCommit()) diff --git a/database/pipeline/delete_test.go b/database/pipeline/delete_test.go index 6bef174c7..39e47a7dd 100644 --- a/database/pipeline/delete_test.go +++ b/database/pipeline/delete_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -31,7 +32,7 @@ func TestPipeline_Engine_DeletePipeline(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreatePipeline(_pipeline) + _, err := _sqlite.CreatePipeline(context.TODO(), _pipeline) if err != nil { t.Errorf("unable to create test pipeline for sqlite: %v", err) } @@ -57,7 +58,7 @@ func TestPipeline_Engine_DeletePipeline(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err = test.database.DeletePipeline(_pipeline) + err = test.database.DeletePipeline(context.TODO(), _pipeline) if test.failure { if err == nil { diff --git a/database/pipeline/get.go b/database/pipeline/get.go index e8eb4b23a..17ec6ab78 100644 --- a/database/pipeline/get.go +++ b/database/pipeline/get.go @@ -5,13 +5,15 @@ package pipeline import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" ) // GetPipeline gets a pipeline by ID from the database. -func (e *engine) GetPipeline(id int64) (*library.Pipeline, error) { +func (e *engine) GetPipeline(ctx context.Context, id int64) (*library.Pipeline, error) { e.logger.Tracef("getting pipeline %d from the database", id) // variable to store query results diff --git a/database/pipeline/get_repo.go b/database/pipeline/get_repo.go index 1d431ca10..f165c18eb 100644 --- a/database/pipeline/get_repo.go +++ b/database/pipeline/get_repo.go @@ -5,6 +5,8 @@ package pipeline import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // GetPipelineForRepo gets a pipeline by number and repo ID from the database. -func (e *engine) GetPipelineForRepo(commit string, r *library.Repo) (*library.Pipeline, error) { +func (e *engine) GetPipelineForRepo(ctx context.Context, commit string, r *library.Repo) (*library.Pipeline, error) { e.logger.WithFields(logrus.Fields{ "org": r.GetOrg(), "pipeline": commit, diff --git a/database/pipeline/get_repo_test.go b/database/pipeline/get_repo_test.go index ab810ba5f..67161931d 100644 --- a/database/pipeline/get_repo_test.go +++ b/database/pipeline/get_repo_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "reflect" "testing" @@ -37,7 +38,7 @@ func TestPipeline_Engine_GetPipelineForRepo(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreatePipeline(_pipeline) + _, err := _sqlite.CreatePipeline(context.TODO(), _pipeline) if err != nil { t.Errorf("unable to create test pipeline for sqlite: %v", err) } @@ -66,7 +67,7 @@ func TestPipeline_Engine_GetPipelineForRepo(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.GetPipelineForRepo("48afb5bdc41ad69bf22588491333f7cf71135163", &library.Repo{ID: _pipeline.RepoID}) + got, err := test.database.GetPipelineForRepo(context.TODO(), "48afb5bdc41ad69bf22588491333f7cf71135163", &library.Repo{ID: _pipeline.RepoID}) if test.failure { if err == nil { diff --git a/database/pipeline/get_test.go b/database/pipeline/get_test.go index f7c3565d3..ca6d789f2 100644 --- a/database/pipeline/get_test.go +++ b/database/pipeline/get_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "reflect" "testing" @@ -37,7 +38,7 @@ func TestPipeline_Engine_GetPipeline(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreatePipeline(_pipeline) + _, err := _sqlite.CreatePipeline(context.TODO(), _pipeline) if err != nil { t.Errorf("unable to create test pipeline for sqlite: %v", err) } @@ -66,7 +67,7 @@ func TestPipeline_Engine_GetPipeline(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.GetPipeline(1) + got, err := test.database.GetPipeline(context.TODO(), 1) if test.failure { if err == nil { diff --git a/database/pipeline/index.go b/database/pipeline/index.go index 506fddfa8..3189b728d 100644 --- a/database/pipeline/index.go +++ b/database/pipeline/index.go @@ -4,6 +4,8 @@ package pipeline +import "context" + const ( // CreateRepoIDIndex represents a query to create an // index on the pipelines table for the repo_id column. @@ -16,7 +18,7 @@ ON pipelines (repo_id); ) // CreatePipelineIndexes creates the indexes for the pipelines table in the database. -func (e *engine) CreatePipelineIndexes() error { +func (e *engine) CreatePipelineIndexes(ctx context.Context) error { e.logger.Tracef("creating indexes for pipelines table in the database") // create the repo_id column index for the pipelines table diff --git a/database/pipeline/index_test.go b/database/pipeline/index_test.go index 1fa77b7b0..e72b5a593 100644 --- a/database/pipeline/index_test.go +++ b/database/pipeline/index_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -41,7 +42,7 @@ func TestPipeline_Engine_CreatePipelineIndexes(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err := test.database.CreatePipelineIndexes() + err := test.database.CreatePipelineIndexes(context.TODO()) if test.failure { if err == nil { diff --git a/database/pipeline/interface.go b/database/pipeline/interface.go index c28c30da8..ae75ebd54 100644 --- a/database/pipeline/interface.go +++ b/database/pipeline/interface.go @@ -5,6 +5,8 @@ package pipeline import ( + "context" + "github.com/go-vela/types/library" ) @@ -18,30 +20,30 @@ type PipelineInterface interface { // https://en.wikipedia.org/wiki/Data_definition_language // CreatePipelineIndexes defines a function that creates the indexes for the pipelines table. - CreatePipelineIndexes() error + CreatePipelineIndexes(context.Context) error // CreatePipelineTable defines a function that creates the pipelines table. - CreatePipelineTable(string) error + CreatePipelineTable(context.Context, string) error // Pipeline Data Manipulation Language Functions // // https://en.wikipedia.org/wiki/Data_manipulation_language // CountPipelines defines a function that gets the count of all pipelines. - CountPipelines() (int64, error) + CountPipelines(context.Context) (int64, error) // CountPipelinesForRepo defines a function that gets the count of pipelines by repo ID. - CountPipelinesForRepo(*library.Repo) (int64, error) + CountPipelinesForRepo(context.Context, *library.Repo) (int64, error) // CreatePipeline defines a function that creates a new pipeline. - CreatePipeline(*library.Pipeline) (*library.Pipeline, error) + CreatePipeline(context.Context, *library.Pipeline) (*library.Pipeline, error) // DeletePipeline defines a function that deletes an existing pipeline. - DeletePipeline(*library.Pipeline) error + DeletePipeline(context.Context, *library.Pipeline) error // GetPipeline defines a function that gets a pipeline by ID. - GetPipeline(int64) (*library.Pipeline, error) + GetPipeline(context.Context, int64) (*library.Pipeline, error) // GetPipelineForRepo defines a function that gets a pipeline by commit SHA and repo ID. - GetPipelineForRepo(string, *library.Repo) (*library.Pipeline, error) + GetPipelineForRepo(context.Context, string, *library.Repo) (*library.Pipeline, error) // ListPipelines defines a function that gets a list of all pipelines. - ListPipelines() ([]*library.Pipeline, error) + ListPipelines(context.Context) ([]*library.Pipeline, error) // ListPipelinesForRepo defines a function that gets a list of pipelines by repo ID. - ListPipelinesForRepo(*library.Repo, int, int) ([]*library.Pipeline, int64, error) + ListPipelinesForRepo(context.Context, *library.Repo, int, int) ([]*library.Pipeline, int64, error) // UpdatePipeline defines a function that updates an existing pipeline. - UpdatePipeline(*library.Pipeline) (*library.Pipeline, error) + UpdatePipeline(context.Context, *library.Pipeline) (*library.Pipeline, error) } diff --git a/database/pipeline/list.go b/database/pipeline/list.go index 9159a87aa..54f8015c4 100644 --- a/database/pipeline/list.go +++ b/database/pipeline/list.go @@ -5,13 +5,15 @@ package pipeline import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" ) // ListPipelines gets a list of all pipelines from the database. -func (e *engine) ListPipelines() ([]*library.Pipeline, error) { +func (e *engine) ListPipelines(ctx context.Context) ([]*library.Pipeline, error) { e.logger.Trace("listing all pipelines from the database") // variables to store query results and return value @@ -20,7 +22,7 @@ func (e *engine) ListPipelines() ([]*library.Pipeline, error) { pipelines := []*library.Pipeline{} // count the results - count, err := e.CountPipelines() + count, err := e.CountPipelines(ctx) if err != nil { return nil, err } diff --git a/database/pipeline/list_repo.go b/database/pipeline/list_repo.go index 609dcc340..e8c38afab 100644 --- a/database/pipeline/list_repo.go +++ b/database/pipeline/list_repo.go @@ -5,6 +5,8 @@ package pipeline import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -14,7 +16,7 @@ import ( // ListPipelinesForRepo gets a list of pipelines by repo ID from the database. // //nolint:lll // ignore long line length due to variable names -func (e *engine) ListPipelinesForRepo(r *library.Repo, page, perPage int) ([]*library.Pipeline, int64, error) { +func (e *engine) ListPipelinesForRepo(ctx context.Context, r *library.Repo, page, perPage int) ([]*library.Pipeline, int64, error) { e.logger.WithFields(logrus.Fields{ "org": r.GetOrg(), "repo": r.GetName(), @@ -26,7 +28,7 @@ func (e *engine) ListPipelinesForRepo(r *library.Repo, page, perPage int) ([]*li pipelines := []*library.Pipeline{} // count the results - count, err := e.CountPipelinesForRepo(r) + count, err := e.CountPipelinesForRepo(context.TODO(), r) if err != nil { return pipelines, 0, err } diff --git a/database/pipeline/list_repo_test.go b/database/pipeline/list_repo_test.go index dc64e8a0f..35cc2e769 100644 --- a/database/pipeline/list_repo_test.go +++ b/database/pipeline/list_repo_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "reflect" "testing" @@ -53,12 +54,12 @@ func TestPipeline_Engine_ListPipelinesForRepo(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreatePipeline(_pipelineOne) + _, err := _sqlite.CreatePipeline(context.TODO(), _pipelineOne) if err != nil { t.Errorf("unable to create test pipeline for sqlite: %v", err) } - _, err = _sqlite.CreatePipeline(_pipelineTwo) + _, err = _sqlite.CreatePipeline(context.TODO(), _pipelineTwo) if err != nil { t.Errorf("unable to create test pipeline for sqlite: %v", err) } @@ -87,7 +88,7 @@ func TestPipeline_Engine_ListPipelinesForRepo(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, _, err := test.database.ListPipelinesForRepo(&library.Repo{ID: _pipelineOne.RepoID}, 1, 10) + got, _, err := test.database.ListPipelinesForRepo(context.TODO(), &library.Repo{ID: _pipelineOne.RepoID}, 1, 10) if test.failure { if err == nil { diff --git a/database/pipeline/list_test.go b/database/pipeline/list_test.go index 36f65199d..e1d14166e 100644 --- a/database/pipeline/list_test.go +++ b/database/pipeline/list_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "reflect" "testing" @@ -53,12 +54,12 @@ func TestPipeline_Engine_ListPipelines(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreatePipeline(_pipelineOne) + _, err := _sqlite.CreatePipeline(context.TODO(), _pipelineOne) if err != nil { t.Errorf("unable to create test pipeline for sqlite: %v", err) } - _, err = _sqlite.CreatePipeline(_pipelineTwo) + _, err = _sqlite.CreatePipeline(context.TODO(), _pipelineTwo) if err != nil { t.Errorf("unable to create test pipeline for sqlite: %v", err) } @@ -87,7 +88,7 @@ func TestPipeline_Engine_ListPipelines(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.ListPipelines() + got, err := test.database.ListPipelines(context.TODO()) if test.failure { if err == nil { diff --git a/database/pipeline/opts.go b/database/pipeline/opts.go index f04796ff7..e4d1e6d46 100644 --- a/database/pipeline/opts.go +++ b/database/pipeline/opts.go @@ -5,6 +5,8 @@ package pipeline import ( + "context" + "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -52,3 +54,12 @@ func WithSkipCreation(skipCreation bool) EngineOpt { return nil } } + +// WithContext sets the context in the database engine for Pipelines. +func WithContext(ctx context.Context) EngineOpt { + return func(e *engine) error { + e.ctx = ctx + + return nil + } +} diff --git a/database/pipeline/opts_test.go b/database/pipeline/opts_test.go index 94822e33b..2380bdda7 100644 --- a/database/pipeline/opts_test.go +++ b/database/pipeline/opts_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "reflect" "testing" @@ -214,3 +215,52 @@ func TestPipeline_EngineOpt_WithSkipCreation(t *testing.T) { }) } } + +func TestPipeline_EngineOpt_WithContext(t *testing.T) { + // setup types + e := &engine{config: new(config)} + + // setup tests + tests := []struct { + failure bool + name string + ctx context.Context + want context.Context + }{ + { + failure: false, + name: "context set to TODO", + ctx: context.TODO(), + want: context.TODO(), + }, + { + failure: false, + name: "context set to nil", + ctx: nil, + want: nil, + }, + } + + // run tests + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := WithContext(test.ctx)(e) + + if test.failure { + if err == nil { + t.Errorf("WithContext for %s should have returned err", test.name) + } + + return + } + + if err != nil { + t.Errorf("WithContext returned err: %v", err) + } + + if !reflect.DeepEqual(e.ctx, test.want) { + t.Errorf("WithContext is %v, want %v", e.ctx, test.want) + } + }) + } +} diff --git a/database/pipeline/pipeline.go b/database/pipeline/pipeline.go index a48cc6e07..7c9f29d8f 100644 --- a/database/pipeline/pipeline.go +++ b/database/pipeline/pipeline.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "fmt" "github.com/go-vela/types/constants" @@ -27,6 +28,8 @@ type ( // engine configuration settings used in pipeline functions config *config + ctx context.Context + // gorm.io/gorm database client used in pipeline functions // // https://pkg.go.dev/gorm.io/gorm#DB @@ -50,6 +53,7 @@ func New(opts ...EngineOpt) (*engine, error) { e.client = new(gorm.DB) e.config = new(config) e.logger = new(logrus.Entry) + e.ctx = context.TODO() // apply all provided configuration options for _, opt := range opts { @@ -67,13 +71,13 @@ func New(opts ...EngineOpt) (*engine, error) { } // create the pipelines table - err := e.CreatePipelineTable(e.client.Config.Dialector.Name()) + err := e.CreatePipelineTable(e.ctx, e.client.Config.Dialector.Name()) if err != nil { return nil, fmt.Errorf("unable to create %s table: %w", constants.TablePipeline, err) } // create the indexes for the pipelines table - err = e.CreatePipelineIndexes() + err = e.CreatePipelineIndexes(e.ctx) if err != nil { return nil, fmt.Errorf("unable to create indexes for %s table: %w", constants.TablePipeline, err) } diff --git a/database/pipeline/pipeline_test.go b/database/pipeline/pipeline_test.go index f478fe37f..13d01393a 100644 --- a/database/pipeline/pipeline_test.go +++ b/database/pipeline/pipeline_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "database/sql/driver" "reflect" "testing" @@ -65,6 +66,7 @@ func TestPipeline_New(t *testing.T) { want: &engine{ client: _postgres, config: &config{CompressionLevel: 1, SkipCreation: false}, + ctx: context.TODO(), logger: logger, }, }, @@ -78,6 +80,7 @@ func TestPipeline_New(t *testing.T) { want: &engine{ client: _sqlite, config: &config{CompressionLevel: 1, SkipCreation: false}, + ctx: context.TODO(), logger: logger, }, }, @@ -87,6 +90,7 @@ func TestPipeline_New(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { got, err := New( + WithContext(context.TODO()), WithClient(test.client), WithCompressionLevel(test.level), WithLogger(test.logger), diff --git a/database/pipeline/table.go b/database/pipeline/table.go index bc463da68..eef8b885a 100644 --- a/database/pipeline/table.go +++ b/database/pipeline/table.go @@ -4,7 +4,11 @@ package pipeline -import "github.com/go-vela/types/constants" +import ( + "context" + + "github.com/go-vela/types/constants" +) const ( // CreatePostgresTable represents a query to create the Postgres pipelines table. @@ -57,7 +61,7 @@ pipelines ( ) // CreatePipelineTable creates the pipelines table in the database. -func (e *engine) CreatePipelineTable(driver string) error { +func (e *engine) CreatePipelineTable(ctx context.Context, driver string) error { e.logger.Tracef("creating pipelines table in the database") // handle the driver provided to create the table diff --git a/database/pipeline/table_test.go b/database/pipeline/table_test.go index 5b05e4313..72add7aee 100644 --- a/database/pipeline/table_test.go +++ b/database/pipeline/table_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -41,7 +42,7 @@ func TestPipeline_Engine_CreatePipelineTable(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err := test.database.CreatePipelineTable(test.name) + err := test.database.CreatePipelineTable(context.TODO(), test.name) if test.failure { if err == nil { diff --git a/database/pipeline/update.go b/database/pipeline/update.go index 59b552164..64d9dd561 100644 --- a/database/pipeline/update.go +++ b/database/pipeline/update.go @@ -5,6 +5,8 @@ package pipeline import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // UpdatePipeline updates an existing pipeline in the database. -func (e *engine) UpdatePipeline(p *library.Pipeline) (*library.Pipeline, error) { +func (e *engine) UpdatePipeline(ctx context.Context, p *library.Pipeline) (*library.Pipeline, error) { e.logger.WithFields(logrus.Fields{ "pipeline": p.GetCommit(), }).Tracef("updating pipeline %s in the database", p.GetCommit()) diff --git a/database/pipeline/update_test.go b/database/pipeline/update_test.go index d321d63b5..f8ab6b3d3 100644 --- a/database/pipeline/update_test.go +++ b/database/pipeline/update_test.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "reflect" "testing" @@ -35,7 +36,7 @@ WHERE "id" = $15`). _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreatePipeline(_pipeline) + _, err := _sqlite.CreatePipeline(context.TODO(), _pipeline) if err != nil { t.Errorf("unable to create test pipeline for sqlite: %v", err) } @@ -61,7 +62,7 @@ WHERE "id" = $15`). // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.UpdatePipeline(_pipeline) + got, err := test.database.UpdatePipeline(context.TODO(), _pipeline) if test.failure { if err == nil { diff --git a/database/resource.go b/database/resource.go index a2e177353..e6784b2e1 100644 --- a/database/resource.go +++ b/database/resource.go @@ -58,6 +58,7 @@ func (e *engine) NewResources(ctx context.Context) error { // create the database agnostic engine for pipelines e.PipelineInterface, err = pipeline.New( + pipeline.WithContext(e.ctx), pipeline.WithClient(e.client), pipeline.WithCompressionLevel(e.config.CompressionLevel), pipeline.WithLogger(e.logger), diff --git a/router/middleware/pipeline/pipeline.go b/router/middleware/pipeline/pipeline.go index 1910e2c0f..b3f2e57af 100644 --- a/router/middleware/pipeline/pipeline.go +++ b/router/middleware/pipeline/pipeline.go @@ -32,6 +32,7 @@ func Establish() gin.HandlerFunc { o := org.Retrieve(c) r := repo.Retrieve(c) u := user.Retrieve(c) + ctx := c.Request.Context() if r == nil { retErr := fmt.Errorf("repo %s/%s not found", util.PathParameter(c, "org"), util.PathParameter(c, "repo")) @@ -62,7 +63,7 @@ func Establish() gin.HandlerFunc { "user": u.GetName(), }).Debugf("reading pipeline %s", entry) - pipeline, err := database.FromContext(c).GetPipelineForRepo(p, r) + pipeline, err := database.FromContext(c).GetPipelineForRepo(ctx, p, r) if err != nil { // assume the pipeline doesn't exist in the database yet (before pipeline support was added) // send API call to capture the pipeline configuration file config, err := scm.FromContext(c).ConfigBackoff(u, r, p) diff --git a/router/middleware/pipeline/pipeline_test.go b/router/middleware/pipeline/pipeline_test.go index ff410c1da..0fc1e8b29 100644 --- a/router/middleware/pipeline/pipeline_test.go +++ b/router/middleware/pipeline/pipeline_test.go @@ -103,13 +103,13 @@ func TestPipeline_Establish(t *testing.T) { } defer func() { - db.DeletePipeline(want) + db.DeletePipeline(context.TODO(), want) db.DeleteRepo(context.TODO(), r) db.Close() }() _, _ = db.CreateRepo(context.TODO(), r) - _, _ = db.CreatePipeline(want) + _, _ = db.CreatePipeline(context.TODO(), want) // setup context gin.SetMode(gin.TestMode)