Skip to content

Commit

Permalink
go/cmd: Audit and fix context.Background() usage
Browse files Browse the repository at this point in the history
This removes a bunch of the context.Background() usage for the command
line entrypoints. In general, we should use the command context (which
normally is context.Background(), but it's more semantically accurate).

There's a few cases where we need more fixes. Specifically in vtgate
where we want to setup a cancellable context and cancel it when we shut
down. This is the one that ends up running things like the topo watcher
and this ensures things are closed appropriately.

Similarly in vtcombo we apply similar fixes so that we always correctly
cancel the context on shutdown and the same for vttablet.

Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink committed May 14, 2024
1 parent 473c49a commit 69b352d
Show file tree
Hide file tree
Showing 20 changed files with 71 additions and 74 deletions.
2 changes: 1 addition & 1 deletion go/cmd/mysqlctl/command/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func commandInit(cmd *cobra.Command, args []string) error {
}
defer mysqld.Close()

ctx, cancel := context.WithTimeout(context.Background(), initArgs.WaitTime)
ctx, cancel := context.WithTimeout(cmd.Context(), initArgs.WaitTime)
defer cancel()
if err := mysqld.Init(ctx, cnf, initArgs.InitDbSQLFile); err != nil {
return fmt.Errorf("failed init mysql: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/mysqlctl/command/shutdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func commandShutdown(cmd *cobra.Command, args []string) error {
}
defer mysqld.Close()

ctx, cancel := context.WithTimeout(context.Background(), shutdownArgs.WaitTime+10*time.Second)
ctx, cancel := context.WithTimeout(cmd.Context(), shutdownArgs.WaitTime+10*time.Second)
defer cancel()
if err := mysqld.Shutdown(ctx, cnf, true, shutdownArgs.WaitTime); err != nil {
return fmt.Errorf("failed shutdown mysql: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/mysqlctl/command/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func commandStart(cmd *cobra.Command, args []string) error {
}
defer mysqld.Close()

ctx, cancel := context.WithTimeout(context.Background(), startArgs.WaitTime)
ctx, cancel := context.WithTimeout(cmd.Context(), startArgs.WaitTime)
defer cancel()
if err := mysqld.Start(ctx, cnf, startArgs.MySQLdArgs...); err != nil {
return fmt.Errorf("failed start mysql: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/mysqlctl/command/teardown.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func commandTeardown(cmd *cobra.Command, args []string) error {
}
defer mysqld.Close()

ctx, cancel := context.WithTimeout(context.Background(), teardownArgs.WaitTime+10*time.Second)
ctx, cancel := context.WithTimeout(cmd.Context(), teardownArgs.WaitTime+10*time.Second)
defer cancel()
if err := mysqld.Teardown(ctx, cnf, teardownArgs.Force, teardownArgs.WaitTime); err != nil {
return fmt.Errorf("failed teardown mysql (forced? %v): %v", teardownArgs.Force, err)
Expand Down
4 changes: 2 additions & 2 deletions go/cmd/mysqlctld/cli/mysqlctld.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func run(cmd *cobra.Command, args []string) error {
}

// Start or Init mysqld as needed.
ctx, cancel := context.WithTimeout(context.Background(), waitTime)
ctx, cancel := context.WithTimeout(cmd.Context(), waitTime)
mycnfFile := mysqlctl.MycnfFile(tabletUID)
if _, statErr := os.Stat(mycnfFile); os.IsNotExist(statErr) {
// Generate my.cnf from scratch and use it to find mysqld.
Expand Down Expand Up @@ -167,7 +167,7 @@ func run(cmd *cobra.Command, args []string) error {
// Take mysqld down with us on SIGTERM before entering lame duck.
servenv.OnTermSync(func() {
log.Infof("mysqlctl received SIGTERM, shutting down mysqld first")
ctx, cancel := context.WithTimeout(context.Background(), shutdownWaitTime+10*time.Second)
ctx, cancel := context.WithTimeout(cmd.Context(), shutdownWaitTime+10*time.Second)
defer cancel()
if err := mysqld.Shutdown(ctx, cnf, true, shutdownWaitTime); err != nil {
log.Errorf("failed to shutdown mysqld: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/topo2topo/cli/topo2topo.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func run(cmd *cobra.Command, args []string) error {
return fmt.Errorf("Cannot open 'to' topo %v: %w", toImplementation, err)
}

ctx := context.Background()
ctx := cmd.Context()

if compare {
return compareTopos(ctx, fromTS, toTS)
Expand Down
3 changes: 1 addition & 2 deletions go/cmd/vtadmin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package main

import (
"context"
"flag"
"io"
"time"
Expand Down Expand Up @@ -97,7 +96,7 @@ func startTracing(cmd *cobra.Command) {
}

func run(cmd *cobra.Command, args []string) {
bootSpan, ctx := trace.NewSpan(context.Background(), "vtadmin.boot")
bootSpan, ctx := trace.NewSpan(cmd.Context(), "vtadmin.boot")
defer bootSpan.Finish()

configs := clusterFileConfig.Combine(defaultClusterConfig, clusterConfigs)
Expand Down
14 changes: 7 additions & 7 deletions go/cmd/vtbackup/cli/vtbackup.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ func init() {
collationEnv = collations.NewEnvironment(servenv.MySQLServerVersion())
}

func run(_ *cobra.Command, args []string) error {
func run(cc *cobra.Command, args []string) error {
servenv.Init()

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(cc.Context())
servenv.OnClose(func() {
cancel()
})
Expand Down Expand Up @@ -282,7 +282,7 @@ func run(_ *cobra.Command, args []string) error {
return fmt.Errorf("Can't take backup: %w", err)
}
if doBackup {
if err := takeBackup(ctx, topoServer, backupStorage); err != nil {
if err := takeBackup(ctx, cc.Context(), topoServer, backupStorage); err != nil {
return fmt.Errorf("Failed to take backup: %w", err)
}
}
Expand All @@ -304,7 +304,7 @@ func run(_ *cobra.Command, args []string) error {
return nil
}

func takeBackup(ctx context.Context, topoServer *topo.Server, backupStorage backupstorage.BackupStorage) error {
func takeBackup(ctx, backgroundCtx context.Context, topoServer *topo.Server, backupStorage backupstorage.BackupStorage) error {
// This is an imaginary tablet alias. The value doesn't matter for anything,
// except that we generate a random UID to ensure the target backup
// directory is unique if multiple vtbackup instances are launched for the
Expand Down Expand Up @@ -344,9 +344,9 @@ func takeBackup(ctx context.Context, topoServer *topo.Server, backupStorage back
deprecatedDurationByPhase.Set("InitMySQLd", int64(time.Since(initMysqldAt).Seconds()))
// Shut down mysqld when we're done.
defer func() {
// Be careful not to use the original context, because we don't want to
// skip shutdown just because we timed out waiting for other things.
mysqlShutdownCtx, mysqlShutdownCancel := context.WithTimeout(context.Background(), mysqlShutdownTimeout+10*time.Second)
// Be careful use the background context, not the init one, because we don't want to
// skip shutdown just because we timed out waiting for init.
mysqlShutdownCtx, mysqlShutdownCancel := context.WithTimeout(backgroundCtx, mysqlShutdownTimeout+10*time.Second)
defer mysqlShutdownCancel()
if err := mysqld.Shutdown(mysqlShutdownCtx, mycnf, false, mysqlShutdownTimeout); err != nil {
log.Errorf("failed to shutdown mysqld: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtbench/cli/vtbench.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func run(cmd *cobra.Command, args []string) error {

b := vtbench.NewBench(threads, count, connParams, sql)

ctx, cancel := context.WithTimeout(context.Background(), deadline)
ctx, cancel := context.WithTimeout(cmd.Context(), deadline)
defer cancel()

fmt.Printf("Initializing test with %s protocol / %d threads / %d iterations\n",
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtclient/cli/vtclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func _run(cmd *cobra.Command, args []string) (*results, error) {

log.Infof("Sending the query...")

ctx, cancel := context.WithTimeout(context.Background(), timeout)
ctx, cancel := context.WithTimeout(cmd.Context(), timeout)
defer cancel()
return execMulti(ctx, db, cmd.Flags().Arg(0))
}
Expand Down
2 changes: 2 additions & 0 deletions go/cmd/vtclient/cli/vtclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package cli

import (
"context"
"fmt"
"os"
"strings"
Expand Down Expand Up @@ -129,6 +130,7 @@ func TestVtclient(t *testing.T) {
err := Main.ParseFlags(args)
require.NoError(t, err)

Main.SetContext(context.Background())
results, err := _run(Main, args)
if q.errMsg != "" {
if got, want := err.Error(), q.errMsg; !strings.Contains(got, want) {
Expand Down
51 changes: 23 additions & 28 deletions go/cmd/vtcombo/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ func init() {
srvTopoCounts = stats.NewCountersWithSingleLabel("ResilientSrvTopoServer", "Resilient srvtopo server operations", "type")
}

func startMysqld(uid uint32) (mysqld *mysqlctl.Mysqld, cnf *mysqlctl.Mycnf, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
func startMysqld(ctx context.Context, uid uint32) (mysqld *mysqlctl.Mysqld, cnf *mysqlctl.Mycnf, err error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

mycnfFile := mysqlctl.MycnfFile(uid)
Expand Down Expand Up @@ -189,17 +189,20 @@ func run(cmd *cobra.Command, args []string) (err error) {
cmd.Flags().Set("log_dir", "$VTDATAROOT/tmp")
}

ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
if externalTopoServer {
// Open topo server based on the command line flags defined at topo/server.go
// do not create cell info as it should be done by whoever sets up the external topo server
ts = topo.Open()
} else {
// Create topo server. We use a 'memorytopo' implementation.
ts = memorytopo.NewServer(context.Background(), tpb.Cells...)
ts = memorytopo.NewServer(ctx, tpb.Cells...)
}
defer ts.Close()

// attempt to load any routing rules specified by tpb
if err := vtcombo.InitRoutingRules(context.Background(), ts, tpb.GetRoutingRules()); err != nil {
if err := vtcombo.InitRoutingRules(ctx, ts, tpb.GetRoutingRules()); err != nil {
return fmt.Errorf("Failed to load routing rules: %w", err)
}

Expand All @@ -212,17 +215,17 @@ func run(cmd *cobra.Command, args []string) (err error) {
)

if startMysql {
mysqld.Mysqld, cnf, err = startMysqld(1)
mysqld.Mysqld, cnf, err = startMysqld(ctx, 1)
if err != nil {
return err
}
servenv.OnClose(func() {
ctx, cancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer cancel()
mysqld.Shutdown(ctx, cnf, true, mysqlctl.DefaultShutdownTimeout)
shutdownCtx, shutdownCancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer shutdownCancel()
mysqld.Shutdown(shutdownCtx, cnf, true, mysqlctl.DefaultShutdownTimeout)
})
// We want to ensure we can write to this database
mysqld.SetReadOnly(cmd.Context(), false)
mysqld.SetReadOnly(ctx, false)

} else {
dbconfigs.GlobalDBConfigs.InitWithSocket("", env.CollationEnv())
Expand All @@ -241,9 +244,9 @@ func run(cmd *cobra.Command, args []string) (err error) {
if err != nil {
// ensure we start mysql in the event we fail here
if startMysql {
ctx, cancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer cancel()
mysqld.Shutdown(ctx, cnf, true, mysqlctl.DefaultShutdownTimeout)
startCtx, startCancel := context.WithTimeout(ctx, mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer startCancel()
mysqld.Shutdown(startCtx, cnf, true, mysqlctl.DefaultShutdownTimeout)
}

return fmt.Errorf("initTabletMapProto failed: %w", err)
Expand Down Expand Up @@ -287,20 +290,21 @@ func run(cmd *cobra.Command, args []string) (err error) {

// Now that we have fully initialized the tablets, rebuild the keyspace graph.
for _, ks := range tpb.Keyspaces {
err := topotools.RebuildKeyspace(context.Background(), logutil.NewConsoleLogger(), ts, ks.GetName(), tpb.Cells, false)
err := topotools.RebuildKeyspace(cmd.Context(), logutil.NewConsoleLogger(), ts, ks.GetName(), tpb.Cells, false)
if err != nil {
if startMysql {
ctx, cancel := context.WithTimeout(context.Background(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer cancel()
mysqld.Shutdown(ctx, cnf, true, mysqlctl.DefaultShutdownTimeout)
shutdownCtx, shutdownCancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer shutdownCancel()
mysqld.Shutdown(shutdownCtx, cnf, true, mysqlctl.DefaultShutdownTimeout)
}

return fmt.Errorf("Couldn't build srv keyspace for (%v: %v). Got error: %w", ks, tpb.Cells, err)
}
}

// vtgate configuration and init
resilientServer = srvtopo.NewResilientServer(context.Background(), ts, srvTopoCounts)

resilientServer = srvtopo.NewResilientServer(ctx, ts, srvTopoCounts)

tabletTypes := make([]topodatapb.TabletType, 0, 1)
if len(tabletTypesToWait) != 0 {
Expand All @@ -324,7 +328,7 @@ func run(cmd *cobra.Command, args []string) (err error) {
vtgate.QueryzHandler = "/debug/vtgate/queryz"

// pass nil for healthcheck, it will get created
vtg := vtgate.Init(context.Background(), env, nil, resilientServer, tpb.Cells[0], tabletTypes, plannerVersion)
vtg := vtgate.Init(ctx, env, nil, resilientServer, tpb.Cells[0], tabletTypes, plannerVersion)

// vtctld configuration and init
err = vtctld.InitVtctld(env, ts)
Expand All @@ -333,22 +337,13 @@ func run(cmd *cobra.Command, args []string) (err error) {
}

if vschemaPersistenceDir != "" && !externalTopoServer {
startVschemaWatcher(vschemaPersistenceDir, tpb.Keyspaces, ts)
startVschemaWatcher(ctx, vschemaPersistenceDir, ts)
}

servenv.OnRun(func() {
addStatusParts(vtg)
})

servenv.OnTerm(func() {
log.Error("Terminating")
// FIXME(alainjobart): stop vtgate
})
servenv.OnClose(func() {
// We will still use the topo server during lameduck period
// to update our state, so closing it in OnClose()
ts.Close()
})
servenv.RunDefault()

return nil
Expand Down
15 changes: 7 additions & 8 deletions go/cmd/vtcombo/cli/vschema_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,27 @@ import (
"vitess.io/vitess/go/vt/vtgate/vindexes"

vschemapb "vitess.io/vitess/go/vt/proto/vschema"
vttestpb "vitess.io/vitess/go/vt/proto/vttest"
)

func startVschemaWatcher(vschemaPersistenceDir string, keyspaces []*vttestpb.Keyspace, ts *topo.Server) {
func startVschemaWatcher(ctx context.Context, vschemaPersistenceDir string, ts *topo.Server) {
// Create the directory if it doesn't exist.
if err := createDirectoryIfNotExists(vschemaPersistenceDir); err != nil {
log.Fatalf("Unable to create vschema persistence directory %v: %v", vschemaPersistenceDir, err)
}

// If there are keyspace files, load them.
loadKeyspacesFromDir(vschemaPersistenceDir, keyspaces, ts)
loadKeyspacesFromDir(ctx, vschemaPersistenceDir, ts)

// Rebuild the SrvVSchema object in case we loaded vschema from file
if err := ts.RebuildSrvVSchema(context.Background(), tpb.Cells); err != nil {
if err := ts.RebuildSrvVSchema(ctx, tpb.Cells); err != nil {
log.Fatalf("RebuildSrvVSchema failed: %v", err)
}

// Now watch for changes in the SrvVSchema object and persist them to disk.
go watchSrvVSchema(context.Background(), ts, tpb.Cells[0])
go watchSrvVSchema(ctx, ts, tpb.Cells[0])
}

func loadKeyspacesFromDir(dir string, keyspaces []*vttestpb.Keyspace, ts *topo.Server) {
func loadKeyspacesFromDir(ctx context.Context, dir string, ts *topo.Server) {
for _, ks := range tpb.Keyspaces {
ksFile := path.Join(dir, ks.Name+".json")
if _, err := os.Stat(ksFile); err == nil {
Expand All @@ -67,14 +66,14 @@ func loadKeyspacesFromDir(dir string, keyspaces []*vttestpb.Keyspace, ts *topo.S
if err != nil {
log.Fatalf("Invalid keyspace definition: %v", err)
}
ts.SaveVSchema(context.Background(), ks.Name, keyspace)
ts.SaveVSchema(ctx, ks.Name, keyspace)
log.Infof("Loaded keyspace %v from %v\n", ks.Name, ksFile)
}
}
}

func watchSrvVSchema(ctx context.Context, ts *topo.Server, cell string) {
data, ch, err := ts.WatchSrvVSchema(context.Background(), tpb.Cells[0])
data, ch, err := ts.WatchSrvVSchema(ctx, tpb.Cells[0])
if err != nil {
log.Fatalf("WatchSrvVSchema failed: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtctld/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func run(cmd *cobra.Command, args []string) error {
vtctld.RegisterDebugHealthHandler(ts)

// Start schema manager service.
initSchema()
initSchema(cmd.Context())

// And run the server.
servenv.RunDefault()
Expand Down
3 changes: 1 addition & 2 deletions go/cmd/vtctld/cli/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func init() {
Main.Flags().DurationVar(&schemaChangeReplicasTimeout, "schema_change_replicas_timeout", schemaChangeReplicasTimeout, "How long to wait for replicas to receive a schema change.")
}

func initSchema() {
func initSchema(ctx context.Context) {
// Start schema manager service if needed.
if schemaChangeDir != "" {
interval := schemaChangeCheckInterval
Expand All @@ -70,7 +70,6 @@ func initSchema() {
log.Errorf("failed to get controller, error: %v", err)
return
}
ctx := context.Background()
wr := wrangler.New(env, logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient())
_, err = schemamanager.Run(
ctx,
Expand Down
6 changes: 3 additions & 3 deletions go/cmd/vtctldclient/command/legacy_shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var (
Args: cobra.ArbitraryArgs,
RunE: func(cmd *cobra.Command, args []string) error {
cli.FinishedParsing(cmd)
return runLegacyCommand(args)
return runLegacyCommand(cmd.Context(), args)
},
Long: strings.TrimSpace(`
LegacyVtctlCommand uses the legacy vtctl grpc client to make an ExecuteVtctlCommand
Expand Down Expand Up @@ -76,11 +76,11 @@ LegacyVtctlCommand -- AddCellInfo --server_address "localhost:5678" --root "/vit
}
)

func runLegacyCommand(args []string) error {
func runLegacyCommand(ctx context.Context, args []string) error {
// Duplicated (mostly) from go/cmd/vtctlclient/main.go.
logger := logutil.NewConsoleLogger()

ctx, cancel := context.WithTimeout(context.Background(), actionTimeout)
ctx, cancel := context.WithTimeout(ctx, actionTimeout)
defer cancel()

err := vtctlclient.RunCommandAndWait(ctx, server, args, func(e *logutilpb.Event) {
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtctldclient/command/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ connect directly to the topo server(s).`, useInternalVtctld),
client, err = getClientForCommand(cmd)
ctx := cmd.Context()
if ctx == nil {
ctx = context.Background()
ctx = cmd.Context()
}
commandCtx, commandCancel = context.WithTimeout(ctx, actionTimeout)
if compactOutput {
Expand Down
Loading

0 comments on commit 69b352d

Please sign in to comment.