From dd53c1aeff6d35fc51f3218f9c26dd256a596f62 Mon Sep 17 00:00:00 2001 From: davidvader Date: Mon, 28 Aug 2023 09:51:13 -0500 Subject: [PATCH] enhance: add context to Workers --- api/build/cancel.go | 2 +- api/metrics.go | 2 +- api/worker/create.go | 3 +- api/worker/delete.go | 3 +- api/worker/get.go | 3 +- api/worker/list.go | 3 +- api/worker/refresh.go | 3 +- api/worker/update.go | 5 ++- database/integration_test.go | 14 +++---- database/resource.go | 1 + database/worker/count.go | 4 +- database/worker/count_test.go | 7 ++-- database/worker/create.go | 4 +- database/worker/create_test.go | 3 +- database/worker/delete.go | 4 +- database/worker/delete_test.go | 5 ++- database/worker/get.go | 4 +- database/worker/get_hostname.go | 4 +- database/worker/get_hostname_test.go | 5 ++- database/worker/get_test.go | 5 ++- database/worker/index.go | 4 +- database/worker/index_test.go | 3 +- database/worker/interface.go | 20 +++++----- database/worker/list.go | 6 ++- database/worker/list_test.go | 7 ++-- database/worker/opts.go | 11 ++++++ database/worker/opts_test.go | 50 ++++++++++++++++++++++++ database/worker/table.go | 4 +- database/worker/table_test.go | 3 +- database/worker/update.go | 4 +- database/worker/update_test.go | 5 ++- database/worker/worker.go | 7 +++- router/middleware/executors/executors.go | 3 +- router/middleware/worker/worker.go | 4 +- router/middleware/worker/worker_test.go | 5 ++- 35 files changed, 162 insertions(+), 58 deletions(-) diff --git a/api/build/cancel.go b/api/build/cancel.go index 55dc7cecf..8b75dc470 100644 --- a/api/build/cancel.go +++ b/api/build/cancel.go @@ -97,7 +97,7 @@ func CancelBuild(c *gin.Context) { switch b.GetStatus() { case constants.StatusRunning: // retrieve the worker info - w, err := database.FromContext(c).GetWorkerForHostname(b.GetHost()) + w, err := database.FromContext(c).GetWorkerForHostname(ctx, b.GetHost()) if err != nil { retErr := fmt.Errorf("unable to get worker for build %s: %w", entry, err) util.HandleError(c, http.StatusNotFound, retErr) diff --git a/api/metrics.go b/api/metrics.go index 08994e65c..cb3832b9b 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -424,7 +424,7 @@ func recordGauges(c *gin.Context) { // worker_build_limit, active_worker_count, inactive_worker_count, idle_worker_count, available_worker_count, busy_worker_count, error_worker_count if q.WorkerBuildLimit || q.ActiveWorkerCount || q.InactiveWorkerCount || q.IdleWorkerCount || q.AvailableWorkerCount || q.BusyWorkerCount || q.ErrorWorkerCount { // send API call to capture the workers - workers, err := database.FromContext(c).ListWorkers() + workers, err := database.FromContext(c).ListWorkers(ctx) if err != nil { logrus.Errorf("unable to get workers: %v", err) } diff --git a/api/worker/create.go b/api/worker/create.go index b7d1bf286..12a08897b 100644 --- a/api/worker/create.go +++ b/api/worker/create.go @@ -57,6 +57,7 @@ func CreateWorker(c *gin.Context) { // capture middleware values u := user.Retrieve(c) cl := claims.Retrieve(c) + ctx := c.Request.Context() // capture body from API request input := new(library.Worker) @@ -89,7 +90,7 @@ func CreateWorker(c *gin.Context) { "worker": input.GetHostname(), }).Infof("creating new worker %s", input.GetHostname()) - err = database.FromContext(c).CreateWorker(input) + err = database.FromContext(c).CreateWorker(ctx, input) if err != nil { retErr := fmt.Errorf("unable to create worker: %w", err) diff --git a/api/worker/delete.go b/api/worker/delete.go index fa0ad5e65..4857e119a 100644 --- a/api/worker/delete.go +++ b/api/worker/delete.go @@ -47,6 +47,7 @@ func DeleteWorker(c *gin.Context) { // capture middleware values u := user.Retrieve(c) w := worker.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -57,7 +58,7 @@ func DeleteWorker(c *gin.Context) { }).Infof("deleting worker %s", w.GetHostname()) // send API call to remove the step - err := database.FromContext(c).DeleteWorker(w) + err := database.FromContext(c).DeleteWorker(ctx, w) if err != nil { retErr := fmt.Errorf("unable to delete worker %s: %w", w.GetHostname(), err) diff --git a/api/worker/get.go b/api/worker/get.go index 88b323532..b16cbb9cd 100644 --- a/api/worker/get.go +++ b/api/worker/get.go @@ -47,6 +47,7 @@ func GetWorker(c *gin.Context) { // capture middleware values u := user.Retrieve(c) w := worker.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -56,7 +57,7 @@ func GetWorker(c *gin.Context) { "worker": w.GetHostname(), }).Infof("reading worker %s", w.GetHostname()) - w, err := database.FromContext(c).GetWorkerForHostname(w.GetHostname()) + w, err := database.FromContext(c).GetWorkerForHostname(ctx, w.GetHostname()) if err != nil { retErr := fmt.Errorf("unable to get workers: %w", err) diff --git a/api/worker/list.go b/api/worker/list.go index 2587ba07f..07b6d136a 100644 --- a/api/worker/list.go +++ b/api/worker/list.go @@ -41,6 +41,7 @@ import ( func ListWorkers(c *gin.Context) { // capture middleware values u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -49,7 +50,7 @@ func ListWorkers(c *gin.Context) { "user": u.GetName(), }).Info("reading workers") - w, err := database.FromContext(c).ListWorkers() + w, err := database.FromContext(c).ListWorkers(ctx) if err != nil { retErr := fmt.Errorf("unable to get workers: %w", err) diff --git a/api/worker/refresh.go b/api/worker/refresh.go index cd4aa7ef3..6695d5c25 100644 --- a/api/worker/refresh.go +++ b/api/worker/refresh.go @@ -60,6 +60,7 @@ func Refresh(c *gin.Context) { // capture middleware values w := worker.Retrieve(c) cl := claims.Retrieve(c) + ctx := c.Request.Context() // if we are not using a symmetric token, and the subject does not match the input, request should be denied if !strings.EqualFold(cl.TokenType, constants.ServerWorkerTokenType) && !strings.EqualFold(cl.Subject, w.GetHostname()) { @@ -79,7 +80,7 @@ func Refresh(c *gin.Context) { w.SetLastCheckedIn(time.Now().Unix()) // send API call to update the worker - err := database.FromContext(c).UpdateWorker(w) + err := database.FromContext(c).UpdateWorker(ctx, w) if err != nil { retErr := fmt.Errorf("unable to update worker %s: %w", w.GetHostname(), err) diff --git a/api/worker/update.go b/api/worker/update.go index b3a8d5130..5431aca89 100644 --- a/api/worker/update.go +++ b/api/worker/update.go @@ -62,6 +62,7 @@ func UpdateWorker(c *gin.Context) { // capture middleware values u := user.Retrieve(c) w := worker.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -124,7 +125,7 @@ func UpdateWorker(c *gin.Context) { } // send API call to update the worker - err = database.FromContext(c).UpdateWorker(w) + err = database.FromContext(c).UpdateWorker(ctx, w) if err != nil { retErr := fmt.Errorf("unable to update worker %s: %w", w.GetHostname(), err) @@ -134,7 +135,7 @@ func UpdateWorker(c *gin.Context) { } // send API call to capture the updated worker - w, _ = database.FromContext(c).GetWorkerForHostname(w.GetHostname()) + w, _ = database.FromContext(c).GetWorkerForHostname(ctx, w.GetHostname()) c.JSON(http.StatusOK, w) } diff --git a/database/integration_test.go b/database/integration_test.go index 30d6ff696..5edeabbfd 100644 --- a/database/integration_test.go +++ b/database/integration_test.go @@ -1765,7 +1765,7 @@ func testWorkers(t *testing.T, db Interface, resources *Resources) { // create the workers for _, worker := range resources.Workers { - err := db.CreateWorker(worker) + err := db.CreateWorker(context.TODO(), worker) if err != nil { t.Errorf("unable to create worker %d: %v", worker.GetID(), err) } @@ -1773,7 +1773,7 @@ func testWorkers(t *testing.T, db Interface, resources *Resources) { methods["CreateWorker"] = true // count the workers - count, err := db.CountWorkers() + count, err := db.CountWorkers(context.TODO()) if err != nil { t.Errorf("unable to count workers: %v", err) } @@ -1783,7 +1783,7 @@ func testWorkers(t *testing.T, db Interface, resources *Resources) { methods["CountWorkers"] = true // list the workers - list, err := db.ListWorkers() + list, err := db.ListWorkers(context.TODO()) if err != nil { t.Errorf("unable to list workers: %v", err) } @@ -1794,7 +1794,7 @@ func testWorkers(t *testing.T, db Interface, resources *Resources) { // lookup the workers by hostname for _, worker := range resources.Workers { - got, err := db.GetWorkerForHostname(worker.GetHostname()) + got, err := db.GetWorkerForHostname(context.TODO(), worker.GetHostname()) if err != nil { t.Errorf("unable to get worker %d by hostname: %v", worker.GetID(), err) } @@ -1807,13 +1807,13 @@ func testWorkers(t *testing.T, db Interface, resources *Resources) { // update the workers for _, worker := range resources.Workers { worker.SetActive(false) - err = db.UpdateWorker(worker) + err = db.UpdateWorker(context.TODO(), worker) if err != nil { t.Errorf("unable to update worker %d: %v", worker.GetID(), err) } // lookup the worker by ID - got, err := db.GetWorker(worker.GetID()) + got, err := db.GetWorker(context.TODO(), worker.GetID()) if err != nil { t.Errorf("unable to get worker %d by ID: %v", worker.GetID(), err) } @@ -1826,7 +1826,7 @@ func testWorkers(t *testing.T, db Interface, resources *Resources) { // delete the workers for _, worker := range resources.Workers { - err = db.DeleteWorker(worker) + err = db.DeleteWorker(context.TODO(), worker) if err != nil { t.Errorf("unable to delete worker %d: %v", worker.GetID(), err) } diff --git a/database/resource.go b/database/resource.go index 74a4b8fe5..7d3e9373c 100644 --- a/database/resource.go +++ b/database/resource.go @@ -151,6 +151,7 @@ func (e *engine) NewResources(ctx context.Context) error { // create the database agnostic engine for workers e.WorkerInterface, err = worker.New( + worker.WithContext(e.ctx), worker.WithClient(e.client), worker.WithLogger(e.logger), worker.WithSkipCreation(e.config.SkipCreation), diff --git a/database/worker/count.go b/database/worker/count.go index 8ac0f3eb5..f85227840 100644 --- a/database/worker/count.go +++ b/database/worker/count.go @@ -5,11 +5,13 @@ package worker import ( + "context" + "github.com/go-vela/types/constants" ) // CountWorkers gets the count of all workers from the database. -func (e *engine) CountWorkers() (int64, error) { +func (e *engine) CountWorkers(ctx context.Context) (int64, error) { e.logger.Tracef("getting count of all workers from the database") // variable to store query results diff --git a/database/worker/count_test.go b/database/worker/count_test.go index bd9d4c4ac..de5509db8 100644 --- a/database/worker/count_test.go +++ b/database/worker/count_test.go @@ -5,6 +5,7 @@ package worker import ( + "context" "reflect" "testing" @@ -37,12 +38,12 @@ func TestWorker_Engine_CountWorkers(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - err := _sqlite.CreateWorker(_workerOne) + err := _sqlite.CreateWorker(context.TODO(), _workerOne) if err != nil { t.Errorf("unable to create test worker for sqlite: %v", err) } - err = _sqlite.CreateWorker(_workerTwo) + err = _sqlite.CreateWorker(context.TODO(), _workerTwo) if err != nil { t.Errorf("unable to create test worker for sqlite: %v", err) } @@ -71,7 +72,7 @@ func TestWorker_Engine_CountWorkers(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.CountWorkers() + got, err := test.database.CountWorkers(context.TODO()) if test.failure { if err == nil { diff --git a/database/worker/create.go b/database/worker/create.go index 6c62b30b6..a7e33f4ae 100644 --- a/database/worker/create.go +++ b/database/worker/create.go @@ -5,6 +5,8 @@ package worker import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // CreateWorker creates a new worker in the database. -func (e *engine) CreateWorker(w *library.Worker) error { +func (e *engine) CreateWorker(ctx context.Context, w *library.Worker) error { e.logger.WithFields(logrus.Fields{ "worker": w.GetHostname(), }).Tracef("creating worker %s in the database", w.GetHostname()) diff --git a/database/worker/create_test.go b/database/worker/create_test.go index e4c4dc9cb..152ecb8fb 100644 --- a/database/worker/create_test.go +++ b/database/worker/create_test.go @@ -5,6 +5,7 @@ package worker import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -55,7 +56,7 @@ VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12) RETURNING "id"`). // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err := test.database.CreateWorker(_worker) + err := test.database.CreateWorker(context.TODO(), _worker) if test.failure { if err == nil { diff --git a/database/worker/delete.go b/database/worker/delete.go index a04ebb13e..c3c29fd5b 100644 --- a/database/worker/delete.go +++ b/database/worker/delete.go @@ -5,6 +5,8 @@ package worker import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // DeleteWorker deletes an existing worker from the database. -func (e *engine) DeleteWorker(w *library.Worker) error { +func (e *engine) DeleteWorker(ctx context.Context, w *library.Worker) error { e.logger.WithFields(logrus.Fields{ "worker": w.GetHostname(), }).Tracef("deleting worker %s from the database", w.GetHostname()) diff --git a/database/worker/delete_test.go b/database/worker/delete_test.go index c8a9bd1be..8125a3105 100644 --- a/database/worker/delete_test.go +++ b/database/worker/delete_test.go @@ -5,6 +5,7 @@ package worker import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -29,7 +30,7 @@ func TestWorker_Engine_DeleteWorker(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - err := _sqlite.CreateWorker(_worker) + err := _sqlite.CreateWorker(context.TODO(), _worker) if err != nil { t.Errorf("unable to create test worker for sqlite: %v", err) } @@ -55,7 +56,7 @@ func TestWorker_Engine_DeleteWorker(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err = test.database.DeleteWorker(_worker) + err = test.database.DeleteWorker(context.TODO(), _worker) if test.failure { if err == nil { diff --git a/database/worker/get.go b/database/worker/get.go index dd2b07ecc..b2a033983 100644 --- a/database/worker/get.go +++ b/database/worker/get.go @@ -5,13 +5,15 @@ package worker import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" ) // GetWorker gets a worker by ID from the database. -func (e *engine) GetWorker(id int64) (*library.Worker, error) { +func (e *engine) GetWorker(ctx context.Context, id int64) (*library.Worker, error) { e.logger.Tracef("getting worker %d from the database", id) // variable to store query results diff --git a/database/worker/get_hostname.go b/database/worker/get_hostname.go index 6bcf42a2b..6a2a89796 100644 --- a/database/worker/get_hostname.go +++ b/database/worker/get_hostname.go @@ -5,6 +5,8 @@ package worker import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // GetWorkerForHostname gets a worker by hostname from the database. -func (e *engine) GetWorkerForHostname(hostname string) (*library.Worker, error) { +func (e *engine) GetWorkerForHostname(ctx context.Context, hostname string) (*library.Worker, error) { e.logger.WithFields(logrus.Fields{ "worker": hostname, }).Tracef("getting worker %s from the database", hostname) diff --git a/database/worker/get_hostname_test.go b/database/worker/get_hostname_test.go index 3dd1d4fe6..e6ac3d198 100644 --- a/database/worker/get_hostname_test.go +++ b/database/worker/get_hostname_test.go @@ -5,6 +5,7 @@ package worker import ( + "context" "reflect" "testing" @@ -34,7 +35,7 @@ func TestWorker_Engine_GetWorkerForName(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - err := _sqlite.CreateWorker(_worker) + err := _sqlite.CreateWorker(context.TODO(), _worker) if err != nil { t.Errorf("unable to create test worker for sqlite: %v", err) } @@ -63,7 +64,7 @@ func TestWorker_Engine_GetWorkerForName(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.GetWorkerForHostname("worker_0") + got, err := test.database.GetWorkerForHostname(context.TODO(), "worker_0") if test.failure { if err == nil { diff --git a/database/worker/get_test.go b/database/worker/get_test.go index 17fd03739..d2d350f2f 100644 --- a/database/worker/get_test.go +++ b/database/worker/get_test.go @@ -5,6 +5,7 @@ package worker import ( + "context" "reflect" "testing" @@ -34,7 +35,7 @@ func TestWorker_Engine_GetWorker(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - err := _sqlite.CreateWorker(_worker) + err := _sqlite.CreateWorker(context.TODO(), _worker) if err != nil { t.Errorf("unable to create test worker for sqlite: %v", err) } @@ -63,7 +64,7 @@ func TestWorker_Engine_GetWorker(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.GetWorker(1) + got, err := test.database.GetWorker(context.TODO(), 1) if test.failure { if err == nil { diff --git a/database/worker/index.go b/database/worker/index.go index f8f01a4b6..8f6dd23c5 100644 --- a/database/worker/index.go +++ b/database/worker/index.go @@ -4,6 +4,8 @@ package worker +import "context" + const ( // CreateHostnameAddressIndex represents a query to create an // index on the workers table for the hostname and address columns. @@ -16,7 +18,7 @@ ON workers (hostname, address); ) // CreateWorkerIndexes creates the indexes for the workers table in the database. -func (e *engine) CreateWorkerIndexes() error { +func (e *engine) CreateWorkerIndexes(ctx context.Context) error { e.logger.Tracef("creating indexes for workers table in the database") // create the hostname and address columns index for the workers table diff --git a/database/worker/index_test.go b/database/worker/index_test.go index ead204e5c..7a55ebe57 100644 --- a/database/worker/index_test.go +++ b/database/worker/index_test.go @@ -5,6 +5,7 @@ package worker import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -41,7 +42,7 @@ func TestWorker_Engine_CreateWorkerIndexes(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err := test.database.CreateWorkerIndexes() + err := test.database.CreateWorkerIndexes(context.TODO()) if test.failure { if err == nil { diff --git a/database/worker/interface.go b/database/worker/interface.go index 9e7fe1169..9394b2e7c 100644 --- a/database/worker/interface.go +++ b/database/worker/interface.go @@ -5,6 +5,8 @@ package worker import ( + "context" + "github.com/go-vela/types/library" ) @@ -18,26 +20,26 @@ type WorkerInterface interface { // https://en.wikipedia.org/wiki/Data_definition_language // CreateWorkerIndexes defines a function that creates the indexes for the workers table. - CreateWorkerIndexes() error + CreateWorkerIndexes(context.Context) error // CreateWorkerTable defines a function that creates the workers table. - CreateWorkerTable(string) error + CreateWorkerTable(context.Context, string) error // Worker Data Manipulation Language Functions // // https://en.wikipedia.org/wiki/Data_manipulation_language // CountWorkers defines a function that gets the count of all workers. - CountWorkers() (int64, error) + CountWorkers(context.Context) (int64, error) // CreateWorker defines a function that creates a new worker. - CreateWorker(*library.Worker) error + CreateWorker(context.Context, *library.Worker) error // DeleteWorker defines a function that deletes an existing worker. - DeleteWorker(*library.Worker) error + DeleteWorker(context.Context, *library.Worker) error // GetWorker defines a function that gets a worker by ID. - GetWorker(int64) (*library.Worker, error) + GetWorker(context.Context, int64) (*library.Worker, error) // GetWorkerForHostname defines a function that gets a worker by hostname. - GetWorkerForHostname(string) (*library.Worker, error) + GetWorkerForHostname(context.Context, string) (*library.Worker, error) // ListWorkers defines a function that gets a list of all workers. - ListWorkers() ([]*library.Worker, error) + ListWorkers(context.Context) ([]*library.Worker, error) // UpdateWorker defines a function that updates an existing worker. - UpdateWorker(*library.Worker) error + UpdateWorker(context.Context, *library.Worker) error } diff --git a/database/worker/list.go b/database/worker/list.go index 4ec11ef3d..4ab1a4c5b 100644 --- a/database/worker/list.go +++ b/database/worker/list.go @@ -5,13 +5,15 @@ package worker import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" ) // ListWorkers gets a list of all workers from the database. -func (e *engine) ListWorkers() ([]*library.Worker, error) { +func (e *engine) ListWorkers(ctx context.Context) ([]*library.Worker, error) { e.logger.Trace("listing all workers from the database") // variables to store query results and return value @@ -20,7 +22,7 @@ func (e *engine) ListWorkers() ([]*library.Worker, error) { workers := []*library.Worker{} // count the results - count, err := e.CountWorkers() + count, err := e.CountWorkers(ctx) if err != nil { return nil, err } diff --git a/database/worker/list_test.go b/database/worker/list_test.go index 5eed3f94f..962c801b2 100644 --- a/database/worker/list_test.go +++ b/database/worker/list_test.go @@ -5,6 +5,7 @@ package worker import ( + "context" "reflect" "testing" @@ -47,12 +48,12 @@ func TestWorker_Engine_ListWorkers(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - err := _sqlite.CreateWorker(_workerOne) + err := _sqlite.CreateWorker(context.TODO(), _workerOne) if err != nil { t.Errorf("unable to create test worker for sqlite: %v", err) } - err = _sqlite.CreateWorker(_workerTwo) + err = _sqlite.CreateWorker(context.TODO(), _workerTwo) if err != nil { t.Errorf("unable to create test worker for sqlite: %v", err) } @@ -81,7 +82,7 @@ func TestWorker_Engine_ListWorkers(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.ListWorkers() + got, err := test.database.ListWorkers(context.TODO()) if test.failure { if err == nil { diff --git a/database/worker/opts.go b/database/worker/opts.go index c9891ba94..2d07f41f9 100644 --- a/database/worker/opts.go +++ b/database/worker/opts.go @@ -5,6 +5,8 @@ package worker import ( + "context" + "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -42,3 +44,12 @@ func WithSkipCreation(skipCreation bool) EngineOpt { return nil } } + +// WithContext sets the context in the database engine for Workers. +func WithContext(ctx context.Context) EngineOpt { + return func(e *engine) error { + e.ctx = ctx + + return nil + } +} diff --git a/database/worker/opts_test.go b/database/worker/opts_test.go index a0ebf6aa5..e2efb2997 100644 --- a/database/worker/opts_test.go +++ b/database/worker/opts_test.go @@ -5,6 +5,7 @@ package worker import ( + "context" "reflect" "testing" @@ -159,3 +160,52 @@ func TestWorker_EngineOpt_WithSkipCreation(t *testing.T) { }) } } + +func TestWorker_EngineOpt_WithContext(t *testing.T) { + // setup types + e := &engine{config: new(config)} + + // setup tests + tests := []struct { + failure bool + name string + ctx context.Context + want context.Context + }{ + { + failure: false, + name: "context set to TODO", + ctx: context.TODO(), + want: context.TODO(), + }, + { + failure: false, + name: "context set to nil", + ctx: nil, + want: nil, + }, + } + + // run tests + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := WithContext(test.ctx)(e) + + if test.failure { + if err == nil { + t.Errorf("WithContext for %s should have returned err", test.name) + } + + return + } + + if err != nil { + t.Errorf("WithContext returned err: %v", err) + } + + if !reflect.DeepEqual(e.ctx, test.want) { + t.Errorf("WithContext is %v, want %v", e.ctx, test.want) + } + }) + } +} diff --git a/database/worker/table.go b/database/worker/table.go index 1d704674a..304b1dddf 100644 --- a/database/worker/table.go +++ b/database/worker/table.go @@ -5,6 +5,8 @@ package worker import ( + "context" + "github.com/go-vela/types/constants" ) @@ -52,7 +54,7 @@ workers ( ) // CreateWorkerTable creates the workers table in the database. -func (e *engine) CreateWorkerTable(driver string) error { +func (e *engine) CreateWorkerTable(ctx context.Context, driver string) error { e.logger.Tracef("creating workers table in the database") // handle the driver provided to create the table diff --git a/database/worker/table_test.go b/database/worker/table_test.go index 681a267f2..8e37f22de 100644 --- a/database/worker/table_test.go +++ b/database/worker/table_test.go @@ -5,6 +5,7 @@ package worker import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -41,7 +42,7 @@ func TestWorker_Engine_CreateWorkerTable(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err := test.database.CreateWorkerTable(test.name) + err := test.database.CreateWorkerTable(context.TODO(), test.name) if test.failure { if err == nil { diff --git a/database/worker/update.go b/database/worker/update.go index b0e475273..338384669 100644 --- a/database/worker/update.go +++ b/database/worker/update.go @@ -5,6 +5,8 @@ package worker import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // UpdateWorker updates an existing worker in the database. -func (e *engine) UpdateWorker(w *library.Worker) error { +func (e *engine) UpdateWorker(ctx context.Context, w *library.Worker) error { e.logger.WithFields(logrus.Fields{ "worker": w.GetHostname(), }).Tracef("updating worker %s in the database", w.GetHostname()) diff --git a/database/worker/update_test.go b/database/worker/update_test.go index 0beeafa47..7c339a881 100644 --- a/database/worker/update_test.go +++ b/database/worker/update_test.go @@ -5,6 +5,7 @@ package worker import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -31,7 +32,7 @@ WHERE "id" = $12`). _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - err := _sqlite.CreateWorker(_worker) + err := _sqlite.CreateWorker(context.TODO(), _worker) if err != nil { t.Errorf("unable to create test worker for sqlite: %v", err) } @@ -57,7 +58,7 @@ WHERE "id" = $12`). // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err = test.database.UpdateWorker(_worker) + err = test.database.UpdateWorker(context.TODO(), _worker) if test.failure { if err == nil { diff --git a/database/worker/worker.go b/database/worker/worker.go index d18aa6408..7f9277dc1 100644 --- a/database/worker/worker.go +++ b/database/worker/worker.go @@ -5,6 +5,7 @@ package worker import ( + "context" "fmt" "github.com/go-vela/types/constants" @@ -25,6 +26,8 @@ type ( // engine configuration settings used in worker functions config *config + ctx context.Context + // gorm.io/gorm database client used in worker functions // // https://pkg.go.dev/gorm.io/gorm#DB @@ -65,13 +68,13 @@ func New(opts ...EngineOpt) (*engine, error) { } // create the workers table - err := e.CreateWorkerTable(e.client.Config.Dialector.Name()) + err := e.CreateWorkerTable(e.ctx, e.client.Config.Dialector.Name()) if err != nil { return nil, fmt.Errorf("unable to create %s table: %w", constants.TableWorker, err) } // create the indexes for the workers table - err = e.CreateWorkerIndexes() + err = e.CreateWorkerIndexes(e.ctx) if err != nil { return nil, fmt.Errorf("unable to create indexes for %s table: %w", constants.TableWorker, err) } diff --git a/router/middleware/executors/executors.go b/router/middleware/executors/executors.go index 777da7c37..9bd779223 100644 --- a/router/middleware/executors/executors.go +++ b/router/middleware/executors/executors.go @@ -31,6 +31,7 @@ func Establish() gin.HandlerFunc { return func(c *gin.Context) { e := new([]library.Executor) b := build.Retrieve(c) + ctx := c.Request.Context() // if build has no host, we cannot establish executors if len(b.GetHost()) == 0 { @@ -41,7 +42,7 @@ func Establish() gin.HandlerFunc { } // retrieve the worker - w, err := database.FromContext(c).GetWorkerForHostname(b.GetHost()) + w, err := database.FromContext(c).GetWorkerForHostname(ctx, b.GetHost()) if err != nil { retErr := fmt.Errorf("unable to get worker: %w", err) util.HandleError(c, http.StatusNotFound, retErr) diff --git a/router/middleware/worker/worker.go b/router/middleware/worker/worker.go index 5afc8b3bf..2d5dc71cb 100644 --- a/router/middleware/worker/worker.go +++ b/router/middleware/worker/worker.go @@ -23,6 +23,8 @@ func Retrieve(c *gin.Context) *library.Worker { // Establish sets the worker in the given context. func Establish() gin.HandlerFunc { return func(c *gin.Context) { + ctx := c.Request.Context() + wParam := util.PathParameter(c, "worker") if len(wParam) == 0 { retErr := fmt.Errorf("no worker parameter provided") @@ -33,7 +35,7 @@ func Establish() gin.HandlerFunc { logrus.Debugf("Reading worker %s", wParam) - w, err := database.FromContext(c).GetWorkerForHostname(wParam) + w, err := database.FromContext(c).GetWorkerForHostname(ctx, wParam) if err != nil { retErr := fmt.Errorf("unable to read worker %s: %w", wParam, err) util.HandleError(c, http.StatusNotFound, retErr) diff --git a/router/middleware/worker/worker_test.go b/router/middleware/worker/worker_test.go index 58d090825..a9d1ffcf1 100644 --- a/router/middleware/worker/worker_test.go +++ b/router/middleware/worker/worker_test.go @@ -5,6 +5,7 @@ package worker import ( + "context" "net/http" "net/http/httptest" "reflect" @@ -58,11 +59,11 @@ func TestWorker_Establish(t *testing.T) { } defer func() { - db.DeleteWorker(want) + db.DeleteWorker(context.TODO(), want) db.Close() }() - _ = db.CreateWorker(want) + _ = db.CreateWorker(context.TODO(), want) // setup context gin.SetMode(gin.TestMode)