diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index ede88e2d9b8..2b37c865187 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -19,7 +19,9 @@ package vtgate import ( "context" "io" + "runtime/debug" "sync" + "sync/atomic" "time" "vitess.io/vitess/go/mysql/sqlerror" @@ -603,6 +605,12 @@ func (stc *ScatterConn) multiGo( return allErrors } +// panicData is used to capture panics during parallel execution. +type panicData struct { + p any + trace []byte +} + // multiGoTransaction performs the requested 'action' on the specified // ResolvedShards in parallel. For each shard, if the requested // session is in a transaction, it opens a new transactions on the connection, @@ -660,15 +668,28 @@ func (stc *ScatterConn) multiGoTransaction( oneShard(rs, i) } } else { + var panicRecord atomic.Value var wg sync.WaitGroup for i, rs := range rss { wg.Add(1) go func(rs *srvtopo.ResolvedShard, i int) { defer wg.Done() + defer func() { + if r := recover(); r != nil { + panicRecord.Store(&panicData{ + p: r, + trace: debug.Stack(), + }) + } + }() oneShard(rs, i) }(rs, i) } wg.Wait() + if pr, ok := panicRecord.Load().(*panicData); ok { + log.Errorf("caught a panic during parallel execution:\n%s", string(pr.trace)) + panic(pr.p) // rethrow the captured panic in the main thread + } } if session.MustRollback() { diff --git a/go/vt/vtgate/scatter_conn_test.go b/go/vt/vtgate/scatter_conn_test.go index 6e57c10bbbd..0e863805d9c 100644 --- a/go/vt/vtgate/scatter_conn_test.go +++ b/go/vt/vtgate/scatter_conn_test.go @@ -17,8 +17,11 @@ limitations under the License. package vtgate import ( + "fmt" "testing" + "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/mysql/sqlerror" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -105,6 +108,85 @@ func TestExecuteFailOnAutocommit(t *testing.T) { utils.MustMatch(t, []*querypb.BoundQuery{queries[1]}, sbc1.Queries, "") } +func TestExecutePanic(t *testing.T) { + ctx := utils.LeakCheckContext(t) + + createSandbox("TestExecutePanic") + hc := discovery.NewFakeHealthCheck(nil) + sc := newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa") + sbc0 := hc.AddTestTablet("aa", "0", 1, "TestExecutePanic", "0", topodatapb.TabletType_PRIMARY, true, 1, nil) + sbc1 := hc.AddTestTablet("aa", "1", 1, "TestExecutePanic", "1", topodatapb.TabletType_PRIMARY, true, 1, nil) + sbc0.SetPanic(42) + sbc1.SetPanic(42) + rss := []*srvtopo.ResolvedShard{ + { + Target: &querypb.Target{ + Keyspace: "TestExecutePanic", + Shard: "0", + TabletType: topodatapb.TabletType_PRIMARY, + }, + Gateway: sbc0, + }, + { + Target: &querypb.Target{ + Keyspace: "TestExecutePanic", + Shard: "1", + TabletType: topodatapb.TabletType_PRIMARY, + }, + Gateway: sbc1, + }, + } + queries := []*querypb.BoundQuery{ + { + // This will fail to go to shard. It will be rejected at vtgate. + Sql: "query1", + BindVariables: map[string]*querypb.BindVariable{ + "bv0": sqltypes.Int64BindVariable(0), + }, + }, + { + // This will go to shard. + Sql: "query2", + BindVariables: map[string]*querypb.BindVariable{ + "bv1": sqltypes.Int64BindVariable(1), + }, + }, + } + // shard 0 - has transaction + // shard 1 - does not have transaction. + session := &vtgatepb.Session{ + InTransaction: true, + ShardSessions: []*vtgatepb.Session_ShardSession{ + { + Target: &querypb.Target{Keyspace: "TestExecutePanic", Shard: "0", TabletType: topodatapb.TabletType_PRIMARY, Cell: "aa"}, + TransactionId: 123, + TabletAlias: nil, + }, + }, + Autocommit: false, + } + + original := log.Errorf + defer func() { + log.Errorf = original + }() + + var logMessage string + log.Errorf = func(format string, args ...any) { + logMessage = fmt.Sprintf(format, args...) + } + + defer func() { + r := recover() + require.NotNil(t, r, "The code did not panic") + // assert we are seeing the stack trace + require.Contains(t, logMessage, "(*ScatterConn).multiGoTransaction") + }() + + _, _ = sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false) + +} + func TestReservedOnMultiReplica(t *testing.T) { ctx := utils.LeakCheckContext(t) diff --git a/go/vt/vttablet/sandboxconn/sandboxconn.go b/go/vt/vttablet/sandboxconn/sandboxconn.go index ad9c1b3702f..2d0f5d9fff1 100644 --- a/go/vt/vttablet/sandboxconn/sandboxconn.go +++ b/go/vt/vttablet/sandboxconn/sandboxconn.go @@ -125,6 +125,9 @@ type SandboxConn struct { // this error will only happen once EphemeralShardErr error + // if this is not nil, any calls will panic the tablet + panicThis interface{} + NotServing bool getSchemaResult []map[string]string @@ -206,6 +209,7 @@ func (sbc *SandboxConn) SetSchemaResult(r []map[string]string) { // Execute is part of the QueryService interface. func (sbc *SandboxConn) Execute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) { + sbc.panicIfNeeded() sbc.execMu.Lock() defer sbc.execMu.Unlock() sbc.ExecCount.Add(1) @@ -238,6 +242,7 @@ func (sbc *SandboxConn) Execute(ctx context.Context, target *querypb.Target, que // StreamExecute is part of the QueryService interface. func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { + sbc.panicIfNeeded() sbc.sExecMu.Lock() sbc.ExecCount.Add(1) bv := make(map[string]*querypb.BindVariable) @@ -278,6 +283,7 @@ func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Targe // Begin is part of the QueryService interface. func (sbc *SandboxConn) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (queryservice.TransactionState, error) { + sbc.panicIfNeeded() return sbc.begin(ctx, target, nil, 0, options) } @@ -303,6 +309,7 @@ func (sbc *SandboxConn) begin(ctx context.Context, target *querypb.Target, preQu // Commit is part of the QueryService interface. func (sbc *SandboxConn) Commit(ctx context.Context, target *querypb.Target, transactionID int64) (int64, error) { + sbc.panicIfNeeded() sbc.CommitCount.Add(1) reservedID := sbc.getTxReservedID(transactionID) if reservedID != 0 { @@ -323,6 +330,7 @@ func (sbc *SandboxConn) Rollback(ctx context.Context, target *querypb.Target, tr // Prepare prepares the specified transaction. func (sbc *SandboxConn) Prepare(ctx context.Context, target *querypb.Target, transactionID int64, dtid string) (err error) { + sbc.panicIfNeeded() sbc.PrepareCount.Add(1) if sbc.MustFailPrepare > 0 { sbc.MustFailPrepare-- @@ -333,6 +341,7 @@ func (sbc *SandboxConn) Prepare(ctx context.Context, target *querypb.Target, tra // CommitPrepared commits the prepared transaction. func (sbc *SandboxConn) CommitPrepared(ctx context.Context, target *querypb.Target, dtid string) (err error) { + sbc.panicIfNeeded() sbc.CommitPreparedCount.Add(1) if sbc.MustFailCommitPrepared > 0 { sbc.MustFailCommitPrepared-- @@ -343,6 +352,7 @@ func (sbc *SandboxConn) CommitPrepared(ctx context.Context, target *querypb.Targ // RollbackPrepared rolls back the prepared transaction. func (sbc *SandboxConn) RollbackPrepared(ctx context.Context, target *querypb.Target, dtid string, originalID int64) (err error) { + sbc.panicIfNeeded() sbc.RollbackPreparedCount.Add(1) if sbc.MustFailRollbackPrepared > 0 { sbc.MustFailRollbackPrepared-- @@ -364,6 +374,7 @@ func (sbc *SandboxConn) CreateTransaction(ctx context.Context, target *querypb.T // StartCommit atomically commits the transaction along with the // decision to commit the associated 2pc transaction. func (sbc *SandboxConn) StartCommit(ctx context.Context, target *querypb.Target, transactionID int64, dtid string) (err error) { + sbc.panicIfNeeded() sbc.StartCommitCount.Add(1) if sbc.MustFailStartCommit > 0 { sbc.MustFailStartCommit-- @@ -375,6 +386,7 @@ func (sbc *SandboxConn) StartCommit(ctx context.Context, target *querypb.Target, // SetRollback transitions the 2pc transaction to the Rollback state. // If a transaction id is provided, that transaction is also rolled back. func (sbc *SandboxConn) SetRollback(ctx context.Context, target *querypb.Target, dtid string, transactionID int64) (err error) { + sbc.panicIfNeeded() sbc.SetRollbackCount.Add(1) if sbc.MustFailSetRollback > 0 { sbc.MustFailSetRollback-- @@ -410,6 +422,7 @@ func (sbc *SandboxConn) ReadTransaction(ctx context.Context, target *querypb.Tar // BeginExecute is part of the QueryService interface. func (sbc *SandboxConn) BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (queryservice.TransactionState, *sqltypes.Result, error) { + sbc.panicIfNeeded() state, err := sbc.begin(ctx, target, preQueries, reservedID, options) if state.TransactionID != 0 { sbc.setTxReservedID(state.TransactionID, reservedID) @@ -423,6 +436,7 @@ func (sbc *SandboxConn) BeginExecute(ctx context.Context, target *querypb.Target // BeginStreamExecute is part of the QueryService interface. func (sbc *SandboxConn) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.TransactionState, error) { + sbc.panicIfNeeded() state, err := sbc.begin(ctx, target, preQueries, reservedID, options) if state.TransactionID != 0 { sbc.setTxReservedID(state.TransactionID, reservedID) @@ -567,6 +581,7 @@ func (sbc *SandboxConn) HandlePanic(err *error) { // ReserveBeginExecute implements the QueryService interface func (sbc *SandboxConn) ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (queryservice.ReservedTransactionState, *sqltypes.Result, error) { + sbc.panicIfNeeded() reservedID := sbc.reserve(ctx, target, preQueries, bindVariables, 0, options) state, result, err := sbc.BeginExecute(ctx, target, postBeginQueries, sql, bindVariables, reservedID, options) if state.TransactionID != 0 { @@ -581,6 +596,7 @@ func (sbc *SandboxConn) ReserveBeginExecute(ctx context.Context, target *querypb // ReserveBeginStreamExecute is part of the QueryService interface. func (sbc *SandboxConn) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.ReservedTransactionState, error) { + sbc.panicIfNeeded() reservedID := sbc.reserve(ctx, target, preQueries, bindVariables, 0, options) state, err := sbc.BeginStreamExecute(ctx, target, postBeginQueries, sql, bindVariables, reservedID, options, callback) if state.TransactionID != 0 { @@ -595,6 +611,7 @@ func (sbc *SandboxConn) ReserveBeginStreamExecute(ctx context.Context, target *q // ReserveExecute implements the QueryService interface func (sbc *SandboxConn) ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (queryservice.ReservedState, *sqltypes.Result, error) { + sbc.panicIfNeeded() reservedID := sbc.reserve(ctx, target, preQueries, bindVariables, transactionID, options) result, err := sbc.Execute(ctx, target, sql, bindVariables, transactionID, reservedID, options) if transactionID != 0 { @@ -608,6 +625,7 @@ func (sbc *SandboxConn) ReserveExecute(ctx context.Context, target *querypb.Targ // ReserveStreamExecute is part of the QueryService interface. func (sbc *SandboxConn) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.ReservedState, error) { + sbc.panicIfNeeded() reservedID := sbc.reserve(ctx, target, preQueries, bindVariables, transactionID, options) err := sbc.StreamExecute(ctx, target, sql, bindVariables, transactionID, reservedID, options, callback) if transactionID != 0 { @@ -769,3 +787,13 @@ var StreamRowResult = &sqltypes.Result{ sqltypes.NewVarChar("foo"), }}, } + +func (sbc *SandboxConn) SetPanic(i interface{}) { + sbc.panicThis = i +} + +func (sbc *SandboxConn) panicIfNeeded() { + if sbc.panicThis != nil { + panic(sbc.panicThis) + } +}