Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

go/cmd: Audit and fix context.Background() usage #15928

Merged
merged 2 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/cmd/mysqlctl/command/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func commandInit(cmd *cobra.Command, args []string) error {
}
defer mysqld.Close()

ctx, cancel := context.WithTimeout(context.Background(), initArgs.WaitTime)
ctx, cancel := context.WithTimeout(cmd.Context(), initArgs.WaitTime)
defer cancel()
if err := mysqld.Init(ctx, cnf, initArgs.InitDbSQLFile); err != nil {
return fmt.Errorf("failed init mysql: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/mysqlctl/command/shutdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func commandShutdown(cmd *cobra.Command, args []string) error {
}
defer mysqld.Close()

ctx, cancel := context.WithTimeout(context.Background(), shutdownArgs.WaitTime+10*time.Second)
ctx, cancel := context.WithTimeout(cmd.Context(), shutdownArgs.WaitTime+10*time.Second)
defer cancel()
if err := mysqld.Shutdown(ctx, cnf, true, shutdownArgs.WaitTime); err != nil {
return fmt.Errorf("failed shutdown mysql: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/mysqlctl/command/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func commandStart(cmd *cobra.Command, args []string) error {
}
defer mysqld.Close()

ctx, cancel := context.WithTimeout(context.Background(), startArgs.WaitTime)
ctx, cancel := context.WithTimeout(cmd.Context(), startArgs.WaitTime)
defer cancel()
if err := mysqld.Start(ctx, cnf, startArgs.MySQLdArgs...); err != nil {
return fmt.Errorf("failed start mysql: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/mysqlctl/command/teardown.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func commandTeardown(cmd *cobra.Command, args []string) error {
}
defer mysqld.Close()

ctx, cancel := context.WithTimeout(context.Background(), teardownArgs.WaitTime+10*time.Second)
ctx, cancel := context.WithTimeout(cmd.Context(), teardownArgs.WaitTime+10*time.Second)
defer cancel()
if err := mysqld.Teardown(ctx, cnf, teardownArgs.Force, teardownArgs.WaitTime); err != nil {
return fmt.Errorf("failed teardown mysql (forced? %v): %v", teardownArgs.Force, err)
Expand Down
4 changes: 2 additions & 2 deletions go/cmd/mysqlctld/cli/mysqlctld.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func run(cmd *cobra.Command, args []string) error {
}

// Start or Init mysqld as needed.
ctx, cancel := context.WithTimeout(context.Background(), waitTime)
ctx, cancel := context.WithTimeout(cmd.Context(), waitTime)
mycnfFile := mysqlctl.MycnfFile(tabletUID)
if _, statErr := os.Stat(mycnfFile); os.IsNotExist(statErr) {
// Generate my.cnf from scratch and use it to find mysqld.
Expand Down Expand Up @@ -167,7 +167,7 @@ func run(cmd *cobra.Command, args []string) error {
// Take mysqld down with us on SIGTERM before entering lame duck.
servenv.OnTermSync(func() {
log.Infof("mysqlctl received SIGTERM, shutting down mysqld first")
ctx, cancel := context.WithTimeout(context.Background(), shutdownWaitTime+10*time.Second)
ctx, cancel := context.WithTimeout(cmd.Context(), shutdownWaitTime+10*time.Second)
defer cancel()
if err := mysqld.Shutdown(ctx, cnf, true, shutdownWaitTime); err != nil {
log.Errorf("failed to shutdown mysqld: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/topo2topo/cli/topo2topo.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func run(cmd *cobra.Command, args []string) error {
return fmt.Errorf("Cannot open 'to' topo %v: %w", toImplementation, err)
}

ctx := context.Background()
ctx := cmd.Context()

if compare {
return compareTopos(ctx, fromTS, toTS)
Expand Down
3 changes: 1 addition & 2 deletions go/cmd/vtadmin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package main

import (
"context"
"flag"
"io"
"time"
Expand Down Expand Up @@ -97,7 +96,7 @@ func startTracing(cmd *cobra.Command) {
}

func run(cmd *cobra.Command, args []string) {
bootSpan, ctx := trace.NewSpan(context.Background(), "vtadmin.boot")
bootSpan, ctx := trace.NewSpan(cmd.Context(), "vtadmin.boot")
defer bootSpan.Finish()

configs := clusterFileConfig.Combine(defaultClusterConfig, clusterConfigs)
Expand Down
14 changes: 7 additions & 7 deletions go/cmd/vtbackup/cli/vtbackup.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ func init() {
collationEnv = collations.NewEnvironment(servenv.MySQLServerVersion())
}

func run(_ *cobra.Command, args []string) error {
func run(cc *cobra.Command, args []string) error {
servenv.Init()

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(cc.Context())
servenv.OnClose(func() {
cancel()
})
Expand Down Expand Up @@ -282,7 +282,7 @@ func run(_ *cobra.Command, args []string) error {
return fmt.Errorf("Can't take backup: %w", err)
}
if doBackup {
if err := takeBackup(ctx, topoServer, backupStorage); err != nil {
if err := takeBackup(ctx, cc.Context(), topoServer, backupStorage); err != nil {
return fmt.Errorf("Failed to take backup: %w", err)
}
}
Expand All @@ -304,7 +304,7 @@ func run(_ *cobra.Command, args []string) error {
return nil
}

func takeBackup(ctx context.Context, topoServer *topo.Server, backupStorage backupstorage.BackupStorage) error {
func takeBackup(ctx, backgroundCtx context.Context, topoServer *topo.Server, backupStorage backupstorage.BackupStorage) error {
// This is an imaginary tablet alias. The value doesn't matter for anything,
// except that we generate a random UID to ensure the target backup
// directory is unique if multiple vtbackup instances are launched for the
Expand Down Expand Up @@ -344,9 +344,9 @@ func takeBackup(ctx context.Context, topoServer *topo.Server, backupStorage back
deprecatedDurationByPhase.Set("InitMySQLd", int64(time.Since(initMysqldAt).Seconds()))
// Shut down mysqld when we're done.
defer func() {
// Be careful not to use the original context, because we don't want to
// skip shutdown just because we timed out waiting for other things.
mysqlShutdownCtx, mysqlShutdownCancel := context.WithTimeout(context.Background(), mysqlShutdownTimeout+10*time.Second)
// Be careful use the background context, not the init one, because we don't want to
// skip shutdown just because we timed out waiting for init.
mysqlShutdownCtx, mysqlShutdownCancel := context.WithTimeout(backgroundCtx, mysqlShutdownTimeout+10*time.Second)
defer mysqlShutdownCancel()
if err := mysqld.Shutdown(mysqlShutdownCtx, mycnf, false, mysqlShutdownTimeout); err != nil {
log.Errorf("failed to shutdown mysqld: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtbench/cli/vtbench.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func run(cmd *cobra.Command, args []string) error {

b := vtbench.NewBench(threads, count, connParams, sql)

ctx, cancel := context.WithTimeout(context.Background(), deadline)
ctx, cancel := context.WithTimeout(cmd.Context(), deadline)
defer cancel()

fmt.Printf("Initializing test with %s protocol / %d threads / %d iterations\n",
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtclient/cli/vtclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func _run(cmd *cobra.Command, args []string) (*results, error) {

log.Infof("Sending the query...")

ctx, cancel := context.WithTimeout(context.Background(), timeout)
ctx, cancel := context.WithTimeout(cmd.Context(), timeout)
defer cancel()
return execMulti(ctx, db, cmd.Flags().Arg(0))
}
Expand Down
2 changes: 2 additions & 0 deletions go/cmd/vtclient/cli/vtclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package cli

import (
"context"
"fmt"
"os"
"strings"
Expand Down Expand Up @@ -129,6 +130,7 @@ func TestVtclient(t *testing.T) {
err := Main.ParseFlags(args)
require.NoError(t, err)

Main.SetContext(context.Background())
results, err := _run(Main, args)
if q.errMsg != "" {
if got, want := err.Error(), q.errMsg; !strings.Contains(got, want) {
Expand Down
51 changes: 23 additions & 28 deletions go/cmd/vtcombo/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ func init() {
srvTopoCounts = stats.NewCountersWithSingleLabel("ResilientSrvTopoServer", "Resilient srvtopo server operations", "type")
}

func startMysqld(uid uint32) (mysqld *mysqlctl.Mysqld, cnf *mysqlctl.Mycnf, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
func startMysqld(ctx context.Context, uid uint32) (mysqld *mysqlctl.Mysqld, cnf *mysqlctl.Mycnf, err error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

mycnfFile := mysqlctl.MycnfFile(uid)
Expand Down Expand Up @@ -189,17 +189,20 @@ func run(cmd *cobra.Command, args []string) (err error) {
cmd.Flags().Set("log_dir", "$VTDATAROOT/tmp")
}

ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates the toplevel context here and then uses it for example for creating the memory topo and other operations too.

Once this cancels, it ensures that things like the topo and watchers etc. are closed in the right order.

if externalTopoServer {
// Open topo server based on the command line flags defined at topo/server.go
// do not create cell info as it should be done by whoever sets up the external topo server
ts = topo.Open()
} else {
// Create topo server. We use a 'memorytopo' implementation.
ts = memorytopo.NewServer(context.Background(), tpb.Cells...)
ts = memorytopo.NewServer(ctx, tpb.Cells...)
}
defer ts.Close()

// attempt to load any routing rules specified by tpb
if err := vtcombo.InitRoutingRules(context.Background(), ts, tpb.GetRoutingRules()); err != nil {
if err := vtcombo.InitRoutingRules(ctx, ts, tpb.GetRoutingRules()); err != nil {
return fmt.Errorf("Failed to load routing rules: %w", err)
}

Expand All @@ -212,17 +215,17 @@ func run(cmd *cobra.Command, args []string) (err error) {
)

if startMysql {
mysqld.Mysqld, cnf, err = startMysqld(1)
mysqld.Mysqld, cnf, err = startMysqld(ctx, 1)
if err != nil {
return err
}
servenv.OnClose(func() {
ctx, cancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer cancel()
mysqld.Shutdown(ctx, cnf, true, mysqlctl.DefaultShutdownTimeout)
shutdownCtx, shutdownCancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer shutdownCancel()
mysqld.Shutdown(shutdownCtx, cnf, true, mysqlctl.DefaultShutdownTimeout)
})
// We want to ensure we can write to this database
mysqld.SetReadOnly(cmd.Context(), false)
mysqld.SetReadOnly(ctx, false)

} else {
dbconfigs.GlobalDBConfigs.InitWithSocket("", env.CollationEnv())
Expand All @@ -241,9 +244,9 @@ func run(cmd *cobra.Command, args []string) (err error) {
if err != nil {
// ensure we start mysql in the event we fail here
if startMysql {
ctx, cancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer cancel()
mysqld.Shutdown(ctx, cnf, true, mysqlctl.DefaultShutdownTimeout)
startCtx, startCancel := context.WithTimeout(ctx, mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer startCancel()
mysqld.Shutdown(startCtx, cnf, true, mysqlctl.DefaultShutdownTimeout)
}

return fmt.Errorf("initTabletMapProto failed: %w", err)
Expand Down Expand Up @@ -287,20 +290,21 @@ func run(cmd *cobra.Command, args []string) (err error) {

// Now that we have fully initialized the tablets, rebuild the keyspace graph.
for _, ks := range tpb.Keyspaces {
err := topotools.RebuildKeyspace(context.Background(), logutil.NewConsoleLogger(), ts, ks.GetName(), tpb.Cells, false)
err := topotools.RebuildKeyspace(cmd.Context(), logutil.NewConsoleLogger(), ts, ks.GetName(), tpb.Cells, false)
if err != nil {
if startMysql {
ctx, cancel := context.WithTimeout(context.Background(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer cancel()
mysqld.Shutdown(ctx, cnf, true, mysqlctl.DefaultShutdownTimeout)
shutdownCtx, shutdownCancel := context.WithTimeout(cmd.Context(), mysqlctl.DefaultShutdownTimeout+10*time.Second)
defer shutdownCancel()
mysqld.Shutdown(shutdownCtx, cnf, true, mysqlctl.DefaultShutdownTimeout)
}

return fmt.Errorf("Couldn't build srv keyspace for (%v: %v). Got error: %w", ks, tpb.Cells, err)
}
}

// vtgate configuration and init
resilientServer = srvtopo.NewResilientServer(context.Background(), ts, srvTopoCounts)

resilientServer = srvtopo.NewResilientServer(ctx, ts, srvTopoCounts)

tabletTypes := make([]topodatapb.TabletType, 0, 1)
if len(tabletTypesToWait) != 0 {
Expand All @@ -324,7 +328,7 @@ func run(cmd *cobra.Command, args []string) (err error) {
vtgate.QueryzHandler = "/debug/vtgate/queryz"

// pass nil for healthcheck, it will get created
vtg := vtgate.Init(context.Background(), env, nil, resilientServer, tpb.Cells[0], tabletTypes, plannerVersion)
vtg := vtgate.Init(ctx, env, nil, resilientServer, tpb.Cells[0], tabletTypes, plannerVersion)

// vtctld configuration and init
err = vtctld.InitVtctld(env, ts)
Expand All @@ -333,22 +337,13 @@ func run(cmd *cobra.Command, args []string) (err error) {
}

if vschemaPersistenceDir != "" && !externalTopoServer {
startVschemaWatcher(vschemaPersistenceDir, tpb.Keyspaces, ts)
startVschemaWatcher(ctx, vschemaPersistenceDir, ts)
}

servenv.OnRun(func() {
addStatusParts(vtg)
})

servenv.OnTerm(func() {
log.Error("Terminating")
// FIXME(alainjobart): stop vtgate
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually fixed now that the context is cancelled that is used to setup vtgate, so we can remove the TODO and the whole block here as it does nothing.

})
servenv.OnClose(func() {
// We will still use the topo server during lameduck period
// to update our state, so closing it in OnClose()
ts.Close()
})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closing the topo isn't a long running operation or anything, so we can do it with a defer when we create it, removing the need for this callback to be used here at all.

It doesn't change the order of things, since right after RunDefault below we return, and then we'd run the defer now to close ts as well.

servenv.RunDefault()

return nil
Expand Down
15 changes: 7 additions & 8 deletions go/cmd/vtcombo/cli/vschema_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,27 @@ import (
"vitess.io/vitess/go/vt/vtgate/vindexes"

vschemapb "vitess.io/vitess/go/vt/proto/vschema"
vttestpb "vitess.io/vitess/go/vt/proto/vttest"
)

func startVschemaWatcher(vschemaPersistenceDir string, keyspaces []*vttestpb.Keyspace, ts *topo.Server) {
func startVschemaWatcher(ctx context.Context, vschemaPersistenceDir string, ts *topo.Server) {
// Create the directory if it doesn't exist.
if err := createDirectoryIfNotExists(vschemaPersistenceDir); err != nil {
log.Fatalf("Unable to create vschema persistence directory %v: %v", vschemaPersistenceDir, err)
}

// If there are keyspace files, load them.
loadKeyspacesFromDir(vschemaPersistenceDir, keyspaces, ts)
loadKeyspacesFromDir(ctx, vschemaPersistenceDir, ts)

// Rebuild the SrvVSchema object in case we loaded vschema from file
if err := ts.RebuildSrvVSchema(context.Background(), tpb.Cells); err != nil {
if err := ts.RebuildSrvVSchema(ctx, tpb.Cells); err != nil {
log.Fatalf("RebuildSrvVSchema failed: %v", err)
}

// Now watch for changes in the SrvVSchema object and persist them to disk.
go watchSrvVSchema(context.Background(), ts, tpb.Cells[0])
go watchSrvVSchema(ctx, ts, tpb.Cells[0])
}

func loadKeyspacesFromDir(dir string, keyspaces []*vttestpb.Keyspace, ts *topo.Server) {
func loadKeyspacesFromDir(ctx context.Context, dir string, ts *topo.Server) {
for _, ks := range tpb.Keyspaces {
ksFile := path.Join(dir, ks.Name+".json")
if _, err := os.Stat(ksFile); err == nil {
Expand All @@ -67,14 +66,14 @@ func loadKeyspacesFromDir(dir string, keyspaces []*vttestpb.Keyspace, ts *topo.S
if err != nil {
log.Fatalf("Invalid keyspace definition: %v", err)
}
ts.SaveVSchema(context.Background(), ks.Name, keyspace)
ts.SaveVSchema(ctx, ks.Name, keyspace)
log.Infof("Loaded keyspace %v from %v\n", ks.Name, ksFile)
}
}
}

func watchSrvVSchema(ctx context.Context, ts *topo.Server, cell string) {
data, ch, err := ts.WatchSrvVSchema(context.Background(), tpb.Cells[0])
data, ch, err := ts.WatchSrvVSchema(ctx, tpb.Cells[0])
if err != nil {
log.Fatalf("WatchSrvVSchema failed: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtctld/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func run(cmd *cobra.Command, args []string) error {
vtctld.RegisterDebugHealthHandler(ts)

// Start schema manager service.
initSchema()
initSchema(cmd.Context())

// And run the server.
servenv.RunDefault()
Expand Down
3 changes: 1 addition & 2 deletions go/cmd/vtctld/cli/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func init() {
Main.Flags().DurationVar(&schemaChangeReplicasTimeout, "schema_change_replicas_timeout", schemaChangeReplicasTimeout, "How long to wait for replicas to receive a schema change.")
}

func initSchema() {
func initSchema(ctx context.Context) {
// Start schema manager service if needed.
if schemaChangeDir != "" {
interval := schemaChangeCheckInterval
Expand All @@ -70,7 +70,6 @@ func initSchema() {
log.Errorf("failed to get controller, error: %v", err)
return
}
ctx := context.Background()
wr := wrangler.New(env, logutil.NewConsoleLogger(), ts, tmclient.NewTabletManagerClient())
_, err = schemamanager.Run(
ctx,
Expand Down
6 changes: 3 additions & 3 deletions go/cmd/vtctldclient/command/legacy_shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var (
Args: cobra.ArbitraryArgs,
RunE: func(cmd *cobra.Command, args []string) error {
cli.FinishedParsing(cmd)
return runLegacyCommand(args)
return runLegacyCommand(cmd.Context(), args)
},
Long: strings.TrimSpace(`
LegacyVtctlCommand uses the legacy vtctl grpc client to make an ExecuteVtctlCommand
Expand Down Expand Up @@ -76,11 +76,11 @@ LegacyVtctlCommand -- AddCellInfo --server_address "localhost:5678" --root "/vit
}
)

func runLegacyCommand(args []string) error {
func runLegacyCommand(ctx context.Context, args []string) error {
// Duplicated (mostly) from go/cmd/vtctlclient/main.go.
logger := logutil.NewConsoleLogger()

ctx, cancel := context.WithTimeout(context.Background(), actionTimeout)
ctx, cancel := context.WithTimeout(ctx, actionTimeout)
defer cancel()

err := vtctlclient.RunCommandAndWait(ctx, server, args, func(e *logutilpb.Event) {
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtctldclient/command/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ connect directly to the topo server(s).`, useInternalVtctld),
client, err = getClientForCommand(cmd)
ctx := cmd.Context()
if ctx == nil {
ctx = context.Background()
ctx = cmd.Context()
}
commandCtx, commandCancel = context.WithTimeout(ctx, actionTimeout)
if compactOutput {
Expand Down
Loading
Loading